diff --git a/internal/command/wireguard/others.go b/internal/command/wireguard/others.go index 36d03ae6f6..5656652766 100644 --- a/internal/command/wireguard/others.go +++ b/internal/command/wireguard/others.go @@ -17,12 +17,11 @@ import ( ) func argOrPrompt(ctx context.Context, nth int, prompt string) (string, error) { - args := flag.Args(ctx) - if len(args) >= (nth + 1) { - return args[nth], nil + val := flag.GetArg(ctx, nth) + if val != "" { + return val, nil } - val := "" err := survey.AskOne( &survey.Input{Message: prompt}, &val, @@ -32,48 +31,42 @@ func argOrPrompt(ctx context.Context, nth int, prompt string) (string, error) { } func orgByArg(ctx context.Context) (*fly.Organization, error) { - args := flag.Args(ctx) - - if len(args) == 0 { - org, err := prompt.Org(ctx) - if err != nil { - return nil, err - } - - return org, nil + org := flag.FirstArg(ctx) + if org != "" { + apiClient := flyutil.ClientFromContext(ctx) + return apiClient.GetOrganizationBySlug(ctx, org) } - apiClient := flyutil.ClientFromContext(ctx) - return apiClient.GetOrganizationBySlug(ctx, args[0]) + return prompt.Org(ctx) } func resolveOutputWriter(ctx context.Context, idx int, prompt string) (w io.WriteCloser, mustClose bool, err error) { io := iostreams.FromContext(ctx) var f *os.File - var filename string + filename := flag.GetArg(ctx, idx) for { - filename, err = argOrPrompt(ctx, idx, prompt) - if err != nil { - return nil, false, err - } - - if filename == "" { - fmt.Fprintln(io.Out, "Provide a filename (or 'stdout')") - continue - } - if filename == "stdout" { return os.Stdout, false, nil } - f, err = os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600) - if err == nil { - return f, true, nil + if filename != "" { + f, err = os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600) + if err == nil { + return f, true, nil + } + fmt.Fprintf(io.Out, "Can't create '%s': %s\n", filename, err) } - fmt.Fprintf(io.Out, "Can't create '%s': %s\n", filename, err) + err := survey.AskOne( + &survey.Input{Message: prompt}, + &filename, + ) + if err != nil { + return nil, false, err + } } + } func generateWgConf(peer *fly.CreatedWireGuardPeer, privkey string, w io.Writer) { diff --git a/internal/command/wireguard/wireguard.go b/internal/command/wireguard/wireguard.go index 76cb97eb8d..1901f138b7 100644 --- a/internal/command/wireguard/wireguard.go +++ b/internal/command/wireguard/wireguard.go @@ -129,17 +129,8 @@ func runWireguardCreate(ctx context.Context) error { return err } - args := flag.Args(ctx) - var region string - var name string - - if len(args) > 1 && args[1] != "" { - region = args[1] - } - - if len(args) > 2 && args[2] != "" { - name = args[2] - } + region := flag.GetArg(ctx, 1) + name := flag.GetArg(ctx, 2) network := flag.GetString(ctx, "network") @@ -183,11 +174,8 @@ func runWireguardRemove(ctx context.Context) error { return err } - args := flag.Args(ctx) - var name string - if len(args) >= 2 { - name = args[1] - } else { + name := flag.GetArg(ctx, 1) + if name == "" { name, err = selectWireGuardPeer(ctx, apiClient, org.Slug) if err != nil { return err diff --git a/internal/flag/context.go b/internal/flag/context.go index 2d038f709f..cc615d4828 100644 --- a/internal/flag/context.go +++ b/internal/flag/context.go @@ -28,14 +28,18 @@ func Args(ctx context.Context) []string { return FromContext(ctx).Args() } +// GetArg returns argument specified by zero-based idx or an empty string. +func GetArg(ctx context.Context, idx int) string { + if args := Args(ctx); len(args) > idx { + return args[idx] + } + return "" +} + // FirstArg returns the first arg ctx carries or an empty string in case ctx // carries an empty argument set. It panics in case ctx carries no FlagSet. func FirstArg(ctx context.Context) string { - if args := Args(ctx); len(args) > 0 { - return args[0] - } - - return "" + return GetArg(ctx, 0) } // GetString returns the value of the named string flag ctx carries.