Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions command_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,18 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context
}()
}

for _, grp := range cmd.MutuallyExclusiveFlags {
if err := grp.check(cmd); err != nil {
if cmd.OnUsageError != nil {
err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil)
} else {
_ = ShowSubcommandHelp(cmd)
// Walk the parent chain to check mutually exclusive flag groups
// defined on ancestor commands, since persistent flags are inherited.
for pCmd := cmd; pCmd != nil; pCmd = pCmd.parent {
for _, grp := range pCmd.MutuallyExclusiveFlags {
if err := grp.check(cmd); err != nil {
if cmd.OnUsageError != nil {
err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil)
} else {
_ = ShowSubcommandHelp(cmd)
}
return ctx, err
}
return ctx, err
}
}

Expand Down
117 changes: 89 additions & 28 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5445,39 +5445,100 @@ func TestCommand_ParallelRun(t *testing.T) {
}
}

func TestCommand_ExclusiveFlagsPersistentPropagation(t *testing.T) {
var subCmdAlphaValue string
func TestCommand_ExclusiveFlagsPersistent(t *testing.T) {
exclusiveGroup := func(flags ...string) []MutuallyExclusiveFlags {
grp := MutuallyExclusiveFlags{}
for _, name := range flags {
grp.Flags = append(grp.Flags, []Flag{&StringFlag{Name: name}})
}
return []MutuallyExclusiveFlags{grp}
}

cmd := &Command{
Name: "root",
MutuallyExclusiveFlags: []MutuallyExclusiveFlags{
{
Flags: [][]Flag{
{
&StringFlag{
Name: "alpha",
},
},
{
&StringFlag{
Name: "beta",
},
},
},
noop := func(_ context.Context, _ *Command) error { return nil }

newBaseCmd := func() *Command {
return &Command{
Name: "root",
MutuallyExclusiveFlags: exclusiveGroup("alpha", "beta"),
Commands: []*Command{{Name: "sub", Action: noop}},
}
}

tests := []struct {
name string
setup func() *Command
args []string
wantErr string
}{
{
name: "single flag propagated to subcommand",
setup: newBaseCmd,
args: []string{"root", "sub", "--alpha", "hello"},
},
{
name: "both exclusive flags on subcommand errors",
setup: newBaseCmd,
args: []string{"root", "sub", "--alpha", "hello", "--beta", "world"},
wantErr: "cannot be set along with",
},
{
name: "neither flag set without required is ok",
setup: newBaseCmd,
args: []string{"root", "sub"},
},
{
name: "exclusive flags checked on grandchild",
setup: func() *Command {
cmd := newBaseCmd()
sub := cmd.Commands[0]
sub.Name = "mid"
sub.Action = nil
sub.Commands = []*Command{{Name: "leaf", Action: noop}}
return cmd
},
args: []string{"root", "mid", "leaf", "--alpha", "hello", "--beta", "world"},
wantErr: "cannot be set along with",
},
{
name: "subcommand own group checked alongside parent group",
setup: func() *Command {
cmd := newBaseCmd()
cmd.Commands[0].MutuallyExclusiveFlags = exclusiveGroup("gamma", "delta")
return cmd
},
args: []string{"root", "sub", "--gamma", "hello", "--delta", "world"},
wantErr: "cannot be set along with",
},
Commands: []*Command{
{
Name: "sub",
Action: func(_ context.Context, cmd *Command) error {
subCmdAlphaValue = cmd.String("alpha")
return nil
},
{
name: "parent group violation detected when subcommand has own group",
setup: func() *Command {
cmd := newBaseCmd()
cmd.Commands[0].MutuallyExclusiveFlags = exclusiveGroup("gamma", "delta")
return cmd
},
args: []string{"root", "sub", "--alpha", "hello", "--beta", "world"},
wantErr: "cannot be set along with",
},
{
name: "parent and subcommand groups both pass independently",
setup: func() *Command {
cmd := newBaseCmd()
cmd.Commands[0].MutuallyExclusiveFlags = exclusiveGroup("gamma", "delta")
return cmd
},
args: []string{"root", "sub", "--alpha", "hello", "--gamma", "world"},
},
}

err := cmd.Run(buildTestContext(t), []string{"root", "sub", "--alpha", "hello"})
require.NoError(t, err)
assert.Equal(t, "hello", subCmdAlphaValue)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.setup().Run(buildTestContext(t), tt.args)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}