Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: protoc-gen-rpc-swagger #6

Merged
merged 20 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ const (

const (
CodeGenerationCommentPbHttp = "// Code generated by protoc-gen-http-swagger."
CodeGenerationCommentPbRpc = "// Code generated by protoc-gen-rpc-swagger."
CodeGenerationCommentThriftHttp = "// Code generated by thrift-gen-http-swagger."
CodeGenerationCommentThriftRpc = "// Code generated by thrift-gen-rpc-swagger."
)

const (
PluginNameProtocHttpSwagger = "protoc-gen-http-swagger"
PluginNameProtocRpcSwagger = "protoc-gen-rpc-swagger"
PluginNameThriftHttpSwagger = "thrift-gen-http-swagger"
PluginNameThriftRpcSwagger = "thrift-gen-rpc-swagger"
)
Expand All @@ -82,6 +84,7 @@ const (
SchemaObjectType = "object"
ComponentSchemaPrefix = "#/components/schemas/"
ComponentSchemaSuffixBody = "Body"
ComponentSchemaSuffixForm = "Form"
ComponentSchemaSuffixRawBody = "RawBody"

ContentTypeJSON = "application/json"
Expand All @@ -106,4 +109,7 @@ const (

CommentPatternRegexp = `//\s*(.*)|/\*([\s\S]*?)\*/`
LinterRulePatternRegexp = `\(-- .* --\)`

ProtobufValueName = "GoogleProtobufValue"
ProtobufAnyName = "GoogleProtobufAny"
)
293 changes: 289 additions & 4 deletions common/tpl/tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,23 @@ func findThriftFile(fileName string) (string, error) {
}

foundPath := ""
relativePath := fileName

err = filepath.Walk(workingDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && info.Name() == fileName {
foundPath = path
return filepath.SkipDir

if !info.IsDir() {
relative, err := filepath.Rel(workingDir, path)
if err != nil {
return err
}

if relative == relativePath {
foundPath = path
return filepath.SkipDir
}
}
return nil
})
Expand Down Expand Up @@ -210,7 +220,282 @@ func initializeGenericClient() genericclient.Client {

g, err := generic.JSONThriftGeneric(p)
if err != nil {
hlog.Fatal("Failed to create HTTPThriftGeneric:", err)
hlog.Fatal("Failed to create JsonThriftGeneric:", err)
}
var opts []client.Option
opts = append(opts, client.WithTransportProtocol(transport.TTHeader))
opts = append(opts, client.WithMetaHandler(transmeta.ClientTTHeaderHandler))
opts = append(opts, client.WithHostPorts(kitexAddr))
cli, err := genericclient.NewClient("swagger", g, opts...)
if err != nil {
hlog.Fatal("Failed to create generic client:", err)
}

return cli
}

func setupSwaggerRoutes(h *server.Hertz) {
h.GET("swagger/*any", swagger.WrapHandler(swaggerFiles.Handler, swagger.URL("/openapi.yaml")))

h.GET("/openapi.yaml", func(c context.Context, ctx *app.RequestContext) {
ctx.Header("Content-Type", "application/x-yaml")
ctx.Write(openapiYAML)
})
}

func setupProxyRoutes(h *server.Hertz, cli genericclient.Client) {
h.Any("/*ServiceMethod", func(c context.Context, ctx *app.RequestContext) {
serviceMethod := ctx.Param("ServiceMethod")
if serviceMethod == "" {
handleError(ctx, "ServiceMethod not provided", http.StatusBadRequest)
return
}

bodyBytes := ctx.Request.Body()

queryMap := formatQueryParams(ctx)

for k, v := range queryMap {
if strings.HasPrefix(k, "p_") {
c = metainfo.WithPersistentValue(c, k, v)
} else {
c = metainfo.WithValue(c, k, v)
}
}

c = metainfo.WithBackwardValues(c)

jReq := string(bodyBytes)

jRsp, err := cli.GenericCall(c, serviceMethod, jReq)
if err != nil {
hlog.Errorf("GenericCall error: %v", err)
ctx.JSON(500, map[string]interface{}{
"error": err.Error(),
})
return
}

result := make(map[string]interface{})
if err := json.Unmarshal([]byte(jRsp.(string)), &result); err != nil {
hlog.Errorf("Failed to unmarshal response body: %v", err)
ctx.JSON(500, map[string]interface{}{
"error": "Failed to unmarshal response body",
})
return
}

m := metainfo.RecvAllBackwardValues(c)

for key, value := range m {
result[key] = value
}

respBody, err := json.Marshal(result)
if err != nil {
hlog.Errorf("Failed to marshal response body: %v", err)
ctx.JSON(500, map[string]interface{}{
"error": "Failed to marshal response body",
})
return
}

ctx.Data(http.StatusOK, "application/json", respBody)

})
}

func formatQueryParams(ctx *app.RequestContext) map[string]string {
var QueryParams = make(map[string]string)
ctx.Request.URI().QueryArgs().VisitAll(func(key, value []byte) {
QueryParams[string(key)] = string(value)
})
return QueryParams
}

func handleError(ctx *app.RequestContext, errMsg string, statusCode int) {
hlog.Errorf("Error: %s", errMsg)
ctx.JSON(statusCode, map[string]interface{}{
"error": errMsg,
})
}
`

const ServerTemplateRpcPb = `package swagger

import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"

"github.com/bytedance/gopkg/cloud/metainfo"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/route"
"github.com/cloudwego/kitex/client"
"github.com/cloudwego/kitex/client/genericclient"
"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/generic"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/trans/detection"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"github.com/cloudwego/kitex/pkg/transmeta"
"github.com/cloudwego/kitex/transport"
"github.com/hertz-contrib/cors"
"github.com/hertz-contrib/swagger"
swaggerFiles "github.com/swaggo/files"
)

var (
//go:embed openapi.yaml
openapiYAML []byte
hertzEngine *route.Engine
httpReg = regexp.MustCompile("^(?:GET |POST|PUT|DELE|HEAD|OPTI|CONN|TRAC|PATC)$")
)

const (
kitexAddr = "{{.KitexAddr}}"
idlFile = "{{.IdlPath}}"
)

type MixTransHandlerFactory struct {
OriginFactory remote.ServerTransHandlerFactory
}

type transHandler struct {
remote.ServerTransHandler
}

func (t *transHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
t.ServerTransHandler.(remote.InvokeHandleFuncSetter).SetInvokeHandleFunc(inkHdlFunc)
}

func (m MixTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {

if hertzEngine == nil {
StartServer()
}

var kitexOrigin remote.ServerTransHandler
var err error

if m.OriginFactory != nil {
kitexOrigin, err = m.OriginFactory.NewTransHandler(opt)
} else {
kitexOrigin, err = detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(opt)
}
if err != nil {
return nil, err
}
return &transHandler{ServerTransHandler: kitexOrigin}, nil
}

func (t *transHandler) OnRead(ctx context.Context, conn net.Conn) error {
c, ok := conn.(network.Conn)
if ok {
pre, _ := c.Peek(4)
if httpReg.Match(pre) {
klog.Info("using Hertz to process request")
err := hertzEngine.Serve(ctx, c)
if err != nil {
err = errors.New(fmt.Sprintf("HERTZ: %s", err.Error()))
}
return err
}
}

return t.ServerTransHandler.OnRead(ctx, conn)
}

func StartServer() {
h := server.Default()
h.Use(cors.Default())

cli := initializeGenericClient()
setupSwaggerRoutes(h)
setupProxyRoutes(h, cli)

hlog.Info("Swagger UI is available at: http://" + kitexAddr + "/swagger/index.html")
err := h.Engine.Init()
if err != nil {
panic(err)
}

hertzEngine = h.Engine
}

func findPbFile(fileName string) (string, error) {
workingDir, err := os.Getwd()
if err != nil {
return "", err
}

foundPath := ""
relativePath := fileName

err = filepath.Walk(workingDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if !info.IsDir() {
relative, err := filepath.Rel(workingDir, path)
if err != nil {
return err
}

if relative == relativePath {
foundPath = path
return filepath.SkipDir
}
}
return nil
})

if err == nil && foundPath != "" {
return foundPath, nil
}

parentDir := filepath.Dir(workingDir)
for parentDir != "/" && parentDir != "." && parentDir != workingDir {
filePath := filepath.Join(parentDir, fileName)
if _, err := os.Stat(filePath); err == nil {
return filePath, nil
}
workingDir = parentDir
parentDir = filepath.Dir(parentDir)
}

return "", errors.New("proto file not found: " + fileName)
}

func initializeGenericClient() genericclient.Client {
pbFile, err := findPbFile(idlFile)
if err != nil {
hlog.Fatal("Failed to locate Proto file:", err)
}

dOpts := proto.Options{}
p, err := generic.NewPbFileProviderWithDynamicGo(pbFile, context.Background(), dOpts)
if err != nil {
hlog.Fatal("Failed to create PbFileProvider:", err)
}

g, err := generic.JSONPbGeneric(p)
if err != nil {
hlog.Fatal("Failed to create JsonPbGeneric:", err)
}
var opts []client.Option
opts = append(opts, client.WithTransportProtocol(transport.TTHeader))
Expand Down
Loading
Loading