diff --git a/main.go b/main.go index 1bac82e7..ef93d854 100644 --- a/main.go +++ b/main.go @@ -16,30 +16,17 @@ package main import ( "fmt" + "github.com/cloudwego/thriftgo/sdk" "os" "runtime/debug" "runtime/pprof" "time" - - "github.com/cloudwego/thriftgo/args" - "github.com/cloudwego/thriftgo/generator" - "github.com/cloudwego/thriftgo/generator/golang" - "github.com/cloudwego/thriftgo/parser" - "github.com/cloudwego/thriftgo/plugin" - "github.com/cloudwego/thriftgo/semantic" - "github.com/cloudwego/thriftgo/version" -) - -var ( - a args.Arguments - g generator.Generator ) var debugMode bool func init() { - _ = g.RegisterBackend(new(golang.GoBackend)) // export THRIFTGO_DEBUG=1 debugMode = os.Getenv("THRIFTGO_DEBUG") == "1" } @@ -67,65 +54,8 @@ func main() { } defer handlePanic() - check(a.Parse(os.Args)) - if a.AskVersion { - println("thriftgo", version.ThriftgoVersion) - os.Exit(0) - } - - log := a.MakeLogFunc() - - ast, err := parser.ParseFile(a.IDL, a.Includes, true) - check(err) - - if path := parser.CircleDetect(ast); len(path) > 0 { - check(fmt.Errorf("found include circle:\n\t%s", path)) - } - - if a.CheckKeyword { - if warns := parser.DetectKeyword(ast); len(warns) > 0 { - log.MultiWarn(warns) - } - } - - checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) - warns, err := checker.CheckAll(ast) - log.MultiWarn(warns) - check(err) - check(semantic.ResolveSymbols(ast)) - - req := &plugin.Request{ - Version: version.ThriftgoVersion, - OutputPath: a.OutputPath, - Recursive: a.Recursive, - AST: ast, - } - - plugin.MaxExecutionTime = a.PluginTimeLimit - plugins, err := a.UsedPlugins() - check(err) - - langs, err := a.Targets() - check(err) - - if len(langs) == 0 { - println("No output language(s) specified") - os.Exit(2) - } - - for _, out := range langs { - out.UsedPlugins = plugins - req.Language = out.Language - req.OutputPath = a.Output(out.Language) - - arg := &generator.Arguments{Out: out, Req: req, Log: log} - res := g.Generate(arg) - log.MultiWarn(res.Warnings) - - err = g.Persist(res) - check(err) - } + check(sdk.InvokeThriftgo(nil, os.Args...)) } func handlePanic() { diff --git a/sdk/invoke.go b/sdk/invoke.go new file mode 100644 index 00000000..e933e993 --- /dev/null +++ b/sdk/invoke.go @@ -0,0 +1,117 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sdk + +import ( + "fmt" + "github.com/cloudwego/thriftgo/generator/golang" + + targs "github.com/cloudwego/thriftgo/args" + "github.com/cloudwego/thriftgo/generator" + "github.com/cloudwego/thriftgo/parser" + "github.com/cloudwego/thriftgo/plugin" + "github.com/cloudwego/thriftgo/semantic" + "github.com/cloudwego/thriftgo/version" +) + +func init() { + _ = g.RegisterBackend(new(golang.GoBackend)) +} + +var ( + g generator.Generator +) + +// InvokeThriftgo is the core logic of thriftgo, from parse idl to generate code. +func InvokeThriftgo(SDKPlugins []plugin.SDKPlugin, args ...string) (err error) { + + var a targs.Arguments + + err = a.Parse(args) + if err != nil { + if err.Error() == "flag: help requested" { + return nil + } + return err + } + + if a.AskVersion { + println("thriftgo", version.ThriftgoVersion) + return nil + } + + // todo check log + log := a.MakeLogFunc() + + ast, err := parser.ParseFile(a.IDL, a.Includes, true) + if err != nil { + return err + } + + if path := parser.CircleDetect(ast); len(path) > 0 { + return fmt.Errorf("found include circle:\n\t%s", path) + } + + checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) + // todo no warnings when sdk? + warns, err := checker.CheckAll(ast) + log.MultiWarn(warns) + if err != nil { + return err + } + + err = semantic.ResolveSymbols(ast) + if err != nil { + return err + } + + req := &plugin.Request{ + Version: version.ThriftgoVersion, + OutputPath: a.OutputPath, + Recursive: a.Recursive, + AST: ast, + } + + plugin.MaxExecutionTime = a.PluginTimeLimit + plugins, err := a.UsedPlugins() + if err != nil { + return err + } + + langs, err := a.Targets() + if err != nil { + return err + } + + if len(langs) == 0 { + return fmt.Errorf("No output language(s) specified") + } + + for _, out := range langs { + out.UsedPlugins = plugins + out.SDKPlugins = SDKPlugins + req.Language = out.Language + req.OutputPath = a.Output(out.Language) + + arg := &generator.Arguments{Out: out, Req: req, Log: log} + res := g.Generate(arg) + + err = g.Persist(res) + if err != nil { + return err + } + } + return nil +} diff --git a/sdk/sdk.go b/sdk/sdk.go index 7bb07174..70385026 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -15,104 +15,23 @@ package sdk import ( - "fmt" - "github.com/cloudwego/thriftgo/utils/dir_utils" "github.com/cloudwego/thriftgo/config" - targs "github.com/cloudwego/thriftgo/args" - "github.com/cloudwego/thriftgo/generator" - "github.com/cloudwego/thriftgo/generator/backend" - "github.com/cloudwego/thriftgo/generator/golang" - "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/plugin" - "github.com/cloudwego/thriftgo/semantic" - "github.com/cloudwego/thriftgo/version" -) - -func init() { - _ = g.RegisterBackend(new(golang.GoBackend)) -} - -var ( - g generator.Generator ) -func RunThriftgoAsSDK(wd string, plugins []plugin.SDKPlugin, args ...string) error { +func RunThriftgoAsSDK(wd string, SDKPlugins []plugin.SDKPlugin, args ...string) error { // this should execute at the first line! dir_utils.SetGlobalwd(wd) + // for sdk mode, every time when function is invoked, config file should be reload err := config.LoadConfig() if err != nil { return err } - var a targs.Arguments - - err = a.Parse(append([]string{"thriftgo"}, args...)) - if err != nil { - if err.Error() == "flag: help requested" { - return nil - } - return err - } - - if a.AskVersion { - println("thriftgo", version.ThriftgoVersion) - return nil - } - - ast, err := parser.ParseFile(a.IDL, a.Includes, true) - if err != nil { - return err - } - - if path := parser.CircleDetect(ast); len(path) > 0 { - return fmt.Errorf("found include circle:\n\t%s", path) - } - - checker := semantic.NewChecker(semantic.Options{FixWarnings: true}) - _, err = checker.CheckAll(ast) - if err != nil { - return err - } - - err = semantic.ResolveSymbols(ast) - if err != nil { - return err - } - - req := &plugin.Request{ - Version: version.ThriftgoVersion, - OutputPath: a.OutputPath, - Recursive: a.Recursive, - AST: ast, - } - - langs, err := a.Targets() - if err != nil { - return err - } - - if len(langs) == 0 { - return fmt.Errorf("No output language(s) specified") - } - - log := backend.DummyLogFunc() - for _, out := range langs { - out.SDKPlugins = plugins - req.Language = out.Language - req.OutputPath = a.Output(out.Language) - - arg := &generator.Arguments{Out: out, Req: req, Log: log} - res := g.Generate(arg) - - err = g.Persist(res) - if err != nil { - return err - } - } - return nil + return InvokeThriftgo(SDKPlugins, append([]string{"thriftgo"}, args...)...) }