From a42a3e98a18217204fd07ef4e4dad59cbcd35796 Mon Sep 17 00:00:00 2001 From: Michal Liziciar <51174405+luborco@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:51:58 +0200 Subject: [PATCH] Remove connection-specific headers from HTTP/2 message. (#7) * Remove connection-specific headers from HTTP/2 message. * Remove useless parameter, use write for sending response --- pkg/service/transformer/transformer.go | 51 ++++++++++++++++++++------ pkg/transport/transport.go | 20 ++++++++-- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/pkg/service/transformer/transformer.go b/pkg/service/transformer/transformer.go index 323c003..4d899f2 100644 --- a/pkg/service/transformer/transformer.go +++ b/pkg/service/transformer/transformer.go @@ -17,8 +17,27 @@ import ( const ( InvalidGrpcMethodName = jErrors.ConstError("gRPC method name is invalid") + + headerAccept = "accept" + headerContentType = "content-type" + headerContentLength = "content-length" + + // connection specific headers http1.0/1.1 + headerConnection = "connection" + headerProxyConnection = "proxy-connection" + headerKeepAlive = "keep-alive" + headerTransferEncoding = "transfer-encoding" + headerUpgrade = "upgrade" ) +func isConnectionSpecificHeader(name string) bool { + switch name { + case headerConnection, headerProxyConnection, headerKeepAlive, headerTransferEncoding, headerUpgrade: + return true + } + return false +} + func GetRPCRequestContext(request *http.Request) context.Context { grpcMetadata := metadata.Pairs() @@ -26,33 +45,43 @@ func GetRPCRequestContext(request *http.Request) context.Context { name = strings.ToLower(name) // in case the client sends a content-length header it will be removed before proceeding - if name == "content-length" { + if name == headerContentLength { + continue + } + // RFC 9113 8.2.2.: endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields + if request.ProtoMajor > 1 && isConnectionSpecificHeader(name) { continue } grpcMetadata.Append(name, values...) } - grpcMetadata.Set("accept", "application/protobuf") - grpcMetadata.Set("content-type", "application/protobuf") + grpcMetadata.Set(headerAccept, "application/protobuf") + grpcMetadata.Set(headerContentType, "application/protobuf") return metadata.NewOutgoingContext(request.Context(), grpcMetadata) } -func SetRESTHeaders(headers http.Header, gRPCheader metadata.MD, gRPCTrailer metadata.MD) { +func setHeader(headers http.Header, protoMajor int, name string, values []string) { + // RFC 9113 8.2.2.: endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields + if protoMajor > 1 && isConnectionSpecificHeader(name) { + return + } + for _, value := range values { + headers.Add(name, value) + } +} + +func SetRESTHeaders(protoMajor int, headers http.Header, gRPCheader metadata.MD, gRPCTrailer metadata.MD) { // set headers for name, values := range gRPCheader { - for _, value := range values { - headers.Add(name, value) - } + setHeader(headers, protoMajor, name, values) } // append trailers as headers for name, values := range gRPCTrailer { - for _, value := range values { - headers.Add(name, value) - } + setHeader(headers, protoMajor, name, values) } - headers.Set("content-type", "application/json") + headers.Set(headerContentType, "application/json") } func GetRPCResponse(responseDesc protoreflect.MessageDescriptor) *dynamicpb.Message { diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index d3a5748..0f631cf 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -108,16 +108,28 @@ func createRoutingEndpoint(rc *Context, logger Logger) func(w http.ResponseWrite ) if err != nil { if e, ok := status.FromError(err); ok { - transformer.SetRESTHeaders(w.Header(), header, trailer) + transformer.SetRESTHeaders(r.ProtoMajor, w.Header(), header, trailer) w.WriteHeader(transformer.GetHTTPStatusCode(e.Code())) } logger.ErrorContext(r.Context(), jErrors.Details(jErrors.Trace(err))) return } - transformer.SetRESTHeaders(w.Header(), header, trailer) - fmt.Fprint(w, protojson.Format(rpcResponse)) - w.WriteHeader(http.StatusOK) + transformer.SetRESTHeaders(r.ProtoMajor, w.Header(), header, trailer) + + response, err := protojson.Marshal(rpcResponse) + if err != nil { + logger.ErrorContext(r.Context(), jErrors.Details(jErrors.Trace(err))) + w.WriteHeader(http.StatusInternalServerError) + return + } + + _, err = w.Write(response) + if err != nil { + logger.ErrorContext(r.Context(), jErrors.Details(jErrors.Trace(err))) + w.WriteHeader(http.StatusInternalServerError) + return + } } }