From d0c308ee12986ae9a8db098b93ac48e150c4d122 Mon Sep 17 00:00:00 2001 From: Peter Turi Date: Tue, 29 Oct 2024 15:53:36 +0100 Subject: [PATCH] feat: create invoice backend logic (#1763) This patch adds the required backend code for the create invoice now endpoint. The overall logic is, that we take all of the pending invoice lines and whatever is due, we assign it to a new invoice. Further invoice changes will be handled by a proper state machine, but this transition requires cross-invoice state coordination. --- openmeter/billing/adapter.go | 10 + openmeter/billing/adapter/customeroverride.go | 40 +++- openmeter/billing/adapter/invoice.go | 99 ++++++++ openmeter/billing/adapter/invoicelines.go | 64 ++++- openmeter/billing/customeroverride.go | 5 + openmeter/billing/entity/invoice.go | 6 + openmeter/billing/errors.go | 1 + openmeter/billing/httpdriver/invoiceline.go | 8 +- openmeter/billing/invoice.go | 50 +++- openmeter/billing/invoiceline.go | 45 +++- openmeter/billing/service.go | 2 + openmeter/billing/service/invoice.go | 225 ++++++++++++++++++ openmeter/billing/service/invoiceline.go | 97 +++++--- openmeter/billing/service/service.go | 36 ++- test/billing/invoice_test.go | 217 ++++++++++++++++- 15 files changed, 840 insertions(+), 65 deletions(-) diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 200d5159c..173080a99 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -35,15 +35,25 @@ type CustomerOverrideAdapter interface { UpdateCustomerOverride(ctx context.Context, input UpdateCustomerOverrideAdapterInput) (*billingentity.CustomerOverride, error) DeleteCustomerOverride(ctx context.Context, input DeleteCustomerOverrideInput) error + // UpsertCustomerOverride upserts a customer override ignoring the transactional context, the override + // will be empty. + UpsertCustomerOverride(ctx context.Context, input UpsertCustomerOverrideAdapterInput) error + LockCustomerForUpdate(ctx context.Context, input LockCustomerForUpdateAdapterInput) error + GetCustomerOverrideReferencingProfile(ctx context.Context, input HasCustomerOverrideReferencingProfileAdapterInput) ([]customerentity.CustomerID, error) } type InvoiceLineAdapter interface { CreateInvoiceLines(ctx context.Context, input CreateInvoiceLinesAdapterInput) (*CreateInvoiceLinesResponse, error) + ListInvoiceLines(ctx context.Context, input ListInvoiceLinesAdapterInput) ([]billingentity.Line, error) + AssociateLinesToInvoice(ctx context.Context, input AssociateLinesToInvoiceAdapterInput) error } type InvoiceAdapter interface { CreateInvoice(ctx context.Context, input CreateInvoiceAdapterInput) (CreateInvoiceAdapterRespone, error) GetInvoiceById(ctx context.Context, input GetInvoiceByIdInput) (billingentity.Invoice, error) + LockInvoicesForUpdate(ctx context.Context, input LockInvoicesForUpdateInput) error + DeleteInvoices(ctx context.Context, input DeleteInvoicesAdapterInput) error ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) + AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error) } diff --git a/openmeter/billing/adapter/customeroverride.go b/openmeter/billing/adapter/customeroverride.go index 60f7f6da8..edea7ea62 100644 --- a/openmeter/billing/adapter/customeroverride.go +++ b/openmeter/billing/adapter/customeroverride.go @@ -2,8 +2,10 @@ package billingadapter import ( "context" + "database/sql" "fmt" + entsql "entgo.io/ent/dialect/sql" "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing" @@ -176,6 +178,33 @@ func (r *adapter) GetCustomerOverrideReferencingProfile(ctx context.Context, inp return customerIDs, nil } +func (r *adapter) UpsertCustomerOverride(ctx context.Context, input billing.UpsertCustomerOverrideAdapterInput) error { + err := r.db.BillingCustomerOverride.Create(). + SetNamespace(input.Namespace). + SetCustomerID(input.ID). + OnConflict( + entsql.DoNothing(), + ). + Exec(ctx) + if err != nil { + // The do nothing returns no lines, so we have the record ready + if err == sql.ErrNoRows { + return nil + } + } + return nil +} + +func (r *adapter) LockCustomerForUpdate(ctx context.Context, input billing.LockCustomerForUpdateAdapterInput) error { + _, err := r.db.BillingCustomerOverride.Query(). + Where(billingcustomeroverride.CustomerID(input.ID)). + Where(billingcustomeroverride.Namespace(input.Namespace)). + ForUpdate(). + First(ctx) + + return err +} + func mapCustomerOverrideFromDB(dbOverride *db.BillingCustomerOverride) (*billingentity.CustomerOverride, error) { collectionInterval, err := dbOverride.LineCollectionPeriod.ParsePtrOrNil() if err != nil { @@ -197,6 +226,13 @@ func mapCustomerOverrideFromDB(dbOverride *db.BillingCustomerOverride) (*billing return nil, fmt.Errorf("cannot map profile: %w", err) } + var profile *billingentity.Profile + if baseProfile != nil { + profile = &billingentity.Profile{ + BaseProfile: *baseProfile, + } + } + return &billingentity.CustomerOverride{ ID: dbOverride.ID, Namespace: dbOverride.Namespace, @@ -220,8 +256,6 @@ func mapCustomerOverrideFromDB(dbOverride *db.BillingCustomerOverride) (*billing CollectionMethod: dbOverride.InvoiceCollectionMethod, }, - Profile: &billingentity.Profile{ - BaseProfile: *baseProfile, - }, + Profile: profile, }, nil } diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index 77c1471dd..c629f5b79 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -2,7 +2,9 @@ package billingadapter import ( "context" + "errors" "fmt" + "strings" "time" "github.com/samber/lo" @@ -12,6 +14,8 @@ import ( billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" + "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/framework/entutils" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" @@ -55,6 +59,57 @@ func (r *adapter) GetInvoiceById(ctx context.Context, in billing.GetInvoiceByIdI return mapInvoiceFromDB(*invoice, in.Expand) } +func (r *adapter) LockInvoicesForUpdate(ctx context.Context, input billing.LockInvoicesForUpdateInput) error { + if err := input.Validate(); err != nil { + return billing.ValidationError{ + Err: err, + } + } + + ids, err := r.db.BillingInvoice.Query(). + Where(billinginvoice.IDIn(input.InvoiceIDs...)). + Where(billinginvoice.Namespace(input.Namespace)). + ForUpdate(). + Select(billinginvoice.FieldID). + Strings(ctx) + if err != nil { + return err + } + + missingIds := lo.Without(input.InvoiceIDs, ids...) + if len(missingIds) > 0 { + return billing.NotFoundError{ + Entity: billing.EntityInvoice, + ID: strings.Join(missingIds, ","), + Err: fmt.Errorf("cannot select invoices for update"), + } + } + + return nil +} + +func (r *adapter) DeleteInvoices(ctx context.Context, input billing.DeleteInvoicesAdapterInput) error { + if err := input.Validate(); err != nil { + return billing.ValidationError{ + Err: err, + } + } + + nAffected, err := r.db.BillingInvoice.Update(). + Where(billinginvoice.IDIn(input.InvoiceIDs...)). + Where(billinginvoice.Namespace(input.Namespace)). + SetDeletedAt(clock.Now()). + Save(ctx) + + if nAffected != len(input.InvoiceIDs) { + return billing.ValidationError{ + Err: errors.New("invoices failed to delete"), + } + } + + return err +} + // expandLineItems adds the required edges to the query so that line items can be properly mapped func (r *adapter) expandLineItems(query *db.BillingInvoiceQuery) *db.BillingInvoiceQuery { return query.WithBillingInvoiceLines(func(bilq *db.BillingInvoiceLineQuery) { @@ -217,6 +272,50 @@ func (r *adapter) CreateInvoice(ctx context.Context, input billing.CreateInvoice return mapInvoiceFromDB(*newInvoice, billing.InvoiceExpandAll) } +type lineCountQueryOut struct { + InvoiceID string `json:"invoice_id"` + Count int64 `json:"count"` +} + +func (r *adapter) AssociatedLineCounts(ctx context.Context, input billing.AssociatedLineCountsAdapterInput) (billing.AssociatedLineCountsAdapterResponse, error) { + queryOut := []lineCountQueryOut{} + + err := r.db.BillingInvoiceLine.Query(). + Where(billinginvoiceline.DeletedAtIsNil()). + Where(billinginvoiceline.Namespace(input.Namespace)). + Where(billinginvoiceline.InvoiceIDIn(input.InvoiceIDs...)). + Where(billinginvoiceline.StatusIn(billingentity.InvoiceLineStatusValid)). + GroupBy(billinginvoiceline.FieldInvoiceID). + Aggregate( + db.Count(), + ). + Scan(ctx, &queryOut) + if err != nil { + return billing.AssociatedLineCountsAdapterResponse{}, err + } + + res := lo.Associate(queryOut, func(q lineCountQueryOut) (billingentity.InvoiceID, int64) { + return billingentity.InvoiceID{ + Namespace: input.Namespace, + ID: q.InvoiceID, + }, q.Count + }) + + for _, invoiceID := range input.InvoiceIDs { + id := billingentity.InvoiceID{ + Namespace: input.Namespace, + ID: invoiceID, + } + if _, found := res[id]; !found { + res[id] = 0 + } + } + + return billing.AssociatedLineCountsAdapterResponse{ + Counts: res, + }, nil +} + func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (billingentity.Invoice, error) { res := billingentity.Invoice{ ID: invoice.ID, diff --git a/openmeter/billing/adapter/invoicelines.go b/openmeter/billing/adapter/invoicelines.go index 804f1a754..59f608af5 100644 --- a/openmeter/billing/adapter/invoicelines.go +++ b/openmeter/billing/adapter/invoicelines.go @@ -10,8 +10,8 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" "github.com/openmeterio/openmeter/openmeter/ent/db" + "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" - "github.com/openmeterio/openmeter/pkg/models" ) var _ billing.InvoiceLineAdapter = (*adapter)(nil) @@ -69,17 +69,65 @@ func (r *adapter) CreateInvoiceLines(ctx context.Context, input billing.CreateIn return result, nil } -func (r *adapter) GetInvoiceLineByID(ctx context.Context, id models.NamespacedID) (billingentity.Line, error) { - dbLine, err := r.db.BillingInvoiceLine.Query(). - Where(billinginvoiceline.ID(id.ID)). - Where(billinginvoiceline.Namespace(id.Namespace)). +func (r *adapter) ListInvoiceLines(ctx context.Context, input billing.ListInvoiceLinesAdapterInput) ([]billingentity.Line, error) { + if err := input.Validate(); err != nil { + return nil, err + } + + query := r.db.BillingInvoiceLine.Query(). + Where(billinginvoiceline.Namespace(input.Namespace)) + + if len(input.LineIDs) > 0 { + query = query.Where(billinginvoiceline.IDIn(input.LineIDs...)) + } + + if input.InvoiceAtBefore != nil { + query = query.Where(billinginvoiceline.InvoiceAtLT(*input.InvoiceAtBefore)) + } + + query = query.WithBillingInvoice(func(biq *db.BillingInvoiceQuery) { + biq.Where(billinginvoice.Namespace(input.Namespace)) + + if input.CustomerID != "" { + biq.Where(billinginvoice.CustomerID(input.CustomerID)) + } + + if len(input.InvoiceStatuses) > 0 { + biq.Where(billinginvoice.StatusIn(input.InvoiceStatuses...)) + } + }) + + dbLines, err := query. WithBillingInvoiceManualLines(). - Only(ctx) + All(ctx) + if err != nil { + return nil, err + } + + return lo.Map(dbLines, func(line *db.BillingInvoiceLine, _ int) billingentity.Line { + return mapInvoiceLineFromDB(line) + }), nil +} + +func (r *adapter) AssociateLinesToInvoice(ctx context.Context, input billing.AssociateLinesToInvoiceAdapterInput) error { + if err := input.Validate(); err != nil { + return err + } + + nAffected, err := r.db.BillingInvoiceLine.Update(). + SetInvoiceID(input.Invoice.ID). + Where(billinginvoiceline.Namespace(input.Invoice.Namespace)). + Where(billinginvoiceline.IDIn(input.LineIDs...)). + Save(ctx) if err != nil { - return billingentity.Line{}, err + return fmt.Errorf("associating lines: %w", err) + } + + if nAffected != len(input.LineIDs) { + return fmt.Errorf("fewer lines were associated (%d) than expected (%d)", nAffected, len(input.LineIDs)) } - return mapInvoiceLineFromDB(dbLine), nil + return nil } func mapInvoiceLineFromDB(dbLine *db.BillingInvoiceLine) billingentity.Line { diff --git a/openmeter/billing/customeroverride.go b/openmeter/billing/customeroverride.go index 2e1f067c1..69efef244 100644 --- a/openmeter/billing/customeroverride.go +++ b/openmeter/billing/customeroverride.go @@ -152,3 +152,8 @@ type HasCustomerOverrideReferencingProfileAdapterInput genericNamespaceID func (i HasCustomerOverrideReferencingProfileAdapterInput) Validate() error { return genericNamespaceID(i).Validate() } + +type ( + UpsertCustomerOverrideAdapterInput = customerentity.CustomerID + LockCustomerForUpdateAdapterInput = customerentity.CustomerID +) diff --git a/openmeter/billing/entity/invoice.go b/openmeter/billing/entity/invoice.go index 80965482d..fbc26f312 100644 --- a/openmeter/billing/entity/invoice.go +++ b/openmeter/billing/entity/invoice.go @@ -116,6 +116,12 @@ func (s InvoiceStatus) IsMutable() bool { return true } +type InvoiceID models.NamespacedID + +func (i InvoiceID) Validate() error { + return models.NamespacedID(i).Validate() +} + type Invoice struct { Namespace string `json:"namespace"` ID string `json:"id"` diff --git a/openmeter/billing/errors.go b/openmeter/billing/errors.go index 05dd61446..c95c70c16 100644 --- a/openmeter/billing/errors.go +++ b/openmeter/billing/errors.go @@ -45,6 +45,7 @@ const ( EntityCustomer = "Customer" EntityDefaultProfile = "DefaultBillingProfile" EntityInvoice = "Invoice" + EntityInvoiceLine = "InvoiceLine" ) type NotFoundError struct { diff --git a/openmeter/billing/httpdriver/invoiceline.go b/openmeter/billing/httpdriver/invoiceline.go index 838b45aae..69d8e368f 100644 --- a/openmeter/billing/httpdriver/invoiceline.go +++ b/openmeter/billing/httpdriver/invoiceline.go @@ -24,7 +24,7 @@ type ( func (h *handler) CreateLineByCustomer() CreateLineByCustomerHandler { return httptransport.NewHandlerWithArgs( - func(ctx context.Context, r *http.Request, customerKeyOrId string) (CreateLineByCustomerRequest, error) { + func(ctx context.Context, r *http.Request, customerID string) (CreateLineByCustomerRequest, error) { body := api.BillingCreateLineByCustomerJSONRequestBody{} if err := commonhttp.JSONRequestBodyDecoder(r, &body); err != nil { @@ -46,9 +46,9 @@ func (h *handler) CreateLineByCustomer() CreateLineByCustomerHandler { } return CreateLineByCustomerRequest{ - CustomerKeyOrID: customerKeyOrId, - Namespace: ns, - Lines: lines, + CustomerID: customerID, + Namespace: ns, + Lines: lines, }, nil }, func(ctx context.Context, request CreateLineByCustomerRequest) (CreateLineByCustomerResponse, error) { diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index f2ad5a21b..3536b62a2 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -8,8 +8,8 @@ import ( "github.com/openmeterio/openmeter/api" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" + "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" - "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/sortx" ) @@ -37,7 +37,7 @@ func (e InvoiceExpand) Validate() error { } type GetInvoiceByIdInput struct { - Invoice models.NamespacedID + Invoice billingentity.InvoiceID Expand InvoiceExpand } @@ -53,6 +53,29 @@ func (i GetInvoiceByIdInput) Validate() error { return nil } +type genericMultiInvoiceInput struct { + Namespace string + InvoiceIDs []string +} + +func (i genericMultiInvoiceInput) Validate() error { + if i.Namespace == "" { + return errors.New("namespace is required") + } + + if len(i.InvoiceIDs) == 0 { + return errors.New("invoice IDs are required") + } + + return nil +} + +type ( + DeleteInvoicesAdapterInput = genericMultiInvoiceInput + LockInvoicesForUpdateInput = genericMultiInvoiceInput + AssociatedLineCountsAdapterInput = genericMultiInvoiceInput +) + type ListInvoicesInput struct { pagination.Page @@ -131,3 +154,26 @@ func (c CreateInvoiceAdapterInput) Validate() error { } type CreateInvoiceAdapterRespone = billingentity.Invoice + +type CreateInvoiceInput struct { + Customer customerentity.CustomerID + + IncludePendingLines []string + AsOf *time.Time +} + +func (i CreateInvoiceInput) Validate() error { + if err := i.Customer.Validate(); err != nil { + return fmt.Errorf("customer: %w", err) + } + + if i.AsOf != nil && i.AsOf.After(clock.Now()) { + return errors.New("asOf must be in the past") + } + + return nil +} + +type AssociatedLineCountsAdapterResponse struct { + Counts map[billingentity.InvoiceID]int64 +} diff --git a/openmeter/billing/invoiceline.go b/openmeter/billing/invoiceline.go index 1529a527c..facb49a67 100644 --- a/openmeter/billing/invoiceline.go +++ b/openmeter/billing/invoiceline.go @@ -3,14 +3,15 @@ package billing import ( "errors" "fmt" + "time" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" ) type CreateInvoiceLinesInput struct { - CustomerKeyOrID string - Namespace string - Lines []billingentity.Line + CustomerID string + Namespace string + Lines []billingentity.Line } func (c CreateInvoiceLinesInput) Validate() error { @@ -18,7 +19,7 @@ func (c CreateInvoiceLinesInput) Validate() error { return errors.New("namespace is required") } - if c.CustomerKeyOrID == "" { + if c.CustomerID == "" { return errors.New("customer key or ID is required") } @@ -57,3 +58,39 @@ func (c CreateInvoiceLinesAdapterInput) Validate() error { type CreateInvoiceLinesResponse struct { Lines []billingentity.Line } + +type ListInvoiceLinesAdapterInput struct { + Namespace string + + CustomerID string + InvoiceStatuses []billingentity.InvoiceStatus + InvoiceAtBefore *time.Time + + LineIDs []string +} + +func (g ListInvoiceLinesAdapterInput) Validate() error { + if g.Namespace == "" { + return errors.New("namespace is required") + } + + return nil +} + +type AssociateLinesToInvoiceAdapterInput struct { + Invoice billingentity.InvoiceID + + LineIDs []string +} + +func (i AssociateLinesToInvoiceAdapterInput) Validate() error { + if err := i.Invoice.Validate(); err != nil { + return fmt.Errorf("invoice: %w", err) + } + + if len(i.LineIDs) == 0 { + return errors.New("line ids are required") + } + + return nil +} diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index 55994ec35..ffe2edddc 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -37,4 +37,6 @@ type InvoiceLineService interface { type InvoiceService interface { ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) + GetInvoiceByID(ctx context.Context, input GetInvoiceByIdInput) (billingentity.Invoice, error) + CreateInvoice(ctx context.Context, input CreateInvoiceInput) ([]billingentity.Invoice, error) } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index 05437671d..a4502b47a 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -3,9 +3,15 @@ package billingservice import ( "context" "fmt" + "strings" + "time" + + "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/framework/entutils" ) @@ -37,3 +43,222 @@ func (s *Service) ListInvoices(ctx context.Context, input billing.ListInvoicesIn return invoices, nil }) } + +func (s *Service) GetInvoiceByID(ctx context.Context, input billing.GetInvoiceByIdInput) (billingentity.Invoice, error) { + return entutils.TransactingRepo(ctx, s.adapter, func(ctx context.Context, txAdapter billing.Adapter) (billingentity.Invoice, error) { + invoice, err := txAdapter.GetInvoiceById(ctx, input) + if err != nil { + return billingentity.Invoice{}, err + } + + if input.Expand.WorkflowApps { + resolvedApps, err := s.resolveApps(ctx, input.Invoice.Namespace, invoice.Workflow.AppReferences) + if err != nil { + return billingentity.Invoice{}, fmt.Errorf("error resolving apps for invoice [%s]: %w", invoice.ID, err) + } + + invoice.Workflow.Apps = &billingentity.ProfileApps{ + Tax: resolvedApps.Tax.App, + Invoicing: resolvedApps.Invoicing.App, + Payment: resolvedApps.Payment.App, + } + } + + return invoice, nil + }) +} + +func (s *Service) CreateInvoice(ctx context.Context, input billing.CreateInvoiceInput) ([]billingentity.Invoice, error) { + if err := input.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + return TransactingRepoForGatheringInvoiceManipulation( + ctx, + s.adapter, + input.Customer, + func(ctx context.Context, txAdapter billing.Adapter) ([]billingentity.Invoice, error) { + // let's resolve the customer's settings + customerProfile, err := s.GetProfileWithCustomerOverride(ctx, billing.GetProfileWithCustomerOverrideInput{ + Namespace: input.Customer.Namespace, + CustomerID: input.Customer.ID, + }) + if err != nil { + return nil, fmt.Errorf("fetching customer profile: %w", err) + } + + asof := lo.FromPtrOr(input.AsOf, clock.Now()) + + // let's gather the in-scope lines and validate it + inScopeLines, err := s.gatherInscopeLines(ctx, input, txAdapter, asof) + if err != nil { + return nil, err + } + + sourceInvoiceIDs := lo.Uniq(lo.Map(inScopeLines, func(l billingentity.Line, _ int) string { + return l.InvoiceID + })) + + if len(sourceInvoiceIDs) == 0 { + return nil, billing.ValidationError{ + Err: fmt.Errorf("no source lines found"), + } + } + + // let's lock the source gathering invoices, so that no other invoice operation can interfere + err = txAdapter.LockInvoicesForUpdate(ctx, billing.LockInvoicesForUpdateInput{ + Namespace: input.Customer.Namespace, + InvoiceIDs: sourceInvoiceIDs, + }) + if err != nil { + return nil, fmt.Errorf("locking gathering invoices: %w", err) + } + + linesByCurrency := lo.GroupBy(inScopeLines, func(line billingentity.Line) currencyx.Code { + return line.Currency + }) + + createdInvoices := make([]billingentity.InvoiceID, 0, len(linesByCurrency)) + + for currency, lines := range linesByCurrency { + // let's create the invoice + invoice, err := txAdapter.CreateInvoice(ctx, billing.CreateInvoiceAdapterInput{ + Namespace: input.Customer.Namespace, + Customer: customerProfile.Customer, + Profile: customerProfile.Profile, + + Currency: currency, + Status: billingentity.InvoiceStatusDraft, + + Type: billingentity.InvoiceTypeStandard, + }) + if err != nil { + return nil, fmt.Errorf("creating invoice: %w", err) + } + + createdInvoices = append(createdInvoices, billingentity.InvoiceID{ + Namespace: invoice.Namespace, + ID: invoice.ID, + }) + + // let's associate the invoice lines to the invoice + err = s.associateLinesToInvoice(ctx, txAdapter, invoice, lines) + if err != nil { + return nil, fmt.Errorf("associating lines to invoice: %w", err) + } + } + + // Let's check if we need to remove any empty gathering invoices (e.g. if they don't have any line items) + // This typically should happen when a subscription has ended. + + invoiceLineCounts, err := txAdapter.AssociatedLineCounts(ctx, billing.AssociatedLineCountsAdapterInput{ + Namespace: input.Customer.Namespace, + InvoiceIDs: sourceInvoiceIDs, + }) + if err != nil { + return nil, fmt.Errorf("cleanup: line counts check: %w", err) + } + + invoicesWithoutLines := lo.Filter(sourceInvoiceIDs, func(id string, _ int) bool { + return invoiceLineCounts.Counts[billingentity.InvoiceID{ + Namespace: input.Customer.Namespace, + ID: id, + }] == 0 + }) + + if len(invoicesWithoutLines) > 0 { + err = txAdapter.DeleteInvoices(ctx, billing.DeleteInvoicesAdapterInput{ + Namespace: input.Customer.Namespace, + InvoiceIDs: invoicesWithoutLines, + }) + if err != nil { + return nil, fmt.Errorf("cleanup invoices: %w", err) + } + } + + // Assemble output: we need to refetch as the association call will have side-effects of updating + // invoice objects (e.g. totals, period, etc.) + out := make([]billingentity.Invoice, 0, len(createdInvoices)) + for _, invoiceID := range createdInvoices { + invoiceWithLines, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: invoiceID, + Expand: billing.InvoiceExpandAll, + }) + if err != nil { + return nil, fmt.Errorf("cannot get invoice[%s]: %w", invoiceWithLines.ID, err) + } + + out = append(out, invoiceWithLines) + } + return out, nil + }) +} + +func (s *Service) gatherInscopeLines(ctx context.Context, input billing.CreateInvoiceInput, txAdapter billing.Adapter, asOf time.Time) ([]billingentity.Line, error) { + if input.IncludePendingLines != nil { + inScopeLines, err := txAdapter.ListInvoiceLines(ctx, + billing.ListInvoiceLinesAdapterInput{ + Namespace: input.Customer.Namespace, + CustomerID: input.Customer.ID, + + LineIDs: input.IncludePendingLines, + }) + if err != nil { + return nil, fmt.Errorf("resolving in scope lines: %w", err) + } + + // output validation + + // asOf validity + for _, line := range inScopeLines { + if line.InvoiceAt.After(asOf) { + return nil, billing.ValidationError{ + Err: fmt.Errorf("line [%s] has invoiceAt [%s] after asOf [%s]", line.ID, line.InvoiceAt, asOf), + } + } + } + + // all lines must be found + if len(inScopeLines) != len(input.IncludePendingLines) { + includedLines := lo.Map(inScopeLines, func(l billingentity.Line, _ int) string { + return l.ID + }) + + missingIDs := lo.Without(input.IncludePendingLines, includedLines...) + + return nil, billing.NotFoundError{ + ID: strings.Join(missingIDs, ","), + Entity: billing.EntityInvoiceLine, + Err: fmt.Errorf("some invoice lines are not found"), + } + } + + return inScopeLines, nil + } + + lines, err := txAdapter.ListInvoiceLines(ctx, + billing.ListInvoiceLinesAdapterInput{ + Namespace: input.Customer.Namespace, + CustomerID: input.Customer.ID, + + InvoiceStatuses: []billingentity.InvoiceStatus{ + billingentity.InvoiceStatusGathering, + }, + + InvoiceAtBefore: lo.ToPtr(asOf), + }) + if err != nil { + return nil, err + } + + if len(lines) == 0 { + // We haven't requested explicit empty invoice, so we should have some pending lines + return nil, billing.ValidationError{ + Err: fmt.Errorf("no pending lines found"), + } + } + + return lines, nil +} diff --git a/openmeter/billing/service/invoiceline.go b/openmeter/billing/service/invoiceline.go index 7f604d772..e9cb6fa8a 100644 --- a/openmeter/billing/service/invoiceline.go +++ b/openmeter/billing/service/invoiceline.go @@ -4,12 +4,13 @@ import ( "context" "fmt" + "github.com/samber/lo" + "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/openmeter/billing" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" "github.com/openmeterio/openmeter/pkg/currencyx" - "github.com/openmeterio/openmeter/pkg/framework/entutils" - "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/sortx" ) @@ -23,36 +24,44 @@ func (s *Service) CreateInvoiceLines(ctx context.Context, input billing.CreateIn } } - return entutils.TransactingRepo(ctx, s.adapter, func(ctx context.Context, txAdapter billing.Adapter) (*billing.CreateInvoiceLinesResponse, error) { - // let's resolve the customer's settings - customerProfile, err := s.GetProfileWithCustomerOverride(ctx, billing.GetProfileWithCustomerOverrideInput{ - Namespace: input.Namespace, - CustomerID: input.CustomerKeyOrID, - }) - if err != nil { - return nil, fmt.Errorf("fetching customer profile: %w", err) - } - - for i, line := range input.Lines { - updatedLine, err := s.upsertLineInvoice(ctx, txAdapter, line, input, customerProfile) + return TransactingRepoForGatheringInvoiceManipulation( + ctx, + s.adapter, + customerentity.CustomerID{ + Namespace: input.Namespace, + ID: input.CustomerID, + }, + func(ctx context.Context, txAdapter billing.Adapter) (*billing.CreateInvoiceLinesResponse, error) { + // let's resolve the customer's settings + customerProfile, err := s.GetProfileWithCustomerOverride(ctx, billing.GetProfileWithCustomerOverrideInput{ + Namespace: input.Namespace, + CustomerID: input.CustomerID, + }) if err != nil { - return nil, fmt.Errorf("upserting line[%d]: %w", i, err) + return nil, fmt.Errorf("fetching customer profile: %w", err) } - input.Lines[i] = updatedLine - } + // TODO: we should optimize this as this does O(n) queries for invoices per line + for i, line := range input.Lines { + updatedLine, err := s.upsertLineInvoice(ctx, txAdapter, line, input, customerProfile) + if err != nil { + return nil, fmt.Errorf("upserting line[%d]: %w", i, err) + } - // Create the invoice Lines - lines, err := txAdapter.CreateInvoiceLines(ctx, billing.CreateInvoiceLinesAdapterInput{ - Namespace: input.Namespace, - Lines: input.Lines, - }) - if err != nil { - return nil, fmt.Errorf("creating invoice Line: %w", err) - } + input.Lines[i] = updatedLine + } - return lines, nil - }) + // Create the invoice Lines + lines, err := txAdapter.CreateInvoiceLines(ctx, billing.CreateInvoiceLinesAdapterInput{ + Namespace: input.Namespace, + Lines: input.Lines, + }) + if err != nil { + return nil, fmt.Errorf("creating invoice Line: %w", err) + } + + return lines, nil + }) } func (s *Service) upsertLineInvoice(ctx context.Context, txAdapter billing.Adapter, line billingentity.Line, input billing.CreateInvoiceLinesInput, customerProfile *billingentity.ProfileWithCustomerDetails) (billingentity.Line, error) { @@ -62,7 +71,7 @@ func (s *Service) upsertLineInvoice(ctx context.Context, txAdapter billing.Adapt if line.InvoiceID != "" { // We would want to attach the line to an existing invoice invoice, err := txAdapter.GetInvoiceById(ctx, billing.GetInvoiceByIdInput{ - Invoice: models.NamespacedID{ + Invoice: billingentity.InvoiceID{ ID: line.InvoiceID, Namespace: input.Namespace, }, @@ -95,7 +104,7 @@ func (s *Service) upsertLineInvoice(ctx context.Context, txAdapter billing.Adapt PageNumber: 1, PageSize: 10, }, - Customers: []string{input.CustomerKeyOrID}, + Customers: []string{input.CustomerID}, Namespace: input.Namespace, Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, Currencies: []currencyx.Code{line.Currency}, @@ -130,9 +139,37 @@ func (s *Service) upsertLineInvoice(ctx context.Context, txAdapter billing.Adapt // have multiple gathering invoices for the same customer. // This is a rare case, but we should log it at least, later we can implement a call that // merges these invoices (it's fine to just move the Lines to the first invoice) - s.logger.Warn("more than one pending invoice found", "customer", input.CustomerKeyOrID, "namespace", input.Namespace) + s.logger.Warn("more than one pending invoice found", "customer", input.CustomerID, "namespace", input.Namespace) } } return line, nil } + +func (s *Service) associateLinesToInvoice(ctx context.Context, txAdapter billing.Adapter, invoice billingentity.Invoice, lines []billingentity.Line) error { + for _, line := range lines { + if line.InvoiceID == invoice.ID { + return billing.ValidationError{ + Err: fmt.Errorf("line[%s]: line already associated with invoice[%s]", line.ID, invoice.ID), + } + } + } + + // Associate the lines to the invoice + err := txAdapter.AssociateLinesToInvoice(ctx, billing.AssociateLinesToInvoiceAdapterInput{ + Invoice: billingentity.InvoiceID{ + ID: invoice.ID, + Namespace: invoice.Namespace, + }, + + LineIDs: lo.Map(lines, func(l billingentity.Line, _ int) string { + return l.ID + }), + }) + if err != nil { + return err + } + + // TODO[later]: Here we need to recalculate any line specific fields for both invoices + return nil +} diff --git a/openmeter/billing/service/service.go b/openmeter/billing/service/service.go index bdf6fa01d..2fcd540b4 100644 --- a/openmeter/billing/service/service.go +++ b/openmeter/billing/service/service.go @@ -3,13 +3,14 @@ package billingservice import ( "context" "errors" + "fmt" "log/slog" "github.com/openmeterio/openmeter/openmeter/app" "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/customer" + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" "github.com/openmeterio/openmeter/pkg/framework/entutils" - "github.com/openmeterio/openmeter/pkg/framework/transaction" ) var _ billing.Service = (*Service)(nil) @@ -61,17 +62,28 @@ func New(config Config) (*Service, error) { }, nil } -func Transaction[R any](ctx context.Context, creator billing.Adapter, cb func(ctx context.Context, tx billing.Adapter) (R, error)) (R, error) { - return transaction.Run(ctx, creator, func(ctx context.Context) (R, error) { - return entutils.TransactingRepo[R, billing.Adapter](ctx, creator, cb) - }) -} +// TransactingRepoForGatheringInvoiceManipulation is a helper function that wraps the given function in a transaction and ensures that +// an update lock is held on the customer record. This is useful when you need to manipulate the gathering invoices, as we cannot lock an +// invoice, that doesn't exist yet. +func TransactingRepoForGatheringInvoiceManipulation[T any](ctx context.Context, adapter billing.Adapter, customer customerentity.CustomerID, fn func(ctx context.Context, txAdapter billing.Adapter) (T, error)) (T, error) { + if err := customer.Validate(); err != nil { + var empty T + return empty, fmt.Errorf("validating customer: %w", err) + } + + // NOTE: This should not be in transaction, or we can get a conflict for parallel writes + err := adapter.UpsertCustomerOverride(ctx, customer) + if err != nil { + var empty T + return empty, fmt.Errorf("upserting customer override: %w", err) + } + + return entutils.TransactingRepo(ctx, adapter, func(ctx context.Context, txAdapter billing.Adapter) (T, error) { + if err := txAdapter.LockCustomerForUpdate(ctx, customer); err != nil { + var empty T + return empty, fmt.Errorf("locking customer for update: %w", err) + } -func TransactionWithNoValue(ctx context.Context, creator billing.Adapter, cb func(ctx context.Context, tx billing.Adapter) error) error { - return transaction.RunWithNoValue(ctx, creator, func(ctx context.Context) error { - _, err := entutils.TransactingRepo[interface{}, billing.Adapter](ctx, creator, func(ctx context.Context, rep billing.Adapter) (interface{}, error) { - return nil, cb(ctx, rep) - }) - return err + return fn(ctx, txAdapter) }) } diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 1e2b19a26..70cdcc8a2 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -86,8 +86,8 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { res, err := s.BillingService.CreateInvoiceLines(ctx, billing.CreateInvoiceLinesInput{ - Namespace: namespace, - CustomerKeyOrID: customerEntity.ID, + Namespace: namespace, + CustomerID: customerEntity.ID, Lines: []billingentity.Line{ { LineBase: billingentity.LineBase{ @@ -326,3 +326,216 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { require.NotNil(s.T(), invoice.Workflow.Apps.Payment, "apps should be resolved") }) } + +func (s *InvoicingTestSuite) TestCreateInvoice() { + namespace := "ns-create-invoice-gathering-to-draft" + now := time.Now().Truncate(time.Microsecond) + periodEnd := now.Add(-time.Hour) + periodStart := periodEnd.Add(-time.Hour * 24 * 30) + line1IssueAt := now.Add(-2 * time.Hour) + line2IssueAt := now.Add(-time.Hour) + + _ = s.installSandboxApp(s.T(), namespace) + + ctx := context.Background() + + // Given we have a test customer + + customerEntity, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ + Namespace: namespace, + + Customer: customerentity.Customer{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Name: "Test Customer", + }), + PrimaryEmail: lo.ToPtr("test@test.com"), + BillingAddress: &models.Address{ + Country: lo.ToPtr(models.CountryCode("US")), + }, + Currency: lo.ToPtr(currencyx.Code(currency.USD)), + }, + }) + require.NoError(s.T(), err) + require.NotNil(s.T(), customerEntity) + require.NotEmpty(s.T(), customerEntity.ID) + + // Given we have a default profile for the namespace + + minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput.Namespace = namespace + + profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) + + require.NoError(s.T(), err) + require.NotNil(s.T(), profile) + + res, err := s.BillingService.CreateInvoiceLines(ctx, + billing.CreateInvoiceLinesInput{ + Namespace: namespace, + CustomerID: customerEntity.ID, + Lines: []billingentity.Line{ + { + LineBase: billingentity.LineBase{ + Namespace: namespace, + Period: billingentity.Period{Start: periodStart, End: periodEnd}, + + InvoiceAt: line1IssueAt, + + Type: billingentity.InvoiceLineTypeManualFee, + + Name: "Test item1", + Currency: currencyx.Code(currency.USD), + + Metadata: map[string]string{ + "key": "value", + }, + }, + ManualFee: &billingentity.ManualFeeLine{ + Price: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + }, + }, + { + LineBase: billingentity.LineBase{ + Namespace: namespace, + Period: billingentity.Period{Start: periodStart, End: periodEnd}, + + InvoiceAt: line2IssueAt, + + Type: billingentity.InvoiceLineTypeManualFee, + + Name: "Test item2", + Currency: currencyx.Code(currency.USD), + }, + ManualFee: &billingentity.ManualFeeLine{ + Price: alpacadecimal.NewFromFloat(200), + Quantity: alpacadecimal.NewFromFloat(3), + }, + }, + }, + }) + + // Then we should have the items created + require.NoError(s.T(), err) + require.Len(s.T(), res.Lines, 2) + line1ID := res.Lines[0].ID + line2ID := res.Lines[1].ID + require.NotEmpty(s.T(), line1ID) + require.NotEmpty(s.T(), line2ID) + + // Expect that a single gathering invoice has been created + require.Equal(s.T(), res.Lines[0].InvoiceID, res.Lines[1].InvoiceID) + gatheringInvoiceID := billingentity.InvoiceID{ + Namespace: namespace, + ID: res.Lines[0].InvoiceID, + } + + s.Run("Creating invoice in the future fails", func() { + _, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: customerEntity.ID, + Namespace: customerEntity.Namespace, + }, + AsOf: lo.ToPtr(now.Add(time.Hour)), + }) + + require.Error(s.T(), err) + require.ErrorAs(s.T(), err, &billing.ValidationError{}) + }) + + s.Run("Creating invoice without any pending lines being available fails", func() { + _, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: customerEntity.ID, + Namespace: customerEntity.Namespace, + }, + + AsOf: lo.ToPtr(line1IssueAt.Add(-time.Minute)), + }) + + require.Error(s.T(), err) + require.ErrorAs(s.T(), err, &billing.ValidationError{}) + }) + + s.Run("Number of pending invoice lines is reported correctly by the adapter", func() { + res, err := s.BillingAdapter.AssociatedLineCounts(ctx, billing.AssociatedLineCountsAdapterInput{ + Namespace: namespace, + InvoiceIDs: []string{gatheringInvoiceID.ID}, + }) + + require.NoError(s.T(), err) + require.Len(s.T(), res.Counts, 1) + require.Equal(s.T(), int64(2), res.Counts[gatheringInvoiceID]) + }) + + s.Run("When creating an invoice with only item1 included", func() { + invoice, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: customerEntity.ID, + Namespace: customerEntity.Namespace, + }, + AsOf: lo.ToPtr(line1IssueAt.Add(time.Minute)), + }) + + // Then we should have the invoice created + require.NoError(s.T(), err) + require.Len(s.T(), invoice, 1) + + // Then we should have item1 added to the invoice + require.Len(s.T(), invoice[0].Lines, 1) + require.Equal(s.T(), line1ID, invoice[0].Lines[0].ID) + + // Then we expect that the gathering invoice is still present, with item2 + gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.InvoiceExpandAll, + }) + require.NoError(s.T(), err) + require.Nil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be present") + require.Len(s.T(), gatheringInvoice.Lines, 1) + require.Equal(s.T(), line2ID, gatheringInvoice.Lines[0].ID) + }) + + s.Run("When creating an invoice with only item2 included, but bad asof", func() { + _, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: customerEntity.ID, + Namespace: customerEntity.Namespace, + }, + IncludePendingLines: []string{line2ID}, + AsOf: lo.ToPtr(line1IssueAt.Add(time.Minute)), + }) + + // Then we should receive a validation error + require.Error(s.T(), err) + require.ErrorAs(s.T(), err, &billing.ValidationError{}) + }) + + s.Run("When creating an invoice with only item2 included", func() { + invoice, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: customerEntity.ID, + Namespace: customerEntity.Namespace, + }, + IncludePendingLines: []string{line2ID}, + AsOf: lo.ToPtr(now), + }) + + // Then we should have the invoice created + require.NoError(s.T(), err) + require.Len(s.T(), invoice, 1) + + // Then we should have item2 added to the invoice + require.Len(s.T(), invoice[0].Lines, 1) + require.Equal(s.T(), line2ID, invoice[0].Lines[0].ID) + + // Then we expect that the gathering invoice is deleted and empty + gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.InvoiceExpandAll, + }) + require.NoError(s.T(), err) + require.NotNil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be present") + require.Len(s.T(), gatheringInvoice.Lines, 0, "deleted gathering invoice is empty") + }) +}