Skip to content

Commit

Permalink
Use generics for deserialize (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohkinozomu authored Nov 21, 2023
1 parent e5fb616 commit bdcf060
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 82 deletions.
2 changes: 1 addition & 1 deletion internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func Start(c AgentConfig) {
select {
case payload := <-s.payloadCh:
go func() {
requestPacket, err := data.DeserializeRequestPacket(payload, s.commonConfig.Networking.Format)
requestPacket, err := data.Deserialize[*data.HTTPRequestPacket](payload, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error deserializing request packet", zap.Error(err))
return
Expand Down
2 changes: 1 addition & 1 deletion internal/common/split/split.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func Split(id string, bytes []byte, chunkSize int, format string, processFn func
}

func Merge(merger *Merger, body []byte, format string) (combined []byte, completed bool, err error) {
chunk, err := data.DeserializeHTTPBodyChunk(body, format)
chunk, err := data.Deserialize[*data.HTTPBodyChunk](body, format)
if err != nil {
return nil, false, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/hub/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ func (s *server) startHTTP1(c HubConfig) {
case payload := <-s.payloadCh:
go func() {
s.logger.Debug("Received message")
httpResponsePacket, err := data.DeserializeResponsePacket(payload, s.commonConfig.Networking.Format)
httpResponsePacket, err := data.Deserialize[*data.HTTPResponsePacket](payload, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Info("Error deserializing response packet: " + err.Error())
return
Expand Down
120 changes: 41 additions & 79 deletions pkg/data/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
reflect "reflect"

"github.com/thanos-io/objstore"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -39,116 +40,77 @@ func Serialize[T proto.Message](value T, format string) ([]byte, error) {
return payload, nil
}

func DeserializeRequestPacket(payload []byte, format string) (*HTTPRequestPacket, error) {
func Deserialize[T proto.Message](payload []byte, format string) (T, error) {
var err error
requestPacket := HTTPRequestPacket{}
var result T

tType := reflect.TypeOf(result)
if tType.Kind() == reflect.Ptr {
tType = tType.Elem()
}
tValue := reflect.New(tType)
result = tValue.Interface().(T)

switch format {
case "json":
err = json.Unmarshal(payload, &requestPacket)
err = json.Unmarshal(payload, result)
case "protobuf":
err = proto.Unmarshal(payload, &requestPacket)
err = proto.Unmarshal(payload, result)
default:
return nil, fmt.Errorf("unknown format: %s", format)
return result, fmt.Errorf("unknown format: %s", format)
}
return &requestPacket, err

return result, err
}

func DeserializeResponsePacket(payload []byte, format string) (*HTTPResponsePacket, error) {
var err error
responsePacket := HTTPResponsePacket{}
switch format {
case "json":
err = json.Unmarshal(payload, &responsePacket)
case "protobuf":
err = proto.Unmarshal(payload, &responsePacket)
default:
return nil, fmt.Errorf("unknown format: %s", format)
func decodeStorageRelay(body []byte, bucket objstore.Bucket) ([]byte, error) {
rc, err := bucket.Get(context.Background(), string(body))
if err != nil {
return nil, err
}
return &responsePacket, err

data, err := io.ReadAll(rc)
if err != nil {
return nil, err
}

err = rc.Close()
if err != nil {
return nil, err
}
return data, nil
}

func DeserializeHTTPRequestData(b []byte, format string, bucket objstore.Bucket) (*HTTPRequestData, error) {
var httpRequestData HTTPRequestData
switch format {
case "json":
if err := json.Unmarshal(b, &httpRequestData); err != nil {
return nil, fmt.Errorf("error unmarshalling message: %v", err)
}
case "protobuf":
if err := proto.Unmarshal(b, &httpRequestData); err != nil {
return nil, fmt.Errorf("error unmarshalling message: %v", err)
}
default:
return nil, fmt.Errorf("unknown format: %v", format)
httpRequestData, err := Deserialize[*HTTPRequestData](b, format)
if err != nil {
return nil, err
}

if httpRequestData.Body.Type == "storage_relay" {
rc, err := bucket.Get(context.Background(), string(httpRequestData.Body.Body))
if err != nil {
return nil, err
}

data, err := io.ReadAll(rc)
if err != nil {
return nil, err
}

err = rc.Close()
data, err := decodeStorageRelay(httpRequestData.Body.Body, bucket)
if err != nil {
return nil, err
}
httpRequestData.Body.Body = data
}

return &httpRequestData, nil
return httpRequestData, nil
}

func DeserializeHTTPResponseData(b []byte, format string, bucket objstore.Bucket) (*HTTPResponseData, error) {
var httpResponseData HTTPResponseData
switch format {
case "json":
if err := json.Unmarshal(b, &httpResponseData); err != nil {
return nil, fmt.Errorf("error unmarshalling message: %v", err)
}
case "protobuf":
if err := proto.Unmarshal(b, &httpResponseData); err != nil {
return nil, fmt.Errorf("error unmarshalling message: %v", err)
}
default:
return nil, fmt.Errorf("unknown format: %v", format)
httpResponseData, err := Deserialize[*HTTPResponseData](b, format)
if err != nil {
return nil, err
}

if httpResponseData.Body.Type == "storage_relay" {
rc, err := bucket.Get(context.Background(), string(httpResponseData.Body.Body))
if err != nil {
return nil, err
}

data, err := io.ReadAll(rc)
if err != nil {
return nil, err
}

err = rc.Close()
data, err := decodeStorageRelay(httpResponseData.Body.Body, bucket)
if err != nil {
return nil, err
}
httpResponseData.Body.Body = data
}

return &httpResponseData, nil
}

func DeserializeHTTPBodyChunk(payload []byte, format string) (*HTTPBodyChunk, error) {
var err error
httpBodyChunk := HTTPBodyChunk{}
switch format {
case "json":
err = json.Unmarshal(payload, &httpBodyChunk)
case "protobuf":
err = proto.Unmarshal(payload, &httpBodyChunk)
default:
return nil, fmt.Errorf("unknown format: %s", format)
}
return &httpBodyChunk, err
return httpResponseData, nil
}

0 comments on commit bdcf060

Please sign in to comment.