Skip to content

Commit

Permalink
Merge pull request #9 from caiorcferreira/feat/expose-extension-const…
Browse files Browse the repository at this point in the history
…ructors

feat: expose decimal and ipaddr constructors
  • Loading branch information
philhassey authored Apr 23, 2024
2 parents 470d1fe + 5940a63 commit f3d8620
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 36 deletions.
4 changes: 2 additions & 2 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ func (n *decimalLiteralEval) Eval(ctx *evalContext) (Value, error) {
return zeroValue(), err
}

d, err := newDecimalValue(string(literal))
d, err := ParseDecimal(string(literal))
if err != nil {
return zeroValue(), err
}
Expand All @@ -1104,7 +1104,7 @@ func (n *ipLiteralEval) Eval(ctx *evalContext) (Value, error) {
return zeroValue(), err
}

i, err := newIPValue(string(literal))
i, err := ParseIPAddr(string(literal))
if err != nil {
return zeroValue(), err
}
Expand Down
30 changes: 15 additions & 15 deletions eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,10 +713,10 @@ func TestDecimalLessThanNode(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("%s<%s", tt.lhs, tt.rhs), func(t *testing.T) {
t.Parallel()
lhsd, err := newDecimalValue(tt.lhs)
lhsd, err := ParseDecimal(tt.lhs)
testutilOK(t, err)
lhsv := lhsd
rhsd, err := newDecimalValue(tt.rhs)
rhsd, err := ParseDecimal(tt.rhs)
testutilOK(t, err)
rhsv := rhsd
n := newDecimalLessThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv))
Expand Down Expand Up @@ -770,10 +770,10 @@ func TestDecimalLessThanOrEqualNode(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("%s<=%s", tt.lhs, tt.rhs), func(t *testing.T) {
t.Parallel()
lhsd, err := newDecimalValue(tt.lhs)
lhsd, err := ParseDecimal(tt.lhs)
testutilOK(t, err)
lhsv := lhsd
rhsd, err := newDecimalValue(tt.rhs)
rhsd, err := ParseDecimal(tt.rhs)
testutilOK(t, err)
rhsv := rhsd
n := newDecimalLessThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv))
Expand Down Expand Up @@ -827,10 +827,10 @@ func TestDecimalGreaterThanNode(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("%s>%s", tt.lhs, tt.rhs), func(t *testing.T) {
t.Parallel()
lhsd, err := newDecimalValue(tt.lhs)
lhsd, err := ParseDecimal(tt.lhs)
testutilOK(t, err)
lhsv := lhsd
rhsd, err := newDecimalValue(tt.rhs)
rhsd, err := ParseDecimal(tt.rhs)
testutilOK(t, err)
rhsv := rhsd
n := newDecimalGreaterThanEval(newLiteralEval(lhsv), newLiteralEval(rhsv))
Expand Down Expand Up @@ -884,10 +884,10 @@ func TestDecimalGreaterThanOrEqualNode(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("%s>=%s", tt.lhs, tt.rhs), func(t *testing.T) {
t.Parallel()
lhsd, err := newDecimalValue(tt.lhs)
lhsd, err := ParseDecimal(tt.lhs)
testutilOK(t, err)
lhsv := lhsd
rhsd, err := newDecimalValue(tt.rhs)
rhsd, err := ParseDecimal(tt.rhs)
testutilOK(t, err)
rhsv := rhsd
n := newDecimalGreaterThanOrEqualEval(newLiteralEval(lhsv), newLiteralEval(rhsv))
Expand Down Expand Up @@ -1748,7 +1748,7 @@ func TestDecimalLiteralNode(t *testing.T) {

func TestIPLiteralNode(t *testing.T) {
t.Parallel()
ipv6Loopback, err := newIPValue("::1")
ipv6Loopback, err := ParseIPAddr("::1")
testutilOK(t, err)
tests := []struct {
name string
Expand All @@ -1775,11 +1775,11 @@ func TestIPLiteralNode(t *testing.T) {

func TestIPTestNode(t *testing.T) {
t.Parallel()
ipv4Loopback, err := newIPValue("127.0.0.1")
ipv4Loopback, err := ParseIPAddr("127.0.0.1")
testutilOK(t, err)
ipv6Loopback, err := newIPValue("::1")
ipv6Loopback, err := ParseIPAddr("::1")
testutilOK(t, err)
ipv4Multicast, err := newIPValue("224.0.0.1")
ipv4Multicast, err := ParseIPAddr("224.0.0.1")
testutilOK(t, err)
tests := []struct {
name string
Expand Down Expand Up @@ -1813,11 +1813,11 @@ func TestIPTestNode(t *testing.T) {

func TestIPIsInRangeNode(t *testing.T) {
t.Parallel()
ipv4A, err := newIPValue("1.2.3.4")
ipv4A, err := ParseIPAddr("1.2.3.4")
testutilOK(t, err)
ipv4B, err := newIPValue("1.2.3.0/24")
ipv4B, err := ParseIPAddr("1.2.3.0/24")
testutilOK(t, err)
ipv4C, err := newIPValue("1.2.4.0/24")
ipv4C, err := ParseIPAddr("1.2.4.0/24")
testutilOK(t, err)
tests := []struct {
name string
Expand Down
4 changes: 2 additions & 2 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ func unmarshalJSON(b []byte, v *Value) error {
if err := json.Unmarshal(b, &res); err == nil && res.Extn != nil {
switch res.Extn.Fn {
case "ip":
val, err := newIPValue(res.Extn.Arg)
val, err := ParseIPAddr(res.Extn.Arg)
if err != nil {
return err
}
*v = val
return nil
case "decimal":
val, err := newDecimalValue(res.Extn.Arg)
val, err := ParseDecimal(res.Extn.Arg)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
)

func mustDecimalValue(v string) Decimal {
r, _ := newDecimalValue(v)
r, _ := ParseDecimal(v)
return r
}

func mustIPValue(v string) IPAddr {
r, _ := newIPValue(v)
r, _ := ParseIPAddr(v)
return r
}

Expand Down
10 changes: 6 additions & 4 deletions value.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ type Decimal int64
// DecimalPrecision is the precision of a Decimal.
const DecimalPrecision = 10000

func newDecimalValue(s string) (Decimal, error) {
// ParseDecimal takes a string representation of a decimal number and converts it into a Decimal type.
func ParseDecimal(s string) (Decimal, error) {
// Check for empty string.
if len(s) == 0 {
return Decimal(0), fmt.Errorf("%w: string too short", errDecimal)
Expand Down Expand Up @@ -550,7 +551,7 @@ func (v *Decimal) UnmarshalJSON(b []byte) error {
}
arg = res.Extn.Arg
}
vv, err := newDecimalValue(arg)
vv, err := ParseDecimal(arg)
if err != nil {
return err
}
Expand All @@ -576,7 +577,8 @@ func (v Decimal) deepClone() Value { return v }
// The value can represent an individual address or a range of addresses.
type IPAddr netip.Prefix

func newIPValue(s string) (IPAddr, error) {
// ParseIPAddr takes a string representation of an IP address and converts it into an IPAddr type.
func ParseIPAddr(s string) (IPAddr, error) {
// We disallow IPv4-mapped IPv6 addresses in dotted notation because Cedar does.
if strings.Count(s, ":") >= 2 && strings.Count(s, ".") >= 2 {
return IPAddr{}, fmt.Errorf("%w: cannot parse IPv4 addresses embedded in IPv6 addresses", errIP)
Expand Down Expand Up @@ -687,7 +689,7 @@ func (v *IPAddr) UnmarshalJSON(b []byte) error {
}
arg = res.Extn.Arg
}
vv, err := newIPValue(arg)
vv, err := ParseIPAddr(arg)
if err != nil {
return err
}
Expand Down
22 changes: 11 additions & 11 deletions value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func TestDecimal(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("%s->%s", tt.in, tt.out), func(t *testing.T) {
t.Parallel()
d, err := newDecimalValue(tt.in)
d, err := ParseDecimal(tt.in)
testutilOK(t, err)
testutilEquals(t, d.String(), tt.out)
})
Expand Down Expand Up @@ -413,7 +413,7 @@ func TestDecimal(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("%s->%s", tt.in, tt.errStr), func(t *testing.T) {
t.Parallel()
_, err := newDecimalValue(tt.in)
_, err := ParseDecimal(tt.in)
assertError(t, err, errDecimal)
testutilEquals(t, err.Error(), tt.errStr)
})
Expand All @@ -422,7 +422,7 @@ func TestDecimal(t *testing.T) {

t.Run("roundTrip", func(t *testing.T) {
t.Parallel()
dv, err := newDecimalValue("1.20")
dv, err := ParseDecimal("1.20")
testutilOK(t, err)
v, err := valueToDecimal(dv)
testutilOK(t, err)
Expand Down Expand Up @@ -498,7 +498,7 @@ func TestIP(t *testing.T) {
}
t.Run(testName, func(t *testing.T) {
t.Parallel()
i, err := newIPValue(tt.in)
i, err := ParseIPAddr(tt.in)
if tt.parses {
testutilOK(t, err)
testutilEquals(t, i.String(), tt.out)
Expand Down Expand Up @@ -549,9 +549,9 @@ func TestIP(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("ip(%v).equal(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) {
t.Parallel()
lhs, err := newIPValue(tt.lhs)
lhs, err := ParseIPAddr(tt.lhs)
testutilOK(t, err)
rhs, err := newIPValue(tt.rhs)
rhs, err := ParseIPAddr(tt.rhs)
testutilOK(t, err)
equal := lhs.equal(rhs)
if equal != tt.equal {
Expand Down Expand Up @@ -597,7 +597,7 @@ func TestIP(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("ip(%v).isIPv{4,6}()", tt.val), func(t *testing.T) {
t.Parallel()
val, err := newIPValue(tt.val)
val, err := ParseIPAddr(tt.val)
testutilOK(t, err)
isIPv4 := val.isIPv4()
if isIPv4 != tt.isIPv4 {
Expand Down Expand Up @@ -646,7 +646,7 @@ func TestIP(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("ip(%v).isLoopback()", tt.val), func(t *testing.T) {
t.Parallel()
val, err := newIPValue(tt.val)
val, err := ParseIPAddr(tt.val)
testutilOK(t, err)
isLoopback := val.isLoopback()
if isLoopback != tt.isLoopback {
Expand Down Expand Up @@ -680,7 +680,7 @@ func TestIP(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("ip(%v).isMulticast()", tt.val), func(t *testing.T) {
t.Parallel()
val, err := newIPValue(tt.val)
val, err := ParseIPAddr(tt.val)
testutilOK(t, err)
isMulticast := val.isMulticast()
if isMulticast != tt.isMulticast {
Expand Down Expand Up @@ -713,9 +713,9 @@ func TestIP(t *testing.T) {
tt := tt
t.Run(fmt.Sprintf("ip(%v).contains(ip(%v))", tt.lhs, tt.rhs), func(t *testing.T) {
t.Parallel()
lhs, err := newIPValue(tt.lhs)
lhs, err := ParseIPAddr(tt.lhs)
testutilOK(t, err)
rhs, err := newIPValue(tt.rhs)
rhs, err := ParseIPAddr(tt.rhs)
testutilOK(t, err)
contains := lhs.contains(rhs)
if contains != tt.contains {
Expand Down

0 comments on commit f3d8620

Please sign in to comment.