Skip to content

Commit

Permalink
tests: add test for full connection buffers and expiry
Browse files Browse the repository at this point in the history
Fixed a bug in notification code which could cause integration
tests to not be as deterministic as intended; should fix flakey
tests.
  • Loading branch information
kegsay committed Feb 3, 2023
1 parent fdc53dd commit 2139eda
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 29 deletions.
19 changes: 19 additions & 0 deletions pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ type Payload interface {
Type() string
}

// EmptyPayload is used internally to act as a synchronisation point with consumers when bufferSize==0.
// When no buffer is used, pubsub should act synchronously, meaning we wait for the consumer to process
// the message before sending the next one. This is used in tests to stop race conditions in the tests.
// We need to know when the consumer has consumed - make(ch, 0) isn't enough as that wakes up the producer
// too early (as soon as the consumer consumes it will free the buffer, whereas we need to wait for processing
// too). To ensure we wait for processing, we send this emptyPayload immediately after messages. When that
// returns, we know the previous payload was fully consumed.
type emptyPayload struct{}

func (p *emptyPayload) Type() string { return emptyPayloadType }

const emptyPayloadType = "empty"

// Listener represents the common functions required by all subscription listeners
type Listener interface {
// Begin listening on this channel with this callback starting from this position. Blocks until Close() is called.
Expand Down Expand Up @@ -70,6 +83,9 @@ func (ps *PubSub) Notify(chanName string, p Payload) error {
case <-time.After(5 * time.Second):
return fmt.Errorf("notify with payload %v timed out", p.Type())
}
if ps.bufferSize == 0 {
ch <- &emptyPayload{}
}
return nil
}

Expand All @@ -89,6 +105,9 @@ func (ps *PubSub) Close() error {
func (ps *PubSub) Listen(chanName string, fn func(p Payload)) error {
ch := ps.getChan(chanName)
for payload := range ch {
if payload.Type() == emptyPayloadType {
continue
}
fn(payload)
}
return nil
Expand Down
3 changes: 2 additions & 1 deletion sync3/handler/connstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type ConnState struct {
func NewConnState(
userID, deviceID string, userCache *caches.UserCache, globalCache *caches.GlobalCache,
ex extensions.HandlerInterface, joinChecker JoinChecker, histVec *prometheus.HistogramVec,
maxPendingEventUpdates int,
) *ConnState {
cs := &ConnState{
globalCache: globalCache,
Expand All @@ -66,7 +67,7 @@ func NewConnState(
cs.live = &connStateLive{
ConnState: cs,
loadPositions: make(map[string]int64),
updates: make(chan caches.Update, MaxPendingEventUpdates), // TODO: customisable
updates: make(chan caches.Update, maxPendingEventUpdates),
}
cs.userCacheID = cs.userCache.Subsribe(cs)
return cs
Expand Down
13 changes: 5 additions & 8 deletions sync3/handler/connstate_live.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ import (
"github.com/tidwall/gjson"
)

var (
// The max number of events the client is eligible to read (unfiltered) which we are willing to
// buffer on this connection. Too large and we consume lots of memory. Too small and busy accounts
// will trip the connection knifing.
MaxPendingEventUpdates = 2000
)
// the amount of time to try to insert into a full buffer before giving up.
// Customisable for testing
var BufferWaitTime = time.Second * 5

// Contains code for processing live updates. Split out from connstate because they concern different
// code paths. Relies on ConnState for various list/sort/subscription operations.
Expand All @@ -44,7 +41,7 @@ func (s *connStateLive) onUpdate(up caches.Update) {
}
select {
case s.updates <- up:
case <-time.After(5 * time.Second):
case <-time.After(BufferWaitTime):
logger.Warn().Interface("update", up).Str("user", s.userID).Msg(
"cannot send update to connection, buffer exceeded. Destroying connection.",
)
Expand Down Expand Up @@ -203,7 +200,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
// - we then process the live events in turn which adds them again.
if !advancedPastEvent {
roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(s.deviceID, map[string][]json.RawMessage{
roomEventUpdate.RoomID(): []json.RawMessage{roomEventUpdate.EventData.Event},
roomEventUpdate.RoomID(): {roomEventUpdate.EventData.Event},
})
r.Timeline = append(r.Timeline, roomIDtoTimeline[roomEventUpdate.RoomID()]...)
roomID := roomEventUpdate.RoomID()
Expand Down
8 changes: 4 additions & 4 deletions sync3/handler/connstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestConnStateInitial(t *testing.T) {
}
return result
}
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
if userID != cs.UserID() {
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
}
Expand Down Expand Up @@ -268,7 +268,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
userCache.LazyRoomDataOverride = mockLazyRoomOverride
dispatcher.Register(userCache.UserID, userCache)
dispatcher.Register(sync3.DispatcherAllUsers, globalCache)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)

// request first page
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
Expand Down Expand Up @@ -445,7 +445,7 @@ func TestBumpToOutsideRange(t *testing.T) {
userCache.LazyRoomDataOverride = mockLazyRoomOverride
dispatcher.Register(userCache.UserID, userCache)
dispatcher.Register(sync3.DispatcherAllUsers, globalCache)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
// Ask for A,B
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Expand Down Expand Up @@ -553,7 +553,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
}
dispatcher.Register(userCache.UserID, userCache)
dispatcher.Register(sync3.DispatcherAllUsers, globalCache)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
// subscribe to room D
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
Expand Down
22 changes: 12 additions & 10 deletions sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ type SyncLiveHandler struct {
userCaches *sync.Map // map[user_id]*UserCache
Dispatcher *sync3.Dispatcher

GlobalCache *caches.GlobalCache
GlobalCache *caches.GlobalCache
maxPendingEventUpdates int

numConns prometheus.Gauge
histVec *prometheus.HistogramVec
}

func NewSync3Handler(
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, postgresDBURI, secret string,
debug bool, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool,
debug bool, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int,
) (*SyncLiveHandler, error) {
logger.Info().Msg("creating handler")
if debug {
Expand All @@ -67,13 +68,14 @@ func NewSync3Handler(
zerolog.SetGlobalLevel(zerolog.InfoLevel)
}
sh := &SyncLiveHandler{
V2: v2Client,
Storage: store,
V2Store: storev2,
ConnMap: sync3.NewConnMap(),
userCaches: &sync.Map{},
Dispatcher: sync3.NewDispatcher(),
GlobalCache: caches.NewGlobalCache(store),
V2: v2Client,
Storage: store,
V2Store: storev2,
ConnMap: sync3.NewConnMap(),
userCaches: &sync.Map{},
Dispatcher: sync3.NewDispatcher(),
GlobalCache: caches.NewGlobalCache(store),
maxPendingEventUpdates: maxPendingEventUpdates,
}
sh.Extensions = &extensions.Handler{
Store: store,
Expand Down Expand Up @@ -359,7 +361,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
conn, created := h.ConnMap.CreateConn(sync3.ConnID{
DeviceID: deviceID,
}, func() sync3.ConnHandler {
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec)
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
})
if created {
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("created new connection")
Expand Down
63 changes: 63 additions & 0 deletions tests-integration/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package syncv3
import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"

slidingsync "github.com/matrix-org/sliding-sync"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils"
Expand Down Expand Up @@ -560,3 +562,64 @@ func TestSessionExpiry(t *testing.T) {
t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body))
}
}

func TestSessionExpiryOnBufferFill(t *testing.T) {
roomID := "!doesnt:matter"
maxPendingEventUpdates := 3
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
v2.addAccount(alice, aliceToken)
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
state: createRoomState(t, alice, time.Now()),
events: []json.RawMessage{
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "B"}),
},
}),
},
})
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
MaxPendingEventUpdates: maxPendingEventUpdates,
})

res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
})
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
roomID: {
m.MatchJoinCount(1),
},
}))

// inject maxPendingEventUpdates+1 events to expire the session
events := make([]json.RawMessage, maxPendingEventUpdates+1)
for i := range events {
events[i] = testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{
"msgtype": "m.text",
"body": fmt.Sprintf("Test %d", i),
}, testutils.WithTimestamp(time.Now().Add(time.Duration(i)*time.Second)))
}
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: events,
}),
},
})
v2.waitUntilEmpty(t, aliceToken)

_, body, code := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, sync3.Request{})
if code != 400 {
t.Errorf("got HTTP %d want 400", code)
}
if gjson.ParseBytes(body).Get("errcode").Str != "M_UNKNOWN_POS" {
t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body))
}
}
9 changes: 7 additions & 2 deletions tests-integration/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

slidingsync "github.com/matrix-org/sliding-sync"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils"
Expand Down Expand Up @@ -59,7 +60,9 @@ func TestMetricsNumPollers(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString, true)
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
AddPrometheusMetrics: true,
})
defer v2.close()
defer v3.close()
metricsServer := runMetricsServer(t)
Expand Down Expand Up @@ -106,7 +109,9 @@ func TestMetricsNumConns(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString, true)
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
AddPrometheusMetrics: true,
})
defer v2.close()
defer v3.close()
metricsServer := runMetricsServer(t)
Expand Down
12 changes: 9 additions & 3 deletions tests-integration/v3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,24 @@ func (s *testV3Server) doV3Request(t testutils.TestBenchInterface, ctx context.C
return &r, respBytes, resp.StatusCode
}

func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postgresConnectionString string, enableProm ...bool) *testV3Server {
func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postgresConnectionString string, opts ...syncv3.Opts) *testV3Server {
t.Helper()
if postgresConnectionString == "" {
postgresConnectionString = testutils.PrepareDBConnectionString()
}
metricsEnabled := false
if len(enableProm) > 0 && enableProm[0] {
metricsEnabled = true
maxPendingEventUpdates := 200
if len(opts) > 0 {
metricsEnabled = opts[0].AddPrometheusMetrics
if opts[0].MaxPendingEventUpdates > 0 {
maxPendingEventUpdates = opts[0].MaxPendingEventUpdates
handler.BufferWaitTime = 5 * time.Millisecond
}
}
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), syncv3.Opts{
Debug: true,
TestingSynchronousPubsub: true, // critical to avoid flakey tests
MaxPendingEventUpdates: maxPendingEventUpdates,
AddPrometheusMetrics: metricsEnabled,
})
// for ease of use we don't start v2 pollers at startup in tests
Expand Down
9 changes: 8 additions & 1 deletion v3.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ var Version string
type Opts struct {
Debug bool
AddPrometheusMetrics bool
// The max number of events the client is eligible to read (unfiltered) which we are willing to
// buffer on this connection. Too large and we consume lots of memory. Too small and busy accounts
// will trip the connection knifing. Customisable as tests might want to test filling the buffer.
MaxPendingEventUpdates int
// if true, publishing messages will block until the consumer has consumed it.
// Assumes a single producer and a single consumer.
TestingSynchronousPubsub bool
Expand Down Expand Up @@ -74,6 +78,9 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
if opts.TestingSynchronousPubsub {
bufferSize = 0
}
if opts.MaxPendingEventUpdates == 0 {
opts.MaxPendingEventUpdates = 2000
}
pubSub := pubsub.NewPubSub(bufferSize)

// create v2 handler
Expand All @@ -83,7 +90,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
}

// create v3 handler
h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, opts.Debug, pubSub, pubSub, opts.AddPrometheusMetrics)
h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, opts.Debug, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates)
if err != nil {
panic(err)
}
Expand Down

0 comments on commit 2139eda

Please sign in to comment.