Skip to content

Commit

Permalink
Code refactor, change description for new config options
Browse files Browse the repository at this point in the history
  • Loading branch information
luborco committed Aug 13, 2024
2 parents 15f96b0 + a42a3e9 commit 5d44025
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Usage of ./grpc-rest-proxy:
--transport.http.server.gracefulTimeout duration graceful timeout (default 5s)
--transport.http.server.readHeaderTimeout duration read header timeout (default 5s)
--transport.http.server.readTimeout duration read timeout (default 10s)
--service.jsonencoder.emitUnpopulated emit unpopulated fields
--service.jsonencoder.emitDefaultValues emit default values
--service.jsonencoder.emitUnpopulated emit unpopulated fields in JSON response for empty gRPC values
--service.jsonencoder.emitDefaultValues include default values in JSON response for empty gRPC values
-v, --version print version
```

Expand Down
6 changes: 3 additions & 3 deletions cmd/service/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ func createGateways(conf *Config) (*gateways, error) {

func (app *App) createHTTPServer() {
routerContext := &transport.Context{
Router: app.router,
GrcpClient: app.gateways.grpcClient,
Encoder: jsonencoder.NewOptions(app.conf.Service.JSONEncoder),
Router: app.router,
GrcpClient: app.gateways.grpcClient,
JSONEncoder: jsonencoder.New(app.conf.Service.JSONEncoder),
}
handler := transport.NewHandler(routerContext, logging.Default())
app.serverHTTP = http.NewServer(app.conf.Transport.HTTP.Server, handler)
Expand Down
4 changes: 2 additions & 2 deletions cmd/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func main() {
pflag.Bool("gateways.grpc.client.tls", tls, "use TLS for gRPC connection")
pflag.Bool("gateways.grpc.client.tlsSkipverify", tlsSkipverify, "skip TLS verification")

pflag.Bool("service.jsonencoder.emitUnpopulated", defaultEmitUnpopulated, "emit unpopulated fields")
pflag.Bool("service.jsonencoder.emitDefaultValues", defaultEmitDefaultValues, "emit default values")
pflag.Bool("service.jsonencoder.emitUnpopulated", defaultEmitUnpopulated, "emit unpopulated fields in JSON response for empty gRPC values")
pflag.Bool("service.jsonencoder.emitDefaultValues", defaultEmitDefaultValues, "include default values in JSON response for empty gRPC values") //nolint:lll

pflag.BoolP("version", "v", false, "print version")
configFile := pflag.StringP("config", "c", "", "path to config file")
Expand Down
8 changes: 4 additions & 4 deletions pkg/service/jsonencoder/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ type Encoder struct {
opts protojson.MarshalOptions
}

func NewOptions(cfg *Config) Encoder {
func New(cfg *Config) Encoder {
return Encoder{
opts: protojson.MarshalOptions{EmitUnpopulated: cfg.EmitUnpopulated, EmitDefaultValues: cfg.EmitDefaultValues},
}
}

func (e Encoder) Format(m proto.Message) (string, error) {
func (e Encoder) Encode(m proto.Message) ([]byte, error) {
response, err := e.opts.Marshal(m)
if err != nil {
return "", jErrors.Trace(err)
return nil, jErrors.Trace(err)
}

return string(response), nil
return response, nil
}
51 changes: 40 additions & 11 deletions pkg/service/transformer/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,61 @@ import (

const (
InvalidGrpcMethodName = jErrors.ConstError("gRPC method name is invalid")

headerAccept = "accept"
headerContentType = "content-type"
headerContentLength = "content-length"

// connection specific headers http1.0/1.1
headerConnection = "connection"
headerProxyConnection = "proxy-connection"
headerKeepAlive = "keep-alive"
headerTransferEncoding = "transfer-encoding"
headerUpgrade = "upgrade"
)

func isConnectionSpecificHeader(name string) bool {
switch name {
case headerConnection, headerProxyConnection, headerKeepAlive, headerTransferEncoding, headerUpgrade:
return true
}
return false
}

func GetRPCRequestContext(request *http.Request) context.Context {
grpcMetadata := metadata.Pairs()

for name, values := range request.Header {
name = strings.ToLower(name)

// in case the client sends a content-length header it will be removed before proceeding
if name == "content-length" {
if name == headerContentLength {
continue
}
// RFC 9113 8.2.2.: endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields
if request.ProtoMajor > 1 && isConnectionSpecificHeader(name) {
continue
}
grpcMetadata.Append(name, values...)
}

grpcMetadata.Set("accept", "application/protobuf")
grpcMetadata.Set("content-type", "application/protobuf")
grpcMetadata.Set(headerAccept, "application/protobuf")
grpcMetadata.Set(headerContentType, "application/protobuf")

return metadata.NewOutgoingContext(request.Context(), grpcMetadata)
}

func SetRESTHeaders(headers http.Header, gRPCheader metadata.MD, gRPCTrailer metadata.MD) {
func SetRESTHeaders(protoMajor int, headers http.Header, gRPCheader metadata.MD, gRPCTrailer metadata.MD) {
// set headers
for name, values := range gRPCheader {
for _, value := range values {
headers.Add(name, value)
}
setHeader(headers, protoMajor, name, values)
}
// append trailers as headers
for name, values := range gRPCTrailer {
for _, value := range values {
headers.Add(name, value)
}
setHeader(headers, protoMajor, name, values)
}

headers.Set("content-type", "application/json")
headers.Set(headerContentType, "application/json")
}

func GetRPCResponse(responseDesc protoreflect.MessageDescriptor) *dynamicpb.Message {
Expand Down Expand Up @@ -102,3 +121,13 @@ func GetHTTPStatusCode(code codes.Code) int {
return http.StatusInternalServerError
}
}

func setHeader(headers http.Header, protoMajor int, name string, values []string) {
// RFC 9113 8.2.2.: endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields
if protoMajor > 1 && isConnectionSpecificHeader(name) {
return
}
for _, value := range values {
headers.Add(name, value)
}
}
22 changes: 14 additions & 8 deletions pkg/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import (
)

type Context struct {
Router *routerPkg.ReloadableRouter
GrcpClient grpcClient.ClientInterface
Encoder jsonencoder.Encoder
Router *routerPkg.ReloadableRouter
GrcpClient grpcClient.ClientInterface
JSONEncoder jsonencoder.Encoder
}

type Logger interface {
Expand Down Expand Up @@ -109,22 +109,28 @@ func createRoutingEndpoint(rc *Context, logger Logger) func(w http.ResponseWrite
)
if err != nil {
if e, ok := status.FromError(err); ok {
transformer.SetRESTHeaders(w.Header(), header, trailer)
transformer.SetRESTHeaders(r.ProtoMajor, w.Header(), header, trailer)
w.WriteHeader(transformer.GetHTTPStatusCode(e.Code()))
}
logger.ErrorContext(r.Context(), jErrors.Details(jErrors.Trace(err)))
return
}

transformer.SetRESTHeaders(w.Header(), header, trailer)
response, err := rc.Encoder.Format(rpcResponse)
transformer.SetRESTHeaders(r.ProtoMajor, w.Header(), header, trailer)

response, err := rc.JSONEncoder.Encode(rpcResponse)
if err != nil {
logger.ErrorContext(r.Context(), jErrors.Details(jErrors.Trace(err)))
w.WriteHeader(http.StatusInternalServerError)
return
}

fmt.Fprint(w, response)
w.WriteHeader(http.StatusOK)
_, err = w.Write(response)
if err != nil {
logger.ErrorContext(r.Context(), jErrors.Details(jErrors.Trace(err)))
w.WriteHeader(http.StatusInternalServerError)
return
}
}
}

Expand Down

0 comments on commit 5d44025

Please sign in to comment.