Skip to content

Commit bcf23de

Browse files
authored
Allow tagging of the field as having a scalar GraphQL type (shurcooL#24)
* Allow tagging of the field as having a scalar GraphQL type This change allows us to specify that a field in a struct has a scalar GraphQL type associated with it. This is done by adding the tag `scalar:"true"` to the field. This would allow us to - Avoid expansion of the field during request query generation, even when the golang type of the field is a struct - When the response is decoded, the value is simply JSON decoded, instead of the much stricter GraphQL decode. - Types like map[string]interface{}, json.RawMessage, etc should work as far as the corresponding fields are marked as scalar * Handle the error message where we are decoding as json.rawMessage Co-authored-by: Nizar Malangadan <nizar-m@users.noreply.github.com>
1 parent 8c8fa4d commit bcf23de

File tree

4 files changed

+89
-6
lines changed

4 files changed

+89
-6
lines changed

internal/jsonutil/graphql.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"reflect"
12+
"strconv"
1213
"strings"
1314
)
1415

@@ -104,6 +105,7 @@ func (d *decoder) decode() error {
104105
someFieldExist := false
105106
// If one field is raw all must be treated as raw
106107
rawMessage := false
108+
isScalar := false
107109
for i := range d.vs {
108110
v := d.vs[i].Top()
109111
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
@@ -112,7 +114,7 @@ func (d *decoder) decode() error {
112114
var f reflect.Value
113115
switch v.Kind() {
114116
case reflect.Struct:
115-
f = fieldByGraphQLName(v, key)
117+
f, isScalar = fieldByGraphQLName(v, key)
116118
if f.IsValid() {
117119
someFieldExist = true
118120
// Check for special embedded json
@@ -132,10 +134,13 @@ func (d *decoder) decode() error {
132134
return fmt.Errorf("struct field for %q doesn't exist in any of %v places to unmarshal", key, len(d.vs))
133135
}
134136

135-
if rawMessage {
137+
if rawMessage || isScalar {
136138
// Read the next complete object from the json stream
137139
var data json.RawMessage
138-
d.tokenizer.Decode(&data)
140+
err = d.tokenizer.Decode(&data)
141+
if err != nil {
142+
return err
143+
}
139144
tok = data
140145
} else {
141146
// We've just consumed the current token, which was the key.
@@ -361,17 +366,17 @@ func (d *decoder) popLeftArrayTemplates() {
361366

362367
// fieldByGraphQLName returns an exported struct field of struct v
363368
// that matches GraphQL name, or invalid reflect.Value if none found.
364-
func fieldByGraphQLName(v reflect.Value, name string) reflect.Value {
369+
func fieldByGraphQLName(v reflect.Value, name string) (val reflect.Value, taggedAsScalar bool) {
365370
for i := 0; i < v.NumField(); i++ {
366371
if v.Type().Field(i).PkgPath != "" {
367372
// Skip unexported field.
368373
continue
369374
}
370375
if hasGraphQLName(v.Type().Field(i), name) {
371-
return v.Field(i)
376+
return v.Field(i), hasScalarTag(v.Type().Field(i))
372377
}
373378
}
374-
return reflect.Value{}
379+
return reflect.Value{}, false
375380
}
376381

377382
// orderedMapValueByGraphQLName takes [][2]string, interprets it as an ordered map
@@ -387,6 +392,15 @@ func orderedMapValueByGraphQLName(v reflect.Value, name string) reflect.Value {
387392
return reflect.Value{}
388393
}
389394

395+
func hasScalarTag(f reflect.StructField) bool {
396+
return isTrue(f.Tag.Get("scalar"))
397+
}
398+
399+
func isTrue(s string) bool {
400+
b, _ := strconv.ParseBool(s)
401+
return b
402+
}
403+
390404
// hasGraphQLName reports whether struct field f has GraphQL name.
391405
func hasGraphQLName(f reflect.StructField, name string) bool {
392406
value, ok := f.Tag.Lookup("graphql")

internal/jsonutil/graphql_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,42 @@ func TestUnmarshalGraphQL_jsonRawTag(t *testing.T) {
104104
}
105105
}
106106

107+
func TestUnmarshalGraphQL_fieldAsScalar(t *testing.T) {
108+
type query struct {
109+
Data json.RawMessage `scalar:"true"`
110+
DataPtr *json.RawMessage `scalar:"true"`
111+
Another string
112+
Tags map[string]int `scalar:"true"`
113+
}
114+
var got query
115+
err := jsonutil.UnmarshalGraphQL([]byte(`{
116+
"Data" : {"ValA":1,"ValB":"foo"},
117+
"DataPtr" : {"ValC":3,"ValD":false},
118+
"Another" : "stuff",
119+
"Tags": {
120+
"keyA": 2,
121+
"keyB": 3
122+
}
123+
}`), &got)
124+
125+
if err != nil {
126+
t.Fatal(err)
127+
}
128+
dataPtr := json.RawMessage(`{"ValC":3,"ValD":false}`)
129+
want := query{
130+
Data: json.RawMessage(`{"ValA":1,"ValB":"foo"}`),
131+
DataPtr: &dataPtr,
132+
Another: "stuff",
133+
Tags: map[string]int{
134+
"keyA": 2,
135+
"keyB": 3,
136+
},
137+
}
138+
if !reflect.DeepEqual(got, want) {
139+
t.Errorf("not equal: %v %v", want, got)
140+
}
141+
}
142+
107143
func TestUnmarshalGraphQL_orderedMap(t *testing.T) {
108144
type query [][2]interface{}
109145
got := query{

query.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"reflect"
99
"sort"
10+
"strconv"
1011
"strings"
1112

1213
"github.com/hasura/go-graphql-client/ident"
@@ -187,6 +188,10 @@ func writeQuery(w io.Writer, t reflect.Type, v reflect.Value, inline bool) {
187188
io.WriteString(w, ident.ParseMixedCaps(f.Name).ToLowerCamelCase())
188189
}
189190
}
191+
// Skip writeQuery if the GraphQL type associated with the filed is scalar
192+
if isTrue(f.Tag.Get("scalar")) {
193+
continue
194+
}
190195
writeQuery(w, f.Type, FieldSafe(v, i), inlineField)
191196
}
192197
if !inline {
@@ -241,3 +246,8 @@ func FieldSafe(valStruct reflect.Value, i int) reflect.Value {
241246
}
242247

243248
var jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
249+
250+
func isTrue(s string) bool {
251+
b, _ := strconv.ParseBool(s)
252+
return b
253+
}

query_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,29 @@ func TestConstructQuery(t *testing.T) {
305305
}{},
306306
want: `{viewer{login,createdAt,id,databaseId}}`,
307307
},
308+
{
309+
inV: struct {
310+
Viewer struct {
311+
ID interface{}
312+
Login string
313+
CreatedAt time.Time
314+
DatabaseID int
315+
}
316+
Tags map[string]interface{} `scalar:"true"`
317+
}{},
318+
want: `{viewer{id,login,createdAt,databaseId},tags}`,
319+
},
320+
{
321+
inV: struct {
322+
Viewer struct {
323+
ID interface{}
324+
Login string
325+
CreatedAt time.Time
326+
DatabaseID int
327+
} `scalar:"true"`
328+
}{},
329+
want: `{viewer}`,
330+
},
308331
}
309332
for _, tc := range tests {
310333
got, err := constructQuery(tc.inV, tc.inVariables, tc.options...)

0 commit comments

Comments
 (0)