import json import os from decimal import Decimal from fractions import Fraction from functools import lru_cache from logging import getLogger from pathlib import Path from typing import Final, Any import eth_abi from eth_typing import HexStr from hexbytes import HexBytes from protosim_py import SimulationEngine, AccountInfo import requests from web3 import Web3 from .constants import EXTERNAL_ACCOUNT, MAX_BALANCE, ASSETS_FOLDER from .exceptions import OutOfGas from .models import Address, EthereumToken from .tycho_db import TychoDBSingleton log = getLogger(__name__) def decode_tycho_exchange(exchange: str) -> (str, bool): # removes vm prefix if present, returns True if vm prefix was present (vm protocol) or False if native protocol return (exchange.split(":")[1], False) if "vm:" in exchange else (exchange, True) def create_engine( mocked_tokens: list[Address], trace: bool = False ) -> SimulationEngine: """Create a simulation engine with a mocked ERC20 contract at given addresses. Parameters ---------- mocked_tokens A list of addresses at which a mocked ERC20 contract should be inserted. trace Whether to trace calls, only meant for debugging purposes, might print a lot of data to stdout. """ db = TychoDBSingleton.get_instance() engine = SimulationEngine.new_with_tycho_db(db=db, trace=trace) for t in mocked_tokens: info = AccountInfo( balance=0, nonce=0, code=get_contract_bytecode(ASSETS_FOLDER / "ERC20.bin") ) engine.init_account( address=t, account=info, mocked=True, permanent_storage=None ) engine.init_account( address=EXTERNAL_ACCOUNT, account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None), mocked=False, permanent_storage=None, ) return engine class ERC20OverwriteFactory: def __init__(self, token: EthereumToken): """ Initialize the ERC20OverwriteFactory. Parameters: token: The token object. """ self._token = token self._overwrites = dict() self._balance_slot: Final[int] = 0 self._allowance_slot: Final[int] = 1 def set_balance(self, balance: int, owner: Address): """ Set the balance for a given owner. Parameters: balance: The balance value. owner: The owner's address. """ storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot) self._overwrites[storage_index] = balance log.log( 5, f"Override balance: token={self._token.address} owner={owner}" f"value={balance} slot={storage_index}", ) def set_allowance(self, allowance: int, spender: Address, owner: Address): """ Set the allowance for a given spender and owner. Parameters: allowance: The allowance value. spender: The spender's address. owner: The owner's address. """ storage_index = get_storage_slot_at_key( HexStr(spender), get_storage_slot_at_key(HexStr(owner), self._allowance_slot), ) self._overwrites[storage_index] = allowance log.log( 5, f"Override allowance: token={self._token.address} owner={owner}" f"spender={spender} value={allowance} slot={storage_index}", ) def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]: """ Get the overwrites dictionary of previously collected values. Returns: dict[Address, dict]: A dictionary containing the token's address and the overwrites. """ # Protosim returns lowercase addresses in state updates returned from simulation return {self._token.address.lower(): self._overwrites} def get_geth_overwrites(self) -> dict[Address, dict[int, int]]: """ Get the overwrites dictionary of previously collected values. Returns: dict[Address, dict]: A dictionary containing the token's address and the overwrites. """ formatted_overwrites = { HexBytes(key).hex(): "0x" + HexBytes(val).hex().lstrip("0x").zfill(64) for key, val in self._overwrites.items() } code = "0x" + get_contract_bytecode(ASSETS_FOLDER / "ERC20.bin").hex() return {self._token.address: {"stateDiff": formatted_overwrites, "code": code}} def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int: """Get storage slot index of a value stored at a certain key in a mapping Parameters ---------- key Key in a mapping. This function is meant to work with ethereum addresses and accepts only strings. mapping_slot Storage slot at which the mapping itself is stored. See the examples for more explanation. Returns ------- slot An index of a storage slot where the value at the given key is stored. Examples -------- If a mapping is declared as a first variable in solidity code, its storage slot is 0 (e.g. ``balances`` in our mocked ERC20 contract). Here's how to compute a storage slot where balance of a given account is stored:: get_storage_slot_at_key("0xC63135E4bF73F637AF616DFd64cf701866BB2628", 0) For nested mappings, we need to apply the function twice. An example of this is ``allowances`` in ERC20. It is a mapping of form: ``dict[owner, dict[spender, value]]``. In our mocked ERC20 contract, ``allowances`` is a second variable, so it is stored at slot 1. Here's how to get a storage slot where an allowance of ``0xspender`` to spend ``0xowner``'s money is stored:: get_storage_slot_at_key("0xspender", get_storage_slot_at_key("0xowner", 1))) See Also -------- `Solidity Storage Layout documentation `_ """ key_bytes = bytes.fromhex(key[2:]).rjust(32, b"\0") mapping_slot_bytes = int.to_bytes(mapping_slot, 32, "big") slot_bytes = Web3.keccak(key_bytes + mapping_slot_bytes) return int.from_bytes(slot_bytes, "big") @lru_cache def get_contract_bytecode(path: str) -> bytes: """Load contract bytecode from a file given an absolute path""" with open(path, "rb") as fh: code = fh.read() return code def frac_to_decimal(frac: Fraction) -> Decimal: return Decimal(frac.numerator) / Decimal(frac.denominator) def load_abi(name_or_path: str) -> dict: if os.path.exists(abspath := os.path.abspath(name_or_path)): path = abspath else: path = f"{os.path.dirname(os.path.abspath(__file__))}/assets/{name_or_path}.abi" try: with open(os.path.abspath(path)) as f: abi: dict = json.load(f) except FileNotFoundError: search_dir = f"{os.path.dirname(os.path.abspath(__file__))}/assets/" # List all files in search dir and subdirs suggest them to the user in an error message available_files = [] for dirpath, dirnames, filenames in os.walk(search_dir): for filename in filenames: # Make paths relative to search_dir relative_path = os.path.relpath( os.path.join(dirpath, filename), search_dir ) available_files.append(relative_path.replace(".abi", "")) raise FileNotFoundError( f"File {name_or_path} not found. " f"Did you mean one of these? {', '.join(available_files)}" ) return abi # https://docs.soliditylang.org/en/latest/control-structures.html#panic-via-assert-and-error-via-require solidity_panic_codes = { 0: "GenericCompilerPanic", 1: "AssertionError", 17: "ArithmeticOver/Underflow", 18: "ZeroDivisionError", 33: "UnkownEnumMember", 34: "BadStorageByteArrayEncoding", 51: "EmptyArray", 0x32: "OutOfBounds", 0x41: "OutOfMemory", 0x51: "BadFunctionPointer", } def parse_solidity_error_message(data) -> str: data_bytes = HexBytes(data) error_string = f"Failed to decode: {data}" # data is encoded as Error(string) if data_bytes[:4] == HexBytes("0x08c379a0"): (error_string,) = eth_abi.decode(["string"], data_bytes[4:]) return error_string elif data_bytes[:4] == HexBytes("0x4e487b71"): (error_code,) = eth_abi.decode(["uint256"], data_bytes[4:]) return solidity_panic_codes.get(error_code, f"Panic({error_code})") # old solidity: revert 'some string' case try: (error_string,) = eth_abi.decode(["string"], data_bytes) return error_string except Exception: pass # some custom error maybe it is with string? try: (error_string,) = eth_abi.decode(["string"], data_bytes[4:]) return error_string except Exception: pass try: (error_string,) = eth_abi.decode(["string"], data_bytes[4:]) return error_string except Exception: pass return error_string def maybe_coerce_error( err: RuntimeError, pool_state: Any, gas_limit: int = None ) -> Exception: details = err.args[0] # we got bytes as data, so this was a revert if details.data.startswith("0x"): err = RuntimeError( f"Revert! Reason: {parse_solidity_error_message(details.data)}" ) # we have gas information, check if this likely an out of gas err. if gas_limit is not None and details.gas_used is not None: # if we used up 97% or more issue a OutOfGas error. usage = details.gas_used / gas_limit if usage >= 0.97: return OutOfGas( f"SimulationError: Likely out-of-gas. " f"Used: {usage * 100:.2f}% of gas limit. " f"Original error: {err}", repr(pool_state), ) elif "OutOfGas" in details.data: if gas_limit is not None: usage = details.gas_used / gas_limit usage_msg = f"Used: {usage * 100:.2f}% of gas limit. " else: usage_msg = "" return OutOfGas( f"SimulationError: out-of-gas. {usage_msg}Original error: {details.data}", repr(pool_state), ) return err def exec_rpc_method(url, method, params, timeout=240) -> dict: payload = {"jsonrpc": "2.0", "method": method, "params": params, "id": 1} headers = {"Content-Type": "application/json"} r = requests.post(url, data=json.dumps(payload), headers=headers, timeout=timeout) if r.status_code >= 400: raise RuntimeError( "RPC failed: status_code not ok. (method {}: {})".format( method, r.status_code ) ) data = r.json() if "result" in data: return data["result"] elif "error" in data: raise RuntimeError( "RPC failed with Error {} - {}".format(data["error"], method) ) def get_code_for_address(address: str, connection_string: str = None): if connection_string is None: connection_string = os.getenv("RPC_URL") if connection_string is None: raise EnvironmentError("RPC_URL environment variable is not set") method = "eth_getCode" params = [address, "latest"] try: code = exec_rpc_method(connection_string, method, params) return bytes.fromhex(code[2:]) except RuntimeError as e: print(f"Error fetching code for address {address}: {e}") return None