From 67cb7acfa7460df43b87c809f087d4623c308827 Mon Sep 17 00:00:00 2001 From: Eduard S Date: Mon, 23 May 2022 16:12:12 +0200 Subject: [PATCH] Add MPT table lookup to state circuit (#200) * Add MPT table lookup to state circuit * Update following discussion * Update tables layout in EVM spec * Revert AccountFieldTag value changes --- specs/tables.md | 114 ++++++------ src/zkevm_specs/evm/instruction.py | 24 ++- src/zkevm_specs/evm/table.py | 34 +++- src/zkevm_specs/evm/typing.py | 34 ++-- src/zkevm_specs/state.py | 288 ++++++++++++++++++++++------- tests/test_state_circuit.py | 236 ++++++++++++++--------- 6 files changed, 481 insertions(+), 249 deletions(-) diff --git a/specs/tables.md b/specs/tables.md index 7f0d6da98..3a0f90c18 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -57,63 +57,63 @@ Type sizes: - **TxReceipt -> CumulativeGasUsed**, 8 byte -| 0 *rwc* | 1 *isWrite* | 2 *Key0 (Tag)* | 3 *Key1* | 4 *Key2* | 5 *Key3* | 6 *Key4* | 7 *Value0* | 8 *Value1* | 9 *Aux0* | 10 *Aux1* | -| -------- | ----------- | -------------------------- | -------- | -------- | -------------------------- | ----------- | --------- | ---------- | -------- | --------------- | -| | | *RwTableTag* | | | | | | | | | -| $counter | true | TxAccessListAccount | $txID | $address | | | $value | $valuePrev | 0 | 0 | -| $counter | true | TxAccessListAccountStorage | $txID | $address | | $storageKey | $value | $valuePrev | | 0 | -| $counter | $isWrite | TxRefund | $txID | | | | $value | $valuePrev | 0 | 0 | -| | | | | | | | | | | | -| | | | | | *AccountFieldTag* | | | | | | -| $counter | $isWrite | Account | | $address | Nonce | | $value | $valuePrev | 0 | 0 | -| $counter | $isWrite | Account | | $address | Balance | | $value | $valuePrev | 0 | 0 | -| $counter | $isWrite | Account | | $address | CodeHash | | $value | $valuePrev | 0 | 0 | -| $counter | true | AccountDestructed | | $address | | | $value | $valuePrev | 0 | 0 | -| | | | | | | | | | | | -| | | *CallContext constant* | | | *CallContextFieldTag* (ro) | | | | | | -| $counter | false | CallContext | $callID | | RwCounterEndOfReversion | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | CallerId | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | TxId | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | Depth | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | CallerAddress | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | CalleeAddress | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | CallDataOffset | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | CallDataLength | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | ReturnDataOffset | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | ReturnDataLength | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | Value | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | IsSuccess | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | IsPersistent | | $value | 0 | 0 | 0 | -| $counter | false | CallContext | $callID | | IsStatic | | $value | 0 | 0 | 0 | -| | | | | | | | | | | | -| | | *CallContext last callee* | | | *CallContextFieldTag* (rw) | | | | | | -| $counter | $isWrite | CallContext | $callID | | LastCalleeId | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataOffset | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataLength | | $value | 0 | 0 | 0 | -| | | | | | | | | | | | -| | | *CallContext state* | | | *CallContextFieldTag* (rw) | | | | | | -| $counter | $isWrite | CallContext | $callID | | IsRoot | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | IsCreate | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | CodeSource | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | ProgramCounter | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | StackPointer | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | GasLeft | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | MemorySize | | $value | 0 | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | ReversibleWriteCounter | | $value | 0 | 0 | 0 | -| | | | | | | | | | | | -| $counter | $isWrite | Stack | $callID | $stackPointer | | | $value | 0 | 0 | 0 | -| $counter | $isWrite | Memory | $callID | $memoryAddress | | | $value | 0 | 0 | 0 | -| $counter | $isWrite | AccountStorage | | $address | | $storageKey | $value | $valuePrev | $txID | $CommittedValue | -| | | | | | | | | | | | -| $counter | true | TxLog |$txID | $logID | Address | 0 | $value | 0 | 0 | 0 | -| $counter | true | TxLog |$txID | $logID | Topic | $topicIndex | $value | 0 | 0 | 0 | -| $counter | true | TxLog |$txID | $logID | Data | $byteIndex | $value | 0 | 0 | 0 | -| $counter | true | TxLog |$txID | $logID | TopicLength | 0 | $value | 0 | 0 | 0 | -| $counter | true | TxLog |$txID | $logID | DataLength | 0 | $value | 0 | 0 | 0 | -| | | | | | | | | | | | -| $counter | false | TxReceipt |$txID | 0 | PostStateOrStatus | 0 | $value | 0 | 0 | 0 | -| $counter | false | TxReceipt |$txID | 0 | CumulativeGasUsed | 0 | $value | 0 | 0 | 0 | -| $counter | false | TxReceipt |$txID | 0 | LogLength | 0 | $value | 0 | 0 | 0 | +| 0 *rwc* | 1 *isWrite* | 2 *Key0 (Tag)* | 3 *Key1* | 4 *Key2* | 5 *Key3* | 6 *Key4* | 7 *Value0* | 8 *Value1* | 9 *Aux0* | +| -------- | ----------- | -------------------------- | -------- | -------- | -------------------------- | ----------- | --------- | ---------- | --------------- | +| | | *RwTableTag* | | | | | | | | +| $counter | true | TxAccessListAccount | $txID | $address | | | $value | $valuePrev | 0 | +| $counter | true | TxAccessListAccountStorage | $txID | $address | | $storageKey | $value | $valuePrev | 0 | +| $counter | $isWrite | TxRefund | $txID | | | | $value | $valuePrev | 0 | +| | | | | | | | | | | +| | | | | | *AccountFieldTag* | | | | | +| $counter | $isWrite | Account | | $address | Nonce | | $value | $valuePrev | $committedValue | +| $counter | $isWrite | Account | | $address | Balance | | $value | $valuePrev | $committedValue | +| $counter | $isWrite | Account | | $address | CodeHash | | $value | $valuePrev | $committedValue | +| $counter | true | AccountDestructed | | $address | | | $value | $valuePrev | 0 | +| | | | | | | | | | | +| | | *CallContext constant* | | | *CallContextFieldTag* (ro) | | | | | +| $counter | false | CallContext | $callID | | RwCounterEndOfReversion | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | CallerId | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | TxId | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | Depth | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | CallerAddress | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | CalleeAddress | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | CallDataOffset | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | CallDataLength | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | ReturnDataOffset | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | ReturnDataLength | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | Value | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | IsSuccess | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | IsPersistent | | $value | 0 | 0 | +| $counter | false | CallContext | $callID | | IsStatic | | $value | 0 | 0 | +| | | | | | | | | | | +| | | *CallContext last callee* | | | *CallContextFieldTag* (rw) | | | | | +| $counter | $isWrite | CallContext | $callID | | LastCalleeId | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataOffset | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataLength | | $value | 0 | 0 | +| | | | | | | | | | | +| | | *CallContext state* | | | *CallContextFieldTag* (rw) | | | | | +| $counter | $isWrite | CallContext | $callID | | IsRoot | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | IsCreate | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | CodeSource | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | ProgramCounter | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | StackPointer | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | GasLeft | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | MemorySize | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | ReversibleWriteCounter | | $value | 0 | 0 | +| | | | | | | | | | | +| $counter | $isWrite | Stack | $callID | $stackPointer | | | $value | 0 | 0 | +| $counter | $isWrite | Memory | $callID | $memoryAddress | | | $value | 0 | 0 | +| $counter | $isWrite | AccountStorage | $txID | $address | | $storageKey | $value | $valuePrev | $committedValue | +| | | | | | | | | | | +| $counter | true | TxLog | $txID | $logID | Address | 0 | $value | 0 | 0 | +| $counter | true | TxLog | $txID | $logID | Topic | $topicIndex | $value | 0 | 0 | +| $counter | true | TxLog | $txID | $logID | Data | $byteIndex | $value | 0 | 0 | +| $counter | true | TxLog | $txID | $logID | TopicLength | 0 | $value | 0 | 0 | +| $counter | true | TxLog | $txID | $logID | DataLength | 0 | $value | 0 | 0 | +| | | | | | | | | | | +| $counter | false | TxReceipt | $txID | 0 | PostStateOrStatus | 0 | $value | 0 | 0 | +| $counter | false | TxReceipt | $txID | 0 | CumulativeGasUsed | 0 | $value | 0 | 0 | +| $counter | false | TxReceipt | $txID | 0 | LogLength | 0 | $value | 0 | 0 | ## `bytecode_table` diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 9ae39bea8..a611f3055 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -522,7 +522,6 @@ def rw_lookup( value: Expression = None, value_prev: Expression = None, aux0: Expression = None, - aux1: Expression = None, rw_counter: Expression = None, ) -> RWTableRow: if rw_counter is None: @@ -540,7 +539,6 @@ def rw_lookup( value, value_prev, aux0, - aux1, ) def state_write( @@ -553,12 +551,11 @@ def state_write( value: Expression = None, value_prev: Expression = None, aux0: Expression = None, - aux1: Expression = None, reversion_info: ReversionInfo = None, ) -> RWTableRow: assert tag.write_with_reversion() - row = self.rw_lookup(RW.Write, tag, key1, key2, key3, key4, value, value_prev, aux0, aux1) + row = self.rw_lookup(RW.Write, tag, key1, key2, key3, key4, value, value_prev, aux0) if reversion_info is not None and reversion_info.is_persistent == FQ(0): self.tables.rw_lookup( @@ -573,7 +570,6 @@ def state_write( value=row.value_prev, value_prev=row.value, aux0=row.aux0, - aux1=row.aux1, ) return row @@ -645,7 +641,7 @@ def tx_refund_write( def account_read(self, account_address: Expression, account_field_tag: AccountFieldTag) -> RLC: return cast_expr( self.rw_lookup( - RW.Read, RWTableTag.Account, account_address, FQ(account_field_tag) + RW.Read, RWTableTag.Account, key2=account_address, key3=FQ(account_field_tag) ).value, RLC, ) @@ -658,8 +654,8 @@ def account_write( ) -> Tuple[Expression, Expression]: row = self.state_write( RWTableTag.Account, - account_address, - FQ(account_field_tag), + key2=account_address, + key3=FQ(account_field_tag), reversion_info=reversion_info, ) return row.value, row.value_prev @@ -700,9 +696,10 @@ def account_storage_read( row = self.rw_lookup( RW.Read, RWTableTag.AccountStorage, + tx_id, account_address, - storage_key, - aux0=tx_id, + key3=None, + key4=storage_key, ) return cast_expr(row.value, RLC) @@ -715,12 +712,13 @@ def account_storage_write( ) -> Tuple[RLC, RLC, RLC]: row = self.state_write( RWTableTag.AccountStorage, + tx_id, account_address, - storage_key, - aux0=tx_id, + key3=None, + key4=storage_key, reversion_info=reversion_info, ) - return cast_expr(row.value, RLC), cast_expr(row.value_prev, RLC), cast_expr(row.aux1, RLC) + return cast_expr(row.value, RLC), cast_expr(row.value_prev, RLC), cast_expr(row.aux0, RLC) def add_account_to_access_list( self, tx_id: Expression, account_address: Expression, reversion_info: ReversionInfo = None diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 3e0b692e9..8a7d71c9f 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -144,6 +144,17 @@ class RW(IntEnum): Write = 1 +class MPTTableTag(IntEnum): + """ + Tag for MPTTable lookup + """ + + Nonce = 1 + Balance = 2 + CodeHash = 4 + Storage = 8 + + class RWTableTag(IntEnum): """ Tag for RWTable lookup, where the RWTable an advice-column table built by @@ -340,7 +351,16 @@ class RWTableRow(TableRow): value: Expression = field(default=FQ(0)) value_prev: Expression = field(default=FQ(0)) aux0: Expression = field(default=FQ(0)) - aux1: Expression = field(default=FQ(0)) + + +@dataclass(frozen=True) +class MPTTableRow(TableRow): + counter: Expression + target: Expression # MPTTableTag + address: Expression + key: Expression + value: Expression + value_prev: Expression class Tables: @@ -391,7 +411,7 @@ def block_lookup( self, field_tag: Expression, block_number: Expression = FQ(0) ) -> BlockTableRow: query = {"field_tag": field_tag, "block_number_or_zero": block_number} - return _lookup(BlockTableRow, self.block_table, query) + return lookup(BlockTableRow, self.block_table, query) def tx_lookup( self, tx_id: Expression, field_tag: Expression, call_data_index: Expression = FQ(0) @@ -401,7 +421,7 @@ def tx_lookup( "field_tag": field_tag, "call_data_index_or_zero": call_data_index, } - return _lookup(TxTableRow, self.tx_table, query) + return lookup(TxTableRow, self.tx_table, query) def bytecode_lookup( self, @@ -416,7 +436,7 @@ def bytecode_lookup( "index": index, "is_code": is_code, } - return _lookup(BytecodeTableRow, self.bytecode_table, query) + return lookup(BytecodeTableRow, self.bytecode_table, query) def rw_lookup( self, @@ -430,7 +450,6 @@ def rw_lookup( value: Expression = None, value_prev: Expression = None, aux0: Expression = None, - aux1: Expression = None, ) -> RWTableRow: query = { "rw_counter": rw_counter, @@ -443,15 +462,14 @@ def rw_lookup( "value": value, "value_prev": value_prev, "aux0": aux0, - "aux1": aux1, } - return _lookup(RWTableRow, self.rw_table, query) + return lookup(RWTableRow, self.rw_table, query) T = TypeVar("T", bound=TableRow) -def _lookup( +def lookup( table_cls: Type[T], table: Set[T], query: Mapping[str, Optional[Expression]], diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index f09373525..9e0615e4f 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -403,9 +403,7 @@ def tx_receipt_read( RW.Read, RWTableTag.TxReceipt, key1=FQ(tx_id), - key2=FQ(0), key3=FQ(field_tag), - key4=FQ(0), value=value, ) @@ -473,8 +471,8 @@ def account_read( return self._append( RW.Read, RWTableTag.Account, - key1=FQ(account_address), - key2=FQ(field_tag), + key2=FQ(account_address), + key3=FQ(field_tag), value=value, value_prev=value, ) @@ -493,8 +491,8 @@ def account_write( value_prev = FQ(value_prev) return self._state_write( RWTableTag.Account, - key1=FQ(account_address), - key2=FQ(field_tag), + key2=FQ(account_address), + key3=FQ(field_tag), value=value, value_prev=value_prev, rw_counter_of_reversion=rw_counter_of_reversion, @@ -513,12 +511,12 @@ def account_storage_read( return self._append( RW.Read, RWTableTag.AccountStorage, - key1=FQ(account_address), - key2=storage_key, + key1=tx_id, + key2=FQ(account_address), + key4=storage_key, value=value, value_prev=value, - aux0=tx_id, - aux1=value_committed, + aux0=value_committed, ) def account_storage_write( @@ -535,12 +533,12 @@ def account_storage_write( tx_id = FQ(tx_id) return self._state_write( RWTableTag.AccountStorage, - key1=FQ(account_address), - key2=storage_key, + key1=tx_id, + key2=FQ(account_address), + key4=storage_key, value=value, value_prev=value_prev, - aux0=tx_id, - aux1=value_committed, + aux0=value_committed, rw_counter_of_reversion=rw_counter_of_reversion, ) @@ -550,10 +548,10 @@ def _state_write( key1: Expression = FQ(0), key2: Expression = FQ(0), key3: Expression = FQ(0), + key4: Expression = FQ(0), value: Expression = FQ(0), value_prev: Expression = FQ(0), aux0: Expression = FQ(0), - aux1: Expression = FQ(0), rw_counter_of_reversion: int = None, ) -> RWDictionary: self._append( @@ -562,10 +560,10 @@ def _state_write( key1=key1, key2=key2, key3=key3, + key4=key4, value=value, value_prev=value_prev, aux0=aux0, - aux1=aux1, ) if rw_counter_of_reversion is None: @@ -577,10 +575,10 @@ def _state_write( key1=key1, key2=key2, key3=key3, + key4=key4, value=value_prev, value_prev=value, aux0=aux0, - aux1=aux1, rw_counter=rw_counter_of_reversion, ) @@ -595,7 +593,6 @@ def _append( value: Expression = FQ(0), value_prev: Expression = FQ(0), aux0: Expression = FQ(0), - aux1: Expression = FQ(0), rw_counter: int = None, ) -> RWDictionary: if rw_counter is None: @@ -614,7 +611,6 @@ def _append( value, value_prev, aux0, - aux1, ) ) diff --git a/src/zkevm_specs/state.py b/src/zkevm_specs/state.py index ad2b925d9..6f235b069 100644 --- a/src/zkevm_specs/state.py +++ b/src/zkevm_specs/state.py @@ -1,10 +1,19 @@ -from typing import NamedTuple, Tuple, List, Sequence +from typing import NamedTuple, Tuple, List, Sequence, Set, Union, cast from enum import IntEnum from math import log, ceil -from .util import FQ, RLC, U160, U256 +from .util import FQ, RLC, U160, U256, Expression from .encoding import U8, is_circuit_code -from .evm import RW, AccountFieldTag, CallContextFieldTag, TxLogFieldTag, TxReceiptFieldTag +from .evm import ( + RW, + AccountFieldTag, + CallContextFieldTag, + TxLogFieldTag, + TxReceiptFieldTag, + MPTTableRow, + MPTTableTag, + lookup, +) MAX_MEMORY_ADDRESS = 2**32 - 1 MAX_KEY_DIFF = 2**32 - 1 @@ -61,13 +70,61 @@ class Row(NamedTuple): FQ,FQ,FQ,FQ,FQ,FQ,FQ,FQ, FQ,FQ,FQ,FQ,FQ,FQ,FQ,FQ] value: FQ - auxs: Tuple[FQ, FQ] + auxs: Tuple[FQ] + mpt_counter: FQ # fmt: on def tag(self): return self.keys[0] +class Tables: + """ + Tables used for lookup from the state circuit. + """ + + mpt_table: Set[MPTTableRow] + + def __init__(self, mpt_table: Set[MPTTableRow]): + self.mpt_table = mpt_table + + def mpt_account_lookup( + self, + counter: Expression, + target: Expression, + address: Expression, + value: Expression, + value_prev: Expression, + ) -> MPTTableRow: + query = { + "counter": counter, + "target": target, + "address": address, + "key": FQ(0), + "value": value, + "value_prev": value_prev, + } + return lookup(MPTTableRow, self.mpt_table, query) + + def mpt_storage_lookup( + self, + counter: Expression, + address: Expression, + key: Expression, + value: Expression, + value_prev: Expression, + ) -> MPTTableRow: + query = { + "counter": counter, + "target": FQ(MPTTableTag.Storage), + "address": address, + "key": key, + "value": value, + "value_prev": value_prev, + } + return lookup(MPTTableRow, self.mpt_table, query) + + def linear_combine(limbs: Sequence[FQ], base: FQ) -> FQ: ret = FQ.zero() for limb in reversed(limbs): @@ -108,6 +165,9 @@ def check_start(row: Row, row_prev: Row): # 0. rw_counter is 0 assert row.rw_counter == 0 + # 1. mpt_counter is 0 + assert row.mpt_counter == 0 + @is_circuit_code def check_memory(row: Row, row_prev: Row): @@ -162,25 +222,31 @@ def check_stack(row: Row, row_prev: Row): @is_circuit_code -def check_storage(row: Row, row_prev: Row): +def check_storage(row: Row, row_prev: Row, tables: Tables): get_addr = lambda row: row.keys[2] get_storage_key = lambda row: row.keys[4] - - # TODO: cold VS warm - # TODO: connection to MPT on first and last access for each (address, key) + get_committed_value = lambda row: row.auxs[0] # 0. Unused keys are 0 - assert row.keys[1] == 0 assert row.keys[3] == 0 - # 1. First access for a set of all keys - # - # We add an extra write to set the value of the state in previous block, with rwc=0. + # 1. When keys don't change, committed_value must be kept equal + if all_keys_eq(row, row_prev): + assert get_committed_value(row) == get_committed_value(row_prev) + + # TODO: The current spec does an MPT lookup for every storage update. The + # next optimization consists on doing a single lookup merging all updates + # for a given key, using the first and last access values. + + # 2. MPT storage lookup with incremental counter # - # When the set of all keys changes (first access of storage (address, key)) - # - It must be a WRITE - if not all_keys_eq(row, row_prev): - assert row.is_write == 1 and row.rw_counter == 0 + # When the keys are equal in the previous row, the value_prev must be the + # value in previous row. When the keys change, value_prev is loaded from + # committed_value, which holds the storage value before the tx began. + value_prev = row_prev.value if all_keys_eq(row, row_prev) else get_committed_value(row) + tables.mpt_storage_lookup( + row.mpt_counter, get_addr(row), get_storage_key(row), row.value, value_prev + ) @is_circuit_code @@ -196,22 +262,32 @@ def check_call_context(row: Row, row_prev: Row): @is_circuit_code -def check_account(row: Row, row_prev: Row): +def check_account(row: Row, row_prev: Row, tables: Tables): get_addr = lambda row: row.keys[2] get_field_tag = lambda row: row.keys[3] + get_committed_value = lambda row: row.auxs[0] # 0. Unused keys are 0 assert row.keys[1] == 0 assert row.keys[4] == 0 - # 1. First access for a set of all keys - # - # We add an extra write to setup the value of the previous block, with rwc=0. + # 1. When keys don't change, committed_value must be kept equal + if all_keys_eq(row, row_prev): + assert get_committed_value(row) == get_committed_value(row_prev) + + # TODO: The current spec does an MPT lookup for every storage update. The + # next optimization consists on doing a single lookup merging all updates + # for a given key, using the first and last access values. + + # 2. MPT storage lookup with incremental counter # - # When the set of all keys changes (first access of storage (address, AccountFieldTag)) - # - It must be a WRITE - if not all_keys_eq(row, row_prev): - assert row.is_write == 1 and row.rw_counter == 0 + # When the keys are equal in the previous row, the value_prev must be the + # value in previous row. When the keys change, value_prev is loaded from + # committed_value, which holds the account value before the block began. + value_prev = row_prev.value if all_keys_eq(row, row_prev) else get_committed_value(row) + tables.mpt_account_lookup( + row.mpt_counter, get_field_tag(row), get_addr(row), row.value, value_prev + ) # NOTE: Value transition rules are constrained via the EVM circuit: for example, # Nonce only increases by 1 or decreases by 1 (on revert). @@ -227,6 +303,7 @@ def check_tx_refund(row: Row, row_prev: Row): assert row.keys[4] == 0 # TODO: Missing constraints + # - When keys change, value must be 0 @is_circuit_code @@ -239,6 +316,7 @@ def check_tx_access_list_account(row: Row, row_prev: Row): assert row.keys[4] == 0 # TODO: Missing constraints + # - When keys change, value must be 0 @is_circuit_code @@ -251,6 +329,7 @@ def check_tx_access_list_account_storage(row: Row, row_prev: Row): assert row.keys[3] == 0 # TODO: Missing constraints + # - When keys change, value must be 0 @is_circuit_code @@ -263,6 +342,7 @@ def check_account_destructed(row: Row, row_prev: Row): assert row.keys[4] == 0 # TODO: Missing constraints + # - When keys change, value must be 0 @is_circuit_code @@ -329,7 +409,7 @@ def check_tx_receipt(row: Row, row_prev: Row): @is_circuit_code -def check_state_row(row: Row, row_prev: Row, randomness: FQ): +def check_state_row(row: Row, row_prev: Row, tables: Tables, randomness: FQ): # # Constraints that affect all rows, no matter which Tag they use # @@ -400,6 +480,20 @@ def get_keys_compressed_in_order(row: Row) -> List[FQ]: if row.is_write == 0 and all_keys_eq(row, row_prev): assert row.value == row_prev.value + # 7. Increment mpt_counter + # + # When row is Storage or Account, increment the mpt_counter by + # one, otherwise maintain the same value + if row.tag() != Tag.Start: + if row.tag() == Tag.Storage or row.tag() == Tag.Account: + assert row.mpt_counter == row_prev.mpt_counter + 1 + else: + assert row.mpt_counter == row_prev.mpt_counter + + # 8. RWC !=0 except for Tag.Start + if row.tag() != Tag.Start: + assert row.rw_counter != 0 + # # Constraints specific to each Tag # @@ -410,11 +504,11 @@ def get_keys_compressed_in_order(row: Row) -> List[FQ]: elif row.tag() == Tag.Stack: check_stack(row, row_prev) elif row.tag() == Tag.Storage: - check_storage(row, row_prev) + check_storage(row, row_prev, tables) elif row.tag() == Tag.CallContext: check_call_context(row, row_prev) elif row.tag() == Tag.Account: - check_account(row, row_prev) + check_account(row, row_prev, tables) elif row.tag() == Tag.TxRefund: check_tx_refund(row, row_prev) elif row.tag() == Tag.TxAccessListAccountStorage: @@ -446,7 +540,6 @@ class Operation(NamedTuple): key4: U256 value: FQ aux0: FQ - aux1: FQ class StartOp(Operation): @@ -458,7 +551,7 @@ def __new__(self): # fmt: off return super().__new__(self, 0, 0, U256(Tag.Start), U256(0), U256(0), U256(0), U256(0), # keys - FQ(0), FQ(0), FQ(0)) # values + FQ(0), FQ(0)) # values # fmt: on @@ -476,7 +569,7 @@ def __new__(self, rw_counter: int, rw: RW, call_id: int, mem_addr: U160, value: # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.Memory), U256(call_id), U256(mem_addr), U256(0), U256(0), # keys - FQ(value), FQ(0), FQ(0)) # values + FQ(value), FQ(0)) # values # fmt: on @@ -489,7 +582,7 @@ def __new__(self, rw_counter: int, rw: RW, call_id: int, stack_ptr: int, value: # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.Stack), U256(call_id), U256(stack_ptr), U256(0), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -498,11 +591,20 @@ class StorageOp(Operation): Storage Operation """ - def __new__(self, rw_counter: int, rw: RW, addr: U160, key: U256, value: FQ): + def __new__( + self, + rw_counter: int, + rw: RW, + tx_id: int, + addr: U160, + key: U256, + value: FQ, + committed_value: FQ, + ): # fmt: off return super().__new__(self, rw_counter, rw, - U256(Tag.Storage), U256(0), U256(addr), U256(0), U256(key), # keys - value, FQ(0), FQ(0)) # values + U256(Tag.Storage), U256(tx_id), U256(addr), U256(0), U256(key), # keys + value, committed_value) # values # fmt: on @@ -517,7 +619,7 @@ def __new__( # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.CallContext), U256(call_id), U256(0), U256(field_tag), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -526,11 +628,19 @@ class AccountOp(Operation): Account Operation """ - def __new__(self, rw_counter: int, rw: RW, addr: U160, field_tag: AccountFieldTag, value: FQ): + def __new__( + self, + rw_counter: int, + rw: RW, + addr: U160, + field_tag: AccountFieldTag, + value: FQ, + committed_value: FQ, + ): # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.Account), U256(0), U256(addr), U256(field_tag), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, committed_value) # values # fmt: on @@ -543,7 +653,7 @@ def __new__(self, rw_counter: int, rw: RW, tx_id: int, value: FQ): # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.TxRefund), U256(tx_id), U256(0), U256(0), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -556,7 +666,7 @@ def __new__(self, rw_counter: int, rw: RW, tx_id: int, addr: U160, value: FQ): # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.TxAccessListAccount), U256(tx_id), U256(addr), U256(0), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -570,7 +680,7 @@ def __new__(self, rw_counter: int, rw: RW, tx_id: int, addr: U160, key: U256, va return super().__new__(self, rw_counter, rw, U256(Tag.TxAccessListAccountStorage), U256(tx_id), U256(addr), U256(0), U256(key), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -583,7 +693,7 @@ def __new__(self, rw_counter: int, rw: RW, addr: U160, value: FQ): # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.AccountDestructed), U256(0), U256(addr), U256(0), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -605,7 +715,7 @@ def __new__( # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.TxLog),U256(tx_id), U256(log_id), U256(field_tag), U256(index), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on @@ -618,32 +728,43 @@ def __new__(self, rw_counter: int, rw: RW, tx_id: int, field_tag: TxReceiptField # fmt: off return super().__new__(self, rw_counter, rw, U256(Tag.TxReceipt), U256(tx_id), U256(0), U256(field_tag), U256(0), # keys - value, FQ(0), FQ(0)) # values + value, FQ(0)) # values # fmt: on -def op2row(op: Operation, randomness: FQ) -> Row: - rw_counter = FQ(op.rw_counter) - is_write = FQ(0) if op.rw == RW.Read else FQ(1) - key0 = FQ(op.key0) - key1 = FQ(op.key1) - key2 = FQ(op.key2) - key2_bytes = op.key2.to_bytes(20, "little") - key2_limbs = tuple([FQ(key2_bytes[i] + 2**8 * key2_bytes[i + 1]) for i in range(0, 20, 2)]) - key3 = FQ(op.key3) - key4_rlc = RLC(op.key4, randomness) - key4 = key4_rlc.expr() - key4_bytes = tuple([FQ(x) for x in key4_rlc.le_bytes]) - value = FQ(op.value) - aux0 = FQ(op.aux0) - aux1 = FQ(op.aux1) +class Assigner: + mpt_counter: FQ + + def __init__(self): + self.mpt_counter = FQ(0) + + def op2row(self, op: Operation, randomness: FQ) -> Row: + rw_counter = FQ(op.rw_counter) + is_write = FQ(0) if op.rw == RW.Read else FQ(1) + key0 = FQ(op.key0) + key1 = FQ(op.key1) + key2 = FQ(op.key2) + key2_bytes = op.key2.to_bytes(20, "little") + key2_limbs = tuple( + [FQ(key2_bytes[i] + 2**8 * key2_bytes[i + 1]) for i in range(0, 20, 2)] + ) + key3 = FQ(op.key3) + key4_rlc = RLC(op.key4, randomness) + key4 = key4_rlc.expr() + key4_bytes = tuple([FQ(x) for x in key4_rlc.le_bytes]) + value = FQ(op.value) + aux0 = FQ(op.aux0) + + if key0 == FQ(Tag.Storage) or key0 == FQ(Tag.Account): + self.mpt_counter += 1 - # fmt: off - return Row(rw_counter, is_write, - # keys - (key0, key1, key2, key3, key4), key2_limbs, key4_bytes, # type: ignore - value, (aux0, aux1)) # values - # fmt: on + # fmt: off + return Row(rw_counter, is_write, + # keys + (key0, key1, key2, key3, key4), key2_limbs, key4_bytes, # type: ignore + value, (aux0,), # values + self.mpt_counter) + # fmt: on # def rw_table_tag2tag(tag: RWTableTag) -> FQ: @@ -673,5 +794,42 @@ def op2row(op: Operation, randomness: FQ) -> Row: # Generate the advice Rows from a list of Operations def assign_state_circuit(ops: List[Operation], randomness: FQ) -> List[Row]: - rows = [op2row(op, randomness) for op in ops] + assigner = Assigner() + rows = [assigner.op2row(op, randomness) for op in ops] return rows + + +def mpt_table_from_ops( + ops_or_rows: Union[List[Operation], List[Row]], randomness: FQ +) -> Set[MPTTableRow]: + if isinstance(ops_or_rows[0], Operation): + rows = assign_state_circuit(cast(List[Operation], ops_or_rows), randomness) + else: + rows = cast(List[Row], ops_or_rows) + + mpt_rows = [] + for (idx, row) in enumerate(rows): + value_prev = row.auxs[0] + if idx > 0: + row_prev = rows[idx - 1] + if all_keys_eq(row, row_prev): + value_prev = row_prev.value + + if row.keys[0] == FQ(Tag.Storage): + mpt_rows.append( + MPTTableRow( + row.mpt_counter, + FQ(MPTTableTag.Storage), + row.keys[2], + row.keys[4], + row.value, + value_prev, + ) + ) + elif row.keys[0] == FQ(Tag.Account): + mpt_rows.append( + MPTTableRow( + row.mpt_counter, row.keys[3], row.keys[2], row.keys[4], row.value, value_prev + ) + ) + return set(mpt_rows) diff --git a/tests/test_state_circuit.py b/tests/test_state_circuit.py index 062ea4477..ddd8d7811 100644 --- a/tests/test_state_circuit.py +++ b/tests/test_state_circuit.py @@ -7,8 +7,18 @@ randomness = rand_fq() r = randomness + +def rlc(v: int) -> FQ: + return RLC(v, r).expr() + + # Verify the state circuit with the given data -def verify(ops_or_rows: Union[List[Operation], List[Row]], randomness: FQ, success: bool = True): +def verify( + ops_or_rows: Union[List[Operation], List[Row]], + tables: Tables, + randomness: FQ, + success: bool = True, +): rows = ops_or_rows if isinstance(ops_or_rows[0], Operation): rows = assign_state_circuit(ops_or_rows, randomness) @@ -16,7 +26,7 @@ def verify(ops_or_rows: Union[List[Operation], List[Row]], randomness: FQ, succe for (idx, row) in enumerate(rows): row_prev = rows[(idx - 1) % len(rows)] try: - check_state_row(row, row_prev, randomness) + check_state_row(row, row_prev, tables, randomness) except AssertionError as e: if success: traceback.print_exc() @@ -38,21 +48,18 @@ def test_state_ok(): MemoryOp(rw_counter=2, rw=RW.Write, call_id=1, mem_addr=0, value=42), MemoryOp(rw_counter=3, rw=RW.Read, call_id=1, mem_addr=0, value=42), - StackOp(rw_counter=4, rw=RW.Write, call_id=1, stack_ptr=1022, value=RLC(4321 ,r).expr()), - StackOp(rw_counter=5, rw=RW.Write, call_id=1, stack_ptr=1023, value=RLC(533 ,r).expr()), - StackOp(rw_counter=6, rw=RW.Read, call_id=1, stack_ptr=1023, value=RLC(533 ,r).expr()), + StackOp(rw_counter=4, rw=RW.Write, call_id=1, stack_ptr=1022, value=rlc(4321)), + StackOp(rw_counter=5, rw=RW.Write, call_id=1, stack_ptr=1023, value=rlc(533)), + StackOp(rw_counter=6, rw=RW.Read, call_id=1, stack_ptr=1023, value=rlc(533)), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=0x1516, value=RLC(789, r).expr()), - StorageOp(rw_counter=7, rw=RW.Read, addr=0x12345678, key=0x1516, value=RLC(789, r).expr()), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=0x4959, value=RLC(98765, r).expr()), - StorageOp(rw_counter=8, rw=RW.Write, addr=0x12345678, key=0x4959, value=RLC(38491, r).expr()), + StorageOp(rw_counter=7, rw=RW.Read, tx_id=1, addr=0x12345678, key=0x1516, value=rlc(789), committed_value=rlc(789)), + StorageOp(rw_counter=8, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x4959, value=rlc(38491), committed_value=rlc(98765)), CallContextOp(rw_counter= 9, rw=RW.Read, call_id=1, field_tag=CallContextFieldTag.IsStatic, value=FQ(0)), CallContextOp(rw_counter=10, rw=RW.Read, call_id=2, field_tag=CallContextFieldTag.IsStatic, value=FQ(0)), - AccountOp(rw_counter= 0, rw=RW.Write, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(0)), - AccountOp(rw_counter=12, rw=RW.Write, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(1)), - AccountOp(rw_counter=13, rw=RW.Read, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(1)), + AccountOp(rw_counter=12, rw=RW.Write, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(1), committed_value=FQ(0)), + AccountOp(rw_counter=13, rw=RW.Read, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(1), committed_value=FQ(0)), TxRefundOp(rw_counter=14, rw=RW.Write, tx_id=1, value=FQ(1)), TxRefundOp(rw_counter=15, rw=RW.Write, tx_id=1, value=FQ(1)), @@ -84,7 +91,8 @@ def test_state_ok(): TxReceiptOp(rw_counter=37, rw=RW.Read, tx_id=2, field_tag=TxReceiptFieldTag.CumulativeGasUsed, value=FQ(500)), ] # fmt: on - verify(ops, randomness) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness) def test_state_bad_key2(): @@ -95,62 +103,71 @@ def test_state_bad_key2(): ] # fmt: on rows = assign_state_circuit(ops, r) + # key2 doesn't match its limbs rows[1] = rows[1]._replace(key2_limbs=(FQ(1),) * 10) - verify(rows, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(rows, tables, randomness, success=False) def test_state_bad_key4(): # fmt: off ops = [ StartOp(), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=0x15161718, value=RLC(789, r).expr()), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(789), committed_value=rlc(789)), ] # fmt: on rows = assign_state_circuit(ops, r) + # key4 doesn't match its bytes rows[1] = rows[1]._replace(key4_bytes=(FQ(1),) * 10) - verify(rows, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(rows, tables, randomness, success=False) def test_state_bad_is_write(): # fmt: off ops = [ StartOp(), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=0x15161718, value=RLC(789, r).expr()), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(789), committed_value=rlc(789)), ] # fmt: on rows = assign_state_circuit(ops, r) + # is_write not boolean rows[1] = rows[1]._replace(is_write=FQ(2)) - verify(rows, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(rows, tables, randomness, success=False) def test_state_keys_non_lexicographic_order(): # fmt: off ops = [ StartOp(), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=0x1112, value=RLC(98765, r).expr()), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=0x1111, value=RLC(789, r).expr()), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x1112, value=rlc(98765), committed_value=rlc(98765)), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x1111, value=rlc(789), committed_value=rlc(98765)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off ops = [ StartOp(), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=2 << 250, value=RLC(98765, r).expr()), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=1 << 250, value=RLC(789, r).expr()), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=2 << 250, value=rlc(98765), committed_value=rlc(98765)), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=1 << 250, value=rlc(789), committed_value=rlc(98765)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off ops = [ StartOp(), - StorageOp(rw_counter=0, rw=RW.Write, addr=0x12345678, key=123, value=RLC(98765, r).expr()), - StorageOp(rw_counter=1, rw=RW.Write, addr=0x12345678, key=123, value=RLC(789, r).expr()), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=123, value=rlc(98765), committed_value=rlc(98765)), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=123, value=rlc(789), committed_value=rlc(98765)), MemoryOp(rw_counter=2, rw=RW.Read, call_id=1, mem_addr=0, value=0), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off ops = [ @@ -159,22 +176,26 @@ def test_state_keys_non_lexicographic_order(): MemoryOp(rw_counter=2, rw=RW.Read, call_id=1, mem_addr=0, value=0), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_state_bad_rwc(): # fmt: off + # rwc decreases ops = [ StartOp(), MemoryOp(rw_counter=2, rw=RW.Read, call_id=2, mem_addr=123, value=0), MemoryOp(rw_counter=1, rw=RW.Read, call_id=2, mem_addr=123, value=0), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_state_bad_read_consistency(): # fmt: off + # Read a 0 after writing a 8 ops = [ StartOp(), MemoryOp(rw_counter=1, rw=RW.Read, call_id=2, mem_addr=123, value=0), @@ -182,7 +203,8 @@ def test_state_bad_read_consistency(): MemoryOp(rw_counter=3, rw=RW.Read, call_id=2, mem_addr=123, value=0), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_start_bad(): @@ -193,8 +215,10 @@ def test_start_bad(): ] # fmt: on rows = assign_state_circuit(ops, r) + # rw_counter is 1 on Tag.Start rows[0] = rows[0]._replace(rw_counter=FQ(1)) - verify(rows, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(rows, tables, randomness, success=False) def first_memory_op(rw_counter=1, rw=RW.Write, call_id=1, mem_addr=2**32 - 1, value=3): @@ -203,89 +227,66 @@ def first_memory_op(rw_counter=1, rw=RW.Write, call_id=1, mem_addr=2**32 - 1, va def test_first_memory_op_ok(): ops = [StartOp(), first_memory_op()] - verify(ops, randomness, success=True) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=True) def test_memory_bad_address(): + # memory address too big ops = [StartOp(), first_memory_op(mem_addr=2**32)] - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_memory_bad_first_access(): + # first access is a read but value != 0 ops = [StartOp(), first_memory_op(rw=RW.Read)] - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_memory_bad_value_range(): + # memory value too big ops = [StartOp(), first_memory_op(value=2**8)] - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_stack_bad_first_access(): # fmt: off + # first stack operation is read ops = [ StartOp(), - StackOp(rw_counter=1, rw=RW.Read, call_id=1, stack_ptr=1023, value=RLC(4321 ,r).expr()), + StackOp(rw_counter=1, rw=RW.Read, call_id=1, stack_ptr=1023, value=rlc(4321)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_stack_bad_stack_ptr_range(): # fmt: off + # stack pointer is too big ops = [ StartOp(), - StackOp(rw_counter=1, rw=RW.Write, call_id=1, stack_ptr=1024, value=RLC(4321 ,r).expr()), + StackOp(rw_counter=1, rw=RW.Write, call_id=1, stack_ptr=1024, value=rlc(4321)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_stack_bad_stack_ptr_inc(): # fmt: off + # stack pointer increases by 2 ops = [ StartOp(), - StackOp(rw_counter=1, rw=RW.Write, call_id=1, stack_ptr=1021, value=RLC(4321 ,r).expr()), - StackOp(rw_counter=2, rw=RW.Write, call_id=1, stack_ptr=1023, value=RLC(4321 ,r).expr()), - ] - # fmt: on - verify(ops, randomness, success=False) - - -def test_storage_bad_first_access(): - # fmt: off - ops = [ - StartOp(), - StorageOp(rw_counter=0, rw=RW.Read, addr=0x12345678, key=0x1516, value=RLC(789, r).expr()), - ] - # fmt: on - verify(ops, randomness, success=False) - - # fmt: off - ops = [ - StartOp(), - StorageOp(rw_counter=1, rw=RW.Write, addr=0x12345678, key=0x1516, value=RLC(789, r).expr()), - ] - # fmt: on - verify(ops, randomness, success=False) - - -def test_account_bad_first_access(): - # fmt: off - ops = [ - StartOp(), - AccountOp(rw_counter= 0, rw=RW.Read, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(0)), - ] - # fmt: on - verify(ops, randomness, success=False) - - # fmt: off - ops = [ - StartOp(), - AccountOp(rw_counter=1, rw=RW.Write, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(0)), + StackOp(rw_counter=1, rw=RW.Write, call_id=1, stack_ptr=1021, value=rlc(4321)), + StackOp(rw_counter=2, rw=RW.Write, call_id=1, stack_ptr=1023, value=rlc(4321)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_tx_log_bad(): @@ -298,7 +299,8 @@ def test_tx_log_bad(): TxLogOp(rw_counter=3, rw=RW.Write, tx_id=1, log_id=0, field_tag=TxLogFieldTag.Topic, index=0, value=FQ(5)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # topic index out of range >= 4 @@ -312,7 +314,8 @@ def test_tx_log_bad(): TxLogOp(rw_counter=6, rw=RW.Write, tx_id=1, log_id=0, field_tag=TxLogFieldTag.Topic, index=4, value=FQ(5)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # Data index is not increasing @@ -323,7 +326,8 @@ def test_tx_log_bad(): TxLogOp(rw_counter=3, rw=RW.Write, tx_id=1, log_id=0, field_tag=TxLogFieldTag.Data, index=0, value=FQ(255)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # log id is decreasing @@ -334,7 +338,8 @@ def test_tx_log_bad(): TxLogOp(rw_counter=3, rw=RW.Write, tx_id=1, log_id=0, field_tag=TxLogFieldTag.Data, index=0, value=FQ(255)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # TxLogFieldTag is decreasing @@ -345,7 +350,8 @@ def test_tx_log_bad(): TxLogOp(rw_counter=3, rw=RW.Write, tx_id=1, log_id=0, field_tag=TxLogFieldTag.Data, index=0, value=FQ(255)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # when tx_id change, log_id is not reset @@ -356,7 +362,8 @@ def test_tx_log_bad(): TxLogOp(rw_counter=3, rw=RW.Write, tx_id=2, log_id=1, field_tag=TxLogFieldTag.Data, index=0, value=FQ(255)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) def test_tx_receipt_bad(): @@ -367,7 +374,8 @@ def test_tx_receipt_bad(): TxReceiptOp(rw_counter=1, rw=RW.Read, tx_id=1, field_tag=TxReceiptFieldTag.PostStateOrStatus, value=FQ(3)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # tx_id is decreasing when changes @@ -377,7 +385,8 @@ def test_tx_receipt_bad(): TxReceiptOp(rw_counter=2, rw=RW.Read, tx_id=1, field_tag=TxReceiptFieldTag.CumulativeGasUsed, value=FQ(200)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) # fmt: off # tx_id is not increasing by one @@ -387,4 +396,57 @@ def test_tx_receipt_bad(): TxReceiptOp(rw_counter=2, rw=RW.Read, tx_id=5, field_tag=TxReceiptFieldTag.CumulativeGasUsed, value=FQ(200)), ] # fmt: on - verify(ops, randomness, success=False) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) + + +def test_rw_counter_zero_bad(): + # fmt: off + # rw_counter is 0 but tag is not Start + ops = [ + StartOp(), + MemoryOp(rw_counter=0, rw=RW.Read, call_id=2, mem_addr=123, value=0), + ] + # fmt: on + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) + + +def test_storage_committed_value_bad(): + # fmt: off + # Committed value changes but keys don't + ops = [ + StartOp(), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(789), committed_value=rlc(789)), + StorageOp(rw_counter=2, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(123), committed_value=rlc(123)), + ] + # fmt: on + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness, success=False) + + +def test_mpt_counter_bad(): + # fmt: off + ops = [ + StartOp(), + StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(789), committed_value=rlc(789)), + StorageOp(rw_counter=2, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(123), committed_value=rlc(789)), + ] + # fmt: on + rows = assign_state_circuit(ops, r) + # mpt_counter goes from 1 to 3 + rows[2] = rows[2]._replace(mpt_counter=FQ(3)) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(rows, tables, randomness, success=False) + + # fmt: off + ops = [ + StartOp(), + StackOp(rw_counter=1, rw=RW.Write, call_id=1, stack_ptr=1021, value=rlc(4321)), + ] + # fmt: on + rows = assign_state_circuit(ops, r) + # mpt_counter increases when tag is not Account or Storage + rows[1] = rows[1]._replace(mpt_counter=FQ(1)) + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(rows, tables, randomness, success=False)