diff --git a/pyteal/ast/abi/array_base.py b/pyteal/ast/abi/array_base.py index 7fdf51dc6..e88558c58 100644 --- a/pyteal/ast/abi/array_base.py +++ b/pyteal/ast/abi/array_base.py @@ -20,9 +20,12 @@ from pyteal.ast.abi.type import TypeSpec, BaseType, ComputedValue from pyteal.ast.abi.tuple import _encode_tuple -from pyteal.ast.abi.bool import Bool, BoolTypeSpec +from pyteal.ast.abi.bool import BoolTypeSpec from pyteal.ast.abi.uint import Uint16, Uint16TypeSpec -from pyteal.ast.abi.util import substring_for_decoding +from pyteal.ast.abi.util import ( + substring_for_decoding, + _GetAgainstEncoding, +) T = TypeVar("T", bound=BaseType) @@ -207,20 +210,8 @@ def __init__(self, array: Array[T], index: Expr) -> None: def produced_type_spec(self) -> TypeSpec: return self.array.type_spec().value_type_spec() - def store_into(self, output: T) -> Expr: - """Partitions the byte string of the given ABI array and stores the byte string of array - element in the ABI value output. - - The function first checks if the output type matches with array element type, and throw - error if type-mismatch. - - Args: - output: An ABI typed value that the array element byte string stores into. - - Returns: - An expression that stores the byte string of the array element into value `output`. - """ - if output.type_spec() != self.produced_type_spec(): + def __prototype_encoding_store_into(self, output: T | None = None) -> Expr: + if output is not None and output.type_spec() != self.produced_type_spec(): raise TealInputError("Output type does not match value type") encodedArray = self.array.encode() @@ -229,11 +220,13 @@ def store_into(self, output: T) -> Expr: # If the array element type is Bool, we compute the bit index # (if array is dynamic we add 16 to bit index for dynamic array length uint16 prefix) # and decode bit with given array encoding and the bit index for boolean bit. - if output.type_spec() == BoolTypeSpec(): + if self.array.type_spec().value_type_spec() == BoolTypeSpec(): bitIndex = self.index if arrayType.is_dynamic(): bitIndex = bitIndex + Int(Uint16TypeSpec().bit_size()) - return cast(Bool, output).decode_bit(encodedArray, bitIndex) + return _GetAgainstEncoding( + encodedArray, BoolTypeSpec(), start_index=bitIndex + ).get_or_store(output) # Compute the byteIndex (first byte indicating the element encoding) # (If the array is dynamic, add 2 to byte index for dynamic array length uint16 prefix) @@ -271,16 +264,42 @@ def store_into(self, output: T) -> Expr: .Else(nextValueStart) ) - return output.decode( - encodedArray, start_index=valueStart, end_index=valueEnd - ) + return _GetAgainstEncoding( + encodedArray, + arrayType.value_type_spec(), + start_index=valueStart, + end_index=valueEnd, + ).get_or_store(output) # Handling case for array elements are static: # since array._stride() is element's static byte length # we partition the substring for array element. valueStart = byteIndex valueLength = Int(arrayType._stride()) - return output.decode(encodedArray, start_index=valueStart, length=valueLength) + return _GetAgainstEncoding( + encodedArray, + arrayType.value_type_spec(), + start_index=valueStart, + length=valueLength, + ).get_or_store(output) + + def store_into(self, output: T) -> Expr: + """Partitions the byte string of the given ABI array and stores the byte string of array + element in the ABI value output. + + The function first checks if the output type matches with array element type, and throw + error if type-mismatch. + + Args: + output: An ABI typed value that the array element byte string stores into. + + Returns: + An expression that stores the byte string of the array element into value `output`. + """ + return self.__prototype_encoding_store_into(output) + + def encode(self) -> Expr: + return self.__prototype_encoding_store_into() ArrayElement.__module__ = "pyteal.abi" diff --git a/pyteal/ast/abi/array_base_test.py b/pyteal/ast/abi/array_base_test.py index 74dfa74ec..e48703648 100644 --- a/pyteal/ast/abi/array_base_test.py +++ b/pyteal/ast/abi/array_base_test.py @@ -161,3 +161,105 @@ def test_ArrayElement_store_into(): with pytest.raises(pt.TealInputError): element.store_into(abi.Tuple(abi.TupleTypeSpec(elementType))) + + +def test_ArrayElement_encoding(): + from pyteal.ast.abi.util import substring_for_decoding + + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + staticArrayType = abi.StaticArrayTypeSpec(elementType, 100) + staticArray = staticArrayType.new_instance() + index = pt.Int(9) + + element = abi.ArrayElement(staticArray, index) + expr = element.encode() + + encoded = staticArray.encode() + stride = pt.Int(staticArray.type_spec()._stride()) + expectedLength = staticArray.length() + if elementType == abi.BoolTypeSpec(): + expectedExpr = pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, index), + ) + elif not elementType.is_dynamic(): + expectedExpr = substring_for_decoding( + encoded, start_index=stride * index, length=stride + ) + else: + expectedExpr = substring_for_decoding( + encoded, + start_index=pt.ExtractUint16(encoded, stride * index), + end_index=pt.If(index + pt.Int(1) == expectedLength) + .Then(pt.Len(encoded)) + .Else(pt.ExtractUint16(encoded, stride * index + pt.Int(2))), + ) + + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = pt.TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = pt.TealBlock.NormalizeBlocks(actual) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + with pytest.raises(pt.TealInputError): + element.store_into(abi.Tuple(abi.TupleTypeSpec(elementType))) + + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + dynamicArray = dynamicArrayType.new_instance() + index = pt.Int(9) + + element = abi.ArrayElement(dynamicArray, index) + expr = element.encode() + + encoded = dynamicArray.encode() + stride = pt.Int(dynamicArray.type_spec()._stride()) + expectedLength = dynamicArray.length() + if elementType == abi.BoolTypeSpec(): + expectedExpr = pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, index + pt.Int(16)), + ) + elif not elementType.is_dynamic(): + expectedExpr = substring_for_decoding( + encoded, start_index=stride * index + pt.Int(2), length=stride + ) + else: + expectedExpr = substring_for_decoding( + encoded, + start_index=pt.ExtractUint16(encoded, stride * index + pt.Int(2)) + + pt.Int(2), + end_index=pt.If(index + pt.Int(1) == expectedLength) + .Then(pt.Len(encoded)) + .Else( + pt.ExtractUint16(encoded, stride * index + pt.Int(2) + pt.Int(2)) + + pt.Int(2) + ), + ) + + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = pt.TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = pt.TealBlock.NormalizeBlocks(actual) + + with pt.TealComponent.Context.ignoreExprEquality(): + with pt.TealComponent.Context.ignoreScratchSlotEquality(): + assert actual == expected + + assert pt.TealBlock.MatchScratchSlotReferences( + pt.TealBlock.GetReferencedScratchSlots(actual), + pt.TealBlock.GetReferencedScratchSlots(expected), + ) + + with pytest.raises(pt.TealInputError): + element.store_into(abi.Tuple(abi.TupleTypeSpec(elementType))) diff --git a/pyteal/ast/abi/tuple.py b/pyteal/ast/abi/tuple.py index ff0f9f190..fa0f78683 100644 --- a/pyteal/ast/abi/tuple.py +++ b/pyteal/ast/abi/tuple.py @@ -1,4 +1,5 @@ from inspect import get_annotations +from dataclasses import dataclass from typing import ( List, Sequence, @@ -14,15 +15,14 @@ ) from collections import OrderedDict -from pyteal.types import TealType +from pyteal.types import TealType, require_type from pyteal.errors import TealInputError, TealInternalError from pyteal.ast.expr import Expr from pyteal.ast.seq import Seq from pyteal.ast.int import Int from pyteal.ast.bytes import Bytes from pyteal.ast.unaryexpr import Len -from pyteal.ast.binaryexpr import ExtractUint16, GetBit -from pyteal.ast.ternaryexpr import SetBit +from pyteal.ast.binaryexpr import ExtractUint16 from pyteal.ast.naryexpr import Concat from pyteal.ast.abstractvar import alloc_abstract_var @@ -37,7 +37,11 @@ _bool_aware_static_byte_length, ) from pyteal.ast.abi.uint import NUM_BITS_IN_BYTE, Uint16 -from pyteal.ast.abi.util import substring_for_decoding, type_spec_from_annotation +from pyteal.ast.abi.util import ( + substring_for_decoding, + type_spec_from_annotation, + _GetAgainstEncoding, +) def _encode_tuple(values: Sequence[BaseType]) -> Expr: @@ -118,190 +122,128 @@ def _encode_tuple(values: Sequence[BaseType]) -> Expr: return Concat(*toConcat) -def _index_tuple_bytes( - value_types: Sequence[TypeSpec], encoded: Expr, index: int -) -> Expr: - if not (0 <= index < len(value_types)): - raise ValueError("Index outside of range") +@dataclass +class _IndexTuple: + value_types: Sequence[TypeSpec] + encoded: Expr - offset = 0 - ignoreNext = 0 - lastBoolStart = 0 - lastBoolLength = 0 - for i, typeBefore in enumerate(value_types[:index]): - if ignoreNext > 0: - ignoreNext -= 1 - continue + def __post_init__(self): + require_type(self.encoded, TealType.bytes) - if typeBefore == BoolTypeSpec(): - lastBoolStart = offset - lastBoolLength = _consecutive_bool_type_spec_num(value_types, i) - offset += _bool_sequence_length(lastBoolLength) - ignoreNext = lastBoolLength - 1 - continue + def get_or_store(self, index: int, output: BaseType | None = None) -> Expr: + if index not in range(len(self.value_types)): + raise ValueError("Index outside of range") - if typeBefore.is_dynamic(): - offset += 2 - continue - - offset += typeBefore.byte_length_static() - - valueType = value_types[index] - - if type(valueType) is Bool: - if ignoreNext > 0: - # value is in the middle of a bool sequence - bitOffsetInBoolSeq = lastBoolLength - ignoreNext - bitOffsetInEncoded = lastBoolStart * NUM_BITS_IN_BYTE + bitOffsetInBoolSeq - else: - # value is the beginning of a bool sequence (or a single bool) - bitOffsetInEncoded = offset * NUM_BITS_IN_BYTE - return SetBit(Bytes(b"\x00"), Int(0), GetBit(encoded, Int(bitOffsetInEncoded))) - - if valueType.is_dynamic(): - hasNextDynamicValue = False - nextDynamicValueOffset = offset + 2 + offset = 0 ignoreNext = 0 - for i, typeAfter in enumerate(value_types[index + 1 :], start=index + 1): + lastBoolStart = 0 + lastBoolLength = 0 + for i, typeBefore in enumerate(self.value_types[:index]): if ignoreNext > 0: ignoreNext -= 1 continue - if type(typeAfter) is BoolTypeSpec: - boolLength = _consecutive_bool_type_spec_num(value_types, i) - nextDynamicValueOffset += _bool_sequence_length(boolLength) - ignoreNext = boolLength - 1 + if typeBefore == BoolTypeSpec(): + lastBoolStart = offset + lastBoolLength = _consecutive_bool_type_spec_num(self.value_types, i) + offset += _bool_sequence_length(lastBoolLength) + ignoreNext = lastBoolLength - 1 continue - if typeAfter.is_dynamic(): - hasNextDynamicValue = True - break - - nextDynamicValueOffset += typeAfter.byte_length_static() - - start_index = ExtractUint16(encoded, Int(offset)) - if not hasNextDynamicValue: - # This is the final dynamic value, so decode the substring from start_index to the end of - # encoded - return substring_for_decoding(encoded, start_index=start_index) - - # There is a dynamic value after this one, and end_index is where its tail starts, so decode - # the substring from start_index to end_index - end_index = ExtractUint16(encoded, Int(nextDynamicValueOffset)) - return substring_for_decoding( - encoded, start_index=start_index, end_index=end_index - ) - - start_index = Int(offset) - length = Int(valueType.byte_length_static()) - - if index + 1 == len(value_types): - if offset == 0: - # This is the first and only value in the tuple, so decode all of encoded - return encoded - # This is the last value in the tuple, so decode the substring from start_index to the end of - # encoded - return substring_for_decoding(encoded, start_index=start_index) - - if offset == 0: - # This is the first value in the tuple, so decode the substring from 0 with length length - return substring_for_decoding(encoded, length=length) - - # This is not the first or last value, so decode the substring from start_index with length length - return substring_for_decoding(encoded, start_index=start_index, length=length) - - -def _index_tuple( - value_types: Sequence[TypeSpec], encoded: Expr, index: int, output: BaseType -) -> Expr: - if not (0 <= index < len(value_types)): - raise ValueError("Index outside of range") - - offset = 0 - ignoreNext = 0 - lastBoolStart = 0 - lastBoolLength = 0 - for i, typeBefore in enumerate(value_types[:index]): - if ignoreNext > 0: - ignoreNext -= 1 - continue - - if typeBefore == BoolTypeSpec(): - lastBoolStart = offset - lastBoolLength = _consecutive_bool_type_spec_num(value_types, i) - offset += _bool_sequence_length(lastBoolLength) - ignoreNext = lastBoolLength - 1 - continue - - if typeBefore.is_dynamic(): - offset += 2 - continue - - offset += typeBefore.byte_length_static() - - valueType = value_types[index] - if output.type_spec() != valueType: - raise TypeError("Output type does not match value type") - - if type(output) is Bool: - if ignoreNext > 0: - # value is in the middle of a bool sequence - bitOffsetInBoolSeq = lastBoolLength - ignoreNext - bitOffsetInEncoded = lastBoolStart * NUM_BITS_IN_BYTE + bitOffsetInBoolSeq - else: - # value is the beginning of a bool sequence (or a single bool) - bitOffsetInEncoded = offset * NUM_BITS_IN_BYTE - return output.decode_bit(encoded, Int(bitOffsetInEncoded)) - - if valueType.is_dynamic(): - hasNextDynamicValue = False - nextDynamicValueOffset = offset + 2 - ignoreNext = 0 - for i, typeAfter in enumerate(value_types[index + 1 :], start=index + 1): - if ignoreNext > 0: - ignoreNext -= 1 - continue - - if type(typeAfter) is BoolTypeSpec: - boolLength = _consecutive_bool_type_spec_num(value_types, i) - nextDynamicValueOffset += _bool_sequence_length(boolLength) - ignoreNext = boolLength - 1 + if typeBefore.is_dynamic(): + offset += 2 continue - if typeAfter.is_dynamic(): - hasNextDynamicValue = True - break + offset += typeBefore.byte_length_static() - nextDynamicValueOffset += typeAfter.byte_length_static() + valueType = self.value_types[index] + if output is not None and output.type_spec() != valueType: + raise TypeError("Output type does not match value type") - start_index = ExtractUint16(encoded, Int(offset)) - if not hasNextDynamicValue: - # This is the final dynamic value, so decode the substring from start_index to the end of + if type(valueType) is BoolTypeSpec: + if ignoreNext > 0: + # value is in the middle of a bool sequence + bitOffsetInBoolSeq = lastBoolLength - ignoreNext + bitOffsetInEncoded = ( + lastBoolStart * NUM_BITS_IN_BYTE + bitOffsetInBoolSeq + ) + else: + # value is the beginning of a bool sequence (or a single bool) + bitOffsetInEncoded = offset * NUM_BITS_IN_BYTE + + return _GetAgainstEncoding( + self.encoded, + BoolTypeSpec(), + start_index=Int(bitOffsetInEncoded), + ).get_or_store(output) + + if valueType.is_dynamic(): + hasNextDynamicValue = False + nextDynamicValueOffset = offset + 2 + ignoreNext = 0 + for i, typeAfter in enumerate( + self.value_types[index + 1 :], start=index + 1 + ): + if ignoreNext > 0: + ignoreNext -= 1 + continue + + if type(typeAfter) is BoolTypeSpec: + boolLength = _consecutive_bool_type_spec_num(self.value_types, i) + nextDynamicValueOffset += _bool_sequence_length(boolLength) + ignoreNext = boolLength - 1 + continue + + if typeAfter.is_dynamic(): + hasNextDynamicValue = True + break + + nextDynamicValueOffset += typeAfter.byte_length_static() + + start_index = ExtractUint16(self.encoded, Int(offset)) + if not hasNextDynamicValue: + # This is the final dynamic value, so decode the substring from start_index to the end of + # encoded + return _GetAgainstEncoding( + self.encoded, valueType, start_index=start_index + ).get_or_store(output) + + # There is a dynamic value after this one, and end_index is where its tail starts, so decode + # the substring from start_index to end_index + end_index = ExtractUint16(self.encoded, Int(nextDynamicValueOffset)) + return _GetAgainstEncoding( + self.encoded, + valueType, + start_index=start_index, + end_index=end_index, + ).get_or_store(output) + + start_index = Int(offset) + length = Int(valueType.byte_length_static()) + + if index + 1 == len(self.value_types): + if offset == 0: + # This is the first and only value in the tuple, so decode all of encoded + if output is None: + return self.encoded + else: + return output.decode(self.encoded) + # This is the last value in the tuple, so decode the substring from start_index to the end of # encoded - return output.decode(encoded, start_index=start_index) - - # There is a dynamic value after this one, and end_index is where its tail starts, so decode - # the substring from start_index to end_index - end_index = ExtractUint16(encoded, Int(nextDynamicValueOffset)) - return output.decode(encoded, start_index=start_index, end_index=end_index) - - start_index = Int(offset) - length = Int(valueType.byte_length_static()) + return _GetAgainstEncoding( + self.encoded, valueType, start_index=start_index + ).get_or_store(output) - if index + 1 == len(value_types): if offset == 0: - # This is the first and only value in the tuple, so decode all of encoded - return output.decode(encoded) - # This is the last value in the tuple, so decode the substring from start_index to the end of - # encoded - return output.decode(encoded, start_index=start_index) + # This is the first value in the tuple, so decode the substring from 0 with length length + return _GetAgainstEncoding( + self.encoded, valueType, length=length + ).get_or_store(output) - if offset == 0: - # This is the first value in the tuple, so decode the substring from 0 with length length - return output.decode(encoded, length=length) - - # This is not the first or last value, so decode the substring from start_index with length length - return output.decode(encoded, start_index=start_index, length=length) + # This is not the first or last value, so decode the substring from start_index with length length + return _GetAgainstEncoding( + self.encoded, valueType, start_index=start_index, length=length + ).get_or_store(output) class TupleTypeSpec(TypeSpec): @@ -492,19 +434,17 @@ def produced_type_spec(self) -> TypeSpec: return self.tuple.type_spec().value_type_specs()[self.index] def store_into(self, output: T) -> Expr: - return _index_tuple( - self.tuple.type_spec().value_type_specs(), - self.tuple.encode(), + return _IndexTuple( + self.tuple.type_spec().value_type_specs(), self.tuple.encode() + ).get_or_store( self.index, output, ) def encode(self) -> Expr: - return _index_tuple_bytes( - self.tuple.type_spec().value_type_specs(), - self.tuple.encode(), - self.index, - ) + return _IndexTuple( + self.tuple.type_spec().value_type_specs(), self.tuple.encode() + ).get_or_store(self.index) TupleElement.__module__ = "pyteal.abi" diff --git a/pyteal/ast/abi/tuple_test.py b/pyteal/ast/abi/tuple_test.py index b175ce7bd..09c7c4dde 100644 --- a/pyteal/ast/abi/tuple_test.py +++ b/pyteal/ast/abi/tuple_test.py @@ -1,9 +1,9 @@ -from typing import NamedTuple, List, Callable, Literal, cast +from typing import NamedTuple, Callable, Literal, cast import pytest import pyteal as pt from pyteal import abi -from pyteal.ast.abi.tuple import _encode_tuple, _index_tuple, TupleElement +from pyteal.ast.abi.tuple import _encode_tuple, _IndexTuple, TupleElement from pyteal.ast.abi.bool import _encode_bool_sequence from pyteal.ast.abi.util import substring_for_decoding from pyteal.ast.abi.type_test import ContainerType @@ -13,7 +13,7 @@ def test_encodeTuple(): class EncodeTest(NamedTuple): - types: List[abi.BaseType] + types: list[abi.BaseType] expected: pt.Expr # variables used to construct the tests @@ -30,7 +30,7 @@ class EncodeTest(NamedTuple): tail_holder = pt.ScratchVar() encoded_tail = pt.ScratchVar() - tests: List[EncodeTest] = [ + tests: list[EncodeTest] = [ EncodeTest(types=[], expected=pt.Bytes("")), EncodeTest(types=[uint64_a], expected=uint64_a.encode()), EncodeTest( @@ -225,9 +225,10 @@ class EncodeTest(NamedTuple): def test_indexTuple(): class IndexTest(NamedTuple): - types: List[abi.TypeSpec] + types: list[abi.TypeSpec] typeIndex: int - expected: Callable[[abi.BaseType], pt.Expr] + expected_store: Callable[[abi.BaseType], pt.Expr] + expected_encode: pt.Expr # variables used to construct the tests uint64_t = abi.Uint64TypeSpec() @@ -239,103 +240,166 @@ class IndexTest(NamedTuple): encoded = pt.Bytes("encoded") - tests: List[IndexTest] = [ + tests: list[IndexTest] = [ IndexTest( types=[uint64_t], typeIndex=0, - expected=lambda output: output.decode(encoded), + expected_store=lambda output: output.decode(encoded), + expected_encode=encoded, ), IndexTest( types=[uint64_t, uint64_t], typeIndex=0, - expected=lambda output: output.decode(encoded, length=pt.Int(8)), + expected_store=lambda output: output.decode(encoded, length=pt.Int(8)), + expected_encode=substring_for_decoding(encoded, length=pt.Int(8)), ), IndexTest( types=[uint64_t, uint64_t], typeIndex=1, - expected=lambda output: output.decode(encoded, start_index=pt.Int(8)), + expected_store=lambda output: output.decode(encoded, start_index=pt.Int(8)), + expected_encode=substring_for_decoding(encoded, start_index=pt.Int(8)), ), IndexTest( types=[uint64_t, byte_t, uint64_t], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.Int(8), length=pt.Int(1) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.Int(8), length=pt.Int(1) ), ), IndexTest( types=[uint64_t, byte_t, uint64_t], typeIndex=2, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( encoded, start_index=pt.Int(9), length=pt.Int(8) ), + expected_encode=substring_for_decoding(encoded, start_index=pt.Int(9)), ), IndexTest( types=[bool_t], typeIndex=0, - expected=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(0)), + ), ), IndexTest( types=[bool_t, bool_t], typeIndex=0, - expected=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(0)), + ), ), IndexTest( types=[bool_t, bool_t], typeIndex=1, - expected=lambda output: output.decode_bit(encoded, pt.Int(1)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(1)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(1)), + ), ), IndexTest( types=[uint64_t, bool_t], typeIndex=1, - expected=lambda output: output.decode_bit(encoded, pt.Int(8 * 8)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(8 * 8)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(64)), + ), ), IndexTest( types=[uint64_t, bool_t, bool_t], typeIndex=1, - expected=lambda output: output.decode_bit(encoded, pt.Int(8 * 8)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(8 * 8)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(64)), + ), ), IndexTest( types=[uint64_t, bool_t, bool_t], typeIndex=2, - expected=lambda output: output.decode_bit(encoded, pt.Int(8 * 8 + 1)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(8 * 8 + 1)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(65)), + ), ), IndexTest( types=[bool_t, uint64_t], typeIndex=0, - expected=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(0)), + ), ), IndexTest( types=[bool_t, uint64_t], typeIndex=1, - expected=lambda output: output.decode(encoded, start_index=pt.Int(1)), + expected_store=lambda output: output.decode(encoded, start_index=pt.Int(1)), + expected_encode=substring_for_decoding(encoded, start_index=pt.Int(1)), ), IndexTest( types=[bool_t, bool_t, uint64_t], typeIndex=0, - expected=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(0)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(0)), + ), ), IndexTest( types=[bool_t, bool_t, uint64_t], typeIndex=1, - expected=lambda output: output.decode_bit(encoded, pt.Int(1)), + expected_store=lambda output: output.decode_bit(encoded, pt.Int(1)), + expected_encode=pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(1)), + ), ), IndexTest( types=[bool_t, bool_t, uint64_t], typeIndex=2, - expected=lambda output: output.decode(encoded, start_index=pt.Int(1)), + expected_store=lambda output: output.decode(encoded, start_index=pt.Int(1)), + expected_encode=substring_for_decoding(encoded, start_index=pt.Int(1)), ), IndexTest( - types=[tuple_t], typeIndex=0, expected=lambda output: output.decode(encoded) + types=[tuple_t], + typeIndex=0, + expected_store=lambda output: output.decode(encoded), + expected_encode=encoded, ), IndexTest( types=[byte_t, tuple_t], typeIndex=1, - expected=lambda output: output.decode(encoded, start_index=pt.Int(1)), + expected_store=lambda output: output.decode(encoded, start_index=pt.Int(1)), + expected_encode=substring_for_decoding(encoded, start_index=pt.Int(1)), ), IndexTest( types=[tuple_t, byte_t], typeIndex=0, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, + start_index=pt.Int(0), + length=pt.Int(tuple_t.byte_length_static()), + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.Int(0), length=pt.Int(tuple_t.byte_length_static()), @@ -344,7 +408,12 @@ class IndexTest(NamedTuple): IndexTest( types=[byte_t, tuple_t, byte_t], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, + start_index=pt.Int(1), + length=pt.Int(tuple_t.byte_length_static()), + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.Int(1), length=pt.Int(tuple_t.byte_length_static()), @@ -353,35 +422,52 @@ class IndexTest(NamedTuple): IndexTest( types=[dynamic_array_t1], typeIndex=0, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.ExtractUint16(encoded, pt.Int(0)) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(0)) ), ), IndexTest( types=[byte_t, dynamic_array_t1], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)) ), ), IndexTest( types=[dynamic_array_t1, byte_t], typeIndex=0, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.ExtractUint16(encoded, pt.Int(0)) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(0)) ), ), IndexTest( types=[byte_t, dynamic_array_t1, byte_t], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)) ), ), IndexTest( types=[byte_t, dynamic_array_t1, byte_t, dynamic_array_t2], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, + start_index=pt.ExtractUint16(encoded, pt.Int(1)), + end_index=pt.ExtractUint16(encoded, pt.Int(4)), + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)), end_index=pt.ExtractUint16(encoded, pt.Int(4)), @@ -390,14 +476,22 @@ class IndexTest(NamedTuple): IndexTest( types=[byte_t, dynamic_array_t1, byte_t, dynamic_array_t2], typeIndex=3, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.ExtractUint16(encoded, pt.Int(4)) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(4)) ), ), IndexTest( types=[byte_t, dynamic_array_t1, tuple_t, dynamic_array_t2], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, + start_index=pt.ExtractUint16(encoded, pt.Int(1)), + end_index=pt.ExtractUint16(encoded, pt.Int(4)), + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)), end_index=pt.ExtractUint16(encoded, pt.Int(4)), @@ -406,14 +500,22 @@ class IndexTest(NamedTuple): IndexTest( types=[byte_t, dynamic_array_t1, tuple_t, dynamic_array_t2], typeIndex=3, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, start_index=pt.ExtractUint16(encoded, pt.Int(4)) + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(4)) ), ), IndexTest( types=[byte_t, dynamic_array_t2, bool_t, bool_t, dynamic_array_t2], typeIndex=1, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( + encoded, + start_index=pt.ExtractUint16(encoded, pt.Int(1)), + end_index=pt.ExtractUint16(encoded, pt.Int(4)), + ), + expected_encode=substring_for_decoding( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(1)), end_index=pt.ExtractUint16(encoded, pt.Int(4)), @@ -422,41 +524,68 @@ class IndexTest(NamedTuple): IndexTest( types=[byte_t, dynamic_array_t1, bool_t, bool_t, dynamic_array_t2], typeIndex=4, - expected=lambda output: output.decode( + expected_store=lambda output: output.decode( encoded, start_index=pt.ExtractUint16(encoded, pt.Int(4)) ), + expected_encode=substring_for_decoding( + encoded, + start_index=pt.ExtractUint16(encoded, pt.Int(4)), + ), ), ] for i, test in enumerate(tests): output = test.types[test.typeIndex].new_instance() - expr = _index_tuple(test.types, encoded, test.typeIndex, output) - assert expr.type_of() == pt.TealType.none - assert not expr.has_return() + expr_store = _IndexTuple(test.types, encoded).get_or_store( + test.typeIndex, output + ) + assert expr_store.type_of() == pt.TealType.none + assert not expr_store.has_return() - expected, _ = test.expected(output).__teal__(options) - expected.addIncoming() - expected = pt.TealBlock.NormalizeBlocks(expected) + expected_store, _ = test.expected_store(output).__teal__(options) + expected_store.addIncoming() + expected_store = pt.TealBlock.NormalizeBlocks(expected_store) - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = pt.TealBlock.NormalizeBlocks(actual) + actual_store, _ = expr_store.__teal__(options) + actual_store.addIncoming() + actual_store = pt.TealBlock.NormalizeBlocks(actual_store) with pt.TealComponent.Context.ignoreExprEquality(): - assert actual == expected, "Test at index {} failed".format(i) + assert actual_store == expected_store, f"Test at index {i} failed" with pytest.raises(ValueError): - _index_tuple(test.types, encoded, len(test.types), output) + _IndexTuple(test.types, encoded).get_or_store(len(test.types), output) with pytest.raises(ValueError): - _index_tuple(test.types, encoded, -1, output) + _IndexTuple(test.types, encoded).get_or_store(-1, output) otherType = abi.Uint64() if output.type_spec() == otherType.type_spec(): otherType = abi.Uint16() with pytest.raises(TypeError): - _index_tuple(test.types, encoded, test.typeIndex, otherType) + _IndexTuple(test.types, encoded).get_or_store(test.typeIndex, otherType) + + expr_encode = _IndexTuple(test.types, encoded).get_or_store(test.typeIndex) + assert expr_encode.type_of() == pt.TealType.bytes + assert not expr_encode.has_return() + + expected_encode, _ = test.expected_encode.__teal__(options) + expected_encode.addIncoming() + expected_encode = pt.TealBlock.NormalizeBlocks(expected_encode) + + actual_encode, _ = expr_encode.__teal__(options) + actual_encode.addIncoming() + actual_encode = pt.TealBlock.NormalizeBlocks(actual_encode) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert actual_encode == expected_encode, f"Test at index {i} failed" + + with pytest.raises(ValueError): + _IndexTuple(test.types, encoded).get_or_store(len(test.types)) + + with pytest.raises(ValueError): + _IndexTuple(test.types, encoded).get_or_store(-1) def test_TupleTypeSpec_eq(): @@ -485,7 +614,7 @@ def test_TupleTypeSpec_value_type_specs(): def test_TupleTypeSpec_length_static(): - tests: List[List[abi.TypeSpec]] = [ + tests: list[list[abi.TypeSpec]] = [ [], [abi.Uint64TypeSpec()], [ @@ -751,7 +880,7 @@ def test_Tuple_encode(): def test_Tuple_length(): - tests: List[List[abi.TypeSpec]] = [ + tests: list[list[abi.TypeSpec]] = [ [], [abi.Uint64TypeSpec()], [ @@ -779,7 +908,7 @@ def test_Tuple_length(): def test_Tuple_getitem(): - tests: List[List[abi.TypeSpec]] = [ + tests: list[list[abi.TypeSpec]] = [ [], [abi.Uint64TypeSpec()], [ @@ -805,7 +934,7 @@ def test_Tuple_getitem(): def test_TupleElement_store_into(): - tests: List[List[abi.TypeSpec]] = [ + tests: list[list[abi.TypeSpec]] = [ [], [abi.Uint64TypeSpec()], [ @@ -825,7 +954,9 @@ def test_TupleElement_store_into(): assert expr.type_of() == pt.TealType.none assert not expr.has_return() - expectedExpr = _index_tuple(test, tupleValue.encode(), j, output) + expectedExpr = _IndexTuple(test, tupleValue.encode()).get_or_store( + j, output + ) expected, _ = expectedExpr.__teal__(options) expected.addIncoming() expected = pt.TealBlock.NormalizeBlocks(expected) diff --git a/pyteal/ast/abi/type.py b/pyteal/ast/abi/type.py index 3720154ab..6e4d7bbae 100644 --- a/pyteal/ast/abi/type.py +++ b/pyteal/ast/abi/type.py @@ -177,6 +177,14 @@ def store_into(self, output: T_co) -> Expr: # type: ignore[misc] """ pass + def encode(self) -> Expr: + """Get the encoding bytes of the value. + + Returns: + An expression which represents the computed value. + """ + return self.use(lambda value: value.encode()) # type: ignore + def use(self, action: Callable[[T_co], Expr]) -> Expr: """Compute the value and pass it to a callable expression. diff --git a/pyteal/ast/abi/util.py b/pyteal/ast/abi/util.py index 36b828390..127c3d2d0 100644 --- a/pyteal/ast/abi/util.py +++ b/pyteal/ast/abi/util.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from typing import ( Any, Literal, @@ -13,7 +14,11 @@ import algosdk.abi from pyteal.errors import TealInputError +from pyteal.types import require_type, TealType from pyteal.ast.expr import Expr +from pyteal.ast.bytes import Bytes +from pyteal.ast.binaryexpr import GetBit +from pyteal.ast.ternaryexpr import SetBit from pyteal.ast.int import Int from pyteal.ast.substring import Extract, Substring, Suffix from pyteal.ast.abi.type import TypeSpec, BaseType @@ -593,3 +598,50 @@ def type_spec_is_assignable_to(a: TypeSpec, b: TypeSpec) -> bool: return True return False + + +@dataclass +class _GetAgainstEncoding: + full_encoding: Expr + type_spec: TypeSpec + start_index: Expr | None = field(kw_only=True, default=None) + end_index: Expr | None = field(kw_only=True, default=None) + length: Expr | None = field(kw_only=True, default=None) + + def __post_init__(self): + from pyteal.ast.abi import BoolTypeSpec + + require_type(self.full_encoding, TealType.bytes) + if self.type_spec == BoolTypeSpec(): + require_type(self.start_index, TealType.uint64) + + def get_or_store(self, output: BaseType | None = None) -> Expr: + from pyteal.ast.abi import BoolTypeSpec, Bool + + match self.type_spec: + case BoolTypeSpec(): + if output is None: + return SetBit( + Bytes(b"\x00"), + Int(0), + GetBit(self.full_encoding, cast(Expr, self.start_index)), + ) + else: + return cast(Bool, output).decode_bit( + self.full_encoding, cast(Expr, self.start_index) + ) + case _: + if output is None: + return substring_for_decoding( + encoded=self.full_encoding, + start_index=self.start_index, + end_index=self.end_index, + length=self.length, + ) + else: + return output.decode( + self.full_encoding, + start_index=self.start_index, + end_index=self.end_index, + length=self.length, + ) diff --git a/pyteal/ast/abi/util_test.py b/pyteal/ast/abi/util_test.py index 2b9dfd73b..9190cc560 100644 --- a/pyteal/ast/abi/util_test.py +++ b/pyteal/ast/abi/util_test.py @@ -1,4 +1,5 @@ -from typing import Callable, NamedTuple, Literal, Optional, Any, get_origin +from dataclasses import dataclass, field +from typing import Callable, NamedTuple, Literal, Optional, Any, get_origin, cast from inspect import isabstract import pytest @@ -11,6 +12,7 @@ int_literal_from_annotation, type_spec_from_algosdk, type_spec_is_assignable_to, + _GetAgainstEncoding, ) options = pt.CompileOptions(version=5) @@ -1057,3 +1059,146 @@ def exists_in_unsafe_bidirectional(_ts: type): assert not exists_in_unsafe_bidirectional(ts) else: assert exists_in_unsafe_bidirectional(ts) + + +@dataclass +class GetAgainstEncodingTestcase: + type_spec: abi.TypeSpec + expected_store: Callable[[abi.BaseType, pt.Expr], pt.Expr] + expected_encode: Callable[[pt.Expr], pt.Expr] + start_index: Optional[pt.Int] = field(kw_only=True, default=None) + end_index: Optional[pt.Int] = field(kw_only=True, default=None) + length: Optional[pt.Int] = field(kw_only=True, default=None) + + +@pytest.mark.parametrize( + "testcase", + [ + GetAgainstEncodingTestcase( + type_spec=abi.Uint64TypeSpec(), + length=pt.Int(8), + expected_store=lambda output, encoded: output.decode( + encoded, length=pt.Int(8) + ), + expected_encode=lambda encoded: substring_for_decoding( + encoded, length=pt.Int(8) + ), + ), + GetAgainstEncodingTestcase( + type_spec=abi.Uint64TypeSpec(), + start_index=pt.Int(8), + expected_store=lambda output, encoded: output.decode( + encoded, start_index=pt.Int(8) + ), + expected_encode=lambda encoded: substring_for_decoding( + encoded, start_index=pt.Int(8) + ), + ), + GetAgainstEncodingTestcase( + type_spec=abi.ByteTypeSpec(), + start_index=pt.Int(8), + length=pt.Int(1), + expected_store=lambda output, encoded: output.decode( + encoded, start_index=pt.Int(8), length=pt.Int(1) + ), + expected_encode=lambda encoded: substring_for_decoding( + encoded, start_index=pt.Int(8), length=pt.Int(1) + ), + ), + GetAgainstEncodingTestcase( + type_spec=abi.BoolTypeSpec(), + start_index=pt.Int(0), + expected_store=lambda output, encoded: cast(abi.Bool, output).decode_bit( + encoded, pt.Int(0) + ), + expected_encode=lambda encoded: pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(0)), + ), + ), + GetAgainstEncodingTestcase( + type_spec=abi.BoolTypeSpec(), + start_index=pt.Int(1), + expected_store=lambda output, encoded: cast(abi.Bool, output).decode_bit( + encoded, pt.Int(1) + ), + expected_encode=lambda encoded: pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(1)), + ), + ), + GetAgainstEncodingTestcase( + type_spec=abi.BoolTypeSpec(), + start_index=pt.Int(64), + expected_store=lambda output, encoded: cast(abi.Bool, output).decode_bit( + encoded, pt.Int(64) + ), + expected_encode=lambda encoded: pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(64)), + ), + ), + GetAgainstEncodingTestcase( + type_spec=abi.BoolTypeSpec(), + start_index=pt.Int(65), + expected_store=lambda output, encoded: cast(abi.Bool, output).decode_bit( + encoded, pt.Int(8 * 8 + 1) + ), + expected_encode=lambda encoded: pt.SetBit( + pt.Bytes(b"\x00"), + pt.Int(0), + pt.GetBit(encoded, pt.Int(65)), + ), + ), + ], +) +def test_get_against_encoding(testcase: GetAgainstEncodingTestcase): + encoded = pt.Bytes("encoded") + + output_expr = testcase.type_spec.new_instance() + expected_store_expr = testcase.expected_store(output_expr, encoded) + expected_encode_expr = testcase.expected_encode(encoded) + + get_against_encoding = _GetAgainstEncoding( + encoded, + testcase.type_spec, + start_index=testcase.start_index, + end_index=testcase.end_index, + length=testcase.length, + ) + actual_store_expr = get_against_encoding.get_or_store(output_expr) + actual_encode_expr = get_against_encoding.get_or_store() + + def expected_actual_assert(expected: pt.Expr, actual: pt.Expr): + expected_blocks, _ = expected.__teal__(options) + expected_blocks.addIncoming() + expected_blocks = pt.TealBlock.NormalizeBlocks(expected_blocks) + + actual_blocks, _ = actual.__teal__(options) + actual_blocks.addIncoming() + actual_blocks = pt.TealBlock.NormalizeBlocks(actual_blocks) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert actual_blocks == expected_blocks + + expected_actual_assert(expected_store_expr, actual_store_expr) + expected_actual_assert(expected_encode_expr, actual_encode_expr) + + +def test_get_against_encoding_negative_cases(): + with pytest.raises(TypeError) as te: + _GetAgainstEncoding( + pt.Bytes("encoded"), + abi.BoolTypeSpec(), + ) + + assert "Expected a TealType.uint64 object" in str(te) + + with pytest.raises(pt.TealTypeError): + _GetAgainstEncoding( + pt.Int(123), + abi.BoolTypeSpec(), + )