diff --git a/command_run.go b/command_run.go index e5cfff8a57..676a14c676 100644 --- a/command_run.go +++ b/command_run.go @@ -237,14 +237,18 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context }() } - for _, grp := range cmd.MutuallyExclusiveFlags { - if err := grp.check(cmd); err != nil { - if cmd.OnUsageError != nil { - err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) - } else { - _ = ShowSubcommandHelp(cmd) + // Walk the parent chain to check mutually exclusive flag groups + // defined on ancestor commands, since persistent flags are inherited. + for pCmd := cmd; pCmd != nil; pCmd = pCmd.parent { + for _, grp := range pCmd.MutuallyExclusiveFlags { + if err := grp.check(cmd); err != nil { + if cmd.OnUsageError != nil { + err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) + } else { + _ = ShowSubcommandHelp(cmd) + } + return ctx, err } - return ctx, err } } diff --git a/command_test.go b/command_test.go index 08a410a6fe..b3cfbc4788 100644 --- a/command_test.go +++ b/command_test.go @@ -5445,39 +5445,100 @@ func TestCommand_ParallelRun(t *testing.T) { } } -func TestCommand_ExclusiveFlagsPersistentPropagation(t *testing.T) { - var subCmdAlphaValue string +func TestCommand_ExclusiveFlagsPersistent(t *testing.T) { + exclusiveGroup := func(flags ...string) []MutuallyExclusiveFlags { + grp := MutuallyExclusiveFlags{} + for _, name := range flags { + grp.Flags = append(grp.Flags, []Flag{&StringFlag{Name: name}}) + } + return []MutuallyExclusiveFlags{grp} + } - cmd := &Command{ - Name: "root", - MutuallyExclusiveFlags: []MutuallyExclusiveFlags{ - { - Flags: [][]Flag{ - { - &StringFlag{ - Name: "alpha", - }, - }, - { - &StringFlag{ - Name: "beta", - }, - }, - }, + noop := func(_ context.Context, _ *Command) error { return nil } + + newBaseCmd := func() *Command { + return &Command{ + Name: "root", + MutuallyExclusiveFlags: exclusiveGroup("alpha", "beta"), + Commands: []*Command{{Name: "sub", Action: noop}}, + } + } + + tests := []struct { + name string + setup func() *Command + args []string + wantErr string + }{ + { + name: "single flag propagated to subcommand", + setup: newBaseCmd, + args: []string{"root", "sub", "--alpha", "hello"}, + }, + { + name: "both exclusive flags on subcommand errors", + setup: newBaseCmd, + args: []string{"root", "sub", "--alpha", "hello", "--beta", "world"}, + wantErr: "cannot be set along with", + }, + { + name: "neither flag set without required is ok", + setup: newBaseCmd, + args: []string{"root", "sub"}, + }, + { + name: "exclusive flags checked on grandchild", + setup: func() *Command { + cmd := newBaseCmd() + sub := cmd.Commands[0] + sub.Name = "mid" + sub.Action = nil + sub.Commands = []*Command{{Name: "leaf", Action: noop}} + return cmd + }, + args: []string{"root", "mid", "leaf", "--alpha", "hello", "--beta", "world"}, + wantErr: "cannot be set along with", + }, + { + name: "subcommand own group checked alongside parent group", + setup: func() *Command { + cmd := newBaseCmd() + cmd.Commands[0].MutuallyExclusiveFlags = exclusiveGroup("gamma", "delta") + return cmd }, + args: []string{"root", "sub", "--gamma", "hello", "--delta", "world"}, + wantErr: "cannot be set along with", }, - Commands: []*Command{ - { - Name: "sub", - Action: func(_ context.Context, cmd *Command) error { - subCmdAlphaValue = cmd.String("alpha") - return nil - }, + { + name: "parent group violation detected when subcommand has own group", + setup: func() *Command { + cmd := newBaseCmd() + cmd.Commands[0].MutuallyExclusiveFlags = exclusiveGroup("gamma", "delta") + return cmd + }, + args: []string{"root", "sub", "--alpha", "hello", "--beta", "world"}, + wantErr: "cannot be set along with", + }, + { + name: "parent and subcommand groups both pass independently", + setup: func() *Command { + cmd := newBaseCmd() + cmd.Commands[0].MutuallyExclusiveFlags = exclusiveGroup("gamma", "delta") + return cmd }, + args: []string{"root", "sub", "--alpha", "hello", "--gamma", "world"}, }, } - err := cmd.Run(buildTestContext(t), []string{"root", "sub", "--alpha", "hello"}) - require.NoError(t, err) - assert.Equal(t, "hello", subCmdAlphaValue) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.setup().Run(buildTestContext(t), tt.args) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } }