Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
// typically in the `main` function of the application. If the `--help` flag is
// set, it prints usage information and exits.
//
// Additional flags can be registered using `RegisterFlag`, which accepts the flag
// Additional flags can be registered using `Register`, which accepts the flag
// name, a pointer to the variable to populate, and a usage description.
// Existing flags can be overridden using `Override`, which allows changing the
// variable, default value, and description of an already registered flag.
// Flags can be removed using `Unregister`, which removes a previously
// registered flag from the command line.
// Supported types include strings, booleans, integers, unsigned integers, and floats.
//
// Example:
Expand All @@ -32,12 +34,15 @@
// var CustomFlag string
//
// func main() {
// flag.RegisterFlag("custom", &CustomFlag, "A custom flag for demonstration")
// flag.Register("custom", &CustomFlag, "A custom flag for demonstration")
//
// // Override the default path flag
// flag.Path = "/new/default/path"
// flag.Override("path", &flag.Path, "Updated application working directory")
//
// // Unregister a flag if no longer needed
// flag.Unregister("custom")
//
// flag.Init()
//
// fmt.Println("Custom Flag Value:", CustomFlag)
Expand Down Expand Up @@ -83,9 +88,9 @@ func PrintHelp() {
pflag.PrintDefaults()
}

// RegisterFlag registers a new flag with the given name, value and usage
// Register registers a new flag with the given name, value and usage
// It panics if the flag is already registered or if the value is not a pointer
func RegisterFlag(name string, value interface{}, usage string) {
func Register(name string, value interface{}, usage string) {
if pflag.Lookup(name) != nil {
panic(fmt.Sprintf("flag %s already registered", name))
}
Expand Down Expand Up @@ -198,3 +203,26 @@ func Override(name string, value interface{}, usage string) {

pflag.CommandLine = newCommandLine
}

// Unregister removes a previously registered flag
// It panics if the flag is not registered or if flags have already been parsed
func Unregister(name string) {
if pflag.Lookup(name) == nil {
panic(fmt.Sprintf("flag %s is not registered", name))
}

if pflag.Parsed() {
panic(fmt.Sprintf("cannot unregister flag %s after flags have been parsed", name))
}

newCommandLine := pflag.NewFlagSet("", pflag.ContinueOnError)

// Copy all flags except the one we're unregistering
pflag.CommandLine.VisitAll(func(flag *pflag.Flag) {
if flag.Name != name {
newCommandLine.AddFlag(flag)
}
})

pflag.CommandLine = newCommandLine
}
171 changes: 139 additions & 32 deletions flag/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,62 +30,62 @@ func TestDefaultFlags(t *testing.T) {
func TestRegisterFlag(_ *testing.T) {
// Test registering a string flag
var stringFlag string
flag.RegisterFlag("test-string", &stringFlag, "A test string flag")
flag.Register("test-string", &stringFlag, "A test string flag")

// Test registering a bool flag
var boolFlag bool
flag.RegisterFlag("test-bool", &boolFlag, "A test bool flag")
flag.Register("test-bool", &boolFlag, "A test bool flag")

// Test registering an int flag
var intFlag int
flag.RegisterFlag("test-int", &intFlag, "A test int flag")
flag.Register("test-int", &intFlag, "A test int flag")

// Test registering various numeric types
var int8Flag int8
flag.RegisterFlag("test-int8", &int8Flag, "A test int8 flag")
flag.Register("test-int8", &int8Flag, "A test int8 flag")

var int16Flag int16
flag.RegisterFlag("test-int16", &int16Flag, "A test int16 flag")
flag.Register("test-int16", &int16Flag, "A test int16 flag")

var int32Flag int32
flag.RegisterFlag("test-int32", &int32Flag, "A test int32 flag")
flag.Register("test-int32", &int32Flag, "A test int32 flag")

var int64Flag int64
flag.RegisterFlag("test-int64", &int64Flag, "A test int64 flag")
flag.Register("test-int64", &int64Flag, "A test int64 flag")

var uintFlag uint
flag.RegisterFlag("test-uint", &uintFlag, "A test uint flag")
flag.Register("test-uint", &uintFlag, "A test uint flag")

var uint8Flag uint8
flag.RegisterFlag("test-uint8", &uint8Flag, "A test uint8 flag")
flag.Register("test-uint8", &uint8Flag, "A test uint8 flag")

var uint16Flag uint16
flag.RegisterFlag("test-uint16", &uint16Flag, "A test uint16 flag")
flag.Register("test-uint16", &uint16Flag, "A test uint16 flag")

var uint32Flag uint32
flag.RegisterFlag("test-uint32", &uint32Flag, "A test uint32 flag")
flag.Register("test-uint32", &uint32Flag, "A test uint32 flag")

var uint64Flag uint64
flag.RegisterFlag("test-uint64", &uint64Flag, "A test uint64 flag")
flag.Register("test-uint64", &uint64Flag, "A test uint64 flag")

var float32Flag float32
flag.RegisterFlag("test-float32", &float32Flag, "A test float32 flag")
flag.Register("test-float32", &float32Flag, "A test float32 flag")

var float64Flag float64
flag.RegisterFlag("test-float64", &float64Flag, "A test float64 flag")
flag.Register("test-float64", &float64Flag, "A test float64 flag")
}

func TestRegisterFlagPanics(t *testing.T) {
// Test that registering a duplicate flag panics
var testFlag string
flag.RegisterFlag("unique-flag", &testFlag, "A unique flag")
flag.Register("unique-flag", &testFlag, "A unique flag")

defer func() {
if r := recover(); r == nil {
t.Error("Expected panic when registering duplicate flag")
}
}()
flag.RegisterFlag("unique-flag", &testFlag, "A duplicate flag")
flag.Register("unique-flag", &testFlag, "A duplicate flag")
}

func TestRegisterFlagNonPointer(t *testing.T) {
Expand All @@ -95,7 +95,7 @@ func TestRegisterFlagNonPointer(t *testing.T) {
}
}()
var testFlag string
flag.RegisterFlag("non-pointer", testFlag, "A non-pointer flag")
flag.Register("non-pointer", testFlag, "A non-pointer flag")
}

func TestRegisterFlagNilPointer(t *testing.T) {
Expand All @@ -105,7 +105,7 @@ func TestRegisterFlagNilPointer(t *testing.T) {
}
}()
var testFlag *string
flag.RegisterFlag("nil-pointer", testFlag, "A nil pointer flag")
flag.Register("nil-pointer", testFlag, "A nil pointer flag")
}

func TestRegisterFlagUnsupportedType(t *testing.T) {
Expand All @@ -115,7 +115,7 @@ func TestRegisterFlagUnsupportedType(t *testing.T) {
}
}()
var testFlag []string
flag.RegisterFlag("unsupported", &testFlag, "An unsupported type flag")
flag.Register("unsupported", &testFlag, "An unsupported type flag")
}

func TestInit(_ *testing.T) {
Expand All @@ -137,16 +137,16 @@ func TestInit(_ *testing.T) {
// Test flag registration with default values
func TestRegisterFlagWithDefaults(_ *testing.T) {
var stringFlag = "default"
flag.RegisterFlag("default-string", &stringFlag, "A string flag with default")
flag.Register("default-string", &stringFlag, "A string flag with default")

var intFlag = 42
flag.RegisterFlag("default-int", &intFlag, "An int flag with default")
flag.Register("default-int", &intFlag, "An int flag with default")

var boolFlag = true
flag.RegisterFlag("default-bool", &boolFlag, "A bool flag with default")
flag.Register("default-bool", &boolFlag, "A bool flag with default")

var float64Flag = 3.14
flag.RegisterFlag("default-float64", &float64Flag, "A float64 flag with default")
flag.Register("default-float64", &float64Flag, "A float64 flag with default")
}

// Test integration with actual command line parsing
Expand All @@ -160,9 +160,9 @@ func TestCommandLineIntegration(_ *testing.T) {
var testInt int
var testBool bool

flag.RegisterFlag("integration-string", &testString, "Integration test string")
flag.RegisterFlag("integration-int", &testInt, "Integration test int")
flag.RegisterFlag("integration-bool", &testBool, "Integration test bool")
flag.Register("integration-string", &testString, "Integration test string")
flag.Register("integration-int", &testInt, "Integration test int")
flag.Register("integration-bool", &testBool, "Integration test bool")

// Simulate command line arguments
os.Args = []string{
Expand All @@ -181,7 +181,7 @@ func TestCommandLineIntegration(_ *testing.T) {
// Test that flags are properly bound to pflag
func TestFlagBinding(_ *testing.T) {
var testFlag string
flag.RegisterFlag("binding-test", &testFlag, "A binding test flag")
flag.Register("binding-test", &testFlag, "A binding test flag")

// This mainly tests that the function completes without error
// Actual binding verification would require more complex setup
Expand All @@ -197,7 +197,7 @@ func TestOverrideFlag(t *testing.T) {

// First register a flag
var originalFlag string = "original"
flag.RegisterFlag("override-test", &originalFlag, "Original description")
flag.Register("override-test", &originalFlag, "Original description")

// Override it with new variable, value and description
var newFlag string = "new_default"
Expand Down Expand Up @@ -247,7 +247,7 @@ func TestOverrideFlagNonPointer(t *testing.T) {

// First register a flag to override
var originalFlag string
flag.RegisterFlag("override-non-pointer", &originalFlag, "Original flag")
flag.Register("override-non-pointer", &originalFlag, "Original flag")

defer func() {
if r := recover(); r == nil {
Expand All @@ -268,7 +268,7 @@ func TestOverrideFlagNilPointer(t *testing.T) {

// First register a flag to override
var originalFlag string
flag.RegisterFlag("override-nil-pointer", &originalFlag, "Original flag")
flag.Register("override-nil-pointer", &originalFlag, "Original flag")

defer func() {
if r := recover(); r == nil {
Expand All @@ -289,7 +289,7 @@ func TestOverrideFlagAfterParse(t *testing.T) {

// Register a flag and parse
var originalFlag string
flag.RegisterFlag("override-after-parse", &originalFlag, "Original flag")
flag.Register("override-after-parse", &originalFlag, "Original flag")
flag.Init() // This will parse the flags

defer func() {
Expand All @@ -313,7 +313,7 @@ func TestOverrideUnsupportedType(t *testing.T) {

// First register a flag to override
var originalFlag string
flag.RegisterFlag("override-unsupported", &originalFlag, "Original flag")
flag.Register("override-unsupported", &originalFlag, "Original flag")

defer func() {
if r := recover(); r == nil {
Expand All @@ -323,3 +323,110 @@ func TestOverrideUnsupportedType(t *testing.T) {
var testSlice []string
flag.Override("override-unsupported", &testSlice, "An unsupported type override")
}

func TestUnregisterFlag(t *testing.T) {
// Save and restore the original command line
originalCommandLine := pflag.CommandLine
defer func() { pflag.CommandLine = originalCommandLine }()

// Create a fresh command line for this test
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)

// Register a flag to unregister
var testFlag string
flag.Register("unregister-test", &testFlag, "A flag to unregister")

// Verify the flag exists
if pflag.Lookup("unregister-test") == nil {
t.Error("Expected flag to be registered")
return
}

// Unregister the flag
flag.Unregister("unregister-test")

// Verify the flag no longer exists
if pflag.Lookup("unregister-test") != nil {
t.Error("Expected flag to be unregistered")
}
}

func TestUnregisterFlagPanics(t *testing.T) {
// Save and restore the original command line
originalCommandLine := pflag.CommandLine
defer func() { pflag.CommandLine = originalCommandLine }()

// Create a fresh command line for this test
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)

// Test that unregistering a non-existent flag panics
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic when unregistering non-existent flag")
}
}()
flag.Unregister("non-existent-flag")
}

func TestUnregisterFlagAfterParse(t *testing.T) {
// Save and restore the original command line
originalCommandLine := pflag.CommandLine
defer func() { pflag.CommandLine = originalCommandLine }()

// Create a fresh command line for this test
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)

// Register a flag and parse
var testFlag string
flag.Register("unregister-after-parse", &testFlag, "A flag to unregister after parse")
flag.Init() // This will parse the flags

defer func() {
if r := recover(); r == nil {
t.Error("Expected panic when unregistering after parsing")
}
}()

// Try to unregister after parsing - should panic
flag.Unregister("unregister-after-parse")
}

func TestUnregisterMultipleFlags(t *testing.T) {
// Save and restore the original command line
originalCommandLine := pflag.CommandLine
defer func() { pflag.CommandLine = originalCommandLine }()

// Create a fresh command line for this test
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)

// Register multiple flags
var flag1, flag2, flag3 string
flag.Register("unregister-multi-1", &flag1, "First flag")
flag.Register("unregister-multi-2", &flag2, "Second flag")
flag.Register("unregister-multi-3", &flag3, "Third flag")

// Verify all flags exist
if pflag.Lookup("unregister-multi-1") == nil {
t.Error("Expected flag1 to be registered")
}
if pflag.Lookup("unregister-multi-2") == nil {
t.Error("Expected flag2 to be registered")
}
if pflag.Lookup("unregister-multi-3") == nil {
t.Error("Expected flag3 to be registered")
}

// Unregister the middle flag
flag.Unregister("unregister-multi-2")

// Verify only the middle flag is unregistered
if pflag.Lookup("unregister-multi-1") == nil {
t.Error("Expected flag1 to still be registered")
}
if pflag.Lookup("unregister-multi-2") != nil {
t.Error("Expected flag2 to be unregistered")
}
if pflag.Lookup("unregister-multi-3") == nil {
t.Error("Expected flag3 to still be registered")
}
}