Skip to content

Commit

Permalink
Merge pull request #221 from Yamashou/unmarshalgql
Browse files Browse the repository at this point in the history
UnmarshalGQL
  • Loading branch information
Yamashou authored May 1, 2024
2 parents fa26c9a + f2cccda commit 5692918
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 7 deletions.
54 changes: 47 additions & 7 deletions graphqljson/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (
"io"
"reflect"
"strings"

"github.com/99designs/gqlgen/graphql"
)

// Reference: https://blog.gopheracademy.com/advent-2017/custom-json-unmarshaler-for-graphql-client/
Expand Down Expand Up @@ -104,7 +106,7 @@ func (d *Decoder) Decode(v interface{}) error {
}

// decode decodes a single JSON value from d.tokenizer into d.vs.
func (d *Decoder) decode() error {
func (d *Decoder) decode() error { //nolint:maintidx
// The loop invariant is that the top of each d.vs stack
// is where we try to unmarshal the next JSON value we see.
for len(d.vs) > 0 {
Expand Down Expand Up @@ -187,17 +189,55 @@ func (d *Decoder) decode() error {
}

switch tok := tok.(type) {
case string, json.Number, bool, nil, json.RawMessage, map[string]interface{}:
// Value.

case nil: // Handle null values correctly.
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice {
// Set the pointer or slice to nil.
v.Set(reflect.Zero(v.Type()))
} else {
// For other types that cannot directly handle nil, continue to use default zero values.
v.Set(reflect.Zero(v.Type()))
}
}
d.popAllVs()
continue
case string, json.Number, bool, json.RawMessage, map[string]interface{}:
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
if !v.IsValid() {
continue
}
err := unmarshalValue(tok, v)
if err != nil {
return fmt.Errorf(": %w", err)

// Initialize the pointer if it is nil
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}

// Handle both pointer and non-pointer types
target := v
if v.Kind() == reflect.Ptr {
target = v.Elem()
}

// Check if the type of target (or its address) implements graphql.Unmarshaler
var unmarshaler graphql.Unmarshaler
var ok bool
if target.CanAddr() {
unmarshaler, ok = target.Addr().Interface().(graphql.Unmarshaler)
} else if target.CanInterface() {
unmarshaler, ok = target.Interface().(graphql.Unmarshaler)
}

if ok {
if err := unmarshaler.UnmarshalGQL(tok); err != nil {
return fmt.Errorf("unmarshal gql error: %w", err)
}
} else {
// Use the standard unmarshal method for non-custom types
if err := unmarshalValue(tok, target); err != nil {
return fmt.Errorf(": %w", err)
}
}
}
d.popAllVs()
Expand Down
125 changes: 125 additions & 0 deletions graphqljson/graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package graphqljson_test

import (
"encoding/json"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -589,3 +590,127 @@ func TestUnmarshalGraphQL_map(t *testing.T) {
t.Error(diff)
}
}

type Number int64

const (
NumberOne Number = 1
NumberTwo Number = 2
)

func (n *Number) UnmarshalGQL(v any) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}

switch str {
case "ONE":
*n = NumberOne
case "TWO":
*n = NumberTwo
default:

return fmt.Errorf("Number not found Type: %d", n)
}

return nil
}

func TestUnmarshalGQL(t *testing.T) {
t.Parallel()
type query struct {
Enum Number
}
var got query
err := graphqljson.UnmarshalData([]byte(`{
"enum": "ONE"
}`), &got)
if err != nil {
t.Fatal(err)
}
want := query{
Enum: NumberOne,
}
if diff := cmp.Diff(got, want); diff != "" {
t.Error(diff)
}
}

func TestUnmarshalGQL_array(t *testing.T) {
t.Parallel()
type query struct {
Enums []Number
}
var got query
err := graphqljson.UnmarshalData([]byte(`{
"enums": ["ONE", "TWO"]
}`), &got)
if err != nil {
t.Fatal(err)
}
want := query{
Enums: []Number{NumberOne, NumberTwo},
}
if diff := cmp.Diff(got, want); diff != "" {
t.Error(diff)
}
}

func TestUnmarshalGQL_pointer(t *testing.T) {
t.Parallel()
type query struct {
Enum *Number
}
var got query
err := graphqljson.UnmarshalData([]byte(`{
"enum": "ONE"
}`), &got)
if err != nil {
t.Fatal(err)
}

v := NumberOne
want := query{
Enum: &v,
}
if diff := cmp.Diff(got, want); diff != "" {
t.Error(diff)
}
}

func TestUnmarshalGQL_pointerArray(t *testing.T) {
t.Parallel()
type query struct {
Enums []*Number
}
var got query
err := graphqljson.UnmarshalData([]byte(`{
"enums": ["ONE", "TWO"]
}`), &got)
if err != nil {
t.Fatal(err)
}
one := NumberOne
two := NumberTwo
want := query{
Enums: []*Number{&one, &two},
}
if diff := cmp.Diff(got, want); diff != "" {
t.Error(diff)
}
}

func TestUnmarshalGQL_pointerArrayReset(t *testing.T) {
t.Parallel()
got := []*Number{new(Number)}
err := graphqljson.UnmarshalData([]byte(`["TWO"]`), &got)
if err != nil {
t.Fatal(err)
}
want := []*Number{new(Number)}
*want[0] = NumberTwo
if diff := cmp.Diff(got, want); diff != "" {
t.Error(diff)
}
}

0 comments on commit 5692918

Please sign in to comment.