Skip to content

Commit

Permalink
fix(core/reflect): handle missing values in slice with multiple eleme…
Browse files Browse the repository at this point in the history
…nts (#3762)

Co-authored-by: Jules Casteran <[email protected]>
Co-authored-by: Jules Castéran <[email protected]>
  • Loading branch information
3 people authored Apr 11, 2024
1 parent 6ef60ed commit de67e68
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 15 deletions.
2 changes: 1 addition & 1 deletion internal/core/arg_file_content.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func loadArgsFileContent(cmd *Command, cmdArgs interface{}) error {
}

fieldName := strcase.ToPublicGoName(argSpec.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
continue
}
Expand Down
29 changes: 18 additions & 11 deletions internal/core/reflect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"errors"
"fmt"
"reflect"
"sort"
Expand Down Expand Up @@ -34,26 +35,33 @@ func newObjectWithForcedJSONTags(t reflect.Type) interface{} {
return reflect.New(reflect.StructOf(structFieldsCopy)).Interface()
}

// getValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
// GetValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
// The search is based on the name of the field.
func getValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
func GetValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
if len(parts) == 0 {
return []reflect.Value{value}, nil
}

switch value.Kind() {
case reflect.Ptr:
return getValuesForFieldByName(value.Elem(), parts)
return GetValuesForFieldByName(value.Elem(), parts)

case reflect.Slice:
values := []reflect.Value(nil)
errs := []error(nil)

for i := 0; i < value.Len(); i++ {
newValues, err := getValuesForFieldByName(value.Index(i), parts[1:])
newValues, err := GetValuesForFieldByName(value.Index(i), parts[1:])
if err != nil {
return nil, err
errs = append(errs, err)
} else {
values = append(values, newValues...)
}
values = append(values, newValues...)
}

if len(values) == 0 && len(errs) != 0 {
return nil, errors.Join(errs...)
}

return values, nil

case reflect.Map:
Expand All @@ -70,7 +78,7 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl

for _, mapKey := range mapKeys {
mapValue := value.MapIndex(mapKey)
newValues, err := getValuesForFieldByName(mapValue, parts[1:])
newValues, err := GetValuesForFieldByName(mapValue, parts[1:])
if err != nil {
return nil, err
}
Expand All @@ -93,19 +101,18 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl

fieldName := strcase.ToPublicGoName(parts[0])
if fieldIndex, exist := fieldIndexByName[fieldName]; exist {
return getValuesForFieldByName(value.Field(fieldIndex), parts[1:])
return GetValuesForFieldByName(value.Field(fieldIndex), parts[1:])
}

// If it does not exist we try to find it in nested anonymous field
for _, fieldIndex := range anonymousFieldIndexes {
newValues, err := getValuesForFieldByName(value.Field(fieldIndex), parts)
newValues, err := GetValuesForFieldByName(value.Field(fieldIndex), parts)
if err == nil {
return newValues, nil
}
}

return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Type().Name())
}

return nil, fmt.Errorf("case is not handled")
}
181 changes: 181 additions & 0 deletions internal/core/reflect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package core_test

import (
"net"
"reflect"
"strings"
"testing"

"github.com/alecthomas/assert"
"github.com/scaleway/scaleway-cli/v2/internal/core"
"github.com/scaleway/scaleway-sdk-go/scw"
)

type RequestEmbedding struct {
EmbeddingField1 string
EmbeddingField2 int
}

type CreateRequest struct {
*RequestEmbedding
CreateField1 string
CreateField2 int
}

type ExtendedRequest struct {
*CreateRequest
ExtendedField1 string
ExtendedField2 int
}

type ArrowRequest struct {
PrivateNetwork *PrivateNetwork
}

type SpecialRequest struct {
*RequestEmbedding
TabRequest []*ArrowRequest
}

type EndpointSpecPrivateNetwork struct {
PrivateNetworkID string
ServiceIP *scw.IPNet
}

type PrivateNetwork struct {
*EndpointSpecPrivateNetwork
OtherValue string
}

func Test_getValuesForFieldByName(t *testing.T) {
type TestCase struct {
cmdArgs interface{}
fieldName string
expectedError string
expectedValues []reflect.Value
}

expectedServiceIP := &scw.IPNet{
IPNet: net.IPNet{
IP: net.ParseIP("192.0.2.1"),
Mask: net.CIDRMask(24, 32),
},
}

tests := []struct {
name string
testCase TestCase
testFunc func(*testing.T, TestCase)
}{
{
name: "Simple test",
testCase: TestCase{
cmdArgs: &ExtendedRequest{
CreateRequest: &CreateRequest{
RequestEmbedding: &RequestEmbedding{
EmbeddingField1: "value1",
EmbeddingField2: 2,
},
CreateField1: "value3",
CreateField2: 4,
},
ExtendedField1: "value5",
ExtendedField2: 6,
},
fieldName: "EmbeddingField1",
expectedError: "",
expectedValues: []reflect.Value{reflect.ValueOf("value1")},
},
testFunc: func(t *testing.T, tc TestCase) {
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
if err != nil {
assert.Equal(t, tc.expectedError, err.Error())
} else {
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
}
}
},
},
{
name: "Error test",
testCase: TestCase{
cmdArgs: &ExtendedRequest{
CreateRequest: &CreateRequest{
RequestEmbedding: &RequestEmbedding{
EmbeddingField1: "value1",
EmbeddingField2: 2,
},
CreateField1: "value3",
CreateField2: 4,
},
ExtendedField1: "value5",
ExtendedField2: 6,
},
fieldName: "NotExist",
expectedError: "field NotExist does not exist for ExtendedRequest",
expectedValues: []reflect.Value{reflect.ValueOf("value1")},
},
testFunc: func(t *testing.T, tc TestCase) {
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
if err != nil {
assert.Equal(t, tc.expectedError, err.Error())
} else {
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
}
}
},
},
{

name: "Special test",
testCase: TestCase{
cmdArgs: &SpecialRequest{
RequestEmbedding: &RequestEmbedding{
EmbeddingField1: "value1",
EmbeddingField2: 2,
},
TabRequest: []*ArrowRequest{
{
PrivateNetwork: &PrivateNetwork{
EndpointSpecPrivateNetwork: &EndpointSpecPrivateNetwork{
ServiceIP: &scw.IPNet{
IPNet: net.IPNet{
IP: net.ParseIP("192.0.2.1"),
Mask: net.CIDRMask(24, 32),
},
},
},
},
},
{
PrivateNetwork: &PrivateNetwork{
OtherValue: "hello",
},
},
},
},
fieldName: "tabRequest.{index}.privateNetwork.serviceIP",
expectedError: "",
expectedValues: []reflect.Value{reflect.ValueOf(expectedServiceIP)},
},
testFunc: func(t *testing.T, tc TestCase) {
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
if err != nil {
assert.Equal(t, nil, err.Error())
} else {
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
}
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.testFunc(t, tt.testCase)
})
}
}
6 changes: 3 additions & 3 deletions internal/core/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func DefaultCommandValidateFunc() CommandValidateFunc {
func validateArgValues(cmd *Command, cmdArgs interface{}) error {
for _, argSpec := range cmd.ArgSpecs {
fieldName := strcase.ToPublicGoName(argSpec.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
logger.Infof("could not validate arg value for '%v': invalid fieldName: %v: %v", argSpec.Name, fieldName, err.Error())
continue
Expand Down Expand Up @@ -75,7 +75,7 @@ func validateRequiredArgs(cmd *Command, cmdArgs interface{}, rawArgs args.RawArg
}

fieldName := strcase.ToPublicGoName(arg.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error())
if !arg.Required {
Expand Down Expand Up @@ -117,7 +117,7 @@ func validateDeprecated(ctx context.Context, cmd *Command, cmdArgs interface{},
deprecatedArgs := cmd.ArgSpecs.GetDeprecated(true)
for _, arg := range deprecatedArgs {
fieldName := strcase.ToPublicGoName(arg.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error())
if !arg.Required {
Expand Down

0 comments on commit de67e68

Please sign in to comment.