Skip to content

Commit

Permalink
Fix uint64 return types on ec ops (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonpaulos authored Jan 4, 2024
1 parent 7f30e21 commit 6860294
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
19 changes: 11 additions & 8 deletions pyteal/ast/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ def __init__(self, id: int, name: str, min_version: int) -> None:


class EcOperation(Expr):
def __init__(self, op: Op, curve: EllipticCurve, args: list[Expr]) -> None:
def __init__(
self, op: Op, curve: EllipticCurve, args: list[Expr], return_type: TealType
) -> None:
super().__init__()
self.op = op
assert curve in EllipticCurve
self.curve = curve
for arg in args:
require_type(arg, TealType.bytes)
self.args = args
self.return_type = return_type

def __teal__(self, options: "CompileOptions"):
verifyProgramVersion(
Expand All @@ -52,7 +55,7 @@ def __str__(self):
return f"(EcOperation {self.op} {self.curve} {self.args})"

def type_of(self):
return TealType.bytes
return self.return_type

def has_return(self):
return False
Expand All @@ -70,7 +73,7 @@ def EcAdd(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
An expression which evaluates to the sum of the two points on the given
curve.
"""
return EcOperation(Op.ec_add, curve, [a, b])
return EcOperation(Op.ec_add, curve, [a, b], TealType.bytes)


def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr:
Expand All @@ -86,7 +89,7 @@ def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr:
An expression which evaluates to the product of the point and scalar on
the given curve.
"""
return EcOperation(Op.ec_scalar_mul, curve, [point, scalar])
return EcOperation(Op.ec_scalar_mul, curve, [point, scalar], TealType.bytes)


def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
Expand All @@ -102,7 +105,7 @@ def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
point in `a` with its respective point in `b` is equal to the identity
element of the target group. Otherwise, evaluates to 0.
"""
return EcOperation(Op.ec_pairing_check, curve, [a, b])
return EcOperation(Op.ec_pairing_check, curve, [a, b], TealType.uint64)


def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
Expand All @@ -117,7 +120,7 @@ def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
Returns:
An expression that evaluates to curve point :code:`b_0a_0 + b_1a_1 + b_2a_2 + ... + b_Na_N`.
"""
return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b])
return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b], TealType.bytes)


def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr:
Expand All @@ -132,7 +135,7 @@ def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr:
subgroup of the curve (including the point at infinity) else 0. Program
fails if the point is not in the curve at all.
"""
return EcOperation(Op.ec_subgroup_check, curve, [a])
return EcOperation(Op.ec_subgroup_check, curve, [a], TealType.uint64)


def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr:
Expand All @@ -145,4 +148,4 @@ def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr:
Returns:
An expression that evaluates to the mapped point.
"""
return EcOperation(Op.ec_map_to, curve, [a])
return EcOperation(Op.ec_map_to, curve, [a], TealType.bytes)
17 changes: 9 additions & 8 deletions pyteal/ast/ec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@
| Callable[[pt.EllipticCurve, pt.Expr, pt.Expr], pt.Expr],
pt.Op,
int,
pt.TealType,
]
] = [
(pt.EcAdd, pt.Op.ec_add, 2),
(pt.EcScalarMul, pt.Op.ec_scalar_mul, 2),
(pt.EcPairingCheck, pt.Op.ec_pairing_check, 2),
(pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2),
(pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1),
(pt.EcMapTo, pt.Op.ec_map_to, 1),
(pt.EcAdd, pt.Op.ec_add, 2, pt.TealType.bytes),
(pt.EcScalarMul, pt.Op.ec_scalar_mul, 2, pt.TealType.bytes),
(pt.EcPairingCheck, pt.Op.ec_pairing_check, 2, pt.TealType.uint64),
(pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2, pt.TealType.bytes),
(pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1, pt.TealType.uint64),
(pt.EcMapTo, pt.Op.ec_map_to, 1, pt.TealType.bytes),
]


def test_EcOperation():
for operation, expected_op, num_args in OPERATIONS:
for operation, expected_op, num_args, expected_return_type in OPERATIONS:
for curve in pt.EllipticCurve:
args = [pt.Bytes(f"arg{i}") for i in range(num_args)]
expr = operation(curve, *args)
assert expr.type_of() == pt.TealType.bytes
assert expr.type_of() == expected_return_type

expected = pt.TealSimpleBlock(
[pt.TealOp(arg, pt.Op.byte, f'"arg{i}"') for i, arg in enumerate(args)]
Expand Down

0 comments on commit 6860294

Please sign in to comment.