Skip to content

Commit

Permalink
fix: implement truncating logic
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelbrm committed Sep 2, 2024
1 parent dd30e6c commit 2d73671
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 76 deletions.
62 changes: 53 additions & 9 deletions services/skus/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI
return errItemDoesNotExist
}

if err := s.doCredentialsExist(ctx, requestID, item, blindedCreds); err != nil {
nbcreds := len(blindedCreds)
if nbcreds == 0 {
return model.ErrTLV2InvalidCredNum
}

if err := s.doCredentialsExist(ctx, requestID, item, blindedCreds[0]); err != nil {
if errors.Is(err, errCredsAlreadySubmitted) {
return nil
}
Expand All @@ -255,8 +260,9 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI

// Check if the order is for Leo and numIntervals is 8.
// If yes, then truncate credentials to the desired number 576.
creds := truncateTLV2BCreds(order, item, nbcreds, blindedCreds)

if err := checkNumBlindedCreds(order, item, len(blindedCreds)); err != nil {
if err := checkNumBlindedCreds(order, item, len(creds)); err != nil {
return err
}

Expand Down Expand Up @@ -288,7 +294,7 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI
{
IssuerType: issuerID,
IssuerCohort: defaultCohort,
BlindedTokens: blindedCreds,
BlindedTokens: creds,
AssociatedData: associatedData,
},
},
Expand All @@ -301,33 +307,33 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI
return nil
}

func (s *Service) doCredentialsExist(ctx context.Context, requestID uuid.UUID, item *model.OrderItem, blindedCreds []string) error {
func (s *Service) doCredentialsExist(ctx context.Context, requestID uuid.UUID, item *model.OrderItem, firstBCred string) error {
switch item.CredentialType {
case timeLimitedV2:
// NOTE: There was a possible race condition that would allow exceeding limits on the number of cred batches.
// The condition is currently mitigated by:
// - checking the number of active batches before accepting a request to create creds;
// - checking the number of active batches before inserting the signed creds.

return s.doTLV2Exist(ctx, requestID, item, blindedCreds)
return s.doTLV2Exist(ctx, requestID, item, firstBCred)
default:
return s.doCredsExist(ctx, item)
}
}

func (s *Service) doTLV2Exist(ctx context.Context, reqID uuid.UUID, item *model.OrderItem, bcreds []string) error {
func (s *Service) doTLV2Exist(ctx context.Context, reqID uuid.UUID, item *model.OrderItem, firstBCred string) error {
now := time.Now()

return s.doTLV2ExistTxTime(ctx, s.Datastore.RawDB(), reqID, item, bcreds, now, now)
return s.doTLV2ExistTxTime(ctx, s.Datastore.RawDB(), reqID, item, firstBCred, now, now)
}

func (s *Service) doTLV2ExistTxTime(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID, item *model.OrderItem, bcreds []string, from, to time.Time) error {
func (s *Service) doTLV2ExistTxTime(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID, item *model.OrderItem, firstBCred string, from, to time.Time) error {
if item.CredentialType != timeLimitedV2 {
return model.ErrUnsupportedCredType
}

// Check TLV2 to see if we have credentials signed that match incoming blinded tokens.
report, err := s.tlv2Repo.GetCredSubmissionReport(ctx, dbi, item.OrderID, item.ID, reqID, bcreds...)
report, err := s.tlv2Repo.GetCredSubmissionReport(ctx, dbi, item.OrderID, item.ID, reqID, firstBCred)
if err != nil {
return err
}
Expand Down Expand Up @@ -801,3 +807,41 @@ func checkTLV2BatchLimit(lim, nact int) error {

return nil
}

func truncateTLV2BCreds(ord *model.Order, item *model.OrderItem, ncreds int, srcCreds []string) []string {
result := srcCreds
if targetn, ok := shouldTruncateTLV2Creds(ord, item, ncreds); ok {
result = srcCreds[:targetn]
}

return result
}

func shouldTruncateTLV2Creds(ord *model.Order, item *model.OrderItem, ncreds int) (int, bool) {
if !item.IsLeo() {
return 0, false
}

numi, err := ord.NumIntervals()
if err != nil {
// Safe fallback.
return 0, false
}

if numi == 3 {
return 0, false
}

numpi, err := ord.NumPerInterval()
if err != nil {
// Safe fallback.
return 0, false
}

target := numi * numpi
if ncreds <= target {
return 0, false
}

return target, true
}
2 changes: 1 addition & 1 deletion services/skus/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ type orderStoreSvc interface {
}

type tlv2Store interface {
GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error)
GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error)
UniqBatches(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error)
DeleteLegacy(ctx context.Context, dbi sqlx.ExecerContext, orderID uuid.UUID) error
}
Expand Down
66 changes: 33 additions & 33 deletions services/skus/service_nonint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1838,12 +1838,12 @@ func TestService_checkOrderReceipt(t *testing.T) {

func TestService_doTLV2ExistTxTime(t *testing.T) {
type tcGiven struct {
reqID uuid.UUID
item *model.OrderItem
creds []string
from time.Time
to time.Time
repo *repository.MockTLV2
reqID uuid.UUID
item *model.OrderItem
firstBCred string
from time.Time
to time.Time
repo *repository.MockTLV2
}

type testCase struct {
Expand All @@ -1862,10 +1862,10 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{},
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{},
},
exp: model.ErrUnsupportedCredType,
},
Expand All @@ -1879,11 +1879,11 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited-v2",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{
FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) {
FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) {
return model.TLV2CredSubmissionReport{}, model.Error("something_went_wrong")
},
},
Expand All @@ -1900,11 +1900,11 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited-v2",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{
FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) {
FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) {
return model.TLV2CredSubmissionReport{Submitted: true}, nil
},
},
Expand All @@ -1921,11 +1921,11 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited-v2",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{
FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) {
FnGetCredSubmissionReport: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) {
return model.TLV2CredSubmissionReport{ReqIDMismatch: true}, nil
},
},
Expand All @@ -1942,9 +1942,9 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited-v2",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{
FnUniqBatches: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) {
return 0, model.Error("something_went_wrong")
Expand All @@ -1963,9 +1963,9 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited-v2",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{
FnUniqBatches: func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) {
return 10, nil
Expand All @@ -1984,10 +1984,10 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {
OrderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
CredentialType: "time-limited-v2",
},
creds: []string{"cred_01", "cred_02"},
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{},
firstBCred: "cred_01",
from: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
to: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC),
repo: &repository.MockTLV2{},
},
},
}
Expand All @@ -2000,7 +2000,7 @@ func TestService_doTLV2ExistTxTime(t *testing.T) {

ctx := context.Background()

actual := svc.doTLV2ExistTxTime(ctx, nil, tc.given.reqID, tc.given.item, tc.given.creds, tc.given.from, tc.given.to)
actual := svc.doTLV2ExistTxTime(ctx, nil, tc.given.reqID, tc.given.item, tc.given.firstBCred, tc.given.from, tc.given.to)
should.Equal(t, tc.exp, actual)
})
}
Expand Down
6 changes: 3 additions & 3 deletions services/skus/storage/repository/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,17 @@ func (r *MockOrderPayHistory) Insert(ctx context.Context, dbi sqlx.ExecerContext
}

type MockTLV2 struct {
FnGetCredSubmissionReport func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error)
FnGetCredSubmissionReport func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error)
FnUniqBatches func(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error)
FnDeleteLegacy func(ctx context.Context, dbi sqlx.ExecerContext, orderID uuid.UUID) error
}

func (r *MockTLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) {
func (r *MockTLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) {
if r.FnGetCredSubmissionReport == nil {
return model.TLV2CredSubmissionReport{}, nil
}

return r.FnGetCredSubmissionReport(ctx, dbi, orderID, itemID, reqID, creds...)
return r.FnGetCredSubmissionReport(ctx, dbi, orderID, itemID, reqID, firstBCred)
}

func (r *MockTLV2) UniqBatches(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID uuid.UUID, from, to time.Time) (int, error) {
Expand Down
8 changes: 2 additions & 6 deletions services/skus/storage/repository/tlv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@ type TLV2 struct{}

func NewTLV2() *TLV2 { return &TLV2{} }

func (r *TLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, creds ...string) (model.TLV2CredSubmissionReport, error) {
if len(creds) == 0 {
return model.TLV2CredSubmissionReport{}, model.ErrTLV2InvalidCredNum
}

func (r *TLV2) GetCredSubmissionReport(ctx context.Context, dbi sqlx.QueryerContext, orderID, itemID, reqID uuid.UUID, firstBCred string) (model.TLV2CredSubmissionReport, error) {
const q = `SELECT EXISTS(
SELECT 1 FROM time_limited_v2_order_creds WHERE order_id=$1 AND item_id=$2 AND blinded_creds->>0 = $4
) AS submitted, EXISTS(
SELECT 1 FROM time_limited_v2_order_creds WHERE order_id=$1 AND item_id=$2 AND request_id = $3 AND blinded_creds->>0 != $4
) AS req_id_mismatch`

result := model.TLV2CredSubmissionReport{}
if err := sqlx.GetContext(ctx, dbi, &result, q, orderID, itemID, reqID, creds[0]); err != nil {
if err := sqlx.GetContext(ctx, dbi, &result, q, orderID, itemID, reqID, firstBCred); err != nil {
return model.TLV2CredSubmissionReport{}, err
}

Expand Down
37 changes: 13 additions & 24 deletions services/skus/storage/repository/tlv2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) {
}()

type tcGiven struct {
orderID uuid.UUID
itemID uuid.UUID
reqID uuid.UUID
creds []string
orderID uuid.UUID
itemID uuid.UUID
reqID uuid.UUID
firstBCred string

fnBefore func(ctx context.Context, dbi sqlx.ExtContext) error
}
Expand All @@ -46,24 +46,13 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) {
}

tests := []testCase{
{
name: "invalid_param",
given: tcGiven{
orderID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")),
itemID: uuid.Must(uuid.FromString("decade00-0000-4000-a000-000000000000")),
reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")),
fnBefore: func(ctx context.Context, dbi sqlx.ExtContext) error { return nil },
},
exp: tcExpected{err: model.ErrTLV2InvalidCredNum},
},

{
name: "submitted",
given: tcGiven{
orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")),
reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")),
creds: []string{"cred_01", "cred_02", "cred_03"},
orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")),
reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")),
firstBCred: "cred_01",

fnBefore: func(ctx context.Context, dbi sqlx.ExtContext) error {
qs := []string{
Expand Down Expand Up @@ -97,10 +86,10 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) {
{
name: "mismatch",
given: tcGiven{
orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")),
reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")),
creds: []string{"cred_01", "cred_02", "cred_03"},
orderID: uuid.Must(uuid.FromString("c0c0a000-0000-4000-a000-000000000000")),
itemID: uuid.Must(uuid.FromString("ad0be000-0000-4000-a000-000000000000")),
reqID: uuid.Must(uuid.FromString("f100ded0-0000-4000-a000-000000000000")),
firstBCred: "cred_01",

fnBefore: func(ctx context.Context, dbi sqlx.ExtContext) error {
qs := []string{
Expand Down Expand Up @@ -151,7 +140,7 @@ func TestTLV2_GetCredSubmissionReport(t *testing.T) {
must.Equal(t, nil, err)
}

actual, err := repo.GetCredSubmissionReport(ctx, tx, tc.given.orderID, tc.given.itemID, tc.given.reqID, tc.given.creds...)
actual, err := repo.GetCredSubmissionReport(ctx, tx, tc.given.orderID, tc.given.itemID, tc.given.reqID, tc.given.firstBCred)
must.Equal(t, tc.exp.err, err)

if tc.exp.err != nil {
Expand Down

0 comments on commit 2d73671

Please sign in to comment.