diff --git a/broadcaster/broadcaster.go b/broadcaster/broadcaster.go index bde80c93d1..c3f4c62ce0 100644 --- a/broadcaster/broadcaster.go +++ b/broadcaster/broadcaster.go @@ -61,7 +61,7 @@ type ConfirmedSequenceNumberMessage struct { } func NewBroadcaster(config wsbroadcastserver.BroadcasterConfigFetcher, chainId uint64, feedErrChan chan error, dataSigner signature.DataSignerFunc) *Broadcaster { - catchupBuffer := NewSequenceNumberCatchupBuffer(func() bool { return config().LimitCatchup }) + catchupBuffer := NewSequenceNumberCatchupBuffer(func() bool { return config().LimitCatchup }, func() int { return config().MaxCatchup }) return &Broadcaster{ server: wsbroadcastserver.NewWSBroadcastServer(config, catchupBuffer, chainId, feedErrChan), catchupBuffer: catchupBuffer, diff --git a/broadcaster/sequencenumbercatchupbuffer.go b/broadcaster/sequencenumbercatchupbuffer.go index 7664f1b8da..bdd3e60c5b 100644 --- a/broadcaster/sequencenumbercatchupbuffer.go +++ b/broadcaster/sequencenumbercatchupbuffer.go @@ -29,11 +29,13 @@ type SequenceNumberCatchupBuffer struct { messages []*BroadcastFeedMessage messageCount int32 limitCatchup func() bool + maxCatchup func() int } -func NewSequenceNumberCatchupBuffer(limitCatchup func() bool) *SequenceNumberCatchupBuffer { +func NewSequenceNumberCatchupBuffer(limitCatchup func() bool, maxCatchup func() int) *SequenceNumberCatchupBuffer { return &SequenceNumberCatchupBuffer{ limitCatchup: limitCatchup, + maxCatchup: maxCatchup, } } @@ -98,6 +100,15 @@ func (b *SequenceNumberCatchupBuffer) OnRegisterClient(clientConnection *wsbroad return nil, bmCount, time.Since(start) } +// Takes as input an index into the messages array, not a message index +func (b *SequenceNumberCatchupBuffer) pruneBufferToIndex(idx int) { + b.messages = b.messages[idx:] + if len(b.messages) > 10 && cap(b.messages) > len(b.messages)*10 { + // Too much spare capacity, copy to fresh slice to reset memory usage + b.messages = append([]*BroadcastFeedMessage(nil), b.messages[:len(b.messages)]...) + } +} + func (b *SequenceNumberCatchupBuffer) deleteConfirmed(confirmedSequenceNumber arbutil.MessageIndex) { if len(b.messages) == 0 { return @@ -126,11 +137,7 @@ func (b *SequenceNumberCatchupBuffer) deleteConfirmed(confirmedSequenceNumber ar return } - b.messages = b.messages[confirmedIndex+1:] - if len(b.messages) > 10 && cap(b.messages) > len(b.messages)*10 { - // Too much spare capacity, copy to fresh slice to reset memory usage - b.messages = append([]*BroadcastFeedMessage(nil), b.messages[:len(b.messages)]...) - } + b.pruneBufferToIndex(int(confirmedIndex) + 1) } func (b *SequenceNumberCatchupBuffer) OnDoBroadcast(bmi interface{}) error { @@ -147,6 +154,12 @@ func (b *SequenceNumberCatchupBuffer) OnDoBroadcast(bmi interface{}) error { confirmedSequenceNumberGauge.Update(int64(confirmMsg.SequenceNumber)) } + maxCatchup := b.maxCatchup() + if maxCatchup == 0 { + b.messages = nil + return nil + } + for _, newMsg := range broadcastMessage.Messages { if len(b.messages) == 0 { // Add to empty list @@ -167,6 +180,10 @@ func (b *SequenceNumberCatchupBuffer) OnDoBroadcast(bmi interface{}) error { } } + if maxCatchup >= 0 && len(b.messages) > maxCatchup { + b.pruneBufferToIndex(len(b.messages) - maxCatchup) + } + return nil } diff --git a/broadcaster/sequencenumbercatchupbuffer_test.go b/broadcaster/sequencenumbercatchupbuffer_test.go index 40fae9875f..fc6655057e 100644 --- a/broadcaster/sequencenumbercatchupbuffer_test.go +++ b/broadcaster/sequencenumbercatchupbuffer_test.go @@ -22,6 +22,7 @@ import ( "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/util/arbmath" ) func TestGetEmptyCacheMessages(t *testing.T) { @@ -29,6 +30,7 @@ func TestGetEmptyCacheMessages(t *testing.T) { messages: nil, messageCount: 0, limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } // Get everything @@ -60,6 +62,7 @@ func TestGetCacheMessages(t *testing.T) { messages: createDummyBroadcastMessages(indexes), messageCount: int32(len(indexes)), limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } // Get everything @@ -110,6 +113,7 @@ func TestDeleteConfirmedNil(t *testing.T) { messages: nil, messageCount: 0, limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } buffer.deleteConfirmed(0) @@ -124,6 +128,7 @@ func TestDeleteConfirmInvalidOrder(t *testing.T) { messages: createDummyBroadcastMessages(indexes), messageCount: int32(len(indexes)), limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } // Confirm before cache @@ -139,6 +144,7 @@ func TestDeleteConfirmed(t *testing.T) { messages: createDummyBroadcastMessages(indexes), messageCount: int32(len(indexes)), limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } // Confirm older than cache @@ -154,6 +160,7 @@ func TestDeleteFreeMem(t *testing.T) { messages: createDummyBroadcastMessagesImpl(indexes, len(indexes)*10+1), messageCount: int32(len(indexes)), limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } // Confirm older than cache @@ -169,6 +176,7 @@ func TestBroadcastBadMessage(t *testing.T) { messages: nil, messageCount: 0, limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } var foo int @@ -187,6 +195,7 @@ func TestBroadcastPastSeqNum(t *testing.T) { messages: createDummyBroadcastMessagesImpl(indexes, len(indexes)*10+1), messageCount: int32(len(indexes)), limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } bm := BroadcastMessage{ @@ -208,6 +217,8 @@ func TestBroadcastFutureSeqNum(t *testing.T) { buffer := SequenceNumberCatchupBuffer{ messages: createDummyBroadcastMessagesImpl(indexes, len(indexes)*10+1), messageCount: int32(len(indexes)), + limitCatchup: func() bool { return false }, + maxCatchup: func() int { return -1 }, } bm := BroadcastMessage{ @@ -223,3 +234,38 @@ func TestBroadcastFutureSeqNum(t *testing.T) { } } + +func TestMaxCatchupBufferSize(t *testing.T) { + limit := 5 + buffer := SequenceNumberCatchupBuffer{ + messages: nil, + messageCount: 0, + limitCatchup: func() bool { return false }, + maxCatchup: func() int { return limit }, + } + + firstMessage := 10 + for i := firstMessage; i <= 20; i += 2 { + bm := BroadcastMessage{ + Messages: []*BroadcastFeedMessage{ + { + SequenceNumber: arbutil.MessageIndex(i), + }, + { + SequenceNumber: arbutil.MessageIndex(i + 1), + }, + }, + } + err := buffer.OnDoBroadcast(bm) + Require(t, err) + haveMessages := buffer.getCacheMessages(0) + expectedCount := arbmath.MinInt(i+len(bm.Messages)-firstMessage, limit) + if len(haveMessages.Messages) != expectedCount { + t.Errorf("after broadcasting messages %v and %v, expected to have %v messages but got %v", i, i+1, expectedCount, len(haveMessages.Messages)) + } + expectedFirstMessage := arbutil.MessageIndex(arbmath.MaxInt(firstMessage, i+len(bm.Messages)-limit)) + if haveMessages.Messages[0].SequenceNumber != expectedFirstMessage { + t.Errorf("after broadcasting messages %v and %v, expected the first message to be %v but got %v", i, i+1, expectedFirstMessage, haveMessages.Messages[0].SequenceNumber) + } + } +} diff --git a/wsbroadcastserver/wsbroadcastserver.go b/wsbroadcastserver/wsbroadcastserver.go index 014995cee0..cd277387a0 100644 --- a/wsbroadcastserver/wsbroadcastserver.go +++ b/wsbroadcastserver/wsbroadcastserver.go @@ -60,6 +60,7 @@ type BroadcasterConfig struct { EnableCompression bool `koanf:"enable-compression" reload:"hot"` // if reloaded to false will cause disconnection of clients with enabled compression on next broadcast RequireCompression bool `koanf:"require-compression" reload:"hot"` // if reloaded to true will cause disconnection of clients with disabled compression on next broadcast LimitCatchup bool `koanf:"limit-catchup" reload:"hot"` + MaxCatchup int `koanf:"max-catchup" reload:"hot"` ConnectionLimits ConnectionLimiterConfig `koanf:"connection-limits" reload:"hot"` ClientDelay time.Duration `koanf:"client-delay" reload:"hot"` } @@ -93,6 +94,7 @@ func BroadcasterConfigAddOptions(prefix string, f *flag.FlagSet) { f.Bool(prefix+".enable-compression", DefaultBroadcasterConfig.EnableCompression, "enable per message deflate compression support") f.Bool(prefix+".require-compression", DefaultBroadcasterConfig.RequireCompression, "require clients to use compression") f.Bool(prefix+".limit-catchup", DefaultBroadcasterConfig.LimitCatchup, "only supply catchup buffer if requested sequence number is reasonable") + f.Int(prefix+".max-catchup", DefaultBroadcasterConfig.MaxCatchup, "the maximum size of the catchup buffer (-1 means unlimited)") ConnectionLimiterConfigAddOptions(prefix+".connection-limits", f) f.Duration(prefix+".client-delay", DefaultBroadcasterConfig.ClientDelay, "delay the first messages sent to each client by this amount") } @@ -117,6 +119,7 @@ var DefaultBroadcasterConfig = BroadcasterConfig{ EnableCompression: true, RequireCompression: false, LimitCatchup: false, + MaxCatchup: -1, ConnectionLimits: DefaultConnectionLimiterConfig, ClientDelay: 0, } @@ -141,6 +144,7 @@ var DefaultTestBroadcasterConfig = BroadcasterConfig{ EnableCompression: true, RequireCompression: false, LimitCatchup: false, + MaxCatchup: -1, ConnectionLimits: DefaultConnectionLimiterConfig, ClientDelay: 0, }