From 6860294bc28ce52a410550bcbbdfcb87ab832941 Mon Sep 17 00:00:00 2001 From: Jason Paulos Date: Thu, 4 Jan 2024 10:29:19 -0500 Subject: [PATCH] Fix uint64 return types on ec ops (#717) --- pyteal/ast/ec.py | 19 +++++++++++-------- pyteal/ast/ec_test.py | 17 +++++++++-------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pyteal/ast/ec.py b/pyteal/ast/ec.py index b16252a3e..f78313ab3 100644 --- a/pyteal/ast/ec.py +++ b/pyteal/ast/ec.py @@ -27,7 +27,9 @@ def __init__(self, id: int, name: str, min_version: int) -> None: class EcOperation(Expr): - def __init__(self, op: Op, curve: EllipticCurve, args: list[Expr]) -> None: + def __init__( + self, op: Op, curve: EllipticCurve, args: list[Expr], return_type: TealType + ) -> None: super().__init__() self.op = op assert curve in EllipticCurve @@ -35,6 +37,7 @@ def __init__(self, op: Op, curve: EllipticCurve, args: list[Expr]) -> None: for arg in args: require_type(arg, TealType.bytes) self.args = args + self.return_type = return_type def __teal__(self, options: "CompileOptions"): verifyProgramVersion( @@ -52,7 +55,7 @@ def __str__(self): return f"(EcOperation {self.op} {self.curve} {self.args})" def type_of(self): - return TealType.bytes + return self.return_type def has_return(self): return False @@ -70,7 +73,7 @@ def EcAdd(curve: EllipticCurve, a: Expr, b: Expr) -> Expr: An expression which evaluates to the sum of the two points on the given curve. """ - return EcOperation(Op.ec_add, curve, [a, b]) + return EcOperation(Op.ec_add, curve, [a, b], TealType.bytes) def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr: @@ -86,7 +89,7 @@ def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr: An expression which evaluates to the product of the point and scalar on the given curve. """ - return EcOperation(Op.ec_scalar_mul, curve, [point, scalar]) + return EcOperation(Op.ec_scalar_mul, curve, [point, scalar], TealType.bytes) def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr: @@ -102,7 +105,7 @@ def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr: point in `a` with its respective point in `b` is equal to the identity element of the target group. Otherwise, evaluates to 0. """ - return EcOperation(Op.ec_pairing_check, curve, [a, b]) + return EcOperation(Op.ec_pairing_check, curve, [a, b], TealType.uint64) def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr: @@ -117,7 +120,7 @@ def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr: Returns: An expression that evaluates to curve point :code:`b_0a_0 + b_1a_1 + b_2a_2 + ... + b_Na_N`. """ - return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b]) + return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b], TealType.bytes) def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr: @@ -132,7 +135,7 @@ def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr: subgroup of the curve (including the point at infinity) else 0. Program fails if the point is not in the curve at all. """ - return EcOperation(Op.ec_subgroup_check, curve, [a]) + return EcOperation(Op.ec_subgroup_check, curve, [a], TealType.uint64) def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr: @@ -145,4 +148,4 @@ def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr: Returns: An expression that evaluates to the mapped point. """ - return EcOperation(Op.ec_map_to, curve, [a]) + return EcOperation(Op.ec_map_to, curve, [a], TealType.bytes) diff --git a/pyteal/ast/ec_test.py b/pyteal/ast/ec_test.py index 499dcf4b3..925503659 100644 --- a/pyteal/ast/ec_test.py +++ b/pyteal/ast/ec_test.py @@ -10,23 +10,24 @@ | Callable[[pt.EllipticCurve, pt.Expr, pt.Expr], pt.Expr], pt.Op, int, + pt.TealType, ] ] = [ - (pt.EcAdd, pt.Op.ec_add, 2), - (pt.EcScalarMul, pt.Op.ec_scalar_mul, 2), - (pt.EcPairingCheck, pt.Op.ec_pairing_check, 2), - (pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2), - (pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1), - (pt.EcMapTo, pt.Op.ec_map_to, 1), + (pt.EcAdd, pt.Op.ec_add, 2, pt.TealType.bytes), + (pt.EcScalarMul, pt.Op.ec_scalar_mul, 2, pt.TealType.bytes), + (pt.EcPairingCheck, pt.Op.ec_pairing_check, 2, pt.TealType.uint64), + (pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2, pt.TealType.bytes), + (pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1, pt.TealType.uint64), + (pt.EcMapTo, pt.Op.ec_map_to, 1, pt.TealType.bytes), ] def test_EcOperation(): - for operation, expected_op, num_args in OPERATIONS: + for operation, expected_op, num_args, expected_return_type in OPERATIONS: for curve in pt.EllipticCurve: args = [pt.Bytes(f"arg{i}") for i in range(num_args)] expr = operation(curve, *args) - assert expr.type_of() == pt.TealType.bytes + assert expr.type_of() == expected_return_type expected = pt.TealSimpleBlock( [pt.TealOp(arg, pt.Op.byte, f'"arg{i}"') for i, arg in enumerate(args)]