diff --git a/pyteal/ast/abi/array_base.py b/pyteal/ast/abi/array_base.py index c7faef272..638fb83a2 100644 --- a/pyteal/ast/abi/array_base.py +++ b/pyteal/ast/abi/array_base.py @@ -24,7 +24,7 @@ from pyteal.ast.abi.uint import Uint16, Uint16TypeSpec from pyteal.ast.abi.util import ( substring_for_decoding, - _get_or_store_encoded_bytes, + _GetAgainstEncoding, ) T = TypeVar("T", bound=BaseType) @@ -224,9 +224,9 @@ def __prototype_encoding_store_into(self, output: T | None = None) -> Expr: bitIndex = self.index if arrayType.is_dynamic(): bitIndex = bitIndex + Int(Uint16TypeSpec().bit_size()) - return _get_or_store_encoded_bytes( - BoolTypeSpec(), encodedArray, output, start_index=bitIndex - ) + return _GetAgainstEncoding( + BoolTypeSpec(), encodedArray, start_index=bitIndex + ).get_or_store(output) # Compute the byteIndex (first byte indicating the element encoding) # (If the array is dynamic, add 2 to byte index for dynamic array length uint16 prefix) @@ -264,26 +264,24 @@ def __prototype_encoding_store_into(self, output: T | None = None) -> Expr: .Else(nextValueStart) ) - return _get_or_store_encoded_bytes( + return _GetAgainstEncoding( arrayType.value_type_spec(), encodedArray, - output, start_index=valueStart, end_index=valueEnd, - ) + ).get_or_store(output) # Handling case for array elements are static: # since array._stride() is element's static byte length # we partition the substring for array element. valueStart = byteIndex valueLength = Int(arrayType._stride()) - return _get_or_store_encoded_bytes( + return _GetAgainstEncoding( arrayType.value_type_spec(), encodedArray, - output, start_index=valueStart, length=valueLength, - ) + ).get_or_store(output) def store_into(self, output: T) -> Expr: """Partitions the byte string of the given ABI array and stores the byte string of array diff --git a/pyteal/ast/abi/tuple.py b/pyteal/ast/abi/tuple.py index 7091088bd..142e98ab0 100644 --- a/pyteal/ast/abi/tuple.py +++ b/pyteal/ast/abi/tuple.py @@ -40,7 +40,7 @@ from pyteal.ast.abi.util import ( substring_for_decoding, type_spec_from_annotation, - _get_or_store_encoded_bytes, + _GetAgainstEncoding, ) @@ -171,12 +171,11 @@ def __call__(self, index: int, output: BaseType | None = None) -> Expr: # value is the beginning of a bool sequence (or a single bool) bitOffsetInEncoded = offset * NUM_BITS_IN_BYTE - return _get_or_store_encoded_bytes( + return _GetAgainstEncoding( BoolTypeSpec(), self.encoded, - output, start_index=Int(bitOffsetInEncoded), - ) + ).get_or_store(output) if valueType.is_dynamic(): hasNextDynamicValue = False @@ -205,20 +204,19 @@ def __call__(self, index: int, output: BaseType | None = None) -> Expr: if not hasNextDynamicValue: # This is the final dynamic value, so decode the substring from start_index to the end of # encoded - return _get_or_store_encoded_bytes( - valueType, self.encoded, output, start_index=start_index - ) + return _GetAgainstEncoding( + valueType, self.encoded, start_index=start_index + ).get_or_store(output) # There is a dynamic value after this one, and end_index is where its tail starts, so decode # the substring from start_index to end_index end_index = ExtractUint16(self.encoded, Int(nextDynamicValueOffset)) - return _get_or_store_encoded_bytes( + return _GetAgainstEncoding( valueType, self.encoded, - output, start_index=start_index, end_index=end_index, - ) + ).get_or_store(output) start_index = Int(offset) length = Int(valueType.byte_length_static()) @@ -232,20 +230,20 @@ def __call__(self, index: int, output: BaseType | None = None) -> Expr: return output.decode(self.encoded) # This is the last value in the tuple, so decode the substring from start_index to the end of # encoded - return _get_or_store_encoded_bytes( - valueType, self.encoded, output, start_index=start_index - ) + return _GetAgainstEncoding( + valueType, self.encoded, start_index=start_index + ).get_or_store(output) if offset == 0: # This is the first value in the tuple, so decode the substring from 0 with length length - return _get_or_store_encoded_bytes( - valueType, self.encoded, output, length=length - ) + return _GetAgainstEncoding( + valueType, self.encoded, length=length + ).get_or_store(output) # This is not the first or last value, so decode the substring from start_index with length length - return _get_or_store_encoded_bytes( - valueType, self.encoded, output, start_index=start_index, length=length - ) + return _GetAgainstEncoding( + valueType, self.encoded, start_index=start_index, length=length + ).get_or_store(output) class TupleTypeSpec(TypeSpec): diff --git a/pyteal/ast/abi/util.py b/pyteal/ast/abi/util.py index 7cf893377..a85370698 100644 --- a/pyteal/ast/abi/util.py +++ b/pyteal/ast/abi/util.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from typing import ( Any, Literal, @@ -599,46 +600,50 @@ def type_spec_is_assignable_to(a: TypeSpec, b: TypeSpec) -> bool: return False -def _get_or_store_encoded_bytes( - encoding_type: TypeSpec, - full_encoding: Expr, - output: BaseType | None = None, - *, - start_index: Expr | None = None, - end_index: Expr | None = None, - length: Expr | None = None, -) -> Expr: - from pyteal.ast.abi import BoolTypeSpec, Bool +@dataclass +class _GetAgainstEncoding: + type_spec: TypeSpec + full_encoding: Expr + start_index: Expr | None = field(kw_only=True, default=None) + end_index: Expr | None = field(kw_only=True, default=None) + length: Expr | None = field(kw_only=True, default=None) - require_type(full_encoding, TealType.bytes) + def __post_init__(self): + from pyteal.ast.abi import BoolTypeSpec - match encoding_type: - case BoolTypeSpec(): - if start_index is None: - raise TealInputError( - "on BoolTypeSpec, requiring start index to be not None." - ) + require_type(self.full_encoding, TealType.bytes) + if self.type_spec == BoolTypeSpec() and self.start_index is None: + raise TealInputError( + "on BoolTypeSpec, requiring start index to be not None." + ) - if output is None: - return SetBit( - Bytes(b"\x00"), - Int(0), - GetBit(full_encoding, start_index), - ) - else: - return cast(Bool, output).decode_bit(full_encoding, start_index) - case _: - if output is None: - return substring_for_decoding( - encoded=full_encoding, - start_index=start_index, - end_index=end_index, - length=length, - ) - else: - return output.decode( - full_encoding, - start_index=start_index, - end_index=end_index, - length=length, - ) + def get_or_store(self, output: BaseType | None = None) -> Expr: + from pyteal.ast.abi import BoolTypeSpec, Bool + + match self.type_spec: + case BoolTypeSpec(): + if output is None: + return SetBit( + Bytes(b"\x00"), + Int(0), + GetBit(self.full_encoding, cast(Expr, self.start_index)), + ) + else: + return cast(Bool, output).decode_bit( + self.full_encoding, cast(Expr, self.start_index) + ) + case _: + if output is None: + return substring_for_decoding( + encoded=self.full_encoding, + start_index=self.start_index, + end_index=self.end_index, + length=self.length, + ) + else: + return output.decode( + self.full_encoding, + start_index=self.start_index, + end_index=self.end_index, + length=self.length, + )