diff --git a/sync3/extensions/todevice.go b/sync3/extensions/todevice.go index 951177b3..430c3da5 100644 --- a/sync3/extensions/todevice.go +++ b/sync3/extensions/todevice.go @@ -65,6 +65,30 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext r.Limit = 100 // default to 100 } l := logger.With().Str("user", extCtx.UserID).Str("device", extCtx.DeviceID).Logger() + + mapMu.Lock() + lastSentPos, exists := deviceIDToSinceDebugOnly[extCtx.DeviceID] + internal.Logf(ctx, "to_device", "since=%v limit=%v last_sent=%v", r.Since, r.Limit, lastSentPos) + isFirstRequest := !exists + mapMu.Unlock() + + // If this is the first time we've seen this device ID since starting up, ignore the client-provided 'since' + // value. This is done to protect against dropped postgres sequences. Consider: + // - 5 to-device messages arrive for Alice + // - Alice requests all messages, gets them and acks them so since=5, and the nextval() sequence is 6. + // - the server admin drops the DB and starts over again. The DB sequence starts back at 1. + // - 2 to-device messages arrive for Alice + // - Alice requests messages from since=5. No messages are returned as the 2 new messages have a lower sequence number. + // - Even worse, those 2 messages are deleted because sending since=5 ACKNOWLEDGES all messages <=5. + // By ignoring the first `since` on startup, we effectively force the client into sending since=0. In this scenario, + // it will then A) not delete anything as since=0 acknowledges nothing, B) return the 2 to-device events. + // + // The cost to this is that it is possible to send duplicate to-device events if the server restarts before the client + // has time to send the ACK to the server. This isn't fatal as clients do suppress duplicate to-device events. + if isFirstRequest { + r.Since = "" + } + var from int64 var err error if r.Since != "" { @@ -82,10 +106,7 @@ func (r *ToDeviceRequest) ProcessInitial(ctx context.Context, res *Response, ext internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) } } - mapMu.Lock() - lastSentPos := deviceIDToSinceDebugOnly[extCtx.DeviceID] - internal.Logf(ctx, "to_device", "since=%v limit=%v last_sent=%v", r.Since, r.Limit, lastSentPos) - mapMu.Unlock() + if from < lastSentPos { // we told the client about a newer position, but yet they are using an older position, yell loudly l.Warn().Int64("last_sent", lastSentPos).Int64("recv", from).Bool("initial", extCtx.IsInitial).Msg( diff --git a/tests-integration/extensions_test.go b/tests-integration/extensions_test.go index ed25a074..26a5480f 100644 --- a/tests-integration/extensions_test.go +++ b/tests-integration/extensions_test.go @@ -2,6 +2,8 @@ package syncv3 import ( "encoding/json" + "fmt" + "strconv" "testing" "time" @@ -443,6 +445,57 @@ func TestExtensionToDevice(t *testing.T) { m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(newToDeviceMsgs)) } +// Test that if you sync with a very very high numbered since value, we return lower numbered entries. +// This guards against dropped databases. +func TestExtensionToDeviceSequence(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + // setup code + v2 := runTestV2Server(t) + v3 := runTestServer(t, v2, pqString) + defer v2.close() + defer v3.close() + alice := "@TestExtensionToDeviceSequence_alice:localhost" + aliceToken := "ALICE_BEARER_TOKEN_TestExtensionToDeviceSequence" + v2.addAccount(t, alice, aliceToken) + toDeviceMsgs := []json.RawMessage{ + json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"1"}}`), + json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"2"}}`), + json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"3"}}`), + json.RawMessage(`{"sender":"alice","type":"something","content":{"foo":"4"}}`), + } + v2.queueResponse(alice, sync2.SyncResponse{ + ToDevice: sync2.EventsResponse{ + Events: toDeviceMsgs, + }, + }) + + hiSince := 999999 + res := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 10}, // doesn't matter + }, + }}, + Extensions: extensions.Request{ + ToDevice: &extensions.ToDeviceRequest{ + Core: extensions.Core{Enabled: &boolTrue}, + Since: fmt.Sprintf("%d", hiSince), + }, + }, + }) + m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0)), m.MatchToDeviceMessages(toDeviceMsgs), func(res *sync3.Response) error { + // ensure that we return a lower numbered since token + got, err := strconv.Atoi(res.Extensions.ToDevice.NextBatch) + if err != nil { + return err + } + if got >= hiSince { + return fmt.Errorf("next_batch got %v wanted lower than %v", got, hiSince) + } + return nil + }) +} + // tests that the account data extension works: // 1- check global account data is sent on first connection // 2- check global account data updates are proxied through