Skip to content

Commit

Permalink
Make encode method generalized for ComputedType (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahangsu authored Jan 4, 2023
1 parent 93258bd commit 105e307
Show file tree
Hide file tree
Showing 7 changed files with 653 additions and 256 deletions.
63 changes: 41 additions & 22 deletions pyteal/ast/abi/array_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

from pyteal.ast.abi.type import TypeSpec, BaseType, ComputedValue
from pyteal.ast.abi.tuple import _encode_tuple
from pyteal.ast.abi.bool import Bool, BoolTypeSpec
from pyteal.ast.abi.bool import BoolTypeSpec
from pyteal.ast.abi.uint import Uint16, Uint16TypeSpec
from pyteal.ast.abi.util import substring_for_decoding
from pyteal.ast.abi.util import (
substring_for_decoding,
_GetAgainstEncoding,
)

T = TypeVar("T", bound=BaseType)

Expand Down Expand Up @@ -207,20 +210,8 @@ def __init__(self, array: Array[T], index: Expr) -> None:
def produced_type_spec(self) -> TypeSpec:
return self.array.type_spec().value_type_spec()

def store_into(self, output: T) -> Expr:
"""Partitions the byte string of the given ABI array and stores the byte string of array
element in the ABI value output.
The function first checks if the output type matches with array element type, and throw
error if type-mismatch.
Args:
output: An ABI typed value that the array element byte string stores into.
Returns:
An expression that stores the byte string of the array element into value `output`.
"""
if output.type_spec() != self.produced_type_spec():
def __prototype_encoding_store_into(self, output: T | None = None) -> Expr:
if output is not None and output.type_spec() != self.produced_type_spec():
raise TealInputError("Output type does not match value type")

encodedArray = self.array.encode()
Expand All @@ -229,11 +220,13 @@ def store_into(self, output: T) -> Expr:
# If the array element type is Bool, we compute the bit index
# (if array is dynamic we add 16 to bit index for dynamic array length uint16 prefix)
# and decode bit with given array encoding and the bit index for boolean bit.
if output.type_spec() == BoolTypeSpec():
if self.array.type_spec().value_type_spec() == BoolTypeSpec():
bitIndex = self.index
if arrayType.is_dynamic():
bitIndex = bitIndex + Int(Uint16TypeSpec().bit_size())
return cast(Bool, output).decode_bit(encodedArray, bitIndex)
return _GetAgainstEncoding(
encodedArray, BoolTypeSpec(), start_index=bitIndex
).get_or_store(output)

# Compute the byteIndex (first byte indicating the element encoding)
# (If the array is dynamic, add 2 to byte index for dynamic array length uint16 prefix)
Expand Down Expand Up @@ -271,16 +264,42 @@ def store_into(self, output: T) -> Expr:
.Else(nextValueStart)
)

return output.decode(
encodedArray, start_index=valueStart, end_index=valueEnd
)
return _GetAgainstEncoding(
encodedArray,
arrayType.value_type_spec(),
start_index=valueStart,
end_index=valueEnd,
).get_or_store(output)

# Handling case for array elements are static:
# since array._stride() is element's static byte length
# we partition the substring for array element.
valueStart = byteIndex
valueLength = Int(arrayType._stride())
return output.decode(encodedArray, start_index=valueStart, length=valueLength)
return _GetAgainstEncoding(
encodedArray,
arrayType.value_type_spec(),
start_index=valueStart,
length=valueLength,
).get_or_store(output)

def store_into(self, output: T) -> Expr:
"""Partitions the byte string of the given ABI array and stores the byte string of array
element in the ABI value output.
The function first checks if the output type matches with array element type, and throw
error if type-mismatch.
Args:
output: An ABI typed value that the array element byte string stores into.
Returns:
An expression that stores the byte string of the array element into value `output`.
"""
return self.__prototype_encoding_store_into(output)

def encode(self) -> Expr:
return self.__prototype_encoding_store_into()


ArrayElement.__module__ = "pyteal.abi"
102 changes: 102 additions & 0 deletions pyteal/ast/abi/array_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,105 @@ def test_ArrayElement_store_into():

with pytest.raises(pt.TealInputError):
element.store_into(abi.Tuple(abi.TupleTypeSpec(elementType)))


def test_ArrayElement_encoding():
from pyteal.ast.abi.util import substring_for_decoding

for elementType in STATIC_TYPES + DYNAMIC_TYPES:
staticArrayType = abi.StaticArrayTypeSpec(elementType, 100)
staticArray = staticArrayType.new_instance()
index = pt.Int(9)

element = abi.ArrayElement(staticArray, index)
expr = element.encode()

encoded = staticArray.encode()
stride = pt.Int(staticArray.type_spec()._stride())
expectedLength = staticArray.length()
if elementType == abi.BoolTypeSpec():
expectedExpr = pt.SetBit(
pt.Bytes(b"\x00"),
pt.Int(0),
pt.GetBit(encoded, index),
)
elif not elementType.is_dynamic():
expectedExpr = substring_for_decoding(
encoded, start_index=stride * index, length=stride
)
else:
expectedExpr = substring_for_decoding(
encoded,
start_index=pt.ExtractUint16(encoded, stride * index),
end_index=pt.If(index + pt.Int(1) == expectedLength)
.Then(pt.Len(encoded))
.Else(pt.ExtractUint16(encoded, stride * index + pt.Int(2))),
)

expected, _ = expectedExpr.__teal__(options)
expected.addIncoming()
expected = pt.TealBlock.NormalizeBlocks(expected)

actual, _ = expr.__teal__(options)
actual.addIncoming()
actual = pt.TealBlock.NormalizeBlocks(actual)

with pt.TealComponent.Context.ignoreExprEquality():
assert actual == expected

with pytest.raises(pt.TealInputError):
element.store_into(abi.Tuple(abi.TupleTypeSpec(elementType)))

for elementType in STATIC_TYPES + DYNAMIC_TYPES:
dynamicArrayType = abi.DynamicArrayTypeSpec(elementType)
dynamicArray = dynamicArrayType.new_instance()
index = pt.Int(9)

element = abi.ArrayElement(dynamicArray, index)
expr = element.encode()

encoded = dynamicArray.encode()
stride = pt.Int(dynamicArray.type_spec()._stride())
expectedLength = dynamicArray.length()
if elementType == abi.BoolTypeSpec():
expectedExpr = pt.SetBit(
pt.Bytes(b"\x00"),
pt.Int(0),
pt.GetBit(encoded, index + pt.Int(16)),
)
elif not elementType.is_dynamic():
expectedExpr = substring_for_decoding(
encoded, start_index=stride * index + pt.Int(2), length=stride
)
else:
expectedExpr = substring_for_decoding(
encoded,
start_index=pt.ExtractUint16(encoded, stride * index + pt.Int(2))
+ pt.Int(2),
end_index=pt.If(index + pt.Int(1) == expectedLength)
.Then(pt.Len(encoded))
.Else(
pt.ExtractUint16(encoded, stride * index + pt.Int(2) + pt.Int(2))
+ pt.Int(2)
),
)

expected, _ = expectedExpr.__teal__(options)
expected.addIncoming()
expected = pt.TealBlock.NormalizeBlocks(expected)

actual, _ = expr.__teal__(options)
actual.addIncoming()
actual = pt.TealBlock.NormalizeBlocks(actual)

with pt.TealComponent.Context.ignoreExprEquality():
with pt.TealComponent.Context.ignoreScratchSlotEquality():
assert actual == expected

assert pt.TealBlock.MatchScratchSlotReferences(
pt.TealBlock.GetReferencedScratchSlots(actual),
pt.TealBlock.GetReferencedScratchSlots(expected),
)

with pytest.raises(pt.TealInputError):
element.store_into(abi.Tuple(abi.TupleTypeSpec(elementType)))
Loading

0 comments on commit 105e307

Please sign in to comment.