From 2d736717dd4e0cd36b55d037fa5f5da5e98fbf43 Mon Sep 17 00:00:00 2001 From: PavelBrm Date: Mon, 2 Sep 2024 21:38:24 +1200 Subject: [PATCH] fix: implement truncating logic --- services/skus/credentials.go | 62 ++++++++++++++--- services/skus/service.go | 2 +- services/skus/service_nonint_test.go | 66 +++++++++---------- services/skus/storage/repository/mock.go | 6 +- services/skus/storage/repository/tlv2.go | 8 +-- services/skus/storage/repository/tlv2_test.go | 37 ++++------- 6 files changed, 105 insertions(+), 76 deletions(-) diff --git a/services/skus/credentials.go b/services/skus/credentials.go index aaf67be6e..159354653 100644 --- a/services/skus/credentials.go +++ b/services/skus/credentials.go @@ -245,7 +245,12 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI return errItemDoesNotExist } - if err := s.doCredentialsExist(ctx, requestID, item, blindedCreds); err != nil { + nbcreds := len(blindedCreds) + if nbcreds == 0 { + return model.ErrTLV2InvalidCredNum + } + + if err := s.doCredentialsExist(ctx, requestID, item, blindedCreds[0]); err != nil { if errors.Is(err, errCredsAlreadySubmitted) { return nil } @@ -255,8 +260,9 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI // Check if the order is for Leo and numIntervals is 8. // If yes, then truncate credentials to the desired number 576. + creds := truncateTLV2BCreds(order, item, nbcreds, blindedCreds) - if err := checkNumBlindedCreds(order, item, len(blindedCreds)); err != nil { + if err := checkNumBlindedCreds(order, item, len(creds)); err != nil { return err } @@ -288,7 +294,7 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI { IssuerType: issuerID, IssuerCohort: defaultCohort, - BlindedTokens: blindedCreds, + BlindedTokens: creds, AssociatedData: associatedData, }, }, @@ -301,7 +307,7 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI return nil } -func (s *Service) doCredentialsExist(ctx context.Context, requestID uuid.UUID, item *model.OrderItem, blindedCreds []string) error { +func (s *Service) doCredentialsExist(ctx context.Context, requestID uuid.UUID, item *model.OrderItem, firstBCred string) error { switch item.CredentialType { case timeLimitedV2: // NOTE: There was a possible race condition that would allow exceeding limits on the number of cred batches. @@ -309,25 +315,25 @@ func (s *Service) doCredentialsExist(ctx context.Context, requestID uuid.UUID, i // - checking the number of active batches before accepting a request to create creds; // - checking the number of active batches before inserting the signed creds. - return s.doTLV2Exist(ctx, requestID, item, blindedCreds) + return s.doTLV2Exist(ctx, requestID, item, firstBCred) default: return s.doCredsExist(ctx, item) } } -func (s *Service) doTLV2Exist(ctx context.Context, reqID uuid.UUID, item *model.OrderItem, bcreds []string) error { +func (s *Service) doTLV2Exist(ctx context.Context, reqID uuid.UUID, item *model.OrderItem, firstBCred string) error { now := time.Now() - return s.doTLV2ExistTxTime(ctx, s.Datastore.RawDB(), reqID, item, bcreds, now, now) + return s.doTLV2ExistTxTime(ctx, s.Datastore.RawDB(), reqID, item, firstBCred, now, now) } -func (s *Service) doTLV2ExistTxTime(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID, item *model.OrderItem, bcreds []string, from, to time.Time) error { +func (s *Service) doTLV2ExistTxTime(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID, item *model.OrderItem, firstBCred string, from, to time.Time) error { if item.CredentialType != timeLimitedV2 { return model.ErrUnsupportedCredType } // Check TLV2 to see if we have credentials signed that match incoming blinded tokens. - report, err := s.tlv2Repo.GetCredSubmissionReport(ctx, dbi, item.OrderID, item.ID, reqID, bcreds...) + report, err := s.tlv2Repo.GetCredSubmissionReport(ctx, dbi, item.OrderID, item.ID, reqID, firstBCred) if err != nil { return err } @@ -801,3 +807,41 @@ func checkTLV2BatchLimit(lim, nact int) error { return nil } + +func truncateTLV2BCreds(ord *model.Order, item *model.OrderItem, ncreds int, srcCreds []string) []string { + result := srcCreds + if targetn, ok := shouldTruncateTLV2Creds(ord, item, ncreds); ok { + result = srcCreds[:targetn] + } + + return result +} + +func shouldTruncateTLV2Creds(ord *model.Order, item *model.OrderItem, ncreds int) (int, bool) { + if !item.IsLeo() { + return 0, false + } + + numi, err := ord.NumIntervals() + if err != nil { + // Safe fallback. + return 0, false + } + + if numi == 3 { + return 0, false + } + + numpi, err := ord.NumPerInterval() + if err != nil { + // Safe fallback. + return 0, false + } + + target := numi * numpi + if ncreds <= target { + return 0, false + } + + return target, true +} diff --git a/services/skus/service.go b/services/skus/service.go index 54328a9a5..ce9f7ac33 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -105,7 +105,7 @@ type orderStoreSvc interface { } type tlv2Store interface { - GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) + GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) UniqBatches(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) DeleteLegacy(ctx context.Context, dbi sqlx.ExecerContext, orderID uuid.UUID) error } diff --git a/services/skus/service_nonint_test.go b/services/skus/service_nonint_test.go index 3826ff11a..df7e8cfc4 100644 --- a/services/skus/service_nonint_test.go +++ b/services/skus/service_nonint_test.go @@ -1838,12 +1838,12 @@ func TestService_checkOrderReceipt(t *testing.T) { func TestService_doTLV2ExistTxTime(t *testing.T) { type tcGiven struct { - reqID uuid.UUID - item *model.OrderItem - creds []string - from time.Time - to time.Time - repo *repository.MockTLV2 + reqID uuid.UUID + item *model.OrderItem + firstBCred string + from time.Time + to time.Time + repo *repository.MockTLV2 } type testCase struct { @@ -1862,10 +1862,10 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - repo: &repository.MockTLV2{}, + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + repo: &repository.MockTLV2{}, }, exp: model.ErrUnsupportedCredType, }, @@ -1879,11 +1879,11 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited-v2", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), repo: &repository.MockTLV2{ - FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) { + FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) { return model.TLV2CredSubmissionReport{}, model.Error("something_went_wrong") }, }, @@ -1900,11 +1900,11 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited-v2", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), repo: &repository.MockTLV2{ - FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) { + FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) { return model.TLV2CredSubmissionReport{Submitted: true}, nil }, }, @@ -1921,11 +1921,11 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited-v2", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), repo: &repository.MockTLV2{ - FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) { + FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) { return model.TLV2CredSubmissionReport{ReqIDMismatch: true}, nil }, }, @@ -1942,9 +1942,9 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited-v2", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), repo: &repository.MockTLV2{ FnUniqBatches: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) { return 0, model.Error("something_went_wrong") @@ -1963,9 +1963,9 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited-v2", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), repo: &repository.MockTLV2{ FnUniqBatches: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) { return 10, nil @@ -1984,10 +1984,10 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), CredentialType: "time-limited-v2", }, - creds: []string{"cred_01", "cred_02"}, - from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), - repo: &repository.MockTLV2{}, + firstBCred: "cred_01", + from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + repo: &repository.MockTLV2{}, }, }, } @@ -2000,7 +2000,7 @@ func TestService_doTLV2ExistTxTime(t *testing.T) { ctx := context.Background() - actual := svc.doTLV2ExistTxTime(ctx, nil, tc.given.reqID, tc.given.item, tc.given.creds, tc.given.from, tc.given.to) + actual := svc.doTLV2ExistTxTime(ctx, nil, tc.given.reqID, tc.given.item, tc.given.firstBCred, tc.given.from, tc.given.to) should.Equal(t, tc.exp, actual) }) } diff --git a/services/skus/storage/repository/mock.go b/services/skus/storage/repository/mock.go index 6605de639..59cefaa7b 100644 --- a/services/skus/storage/repository/mock.go +++ b/services/skus/storage/repository/mock.go @@ -218,17 +218,17 @@ func (r *MockOrderPayHistory) Insert(ctx context.Context, dbi sqlx.ExecerContext } type MockTLV2 struct { - FnGetCredSubmissionReport func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) + FnGetCredSubmissionReport func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) FnUniqBatches func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) FnDeleteLegacy func(ctx context.Context, dbi sqlx.ExecerContext, orderID uuid.UUID) error } -func (r *MockTLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) { +func (r *MockTLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) { if r.FnGetCredSubmissionReport == nil { return model.TLV2CredSubmissionReport{}, nil } - return r.FnGetCredSubmissionReport(ctx, dbi, orderID, itemID, reqID, creds...) + return r.FnGetCredSubmissionReport(ctx, dbi, orderID, itemID, reqID, firstBCred) } func (r *MockTLV2) UniqBatches(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) { diff --git a/services/skus/storage/repository/tlv2.go b/services/skus/storage/repository/tlv2.go index f7172d28a..b29dae215 100644 --- a/services/skus/storage/repository/tlv2.go +++ b/services/skus/storage/repository/tlv2.go @@ -14,11 +14,7 @@ type TLV2 struct{} func NewTLV2() *TLV2 { return &TLV2{} } -func (r *TLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) { - if len(creds) == 0 { - return model.TLV2CredSubmissionReport{}, model.ErrTLV2InvalidCredNum - } - +func (r *TLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) { const q = `SELECT EXISTS( SELECT 1 FROM time_limited_v2_order_creds WHERE order_id=$1 AND item_id=$2 AND blinded_creds->>0 = $4 ) AS submitted, EXISTS( @@ -26,7 +22,7 @@ func (r *TLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerCont ) AS req_id_mismatch` result := model.TLV2CredSubmissionReport{} - if err := sqlx.GetContext(ctx, dbi, &result, q, orderID, itemID, reqID, creds[0]); err != nil { + if err := sqlx.GetContext(ctx, dbi, &result, q, orderID, itemID, reqID, firstBCred); err != nil { return model.TLV2CredSubmissionReport{}, err } diff --git a/services/skus/storage/repository/tlv2_test.go b/services/skus/storage/repository/tlv2_test.go index 90449d15d..77207bb2d 100644 --- a/services/skus/storage/repository/tlv2_test.go +++ b/services/skus/storage/repository/tlv2_test.go @@ -26,10 +26,10 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) { }() type tcGiven struct { - orderID uuid.UUID - itemID uuid.UUID - reqID uuid.UUID - creds []string + orderID uuid.UUID + itemID uuid.UUID + reqID uuid.UUID + firstBCred string fnBefore func(ctx context.Context, dbi sqlx.ExtContext) error } @@ -46,24 +46,13 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) { } tests := []testCase{ - { - name: "invalid_param", - given: tcGiven{ - orderID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), - itemID: uuid.Must(uuid.FromString("decade00-0000-4000-a000-000000000000")), - reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")), - fnBefore: func(ctx context.Context, dbi sqlx.ExtContext) error { return nil }, - }, - exp: tcExpected{err: model.ErrTLV2InvalidCredNum}, - }, - { name: "submitted", given: tcGiven{ - orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), - itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")), - reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")), - creds: []string{"cred_01", "cred_02", "cred_03"}, + orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), + itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")), + reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")), + firstBCred: "cred_01", fnBefore: func(ctx context.Context, dbi sqlx.ExtContext) error { qs := []string{ @@ -97,10 +86,10 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) { { name: "mismatch", given: tcGiven{ - orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), - itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")), - reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")), - creds: []string{"cred_01", "cred_02", "cred_03"}, + orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")), + itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")), + reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")), + firstBCred: "cred_01", fnBefore: func(ctx context.Context, dbi sqlx.ExtContext) error { qs := []string{ @@ -151,7 +140,7 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) { must.Equal(t, nil, err) } - actual, err := repo.GetCredSubmissionReport(ctx, tx, tc.given.orderID, tc.given.itemID, tc.given.reqID, tc.given.creds...) + actual, err := repo.GetCredSubmissionReport(ctx, tx, tc.given.orderID, tc.given.itemID, tc.given.reqID, tc.given.firstBCred) must.Equal(t, tc.exp.err, err) if tc.exp.err != nil {