diff --git a/flag/flag.go b/flag/flag.go index d7fd3ea..8f9d210 100644 --- a/flag/flag.go +++ b/flag/flag.go @@ -14,10 +14,12 @@ // typically in the `main` function of the application. If the `--help` flag is // set, it prints usage information and exits. // -// Additional flags can be registered using `RegisterFlag`, which accepts the flag +// Additional flags can be registered using `Register`, which accepts the flag // name, a pointer to the variable to populate, and a usage description. // Existing flags can be overridden using `Override`, which allows changing the // variable, default value, and description of an already registered flag. +// Flags can be removed using `Unregister`, which removes a previously +// registered flag from the command line. // Supported types include strings, booleans, integers, unsigned integers, and floats. // // Example: @@ -32,12 +34,15 @@ // var CustomFlag string // // func main() { -// flag.RegisterFlag("custom", &CustomFlag, "A custom flag for demonstration") +// flag.Register("custom", &CustomFlag, "A custom flag for demonstration") // // // Override the default path flag // flag.Path = "/new/default/path" // flag.Override("path", &flag.Path, "Updated application working directory") // +// // Unregister a flag if no longer needed +// flag.Unregister("custom") +// // flag.Init() // // fmt.Println("Custom Flag Value:", CustomFlag) @@ -83,9 +88,9 @@ func PrintHelp() { pflag.PrintDefaults() } -// RegisterFlag registers a new flag with the given name, value and usage +// Register registers a new flag with the given name, value and usage // It panics if the flag is already registered or if the value is not a pointer -func RegisterFlag(name string, value interface{}, usage string) { +func Register(name string, value interface{}, usage string) { if pflag.Lookup(name) != nil { panic(fmt.Sprintf("flag %s already registered", name)) } @@ -198,3 +203,26 @@ func Override(name string, value interface{}, usage string) { pflag.CommandLine = newCommandLine } + +// Unregister removes a previously registered flag +// It panics if the flag is not registered or if flags have already been parsed +func Unregister(name string) { + if pflag.Lookup(name) == nil { + panic(fmt.Sprintf("flag %s is not registered", name)) + } + + if pflag.Parsed() { + panic(fmt.Sprintf("cannot unregister flag %s after flags have been parsed", name)) + } + + newCommandLine := pflag.NewFlagSet("", pflag.ContinueOnError) + + // Copy all flags except the one we're unregistering + pflag.CommandLine.VisitAll(func(flag *pflag.Flag) { + if flag.Name != name { + newCommandLine.AddFlag(flag) + } + }) + + pflag.CommandLine = newCommandLine +} diff --git a/flag/flag_test.go b/flag/flag_test.go index 5d2348e..62e9028 100644 --- a/flag/flag_test.go +++ b/flag/flag_test.go @@ -30,62 +30,62 @@ func TestDefaultFlags(t *testing.T) { func TestRegisterFlag(_ *testing.T) { // Test registering a string flag var stringFlag string - flag.RegisterFlag("test-string", &stringFlag, "A test string flag") + flag.Register("test-string", &stringFlag, "A test string flag") // Test registering a bool flag var boolFlag bool - flag.RegisterFlag("test-bool", &boolFlag, "A test bool flag") + flag.Register("test-bool", &boolFlag, "A test bool flag") // Test registering an int flag var intFlag int - flag.RegisterFlag("test-int", &intFlag, "A test int flag") + flag.Register("test-int", &intFlag, "A test int flag") // Test registering various numeric types var int8Flag int8 - flag.RegisterFlag("test-int8", &int8Flag, "A test int8 flag") + flag.Register("test-int8", &int8Flag, "A test int8 flag") var int16Flag int16 - flag.RegisterFlag("test-int16", &int16Flag, "A test int16 flag") + flag.Register("test-int16", &int16Flag, "A test int16 flag") var int32Flag int32 - flag.RegisterFlag("test-int32", &int32Flag, "A test int32 flag") + flag.Register("test-int32", &int32Flag, "A test int32 flag") var int64Flag int64 - flag.RegisterFlag("test-int64", &int64Flag, "A test int64 flag") + flag.Register("test-int64", &int64Flag, "A test int64 flag") var uintFlag uint - flag.RegisterFlag("test-uint", &uintFlag, "A test uint flag") + flag.Register("test-uint", &uintFlag, "A test uint flag") var uint8Flag uint8 - flag.RegisterFlag("test-uint8", &uint8Flag, "A test uint8 flag") + flag.Register("test-uint8", &uint8Flag, "A test uint8 flag") var uint16Flag uint16 - flag.RegisterFlag("test-uint16", &uint16Flag, "A test uint16 flag") + flag.Register("test-uint16", &uint16Flag, "A test uint16 flag") var uint32Flag uint32 - flag.RegisterFlag("test-uint32", &uint32Flag, "A test uint32 flag") + flag.Register("test-uint32", &uint32Flag, "A test uint32 flag") var uint64Flag uint64 - flag.RegisterFlag("test-uint64", &uint64Flag, "A test uint64 flag") + flag.Register("test-uint64", &uint64Flag, "A test uint64 flag") var float32Flag float32 - flag.RegisterFlag("test-float32", &float32Flag, "A test float32 flag") + flag.Register("test-float32", &float32Flag, "A test float32 flag") var float64Flag float64 - flag.RegisterFlag("test-float64", &float64Flag, "A test float64 flag") + flag.Register("test-float64", &float64Flag, "A test float64 flag") } func TestRegisterFlagPanics(t *testing.T) { // Test that registering a duplicate flag panics var testFlag string - flag.RegisterFlag("unique-flag", &testFlag, "A unique flag") + flag.Register("unique-flag", &testFlag, "A unique flag") defer func() { if r := recover(); r == nil { t.Error("Expected panic when registering duplicate flag") } }() - flag.RegisterFlag("unique-flag", &testFlag, "A duplicate flag") + flag.Register("unique-flag", &testFlag, "A duplicate flag") } func TestRegisterFlagNonPointer(t *testing.T) { @@ -95,7 +95,7 @@ func TestRegisterFlagNonPointer(t *testing.T) { } }() var testFlag string - flag.RegisterFlag("non-pointer", testFlag, "A non-pointer flag") + flag.Register("non-pointer", testFlag, "A non-pointer flag") } func TestRegisterFlagNilPointer(t *testing.T) { @@ -105,7 +105,7 @@ func TestRegisterFlagNilPointer(t *testing.T) { } }() var testFlag *string - flag.RegisterFlag("nil-pointer", testFlag, "A nil pointer flag") + flag.Register("nil-pointer", testFlag, "A nil pointer flag") } func TestRegisterFlagUnsupportedType(t *testing.T) { @@ -115,7 +115,7 @@ func TestRegisterFlagUnsupportedType(t *testing.T) { } }() var testFlag []string - flag.RegisterFlag("unsupported", &testFlag, "An unsupported type flag") + flag.Register("unsupported", &testFlag, "An unsupported type flag") } func TestInit(_ *testing.T) { @@ -137,16 +137,16 @@ func TestInit(_ *testing.T) { // Test flag registration with default values func TestRegisterFlagWithDefaults(_ *testing.T) { var stringFlag = "default" - flag.RegisterFlag("default-string", &stringFlag, "A string flag with default") + flag.Register("default-string", &stringFlag, "A string flag with default") var intFlag = 42 - flag.RegisterFlag("default-int", &intFlag, "An int flag with default") + flag.Register("default-int", &intFlag, "An int flag with default") var boolFlag = true - flag.RegisterFlag("default-bool", &boolFlag, "A bool flag with default") + flag.Register("default-bool", &boolFlag, "A bool flag with default") var float64Flag = 3.14 - flag.RegisterFlag("default-float64", &float64Flag, "A float64 flag with default") + flag.Register("default-float64", &float64Flag, "A float64 flag with default") } // Test integration with actual command line parsing @@ -160,9 +160,9 @@ func TestCommandLineIntegration(_ *testing.T) { var testInt int var testBool bool - flag.RegisterFlag("integration-string", &testString, "Integration test string") - flag.RegisterFlag("integration-int", &testInt, "Integration test int") - flag.RegisterFlag("integration-bool", &testBool, "Integration test bool") + flag.Register("integration-string", &testString, "Integration test string") + flag.Register("integration-int", &testInt, "Integration test int") + flag.Register("integration-bool", &testBool, "Integration test bool") // Simulate command line arguments os.Args = []string{ @@ -181,7 +181,7 @@ func TestCommandLineIntegration(_ *testing.T) { // Test that flags are properly bound to pflag func TestFlagBinding(_ *testing.T) { var testFlag string - flag.RegisterFlag("binding-test", &testFlag, "A binding test flag") + flag.Register("binding-test", &testFlag, "A binding test flag") // This mainly tests that the function completes without error // Actual binding verification would require more complex setup @@ -197,7 +197,7 @@ func TestOverrideFlag(t *testing.T) { // First register a flag var originalFlag string = "original" - flag.RegisterFlag("override-test", &originalFlag, "Original description") + flag.Register("override-test", &originalFlag, "Original description") // Override it with new variable, value and description var newFlag string = "new_default" @@ -247,7 +247,7 @@ func TestOverrideFlagNonPointer(t *testing.T) { // First register a flag to override var originalFlag string - flag.RegisterFlag("override-non-pointer", &originalFlag, "Original flag") + flag.Register("override-non-pointer", &originalFlag, "Original flag") defer func() { if r := recover(); r == nil { @@ -268,7 +268,7 @@ func TestOverrideFlagNilPointer(t *testing.T) { // First register a flag to override var originalFlag string - flag.RegisterFlag("override-nil-pointer", &originalFlag, "Original flag") + flag.Register("override-nil-pointer", &originalFlag, "Original flag") defer func() { if r := recover(); r == nil { @@ -289,7 +289,7 @@ func TestOverrideFlagAfterParse(t *testing.T) { // Register a flag and parse var originalFlag string - flag.RegisterFlag("override-after-parse", &originalFlag, "Original flag") + flag.Register("override-after-parse", &originalFlag, "Original flag") flag.Init() // This will parse the flags defer func() { @@ -313,7 +313,7 @@ func TestOverrideUnsupportedType(t *testing.T) { // First register a flag to override var originalFlag string - flag.RegisterFlag("override-unsupported", &originalFlag, "Original flag") + flag.Register("override-unsupported", &originalFlag, "Original flag") defer func() { if r := recover(); r == nil { @@ -323,3 +323,110 @@ func TestOverrideUnsupportedType(t *testing.T) { var testSlice []string flag.Override("override-unsupported", &testSlice, "An unsupported type override") } + +func TestUnregisterFlag(t *testing.T) { + // Save and restore the original command line + originalCommandLine := pflag.CommandLine + defer func() { pflag.CommandLine = originalCommandLine }() + + // Create a fresh command line for this test + pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError) + + // Register a flag to unregister + var testFlag string + flag.Register("unregister-test", &testFlag, "A flag to unregister") + + // Verify the flag exists + if pflag.Lookup("unregister-test") == nil { + t.Error("Expected flag to be registered") + return + } + + // Unregister the flag + flag.Unregister("unregister-test") + + // Verify the flag no longer exists + if pflag.Lookup("unregister-test") != nil { + t.Error("Expected flag to be unregistered") + } +} + +func TestUnregisterFlagPanics(t *testing.T) { + // Save and restore the original command line + originalCommandLine := pflag.CommandLine + defer func() { pflag.CommandLine = originalCommandLine }() + + // Create a fresh command line for this test + pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError) + + // Test that unregistering a non-existent flag panics + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic when unregistering non-existent flag") + } + }() + flag.Unregister("non-existent-flag") +} + +func TestUnregisterFlagAfterParse(t *testing.T) { + // Save and restore the original command line + originalCommandLine := pflag.CommandLine + defer func() { pflag.CommandLine = originalCommandLine }() + + // Create a fresh command line for this test + pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError) + + // Register a flag and parse + var testFlag string + flag.Register("unregister-after-parse", &testFlag, "A flag to unregister after parse") + flag.Init() // This will parse the flags + + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic when unregistering after parsing") + } + }() + + // Try to unregister after parsing - should panic + flag.Unregister("unregister-after-parse") +} + +func TestUnregisterMultipleFlags(t *testing.T) { + // Save and restore the original command line + originalCommandLine := pflag.CommandLine + defer func() { pflag.CommandLine = originalCommandLine }() + + // Create a fresh command line for this test + pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError) + + // Register multiple flags + var flag1, flag2, flag3 string + flag.Register("unregister-multi-1", &flag1, "First flag") + flag.Register("unregister-multi-2", &flag2, "Second flag") + flag.Register("unregister-multi-3", &flag3, "Third flag") + + // Verify all flags exist + if pflag.Lookup("unregister-multi-1") == nil { + t.Error("Expected flag1 to be registered") + } + if pflag.Lookup("unregister-multi-2") == nil { + t.Error("Expected flag2 to be registered") + } + if pflag.Lookup("unregister-multi-3") == nil { + t.Error("Expected flag3 to be registered") + } + + // Unregister the middle flag + flag.Unregister("unregister-multi-2") + + // Verify only the middle flag is unregistered + if pflag.Lookup("unregister-multi-1") == nil { + t.Error("Expected flag1 to still be registered") + } + if pflag.Lookup("unregister-multi-2") != nil { + t.Error("Expected flag2 to be unregistered") + } + if pflag.Lookup("unregister-multi-3") == nil { + t.Error("Expected flag3 to still be registered") + } +}