Skip to content

Commit

Permalink
Merge pull request #278 from illia-li/il/add/marshal_corrupt_test_suite
Browse files Browse the repository at this point in the history
add marshal `corrupt test suite`
  • Loading branch information
sylwiaszunejko authored Oct 1, 2024
2 parents 0ee3dcb + 6a1a6be commit e5191bd
Show file tree
Hide file tree
Showing 28 changed files with 471 additions and 273 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import (
"reflect"
)

// ErrFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function
// errFirstPtrChanged this error indicates that a double or single reference was passed to the Unmarshal function
// (example (**int)(**0) or (*int)(*0)) and Unmarshal overwritten first reference.
var ErrFirstPtrChanged = errors.New("unmarshal function rewrote first pointer")
var errFirstPtrChanged = errors.New("unmarshal function rewrote first pointer")

// ErrSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function
// errSecondPtrNotChanged this error indicates that a double reference was passed to the Unmarshal function
// (example (**int)(**0)) and the function did not overwrite the second reference.
// Of course, it's not friendly to the garbage collector, overwriting references to values all the time,
// but this is the current implementation `gocql` and changing it can lead to unexpected results in some cases.
var ErrSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer")
var errSecondPtrNotChanged = errors.New("unmarshal function did not rewrite second pointer")

func getPointers(i interface{}) *pointer {
rv := reflect.ValueOf(i)
Expand Down Expand Up @@ -45,10 +45,10 @@ func (p *pointer) NotNil() bool {
func (p *pointer) Valid(v interface{}) error {
p2 := getPointers(v)
if p.Fist != p2.Fist {
return ErrFirstPtrChanged
return errFirstPtrChanged
}
if p.Second != 0 && p2.Second != 0 && p2.Second == p.Second {
return ErrSecondPtrNotChanged
return errSecondPtrNotChanged
}
return nil
}
59 changes: 59 additions & 0 deletions internal/tests/serialization/set_negative_marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package serialization

import (
"errors"
"reflect"
"runtime/debug"
"testing"
)

// NegativeMarshalSet is a tool for marshal funcs testing for cases when the function should an error.
type NegativeMarshalSet struct {
Values []interface{}
BrokenTypes []reflect.Type
}

func (s NegativeMarshalSet) Run(name string, t *testing.T, marshal func(interface{}) ([]byte, error)) {
if name == "" {
t.Fatal("name should be provided")
}
if marshal == nil {
t.Fatal("marshal function should be provided")
}
t.Run(name, func(t *testing.T) {
for m := range s.Values {
val := s.Values[m]

t.Run(stringValue(val), func(t *testing.T) {
_, err := func() (d []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return marshal(val)
}()

testFailed := false
wasPanic := errors.As(err, &panicErr{})
if err == nil || wasPanic {
testFailed = true
}

if isTypeOf(val, s.BrokenTypes) {
if testFailed {
t.Skipf("skipped bacause there is unsolved problem")
}
t.Fatalf("expected to panic or no error for (%T), but got an error", val)
}

if testFailed {
if wasPanic {
t.Fatalf("was panic %s", err)
}
t.Errorf("expected an error for (%T), but got no error", val)
}
})
}
})
}
77 changes: 77 additions & 0 deletions internal/tests/serialization/set_negative_unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package serialization

import (
"bytes"
"errors"
"fmt"
"reflect"
"runtime/debug"
"testing"
)

// NegativeUnmarshalSet is a tool for unmarshal funcs testing for cases when the function should an error.
type NegativeUnmarshalSet struct {
Data []byte
Values []interface{}
BrokenTypes []reflect.Type
}

func (s NegativeUnmarshalSet) Run(name string, t *testing.T, unmarshal func([]byte, interface{}) error) {
if name == "" {
t.Fatal("name should be provided")
}
if unmarshal == nil {
t.Fatal("unmarshal function should be provided")
}
t.Run(name, func(t *testing.T) {
for m := range s.Values {
val := s.Values[m]

if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr {
unmarshalIn := newRef(val)
s.run(fmt.Sprintf("%T", val), t, unmarshal, val, unmarshalIn)
} else {
// Test unmarshal to (*type)(nil)
unmarshalIn := newRef(val)
s.run(fmt.Sprintf("%T**nil", val), t, unmarshal, val, unmarshalIn)

// Test unmarshal to &type{}
unmarshalInZero := newRefToZero(val)
s.run(fmt.Sprintf("%T**zero", val), t, unmarshal, val, unmarshalInZero)
}
}
})
}

func (s NegativeUnmarshalSet) run(name string, t *testing.T, f func([]byte, interface{}) error, val, unmarshalIn interface{}) {
t.Run(name, func(t *testing.T) {
err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(bytes.Clone(s.Data), unmarshalIn)
}()

testFailed := false
wasPanic := errors.As(err, &panicErr{})
if err == nil || wasPanic {
testFailed = true
}

if isTypeOf(val, s.BrokenTypes) {
if testFailed {
t.Skipf("skipped bacause there is unsolved problem")
}
t.Fatalf("expected to panic or no error for (%T), but got an error", unmarshalIn)
}

if testFailed {
if wasPanic {
t.Fatalf("was panic %s", err)
}
t.Errorf("expected an error for (%T), but got no error", unmarshalIn)
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,20 @@ import (
"reflect"
"runtime/debug"
"testing"

"github.com/gocql/gocql/internal/tests/utils"
)

type Sets []*Set

// Set is a tool for generating test cases of marshal and unmarshall funcs.
// For cases when the function should no error,
// marshaled data from Set.Values should be equal with Set.Data,
// unmarshalled value from Set.Data should be equal with Set.Values.
type Set struct {
// PositiveSet is a tool for marshal and unmarshall funcs testing for cases when the function should no error,
// on marshal - marshaled data from PositiveSet.Values should be equal with PositiveSet.Data,
// on unmarshall - unmarshalled value from PositiveSet.Data should be equal with PositiveSet.Values.
type PositiveSet struct {
Data []byte
Values []interface{}

BrokenMarshalTypes []reflect.Type
BrokenUnmarshalTypes []reflect.Type
}

func (s Set) Run(name string, t *testing.T, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) {
func (s PositiveSet) Run(name string, t *testing.T, marshal func(interface{}) ([]byte, error), unmarshal func([]byte, interface{}) error) {
if name == "" {
t.Fatal("name should be provided")
}
Expand All @@ -41,15 +36,15 @@ func (s Set) Run(name string, t *testing.T, marshal func(interface{}) ([]byte, e

if unmarshal != nil {
if rt := reflect.TypeOf(val); rt.Kind() != reflect.Ptr {
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalTest("unmarshal", t, unmarshal, val, unmarshalIn)
} else {
// Test unmarshal to (*type)(nil)
unmarshalIn := utils.NewRef(val)
unmarshalIn := newRef(val)
s.runUnmarshalTest("unmarshal**nil", t, unmarshal, val, unmarshalIn)

// Test unmarshal to &type{}
unmarshalInZero := utils.NewRefToZero(val)
unmarshalInZero := newRefToZero(val)
s.runUnmarshalTest("unmarshal**zero", t, unmarshal, val, unmarshalInZero)
}
}
Expand All @@ -58,25 +53,25 @@ func (s Set) Run(name string, t *testing.T, marshal func(interface{}) ([]byte, e
})
}

func (s Set) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), val interface{}) {
func (s PositiveSet) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), val interface{}) {
t.Run("marshal", func(t *testing.T) {

result, err := func() (d []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: r.(error), Stack: debug.Stack()}
err = panicErr{err: r.(error), stack: debug.Stack()}
}
}()
return f(val)
}()

expected := bytes.Clone(s.Data)
if err != nil {
if !errors.As(err, &utils.PanicErr{}) {
err = errors.Join(MarshalErr, err)
if !errors.As(err, &panicErr{}) {
err = errors.Join(marshalErr, err)
}
} else if !utils.EqualData(expected, result) {
err = UnequalError{Expected: utils.StringData(s.Data), Got: utils.StringData(result)}
} else if !equalData(expected, result) {
err = unequalError{Expected: stringData(s.Data), Got: stringData(result)}
}

if isTypeOf(val, s.BrokenMarshalTypes) {
Expand All @@ -91,26 +86,26 @@ func (s Set) runMarshalTest(t *testing.T, f func(interface{}) ([]byte, error), v
})
}

func (s Set) runUnmarshalTest(name string, t *testing.T, f func([]byte, interface{}) error, expected, result interface{}) {
func (s PositiveSet) runUnmarshalTest(name string, t *testing.T, f func([]byte, interface{}) error, expected, result interface{}) {
t.Run(name, func(t *testing.T) {

expectedPtr := getPointers(result)

err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = utils.PanicErr{Err: fmt.Errorf("%s", r), Stack: debug.Stack()}
err = panicErr{err: fmt.Errorf("%s", r), stack: debug.Stack()}
}
}()
return f(bytes.Clone(s.Data), result)
}()

if err != nil {
if !errors.As(err, &utils.PanicErr{}) {
err = errors.Join(UnmarshalErr, err)
if !errors.As(err, &panicErr{}) {
err = errors.Join(unmarshalErr, err)
}
} else if !utils.EqualVals(expected, utils.DeReference(result)) {
err = UnequalError{Expected: utils.StringValue(expected), Got: utils.StringValue(result)}
} else if !equalVals(expected, deReference(result)) {
err = unequalError{Expected: stringValue(expected), Got: stringValue(result)}
} else {
err = expectedPtr.Valid(result)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
package utils
package serialization

import (
"reflect"
)

func DeReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}

func Reference(val interface{}) interface{} {
out := reflect.New(reflect.TypeOf(val))
out.Elem().Set(reflect.ValueOf(val))
return out.Interface()
}

func GetTypes(values ...interface{}) []reflect.Type {
types := make([]reflect.Type, len(values))
for i, value := range values {
types[i] = reflect.TypeOf(value)
}
return types
}

func isTypeOf(value interface{}, types []reflect.Type) bool {
valueType := reflect.TypeOf(value)
for i := range types {
if types[i] == valueType {
return true
}
}
return false
}

func deReference(in interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(in)).Interface()
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
package utils
package serialization

import (
"bytes"
"fmt"
"github.com/gocql/gocql/internal/tests/serialization/mod"
"gopkg.in/inf.v0"
"math/big"
"reflect"
"unsafe"

"github.com/gocql/gocql/marshal/tests/mod"
)

func EqualData(in1, in2 []byte) bool {
func equalData(in1, in2 []byte) bool {
if in1 == nil || in2 == nil {
return in1 == nil && in2 == nil
}
return bytes.Equal(in1, in2)
}

func EqualVals(in1, in2 interface{}) bool {
func equalVals(in1, in2 interface{}) bool {
rin1 := reflect.ValueOf(in1)
rin2 := reflect.ValueOf(in2)
if rin1.Kind() != rin2.Kind() {
Expand Down
27 changes: 27 additions & 0 deletions internal/tests/serialization/utils_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package serialization

import (
"errors"
"fmt"
)

var unmarshalErr = errors.New("unmarshal unexpectedly failed with error")
var marshalErr = errors.New("marshal unexpectedly failed with error")

type unequalError struct {
Expected string
Got string
}

func (e unequalError) Error() string {
return fmt.Sprintf("expect %s but got %s", e.Expected, e.Got)
}

type panicErr struct {
err error
stack []byte
}

func (e panicErr) Error() string {
return fmt.Sprintf("%v\n%s", e.err, e.stack)
}
Loading

0 comments on commit e5191bd

Please sign in to comment.