diff --git a/discovery/api/server/api.go b/discovery/api/server/api.go index dd519d443..fe8b3d2b7 100644 --- a/discovery/api/server/api.go +++ b/discovery/api/server/api.go @@ -71,11 +71,12 @@ func (w *Wrapper) GetPresentations(ctx context.Context, request GetPresentations timestamp = *request.Params.Timestamp } - presentations, newTimestamp, err := w.Server.Get(contextWithForwardedHost(ctx), request.ServiceID, timestamp) + presentations, seed, newTimestamp, err := w.Server.Get(contextWithForwardedHost(ctx), request.ServiceID, timestamp) if err != nil { return nil, err } return GetPresentations200JSONResponse{ + Seed: seed, Entries: presentations, Timestamp: newTimestamp, }, nil diff --git a/discovery/api/server/api_test.go b/discovery/api/server/api_test.go index 9092aef46..e04bd02f2 100644 --- a/discovery/api/server/api_test.go +++ b/discovery/api/server/api_test.go @@ -35,10 +35,11 @@ const serviceID = "wonderland" func TestWrapper_GetPresentations(t *testing.T) { lastTimestamp := 1 presentations := map[string]vc.VerifiablePresentation{} + seed := "seed" ctx := context.Background() t.Run("no timestamp", func(t *testing.T) { test := newMockContext(t) - test.server.EXPECT().Get(gomock.Any(), serviceID, 0).Return(presentations, lastTimestamp, nil) + test.server.EXPECT().Get(gomock.Any(), serviceID, 0).Return(presentations, seed, lastTimestamp, nil) response, err := test.wrapper.GetPresentations(ctx, GetPresentationsRequestObject{ServiceID: serviceID}) @@ -46,11 +47,12 @@ func TestWrapper_GetPresentations(t *testing.T) { require.IsType(t, GetPresentations200JSONResponse{}, response) assert.Equal(t, lastTimestamp, response.(GetPresentations200JSONResponse).Timestamp) assert.Equal(t, presentations, response.(GetPresentations200JSONResponse).Entries) + assert.Equal(t, seed, response.(GetPresentations200JSONResponse).Seed) }) t.Run("with timestamp", func(t *testing.T) { givenTimestamp := 1 test := newMockContext(t) - test.server.EXPECT().Get(gomock.Any(), serviceID, 1).Return(presentations, lastTimestamp, nil) + test.server.EXPECT().Get(gomock.Any(), serviceID, 1).Return(presentations, seed, lastTimestamp, nil) response, err := test.wrapper.GetPresentations(ctx, GetPresentationsRequestObject{ ServiceID: serviceID, @@ -66,7 +68,7 @@ func TestWrapper_GetPresentations(t *testing.T) { }) t.Run("error", func(t *testing.T) { test := newMockContext(t) - test.server.EXPECT().Get(gomock.Any(), serviceID, 0).Return(nil, 0, errors.New("foo")) + test.server.EXPECT().Get(gomock.Any(), serviceID, 0).Return(nil, "", 0, errors.New("foo")) _, err := test.wrapper.GetPresentations(ctx, GetPresentationsRequestObject{ServiceID: serviceID}) diff --git a/discovery/api/server/client/http.go b/discovery/api/server/client/http.go index 84695dde1..6623b642a 100644 --- a/discovery/api/server/client/http.go +++ b/discovery/api/server/client/http.go @@ -71,31 +71,31 @@ func (h DefaultHTTPClient) Register(ctx context.Context, serviceEndpointURL stri return nil } -func (h DefaultHTTPClient) Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, int, error) { +func (h DefaultHTTPClient) Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, string, int, error) { httpRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, serviceEndpointURL, nil) httpRequest.URL.RawQuery = url.Values{"timestamp": []string{fmt.Sprintf("%d", timestamp)}}.Encode() if err != nil { - return nil, 0, err + return nil, "", 0, err } httpRequest.Header.Set("X-Forwarded-Host", httpRequest.Host) // prevent cycles httpResponse, err := h.client.Do(httpRequest) if err != nil { - return nil, 0, fmt.Errorf("failed to invoke remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, "", 0, fmt.Errorf("failed to invoke remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } defer httpResponse.Body.Close() if err := core.TestResponseCode(200, httpResponse); err != nil { httpErr := err.(core.HttpError) // TestResponseCodeWithLog always returns an HttpError - return nil, 0, fmt.Errorf("non-OK response from remote Discovery Service (url=%s): %s", serviceEndpointURL, problemResponseToError(httpErr)) + return nil, "", 0, fmt.Errorf("non-OK response from remote Discovery Service (url=%s): %s", serviceEndpointURL, problemResponseToError(httpErr)) } responseData, err := io.ReadAll(httpResponse.Body) if err != nil { - return nil, 0, fmt.Errorf("failed to read response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, "", 0, fmt.Errorf("failed to read response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } var result PresentationsResponse if err := json.Unmarshal(responseData, &result); err != nil { - return nil, 0, fmt.Errorf("failed to unmarshal response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, "", 0, fmt.Errorf("failed to unmarshal response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } - return result.Entries, result.Timestamp, nil + return result.Entries, result.Seed, result.Timestamp, nil } // problemResponseToError converts a Problem Details response to an error. diff --git a/discovery/api/server/client/http_test.go b/discovery/api/server/client/http_test.go index c83262180..bfe0717ce 100644 --- a/discovery/api/server/client/http_test.go +++ b/discovery/api/server/client/http_test.go @@ -79,29 +79,32 @@ func TestHTTPInvoker_Get(t *testing.T) { t.Run("no timestamp from client", func(t *testing.T) { handler := &testHTTP.Handler{StatusCode: http.StatusOK} handler.ResponseData = map[string]interface{}{ + "seed": "seed", "entries": map[string]interface{}{"1": vp}, "timestamp": 1, } server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - presentations, timestamp, err := client.Get(context.Background(), server.URL, 0) + presentations, seed, timestamp, err := client.Get(context.Background(), server.URL, 0) assert.NoError(t, err) assert.Len(t, presentations, 1) assert.Equal(t, "0", handler.RequestQuery.Get("timestamp")) assert.Equal(t, 1, timestamp) + assert.Equal(t, "seed", seed) }) t.Run("timestamp provided by client", func(t *testing.T) { handler := &testHTTP.Handler{StatusCode: http.StatusOK} handler.ResponseData = map[string]interface{}{ + "seed": "seed", "entries": map[string]interface{}{"1": vp}, "timestamp": 1, } server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - presentations, timestamp, err := client.Get(context.Background(), server.URL, 1) + presentations, _, timestamp, err := client.Get(context.Background(), server.URL, 1) assert.NoError(t, err) assert.Len(t, presentations, 1) @@ -119,7 +122,7 @@ func TestHTTPInvoker_Get(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) client := New(false, time.Minute, server.TLS) - _, _, err := client.Get(context.Background(), server.URL, 0) + _, _, _, err := client.Get(context.Background(), server.URL, 0) require.NoError(t, err) assert.True(t, strings.HasPrefix(capturedRequest.Header.Get("X-Forwarded-Host"), "127.0.0.1")) @@ -129,7 +132,7 @@ func TestHTTPInvoker_Get(t *testing.T) { server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - _, _, err := client.Get(context.Background(), server.URL, 0) + _, _, _, err := client.Get(context.Background(), server.URL, 0) assert.ErrorContains(t, err, "non-OK response from remote Discovery Service") assert.ErrorContains(t, err, "server returned HTTP status code 500") @@ -141,7 +144,7 @@ func TestHTTPInvoker_Get(t *testing.T) { server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - _, _, err := client.Get(context.Background(), server.URL, 0) + _, _, _, err := client.Get(context.Background(), server.URL, 0) assert.ErrorContains(t, err, "failed to unmarshal response from remote Discovery Service") }) diff --git a/discovery/api/server/client/interface.go b/discovery/api/server/client/interface.go index 24087f718..10722ea8c 100644 --- a/discovery/api/server/client/interface.go +++ b/discovery/api/server/client/interface.go @@ -31,5 +31,5 @@ type HTTPClient interface { // Get retrieves Verifiable Presentations from the remote Discovery Service, that were added since the given timestamp. // If the call succeeds it returns the Verifiable Presentations and the timestamp that was returned by the server. // If the given timestamp is 0, all Verifiable Presentations are retrieved. - Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, int, error) + Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, string, int, error) } diff --git a/discovery/api/server/client/mock.go b/discovery/api/server/client/mock.go index 2fe595a28..895718a0d 100644 --- a/discovery/api/server/client/mock.go +++ b/discovery/api/server/client/mock.go @@ -41,13 +41,14 @@ func (m *MockHTTPClient) EXPECT() *MockHTTPClientMockRecorder { } // Get mocks base method. -func (m *MockHTTPClient) Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, int, error) { +func (m *MockHTTPClient) Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, string, int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", ctx, serviceEndpointURL, timestamp) ret0, _ := ret[0].(map[string]vc.VerifiablePresentation) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(int) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 } // Get indicates an expected call of Get. diff --git a/discovery/api/server/client/types.go b/discovery/api/server/client/types.go index 89c17609e..184a4a0a8 100644 --- a/discovery/api/server/client/types.go +++ b/discovery/api/server/client/types.go @@ -24,6 +24,8 @@ import "github.com/nuts-foundation/go-did/vc" type PresentationsResponse struct { // Entries contains mappings from timestamp (as string) to a VerifiablePresentation. Entries map[string]vc.VerifiablePresentation `json:"entries"` + // Seed is a unique value for the combination of serviceID and a server instance. + Seed string `json:"seed"` // Timestamp is the timestamp of the latest entry. It's not a unix timestamp but a Lamport Clock. Timestamp int `json:"timestamp"` } diff --git a/discovery/client.go b/discovery/client.go index 11a4665ce..e22d4b351 100644 --- a/discovery/client.go +++ b/discovery/client.go @@ -367,16 +367,21 @@ func (u *clientUpdater) updateService(ctx context.Context, service ServiceDefini log.Logger(). WithField("discoveryService", service.ID). Tracef("Checking for new Verifiable Presentations from Discovery Service (timestamp: %d)", currentTimestamp) - presentations, serverTimestamp, err := u.client.Get(ctx, service.Endpoint, currentTimestamp) + presentations, seed, serverTimestamp, err := u.client.Get(ctx, service.Endpoint, currentTimestamp) if err != nil { return fmt.Errorf("failed to get presentations from discovery service (id=%s): %w", service.ID, err) } + // check testSeed in store, wipe if it's different. Done by the store for transaction safety. + err = u.store.wipeOnSeedChange(service.ID, seed) + if err != nil { + return fmt.Errorf("failed to wipe on testSeed change (service=%s, testSeed=%s): %w", service.ID, seed, err) + } for _, presentation := range presentations { if err := u.verifier(service, presentation); err != nil { log.Logger().WithError(err).Warnf("Presentation verification failed, not adding it (service=%s, id=%s)", service.ID, presentation.ID) continue } - if err := u.store.add(service.ID, presentation, serverTimestamp); err != nil { + if err := u.store.add(service.ID, presentation, seed, serverTimestamp); err != nil { return fmt.Errorf("failed to store presentation (service=%s, id=%s): %w", service.ID, presentation.ID, err) } log.Logger(). diff --git a/discovery/client_test.go b/discovery/client_test.go index 4098df696..0ab14a8c4 100644 --- a/discovery/client_test.go +++ b/discovery/client_test.go @@ -221,7 +221,7 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { ctx.invoker.EXPECT().Register(gomock.Any(), gomock.Any(), gomock.Any()) ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(&vpAlice, nil) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - require.NoError(t, ctx.store.add(testServiceID, vpAlice, 1)) + require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1)) err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) @@ -236,7 +236,7 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { claims["retract_jti"] = vpAlice.ID.String() vp.Type = append(vp.Type, retractionPresentationType) }, vcAlice) - require.NoError(t, ctx.store.add(testServiceID, vpAliceDeactivated, 1)) + require.NoError(t, ctx.store.add(testServiceID, vpAliceDeactivated, testSeed, 1)) err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) @@ -255,7 +255,7 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { ctx.invoker.EXPECT().Register(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("remote error")) ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(&vpAlice, nil) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - require.NoError(t, ctx.store.add(testServiceID, vpAlice, 1)) + require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1)) err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) @@ -266,7 +266,7 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { ctx := newTestContext(t) ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(nil, assert.AnError) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - require.NoError(t, ctx.store.add(testServiceID, vpAlice, 1)) + require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1)) err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) @@ -394,7 +394,7 @@ func Test_clientUpdater_updateService(t *testing.T) { httpClient := client.NewMockHTTPClient(ctrl) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) - httpClient.EXPECT().Get(ctx, testDefinitions()[testServiceID].Endpoint, 0).Return(map[string]vc.VerifiablePresentation{}, 0, nil) + httpClient.EXPECT().Get(ctx, testDefinitions()[testServiceID].Endpoint, 0).Return(map[string]vc.VerifiablePresentation{}, testSeed, 0, nil) err := updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -406,7 +406,7 @@ func Test_clientUpdater_updateService(t *testing.T) { httpClient := client.NewMockHTTPClient(ctrl) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) - httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, 1, nil) + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, testSeed, 1, nil) err := updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -423,7 +423,7 @@ func Test_clientUpdater_updateService(t *testing.T) { return nil }, httpClient) - httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice, "2": vpBob}, 2, nil) + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice, "2": vpBob}, testSeed, 2, nil) err := updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -440,28 +440,49 @@ func Test_clientUpdater_updateService(t *testing.T) { resetStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) - err := store.setTimestamp(store.db, testServiceID, 1) + err := store.setTimestamp(store.db, testServiceID, testSeed, 1) require.NoError(t, err) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) - httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 1).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, 1, nil) + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 1).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, testSeed, 1, nil) err = updater.updateService(ctx, testDefinitions()[testServiceID]) require.NoError(t, err) }) + t.Run("seed change wipes entries", func(t *testing.T) { + resetStore(t, storageEngine.GetSQLDatabase()) + ctrl := gomock.NewController(t) + httpClient := client.NewMockHTTPClient(ctrl) + updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) + store.add(testServiceID, vpAlice, testSeed, 0) + + exists, err := store.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) + require.NoError(t, err) + require.True(t, exists) + + httpClient.EXPECT().Get(ctx, testDefinitions()[testServiceID].Endpoint, 1).Return(map[string]vc.VerifiablePresentation{}, "other", 0, nil) + + err = updater.updateService(ctx, testDefinitions()[testServiceID]) + + require.NoError(t, err) + exists, err = store.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) + require.NoError(t, err) + require.False(t, exists) + }) } func Test_clientUpdater_update(t *testing.T) { + seed := "seed" t.Run("proceeds when service update fails", func(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) store := setupStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, 0, nil) - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, 0, errors.New("test")) - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/unsupported", gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, 0, nil) + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, seed, 0, nil) + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, "", 0, errors.New("test")) + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/unsupported", gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, seed, 0, nil) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) err := updater.update(context.Background()) @@ -474,7 +495,7 @@ func Test_clientUpdater_update(t *testing.T) { store := setupStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, 0, nil).MinTimes(2) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, seed, 0, nil).MinTimes(2) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) err := updater.update(context.Background()) diff --git a/discovery/interface.go b/discovery/interface.go index f2522d261..7a9cf5e92 100644 --- a/discovery/interface.go +++ b/discovery/interface.go @@ -48,7 +48,7 @@ type Server interface { Register(context context.Context, serviceID string, presentation vc.VerifiablePresentation) error // Get retrieves the presentations for the given service, starting from the given timestamp. // If the node is not configured as server for the given serviceID, the call will be forwarded to the configured server. - Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) + Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, string, int, error) } // Client defines the API for Discovery Clients. diff --git a/discovery/mock.go b/discovery/mock.go index 5466dd198..3d231317e 100644 --- a/discovery/mock.go +++ b/discovery/mock.go @@ -41,13 +41,14 @@ func (m *MockServer) EXPECT() *MockServerMockRecorder { } // Get mocks base method. -func (m *MockServer) Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) { +func (m *MockServer) Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, string, int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", context, serviceID, startAfter) ret0, _ := ret[0].(map[string]vc.VerifiablePresentation) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(int) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 } // Get indicates an expected call of Get. diff --git a/discovery/module.go b/discovery/module.go index 0ad5f38e5..a34f45fed 100644 --- a/discovery/module.go +++ b/discovery/module.go @@ -203,7 +203,7 @@ func (m *Module) Register(context context.Context, serviceID string, presentatio return err } - return m.store.add(serviceID, presentation, 0) + return m.store.add(serviceID, presentation, "", 0) } func (m *Module) verifyRegistration(definition ServiceDefinition, presentation vc.VerifiablePresentation) error { @@ -327,18 +327,18 @@ func (m *Module) validateRetraction(serviceID string, presentation vc.Verifiable // Get is a Discovery Server function that retrieves the presentations for the given service, starting at timestamp+1. // See interface.go for more information. -func (m *Module) Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) { +func (m *Module) Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, string, int, error) { _, exists := m.serverDefinitions[serviceID] if !exists { // forward to configured server service, exists := m.allDefinitions[serviceID] if !exists { - return nil, 0, ErrServiceNotFound + return nil, "", 0, ErrServiceNotFound } // check If X-Forwarded-Host header is set, if set it must not be the same as service.Endpoint if cycleDetected(context, service) { - return nil, 0, errCyclicForwardingDetected + return nil, "", 0, errCyclicForwardingDetected } log.Logger().Infof("Forwarding Get request to configured server (service=%s)", serviceID) diff --git a/discovery/module_test.go b/discovery/module_test.go index 8e7d317bc..23f5f436e 100644 --- a/discovery/module_test.go +++ b/discovery/module_test.go @@ -39,6 +39,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "gorm.io/gorm" + "os" "testing" "time" ) @@ -65,9 +67,16 @@ func Test_Module_Register(t *testing.T) { err := m.Register(ctx, testServiceID, vpAlice) require.NoError(t, err) - _, timestamp, err := m.Get(ctx, testServiceID, 0) + _, seed, timestamp, err := m.Get(ctx, testServiceID, 0) require.NoError(t, err) assert.Equal(t, 1, timestamp) + assert.NotEmpty(t, seed) + + t.Run("already exists", func(t *testing.T) { + err = m.Register(ctx, testServiceID, vpAlice) + + assert.ErrorIs(t, err, ErrPresentationAlreadyExists) + }) }) t.Run("not a server", func(t *testing.T) { m, _ := setupModule(t, storageEngine, func(module *Module) { @@ -76,7 +85,7 @@ func Test_Module_Register(t *testing.T) { Endpoint: "https://example.com/someother", } mockhttpclient := module.httpClient.(*client.MockHTTPClient) - mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", gomock.Any()).Return(nil, 0, nil).AnyTimes() + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", gomock.Any()).Return(nil, testSeed, 0, nil).AnyTimes() mockhttpclient.EXPECT().Register(gomock.Any(), "https://example.com/someother", vpAlice).Return(nil) }) @@ -91,19 +100,10 @@ func Test_Module_Register(t *testing.T) { err := m.Register(ctx, testServiceID, vpAlice) require.EqualError(t, err, "presentation is invalid for registration\npresentation verification failed: failed") - _, timestamp, err := m.Get(ctx, testServiceID, 0) + _, _, timestamp, err := m.Get(ctx, testServiceID, 0) require.NoError(t, err) assert.Equal(t, 0, timestamp) }) - t.Run("already exists", func(t *testing.T) { - m, testContext := setupModule(t, storageEngine) - testContext.verifier.EXPECT().VerifyVP(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - - err := m.Register(ctx, testServiceID, vpAlice) - assert.NoError(t, err) - err = m.Register(ctx, testServiceID, vpAlice) - assert.ErrorIs(t, err, ErrPresentationAlreadyExists) - }) t.Run("valid for too long", func(t *testing.T) { m, _ := setupModule(t, storageEngine, func(module *Module) { def := module.allDefinitions[testServiceID] @@ -160,7 +160,7 @@ func Test_Module_Register(t *testing.T) { err := m.Register(ctx, testServiceID, otherVP) assert.ErrorIs(t, err, pe.ErrNoCredentials) - _, timestamp, _ := m.Get(ctx, testServiceID, 0) + _, _, timestamp, _ := m.Get(ctx, testServiceID, 0) assert.Equal(t, 0, timestamp) }) t.Run("unsupported DID method", func(t *testing.T) { @@ -184,7 +184,7 @@ func Test_Module_Register(t *testing.T) { Endpoint: "https://example.com/someother", } mockhttpclient := module.httpClient.(*client.MockHTTPClient) - mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", gomock.Any()).Return(nil, 0, nil).AnyTimes() + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", gomock.Any()).Return(nil, testSeed, 0, nil).AnyTimes() }) ctx := context.WithValue(ctx, XForwardedHostContextKey{}, "https://example.com") @@ -200,7 +200,10 @@ func Test_Module_Register(t *testing.T) { claims[jwt.AudienceKey] = []string{testServiceID} }) t.Run("ok", func(t *testing.T) { - m, testContext := setupModule(t, storageEngine) + m, testContext := setupModule(t, storageEngine, func(module *Module) { + // disable updater + module.config.Client.RefreshInterval = 0 + }) testContext.verifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).Times(2) err := m.Register(ctx, testServiceID, vpAlice) @@ -260,18 +263,18 @@ func Test_Module_Get(t *testing.T) { ctx := context.Background() t.Run("ok", func(t *testing.T) { m, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, 0)) - presentations, timestamp, err := m.Get(ctx, testServiceID, 0) + require.NoError(t, m.store.add(testServiceID, vpAlice, testSeed, 0)) + presentations, seed, timestamp, err := m.Get(ctx, testServiceID, 0) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice}, presentations) assert.Equal(t, 1, timestamp) - }) - t.Run("ok - retrieve delta", func(t *testing.T) { - m, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, 0)) - presentations, _, err := m.Get(ctx, testServiceID, 0) - require.NoError(t, err) - require.Len(t, presentations, 1) + assert.NotEmpty(t, seed) + + t.Run("ok - retrieve delta", func(t *testing.T) { + presentations, _, _, err := m.Get(ctx, testServiceID, 1) + require.NoError(t, err) + require.Len(t, presentations, 0) + }) }) t.Run("not a server for this service ID, call forwarded", func(t *testing.T) { m, _ := setupModule(t, storageEngine, func(module *Module) { @@ -280,14 +283,15 @@ func Test_Module_Get(t *testing.T) { Endpoint: "https://example.com/someother", } mockhttpclient := module.httpClient.(*client.MockHTTPClient) - mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, 1, nil).AnyTimes() + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, "otherSeed", 1, nil).AnyTimes() }) - presentations, timestamp, err := m.Get(ctx, "someother", 0) + presentations, seed, timestamp, err := m.Get(ctx, "someother", 0) require.NoError(t, err) assert.Equal(t, 1, timestamp) assert.Len(t, presentations, 1) + assert.Equal(t, "otherSeed", seed) }) t.Run("not a server for this service ID, call forwarded, cycle detected", func(t *testing.T) { m, _ := setupModule(t, storageEngine, func(module *Module) { @@ -296,11 +300,11 @@ func Test_Module_Get(t *testing.T) { Endpoint: "https://example.com/someother", } mockhttpclient := module.httpClient.(*client.MockHTTPClient) - mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", 0).Return(nil, 0, nil).AnyTimes() + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", 0).Return(nil, "", 0, nil).AnyTimes() }) ctx := context.WithValue(ctx, XForwardedHostContextKey{}, "https://example.com") - _, _, err := m.Get(ctx, "someother", 0) + _, _, _, err := m.Get(ctx, "someother", 0) assert.ErrorIs(t, err, errCyclicForwardingDetected) }) @@ -325,10 +329,23 @@ func setupModule(t *testing.T, storageInstance storage.Engine, visitors ...func( m.config = DefaultConfig() m.publicURL = test.MustParseURL("https://example.com") require.NoError(t, m.Configure(core.TestServerConfig())) + httpClient := client.NewMockHTTPClient(ctrl) - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, 0, nil).AnyTimes() - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return(nil, 0, nil).AnyTimes() - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/unsupported", gomock.Any()).Return(nil, 0, nil).AnyTimes() + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, testSeed, 0, nil).AnyTimes() + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return(nil, testSeed, 0, nil).AnyTimes() + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/unsupported", gomock.Any()).Return(nil, testSeed, 0, nil).AnyTimes() + // set seed in DB otherwise behaviour is unpredictable due to background processes + if m.store != nil { + require.NoError(t, m.store.db.Transaction(func(tx *gorm.DB) error { + service := serviceRecord{ + ID: testServiceID, + Seed: testSeed, + LastLamportTimestamp: 0, + } + return tx.Save(&service).Error + })) + } + m.httpClient = httpClient m.allDefinitions = testDefinitions() m.serverDefinitions = map[string]ServiceDefinition{ @@ -426,7 +443,7 @@ func TestModule_Search(t *testing.T) { t.Run("ok", func(t *testing.T) { m, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.store.add(testServiceID, vpAlice, testSeed, 0)) results, err := m.Search(testServiceID, map[string]string{ "credentialSubject.person.givenName": "Alice", @@ -461,7 +478,7 @@ func TestModule_update(t *testing.T) { // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) // Get() should be called at least twice (times the number of Service Definitions), once for the initial run on startup, then again after the refresh interval - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil).MinTimes(2 * len(module.allDefinitions)) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", 0, nil).MinTimes(2 * len(module.allDefinitions)) module.httpClient = httpClient }) time.Sleep(10 * time.Millisecond) @@ -473,7 +490,7 @@ func TestModule_update(t *testing.T) { // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) // update causes call to HttpClient.Get(), once for each Service Definition - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil).Times(len(module.allDefinitions)) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", 0, nil).Times(len(module.allDefinitions)) module.httpClient = httpClient }) }) @@ -487,7 +504,7 @@ func TestModule_ActivateServiceForSubject(t *testing.T) { // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) httpClient.EXPECT().Register(gomock.Any(), gomock.Any(), vpAlice).Return(nil) - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", 0, nil) module.httpClient = httpClient // disable auto-refresh job to have deterministic assertions module.config.Client.RefreshInterval = 0 @@ -511,7 +528,7 @@ func TestModule_ActivateServiceForSubject(t *testing.T) { // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) httpClient.EXPECT().Register(gomock.Any(), gomock.Any(), vpAlice).Return(nil) - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", 0, nil) module.httpClient = httpClient // disable auto-refresh job to have deterministic assertions module.config.Client.RefreshInterval = 0 @@ -614,7 +631,7 @@ func TestModule_GetServiceActivation(t *testing.T) { testContext.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil).AnyTimes() next := time.Now() _ = m.store.updatePresentationRefreshTime(testServiceID, aliceSubject, nil, &next) - _ = m.store.add(testServiceID, vpAlice, 0) + _ = m.store.add(testServiceID, vpAlice, testSeed, 0) activated, presentation, err := m.GetServiceActivation(context.Background(), testServiceID, aliceSubject) @@ -633,3 +650,13 @@ func TestModule_GetServiceActivation(t *testing.T) { }) }) } + +func checkWriteAccess(dir string) bool { + info, err := os.Stat(dir) + if err != nil { + return false + } + + // Check if the directory is writable by the current user + return info.Mode().Perm()&(1<<(uint(7))) != 0 +} diff --git a/discovery/store.go b/discovery/store.go index d09629185..9859d0b67 100644 --- a/discovery/store.go +++ b/discovery/store.go @@ -38,6 +38,7 @@ import ( type serviceRecord struct { ID string `gorm:"primaryKey"` + Seed string LastLamportTimestamp int } @@ -135,7 +136,7 @@ func newSQLStore(db *gorm.DB, clientDefinitions map[string]ServiceDefinition) (* // add adds a presentation to the list of presentations. // If the given timestamp is 0, the server will assign a timestamp. -func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, timestamp int) error { +func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, seed string, timestamp int) error { credentialSubjectID, err := credential.PresentationSigner(presentation) if err != nil { return err @@ -146,13 +147,16 @@ func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, return s.db.Transaction(func(tx *gorm.DB) error { if timestamp == 0 { var newTs *int - newTs, err = s.incrementTimestamp(tx, serviceID) + if len(seed) == 0 { // default for server + seed = uuid.NewString() + } + newTs, err = s.incrementTimestamp(tx, serviceID, seed) if err != nil { return err } timestamp = *newTs } else { - err = s.setTimestamp(tx, serviceID, timestamp) + err = s.setTimestamp(tx, serviceID, seed, timestamp) if err != nil { return err } @@ -202,26 +206,26 @@ func storePresentation(tx *gorm.DB, serviceID string, timestamp int, presentatio // get returns all presentations, registered on the given service, starting after the given timestamp. // It also returns the latest timestamp of the returned presentations. -func (s *sqlStore) get(serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) { +func (s *sqlStore) get(serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, string, int, error) { var service serviceRecord if err := s.db.Find(&service, "id = ?", serviceID).Error; err != nil { - return nil, 0, fmt.Errorf("query service '%s': %w", serviceID, err) + return nil, "", 0, fmt.Errorf("query service '%s': %w", serviceID, err) } var rows []presentationRecord err := s.db.Order("lamport_timestamp ASC").Find(&rows, "service_id = ? AND lamport_timestamp > ?", serviceID, startAfter).Error if err != nil { - return nil, 0, fmt.Errorf("query service '%s': %w", serviceID, err) + return nil, "", 0, fmt.Errorf("query service '%s': %w", serviceID, err) } presentations := make(map[string]vc.VerifiablePresentation, len(rows)) for _, row := range rows { presentation, err := vc.ParseVerifiablePresentation(row.PresentationRaw) if err != nil { - return nil, 0, fmt.Errorf("parse presentation '%s' of service '%s': %w", row.PresentationID, serviceID, err) + return nil, "", 0, fmt.Errorf("parse presentation '%s' of service '%s': %w", row.PresentationID, serviceID, err) } presentations[fmt.Sprintf("%d", row.LamportTimestamp)] = *presentation } - return presentations, service.LastLamportTimestamp, nil + return presentations, service.Seed, service.LastLamportTimestamp, nil } // search searches for presentations, registered on the given service, matching the given query. @@ -256,14 +260,17 @@ func (s *sqlStore) search(serviceID string, query map[string]string) ([]vc.Verif return results, nil } -// incrementTimestamp increments the last_timestamp of the given service. -func (s *sqlStore) incrementTimestamp(tx *gorm.DB, serviceID string) (*int, error) { +// incrementTimestamp increments the last_timestamp of the given service. USed by server. +func (s *sqlStore) incrementTimestamp(tx *gorm.DB, serviceID string, seed string) (*int, error) { service, err := s.findAndLockService(tx, serviceID) if err != nil { return nil, err } service.ID = serviceID service.LastLamportTimestamp = service.LastLamportTimestamp + 1 + if len(service.Seed) == 0 { // first time this service is used, generate a new testSeed + service.Seed = seed + } if err := tx.Save(service).Error; err != nil { return nil, err @@ -271,14 +278,15 @@ func (s *sqlStore) incrementTimestamp(tx *gorm.DB, serviceID string) (*int, erro return &service.LastLamportTimestamp, nil } -// setTimestamp sets the last_timestamp of the given service. -func (s *sqlStore) setTimestamp(tx *gorm.DB, serviceID string, timestamp int) error { +// setTimestamp sets the last_timestamp of the given service. Used by clients. +func (s *sqlStore) setTimestamp(tx *gorm.DB, serviceID string, seed string, timestamp int) error { service, err := s.findAndLockService(tx, serviceID) if err != nil { return err } service.ID = serviceID service.LastLamportTimestamp = timestamp + service.Seed = seed return tx.Save(service).Error } @@ -496,3 +504,30 @@ func (s *sqlStore) getSubjectVPsOnService(serviceID string, subjectDIDs []did.DI } return result, nil } + +// wipeOnSeedChange wipes the store on a testSeed change. +func (s *sqlStore) wipeOnSeedChange(serviceID string, seed string) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // get the service + service, err := s.findAndLockService(tx, serviceID) + if err != nil { + return err + } + if service.Seed != seed && len(service.Seed) > 0 { + log.Logger(). + WithField("serviceID", serviceID). + Warnf("Seed changed, wiping store (old: %s, new: %s)", service.Seed, seed) + + // wipe the store + if err = tx.Where("service_id = ?", serviceID).Delete(&presentationRecord{}).Error; err != nil { + return err + } + + // reset the testSeed and timestamp + service.Seed = seed + service.LastLamportTimestamp = 0 + return tx.Save(service).Error + } + return nil + }) +} diff --git a/discovery/store_test.go b/discovery/store_test.go index b782cb0a4..812bba212 100644 --- a/discovery/store_test.go +++ b/discovery/store_test.go @@ -45,21 +45,21 @@ func Test_sqlStore_exists(t *testing.T) { }) t.Run("non-empty list, no match (other subject and ID)", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpBob, 0)) + require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.False(t, exists) }) t.Run("non-empty list, no match (other list)", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) exists, err := m.exists("other", aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.False(t, exists) }) t.Run("non-empty list, match", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.True(t, exists) @@ -72,13 +72,34 @@ func Test_sqlStore_add(t *testing.T) { t.Run("no credentials in presentation", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - err := m.add(testServiceID, createPresentation(aliceDID), 0) + err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 0) assert.NoError(t, err) }) + t.Run("seed", func(t *testing.T) { + t.Run("passing seed updates last_seed", func(t *testing.T) { + m := setupStore(t, storageEngine.GetSQLDatabase()) + require.NoError(t, m.add(testServiceID, createPresentation(aliceDID), testSeed, 0)) + + _, seed, _, err := m.get(testServiceID, 0) + + require.NoError(t, err) + assert.Equal(t, testSeed, seed) + }) + t.Run("generated seed", func(t *testing.T) { + m := setupStore(t, storageEngine.GetSQLDatabase()) + require.NoError(t, m.add(testServiceID, createPresentation(aliceDID), "", 0)) + + _, seed, _, err := m.get(testServiceID, 0) + + require.NoError(t, err) + assert.Len(t, seed, 36) // uuid v4 + }) + }) + t.Run("passing timestamp updates last_timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - err := m.add(testServiceID, createPresentation(aliceDID), 1) + err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 1) require.NoError(t, err) timestamp, err := m.getTimestamp(testServiceID) @@ -91,8 +112,8 @@ func Test_sqlStore_add(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) secondVP := createPresentation(aliceDID, vcAlice) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) - require.NoError(t, m.add(testServiceID, secondVP, 0)) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + require.NoError(t, m.add(testServiceID, secondVP, testSeed, 0)) // First VP should not exist exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) @@ -112,42 +133,44 @@ func Test_sqlStore_get(t *testing.T) { t.Run("empty list, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - presentations, timestamp, err := m.get(testServiceID, 0) + presentations, seed, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) assert.Empty(t, presentations) assert.Equal(t, 0, timestamp) + assert.Empty(t, seed) }) t.Run("1 entry, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) - presentations, timestamp, err := m.get(testServiceID, 0) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + presentations, seed, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice}, presentations) assert.Equal(t, 1, timestamp) + assert.Equal(t, testSeed, seed) }) t.Run("2 entries, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) - require.NoError(t, m.add(testServiceID, vpBob, 0)) - presentations, timestamp, err := m.get(testServiceID, 0) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + presentations, _, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice, "2": vpBob}, presentations) assert.Equal(t, 2, timestamp) }) t.Run("2 entries, start after first", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) - require.NoError(t, m.add(testServiceID, vpBob, 0)) - presentations, timestamp, err := m.get(testServiceID, 1) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + presentations, _, timestamp, err := m.get(testServiceID, 1) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"2": vpBob}, presentations) assert.Equal(t, 2, timestamp) }) t.Run("2 entries, start at end", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, 0)) - require.NoError(t, m.add(testServiceID, vpBob, 0)) - presentations, timestamp, err := m.get(testServiceID, 2) + require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + presentations, _, timestamp, err := m.get(testServiceID, 2) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{}, presentations) assert.Equal(t, 2, timestamp) @@ -159,7 +182,7 @@ func Test_sqlStore_get(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err := c.add(testServiceID, createPresentation(aliceDID, vcAlice), 0) + err := c.add(testServiceID, createPresentation(aliceDID, vcAlice), testSeed, 0) require.NoError(t, err) }() } @@ -185,7 +208,7 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, 0) + err := c.add(testServiceID, vp, testSeed, 0) require.NoError(t, err) } @@ -200,7 +223,7 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice, vpBob} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, 0) + err := c.add(testServiceID, vp, testSeed, 0) require.NoError(t, err) } @@ -212,7 +235,7 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice, vpBob} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, 0) + err := c.add(testServiceID, vp, testSeed, 0) require.NoError(t, err) } actualVPs, err := c.search(testServiceID, map[string]string{ @@ -350,8 +373,8 @@ func Test_sqlStore_getSubjectVPsOnService(t *testing.T) { _ = storageEngine.Shutdown() }) c := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, c.add(testServiceID, vpAlice2, 0)) - require.NoError(t, c.add(testServiceID, vpBob2, 0)) + require.NoError(t, c.add(testServiceID, vpAlice2, testSeed, 0)) + require.NoError(t, c.add(testServiceID, vpBob2, testSeed, 0)) t.Run("ok - single", func(t *testing.T) { vps, err := c.getSubjectVPsOnService(testServiceID, []did.DID{aliceDID}) @@ -365,6 +388,36 @@ func Test_sqlStore_getSubjectVPsOnService(t *testing.T) { }) } +func Test_sqlStore_wipeOnSeedChange(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + storageEngine := storage.NewTestStorageEngine(t) + require.NoError(t, storageEngine.Start()) + t.Cleanup(func() { + _ = storageEngine.Shutdown() + }) + + t.Run("empty database", func(t *testing.T) { + c := setupStore(t, storageEngine.GetSQLDatabase()) + err := c.wipeOnSeedChange(testServiceID, "other") + require.NoError(t, err) + }) + t.Run("1 entry wiped, 1 remains", func(t *testing.T) { + c := setupStore(t, storageEngine.GetSQLDatabase()) + require.NoError(t, c.add(testServiceID, vpAlice, testSeed, 0)) + require.NoError(t, c.add("other", vpAlice, testSeed, 0)) + + err := c.wipeOnSeedChange(testServiceID, "other") + require.NoError(t, err) + + vps, err := c.search(testServiceID, map[string]string{}) + require.NoError(t, err) + require.Len(t, vps, 0) + vps, err = c.search("other", map[string]string{}) + require.NoError(t, err) + require.Len(t, vps, 1) + }) +} + func setupStore(t *testing.T, db *gorm.DB) *sqlStore { resetStore(t, db) defs := testDefinitions() diff --git a/discovery/test.go b/discovery/test.go index 222f2f9e2..3c927495f 100644 --- a/discovery/test.go +++ b/discovery/test.go @@ -40,6 +40,8 @@ import ( "time" ) +const testSeed = "1234567890" + var keyPairs map[string]*ecdsa.PrivateKey var authorityDID did.DID var aliceSubject string diff --git a/docs/_static/discovery/server.yaml b/docs/_static/discovery/server.yaml index 4bf854784..9e8b8ccbc 100644 --- a/docs/_static/discovery/server.yaml +++ b/docs/_static/discovery/server.yaml @@ -76,9 +76,13 @@ components: PresentationsResponse: type: object required: + - seed - timestamp - entries properties: + seed: + description: unique value for the combination of serviceID and a server instance. + type: string timestamp: description: highest timestamp of the returned presentations, should be used as the timestamp for the next query type: integer diff --git a/storage/sql_migrations/009_discoveryservice_seed.sql b/storage/sql_migrations/009_discoveryservice_seed.sql new file mode 100644 index 000000000..4d0e48a9f --- /dev/null +++ b/storage/sql_migrations/009_discoveryservice_seed.sql @@ -0,0 +1,6 @@ +-- +goose Up +-- discovery_service: add seed column +alter table discovery_service add seed varchar(36); + +-- +goose Down +alter table discovery_service drop column seed;