Skip to content

Commit

Permalink
Merge pull request #56 from strongdm/internalize-public-errors
Browse files Browse the repository at this point in the history
types: internalize various errors
  • Loading branch information
patjakdev authored Nov 8, 2024
2 parents 1bf4c4e + 2630949 commit c3dfe27
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 79 deletions.
13 changes: 13 additions & 0 deletions internal/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package internal

import "fmt"

// These errors are declared here in order to allow the tests outside of the
// types package to assert on the error type returned. One day, we could
// consider making them public.

var ErrDatetime = fmt.Errorf("error parsing datetime value")
var ErrDecimal = fmt.Errorf("error parsing decimal value")
var ErrDuration = fmt.Errorf("error parsing duration value")
var ErrIP = fmt.Errorf("error parsing ip value")
var ErrNotComparable = fmt.Errorf("incompatible types in comparison")
9 changes: 5 additions & 4 deletions internal/eval/evalers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/cedar-policy/cedar-go/internal"
"github.com/cedar-policy/cedar-go/internal/consts"
"github.com/cedar-policy/cedar-go/internal/parser"
"github.com/cedar-policy/cedar-go/internal/testutil"
Expand Down Expand Up @@ -2177,7 +2178,7 @@ func TestDecimalLiteralNode(t *testing.T) {
}{
{"Error", newErrorEval(errTest), zeroValue(), errTest},
{"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType},
{"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDecimal},
{"DecimalError", newLiteralEval(types.String("frob")), zeroValue(), internal.ErrDecimal},
{"Success", newLiteralEval(types.String("1.0")), testutil.Must(types.NewDecimalFromInt(1)), nil},
}
for _, tt := range tests {
Expand All @@ -2204,7 +2205,7 @@ func TestIPLiteralNode(t *testing.T) {
}{
{"Error", newErrorEval(errTest), zeroValue(), errTest},
{"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType},
{"IPError", newLiteralEval(types.String("not-an-IP-address")), zeroValue(), types.ErrIP},
{"IPError", newLiteralEval(types.String("not-an-IP-address")), zeroValue(), internal.ErrIP},
{"Success", newLiteralEval(types.String("::1/128")), ipv6Loopback, nil},
}
for _, tt := range tests {
Expand Down Expand Up @@ -2335,7 +2336,7 @@ func TestDatetimeLiteralNode(t *testing.T) {
}{
{"Error", newErrorEval(errTest), zeroValue(), errTest},
{"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType},
{"DatetimeError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDatetime},
{"DatetimeError", newLiteralEval(types.String("frob")), zeroValue(), internal.ErrDatetime},
{"Success", newLiteralEval(types.String("1970-01-01")), types.FromStdTime(time.UnixMilli(0)), nil},
}
for _, tt := range tests {
Expand Down Expand Up @@ -2480,7 +2481,7 @@ func TestDurationLiteralNode(t *testing.T) {
}{
{"Error", newErrorEval(errTest), zeroValue(), errTest},
{"TypeError", newLiteralEval(types.Long(1)), zeroValue(), ErrType},
{"DurationError", newLiteralEval(types.String("frob")), zeroValue(), types.ErrDuration},
{"DurationError", newLiteralEval(types.String("frob")), zeroValue(), internal.ErrDuration},
{"Success", newLiteralEval(types.String("1h")), types.FromStdDuration(1 * time.Hour), nil},
}
for _, tt := range tests {
Expand Down
50 changes: 27 additions & 23 deletions types/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
"strconv"
"time"
"unicode"

"github.com/cedar-policy/cedar-go/internal"
)

var errDatetime = internal.ErrDatetime

// Datetime represents a Cedar datetime value
type Datetime struct {
// value is a timestamp in milliseconds
Expand Down Expand Up @@ -44,7 +48,7 @@ func ParseDatetime(s string) (Datetime, error) {

length := len(s)
if length < 10 {
return Datetime{}, fmt.Errorf("%w: string too short", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: string too short", errDatetime)
}

// Date: YYYY-MM-DD
Expand All @@ -57,32 +61,32 @@ func ParseDatetime(s string) (Datetime, error) {
unicode.IsDigit(rune(s[1])) &&
unicode.IsDigit(rune(s[2])) &&
unicode.IsDigit(rune(s[3]))) {
return Datetime{}, fmt.Errorf("%w: invalid year", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid year", errDatetime)
}
year = 1000*int(rune(s[0])-'0') +
100*int(rune(s[1])-'0') +
10*int(rune(s[2])-'0') +
int(rune(s[3])-'0')

if s[4] != '-' {
return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[4])))
return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[4])))
}

// MM
if !(unicode.IsDigit(rune(s[5])) &&
unicode.IsDigit(rune(s[6]))) {
return Datetime{}, fmt.Errorf("%w: invalid month", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid month", errDatetime)
}
month = 10*int(rune(s[5])-'0') + int(rune(s[6])-'0')

if s[7] != '-' {
return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[7])))
return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[7])))
}

// DD
if !(unicode.IsDigit(rune(s[8])) &&
unicode.IsDigit(rune(s[9]))) {
return Datetime{}, fmt.Errorf("%w: invalid day", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid day", errDatetime)
}
day = 10*int(rune(s[8])-'0') + int(rune(s[9])-'0')

Expand All @@ -94,7 +98,7 @@ func ParseDatetime(s string) (Datetime, error) {

// If the length is less than 20, we can't have a valid time.
if length < 20 {
return Datetime{}, fmt.Errorf("%w: invalid time", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid time", errDatetime)
}

// Time: Thh:mm:ss?
Expand All @@ -106,32 +110,32 @@ func ParseDatetime(s string) (Datetime, error) {
// ? is at 19, and... we'll skip to get back to that.

if s[10] != 'T' {
return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[10])))
return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[10])))
}

if !(unicode.IsDigit(rune(s[11])) &&
unicode.IsDigit(rune(s[12]))) {
return Datetime{}, fmt.Errorf("%w: invalid hour", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid hour", errDatetime)
}
hour = 10*int(rune(s[11])-'0') + int(rune(s[12])-'0')

if s[13] != ':' {
return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[13])))
return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[13])))
}

if !(unicode.IsDigit(rune(s[14])) &&
unicode.IsDigit(rune(s[15]))) {
return Datetime{}, fmt.Errorf("%w: invalid minute", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid minute", errDatetime)
}
minute = 10*int(rune(s[14])-'0') + int(rune(s[15])-'0')

if s[16] != ':' {
return Datetime{}, fmt.Errorf("%w: unexpected character %s", ErrDatetime, strconv.QuoteRune(rune(s[16])))
return Datetime{}, fmt.Errorf("%w: unexpected character %s", errDatetime, strconv.QuoteRune(rune(s[16])))
}

if !(unicode.IsDigit(rune(s[17])) &&
unicode.IsDigit(rune(s[18]))) {
return Datetime{}, fmt.Errorf("%w: invalid second", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid second", errDatetime)
}
second = 10*int(rune(s[17])-'0') + int(rune(s[18])-'0')

Expand All @@ -142,29 +146,29 @@ func ParseDatetime(s string) (Datetime, error) {
trailerOffset := 19
if s[19] == '.' {
if length < 23 {
return Datetime{}, fmt.Errorf("%w: invalid millisecond", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid millisecond", errDatetime)
}

if !(unicode.IsDigit(rune(s[20])) &&
unicode.IsDigit(rune(s[21])) &&
unicode.IsDigit(rune(s[22]))) {
return Datetime{}, fmt.Errorf("%w: invalid millisecond", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid millisecond", errDatetime)
}

milli = 100*int(rune(s[20])-'0') + 10*int(rune(s[21])-'0') + int(rune(s[22])-'0')
trailerOffset = 23
}

if length == trailerOffset {
return Datetime{}, fmt.Errorf("%w: expected time zone designator", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: expected time zone designator", errDatetime)
}

// At this point, we can only have 2 possible lengths. Anything else is an error.
switch s[trailerOffset] {
case 'Z':
if length > trailerOffset+1 {
// If something comes after the Z, it's an error
return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", errDatetime)
}
case '+', '-':
sign := 1
Expand All @@ -173,25 +177,25 @@ func ParseDatetime(s string) (Datetime, error) {
}

if length > trailerOffset+5 {
return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: unexpected trailer after time zone designator", errDatetime)
} else if length != trailerOffset+5 {
return Datetime{}, fmt.Errorf("%w: invalid time zone offset", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid time zone offset", errDatetime)
}

// get the time zone offset hhmm.
if !(unicode.IsDigit(rune(s[trailerOffset+1])) &&
unicode.IsDigit(rune(s[trailerOffset+2])) &&
unicode.IsDigit(rune(s[trailerOffset+3])) &&
unicode.IsDigit(rune(s[trailerOffset+4]))) {
return Datetime{}, fmt.Errorf("%w: invalid time zone offset", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid time zone offset", errDatetime)
}

hh := time.Duration(10*int64(rune(s[trailerOffset+1])-'0')+int64(rune(s[trailerOffset+2])-'0')) * time.Hour
mm := time.Duration(10*int64(rune(s[trailerOffset+3])-'0')+int64(rune(s[trailerOffset+4])-'0')) * time.Minute
offset = time.Duration(sign) * (hh + mm)

default:
return Datetime{}, fmt.Errorf("%w: invalid time zone designator", ErrDatetime)
return Datetime{}, fmt.Errorf("%w: invalid time zone designator", errDatetime)
}

t := time.Date(year, time.Month(month), day,
Expand All @@ -213,7 +217,7 @@ func (a Datetime) Equal(bi Value) bool {
func (a Datetime) LessThan(bi Value) (bool, error) {
b, ok := bi.(Datetime)
if !ok {
return false, ErrNotComparable
return false, internal.ErrNotComparable
}
return a.value < b.value, nil
}
Expand All @@ -224,7 +228,7 @@ func (a Datetime) LessThan(bi Value) (bool, error) {
func (a Datetime) LessThanOrEqual(bi Value) (bool, error) {
b, ok := bi.(Datetime)
if !ok {
return false, ErrNotComparable
return false, internal.ErrNotComparable
}
return a.value <= b.value, nil
}
Expand Down
7 changes: 4 additions & 3 deletions types/datetime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/cedar-policy/cedar-go/internal"
"github.com/cedar-policy/cedar-go/internal/testutil"
"github.com/cedar-policy/cedar-go/types"
)
Expand Down Expand Up @@ -104,7 +105,7 @@ func TestDatetime(t *testing.T) {
t.Run(fmt.Sprintf("%d_%s->%s", ti, tt.in, tt.errStr), func(t *testing.T) {
t.Parallel()
_, err := types.ParseDatetime(tt.in)
testutil.ErrorIs(t, err, types.ErrDatetime)
testutil.ErrorIs(t, err, internal.ErrDatetime)
testutil.Equals(t, err.Error(), tt.errStr)
})
}
Expand Down Expand Up @@ -145,7 +146,7 @@ func TestDatetime(t *testing.T) {
{one, zero, false, nil},
{zero, one, true, nil},
{zero, zero, false, nil},
{zero, f, false, types.ErrNotComparable},
{zero, f, false, internal.ErrNotComparable},
}

for ti, tt := range tests {
Expand Down Expand Up @@ -175,7 +176,7 @@ func TestDatetime(t *testing.T) {
{one, zero, false, nil},
{zero, one, true, nil},
{zero, zero, true, nil},
{zero, f, false, types.ErrNotComparable},
{zero, f, false, internal.ErrNotComparable},
}

for ti, tt := range tests {
Expand Down
29 changes: 16 additions & 13 deletions types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
"strconv"
"strings"

"github.com/cedar-policy/cedar-go/internal"
"golang.org/x/exp/constraints"
)

var errDecimal = internal.ErrDecimal

// decimalPrecision is the precision of a Decimal.
const decimalPrecision = 10000

Expand All @@ -26,9 +29,9 @@ type Decimal struct {
// sign of intPart and tenThousandths should match.
func newDecimal(intPart int64, tenThousandths int16) (Decimal, error) {
if intPart > 922337203685477 || (intPart == 922337203685477 && tenThousandths > 5807) {
return Decimal{}, fmt.Errorf("%w: value would overflow", ErrDecimal)
return Decimal{}, fmt.Errorf("%w: value would overflow", errDecimal)
} else if intPart < -922337203685477 || (intPart == -922337203685477 && tenThousandths < -5808) {
return Decimal{}, fmt.Errorf("%w: value would underflow", ErrDecimal)
return Decimal{}, fmt.Errorf("%w: value would underflow", errDecimal)
}

return Decimal{value: intPart*decimalPrecision + int64(tenThousandths)}, nil
Expand All @@ -37,7 +40,7 @@ func newDecimal(intPart int64, tenThousandths int16) (Decimal, error) {
// NewDecimal returns a Decimal value of i * 10^exponent.
func NewDecimal(i int64, exponent int) (Decimal, error) {
if exponent < -4 || exponent > 14 {
return Decimal{}, fmt.Errorf("%w: exponent value of %v exceeds maximum range of Decimal", ErrDecimal, exponent)
return Decimal{}, fmt.Errorf("%w: exponent value of %v exceeds maximum range of Decimal", errDecimal, exponent)
}

var intPart int64
Expand All @@ -48,9 +51,9 @@ func NewDecimal(i int64, exponent int) (Decimal, error) {
} else {
intPart = i * int64(math.Pow10(exponent))
if i > 0 && intPart < i {
return Decimal{}, fmt.Errorf("%w: value %ve%v would overflow", ErrDecimal, i, exponent)
return Decimal{}, fmt.Errorf("%w: value %ve%v would overflow", errDecimal, i, exponent)
} else if i < 0 && intPart > i {
return Decimal{}, fmt.Errorf("%w: value %ve%v would underflow", ErrDecimal, i, exponent)
return Decimal{}, fmt.Errorf("%w: value %ve%v would underflow", errDecimal, i, exponent)
}
}

Expand All @@ -73,9 +76,9 @@ func NewDecimalFromInt[T constraints.Signed](i T) (Decimal, error) {
func NewDecimalFromFloat[T constraints.Float](f T) (Decimal, error) {
f = f * decimalPrecision
if f > math.MaxInt64 {
return Decimal{}, fmt.Errorf("%w: value %v would overflow", ErrDecimal, f)
return Decimal{}, fmt.Errorf("%w: value %v would overflow", errDecimal, f)
} else if f < math.MinInt64 {
return Decimal{}, fmt.Errorf("%w: value %v would underflow", ErrDecimal, f)
return Decimal{}, fmt.Errorf("%w: value %v would underflow", errDecimal, f)
}

return Decimal{int64(f)}, nil
Expand All @@ -94,29 +97,29 @@ func (d Decimal) Compare(other Decimal) int {
func ParseDecimal(s string) (Decimal, error) {
decimalIndex := strings.Index(s, ".")
if decimalIndex < 0 {
return Decimal{}, fmt.Errorf("%w: missing decimal point", ErrDecimal)
return Decimal{}, fmt.Errorf("%w: missing decimal point", errDecimal)
}

intPart, err := strconv.ParseInt(s[0:decimalIndex], 10, 64)
if err != nil {
if errors.Is(err, strconv.ErrRange) {
return Decimal{}, fmt.Errorf("%w: value would overflow", ErrDecimal)
return Decimal{}, fmt.Errorf("%w: value would overflow", errDecimal)
}
return Decimal{}, fmt.Errorf("%w: %w", ErrDecimal, err)
return Decimal{}, fmt.Errorf("%w: %w", errDecimal, err)
}

fracPartStr := s[decimalIndex+1:]
fracPart, err := strconv.ParseUint(fracPartStr, 10, 16)
if err != nil {
if errors.Is(err, strconv.ErrRange) {
return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", ErrDecimal)
return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", errDecimal)
}
return Decimal{}, fmt.Errorf("%w: %w", ErrDecimal, err)
return Decimal{}, fmt.Errorf("%w: %w", errDecimal, err)
}

decimalPlaces := len(fracPartStr)
if decimalPlaces > 4 {
return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", ErrDecimal)
return Decimal{}, fmt.Errorf("%w: fractional part exceeds Decimal precision", errDecimal)
}

tenThousandths := int16(fracPart) * int16(math.Pow10(4-decimalPlaces))
Expand Down
Loading

0 comments on commit c3dfe27

Please sign in to comment.