Skip to content

Commit

Permalink
Merge pull request #1832 from openmeterio/feat/plan-api-impl
Browse files Browse the repository at this point in the history
fix: updating Feature in Plan RateCard
  • Loading branch information
chrisgacsal authored Nov 12, 2024
2 parents 33cfc85 + 271fed5 commit 4d54451
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 312 deletions.
120 changes: 52 additions & 68 deletions openmeter/productcatalog/plan/adapter/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,21 @@ func (a *adapter) CreatePhase(ctx context.Context, params plan.CreatePhaseInput)

bulk, bulkFn := newRateCardBulkCreate(rateCardInputs, planPhase.ID, params.Namespace)

rateCardRows, err := a.db.PlanRateCard.MapCreateBulk(bulk, bulkFn).Save(ctx)
if err != nil {
if err = a.db.PlanRateCard.MapCreateBulk(bulk, bulkFn).Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to bulk create RateCards for PlanPhase %s: %w", planPhase.ID, err)
}

planPhase.RateCards = make([]plan.RateCard, 0, len(rateCardRows))
for _, rateCardRow := range rateCardRows {
if rateCardRow == nil {
return nil, errors.New("invalid query result: nil RateCard received after bulk create")
}

rateCard, err := fromPlanRateCardRow(*rateCardRow)
if err != nil {
return nil, fmt.Errorf("failed to cast RateCard: %w", err)
}
planPhaseRow, err = a.db.PlanPhase.Query().
Where(phasedb.Namespace(params.Namespace), phasedb.ID(planPhase.ID)).
WithRatecards(rateCardEagerLoadFeaturesFn).
First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get PlanPhase: %w", err)
}

planPhase.RateCards = append(planPhase.RateCards, *rateCard)
planPhase, err = fromPlanPhaseRow(*planPhaseRow)
if err != nil {
return nil, fmt.Errorf("failed to cast PlanPhase %w", err)
}

return planPhase, nil
Expand All @@ -191,9 +189,15 @@ func newRateCardBulkCreate(r []entdb.PlanRateCard, phaseID string, ns string) ([
SetNillableFeatureKey(r[i].FeatureKey).
SetNillableFeaturesID(r[i].FeatureID).
SetEntitlementTemplate(r[i].EntitlementTemplate).
SetTaxConfig(r[i].TaxConfig).
SetNillableBillingCadence(r[i].BillingCadence).
SetPrice(r[i].Price)
SetNillableBillingCadence(r[i].BillingCadence)

if r[i].TaxConfig != nil {
q.SetTaxConfig(r[i].TaxConfig)
}

if r[i].Price != nil {
q.SetPrice(r[i].Price)
}
}
}

Expand Down Expand Up @@ -305,7 +309,7 @@ func (a *adapter) GetPhase(ctx context.Context, params plan.GetPhaseInput) (*pla
return nil, errors.New("invalid get PlanPhase parameters")
}

query = query.WithRatecards()
query = query.WithRatecards(rateCardEagerLoadFeaturesFn)

phaseRow, err := query.First(ctx)
if err != nil {
Expand Down Expand Up @@ -382,34 +386,22 @@ func (a *adapter) UpdatePhase(ctx context.Context, params plan.UpdatePhaseInput)
}
}

if params.RateCards != nil {
rateCards := make([]plan.RateCard, 0, len(p.RateCards))

if params.RateCards != nil && len(*params.RateCards) > 0 {
diffResult, err := rateCardsDiff(*params.RateCards, p.RateCards)
if err != nil {
return nil, fmt.Errorf("failed to generate RateCard diff for PlanPhase update: %w", err)
}

if !diffResult.IsDiff() {
return p, nil
}

if len(diffResult.Add) > 0 {
bulk, bulkFn := newRateCardBulkCreate(diffResult.Add, p.ID, params.Namespace)

rateCardRows, err := a.db.PlanRateCard.MapCreateBulk(bulk, bulkFn).Save(ctx)
if err != nil {
if err = a.db.PlanRateCard.MapCreateBulk(bulk, bulkFn).Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to bulk create RateCards: %w", err)
}

for _, rateCardRow := range rateCardRows {
if rateCardRow == nil {
return nil, errors.New("invalid query result: nil RateCard received after bulk create")
}

rateCard, err := fromPlanRateCardRow(*rateCardRow)
if err != nil {
return nil, fmt.Errorf("failed to cast RateCard: %w", err)
}

rateCards = append(rateCards, *rateCard)
}
}

if len(diffResult.Remove) > 0 {
Expand All @@ -423,47 +415,39 @@ func (a *adapter) UpdatePhase(ctx context.Context, params plan.UpdatePhaseInput)

if len(diffResult.Update) > 0 {
for _, rateCardInput := range diffResult.Update {
rateCardRow, err := a.db.PlanRateCard.UpdateOneID(rateCardInput.ID).
q := a.db.PlanRateCard.UpdateOneID(rateCardInput.ID).
Where(ratecarddb.Namespace(params.Namespace)).
SetMetadata(rateCardInput.Metadata).
SetOrClearMetadata(&rateCardInput.Metadata).
SetName(rateCardInput.Name).
SetNillableDescription(rateCardInput.Description).
SetNillableFeatureKey(rateCardInput.FeatureKey).
SetNillableFeaturesID(rateCardInput.FeatureID).
SetOrClearDescription(rateCardInput.Description).
SetOrClearFeatureKey(rateCardInput.FeatureKey).
SetEntitlementTemplate(rateCardInput.EntitlementTemplate).
SetTaxConfig(rateCardInput.TaxConfig).
SetNillableBillingCadence(rateCardInput.BillingCadence).
SetPrice(rateCardInput.Price).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to update RateCard: %w", err)
}
SetOrClearBillingCadence(rateCardInput.BillingCadence).
SetPrice(rateCardInput.Price)

if rateCardRow == nil {
return nil, errors.New("invalid query result: nil RateCard received update")
if rateCardInput.FeatureID == nil {
q.ClearFeatureID()
}

rateCard, err := fromPlanRateCardRow(*rateCardRow)
err = q.Exec(ctx)
if err != nil {
return nil, fmt.Errorf("failed to cast RateCard: %w", err)
return nil, fmt.Errorf("failed to update RateCard: %w", err)
}

rateCards = append(rateCards, *rateCard)
}
}

if len(diffResult.Keep) > 0 {
for _, rateCardRow := range diffResult.Keep {
rateCard, err := fromPlanRateCardRow(rateCardRow)
if err != nil {
return nil, fmt.Errorf("failed to cast RateCard: %w", err)
}

rateCards = append(rateCards, *rateCard)
}
p, err = a.GetPhase(ctx, plan.GetPhaseInput{
NamespacedID: models.NamespacedID{
Namespace: params.Namespace,
ID: params.ID,
},
Key: params.Key,
PlanID: params.PlanID,
})
if err != nil {
return nil, fmt.Errorf("failed to get updated PlanPhase: %w", err)
}

p.RateCards = rateCards
}

return p, nil
Expand Down Expand Up @@ -542,7 +526,7 @@ func rateCardsDiff(inputs, rateCards []plan.RateCard) (rateCardsDiffResult, erro
}
}

// Collect phases to be deleted
// Collect RateCards to be deleted
if len(rateCardsVisited) != len(rateCardsMap) {
for rateCardKey, rateCard := range rateCardsMap {
if _, ok := rateCardsVisited[rateCardKey]; !ok {
Expand All @@ -554,11 +538,11 @@ func rateCardsDiff(inputs, rateCards []plan.RateCard) (rateCardsDiffResult, erro
return result, nil
}

func rateCardCmp(r1, r2 entdb.PlanRateCard) (bool, error) {
if r1.ID != r2.ID {
return false, nil
}
func (r rateCardsDiffResult) IsDiff() bool {
return len(r.Add) > 0 || len(r.Update) > 0 || len(r.Remove) > 0
}

func rateCardCmp(r1, r2 entdb.PlanRateCard) (bool, error) {
if r1.Namespace != r2.Namespace {
return false, nil
}
Expand Down
109 changes: 70 additions & 39 deletions openmeter/productcatalog/plan/adapter/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,56 +414,65 @@ func (a *adapter) UpdatePlan(ctx context.Context, params plan.UpdatePlanInput) (
}
}

if params.Phases != nil {
phases := make([]plan.Phase, 0, len(p.Phases))
// Return early if there are no PlanPhases set in Plan.
if params.Phases == nil || len(*params.Phases) == 0 {
return p, nil
}

// Return early if there are no changes in PlanPhases.
diffResult, err := planPhasesDiff(*params.Phases, p.Phases)
if err != nil {
return nil, fmt.Errorf("failed to calculate Plan Phases diff: %w", err)
}

diffResult := planPhasesDiff(*params.Phases, p.Phases)
if !diffResult.IsDiff() {
return p, nil
}

if len(diffResult.Add) > 0 {
for _, createInput := range diffResult.Add {
createInput.Namespace = params.Namespace
phases := make([]plan.Phase, 0, len(p.Phases))

phase, err := a.CreatePhase(ctx, createInput)
if err != nil {
return nil, fmt.Errorf("failed to create PlanPhase: %w", err)
}
if len(diffResult.Keep) > 0 {
phases = append(phases, diffResult.Keep...)
}

if len(diffResult.Add) > 0 {
for _, createInput := range diffResult.Add {
createInput.Namespace = params.Namespace

phases = append(phases, *phase)
phase, err := a.CreatePhase(ctx, createInput)
if err != nil {
return nil, fmt.Errorf("failed to create PlanPhase: %w", err)
}

phases = append(phases, *phase)
}
}

if len(diffResult.Remove) > 0 {
for _, deleteInput := range diffResult.Remove {
err = a.DeletePhase(ctx, deleteInput)
if err != nil {
return nil, fmt.Errorf("failed to delete PlanPhase: %w", err)
}
if len(diffResult.Remove) > 0 {
for _, deleteInput := range diffResult.Remove {
err = a.DeletePhase(ctx, deleteInput)
if err != nil {
return nil, fmt.Errorf("failed to delete PlanPhase: %w", err)
}
}
}

if len(diffResult.Update) > 0 {
for _, updateInput := range diffResult.Update {
updateInput.Namespace = params.Namespace

phase, err := a.UpdatePhase(ctx, updateInput)
if err != nil {
return nil, fmt.Errorf("failed to update PlanPhase: %w", err)
}
if len(diffResult.Update) > 0 {
for _, updateInput := range diffResult.Update {
updateInput.Namespace = params.Namespace

phases = append(phases, *phase)
phase, err := a.UpdatePhase(ctx, updateInput)
if err != nil {
return nil, fmt.Errorf("failed to update PlanPhase: %w", err)
}
}

if len(diffResult.Keep) > 0 {
phases = append(phases, diffResult.Keep...)
phases = append(phases, *phase)
}
}

if len(phases) > 0 {
plan.SortPhases(p.Phases, plan.SortPhasesByStartAfter)
}
plan.SortPhases(p.Phases, plan.SortPhasesByStartAfter)

p.Phases = phases
}
p.Phases = phases

return p, nil
}
Expand All @@ -476,7 +485,11 @@ var planPhaseAscOrderingByStartAfterFn = func(q *entdb.PlanPhaseQuery) {
}

var planPhaseEagerLoadRateCardsFn = func(q *entdb.PlanPhaseQuery) {
q.WithRatecards()
q.WithRatecards(rateCardEagerLoadFeaturesFn)
}

var rateCardEagerLoadFeaturesFn = func(q *entdb.PlanRateCardQuery) {
q.WithFeatures()
}

type planPhasesDiffResult struct {
Expand All @@ -493,7 +506,11 @@ type planPhasesDiffResult struct {
Keep []plan.Phase
}

func planPhasesDiff(requested, actual []plan.Phase) planPhasesDiffResult {
func (d planPhasesDiffResult) IsDiff() bool {
return len(d.Add) > 0 || len(d.Update) > 0 || len(d.Remove) > 0
}

func planPhasesDiff(requested, actual []plan.Phase) (planPhasesDiffResult, error) {
result := planPhasesDiffResult{}

inputsMap := make(map[string]plan.UpdatePhaseInput, len(requested))
Expand Down Expand Up @@ -545,10 +562,24 @@ func planPhasesDiff(requested, actual []plan.Phase) planPhasesDiffResult {
if !input.Equal(phase) {
result.Update = append(result.Update, input)
phasesVisited[phaseKey] = struct{}{}
} else {
result.Keep = append(result.Keep, phase)

continue
}

diffResult, err := rateCardsDiff(lo.FromPtr(input.RateCards), phase.RateCards)
if err != nil {
return result, err
}

if diffResult.IsDiff() {
result.Update = append(result.Update, input)
phasesVisited[phaseKey] = struct{}{}

continue
}

result.Keep = append(result.Keep, phase)
phasesVisited[phaseKey] = struct{}{}
}

// Collect phases to be deleted
Expand All @@ -565,5 +596,5 @@ func planPhasesDiff(requested, actual []plan.Phase) planPhasesDiffResult {
}
}

return result
return result, nil
}
4 changes: 0 additions & 4 deletions openmeter/productcatalog/plan/ratecard.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,6 @@ type RateCardMeta struct {
func (r *RateCardMeta) Validate() error {
var errs []error

if r.Feature != nil && r.EntitlementTemplate == nil {
errs = append(errs, errors.New("invalid EntitlementTemplate: must be provided if Feature is set"))
}

if r.EntitlementTemplate != nil {
if err := r.EntitlementTemplate.Validate(); err != nil {
errs = append(errs, fmt.Errorf("invalid EntitlementTemplate: %w", err))
Expand Down
4 changes: 2 additions & 2 deletions openmeter/productcatalog/plan/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ func (i UpdatePlanInput) Equal(p Plan) bool {
return false
}

if lo.FromPtrOr(i.Description, "") != lo.FromPtrOr(p.Description, "") {
if i.Description != nil && lo.FromPtrOr(i.Description, "") != lo.FromPtrOr(p.Description, "") {
return false
}

if !MetadataEqual(lo.FromPtrOr(i.Metadata, nil), p.Metadata) {
if i.Metadata != nil && !MetadataEqual(*i.Metadata, p.Metadata) {
return false
}

Expand Down
6 changes: 6 additions & 0 deletions openmeter/productcatalog/plan/service/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func (s service) CreatePhase(ctx context.Context, params plan.CreatePhaseInput)

logger.Debug("creating PlanPhase")

if len(params.RateCards) > 0 {
if err := s.expandFeatures(ctx, params.Namespace, &params.RateCards); err != nil {
return nil, fmt.Errorf("failed to expand Features for RateCards in PlanPhase: %w", err)
}
}

phase, err := s.adapter.CreatePhase(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to create PlanPhase: %w", err)
Expand Down
Loading

0 comments on commit 4d54451

Please sign in to comment.