Skip to content

Commit 8e73011

Browse files
[ADD] flag override functionalit
1 parent 156fcff commit 8e73011

File tree

2 files changed

+216
-2
lines changed

2 files changed

+216
-2
lines changed

flag/flag.go

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
//
1717
// Additional flags can be registered using `RegisterFlag`, which accepts the flag
1818
// name, a pointer to the variable to populate, and a usage description.
19+
// Existing flags can be overridden using `Override`, which allows changing the
20+
// variable, default value, and description of an already registered flag.
1921
// Supported types include strings, booleans, integers, unsigned integers, and floats.
2022
//
2123
// Example:
@@ -31,14 +33,21 @@
3133
//
3234
// func main() {
3335
// flag.RegisterFlag("custom", &CustomFlag, "A custom flag for demonstration")
36+
//
37+
// // Override the default path flag
38+
// flag.Path = "/new/default/path"
39+
// flag.Override("path", &flag.Path, "Updated application working directory")
40+
//
3441
// flag.Init()
3542
//
3643
// fmt.Println("Custom Flag Value:", CustomFlag)
44+
// fmt.Println("Path:", flag.Path)
3745
// }
3846
package flag
3947

4048
import (
4149
"fmt"
50+
"os"
4251
"reflect"
4352

4453
"github.com/spf13/pflag"
@@ -56,7 +65,7 @@ var (
5665
)
5766

5867
func init() {
59-
pflag.StringVar(&Path, "path", "./", "Sets the application working directory")
68+
pflag.StringVar(&Path, "path", "./data", "Sets the application working directory")
6069
pflag.BoolVar(&Help, "help", false, "Prints the help page")
6170
pflag.BoolVar(&Version, "version", false, "Prints the software version")
6271
pflag.BoolVar(&Debug, "debug", false, "Enables debug mode")
@@ -68,8 +77,9 @@ func Init() {
6877
pflag.Parse()
6978
}
7079

80+
// PrintHelp prints the help message to standard error output
7181
func PrintHelp() {
72-
fmt.Println("Usage:")
82+
fmt.Fprintln(os.Stderr, "Usage:")
7383
pflag.PrintDefaults()
7484
}
7585

@@ -122,3 +132,69 @@ func RegisterFlag(name string, value interface{}, usage string) {
122132
panic(fmt.Sprintf("unsupported type %T", v))
123133
}
124134
}
135+
136+
// Override allows changing an existing flag's variable, default value and description
137+
// It panics if the flag is not already registered or if the value is not a pointer
138+
// Note: The flag must not have been parsed yet for this to work properly
139+
func Override(name string, value interface{}, usage string) {
140+
if pflag.Lookup(name) == nil {
141+
panic(fmt.Sprintf("flag %s is not registered", name))
142+
}
143+
144+
val := reflect.ValueOf(value)
145+
if val.Kind() != reflect.Ptr {
146+
panic(fmt.Sprintf("flag %s value must be a pointer", name))
147+
}
148+
149+
if val.IsNil() {
150+
panic(fmt.Sprintf("flag %s value must not be nil", name))
151+
}
152+
153+
if pflag.Parsed() {
154+
panic(fmt.Sprintf("cannot override flag %s after flags have been parsed", name))
155+
}
156+
157+
newCommandLine := pflag.NewFlagSet("", pflag.ContinueOnError)
158+
159+
// Copy all flags except the one we're overriding
160+
pflag.CommandLine.VisitAll(func(flag *pflag.Flag) {
161+
if flag.Name != name {
162+
newCommandLine.AddFlag(flag)
163+
}
164+
})
165+
166+
switch v := value.(type) {
167+
case *string:
168+
newCommandLine.StringVar(v, name, *v, usage)
169+
case *bool:
170+
newCommandLine.BoolVar(v, name, *v, usage)
171+
case *int:
172+
newCommandLine.IntVar(v, name, *v, usage)
173+
case *int8:
174+
newCommandLine.Int8Var(v, name, *v, usage)
175+
case *int16:
176+
newCommandLine.Int16Var(v, name, *v, usage)
177+
case *int32:
178+
newCommandLine.Int32Var(v, name, *v, usage)
179+
case *int64:
180+
newCommandLine.Int64Var(v, name, *v, usage)
181+
case *uint:
182+
newCommandLine.UintVar(v, name, *v, usage)
183+
case *uint8:
184+
newCommandLine.Uint8Var(v, name, *v, usage)
185+
case *uint16:
186+
newCommandLine.Uint16Var(v, name, *v, usage)
187+
case *uint32:
188+
newCommandLine.Uint32Var(v, name, *v, usage)
189+
case *uint64:
190+
newCommandLine.Uint64Var(v, name, *v, usage)
191+
case *float32:
192+
newCommandLine.Float32Var(v, name, *v, usage)
193+
case *float64:
194+
newCommandLine.Float64Var(v, name, *v, usage)
195+
default:
196+
panic(fmt.Sprintf("unsupported type %T", v))
197+
}
198+
199+
pflag.CommandLine = newCommandLine
200+
}

flag/flag_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"os"
55
"testing"
66

7+
"github.com/spf13/pflag"
78
"github.com/valentin-kaiser/go-core/flag"
89
)
910

@@ -185,3 +186,140 @@ func TestFlagBinding(_ *testing.T) {
185186
// This mainly tests that the function completes without error
186187
// Actual binding verification would require more complex setup
187188
}
189+
190+
func TestOverrideFlag(t *testing.T) {
191+
// Save and restore the original command line
192+
originalCommandLine := pflag.CommandLine
193+
defer func() { pflag.CommandLine = originalCommandLine }()
194+
195+
// Create a fresh command line for this test
196+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
197+
198+
// First register a flag
199+
var originalFlag string = "original"
200+
flag.RegisterFlag("override-test", &originalFlag, "Original description")
201+
202+
// Override it with new variable, value and description
203+
var newFlag string = "new_default"
204+
flag.Override("override-test", &newFlag, "New description")
205+
206+
// Test that the flag was overridden
207+
overriddenFlag := pflag.Lookup("override-test")
208+
if overriddenFlag == nil {
209+
t.Error("Expected overridden flag to exist")
210+
return
211+
}
212+
213+
if overriddenFlag.Usage != "New description" {
214+
t.Errorf("Expected usage to be 'New description', got '%s'", overriddenFlag.Usage)
215+
}
216+
217+
if overriddenFlag.DefValue != "new_default" {
218+
t.Errorf("Expected default value to be 'new_default', got '%s'", overriddenFlag.DefValue)
219+
}
220+
}
221+
222+
func TestOverrideFlagPanics(t *testing.T) {
223+
// Save and restore the original command line
224+
originalCommandLine := pflag.CommandLine
225+
defer func() { pflag.CommandLine = originalCommandLine }()
226+
227+
// Create a fresh command line for this test
228+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
229+
230+
// Test that overriding a non-existent flag panics
231+
defer func() {
232+
if r := recover(); r == nil {
233+
t.Error("Expected panic when overriding non-existent flag")
234+
}
235+
}()
236+
var testFlag string
237+
flag.Override("non-existent-flag", &testFlag, "A non-existent flag")
238+
}
239+
240+
func TestOverrideFlagNonPointer(t *testing.T) {
241+
// Save and restore the original command line
242+
originalCommandLine := pflag.CommandLine
243+
defer func() { pflag.CommandLine = originalCommandLine }()
244+
245+
// Create a fresh command line for this test
246+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
247+
248+
// First register a flag to override
249+
var originalFlag string
250+
flag.RegisterFlag("override-non-pointer", &originalFlag, "Original flag")
251+
252+
defer func() {
253+
if r := recover(); r == nil {
254+
t.Error("Expected panic when overriding with non-pointer value")
255+
}
256+
}()
257+
var testFlag string
258+
flag.Override("override-non-pointer", testFlag, "A non-pointer override")
259+
}
260+
261+
func TestOverrideFlagNilPointer(t *testing.T) {
262+
// Save and restore the original command line
263+
originalCommandLine := pflag.CommandLine
264+
defer func() { pflag.CommandLine = originalCommandLine }()
265+
266+
// Create a fresh command line for this test
267+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
268+
269+
// First register a flag to override
270+
var originalFlag string
271+
flag.RegisterFlag("override-nil-pointer", &originalFlag, "Original flag")
272+
273+
defer func() {
274+
if r := recover(); r == nil {
275+
t.Error("Expected panic when overriding with nil pointer")
276+
}
277+
}()
278+
var testFlag *string
279+
flag.Override("override-nil-pointer", testFlag, "A nil pointer override")
280+
}
281+
282+
func TestOverrideFlagAfterParse(t *testing.T) {
283+
// Save and restore the original command line
284+
originalCommandLine := pflag.CommandLine
285+
defer func() { pflag.CommandLine = originalCommandLine }()
286+
287+
// Create a fresh command line for this test
288+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
289+
290+
// Register a flag and parse
291+
var originalFlag string
292+
flag.RegisterFlag("override-after-parse", &originalFlag, "Original flag")
293+
flag.Init() // This will parse the flags
294+
295+
defer func() {
296+
if r := recover(); r == nil {
297+
t.Error("Expected panic when overriding after parsing")
298+
}
299+
}()
300+
301+
// Try to override after parsing - should panic
302+
var newFlag string = "new"
303+
flag.Override("override-after-parse", &newFlag, "New description")
304+
}
305+
306+
func TestOverrideUnsupportedType(t *testing.T) {
307+
// Save and restore the original command line
308+
originalCommandLine := pflag.CommandLine
309+
defer func() { pflag.CommandLine = originalCommandLine }()
310+
311+
// Create a fresh command line for this test
312+
pflag.CommandLine = pflag.NewFlagSet("", pflag.ContinueOnError)
313+
314+
// First register a flag to override
315+
var originalFlag string
316+
flag.RegisterFlag("override-unsupported", &originalFlag, "Original flag")
317+
318+
defer func() {
319+
if r := recover(); r == nil {
320+
t.Error("Expected panic when overriding with unsupported type")
321+
}
322+
}()
323+
var testSlice []string
324+
flag.Override("override-unsupported", &testSlice, "An unsupported type override")
325+
}

0 commit comments

Comments
 (0)