Add ThirdPartyPoolState and necessary auxiliary code

This commit is contained in:
Thales Lima
2024-07-12 18:23:11 +02:00
committed by tvinagre
parent 94d4ab568a
commit f6fba2805a
8 changed files with 1024 additions and 16 deletions

View File

@@ -0,0 +1,211 @@
import logging
import time
from decimal import Decimal
from fractions import Fraction
from typing import Any, Union, NamedTuple
import eth_abi
from eth_abi.exceptions import DecodingError
from eth_typing import HexStr
from eth_utils import keccak
from eth_utils.abi import collapse_if_tuple
from hexbytes import HexBytes
from protosim_py import (
SimulationEngine,
SimulationParameters,
SimulationResult,
StateUpdate,
)
from tycho.tycho.constants import EXTERNAL_ACCOUNT
from tycho.tycho.models import Address, EthereumToken, EVMBlock, Capability
from tycho.tycho.utils import load_abi, maybe_coerce_error
log = logging.getLogger(__name__)
TStateOverwrites = dict[Address, dict[int, int]]
class Trade(NamedTuple):
"""
Trade represents a simple trading operation with fields:
received_amount: Amount received from the trade
gas_used: Amount of gas used in the transaction
price: Price at which the trade was executed
"""
received_amount: float
gas_used: float
price: float
class ProtoSimResponse:
def __init__(self, return_value: Any, simulation_result: "SimulationResult"):
self.return_value = return_value
self.simulation_result = simulation_result
class ProtoSimContract:
def __init__(self, address: Address, abi_name: str, engine: SimulationEngine):
self.abi = load_abi(abi_name)
self.address = address
self.engine = engine
self._default_tx_env = dict(
caller=EXTERNAL_ACCOUNT, to=self.address, value=0, overrides={}
)
functions = [f for f in self.abi if f["type"] == "function"]
self._functions = {f["name"]: f for f in functions}
if len(self._functions) != len(functions):
raise ValueError(
f"ProtoSimContract does not support overloaded function names! "
f"Encountered while loading {abi_name}."
)
def _encode_input(self, fname: str, args: list) -> bytearray:
func = self._functions[fname]
types = [collapse_if_tuple(t) for t in func["inputs"]]
selector = keccak(text=f"{fname}({','.join(types)})")[:4]
return bytearray(selector + eth_abi.encode(types, args))
def _decode_output(self, fname: str, encoded: list[int]) -> Any:
func = self._functions[fname]
types = [collapse_if_tuple(t) for t in func["outputs"]]
return eth_abi.decode(types, bytearray(encoded))
def call(
self,
fname: str,
*args: list[Union[int, str, bool, bytes]],
block_number,
timestamp: int = None,
overrides: TStateOverwrites = None,
caller: Address = EXTERNAL_ACCOUNT,
value: int = 0,
) -> ProtoSimResponse:
call_data = self._encode_input(fname, *args)
params = SimulationParameters(
data=call_data,
to=self.address,
block_number=block_number,
timestamp=timestamp or int(time.time()),
overrides=overrides or {},
caller=caller,
value=value,
)
sim_result = self._simulate(params)
try:
output = self._decode_output(fname, sim_result.result)
except DecodingError:
log.warning("Failed to decode output")
output = None
return ProtoSimResponse(output, sim_result)
def _simulate(self, params: SimulationParameters) -> "SimulationResult":
"""Run simulation and handle errors.
It catches a RuntimeError:
- if it's ``Execution reverted``, re-raises a RuntimeError
with a Tenderly link added
- if it's ``Out of gas``, re-raises a RecoverableSimulationException
- otherwise it just re-raises the original error.
"""
try:
simulation_result = self.engine.run_sim(params)
return simulation_result
except RuntimeError as err:
try:
coerced_err = maybe_coerce_error(err, self, params.gas_limit)
except Exception:
log.exception("Couldn't coerce error. Re-raising the original one.")
raise err
msg = str(coerced_err)
if "Revert!" in msg:
raise type(coerced_err)(msg, repr(self)) from err
else:
raise coerced_err
class AdapterContract(ProtoSimContract):
"""
The AdapterContract provides an interface to interact with the protocols implemented
by third parties using the `propeller-protocol-lib`.
"""
def __init__(self, address: Address, engine: SimulationEngine):
super().__init__(address, "ISwapAdapter", engine)
def price(
self,
pair_id: HexStr,
sell_token: EthereumToken,
buy_token: EthereumToken,
amounts: list[int],
block: EVMBlock,
overwrites: TStateOverwrites = None,
) -> list[Fraction]:
args = [HexBytes(pair_id), sell_token.address, buy_token.address, amounts]
res = self.call(
"price",
args,
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrites,
)
return list(map(lambda x: Fraction(*x), res.return_value[0]))
def swap(
self,
pair_id: HexStr,
sell_token: EthereumToken,
buy_token: EthereumToken,
is_buy: bool,
amount: Decimal,
block: EVMBlock,
overwrites: TStateOverwrites = None,
) -> tuple[Trade, dict[str, StateUpdate]]:
args = [
HexBytes(pair_id),
sell_token.address,
buy_token.address,
int(is_buy),
amount,
]
res = self.call(
"swap",
args,
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrites,
)
amount, gas, price = res.return_value[0]
return Trade(amount, gas, Fraction(*price)), res.simulation_result.state_updates
def get_limits(
self,
pair_id: HexStr,
sell_token: EthereumToken,
buy_token: EthereumToken,
block: EVMBlock,
overwrites: TStateOverwrites = None,
) -> tuple[int, int]:
args = [HexBytes(pair_id), sell_token.address, buy_token.address]
res = self.call(
"getLimits",
args,
block_number=block.id,
timestamp=int(block.ts.timestamp()),
overrides=overwrites,
)
return res.return_value[0]
def get_capabilities(
self, pair_id: HexStr, sell_token: EthereumToken, buy_token: EthereumToken
) -> set[Capability]:
args = [HexBytes(pair_id), sell_token.address, buy_token.address]
res = self.call("getCapabilities", args, block_number=1)
return set(map(Capability, res.return_value[0]))
def min_gas_usage(self) -> int:
res = self.call("minGasUsage", [], block_number=1)
return res.return_value[0]

View File

@@ -1,8 +1,12 @@
from decimal import Decimal
from logging import getLogger
from typing import Any
from protosim_py import SimulationEngine
from tycho.tycho.exceptions import TychoDecodeError
from tycho.tycho.models import EVMBlock, ThirdPartyPool, EthereumToken
from tycho.tycho.models import EVMBlock, EthereumToken, DatabaseType
from tycho.tycho.pool_state import ThirdPartyPool
from tycho.tycho.utils import decode_tycho_exchange
log = getLogger(__name__)
@@ -60,7 +64,6 @@ class ThirdPartyPoolTychoDecoder:
adapter_contract_name=self.adapter_contract,
minimum_gas=self.minimum_gas,
hard_sell_limit=self.hard_limit,
db_type=DatabaseType.tycho,
trace=True,
**optional_attributes,
)

View File

@@ -1,3 +1,6 @@
from decimal import Decimal
class TychoDecodeError(Exception):
def __init__(self, msg: str, pool_id: str):
super().__init__(msg)
@@ -6,3 +9,47 @@ class TychoDecodeError(Exception):
class APIRequestError(Exception):
pass
class TradeSimulationException(Exception):
def __init__(self, message, pool_id: str):
self.pool_id = pool_id
super().__init__(message)
class RecoverableSimulationException(TradeSimulationException):
"""Marks that the simulation could not fully fulfill the requested order.
Provides a partial trade that is valid but does not fully fulfill the conditions
requested.
Parameters
----------
message
Error message
pool_id
ID of a pool that caused the error
partial_trade
A tuple of (bought_amount, gas_used, new_pool_state, sold_amount)
"""
def __init__(
self,
message,
pool_id: str,
partial_trade: tuple[Decimal, int, "ThirdPartyPool", Decimal] = None,
):
super().__init__(message, pool_id)
self.partial_trade = partial_trade
class OutOfGas(RecoverableSimulationException):
"""This exception indicates that the underlying VM **likely** ran out of gas.
It is not easy to judge whether it was really due to out of gas, as the details
of the SC being called might be hiding this. E.g. out of gas may happen while
calling an external contract, which might show as the external call failing, although
it was due to a lack of gas.
"""
pass

View File

@@ -1,9 +1,11 @@
import datetime
from enum import Enum
from enum import Enum, IntEnum, auto
from typing import Union
from pydantic import BaseModel, Field
Address = str
class Blockchain(Enum):
ethereum = "ethereum"
@@ -18,12 +20,26 @@ class EVMBlock(BaseModel):
hash_: str
class ThirdPartyPool:
pass
class EthereumToken(BaseModel):
symbol: str
address: str
decimals: int
gas: Union[int, list[int]] = 29000
class DatabaseType(Enum):
# Make call to the node each time it needs a storage (unless cached from a previous call).
rpc_reader = "rpc_reader"
# Connect to Tycho and cache the whole state of a target contract, the state is continuously updated by Tycho.
# To use this we need Tycho to be configured to index the target contract state.
tycho = "tycho"
class Capability(IntEnum):
SellSide = auto()
BuySide = auto()
PriceFunction = auto()
FeeOnTransfer = auto()
ConstantPrice = auto()
TokenBalanceIndependent = auto()
ScaledPrice = auto()

394
tycho/tycho/pool_state.py Normal file
View File

@@ -0,0 +1,394 @@
import functools
import itertools
from collections import defaultdict
from copy import deepcopy
from decimal import Decimal
from fractions import Fraction
from logging import getLogger
from typing import Optional, cast, TypeVar
from protosim_py import SimulationEngine, AccountInfo
from pydantic import BaseModel, PrivateAttr, Field
from tycho.tycho.adapter_contract import AdapterContract
from tycho.tycho.constants import MAX_BALANCE, EXTERNAL_ACCOUNT
from tycho.tycho.exceptions import RecoverableSimulationException
from tycho.tycho.models import (
EVMBlock,
DatabaseType,
Capability,
Address,
EthereumToken,
)
from tycho.tycho.utils import (
create_engine,
get_contract_bytecode,
frac_to_decimal,
ERC20OverwriteFactory,
)
from eth_typing import HexStr
ADAPTER_ADDRESS = "0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270"
log = getLogger(__name__)
TPoolState = TypeVar("TPoolState", bound="ThirdPartyPool")
class ThirdPartyPool(BaseModel):
id_: str
tokens: tuple[EthereumToken, ...]
balances: dict[Address, Decimal]
block: EVMBlock
spot_prices: dict[tuple[EthereumToken, EthereumToken], Decimal]
trading_fee: Decimal
exchange: str
minimum_gas: int
_engine: SimulationEngine = PrivateAttr(default=None)
adapter_contract_name: str
"""The adapters contract name. Used to look up the byte code for the adapter."""
_adapter_contract: AdapterContract = PrivateAttr(default=None)
stateless_contracts: dict[str, bytes] = {}
"""The address to bytecode map of all stateless contracts used by the protocol for simulations."""
capabilities: set[Capability] = Field(default_factory=lambda: {Capability.SellSide})
"""The supported capabilities of this pool."""
balance_owner: Optional[str] = None
"""The contract address for where protocol balances are stored (i.e. a vault contract).
If given, balances will be overwritten here instead of on the pool contract during simulations."""
block_lasting_overwrites: defaultdict[Address, dict[int, int]] = Field(
default_factory=lambda: defaultdict(dict)
)
"""Storage overwrites that will be applied to all simulations. They will be cleared
when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name."""
trace: bool = False
hard_sell_limit: bool = False
"""
Whether the pool will revert if you attempt to sell more than the limit. Defaults to
False where it is assumed that exceeding the limit will provide a bad price but will
still succeed.
"""
def __init__(self, **data):
super().__init__(**data)
self._set_engine(data.get("engine", None))
self.balance_owner = data.get("balance_owner", None)
self._adapter_contract = AdapterContract(ADAPTER_ADDRESS, self._engine)
self._set_capabilities()
if len(self.spot_prices) == 0:
self._set_spot_prices()
def _set_engine(self, engine: Optional[SimulationEngine]):
"""Set instance's simulation engine. If no engine given, make a default one.
If engine is already set, this is a noop.
The engine will have the specified adapter contract mocked, as well as the
tokens used by the pool.
Parameters
----------
engine
Optional simulation engine instance.
"""
if self._engine is not None:
return
else:
engine = create_engine([t.address for t in self.tokens], trace=self.trace)
engine.init_account(
address=ADAPTER_ADDRESS,
account=AccountInfo(
balance=0,
nonce=0,
code=get_contract_bytecode(self.adapter_contract_name),
),
mocked=False,
permanent_storage=None,
)
for addr, bytecode in self.stateless_contracts.items():
engine.init_account(
address=addr,
account=AccountInfo(balance=0, nonce=0, code=bytecode),
mocked=False,
permanent_storage=None,
)
self._engine = engine
"""Set the spot prices for this pool.
We currently require the price function capability for now.
"""
self._ensure_capability(Capability.PriceFunction)
for t0, t1 in itertools.permutations(self.tokens, 2):
sell_amount = t0.to_onchain_amount(
self.get_sell_amount_limit(t0, t1) * Decimal("0.01")
)
frac = self._adapter_contract.price(
cast(HexStr, self.id_),
t0,
t1,
[sell_amount],
block=self.block,
overwrites=self.block_lasting_overwrites,
)[0]
if Capability.ScaledPrice in self.capabilities:
self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
else:
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals)
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
def _ensure_capability(self, capability: Capability):
"""Ensures the protocol/adapter implement a certain capability."""
if capability not in self.capabilities:
raise NotImplemented(f"{capability} not available!")
def _set_capabilities(self):
"""Sets capabilities of the pool."""
capabilities = []
for t0, t1 in itertools.permutations(self.tokens, 2):
capabilities.append(
self._adapter_contract.get_capabilities(cast(HexStr, self.id_), t0, t1)
)
max_capabilities = max(map(len, capabilities))
self.capabilities = functools.reduce(set.intersection, capabilities)
if len(self.capabilities) < max_capabilities:
log.warning(
f"Pool {self.id_} hash different capabilities depending on the token pair!"
)
def get_amount_out(
self: TPoolState,
sell_token: EthereumToken,
sell_amount: Decimal,
buy_token: EthereumToken,
slippage: Decimal = Decimal(0),
create_new_pool: bool = True,
) -> tuple[Decimal, int, TPoolState]:
# if the pool has a hard limit and the sell amount exceeds that, simulate and
# raise a partial trade
if self.hard_sell_limit:
sell_limit = self.get_sell_amount_limit(sell_token, buy_token)
if sell_amount > sell_limit:
partial_trade = self._get_amount_out(sell_token, sell_limit, buy_token)
raise RecoverableSimulationException(
"Sell amount exceeds sell limit",
repr(self),
partial_trade + (sell_limit,),
)
return self._get_amount_out(sell_token, sell_amount, buy_token)
def _get_amount_out(
self: TPoolState,
sell_token: EthereumToken,
sell_amount: Decimal,
buy_token: EthereumToken,
) -> tuple[Decimal, int, TPoolState]:
trade, state_changes = self._adapter_contract.swap(
cast(HexStr, self.id_),
sell_token,
buy_token,
False,
sell_token.to_onchain_amount(sell_amount),
block=self.block,
overwrites=self._get_overwrites(sell_token, buy_token),
)
new_state = self._duplicate()
for address, state_update in state_changes.items():
for slot, value in state_update.storage.items():
new_state.block_lasting_overwrites[address][slot] = value
new_price = frac_to_decimal(trade.price)
if new_price != Decimal(0):
new_state.spot_prices = {
(sell_token, buy_token): new_price,
(buy_token, sell_token): Decimal(1) / new_price,
}
buy_amount = buy_token.from_onchain_amount(trade.received_amount)
return buy_amount, trade.gas_used, new_state
def _get_overwrites(
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
) -> dict[Address, dict[int, int]]:
"""Get an overwrites dictionary to use in a simulation.
The returned overwrites include block-lasting overwrites set on the instance
level, and token-specific overwrites that depend on passed tokens.
"""
token_overwrites = self._get_token_overwrites(sell_token, buy_token, **kwargs)
return _merge(self.block_lasting_overwrites, token_overwrites)
def _get_token_overwrites(
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
) -> dict[Address, dict[int, int]]:
"""Creates overwrites for a token.
Funds external account with enough tokens to execute swaps. Also creates a
corresponding approval to the adapter contract.
If the protocol reads its own token balances, the balances for the underlying
pool contract will also be overwritten.
"""
res = []
if Capability.TokenBalanceIndependent not in self.capabilities:
res = [self._get_balance_overwrites()]
# avoids recursion if using this method with get_sell_amount_limit
if max_amount is None:
max_amount = sell_token.to_onchain_amount(
self.get_sell_amount_limit(sell_token, buy_token)
)
overwrites = ERC20OverwriteFactory(sell_token)
overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT)
overwrites.set_allowance(
allowance=max_amount, owner=EXTERNAL_ACCOUNT, spender=ADAPTER_ADDRESS
)
res.append(overwrites.get_protosim_overwrites())
# we need to merge the dictionaries because balance overwrites may target
# the same token address.
res = functools.reduce(_merge, res)
return res
def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]:
balance_overwrites = {}
address = self.balance_owner or self.id_
for t in self.tokens:
overwrites = ERC20OverwriteFactory(t)
overwrites.set_balance(
t.to_onchain_amount(self.balances[t.address]), address
)
balance_overwrites.update(overwrites.get_protosim_overwrites())
return balance_overwrites
def _duplicate(self: type["ThirdPartyPool"]) -> "ThirdPartyPool":
"""Make a new instance identical to self that shares the same simulation engine.
Note that the new and current state become coupled in a way that they must
simulate the same block. This is fine, see
https://datarevenue.atlassian.net/browse/ROC-1301
Not naming this method _copy to not confuse with pydantic's .copy method.
"""
return type(self)(
exchange=self.exchange,
adapter_contract_name=self.adapter_contract_name,
block=self.block,
id_=self.id_,
tokens=self.tokens,
spot_prices=self.spot_prices.copy(),
trading_fee=self.trading_fee,
block_lasting_overwrites=deepcopy(self.block_lasting_overwrites),
engine=self._engine,
balances=self.balances,
minimum_gas=self.minimum_gas,
hard_sell_limit=self.hard_sell_limit,
balance_owner=self.balance_owner,
stateless_contracts=self.stateless_contracts,
)
def get_sell_amount_limit(
self, sell_token: EthereumToken, buy_token: EthereumToken
) -> Decimal:
"""
Retrieves the sell amount of the given token.
For pools with more than 2 tokens, the sell limit is obtain for all possible buy token
combinations and the minimum is returned.
"""
limit = self._adapter_contract.get_limits(
cast(HexStr, self.id_),
sell_token,
buy_token,
block=self.block,
overwrites=self._get_overwrites(
sell_token, buy_token, max_amount=MAX_BALANCE // 100
),
)[0]
return sell_token.from_onchain_amount(limit)
def simulate_transition(
self: TPoolState,
sell_token: EthereumToken,
buy_token: EthereumToken,
target_price: Decimal,
max_sell_amount: Decimal,
) -> tuple[Decimal, Decimal, int, TPoolState]:
pass
def update_inplace(self, new: "ThirdPartyPool"):
"""
Updates the current `ThirdPartyPool` in-place.
If the block attribute of the `new` state differs from the block attribute of
the current state, the temporary storage of the simulation engine is cleared.
Parameters
----------
new
The new state object to update from.
"""
old_block = self.block
super(ThirdPartyPool, self).update_inplace(new)
self.block_lasting_overwrites = new.block_lasting_overwrites.copy()
if new.block != old_block:
self.clear_all_cache()
def clear_all_cache(self):
self._engine.clear_temp_storage()
self.block_lasting_overwrites = defaultdict(dict)
self._set_spot_prices()
# def transition(self, event: ProtocolEvent) -> "ThirdPartyPool":
# """Make a new pool state so that everything's initialised from scratch.
#
# Instead of interpreting the event and applying the changes in signals, we just
# create a fresh instance of the pool. This way all dynamic parameters will be
# set again using the on-chain data.
# """
# new = type(self)(
# block=self.block,
# id_=self.id_,
# tokens=self.tokens,
# balances=self.balances.copy(),
# )
# if isinstance(event, ERC20Transfer):
# transition_balances_inplace(self.id_, event, new.balances)
# return new
#
def _merge(a: dict, b: dict, path=None):
"""
Merges two dictionaries (a and b) deeply. This means it will traverse and combine
their nested dictionaries too if present.
Parameters:
a (dict): The first dictionary to merge.
b (dict): The second dictionary to merge into the first one.
path (list, optional): An internal parameter used during recursion
to keep track of the ancestry of nested dictionaries.
Returns:
a (dict): The merged dictionary which includes all key-value pairs from b
added into a. If they have nested dictionaries with same keys, those are also merged.
On key conflicts, preference is given to values from b.
"""
if path is None:
path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
_merge(a[key], b[key], path + [str(key)])
else:
a[key] = b[key]
return a

View File

@@ -28,7 +28,9 @@ from tycho.tycho.constants import (
)
from tycho.tycho.decoders import ThirdPartyPoolTychoDecoder
from tycho.tycho.exceptions import APIRequestError
from tycho.tycho.models import Blockchain, EVMBlock, ThirdPartyPool, EthereumToken
from tycho.tycho.models import Blockchain, EVMBlock, EthereumToken
from tycho.tycho.pool_state import ThirdPartyPool
from tycho.tycho.utils import create_engine, TychoDBSingleton
log = getLogger(__name__)
@@ -124,14 +126,9 @@ class TychoPoolStateStreamAdapter:
self._decoder = decoder
# Create engine
self._db = TychoDB(tycho_http_url=self.tycho_url)
self._engine = SimulationEngine.new_with_tycho_db(db=self._db, trace=True)
self._engine.init_account(
address=EXTERNAL_ACCOUNT,
account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None),
mocked=False,
permanent_storage=None,
)
# TODO: This should be initialized outside the adapter?
TychoDBSingleton.initialize(tycho_http_url=self.tycho_url)
self._engine = create_engine([], state_block=None, trace=True)
# Loads tokens from Tycho
self._tokens: dict[str, EthereumToken] = TokenLoader(

0
tycho/tycho/tycho_db.py Normal file
View File

View File

@@ -1,3 +1,343 @@
import json
import os
from decimal import Decimal
from fractions import Fraction
from functools import lru_cache
from logging import getLogger
from pathlib import Path
from typing import Final, Any
import eth_abi
from eth_typing import HexStr
from hexbytes import HexBytes
from protosim_py import SimulationEngine, TychoDB, AccountInfo
from web3 import Web3
from tycho.tycho.constants import EXTERNAL_ACCOUNT, MAX_BALANCE
from tycho.tycho.exceptions import OutOfGas
from tycho.tycho.models import Address, EthereumToken
log = getLogger(__name__)
def decode_tycho_exchange(exchange: str) -> (str, bool):
# removes vm prefix if present, returns True if vm prefix was present (vm protocol) or False if native protocol
return (exchange.split(":")[1], False) if "vm:" in exchange else (exchange, True)
class TychoDBSingleton:
"""
A singleton wrapper around the TychoDB class.
This class ensures that there is only one instance of TychoDB throughout the lifetime of the program,
avoiding the overhead of creating multiple instances.
"""
_instance = None
@classmethod
def initialize(cls, tycho_http_url: str):
"""
Initialize the TychoDB instance with the given URLs.
Parameters
----------
tycho_http_url : str
The URL of the Tycho HTTP server.
"""
cls._instance = TychoDB(tycho_http_url=tycho_http_url)
@classmethod
def get_instance(cls) -> TychoDB:
"""
Retrieve the singleton instance of TychoDB.
If the TychoDB instance does not exist, it creates a new one.
If it already exists, it returns the existing instance.
Returns
-------
TychoDB
The singleton instance of TychoDB.
"""
if cls._instance is None:
raise ValueError(
"TychoDB instance not initialized. Call initialize() first."
)
return cls._instance
@classmethod
def clear_instance(cls):
cls._instance = None
def create_engine(
mocked_tokens: list[Address], trace: bool = False
) -> SimulationEngine:
"""Create a simulation engine with a mocked ERC20 contract at given addresses.
Parameters
----------
mocked_tokens
A list of addresses at which a mocked ERC20 contract should be inserted.
trace
Whether to trace calls, only meant for debugging purposes, might print a lot of
data to stdout.
"""
db = TychoDBSingleton.get_instance()
engine = SimulationEngine.new_with_tycho_db(db=db, trace=trace)
for t in mocked_tokens:
info = AccountInfo(balance=0, nonce=0, code=get_contract_bytecode("ERC20.bin"))
engine.init_account(
address=t, account=info, mocked=True, permanent_storage=None
)
engine.init_account(
address=EXTERNAL_ACCOUNT,
account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None),
mocked=False,
permanent_storage=None,
)
return engine
class ERC20OverwriteFactory:
def __init__(self, token: EthereumToken):
"""
Initialize the ERC20OverwriteFactory.
Parameters:
token: The token object.
"""
self._token = token
self._overwrites = dict()
self._balance_slot: Final[int] = 0
self._allowance_slot: Final[int] = 1
def set_balance(self, balance: int, owner: Address):
"""
Set the balance for a given owner.
Parameters:
balance: The balance value.
owner: The owner's address.
"""
storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot)
self._overwrites[storage_index] = balance
log.log(
5,
f"Override balance: token={self._token.address} owner={owner}"
f"value={balance} slot={storage_index}",
)
def set_allowance(self, allowance: int, spender: Address, owner: Address):
"""
Set the allowance for a given spender and owner.
Parameters:
allowance: The allowance value.
spender: The spender's address.
owner: The owner's address.
"""
storage_index = get_storage_slot_at_key(
HexStr(spender),
get_storage_slot_at_key(HexStr(owner), self._allowance_slot),
)
self._overwrites[storage_index] = allowance
log.log(
5,
f"Override allowance: token={self._token.address} owner={owner}"
f"spender={spender} value={allowance} slot={storage_index}",
)
def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]:
"""
Get the overwrites dictionary of previously collected values.
Returns:
dict[Address, dict]: A dictionary containing the token's address
and the overwrites.
"""
# Protosim returns lowercase addresses in state updates returned from simulation
return {self._token.address.lower(): self._overwrites}
def get_geth_overwrites(self) -> dict[Address, dict[int, int]]:
"""
Get the overwrites dictionary of previously collected values.
Returns:
dict[Address, dict]: A dictionary containing the token's address
and the overwrites.
"""
formatted_overwrites = {
HexBytes(key).hex(): "0x" + HexBytes(val).hex().lstrip("0x").zfill(64)
for key, val in self._overwrites.items()
}
code = "0x" + get_contract_bytecode("ERC20.bin").hex()
return {self._token.address: {"stateDiff": formatted_overwrites, "code": code}}
def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int:
"""Get storage slot index of a value stored at a certain key in a mapping
Parameters
----------
key
Key in a mapping. This function is meant to work with ethereum addresses
and accepts only strings.
mapping_slot
Storage slot at which the mapping itself is stored. See the examples for more
explanation.
Returns
-------
slot
An index of a storage slot where the value at the given key is stored.
Examples
--------
If a mapping is declared as a first variable in solidity code, its storage slot
is 0 (e.g. ``balances`` in our mocked ERC20 contract). Here's how to compute
a storage slot where balance of a given account is stored::
get_storage_slot_at_key("0xC63135E4bF73F637AF616DFd64cf701866BB2628", 0)
For nested mappings, we need to apply the function twice. An example of this is
``allowances`` in ERC20. It is a mapping of form:
``dict[owner, dict[spender, value]]``. In our mocked ERC20 contract, ``allowances``
is a second variable, so it is stored at slot 1. Here's how to get a storage slot
where an allowance of ``0xspender`` to spend ``0xowner``'s money is stored::
get_storage_slot_at_key("0xspender", get_storage_slot_at_key("0xowner", 1)))
See Also
--------
`Solidity Storage Layout documentation
<https://docs.soliditylang.org/en/v0.8.13/internals/layout_in_storage.html#mappings-and-dynamic-arrays>`_
"""
key_bytes = bytes.fromhex(key[2:]).rjust(32, b"\0")
mapping_slot_bytes = int.to_bytes(mapping_slot, 32, "big")
slot_bytes = Web3.keccak(key_bytes + mapping_slot_bytes)
return int.from_bytes(slot_bytes, "big")
@lru_cache
def get_contract_bytecode(name: str) -> bytes:
with open(Path(__file__).parent / "assets" / name, "rb") as fh:
code = fh.read()
return code
def frac_to_decimal(frac: Fraction) -> Decimal:
return Decimal(frac.numerator) / Decimal(frac.denominator)
def load_abi(name_or_path: str) -> dict:
if os.path.exists(abspath := os.path.abspath(name_or_path)):
path = abspath
else:
path = f"{os.path.dirname(os.path.abspath(__file__))}/assets/{name_or_path}.abi"
try:
with open(os.path.abspath(path)) as f:
abi: dict = json.load(f)
except FileNotFoundError:
search_dir = f"{os.path.dirname(os.path.abspath(__file__))}/assets/"
# List all files in search dir and subdirs suggest them to the user in an error message
available_files = []
for dirpath, dirnames, filenames in os.walk(search_dir):
for filename in filenames:
# Make paths relative to search_dir
relative_path = os.path.relpath(
os.path.join(dirpath, filename), search_dir
)
available_files.append(relative_path.replace(".abi", ""))
raise FileNotFoundError(
f"File {name_or_path} not found. "
f"Did you mean one of these? {', '.join(available_files)}"
)
return abi
# https://docs.soliditylang.org/en/latest/control-structures.html#panic-via-assert-and-error-via-require
solidity_panic_codes = {
0: "GenericCompilerPanic",
1: "AssertionError",
17: "ArithmeticOver/Underflow",
18: "ZeroDivisionError",
33: "UnkownEnumMember",
34: "BadStorageByteArrayEncoding",
51: "EmptyArray",
0x32: "OutOfBounds",
0x41: "OutOfMemory",
0x51: "BadFunctionPointer",
}
def parse_solidity_error_message(data) -> str:
data_bytes = HexBytes(data)
error_string = f"Failed to decode: {data}"
# data is encoded as Error(string)
if data_bytes[:4] == HexBytes("0x08c379a0"):
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
return error_string
elif data_bytes[:4] == HexBytes("0x4e487b71"):
(error_code,) = eth_abi.decode(["uint256"], data_bytes[4:])
return solidity_panic_codes.get(error_code, f"Panic({error_code})")
# old solidity: revert 'some string' case
try:
(error_string,) = eth_abi.decode(["string"], data_bytes)
return error_string
except Exception:
pass
# some custom error maybe it is with string?
try:
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
return error_string
except Exception:
pass
try:
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
return error_string
except Exception:
pass
return error_string
def maybe_coerce_error(
err: RuntimeError, pool_state: Any, gas_limit: int = None
) -> Exception:
details = err.args[0]
# we got bytes as data, so this was a revert
if details.data.startswith("0x"):
err = RuntimeError(
f"Revert! Reason: {parse_solidity_error_message(details.data)}"
)
# we have gas information, check if this likely an out of gas err.
if gas_limit is not None and details.gas_used is not None:
# if we used up 97% or more issue a OutOfGas error.
usage = details.gas_used / gas_limit
if usage >= 0.97:
return OutOfGas(
f"SimulationError: Likely out-of-gas. "
f"Used: {usage * 100:.2f}% of gas limit. "
f"Original error: {err}",
repr(pool_state),
)
elif "OutOfGas" in details.data:
if gas_limit is not None:
usage = details.gas_used / gas_limit
usage_msg = f"Used: {usage * 100:.2f}% of gas limit. "
else:
usage_msg = ""
return OutOfGas(
f"SimulationError: out-of-gas. {usage_msg}Original error: {details.data}",
repr(pool_state),
)
return err