Skip to content

Commit ca539be

Browse files
[ADD] unregister function to remove registered flags
1 parent 1f2edde commit ca539be

File tree

2 files changed

+171
-36
lines changed

2 files changed

+171
-36
lines changed

flag/flag.go

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
// typically in the `main` function of the application. If the `--help` flag is
1515
// set, it prints usage information and exits.
1616
//
17-
// Additional flags can be registered using `RegisterFlag`, which accepts the flag
17+
// Additional flags can be registered using `Register`, which accepts the flag
1818
// name, a pointer to the variable to populate, and a usage description.
1919
// Existing flags can be overridden using `Override`, which allows changing the
2020
// variable, default value, and description of an already registered flag.
21+
// Flags can be removed using `Unregister`, which removes a previously
22+
// registered flag from the command line.
2123
// Supported types include strings, booleans, integers, unsigned integers, and floats.
2224
//
2325
// Example:
@@ -32,12 +34,15 @@
3234
// var CustomFlag string
3335
//
3436
// func main() {
35-
// flag.RegisterFlag("custom", &CustomFlag, "A custom flag for demonstration")
37+
// flag.Register("custom", &CustomFlag, "A custom flag for demonstration")
3638
//
3739
// // Override the default path flag
3840
// flag.Path = "/new/default/path"
3941
// flag.Override("path", &flag.Path, "Updated application working directory")
4042
//
43+
// // Unregister a flag if no longer needed
44+
// flag.Unregister("custom")
45+
//
4146
// flag.Init()
4247
//
4348
// fmt.Println("Custom Flag Value:", CustomFlag)
@@ -83,9 +88,9 @@ func PrintHelp() {
8388
pflag.PrintDefaults()
8489
}
8590

86-
// RegisterFlag registers a new flag with the given name, value and usage
91+
// Register registers a new flag with the given name, value and usage
8792
// It panics if the flag is already registered or if the value is not a pointer
88-
func RegisterFlag(name string, value interface{}, usage string) {
93+
func Register(name string, value interface{}, usage string) {
8994
if pflag.Lookup(name) != nil {
9095
panic(fmt.Sprintf("flag %s already registered", name))
9196
}
@@ -198,3 +203,26 @@ func Override(name string, value interface{}, usage string) {
198203

199204
pflag.CommandLine = newCommandLine
200205
}
206+
207+
// Unregister removes a previously registered flag
208+
// It panics if the flag is not registered or if flags have already been parsed
209+
func Unregister(name string) {
210+
if pflag.Lookup(name) == nil {
211+
panic(fmt.Sprintf("flag %s is not registered", name))
212+
}
213+
214+
if pflag.Parsed() {
215+
panic(fmt.Sprintf("cannot unregister flag %s after flags have been parsed", name))
216+
}
217+
218+
newCommandLine := pflag.NewFlagSet("", pflag.ContinueOnError)
219+
220+
// Copy all flags except the one we're unregistering
221+
pflag.CommandLine.VisitAll(func(flag *pflag.Flag) {
222+
if flag.Name != name {
223+
newCommandLine.AddFlag(flag)
224+
}
225+
})
226+
227+
pflag.CommandLine = newCommandLine
228+
}

flag/flag_test.go

Lines changed: 139 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,62 +30,62 @@ func TestDefaultFlags(t *testing.T) {
3030
func TestRegisterFlag(_ *testing.T) {
3131
// Test registering a string flag
3232
var stringFlag string
33-
flag.RegisterFlag("test-string", &stringFlag, "A test string flag")
33+
flag.Register("test-string", &stringFlag, "A test string flag")
3434

3535
// Test registering a bool flag
3636
var boolFlag bool
37-
flag.RegisterFlag("test-bool", &boolFlag, "A test bool flag")
37+
flag.Register("test-bool", &boolFlag, "A test bool flag")
3838

3939
// Test registering an int flag
4040
var intFlag int
41-
flag.RegisterFlag("test-int", &intFlag, "A test int flag")
41+
flag.Register("test-int", &intFlag, "A test int flag")
4242

4343
// Test registering various numeric types
4444
var int8Flag int8
45-
flag.RegisterFlag("test-int8", &int8Flag, "A test int8 flag")
45+
flag.Register("test-int8", &int8Flag, "A test int8 flag")
4646

4747
var int16Flag int16
48-
flag.RegisterFlag("test-int16", &int16Flag, "A test int16 flag")
48+
flag.Register("test-int16", &int16Flag, "A test int16 flag")
4949

5050
var int32Flag int32
51-
flag.RegisterFlag("test-int32", &int32Flag, "A test int32 flag")
51+
flag.Register("test-int32", &int32Flag, "A test int32 flag")
5252

5353
var int64Flag int64
54-
flag.RegisterFlag("test-int64", &int64Flag, "A test int64 flag")
54+
flag.Register("test-int64", &int64Flag, "A test int64 flag")
5555

5656
var uintFlag uint
57-
flag.RegisterFlag("test-uint", &uintFlag, "A test uint flag")
57+
flag.Register("test-uint", &uintFlag, "A test uint flag")
5858

5959
var uint8Flag uint8
60-
flag.RegisterFlag("test-uint8", &uint8Flag, "A test uint8 flag")
60+
flag.Register("test-uint8", &uint8Flag, "A test uint8 flag")
6161

6262
var uint16Flag uint16
63-
flag.RegisterFlag("test-uint16", &uint16Flag, "A test uint16 flag")
63+
flag.Register("test-uint16", &uint16Flag, "A test uint16 flag")
6464

6565
var uint32Flag uint32
66-
flag.RegisterFlag("test-uint32", &uint32Flag, "A test uint32 flag")
66+
flag.Register("test-uint32", &uint32Flag, "A test uint32 flag")
6767

6868
var uint64Flag uint64
69-
flag.RegisterFlag("test-uint64", &uint64Flag, "A test uint64 flag")
69+
flag.Register("test-uint64", &uint64Flag, "A test uint64 flag")
7070

7171
var float32Flag float32
72-
flag.RegisterFlag("test-float32", &float32Flag, "A test float32 flag")
72+
flag.Register("test-float32", &float32Flag, "A test float32 flag")
7373

7474
var float64Flag float64
75-
flag.RegisterFlag("test-float64", &float64Flag, "A test float64 flag")
75+
flag.Register("test-float64", &float64Flag, "A test float64 flag")
7676
}
7777

7878
func TestRegisterFlagPanics(t *testing.T) {
7979
// Test that registering a duplicate flag panics
8080
var testFlag string
81-
flag.RegisterFlag("unique-flag", &testFlag, "A unique flag")
81+
flag.Register("unique-flag", &testFlag, "A unique flag")
8282

8383
defer func() {
8484
if r := recover(); r == nil {
8585
t.Error("Expected panic when registering duplicate flag")
8686
}
8787
}()
88-
flag.RegisterFlag("unique-flag", &testFlag, "A duplicate flag")
88+
flag.Register("unique-flag", &testFlag, "A duplicate flag")
8989
}
9090

9191
func TestRegisterFlagNonPointer(t *testing.T) {
@@ -95,7 +95,7 @@ func TestRegisterFlagNonPointer(t *testing.T) {
9595
}
9696
}()
9797
var testFlag string
98-
flag.RegisterFlag("non-pointer", testFlag, "A non-pointer flag")
98+
flag.Register("non-pointer", testFlag, "A non-pointer flag")
9999
}
100100

101101
func TestRegisterFlagNilPointer(t *testing.T) {
@@ -105,7 +105,7 @@ func TestRegisterFlagNilPointer(t *testing.T) {
105105
}
106106
}()
107107
var testFlag *string
108-
flag.RegisterFlag("nil-pointer", testFlag, "A nil pointer flag")
108+
flag.Register("nil-pointer", testFlag, "A nil pointer flag")
109109
}
110110

111111
func TestRegisterFlagUnsupportedType(t *testing.T) {
@@ -115,7 +115,7 @@ func TestRegisterFlagUnsupportedType(t *testing.T) {
115115
}
116116
}()
117117
var testFlag []string
118-
flag.RegisterFlag("unsupported", &testFlag, "An unsupported type flag")
118+
flag.Register("unsupported", &testFlag, "An unsupported type flag")
119119
}
120120

121121
func TestInit(_ *testing.T) {
@@ -137,16 +137,16 @@ func TestInit(_ *testing.T) {
137137
// Test flag registration with default values
138138
func TestRegisterFlagWithDefaults(_ *testing.T) {
139139
var stringFlag = "default"
140-
flag.RegisterFlag("default-string", &stringFlag, "A string flag with default")
140+
flag.Register("default-string", &stringFlag, "A string flag with default")
141141

142142
var intFlag = 42
143-
flag.RegisterFlag("default-int", &intFlag, "An int flag with default")
143+
flag.Register("default-int", &intFlag, "An int flag with default")
144144

145145
var boolFlag = true
146-
flag.RegisterFlag("default-bool", &boolFlag, "A bool flag with default")
146+
flag.Register("default-bool", &boolFlag, "A bool flag with default")
147147

148148
var float64Flag = 3.14
149-
flag.RegisterFlag("default-float64", &float64Flag, "A float64 flag with default")
149+
flag.Register("default-float64", &float64Flag, "A float64 flag with default")
150150
}
151151

152152
// Test integration with actual command line parsing
@@ -160,9 +160,9 @@ func TestCommandLineIntegration(_ *testing.T) {
160160
var testInt int
161161
var testBool bool
162162

163-
flag.RegisterFlag("integration-string", &testString, "Integration test string")
164-
flag.RegisterFlag("integration-int", &testInt, "Integration test int")
165-
flag.RegisterFlag("integration-bool", &testBool, "Integration test bool")
163+
flag.Register("integration-string", &testString, "Integration test string")
164+
flag.Register("integration-int", &testInt, "Integration test int")
165+
flag.Register("integration-bool", &testBool, "Integration test bool")
166166

167167
// Simulate command line arguments
168168
os.Args = []string{
@@ -181,7 +181,7 @@ func TestCommandLineIntegration(_ *testing.T) {
181181
// Test that flags are properly bound to pflag
182182
func TestFlagBinding(_ *testing.T) {
183183
var testFlag string
184-
flag.RegisterFlag("binding-test", &testFlag, "A binding test flag")
184+
flag.Register("binding-test", &testFlag, "A binding test flag")
185185

186186
// This mainly tests that the function completes without error
187187
// Actual binding verification would require more complex setup
@@ -197,7 +197,7 @@ func TestOverrideFlag(t *testing.T) {
197197

198198
// First register a flag
199199
var originalFlag string = "original"
200-
flag.RegisterFlag("override-test", &originalFlag, "Original description")
200+
flag.Register("override-test", &originalFlag, "Original description")
201201

202202
// Override it with new variable, value and description
203203
var newFlag string = "new_default"
@@ -247,7 +247,7 @@ func TestOverrideFlagNonPointer(t *testing.T) {
247247

248248
// First register a flag to override
249249
var originalFlag string
250-
flag.RegisterFlag("override-non-pointer", &originalFlag, "Original flag")
250+
flag.Register("override-non-pointer", &originalFlag, "Original flag")
251251

252252
defer func() {
253253
if r := recover(); r == nil {
@@ -268,7 +268,7 @@ func TestOverrideFlagNilPointer(t *testing.T) {
268268

269269
// First register a flag to override
270270
var originalFlag string
271-
flag.RegisterFlag("override-nil-pointer", &originalFlag, "Original flag")
271+
flag.Register("override-nil-pointer", &originalFlag, "Original flag")
272272

273273
defer func() {
274274
if r := recover(); r == nil {
@@ -289,7 +289,7 @@ func TestOverrideFlagAfterParse(t *testing.T) {
289289

290290
// Register a flag and parse
291291
var originalFlag string
292-
flag.RegisterFlag("override-after-parse", &originalFlag, "Original flag")
292+
flag.Register("override-after-parse", &originalFlag, "Original flag")
293293
flag.Init() // This will parse the flags
294294

295295
defer func() {
@@ -313,7 +313,7 @@ func TestOverrideUnsupportedType(t *testing.T) {
313313

314314
// First register a flag to override
315315
var originalFlag string
316-
flag.RegisterFlag("override-unsupported", &originalFlag, "Original flag")
316+
flag.Register("override-unsupported", &originalFlag, "Original flag")
317317

318318
defer func() {
319319
if r := recover(); r == nil {
@@ -323,3 +323,110 @@ func TestOverrideUnsupportedType(t *testing.T) {
323323
var testSlice []string
324324
flag.Override("override-unsupported", &testSlice, "An unsupported type override")
325325
}
326+
327+
func TestUnregisterFlag(t *testing.T) {
328+
// Save and restore the original command line
329+
originalCommandLine := pflag.CommandLine
330+
defer func() { pflag.CommandLine = originalCommandLine }()
331+
332+
// Create a fresh command line for this test
333+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
334+
335+
// Register a flag to unregister
336+
var testFlag string
337+
flag.Register("unregister-test", &testFlag, "A flag to unregister")
338+
339+
// Verify the flag exists
340+
if pflag.Lookup("unregister-test") == nil {
341+
t.Error("Expected flag to be registered")
342+
return
343+
}
344+
345+
// Unregister the flag
346+
flag.Unregister("unregister-test")
347+
348+
// Verify the flag no longer exists
349+
if pflag.Lookup("unregister-test") != nil {
350+
t.Error("Expected flag to be unregistered")
351+
}
352+
}
353+
354+
func TestUnregisterFlagPanics(t *testing.T) {
355+
// Save and restore the original command line
356+
originalCommandLine := pflag.CommandLine
357+
defer func() { pflag.CommandLine = originalCommandLine }()
358+
359+
// Create a fresh command line for this test
360+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
361+
362+
// Test that unregistering a non-existent flag panics
363+
defer func() {
364+
if r := recover(); r == nil {
365+
t.Error("Expected panic when unregistering non-existent flag")
366+
}
367+
}()
368+
flag.Unregister("non-existent-flag")
369+
}
370+
371+
func TestUnregisterFlagAfterParse(t *testing.T) {
372+
// Save and restore the original command line
373+
originalCommandLine := pflag.CommandLine
374+
defer func() { pflag.CommandLine = originalCommandLine }()
375+
376+
// Create a fresh command line for this test
377+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
378+
379+
// Register a flag and parse
380+
var testFlag string
381+
flag.Register("unregister-after-parse", &testFlag, "A flag to unregister after parse")
382+
flag.Init() // This will parse the flags
383+
384+
defer func() {
385+
if r := recover(); r == nil {
386+
t.Error("Expected panic when unregistering after parsing")
387+
}
388+
}()
389+
390+
// Try to unregister after parsing - should panic
391+
flag.Unregister("unregister-after-parse")
392+
}
393+
394+
func TestUnregisterMultipleFlags(t *testing.T) {
395+
// Save and restore the original command line
396+
originalCommandLine := pflag.CommandLine
397+
defer func() { pflag.CommandLine = originalCommandLine }()
398+
399+
// Create a fresh command line for this test
400+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
401+
402+
// Register multiple flags
403+
var flag1, flag2, flag3 string
404+
flag.Register("unregister-multi-1", &flag1, "First flag")
405+
flag.Register("unregister-multi-2", &flag2, "Second flag")
406+
flag.Register("unregister-multi-3", &flag3, "Third flag")
407+
408+
// Verify all flags exist
409+
if pflag.Lookup("unregister-multi-1") == nil {
410+
t.Error("Expected flag1 to be registered")
411+
}
412+
if pflag.Lookup("unregister-multi-2") == nil {
413+
t.Error("Expected flag2 to be registered")
414+
}
415+
if pflag.Lookup("unregister-multi-3") == nil {
416+
t.Error("Expected flag3 to be registered")
417+
}
418+
419+
// Unregister the middle flag
420+
flag.Unregister("unregister-multi-2")
421+
422+
// Verify only the middle flag is unregistered
423+
if pflag.Lookup("unregister-multi-1") == nil {
424+
t.Error("Expected flag1 to still be registered")
425+
}
426+
if pflag.Lookup("unregister-multi-2") != nil {
427+
t.Error("Expected flag2 to be unregistered")
428+
}
429+
if pflag.Lookup("unregister-multi-3") == nil {
430+
t.Error("Expected flag3 to still be registered")
431+
}
432+
}

0 commit comments

Comments
 (0)