diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml new file mode 100644 index 0000000..b5c1a13 --- /dev/null +++ b/.github/workflows/deploy.yaml @@ -0,0 +1,45 @@ +name: deploy + +on: + push: + branches: [main] + tags: [v*] + +jobs: + deploy: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v3 + + - uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - id: meta + uses: docker/metadata-action@v4 + with: + images: ghcr.io/${{ github.repository }} + labels: | + org.opencontainers.image.licenses=MIT OR Apache-2.0 + - name: Push Project Image + uses: docker/build-push-action@v3 + with: + context: . + file: Dockerfile + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + - uses: cowprotocol/autodeploy-action@v2 + if: ${{ github.ref == 'refs/heads/main' }} + with: + images: ghcr.io/cowprotocol/token-imbalances:main + url: ${{ secrets.AUTODEPLOY_URL }} + token: ${{ secrets.AUTODEPLOY_TOKEN }} + timeout: 600000 # 10 minutes \ No newline at end of file diff --git a/.github/workflows/pull_request.yaml b/.github/workflows/pull_request.yaml new file mode 100644 index 0000000..378bd13 --- /dev/null +++ b/.github/workflows/pull_request.yaml @@ -0,0 +1,28 @@ +name: pull request +on: + pull_request: + push: + branches: [ main ] +jobs: + python: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Setup Python 3.12 + uses: actions/setup-python@v3 + with: + python-version: '3.12' + - name: Install Requirements + run: + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Set PYTHONPATH + run: echo "PYTHONPATH=src" >> $GITHUB_ENV + - name: Pylint + run: + pylint --fail-under=8 $(git ls-files '**/*.py') + - name: Black + run: + black --check ./ + - name: Type Check (mypy) + run: mypy src diff --git a/contracts/erc20_abi.py b/contracts/erc20_abi.py index 40e7d65..03a62f2 100644 --- a/contracts/erc20_abi.py +++ b/contracts/erc20_abi.py @@ -1,222 +1,117 @@ erc20_abi = [ - { - "constant": True, - "inputs": [], - "name": "name", - "outputs": [ - { - "name": "", - "type": "string" - } - ], - "payable": False, - "stateMutability": "view", - "type": "function" - }, - { - "constant": False, - "inputs": [ - { - "name": "_spender", - "type": "address" - }, - { - "name": "_value", - "type": "uint256" - } - ], - "name": "approve", - "outputs": [ - { - "name": "", - "type": "bool" - } - ], - "payable": False, - "stateMutability": "nonpayable", - "type": "function" - }, - { - "constant": True, - "inputs": [], - "name": "totalSupply", - "outputs": [ - { - "name": "", - "type": "uint256" - } - ], - "payable": False, - "stateMutability": "view", - "type": "function" - }, - { - "constant": False, - "inputs": [ - { - "name": "_from", - "type": "address" - }, - { - "name": "_to", - "type": "address" - }, - { - "name": "_value", - "type": "uint256" - } - ], - "name": "transferFrom", - "outputs": [ - { - "name": "", - "type": "bool" - } - ], - "payable": False, - "stateMutability": "nonpayable", - "type": "function" - }, - { - "constant": True, - "inputs": [], - "name": "decimals", - "outputs": [ - { - "name": "", - "type": "uint8" - } - ], - "payable": False, - "stateMutability": "view", - "type": "function" - }, - { - "constant": True, - "inputs": [ - { - "name": "_owner", - "type": "address" - } - ], - "name": "balanceOf", - "outputs": [ - { - "name": "balance", - "type": "uint256" - } - ], - "payable": False, - "stateMutability": "view", - "type": "function" - }, - { - "constant": True, - "inputs": [], - "name": "symbol", - "outputs": [ - { - "name": "", - "type": "string" - } - ], - "payable": False, - "stateMutability": "view", - "type": "function" - }, - { - "constant": False, - "inputs": [ - { - "name": "_to", - "type": "address" - }, - { - "name": "_value", - "type": "uint256" - } - ], - "name": "transfer", - "outputs": [ - { - "name": "", - "type": "bool" - } - ], - "payable": False, - "stateMutability": "nonpayable", - "type": "function" - }, - { - "constant": True, - "inputs": [ - { - "name": "_owner", - "type": "address" - }, - { - "name": "_spender", - "type": "address" - } - ], - "name": "allowance", - "outputs": [ - { - "name": "", - "type": "uint256" - } - ], - "payable": False, - "stateMutability": "view", - "type": "function" - }, - { - "payable": True, - "stateMutability": "payable", - "type": "fallback" - }, - { - "anonymous": False, - "inputs": [ - { - "indexed": True, - "name": "owner", - "type": "address" - }, - { - "indexed": True, - "name": "spender", - "type": "address" - }, - { - "indexed": False, - "name": "value", - "type": "uint256" - } - ], - "name": "Approval", - "type": "event" - }, - { - "anonymous": False, - "inputs": [ - { - "indexed": True, - "name": "from", - "type": "address" - }, - { - "indexed": True, - "name": "to", - "type": "address" - }, - { - "indexed": False, - "name": "value", - "type": "uint256" - } - ], - "name": "Transfer", - "type": "event" - } -] \ No newline at end of file + { + "constant": True, + "inputs": [], + "name": "name", + "outputs": [{"name": "", "type": "string"}], + "payable": False, + "stateMutability": "view", + "type": "function", + }, + { + "constant": False, + "inputs": [ + {"name": "_spender", "type": "address"}, + {"name": "_value", "type": "uint256"}, + ], + "name": "approve", + "outputs": [{"name": "", "type": "bool"}], + "payable": False, + "stateMutability": "nonpayable", + "type": "function", + }, + { + "constant": True, + "inputs": [], + "name": "totalSupply", + "outputs": [{"name": "", "type": "uint256"}], + "payable": False, + "stateMutability": "view", + "type": "function", + }, + { + "constant": False, + "inputs": [ + {"name": "_from", "type": "address"}, + {"name": "_to", "type": "address"}, + {"name": "_value", "type": "uint256"}, + ], + "name": "transferFrom", + "outputs": [{"name": "", "type": "bool"}], + "payable": False, + "stateMutability": "nonpayable", + "type": "function", + }, + { + "constant": True, + "inputs": [], + "name": "decimals", + "outputs": [{"name": "", "type": "uint8"}], + "payable": False, + "stateMutability": "view", + "type": "function", + }, + { + "constant": True, + "inputs": [{"name": "_owner", "type": "address"}], + "name": "balanceOf", + "outputs": [{"name": "balance", "type": "uint256"}], + "payable": False, + "stateMutability": "view", + "type": "function", + }, + { + "constant": True, + "inputs": [], + "name": "symbol", + "outputs": [{"name": "", "type": "string"}], + "payable": False, + "stateMutability": "view", + "type": "function", + }, + { + "constant": False, + "inputs": [ + {"name": "_to", "type": "address"}, + {"name": "_value", "type": "uint256"}, + ], + "name": "transfer", + "outputs": [{"name": "", "type": "bool"}], + "payable": False, + "stateMutability": "nonpayable", + "type": "function", + }, + { + "constant": True, + "inputs": [ + {"name": "_owner", "type": "address"}, + {"name": "_spender", "type": "address"}, + ], + "name": "allowance", + "outputs": [{"name": "", "type": "uint256"}], + "payable": False, + "stateMutability": "view", + "type": "function", + }, + {"payable": True, "stateMutability": "payable", "type": "fallback"}, + { + "anonymous": False, + "inputs": [ + {"indexed": True, "name": "owner", "type": "address"}, + {"indexed": True, "name": "spender", "type": "address"}, + {"indexed": False, "name": "value", "type": "uint256"}, + ], + "name": "Approval", + "type": "event", + }, + { + "anonymous": False, + "inputs": [ + {"indexed": True, "name": "from", "type": "address"}, + {"indexed": True, "name": "to", "type": "address"}, + {"indexed": False, "name": "value", "type": "uint256"}, + ], + "name": "Transfer", + "type": "event", + }, +] diff --git a/requirements.txt b/requirements.txt index 88a8084..247a178 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,11 @@ -web3==6.0.0 +requests==2.31.0 +web3>=6.0.0 pandas==2.2.1 SQLAlchemy==2.0.28 psycopg2==2.9.9 -python-dotenv==1.0.0 \ No newline at end of file +python-dotenv==1.0.0 +black==23.3.0 +mypy==1.4.1 +pylint==3.2.5 +pytest==7.4.0 +setuptools \ No newline at end of file diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index 617fc85..f317f8d 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -1,9 +1,13 @@ +# mypy: disable-error-code="call-overload, arg-type, operator" import sys import os -# for debugging purposes -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +# for debugging purposes +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from web3 import Web3 +from web3.types import TxReceipt +from eth_typing import ChecksumAddress from typing import Dict, Optional, Set from src.config import ETHEREUM_NODE_URL from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS @@ -11,15 +15,20 @@ # conducting sanity test only for ethereum mainnet transactions + class BalanceOfImbalances: def __init__(self, ETHEREUM_NODE_URL: str): self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL)) - def get_token_balance(self, token_address: str, account: str, block_identifier: int) -> Optional[int]: - """ Retrieve the ERC-20 token balance of an account at a given block. """ + def get_token_balance( + self, token_address: str, account: str, block_identifier: int + ) -> Optional[int]: + """Retrieve the ERC-20 token balance of an account at a given block.""" token_contract = self.web3.eth.contract(address=token_address, abi=erc20_abi) try: - return token_contract.functions.balanceOf(account).call(block_identifier=block_identifier) + return token_contract.functions.balanceOf(account).call( + block_identifier=block_identifier + ) except Exception as e: print(f"Error fetching balance for token {token_address}: {e}") return None @@ -38,14 +47,14 @@ def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: transfer_topics = { self.web3.keccak(text="Transfer(address,address,uint256)").hex(), self.web3.keccak(text="ERC20Transfer(address,address,uint256)").hex(), - self.web3.keccak(text="Withdrawal(address,uint256)").hex() + self.web3.keccak(text="Withdrawal(address,uint256)").hex(), } - for log in tx_receipt['logs']: - if log['topics'][0].hex() in transfer_topics: - token_addresses.add(log['address']) + for log in tx_receipt["logs"]: + if log["topics"][0].hex() in transfer_topics: + token_addresses.add(log["address"]) return token_addresses - def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict]: + def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: """Fetch the transaction receipt for the given hash.""" try: return self.web3.eth.get_transaction_receipt(tx_hash) @@ -53,21 +62,34 @@ def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict]: print(f"Error fetching transaction receipt for hash {tx_hash}: {e}") return None - def get_balances(self, token_addresses: Set[str], block_number: int) -> Dict[str, Optional[int]]: + def get_balances( + self, token_addresses: Set[ChecksumAddress], block_number: int + ) -> Dict[ChecksumAddress, Optional[int]]: """Get balances for all tokens at the given block number.""" balances = {} - balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance(SETTLEMENT_CONTRACT_ADDRESS, block_number) + balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance( + SETTLEMENT_CONTRACT_ADDRESS, block_number + ) for token_address in token_addresses: - balances[token_address] = self.get_token_balance(token_address, SETTLEMENT_CONTRACT_ADDRESS, block_number) + balances[token_address] = self.get_token_balance( + token_address, SETTLEMENT_CONTRACT_ADDRESS, block_number + ) return balances - def calculate_imbalances(self, prev_balances: Dict[str, Optional[int]], final_balances: Dict[str, Optional[int]]) -> Dict[str, int]: + def calculate_imbalances( + self, + prev_balances: Dict[str, Optional[int]], + final_balances: Dict[str, Optional[int]], + ) -> Dict[str, int]: """Calculate imbalances between previous and final balances.""" imbalances = {} for token_address in prev_balances: - if prev_balances[token_address] is not None and final_balances[token_address] is not None: + if ( + prev_balances[token_address] is not None + and final_balances[token_address] is not None + ): imbalance = final_balances[token_address] - prev_balances[token_address] imbalances[token_address] = imbalance return imbalances @@ -83,14 +105,15 @@ def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: print("No tokens involved in this transaction.") return {} - prev_block = tx_receipt['blockNumber'] - 1 - final_block = tx_receipt['blockNumber'] + prev_block = tx_receipt["blockNumber"] - 1 + final_block = tx_receipt["blockNumber"] prev_balances = self.get_balances(token_addresses, prev_block) final_balances = self.get_balances(token_addresses, final_block) return self.calculate_imbalances(prev_balances, final_balances) + def main(): tx_hash = input("Enter transaction hash: ") bo = BalanceOfImbalances(ETHEREUM_NODE_URL) @@ -99,5 +122,6 @@ def main(): for token_address, imbalance in imbalances.items(): print(f"Token: {token_address}, Imbalance: {imbalance}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/config.py b/src/config.py index d668d68..3d2e71b 100644 --- a/src/config.py +++ b/src/config.py @@ -2,15 +2,9 @@ from dotenv import load_dotenv load_dotenv() -ETHEREUM_NODE_URL = os.getenv('ETHEREUM_NODE_URL') -GNOSIS_NODE_URL = os.getenv('GNOSIS_NODE_URL') +ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL") +GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL") -CHAIN_RPC_ENDPOINTS = { - 'Ethereum': ETHEREUM_NODE_URL, - 'Gnosis': GNOSIS_NODE_URL -} +CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} -CHAIN_SLEEP_TIMES = { - 'Ethereum': 60, - 'Gnosis': 120 -} \ No newline at end of file +CHAIN_SLEEP_TIMES = {"Ethereum": 60, "Gnosis": 120} diff --git a/src/constants.py b/src/constants.py index 5a0e951..bbe179e 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,6 +1,14 @@ from web3 import Web3 -SETTLEMENT_CONTRACT_ADDRESS = Web3.to_checksum_address('0x9008D19f58AAbD9eD0D60971565AA8510560ab41') -NATIVE_ETH_TOKEN_ADDRESS = Web3.to_checksum_address('0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee') -WETH_TOKEN_ADDRESS = Web3.to_checksum_address('0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2') -SDAI_TOKEN_ADDRESS = Web3.to_checksum_address('0x83F20F44975D03b1b09e64809B757c47f942BEeA') +SETTLEMENT_CONTRACT_ADDRESS = Web3.to_checksum_address( + "0x9008D19f58AAbD9eD0D60971565AA8510560ab41" +) +NATIVE_ETH_TOKEN_ADDRESS = Web3.to_checksum_address( + "0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee" +) +WETH_TOKEN_ADDRESS = Web3.to_checksum_address( + "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2" +) +SDAI_TOKEN_ADDRESS = Web3.to_checksum_address( + "0x83F20F44975D03b1b09e64809B757c47f942BEeA" +) diff --git a/src/daemon.py b/src/daemon.py index ad7333a..d1ff665 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="import, arg-type" import os import time import pandas as pd @@ -9,22 +10,28 @@ from src.imbalances_script import RawTokenImbalances from src.config import CHAIN_RPC_ENDPOINTS, CHAIN_SLEEP_TIMES + def get_web3_instance(chain_name: str) -> Web3: return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) + def get_finalized_block_number(web3: Web3) -> int: return web3.eth.block_number - 64 + def create_db_connection(chain_name: str): - """function that creates a connection to the CoW db.""" - if chain_name == 'Ethereum': + """function that creates a connection to the CoW db.""" + if chain_name == "Ethereum": db_url = os.getenv("ETHEREUM_DB_URL") - elif chain_name == 'Gnosis': + elif chain_name == "Gnosis": db_url = os.getenv("GNOSIS_DB_URL") return create_engine(f"postgresql+psycopg2://{db_url}") -def fetch_transaction_hashes(db_connection: Engine, start_block: int, end_block: int) -> List[str]: + +def fetch_transaction_hashes( + db_connection: Engine, start_block: int, end_block: int +) -> List[str]: """Fetch transaction hashes beginning start_block.""" query = f""" SELECT tx_hash @@ -34,10 +41,11 @@ def fetch_transaction_hashes(db_connection: Engine, start_block: int, end_block: """ db_hashes = pd.read_sql(query, db_connection) - # converts hashes at memory location to hex - db_hashes['tx_hash'] = db_hashes['tx_hash'].apply(lambda x: f"0x{x.hex()}") - - return db_hashes['tx_hash'].tolist() + # converts hashes at memory location to hex + db_hashes["tx_hash"] = db_hashes["tx_hash"].apply(lambda x: f"0x{x.hex()}") + + return db_hashes["tx_hash"].tolist() + def process_transactions(chain_name: str) -> None: web3 = get_web3_instance(chain_name) @@ -46,20 +54,22 @@ def process_transactions(chain_name: str) -> None: db_connection = create_db_connection(chain_name) previous_block = get_finalized_block_number(web3) - unprocessed_txs = [] + unprocessed_txs = [] # type: List print(f"{chain_name} Daemon started.") - + while True: try: latest_block = get_finalized_block_number(web3) - new_txs = fetch_transaction_hashes(db_connection, previous_block, latest_block) + new_txs = fetch_transaction_hashes( + db_connection, previous_block, latest_block + ) # add any unprocessed hashes for processing, then clear list of unprocessed all_txs = new_txs + unprocessed_txs unprocessed_txs.clear() for tx in all_txs: - print(f'Processing transaction on {chain_name}: {tx}') + print(f"Processing transaction on {chain_name}: {tx}") try: imbalances = rt.compute_imbalances(tx) print(f"Token Imbalances on {chain_name}:") @@ -68,7 +78,7 @@ def process_transactions(chain_name: str) -> None: except ValueError as e: print(e) unprocessed_txs.append(tx) - + print("Done checks..") previous_block = latest_block + 1 except ConnectionError as e: @@ -78,16 +88,18 @@ def process_transactions(chain_name: str) -> None: time.sleep(sleep_time) + def main() -> None: threads = [] - + for chain_name in CHAIN_RPC_ENDPOINTS.keys(): thread = Thread(target=process_transactions, args=(chain_name,), daemon=True) thread.start() threads.append(thread) - + for thread in threads: thread.join() + if __name__ == "__main__": main() diff --git a/src/imbalances_script.py b/src/imbalances_script.py index 6d2d3ae..cfd5d5b 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="arg-type, operator, return, attr-defined" """ Steps for computing token imbalances: @@ -19,26 +20,32 @@ 9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI transfer involved which has special handling for its events. """ - from web3.datastructures import AttributeDict from typing import Dict, List, Optional, Tuple from web3 import Web3 +from web3.types import TxReceipt from src.config import CHAIN_RPC_ENDPOINTS -from src.constants import (SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS, - WETH_TOKEN_ADDRESS, SDAI_TOKEN_ADDRESS) +from src.constants import ( + SETTLEMENT_CONTRACT_ADDRESS, + NATIVE_ETH_TOKEN_ADDRESS, + WETH_TOKEN_ADDRESS, + SDAI_TOKEN_ADDRESS, +) EVENT_TOPICS = { - 'Transfer': 'Transfer(address,address,uint256)', - 'ERC20Transfer': 'ERC20Transfer(address,address,uint256)', - 'WithdrawalWETH': 'Withdrawal(address,uint256)', - 'DepositSDAI': 'Deposit(address,address,uint256,uint256)', - 'WithdrawSDAI': 'Withdraw(address,address,address,uint256,uint256)', + "Transfer": "Transfer(address,address,uint256)", + "ERC20Transfer": "ERC20Transfer(address,address,uint256)", + "WithdrawalWETH": "Withdrawal(address,uint256)", + "DepositSDAI": "Deposit(address,address,uint256,uint256)", + "WithdrawSDAI": "Withdraw(address,address,address,uint256,uint256)", } + def compute_event_topics(web3: Web3) -> Dict[str, str]: """Compute the event topics for all relevant events.""" return {name: web3.keccak(text=text).hex() for name, text in EVENT_TOPICS.items()} + def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: """ Find the chain where the transaction is present. @@ -57,19 +64,25 @@ def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: print(f"Transaction not found on {chain_name}: {e}") raise ValueError(f"Transaction hash {tx_hash} not found on any chain.") + def _to_int(value: str | int) -> int: """Convert hex string or integer to integer.""" try: - return int(value, 16) if isinstance(value, str) and value.startswith('0x') else int(value) + return ( + int(value, 16) + if isinstance(value, str) and value.startswith("0x") + else int(value) + ) except ValueError: print(f"Error converting value {value} to integer.") + class RawTokenImbalances: def __init__(self, web3: Web3, chain_name: str): self.web3 = web3 self.chain_name = chain_name - def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict]: + def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: """ Get the transaction receipt from the provided web3 instance. """ @@ -80,7 +93,7 @@ def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict]: return None def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: - """ Function used for retreiving trace to identify ETH transfers. """ + """Function used for retreiving trace to identify ETH transfers.""" try: res = self.web3.tracing.trace_transaction(tx_hash) return res @@ -89,19 +102,21 @@ def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: return None def extract_actions(self, traces: List[AttributeDict], address: str) -> List[Dict]: - """ Identify transfer events in trace involving the specified contract. """ + """Identify transfer events in trace involving the specified contract.""" normalized_address = Web3.to_checksum_address(address) actions = [] # input_field = '0x' denotes a native ETH transfer event, which we want to filter for - input_field: str = '0x' + input_field: str = "0x" for trace in traces: if isinstance(trace, AttributeDict): - action = trace.get('action', {}) - input_value = action.get('input', b"").hex() + action = trace.get("action", {}) + input_value = action.get("input", b"").hex() # filter out action if involved in an ETH transfer event if input_value == input_field and ( - Web3.to_checksum_address(action.get('from', '')) == normalized_address or - Web3.to_checksum_address(action.get('to', '')) == normalized_address + Web3.to_checksum_address(action.get("from", "")) + == normalized_address + or Web3.to_checksum_address(action.get("to", "")) + == normalized_address ): actions.append(dict(action)) return actions @@ -110,15 +125,15 @@ def calculate_native_eth_imbalance(self, actions: List[Dict], address: str) -> i """Extract ETH imbalance from transfer actions.""" # inflow is the total value transferred to address param inflow = sum( - _to_int(action['value']) + _to_int(action["value"]) for action in actions - if Web3.to_checksum_address(action.get('to', '')) == address + if Web3.to_checksum_address(action.get("to", "")) == address ) # outflow is the total value transferred out of address param outflow = sum( - _to_int(action['value']) + _to_int(action["value"]) for action in actions - if Web3.to_checksum_address(action.get('from', '')) == address + if Web3.to_checksum_address(action.get("from", "")) == address ) return inflow - outflow @@ -126,15 +141,19 @@ def extract_events(self, tx_receipt: Dict) -> Dict[str, List[Dict]]: """Extract relevant events from the transaction receipt.""" event_topics = compute_event_topics(self.web3) # transfer_topics are filtered to find imbalances for most ERC-20 tokens - transfer_topics = {k: v for k, v in event_topics.items() if k in ['Transfer', 'ERC20Transfer']} + transfer_topics = { + k: v for k, v in event_topics.items() if k in ["Transfer", "ERC20Transfer"] + } # other_topics is used to find imbalances for SDAI, ETH txss - other_topics = {k: v for k, v in event_topics.items() if k not in transfer_topics} + other_topics = { + k: v for k, v in event_topics.items() if k not in transfer_topics + } - events = {name: [] for name in EVENT_TOPICS} - for log in tx_receipt['logs']: - log_topic = log['topics'][0].hex() + events = {name: [] for name in EVENT_TOPICS} # type: dict + for log in tx_receipt["logs"]: + log_topic = log["topics"][0].hex() if log_topic in transfer_topics.values(): - events['Transfer'].append(log) + events["Transfer"].append(log) else: for event_name, topic in other_topics.items(): if log_topic == topic: @@ -142,22 +161,28 @@ def extract_events(self, tx_receipt: Dict) -> Dict[str, List[Dict]]: break return events - def decode_event(self, event: Dict) -> Tuple[Optional[str], Optional[str], Optional[int]]: + def decode_event( + self, event: Dict + ) -> Tuple[Optional[str], Optional[str], Optional[int]]: """ Decode transfer and withdrawal events. Returns from_address, to_address (for transfer), and value. """ try: - from_address = Web3.to_checksum_address("0x" + event['topics'][1].hex()[-40:]) - value_hex = event['data'] - + from_address = Web3.to_checksum_address( + "0x" + event["topics"][1].hex()[-40:] + ) + value_hex = event["data"] + if isinstance(value_hex, bytes): - value = int.from_bytes(value_hex, byteorder='big') + value = int.from_bytes(value_hex, byteorder="big") else: value = int(value_hex, 16) - if len(event['topics']) > 2: # Transfer event - to_address = Web3.to_checksum_address("0x" + event['topics'][2].hex()[-40:]) + if len(event["topics"]) > 2: # Transfer event + to_address = Web3.to_checksum_address( + "0x" + event["topics"][2].hex()[-40:] + ) return from_address, to_address, value else: # Withdrawal event return from_address, None, value @@ -165,40 +190,57 @@ def decode_event(self, event: Dict) -> Tuple[Optional[str], Optional[str], Optio print(f"Error decoding event: {str(e)}") return None, None, None - def process_event(self, event: Dict, inflows: Dict[str, int], outflows: Dict[str, int], address: str) -> None: + def process_event( + self, + event: Dict, + inflows: Dict[str, int], + outflows: Dict[str, int], + address: str, + ) -> None: """Process a single event to update inflows and outflows.""" from_address, to_address, value = self.decode_event(event) if from_address is None or to_address is None: return if to_address == address: - inflows[event['address']] = inflows.get(event['address'], 0) + value + inflows[event["address"]] = inflows.get(event["address"], 0) + value if from_address == address: - outflows[event['address']] = outflows.get(event['address'], 0) + value + outflows[event["address"]] = outflows.get(event["address"], 0) + value - def calculate_imbalances(self, events: Dict[str, List[Dict]], address: str) -> Dict[str, int]: + def calculate_imbalances( + self, events: Dict[str, List[Dict]], address: str + ) -> Dict[str, int]: """Calculate token imbalances from events.""" - inflows, outflows = {}, {} - for event in events['Transfer']: + inflows, outflows = {}, {} # type: (dict, dict) + for event in events["Transfer"]: self.process_event(event, inflows, outflows, address) imbalances = { - token_address: inflows.get(token_address, 0) - outflows.get(token_address, 0) + token_address: inflows.get(token_address, 0) + - outflows.get(token_address, 0) for token_address in set(inflows.keys()).union(outflows.keys()) } return imbalances - def update_weth_imbalance(self, events: Dict[str, List[Dict]], actions: List[Dict], imbalances: Dict[str, int], address: str) -> None: + def update_weth_imbalance( + self, + events: Dict[str, List[Dict]], + actions: List[Dict], + imbalances: Dict[str, int], + address: str, + ) -> None: """Update the WETH imbalance in imbalances.""" weth_inflow = imbalances.get(WETH_TOKEN_ADDRESS, 0) weth_outflow = 0 weth_withdrawals = 0 - for event in events['WithdrawalWETH']: + for event in events["WithdrawalWETH"]: from_address, _, value = self.decode_event(event) if from_address == address: weth_withdrawals += value imbalances[WETH_TOKEN_ADDRESS] = weth_inflow - weth_outflow - weth_withdrawals - def update_native_eth_imbalance(self, imbalances: Dict[str, int], native_eth_imbalance: Optional[int]) -> None: + def update_native_eth_imbalance( + self, imbalances: Dict[str, int], native_eth_imbalance: Optional[int] + ) -> None: """Update the native ETH imbalance in imbalances.""" if native_eth_imbalance is not None: imbalances[NATIVE_ETH_TOKEN_ADDRESS] = native_eth_imbalance @@ -207,9 +249,9 @@ def decode_sdai_event(self, event: Dict) -> int | None: """Decode sDAI event.""" try: # SDAI event has hex value at the end, which needs to be extracted - value_hex = event['data'][-30:] + value_hex = event["data"][-30:] if isinstance(value_hex, bytes): - value = int.from_bytes(value_hex, byteorder='big') + value = int.from_bytes(value_hex, byteorder="big") else: value = int(value_hex, 16) return value @@ -217,49 +259,64 @@ def decode_sdai_event(self, event: Dict) -> int | None: print(f"Error decoding sDAI event: {str(e)}") return None - def process_sdai_event(self, event: Dict, imbalances: Dict[str, int], is_deposit: bool) -> None: + def process_sdai_event( + self, event: Dict, imbalances: Dict[str, int], is_deposit: bool + ) -> None: """Process an sDAI deposit or withdrawal event to update imbalances.""" decoded_event_value = self.decode_sdai_event(event) if decoded_event_value is None: return if is_deposit: - imbalances[SDAI_TOKEN_ADDRESS] = imbalances.get(SDAI_TOKEN_ADDRESS, 0) + decoded_event_value + imbalances[SDAI_TOKEN_ADDRESS] = ( + imbalances.get(SDAI_TOKEN_ADDRESS, 0) + decoded_event_value + ) else: - imbalances[SDAI_TOKEN_ADDRESS] = imbalances.get(SDAI_TOKEN_ADDRESS, 0) - decoded_event_value + imbalances[SDAI_TOKEN_ADDRESS] = ( + imbalances.get(SDAI_TOKEN_ADDRESS, 0) - decoded_event_value + ) - def update_sdai_imbalance(self, events: Dict[str, List[Dict]], imbalances: Dict[str, int]) -> None: + def update_sdai_imbalance( + self, events: Dict[str, List[Dict]], imbalances: Dict[str, int] + ) -> None: """Update the sDAI imbalance in imbalances.""" - for event in events['DepositSDAI']: - if event['address'] == SDAI_TOKEN_ADDRESS: + for event in events["DepositSDAI"]: + if event["address"] == SDAI_TOKEN_ADDRESS: self.process_sdai_event(event, imbalances, is_deposit=True) - for event in events['WithdrawSDAI']: - if event['address'] == SDAI_TOKEN_ADDRESS: + for event in events["WithdrawSDAI"]: + if event["address"] == SDAI_TOKEN_ADDRESS: self.process_sdai_event(event, imbalances, is_deposit=False) def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: """Compute token imbalances for a given transaction hash.""" tx_receipt = self.get_transaction_receipt(tx_hash) if tx_receipt is None: - raise ValueError(f"Transaction hash {tx_hash} not found on chain {self.chain_name}.") + raise ValueError( + f"Transaction hash {tx_hash} not found on chain {self.chain_name}." + ) # find trace and actions from trace to track native ETH events traces = self.get_transaction_trace(tx_hash) native_eth_imbalance = None actions = [] if traces is not None: actions = self.extract_actions(traces, SETTLEMENT_CONTRACT_ADDRESS) - native_eth_imbalance = self.calculate_native_eth_imbalance(actions, SETTLEMENT_CONTRACT_ADDRESS) + native_eth_imbalance = self.calculate_native_eth_imbalance( + actions, SETTLEMENT_CONTRACT_ADDRESS + ) events = self.extract_events(tx_receipt) imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS) if actions: - self.update_weth_imbalance(events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS) + self.update_weth_imbalance( + events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS + ) self.update_native_eth_imbalance(imbalances, native_eth_imbalance) self.update_sdai_imbalance(events, imbalances) return imbalances + # main method for finding imbalance for a single tx hash def main() -> None: tx_hash = input("Enter transaction hash: ") @@ -273,5 +330,6 @@ def main() -> None: except ValueError as e: print(e) + if __name__ == "__main__": main() diff --git a/tests/basic_test.py b/tests/basic_test.py index 378756c..4598d5e 100644 --- a/tests/basic_test.py +++ b/tests/basic_test.py @@ -1,28 +1,37 @@ import pytest from src.imbalances_script import RawTokenImbalances -@pytest.mark.parametrize("tx_hash, expected_imbalances", [ - # Native ETH buy - ("0x749b557872d7d1f857719f619300df9621631f87338caa706154a3d7040fac9f", - { - "0x6B175474E89094C44Da98b954EedeAC495271d0F": 6286775129763176601, - "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": 12147750061816, - "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE": 221116798827683 - }), - # SDAI sell - ("0xdae82500c69c66db4e4a8c64e1d6a95f3cdc5cb81a5a00228ce6f247b9b8cefd", - { - "0x83F20F44975D03b1b09e64809B757c47f942BEeA": 90419674604117409792, - "0x6B175474E89094C44Da98b954EedeAC495271d0F": 360948092321672598, - }), - # ERC404 Token Buy - ("0xfcb1d20df8a90f5b4646a5d1818da407b3a78cfcb8291f477291f5c01115ca7a", - { - "0x9E9FbDE7C7a83c43913BddC8779158F1368F0413": -11207351687745217, - "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": 64641750602289665, - }), -]) +@pytest.mark.parametrize( + "tx_hash, expected_imbalances", + [ + # Native ETH buy + ( + "0x749b557872d7d1f857719f619300df9621631f87338caa706154a3d7040fac9f", + { + "0x6B175474E89094C44Da98b954EedeAC495271d0F": 6286775129763176601, + "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": 12147750061816, + "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE": 221116798827683, + }, + ), + # SDAI sell + ( + "0xdae82500c69c66db4e4a8c64e1d6a95f3cdc5cb81a5a00228ce6f247b9b8cefd", + { + "0x83F20F44975D03b1b09e64809B757c47f942BEeA": 90419674604117409792, + "0x6B175474E89094C44Da98b954EedeAC495271d0F": 360948092321672598, + }, + ), + # ERC404 Token Buy + ( + "0xfcb1d20df8a90f5b4646a5d1818da407b3a78cfcb8291f477291f5c01115ca7a", + { + "0x9E9FbDE7C7a83c43913BddC8779158F1368F0413": -11207351687745217, + "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": 64641750602289665, + }, + ), + ], +) def test_imbalances(tx_hash, expected_imbalances): rt = RawTokenImbalances() imbalances, _ = rt.compute_imbalances(tx_hash) diff --git a/tests/compare_imbalances.py b/tests/compare_imbalances.py index 43900d3..ea336e2 100644 --- a/tests/compare_imbalances.py +++ b/tests/compare_imbalances.py @@ -2,6 +2,7 @@ Script can be used as a sanity test to compare raw imbalances via RawTokenImbalances class and the BalanceOfImbalances class. """ + import time from web3 import Web3 from src.config import ETHEREUM_NODE_URL @@ -12,13 +13,15 @@ RED_COLOR = "\033[91m" RESET_COLOR = "\033[0m" + def remove_zero_balances(balances: dict) -> dict: """Remove entries with zero balance for all tokens.""" return {token: balance for token, balance in balances.items() if balance != 0} + def compare_imbalances(tx_hash: str, web3: Web3) -> None: """Compare imbalances computed by RawTokenImbalances and BalanceOfImbalances.""" - raw_imbalances = RawTokenImbalances(web3, 'Ethereum') + raw_imbalances = RawTokenImbalances(web3, "Ethereum") balanceof_imbalances = BalanceOfImbalances(ETHEREUM_NODE_URL) raw_result = raw_imbalances.compute_imbalances(tx_hash) @@ -29,10 +32,13 @@ def compare_imbalances(tx_hash: str, web3: Web3) -> None: balanceof_result = remove_zero_balances(balanceof_result) if raw_result != balanceof_result: - print(f"{RED_COLOR}Imbalances do not match for tx: {tx_hash}.\nRaw: {raw_result}\nBalanceOf: {balanceof_result}{RESET_COLOR}") + print( + f"{RED_COLOR}Imbalances do not match for tx: {tx_hash}.\nRaw: {raw_result}\nBalanceOf: {balanceof_result}{RESET_COLOR}" + ) else: print(f"Imbalances match for transaction {tx_hash}.") + def main() -> None: start_block = int(input("Enter start block number: ")) end_block = int(input("Enter end block number: ")) @@ -49,5 +55,6 @@ def main() -> None: except Exception as e: print(f"Error comparing imbalances for tx {tx_hash}: {e}") + if __name__ == "__main__": main()