diff --git a/openmeter/billing/adapter/invoicelines.go b/openmeter/billing/adapter/invoicelines.go index 093e53d5b..194973fed 100644 --- a/openmeter/billing/adapter/invoicelines.go +++ b/openmeter/billing/adapter/invoicelines.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/alpacahq/alpacadecimal" + "github.com/oklog/ulid/v2" "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing" @@ -44,6 +45,14 @@ func (r *adapter) CreateInvoiceLines(ctx context.Context, input billing.CreateIn newEnt = newEnt.SetTaxConfig(*line.TaxConfig) } + if line.ChildUniqueReferenceID != nil { + newEnt = newEnt.SetChildUniqueReferenceID(*line.ChildUniqueReferenceID) + } else { + id := ulid.Make().String() + newEnt = newEnt.SetChildUniqueReferenceID(id). + SetID(id) + } + edges := db.BillingInvoiceLineEdges{} switch line.Type { @@ -230,6 +239,12 @@ func (r *adapter) UpdateInvoiceLine(ctx context.Context, input billing.UpdateInv SetStatus(input.Status). SetOrClearTaxConfig(input.TaxConfig) + if input.ChildUniqueReferenceID != nil { + updateLine = updateLine.SetChildUniqueReferenceID(*input.ChildUniqueReferenceID) + } else { + updateLine = updateLine.SetChildUniqueReferenceID(existingLine.ID) + } + edges := db.BillingInvoiceLineEdges{} // Let's update the line based on the type @@ -370,6 +385,10 @@ func mapInvoiceLineFromDB(dbLine *db.BillingInvoiceLine) (billingentity.Line, er }, ParentLineID: dbLine.ParentLineID, + ChildUniqueReferenceID: lo.If( + dbLine.ChildUniqueReferenceID != dbLine.ID, + lo.ToPtr(dbLine.ChildUniqueReferenceID), + ).Else(nil), InvoiceAt: dbLine.InvoiceAt, diff --git a/openmeter/billing/entity/errors.go b/openmeter/billing/entity/errors.go index 964a5736b..8a966f285 100644 --- a/openmeter/billing/entity/errors.go +++ b/openmeter/billing/entity/errors.go @@ -28,6 +28,8 @@ var ( ErrInvoiceActionNotAvailable = NewValidationError("invoice_action_not_available", "invoice action not available") ErrInvoiceLineFeatureHasNoMeters = NewValidationError("invoice_line_feature_has_no_meters", "usage based invoice line: feature has no meters") + ErrInvoiceLineGraduatedSplitNotSupported = NewValidationError("invoice_line_graduated_split_not_supported", "graduated tiered pricing is not supported for split periods") + ErrInvoiceLineNoTiers = NewValidationError("invoice_line_no_tiers", "usage based invoice line: no tiers found") ErrInvoiceCreateNoLines = NewValidationError("invoice_create_no_lines", "the new invoice would have no lines") ErrInvoiceCreateUBPLineCustomerHasNoSubjects = NewValidationError("invoice_create_ubp_line_customer_has_no_subjects", "creating an usage based line: customer has no subjects") ErrInvoiceCreateUBPLinePeriodIsEmpty = NewValidationError("invoice_create_ubp_line_period_is_empty", "creating an usage based line: truncated period is empty") diff --git a/openmeter/billing/entity/invoiceline.go b/openmeter/billing/entity/invoiceline.go index 515536c08..951c6a666 100644 --- a/openmeter/billing/entity/invoiceline.go +++ b/openmeter/billing/entity/invoiceline.go @@ -34,6 +34,8 @@ const ( InvoiceLineStatusValid InvoiceLineStatus = "valid" // InvoiceLineStatusSplit is a split invoice line (the child lines will have this set as parent). InvoiceLineStatusSplit InvoiceLineStatus = "split" + // InvoiceLineStatusDetailed is a detailed invoice line. + InvoiceLineStatusDetailed InvoiceLineStatus = "detailed" ) func (InvoiceLineStatus) Values() []string { @@ -109,12 +111,14 @@ type LineBase struct { // TODO: Add discounts etc // Relationships - ParentLineID *string `json:"parentLine,omitempty"` - ParentLine *Line `json:"parent,omitempty"` - RelatedLines []string `json:"relatedLine,omitempty"` - Status InvoiceLineStatus `json:"status"` + ParentLineID *string `json:"parentLine,omitempty"` + ParentLine *Line `json:"parent,omitempty"` + DetailedLines []Line `json:"detailedLines,omitempty"` + Status InvoiceLineStatus `json:"status"` + ChildUniqueReferenceID *string `json:"childUniqueReferenceID,omitempty"` - TaxConfig *TaxConfig `json:"taxOverrides,omitempty"` + TaxConfig *TaxConfig `json:"taxOverrides,omitempty"` + Discounts []LineDiscount `json:"discounts,omitempty"` Total alpacadecimal.Decimal `json:"total"` } @@ -228,3 +232,29 @@ func (i UsageBasedLine) Validate() error { return nil } + +type LineDiscountSource string + +const ( + // ManualLineDiscountSource is a manually added discount. + ManualLineDiscountSource LineDiscountSource = "manual" + // CalculatedLineDiscountSource is a discount applied due to maximum spend. + CalculatedLineDiscountSource LineDiscountSource = "calculated" +) + +type LineDiscountType string + +const ( + // MaximumSpendLineDiscountType is a discount applied due to maximum spend. + MaximumSpendLineDiscountType LineDiscountType = "maximum_spend" + // CappedTierLineDiscountType is a discount applied due to capped tier (e.g. we are over the biggest tier and the tier structure is not open ended). + CappedTierLineDiscountType LineDiscountType = "capped_tier" +) + +type LineDiscount struct { + ID string `json:"id"` + Amount alpacadecimal.Decimal `json:"amount"` + Description *string `json:"description,omitempty"` + Type *LineDiscountType `json:"type,omitempty"` + Source LineDiscountSource `json:"source"` +} diff --git a/openmeter/billing/service/lineservice/linebase.go b/openmeter/billing/service/lineservice/linebase.go index b6d84fe7b..4992b4f75 100644 --- a/openmeter/billing/service/lineservice/linebase.go +++ b/openmeter/billing/service/lineservice/linebase.go @@ -51,6 +51,8 @@ type LineBase interface { Period() billingentity.Period Status() billingentity.InvoiceLineStatus HasParent() bool + // IsLastInPeriod returns true if the line is the last line in the period that is going to be invoiced. + IsLastInPeriod() bool CloneForCreate(in UpdateInput) Line Update(in UpdateInput) Line @@ -112,6 +114,18 @@ func (l lineBase) Validate(ctx context.Context, invoice *billingentity.Invoice) return nil } +func (l lineBase) IsLastInPeriod() bool { + return (l.line.Status == billingentity.InvoiceLineStatusValid && // We only care about valid lines + (l.line.ParentLineID == nil || // Either we haven't split the line + l.line.Period.End.Equal(l.line.ParentLine.Period.End))) // Or we have split the line and this is the last split +} + +func (l lineBase) IsFirstInPeriod() bool { + return (l.line.Status == billingentity.InvoiceLineStatusValid && // We only care about valid lines + (l.line.ParentLineID == nil || // Either we haven't split the line + l.line.Period.Start.Equal(l.line.ParentLine.Period.Start))) // Or we have split the line and this is the last split +} + func (l lineBase) Save(ctx context.Context) (Line, error) { line, err := l.service.BillingAdapter.UpdateInvoiceLine(ctx, billing.UpdateInvoiceLineAdapterInput(l.line)) if err != nil { diff --git a/openmeter/billing/service/lineservice/service.go b/openmeter/billing/service/lineservice/service.go index 66365757c..3ed832ff6 100644 --- a/openmeter/billing/service/lineservice/service.go +++ b/openmeter/billing/service/lineservice/service.go @@ -169,8 +169,8 @@ func (s *Service) AssociateLinesToInvoice(ctx context.Context, invoice *billinge } type snapshotQuantityResult struct { - Line Line - // TODO[OM-980]: Detailed lines should be returned here, that we are upserting based on the qty as described in README.md (see `Detailed Lines vs Splitting`) + Line Line + DetailedLines []Line } type Line interface { diff --git a/openmeter/billing/service/lineservice/usagebasedline.go b/openmeter/billing/service/lineservice/usagebasedline.go index b2589f7a4..78ca5d2e0 100644 --- a/openmeter/billing/service/lineservice/usagebasedline.go +++ b/openmeter/billing/service/lineservice/usagebasedline.go @@ -2,17 +2,35 @@ package lineservice import ( "context" + "fmt" "time" + "github.com/alpacahq/alpacadecimal" "github.com/samber/lo" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" "github.com/openmeterio/openmeter/openmeter/productcatalog/plan" "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/slicesx" ) var _ Line = usageBasedLine{} +const ( + FlatPriceChildUniqueReferenceID = "flat-price" + UnitPriceUsageChildUniqueReferenceID = "unit-price-usage" + UnitPriceMinSpendChildUniqueReferenceID = "unit-price-min-spend" + UnitPriceMaxSpendChildUniqueReferenceID = "unit-price-max-spend" + + GraduatedFlatPriceChildUniqueReferenceID = "graduated-flat-price" + GraduatedUnitPriceChildUniqueReferenceID = "graduated-tiered-price" + GraduatedMinSpendChildUniqueReferenceID = "graduated-min-spend" + + VolumeTieredPriceUsageChildUniqueReferenceID = "volume-tiered-%d-price-usage" + VolumeTieredFlatPriceChildUniqueReferenceID = "volume-tiered-%d-flat-price" + VolumeMinSpendChildUniqueReferenceID = "volume-tiered-min-spend" +) + type usageBasedLine struct { lineBase } @@ -135,3 +153,621 @@ func (l usageBasedLine) SnapshotQuantity(ctx context.Context, invoice *billingen Line: updatedLine, }, nil } + +func (l usageBasedLine) calculateDetailedLines(usage *featureUsageResponse) (newDetailedLinesInput, error) { + switch l.line.UsageBased.Price.Type() { + case plan.FlatPriceType: + flatPrice, err := l.line.UsageBased.Price.AsFlat() + if err != nil { + return nil, fmt.Errorf("converting price to flat price: %w", err) + } + return l.calculateFlatPriceDetailedLines(usage, flatPrice) + + case plan.UnitPriceType: + unitPrice, err := l.line.UsageBased.Price.AsUnit() + if err != nil { + return nil, fmt.Errorf("converting price to unit price: %w", err) + } + + return l.calculateUnitPriceDetailedLines(usage, unitPrice) + case plan.TieredPriceType: + tieredPrice, err := l.line.UsageBased.Price.AsTiered() + if err != nil { + return nil, fmt.Errorf("converting price to tiered price: %w", err) + } + + switch tieredPrice.Mode { + case plan.GraduatedTieredPrice: + return l.calculateGraduatedTieredPriceDetailedLines(usage, tieredPrice) + + case plan.VolumeTieredPrice: + return l.calculateVolumeTieredPriceDetailedLines(usage, tieredPrice) + default: + return nil, fmt.Errorf("unsupported tiered price mode: %s", tieredPrice.Mode) + } + default: + return nil, fmt.Errorf("unsupported price type: %s", l.line.UsageBased.Price.Type()) + } +} + +func (l usageBasedLine) calculateFlatPriceDetailedLines(_ *featureUsageResponse, flatPrice plan.FlatPrice) (newDetailedLinesInput, error) { + // Flat price is the same as the non-metered version, we just allow attaching entitlements to it + switch { + case flatPrice.PaymentTerm == plan.InAdvancePaymentTerm && l.IsFirstInPeriod(): + return newDetailedLinesInput{ + { + Name: l.line.Name, + Quantity: alpacadecimal.NewFromInt(1), + Amount: flatPrice.Amount, + ChildUniqueReferenceID: FlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InAdvancePaymentTerm, + }, + }, nil + case flatPrice.PaymentTerm != plan.InAdvancePaymentTerm && l.IsLastInPeriod(): + return newDetailedLinesInput{ + { + Name: l.line.Name, + Quantity: alpacadecimal.NewFromInt(1), + Amount: flatPrice.Amount, + ChildUniqueReferenceID: FlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, nil + } + + return nil, nil +} + +func (l usageBasedLine) calculateUnitPriceDetailedLines(usage *featureUsageResponse, unitPrice plan.UnitPrice) (newDetailedLinesInput, error) { + out := make(newDetailedLinesInput, 0, 3) + totalPreUsageAmount := usage.PreLinePeriodQty.Mul(unitPrice.Amount) + + if usage.LinePeriodQty.IsPositive() { + usageLine := newDetailedLineInput{ + Name: fmt.Sprintf("%s: usage in period", l.line.Name), + Quantity: usage.LinePeriodQty, + Amount: unitPrice.Amount, + ChildUniqueReferenceID: UnitPriceUsageChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + } + + if unitPrice.MaximumAmount != nil { + // We need to apply the discount for the usage that is over the maximum spend + usageLine = usageLine.AddDiscountForOverage(addDiscountInput{ + BilledAmountBeforeLine: totalPreUsageAmount, + MaxSpend: *unitPrice.MaximumAmount, + }) + } + + out = append(out, usageLine) + } + + // Minimum spend is always billed arrears + if l.IsLastInPeriod() && unitPrice.MinimumAmount != nil { + totalUsageAmount := totalPreUsageAmount.Add(out.Sum()) + if totalUsageAmount.LessThan(*unitPrice.MinimumAmount) { + period := l.line.Period + if l.line.ParentLine != nil { + period = l.line.ParentLine.Period + } + + out = append(out, newDetailedLineInput{ + Name: fmt.Sprintf("%s: minimum spend", l.line.Name), + Quantity: alpacadecimal.NewFromInt(1), + Amount: unitPrice.MinimumAmount.Sub(totalUsageAmount), + // Min spend is always billed for the whole period + Period: &period, + ChildUniqueReferenceID: UnitPriceMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }) + } + } + + return out, nil +} + +func (l usageBasedLine) calculateGraduatedTieredPriceDetailedLines(usage *featureUsageResponse, price plan.TieredPrice) (newDetailedLinesInput, error) { + if !usage.PreLinePeriodQty.IsZero() { + return nil, billingentity.ErrInvoiceLineGraduatedSplitNotSupported + } + + if !l.IsLastInPeriod() { + return nil, nil + } + + out := make(newDetailedLinesInput, 0, 4) + + // No usage => we are not billing any tiers + if !usage.LinePeriodQty.IsZero() { + tier, tierIndex := findTierForQuantity(price, usage.LinePeriodQty) + if tier == nil { + return nil, fmt.Errorf("could not find tier for quantity %s (most probably tier is not open ended, thus invalid)", usage.LinePeriodQty) + } + + if tier.FlatPrice != nil { + line := newDetailedLineInput{ + Name: fmt.Sprintf("%s: flat price for tier %d", l.line.Name, tierIndex+1), + Quantity: alpacadecimal.NewFromInt(1), + Amount: tier.FlatPrice.Amount, + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + } + + if price.MaximumAmount != nil { + line = line.AddDiscountForOverage(addDiscountInput{ + BilledAmountBeforeLine: out.Sum(), + MaxSpend: *price.MaximumAmount, + }) + } + out = append(out, line) + } + + if tier.UnitPrice != nil { + line := newDetailedLineInput{ + Name: fmt.Sprintf("%s: unit price for tier %d", l.line.Name, tierIndex+1), + Quantity: usage.LinePeriodQty, + Amount: tier.UnitPrice.Amount, + ChildUniqueReferenceID: GraduatedUnitPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + } + + if price.MaximumAmount != nil { + line = line.AddDiscountForOverage(addDiscountInput{ + BilledAmountBeforeLine: out.Sum(), + MaxSpend: *price.MaximumAmount, + }) + } + + out = append(out, line) + } + } + + total := out.Sum() + + if price.MinimumAmount != nil && total.LessThan(*price.MinimumAmount) { + out = append(out, newDetailedLineInput{ + Name: fmt.Sprintf("%s: minimum spend", l.line.Name), + Quantity: alpacadecimal.NewFromInt(1), + Amount: price.MinimumAmount.Sub(total), + ChildUniqueReferenceID: GraduatedMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }) + } + + return out, nil +} + +func findTierForQuantity(price plan.TieredPrice, quantity alpacadecimal.Decimal) (*plan.PriceTier, int) { + for i, tier := range price.WithSortedTiers().Tiers { + if tier.UpToAmount == nil || quantity.LessThanOrEqual(*tier.UpToAmount) { + return &price.Tiers[i], i + } + } + + // Technically this should not happen, as the last tier should have an upper limit of infinity + return nil, 0 +} + +func (l usageBasedLine) calculateVolumeTieredPriceDetailedLines(usage *featureUsageResponse, price plan.TieredPrice) (newDetailedLinesInput, error) { + out := make(newDetailedLinesInput, 0, len(price.Tiers)) + + err := tieredPriceCalculator(tieredPriceCalculatorInput{ + TieredPrice: price, + FromQty: usage.PreLinePeriodQty, + ToQty: usage.LinePeriodQty.Add(usage.PreLinePeriodQty), + TierCallbackFn: func(in tierCallbackInput) error { + billedAmount := in.PreviousTotalAmount + + tierIndex := in.TierIndex + 1 + + if in.Tier.UnitPrice != nil { + newLine := newDetailedLineInput{ + Name: fmt.Sprintf("%s: usage price for tier %d", l.line.Name, tierIndex), + Quantity: in.Quantity, + Amount: in.Tier.UnitPrice.Amount, + ChildUniqueReferenceID: fmt.Sprintf(VolumeTieredPriceUsageChildUniqueReferenceID, tierIndex), + PaymentTerm: plan.InArrearsPaymentTerm, + } + + if price.MaximumAmount != nil { + newLine = newLine.AddDiscountForOverage(addDiscountInput{ + BilledAmountBeforeLine: billedAmount, + MaxSpend: *price.MaximumAmount, + }) + } + + billedAmount = billedAmount.Add(in.Quantity.Mul(in.Tier.UnitPrice.Amount)) + + out = append(out, newLine) + } + + // Flat price is always billed for the whole tier when we are crossing the tier boundary + if in.Tier.FlatPrice != nil && in.AtTierBoundary { + newLine := newDetailedLineInput{ + Name: fmt.Sprintf("%s: flat price for tier %d", l.line.Name, tierIndex), + Quantity: alpacadecimal.NewFromInt(1), + Amount: in.Tier.FlatPrice.Amount, + ChildUniqueReferenceID: fmt.Sprintf(VolumeTieredFlatPriceChildUniqueReferenceID, tierIndex), + PaymentTerm: plan.InArrearsPaymentTerm, + } + + if price.MaximumAmount != nil { + newLine = newLine.AddDiscountForOverage(addDiscountInput{ + BilledAmountBeforeLine: billedAmount, + MaxSpend: *price.MaximumAmount, + }) + } + + out = append(out, newLine) + } + return nil + }, + FinalizerFn: func(periodTotal alpacadecimal.Decimal) error { + if l.IsLastInPeriod() && price.MinimumAmount != nil && periodTotal.LessThan(*price.MinimumAmount) { + out = append(out, newDetailedLineInput{ + Name: fmt.Sprintf("%s: minimum spend", l.line.Name), + Quantity: alpacadecimal.NewFromInt(1), + Amount: price.MinimumAmount.Sub(periodTotal), + ChildUniqueReferenceID: VolumeMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }) + } + + return nil + }, + }) + if err != nil { + return nil, fmt.Errorf("calculating tiered price: %w", err) + } + + return out, nil +} + +type tierRange struct { + Tier plan.PriceTier + TierIndex int + + FromQty alpacadecimal.Decimal // exclusive + ToQty alpacadecimal.Decimal // inclusive + AtTierBoundary bool +} + +type tierCallbackInput struct { + Tier plan.PriceTier + TierIndex int + Quantity alpacadecimal.Decimal + AtTierBoundary bool + PreviousTotalAmount alpacadecimal.Decimal +} + +type tieredPriceCalculatorInput struct { + TieredPrice plan.TieredPrice + // FromQty is the quantity that was already billed for the previous tiers (exclusive) + FromQty alpacadecimal.Decimal + // ToQty is the quantity that we are going to bill for this tiered price (inclusive) + ToQty alpacadecimal.Decimal + + TierCallbackFn func(tierCallbackInput) error + FinalizerFn func(total alpacadecimal.Decimal) error + IntrospectRangesFn func(ranges []tierRange) +} + +func (i tieredPriceCalculatorInput) Validate() error { + if err := i.TieredPrice.Validate(); err != nil { + return err + } + + if i.TieredPrice.Mode != plan.VolumeTieredPrice { + return fmt.Errorf("only volume tiered prices are supported") + } + + if i.FromQty.IsNegative() { + return fmt.Errorf("from quantity must be zero or positive") + } + + if i.ToQty.IsNegative() { + return fmt.Errorf("to quantity must be zero or positive") + } + + if i.ToQty.LessThan(i.FromQty) { + return fmt.Errorf("to quantity must be greater or equal to from quantity") + } + + return nil +} + +func splitTierRangeAtBoundary(from, to alpacadecimal.Decimal, qtyRange tierRange) []tierRange { + res := make([]tierRange, 0, 3) + + // Pending line is always the last line, as we might need to split it + pendingLine := qtyRange + + // If from == in.FromQty we don't need to split the range, as the range is already at some boundary + if pendingLine.FromQty.LessThan(from) && pendingLine.ToQty.GreaterThan(from) { + // We need to split the range at the from boundary + res = append(res, tierRange{ + Tier: pendingLine.Tier, + TierIndex: pendingLine.TierIndex, + + FromQty: pendingLine.FromQty, + ToQty: from, + + AtTierBoundary: pendingLine.AtTierBoundary, + }) + + pendingLine = tierRange{ + Tier: pendingLine.Tier, + TierIndex: pendingLine.TierIndex, + + FromQty: from, + ToQty: pendingLine.ToQty, + } + } + + // If to == in.ToQty we don't need to split the range, as the range is already at some boundary + if pendingLine.FromQty.LessThan(to) && pendingLine.ToQty.GreaterThan(to) { + res = append(res, tierRange{ + Tier: pendingLine.Tier, + TierIndex: pendingLine.TierIndex, + + FromQty: pendingLine.FromQty, + ToQty: to, + + AtTierBoundary: pendingLine.AtTierBoundary, + }) + pendingLine = tierRange{ + Tier: pendingLine.Tier, + TierIndex: pendingLine.TierIndex, + + FromQty: to, + ToQty: pendingLine.ToQty, + } + } + + return append(res, pendingLine) +} + +// getTotalAmountForVolumeTieredPrice calculates the total amount for a volume tiered price for a given quantity +// without considering any discounts +func tieredPriceCalculator(in tieredPriceCalculatorInput) error { + // Note: this is not the most efficient algorithm, but it is at least pseudo-readable + if err := in.Validate(); err != nil { + return err + } + + // Let's break up the tiers and the input data into a sequence of periods, for easier processing + // Invariant of the qtyRanges: + // - Non overlapping ranges + // - The ranges are sorted by the from quantity + // - There is always one range for which range.From == in.FromQty + // - There is always one range for which range.ToQty == in.ToQty + qtyRanges := make([]tierRange, 0, len(in.TieredPrice.Tiers)+2) + + previousTierQty := alpacadecimal.Zero + for idx, tier := range in.TieredPrice.WithSortedTiers().Tiers { + if previousTierQty.GreaterThanOrEqual(in.ToQty) { + // We already have enough data to bill for this tiered price + break + } + + // Given that the previous tier's max qty was less than then in.ToQty, toQty will fall into the + // open ended tier, so we can safely use it as the upper bound + tierUpperBound := in.ToQty + if tier.UpToAmount != nil { + tierUpperBound = *tier.UpToAmount + } + + input := tierRange{ + Tier: tier, + TierIndex: idx, + AtTierBoundary: true, + FromQty: previousTierQty, + ToQty: tierUpperBound, + } + + qtyRanges = append(qtyRanges, splitTierRangeAtBoundary(in.FromQty, in.ToQty, input)...) + + previousTierQty = tierUpperBound + } + + if in.IntrospectRangesFn != nil { + in.IntrospectRangesFn(qtyRanges) + } + + // Now that we have the ranges, let's iterate over the ranges and calculate the cummulative total amount + // and call the callback for each in-scope range + total := alpacadecimal.Zero + shouldEmitCallbacks := false + for _, qtyRange := range qtyRanges { + if qtyRange.FromQty.Equal(in.FromQty) { + shouldEmitCallbacks = true + } + + if shouldEmitCallbacks && in.TierCallbackFn != nil { + err := in.TierCallbackFn(tierCallbackInput{ + Tier: qtyRange.Tier, + TierIndex: qtyRange.TierIndex, + Quantity: qtyRange.ToQty.Sub(qtyRange.FromQty), + PreviousTotalAmount: total, + AtTierBoundary: qtyRange.AtTierBoundary, + }) + if err != nil { + return err + } + } + + // Let's update totals + if qtyRange.Tier.FlatPrice != nil && qtyRange.AtTierBoundary { + total = total.Add(qtyRange.Tier.FlatPrice.Amount) + } + + if qtyRange.Tier.UnitPrice != nil { + total = total.Add(qtyRange.ToQty.Sub(qtyRange.FromQty).Mul(qtyRange.Tier.UnitPrice.Amount)) + } + + // We should only calculate totals up to in.ToQty (given tiers are open-ended we cannot have a full upper bound + // either ways) + if qtyRange.ToQty.GreaterThanOrEqual(in.ToQty) { + break + } + } + + if in.FinalizerFn != nil { + if err := in.FinalizerFn(total); err != nil { + return err + } + } + + return nil +} + +type newDetailedLinesInput []newDetailedLineInput + +func (i newDetailedLinesInput) Sum() alpacadecimal.Decimal { + sum := alpacadecimal.Zero + + for _, in := range i { + sum = sum.Add(in.Amount.Mul(in.Quantity)) + + for _, discount := range in.Discounts { + sum = sum.Sub(discount.Amount) + } + } + + return sum +} + +func (i newDetailedLinesInput) ApplyDiscount(amount alpacadecimal.Decimal, discountType billingentity.LineDiscountType) alpacadecimal.Decimal { + remaining := amount + for idx := range i { + if remaining.IsZero() { + break + } + + line := &i[idx] + lineTotal := line.Amount.Mul(line.Quantity) + + if lineTotal.LessThan(remaining) { + line.Discounts = append(line.Discounts, billingentity.LineDiscount{ + Amount: lineTotal, + Description: formatMaximumSpendDiscountDescription(lineTotal), + Type: lo.ToPtr(discountType), + Source: billingentity.CalculatedLineDiscountSource, + }) + + remaining = remaining.Sub(lineTotal) + continue + } + + line.Discounts = append(line.Discounts, billingentity.LineDiscount{ + Amount: remaining, + Description: formatMaximumSpendDiscountDescription(remaining), + Type: lo.ToPtr(discountType), + Source: billingentity.CalculatedLineDiscountSource, + }) + } + return remaining +} + +type newDetailedLineInput struct { + Name string `json:"name"` + Quantity alpacadecimal.Decimal `json:"quantity"` + Amount alpacadecimal.Decimal `json:"amount"` + ChildUniqueReferenceID string `json:"childUniqueReferenceID"` + Period *billingentity.Period `json:"period,omitempty"` + // PaymentTerm is the payment term for the detailed line, defaults to arrears + PaymentTerm plan.PaymentTermType `json:"paymentTerm,omitempty"` + + Discounts []billingentity.LineDiscount `json:"discounts,omitempty"` +} + +func (i newDetailedLineInput) Validate() error { + if i.Quantity.IsNegative() { + return fmt.Errorf("quantity must be zero or positive") + } + + if i.Amount.IsNegative() { + return fmt.Errorf("amount must be zero or positive") + } + + if i.ChildUniqueReferenceID == "" { + return fmt.Errorf("child unique ID is required") + } + + if i.Name == "" { + return fmt.Errorf("name is required") + } + + return nil +} + +type addDiscountInput struct { + BilledAmountBeforeLine alpacadecimal.Decimal + MaxSpend alpacadecimal.Decimal +} + +func (i newDetailedLineInput) AddDiscountForOverage(in addDiscountInput) newDetailedLineInput { + lineTotal := i.Amount.Mul(i.Quantity) + currentBillableAmount := in.BilledAmountBeforeLine.Add(lineTotal) + + if currentBillableAmount.LessThanOrEqual(in.MaxSpend) { + // Nothing to do here + return i + } + + if currentBillableAmount.GreaterThanOrEqual(in.MaxSpend) && in.BilledAmountBeforeLine.GreaterThanOrEqual(in.MaxSpend) { + // 100% discount + i.Discounts = append(i.Discounts, billingentity.LineDiscount{ + Amount: lineTotal, + Description: formatMaximumSpendDiscountDescription(in.MaxSpend), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }) + return i + } + + discountAmount := currentBillableAmount.Sub(in.MaxSpend) + i.Discounts = append(i.Discounts, billingentity.LineDiscount{ + Amount: discountAmount, + Description: formatMaximumSpendDiscountDescription(in.MaxSpend), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }) + + return i +} + +func (l usageBasedLine) newDetailedLines(inputs ...newDetailedLineInput) ([]Line, error) { + return slicesx.MapWithErr(inputs, func(in newDetailedLineInput) (Line, error) { + if err := in.Validate(); err != nil { + return nil, err + } + + return l.service.FromEntity(billingentity.Line{ + LineBase: billingentity.LineBase{ + Namespace: l.line.Namespace, + Type: billingentity.InvoiceLineTypeFee, + Status: billingentity.InvoiceLineStatusDetailed, + Period: lo.If(in.Period != nil, *in.Period).Else(l.line.Period), + Name: in.Name, + InvoiceAt: l.line.InvoiceAt, + InvoiceID: l.line.InvoiceID, + Currency: l.line.Currency, + ChildUniqueReferenceID: &in.ChildUniqueReferenceID, + ParentLineID: lo.ToPtr(l.line.ID), + // TODO: Parent line? + TaxConfig: l.line.TaxConfig, + }, + FlatFee: billingentity.FlatFeeLine{ + PaymentTerm: lo.CoalesceOrEmpty(in.PaymentTerm, plan.InArrearsPaymentTerm), + Amount: in.Amount, + Quantity: in.Quantity, + }, + }) + }) +} + +func formatMaximumSpendDiscountDescription(amount alpacadecimal.Decimal) *string { + // TODO[later]: currency formatting + return lo.ToPtr(fmt.Sprintf("Maximum spend discount for charges over %s", amount)) +} diff --git a/openmeter/billing/service/lineservice/usagebasedline_test.go b/openmeter/billing/service/lineservice/usagebasedline_test.go new file mode 100644 index 000000000..27102bf45 --- /dev/null +++ b/openmeter/billing/service/lineservice/usagebasedline_test.go @@ -0,0 +1,1366 @@ +package lineservice + +import ( + "encoding/json" + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" + "github.com/openmeterio/openmeter/openmeter/productcatalog/plan" +) + +type testLineMode string + +const ( + singlePerPeriodLineMode testLineMode = "single_per_period" + midPeriodSplitLineMode testLineMode = "mid_period_split" + lastInPeriodSplitLineMode testLineMode = "last_in_period_split" +) + +var ubpTestFullPeriod = billingentity.Period{ + Start: lo.Must(time.Parse(time.RFC3339, "2021-01-01T00:00:00Z")), + End: lo.Must(time.Parse(time.RFC3339, "2021-01-02T00:00:00Z")), +} + +type ubpCalculationTestCase struct { + price plan.Price + lineMode testLineMode + usage featureUsageResponse + expect newDetailedLinesInput +} + +func runUBPTest(t *testing.T, tc ubpCalculationTestCase) { + t.Helper() + l := usageBasedLine{ + lineBase: lineBase{ + line: billingentity.Line{ + LineBase: billingentity.LineBase{ + ID: "fake-line", + Type: billingentity.InvoiceLineTypeUsageBased, + Status: billingentity.InvoiceLineStatusValid, + Name: "feature", + }, + UsageBased: billingentity.UsageBasedLine{ + Price: tc.price, + }, + }, + }, + } + + fakeParentLine := billingentity.Line{ + LineBase: billingentity.LineBase{ + ID: "fake-parent-line", + Period: ubpTestFullPeriod, + Status: billingentity.InvoiceLineStatusSplit, + }, + } + + switch tc.lineMode { + case singlePerPeriodLineMode: + l.line.Period = ubpTestFullPeriod + case midPeriodSplitLineMode: + l.line.Period = billingentity.Period{ + Start: ubpTestFullPeriod.Start.Add(time.Hour * 12), + End: ubpTestFullPeriod.End.Add(-time.Hour), + } + l.line.ParentLine = &fakeParentLine + l.line.ParentLineID = &fakeParentLine.ID + + case lastInPeriodSplitLineMode: + l.line.Period = billingentity.Period{ + Start: ubpTestFullPeriod.Start.Add(time.Hour * 12), + End: ubpTestFullPeriod.End, + } + + l.line.ParentLine = &fakeParentLine + l.line.ParentLineID = &fakeParentLine.ID + } + + res, err := l.calculateDetailedLines(&tc.usage) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // let's get around nil slices + if len(tc.expect) == 0 && len(res) == 0 { + return + } + + expectJSON, err := json.Marshal(tc.expect) + require.NoError(t, err) + + resJSON, err := json.Marshal(res) + require.NoError(t, err) + + require.JSONEq(t, string(expectJSON), string(resJSON)) +} + +func TestFlatLineCalculation(t *testing.T) { + // Flat price tests + t.Run("flat price no usage", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: plan.InAdvancePaymentTerm, + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: FlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InAdvancePaymentTerm, + }, + }, + }) + }) + + t.Run("flat price, in advance, usage present", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: plan.InAdvancePaymentTerm, + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{ + { + Name: "feature", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: FlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InAdvancePaymentTerm, + }, + }, + }) + }) + + t.Run("flat price, in advance, usage present, mid period", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: plan.InAdvancePaymentTerm, + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{}, + }) + }) + + t.Run("flat price, in arrears, usage present, single period line", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: plan.InArrearsPaymentTerm, + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{ + { + Name: "feature", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: FlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("flat price, in arrears, usage present, mid period line", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: plan.InArrearsPaymentTerm, + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{}, // It will be billed in the last period + }) + }) + + t.Run("flat price, in arrears, usage present, last period line", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(100), + PaymentTerm: plan.InArrearsPaymentTerm, + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{ + { + Name: "feature", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: FlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) +} + +func TestUnitPriceCalculation(t *testing.T) { + t.Run("unit price, no usage", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{}, + }) + }) + + // When there is no usage, we are still honoring the minimum spend + t.Run("unit price, no usage, min spend set", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: UnitPriceMinSpendChildUniqueReferenceID, + Period: &ubpTestFullPeriod, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + // Min spend is always billed in arrears => we are not billing it in advance + t.Run("no usage, not the last line in period, min spend set", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{}, + }) + }) + + // Min spend is always billed in arrears => we are billing it for the last line + t.Run("no usage, last line in period, min spend set", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: UnitPriceMinSpendChildUniqueReferenceID, + Period: &ubpTestFullPeriod, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + // Usage is billed regardless of line position + t.Run("usage present", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(100), + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: usage in period", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(10), + ChildUniqueReferenceID: UnitPriceUsageChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("usage present, mid line", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(100), + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: usage in period", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(10), + ChildUniqueReferenceID: UnitPriceUsageChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + // Max spend is always honored + t.Run("usage present, max spend set, but not hit", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: usage in period", + Amount: alpacadecimal.NewFromFloat(10), + Quantity: alpacadecimal.NewFromFloat(10), + ChildUniqueReferenceID: UnitPriceUsageChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("usage present, max spend set, but not hit", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(5), + PreLinePeriodQty: alpacadecimal.NewFromFloat(7), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: usage in period", + Amount: alpacadecimal.NewFromFloat(10), + Quantity: alpacadecimal.NewFromFloat(5), + ChildUniqueReferenceID: UnitPriceUsageChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 100"), + Amount: alpacadecimal.NewFromFloat(20), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, + }, + }) + }) +} + +func TestTieredGraduatedCalculation(t *testing.T) { + testTiers := []plan.PriceTier{ + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(5)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 20/unit + Amount: alpacadecimal.NewFromFloat(100), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(10)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 10/unit + Amount: alpacadecimal.NewFromFloat(150), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(15)), + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + }, + }, + { + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(5), + }, + }, + } + + t.Run("tiered graduated, mid price", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(10), + }, + expect: newDetailedLinesInput{}, + }) + }) + + t.Run("tiered graduated, last price, no usage", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{}, + }) + }) + + t.Run("tiered graduated, last price, usage present, tier1 mid", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(3), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 1", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, last price, usage present, tier1 top", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(5), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 1", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, last price, usage present, tier4", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(100), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: unit price for tier 4", + Amount: alpacadecimal.NewFromFloat(5), + Quantity: alpacadecimal.NewFromFloat(100), + ChildUniqueReferenceID: GraduatedUnitPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + // Minimum spend + + t.Run("tiered graduated, last price, no usage, min spend", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, last price, usage over, min spend", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(100), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: unit price for tier 4", + Amount: alpacadecimal.NewFromFloat(5), + Quantity: alpacadecimal.NewFromFloat(100), + ChildUniqueReferenceID: GraduatedUnitPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, last price, usage less than min spend", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(150)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(5), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 1", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(50), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, last price, usage less equals min spend", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(5), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 1", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, no usage, min spend should be returned", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + // Maximum spend + t.Run("tiered graduated, first price, usage eq max spend", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(100)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(5), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 1", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered graduated, first price, usage above max spend, max spend is not at tier boundary ", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.GraduatedTieredPrice, + Tiers: testTiers, + MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(125)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(7), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 2", + Amount: alpacadecimal.NewFromFloat(150), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: GraduatedFlatPriceChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 125"), + Amount: alpacadecimal.NewFromFloat(25), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, + }, + }) + }) +} + +func TestTieredVolumeCalculation(t *testing.T) { + testTiers := []plan.PriceTier{ + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(5)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 20/unit + Amount: alpacadecimal.NewFromFloat(100), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(10)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 10/unit + Amount: alpacadecimal.NewFromFloat(50), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(15)), + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(5), + }, + }, + { + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(1), + }, + }, + } + + t.Run("tiered volume, mid price, flat only => no lines are output", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + PreLinePeriodQty: alpacadecimal.NewFromFloat(7), + LinePeriodQty: alpacadecimal.NewFromFloat(1), + }, + expect: newDetailedLinesInput{}, + }) + }) + + t.Run("tiered volume, last price, no usage", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{}, + }) + }) + + t.Run("tiered volume, single period multiple tier usage", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + LinePeriodQty: alpacadecimal.NewFromFloat(22), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: flat price for tier 1", + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: "volume-tiered-1-flat-price", + PaymentTerm: plan.InArrearsPaymentTerm, + }, + { + Name: "feature: flat price for tier 2", + Amount: alpacadecimal.NewFromFloat(50), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: "volume-tiered-2-flat-price", + PaymentTerm: plan.InArrearsPaymentTerm, + }, + { + Name: "feature: usage price for tier 3", + Amount: alpacadecimal.NewFromFloat(5), + Quantity: alpacadecimal.NewFromFloat(5), + ChildUniqueReferenceID: "volume-tiered-3-price-usage", + PaymentTerm: plan.InArrearsPaymentTerm, + }, + { + Name: "feature: usage price for tier 4", + Amount: alpacadecimal.NewFromFloat(1), + Quantity: alpacadecimal.NewFromFloat(7), + ChildUniqueReferenceID: "volume-tiered-4-price-usage", + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered volume, mid period, multiple tier usage", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + PreLinePeriodQty: alpacadecimal.NewFromFloat(12), + LinePeriodQty: alpacadecimal.NewFromFloat(10), // total usage is at 22 + }, + expect: newDetailedLinesInput{ + { + Name: "feature: usage price for tier 3", + Amount: alpacadecimal.NewFromFloat(5), + Quantity: alpacadecimal.NewFromFloat(3), + ChildUniqueReferenceID: "volume-tiered-3-price-usage", + PaymentTerm: plan.InArrearsPaymentTerm, + }, + { + Name: "feature: usage price for tier 4", + Amount: alpacadecimal.NewFromFloat(1), + Quantity: alpacadecimal.NewFromFloat(7), + ChildUniqueReferenceID: "volume-tiered-4-price-usage", + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + // Minimum spend + + t.Run("tiered volume, last line, no usage, minimum price set", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(1000)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + PreLinePeriodQty: alpacadecimal.NewFromFloat(0), + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(1000), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: VolumeMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered volume, last line, no usage, minimum price set", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(1000)), + }), + lineMode: lastInPeriodSplitLineMode, + usage: featureUsageResponse{ + PreLinePeriodQty: alpacadecimal.NewFromFloat(2), + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{ + { + Name: "feature: minimum spend", + Amount: alpacadecimal.NewFromFloat(900), + Quantity: alpacadecimal.NewFromFloat(1), + ChildUniqueReferenceID: VolumeMinSpendChildUniqueReferenceID, + PaymentTerm: plan.InArrearsPaymentTerm, + }, + }, + }) + }) + + t.Run("tiered volume, mid line, no usage, minimum price set", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + MinimumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(1000)), + }), + lineMode: midPeriodSplitLineMode, + usage: featureUsageResponse{ + PreLinePeriodQty: alpacadecimal.NewFromFloat(2), + LinePeriodQty: alpacadecimal.NewFromFloat(0), + }, + expect: newDetailedLinesInput{}, + }) + }) + + // Maximum spend + t.Run("tiered volume, mid period, multiple tier usage, maximum spend set mid tier 2/3", func(t *testing.T) { + runUBPTest(t, ubpCalculationTestCase{ + price: plan.NewPriceFrom(plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: testTiers, + MaximumAmount: lo.ToPtr(alpacadecimal.NewFromFloat(170)), + }), + lineMode: singlePerPeriodLineMode, + usage: featureUsageResponse{ + PreLinePeriodQty: alpacadecimal.NewFromFloat(12), + LinePeriodQty: alpacadecimal.NewFromFloat(10), // total usage is at 22 + }, + + // Total previous usage due to the PreLinePeriodQty: + // tier 1: $100 flat + // tier 2: $50 flat + // tier 3: 2*$5 = $10 usage + // total: $160 + + expect: newDetailedLinesInput{ + { + Name: "feature: usage price for tier 3", + Amount: alpacadecimal.NewFromFloat(5), + Quantity: alpacadecimal.NewFromFloat(3), + ChildUniqueReferenceID: "volume-tiered-3-price-usage", + PaymentTerm: plan.InArrearsPaymentTerm, + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 170"), + Amount: alpacadecimal.NewFromFloat(5), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, + { + Name: "feature: usage price for tier 4", + Amount: alpacadecimal.NewFromFloat(1), + Quantity: alpacadecimal.NewFromFloat(7), + ChildUniqueReferenceID: "volume-tiered-4-price-usage", + PaymentTerm: plan.InArrearsPaymentTerm, + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 170"), + Amount: alpacadecimal.NewFromFloat(7), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, + }, + }) + }) +} + +func TestAddDiscountForOverage(t *testing.T) { + l := newDetailedLineInput{ + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(10), + } + + t.Run("no overage", func(t *testing.T) { + lineWithDiscount := l.AddDiscountForOverage(addDiscountInput{ + MaxSpend: alpacadecimal.NewFromFloat(10000), + BilledAmountBeforeLine: alpacadecimal.NewFromFloat(9000), + // Total $10000 => No max spend is reached + }) + + require.Equal(t, l, lineWithDiscount) + }) + + t.Run("overage and some valid charges", func(t *testing.T) { + lineWithDiscount := l.AddDiscountForOverage(addDiscountInput{ + MaxSpend: alpacadecimal.NewFromFloat(10000), + BilledAmountBeforeLine: alpacadecimal.NewFromFloat(9600), + // Total $10000 => $500 discount + }) + + require.Equal(t, newDetailedLineInput{ + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(10), + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 10000"), + Amount: alpacadecimal.NewFromFloat(600), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, lineWithDiscount) + }) + + t.Run("overage 100% discount", func(t *testing.T) { + lineWithDiscount := l.AddDiscountForOverage(addDiscountInput{ + MaxSpend: alpacadecimal.NewFromFloat(10000), + BilledAmountBeforeLine: alpacadecimal.NewFromFloat(10000), + // Total $10000 => $1000 discount + }) + + require.Equal(t, newDetailedLineInput{ + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(10), + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 10000"), + Amount: alpacadecimal.NewFromFloat(1000), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, lineWithDiscount) + }) + + t.Run("overage and 100% discount when hugely over the max spend", func(t *testing.T) { + lineWithDiscount := l.AddDiscountForOverage(addDiscountInput{ + MaxSpend: alpacadecimal.NewFromFloat(10000), + BilledAmountBeforeLine: alpacadecimal.NewFromFloat(20000), + // Total $10000 => $1000 discount + }) + + require.Equal(t, newDetailedLineInput{ + Amount: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(10), + Discounts: []billingentity.LineDiscount{ + { + Description: lo.ToPtr("Maximum spend discount for charges over 10000"), + Amount: alpacadecimal.NewFromFloat(1000), + Type: lo.ToPtr(billingentity.MaximumSpendLineDiscountType), + Source: billingentity.CalculatedLineDiscountSource, + }, + }, + }, lineWithDiscount) + }) +} + +func TestFindTierForQuantity(t *testing.T) { + testIn := plan.TieredPrice{ + Tiers: []plan.PriceTier{ + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(5)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 20/unit + Amount: alpacadecimal.NewFromFloat(100), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(10)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 10/unit + Amount: alpacadecimal.NewFromFloat(150), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(15)), + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + }, + }, + { + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(5), + }, + }, + }, + } + + tier, index := findTierForQuantity(testIn, alpacadecimal.NewFromFloat(3)) + require.Equal(t, 0, index) + require.Equal(t, testIn.Tiers[0], *tier) + + tier, index = findTierForQuantity(testIn, alpacadecimal.NewFromFloat(5)) + require.Equal(t, 0, index) + require.Equal(t, testIn.Tiers[0], *tier) + + tier, index = findTierForQuantity(testIn, alpacadecimal.NewFromFloat(6)) + require.Equal(t, 1, index) + require.Equal(t, testIn.Tiers[1], *tier) + + tier, index = findTierForQuantity(testIn, alpacadecimal.NewFromFloat(100)) + require.Equal(t, 3, index) + require.Equal(t, testIn.Tiers[3], *tier) +} + +func getTotalAmountForVolumeTieredPrice(t *testing.T, qty alpacadecimal.Decimal, price plan.TieredPrice) alpacadecimal.Decimal { + t.Helper() + + total := alpacadecimal.Zero + err := tieredPriceCalculator(tieredPriceCalculatorInput{ + TieredPrice: price, + ToQty: qty, + + FinalizerFn: func(t alpacadecimal.Decimal) error { + total = t + return nil + }, + IntrospectRangesFn: introspectTieredPriceRangesFn(t), + }) + + require.NoError(t, err) + + return total +} + +func introspectTieredPriceRangesFn(t *testing.T) func([]tierRange) { + return func(qtyRanges []tierRange) { + for _, qtyRange := range qtyRanges { + t.Logf("From: %s, To: %s, AtBoundary: %t, Tier[idx=%d]: %+v", qtyRange.FromQty.String(), qtyRange.ToQty.String(), qtyRange.AtTierBoundary, qtyRange.TierIndex, qtyRange.Tier) + } + } +} + +type mockableTieredPriceCalculator struct { + mock.Mock +} + +func (m *mockableTieredPriceCalculator) TierCallbackFn(i tierCallbackInput) error { + args := m.Called(i) + return args.Error(0) +} + +func (m *mockableTieredPriceCalculator) FinalizerFn(t alpacadecimal.Decimal) error { + args := m.Called(t) + return args.Error(0) +} + +func TestTieredPriceCalculator(t *testing.T) { + testIn := plan.TieredPrice{ + Mode: plan.VolumeTieredPrice, + Tiers: []plan.PriceTier{ + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(5)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 20/unit + Amount: alpacadecimal.NewFromFloat(100), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(10)), + FlatPrice: &plan.PriceTierFlatPrice{ + // 10/unit + Amount: alpacadecimal.NewFromFloat(50), + }, + }, + { + UpToAmount: lo.ToPtr(alpacadecimal.NewFromFloat(15)), + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + }, + }, + { + UnitPrice: &plan.PriceTierUnitPrice{ + Amount: alpacadecimal.NewFromFloat(5), + }, + }, + }, + } + + t.Run("totals, no usage", func(t *testing.T) { + totalAmount := getTotalAmountForVolumeTieredPrice(t, alpacadecimal.NewFromFloat(0), testIn) + require.Equal(t, alpacadecimal.NewFromFloat(0), totalAmount) + }) + + t.Run("totals, usage in tier 1", func(t *testing.T) { + totalAmount := getTotalAmountForVolumeTieredPrice(t, alpacadecimal.NewFromFloat(3), testIn) + require.Equal(t, alpacadecimal.NewFromFloat(100), totalAmount) + + totalAmount = getTotalAmountForVolumeTieredPrice(t, alpacadecimal.NewFromFloat(5), testIn) + require.Equal(t, alpacadecimal.NewFromFloat(100), totalAmount) + }) + + t.Run("totals, usage in tier 2", func(t *testing.T) { + totalAmount := getTotalAmountForVolumeTieredPrice(t, alpacadecimal.NewFromFloat(7), testIn) + require.Equal(t, alpacadecimal.NewFromFloat(100+50), totalAmount) + }) + + t.Run("totals, usage in tier 3", func(t *testing.T) { + totalAmount := getTotalAmountForVolumeTieredPrice(t, alpacadecimal.NewFromFloat(12), testIn) + require.Equal(t, alpacadecimal.NewFromFloat(170 /* = 100+50+2*10 */), totalAmount) + }) + + t.Run("totals, usage in tier 4", func(t *testing.T) { + totalAmount := getTotalAmountForVolumeTieredPrice(t, alpacadecimal.NewFromFloat(22), testIn) + require.Equal(t, alpacadecimal.NewFromFloat(235 /* = 100+50+10*5+5*7 */), totalAmount) + }) + + t.Run("tier callback, mid tier invocation", func(t *testing.T) { + callback := mockableTieredPriceCalculator{} + + callback.On("TierCallbackFn", tierCallbackInput{ + Tier: testIn.Tiers[0], + TierIndex: 0, + + AtTierBoundary: false, + Quantity: alpacadecimal.NewFromFloat(2), + // The flat price has been already billed for + PreviousTotalAmount: alpacadecimal.NewFromFloat(100), + }).Return(nil).Once() + + callback.On("TierCallbackFn", tierCallbackInput{ + Tier: testIn.Tiers[1], + TierIndex: 1, + + AtTierBoundary: true, + Quantity: alpacadecimal.NewFromFloat(2), + PreviousTotalAmount: alpacadecimal.NewFromFloat(100), + }).Return(nil).Once() + + callback.On("FinalizerFn", alpacadecimal.NewFromFloat(150)).Return(nil).Once() + + require.NoError(t, tieredPriceCalculator( + tieredPriceCalculatorInput{ + TieredPrice: testIn, + FromQty: alpacadecimal.NewFromFloat(3), // exclusive + ToQty: alpacadecimal.NewFromFloat(7), // inclusive + TierCallbackFn: callback.TierCallbackFn, + FinalizerFn: callback.FinalizerFn, + IntrospectRangesFn: introspectTieredPriceRangesFn(t), + }, + ), + ) + + callback.AssertExpectations(t) + }) + + t.Run("tier callback, open ended invocation", func(t *testing.T) { + callback := mockableTieredPriceCalculator{} + + callback.On("TierCallbackFn", tierCallbackInput{ + Tier: testIn.Tiers[2], + TierIndex: 2, + + AtTierBoundary: false, + Quantity: alpacadecimal.NewFromFloat(3), + PreviousTotalAmount: alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + testIn.Tiers[2].UnitPrice.Amount.Mul(alpacadecimal.NewFromFloat(2)), + ), + }).Return(nil).Once() + + callback.On("TierCallbackFn", tierCallbackInput{ + Tier: testIn.Tiers[3], + TierIndex: 3, + + AtTierBoundary: true, + Quantity: alpacadecimal.NewFromFloat(5), + PreviousTotalAmount: alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + testIn.Tiers[2].UnitPrice.Amount.Mul(alpacadecimal.NewFromFloat(5)), + ), + }).Return(nil).Once() + + callback.On("FinalizerFn", + alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + testIn.Tiers[2].UnitPrice.Amount.Mul(alpacadecimal.NewFromFloat(5)), + testIn.Tiers[3].UnitPrice.Amount.Mul(alpacadecimal.NewFromFloat(5)), + )).Return(nil).Once() + + require.NoError(t, tieredPriceCalculator( + tieredPriceCalculatorInput{ + TieredPrice: testIn, + FromQty: alpacadecimal.NewFromFloat(12), // exclusive + ToQty: alpacadecimal.NewFromFloat(20), // inclusive + TierCallbackFn: callback.TierCallbackFn, + FinalizerFn: callback.FinalizerFn, + IntrospectRangesFn: introspectTieredPriceRangesFn(t), + }, + ), + ) + + callback.AssertExpectations(t) + }) + + t.Run("tier callback, callback on boundary", func(t *testing.T) { + callback := mockableTieredPriceCalculator{} + + callback.On("TierCallbackFn", tierCallbackInput{ + Tier: testIn.Tiers[1], + TierIndex: 1, + + AtTierBoundary: true, + Quantity: alpacadecimal.NewFromFloat(5), + PreviousTotalAmount: testIn.Tiers[0].FlatPrice.Amount, + }).Return(nil).Once() + + callback.On("FinalizerFn", + alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + )).Return(nil).Once() + + require.NoError(t, tieredPriceCalculator( + tieredPriceCalculatorInput{ + TieredPrice: testIn, + FromQty: alpacadecimal.NewFromFloat(5), // exclusive + ToQty: alpacadecimal.NewFromFloat(10), // inclusive + TierCallbackFn: callback.TierCallbackFn, + FinalizerFn: callback.FinalizerFn, + IntrospectRangesFn: introspectTieredPriceRangesFn(t), + }, + ), + ) + + callback.AssertExpectations(t) + }) + + t.Run("tier callback, from/to in same tier", func(t *testing.T) { + callback := mockableTieredPriceCalculator{} + + callback.On("TierCallbackFn", tierCallbackInput{ + Tier: testIn.Tiers[1], + TierIndex: 1, + + AtTierBoundary: false, + Quantity: alpacadecimal.NewFromFloat(1), + PreviousTotalAmount: alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + ), + }).Return(nil).Once() + + callback.On("FinalizerFn", + alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + )).Return(nil).Once() + + require.NoError(t, tieredPriceCalculator( + tieredPriceCalculatorInput{ + TieredPrice: testIn, + FromQty: alpacadecimal.NewFromFloat(6), // exclusive + ToQty: alpacadecimal.NewFromFloat(7), // inclusive + TierCallbackFn: callback.TierCallbackFn, + FinalizerFn: callback.FinalizerFn, + IntrospectRangesFn: introspectTieredPriceRangesFn(t), + }, + ), + ) + + callback.AssertExpectations(t) + }) + + t.Run("tier callback, from == to, only finalizer is called ", func(t *testing.T) { + callback := mockableTieredPriceCalculator{} + + callback.On("FinalizerFn", alpacadecimal.Sum( + testIn.Tiers[0].FlatPrice.Amount, + testIn.Tiers[1].FlatPrice.Amount, + )).Return(nil).Once() + + require.NoError(t, tieredPriceCalculator( + tieredPriceCalculatorInput{ + TieredPrice: testIn, + FromQty: alpacadecimal.NewFromFloat(6), // exclusive + ToQty: alpacadecimal.NewFromFloat(6), // inclusive + TierCallbackFn: callback.TierCallbackFn, + FinalizerFn: callback.FinalizerFn, + IntrospectRangesFn: introspectTieredPriceRangesFn(t), + }, + ), + ) + + // Nothing should be called + callback.AssertExpectations(t) + }) +} diff --git a/openmeter/ent/db/billinginvoiceline.go b/openmeter/ent/db/billinginvoiceline.go index 2ae6a686e..0b82e8224 100644 --- a/openmeter/ent/db/billinginvoiceline.go +++ b/openmeter/ent/db/billinginvoiceline.go @@ -59,6 +59,8 @@ type BillingInvoiceLine struct { Quantity *alpacadecimal.Decimal `json:"quantity,omitempty"` // TaxConfig holds the value of the "tax_config" field. TaxConfig plan.TaxConfig `json:"tax_config,omitempty"` + // ChildUniqueReferenceID holds the value of the "child_unique_reference_id" field. + ChildUniqueReferenceID string `json:"child_unique_reference_id,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the BillingInvoiceLineQuery when eager-loading is set. Edges BillingInvoiceLineEdges `json:"edges"` @@ -146,7 +148,7 @@ func (*BillingInvoiceLine) scanValues(columns []string) ([]any, error) { values[i] = &sql.NullScanner{S: new(alpacadecimal.Decimal)} case billinginvoiceline.FieldMetadata, billinginvoiceline.FieldTaxConfig: values[i] = new([]byte) - case billinginvoiceline.FieldID, billinginvoiceline.FieldNamespace, billinginvoiceline.FieldName, billinginvoiceline.FieldDescription, billinginvoiceline.FieldInvoiceID, billinginvoiceline.FieldParentLineID, billinginvoiceline.FieldType, billinginvoiceline.FieldStatus, billinginvoiceline.FieldCurrency: + case billinginvoiceline.FieldID, billinginvoiceline.FieldNamespace, billinginvoiceline.FieldName, billinginvoiceline.FieldDescription, billinginvoiceline.FieldInvoiceID, billinginvoiceline.FieldParentLineID, billinginvoiceline.FieldType, billinginvoiceline.FieldStatus, billinginvoiceline.FieldCurrency, billinginvoiceline.FieldChildUniqueReferenceID: values[i] = new(sql.NullString) case billinginvoiceline.FieldCreatedAt, billinginvoiceline.FieldUpdatedAt, billinginvoiceline.FieldDeletedAt, billinginvoiceline.FieldPeriodStart, billinginvoiceline.FieldPeriodEnd, billinginvoiceline.FieldInvoiceAt: values[i] = new(sql.NullTime) @@ -285,6 +287,12 @@ func (bil *BillingInvoiceLine) assignValues(columns []string, values []any) erro return fmt.Errorf("unmarshal field tax_config: %w", err) } } + case billinginvoiceline.FieldChildUniqueReferenceID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field child_unique_reference_id", values[i]) + } else if value.Valid { + bil.ChildUniqueReferenceID = value.String + } case billinginvoiceline.ForeignKeys[0]: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field fee_line_config_id", values[i]) @@ -418,6 +426,9 @@ func (bil *BillingInvoiceLine) String() string { builder.WriteString(", ") builder.WriteString("tax_config=") builder.WriteString(fmt.Sprintf("%v", bil.TaxConfig)) + builder.WriteString(", ") + builder.WriteString("child_unique_reference_id=") + builder.WriteString(bil.ChildUniqueReferenceID) builder.WriteByte(')') return builder.String() } diff --git a/openmeter/ent/db/billinginvoiceline/billinginvoiceline.go b/openmeter/ent/db/billinginvoiceline/billinginvoiceline.go index c45481819..22cd2f4b4 100644 --- a/openmeter/ent/db/billinginvoiceline/billinginvoiceline.go +++ b/openmeter/ent/db/billinginvoiceline/billinginvoiceline.go @@ -50,6 +50,8 @@ const ( FieldQuantity = "quantity" // FieldTaxConfig holds the string denoting the tax_config field in the database. FieldTaxConfig = "tax_config" + // FieldChildUniqueReferenceID holds the string denoting the child_unique_reference_id field in the database. + FieldChildUniqueReferenceID = "child_unique_reference_id" // EdgeBillingInvoice holds the string denoting the billing_invoice edge name in mutations. EdgeBillingInvoice = "billing_invoice" // EdgeFlatFeeLine holds the string denoting the flat_fee_line edge name in mutations. @@ -113,6 +115,7 @@ var Columns = []string{ FieldCurrency, FieldQuantity, FieldTaxConfig, + FieldChildUniqueReferenceID, } // ForeignKeys holds the SQL foreign-keys that are owned by the "billing_invoice_lines" @@ -148,6 +151,8 @@ var ( UpdateDefaultUpdatedAt func() time.Time // CurrencyValidator is a validator for the "currency" field. It is called by the builders before save. CurrencyValidator func(string) error + // ChildUniqueReferenceIDValidator is a validator for the "child_unique_reference_id" field. It is called by the builders before save. + ChildUniqueReferenceIDValidator func(string) error // DefaultID holds the default value on creation for the "id" field. DefaultID func() string ) @@ -255,6 +260,11 @@ func ByQuantity(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldQuantity, opts...).ToFunc() } +// ByChildUniqueReferenceID orders the results by the child_unique_reference_id field. +func ByChildUniqueReferenceID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChildUniqueReferenceID, opts...).ToFunc() +} + // ByBillingInvoiceField orders the results by billing_invoice field. func ByBillingInvoiceField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/openmeter/ent/db/billinginvoiceline/where.go b/openmeter/ent/db/billinginvoiceline/where.go index 9154b5953..3ea79e939 100644 --- a/openmeter/ent/db/billinginvoiceline/where.go +++ b/openmeter/ent/db/billinginvoiceline/where.go @@ -134,6 +134,11 @@ func Quantity(v alpacadecimal.Decimal) predicate.BillingInvoiceLine { return predicate.BillingInvoiceLine(sql.FieldEQ(FieldQuantity, v)) } +// ChildUniqueReferenceID applies equality check predicate on the "child_unique_reference_id" field. It's identical to ChildUniqueReferenceIDEQ. +func ChildUniqueReferenceID(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldEQ(FieldChildUniqueReferenceID, v)) +} + // NamespaceEQ applies the EQ predicate on the "namespace" field. func NamespaceEQ(v string) predicate.BillingInvoiceLine { return predicate.BillingInvoiceLine(sql.FieldEQ(FieldNamespace, v)) @@ -943,6 +948,71 @@ func TaxConfigNotNil() predicate.BillingInvoiceLine { return predicate.BillingInvoiceLine(sql.FieldNotNull(FieldTaxConfig)) } +// ChildUniqueReferenceIDEQ applies the EQ predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDEQ(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldEQ(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDNEQ applies the NEQ predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDNEQ(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldNEQ(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDIn applies the In predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDIn(vs ...string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldIn(FieldChildUniqueReferenceID, vs...)) +} + +// ChildUniqueReferenceIDNotIn applies the NotIn predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDNotIn(vs ...string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldNotIn(FieldChildUniqueReferenceID, vs...)) +} + +// ChildUniqueReferenceIDGT applies the GT predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDGT(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldGT(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDGTE applies the GTE predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDGTE(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldGTE(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDLT applies the LT predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDLT(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldLT(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDLTE applies the LTE predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDLTE(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldLTE(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDContains applies the Contains predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDContains(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldContains(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDHasPrefix applies the HasPrefix predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDHasPrefix(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldHasPrefix(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDHasSuffix applies the HasSuffix predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDHasSuffix(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldHasSuffix(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDEqualFold applies the EqualFold predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDEqualFold(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldEqualFold(FieldChildUniqueReferenceID, v)) +} + +// ChildUniqueReferenceIDContainsFold applies the ContainsFold predicate on the "child_unique_reference_id" field. +func ChildUniqueReferenceIDContainsFold(v string) predicate.BillingInvoiceLine { + return predicate.BillingInvoiceLine(sql.FieldContainsFold(FieldChildUniqueReferenceID, v)) +} + // HasBillingInvoice applies the HasEdge predicate on the "billing_invoice" edge. func HasBillingInvoice() predicate.BillingInvoiceLine { return predicate.BillingInvoiceLine(func(s *sql.Selector) { diff --git a/openmeter/ent/db/billinginvoiceline_create.go b/openmeter/ent/db/billinginvoiceline_create.go index 0cd17cb22..e780246b7 100644 --- a/openmeter/ent/db/billinginvoiceline_create.go +++ b/openmeter/ent/db/billinginvoiceline_create.go @@ -188,6 +188,12 @@ func (bilc *BillingInvoiceLineCreate) SetNillableTaxConfig(pc *plan.TaxConfig) * return bilc } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (bilc *BillingInvoiceLineCreate) SetChildUniqueReferenceID(s string) *BillingInvoiceLineCreate { + bilc.mutation.SetChildUniqueReferenceID(s) + return bilc +} + // SetID sets the "id" field. func (bilc *BillingInvoiceLineCreate) SetID(s string) *BillingInvoiceLineCreate { bilc.mutation.SetID(s) @@ -380,6 +386,14 @@ func (bilc *BillingInvoiceLineCreate) check() error { return &ValidationError{Name: "tax_config", err: fmt.Errorf(`db: validator failed for field "BillingInvoiceLine.tax_config": %w`, err)} } } + if _, ok := bilc.mutation.ChildUniqueReferenceID(); !ok { + return &ValidationError{Name: "child_unique_reference_id", err: errors.New(`db: missing required field "BillingInvoiceLine.child_unique_reference_id"`)} + } + if v, ok := bilc.mutation.ChildUniqueReferenceID(); ok { + if err := billinginvoiceline.ChildUniqueReferenceIDValidator(v); err != nil { + return &ValidationError{Name: "child_unique_reference_id", err: fmt.Errorf(`db: validator failed for field "BillingInvoiceLine.child_unique_reference_id": %w`, err)} + } + } if len(bilc.mutation.BillingInvoiceIDs()) == 0 { return &ValidationError{Name: "billing_invoice", err: errors.New(`db: missing required edge "BillingInvoiceLine.billing_invoice"`)} } @@ -479,6 +493,10 @@ func (bilc *BillingInvoiceLineCreate) createSpec() (*BillingInvoiceLine, *sqlgra _spec.SetField(billinginvoiceline.FieldTaxConfig, field.TypeJSON, value) _node.TaxConfig = value } + if value, ok := bilc.mutation.ChildUniqueReferenceID(); ok { + _spec.SetField(billinginvoiceline.FieldChildUniqueReferenceID, field.TypeString, value) + _node.ChildUniqueReferenceID = value + } if nodes := bilc.mutation.BillingInvoiceIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -807,6 +825,18 @@ func (u *BillingInvoiceLineUpsert) ClearTaxConfig() *BillingInvoiceLineUpsert { return u } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (u *BillingInvoiceLineUpsert) SetChildUniqueReferenceID(v string) *BillingInvoiceLineUpsert { + u.Set(billinginvoiceline.FieldChildUniqueReferenceID, v) + return u +} + +// UpdateChildUniqueReferenceID sets the "child_unique_reference_id" field to the value that was provided on create. +func (u *BillingInvoiceLineUpsert) UpdateChildUniqueReferenceID() *BillingInvoiceLineUpsert { + u.SetExcluded(billinginvoiceline.FieldChildUniqueReferenceID) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -1091,6 +1121,20 @@ func (u *BillingInvoiceLineUpsertOne) ClearTaxConfig() *BillingInvoiceLineUpsert }) } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (u *BillingInvoiceLineUpsertOne) SetChildUniqueReferenceID(v string) *BillingInvoiceLineUpsertOne { + return u.Update(func(s *BillingInvoiceLineUpsert) { + s.SetChildUniqueReferenceID(v) + }) +} + +// UpdateChildUniqueReferenceID sets the "child_unique_reference_id" field to the value that was provided on create. +func (u *BillingInvoiceLineUpsertOne) UpdateChildUniqueReferenceID() *BillingInvoiceLineUpsertOne { + return u.Update(func(s *BillingInvoiceLineUpsert) { + s.UpdateChildUniqueReferenceID() + }) +} + // Exec executes the query. func (u *BillingInvoiceLineUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1542,6 +1586,20 @@ func (u *BillingInvoiceLineUpsertBulk) ClearTaxConfig() *BillingInvoiceLineUpser }) } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (u *BillingInvoiceLineUpsertBulk) SetChildUniqueReferenceID(v string) *BillingInvoiceLineUpsertBulk { + return u.Update(func(s *BillingInvoiceLineUpsert) { + s.SetChildUniqueReferenceID(v) + }) +} + +// UpdateChildUniqueReferenceID sets the "child_unique_reference_id" field to the value that was provided on create. +func (u *BillingInvoiceLineUpsertBulk) UpdateChildUniqueReferenceID() *BillingInvoiceLineUpsertBulk { + return u.Update(func(s *BillingInvoiceLineUpsert) { + s.UpdateChildUniqueReferenceID() + }) +} + // Exec executes the query. func (u *BillingInvoiceLineUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/openmeter/ent/db/billinginvoiceline_update.go b/openmeter/ent/db/billinginvoiceline_update.go index 371f14cdf..aa39b8d56 100644 --- a/openmeter/ent/db/billinginvoiceline_update.go +++ b/openmeter/ent/db/billinginvoiceline_update.go @@ -236,6 +236,20 @@ func (bilu *BillingInvoiceLineUpdate) ClearTaxConfig() *BillingInvoiceLineUpdate return bilu } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (bilu *BillingInvoiceLineUpdate) SetChildUniqueReferenceID(s string) *BillingInvoiceLineUpdate { + bilu.mutation.SetChildUniqueReferenceID(s) + return bilu +} + +// SetNillableChildUniqueReferenceID sets the "child_unique_reference_id" field if the given value is not nil. +func (bilu *BillingInvoiceLineUpdate) SetNillableChildUniqueReferenceID(s *string) *BillingInvoiceLineUpdate { + if s != nil { + bilu.SetChildUniqueReferenceID(*s) + } + return bilu +} + // SetBillingInvoiceID sets the "billing_invoice" edge to the BillingInvoice entity by ID. func (bilu *BillingInvoiceLineUpdate) SetBillingInvoiceID(id string) *BillingInvoiceLineUpdate { bilu.mutation.SetBillingInvoiceID(id) @@ -403,6 +417,11 @@ func (bilu *BillingInvoiceLineUpdate) check() error { return &ValidationError{Name: "tax_config", err: fmt.Errorf(`db: validator failed for field "BillingInvoiceLine.tax_config": %w`, err)} } } + if v, ok := bilu.mutation.ChildUniqueReferenceID(); ok { + if err := billinginvoiceline.ChildUniqueReferenceIDValidator(v); err != nil { + return &ValidationError{Name: "child_unique_reference_id", err: fmt.Errorf(`db: validator failed for field "BillingInvoiceLine.child_unique_reference_id": %w`, err)} + } + } if bilu.mutation.BillingInvoiceCleared() && len(bilu.mutation.BillingInvoiceIDs()) > 0 { return errors.New(`db: clearing a required unique edge "BillingInvoiceLine.billing_invoice"`) } @@ -469,6 +488,9 @@ func (bilu *BillingInvoiceLineUpdate) sqlSave(ctx context.Context) (n int, err e if bilu.mutation.TaxConfigCleared() { _spec.ClearField(billinginvoiceline.FieldTaxConfig, field.TypeJSON) } + if value, ok := bilu.mutation.ChildUniqueReferenceID(); ok { + _spec.SetField(billinginvoiceline.FieldChildUniqueReferenceID, field.TypeString, value) + } if bilu.mutation.BillingInvoiceCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -852,6 +874,20 @@ func (biluo *BillingInvoiceLineUpdateOne) ClearTaxConfig() *BillingInvoiceLineUp return biluo } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (biluo *BillingInvoiceLineUpdateOne) SetChildUniqueReferenceID(s string) *BillingInvoiceLineUpdateOne { + biluo.mutation.SetChildUniqueReferenceID(s) + return biluo +} + +// SetNillableChildUniqueReferenceID sets the "child_unique_reference_id" field if the given value is not nil. +func (biluo *BillingInvoiceLineUpdateOne) SetNillableChildUniqueReferenceID(s *string) *BillingInvoiceLineUpdateOne { + if s != nil { + biluo.SetChildUniqueReferenceID(*s) + } + return biluo +} + // SetBillingInvoiceID sets the "billing_invoice" edge to the BillingInvoice entity by ID. func (biluo *BillingInvoiceLineUpdateOne) SetBillingInvoiceID(id string) *BillingInvoiceLineUpdateOne { biluo.mutation.SetBillingInvoiceID(id) @@ -1032,6 +1068,11 @@ func (biluo *BillingInvoiceLineUpdateOne) check() error { return &ValidationError{Name: "tax_config", err: fmt.Errorf(`db: validator failed for field "BillingInvoiceLine.tax_config": %w`, err)} } } + if v, ok := biluo.mutation.ChildUniqueReferenceID(); ok { + if err := billinginvoiceline.ChildUniqueReferenceIDValidator(v); err != nil { + return &ValidationError{Name: "child_unique_reference_id", err: fmt.Errorf(`db: validator failed for field "BillingInvoiceLine.child_unique_reference_id": %w`, err)} + } + } if biluo.mutation.BillingInvoiceCleared() && len(biluo.mutation.BillingInvoiceIDs()) > 0 { return errors.New(`db: clearing a required unique edge "BillingInvoiceLine.billing_invoice"`) } @@ -1115,6 +1156,9 @@ func (biluo *BillingInvoiceLineUpdateOne) sqlSave(ctx context.Context) (_node *B if biluo.mutation.TaxConfigCleared() { _spec.ClearField(billinginvoiceline.FieldTaxConfig, field.TypeJSON) } + if value, ok := biluo.mutation.ChildUniqueReferenceID(); ok { + _spec.SetField(billinginvoiceline.FieldChildUniqueReferenceID, field.TypeString, value) + } if biluo.mutation.BillingInvoiceCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/openmeter/ent/db/migrate/schema.go b/openmeter/ent/db/migrate/schema.go index 4c9446738..9719219ad 100644 --- a/openmeter/ent/db/migrate/schema.go +++ b/openmeter/ent/db/migrate/schema.go @@ -449,6 +449,7 @@ var ( {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, {Name: "quantity", Type: field.TypeOther, Nullable: true, SchemaType: map[string]string{"postgres": "numeric"}}, {Name: "tax_config", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "child_unique_reference_id", Type: field.TypeString}, {Name: "invoice_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, {Name: "fee_line_config_id", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "char(26)"}}, {Name: "usage_based_line_config_id", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "char(26)"}}, @@ -462,25 +463,25 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "billing_invoice_lines_billing_invoices_billing_invoice_lines", - Columns: []*schema.Column{BillingInvoiceLinesColumns[16]}, + Columns: []*schema.Column{BillingInvoiceLinesColumns[17]}, RefColumns: []*schema.Column{BillingInvoicesColumns[0]}, OnDelete: schema.Cascade, }, { Symbol: "billing_invoice_lines_billing_invoice_flat_fee_line_configs_flat_fee_line", - Columns: []*schema.Column{BillingInvoiceLinesColumns[17]}, + Columns: []*schema.Column{BillingInvoiceLinesColumns[18]}, RefColumns: []*schema.Column{BillingInvoiceFlatFeeLineConfigsColumns[0]}, OnDelete: schema.Cascade, }, { Symbol: "billing_invoice_lines_billing_invoice_usage_based_line_configs_usage_based_line", - Columns: []*schema.Column{BillingInvoiceLinesColumns[18]}, + Columns: []*schema.Column{BillingInvoiceLinesColumns[19]}, RefColumns: []*schema.Column{BillingInvoiceUsageBasedLineConfigsColumns[0]}, OnDelete: schema.Cascade, }, { Symbol: "billing_invoice_lines_billing_invoice_lines_child_lines", - Columns: []*schema.Column{BillingInvoiceLinesColumns[19]}, + Columns: []*schema.Column{BillingInvoiceLinesColumns[20]}, RefColumns: []*schema.Column{BillingInvoiceLinesColumns[0]}, OnDelete: schema.SetNull, }, @@ -504,12 +505,17 @@ var ( { Name: "billinginvoiceline_namespace_invoice_id", Unique: false, - Columns: []*schema.Column{BillingInvoiceLinesColumns[1], BillingInvoiceLinesColumns[16]}, + Columns: []*schema.Column{BillingInvoiceLinesColumns[1], BillingInvoiceLinesColumns[17]}, }, { Name: "billinginvoiceline_namespace_parent_line_id", Unique: false, - Columns: []*schema.Column{BillingInvoiceLinesColumns[1], BillingInvoiceLinesColumns[19]}, + Columns: []*schema.Column{BillingInvoiceLinesColumns[1], BillingInvoiceLinesColumns[20]}, + }, + { + Name: "billinginvoiceline_namespace_parent_line_id_child_unique_reference_id", + Unique: true, + Columns: []*schema.Column{BillingInvoiceLinesColumns[1], BillingInvoiceLinesColumns[20], BillingInvoiceLinesColumns[16]}, }, }, } diff --git a/openmeter/ent/db/mutation.go b/openmeter/ent/db/mutation.go index 8f297a27a..382c3872d 100644 --- a/openmeter/ent/db/mutation.go +++ b/openmeter/ent/db/mutation.go @@ -10025,39 +10025,40 @@ func (m *BillingInvoiceFlatFeeLineConfigMutation) ResetEdge(name string) error { // BillingInvoiceLineMutation represents an operation that mutates the BillingInvoiceLine nodes in the graph. type BillingInvoiceLineMutation struct { config - op Op - typ string - id *string - namespace *string - metadata *map[string]string - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - period_start *time.Time - period_end *time.Time - invoice_at *time.Time - _type *billingentity.InvoiceLineType - status *billingentity.InvoiceLineStatus - currency *currencyx.Code - quantity *alpacadecimal.Decimal - tax_config *plan.TaxConfig - clearedFields map[string]struct{} - billing_invoice *string - clearedbilling_invoice bool - flat_fee_line *string - clearedflat_fee_line bool - usage_based_line *string - clearedusage_based_line bool - parent_line *string - clearedparent_line bool - child_lines map[string]struct{} - removedchild_lines map[string]struct{} - clearedchild_lines bool - done bool - oldValue func(context.Context) (*BillingInvoiceLine, error) - predicates []predicate.BillingInvoiceLine + op Op + typ string + id *string + namespace *string + metadata *map[string]string + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + period_start *time.Time + period_end *time.Time + invoice_at *time.Time + _type *billingentity.InvoiceLineType + status *billingentity.InvoiceLineStatus + currency *currencyx.Code + quantity *alpacadecimal.Decimal + tax_config *plan.TaxConfig + child_unique_reference_id *string + clearedFields map[string]struct{} + billing_invoice *string + clearedbilling_invoice bool + flat_fee_line *string + clearedflat_fee_line bool + usage_based_line *string + clearedusage_based_line bool + parent_line *string + clearedparent_line bool + child_lines map[string]struct{} + removedchild_lines map[string]struct{} + clearedchild_lines bool + done bool + oldValue func(context.Context) (*BillingInvoiceLine, error) + predicates []predicate.BillingInvoiceLine } var _ ent.Mutation = (*BillingInvoiceLineMutation)(nil) @@ -10854,6 +10855,42 @@ func (m *BillingInvoiceLineMutation) ResetTaxConfig() { delete(m.clearedFields, billinginvoiceline.FieldTaxConfig) } +// SetChildUniqueReferenceID sets the "child_unique_reference_id" field. +func (m *BillingInvoiceLineMutation) SetChildUniqueReferenceID(s string) { + m.child_unique_reference_id = &s +} + +// ChildUniqueReferenceID returns the value of the "child_unique_reference_id" field in the mutation. +func (m *BillingInvoiceLineMutation) ChildUniqueReferenceID() (r string, exists bool) { + v := m.child_unique_reference_id + if v == nil { + return + } + return *v, true +} + +// OldChildUniqueReferenceID returns the old "child_unique_reference_id" field's value of the BillingInvoiceLine entity. +// If the BillingInvoiceLine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BillingInvoiceLineMutation) OldChildUniqueReferenceID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChildUniqueReferenceID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChildUniqueReferenceID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChildUniqueReferenceID: %w", err) + } + return oldValue.ChildUniqueReferenceID, nil +} + +// ResetChildUniqueReferenceID resets all changes to the "child_unique_reference_id" field. +func (m *BillingInvoiceLineMutation) ResetChildUniqueReferenceID() { + m.child_unique_reference_id = nil +} + // SetBillingInvoiceID sets the "billing_invoice" edge to the BillingInvoice entity by id. func (m *BillingInvoiceLineMutation) SetBillingInvoiceID(id string) { m.billing_invoice = &id @@ -11087,7 +11124,7 @@ func (m *BillingInvoiceLineMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BillingInvoiceLineMutation) Fields() []string { - fields := make([]string, 0, 17) + fields := make([]string, 0, 18) if m.namespace != nil { fields = append(fields, billinginvoiceline.FieldNamespace) } @@ -11139,6 +11176,9 @@ func (m *BillingInvoiceLineMutation) Fields() []string { if m.tax_config != nil { fields = append(fields, billinginvoiceline.FieldTaxConfig) } + if m.child_unique_reference_id != nil { + fields = append(fields, billinginvoiceline.FieldChildUniqueReferenceID) + } return fields } @@ -11181,6 +11221,8 @@ func (m *BillingInvoiceLineMutation) Field(name string) (ent.Value, bool) { return m.Quantity() case billinginvoiceline.FieldTaxConfig: return m.TaxConfig() + case billinginvoiceline.FieldChildUniqueReferenceID: + return m.ChildUniqueReferenceID() } return nil, false } @@ -11224,6 +11266,8 @@ func (m *BillingInvoiceLineMutation) OldField(ctx context.Context, name string) return m.OldQuantity(ctx) case billinginvoiceline.FieldTaxConfig: return m.OldTaxConfig(ctx) + case billinginvoiceline.FieldChildUniqueReferenceID: + return m.OldChildUniqueReferenceID(ctx) } return nil, fmt.Errorf("unknown BillingInvoiceLine field %s", name) } @@ -11352,6 +11396,13 @@ func (m *BillingInvoiceLineMutation) SetField(name string, value ent.Value) erro } m.SetTaxConfig(v) return nil + case billinginvoiceline.FieldChildUniqueReferenceID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChildUniqueReferenceID(v) + return nil } return fmt.Errorf("unknown BillingInvoiceLine field %s", name) } @@ -11491,6 +11542,9 @@ func (m *BillingInvoiceLineMutation) ResetField(name string) error { case billinginvoiceline.FieldTaxConfig: m.ResetTaxConfig() return nil + case billinginvoiceline.FieldChildUniqueReferenceID: + m.ResetChildUniqueReferenceID() + return nil } return fmt.Errorf("unknown BillingInvoiceLine field %s", name) } diff --git a/openmeter/ent/db/runtime.go b/openmeter/ent/db/runtime.go index cc6c5abf8..1c28aa6b7 100644 --- a/openmeter/ent/db/runtime.go +++ b/openmeter/ent/db/runtime.go @@ -344,6 +344,10 @@ func init() { billinginvoicelineDescCurrency := billinginvoicelineFields[7].Descriptor() // billinginvoiceline.CurrencyValidator is a validator for the "currency" field. It is called by the builders before save. billinginvoiceline.CurrencyValidator = billinginvoicelineDescCurrency.Validators[0].(func(string) error) + // billinginvoicelineDescChildUniqueReferenceID is the schema descriptor for child_unique_reference_id field. + billinginvoicelineDescChildUniqueReferenceID := billinginvoicelineFields[10].Descriptor() + // billinginvoiceline.ChildUniqueReferenceIDValidator is a validator for the "child_unique_reference_id" field. It is called by the builders before save. + billinginvoiceline.ChildUniqueReferenceIDValidator = billinginvoicelineDescChildUniqueReferenceID.Validators[0].(func(string) error) // billinginvoicelineDescID is the schema descriptor for id field. billinginvoicelineDescID := billinginvoicelineMixinFields0[0].Descriptor() // billinginvoiceline.DefaultID holds the default value on creation for the id field. diff --git a/openmeter/ent/schema/billing.go b/openmeter/ent/schema/billing.go index f3cff3d60..8019f285f 100644 --- a/openmeter/ent/schema/billing.go +++ b/openmeter/ent/schema/billing.go @@ -285,6 +285,16 @@ func (BillingInvoiceLine) Fields() []ent.Field { "postgres": "jsonb", }). Optional(), + + // child_unique_reference_id is uniqe per parent line, can be used for upserting + // and identifying lines created for the same reason (e.g. tiered price tier) + // between different invoices. + // + // As entgo doesn't support conditional unique indexes, defaults to ID of the + // line. + // TODO: add hooks + field.String("child_unique_reference_id"). + NotEmpty(), } } @@ -292,6 +302,7 @@ func (BillingInvoiceLine) Indexes() []ent.Index { return []ent.Index{ index.Fields("namespace", "invoice_id"), index.Fields("namespace", "parent_line_id"), + index.Fields("namespace", "parent_line_id", "child_unique_reference_id").Unique(), } } diff --git a/openmeter/productcatalog/plan/price.go b/openmeter/productcatalog/plan/price.go index 5d9ba8e40..930fab8c3 100644 --- a/openmeter/productcatalog/plan/price.go +++ b/openmeter/productcatalog/plan/price.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" decimal "github.com/alpacahq/alpacadecimal" @@ -361,8 +362,17 @@ func (t TieredPrice) Validate() error { errs = append(errs, fmt.Errorf("invalid TieredPrice mode: %s", t.Mode)) } + if len(t.Tiers) == 0 { + errs = append(errs, errors.New("at least one PriceTier must be provided")) + } + upToAmounts := make(map[string]struct{}, len(t.Tiers)) + tierOpenEndedPresent := false for _, tier := range t.Tiers { + if tier.UpToAmount == nil { + tierOpenEndedPresent = true + } + uta := lo.FromPtrOr(tier.UpToAmount, decimal.Zero) if !uta.IsZero() { if _, ok := upToAmounts[uta.String()]; ok { @@ -378,6 +388,10 @@ func (t TieredPrice) Validate() error { } } + if !tierOpenEndedPresent { + errs = append(errs, errors.New("at least one PriceTier must be open-ended")) + } + minAmount := lo.FromPtrOr(t.MinimumAmount, decimal.Zero) if minAmount.IsNegative() { errs = append(errs, errors.New("the MinimumAmount must not be negative")) @@ -401,6 +415,31 @@ func (t TieredPrice) Validate() error { return nil } +func (t TieredPrice) WithSortedTiers() TieredPrice { + out := t + out.Tiers = make([]PriceTier, len(t.Tiers)) + copy(out.Tiers, t.Tiers) + + // Sort tiers by UpToAmount in ascending order + slices.SortFunc(out.Tiers, func(a, b PriceTier) int { + if a.UpToAmount == nil && b.UpToAmount == nil { + return 0 + } + + if a.UpToAmount == nil { + return 1 + } + + if b.UpToAmount == nil { + return -1 + } + + return a.UpToAmount.Cmp(*b.UpToAmount) + }) + + return out +} + var _ Validator = (*PriceTier)(nil) // PriceTier describes a tier of price(s). diff --git a/openmeter/productcatalog/plan/price_test.go b/openmeter/productcatalog/plan/price_test.go index dddda3cba..4f66136e3 100644 --- a/openmeter/productcatalog/plan/price_test.go +++ b/openmeter/productcatalog/plan/price_test.go @@ -274,3 +274,29 @@ func TestTieredPrice(t *testing.T) { } }) } + +func TestTieredPriceSorting(t *testing.T) { + in := TieredPrice{ + Tiers: []PriceTier{ + { + UpToAmount: lo.ToPtr(decimal.NewFromInt(1000)), + }, + { + UpToAmount: nil, + }, + { + UpToAmount: lo.ToPtr(decimal.NewFromInt(500)), + }, + }, + } + + out := in.WithSortedTiers() + + assert.Equal(t, []*decimal.Decimal{ + lo.ToPtr(decimal.NewFromInt(500)), + lo.ToPtr(decimal.NewFromInt(1000)), + nil, + }, lo.Map(out.Tiers, func(t PriceTier, _ int) *decimal.Decimal { + return t.UpToAmount + })) +}