Skip to content

Commit

Permalink
feat: create invoice backend logic (#1763)
Browse files Browse the repository at this point in the history
This patch adds the required backend code for the create invoice
now endpoint.

The overall logic is, that we take all of the pending invoice lines and 
whatever is due, we assign it to a new invoice.

Further invoice changes will be handled by a proper state machine,
but this transition requires cross-invoice state coordination.
  • Loading branch information
turip authored Oct 29, 2024
1 parent af1dcce commit d0c308e
Show file tree
Hide file tree
Showing 15 changed files with 840 additions and 65 deletions.
10 changes: 10 additions & 0 deletions openmeter/billing/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,25 @@ type CustomerOverrideAdapter interface {
UpdateCustomerOverride(ctx context.Context, input UpdateCustomerOverrideAdapterInput) (*billingentity.CustomerOverride, error)
DeleteCustomerOverride(ctx context.Context, input DeleteCustomerOverrideInput) error

// UpsertCustomerOverride upserts a customer override ignoring the transactional context, the override
// will be empty.
UpsertCustomerOverride(ctx context.Context, input UpsertCustomerOverrideAdapterInput) error
LockCustomerForUpdate(ctx context.Context, input LockCustomerForUpdateAdapterInput) error

GetCustomerOverrideReferencingProfile(ctx context.Context, input HasCustomerOverrideReferencingProfileAdapterInput) ([]customerentity.CustomerID, error)
}

type InvoiceLineAdapter interface {
CreateInvoiceLines(ctx context.Context, input CreateInvoiceLinesAdapterInput) (*CreateInvoiceLinesResponse, error)
ListInvoiceLines(ctx context.Context, input ListInvoiceLinesAdapterInput) ([]billingentity.Line, error)
AssociateLinesToInvoice(ctx context.Context, input AssociateLinesToInvoiceAdapterInput) error
}

type InvoiceAdapter interface {
CreateInvoice(ctx context.Context, input CreateInvoiceAdapterInput) (CreateInvoiceAdapterRespone, error)
GetInvoiceById(ctx context.Context, input GetInvoiceByIdInput) (billingentity.Invoice, error)
LockInvoicesForUpdate(ctx context.Context, input LockInvoicesForUpdateInput) error
DeleteInvoices(ctx context.Context, input DeleteInvoicesAdapterInput) error
ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error)
AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error)
}
40 changes: 37 additions & 3 deletions openmeter/billing/adapter/customeroverride.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package billingadapter

import (
"context"
"database/sql"
"fmt"

entsql "entgo.io/ent/dialect/sql"
"github.com/samber/lo"

"github.com/openmeterio/openmeter/openmeter/billing"
Expand Down Expand Up @@ -176,6 +178,33 @@ func (r *adapter) GetCustomerOverrideReferencingProfile(ctx context.Context, inp
return customerIDs, nil
}

func (r *adapter) UpsertCustomerOverride(ctx context.Context, input billing.UpsertCustomerOverrideAdapterInput) error {
err := r.db.BillingCustomerOverride.Create().
SetNamespace(input.Namespace).
SetCustomerID(input.ID).
OnConflict(
entsql.DoNothing(),
).
Exec(ctx)
if err != nil {
// The do nothing returns no lines, so we have the record ready
if err == sql.ErrNoRows {
return nil
}
}
return nil
}

func (r *adapter) LockCustomerForUpdate(ctx context.Context, input billing.LockCustomerForUpdateAdapterInput) error {
_, err := r.db.BillingCustomerOverride.Query().
Where(billingcustomeroverride.CustomerID(input.ID)).
Where(billingcustomeroverride.Namespace(input.Namespace)).
ForUpdate().
First(ctx)

return err
}

func mapCustomerOverrideFromDB(dbOverride *db.BillingCustomerOverride) (*billingentity.CustomerOverride, error) {
collectionInterval, err := dbOverride.LineCollectionPeriod.ParsePtrOrNil()
if err != nil {
Expand All @@ -197,6 +226,13 @@ func mapCustomerOverrideFromDB(dbOverride *db.BillingCustomerOverride) (*billing
return nil, fmt.Errorf("cannot map profile: %w", err)
}

var profile *billingentity.Profile
if baseProfile != nil {
profile = &billingentity.Profile{
BaseProfile: *baseProfile,
}
}

return &billingentity.CustomerOverride{
ID: dbOverride.ID,
Namespace: dbOverride.Namespace,
Expand All @@ -220,8 +256,6 @@ func mapCustomerOverrideFromDB(dbOverride *db.BillingCustomerOverride) (*billing
CollectionMethod: dbOverride.InvoiceCollectionMethod,
},

Profile: &billingentity.Profile{
BaseProfile: *baseProfile,
},
Profile: profile,
}, nil
}
99 changes: 99 additions & 0 deletions openmeter/billing/adapter/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package billingadapter

import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/samber/lo"
Expand All @@ -12,6 +14,8 @@ import (
billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity"
"github.com/openmeterio/openmeter/openmeter/ent/db"
"github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice"
"github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/framework/entutils"
"github.com/openmeterio/openmeter/pkg/models"
"github.com/openmeterio/openmeter/pkg/pagination"
Expand Down Expand Up @@ -55,6 +59,57 @@ func (r *adapter) GetInvoiceById(ctx context.Context, in billing.GetInvoiceByIdI
return mapInvoiceFromDB(*invoice, in.Expand)
}

func (r *adapter) LockInvoicesForUpdate(ctx context.Context, input billing.LockInvoicesForUpdateInput) error {
if err := input.Validate(); err != nil {
return billing.ValidationError{
Err: err,
}
}

ids, err := r.db.BillingInvoice.Query().
Where(billinginvoice.IDIn(input.InvoiceIDs...)).
Where(billinginvoice.Namespace(input.Namespace)).
ForUpdate().
Select(billinginvoice.FieldID).
Strings(ctx)
if err != nil {
return err
}

missingIds := lo.Without(input.InvoiceIDs, ids...)
if len(missingIds) > 0 {
return billing.NotFoundError{
Entity: billing.EntityInvoice,
ID: strings.Join(missingIds, ","),
Err: fmt.Errorf("cannot select invoices for update"),
}
}

return nil
}

func (r *adapter) DeleteInvoices(ctx context.Context, input billing.DeleteInvoicesAdapterInput) error {
if err := input.Validate(); err != nil {
return billing.ValidationError{
Err: err,
}
}

nAffected, err := r.db.BillingInvoice.Update().
Where(billinginvoice.IDIn(input.InvoiceIDs...)).
Where(billinginvoice.Namespace(input.Namespace)).
SetDeletedAt(clock.Now()).
Save(ctx)

if nAffected != len(input.InvoiceIDs) {
return billing.ValidationError{
Err: errors.New("invoices failed to delete"),
}
}

return err
}

// expandLineItems adds the required edges to the query so that line items can be properly mapped
func (r *adapter) expandLineItems(query *db.BillingInvoiceQuery) *db.BillingInvoiceQuery {
return query.WithBillingInvoiceLines(func(bilq *db.BillingInvoiceLineQuery) {
Expand Down Expand Up @@ -217,6 +272,50 @@ func (r *adapter) CreateInvoice(ctx context.Context, input billing.CreateInvoice
return mapInvoiceFromDB(*newInvoice, billing.InvoiceExpandAll)
}

type lineCountQueryOut struct {
InvoiceID string `json:"invoice_id"`
Count int64 `json:"count"`
}

func (r *adapter) AssociatedLineCounts(ctx context.Context, input billing.AssociatedLineCountsAdapterInput) (billing.AssociatedLineCountsAdapterResponse, error) {
queryOut := []lineCountQueryOut{}

err := r.db.BillingInvoiceLine.Query().
Where(billinginvoiceline.DeletedAtIsNil()).
Where(billinginvoiceline.Namespace(input.Namespace)).
Where(billinginvoiceline.InvoiceIDIn(input.InvoiceIDs...)).
Where(billinginvoiceline.StatusIn(billingentity.InvoiceLineStatusValid)).
GroupBy(billinginvoiceline.FieldInvoiceID).
Aggregate(
db.Count(),
).
Scan(ctx, &queryOut)
if err != nil {
return billing.AssociatedLineCountsAdapterResponse{}, err
}

res := lo.Associate(queryOut, func(q lineCountQueryOut) (billingentity.InvoiceID, int64) {
return billingentity.InvoiceID{
Namespace: input.Namespace,
ID: q.InvoiceID,
}, q.Count
})

for _, invoiceID := range input.InvoiceIDs {
id := billingentity.InvoiceID{
Namespace: input.Namespace,
ID: invoiceID,
}
if _, found := res[id]; !found {
res[id] = 0
}
}

return billing.AssociatedLineCountsAdapterResponse{
Counts: res,
}, nil
}

func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (billingentity.Invoice, error) {
res := billingentity.Invoice{
ID: invoice.ID,
Expand Down
64 changes: 56 additions & 8 deletions openmeter/billing/adapter/invoicelines.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"github.com/openmeterio/openmeter/openmeter/billing"
billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity"
"github.com/openmeterio/openmeter/openmeter/ent/db"
"github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice"
"github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline"
"github.com/openmeterio/openmeter/pkg/models"
)

var _ billing.InvoiceLineAdapter = (*adapter)(nil)
Expand Down Expand Up @@ -69,17 +69,65 @@ func (r *adapter) CreateInvoiceLines(ctx context.Context, input billing.CreateIn
return result, nil
}

func (r *adapter) GetInvoiceLineByID(ctx context.Context, id models.NamespacedID) (billingentity.Line, error) {
dbLine, err := r.db.BillingInvoiceLine.Query().
Where(billinginvoiceline.ID(id.ID)).
Where(billinginvoiceline.Namespace(id.Namespace)).
func (r *adapter) ListInvoiceLines(ctx context.Context, input billing.ListInvoiceLinesAdapterInput) ([]billingentity.Line, error) {
if err := input.Validate(); err != nil {
return nil, err
}

query := r.db.BillingInvoiceLine.Query().
Where(billinginvoiceline.Namespace(input.Namespace))

if len(input.LineIDs) > 0 {
query = query.Where(billinginvoiceline.IDIn(input.LineIDs...))
}

if input.InvoiceAtBefore != nil {
query = query.Where(billinginvoiceline.InvoiceAtLT(*input.InvoiceAtBefore))
}

query = query.WithBillingInvoice(func(biq *db.BillingInvoiceQuery) {
biq.Where(billinginvoice.Namespace(input.Namespace))

if input.CustomerID != "" {
biq.Where(billinginvoice.CustomerID(input.CustomerID))
}

if len(input.InvoiceStatuses) > 0 {
biq.Where(billinginvoice.StatusIn(input.InvoiceStatuses...))
}
})

dbLines, err := query.
WithBillingInvoiceManualLines().
Only(ctx)
All(ctx)
if err != nil {
return nil, err
}

return lo.Map(dbLines, func(line *db.BillingInvoiceLine, _ int) billingentity.Line {
return mapInvoiceLineFromDB(line)
}), nil
}

func (r *adapter) AssociateLinesToInvoice(ctx context.Context, input billing.AssociateLinesToInvoiceAdapterInput) error {
if err := input.Validate(); err != nil {
return err
}

nAffected, err := r.db.BillingInvoiceLine.Update().
SetInvoiceID(input.Invoice.ID).
Where(billinginvoiceline.Namespace(input.Invoice.Namespace)).
Where(billinginvoiceline.IDIn(input.LineIDs...)).
Save(ctx)
if err != nil {
return billingentity.Line{}, err
return fmt.Errorf("associating lines: %w", err)
}

if nAffected != len(input.LineIDs) {
return fmt.Errorf("fewer lines were associated (%d) than expected (%d)", nAffected, len(input.LineIDs))
}

return mapInvoiceLineFromDB(dbLine), nil
return nil
}

func mapInvoiceLineFromDB(dbLine *db.BillingInvoiceLine) billingentity.Line {
Expand Down
5 changes: 5 additions & 0 deletions openmeter/billing/customeroverride.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,8 @@ type HasCustomerOverrideReferencingProfileAdapterInput genericNamespaceID
func (i HasCustomerOverrideReferencingProfileAdapterInput) Validate() error {
return genericNamespaceID(i).Validate()
}

type (
UpsertCustomerOverrideAdapterInput = customerentity.CustomerID
LockCustomerForUpdateAdapterInput = customerentity.CustomerID
)
6 changes: 6 additions & 0 deletions openmeter/billing/entity/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ func (s InvoiceStatus) IsMutable() bool {
return true
}

type InvoiceID models.NamespacedID

func (i InvoiceID) Validate() error {
return models.NamespacedID(i).Validate()
}

type Invoice struct {
Namespace string `json:"namespace"`
ID string `json:"id"`
Expand Down
1 change: 1 addition & 0 deletions openmeter/billing/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const (
EntityCustomer = "Customer"
EntityDefaultProfile = "DefaultBillingProfile"
EntityInvoice = "Invoice"
EntityInvoiceLine = "InvoiceLine"
)

type NotFoundError struct {
Expand Down
8 changes: 4 additions & 4 deletions openmeter/billing/httpdriver/invoiceline.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type (

func (h *handler) CreateLineByCustomer() CreateLineByCustomerHandler {
return httptransport.NewHandlerWithArgs(
func(ctx context.Context, r *http.Request, customerKeyOrId string) (CreateLineByCustomerRequest, error) {
func(ctx context.Context, r *http.Request, customerID string) (CreateLineByCustomerRequest, error) {
body := api.BillingCreateLineByCustomerJSONRequestBody{}

if err := commonhttp.JSONRequestBodyDecoder(r, &body); err != nil {
Expand All @@ -46,9 +46,9 @@ func (h *handler) CreateLineByCustomer() CreateLineByCustomerHandler {
}

return CreateLineByCustomerRequest{
CustomerKeyOrID: customerKeyOrId,
Namespace: ns,
Lines: lines,
CustomerID: customerID,
Namespace: ns,
Lines: lines,
}, nil
},
func(ctx context.Context, request CreateLineByCustomerRequest) (CreateLineByCustomerResponse, error) {
Expand Down
Loading

0 comments on commit d0c308e

Please sign in to comment.