From 9d40f5a3b28a57aba71cf68fffd18b69f9f72f83 Mon Sep 17 00:00:00 2001 From: PavelBrm Date: Wed, 31 Jul 2024 23:00:57 +1200 Subject: [PATCH] refactor: use proper expiry time when fixing up order while getting --- services/skus/controllers.go | 36 +- services/skus/controllers_test.go | 4 +- services/skus/model/model.go | 110 +-- services/skus/model/model_test.go | 193 ++++- services/skus/order.go | 16 - services/skus/service.go | 364 ++++++---- services/skus/service_nonint_test.go | 671 +++++++++++++++++- services/skus/storage/repository/mock.go | 27 +- services/skus/xstripe/mock.go | 73 ++ services/skus/xstripe/xstripe.go | 60 ++ .../xstripe_test.go} | 6 +- 11 files changed, 1266 insertions(+), 294 deletions(-) create mode 100644 services/skus/xstripe/mock.go create mode 100644 services/skus/xstripe/xstripe.go rename services/skus/{order_noint_test.go => xstripe/xstripe_test.go} (90%) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 774f29a39..c02ac4097 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -72,7 +72,7 @@ func Router( { corsMwrGet := NewCORSMwr(copts, http.MethodGet) r.Method(http.MethodOptions, "/{orderID}", metricsMwr("GetOrderOptions", corsMwrGet(nil))) - r.Method(http.MethodGet, "/{orderID}", metricsMwr("GetOrder", corsMwrGet(GetOrder(svc)))) + r.Method(http.MethodGet, "/{orderID}", metricsMwr("GetOrder", corsMwrGet(handleGetOrder(svc)))) } r.Method( @@ -370,30 +370,30 @@ func CancelOrder(service *Service) handlers.AppHandler { }) } -// GetOrder is the handler for getting an order -func GetOrder(service *Service) handlers.AppHandler { +func handleGetOrder(svc *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var orderID = new(inputs.ID) - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { - return handlers.ValidationError( - "Error validating request url parameter", - map[string]interface{}{ - "orderID": err.Error(), - }, - ) - } + ctx := r.Context() - order, err := service.GetOrder(*orderID.UUID()) + orderID, err := uuid.FromString(chi.URLParamFromCtx(ctx, "orderID")) if err != nil { - return handlers.WrapError(err, "Error retrieving the order", http.StatusInternalServerError) + return handlers.ValidationError("request", map[string]interface{}{"orderID": err.Error()}) } - status := http.StatusOK - if order == nil { - status = http.StatusNotFound + order, err := svc.getTransformOrder(ctx, orderID) + if err != nil { + switch { + case errors.Is(err, context.Canceled): + return handlers.WrapError(model.ErrSomethingWentWrong, "request has been cancelled", model.StatusClientClosedConn) + + case errors.Is(err, model.ErrOrderNotFound): + return handlers.WrapError(err, "order not found", http.StatusNotFound) + + default: + return handlers.WrapError(err, "Error retrieving the order", http.StatusInternalServerError) + } } - return handlers.RenderContent(r.Context(), order, w, status) + return handlers.RenderContent(ctx, order, w, http.StatusOK) }) } diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index d8efc75d2..28a02d460 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -407,7 +407,7 @@ func (suite *ControllersTestSuite) TestGetOrder() { req, err := http.NewRequest("GET", "/v1/orders/{orderID}", nil) suite.Require().NoError(err) - getOrderHandler := GetOrder(suite.service) + getOrderHandler := handleGetOrder(suite.service) rctx := chi.NewRouteContext() rctx.URLParams.Add("orderID", order.ID.String()) getReq := req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) @@ -435,7 +435,7 @@ func (suite *ControllersTestSuite) TestGetMissingOrder() { req, err := http.NewRequest("GET", "/v1/orders/{orderID}", nil) suite.Require().NoError(err) - getOrderHandler := GetOrder(suite.service) + getOrderHandler := handleGetOrder(suite.service) rctx := chi.NewRouteContext() rctx.URLParams.Add("orderID", "9645ca16-bc93-4e37-8edf-cb35b1763216") getReq := req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) diff --git a/services/skus/model/model.go b/services/skus/model/model.go index 77856e7f0..c486701ab 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -13,9 +13,6 @@ import ( "github.com/lib/pq" uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" - "github.com/stripe/stripe-go/v72" - "github.com/stripe/stripe-go/v72/checkout/session" - "github.com/stripe/stripe-go/v72/customer" "github.com/brave-intl/bat-go/libs/clients/radom" "github.com/brave-intl/bat-go/libs/datastore" @@ -138,71 +135,6 @@ func (o *Order) ShouldSetTrialDays() bool { return !o.IsPaid() && o.IsStripePayable() } -// CreateStripeCheckoutSession creates a Stripe checkout session for the order. -// -// Deprecated: Use CreateStripeCheckoutSession function instead of this method. -func (o *Order) CreateStripeCheckoutSession( - email, successURI, cancelURI string, - freeTrialDays int64, -) (CreateCheckoutSessionResponse, error) { - return CreateStripeCheckoutSession(o.ID.String(), email, successURI, cancelURI, freeTrialDays, o.Items) -} - -// CreateStripeCheckoutSession creates a Stripe checkout session for the order. -func CreateStripeCheckoutSession( - oid, email, successURI, cancelURI string, - trialDays int64, - items []OrderItem, -) (CreateCheckoutSessionResponse, error) { - var custID string - if email != "" { - // Find the existing customer by email to use the customer id instead email. - l := customer.List(&stripe.CustomerListParams{ - Email: stripe.String(email), - }) - - for l.Next() { - custID = l.Customer().ID - } - } - - params := &stripe.CheckoutSessionParams{ - // TODO: Get rid of this stripe.* nonsense, and use ptrTo instead. - PaymentMethodTypes: stripe.StringSlice([]string{"card"}), - Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), - SuccessURL: stripe.String(successURI), - CancelURL: stripe.String(cancelURI), - ClientReferenceID: stripe.String(oid), - SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{}, - LineItems: OrderItemList(items).stripeLineItems(), - } - - // If a free trial is set, apply it. - if trialDays > 0 { - params.SubscriptionData.TrialPeriodDays = &trialDays - } - - if custID != "" { - // Use existing customer if found. - params.Customer = stripe.String(custID) - } else if email != "" { - // Otherwise, create a new using email. - params.CustomerEmail = stripe.String(email) - } - // Otherwise, we have no record of this email for this checkout session. - // ? The user will be asked for the email, we cannot send an empty customer email as a param. - - params.SubscriptionData.AddMetadata("orderID", oid) - params.AddExtra("allow_promotion_codes", "true") - - session, err := session.New(params) - if err != nil { - return EmptyCreateCheckoutSessionResponse(), fmt.Errorf("failed to create stripe session: %w", err) - } - - return CreateCheckoutSessionResponse{SessionID: session.ID}, nil -} - // CreateRadomCheckoutSession creates a Radom checkout session for o. func (o *Order) CreateRadomCheckoutSession( ctx context.Context, @@ -341,6 +273,12 @@ func (o *Order) StripeSubID() (string, bool) { return sid, ok } +func (o *Order) StripeSessID() (string, bool) { + sessID, ok := o.Metadata["stripeCheckoutSessionId"].(string) + + return sessID, ok +} + func (o *Order) IsIOS() bool { pp, ok := o.PaymentProc() if !ok { @@ -369,6 +307,15 @@ func (o *Order) IsAndroid() bool { return pp == "android" && vn == VendorGoogle } +func (o *Order) IsStripe() bool { + pp, ok := o.PaymentProc() + if !ok { + return false + } + + return pp == StripePaymentMethod +} + func (o *Order) PaymentProc() (string, bool) { pp, ok := o.Metadata["paymentProcessor"].(string) @@ -417,6 +364,12 @@ func (x *OrderItem) IsLeo() bool { return x.SKU == "brave-leo-premium" } +func (x *OrderItem) StripeItemID() (string, bool) { + itemID, ok := x.Metadata["stripe_item_id"].(string) + + return itemID, ok +} + // OrderNew represents a request to create an order in the database. type OrderNew struct { MerchantID string `db:"merchant_id"` @@ -466,27 +419,6 @@ func (l OrderItemList) HasItem(id uuid.UUID) (*OrderItem, bool) { } -func (l OrderItemList) stripeLineItems() []*stripe.CheckoutSessionLineItemParams { - result := make([]*stripe.CheckoutSessionLineItemParams, 0, len(l)) - - for _, item := range l { - // Obtain the item id from the metadata. - priceID, ok := item.Metadata["stripe_item_id"].(string) - if !ok { - continue - } - - // Assume that the stripe product is embedded in macaroon as metadata - // because a stripe line item is being created. - result = append(result, &stripe.CheckoutSessionLineItemParams{ - Price: stripe.String(priceID), - Quantity: stripe.Int64(int64(item.Quantity)), - }) - } - - return result -} - type Error string func (e Error) Error() string { diff --git a/services/skus/model/model_test.go b/services/skus/model/model_test.go index 94320f73b..0f028699c 100644 --- a/services/skus/model/model_test.go +++ b/services/skus/model/model_test.go @@ -317,7 +317,7 @@ func TestOrderItemRequestNew_Unmarshal(t *testing.T) { "price": "1" }`), exp: &model.OrderItemRequestNew{ - Price: mustDecimalFromString("1"), + Price: decimal.RequireFromString("1"), }, }, @@ -379,7 +379,7 @@ func TestOrderItemRequestNew_Unmarshal(t *testing.T) { } }`), exp: &model.OrderItemRequestNew{ - Price: mustDecimalFromString("1"), + Price: decimal.RequireFromString("1"), CredentialValidDurationEach: ptrTo("P1D"), IssuanceInterval: ptrTo("P1M"), StripeMetadata: &model.ItemStripeMetadata{ @@ -764,6 +764,71 @@ func TestOrder_StripeSubID(t *testing.T) { } } +func TestOrder_StripeSessID(t *testing.T) { + type tcExpected struct { + val string + ok bool + } + + type testCase struct { + name string + given model.Order + exp tcExpected + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_field", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "not_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": 42, + }, + }, + }, + + { + name: "empty_string", + given: model.Order{ + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": "", + }, + }, + exp: tcExpected{ok: true}, + }, + + { + name: "sess_id", + given: model.Order{ + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": "sess_id", + }, + }, + exp: tcExpected{val: "sess_id", ok: true}, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, ok := tc.given.StripeSessID() + should.Equal(t, tc.exp.ok, ok) + should.Equal(t, tc.exp.val, actual) + }) + } +} + func TestOrder_IsIOS(t *testing.T) { type testCase struct { name string @@ -924,6 +989,64 @@ func TestOrder_IsAndroid(t *testing.T) { } } +func TestOrder_IsStripe(t *testing.T) { + type testCase struct { + name string + given model.Order + exp bool + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_pp", + given: model.Order{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "false_pp_ios", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "ios", + }, + }, + }, + + { + name: "false_pp_android", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "android", + }, + }, + }, + + { + name: "pp_stripe", + given: model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "stripe", + }, + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.IsStripe() + should.Equal(t, tc.exp, actual) + }) + } +} + func TestOrder_PaymentProc(t *testing.T) { type tcExpected struct { val string @@ -1143,13 +1266,69 @@ func TestOrder_ShouldSetTrialDays(t *testing.T) { } } -func mustDecimalFromString(v string) decimal.Decimal { - result, err := decimal.NewFromString(v) - if err != nil { - panic(err) +func TestOrderItem_StripeItemID(t *testing.T) { + type tcExpected struct { + val string + ok bool + } + + type testCase struct { + name string + given model.OrderItem + exp tcExpected + } + + tests := []testCase{ + { + name: "no_metadata", + }, + + { + name: "no_field", + given: model.OrderItem{ + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + + { + name: "not_string", + given: model.OrderItem{ + Metadata: datastore.Metadata{ + "stripe_item_id": 42, + }, + }, + }, + + { + name: "empty_string", + given: model.OrderItem{ + Metadata: datastore.Metadata{ + "stripe_item_id": "", + }, + }, + exp: tcExpected{ok: true}, + }, + + { + name: "stripe_item_id", + given: model.OrderItem{ + Metadata: datastore.Metadata{ + "stripe_item_id": "stripe_item_id", + }, + }, + exp: tcExpected{val: "stripe_item_id", ok: true}, + }, } - return result + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, ok := tc.given.StripeItemID() + should.Equal(t, tc.exp.ok, ok) + should.Equal(t, tc.exp.val, actual) + }) + } } func ptrTo[T any](v T) *T { diff --git a/services/skus/order.go b/services/skus/order.go index 0d68a3427..e73d38fa8 100644 --- a/services/skus/order.go +++ b/services/skus/order.go @@ -9,7 +9,6 @@ import ( "time" "github.com/shopspring/decimal" - "github.com/stripe/stripe-go/v72" "gopkg.in/macaroon.v2" "github.com/brave-intl/bat-go/libs/logging" @@ -155,18 +154,3 @@ func (s *Service) CreateOrderItemFromMacaroon(ctx context.Context, sku string, q return &orderItem, allowedPaymentMethods, issuerConfig, nil } - -func getCustEmailFromStripeCheckout(sess *stripe.CheckoutSession) string { - // Use the customer email if the customer has completed the payment flow. - if sess.Customer != nil && sess.Customer.Email != "" { - return sess.Customer.Email - } - - // This is unlikely to be set, but in case it is, use it. - if sess.CustomerEmail != "" { - return sess.CustomerEmail - } - - // Default to empty, Stripe will ask the customer. - return "" -} diff --git a/services/skus/service.go b/services/skus/service.go index 49d972f7d..fe7c8b0de 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net/http" - "net/url" "os" "strconv" "strings" @@ -23,7 +22,6 @@ import ( "github.com/segmentio/kafka-go" "github.com/shopspring/decimal" "github.com/stripe/stripe-go/v72" - "github.com/stripe/stripe-go/v72/checkout/session" "github.com/stripe/stripe-go/v72/client" "github.com/stripe/stripe-go/v72/sub" "google.golang.org/api/idtoken" @@ -48,6 +46,7 @@ import ( "github.com/brave-intl/bat-go/services/wallet" "github.com/brave-intl/bat-go/services/skus/model" + "github.com/brave-intl/bat-go/services/skus/xstripe" ) var ( @@ -102,6 +101,7 @@ type orderStoreSvc interface { AppendMetadata(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error AppendMetadataInt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error AppendMetadataInt64(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error + GetExpiredStripeCheckoutSessionID(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) (string, error) } type tlv2Store interface { @@ -120,6 +120,13 @@ type gpsMessageAuthenticator interface { authenticate(ctx context.Context, token string) error } +type stripeClient interface { + Session(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) + CreateSession(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) + Subscription(ctx context.Context, id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) + FindCustomer(ctx context.Context, email string) (*stripe.Customer, bool) +} + // Service contains datastore type Service struct { orderRepo orderStoreSvc @@ -136,6 +143,7 @@ type Service struct { geminiClient gemini.Client geminiConf *gemini.Conf scClient *client.API + stpClient stripeClient codecs map[string]*goavro.Codec kafkaWriter *kafka.Writer kafkaDialer *kafka.Dialer @@ -211,13 +219,14 @@ func InitService( // setup stripe if exists in context and enabled scClient := &client.API{} if enabled, ok := ctx.Value(appctx.StripeEnabledCTXKey).(bool); ok && enabled { - sublogger.Debug().Msg("stripe enabled") + stripe.EnableTelemetry = false + var err error stripe.Key, err = appctx.GetStringFromContext(ctx, appctx.StripeSecretCTXKey) if err != nil { sublogger.Panic().Err(err).Msg("failed to get Stripe secret from context, and Stripe enabled") } - // initialize stripe client + scClient.Init(stripe.Key, nil) } @@ -344,6 +353,7 @@ func InitService( geminiConf: geminiConf, cbClient: cbClient, scClient: scClient, + stpClient: xstripe.NewClient(scClient), pauseVoteUntilMu: sync.RWMutex{}, retry: backoff.Retry, radomClient: radomClient, @@ -532,25 +542,6 @@ func (s *Service) CreateOrderFromRequest(ctx context.Context, req model.CreateOr } defer func() { _ = tx2.Rollback() }() - if !order.IsPaid() { - // TODO: Remove this after confirming no calls are made to this for Premium orders. - if order.IsStripePayable() { - session, err := order.CreateStripeCheckoutSession( - req.Email, - parseURLAddOrderIDParam(stripeSuccessURI, order.ID), - parseURLAddOrderIDParam(stripeCancelURI, order.ID), - order.GetTrialDays(), - ) - if err != nil { - return nil, fmt.Errorf("failed to create checkout session: %w", err) - } - - if err := s.orderRepo.AppendMetadata(ctx, tx2, order.ID, "stripeCheckoutSessionId", session.SessionID); err != nil { - return nil, fmt.Errorf("failed to update order metadata: %w", err) - } - } - } - if numIntervals > 0 { if err := s.orderRepo.AppendMetadataInt(ctx, tx2, order.ID, "numIntervals", numIntervals); err != nil { return nil, fmt.Errorf("failed to update order metadata: %w", err) @@ -570,93 +561,98 @@ func (s *Service) CreateOrderFromRequest(ctx context.Context, req model.CreateOr return order, nil } -// GetOrder - business logic for getting an order, needs to validate the checkout session is not expired -func (s *Service) GetOrder(orderID uuid.UUID) (*Order, error) { - // get the order - order, err := s.Datastore.GetOrder(orderID) +func (s *Service) getTransformOrder(ctx context.Context, orderID uuid.UUID) (*Order, error) { + tx, err := s.Datastore.RawDB().BeginTxx(ctx, nil) if err != nil { - return nil, fmt.Errorf("failed to get order (%s): %w", orderID.String(), err) + return nil, err } + defer func() { _ = tx.Rollback() }() - if order != nil { - if !order.IsPaid() && order.IsStripePayable() { - order, err = s.TransformStripeOrder(order) - if err != nil { - return nil, fmt.Errorf("failed to transform stripe order (%s): %w", orderID.String(), err) - } - } + result, err := s.getTransformOrderTx(ctx, tx, orderID) + if err != nil { + return nil, err } - return order, nil + if err := tx.Commit(); err != nil { + return nil, err + } + return result, nil } -// TransformStripeOrder updates checkout session if expired, checks the status of the checkout session. -func (s *Service) TransformStripeOrder(order *Order) (*Order, error) { - ctx := context.Background() - - // check if this order has an expired checkout session - expired, cs, err := s.Datastore.CheckExpiredCheckoutSession(order.ID) +func (s *Service) getTransformOrderTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID) (*Order, error) { + ord, err := s.getOrderFullTx(ctx, dbi, orderID) if err != nil { - return nil, fmt.Errorf("failed to check for expired stripe checkout session: %w", err) + return nil, fmt.Errorf("failed to get order (%s): %w", orderID.String(), err) } - if expired { - // get old checkout session from stripe by id - stripeSession, err := session.Get(cs, nil) - if err != nil { - return nil, fmt.Errorf("failed to get stripe checkout session: %w", err) - } + // Nothing more to do for orders with a Stripe subscription. + if _, ok := ord.StripeSubID(); ok { + return ord, nil + } - checkoutSession, err := order.CreateStripeCheckoutSession( - getCustEmailFromStripeCheckout(stripeSession), - stripeSession.SuccessURL, stripeSession.CancelURL, - order.GetTrialDays(), - ) - if err != nil { - return nil, fmt.Errorf("failed to create checkout session: %w", err) - } + if !shouldTransformStripeOrder(ord) { + return ord, nil + } - err = s.Datastore.AppendOrderMetadata(ctx, &order.ID, "stripeCheckoutSessionId", checkoutSession.SessionID) - if err != nil { - return nil, fmt.Errorf("failed to update order metadata: %w", err) - } + if err := s.updateOrderStripeSession(ctx, dbi, ord); err != nil { + return nil, fmt.Errorf("failed to transform stripe order (%s): %w", orderID.String(), err) } - // if this is a stripe order, and there is a checkout session, we actually need to check it with - // stripe, as the redirect flow sometimes is too fast for the webhook to be delivered. - // exclude any order with a subscription identifier from stripe - if _, sOK := order.Metadata["stripeSubscriptionId"]; !sOK { - if cs, ok := order.Metadata["stripeCheckoutSessionId"].(string); ok && cs != "" { - // get old checkout session from stripe by id - sess, err := session.Get(cs, nil) - if err != nil { - return nil, fmt.Errorf("failed to get stripe checkout session: %w", err) - } + return s.getOrderFullTx(ctx, dbi, orderID) +} - // Set status to paid and the subscription id and if the session is actually paid. - if sess.PaymentStatus == "paid" { - if err = s.Datastore.UpdateOrder(order.ID, "paid"); err != nil { - return nil, fmt.Errorf("failed to update order to paid status: %w", err) - } +// updateOrderStripeSession checks the status of the checkout session, updates it if expired. +func (s *Service) updateOrderStripeSession(ctx context.Context, dbi sqlx.ExtContext, ord *Order) error { + expSessID, err := s.orderRepo.GetExpiredStripeCheckoutSessionID(ctx, dbi, ord.ID) + if err != nil && !errors.Is(err, model.ErrExpiredStripeCheckoutSessionIDNotFound) { + return fmt.Errorf("failed to check for expired stripe checkout session: %w", err) + } - if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "stripeSubscriptionId", sess.Subscription.ID); err != nil { - return nil, fmt.Errorf("failed to update order to add the subscription id") - } + var sessID string - if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "paymentProcessor", model.StripePaymentMethod); err != nil { - return nil, fmt.Errorf("failed to update order to add the payment processor") - } - } + if expSessID != "" { + nsessID, err := s.recreateStripeSession(ctx, dbi, ord, expSessID) + if err != nil { + return fmt.Errorf("failed to create checkout session: %w", err) } + + sessID = nsessID + } + + // Below goes some leagcy stuff. + // There was also a bug where the old subscription would be tested for payment. + // The code below did not take into account that the session could have been updated just above. + // + // If this is a stripe order, and there is a checkout session, check it with Stripe. + // The redirect flow sometimes is too fast for the webhook to be delivered. + sessID, ok := chooseStripeSessID(ord, sessID) + if !ok || sessID == "" { + // Nothing to do here. + return nil } - result, err := s.Datastore.GetOrder(order.ID) + sess, err := s.stpClient.Session(ctx, sessID, nil) if err != nil { - return nil, fmt.Errorf("failed to get order: %w", err) + return fmt.Errorf("failed to get stripe checkout session: %w", err) } - return result, nil + // Skip unpaid sessions. + if sess.PaymentStatus != "paid" { + return nil + } + + // Need to update the order as paid. + // This requires fetching the subscription as the expiry time is needed. + sub, err := s.stpClient.Subscription(ctx, sess.Subscription.ID, nil) + if err != nil { + return err + } + + expt := time.Unix(sub.CurrentPeriodEnd, 0).UTC() + paidt := time.Unix(sub.CurrentPeriodStart, 0).UTC() + + return s.renewOrderStripe(ctx, dbi, ord, sub.ID, expt, paidt) } // CancelOrder cancels an order, propagates to stripe if needed. @@ -727,28 +723,14 @@ func (s *Service) SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, day return nil } - // Recreate the stripe checkout session. - oldSessID, ok := ord.Metadata["stripeCheckoutSessionId"].(string) + oldSessID, ok := ord.StripeSessID() if !ok { return model.ErrNoStripeCheckoutSessID } - sess, err := session.Get(oldSessID, nil) - if err != nil { - return fmt.Errorf("failed to get stripe checkout session: %w", err) - } + _, err = s.recreateStripeSession(ctx, s.Datastore.RawDB(), ord, oldSessID) - cs, err := ord.CreateStripeCheckoutSession(getCustEmailFromStripeCheckout(sess), sess.SuccessURL, sess.CancelURL, ord.GetTrialDays()) - if err != nil { - return fmt.Errorf("failed to create checkout session: %w", err) - } - - // Overwrite the old checkout session. - if err := s.Datastore.AppendOrderMetadata(ctx, &ord.ID, "stripeCheckoutSessionId", cs.SessionID); err != nil { - return fmt.Errorf("failed to update order metadata: %w", err) - } - - return nil + return err } // UpdateOrderStatus checks to see if an order has been paid and updates it if so @@ -1071,19 +1053,6 @@ func (s *Service) IsOrderPaid(orderID uuid.UUID) (bool, error) { return sum.GreaterThanOrEqual(order.TotalPrice), nil } -func parseURLAddOrderIDParam(u string, orderID uuid.UUID) string { - // add order id to the stripe success and cancel urls - surl, err := url.Parse(u) - if err == nil { - surlv := surl.Query() - surlv.Add("order_id", orderID.String()) - surl.RawQuery = surlv.Encode() - return surl.String() - } - // there was a parse error, return whatever was given - return u -} - // UniqBatches returns the limit for active batches and the current number of active batches. func (s *Service) UniqBatches(ctx context.Context, orderID, itemID uuid.UUID) (int, int, error) { now := time.Now() @@ -1686,27 +1655,14 @@ func (s *Service) processStripeNotificationTx(ctx context.Context, dbi sqlx.ExtC return err } - if shouldUpdateOrderStripeSubID(ord, subID) { - if err := s.orderRepo.AppendMetadata(ctx, dbi, oid, "stripeSubscriptionId", subID); err != nil { - return err - } - } - expt, err := ntf.expiresTime() if err != nil { return err } - // Add 1-day leeway in case next billing cycle's webhook gets delayed. - expt = expt.Add(24 * time.Hour) - paidt := time.Now() - if err := s.renewOrderWithExpPaidTimeTx(ctx, dbi, ord.ID, expt, paidt); err != nil { - return err - } - - return s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "paymentProcessor", model.StripePaymentMethod) + return s.renewOrderStripe(ctx, dbi, ord, subID, expt, paidt) case ntf.shouldCancel(): oid, err := ntf.orderID() @@ -1884,7 +1840,7 @@ func (s *Service) createOrderPremium(ctx context.Context, req *model.CreateOrder if !order.IsPaid() { switch { case order.IsStripePayable(): - ssid, err := s.createStripeSessID(ctx, req, order) + ssid, err := s.createStripeSession(ctx, req, order) if err != nil { return nil, err } @@ -1964,7 +1920,7 @@ func (s *Service) createOrderIssuers(ctx context.Context, dbi sqlx.QueryerContex return numIntervals, nil } -func (s *Service) createStripeSessID(ctx context.Context, req *model.CreateOrderRequestNew, order *model.Order) (string, error) { +func (s *Service) createStripeSession(ctx context.Context, req *model.CreateOrderRequestNew, order *model.Order) (string, error) { oid := order.ID.String() surl, err := req.StripeMetadata.SuccessURL(oid) @@ -1977,12 +1933,16 @@ func (s *Service) createStripeSessID(ctx context.Context, req *model.CreateOrder return "", err } - sess, err := model.CreateStripeCheckoutSession(oid, req.Email, surl, curl, order.GetTrialDays(), order.Items) - if err != nil { - return "", fmt.Errorf("failed to create checkout session: %w", err) + sreq := createStripeSessionRequest{ + orderID: oid, + email: req.Email, + successURL: surl, + cancelURL: curl, + trialDays: order.GetTrialDays(), + items: buildStripeLineItems(order.Items), } - return sess.SessionID, nil + return createStripeSession(ctx, s.stpClient, sreq) } // TODO: Refactor the Radom-related logic. @@ -2067,7 +2027,7 @@ func (s *Service) renewOrderWithExpPaidTime(ctx context.Context, id uuid.UUID, e // renewOrderWithExpPaidTimeTx performs updates relevant to advancing a paid order forward after renewal. // // TODO: Add a repo method to update all three fields at once. -func (s *Service) renewOrderWithExpPaidTimeTx(ctx context.Context, dbi sqlx.ExtContext, id uuid.UUID, expt, paidt time.Time) error { +func (s *Service) renewOrderWithExpPaidTimeTx(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, expt, paidt time.Time) error { if err := s.orderRepo.SetStatus(ctx, dbi, id, model.OrderStatusPaid); err != nil { return err } @@ -2318,6 +2278,55 @@ func createOrderWithReceipt( return order, nil } +func (s *Service) renewOrderStripe(ctx context.Context, dbi sqlx.ExecerContext, ord *model.Order, subID string, expt, paidt time.Time) error { + if shouldUpdateOrderStripeSubID(ord, subID) { + if err := s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "stripeSubscriptionId", subID); err != nil { + return err + } + } + + // Add 1-day leeway in case next billing cycle's webhook gets delayed. + expt = expt.Add(24 * time.Hour) + + if err := s.renewOrderWithExpPaidTimeTx(ctx, dbi, ord.ID, expt, paidt); err != nil { + return err + } + + // Skip updating payment processor if it's already Stripe. + if ord.IsStripe() { + return nil + } + + return s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "paymentProcessor", model.StripePaymentMethod) +} + +func (s *Service) recreateStripeSession(ctx context.Context, dbi sqlx.ExecerContext, ord *model.Order, oldSessID string) (string, error) { + oldSess, err := s.stpClient.Session(ctx, oldSessID, nil) + if err != nil { + return "", err + } + + req := createStripeSessionRequest{ + orderID: ord.ID.String(), + email: xstripe.CustomerEmailFromSession(oldSess), + successURL: oldSess.SuccessURL, + cancelURL: oldSess.CancelURL, + trialDays: ord.GetTrialDays(), + items: buildStripeLineItems(ord.Items), + } + + sessID, err := createStripeSession(ctx, s.stpClient, req) + if err != nil { + return "", err + } + + if err := s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "stripeCheckoutSessionId", sessID); err != nil { + return "", err + } + + return sessID, nil +} + func newOrderNewForReq(req *model.CreateOrderRequestNew, items []model.OrderItem, merchID, status string) (*model.OrderNew, error) { // Check for number of items to be above 0. // @@ -2497,3 +2506,84 @@ func shouldUpdateOrderStripeSubID(ord *model.Order, subID string) bool { return false } + +func shouldTransformStripeOrder(ord *model.Order) bool { + if ord.IsIOS() { + return false + } + + if ord.IsAndroid() { + return false + } + + return !ord.IsPaid() && ord.IsStripePayable() +} + +func chooseStripeSessID(ord *model.Order, canBeNewSessID string) (string, bool) { + if canBeNewSessID != "" { + return canBeNewSessID, true + } + + return ord.StripeSessID() +} + +type createStripeSessionRequest struct { + orderID string + email string + successURL string + cancelURL string + trialDays int64 + items []*stripe.CheckoutSessionLineItemParams +} + +func createStripeSession(ctx context.Context, cl stripeClient, req createStripeSessionRequest) (string, error) { + params := &stripe.CheckoutSessionParams{ + PaymentMethodTypes: []*string{ptrTo("card")}, + Mode: ptrTo(string(stripe.CheckoutSessionModeSubscription)), + SuccessURL: &req.successURL, + CancelURL: &req.cancelURL, + ClientReferenceID: &req.orderID, + SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{}, + LineItems: req.items, + } + + if custID, ok := cl.FindCustomer(ctx, req.email); ok { + params.Customer = &custID.ID + } else { + if req.email != "" { + params.CustomerEmail = &req.email + } + } + + if req.trialDays > 0 { + params.SubscriptionData.TrialPeriodDays = &req.trialDays + } + + params.SubscriptionData.AddMetadata("orderID", req.orderID) + params.AddExtra("allow_promotion_codes", "true") + + sess, err := cl.CreateSession(ctx, params) + if err != nil { + return "", err + } + + return sess.ID, nil +} + +func buildStripeLineItems(items []model.OrderItem) []*stripe.CheckoutSessionLineItemParams { + var result []*stripe.CheckoutSessionLineItemParams + + for i := range items { + priceID, ok := items[i].StripeItemID() + if !ok { + continue + } + + result = append(result, &stripe.CheckoutSessionLineItemParams{ + Price: ptrTo(priceID), + Quantity: ptrTo(int64(items[i].Quantity)), + }) + } + + return result +} diff --git a/services/skus/service_nonint_test.go b/services/skus/service_nonint_test.go index fde41b141..63ef0a52b 100644 --- a/services/skus/service_nonint_test.go +++ b/services/skus/service_nonint_test.go @@ -24,6 +24,7 @@ import ( "github.com/brave-intl/bat-go/services/skus/model" "github.com/brave-intl/bat-go/services/skus/storage/repository" + "github.com/brave-intl/bat-go/services/skus/xstripe" ) func TestService_uniqBatchesTxTime(t *testing.T) { @@ -2327,7 +2328,7 @@ func TestService_processStripeNotificationTx(t *testing.T) { }, { - name: "renew_should_update_sub_id_error", + name: "renew_expires_time_error", given: tcGiven{ ntf: &stripeNotification{ raw: &stripe.Event{Type: "invoice.paid"}, @@ -2347,25 +2348,20 @@ func TestService_processStripeNotificationTx(t *testing.T) { ordRepo: &repository.MockOrder{ FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { result := &model.Order{ - ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), - Metadata: datastore.Metadata{ - "stripeSubscriptionId": "wrong_sub_id", - }, + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Metadata: datastore.Metadata{}, } return result, nil }, - FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { - return model.Error("something_went_wrong") - }, }, phRepo: &repository.MockOrderPayHistory{}, }, - exp: model.Error("something_went_wrong"), + exp: errStripeInvalidSubPeriod, }, { - name: "renew_expires_time_error", + name: "renew_should_update_sub_id_error", given: tcGiven{ ntf: &stripeNotification{ raw: &stripe.Event{Type: "invoice.paid"}, @@ -2377,6 +2373,10 @@ func TestService_processStripeNotificationTx(t *testing.T) { Metadata: map[string]string{ "orderID": "facade00-0000-4000-a000-000000000000", }, + Period: &stripe.Period{ + Start: 1719792001, + End: 1722470400, + }, }, }, }, @@ -2385,16 +2385,21 @@ func TestService_processStripeNotificationTx(t *testing.T) { ordRepo: &repository.MockOrder{ FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { result := &model.Order{ - ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), - Metadata: datastore.Metadata{}, + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Metadata: datastore.Metadata{ + "stripeSubscriptionId": "wrong_sub_id", + }, } return result, nil }, + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + return model.Error("something_went_wrong") + }, }, phRepo: &repository.MockOrderPayHistory{}, }, - exp: errStripeInvalidSubPeriod, + exp: model.Error("something_went_wrong"), }, { @@ -2624,6 +2629,246 @@ func TestService_processStripeNotificationTx(t *testing.T) { } } +func TestService_renewOrderStripe(t *testing.T) { + type tcGiven struct { + ordRepo *repository.MockOrder + payRepo *repository.MockOrderPayHistory + ord *model.Order + subID string + expt time.Time + paidt time.Time + } + + type testCase struct { + name string + given tcGiven + exp error + } + + tests := []testCase{ + { + name: "should_update_sub_id_error", + given: tcGiven{ + ordRepo: &repository.MockOrder{ + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + return model.Error("something_went_wrong") + }, + }, + payRepo: &repository.MockOrderPayHistory{}, + ord: &model.Order{ + Metadata: datastore.Metadata{ + "stripeSubscriptionId": "old_sub_id", + }, + }, + subID: "sub_id", + expt: time.Date(2024, time.July, 1, 0, 0, 0, 0, time.UTC), + paidt: time.Date(2024, time.June, 1, 0, 0, 1, 0, time.UTC), + }, + exp: model.Error("something_went_wrong"), + }, + + { + name: "renew_order_error", + given: tcGiven{ + ordRepo: &repository.MockOrder{ + FnSetStatus: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error { + return model.Error("something_went_wrong") + }, + }, + payRepo: &repository.MockOrderPayHistory{}, + ord: &model.Order{ + Metadata: datastore.Metadata{ + "stripeSubscriptionId": "old_sub_id", + }, + }, + subID: "sub_id", + expt: time.Date(2024, time.July, 1, 0, 0, 0, 0, time.UTC), + paidt: time.Date(2024, time.June, 1, 0, 0, 1, 0, time.UTC), + }, + exp: model.Error("something_went_wrong"), + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + svc := &Service{orderRepo: tc.given.ordRepo, payHistRepo: tc.given.payRepo} + + ctx := context.Background() + + actual := svc.renewOrderStripe(ctx, nil, tc.given.ord, tc.given.subID, tc.given.expt, tc.given.paidt) + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestService_recreateStripeSession(t *testing.T) { + type tcGiven struct { + ordRepo *repository.MockOrder + cl *xstripe.MockClient + ord *model.Order + oldSessID string + } + + type tcExpected struct { + val string + err error + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "unable_fetch_old_session", + given: tcGiven{ + ordRepo: &repository.MockOrder{}, + cl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return nil, model.Error("something_went_wrong") + }, + }, + ord: &model.Order{}, + oldSessID: "cs_test_id_old", + }, + exp: tcExpected{ + err: model.Error("something_went_wrong"), + }, + }, + + { + name: "unable_create_session", + given: tcGiven{ + ordRepo: &repository.MockOrder{}, + cl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + result := &stripe.CheckoutSession{ + ID: "cs_test_id_old", + SuccessURL: "https://example.com/success", + CancelURL: "https://example.com/cancel", + Customer: &stripe.Customer{Email: "you@example.com"}, + } + + return result, nil + }, + + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return nil, model.Error("something_went_wrong") + }, + }, + ord: &model.Order{ + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Items: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id"}, + }, + }, + }, + oldSessID: "cs_test_id_old", + }, + exp: tcExpected{ + err: model.Error("something_went_wrong"), + }, + }, + + { + name: "unable_append_metadata", + given: tcGiven{ + ordRepo: &repository.MockOrder{ + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + return model.Error("something_went_wrong") + }, + }, + cl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + result := &stripe.CheckoutSession{ + ID: "cs_test_id_old", + SuccessURL: "https://example.com/success", + CancelURL: "https://example.com/cancel", + Customer: &stripe.Customer{Email: "you@example.com"}, + } + + return result, nil + }, + }, + ord: &model.Order{ + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Items: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id"}, + }, + }, + }, + oldSessID: "cs_test_id_old", + }, + exp: tcExpected{ + err: model.Error("something_went_wrong"), + }, + }, + + { + name: "success", + given: tcGiven{ + ordRepo: &repository.MockOrder{ + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + if key == "stripeCheckoutSessionId" && val == "cs_test_id" { + return nil + } + + return model.Error("unexpected") + }, + }, + cl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + result := &stripe.CheckoutSession{ + ID: "cs_test_id_old", + SuccessURL: "https://example.com/success", + CancelURL: "https://example.com/cancel", + Customer: &stripe.Customer{Email: "you@example.com"}, + } + + return result, nil + }, + }, + ord: &model.Order{ + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Items: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id"}, + }, + }, + }, + oldSessID: "cs_test_id_old", + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + svc := &Service{orderRepo: tc.given.ordRepo, stpClient: tc.given.cl} + + ctx := context.Background() + + actual, err := svc.recreateStripeSession(ctx, nil, tc.given.ord, tc.given.oldSessID) + must.Equal(t, tc.exp.err, err) + + should.Equal(t, tc.exp.val, actual) + }) + } +} + func TestShouldUpdateOrderStripeSubID(t *testing.T) { type tcGiven struct { ord *model.Order @@ -2695,6 +2940,406 @@ func TestShouldUpdateOrderStripeSubID(t *testing.T) { } } +func TestShouldTransformStripeOrder(t *testing.T) { + type testCase struct { + name string + given *model.Order + exp bool + } + + tests := []testCase{ + { + name: "false_ios", + given: &model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "ios", + "vendor": "ios", + }, + }, + }, + + { + name: "false_android", + given: &model.Order{ + Metadata: datastore.Metadata{ + "paymentProcessor": "android", + "vendor": "android", + }, + }, + }, + + { + name: "false_paid", + given: &model.Order{Status: model.OrderStatusPaid}, + }, + + { + name: "false_non_stripe", + given: &model.Order{Status: model.OrderStatusPending}, + }, + + { + name: "true_unpaid_stripe", + given: &model.Order{ + Status: model.OrderStatusPending, + AllowedPaymentMethods: pq.StringArray([]string{"stripe"}), + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := shouldTransformStripeOrder(tc.given) + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestChooseStripeSessID(t *testing.T) { + type tcGiven struct { + ord *model.Order + newSessID string + } + + type tcExpected struct { + val string + ok bool + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "new_sess_id_no_old_sess_id", + given: tcGiven{ + ord: &model.Order{}, + newSessID: "new_sess_id", + }, + exp: tcExpected{ + val: "new_sess_id", + ok: true, + }, + }, + + { + name: "new_sess_id_old_sess_id", + given: tcGiven{ + ord: &model.Order{ + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": "sess_id", + }, + }, + newSessID: "new_sess_id", + }, + exp: tcExpected{ + val: "new_sess_id", + ok: true, + }, + }, + + { + name: "no_new_sess_id_no_old_sess_id", + given: tcGiven{ + ord: &model.Order{}, + }, + exp: tcExpected{}, + }, + + { + name: "new_sess_id_old_sess_id", + given: tcGiven{ + ord: &model.Order{ + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": "sess_id", + }, + }, + }, + exp: tcExpected{ + val: "sess_id", + ok: true, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, ok := chooseStripeSessID(tc.given.ord, tc.given.newSessID) + should.Equal(t, tc.exp.ok, ok) + + should.Equal(t, tc.exp.val, actual) + }) + } +} + +func TestCreateStripeSession(t *testing.T) { + type tcGiven struct { + cl *xstripe.MockClient + req createStripeSessionRequest + } + + type tcExpected struct { + val string + err error + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "success_found_customer", + given: tcGiven{ + cl: &xstripe.MockClient{ + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if params.Customer == nil || *params.Customer != "cus_id" { + return nil, model.Error("unexpected") + } + + result := &stripe.CheckoutSession{ID: "cs_test_id"} + + return result, nil + }, + }, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + email: "you@example.com", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + trialDays: 7, + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "success_customer_not_found", + given: tcGiven{ + cl: &xstripe.MockClient{ + FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { + return nil, false + }, + + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if params.CustomerEmail == nil || *params.CustomerEmail != "you@example.com" { + return nil, model.Error("unexpected") + } + + result := &stripe.CheckoutSession{ID: "cs_test_id"} + + return result, nil + }, + }, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + email: "you@example.com", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + trialDays: 7, + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "success_no_customer_email", + given: tcGiven{ + cl: &xstripe.MockClient{ + FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { + return nil, false + }, + }, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + trialDays: 7, + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "success_no_trial_days", + given: tcGiven{ + cl: &xstripe.MockClient{}, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + email: "you@example.com", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "create_error", + given: tcGiven{ + cl: &xstripe.MockClient{ + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return nil, model.Error("something_went_wrong") + }, + }, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + email: "you@example.com", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + trialDays: 7, + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + err: model.Error("something_went_wrong"), + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + actual, err := createStripeSession(ctx, tc.given.cl, tc.given.req) + must.Equal(t, tc.exp.err, err) + + should.Equal(t, tc.exp.val, actual) + }) + } +} + +func TestBuildStripeLineItems(t *testing.T) { + tests := []struct { + name string + given []model.OrderItem + exp []*stripe.CheckoutSessionLineItemParams + }{ + { + name: "nil", + }, + + { + name: "empty_nil", + given: []model.OrderItem{}, + }, + + { + name: "empty_no_price_id", + given: []model.OrderItem{ + { + Metadata: datastore.Metadata{"key": "value"}, + }, + }, + }, + + { + name: "one_item", + given: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id"}, + }, + }, + exp: []*stripe.CheckoutSessionLineItemParams{ + { + Price: ptrTo("stripe_item_id"), + Quantity: ptrTo[int64](1), + }, + }, + }, + + { + name: "two_items", + given: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id_01"}, + }, + + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id_02"}, + }, + }, + exp: []*stripe.CheckoutSessionLineItemParams{ + { + Price: ptrTo("stripe_item_id_01"), + Quantity: ptrTo[int64](1), + }, + + { + Price: ptrTo("stripe_item_id_02"), + Quantity: ptrTo[int64](1), + }, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := buildStripeLineItems(tc.given) + should.Equal(t, tc.exp, actual) + }) + } +} + type mockPaidOrderCreator struct { fnCreateOrderPremium func(ctx context.Context, req *model.CreateOrderRequestNew, ordNew *model.OrderNew, items []model.OrderItem) (*model.Order, error) fnRenewOrderWithExpPaidTime func(ctx context.Context, id uuid.UUID, expt, paidt time.Time) error diff --git a/services/skus/storage/repository/mock.go b/services/skus/storage/repository/mock.go index cc6438ef5..3bd687b49 100644 --- a/services/skus/storage/repository/mock.go +++ b/services/skus/storage/repository/mock.go @@ -13,15 +13,16 @@ import ( ) type MockOrder struct { - FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) - FnGetByExternalID func(ctx context.Context, dbi sqlx.QueryerContext, extID string) (*model.Order, error) - FnCreate func(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) - FnSetStatus func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error - FnSetExpiresAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error - FnSetLastPaidAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error - FnAppendMetadata func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error - FnAppendMetadataInt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error - FnAppendMetadataInt64 func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error + FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) + FnGetByExternalID func(ctx context.Context, dbi sqlx.QueryerContext, extID string) (*model.Order, error) + FnCreate func(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) + FnSetStatus func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error + FnSetExpiresAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error + FnSetLastPaidAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error + FnAppendMetadata func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error + FnAppendMetadataInt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error + FnAppendMetadataInt64 func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error + FnGetExpiredStripeCheckoutSessionID func(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) (string, error) } func (r *MockOrder) Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { @@ -115,6 +116,14 @@ func (r *MockOrder) AppendMetadataInt64(ctx context.Context, dbi sqlx.ExecerCont return r.FnAppendMetadataInt64(ctx, dbi, id, key, val) } +func (r *MockOrder) GetExpiredStripeCheckoutSessionID(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) (string, error) { + if r.FnGetExpiredStripeCheckoutSessionID == nil { + return "sub_id", nil + } + + return r.FnGetExpiredStripeCheckoutSessionID(ctx, dbi, orderID) +} + type MockOrderItem struct { FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.OrderItem, error) FnFindByOrderID func(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) ([]model.OrderItem, error) diff --git a/services/skus/xstripe/mock.go b/services/skus/xstripe/mock.go new file mode 100644 index 000000000..442ff3f57 --- /dev/null +++ b/services/skus/xstripe/mock.go @@ -0,0 +1,73 @@ +package xstripe + +import ( + "context" + + "github.com/stripe/stripe-go/v72" +) + +type MockClient struct { + FnSession func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) + FnCreateSession func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) + FnSubscription func(ctx context.Context, id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) + FnFindCustomer func(ctx context.Context, email string) (*stripe.Customer, bool) +} + +func (c *MockClient) Session(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if c.FnSession == nil { + result := &stripe.CheckoutSession{ID: id} + + return result, nil + } + + return c.FnSession(ctx, id, params) +} + +func (c *MockClient) CreateSession(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if c.FnCreateSession == nil { + result := &stripe.CheckoutSession{ + ID: "cs_test_id", + PaymentMethodTypes: []string{"card"}, + Mode: stripe.CheckoutSessionModeSubscription, + SuccessURL: *params.SuccessURL, + CancelURL: *params.CancelURL, + ClientReferenceID: *params.ClientReferenceID, + Subscription: &stripe.Subscription{ + ID: "sub_id", + Metadata: map[string]string{ + "orderID": *params.ClientReferenceID, + }, + }, + AllowPromotionCodes: true, + } + + return result, nil + } + + return c.FnCreateSession(ctx, params) +} + +func (c *MockClient) Subscription(ctx context.Context, id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) { + if c.FnSubscription == nil { + result := &stripe.Subscription{ + ID: id, + } + + return result, nil + } + + return c.FnSubscription(ctx, id, params) +} + +func (c *MockClient) FindCustomer(ctx context.Context, email string) (*stripe.Customer, bool) { + if c.FnFindCustomer == nil { + result := &stripe.Customer{ + ID: "cus_id", + Email: email, + } + + return result, true + } + + return c.FnFindCustomer(ctx, email) +} diff --git a/services/skus/xstripe/xstripe.go b/services/skus/xstripe/xstripe.go new file mode 100644 index 000000000..78a78f821 --- /dev/null +++ b/services/skus/xstripe/xstripe.go @@ -0,0 +1,60 @@ +package xstripe + +import ( + "context" + + "github.com/stripe/stripe-go/v72" + "github.com/stripe/stripe-go/v72/client" + "github.com/stripe/stripe-go/v72/customer" +) + +type Client struct { + cl *client.API +} + +func NewClient(cl *client.API) *Client { + return &Client{cl: cl} +} + +func (c *Client) Session(_ context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return c.cl.CheckoutSessions.Get(id, params) +} + +func (c *Client) CreateSession(_ context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return c.cl.CheckoutSessions.New(params) +} + +func (c *Client) Subscription(_ context.Context, id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) { + return c.cl.Subscriptions.Get(id, params) +} + +func (c *Client) FindCustomer(ctx context.Context, email string) (*stripe.Customer, bool) { + iter := c.Customers(ctx, &stripe.CustomerListParams{ + Email: stripe.String(email), + }) + + for iter.Next() { + return iter.Customer(), true + } + + return nil, false +} + +func (c *Client) Customers(_ context.Context, params *stripe.CustomerListParams) *customer.Iter { + return c.cl.Customers.List(params) +} + +func CustomerEmailFromSession(sess *stripe.CheckoutSession) string { + // Use the customer email if the customer has completed the payment flow. + if sess.Customer != nil && sess.Customer.Email != "" { + return sess.Customer.Email + } + + // This is unlikely to be set, but in case it is, use it. + if sess.CustomerEmail != "" { + return sess.CustomerEmail + } + + // Default to empty, Stripe will ask the customer. + return "" +} diff --git a/services/skus/order_noint_test.go b/services/skus/xstripe/xstripe_test.go similarity index 90% rename from services/skus/order_noint_test.go rename to services/skus/xstripe/xstripe_test.go index 94b87d01e..6af57ac4f 100644 --- a/services/skus/order_noint_test.go +++ b/services/skus/xstripe/xstripe_test.go @@ -1,4 +1,4 @@ -package skus +package xstripe import ( "testing" @@ -7,7 +7,7 @@ import ( "github.com/stripe/stripe-go/v72" ) -func TestGetCustEmailFromStripeCheckout(t *testing.T) { +func TestCustomerEmailFromSession(t *testing.T) { tests := []struct { name string exp string @@ -60,7 +60,7 @@ func TestGetCustEmailFromStripeCheckout(t *testing.T) { tc := tests[i] t.Run(tc.name, func(t *testing.T) { - actual := getCustEmailFromStripeCheckout(tc.given) + actual := CustomerEmailFromSession(tc.given) should.Equal(t, tc.exp, actual) }) }