diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 277fb32..b8dd79a 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -333,7 +333,7 @@ func Start(c AgentConfig) { select { case payload := <-s.payloadCh: go func() { - requestPacket, err := data.DeserializeRequestPacket(payload, s.commonConfig.Networking.Format) + requestPacket, err := data.Deserialize[*data.HTTPRequestPacket](payload, s.commonConfig.Networking.Format) if err != nil { s.logger.Error("Error deserializing request packet", zap.Error(err)) return diff --git a/internal/common/split/split.go b/internal/common/split/split.go index aa314f3..f7de865 100644 --- a/internal/common/split/split.go +++ b/internal/common/split/split.go @@ -42,7 +42,7 @@ func Split(id string, bytes []byte, chunkSize int, format string, processFn func } func Merge(merger *Merger, body []byte, format string) (combined []byte, completed bool, err error) { - chunk, err := data.DeserializeHTTPBodyChunk(body, format) + chunk, err := data.Deserialize[*data.HTTPBodyChunk](body, format) if err != nil { return nil, false, err } diff --git a/internal/hub/server.go b/internal/hub/server.go index 9d024f5..2d1941c 100644 --- a/internal/hub/server.go +++ b/internal/hub/server.go @@ -498,7 +498,7 @@ func (s *server) startHTTP1(c HubConfig) { case payload := <-s.payloadCh: go func() { s.logger.Debug("Received message") - httpResponsePacket, err := data.DeserializeResponsePacket(payload, s.commonConfig.Networking.Format) + httpResponsePacket, err := data.Deserialize[*data.HTTPResponsePacket](payload, s.commonConfig.Networking.Format) if err != nil { s.logger.Info("Error deserializing response packet: " + err.Error()) return diff --git a/pkg/data/serde.go b/pkg/data/serde.go index ebe714f..c421279 100644 --- a/pkg/data/serde.go +++ b/pkg/data/serde.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + reflect "reflect" "github.com/thanos-io/objstore" "google.golang.org/protobuf/proto" @@ -39,116 +40,77 @@ func Serialize[T proto.Message](value T, format string) ([]byte, error) { return payload, nil } -func DeserializeRequestPacket(payload []byte, format string) (*HTTPRequestPacket, error) { +func Deserialize[T proto.Message](payload []byte, format string) (T, error) { var err error - requestPacket := HTTPRequestPacket{} + var result T + + tType := reflect.TypeOf(result) + if tType.Kind() == reflect.Ptr { + tType = tType.Elem() + } + tValue := reflect.New(tType) + result = tValue.Interface().(T) + switch format { case "json": - err = json.Unmarshal(payload, &requestPacket) + err = json.Unmarshal(payload, result) case "protobuf": - err = proto.Unmarshal(payload, &requestPacket) + err = proto.Unmarshal(payload, result) default: - return nil, fmt.Errorf("unknown format: %s", format) + return result, fmt.Errorf("unknown format: %s", format) } - return &requestPacket, err + + return result, err } -func DeserializeResponsePacket(payload []byte, format string) (*HTTPResponsePacket, error) { - var err error - responsePacket := HTTPResponsePacket{} - switch format { - case "json": - err = json.Unmarshal(payload, &responsePacket) - case "protobuf": - err = proto.Unmarshal(payload, &responsePacket) - default: - return nil, fmt.Errorf("unknown format: %s", format) +func decodeStorageRelay(body []byte, bucket objstore.Bucket) ([]byte, error) { + rc, err := bucket.Get(context.Background(), string(body)) + if err != nil { + return nil, err } - return &responsePacket, err + + data, err := io.ReadAll(rc) + if err != nil { + return nil, err + } + + err = rc.Close() + if err != nil { + return nil, err + } + return data, nil } func DeserializeHTTPRequestData(b []byte, format string, bucket objstore.Bucket) (*HTTPRequestData, error) { - var httpRequestData HTTPRequestData - switch format { - case "json": - if err := json.Unmarshal(b, &httpRequestData); err != nil { - return nil, fmt.Errorf("error unmarshalling message: %v", err) - } - case "protobuf": - if err := proto.Unmarshal(b, &httpRequestData); err != nil { - return nil, fmt.Errorf("error unmarshalling message: %v", err) - } - default: - return nil, fmt.Errorf("unknown format: %v", format) + httpRequestData, err := Deserialize[*HTTPRequestData](b, format) + if err != nil { + return nil, err } if httpRequestData.Body.Type == "storage_relay" { - rc, err := bucket.Get(context.Background(), string(httpRequestData.Body.Body)) - if err != nil { - return nil, err - } - - data, err := io.ReadAll(rc) - if err != nil { - return nil, err - } - - err = rc.Close() + data, err := decodeStorageRelay(httpRequestData.Body.Body, bucket) if err != nil { return nil, err } httpRequestData.Body.Body = data } - return &httpRequestData, nil + return httpRequestData, nil } func DeserializeHTTPResponseData(b []byte, format string, bucket objstore.Bucket) (*HTTPResponseData, error) { - var httpResponseData HTTPResponseData - switch format { - case "json": - if err := json.Unmarshal(b, &httpResponseData); err != nil { - return nil, fmt.Errorf("error unmarshalling message: %v", err) - } - case "protobuf": - if err := proto.Unmarshal(b, &httpResponseData); err != nil { - return nil, fmt.Errorf("error unmarshalling message: %v", err) - } - default: - return nil, fmt.Errorf("unknown format: %v", format) + httpResponseData, err := Deserialize[*HTTPResponseData](b, format) + if err != nil { + return nil, err } if httpResponseData.Body.Type == "storage_relay" { - rc, err := bucket.Get(context.Background(), string(httpResponseData.Body.Body)) - if err != nil { - return nil, err - } - - data, err := io.ReadAll(rc) - if err != nil { - return nil, err - } - - err = rc.Close() + data, err := decodeStorageRelay(httpResponseData.Body.Body, bucket) if err != nil { return nil, err } httpResponseData.Body.Body = data } - return &httpResponseData, nil -} - -func DeserializeHTTPBodyChunk(payload []byte, format string) (*HTTPBodyChunk, error) { - var err error - httpBodyChunk := HTTPBodyChunk{} - switch format { - case "json": - err = json.Unmarshal(payload, &httpBodyChunk) - case "protobuf": - err = proto.Unmarshal(payload, &httpBodyChunk) - default: - return nil, fmt.Errorf("unknown format: %s", format) - } - return &httpBodyChunk, err + return httpResponseData, nil }