Skip to content
This repository has been archived by the owner on Jul 5, 2024. It is now read-only.

Add MPT table lookup to state circuit #200

Merged
merged 4 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions src/zkevm_specs/evm/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ class RW(IntEnum):
Write = 1


class MPTTableTag(IntEnum):
"""
Tag for MPTTable lookup
"""

Nonce = 1
Balance = 2
CodeHash = 4
Storage = 8


class RWTableTag(IntEnum):
"""
Tag for RWTable lookup, where the RWTable an advice-column table built by
Expand Down Expand Up @@ -179,9 +190,9 @@ def write_with_reversion(self) -> bool:


class AccountFieldTag(IntEnum):
Nonce = auto()
Balance = auto()
CodeHash = auto()
Nonce = 1
ed255 marked this conversation as resolved.
Show resolved Hide resolved
Balance = 2
CodeHash = 4


class CallContextFieldTag(IntEnum):
Expand Down Expand Up @@ -343,6 +354,16 @@ class RWTableRow(TableRow):
aux1: Expression = field(default=FQ(0))


@dataclass(frozen=True)
class MPTTableRow(TableRow):
counter: Expression
target: Expression # MPTTableTag
address: Expression
key: Expression
value: Expression
value_prev: Expression


class Tables:
"""
A collection of lookup tables used in EVM circuit.
Expand Down Expand Up @@ -391,7 +412,7 @@ def block_lookup(
self, field_tag: Expression, block_number: Expression = FQ(0)
) -> BlockTableRow:
query = {"field_tag": field_tag, "block_number_or_zero": block_number}
return _lookup(BlockTableRow, self.block_table, query)
return lookup(BlockTableRow, self.block_table, query)

def tx_lookup(
self, tx_id: Expression, field_tag: Expression, call_data_index: Expression = FQ(0)
Expand All @@ -401,7 +422,7 @@ def tx_lookup(
"field_tag": field_tag,
"call_data_index_or_zero": call_data_index,
}
return _lookup(TxTableRow, self.tx_table, query)
return lookup(TxTableRow, self.tx_table, query)

def bytecode_lookup(
self,
Expand All @@ -416,7 +437,7 @@ def bytecode_lookup(
"index": index,
"is_code": is_code,
}
return _lookup(BytecodeTableRow, self.bytecode_table, query)
return lookup(BytecodeTableRow, self.bytecode_table, query)

def rw_lookup(
self,
Expand Down Expand Up @@ -445,13 +466,13 @@ def rw_lookup(
"aux0": aux0,
"aux1": aux1,
}
return _lookup(RWTableRow, self.rw_table, query)
return lookup(RWTableRow, self.rw_table, query)


T = TypeVar("T", bound=TableRow)


def _lookup(
def lookup(
table_cls: Type[T],
table: Set[T],
query: Mapping[str, Optional[Expression]],
Expand Down
185 changes: 155 additions & 30 deletions src/zkevm_specs/state.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import NamedTuple, Tuple, List, Sequence
from typing import NamedTuple, Tuple, List, Sequence, Set, Union, cast
from enum import IntEnum
from math import log, ceil

from .util import FQ, RLC, U160, U256
from .util import FQ, RLC, U160, U256, Expression
from .encoding import U8, is_circuit_code
from .evm import RW, AccountFieldTag, CallContextFieldTag, TxLogFieldTag, TxReceiptFieldTag
from .evm import (
RW,
AccountFieldTag,
CallContextFieldTag,
TxLogFieldTag,
TxReceiptFieldTag,
MPTTableRow,
MPTTableTag,
lookup,
)

MAX_MEMORY_ADDRESS = 2**32 - 1
MAX_KEY_DIFF = 2**32 - 1
Expand Down Expand Up @@ -62,12 +71,60 @@ class Row(NamedTuple):
FQ,FQ,FQ,FQ,FQ,FQ,FQ,FQ]
value: FQ
auxs: Tuple[FQ, FQ]
mpt_counter: FQ
# fmt: on

def tag(self):
return self.keys[0]


class Tables:
"""
Tables used for lookup from the state circuit.
"""

mpt_table: Set[MPTTableRow]

def __init__(self, mpt_table: Set[MPTTableRow]):
self.mpt_table = mpt_table

def mpt_account_lookup(
self,
counter: Expression,
target: Expression,
address: Expression,
value: Expression,
value_prev: Expression,
) -> MPTTableRow:
query = {
"counter": counter,
"target": target,
"address": address,
"key": FQ(0),
"value": value,
"value_prev": value_prev,
}
return lookup(MPTTableRow, self.mpt_table, query)

def mpt_storage_lookup(
self,
counter: Expression,
address: Expression,
key: Expression,
value: Expression,
value_prev: Expression,
) -> MPTTableRow:
query = {
"counter": counter,
"target": FQ(MPTTableTag.Storage),
"address": address,
"key": key,
"value": value,
"value_prev": value_prev,
}
return lookup(MPTTableRow, self.mpt_table, query)


def linear_combine(limbs: Sequence[FQ], base: FQ) -> FQ:
ret = FQ.zero()
for limb in reversed(limbs):
Expand Down Expand Up @@ -162,7 +219,7 @@ def check_stack(row: Row, row_prev: Row):


@is_circuit_code
def check_storage(row: Row, row_prev: Row):
def check_storage(row: Row, row_prev: Row, tables: Tables):
get_addr = lambda row: row.keys[2]
get_storage_key = lambda row: row.keys[4]

Expand All @@ -182,6 +239,12 @@ def check_storage(row: Row, row_prev: Row):
if not all_keys_eq(row, row_prev):
assert row.is_write == 1 and row.rw_counter == 0

# 2. MPT storage lookup with incremental counter
value_prev = row_prev.value if all_keys_eq(row, row_prev) else FQ(0)
han0110 marked this conversation as resolved.
Show resolved Hide resolved
tables.mpt_storage_lookup(
row.mpt_counter, get_addr(row), get_storage_key(row), row.value, value_prev
)


@is_circuit_code
def check_call_context(row: Row, row_prev: Row):
Expand All @@ -196,7 +259,7 @@ def check_call_context(row: Row, row_prev: Row):


@is_circuit_code
def check_account(row: Row, row_prev: Row):
def check_account(row: Row, row_prev: Row, tables: Tables):
get_addr = lambda row: row.keys[2]
get_field_tag = lambda row: row.keys[3]

Expand All @@ -213,6 +276,12 @@ def check_account(row: Row, row_prev: Row):
if not all_keys_eq(row, row_prev):
assert row.is_write == 1 and row.rw_counter == 0

# 2. MPT storage lookup with incremental counter
value_prev = row_prev.value if all_keys_eq(row, row_prev) else FQ(0)
tables.mpt_account_lookup(
row.mpt_counter, get_field_tag(row), get_addr(row), row.value, value_prev
)

# NOTE: Value transition rules are constrained via the EVM circuit: for example,
# Nonce only increases by 1 or decreases by 1 (on revert).

Expand Down Expand Up @@ -329,7 +398,7 @@ def check_tx_receipt(row: Row, row_prev: Row):


@is_circuit_code
def check_state_row(row: Row, row_prev: Row, randomness: FQ):
def check_state_row(row: Row, row_prev: Row, tables: Tables, randomness: FQ):
#
# Constraints that affect all rows, no matter which Tag they use
#
Expand Down Expand Up @@ -400,6 +469,12 @@ def get_keys_compressed_in_order(row: Row) -> List[FQ]:
if row.is_write == 0 and all_keys_eq(row, row_prev):
assert row.value == row_prev.value

# 7. Increment mpt_counter
#
# When previous row is Storage or Account, increment the mpt_counter by one
if row_prev.tag() == Tag.Storage or row_prev.tag() == Tag.Account:
assert row.mpt_counter == row_prev.mpt_counter + 1

han0110 marked this conversation as resolved.
Show resolved Hide resolved
#
# Constraints specific to each Tag
#
Expand All @@ -410,11 +485,11 @@ def get_keys_compressed_in_order(row: Row) -> List[FQ]:
elif row.tag() == Tag.Stack:
check_stack(row, row_prev)
elif row.tag() == Tag.Storage:
check_storage(row, row_prev)
check_storage(row, row_prev, tables)
elif row.tag() == Tag.CallContext:
check_call_context(row, row_prev)
elif row.tag() == Tag.Account:
check_account(row, row_prev)
check_account(row, row_prev, tables)
elif row.tag() == Tag.TxRefund:
check_tx_refund(row, row_prev)
elif row.tag() == Tag.TxAccessListAccountStorage:
Expand Down Expand Up @@ -622,28 +697,41 @@ def __new__(self, rw_counter: int, rw: RW, tx_id: int, field_tag: TxReceiptField
# fmt: on


def op2row(op: Operation, randomness: FQ) -> Row:
rw_counter = FQ(op.rw_counter)
is_write = FQ(0) if op.rw == RW.Read else FQ(1)
key0 = FQ(op.key0)
key1 = FQ(op.key1)
key2 = FQ(op.key2)
key2_bytes = op.key2.to_bytes(20, "little")
key2_limbs = tuple([FQ(key2_bytes[i] + 2**8 * key2_bytes[i + 1]) for i in range(0, 20, 2)])
key3 = FQ(op.key3)
key4_rlc = RLC(op.key4, randomness)
key4 = key4_rlc.expr()
key4_bytes = tuple([FQ(x) for x in key4_rlc.le_bytes])
value = FQ(op.value)
aux0 = FQ(op.aux0)
aux1 = FQ(op.aux1)
class Assigner:
mpt_counter: FQ

def __init__(self):
self.mpt_counter = FQ(1)

def op2row(self, op: Operation, randomness: FQ) -> Row:
rw_counter = FQ(op.rw_counter)
is_write = FQ(0) if op.rw == RW.Read else FQ(1)
key0 = FQ(op.key0)
key1 = FQ(op.key1)
key2 = FQ(op.key2)
key2_bytes = op.key2.to_bytes(20, "little")
key2_limbs = tuple(
[FQ(key2_bytes[i] + 2**8 * key2_bytes[i + 1]) for i in range(0, 20, 2)]
)
key3 = FQ(op.key3)
key4_rlc = RLC(op.key4, randomness)
key4 = key4_rlc.expr()
key4_bytes = tuple([FQ(x) for x in key4_rlc.le_bytes])
value = FQ(op.value)
aux0 = FQ(op.aux0)
aux1 = FQ(op.aux1)
mpt_counter = self.mpt_counter

if key0 == FQ(Tag.Storage) or key0 == FQ(Tag.Account):
self.mpt_counter += 1

# fmt: off
return Row(rw_counter, is_write,
# keys
(key0, key1, key2, key3, key4), key2_limbs, key4_bytes, # type: ignore
value, (aux0, aux1)) # values
# fmt: on
# fmt: off
return Row(rw_counter, is_write,
# keys
(key0, key1, key2, key3, key4), key2_limbs, key4_bytes, # type: ignore
value, (aux0, aux1), # values
mpt_counter)
# fmt: on


# def rw_table_tag2tag(tag: RWTableTag) -> FQ:
Expand Down Expand Up @@ -673,5 +761,42 @@ def op2row(op: Operation, randomness: FQ) -> Row:

# Generate the advice Rows from a list of Operations
def assign_state_circuit(ops: List[Operation], randomness: FQ) -> List[Row]:
rows = [op2row(op, randomness) for op in ops]
assigner = Assigner()
rows = [assigner.op2row(op, randomness) for op in ops]
return rows


def mpt_table_from_ops(
ops_or_rows: Union[List[Operation], List[Row]], randomness: FQ
) -> Set[MPTTableRow]:
if isinstance(ops_or_rows[0], Operation):
rows = assign_state_circuit(cast(List[Operation], ops_or_rows), randomness)
else:
rows = cast(List[Row], ops_or_rows)

mpt_rows = []
for (idx, row) in enumerate(rows):
value_prev = FQ(0)
if idx > 0:
row_prev = rows[idx - 1]
if all_keys_eq(row, row_prev):
value_prev = row_prev.value

if row.keys[0] == FQ(Tag.Storage):
mpt_rows.append(
MPTTableRow(
row.mpt_counter,
FQ(MPTTableTag.Storage),
row.keys[2],
row.keys[4],
row.value,
value_prev,
)
)
elif row.keys[0] == FQ(Tag.Account):
mpt_rows.append(
MPTTableRow(
row.mpt_counter, row.keys[3], row.keys[2], row.keys[4], row.value, value_prev
)
)
return set(mpt_rows)
Loading