Skip to content

Commit

Permalink
Refactoring (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohkinozomu authored Nov 21, 2023
1 parent 360db5c commit 42c91a0
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 205 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ require (
github.com/spf13/afero v1.10.0 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tencentyun/cos-go-sdk-v5 v0.7.40 // indirect
go.opencensus.io v0.24.0 // indirect
Expand Down
23 changes: 7 additions & 16 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type server struct {
payloadCh chan []byte
mergeCh chan mergeChPayload
processCh chan processChPayload
merger *data.Merger
merger *split.Merger
}

func createWillMessage(c AgentConfig) *paho.WillMessage {
Expand Down Expand Up @@ -239,7 +239,7 @@ func newServer(c AgentConfig) server {
encoder: encoder,
decoder: decoder,
bucket: bucket,
merger: data.NewMerger(c.Logger),
merger: split.NewMerger(),
}
}

Expand Down Expand Up @@ -360,25 +360,16 @@ func Start(c AgentConfig) {
}()
case mergeChPayload := <-s.mergeCh:
go func() {
chunk, err := data.DeserializeHTTPBodyChunk(mergeChPayload.httpRequestData.Body.Body, s.commonConfig.Networking.Format)
combined, completed, err := split.Merge(s.merger, mergeChPayload.httpRequestData.Body.Body, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error deserializing HTTP body chunk", zap.Error(err))
s.logger.Info("Error merging message: " + err.Error())
return
}
s.logger.Debug("Received chunk")
s.merger.AddChunk(chunk)

if s.merger.IsComplete(chunk) {
s.logger.Debug("Received last chunk")
combined := s.merger.GetCombinedData(chunk)
s.logger.Debug("Combined data")
if completed {
mergeChPayload.httpRequestData.Body.Body = combined
processChPayload := processChPayload(mergeChPayload)
s.logger.Debug("Sending to processCh")
s.processCh <- processChPayload
s.logger.Debug("Sent to processCh")
// TODO: delete from merger
s.processCh <- processChPayload(mergeChPayload)
}
s.logger.Debug("Done")
}()
case processChPayload := <-s.processCh:
go func() {
Expand Down
27 changes: 6 additions & 21 deletions pkg/data/merger.go → internal/common/split/merger.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
package data
package split

import (
"sort"
"sync"

"go.uber.org/zap"
"github.com/ohkinozomu/fuyuu-router/pkg/data"
)

type Merger struct {
chunks map[string]map[int][]byte
logger *zap.Logger
mu sync.Mutex
}

func NewMerger(logger *zap.Logger) *Merger {
func NewMerger() *Merger {
return &Merger{
chunks: make(map[string]map[int][]byte),
logger: logger,
mu: sync.Mutex{},
}
}

func (m *Merger) AddChunk(chunk *HTTPBodyChunk) {
func (m *Merger) AddChunk(chunk *data.HTTPBodyChunk) {
// Avoid concurrent map writes
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.chunks[chunk.RequestId]; !exists {
m.chunks[chunk.RequestId] = make(map[int][]byte)
}
m.chunks[chunk.RequestId][int(chunk.Sequence)] = chunk.Data
m.logger.Debug("Chunk added")
}

func (m *Merger) IsComplete(chunk *HTTPBodyChunk) bool {
func (m *Merger) IsComplete(chunk *data.HTTPBodyChunk) bool {
return len(m.chunks[chunk.RequestId]) == int(chunk.Total)
}

func (m *Merger) GetCombinedData(chunk *HTTPBodyChunk) []byte {
func (m *Merger) GetCombinedData(chunk *data.HTTPBodyChunk) []byte {
// Avoid concurrent map read and map write
m.mu.Lock()
defer m.mu.Unlock()
Expand All @@ -58,15 +55,3 @@ func (m *Merger) GetCombinedData(chunk *HTTPBodyChunk) []byte {

return combinedData
}

func SplitChunk(body []byte, chunkByte int) [][]byte {
var chunks [][]byte
for i := 0; i < len(body); i += chunkByte {
end := i + chunkByte
if end > len(body) {
end = len(body)
}
chunks = append(chunks, body[i:end])
}
return chunks
}
74 changes: 74 additions & 0 deletions internal/common/split/merger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package split

import (
"reflect"
"testing"

"github.com/ohkinozomu/fuyuu-router/pkg/data"
)

func TestNewMerger(t *testing.T) {
merger := NewMerger()
if merger == nil {
t.Error("NewMerger returned nil")
}
if merger.chunks == nil {
t.Error("NewMerger did not initialize chunks map")
}
}

func TestAddChunkAndIsComplete(t *testing.T) {
merger := NewMerger()
chunk := &data.HTTPBodyChunk{
RequestId: "test",
Sequence: 1,
Total: 2,
Data: []byte("part1"),
}

merger.AddChunk(chunk)
if !reflect.DeepEqual(merger.chunks[chunk.RequestId][int(chunk.Sequence)], chunk.Data) {
t.Errorf("AddChunk did not add the chunk data correctly")
}

if merger.IsComplete(chunk) {
t.Error("IsComplete should return false when the total number of chunks has not been reached")
}

chunk2 := &data.HTTPBodyChunk{
RequestId: "test",
Sequence: 2,
Total: 2,
Data: []byte("part2"),
}

merger.AddChunk(chunk2)
if !merger.IsComplete(chunk2) {
t.Error("IsComplete should return true when all chunks have been added")
}
}

func TestGetCombinedData(t *testing.T) {
merger := NewMerger()
chunk1 := &data.HTTPBodyChunk{
RequestId: "test",
Sequence: 1,
Total: 2,
Data: []byte("part1"),
}
chunk2 := &data.HTTPBodyChunk{
RequestId: "test",
Sequence: 2,
Total: 2,
Data: []byte("part2"),
}

merger.AddChunk(chunk1)
merger.AddChunk(chunk2)

combined := merger.GetCombinedData(chunk1)
expected := []byte("part1part2")
if !reflect.DeepEqual(combined, expected) {
t.Errorf("GetCombinedData returned %v, expected %v", combined, expected)
}
}
28 changes: 27 additions & 1 deletion internal/common/split/split.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,20 @@ import (
"github.com/ohkinozomu/fuyuu-router/pkg/data"
)

func splitChunk(body []byte, chunkByte int) [][]byte {
var chunks [][]byte
for i := 0; i < len(body); i += chunkByte {
end := i + chunkByte
if end > len(body) {
end = len(body)
}
chunks = append(chunks, body[i:end])
}
return chunks
}

func Split(id string, bytes []byte, chunkSize int, format string, processFn func(int, []byte) (any, error), sendFn func(any) error) error {
chunks := data.SplitChunk(bytes, chunkSize)
chunks := splitChunk(bytes, chunkSize)

for sequence, chunk := range chunks {
httpBodyChunk := data.HTTPBodyChunk{
Expand All @@ -28,3 +40,17 @@ func Split(id string, bytes []byte, chunkSize int, format string, processFn func
}
return nil
}

func Merge(merger *Merger, body []byte, format string) (combined []byte, completed bool, err error) {
chunk, err := data.DeserializeHTTPBodyChunk(body, format)
if err != nil {
return nil, false, err
}
merger.AddChunk(chunk)

if merger.IsComplete(chunk) {
combined := merger.GetCombinedData(chunk)
return combined, true, nil
}
return nil, false, nil
}
110 changes: 110 additions & 0 deletions internal/common/split/split_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package split

import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestSplitChunk(t *testing.T) {
tests := []struct {
name string
body []byte
chunkByte int
want [][]byte
}{
{
name: "Empty input",
body: []byte(""),
chunkByte: 4,
want: nil,
},
{
name: "Regular input, complete division",
body: []byte("abcdef"),
chunkByte: 2,
want: [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")},
},
{
name: "Regular input, incomplete division",
body: []byte("abcdefgh"),
chunkByte: 3,
want: [][]byte{[]byte("abc"), []byte("def"), []byte("gh")},
},
{
name: "chunkByte larger than input",
body: []byte("abc"),
chunkByte: 10,
want: [][]byte{[]byte("abc")},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitChunk(tt.body, tt.chunkByte)
assert.Equal(t, tt.want, got)
})
}
}

func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))

b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}

func TestSplitAndMerge(t *testing.T) {
originalData := []byte(randomString(10000))
id := "test-id"
chunkSize := 100
format := "json"

dataCh := make(chan []byte)
doneCh := make(chan bool)
merger := NewMerger()

mockProcessFn := func(sequence int, data []byte) (any, error) {
return data, nil
}

go func() {
err := Split(id, originalData, chunkSize, format, mockProcessFn, func(chunk any) error {
serializedChunk, ok := chunk.([]byte)
if !ok {
return fmt.Errorf("invalid chunk type")
}

dataCh <- serializedChunk
return nil
})
if err != nil {
t.Error(err)
}
close(dataCh)
}()

go func() {
for chunk := range dataCh {
combined, completed, err := Merge(merger, chunk, format)
if err != nil {
t.Error(err)
break
}
if completed {
assert.Equal(t, originalData, combined)
doneCh <- true
break
}
}
}()

<-doneCh
}
14 changes: 6 additions & 8 deletions internal/hub/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type server struct {
decoder *zstd.Decoder
commonConfig common.CommonConfigV2
bucket objstore.Bucket
merger *data.Merger
merger *split.Merger
payloadCh chan []byte
mergeCh chan mergeChPayload
busCh chan busChPayload
Expand Down Expand Up @@ -249,7 +249,7 @@ func newServer(c HubConfig) server {
decoder: decoder,
commonConfig: c.CommonConfigV2,
bucket: bucket,
merger: data.NewMerger(c.Logger),
merger: split.NewMerger(),
payloadCh: payloadCh,
mergeCh: mergeCh,
busCh: busCh,
Expand Down Expand Up @@ -525,16 +525,14 @@ func (s *server) startHTTP1(c HubConfig) {
}()
case mergeChPayload := <-s.mergeCh:
go func() {
chunk, err := data.DeserializeHTTPBodyChunk(mergeChPayload.httpResponseData.Body.Body, s.commonConfig.Networking.Format)
combined, completed, err := split.Merge(s.merger, mergeChPayload.httpResponseData.Body.Body, s.commonConfig.Networking.Format)
if err != nil {
s.logger.Error("Error deserializing HTTP body chunk", zap.Error(err))
s.logger.Info("Error merging message: " + err.Error())
return
}
s.merger.AddChunk(chunk)

if s.merger.IsComplete(chunk) {
combined := s.merger.GetCombinedData(chunk)
if completed {
mergeChPayload.httpResponseData.Body.Body = combined
// TODO: delete from merger
s.busCh <- busChPayload(mergeChPayload)
}
}()
Expand Down
Loading

0 comments on commit 42c91a0

Please sign in to comment.