diff --git a/go.sum b/go.sum index e773305..1cb8bc4 100644 --- a/go.sum +++ b/go.sum @@ -298,8 +298,6 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= -github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 854c130..c460353 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -6,7 +6,6 @@ import ( "encoding/json" "net/http" "os" - "strings" "time" "github.com/eclipse/paho.golang/paho" @@ -105,14 +104,14 @@ func newServer(c AgentConfig) server { } } -func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) (string, int, http.Header, error) { +func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) ([]byte, int, http.Header, error) { var responseHeader http.Header url := "http://" + proxyHost + data.Path body := bytes.NewBuffer(data.Body.Body) req, err := http.NewRequest(data.Method, url, body) if err != nil { - return "", http.StatusInternalServerError, responseHeader, err + return nil, http.StatusInternalServerError, responseHeader, err } for key, values := range data.Headers.GetHeaders() { for _, value := range values.GetValues() { @@ -122,15 +121,15 @@ func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) (string, int client := &http.Client{} resp, err := client.Do(req) if err != nil { - return "", http.StatusInternalServerError, responseHeader, err + return nil, http.StatusInternalServerError, responseHeader, err } defer resp.Body.Close() responseBody := new(bytes.Buffer) _, err = responseBody.ReadFrom(resp.Body) if err != nil { - return "", resp.StatusCode, responseHeader, err + return nil, resp.StatusCode, responseHeader, err } - return responseBody.String(), resp.StatusCode, resp.Header, nil + return responseBody.Bytes(), resp.StatusCode, resp.Header, nil } func Start(c AgentConfig) { @@ -293,15 +292,18 @@ func Start(c AgentConfig) { Headers: &protoHeaders, } } else { - httpResponseBytes := []byte(httpResponse) - if s.commonConfig.Networking.LargeDataPolicy == "split" && len(httpResponseBytes) > s.commonConfig.Split.ChunkBytes { + if s.commonConfig.Networking.Compress == "zstd" && s.encoder != nil { + httpResponse = s.encoder.EncodeAll(httpResponse, nil) + } + + if s.commonConfig.Networking.LargeDataPolicy == "split" && len(httpResponse) > s.commonConfig.Split.ChunkBytes { var chunks [][]byte - for i := 0; i < len(httpResponseBytes); i += s.commonConfig.Split.ChunkBytes { + for i := 0; i < len(httpResponse); i += s.commonConfig.Split.ChunkBytes { end := i + s.commonConfig.Split.ChunkBytes - if end > len(httpResponseBytes) { - end = len(httpResponseBytes) + if end > len(httpResponse) { + end = len(httpResponse) } - chunks = append(chunks, httpResponseBytes[i:end]) + chunks = append(chunks, httpResponse[i:end]) } for sequence, c := range chunks { @@ -360,7 +362,7 @@ func Start(c AgentConfig) { } else { if s.commonConfig.Networking.LargeDataPolicy == "storage_relay" && len(httpResponse) > s.commonConfig.StorageRelay.ThresholdBytes { objectName = s.id + "/" + processChPayload.requestPacket.RequestId + "/response" - err := s.bucket.Upload(context.Background(), objectName, strings.NewReader(httpResponse)) + err := s.bucket.Upload(context.Background(), objectName, bytes.NewReader(httpResponse)) if err != nil { s.logger.Error("Error uploading object to object storage", zap.Error(err)) return diff --git a/internal/hub/server.go b/internal/hub/server.go index bd54766..482588c 100644 --- a/internal/hub/server.go +++ b/internal/hub/server.go @@ -198,6 +198,10 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { bodyBytes := make([]byte, r.ContentLength) r.Body.Read(bodyBytes) + if s.encoder != nil { + bodyBytes = s.encoder.EncodeAll(bodyBytes, nil) + } + dataHeaders := data.HTTPHeaderToProtoHeaders(r.Header) var body data.HTTPBody @@ -229,6 +233,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { Body: b, Type: "split", } + requestData := data.HTTPRequestData{ Method: r.Method, Path: r.URL.Path, @@ -236,7 +241,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { Body: &body, } - b, err = data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format, s.encoder) + b, err = data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error serializing request data", zap.Error(err)) return @@ -284,7 +289,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) { Body: &body, } - b, err := data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format, s.encoder) + b, err := data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error serializing request data", zap.Error(err)) return diff --git a/pkg/data/serde.go b/pkg/data/serde.go index 668ada4..d8b62d0 100644 --- a/pkg/data/serde.go +++ b/pkg/data/serde.go @@ -86,7 +86,7 @@ func DeserializeResponsePacket(payload []byte, format string) (*HTTPResponsePack return &responsePacket, err } -func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string, encoder *zstd.Encoder) ([]byte, error) { +func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string) ([]byte, error) { var b []byte var err error switch format { @@ -103,9 +103,6 @@ func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string, e default: return nil, fmt.Errorf("unknown format: %s", format) } - if encoder != nil { - b = encoder.EncodeAll(b, nil) - } return b, nil } @@ -170,9 +167,6 @@ func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string default: return nil, fmt.Errorf("unknown format: %s", format) } - if encoder != nil { - b = encoder.EncodeAll(b, nil) - } return b, nil } diff --git a/pkg/data/serde_test.go b/pkg/data/serde_test.go index 5966ca2..4105d94 100644 --- a/pkg/data/serde_test.go +++ b/pkg/data/serde_test.go @@ -55,6 +55,12 @@ func TestSerializedRequestPacket(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.format, func(t *testing.T) { + body := []byte("test") + + if testCase.compress == "zstd" && testCase.encoder != nil { + body = testCase.encoder.EncodeAll(body, nil) + } + httpRequestData := HTTPRequestData{ Method: "GET", Path: "/", @@ -66,11 +72,11 @@ func TestSerializedRequestPacket(t *testing.T) { }, }, Body: &HTTPBody{ - Body: []byte("test"), + Body: body, Type: "data", }, } - b, err := SerializeHTTPRequestData(&httpRequestData, testCase.format, testCase.encoder) + b, err := SerializeHTTPRequestData(&httpRequestData, testCase.format) if err != nil { t.Fatal(err) } @@ -151,11 +157,17 @@ func TestSerializedResponsePacket(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.format, func(t *testing.T) { + body := []byte("test") + + if testCase.compress == "zstd" && testCase.encoder != nil { + body = testCase.encoder.EncodeAll(body, nil) + } + httpReponseData := HTTPResponseData{ StatusCode: 200, Headers: &headers, Body: &HTTPBody{ - Body: []byte("test"), + Body: body, Type: "data", }, } @@ -238,10 +250,16 @@ func TestHTTPResponseDataSerialize(t *testing.T) { "Content-Length": {"31"}, }) + body := []byte("test") + + if testCase.compress == "zstd" && testCase.encoder != nil { + body = testCase.encoder.EncodeAll(body, nil) + } + httpResponseData := HTTPResponseData{ StatusCode: 200, Headers: &headers, - Body: &HTTPBody{Body: []byte("test"), Type: "data"}, + Body: &HTTPBody{Body: body, Type: "data"}, } serializedResponseData, err := SerializeHTTPResponseData(&httpResponseData, testCase.format, testCase.encoder)