diff --git a/pkg/codec.go b/pkg/codec.go index 5209761b..8af73f4a 100644 --- a/pkg/codec.go +++ b/pkg/codec.go @@ -23,7 +23,6 @@ import ( "context" "errors" "fmt" - "reflect" hessian "github.com/apache/dubbo-go-hessian2" "github.com/apache/dubbo-go-hessian2/java_exception" @@ -217,7 +216,7 @@ func (m *Hessian2Codec) messageData(message remote.Message, e iface.Encoder) err if !ok { return fmt.Errorf("invalid data: not hessian2.MessageWriter") } - types, err := getTypes(data) + types, err := dubbo.GetTypes(data) if err != nil { return err } @@ -375,10 +374,6 @@ func processAttachments(decoder iface.Decoder, message remote.Message) error { if err != nil { return err } - // - //if attachmentsRaw == nil || attachmentsRaw == "" { - // attachmentsRaw = map[interface{}]interface{}{"interface": service.} - //} if attachments, ok := attachmentsRaw.(map[interface{}]interface{}); ok { for keyRaw, val := range attachments { @@ -392,16 +387,6 @@ func processAttachments(decoder iface.Decoder, message remote.Message) error { return fmt.Errorf("unsupported attachments: %v", attachmentsRaw) } -func getTypes(data interface{}) (string, error) { - elem := reflect.ValueOf(data).Elem() - numField := elem.NumField() - fields := make([]interface{}, numField) - for i := 0; i < numField; i++ { - fields[i] = elem.Field(i).Interface() - } - return dubbo.GetParamsTypeList(fields) -} - func readBody(header *dubbo.DubboHeader, in remote.ByteBuffer) ([]byte, error) { length := int(header.DataLength) if in.ReadableLen() < length { diff --git a/pkg/dubbo/parameter.go b/pkg/dubbo/parameter.go index ddb2854f..02e05067 100644 --- a/pkg/dubbo/parameter.go +++ b/pkg/dubbo/parameter.go @@ -15,6 +15,10 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. + * + * This source file has been replicated from the original dubbo-go project + * repository, and we extend our sincere appreciation to the dubbo-go + * development team for their valuable contribution. */ package dubbo @@ -23,13 +27,46 @@ import ( "fmt" "reflect" "strings" + "sync" "time" hessian "github.com/apache/dubbo-go-hessian2" ) +var ( + typesMapLock sync.Mutex + typesMap = make(map[reflect.Type]string) +) + +func GetTypes(data interface{}) (string, error) { + val := reflect.ValueOf(data) + typ := val.Type() + typesMapLock.Lock() + types, ok := typesMap[typ] + if ok { + typesMapLock.Unlock() + return types, nil + } + + elem := val.Elem() + numField := elem.NumField() + fields := make([]interface{}, numField) + for i := 0; i < numField; i++ { + fields[i] = elem.Field(i).Interface() + } + + types, err := getParamsTypeList(fields) + if err != nil { + typesMapLock.Unlock() + return "", err + } + typesMap[typ] = types + typesMapLock.Unlock() + return types, nil +} + // GetParamsTypeList is copied from dubbo-go, it should be rewritten -func GetParamsTypeList(params []interface{}) (string, error) { +func getParamsTypeList(params []interface{}) (string, error) { var ( typ string types string diff --git a/pkg/dubbo/parameter_test.go b/pkg/dubbo/parameter_test.go new file mode 100644 index 00000000..af390316 --- /dev/null +++ b/pkg/dubbo/parameter_test.go @@ -0,0 +1,132 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dubbo + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testInternalStruct struct { + Field int +} + +type testStructA struct { + Field int +} + +type testStructB struct { + Internal *testInternalStruct +} + +func TestGetTypes(t *testing.T) { + tests := []struct { + desc string + datum []interface{} + expected func(t *testing.T, typesMap map[reflect.Type]string) + }{ + { + desc: "same structs with basic Type", + datum: []interface{}{ + &testStructA{Field: 1}, + &testStructA{Field: 2}, + }, + expected: func(t *testing.T, typesMap map[reflect.Type]string) { + data := &testStructA{Field: 3} + typ := reflect.ValueOf(data).Type() + assert.Equal(t, 1, len(typesMap)) + assert.Equal(t, "J", typesMap[typ]) + }, + }, + { + desc: "same structs with embedded Type", + datum: []interface{}{ + &testStructB{ + Internal: &testInternalStruct{ + Field: 1, + }, + }, + &testStructB{ + Internal: &testInternalStruct{ + Field: 2, + }, + }, + }, + expected: func(t *testing.T, typesMap map[reflect.Type]string) { + data := &testStructB{ + Internal: &testInternalStruct{ + Field: 3, + }, + } + typ := reflect.ValueOf(data).Type() + assert.Equal(t, 1, len(typesMap)) + assert.Equal(t, "Ljava/lang/Object;", typesMap[typ]) + }, + }, + { + desc: "different structs", + datum: []interface{}{ + &testStructA{Field: 1}, + &testStructB{ + Internal: &testInternalStruct{ + Field: 2, + }, + }, + }, + expected: func(t *testing.T, typesMap map[reflect.Type]string) { + dataA := &testStructA{Field: 3} + dataB := &testStructB{ + Internal: &testInternalStruct{ + Field: 3, + }, + } + typA := reflect.ValueOf(dataA).Type() + typB := reflect.ValueOf(dataB).Type() + assert.Equal(t, 2, len(typesMap)) + assert.Equal(t, "J", typesMap[typA]) + assert.Equal(t, "Ljava/lang/Object;", typesMap[typB]) + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + // run GetTypes concurrently + for i, data := range test.datum { + testData := data + t.Run(fmt.Sprintf("struct%d", i), func(t *testing.T) { + t.Parallel() + _, err := GetTypes(testData) + if err != nil { + t.Fatal(err) + } + }) + } + t.Cleanup(func() { + test.expected(t, typesMap) + // reset + typesMap = make(map[reflect.Type]string) + }) + }) + } +}