From a2d6740bf9c70ac25ef614e50e87927c4e51e372 Mon Sep 17 00:00:00 2001 From: PavelBrm Date: Tue, 23 Jul 2024 01:12:57 +1200 Subject: [PATCH] fix: use proper expiry time when fixing up order while getting --- services/skus/controllers.go | 125 +++-------------- services/skus/controllers_test.go | 4 +- services/skus/model/model.go | 6 + services/skus/service.go | 171 +++++++++++++++-------- services/skus/storage/repository/mock.go | 27 ++-- 5 files changed, 156 insertions(+), 177 deletions(-) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index c2d55cef9..65d6adcab 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -2,7 +2,6 @@ package skus import ( "context" - "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -19,7 +18,6 @@ import ( uuid "github.com/satori/go.uuid" "github.com/stripe/stripe-go/v72/webhook" - "github.com/brave-intl/bat-go/libs/clients/radom" appctx "github.com/brave-intl/bat-go/libs/context" "github.com/brave-intl/bat-go/libs/handlers" "github.com/brave-intl/bat-go/libs/inputs" @@ -72,7 +70,7 @@ func Router( { corsMwrGet := NewCORSMwr(copts, http.MethodGet) r.Method(http.MethodOptions, "/{orderID}", metricsMwr("GetOrderOptions", corsMwrGet(nil))) - r.Method(http.MethodGet, "/{orderID}", metricsMwr("GetOrder", corsMwrGet(GetOrder(svc)))) + r.Method(http.MethodGet, "/{orderID}", metricsMwr("GetOrder", corsMwrGet(handleGetOrder(svc)))) } r.Method( @@ -370,30 +368,30 @@ func CancelOrder(service *Service) handlers.AppHandler { }) } -// GetOrder is the handler for getting an order -func GetOrder(service *Service) handlers.AppHandler { +func handleGetOrder(svc *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var orderID = new(inputs.ID) - if err := inputs.DecodeAndValidateString(context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { - return handlers.ValidationError( - "Error validating request url parameter", - map[string]interface{}{ - "orderID": err.Error(), - }, - ) - } + ctx := r.Context() - order, err := service.GetOrder(*orderID.UUID()) + orderID, err := uuid.FromString(chi.URLParamFromCtx(ctx, "orderID")) if err != nil { - return handlers.WrapError(err, "Error retrieving the order", http.StatusInternalServerError) + return handlers.ValidationError("request", map[string]interface{}{"orderID": err.Error()}) } - status := http.StatusOK - if order == nil { - status = http.StatusNotFound + order, err := svc.getTransformOrder(ctx, orderID) + if err != nil { + switch { + case errors.Is(err, context.Canceled): + return handlers.WrapError(model.ErrSomethingWentWrong, "request has been cancelled", model.StatusClientClosedConn) + + case errors.Is(err, model.ErrOrderNotFound): + return handlers.WrapError(err, "order not found", http.StatusNotFound) + + default: + return handlers.WrapError(err, "Error retrieving the order", http.StatusInternalServerError) + } } - return handlers.RenderContent(r.Context(), order, w, status) + return handlers.RenderContent(ctx, order, w, http.StatusOK) }) } @@ -993,7 +991,7 @@ func WebhookRouter(svc *Service) chi.Router { r := chi.NewRouter() r.Method(http.MethodPost, "/stripe", middleware.InstrumentHandler("HandleStripeWebhook", handleStripeWebhook(svc))) - r.Method(http.MethodPost, "/radom", middleware.InstrumentHandler("HandleRadomWebhook", HandleRadomWebhook(svc))) + // r.Method(http.MethodPost, "/radom", middleware.InstrumentHandler("HandleRadomWebhook", HandleRadomWebhook(svc))) r.Method(http.MethodPost, "/android", middleware.InstrumentHandler("handleWebhookPlayStore", handleWebhookPlayStore(svc))) r.Method(http.MethodPost, "/ios", middleware.InstrumentHandler("handleWebhookAppStore", handleWebhookAppStore(svc))) @@ -1183,91 +1181,6 @@ func handleWebhookAppStoreH(w http.ResponseWriter, r *http.Request, svc *Service return handlers.RenderContent(ctx, struct{}{}, w, http.StatusOK) } -// HandleRadomWebhook handles Radom checkout session webhooks. -func HandleRadomWebhook(service *Service) handlers.AppHandler { - return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - ctx := r.Context() - - lg := logging.Logger(ctx, "payments").With().Str("func", "HandleRadomWebhook").Logger() - - // Get webhook secret. - endpointSecret, err := appctx.GetStringFromContext(ctx, appctx.RadomWebhookSecretCTXKey) - if err != nil { - lg.Error().Err(err).Msg("failed to get radom_webhook_secret from context") - return handlers.WrapError(err, "error getting radom_webhook_secret from context", http.StatusInternalServerError) - } - - // Check verification key. - if subtle.ConstantTimeCompare([]byte(r.Header.Get("radom-verification-key")), []byte(endpointSecret)) != 1 { - lg.Error().Err(err).Msg("invalid verification key from webhook") - return handlers.WrapError(err, "invalid verification key", http.StatusBadRequest) - } - - req := radom.WebhookRequest{} - if err := requestutils.ReadJSON(ctx, r.Body, &req); err != nil { - lg.Error().Err(err).Msg("failed to read request body") - return handlers.WrapError(err, "error reading request body", http.StatusServiceUnavailable) - } - - lg.Debug().Str("event_type", req.EventType).Str("data", fmt.Sprintf("%+v", req)).Msg("webhook event captured") - - // Handle only successful payment events. - if req.EventType != "managedRecurringPayment" && req.EventType != "newSubscription" { - return handlers.WrapError(err, "event type not implemented", http.StatusBadRequest) - } - - // Lookup the order, the checkout session was created with orderId in metadata. - rawOrderID, err := req.Data.CheckoutSession.Metadata.Get("braveOrderId") - if err != nil || rawOrderID == "" { - return handlers.WrapError(err, "brave metadata not found in webhook", http.StatusBadRequest) - } - - orderID, err := uuid.FromString(rawOrderID) - if err != nil { - return handlers.WrapError(err, "invalid braveOrderId in request", http.StatusBadRequest) - } - - // Set order id to paid, and update metadata values. - if err := service.Datastore.UpdateOrder(orderID, OrderStatusPaid); err != nil { - lg.Error().Err(err).Msg("failed to update order status") - return handlers.WrapError(err, "error updating order status", http.StatusInternalServerError) - } - - if err := service.Datastore.AppendOrderMetadata( - ctx, &orderID, "radomCheckoutSession", req.Data.CheckoutSession.CheckoutSessionID); err != nil { - lg.Error().Err(err).Msg("failed to update order metadata") - return handlers.WrapError(err, "error updating order metadata", http.StatusInternalServerError) - } - - if req.EventType == "newSubscription" { - - if err := service.Datastore.AppendOrderMetadata( - ctx, &orderID, "subscriptionId", req.EventData.NewSubscription.SubscriptionID); err != nil { - lg.Error().Err(err).Msg("failed to update order metadata") - return handlers.WrapError(err, "error updating order metadata", http.StatusInternalServerError) - } - - if err := service.Datastore.AppendOrderMetadata( - ctx, &orderID, "subscriptionContractAddress", - req.EventData.NewSubscription.Subscription.AutomatedEVMSubscription.SubscriptionContractAddress); err != nil { - - lg.Error().Err(err).Msg("failed to update order metadata") - return handlers.WrapError(err, "error updating order metadata", http.StatusInternalServerError) - } - - } - - // Set paymentProcessor to Radom. - if err := service.Datastore.AppendOrderMetadata(ctx, &orderID, "paymentProcessor", model.RadomPaymentMethod); err != nil { - lg.Error().Err(err).Msg("failed to update order to add the payment processor") - return handlers.WrapError(err, "failed to update order to add the payment processor", http.StatusInternalServerError) - } - - lg.Debug().Str("orderID", orderID.String()).Msg("order is now paid") - return handlers.RenderContent(ctx, "payment successful", w, http.StatusOK) - } -} - func handleStripeWebhook(svc *Service) handlers.AppHandler { return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { ctx := r.Context() diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index d8efc75d2..28a02d460 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -407,7 +407,7 @@ func (suite *ControllersTestSuite) TestGetOrder() { req, err := http.NewRequest("GET", "/v1/orders/{orderID}", nil) suite.Require().NoError(err) - getOrderHandler := GetOrder(suite.service) + getOrderHandler := handleGetOrder(suite.service) rctx := chi.NewRouteContext() rctx.URLParams.Add("orderID", order.ID.String()) getReq := req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) @@ -435,7 +435,7 @@ func (suite *ControllersTestSuite) TestGetMissingOrder() { req, err := http.NewRequest("GET", "/v1/orders/{orderID}", nil) suite.Require().NoError(err) - getOrderHandler := GetOrder(suite.service) + getOrderHandler := handleGetOrder(suite.service) rctx := chi.NewRouteContext() rctx.URLParams.Add("orderID", "9645ca16-bc93-4e37-8edf-cb35b1763216") getReq := req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) diff --git a/services/skus/model/model.go b/services/skus/model/model.go index 77856e7f0..ee997dbd2 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -341,6 +341,12 @@ func (o *Order) StripeSubID() (string, bool) { return sid, ok } +func (o *Order) StripeSessID() (string, bool) { + sessID, ok := o.Metadata["stripeCheckoutSessionId"].(string) + + return sessID, ok +} + func (o *Order) IsIOS() bool { pp, ok := o.PaymentProc() if !ok { diff --git a/services/skus/service.go b/services/skus/service.go index 49d972f7d..869dc43a3 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -102,6 +102,7 @@ type orderStoreSvc interface { AppendMetadata(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error AppendMetadataInt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error AppendMetadataInt64(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error + GetExpiredStripeCheckoutSessionID(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) (string, error) } type tlv2Store interface { @@ -211,13 +212,14 @@ func InitService( // setup stripe if exists in context and enabled scClient := &client.API{} if enabled, ok := ctx.Value(appctx.StripeEnabledCTXKey).(bool); ok && enabled { - sublogger.Debug().Msg("stripe enabled") + stripe.EnableTelemetry = false + var err error stripe.Key, err = appctx.GetStringFromContext(ctx, appctx.StripeSecretCTXKey) if err != nil { sublogger.Panic().Err(err).Msg("failed to get Stripe secret from context, and Stripe enabled") } - // initialize stripe client + scClient.Init(stripe.Key, nil) } @@ -570,93 +572,122 @@ func (s *Service) CreateOrderFromRequest(ctx context.Context, req model.CreateOr return order, nil } -// GetOrder - business logic for getting an order, needs to validate the checkout session is not expired -func (s *Service) GetOrder(orderID uuid.UUID) (*Order, error) { - // get the order - order, err := s.Datastore.GetOrder(orderID) +func (s *Service) getTransformOrder(ctx context.Context, orderID uuid.UUID) (*Order, error) { + tx, err := s.Datastore.RawDB().BeginTxx(ctx, nil) if err != nil { - return nil, fmt.Errorf("failed to get order (%s): %w", orderID.String(), err) + return nil, err } + defer func() { _ = tx.Rollback() }() - if order != nil { - if !order.IsPaid() && order.IsStripePayable() { - order, err = s.TransformStripeOrder(order) - if err != nil { - return nil, fmt.Errorf("failed to transform stripe order (%s): %w", orderID.String(), err) - } - } + result, err := s.getTransformOrderTx(ctx, tx, orderID) + if err != nil { + return nil, err } - return order, nil + if err := tx.Commit(); err != nil { + return nil, err + } + return result, nil } -// TransformStripeOrder updates checkout session if expired, checks the status of the checkout session. -func (s *Service) TransformStripeOrder(order *Order) (*Order, error) { - ctx := context.Background() - - // check if this order has an expired checkout session - expired, cs, err := s.Datastore.CheckExpiredCheckoutSession(order.ID) +func (s *Service) getTransformOrderTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID) (*Order, error) { + ord, err := s.getOrderFullTx(ctx, dbi, orderID) if err != nil { - return nil, fmt.Errorf("failed to check for expired stripe checkout session: %w", err) + return nil, fmt.Errorf("failed to get order (%s): %w", orderID.String(), err) + } + + if !shouldTransformStripeOrder(ord) { + return ord, nil } - if expired { - // get old checkout session from stripe by id - stripeSession, err := session.Get(cs, nil) + if err := s.updateOrderStripeSession(ctx, dbi, ord); err != nil { + return nil, fmt.Errorf("failed to transform stripe order (%s): %w", orderID.String(), err) + } + + return s.getOrderFullTx(ctx, dbi, orderID) +} + +// updateOrderStripeSession checks the status of the checkout session, updates it if expired. +func (s *Service) updateOrderStripeSession(ctx context.Context, dbi sqlx.ExtContext, ord *Order) error { + expSessID, err := s.orderRepo.GetExpiredStripeCheckoutSessionID(ctx, dbi, ord.ID) + if err != nil && !errors.Is(err, model.ErrExpiredStripeCheckoutSessionIDNotFound) { + return fmt.Errorf("failed to check for expired stripe checkout session: %w", err) + } + + var sessID string + + if expSessID != "" { + expSess, err := session.Get(expSessID, nil) if err != nil { - return nil, fmt.Errorf("failed to get stripe checkout session: %w", err) + return fmt.Errorf("failed to get stripe checkout session: %w", err) } - checkoutSession, err := order.CreateStripeCheckoutSession( - getCustEmailFromStripeCheckout(stripeSession), - stripeSession.SuccessURL, stripeSession.CancelURL, - order.GetTrialDays(), + sess, err := model.CreateStripeCheckoutSession( + ord.ID.String(), + getCustEmailFromStripeCheckout(expSess), + expSess.SuccessURL, + expSess.CancelURL, + ord.GetTrialDays(), + ord.Items, ) if err != nil { - return nil, fmt.Errorf("failed to create checkout session: %w", err) + return fmt.Errorf("failed to create checkout session: %w", err) } - err = s.Datastore.AppendOrderMetadata(ctx, &order.ID, "stripeCheckoutSessionId", checkoutSession.SessionID) - if err != nil { - return nil, fmt.Errorf("failed to update order metadata: %w", err) + if err := s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "stripeCheckoutSessionId", sess.SessionID); err != nil { + return fmt.Errorf("failed to update order metadata: %w", err) } + + sessID = sess.SessionID } - // if this is a stripe order, and there is a checkout session, we actually need to check it with - // stripe, as the redirect flow sometimes is too fast for the webhook to be delivered. - // exclude any order with a subscription identifier from stripe - if _, sOK := order.Metadata["stripeSubscriptionId"]; !sOK { - if cs, ok := order.Metadata["stripeCheckoutSessionId"].(string); ok && cs != "" { - // get old checkout session from stripe by id - sess, err := session.Get(cs, nil) - if err != nil { - return nil, fmt.Errorf("failed to get stripe checkout session: %w", err) - } + // Nothing more to do for orders with a Stripe subscription. + if _, ok := ord.StripeSubID(); ok { + return nil + } - // Set status to paid and the subscription id and if the session is actually paid. - if sess.PaymentStatus == "paid" { - if err = s.Datastore.UpdateOrder(order.ID, "paid"); err != nil { - return nil, fmt.Errorf("failed to update order to paid status: %w", err) - } + // Below goes some leagcy stuff. + // There was also a bug where the old subscription would be tested for payment. + // The code below did not take into account that the session could have been updated just above. + // + // If this is a stripe order, and there is a checkout session, check it with Stripe. + // The redirect flow sometimes is too fast for the webhook to be delivered. + sessID, ok := chooseStripeSessID(ord, sessID) + if !ok || sessID == "" { + // Nothing to do here. + return nil + } - if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "stripeSubscriptionId", sess.Subscription.ID); err != nil { - return nil, fmt.Errorf("failed to update order to add the subscription id") - } + sess, err := session.Get(sessID, nil) + if err != nil { + return fmt.Errorf("failed to get stripe checkout session: %w", err) + } - if err := s.Datastore.AppendOrderMetadata(ctx, &order.ID, "paymentProcessor", model.StripePaymentMethod); err != nil { - return nil, fmt.Errorf("failed to update order to add the payment processor") - } - } - } + // Skip unpaid sessions. + if sess.PaymentStatus != "paid" { + return nil } - result, err := s.Datastore.GetOrder(order.ID) + // Need to update the order as paid. + // This requires fetching the subscription as the expiry time is needed. + sub, err := s.scClient.Subscriptions.Get(sess.Subscription.ID, nil) if err != nil { - return nil, fmt.Errorf("failed to get order: %w", err) + return err } - return result, nil + expt := time.Unix(sub.CurrentPeriodEnd, 0).UTC().Add(24 * time.Hour) + paidt := time.Unix(sub.CurrentPeriodStart, 0).UTC() + + if err := s.renewOrderWithExpPaidTimeTx(ctx, dbi, ord.ID, expt, paidt); err != nil { + return err + } + + if err := s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "stripeSubscriptionId", sub.ID); err != nil { + return err + } + + return s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "paymentProcessor", model.StripePaymentMethod) } // CancelOrder cancels an order, propagates to stripe if needed. @@ -2497,3 +2528,23 @@ func shouldUpdateOrderStripeSubID(ord *model.Order, subID string) bool { return false } + +func shouldTransformStripeOrder(ord *model.Order) bool { + if ord.IsIOS() { + return false + } + + if ord.IsAndroid() { + return false + } + + return !ord.IsPaid() && ord.IsStripePayable() +} + +func chooseStripeSessID(ord *model.Order, canBeNewSessID string) (string, bool) { + if canBeNewSessID != "" { + return canBeNewSessID, true + } + + return ord.StripeSessID() +} diff --git a/services/skus/storage/repository/mock.go b/services/skus/storage/repository/mock.go index cc6438ef5..3bd687b49 100644 --- a/services/skus/storage/repository/mock.go +++ b/services/skus/storage/repository/mock.go @@ -13,15 +13,16 @@ import ( ) type MockOrder struct { - FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) - FnGetByExternalID func(ctx context.Context, dbi sqlx.QueryerContext, extID string) (*model.Order, error) - FnCreate func(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) - FnSetStatus func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error - FnSetExpiresAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error - FnSetLastPaidAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error - FnAppendMetadata func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error - FnAppendMetadataInt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error - FnAppendMetadataInt64 func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error + FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) + FnGetByExternalID func(ctx context.Context, dbi sqlx.QueryerContext, extID string) (*model.Order, error) + FnCreate func(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) + FnSetStatus func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error + FnSetExpiresAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error + FnSetLastPaidAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error + FnAppendMetadata func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error + FnAppendMetadataInt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error + FnAppendMetadataInt64 func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error + FnGetExpiredStripeCheckoutSessionID func(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) (string, error) } func (r *MockOrder) Get(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { @@ -115,6 +116,14 @@ func (r *MockOrder) AppendMetadataInt64(ctx context.Context, dbi sqlx.ExecerCont return r.FnAppendMetadataInt64(ctx, dbi, id, key, val) } +func (r *MockOrder) GetExpiredStripeCheckoutSessionID(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) (string, error) { + if r.FnGetExpiredStripeCheckoutSessionID == nil { + return "sub_id", nil + } + + return r.FnGetExpiredStripeCheckoutSessionID(ctx, dbi, orderID) +} + type MockOrderItem struct { FnGet func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.OrderItem, error) FnFindByOrderID func(ctx context.Context, dbi sqlx.QueryerContext, orderID uuid.UUID) ([]model.OrderItem, error)