diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 8730b8b3..dea3efe1 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -20,6 +20,19 @@ type Payload interface { Type() string } +// EmptyPayload is used internally to act as a synchronisation point with consumers when bufferSize==0. +// When no buffer is used, pubsub should act synchronously, meaning we wait for the consumer to process +// the message before sending the next one. This is used in tests to stop race conditions in the tests. +// We need to know when the consumer has consumed - make(ch, 0) isn't enough as that wakes up the producer +// too early (as soon as the consumer consumes it will free the buffer, whereas we need to wait for processing +// too). To ensure we wait for processing, we send this emptyPayload immediately after messages. When that +// returns, we know the previous payload was fully consumed. +type emptyPayload struct{} + +func (p *emptyPayload) Type() string { return emptyPayloadType } + +const emptyPayloadType = "empty" + // Listener represents the common functions required by all subscription listeners type Listener interface { // Begin listening on this channel with this callback starting from this position. Blocks until Close() is called. @@ -70,6 +83,9 @@ func (ps *PubSub) Notify(chanName string, p Payload) error { case <-time.After(5 * time.Second): return fmt.Errorf("notify with payload %v timed out", p.Type()) } + if ps.bufferSize == 0 { + ch <- &emptyPayload{} + } return nil } @@ -89,6 +105,9 @@ func (ps *PubSub) Close() error { func (ps *PubSub) Listen(chanName string, fn func(p Payload)) error { ch := ps.getChan(chanName) for payload := range ch { + if payload.Type() == emptyPayloadType { + continue + } fn(payload) } return nil diff --git a/sync3/handler/connstate.go b/sync3/handler/connstate.go index 6d86392a..c4a9c280 100644 --- a/sync3/handler/connstate.go +++ b/sync3/handler/connstate.go @@ -49,6 +49,7 @@ type ConnState struct { func NewConnState( userID, deviceID string, userCache *caches.UserCache, globalCache *caches.GlobalCache, ex extensions.HandlerInterface, joinChecker JoinChecker, histVec *prometheus.HistogramVec, + maxPendingEventUpdates int, ) *ConnState { cs := &ConnState{ globalCache: globalCache, @@ -66,7 +67,7 @@ func NewConnState( cs.live = &connStateLive{ ConnState: cs, loadPositions: make(map[string]int64), - updates: make(chan caches.Update, MaxPendingEventUpdates), // TODO: customisable + updates: make(chan caches.Update, maxPendingEventUpdates), } cs.userCacheID = cs.userCache.Subsribe(cs) return cs diff --git a/sync3/handler/connstate_live.go b/sync3/handler/connstate_live.go index d670a13a..e0d3b425 100644 --- a/sync3/handler/connstate_live.go +++ b/sync3/handler/connstate_live.go @@ -13,12 +13,9 @@ import ( "github.com/tidwall/gjson" ) -var ( - // The max number of events the client is eligible to read (unfiltered) which we are willing to - // buffer on this connection. Too large and we consume lots of memory. Too small and busy accounts - // will trip the connection knifing. - MaxPendingEventUpdates = 2000 -) +// the amount of time to try to insert into a full buffer before giving up. +// Customisable for testing +var BufferWaitTime = time.Second * 5 // Contains code for processing live updates. Split out from connstate because they concern different // code paths. Relies on ConnState for various list/sort/subscription operations. @@ -44,7 +41,7 @@ func (s *connStateLive) onUpdate(up caches.Update) { } select { case s.updates <- up: - case <-time.After(5 * time.Second): + case <-time.After(BufferWaitTime): logger.Warn().Interface("update", up).Str("user", s.userID).Msg( "cannot send update to connection, buffer exceeded. Destroying connection.", ) @@ -203,7 +200,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, // - we then process the live events in turn which adds them again. if !advancedPastEvent { roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(s.deviceID, map[string][]json.RawMessage{ - roomEventUpdate.RoomID(): []json.RawMessage{roomEventUpdate.EventData.Event}, + roomEventUpdate.RoomID(): {roomEventUpdate.EventData.Event}, }) r.Timeline = append(r.Timeline, roomIDtoTimeline[roomEventUpdate.RoomID()]...) roomID := roomEventUpdate.RoomID() diff --git a/sync3/handler/connstate_test.go b/sync3/handler/connstate_test.go index 50a360de..883e2030 100644 --- a/sync3/handler/connstate_test.go +++ b/sync3/handler/connstate_test.go @@ -104,7 +104,7 @@ func TestConnStateInitial(t *testing.T) { } return result } - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) if userID != cs.UserID() { t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID) } @@ -268,7 +268,7 @@ func TestConnStateMultipleRanges(t *testing.T) { userCache.LazyRoomDataOverride = mockLazyRoomOverride dispatcher.Register(userCache.UserID, userCache) dispatcher.Register(sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) // request first page res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ @@ -445,7 +445,7 @@ func TestBumpToOutsideRange(t *testing.T) { userCache.LazyRoomDataOverride = mockLazyRoomOverride dispatcher.Register(userCache.UserID, userCache) dispatcher.Register(sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) // Ask for A,B res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ Lists: map[string]sync3.RequestList{"a": { @@ -553,7 +553,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { } dispatcher.Register(userCache.UserID, userCache) dispatcher.Register(sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) // subscribe to room D res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ RoomSubscriptions: map[string]sync3.RoomSubscription{ diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 39a042de..72bab4ea 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -50,7 +50,8 @@ type SyncLiveHandler struct { userCaches *sync.Map // map[user_id]*UserCache Dispatcher *sync3.Dispatcher - GlobalCache *caches.GlobalCache + GlobalCache *caches.GlobalCache + maxPendingEventUpdates int numConns prometheus.Gauge histVec *prometheus.HistogramVec @@ -58,7 +59,7 @@ type SyncLiveHandler struct { func NewSync3Handler( store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, postgresDBURI, secret string, - debug bool, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, + debug bool, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int, ) (*SyncLiveHandler, error) { logger.Info().Msg("creating handler") if debug { @@ -67,13 +68,14 @@ func NewSync3Handler( zerolog.SetGlobalLevel(zerolog.InfoLevel) } sh := &SyncLiveHandler{ - V2: v2Client, - Storage: store, - V2Store: storev2, - ConnMap: sync3.NewConnMap(), - userCaches: &sync.Map{}, - Dispatcher: sync3.NewDispatcher(), - GlobalCache: caches.NewGlobalCache(store), + V2: v2Client, + Storage: store, + V2Store: storev2, + ConnMap: sync3.NewConnMap(), + userCaches: &sync.Map{}, + Dispatcher: sync3.NewDispatcher(), + GlobalCache: caches.NewGlobalCache(store), + maxPendingEventUpdates: maxPendingEventUpdates, } sh.Extensions = &extensions.Handler{ Store: store, @@ -359,7 +361,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ conn, created := h.ConnMap.CreateConn(sync3.ConnID{ DeviceID: deviceID, }, func() sync3.ConnHandler { - return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec) + return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates) }) if created { log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("created new connection") diff --git a/tests-integration/connection_test.go b/tests-integration/connection_test.go index 1e77706b..66f2ca21 100644 --- a/tests-integration/connection_test.go +++ b/tests-integration/connection_test.go @@ -3,10 +3,12 @@ package syncv3 import ( "context" "encoding/json" + "fmt" "sync" "testing" "time" + slidingsync "github.com/matrix-org/sliding-sync" "github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/testutils" @@ -560,3 +562,64 @@ func TestSessionExpiry(t *testing.T) { t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body)) } } + +func TestSessionExpiryOnBufferFill(t *testing.T) { + roomID := "!doesnt:matter" + maxPendingEventUpdates := 3 + pqString := testutils.PrepareDBConnectionString() + v2 := runTestV2Server(t) + v2.addAccount(alice, aliceToken) + v2.queueResponse(alice, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + state: createRoomState(t, alice, time.Now()), + events: []json.RawMessage{ + testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "B"}), + }, + }), + }, + }) + v3 := runTestServer(t, v2, pqString, slidingsync.Opts{ + MaxPendingEventUpdates: maxPendingEventUpdates, + }) + + res := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + RoomSubscriptions: map[string]sync3.RoomSubscription{ + roomID: { + TimelineLimit: 1, + }, + }, + }) + m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ + roomID: { + m.MatchJoinCount(1), + }, + })) + + // inject maxPendingEventUpdates+1 events to expire the session + events := make([]json.RawMessage, maxPendingEventUpdates+1) + for i := range events { + events[i] = testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{ + "msgtype": "m.text", + "body": fmt.Sprintf("Test %d", i), + }, testutils.WithTimestamp(time.Now().Add(time.Duration(i)*time.Second))) + } + v2.queueResponse(alice, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + events: events, + }), + }, + }) + v2.waitUntilEmpty(t, aliceToken) + + _, body, code := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, sync3.Request{}) + if code != 400 { + t.Errorf("got HTTP %d want 400", code) + } + if gjson.ParseBytes(body).Get("errcode").Str != "M_UNKNOWN_POS" { + t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body)) + } +} diff --git a/tests-integration/metrics_test.go b/tests-integration/metrics_test.go index df07ba0d..f332f9b9 100644 --- a/tests-integration/metrics_test.go +++ b/tests-integration/metrics_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + slidingsync "github.com/matrix-org/sliding-sync" "github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/testutils" @@ -59,7 +60,9 @@ func TestMetricsNumPollers(t *testing.T) { pqString := testutils.PrepareDBConnectionString() // setup code v2 := runTestV2Server(t) - v3 := runTestServer(t, v2, pqString, true) + v3 := runTestServer(t, v2, pqString, slidingsync.Opts{ + AddPrometheusMetrics: true, + }) defer v2.close() defer v3.close() metricsServer := runMetricsServer(t) @@ -106,7 +109,9 @@ func TestMetricsNumConns(t *testing.T) { pqString := testutils.PrepareDBConnectionString() // setup code v2 := runTestV2Server(t) - v3 := runTestServer(t, v2, pqString, true) + v3 := runTestServer(t, v2, pqString, slidingsync.Opts{ + AddPrometheusMetrics: true, + }) defer v2.close() defer v3.close() metricsServer := runMetricsServer(t) diff --git a/tests-integration/v3_test.go b/tests-integration/v3_test.go index f0d9f3a5..82fa4ac3 100644 --- a/tests-integration/v3_test.go +++ b/tests-integration/v3_test.go @@ -309,18 +309,24 @@ func (s *testV3Server) doV3Request(t testutils.TestBenchInterface, ctx context.C return &r, respBytes, resp.StatusCode } -func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postgresConnectionString string, enableProm ...bool) *testV3Server { +func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postgresConnectionString string, opts ...syncv3.Opts) *testV3Server { t.Helper() if postgresConnectionString == "" { postgresConnectionString = testutils.PrepareDBConnectionString() } metricsEnabled := false - if len(enableProm) > 0 && enableProm[0] { - metricsEnabled = true + maxPendingEventUpdates := 200 + if len(opts) > 0 { + metricsEnabled = opts[0].AddPrometheusMetrics + if opts[0].MaxPendingEventUpdates > 0 { + maxPendingEventUpdates = opts[0].MaxPendingEventUpdates + handler.BufferWaitTime = 5 * time.Millisecond + } } h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), syncv3.Opts{ Debug: true, TestingSynchronousPubsub: true, // critical to avoid flakey tests + MaxPendingEventUpdates: maxPendingEventUpdates, AddPrometheusMetrics: metricsEnabled, }) // for ease of use we don't start v2 pollers at startup in tests diff --git a/v3.go b/v3.go index 089c3d1b..771898f1 100644 --- a/v3.go +++ b/v3.go @@ -28,6 +28,10 @@ var Version string type Opts struct { Debug bool AddPrometheusMetrics bool + // The max number of events the client is eligible to read (unfiltered) which we are willing to + // buffer on this connection. Too large and we consume lots of memory. Too small and busy accounts + // will trip the connection knifing. Customisable as tests might want to test filling the buffer. + MaxPendingEventUpdates int // if true, publishing messages will block until the consumer has consumed it. // Assumes a single producer and a single consumer. TestingSynchronousPubsub bool @@ -74,6 +78,9 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han if opts.TestingSynchronousPubsub { bufferSize = 0 } + if opts.MaxPendingEventUpdates == 0 { + opts.MaxPendingEventUpdates = 2000 + } pubSub := pubsub.NewPubSub(bufferSize) // create v2 handler @@ -83,7 +90,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han } // create v3 handler - h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, opts.Debug, pubSub, pubSub, opts.AddPrometheusMetrics) + h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, opts.Debug, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates) if err != nil { panic(err) }