302 lines
10 KiB
Python
302 lines
10 KiB
Python
import json
|
|
import os
|
|
from decimal import Decimal
|
|
from fractions import Fraction
|
|
from functools import lru_cache
|
|
from logging import getLogger
|
|
from pathlib import Path
|
|
from typing import Final, Any
|
|
|
|
import eth_abi
|
|
from eth_typing import HexStr
|
|
from hexbytes import HexBytes
|
|
from protosim_py import SimulationEngine, AccountInfo
|
|
from web3 import Web3
|
|
|
|
from .constants import EXTERNAL_ACCOUNT, MAX_BALANCE, ASSETS_FOLDER
|
|
from .exceptions import OutOfGas
|
|
from .models import Address, EthereumToken
|
|
from .tycho_db import TychoDBSingleton
|
|
|
|
log = getLogger(__name__)
|
|
|
|
|
|
def decode_tycho_exchange(exchange: str) -> (str, bool):
|
|
# removes vm prefix if present, returns True if vm prefix was present (vm protocol) or False if native protocol
|
|
return (exchange.split(":")[1], False) if "vm:" in exchange else (exchange, True)
|
|
|
|
|
|
def create_engine(
|
|
mocked_tokens: list[Address], trace: bool = False
|
|
) -> SimulationEngine:
|
|
"""Create a simulation engine with a mocked ERC20 contract at given addresses.
|
|
|
|
Parameters
|
|
----------
|
|
mocked_tokens
|
|
A list of addresses at which a mocked ERC20 contract should be inserted.
|
|
|
|
trace
|
|
Whether to trace calls, only meant for debugging purposes, might print a lot of
|
|
data to stdout.
|
|
"""
|
|
|
|
db = TychoDBSingleton.get_instance()
|
|
engine = SimulationEngine.new_with_tycho_db(db=db, trace=trace)
|
|
|
|
for t in mocked_tokens:
|
|
info = AccountInfo(
|
|
balance=0, nonce=0, code=get_contract_bytecode(ASSETS_FOLDER / "ERC20.bin")
|
|
)
|
|
engine.init_account(
|
|
address=t, account=info, mocked=True, permanent_storage=None
|
|
)
|
|
engine.init_account(
|
|
address=EXTERNAL_ACCOUNT,
|
|
account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None),
|
|
mocked=False,
|
|
permanent_storage=None,
|
|
)
|
|
|
|
return engine
|
|
|
|
|
|
class ERC20OverwriteFactory:
|
|
def __init__(self, token: EthereumToken):
|
|
"""
|
|
Initialize the ERC20OverwriteFactory.
|
|
|
|
Parameters:
|
|
token: The token object.
|
|
"""
|
|
self._token = token
|
|
self._overwrites = dict()
|
|
self._balance_slot: Final[int] = 0
|
|
self._allowance_slot: Final[int] = 1
|
|
|
|
def set_balance(self, balance: int, owner: Address):
|
|
"""
|
|
Set the balance for a given owner.
|
|
|
|
Parameters:
|
|
balance: The balance value.
|
|
owner: The owner's address.
|
|
"""
|
|
storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot)
|
|
self._overwrites[storage_index] = balance
|
|
log.log(
|
|
5,
|
|
f"Override balance: token={self._token.address} owner={owner}"
|
|
f"value={balance} slot={storage_index}",
|
|
)
|
|
|
|
def set_allowance(self, allowance: int, spender: Address, owner: Address):
|
|
"""
|
|
Set the allowance for a given spender and owner.
|
|
|
|
Parameters:
|
|
allowance: The allowance value.
|
|
spender: The spender's address.
|
|
owner: The owner's address.
|
|
"""
|
|
storage_index = get_storage_slot_at_key(
|
|
HexStr(spender),
|
|
get_storage_slot_at_key(HexStr(owner), self._allowance_slot),
|
|
)
|
|
self._overwrites[storage_index] = allowance
|
|
log.log(
|
|
5,
|
|
f"Override allowance: token={self._token.address} owner={owner}"
|
|
f"spender={spender} value={allowance} slot={storage_index}",
|
|
)
|
|
|
|
def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]:
|
|
"""
|
|
Get the overwrites dictionary of previously collected values.
|
|
|
|
Returns:
|
|
dict[Address, dict]: A dictionary containing the token's address
|
|
and the overwrites.
|
|
"""
|
|
# Protosim returns lowercase addresses in state updates returned from simulation
|
|
|
|
return {self._token.address.lower(): self._overwrites}
|
|
|
|
def get_geth_overwrites(self) -> dict[Address, dict[int, int]]:
|
|
"""
|
|
Get the overwrites dictionary of previously collected values.
|
|
|
|
Returns:
|
|
dict[Address, dict]: A dictionary containing the token's address
|
|
and the overwrites.
|
|
"""
|
|
formatted_overwrites = {
|
|
HexBytes(key).hex(): "0x" + HexBytes(val).hex().lstrip("0x").zfill(64)
|
|
for key, val in self._overwrites.items()
|
|
}
|
|
|
|
code = "0x" + get_contract_bytecode(ASSETS_FOLDER / "ERC20.bin").hex()
|
|
return {self._token.address: {"stateDiff": formatted_overwrites, "code": code}}
|
|
|
|
|
|
def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int:
|
|
"""Get storage slot index of a value stored at a certain key in a mapping
|
|
|
|
Parameters
|
|
----------
|
|
key
|
|
Key in a mapping. This function is meant to work with ethereum addresses
|
|
and accepts only strings.
|
|
mapping_slot
|
|
Storage slot at which the mapping itself is stored. See the examples for more
|
|
explanation.
|
|
|
|
Returns
|
|
-------
|
|
slot
|
|
An index of a storage slot where the value at the given key is stored.
|
|
|
|
Examples
|
|
--------
|
|
If a mapping is declared as a first variable in solidity code, its storage slot
|
|
is 0 (e.g. ``balances`` in our mocked ERC20 contract). Here's how to compute
|
|
a storage slot where balance of a given account is stored::
|
|
|
|
get_storage_slot_at_key("0xC63135E4bF73F637AF616DFd64cf701866BB2628", 0)
|
|
|
|
For nested mappings, we need to apply the function twice. An example of this is
|
|
``allowances`` in ERC20. It is a mapping of form:
|
|
``dict[owner, dict[spender, value]]``. In our mocked ERC20 contract, ``allowances``
|
|
is a second variable, so it is stored at slot 1. Here's how to get a storage slot
|
|
where an allowance of ``0xspender`` to spend ``0xowner``'s money is stored::
|
|
|
|
get_storage_slot_at_key("0xspender", get_storage_slot_at_key("0xowner", 1)))
|
|
|
|
See Also
|
|
--------
|
|
`Solidity Storage Layout documentation
|
|
<https://docs.soliditylang.org/en/v0.8.13/internals/layout_in_storage.html#mappings-and-dynamic-arrays>`_
|
|
"""
|
|
key_bytes = bytes.fromhex(key[2:]).rjust(32, b"\0")
|
|
mapping_slot_bytes = int.to_bytes(mapping_slot, 32, "big")
|
|
slot_bytes = Web3.keccak(key_bytes + mapping_slot_bytes)
|
|
return int.from_bytes(slot_bytes, "big")
|
|
|
|
|
|
@lru_cache
|
|
def get_contract_bytecode(path: str) -> bytes:
|
|
"""Load contract bytecode from a file given an absolute path"""
|
|
with open(path, "rb") as fh:
|
|
code = fh.read()
|
|
return code
|
|
|
|
|
|
def frac_to_decimal(frac: Fraction) -> Decimal:
|
|
return Decimal(frac.numerator) / Decimal(frac.denominator)
|
|
|
|
|
|
def load_abi(name_or_path: str) -> dict:
|
|
if os.path.exists(abspath := os.path.abspath(name_or_path)):
|
|
path = abspath
|
|
else:
|
|
path = f"{os.path.dirname(os.path.abspath(__file__))}/assets/{name_or_path}.abi"
|
|
try:
|
|
with open(os.path.abspath(path)) as f:
|
|
abi: dict = json.load(f)
|
|
except FileNotFoundError:
|
|
search_dir = f"{os.path.dirname(os.path.abspath(__file__))}/assets/"
|
|
|
|
# List all files in search dir and subdirs suggest them to the user in an error message
|
|
available_files = []
|
|
for dirpath, dirnames, filenames in os.walk(search_dir):
|
|
for filename in filenames:
|
|
# Make paths relative to search_dir
|
|
relative_path = os.path.relpath(
|
|
os.path.join(dirpath, filename), search_dir
|
|
)
|
|
available_files.append(relative_path.replace(".abi", ""))
|
|
|
|
raise FileNotFoundError(
|
|
f"File {name_or_path} not found. "
|
|
f"Did you mean one of these? {', '.join(available_files)}"
|
|
)
|
|
return abi
|
|
|
|
|
|
# https://docs.soliditylang.org/en/latest/control-structures.html#panic-via-assert-and-error-via-require
|
|
solidity_panic_codes = {
|
|
0: "GenericCompilerPanic",
|
|
1: "AssertionError",
|
|
17: "ArithmeticOver/Underflow",
|
|
18: "ZeroDivisionError",
|
|
33: "UnkownEnumMember",
|
|
34: "BadStorageByteArrayEncoding",
|
|
51: "EmptyArray",
|
|
0x32: "OutOfBounds",
|
|
0x41: "OutOfMemory",
|
|
0x51: "BadFunctionPointer",
|
|
}
|
|
|
|
|
|
def parse_solidity_error_message(data) -> str:
|
|
data_bytes = HexBytes(data)
|
|
error_string = f"Failed to decode: {data}"
|
|
# data is encoded as Error(string)
|
|
if data_bytes[:4] == HexBytes("0x08c379a0"):
|
|
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
|
|
return error_string
|
|
elif data_bytes[:4] == HexBytes("0x4e487b71"):
|
|
(error_code,) = eth_abi.decode(["uint256"], data_bytes[4:])
|
|
return solidity_panic_codes.get(error_code, f"Panic({error_code})")
|
|
# old solidity: revert 'some string' case
|
|
try:
|
|
(error_string,) = eth_abi.decode(["string"], data_bytes)
|
|
return error_string
|
|
except Exception:
|
|
pass
|
|
# some custom error maybe it is with string?
|
|
try:
|
|
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
|
|
return error_string
|
|
except Exception:
|
|
pass
|
|
try:
|
|
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
|
|
return error_string
|
|
except Exception:
|
|
pass
|
|
return error_string
|
|
|
|
|
|
def maybe_coerce_error(
|
|
err: RuntimeError, pool_state: Any, gas_limit: int = None
|
|
) -> Exception:
|
|
details = err.args[0]
|
|
# we got bytes as data, so this was a revert
|
|
if details.data.startswith("0x"):
|
|
err = RuntimeError(
|
|
f"Revert! Reason: {parse_solidity_error_message(details.data)}"
|
|
)
|
|
# we have gas information, check if this likely an out of gas err.
|
|
if gas_limit is not None and details.gas_used is not None:
|
|
# if we used up 97% or more issue a OutOfGas error.
|
|
usage = details.gas_used / gas_limit
|
|
if usage >= 0.97:
|
|
return OutOfGas(
|
|
f"SimulationError: Likely out-of-gas. "
|
|
f"Used: {usage * 100:.2f}% of gas limit. "
|
|
f"Original error: {err}",
|
|
repr(pool_state),
|
|
)
|
|
elif "OutOfGas" in details.data:
|
|
if gas_limit is not None:
|
|
usage = details.gas_used / gas_limit
|
|
usage_msg = f"Used: {usage * 100:.2f}% of gas limit. "
|
|
else:
|
|
usage_msg = ""
|
|
return OutOfGas(
|
|
f"SimulationError: out-of-gas. {usage_msg}Original error: {details.data}",
|
|
repr(pool_state),
|
|
)
|
|
return err
|