diff --git a/clientv2/client.go b/clientv2/client.go index 1b08563..d0182fc 100644 --- a/clientv2/client.go +++ b/clientv2/client.go @@ -396,12 +396,12 @@ func (c *Client) unmarshal(data []byte, res interface{}) error { func MarshalJSON(v interface{}) ([]byte, error) { if v == nil { - return []byte("null"), nil // Directly return "null" for nil interface{} + return []byte("null"), nil } val := reflect.ValueOf(v) if !val.IsValid() || (val.Kind() == reflect.Ptr && val.IsNil()) { - return []byte("null"), nil // Return "null" for nil pointer or invalid reflect value + return []byte("null"), nil } return encode(val) @@ -417,6 +417,10 @@ func checkImplements[I any](v reflect.Value) bool { // encode returns an appropriate encoder function for the provided value. func encode(v reflect.Value) ([]byte, error) { + if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) { + return []byte("null"), nil + } + if checkImplements[graphql.Marshaler](v) { return encodeGQLMarshaler(v.Interface()) } @@ -457,6 +461,10 @@ func encode(v reflect.Value) ([]byte, error) { } func encodeGQLMarshaler(v any) ([]byte, error) { + if v == nil { + return []byte("null"), nil + } + var buf bytes.Buffer if val, ok := v.(graphql.Marshaler); ok { val.MarshalGQL(&buf) diff --git a/clientv2/client_test.go b/clientv2/client_test.go index 124d659..23415b9 100644 --- a/clientv2/client_test.go +++ b/clientv2/client_test.go @@ -579,6 +579,8 @@ func TestMarshalJSON(t *testing.T) { Number Number `json:"number"` } + var b *Number + // example nested struct type WhereInput struct { Not *WhereInput `json:"not,omitempty"` @@ -640,6 +642,18 @@ func TestMarshalJSON(t *testing.T) { }, want: []byte(`{"operationName":"query", "query":"query ($input: Number!) { input }","variables":{"where":{"not":{"id":"1"}}}}`), }, + { + name: "marshal nil", + args: args{ + v: Request{ + OperationName: "query", + Variables: map[string]any{ + "v": b, + }, + }, + }, + want: []byte(`{"operationName":"query", "query":"","variables":{"v":null}}`), + }, { name: "marshal a struct with custom marshaler", args: args{ @@ -738,7 +752,7 @@ func TestMarshalJSON(t *testing.T) { return } if err := json.Unmarshal(tt.want, &wantMap); err != nil { - t.Errorf("Failed to unmarshal 'want': %s", tt.want) + t.Errorf("Failed to unmarshal err: %s", err) return }