Skip to content

Commit

Permalink
Merge pull request #410 from matrix-org/kegan/race-tests
Browse files Browse the repository at this point in the history
Fix race conditions in tests
  • Loading branch information
kegsay authored Mar 11, 2024
2 parents cfff8bc + 05a82a4 commit 8750e1c
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Test
run: |
set -euo pipefail
go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all
go test -count=1 -race -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all
shell: bash
env:
POSTGRES_HOST: localhost
Expand Down
4 changes: 4 additions & 0 deletions pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ func (ps *PubSub) Notify(chanName string, p Payload) error {
return fmt.Errorf("notify with payload %v timed out", p.Type())
}
if ps.bufferSize == 0 {
// for some reason go test -race flags this as racing with calls
// to close(ch), despite the fact that it _should_ be thread-safe :S
ps.mu.Lock()
ch <- &emptyPayload{}
ps.mu.Unlock()
}
return nil
}
Expand Down
12 changes: 7 additions & 5 deletions state/accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/matrix-org/sliding-sync/testutils"
"reflect"
"sort"
"sync"
"sync/atomic"
"testing"

"github.com/matrix-org/sliding-sync/testutils"

"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/sync2"
Expand Down Expand Up @@ -680,7 +682,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
[]byte(`{"event_id":"con_4", "type":"m.room.name", "state_key":"", "content":{"name":"4"}}`),
[]byte(`{"event_id":"con_5", "type":"m.room.name", "state_key":"", "content":{"name":"5"}}`),
}
totalNumNew := 0
var totalNumNew atomic.Int64
var wg sync.WaitGroup
wg.Add(len(newEvents))
for i := 0; i < len(newEvents); i++ {
Expand All @@ -689,7 +691,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
result, err := accumulator.Accumulate(txn, userID, roomID, sync2.TimelineResponse{Events: subset})
totalNumNew += result.NumNew
totalNumNew.Add(int64(result.NumNew))
return err
})
if err != nil {
Expand All @@ -698,8 +700,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
}(i)
}
wg.Wait() // wait for all goroutines to finish
if totalNumNew != len(newEvents) {
t.Errorf("got %d total new events, want %d", totalNumNew, len(newEvents))
if int(totalNumNew.Load()) != len(newEvents) {
t.Errorf("got %d total new events, want %d", totalNumNew.Load(), len(newEvents))
}
// check that the name of the room is "5"
snapshot := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID)
Expand Down
47 changes: 32 additions & 15 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,23 @@ const initialSinceToken = "0"
var (
timeSinceMu sync.Mutex
timeSinceValue = time.Duration(0) // 0 means use the real impl
timeSleepMu sync.Mutex
timeSleepValue = time.Duration(0) // 0 means use the real impl
timeSleepCheck func(time.Duration) // called to check sleep values
)

func setTimeSinceValue(val time.Duration) {
timeSinceMu.Lock()
defer timeSinceMu.Unlock()
timeSinceValue = val
timeSinceMu.Unlock()
}
func setTimeSleepDelay(val time.Duration, fn ...func(d time.Duration)) {
timeSleepMu.Lock()
defer timeSleepMu.Unlock()
timeSleepValue = val
if len(fn) > 0 {
timeSleepCheck = fn[0]
}
}
func init() {
timeSince = func(t time.Time) time.Duration {
Expand All @@ -41,6 +52,18 @@ func init() {
}
return timeSinceValue
}
timeSleep = func(d time.Duration) {
timeSleepMu.Lock()
defer timeSleepMu.Unlock()
if timeSleepCheck != nil {
timeSleepCheck(d)
}
if timeSleepValue == 0 {
time.Sleep(d)
return
}
time.Sleep(timeSleepValue)
}
}

// Tests that EnsurePolling works in the happy case
Expand Down Expand Up @@ -583,12 +606,10 @@ func TestPollerGivesUpEventually(t *testing.T) {
accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
return nil, 524, fmt.Errorf("gateway timeout")
})
timeSleep = func(d time.Duration) {
// actually sleep to make sure async actions can happen if any
time.Sleep(1 * time.Microsecond)
}
// actually sleep to make sure async actions can happen if any
setTimeSleepDelay(time.Microsecond)
defer func() { // reset the value after the test runs
timeSleep = time.Sleep
setTimeSleepDelay(0)
}()
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -654,15 +675,13 @@ func TestPollerBackoff(t *testing.T) {
wantBackoffDuration = errorResponses[i].backoff
return nil, errorResponses[i].code, errorResponses[i].err
})
timeSleep = func(d time.Duration) {
setTimeSleepDelay(time.Millisecond, func(d time.Duration) {
if d != wantBackoffDuration {
t.Errorf("time.Sleep called incorrectly: got %v want %v", d, wantBackoffDuration)
}
// actually sleep to make sure async actions can happen if any
time.Sleep(1 * time.Millisecond)
}
})
defer func() { // reset the value after the test runs
timeSleep = time.Sleep
setTimeSleepDelay(0)
}()
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -727,12 +746,10 @@ func TestPollerResendsOnCallbackError(t *testing.T) {
pid := PollerID{UserID: "@TestPollerResendsOnCallbackError:localhost", DeviceID: "FOOBAR"}

defer func() { // reset the value after the test runs
timeSleep = time.Sleep
setTimeSleepDelay(0)
}()
// we don't actually want to wait 3s between retries, so monkey patch it out
timeSleep = func(d time.Duration) {
time.Sleep(time.Millisecond)
}
setTimeSleepDelay(time.Millisecond)

testCases := []struct {
name string
Expand Down
9 changes: 5 additions & 4 deletions sync3/connmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"sort"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -201,15 +202,15 @@ func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedF
t.Helper()
for cid, conn := range cidToConn {
if isDestroyedFn(cid) {
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid))
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), true, fmt.Sprintf("conn %+v was not destroyed", cid))
} else {
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid))
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed.Load(), false, fmt.Sprintf("conn %+v was destroyed", cid))
}
}
}

type mockConnHandler struct {
isDestroyed bool
isDestroyed atomic.Bool
cancel context.CancelFunc
}

Expand All @@ -219,7 +220,7 @@ func (c *mockConnHandler) OnIncomingRequest(ctx context.Context, cid ConnID, req
func (c *mockConnHandler) OnUpdate(ctx context.Context, update caches.Update) {}
func (c *mockConnHandler) PublishEventsUpTo(roomID string, nid int64) {}
func (c *mockConnHandler) Destroy() {
c.isDestroyed = true
c.isDestroyed.Store(true)
}
func (c *mockConnHandler) Alive() bool {
return true // buffer never fills up
Expand Down
13 changes: 7 additions & 6 deletions tests-integration/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"net/http"
"os"
"sync/atomic"
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"

"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/sync3/extensions"
Expand Down Expand Up @@ -45,7 +46,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
// now sync with device B, and check we send the filter up
deviceBToken := "DEVICE_B_TOKEN"
v2.addAccountWithDeviceID(alice, "B", deviceBToken)
seenInitialRequest := false
var seenInitialRequest atomic.Bool
v2.SetCheckRequest(func(token string, req *http.Request) {
if token != deviceBToken {
return
Expand All @@ -62,7 +63,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
timelineLimit := filterJSON.Get("room.timeline.limit").Int()
roomsFilter := filterJSON.Get("room.rooms")

if !seenInitialRequest {
if !seenInitialRequest.Load() {
// First poll: should be an initial sync, limit 1, excluding all room timelines.
if since != "" {
t.Errorf("Expected no since token on first poll, but got %v", since)
Expand All @@ -89,7 +90,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
}
}

seenInitialRequest = true
seenInitialRequest.Store(true)
})

wantMsg := json.RawMessage(`{"type":"f","content":{"f":"b"}}`)
Expand All @@ -110,7 +111,7 @@ func TestSecondPollerFiltersToDevice(t *testing.T) {
},
})

if !seenInitialRequest {
if !seenInitialRequest.Load() {
t.Fatalf("did not see initial request for 2nd device")
}
// the first request will not wait for the response before returning due to device A. Poll again
Expand Down

0 comments on commit 8750e1c

Please sign in to comment.