Skip to content

Commit

Permalink
Compress before splitting (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohkinozomu authored Nov 16, 2023
1 parent 78e2ca7 commit 29866bd
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 138 deletions.
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
Expand Down
42 changes: 26 additions & 16 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"net/http"
"os"
"strings"
"time"

"github.com/eclipse/paho.golang/paho"
Expand Down Expand Up @@ -105,14 +104,14 @@ func newServer(c AgentConfig) server {
}
}

func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) (string, int, http.Header, error) {
func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) ([]byte, int, http.Header, error) {
var responseHeader http.Header

url := "http://" + proxyHost + data.Path
body := bytes.NewBuffer(data.Body.Body)
req, err := http.NewRequest(data.Method, url, body)
if err != nil {
return "", http.StatusInternalServerError, responseHeader, err
return nil, http.StatusInternalServerError, responseHeader, err
}
for key, values := range data.Headers.GetHeaders() {
for _, value := range values.GetValues() {
Expand All @@ -122,15 +121,15 @@ func sendHTTP1Request(proxyHost string, data *data.HTTPRequestData) (string, int
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", http.StatusInternalServerError, responseHeader, err
return nil, http.StatusInternalServerError, responseHeader, err
}
defer resp.Body.Close()
responseBody := new(bytes.Buffer)
_, err = responseBody.ReadFrom(resp.Body)
if err != nil {
return "", resp.StatusCode, responseHeader, err
return nil, resp.StatusCode, responseHeader, err
}
return responseBody.String(), resp.StatusCode, resp.Header, nil
return responseBody.Bytes(), resp.StatusCode, resp.Header, nil
}

func Start(c AgentConfig) {
Expand Down Expand Up @@ -229,7 +228,7 @@ func Start(c AgentConfig) {
return
}

httpRequestData, err := data.DeserializeHTTPRequestData(requestPacket.HttpRequestData, requestPacket.Compress, s.commonConfig.Networking.Format, s.decoder, s.bucket)
httpRequestData, err := data.DeserializeHTTPRequestData(requestPacket.HttpRequestData, s.commonConfig.Networking.Format, s.bucket)
if err != nil {
s.logger.Error("Error deserializing request data", zap.Error(err))
return
Expand Down Expand Up @@ -277,6 +276,14 @@ func Start(c AgentConfig) {
return
}

if processChPayload.requestPacket.Compress == "zstd" && s.decoder != nil {
processChPayload.httpRequestData.Body.Body, err = s.decoder.DecodeAll(processChPayload.httpRequestData.Body.Body, nil)
if err != nil {
s.logger.Error("Error decoding request body", zap.Error(err))
return
}
}

var responseData data.HTTPResponseData
var objectName string
httpResponse, statusCode, responseHeader, err := sendHTTP1Request(s.proxyHost, processChPayload.httpRequestData)
Expand All @@ -293,15 +300,18 @@ func Start(c AgentConfig) {
Headers: &protoHeaders,
}
} else {
httpResponseBytes := []byte(httpResponse)
if s.commonConfig.Networking.LargeDataPolicy == "split" && len(httpResponseBytes) > s.commonConfig.Split.ChunkBytes {
if s.commonConfig.Networking.Compress == "zstd" && s.encoder != nil {
httpResponse = s.encoder.EncodeAll(httpResponse, nil)
}

if s.commonConfig.Networking.LargeDataPolicy == "split" && len(httpResponse) > s.commonConfig.Split.ChunkBytes {
var chunks [][]byte
for i := 0; i < len(httpResponseBytes); i += s.commonConfig.Split.ChunkBytes {
for i := 0; i < len(httpResponse); i += s.commonConfig.Split.ChunkBytes {
end := i + s.commonConfig.Split.ChunkBytes
if end > len(httpResponseBytes) {
end = len(httpResponseBytes)
if end > len(httpResponse) {
end = len(httpResponse)
}
chunks = append(chunks, httpResponseBytes[i:end])
chunks = append(chunks, httpResponse[i:end])
}

for sequence, c := range chunks {
Expand All @@ -328,7 +338,7 @@ func Start(c AgentConfig) {
Headers: &protoHeaders,
}

b, err = data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format, s.encoder)
b, err = data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error serializing response data", zap.Error(err))
return
Expand Down Expand Up @@ -360,7 +370,7 @@ func Start(c AgentConfig) {
} else {
if s.commonConfig.Networking.LargeDataPolicy == "storage_relay" && len(httpResponse) > s.commonConfig.StorageRelay.ThresholdBytes {
objectName = s.id + "/" + processChPayload.requestPacket.RequestId + "/response"
err := s.bucket.Upload(context.Background(), objectName, strings.NewReader(httpResponse))
err := s.bucket.Upload(context.Background(), objectName, bytes.NewReader(httpResponse))
if err != nil {
s.logger.Error("Error uploading object to object storage", zap.Error(err))
return
Expand Down Expand Up @@ -388,7 +398,7 @@ func Start(c AgentConfig) {
}
}
}
b, err := data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format, s.encoder)
b, err := data.SerializeHTTPResponseData(&responseData, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error serializing response data", zap.Error(err))
return
Expand Down
3 changes: 2 additions & 1 deletion internal/hub/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func (r *Router) Route(p *packets.Publish) {
r.logger.Info("Error deserializing response packet: " + err.Error())
return
}
httpResponseData, err := data.DeserializeHTTPResponseData(httpResponsePacket.GetHttpResponseData(), httpResponsePacket.Compress, r.commonConfig.Networking.Format, r.decoder, nil)

httpResponseData, err := data.DeserializeHTTPResponseData(httpResponsePacket.GetHttpResponseData(), r.commonConfig.Networking.Format, nil)
if err != nil {
r.logger.Info("Error deserializing HTTP response data: " + err.Error())
return
Expand Down
53 changes: 25 additions & 28 deletions internal/hub/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) {
bodyBytes := make([]byte, r.ContentLength)
r.Body.Read(bodyBytes)

if s.encoder != nil {
bodyBytes = s.encoder.EncodeAll(bodyBytes, nil)
}

dataHeaders := data.HTTPHeaderToProtoHeaders(r.Header)

var body data.HTTPBody
Expand Down Expand Up @@ -229,14 +233,15 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) {
Body: b,
Type: "split",
}

requestData := data.HTTPRequestData{
Method: r.Method,
Path: r.URL.Path,
Headers: &dataHeaders,
Body: &body,
}

b, err = data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format, s.encoder)
b, err = data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error serializing request data", zap.Error(err))
return
Expand Down Expand Up @@ -284,7 +289,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) {
Body: &body,
}

b, err := data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format, s.encoder)
b, err := data.SerializeHTTPRequestData(&requestData, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error serializing request data", zap.Error(err))
return
Expand Down Expand Up @@ -318,26 +323,7 @@ func (s *server) handleRequest(w http.ResponseWriter, r *http.Request) {
select {
case value := <-s.dataCh:
s.logger.Debug("Writing response...")
var compress string
err := s.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte("compress/" + uuid))
if err != nil {
return err
}
err = item.Value(func(val []byte) error {
compress = string(val)
return nil
})
if err != nil {
return err
}
return nil
})
if err != nil {
s.logger.Error("Error getting compress value from database", zap.Error(err))
return
}
httpResponseData, err := data.DeserializeHTTPResponseData(value, compress, s.commonConfig.Networking.Format, s.decoder, s.bucket)
httpResponseData, err := data.DeserializeHTTPResponseData(value, s.commonConfig.Networking.Format, s.bucket)
if err != nil {
s.logger.Error("Error deserializing response data", zap.Error(err))
return
Expand Down Expand Up @@ -433,14 +419,25 @@ func (s *server) startHTTP1(c HubConfig) {
}
case updateDBChPayload := <-s.updateDBCh:
s.logger.Debug("Writing response to database...")
err := s.db.Update(func(txn *badger.Txn) error {
e1 := badger.NewEntry([]byte(updateDBChPayload.responsePacket.RequestId), updateDBChPayload.responsePacket.GetHttpResponseData()).WithTTL(time.Minute * 5)
err := txn.SetEntry(e1)

if updateDBChPayload.responsePacket.Compress == "zstd" && s.decoder != nil {
var err error
updateDBChPayload.httpResponseData.Body.Body, err = s.decoder.DecodeAll(updateDBChPayload.httpResponseData.Body.Body, nil)
if err != nil {
return err
s.logger.Info("Error decompressing message: " + err.Error())
return
}
e2 := badger.NewEntry([]byte("compress/"+updateDBChPayload.responsePacket.RequestId), []byte(updateDBChPayload.responsePacket.Compress)).WithTTL(time.Minute * 5)
err = txn.SetEntry(e2)
}

b, err := data.SerializeHTTPResponseData(updateDBChPayload.httpResponseData, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Info("Error serializing HTTP response data: " + err.Error())
return
}

err = s.db.Update(func(txn *badger.Txn) error {
e := badger.NewEntry([]byte(updateDBChPayload.responsePacket.RequestId), b).WithTTL(time.Minute * 5)
err := txn.SetEntry(e)
if err != nil {
return err
}
Expand Down
30 changes: 4 additions & 26 deletions pkg/data/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"net/http"

"github.com/klauspost/compress/zstd"
"github.com/thanos-io/objstore"
"google.golang.org/protobuf/proto"
)
Expand Down Expand Up @@ -86,7 +85,7 @@ func DeserializeResponsePacket(payload []byte, format string) (*HTTPResponsePack
return &responsePacket, err
}

func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string, encoder *zstd.Encoder) ([]byte, error) {
func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string) ([]byte, error) {
var b []byte
var err error
switch format {
Expand All @@ -103,21 +102,10 @@ func SerializeHTTPRequestData(httpRequestData *HTTPRequestData, format string, e
default:
return nil, fmt.Errorf("unknown format: %s", format)
}
if encoder != nil {
b = encoder.EncodeAll(b, nil)
}
return b, nil
}

func DeserializeHTTPRequestData(b []byte, compress string, format string, decoder *zstd.Decoder, bucket objstore.Bucket) (*HTTPRequestData, error) {
var err error
if compress == "zstd" && decoder != nil {
b, err = decoder.DecodeAll(b, nil)
if err != nil {
return nil, err
}
}

func DeserializeHTTPRequestData(b []byte, format string, bucket objstore.Bucket) (*HTTPRequestData, error) {
var httpRequestData HTTPRequestData
switch format {
case "json":
Expand Down Expand Up @@ -153,7 +141,7 @@ func DeserializeHTTPRequestData(b []byte, compress string, format string, decode
return &httpRequestData, nil
}

func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string, encoder *zstd.Encoder) ([]byte, error) {
func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string) ([]byte, error) {
var b []byte
var err error
switch format {
Expand All @@ -170,20 +158,10 @@ func SerializeHTTPResponseData(httpResponseData *HTTPResponseData, format string
default:
return nil, fmt.Errorf("unknown format: %s", format)
}
if encoder != nil {
b = encoder.EncodeAll(b, nil)
}
return b, nil
}

func DeserializeHTTPResponseData(b []byte, compress string, format string, decoder *zstd.Decoder, bucket objstore.Bucket) (*HTTPResponseData, error) {
var err error
if compress == "zstd" && decoder != nil {
b, err = decoder.DecodeAll(b, nil)
if err != nil {
return nil, err
}
}
func DeserializeHTTPResponseData(b []byte, format string, bucket objstore.Bucket) (*HTTPResponseData, error) {
var httpResponseData HTTPResponseData
switch format {
case "json":
Expand Down
Loading

0 comments on commit 29866bd

Please sign in to comment.