diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 48957e830..dfac0d33b 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -141,10 +141,12 @@ func Router( func CredentialRouter(svc *Service, authMwr middlewareFn) chi.Router { r := chi.NewRouter() + valid := validator.New() + r.Method( http.MethodPost, "/subscription/verifications", - middleware.InstrumentHandler("VerifyCredentialV1", authMwr(VerifyCredentialV1(svc))), + middleware.InstrumentHandler("handleVerifyCredV1", authMwr(handleVerifyCredV1(svc, valid))), ) return r @@ -154,10 +156,12 @@ func CredentialRouter(svc *Service, authMwr middlewareFn) chi.Router { func CredentialV2Router(svc *Service, authMwr middlewareFn) chi.Router { r := chi.NewRouter() + valid := validator.New() + r.Method( http.MethodPost, "/subscription/verifications", - middleware.InstrumentHandler("VerifyCredentialV2", authMwr(VerifyCredentialV2(svc))), + middleware.InstrumentHandler("handleVerifyCredV2", authMwr(handleVerifyCredV2(svc, valid))), ) return r @@ -943,54 +947,79 @@ func MerchantTransactions(service *Service) handlers.AppHandler { }) } -func VerifyCredentialV2(service *Service) handlers.AppHandler { +func handleVerifyCredV2(svc *Service, valid *validator.Validate) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { ctx := r.Context() - l := logging.Logger(ctx, "skus").With().Str("func", "VerifyCredentialV2").Logger() + lg := logging.Logger(ctx, "skus").With().Str("func", "handleVerifyCredV2").Logger() + + data, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) + if err != nil { + lg.Warn().Err(err).Msg("failed to read body") + + return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) + } + + req, err := parseVerifyCredRequestV2(data) + if err != nil { + lg.Warn().Err(err).Msg("failed to deserialize request") - req := &VerifyCredentialRequestV2{} - if err := inputs.DecodeAndValidateReader(ctx, req, r.Body); err != nil { - l.Error().Err(err).Msg("failed to read request") return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) } - appErr := service.verifyCredential(ctx, req, w) - if appErr != nil { - l.Error().Err(appErr).Msg("failed to verify credential") + if err := validateVerifyCredRequestV2(valid, req); err != nil { + verrs, ok := collectValidationErrors(err) + if !ok { + return handlers.ValidationError("request", map[string]interface{}{"request-body": err.Error()}) + } + + return handlers.ValidationError("request", verrs) + } + + aerr := svc.verifyCredential(ctx, req, w) + if aerr != nil { + lg.Err(aerr).Msg("failed to verify credential") } - return appErr + return aerr } } -// VerifyCredentialV1 is the handler for verifying subscription credentials -func VerifyCredentialV1(service *Service) handlers.AppHandler { +func handleVerifyCredV1(svc *Service, valid *validator.Validate) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { ctx := r.Context() - l := logging.Logger(r.Context(), "VerifyCredentialV1") - var req = new(VerifyCredentialRequestV1) + lg := logging.Logger(ctx, "skus").With().Str("func", "handleVerifyCredV1").Logger() - err := requestutils.ReadJSON(r.Context(), r.Body, &req) + data, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) if err != nil { - l.Error().Err(err).Msg("failed to read request") + lg.Warn().Err(err).Msg("failed to read body") + return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) } - l.Debug().Msg("read verify credential post body") - _, err = govalidator.ValidateStruct(req) - if err != nil { - l.Error().Err(err).Msg("failed to validate request") - return handlers.WrapError(err, "Error in request validation", http.StatusBadRequest) + req := &model.VerifyCredentialRequestV1{} + if err := json.Unmarshal(data, req); err != nil { + lg.Warn().Err(err).Msg("failed to deserialize request") + + return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) + } + + if err := valid.StructCtx(ctx, req); err != nil { + verrs, ok := collectValidationErrors(err) + if !ok { + return handlers.ValidationError("request", map[string]interface{}{"request-body": err.Error()}) + } + + return handlers.ValidationError("request", verrs) } - appErr := service.verifyCredential(ctx, req, w) - if appErr != nil { - l.Error().Err(appErr).Msg("failed to verify credential") + aerr := svc.verifyCredential(ctx, req, w) + if aerr != nil { + lg.Err(aerr).Msg("failed to verify credential") } - return appErr + return aerr } } @@ -1639,3 +1668,42 @@ func collectValidationErrors(err error) (map[string]string, bool) { return result, true } + +func parseVerifyCredRequestV2(raw []byte) (*model.VerifyCredentialRequestV2, error) { + result := &model.VerifyCredentialRequestV2{} + + if err := json.Unmarshal(raw, result); err != nil { + return nil, err + } + + copaque, err := parseVerifyCredeOpaque(result.Credential) + if err != nil { + return nil, err + } + + result.CredentialOpaque = copaque + + return result, nil +} + +func parseVerifyCredeOpaque(raw string) (*model.VerifyCredentialOpaque, error) { + data, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, err + } + + result := &model.VerifyCredentialOpaque{} + if err = json.Unmarshal(data, result); err != nil { + return nil, err + } + + return result, nil +} + +func validateVerifyCredRequestV2(valid *validator.Validate, req *model.VerifyCredentialRequestV2) error { + if err := valid.Struct(req); err != nil { + return err + } + + return valid.Struct(req.CredentialOpaque) +} diff --git a/services/skus/controllers_noint_test.go b/services/skus/controllers_noint_test.go index e40bbb485..1317d2b08 100644 --- a/services/skus/controllers_noint_test.go +++ b/services/skus/controllers_noint_test.go @@ -3,6 +3,7 @@ package skus import ( "context" "net/http" + "reflect" "testing" "github.com/go-playground/validator/v10" @@ -222,3 +223,151 @@ func TestHandleReceiptErr(t *testing.T) { }) } } + +func TestParseVerifyCredRequestV2(t *testing.T) { + type tcExpected struct { + val *model.VerifyCredentialRequestV2 + errFn must.ErrorAssertionFunc + } + + type testCase struct { + name string + given []byte + exp tcExpected + } + + tests := []testCase{ + { + name: "error_malformed_payload", + given: []byte(`nonsense`), + exp: tcExpected{ + errFn: func(tt must.TestingT, err error, i ...interface{}) { + must.Equal(tt, true, err != nil) + }, + }, + }, + + { + name: "error_malformed_credential", + given: []byte(`{"sku":"sku","merchantId":"merchantId"}`), + exp: tcExpected{ + errFn: func(tt must.TestingT, err error, i ...interface{}) { + must.Equal(tt, true, err != nil) + }, + }, + }, + + { + name: "success_complete", + given: []byte(`{"sku": "sku","merchantId": "merchantId","credential":"eyJ0eXBlIjoidGltZS1saW1pdGVkLXYyIiwicHJlc2VudGF0aW9uIjoiVG1GMGRYSmxJR0ZpYUc5eWN5QmhJSFpoWTNWMWJTNEsifQo="}`), + exp: tcExpected{ + val: &model.VerifyCredentialRequestV2{ + SKU: "sku", + MerchantID: "merchantId", + Credential: "eyJ0eXBlIjoidGltZS1saW1pdGVkLXYyIiwicHJlc2VudGF0aW9uIjoiVG1GMGRYSmxJR0ZpYUc5eWN5QmhJSFpoWTNWMWJTNEsifQo=", + CredentialOpaque: &model.VerifyCredentialOpaque{ + Type: "time-limited-v2", + Presentation: "TmF0dXJlIGFiaG9ycyBhIHZhY3V1bS4K", + }, + }, + errFn: func(tt must.TestingT, err error, i ...interface{}) { + must.Equal(tt, true, err == nil) + }, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, err := parseVerifyCredRequestV2(tc.given) + tc.exp.errFn(t, err) + + should.Equal(t, tc.exp.val, actual) + }) + } +} + +func TestValidateVerifyCredRequestV2(t *testing.T) { + type tcGiven struct { + valid *validator.Validate + req *model.VerifyCredentialRequestV2 + } + + tests := []struct { + name string + given tcGiven + exp error + }{ + { + name: "error_credential_opaque_nil", + given: tcGiven{ + valid: validator.New(), + req: &model.VerifyCredentialRequestV2{ + SKU: "sku", + MerchantID: "merchantId", + Credential: "eyJ0eXBlIjoic2luZ2xlLXVzZSIsInByZXNlbnRhdGlvbiI6IlRtRjBkWEpsSUdGaWFHOXljeUJoSUhaaFkzVjFiUzRLIn0K", + }, + }, + exp: &validator.InvalidValidationError{Type: reflect.TypeOf((*model.VerifyCredentialOpaque)(nil))}, + }, + + { + name: "valid_single_use", + given: tcGiven{ + valid: validator.New(), + req: &model.VerifyCredentialRequestV2{ + SKU: "sku", + MerchantID: "merchantId", + Credential: "eyJ0eXBlIjoic2luZ2xlLXVzZSIsInByZXNlbnRhdGlvbiI6IlRtRjBkWEpsSUdGaWFHOXljeUJoSUhaaFkzVjFiUzRLIn0K", + CredentialOpaque: &model.VerifyCredentialOpaque{ + Type: "single-use", + Presentation: "TmF0dXJlIGFiaG9ycyBhIHZhY3V1bS4K", + }, + }, + }, + }, + + { + name: "valid_time_limited", + given: tcGiven{ + valid: validator.New(), + req: &model.VerifyCredentialRequestV2{ + SKU: "sku", + MerchantID: "merchantId", + Credential: "eyJ0eXBlIjoidGltZS1saW1pdGVkIiwicHJlc2VudGF0aW9uIjoiVG1GMGRYSmxJR0ZpYUc5eWN5QmhJSFpoWTNWMWJTNEsifQo=", + CredentialOpaque: &model.VerifyCredentialOpaque{ + Type: "time-limited", + Presentation: "TmF0dXJlIGFiaG9ycyBhIHZhY3V1bS4K", + }, + }, + }, + }, + + { + name: "valid_time_limited_v2", + given: tcGiven{ + valid: validator.New(), + req: &model.VerifyCredentialRequestV2{ + SKU: "sku", + MerchantID: "merchantId", + Credential: "eyJ0eXBlIjoidGltZS1saW1pdGVkLXYyIiwicHJlc2VudGF0aW9uIjoiVG1GMGRYSmxJR0ZpYUc5eWN5QmhJSFpoWTNWMWJTNEsifQo=", + CredentialOpaque: &model.VerifyCredentialOpaque{ + Type: "time-limited-v2", + Presentation: "TmF0dXJlIGFiaG9ycyBhIHZhY3V1bS4K", + }, + }, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := validateVerifyCredRequestV2(tc.given.valid, tc.given.req) + should.Equal(t, tc.exp, actual) + }) + } +} diff --git a/services/skus/credentials.go b/services/skus/credentials.go index 60ac2fa65..71d2c7701 100644 --- a/services/skus/credentials.go +++ b/services/skus/credentials.go @@ -61,7 +61,7 @@ var ( // // This only happens in the event of a new sku being created. func (s *Service) CreateIssuer(ctx context.Context, dbi sqlx.QueryerContext, merchID string, item *OrderItem) error { - encMerchID, err := encodeIssuerID(merchID, item.SKU) + encMerchID, err := encodeIssuerID(merchID, item.SKUForIssuer()) if err != nil { return errorutils.Wrap(err, "error encoding issuer name") } @@ -114,7 +114,7 @@ func (s *Service) CreateIssuer(ctx context.Context, dbi sqlx.QueryerContext, mer // // This only happens in the event of a new sku being created. func (s *Service) CreateIssuerV3(ctx context.Context, dbi sqlx.QueryerContext, merchID string, item *OrderItem, issuerCfg model.IssuerConfig) error { - encMerchID, err := encodeIssuerID(merchID, item.SKU) + encMerchID, err := encodeIssuerID(merchID, item.SKUForIssuer()) if err != nil { return errorutils.Wrap(err, "error encoding issuer name") } @@ -266,7 +266,7 @@ func (s *Service) CreateOrderItemCredentials(ctx context.Context, orderID, itemI return err } - issuerID, err := encodeIssuerID(order.MerchantID, item.SKU) + issuerID, err := encodeIssuerID(order.MerchantID, item.SKUForIssuer()) if err != nil { return errorutils.Wrap(err, "error encoding issuer name") } diff --git a/services/skus/credentials_test.go b/services/skus/credentials_test.go index 48d2c9c47..60062a40f 100644 --- a/services/skus/credentials_test.go +++ b/services/skus/credentials_test.go @@ -237,7 +237,7 @@ func TestCreateIssuer_NewIssuer(t *testing.T) { EachCredentialValidForISO: ptr.FromString("P1D"), } - issuerID, err := encodeIssuerID(merchantID, orderItem.SKU) + issuerID, err := encodeIssuerID(merchantID, orderItem.SKUForIssuer()) must.Equal(t, nil, err) cbrClient := mock_cbr.NewMockClient(ctrl) @@ -291,7 +291,7 @@ func TestCreateIssuerV3_NewIssuer(t *testing.T) { EachCredentialValidForISO: ptr.FromString("P1D"), } - issuerID, err := encodeIssuerID(merchantID, orderItem.SKU) + issuerID, err := encodeIssuerID(merchantID, orderItem.SKUForIssuer()) must.Equal(t, nil, err) issuerConfig := model.IssuerConfig{ @@ -360,7 +360,7 @@ func TestCreateIssuer_AlreadyExists(t *testing.T) { EachCredentialValidForISO: ptr.FromString("P1D"), } - issuerID, err := encodeIssuerID(merchantID, orderItem.SKU) + issuerID, err := encodeIssuerID(merchantID, orderItem.SKUForIssuer()) must.Equal(t, nil, err) issuer := &Issuer{ @@ -402,7 +402,7 @@ func TestCreateIssuerV3_AlreadyExists(t *testing.T) { EachCredentialValidForISO: ptr.FromString("P1D"), } - issuerID, err := encodeIssuerID(merchantID, orderItem.SKU) + issuerID, err := encodeIssuerID(merchantID, orderItem.SKUForIssuer()) must.Equal(t, nil, err) issuer := &Issuer{ @@ -477,7 +477,7 @@ func TestCreateOrderCredentials(t *testing.T) { EachCredentialValidForISO: ptr.FromString("P1D"), } - issuerID, err := encodeIssuerID(merchantID, orderItem.SKU) + issuerID, err := encodeIssuerID(merchantID, orderItem.SKUForIssuer()) must.Equal(t, nil, err) issuer := &Issuer{ @@ -543,7 +543,6 @@ func TestDeduplicateCredentialBindings(t *testing.T) { } func TestIssuerID(t *testing.T) { - cases := []struct { MerchantID string SKU string diff --git a/services/skus/input.go b/services/skus/input.go deleted file mode 100644 index 8a3bedc21..000000000 --- a/services/skus/input.go +++ /dev/null @@ -1,124 +0,0 @@ -package skus - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - - "github.com/asaskevich/govalidator" - "github.com/brave-intl/bat-go/libs/logging" -) - -// VerifyCredentialRequestV1 includes an opaque subscription credential blob -type VerifyCredentialRequestV1 struct { - Version float64 `json:"version" valid:"-"` - Type string `json:"type" valid:"in(single-use|time-limited|time-limited-v2)"` - SKU string `json:"sku" valid:"-"` - MerchantID string `json:"merchantId" valid:"-"` - Presentation string `json:"presentation" valid:"base64"` -} - -// GetSku - implement credential interface -func (vcr *VerifyCredentialRequestV1) GetSku(ctx context.Context) string { - return vcr.SKU -} - -// GetType - implement credential interface -func (vcr *VerifyCredentialRequestV1) GetType(ctx context.Context) string { - return vcr.Type -} - -// GetMerchantID - implement credential interface -func (vcr *VerifyCredentialRequestV1) GetMerchantID(ctx context.Context) string { - return vcr.MerchantID -} - -// GetPresentation - implement credential interface -func (vcr *VerifyCredentialRequestV1) GetPresentation(ctx context.Context) string { - return vcr.Presentation -} - -// VerifyCredentialRequestV2 includes an opaque subscription credential blob -type VerifyCredentialRequestV2 struct { - SKU string `json:"sku" valid:"-"` - MerchantID string `json:"merchantId" valid:"-"` - Credential string `json:"credential" valid:"base64"` - CredentialOpaque *VerifyCredentialOpaque `json:"-" valid:"-"` -} - -// GetSku - implement credential interface -func (vcr *VerifyCredentialRequestV2) GetSku(ctx context.Context) string { - return vcr.SKU -} - -// GetType - implement credential interface -func (vcr *VerifyCredentialRequestV2) GetType(ctx context.Context) string { - if vcr.CredentialOpaque == nil { - return "" - } - return vcr.CredentialOpaque.Type -} - -// GetMerchantID - implement credential interface -func (vcr *VerifyCredentialRequestV2) GetMerchantID(ctx context.Context) string { - return vcr.MerchantID -} - -// GetPresentation - implement credential interface -func (vcr *VerifyCredentialRequestV2) GetPresentation(ctx context.Context) string { - if vcr.CredentialOpaque == nil { - return "" - } - return vcr.CredentialOpaque.Presentation -} - -// Decode - implement Decodable interface -func (vcr *VerifyCredentialRequestV2) Decode(ctx context.Context, data []byte) error { - logger := logging.Logger(ctx, "VerifyCredentialRequestV2.Decode") - logger.Debug().Msg("starting VerifyCredentialRequestV2.Decode") - var err error - - if err := json.Unmarshal(data, vcr); err != nil { - return fmt.Errorf("failed to json decode credential request payload: %w", err) - } - // decode the opaque credential - if vcr.CredentialOpaque, err = credentialOpaqueFromString(vcr.Credential); err != nil { - return fmt.Errorf("failed to decode opaque credential payload: %w", err) - } - return nil -} - -// Validate - implement Validable interface -func (vcr *VerifyCredentialRequestV2) Validate(ctx context.Context) error { - logger := logging.Logger(ctx, "VerifyCredentialRequestV2.Validate") - var err error - for _, v := range []interface{}{vcr, vcr.CredentialOpaque} { - _, err = govalidator.ValidateStruct(v) - if err != nil { - logger.Error().Err(err).Msg("failed to validate request") - return fmt.Errorf("failed to validate verify credential request: %w", err) - } - } - return nil -} - -// VerifyCredentialOpaque includes an opaque presentation blob -type VerifyCredentialOpaque struct { - Type string `json:"type" valid:"in(single-use|time-limited|time-limited-v2)"` - Version float64 `json:"version" valid:"-"` - Presentation string `json:"presentation" valid:"base64"` -} - -// credentialOpaqueFromString - given a base64 encoded "credential" unmarshal into a VerifyCredentialOpaque -func credentialOpaqueFromString(s string) (*VerifyCredentialOpaque, error) { - d, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, fmt.Errorf("failed to base64 decode credential payload: %w", err) - } - var vcp = new(VerifyCredentialOpaque) - if err = json.Unmarshal(d, vcp); err != nil { - return nil, fmt.Errorf("failed to json decode credential payload: %w", err) - } - return vcp, nil -} diff --git a/services/skus/model/model.go b/services/skus/model/model.go index 8676ffd76..a053a9b46 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -8,6 +8,7 @@ import ( "net/url" "sort" "strconv" + "strings" "time" "github.com/lib/pq" @@ -370,6 +371,10 @@ func (x *OrderItem) StripeItemID() (string, bool) { return itemID, ok } +func (x *OrderItem) SKUForIssuer() string { + return fixPremiumSKUForIssuer(x.SKU) +} + // OrderNew represents a request to create an order in the database. type OrderNew struct { MerchantID string `db:"merchant_id"` @@ -677,6 +682,71 @@ type CreateOrderWithReceiptResponse struct { ID string `json:"orderId"` } +type VerifyCredentialRequestV1 struct { + Type string `json:"type" validate:"oneof=single-use time-limited time-limited-v2"` + SKU string `json:"sku" validate:"-"` + MerchantID string `json:"merchantId" validate:"-"` + Presentation string `json:"presentation" validate:"base64"` + Version float64 `json:"version" validate:"-"` +} + +func (r *VerifyCredentialRequestV1) GetSKU() string { + return fixPremiumSKUForIssuer(r.SKU) +} + +func (r *VerifyCredentialRequestV1) GetType() string { + return r.Type +} + +func (r *VerifyCredentialRequestV1) GetMerchantID() string { + return r.MerchantID +} + +func (r *VerifyCredentialRequestV1) GetPresentation() string { + return r.Presentation +} + +type VerifyCredentialRequestV2 struct { + SKU string `json:"sku" validate:"-"` + MerchantID string `json:"merchantId" validate:"-"` + Credential string `json:"credential" validate:"base64"` + CredentialOpaque *VerifyCredentialOpaque `json:"-" validate:"-"` +} + +func (r *VerifyCredentialRequestV2) GetSKU() string { + return fixPremiumSKUForIssuer(r.SKU) +} + +func (r *VerifyCredentialRequestV2) GetType() string { + if r.CredentialOpaque == nil { + return "" + } + + return r.CredentialOpaque.Type +} + +func (r *VerifyCredentialRequestV2) GetMerchantID() string { + return r.MerchantID +} + +func (r *VerifyCredentialRequestV2) GetPresentation() string { + if r.CredentialOpaque == nil { + return "" + } + + return r.CredentialOpaque.Presentation +} + +type VerifyCredentialOpaque struct { + Type string `json:"type" validate:"oneof=single-use time-limited time-limited-v2"` + Presentation string `json:"presentation" validate:"base64"` + Version float64 `json:"version" validate:"-"` +} + +func fixPremiumSKUForIssuer(val string) string { + return strings.TrimSuffix(val, "-year") +} + func addURLParam(src, name, val string) (string, error) { raw, err := url.Parse(src) if err != nil { diff --git a/services/skus/model/model_pvt_test.go b/services/skus/model/model_pvt_test.go index 84bb72e23..32b29484e 100644 --- a/services/skus/model/model_pvt_test.go +++ b/services/skus/model/model_pvt_test.go @@ -83,3 +83,41 @@ func TestAddURLParam(t *testing.T) { }) } } + +func TestFixPremiumSKUForIssuer(t *testing.T) { + tests := []struct { + name string + given string + exp string + }{ + { + name: "empty", + }, + + { + name: "trimmed_empty", + given: "-year", + }, + + { + name: "untouched", + given: "anything", + exp: "anything", + }, + + { + name: "trimmed", + given: "anything-year", + exp: "anything", + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := fixPremiumSKUForIssuer(tc.given) + should.Equal(t, tc.exp, actual) + }) + } +} diff --git a/services/skus/model/model_test.go b/services/skus/model/model_test.go index 8bc44525c..436a09e06 100644 --- a/services/skus/model/model_test.go +++ b/services/skus/model/model_test.go @@ -1333,6 +1333,93 @@ func TestOrderItem_StripeItemID(t *testing.T) { } } +func TestOrderItem_SKUForIssuer(t *testing.T) { + type testCase struct { + name string + given model.OrderItem + exp string + } + + tests := []testCase{ + { + name: "empty", + }, + + { + name: "talk", + given: model.OrderItem{ + SKU: "brave-talk-premium", + }, + exp: "brave-talk-premium", + }, + + { + name: "talka", + given: model.OrderItem{ + SKU: "brave-talk-premium-year", + }, + exp: "brave-talk-premium", + }, + + { + name: "search", + given: model.OrderItem{ + SKU: "brave-search-premium", + }, + exp: "brave-search-premium", + }, + + { + name: "searcha", + given: model.OrderItem{ + SKU: "brave-search-premium-year", + }, + exp: "brave-search-premium", + }, + + { + name: "vpn", + given: model.OrderItem{ + SKU: "brave-vpn-premium", + }, + exp: "brave-vpn-premium", + }, + + { + name: "vpna", + given: model.OrderItem{ + SKU: "brave-vpn-premium-year", + }, + exp: "brave-vpn-premium", + }, + + { + name: "leo", + given: model.OrderItem{ + SKU: "brave-leo-premium", + }, + exp: "brave-leo-premium", + }, + + { + name: "leoa", + given: model.OrderItem{ + SKU: "brave-leo-premium-year", + }, + exp: "brave-leo-premium", + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.SKUForIssuer() + should.Equal(t, tc.exp, actual) + }) + } +} + func TestOrderItemRequestNew_TokenBufferOrDefault(t *testing.T) { type testCase struct { name string @@ -1482,6 +1569,76 @@ func TestOrderItemRequestNew_IsTLV2(t *testing.T) { } } +func TestVerifyCredentialRequestV1_GetSKU(t *testing.T) { + type testCase struct { + name string + given model.VerifyCredentialRequestV1 + exp string + } + + tests := []testCase{ + { + name: "anything", + given: model.VerifyCredentialRequestV1{ + SKU: "anything", + }, + exp: "anything", + }, + + { + name: "leoa", + given: model.VerifyCredentialRequestV1{ + SKU: "brave-leo-premium-year", + }, + exp: "brave-leo-premium", + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.GetSKU() + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestVerifyCredentialRequestV2_GetSKU(t *testing.T) { + type testCase struct { + name string + given model.VerifyCredentialRequestV2 + exp string + } + + tests := []testCase{ + { + name: "anything", + given: model.VerifyCredentialRequestV2{ + SKU: "anything", + }, + exp: "anything", + }, + + { + name: "leoa", + given: model.VerifyCredentialRequestV2{ + SKU: "brave-leo-premium-year", + }, + exp: "brave-leo-premium", + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.GetSKU() + should.Equal(t, tc.exp, actual) + }) + } +} + func ptrTo[T any](v T) *T { return &v } diff --git a/services/skus/service.go b/services/skus/service.go index ce9f7ac33..0afb42552 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -1381,7 +1381,7 @@ func (s *Service) GetTimeLimitedCreds(ctx context.Context, order *Order, itemID, return nil, http.StatusInternalServerError, model.Error("unable to parse issuance interval for credentials") } - issuerID, err := encodeIssuerID(order.MerchantID, item.SKU) + issuerID, err := encodeIssuerID(order.MerchantID, item.SKUForIssuer()) if err != nil { return nil, http.StatusInternalServerError, fmt.Errorf("error encoding issuer: %w", err) } @@ -1399,10 +1399,10 @@ func (s *Service) GetTimeLimitedCreds(ctx context.Context, order *Order, itemID, } type credential interface { - GetSku(context.Context) string - GetType(context.Context) string - GetMerchantID(context.Context) string - GetPresentation(context.Context) string + GetSKU() string + GetType() string + GetMerchantID() string + GetPresentation() string } // verifyCredential - given a credential, verify it. @@ -1417,21 +1417,21 @@ func (s *Service) verifyCredential(ctx context.Context, cred credential, w http. caveats := caveatsFromCtx(ctx) - if merchID := cred.GetMerchantID(ctx); merchID != merchant { + if merchID := cred.GetMerchantID(); merchID != merchant { logger.Warn().Str("req.MerchantID", merchID).Str("merchant", merchant).Msg("merchant does not match the key's merchant") return handlers.WrapError(nil, "Verify request merchant does not match authentication", http.StatusForbidden) } if caveats != nil { if sku, ok := caveats["sku"]; ok { - if csku := cred.GetSku(ctx); csku != sku { + if csku := cred.GetSKU(); csku != sku { logger.Warn().Str("req.SKU", csku).Str("sku", sku).Msg("sku caveat does not match") return handlers.WrapError(nil, "Verify request sku does not match authentication", http.StatusForbidden) } } } - kind := cred.GetType(ctx) + kind := cred.GetType() switch kind { case singleUse, timeLimitedV2: return s.verifyBlindedTokenCredential(ctx, cred, w) @@ -1444,7 +1444,7 @@ func (s *Service) verifyCredential(ctx context.Context, cred credential, w http. // verifyBlindedTokenCredential verifies a single use or time limited v2 credential. func (s *Service) verifyBlindedTokenCredential(ctx context.Context, req credential, w http.ResponseWriter) *handlers.AppError { - bytes, err := base64.StdEncoding.DecodeString(req.GetPresentation(ctx)) + bytes, err := base64.StdEncoding.DecodeString(req.GetPresentation()) if err != nil { return handlers.WrapError(err, "Error in decoding presentation", http.StatusBadRequest) } @@ -1455,7 +1455,7 @@ func (s *Service) verifyBlindedTokenCredential(ctx context.Context, req credenti } // Ensure that the credential being redeemed (opaque to merchant) matches the outer credential details. - issuerID, err := encodeIssuerID(req.GetMerchantID(ctx), req.GetSku(ctx)) + issuerID, err := encodeIssuerID(req.GetMerchantID(), req.GetSKU()) if err != nil { return handlers.WrapError(err, "Error in outer merchantId or sku", http.StatusBadRequest) } @@ -1464,12 +1464,12 @@ func (s *Service) verifyBlindedTokenCredential(ctx context.Context, req credenti return handlers.WrapError(nil, "Error, outer merchant and sku don't match issuer", http.StatusBadRequest) } - return s.redeemBlindedCred(ctx, w, req.GetType(ctx), decodedCred) + return s.redeemBlindedCred(ctx, w, req.GetType(), decodedCred) } // verifyTimeLimitedV1Credential verifies a time limited v1 credential. func (s *Service) verifyTimeLimitedV1Credential(ctx context.Context, req credential, w http.ResponseWriter) *handlers.AppError { - data, err := base64.StdEncoding.DecodeString(req.GetPresentation(ctx)) + data, err := base64.StdEncoding.DecodeString(req.GetPresentation()) if err != nil { return handlers.WrapError(err, "Error in decoding presentation", http.StatusBadRequest) } @@ -1479,10 +1479,10 @@ func (s *Service) verifyTimeLimitedV1Credential(ctx context.Context, req credent return handlers.WrapError(err, "Error in presentation formatting", http.StatusBadRequest) } - merchID := req.GetMerchantID(ctx) + merchID := req.GetMerchantID() // Ensure that the credential being redeemed (opaque to merchant) matches the outer credential details. - issuerID, err := encodeIssuerID(merchID, req.GetSku(ctx)) + issuerID, err := encodeIssuerID(merchID, req.GetSKU()) if err != nil { return handlers.WrapError(err, "Error in outer merchantId or sku", http.StatusBadRequest) }