diff --git a/args/args.go b/args/args.go index 34cb1f8f..b50928a0 100644 --- a/args/args.go +++ b/args/args.go @@ -149,15 +149,13 @@ func (a *Arguments) MakeLogFunc() backend.LogFunc { if !a.Quiet { if a.Verbose { logger := log.New(os.Stderr, "[INFO] ", 0) - logs.Info = func(v ...interface{}) { - logger.Println(v...) - } + logs.Info = logger.Println + logs.Infof = logger.Printf } logger := log.New(os.Stderr, "[WARN] ", 0) - logs.Warn = func(v ...interface{}) { - logger.Println(v...) - } + logs.Warn = logger.Println + logs.Warnf = logger.Printf logs.MultiWarn = func(ws []string) { for _, w := range ws { logger.Println(w) diff --git a/generator/backend/backend.go b/generator/backend/backend.go index 896c6854..c18e95ca 100644 --- a/generator/backend/backend.go +++ b/generator/backend/backend.go @@ -20,7 +20,9 @@ import "github.com/cloudwego/thriftgo/plugin" // LogFunc defines a set of log functions. type LogFunc struct { Info func(v ...interface{}) + Infof func(fmt string, v ...interface{}) Warn func(v ...interface{}) + Warnf func(fmt string, v ...interface{}) MultiWarn func(warns []string) } @@ -28,7 +30,9 @@ type LogFunc struct { func DummyLogFunc() LogFunc { return LogFunc{ Info: func(v ...interface{}) {}, + Infof: func(fmt string, v ...interface{}) {}, Warn: func(v ...interface{}) {}, + Warnf: func(fmt string, v ...interface{}) {}, MultiWarn: func(warns []string) {}, } } diff --git a/generator/fastgo/bitset.go b/generator/fastgo/bitset.go new file mode 100644 index 00000000..a9f4791f --- /dev/null +++ b/generator/fastgo/bitset.go @@ -0,0 +1,156 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +// bitsetCodeGen ... +// it's used by generate required fields bitset +type bitsetCodeGen struct { + varname string + typename string + varbits uint + + i uint + m map[interface{}]uint +} + +// newBitsetCodeGen ... +// varname - definition name of a bitset +// typename - it generates `var $varname [N]$typename` +func newBitsetCodeGen(varname, typename string) *bitsetCodeGen { + ret := &bitsetCodeGen{ + varname: varname, + typename: typename, + m: map[interface{}]uint{}, + } + switch typename { + case "byte", "uint8": + ret.varbits = 8 + case "uint16": + ret.varbits = 16 + case "uint32": + ret.varbits = 32 + case "uint64": + ret.varbits = 64 + default: + panic(typename) + } + return ret +} + +// Add adds a `v` to bitsetCodeGen. `v` must be uniq. +// it will be used by `GenSetbit` and `GenIfNotSet` +func (g *bitsetCodeGen) Add(v interface{}) { + _, ok := g.m[v] + if ok { + panic("duplicated") + } + g.m[v] = g.i + g.i++ +} + +// Len ... +func (g *bitsetCodeGen) Len() int { + return len(g.m) +} + +// GenVar generates the definition of a bitset +// if generates nothing if Add not called +func (g *bitsetCodeGen) GenVar(w *codewriter) { + if g.i == 0 { + return + } + bits := g.varbits + if g.i <= bits { + w.f("var %s %s", g.varname, g.typename) + return + } + w.f("var %s [%d]%s", g.varname, (g.i+bits-1)/bits, g.typename) +} + +func (g *bitsetCodeGen) bitvalue(i uint) uint64 { + i = i % g.varbits + return 1 << uint64(i) +} + +func (g *bitsetCodeGen) bitsvalue(n uint) uint64 { + if n > g.varbits { + panic(n) + } + ret := uint64(0) + for i := uint(0); i < n; i++ { + ret |= 1 << uint64(i) + } + return ret +} + +// GenSetbit generates setbit code for v, vmust be added to bitsetCodeGen +func (g *bitsetCodeGen) GenSetbit(w *codewriter, v interface{}) { + i, ok := g.m[v] + if !ok { + panic("[BUG] unknown v?") + } + if g.i <= g.varbits { + w.f("%s |= 0x%x", g.varname, g.bitvalue(i)) + } else { + w.f("%s[%d] |= 0x%x", g.varname, i/g.varbits, g.bitvalue(i)) + } +} + +// GenIfNotSet generates `if` code for each v +func (g *bitsetCodeGen) GenIfNotSet(w *codewriter, f func(w *codewriter, v interface{})) { + if len(g.m) == 0 { + return + } + m := make(map[uint]interface{}) + for k, v := range g.m { + m[v] = k + } + if g.i <= g.varbits { + if g.i > g.varbits/2 { + w.f("if %s != 0x%x {", g.varname, g.bitsvalue(g.i)) + defer w.f("}") + } + for i := uint(0); i < g.i; i++ { + w.f("if %s & 0x%x == 0 {", g.varname, g.bitvalue(i)) + f(w, m[i]) + w.f("}") + } + return + } + i := uint(0) + for i+g.varbits < g.i { + w.f("if %s[%d] != 0x%x {", g.varname, i/g.varbits, g.bitsvalue(g.varbits)) + end := i + g.varbits + for ; i < end; i++ { + w.f("if %s[%d] & 0x%x == 0 {", g.varname, i/g.varbits, g.bitvalue(i)) + f(w, m[i]) + w.f("}") + } + w.f("}") + } + if i < g.i { + if g.i%g.varbits > g.varbits/2 { + w.f("if %s[%d] != 0x%x {", g.varname, i/g.varbits, g.bitsvalue(g.i%g.varbits)) + defer w.f("}") + } + for ; i < g.i; i++ { + w.f("if %s[%d] & 0x%x == 0 {", g.varname, i/g.varbits, g.bitvalue(i)) + f(w, m[i]) + w.f("}") + } + } +} diff --git a/generator/fastgo/bitset_test.go b/generator/fastgo/bitset_test.go new file mode 100644 index 00000000..6dc32a6c --- /dev/null +++ b/generator/fastgo/bitset_test.go @@ -0,0 +1,106 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "fmt" + "strings" + "testing" +) + +func TestBitsetCodeGen(t *testing.T) { + g := newBitsetCodeGen("bitset", "uint8") + + w := newCodewriter() + + // case: less or equal than 64 elements + g.Add(1) + g.Add(2) + g.GenVar(w) + srcEqual(t, w.String(), "var bitset uint8") + w.Reset() + + g.GenSetbit(w, 2) + srcEqual(t, w.String(), "bitset |= 0x2") + w.Reset() + + g.GenIfNotSet(w, func(w *codewriter, id interface{}) { + w.f("_ = %d", id) + }) + srcEqual(t, w.String(), `if bitset & 0x1 == 0 { _ = 1 } + if bitset & 0x2 == 0 { _ = 2 }`) + w.Reset() + + g.Add(3) + g.Add(4) + g.Add(5) + g.Add(6) + g.Add(7) + g.Add(8) // case: g.i > g.varbits/2 + g.GenIfNotSet(w, func(w *codewriter, id interface{}) { + w.f("_ = %d", id) + }) + srcEqual(t, w.String(), `if bitset != 0xff { + if bitset & 0x1 == 0 { _ = 1 } + if bitset & 0x2 == 0 { _ = 2 } + if bitset & 0x4 == 0 { _ = 3 } + if bitset & 0x8 == 0 { _ = 4 } + if bitset & 0x10 == 0 { _ = 5 } + if bitset & 0x20 == 0 { _ = 6 } + if bitset & 0x40 == 0 { _ = 7 } + if bitset & 0x80 == 0 { _ = 8 } + }`) + w.Reset() + + // case: more than varbits elements + g = newBitsetCodeGen("bitset", "uint8") + for i := 0; i < 17; i++ { + g.Add(i + 100) + } + g.GenVar(w) + srcEqual(t, w.String(), "var bitset [3]uint8") + w.Reset() + + g.GenSetbit(w, 100) + srcEqual(t, w.String(), "bitset[0] |= 0x1") + w.Reset() + + g.GenSetbit(w, 115) + srcEqual(t, w.String(), "bitset[1] |= 0x80") + w.Reset() + + g.GenSetbit(w, 116) + srcEqual(t, w.String(), "bitset[2] |= 0x1") + w.Reset() + + g.GenIfNotSet(w, func(w *codewriter, id interface{}) { + w.f("_ = %d", id) + }) + sb := &strings.Builder{} + fmt.Fprintln(sb, "if bitset[0] != 0xff {") + for i := 0; i < 8; i++ { + fmt.Fprintf(sb, "if bitset[0]&0x%x == 0 { _ = %d }\n", 1< alias +} + +func newCodewriter() *codewriter { + return &codewriter{ + Buffer: &bytes.Buffer{}, + pkgs: make(map[string]string), + } +} + +func (w *codewriter) UsePkg(s, a string) { + if path.Base(s) == a { + w.pkgs[s] = "" + } else { + w.pkgs[s] = a + } +} + +func (w *codewriter) Imports() string { + pp0 := make([]string, 0, len(w.pkgs)) + pp1 := make([]string, 0, len(w.pkgs)) // for cloudwego + for pkg, _ := range w.pkgs { // grouping + if strings.HasPrefix(pkg, cloudwegoRepoPrefix) { + pp1 = append(pp1, pkg) + } else { + pp0 = append(pp0, pkg) + } + } + + // check if need an empty line between groups + if len(pp0) != 0 && len(pp1) > 0 { + pp0 = append(pp0, "") + } + + // no imports? + pp0 = append(pp0, pp1...) + if len(pp0) == 0 { + return "" + } + + // only imports one pkg? + if len(pp0) == 1 { + return fmt.Sprintf("import %s %q", w.pkgs[pp0[0]], pp0[0]) + } + + // more than one imports + s := &strings.Builder{} + fmt.Fprintln(s, "import (") + for _, p := range pp0 { + if p == "" { + fmt.Fprintln(s, "") + } else { + fmt.Fprintf(s, "%s %q\n", w.pkgs[p], p) + } + } + fmt.Fprintln(s, ")") + return s.String() +} + +func (w *codewriter) f(format string, a ...interface{}) { + fmt.Fprintf(w, format, a...) + + // always newline for each call + if len(format) == 0 || format[len(format)-1] != '\n' { + w.WriteByte('\n') + } +} diff --git a/generator/fastgo/consts.go b/generator/fastgo/consts.go new file mode 100644 index 00000000..169d1729 --- /dev/null +++ b/generator/fastgo/consts.go @@ -0,0 +1,103 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "strings" + + "github.com/cloudwego/thriftgo/parser" + "github.com/cloudwego/thriftgo/version" +) + +const cloudwegoRepoPrefix = "github.com/cloudwego/" + +var fixedFileHeader string + +func init() { + fixedFileHeader = strings.Replace( + `// Code generated by thriftgo ({{Version}}) (fastgo). DO NOT EDIT.`, + "{{Version}}", + version.ThriftgoVersion, 1) +} + +const ( // wiretypes + tSTOP = 0 + tVOID = 1 + tBOOL = 2 + tBYTE = 3 + tI08 = 3 + tDOUBLE = 4 + tI16 = 6 + tI32 = 8 + tI64 = 10 + tSTRING = 11 + tUTF7 = 11 + tSTRUCT = 12 + tMAP = 13 + tSET = 14 + tLIST = 15 + tUTF8 = 16 + tUTF16 = 17 +) + +var category2ThriftWireType = [16]int{ + // 0-15, panic if Category_Typedef or Category_Service + parser.Category_Bool: tBOOL, + parser.Category_Byte: tI08, + parser.Category_I16: tI16, + parser.Category_I32: tI32, + parser.Category_I64: tI64, + parser.Category_Double: tDOUBLE, + parser.Category_String: tSTRING, + parser.Category_Binary: tSTRING, + parser.Category_Map: tMAP, + parser.Category_List: tLIST, + parser.Category_Set: tSET, + parser.Category_Enum: tI32, + parser.Category_Struct: tSTRUCT, + parser.Category_Union: tSTRUCT, + parser.Category_Exception: tSTRUCT, +} + +var category2GopkgConsts = [16]string{ + // 0-15, panic if Category_Typedef or Category_Service + parser.Category_Bool: "thrift.BOOL", + parser.Category_Byte: "thrift.I08", + parser.Category_I16: "thrift.I16", + parser.Category_I32: "thrift.I32", + parser.Category_I64: "thrift.I64", + parser.Category_Double: "thrift.DOUBLE", + parser.Category_String: "thrift.STRING", + parser.Category_Binary: "thrift.STRING", + parser.Category_Map: "thrift.MAP", + parser.Category_List: "thrift.LIST", + parser.Category_Set: "thrift.SET", + parser.Category_Enum: "thrift.I32", + parser.Category_Struct: "thrift.STRUCT", + parser.Category_Union: "thrift.STRUCT", + parser.Category_Exception: "thrift.STRUCT", +} + +var category2WireSize = [16]int{ + parser.Category_Bool: 1, + parser.Category_Byte: 1, + parser.Category_I16: 2, + parser.Category_I32: 4, + parser.Category_Enum: 4, + parser.Category_I64: 8, + parser.Category_Double: 8, +} diff --git a/generator/fastgo/fastgo.go b/generator/fastgo/fastgo.go new file mode 100644 index 00000000..bd04ff96 --- /dev/null +++ b/generator/fastgo/fastgo.go @@ -0,0 +1,186 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "bytes" + "fmt" + "go/format" + "path" + "path/filepath" + + "github.com/cloudwego/thriftgo/generator/backend" + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/parser" + "github.com/cloudwego/thriftgo/plugin" +) + +// FastGoBackend ... +type FastGoBackend struct { + golang.GoBackend + + req *plugin.Request + log backend.LogFunc + + utils *golang.CodeUtils +} + +var _ backend.Backend = &FastGoBackend{} + +// Name implements the Backend interface. +func (g *FastGoBackend) Name() string { return "fastgo" } + +// Lang implements the Backend interface. +func (g *FastGoBackend) Lang() string { return "FastGo" } + +// Generate implements the Backend interface. +func (g *FastGoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.Response { + ret := g.GoBackend.Generate(req, log) + if ret.Error != nil { + return ret + } + g.req = req + g.log = log + g.utils = g.GoBackend.GetCoreUtils() + var trees chan *parser.Thrift + if req.Recursive { + trees = req.AST.DepthFirstSearch() + } else { + trees = make(chan *parser.Thrift, 1) + trees <- req.AST + close(trees) + } + respErr := func(err error) *plugin.Response { + errstr := err.Error() + ret.Error = &errstr + return ret + } + processed := make(map[*parser.Thrift]bool) + for ast := range trees { + if processed[ast] { + continue + } + processed[ast] = true + log.Info("Processing", ast.Filename) + content, err := g.GenerateOne(ast) + if err != nil { + return respErr(err) + } + ret.Contents = append(ret.Contents, content) + } + return ret +} + +func (g *FastGoBackend) GenerateOne(ast *parser.Thrift) (*plugin.Generated, error) { + // the filename should differentiate the default code files, + // keep same as kitex, coz we're deprecating the old impl of fastcodec. + // it will overwrites the old k-xxx.go. + filename := "k-" + g.utils.GetFilename(ast) + filename = filepath.Join(g.utils.CombineOutputPath(g.req.OutputPath, ast), filename) + + // not generating ref code, see `code_ref` parameter + scope, _, err := golang.BuildRefScope(g.utils, ast) + if err != nil { + return nil, fmt.Errorf("golang.BuildRefScope: %w", err) + } + + w := newCodewriter() + + // TODO: only supports struct now, other dirty jobs will be done in golang.GoBackend + for _, s := range scope.Structs() { + g.generateStruct(w, scope, s) + } + for _, s := range scope.Unions() { + g.generateStruct(w, scope, s) + } + for _, s := range scope.Exceptions() { + g.generateStruct(w, scope, s) + } + for _, ss := range scope.Services() { + for _, f := range ss.Functions() { + if s := f.ArgType(); s != nil { + g.generateStruct(w, scope, s) + } + if s := f.ResType(); s != nil { + g.generateStruct(w, scope, s) + } + } + } + + ret := &plugin.Generated{} + ret.Name = &filename + + // for ret.Content + c := &bytes.Buffer{} + + // Headers: + // thriftgo version and package name + packageName := path.Base(golang.GetImportPath(g.utils, ast)) + fmt.Fprintf(c, "%s\npackage %s\n\n", fixedFileHeader, packageName) + + // Imports + unusedProtect := false + for _, incl := range scope.Includes() { + if incl == nil { // TODO(liyun.339): fix this + continue + } + unusedProtect = true + w.UsePkg(incl.ImportPath, incl.PackageName) + } + if len(w.pkgs) > 0 { + c.WriteString(w.Imports()) + } + c.WriteByte('\n') + + // Unused protects + if unusedProtect { + fmt.Fprintln(c, "var (") + for _, incl := range scope.Includes() { + if incl == nil { // TODO(liyun.339): fix this + continue + } + fmt.Fprintf(c, "_ = %s.KitexUnusedProtection\n", incl.PackageName) + } + fmt.Fprintln(c, ")") + } + + // Methods + c.Write(w.Bytes()) + + ret.Content = g.Format(filename, c.Bytes()) + return ret, nil +} + +func (g *FastGoBackend) Format(filename string, content []byte) string { + if g.utils.Features().NoFmt { + return string(content) + } + if formated, err := format.Source(content); err != nil { + g.log.Warnf("Failed to format %s: %s", filename, err.Error()) + } else { + content = formated + } + return string(content) +} + +func (g *FastGoBackend) generateStruct(w *codewriter, scope *golang.Scope, s *golang.StructLike) { + // TODO: This method doesn't generate struct definition for now. + // It only generates a better version of FastRead, FastWrite(Nocopy) methods which originally from Kitex. + g.genBLength(w, scope, s) + g.genFastWrite(w, scope, s) + g.genFastRead(w, scope, s) +} diff --git a/generator/fastgo/gen_blength.go b/generator/fastgo/gen_blength.go new file mode 100644 index 00000000..19b9c3eb --- /dev/null +++ b/generator/fastgo/gen_blength.go @@ -0,0 +1,181 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "strconv" + + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/parser" +) + +// genBLength must be aligned with genBLength +// XXX: the code looks a bit redundant ... +func (g *FastGoBackend) genBLength(w *codewriter, scope *golang.Scope, s *golang.StructLike) { + // var conventions: + // - p is the var of pointer to the struct going to be generated + // - off is the counter of BLength + + // func definition + w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "") + w.f("func (p *%s) BLength() int {", s.GoName()) + + // case nil, STOP + w.f("if p == nil { return 1; }") + + w.f("off := 0") + + // fields + ff := getSortedFields(s) + for _, f := range ff { + rwctx, err := g.utils.MkRWCtx(scope, f) + if err != nil { + // never goes here, should fail early in generator/golang pkg + panic(err) + } + genBLengthField(w, rwctx, f) + } + + // end of field encoding + w.f("return off + 1") // return including the STOP byte + + // end of func definition + w.f("}\n\n") +} + +func genBLengthField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.Field) { + // the real var name ref to the field + varname := string("p." + f.GoName()) + + // add comment like // ${FieldName} ${FieldID} ${FieldType} + w.f("\n// %s ID:%d %s", rwctx.Target, f.ID, category2GopkgConsts[f.Type.Category]) + + // check skip cases + // only for optional fields + if f.Requiredness == parser.FieldType_Optional { + if f.GoTypeName().IsPointer() || isContainerType(f.Type) { + // case 1: optional and nil + w.f("if %s != nil {", varname) + defer w.f("}") + } else if !f.GoTypeName().IsPointer() && f.Default != nil { + // case 2: optional and equals to default value + w.f("if %s != %v {", varname, f.DefaultValue()) + defer w.f("}") + } + } + + // field header + w.f("off += 3") // type + fid + + // field value + genBLengthAny(w, rwctx, varname, 0) + +} + +func genBLengthAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + t := rwctx.Type + if sz := category2WireSize[t.Category]; sz > 0 { + w.f("off += %d", sz) + return + } + pointer := rwctx.IsPointer + switch t.Category { + case parser.Category_String, parser.Category_Binary: + genBLengthString(w, pointer, varname) + case parser.Category_Map: + genBLengthMap(w, rwctx, varname, depth) + case parser.Category_List, parser.Category_Set: + genBLengthList(w, rwctx, varname, depth) + case parser.Category_Struct, parser.Category_Union, parser.Category_Exception: + genBLengthStruct(w, rwctx, varname) + } +} + +func genBLengthBinary(w *codewriter, pointer bool, varname string) { + varname = varnameVal(pointer, varname) + w.f("off += 4 + len(%s)", varname) +} + +func genBLengthString(w *codewriter, pointer bool, varname string) { + varname = varnameVal(pointer, varname) + w.f("off += 4 + len(%s)", varname) +} + +func genBLengthStruct(w *codewriter, _ *golang.ReadWriteContext, varname string) { + w.f("off += %s.BLength()", varname) +} + +func genBLengthList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + t := rwctx.Type + // list header + w.f("off += 5") + + // if element is basic type like int32, we can speed up the calc by sizeof(int32) * len(l) + if sz := category2WireSize[t.ValueType.Category]; sz > 0 { // fast path for less code + w.f("off += len(%s) * %d", varnameVal(rwctx.IsPointer, varname), sz) + return + } + + // iteration tmp var + tmpv := "v" + if depth > 0 { // avoid redeclared vars + tmpv = "v" + strconv.Itoa(depth-1) + } + w.f("for _, %s := range %s {", tmpv, varname) + genBLengthAny(w, rwctx.ValCtx, tmpv, depth+1) + w.f("}") +} + +func genBLengthMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + t := rwctx.Type + kt := t.KeyType + vt := t.ValueType + + // map header + w.f("off += 6") + + // iteration tmp var + tmpk := "k" + tmpv := "v" + if depth > 0 { // avoid redeclared vars + tmpk = "k" + strconv.Itoa(depth-1) + tmpv = "v" + strconv.Itoa(depth-1) + } + + // if key or value is basic type like int32, we can speed up the calc by sizeof(int32) * len(m) + varname = varnameVal(rwctx.IsPointer, varname) + ksz := category2WireSize[kt.Category] + vsz := category2WireSize[vt.Category] + if ksz > 0 && vsz > 0 { + w.f("off += len(%s) * (%d+%d)", varname, ksz, vsz) + } else if ksz > 0 { + w.f("off += len(%s) * %d", varname, ksz) + w.f("for _, %s := range %s {", tmpv, varname) + genBLengthAny(w, rwctx.ValCtx, tmpv, depth+1) + w.f("}") + } else if vsz > 0 { + w.f("off += len(%s) * %d", varname, vsz) + w.f("for %s, _ := range %s {", tmpk, varname) + genBLengthAny(w, rwctx.KeyCtx, tmpk, depth+1) + w.f("}") + } else { + w.f("for %s, %s := range %s {", tmpk, tmpv, varname) + genBLengthAny(w, rwctx.KeyCtx, tmpk, depth+1) + genBLengthAny(w, rwctx.ValCtx, tmpv, depth+1) + w.f("}") + } +} diff --git a/generator/fastgo/gen_fastread.go b/generator/fastgo/gen_fastread.go new file mode 100644 index 00000000..7da4e1a0 --- /dev/null +++ b/generator/fastgo/gen_fastread.go @@ -0,0 +1,328 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "strconv" + + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/parser" +) + +func (g *FastGoBackend) genFastRead(w *codewriter, scope *golang.Scope, s *golang.StructLike) { + // var conventions: + // - p is the var of pointer to the struct going to be generated + // - b is the buf to read from + // - off is the offset of b + // - err is the return err + // - ftyp, fid only used in this method + // - l must be increased after read + // - enum is the tmp var for enum, it's updated by ReadInt32, and then set to the enum field + // - x is the decoder of thrift.BinaryProtocol + // + // Please update the list if you'r going to add more vars + // Instead of using consts for vars above, would like to use the names directly making code clear + + // func definition + w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "") + w.f("func (p *%s) FastRead(b []byte) (off int, err error) {", s.GoName()) + w.f("var ftyp thrift.TType") + w.f("var fid int16") + w.f("var l int") + + isset := newBitsetCodeGen("isset", "uint8") + hasEnum := false + ff := getSortedFields(s) + for _, f := range ff { + if f.Type.Category == parser.Category_Enum { + hasEnum = true + } + if f.Requiredness == parser.FieldType_Required { + isset.Add(f) + } + } + if hasEnum { + w.f("var enum int32") // tmp var for enum + } + isset.GenVar(w) + + w.f("x := thrift.BinaryProtocol{}") // empty struct, no stack needed, for shorten varname + + w.f("for {") + + w.f("ftyp, fid, l, err = x.ReadFieldBegin(b[off:])") + w.f("off += l") + w.f("if err != nil { goto ReadFieldBeginError }") + w.f("if ftyp == thrift.STOP { break }") + + // fields + w.f("switch uint32(fid)<<8| uint32(ftyp) {") + for _, f := range ff { + rwctx, err := g.utils.MkRWCtx(scope, f) + if err != nil { + // never goes here, should fail early in generator/golang pkg + panic(err) + } + w.f("case 0x%x: // %s ID:%d %s", + uint32(f.ID)<<8|uint32(category2ThriftWireType[f.Type.Category]), + rwctx.Target, f.ID, category2GopkgConsts[f.Type.Category]) + genFastReadAny(w, rwctx, rwctx.Target, 0) + if f.Requiredness == parser.FieldType_Required { + isset.GenSetbit(w, f) + } + } + w.f("default:") // default case, skip + w.f(" l, err = x.Skip(b[off:], ftyp)") + w.f(" off += l") + w.f(" if err != nil { goto SkipFieldError }") + w.f("}") // switch fid ends + w.f("}") // for ends + + isset.GenIfNotSet(w, func(w *codewriter, v interface{}) { + f := v.(*golang.Field) + w.f("fid = %d // %s", f.ID, f.GoName()) + w.f("goto RequiredFieldNotSetError") + }) + + w.f("return") // no error + + w.UsePkg("fmt", "") + w.f("ReadFieldBeginError:") + w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field begin error: ", p), err)`) + + if len(ff) > 0 { // fix `label ReadFieldError defined and not used` + w.f("ReadFieldError:") + w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T read field %%d '%%s' error: ", p, fid, fieldIDToName_%s[fid]), err)`, s.GoName()) + } + + w.f("SkipFieldError:") + w.f(`return off, thrift.PrependError(fmt.Sprintf("%%T skip field %%d type %%d error: ", p, fid, ftyp), err)`) + + if isset.Len() > 0 { + w.f("RequiredFieldNotSetError:") + w.f(`return off, thrift.NewProtocolException(thrift.INVALID_DATA, fmt.Sprintf("required field %%s is not set", fieldIDToName_%s[fid]))`, s.GoName()) + } + + // end of func definition + w.f("}\n\n") +} + +func genFastReadAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + t := rwctx.Type + pointer := rwctx.IsPointer + switch t.Category { + case parser.Category_Bool: + genFastReadBool(w, pointer, varname) + case parser.Category_Byte: + genFastReadByte(w, pointer, varname) + case parser.Category_I16: + genFastReadInt16(w, pointer, varname) + case parser.Category_I32: + genFastReadInt32(w, pointer, varname) + case parser.Category_Enum: + genFastReadEnum(w, rwctx, varname) + case parser.Category_I64: + genFastReadInt64(w, pointer, varname) + case parser.Category_Double: + genFastReadDouble(w, pointer, varname) + case parser.Category_String: + genFastReadString(w, pointer, varname) + case parser.Category_Binary: + genFastReadBinary(w, pointer, varname) + case parser.Category_Map: + genFastReadMap(w, rwctx, varname, depth) + case parser.Category_List: + genFastReadList(w, rwctx, varname, depth) + case parser.Category_Set: + genFastReadList(w, rwctx, varname, depth) + case parser.Category_Struct: + genFastReadStruct(w, rwctx, varname) + case parser.Category_Union: + genFastReadStruct(w, rwctx, varname) + case parser.Category_Exception: + genFastReadStruct(w, rwctx, varname) + } +} + +func genFastReadBool(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(bool) }", varname, varname) + } + w.f("%s, l, err = x.ReadBool(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadByte(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(int8) }", varname, varname) + } + w.f("%s, l, err = x.ReadByte(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadDouble(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(float64) }", varname, varname) + } + w.f("%s, l, err = x.ReadDouble(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadInt16(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(int16) }", varname, varname) + } + w.f("%s, l, err = x.ReadI16(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadInt32(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(int32) }", varname, varname) + } + w.f("%s, l, err = x.ReadI32(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadEnum(w *codewriter, rwctx *golang.ReadWriteContext, varname string) { + pointer := rwctx.IsPointer + if pointer { + w.f("if %s == nil { %s = new(%s) }", varname, varname, rwctx.TypeName.Deref()) + } + + w.f("enum, l, err = x.ReadI32(b[off:])") + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") + w.f("%s = %s(enum)", varnameVal(pointer, varname), rwctx.TypeName.Deref()) +} + +func genFastReadInt64(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(int64) }", varname, varname) + } + w.f("%s, l, err = x.ReadI64(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadBinary(w *codewriter, pointer bool, varname string) { + if pointer { // always false? + w.f("if %s == nil { %s = new([]byte) } ", varname, varname) + } + w.f("%s, l, err = x.ReadBinary(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadString(w *codewriter, pointer bool, varname string) { + if pointer { + w.f("if %s == nil { %s = new(string) } ", varname, varname) + } + w.f("%s, l, err = x.ReadString(b[off:])", varnameVal(pointer, varname)) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadStruct(w *codewriter, rwctx *golang.ReadWriteContext, varname string) { + w.f("%s = %s()", varname, rwctx.TypeName.Deref().NewFunc()) + w.f("l, err = %s.FastRead(b[off:])", varname) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") +} + +func genFastReadList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + // var conventions: + // - sz is the size of a list + // - i is unsed to interate for loop + // + // you must use the vars below instead of using literal above, + // coz we may have embedded structs like list> + if depth != 0 { + w.f("{") // new block to protect tmp vars + defer w.f("}") + } + tmpsize := "sz" // for ReadListBegin, size int + tmpi := "i" // loop var + if depth > 0 { // avoid redeclared vars + sub := strconv.Itoa(depth - 1) + tmpsize = tmpsize + sub + tmpi = tmpi + sub + } + + w.f("var %s int", tmpsize) + + // ??? thriftgo & kitex always ignore element type of a list/set? + w.f("_, %s, l, err = x.ReadListBegin(b[off:])", tmpsize) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") + + w.f("%s = make(%s, %s)", varname, rwctx.TypeName.Deref(), tmpsize) + w.f("for %s := 0; %s < %s; %s++ {", tmpi, tmpi, tmpsize, tmpi) + genFastReadAny(w, rwctx.ValCtx, varname+"["+tmpi+"]", depth+1) + w.f("}") +} + +func genFastReadMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + // var conventions: + // - sz is the size of a map + // - i is the counter for decoding a map + // + // you must use the vars below instead of using literal above, + // coz we may have embedded structs like list> + if depth != 0 { + w.f("{") // new block to protect tmp vars + defer w.f("}") + } + tmpsize := "sz" // for ReadMapBegin, size int + tmpk := "k" // for reading keys + tmpv := "v" // for reading values + tmpi := "i" // loop var + if depth > 0 { // avoid redeclared vars + sub := strconv.Itoa(depth - 1) + tmpsize = tmpsize + sub + tmpk = tmpk + sub + tmpv = tmpv + sub + tmpi = tmpi + sub + } + + w.f("var %s int", tmpsize) + + // ??? thriftgo & kitex always ignore kv types of a map? + w.f("_, _, %s, l, err = x.ReadMapBegin(b[off:])", tmpsize) + w.f("off += l") + w.f("if err != nil { goto ReadFieldError }") + + w.f("%s = make(%s, %s)", varname, rwctx.TypeName, tmpsize) + w.f("for %s := 0; %s < %s; %s++ {", tmpi, tmpi, tmpsize, tmpi) + if rwctx.KeyCtx.TypeID == "Struct" && !rwctx.KeyCtx.IsPointer { + // hotfix for struct, it's always pointer for keys + // remove this check after generator/gopkg fix it + w.f("var %s *%s", tmpk, rwctx.KeyCtx.TypeName) + } else { + w.f("var %s %s", tmpk, rwctx.KeyCtx.TypeName) + } + w.f("var %s %s", tmpv, rwctx.ValCtx.TypeName) + genFastReadAny(w, rwctx.KeyCtx, tmpk, depth+1) + genFastReadAny(w, rwctx.ValCtx, tmpv, depth+1) + w.f("%s[%s] = %s", varname, tmpk, tmpv) + w.f("}") +} diff --git a/generator/fastgo/gen_fastwrite.go b/generator/fastgo/gen_fastwrite.go new file mode 100644 index 00000000..246d49f2 --- /dev/null +++ b/generator/fastgo/gen_fastwrite.go @@ -0,0 +1,218 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "strconv" + + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/parser" +) + +func (g *FastGoBackend) genFastWrite(w *codewriter, scope *golang.Scope, s *golang.StructLike) { + // var conventions: + // - p is the var of pointer to the struct going to be generated + // - b is the buf to write into + // - w is the var of thrift.NocopyWriter + // - off is the offset of b + + // func definition + w.UsePkg("github.com/cloudwego/gopkg/protocol/thrift", "") + w.f("func (p *%s) FastWrite(b []byte) int { return p.FastWriteNocopy(b, nil) }\n\n", s.GoName()) + w.f("func (p *%s) FastWriteNocopy(b []byte, w thrift.NocopyWriter) int {", s.GoName()) + + // case nil, STOP and return + w.f("if p == nil { b[0] = 0; return 1; }") + + // `off` definition for buf cursor + w.f("off := 0") + + // fields + ff := getSortedFields(s) + for _, f := range ff { + rwctx, err := g.utils.MkRWCtx(scope, f) + if err != nil { + // never goes here, should fail early in generator/golang pkg + panic(err) + } + genFastWriteField(w, rwctx, f) + } + + // end of field encoding + w.f("") // empty line + w.f("b[off] = 0") // STOP + w.f("return off + 1") // return including the STOP byte + + // end of func definition + w.f("}\n\n") +} + +func genFastWriteField(w *codewriter, rwctx *golang.ReadWriteContext, f *golang.Field) { + // the real var name ref to the field + varname := string("p." + f.GoName()) + + // add comment like // ${FieldName} ${FieldID} ${FieldType} + w.f("\n// %s ID:%d %s", rwctx.Target, f.ID, category2GopkgConsts[f.Type.Category]) + + // check skip cases + // only for optional fields + if f.Requiredness == parser.FieldType_Optional { + if f.GoTypeName().IsPointer() || isContainerType(f.Type) { + // case 1: optional and nil + w.f("if %s != nil {", varname) + defer w.f("}") + } else if !f.GoTypeName().IsPointer() && f.Default != nil { + // case 2: optional and equals to default value + w.f("if %s != %v {", varname, f.DefaultValue()) + defer w.f("}") + } + } + + // field header + w.UsePkg("encoding/binary", "") + w.f("b[off] = %d", category2ThriftWireType[f.Type.Category]) + w.f("binary.BigEndian.PutUint16(b[off+1:], %d) ", f.ID) + w.f("off += 3") + + // field value + genFastWriteAny(w, rwctx, varname, 0) + +} + +func genFastWriteAny(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + t := rwctx.Type + pointer := rwctx.IsPointer + switch t.Category { + case parser.Category_Bool: + genFastWriteBool(w, pointer, varname) + case parser.Category_Byte: + genFastWriteByte(w, pointer, varname) + case parser.Category_I16: + genFastWriteInt16(w, pointer, varname) + case parser.Category_I32, parser.Category_Enum: + genFastWriteInt32(w, pointer, varname) + case parser.Category_I64: + genFastWriteInt64(w, pointer, varname) + case parser.Category_Double: + genFastWriteDouble(w, pointer, varname) + case parser.Category_String: + genFastWriteString(w, pointer, varname) + case parser.Category_Binary: + genFastWriteBinary(w, pointer, varname) + case parser.Category_Map: + genFastWriteMap(w, rwctx, varname, depth) + case parser.Category_List, parser.Category_Set: + genFastWriteList(w, rwctx, varname, depth) + case parser.Category_Struct, parser.Category_Union, parser.Category_Exception: + // TODO: fix for parser.Category_Union? must only one field set + genFastWriteStruct(w, rwctx, varname) + } +} + +func genFastWriteBool(w *codewriter, pointer bool, varname string) { + // for bool, the underlying byte of true is always 1, and 0 for false + // which is same as thrift binary protocol + w.UsePkg("unsafe", "") + w.f("b[off] = *((*byte)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname)) + w.f("off++") +} + +func genFastWriteByte(w *codewriter, pointer bool, varname string) { + w.f("b[off] = byte(%s)", varnameVal(pointer, varname)) + w.f("off++") +} + +func genFastWriteDouble(w *codewriter, pointer bool, varname string) { + w.UsePkg("unsafe", "") + w.f("binary.BigEndian.PutUint64(b[off:], *(*uint64)(unsafe.Pointer(%s)))", varnamePtr(pointer, varname)) + w.f("off += 8") +} + +func genFastWriteInt16(w *codewriter, pointer bool, varname string) { + w.UsePkg("encoding/binary", "") + w.f("binary.BigEndian.PutUint16(b[off:], uint16(%s))", varnameVal(pointer, varname)) + w.f("off += 2") +} + +func genFastWriteInt32(w *codewriter, pointer bool, varname string) { + w.UsePkg("encoding/binary", "") + w.f("binary.BigEndian.PutUint32(b[off:], uint32(%s))", varnameVal(pointer, varname)) + w.f("off += 4") +} + +func genFastWriteInt64(w *codewriter, pointer bool, varname string) { + w.UsePkg("encoding/binary", "") + w.f("binary.BigEndian.PutUint64(b[off:], uint64(%s))", varnameVal(pointer, varname)) + w.f("off += 8") +} + +func genFastWriteBinary(w *codewriter, pointer bool, varname string) { + varname = varnameVal(pointer, varname) + w.f("off += thrift.Binary.WriteBinaryNocopy(b[off:], w, %s)", varname) +} + +func genFastWriteString(w *codewriter, pointer bool, varname string) { + varname = varnameVal(pointer, varname) + w.f("off += thrift.Binary.WriteStringNocopy(b[off:], w, %s)", varname) +} + +func genFastWriteStruct(w *codewriter, rwctx *golang.ReadWriteContext, varname string) { + w.f("off += %s.FastWriteNocopy(b[off:], w)", varname) +} + +func genFastWriteList(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + rwctx = rwctx.ValCtx + t := rwctx.Type + w.UsePkg("encoding/binary", "") + // list header + w.f("b[off] = %d", category2ThriftWireType[t.Category]) + w.f("binary.BigEndian.PutUint32(b[off+1:], uint32(len(%s)))", varname) + w.f("off += 5") + + // iteration tmp var + tmpv := "v" + if depth > 0 { // avoid redeclared vars + tmpv = "v" + strconv.Itoa(depth-1) + } + w.f("for _, %s := range %s {", tmpv, varname) + genFastWriteAny(w, rwctx, tmpv, depth+1) + w.f("}") +} + +func genFastWriteMap(w *codewriter, rwctx *golang.ReadWriteContext, varname string, depth int) { + t := rwctx.Type + kt := t.KeyType + vt := t.ValueType + // map header + w.UsePkg("encoding/binary", "") + w.f("b[off] = %d", category2ThriftWireType[kt.Category]) + w.f("b[off+1] = %d", category2ThriftWireType[vt.Category]) + w.f("binary.BigEndian.PutUint32(b[off+2:], uint32(len(%s)))", varname) + w.f("off += 6") + + // iteration tmp var + tmpk := "k" + tmpv := "v" + if depth > 0 { // avoid redeclared vars + tmpk = "k" + strconv.Itoa(depth-1) + tmpv = "v" + strconv.Itoa(depth-1) + } + w.f("for %s, %s := range %s {", tmpk, tmpv, varname) + genFastWriteAny(w, rwctx.KeyCtx, tmpk, depth+1) + genFastWriteAny(w, rwctx.ValCtx, tmpv, depth+1) + w.f("}") +} diff --git a/generator/fastgo/utils.go b/generator/fastgo/utils.go new file mode 100644 index 00000000..a98c5a6c --- /dev/null +++ b/generator/fastgo/utils.go @@ -0,0 +1,58 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "sort" + + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/parser" +) + +func isContainerType(f *parser.Type) bool { + switch f.Category { + case parser.Category_Map, + parser.Category_List, + parser.Category_Set: + return true + case parser.Category_Binary: + return true // []byte, a byte list + } + return false +} + +func varnameVal(pointer bool, varname string) string { + if pointer { + return "*" + varname + } + return varname +} + +func varnamePtr(pointer bool, varname string) string { + if pointer { + return varname + } + return "&" + varname +} + +// getSortedFields returns fields sorted by field id. +// we don't want to see code changes due to field order. +func getSortedFields(s *golang.StructLike) []*golang.Field { + ff := append([]*golang.Field(nil), s.Fields()...) + sort.Slice(ff, func(i, j int) bool { return ff[i].ID < ff[j].ID }) + return ff +} diff --git a/generator/fastgo/utils_test.go b/generator/fastgo/utils_test.go new file mode 100644 index 00000000..da14deb2 --- /dev/null +++ b/generator/fastgo/utils_test.go @@ -0,0 +1,47 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fastgo + +import ( + "fmt" + "go/format" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/cloudwego/thriftgo/pkg/test" +) + +func srcEqual(t *testing.T, a, b string) { + _, file, line, _ := runtime.Caller(1) + location := fmt.Sprintf("@ %s:%d", filepath.Base(file), line) + b0, err := format.Source([]byte(a)) + if err != nil { + t.Log("syntax err", err, a, location) + t.FailNow() + } + b1, err := format.Source([]byte(b)) + if err != nil { + t.Log("syntax err", err, a, location) + t.FailNow() + } + s0 := strings.TrimSpace(string(b0)) + s1 := strings.TrimSpace(string(b1)) + + test.Assert(t, s0 == s1, fmt.Sprintf("\n%s\n != \n%s\n %s", s0, s1, location)) +} diff --git a/generator/golang/backend.go b/generator/golang/backend.go index caa29d2d..b949f300 100644 --- a/generator/golang/backend.go +++ b/generator/golang/backend.go @@ -107,6 +107,10 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R return g.buildResponse() } +func (g *GoBackend) GetCoreUtils() *CodeUtils { + return g.utils +} + func (g *GoBackend) prepareUtilities() { if g.err != nil { return diff --git a/main.go b/main.go index ef93d854..9966bf2e 100644 --- a/main.go +++ b/main.go @@ -16,17 +16,28 @@ package main import ( "fmt" - "github.com/cloudwego/thriftgo/sdk" "os" "runtime/debug" "runtime/pprof" - "time" + + "github.com/cloudwego/thriftgo/args" + "github.com/cloudwego/thriftgo/generator" + "github.com/cloudwego/thriftgo/generator/fastgo" + "github.com/cloudwego/thriftgo/generator/golang" + "github.com/cloudwego/thriftgo/sdk" +) + +var ( + a args.Arguments + g generator.Generator ) var debugMode bool func init() { + _ = g.RegisterBackend(new(golang.GoBackend)) + _ = g.RegisterBackend(new(fastgo.FastGoBackend)) // export THRIFTGO_DEBUG=1 debugMode = os.Getenv("THRIFTGO_DEBUG") == "1" }