From d246fc07da3d5cd2ba3b49f54ff1f9abe1a78557 Mon Sep 17 00:00:00 2001 From: Krisztian Gacsal Date: Wed, 13 Nov 2024 16:21:46 +0100 Subject: [PATCH] fix: Plan discount JSON serialization --- openmeter/productcatalog/plan/discount.go | 40 +++++++++-------- .../productcatalog/plan/discount_test.go | 44 +++++++++++++++++++ .../productcatalog/plan/httpdriver/mapping.go | 3 -- 3 files changed, 66 insertions(+), 21 deletions(-) create mode 100644 openmeter/productcatalog/plan/discount_test.go diff --git a/openmeter/productcatalog/plan/discount.go b/openmeter/productcatalog/plan/discount.go index e85393141..97068fbbc 100644 --- a/openmeter/productcatalog/plan/discount.go +++ b/openmeter/productcatalog/plan/discount.go @@ -56,38 +56,49 @@ func (d *Discount) RateCardKeys() []string { func (d *Discount) MarshalJSON() ([]byte, error) { var b []byte var err error + var serde interface{} switch d.t { case PercentageDiscountType: - b, err = json.Marshal(d.percentage) - if err != nil { - return nil, fmt.Errorf("failed to json marshal percentage discount: %w", err) + serde = struct { + Type DiscountType `json:"type"` + *PercentageDiscount + }{ + Type: PercentageDiscountType, + PercentageDiscount: d.percentage, } default: - return nil, fmt.Errorf("invalid discount type: %s", d.t) + return nil, fmt.Errorf("invalid Discount type: %s", d.t) + } + + b, err = json.Marshal(serde) + if err != nil { + return nil, fmt.Errorf("failed to JSON serialize Discount: %w", err) } return b, nil } func (d *Discount) UnmarshalJSON(bytes []byte) error { - meta := &DiscountMeta{} + serde := &struct { + Type DiscountType `json:"type"` + }{} - if err := json.Unmarshal(bytes, meta); err != nil { - return fmt.Errorf("failed to json unmarshal discount type: %w", err) + if err := json.Unmarshal(bytes, serde); err != nil { + return fmt.Errorf("failed to JSON deserialize Discount type: %w", err) } - switch meta.Type { + switch serde.Type { case PercentageDiscountType: v := &PercentageDiscount{} if err := json.Unmarshal(bytes, v); err != nil { - return fmt.Errorf("failed to json unmarshal percentage discount: %w", err) + return fmt.Errorf("failed to JSON deserialize Discount: %w", err) } d.percentage = v d.t = PercentageDiscountType default: - return fmt.Errorf("invalid discount type: %s", meta.Type) + return fmt.Errorf("invalid Discount type: %s", serde.Type) } return nil @@ -127,7 +138,7 @@ func NewDiscountFrom[T PercentageDiscount](v T) Discount { d := Discount{} switch any(v).(type) { - case FlatPrice: + case PercentageDiscount: percentage := any(v).(PercentageDiscount) d.FromPercentage(percentage) } @@ -135,16 +146,9 @@ func NewDiscountFrom[T PercentageDiscount](v T) Discount { return d } -type DiscountMeta struct { - // Type of the Discount. - Type DiscountType `json:"type"` -} - var _ Validator = (*PercentageDiscount)(nil) type PercentageDiscount struct { - DiscountMeta - // Percentage defines percentage of the discount. Percentage decimal.Decimal `json:"percentage"` diff --git a/openmeter/productcatalog/plan/discount_test.go b/openmeter/productcatalog/plan/discount_test.go new file mode 100644 index 000000000..58f233c55 --- /dev/null +++ b/openmeter/productcatalog/plan/discount_test.go @@ -0,0 +1,44 @@ +package plan + +import ( + "testing" + + decimal "github.com/alpacahq/alpacadecimal" + json "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDiscount_JSON(t *testing.T) { + tests := []struct { + Name string + Discount Discount + ExpectedError bool + }{ + { + Name: "Valid", + Discount: NewDiscountFrom(PercentageDiscount{ + Percentage: decimal.NewFromFloat(99.9), + RateCards: []string{ + "ratecard-1", + "ratecard-2", + }, + }), + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + b, err := json.Marshal(&test.Discount) + require.NoError(t, err) + + t.Logf("Serialized Discount: %s", string(b)) + + d := Discount{} + err = json.Unmarshal(b, &d) + require.NoError(t, err) + + assert.Equal(t, test.Discount, d) + }) + } +} diff --git a/openmeter/productcatalog/plan/httpdriver/mapping.go b/openmeter/productcatalog/plan/httpdriver/mapping.go index 32a3217ef..613a8a556 100644 --- a/openmeter/productcatalog/plan/httpdriver/mapping.go +++ b/openmeter/productcatalog/plan/httpdriver/mapping.go @@ -322,9 +322,6 @@ func AsPlanPhase(a api.PlanPhase, namespace, phaseID string) (plan.Phase, error) switch discount.Type { case api.DiscountPercentageTypePercentage: percentageDiscount := plan.PercentageDiscount{ - DiscountMeta: plan.DiscountMeta{ - Type: plan.PercentageDiscountType, - }, Percentage: decimal.NewFromFloat(float64(discount.Percentage)), RateCards: lo.FromPtrOr(discount.RateCards, nil), }