From 29866bdf165de566f1d71287b134f33923acba66 Mon Sep 17 00:00:00 2001 From: Ohki Nozomu Date: Thu, 16 Nov 2023 20:06:54 +0900 Subject: [PATCH] Compress before splitting (#47) --- go.sum | 2 - internal/agent/agent.go | 42 ++++++++++++------- internal/hub/router.go | 3 +- internal/hub/server.go | 53 +++++++++++------------- pkg/data/serde.go | 30 ++------------ pkg/data/serde_test.go | 92 ++++++++++++----------------------------- 6 files changed, 84 insertions(+), 138 deletions(-) diff --git a/go.sum b/go.sum index e773305..1cb8bc4 100644 --- a/go.sum +++ b/go.sum @@ -298,8 +298,6 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= -github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 854c130..e87fa3b 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -6,7 +6,6 @@ import ( "encoding/json" "net/http" "os" - "strings" "time" "github.com/eclipse/paho.golang/paho" @@ -105,14 +104,14 @@ func newServer(c AgentConfig) server { } } -func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) (string, int, http.Header, error) { +func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) ([]byte, int, http.Header, error) { var responseHeader http.Header url := "http://" + proxyHost + data.Path body := bytes.NewBuffer(data.Body.Body) req, err := http.NewRequest(data.Method, url, body) if err != nil { - return "", http.StatusInternalServerError, responseHeader, err + return nil, http.StatusInternalServerError, responseHeader, err } for key, values := range data.Headers.GetHeaders() { for _, value := range values.GetValues() { @@ -122,15 +121,15 @@ func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) (string, int client := &http.Client{} resp, err := client.Do(req) if err != nil { - return "", http.StatusInternalServerError, responseHeader, err + return nil, http.StatusInternalServerError, responseHeader, err } defer resp.Body.Close() responseBody := new(bytes.Buffer) _, err = responseBody.ReadFrom(resp.Body) if err != nil { - return "", resp.StatusCode, responseHeader, err + return nil, resp.StatusCode, responseHeader, err } - return responseBody.String(), resp.StatusCode, resp.Header, nil + return responseBody.Bytes(), resp.StatusCode, resp.Header, nil } func Start(c AgentConfig) { @@ -229,7 +228,7 @@ func Start(c AgentConfig) { return } - httpRequestData, err := data.DeserializeHTTPRequestData(requestPacket.HttpRequestData, requestPacket.Compress, s.commonConfig.Networking.Format, s.decoder, s.bucket) + httpRequestData, err := data.DeserializeHTTPRequestData(requestPacket.HttpRequestData, s.commonConfig.Networking.Format, s.bucket) if err != nil { s.logger.Error("Error deserializing request data", zap.Error(err)) return @@ -277,6 +276,14 @@ func Start(c AgentConfig) { return } + if processChPayload.requestPacket.Compress == "zstd" && s.decoder != nil { + processChPayload.httpRequestData.Body.Body, err = s.decoder.DecodeAll(processChPayload.httpRequestData.Body.Body, nil) + if err != nil { + s.logger.Error("Error decoding request body", zap.Error(err)) + return + } + } + var responseData data.HTTPResponseData var objectName string httpResponse, statusCode, responseHeader, err := sendHTTP1Request(s.proxyHost, processChPayload.httpRequestData) @@ -293,15 +300,18 @@ func Start(c AgentConfig) { Headers: &protoHeaders, } } else { - httpResponseBytes := []byte(httpResponse) - if s.commonConfig.Networking.LargeDataPolicy == "split" && len(httpResponseBytes) > s.commonConfig.Split.ChunkBytes { + if s.commonConfig.Networking.Compress == "zstd" && s.encoder != nil { + httpResponse = s.encoder.EncodeAll(httpResponse, nil) + } + + if s.commonConfig.Networking.LargeDataPolicy == "split" && len(httpResponse) > s.commonConfig.Split.ChunkBytes { var chunks [][]byte - for i := 0; i < len(httpResponseBytes); i += s.commonConfig.Split.ChunkBytes { + for i := 0; i < len(httpResponse); i += s.commonConfig.Split.ChunkBytes { end := i + s.commonConfig.Split.ChunkBytes - if end > len(httpResponseBytes) { - end = len(httpResponseBytes) + if end > len(httpResponse) { + end = len(httpResponse) } - chunks = append(chunks, httpResponseBytes[i:end]) + chunks = append(chunks, httpResponse[i:end]) } for sequence, c := range chunks { @@ -328,7 +338,7 @@ func Start(c AgentConfig) { Headers: &protoHeaders, } - b, err = data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format, s.encoder) + b, err = data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error serializing response data", zap.Error(err)) return @@ -360,7 +370,7 @@ func Start(c AgentConfig) { } else { if s.commonConfig.Networking.LargeDataPolicy == "storage_relay" && len(httpResponse) > s.commonConfig.StorageRelay.ThresholdBytes { objectName = s.id + "/" + processChPayload.requestPacket.RequestId + "/response" - err := s.bucket.Upload(context.Background(), objectName, strings.NewReader(httpResponse)) + err := s.bucket.Upload(context.Background(), objectName, bytes.NewReader(httpResponse)) if err != nil { s.logger.Error("Error uploading object to object storage", zap.Error(err)) return @@ -388,7 +398,7 @@ func Start(c AgentConfig) { } } } - b, err := data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format, s.encoder) + b, err := data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error serializing response data", zap.Error(err)) return diff --git a/internal/hub/router.go b/internal/hub/router.go index d9ed708..80814b4 100644 --- a/internal/hub/router.go +++ b/internal/hub/router.go @@ -41,7 +41,8 @@ func (r *Router) Route(p *packets.Publish) { r.logger.Info("Error deserializing response packet: " + err.Error()) return } - httpResponseData, err := data.DeserializeHTTPResponseData(httpResponsePacket.GetHttpResponseData(), httpResponsePacket.Compress, r.commonConfig.Networking.Format, r.decoder, nil) + + httpResponseData, err := data.DeserializeHTTPResponseData(httpResponsePacket.GetHttpResponseData(), r.commonConfig.Networking.Format, nil) if err != nil { r.logger.Info("Error deserializing HTTP response data: " + err.Error()) return diff --git a/internal/hub/server.go b/internal/hub/server.go index bd54766..c333693 100644 --- a/internal/hub/server.go +++ b/internal/hub/server.go @@ -198,6 +198,10 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { bodyBytes := make([]byte, r.ContentLength) r.Body.Read(bodyBytes) + if s.encoder != nil { + bodyBytes = s.encoder.EncodeAll(bodyBytes, nil) + } + dataHeaders := data.HTTPHeaderToProtoHeaders(r.Header) var body data.HTTPBody @@ -229,6 +233,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { Body: b, Type: "split", } + requestData := data.HTTPRequestData{ Method: r.Method, Path: r.URL.Path, @@ -236,7 +241,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { Body: &body, } - b, err = data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format, s.encoder) + b, err = data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error serializing request data", zap.Error(err)) return @@ -284,7 +289,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { Body: &body, } - b, err := data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format, s.encoder) + b, err := data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error serializing request data", zap.Error(err)) return @@ -318,26 +323,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { select { case value := <-s.dataCh: s.logger.Debug("Writing response...") - var compress string - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte("compress/" + uuid)) - if err != nil { - return err - } - err = item.Value(func(val []byte) error { - compress = string(val) - return nil - }) - if err != nil { - return err - } - return nil - }) - if err != nil { - s.logger.Error("Error getting compress value from database", zap.Error(err)) - return - } - httpResponseData, err := data.DeserializeHTTPResponseData(value, compress, s.commonConfig.Networking.Format, s.decoder, s.bucket) + httpResponseData, err := data.DeserializeHTTPResponseData(value, s.commonConfig.Networking.Format, s.bucket) if err != nil { s.logger.Error("Error deserializing response data", zap.Error(err)) return @@ -433,14 +419,25 @@ func (s *server) startHTTP1(c HubConfig) { } case updateDBChPayload := <-s.updateDBCh: s.logger.Debug("Writing response to database...") - err := s.db.Update(func(txn *badger.Txn) error { - e1 := badger.NewEntry([]byte(updateDBChPayload.responsePacket.RequestId), updateDBChPayload.responsePacket.GetHttpResponseData()).WithTTL(time.Minute * 5) - err := txn.SetEntry(e1) + + if updateDBChPayload.responsePacket.Compress == "zstd" && s.decoder != nil { + var err error + updateDBChPayload.httpResponseData.Body.Body, err = s.decoder.DecodeAll(updateDBChPayload.httpResponseData.Body.Body, nil) if err != nil { - return err + s.logger.Info("Error decompressing message: " + err.Error()) + return } - e2 := badger.NewEntry([]byte("compress/"+updateDBChPayload.responsePacket.RequestId), []byte(updateDBChPayload.responsePacket.Compress)).WithTTL(time.Minute * 5) - err = txn.SetEntry(e2) + } + + b, err := data.SerializeHTTPResponseData(updateDBChPayload.httpResponseData, s.commonConfig.Networking.Format) + if err != nil { + s.logger.Info("Error serializing HTTP response data: " + err.Error()) + return + } + + err = s.db.Update(func(txn *badger.Txn) error { + e := badger.NewEntry([]byte(updateDBChPayload.responsePacket.RequestId), b).WithTTL(time.Minute * 5) + err := txn.SetEntry(e) if err != nil { return err } diff --git a/pkg/data/serde.go b/pkg/data/serde.go index 668ada4..62f9f7a 100644 --- a/pkg/data/serde.go +++ b/pkg/data/serde.go @@ -7,7 +7,6 @@ import ( "io" "net/http" - "github.com/klauspost/compress/zstd" "github.com/thanos-io/objstore" "google.golang.org/protobuf/proto" ) @@ -86,7 +85,7 @@ func DeserializeResponsePacket(payload []byte, format string) (*HTTPResponsePack return &responsePacket, err } -func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string, encoder *zstd.Encoder) ([]byte, error) { +func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string) ([]byte, error) { var b []byte var err error switch format { @@ -103,21 +102,10 @@ func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string, e default: return nil, fmt.Errorf("unknown format: %s", format) } - if encoder != nil { - b = encoder.EncodeAll(b, nil) - } return b, nil } -func DeserializeHTTPRequestData(b []byte, compress string, format string, decoder *zstd.Decoder, bucket objstore.Bucket) (*HTTPRequestData, error) { - var err error - if compress == "zstd" && decoder != nil { - b, err = decoder.DecodeAll(b, nil) - if err != nil { - return nil, err - } - } - +func DeserializeHTTPRequestData(b []byte, format string, bucket objstore.Bucket) (*HTTPRequestData, error) { var httpRequestData HTTPRequestData switch format { case "json": @@ -153,7 +141,7 @@ func DeserializeHTTPRequestData(b []byte, compress string, format string, decode return &httpRequestData, nil } -func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string, encoder *zstd.Encoder) ([]byte, error) { +func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string) ([]byte, error) { var b []byte var err error switch format { @@ -170,20 +158,10 @@ func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string default: return nil, fmt.Errorf("unknown format: %s", format) } - if encoder != nil { - b = encoder.EncodeAll(b, nil) - } return b, nil } -func DeserializeHTTPResponseData(b []byte, compress string, format string, decoder *zstd.Decoder, bucket objstore.Bucket) (*HTTPResponseData, error) { - var err error - if compress == "zstd" && decoder != nil { - b, err = decoder.DecodeAll(b, nil) - if err != nil { - return nil, err - } - } +func DeserializeHTTPResponseData(b []byte, format string, bucket objstore.Bucket) (*HTTPResponseData, error) { var httpResponseData HTTPResponseData switch format { case "json": diff --git a/pkg/data/serde_test.go b/pkg/data/serde_test.go index 5966ca2..fdeba3e 100644 --- a/pkg/data/serde_test.go +++ b/pkg/data/serde_test.go @@ -17,44 +17,27 @@ func TestHTTPHeaderToProtoHeaders(t *testing.T) { } func TestSerializedRequestPacket(t *testing.T) { - encoder, err := zstd.NewWriter(nil) - if err != nil { - t.Fatal(err) - } - decoder, err := zstd.NewReader(nil) - if err != nil { - t.Fatal(err) - } - testCases := []struct { format string compress string encoder *zstd.Encoder }{ { - format: "json", - compress: "none", - encoder: nil, - }, - { - format: "protobuf", - compress: "none", - encoder: nil, - }, - { - format: "json", - compress: "zstd", - encoder: encoder, + format: "json", }, { - format: "protobuf", - compress: "zstd", - encoder: encoder, + format: "protobuf", }, } for _, testCase := range testCases { t.Run(testCase.format, func(t *testing.T) { + body := []byte("test") + + if testCase.compress == "zstd" && testCase.encoder != nil { + body = testCase.encoder.EncodeAll(body, nil) + } + httpRequestData := HTTPRequestData{ Method: "GET", Path: "/", @@ -66,11 +49,11 @@ func TestSerializedRequestPacket(t *testing.T) { }, }, Body: &HTTPBody{ - Body: []byte("test"), + Body: body, Type: "data", }, } - b, err := SerializeHTTPRequestData(&httpRequestData, testCase.format, testCase.encoder) + b, err := SerializeHTTPRequestData(&httpRequestData, testCase.format) if err != nil { t.Fatal(err) } @@ -88,7 +71,7 @@ func TestSerializedRequestPacket(t *testing.T) { if err != nil { t.Fatal(err) } - deserializedHTTPRequestData, err := DeserializeHTTPRequestData(deserializedRequestPacket.GetHttpRequestData(), deserializedRequestPacket.Compress, testCase.format, decoder, nil) + deserializedHTTPRequestData, err := DeserializeHTTPRequestData(deserializedRequestPacket.GetHttpRequestData(), testCase.format, nil) if err != nil { t.Fatal(err) } @@ -102,44 +85,15 @@ func TestSerializedRequestPacket(t *testing.T) { } func TestSerializedResponsePacket(t *testing.T) { - encoder, err := zstd.NewWriter(nil) - if err != nil { - t.Fatal(err) - } - decoder, err := zstd.NewReader(nil) - if err != nil { - t.Fatal(err) - } - testCases := []struct { compress string format string - encoder *zstd.Encoder - decoder *zstd.Decoder }{ { - compress: "none", - format: "json", - encoder: nil, - decoder: nil, + format: "json", }, { - compress: "none", - format: "protobuf", - encoder: nil, - decoder: nil, - }, - { - compress: "zstd", - format: "json", - encoder: encoder, - decoder: decoder, - }, - { - compress: "zstd", - format: "protobuf", - encoder: encoder, - decoder: decoder, + format: "protobuf", }, } @@ -151,15 +105,17 @@ func TestSerializedResponsePacket(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.format, func(t *testing.T) { + body := []byte("test") + httpReponseData := HTTPResponseData{ StatusCode: 200, Headers: &headers, Body: &HTTPBody{ - Body: []byte("test"), + Body: body, Type: "data", }, } - b, err := SerializeHTTPResponseData(&httpReponseData, testCase.format, testCase.encoder) + b, err := SerializeHTTPResponseData(&httpReponseData, testCase.format) if err != nil { t.Fatal(err) } @@ -177,7 +133,7 @@ func TestSerializedResponsePacket(t *testing.T) { if err != nil { t.Fatal(err) } - deserializedHTTPResponseData, err := DeserializeHTTPResponseData(deserializedResponsePacket.GetHttpResponseData(), testCase.compress, testCase.format, testCase.decoder, nil) + deserializedHTTPResponseData, err := DeserializeHTTPResponseData(deserializedResponsePacket.GetHttpResponseData(), testCase.format, nil) if err != nil { t.Fatal(err) } @@ -238,18 +194,24 @@ func TestHTTPResponseDataSerialize(t *testing.T) { "Content-Length": {"31"}, }) + body := []byte("test") + + if testCase.compress == "zstd" && testCase.encoder != nil { + body = testCase.encoder.EncodeAll(body, nil) + } + httpResponseData := HTTPResponseData{ StatusCode: 200, Headers: &headers, - Body: &HTTPBody{Body: []byte("test"), Type: "data"}, + Body: &HTTPBody{Body: body, Type: "data"}, } - serializedResponseData, err := SerializeHTTPResponseData(&httpResponseData, testCase.format, testCase.encoder) + serializedResponseData, err := SerializeHTTPResponseData(&httpResponseData, testCase.format) if err != nil { t.Fatal(err) } - deserializedResponseData, err := DeserializeHTTPResponseData(serializedResponseData, testCase.compress, testCase.format, testCase.decoder, nil) + deserializedResponseData, err := DeserializeHTTPResponseData(serializedResponseData, testCase.format, nil) if err != nil { t.Fatal(err) }