Skip to content

Commit 4e4c7b0

Browse files
committed
fix #213 support typed nil in interpolation
1 parent 51f40ba commit 4e4c7b0

File tree

2 files changed

+119
-3
lines changed

2 files changed

+119
-3
lines changed

interpolate.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,12 +618,22 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
618618
buf = append(buf, "', 'YYYY-MM-DD HH24:MI:SS.FF')"...)
619619
}
620620

621-
case fmt.Stringer:
622-
buf = quoteStringValue(buf, v.String(), flavor)
623-
624621
default:
625622
primative := reflect.ValueOf(arg)
626623

624+
// Handle typed nil values (e.g. (*string)(nil), (*time.Time)(nil))
625+
// This check must come before fmt.Stringer check since nil pointers may implement interfaces
626+
if !primative.IsValid() || (primative.Kind() == reflect.Ptr && primative.IsNil()) {
627+
buf = append(buf, "NULL"...)
628+
return buf, nil
629+
}
630+
631+
// Check for fmt.Stringer after nil pointer check
632+
if stringer, ok := arg.(fmt.Stringer); ok {
633+
buf = quoteStringValue(buf, stringer.String(), flavor)
634+
return buf, nil
635+
}
636+
627637
switch k := primative.Kind(); k {
628638
case reflect.Bool:
629639
switch flavor {

interpolate_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,109 @@ func TestFlavorInterpolate(t *testing.T) {
389389
})
390390
}
391391
}
392+
393+
func TestTypedNilInterpolation(t *testing.T) {
394+
a := assert.New(t)
395+
396+
// Test various typed nil pointers
397+
cases := []struct {
398+
name string
399+
flavor Flavor
400+
sql string
401+
args []interface{}
402+
expected string
403+
}{
404+
{
405+
name: "string pointer nil",
406+
flavor: MySQL,
407+
sql: "SELECT ?",
408+
args: []interface{}{(*string)(nil)},
409+
expected: "SELECT NULL",
410+
},
411+
{
412+
name: "int pointer nil",
413+
flavor: MySQL,
414+
sql: "SELECT ?",
415+
args: []interface{}{(*int)(nil)},
416+
expected: "SELECT NULL",
417+
},
418+
{
419+
name: "float64 pointer nil",
420+
flavor: MySQL,
421+
sql: "SELECT ?",
422+
args: []interface{}{(*float64)(nil)},
423+
expected: "SELECT NULL",
424+
},
425+
{
426+
name: "time.Time pointer nil",
427+
flavor: MySQL,
428+
sql: "SELECT ?",
429+
args: []interface{}{(*time.Time)(nil)},
430+
expected: "SELECT NULL",
431+
},
432+
{
433+
name: "byte slice pointer nil",
434+
flavor: MySQL,
435+
sql: "SELECT ?",
436+
args: []interface{}{(*[]byte)(nil)},
437+
expected: "SELECT NULL",
438+
},
439+
{
440+
name: "PostgreSQL string pointer nil",
441+
flavor: PostgreSQL,
442+
sql: "SELECT $1",
443+
args: []interface{}{(*string)(nil)},
444+
expected: "SELECT NULL",
445+
},
446+
{
447+
name: "SQLite int pointer nil",
448+
flavor: SQLite,
449+
sql: "SELECT ?",
450+
args: []interface{}{(*int)(nil)},
451+
expected: "SELECT NULL",
452+
},
453+
{
454+
name: "Oracle float pointer nil",
455+
flavor: Oracle,
456+
sql: "SELECT :1",
457+
args: []interface{}{(*float64)(nil)},
458+
expected: "SELECT NULL",
459+
},
460+
{
461+
name: "mixed nil types",
462+
flavor: MySQL,
463+
sql: "SELECT ?, ?, ?, ?",
464+
args: []interface{}{(*string)(nil), nil, (*int)(nil), (*[]byte)(nil)},
465+
expected: "SELECT NULL, NULL, NULL, NULL",
466+
},
467+
{
468+
name: "interface pointer nil",
469+
flavor: MySQL,
470+
sql: "SELECT ?",
471+
args: []interface{}{(*fmt.Stringer)(nil)},
472+
expected: "SELECT NULL",
473+
},
474+
{
475+
name: "slice pointer nil",
476+
flavor: MySQL,
477+
sql: "SELECT ?",
478+
args: []interface{}{(*[]int)(nil)},
479+
expected: "SELECT NULL",
480+
},
481+
{
482+
name: "map pointer nil",
483+
flavor: MySQL,
484+
sql: "SELECT ?",
485+
args: []interface{}{(*map[string]int)(nil)},
486+
expected: "SELECT NULL",
487+
},
488+
}
489+
490+
for _, tc := range cases {
491+
t.Run(tc.name, func(t *testing.T) {
492+
query, err := tc.flavor.Interpolate(tc.sql, tc.args)
493+
a.NilError(err)
494+
a.Equal(tc.expected, query)
495+
})
496+
}
497+
}

0 commit comments

Comments
 (0)