Add ThirdPartyPoolState and necessary auxiliary code
This commit is contained in:
211
tycho/tycho/adapter_contract.py
Normal file
211
tycho/tycho/adapter_contract.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import logging
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from typing import Any, Union, NamedTuple
|
||||
|
||||
import eth_abi
|
||||
from eth_abi.exceptions import DecodingError
|
||||
from eth_typing import HexStr
|
||||
from eth_utils import keccak
|
||||
from eth_utils.abi import collapse_if_tuple
|
||||
from hexbytes import HexBytes
|
||||
from protosim_py import (
|
||||
SimulationEngine,
|
||||
SimulationParameters,
|
||||
SimulationResult,
|
||||
StateUpdate,
|
||||
)
|
||||
|
||||
from tycho.tycho.constants import EXTERNAL_ACCOUNT
|
||||
from tycho.tycho.models import Address, EthereumToken, EVMBlock, Capability
|
||||
from tycho.tycho.utils import load_abi, maybe_coerce_error
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TStateOverwrites = dict[Address, dict[int, int]]
|
||||
|
||||
|
||||
class Trade(NamedTuple):
|
||||
"""
|
||||
Trade represents a simple trading operation with fields:
|
||||
received_amount: Amount received from the trade
|
||||
gas_used: Amount of gas used in the transaction
|
||||
price: Price at which the trade was executed
|
||||
"""
|
||||
|
||||
received_amount: float
|
||||
gas_used: float
|
||||
price: float
|
||||
|
||||
|
||||
class ProtoSimResponse:
|
||||
def __init__(self, return_value: Any, simulation_result: "SimulationResult"):
|
||||
self.return_value = return_value
|
||||
self.simulation_result = simulation_result
|
||||
|
||||
|
||||
class ProtoSimContract:
|
||||
def __init__(self, address: Address, abi_name: str, engine: SimulationEngine):
|
||||
self.abi = load_abi(abi_name)
|
||||
self.address = address
|
||||
self.engine = engine
|
||||
self._default_tx_env = dict(
|
||||
caller=EXTERNAL_ACCOUNT, to=self.address, value=0, overrides={}
|
||||
)
|
||||
functions = [f for f in self.abi if f["type"] == "function"]
|
||||
self._functions = {f["name"]: f for f in functions}
|
||||
if len(self._functions) != len(functions):
|
||||
raise ValueError(
|
||||
f"ProtoSimContract does not support overloaded function names! "
|
||||
f"Encountered while loading {abi_name}."
|
||||
)
|
||||
|
||||
def _encode_input(self, fname: str, args: list) -> bytearray:
|
||||
func = self._functions[fname]
|
||||
types = [collapse_if_tuple(t) for t in func["inputs"]]
|
||||
selector = keccak(text=f"{fname}({','.join(types)})")[:4]
|
||||
return bytearray(selector + eth_abi.encode(types, args))
|
||||
|
||||
def _decode_output(self, fname: str, encoded: list[int]) -> Any:
|
||||
func = self._functions[fname]
|
||||
types = [collapse_if_tuple(t) for t in func["outputs"]]
|
||||
return eth_abi.decode(types, bytearray(encoded))
|
||||
|
||||
def call(
|
||||
self,
|
||||
fname: str,
|
||||
*args: list[Union[int, str, bool, bytes]],
|
||||
block_number,
|
||||
timestamp: int = None,
|
||||
overrides: TStateOverwrites = None,
|
||||
caller: Address = EXTERNAL_ACCOUNT,
|
||||
value: int = 0,
|
||||
) -> ProtoSimResponse:
|
||||
call_data = self._encode_input(fname, *args)
|
||||
params = SimulationParameters(
|
||||
data=call_data,
|
||||
to=self.address,
|
||||
block_number=block_number,
|
||||
timestamp=timestamp or int(time.time()),
|
||||
overrides=overrides or {},
|
||||
caller=caller,
|
||||
value=value,
|
||||
)
|
||||
sim_result = self._simulate(params)
|
||||
try:
|
||||
output = self._decode_output(fname, sim_result.result)
|
||||
except DecodingError:
|
||||
log.warning("Failed to decode output")
|
||||
output = None
|
||||
return ProtoSimResponse(output, sim_result)
|
||||
|
||||
def _simulate(self, params: SimulationParameters) -> "SimulationResult":
|
||||
"""Run simulation and handle errors.
|
||||
|
||||
It catches a RuntimeError:
|
||||
|
||||
- if it's ``Execution reverted``, re-raises a RuntimeError
|
||||
with a Tenderly link added
|
||||
- if it's ``Out of gas``, re-raises a RecoverableSimulationException
|
||||
- otherwise it just re-raises the original error.
|
||||
"""
|
||||
try:
|
||||
simulation_result = self.engine.run_sim(params)
|
||||
return simulation_result
|
||||
except RuntimeError as err:
|
||||
try:
|
||||
coerced_err = maybe_coerce_error(err, self, params.gas_limit)
|
||||
except Exception:
|
||||
log.exception("Couldn't coerce error. Re-raising the original one.")
|
||||
raise err
|
||||
msg = str(coerced_err)
|
||||
if "Revert!" in msg:
|
||||
raise type(coerced_err)(msg, repr(self)) from err
|
||||
else:
|
||||
raise coerced_err
|
||||
|
||||
|
||||
class AdapterContract(ProtoSimContract):
|
||||
"""
|
||||
The AdapterContract provides an interface to interact with the protocols implemented
|
||||
by third parties using the `propeller-protocol-lib`.
|
||||
"""
|
||||
|
||||
def __init__(self, address: Address, engine: SimulationEngine):
|
||||
super().__init__(address, "ISwapAdapter", engine)
|
||||
|
||||
def price(
|
||||
self,
|
||||
pair_id: HexStr,
|
||||
sell_token: EthereumToken,
|
||||
buy_token: EthereumToken,
|
||||
amounts: list[int],
|
||||
block: EVMBlock,
|
||||
overwrites: TStateOverwrites = None,
|
||||
) -> list[Fraction]:
|
||||
args = [HexBytes(pair_id), sell_token.address, buy_token.address, amounts]
|
||||
res = self.call(
|
||||
"price",
|
||||
args,
|
||||
block_number=block.id,
|
||||
timestamp=int(block.ts.timestamp()),
|
||||
overrides=overwrites,
|
||||
)
|
||||
return list(map(lambda x: Fraction(*x), res.return_value[0]))
|
||||
|
||||
def swap(
|
||||
self,
|
||||
pair_id: HexStr,
|
||||
sell_token: EthereumToken,
|
||||
buy_token: EthereumToken,
|
||||
is_buy: bool,
|
||||
amount: Decimal,
|
||||
block: EVMBlock,
|
||||
overwrites: TStateOverwrites = None,
|
||||
) -> tuple[Trade, dict[str, StateUpdate]]:
|
||||
args = [
|
||||
HexBytes(pair_id),
|
||||
sell_token.address,
|
||||
buy_token.address,
|
||||
int(is_buy),
|
||||
amount,
|
||||
]
|
||||
res = self.call(
|
||||
"swap",
|
||||
args,
|
||||
block_number=block.id,
|
||||
timestamp=int(block.ts.timestamp()),
|
||||
overrides=overwrites,
|
||||
)
|
||||
amount, gas, price = res.return_value[0]
|
||||
return Trade(amount, gas, Fraction(*price)), res.simulation_result.state_updates
|
||||
|
||||
def get_limits(
|
||||
self,
|
||||
pair_id: HexStr,
|
||||
sell_token: EthereumToken,
|
||||
buy_token: EthereumToken,
|
||||
block: EVMBlock,
|
||||
overwrites: TStateOverwrites = None,
|
||||
) -> tuple[int, int]:
|
||||
args = [HexBytes(pair_id), sell_token.address, buy_token.address]
|
||||
res = self.call(
|
||||
"getLimits",
|
||||
args,
|
||||
block_number=block.id,
|
||||
timestamp=int(block.ts.timestamp()),
|
||||
overrides=overwrites,
|
||||
)
|
||||
return res.return_value[0]
|
||||
|
||||
def get_capabilities(
|
||||
self, pair_id: HexStr, sell_token: EthereumToken, buy_token: EthereumToken
|
||||
) -> set[Capability]:
|
||||
args = [HexBytes(pair_id), sell_token.address, buy_token.address]
|
||||
res = self.call("getCapabilities", args, block_number=1)
|
||||
return set(map(Capability, res.return_value[0]))
|
||||
|
||||
def min_gas_usage(self) -> int:
|
||||
res = self.call("minGasUsage", [], block_number=1)
|
||||
return res.return_value[0]
|
||||
@@ -1,8 +1,12 @@
|
||||
from decimal import Decimal
|
||||
from logging import getLogger
|
||||
from typing import Any
|
||||
|
||||
from protosim_py import SimulationEngine
|
||||
|
||||
from tycho.tycho.exceptions import TychoDecodeError
|
||||
from tycho.tycho.models import EVMBlock, ThirdPartyPool, EthereumToken
|
||||
from tycho.tycho.models import EVMBlock, EthereumToken, DatabaseType
|
||||
from tycho.tycho.pool_state import ThirdPartyPool
|
||||
from tycho.tycho.utils import decode_tycho_exchange
|
||||
|
||||
log = getLogger(__name__)
|
||||
@@ -60,7 +64,6 @@ class ThirdPartyPoolTychoDecoder:
|
||||
adapter_contract_name=self.adapter_contract,
|
||||
minimum_gas=self.minimum_gas,
|
||||
hard_sell_limit=self.hard_limit,
|
||||
db_type=DatabaseType.tycho,
|
||||
trace=True,
|
||||
**optional_attributes,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class TychoDecodeError(Exception):
|
||||
def __init__(self, msg: str, pool_id: str):
|
||||
super().__init__(msg)
|
||||
@@ -6,3 +9,47 @@ class TychoDecodeError(Exception):
|
||||
|
||||
class APIRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TradeSimulationException(Exception):
|
||||
def __init__(self, message, pool_id: str):
|
||||
self.pool_id = pool_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RecoverableSimulationException(TradeSimulationException):
|
||||
"""Marks that the simulation could not fully fulfill the requested order.
|
||||
|
||||
Provides a partial trade that is valid but does not fully fulfill the conditions
|
||||
requested.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message
|
||||
Error message
|
||||
pool_id
|
||||
ID of a pool that caused the error
|
||||
partial_trade
|
||||
A tuple of (bought_amount, gas_used, new_pool_state, sold_amount)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
pool_id: str,
|
||||
partial_trade: tuple[Decimal, int, "ThirdPartyPool", Decimal] = None,
|
||||
):
|
||||
super().__init__(message, pool_id)
|
||||
self.partial_trade = partial_trade
|
||||
|
||||
|
||||
class OutOfGas(RecoverableSimulationException):
|
||||
"""This exception indicates that the underlying VM **likely** ran out of gas.
|
||||
|
||||
It is not easy to judge whether it was really due to out of gas, as the details
|
||||
of the SC being called might be hiding this. E.g. out of gas may happen while
|
||||
calling an external contract, which might show as the external call failing, although
|
||||
it was due to a lack of gas.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
Address = str
|
||||
|
||||
|
||||
class Blockchain(Enum):
|
||||
ethereum = "ethereum"
|
||||
@@ -18,12 +20,26 @@ class EVMBlock(BaseModel):
|
||||
hash_: str
|
||||
|
||||
|
||||
class ThirdPartyPool:
|
||||
pass
|
||||
|
||||
|
||||
class EthereumToken(BaseModel):
|
||||
symbol: str
|
||||
address: str
|
||||
decimals: int
|
||||
gas: Union[int, list[int]] = 29000
|
||||
|
||||
|
||||
class DatabaseType(Enum):
|
||||
# Make call to the node each time it needs a storage (unless cached from a previous call).
|
||||
rpc_reader = "rpc_reader"
|
||||
# Connect to Tycho and cache the whole state of a target contract, the state is continuously updated by Tycho.
|
||||
# To use this we need Tycho to be configured to index the target contract state.
|
||||
tycho = "tycho"
|
||||
|
||||
|
||||
class Capability(IntEnum):
|
||||
SellSide = auto()
|
||||
BuySide = auto()
|
||||
PriceFunction = auto()
|
||||
FeeOnTransfer = auto()
|
||||
ConstantPrice = auto()
|
||||
TokenBalanceIndependent = auto()
|
||||
ScaledPrice = auto()
|
||||
|
||||
394
tycho/tycho/pool_state.py
Normal file
394
tycho/tycho/pool_state.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import functools
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from logging import getLogger
|
||||
from typing import Optional, cast, TypeVar
|
||||
|
||||
from protosim_py import SimulationEngine, AccountInfo
|
||||
from pydantic import BaseModel, PrivateAttr, Field
|
||||
|
||||
from tycho.tycho.adapter_contract import AdapterContract
|
||||
from tycho.tycho.constants import MAX_BALANCE, EXTERNAL_ACCOUNT
|
||||
from tycho.tycho.exceptions import RecoverableSimulationException
|
||||
from tycho.tycho.models import (
|
||||
EVMBlock,
|
||||
DatabaseType,
|
||||
Capability,
|
||||
Address,
|
||||
EthereumToken,
|
||||
)
|
||||
from tycho.tycho.utils import (
|
||||
create_engine,
|
||||
get_contract_bytecode,
|
||||
frac_to_decimal,
|
||||
ERC20OverwriteFactory,
|
||||
)
|
||||
from eth_typing import HexStr
|
||||
|
||||
ADAPTER_ADDRESS = "0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270"
|
||||
|
||||
log = getLogger(__name__)
|
||||
TPoolState = TypeVar("TPoolState", bound="ThirdPartyPool")
|
||||
|
||||
|
||||
class ThirdPartyPool(BaseModel):
|
||||
id_: str
|
||||
tokens: tuple[EthereumToken, ...]
|
||||
balances: dict[Address, Decimal]
|
||||
block: EVMBlock
|
||||
spot_prices: dict[tuple[EthereumToken, EthereumToken], Decimal]
|
||||
trading_fee: Decimal
|
||||
exchange: str
|
||||
minimum_gas: int
|
||||
|
||||
_engine: SimulationEngine = PrivateAttr(default=None)
|
||||
|
||||
adapter_contract_name: str
|
||||
"""The adapters contract name. Used to look up the byte code for the adapter."""
|
||||
_adapter_contract: AdapterContract = PrivateAttr(default=None)
|
||||
|
||||
stateless_contracts: dict[str, bytes] = {}
|
||||
"""The address to bytecode map of all stateless contracts used by the protocol for simulations."""
|
||||
|
||||
capabilities: set[Capability] = Field(default_factory=lambda: {Capability.SellSide})
|
||||
"""The supported capabilities of this pool."""
|
||||
|
||||
balance_owner: Optional[str] = None
|
||||
"""The contract address for where protocol balances are stored (i.e. a vault contract).
|
||||
If given, balances will be overwritten here instead of on the pool contract during simulations."""
|
||||
|
||||
block_lasting_overwrites: defaultdict[Address, dict[int, int]] = Field(
|
||||
default_factory=lambda: defaultdict(dict)
|
||||
)
|
||||
"""Storage overwrites that will be applied to all simulations. They will be cleared
|
||||
when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name."""
|
||||
|
||||
trace: bool = False
|
||||
|
||||
hard_sell_limit: bool = False
|
||||
"""
|
||||
Whether the pool will revert if you attempt to sell more than the limit. Defaults to
|
||||
False where it is assumed that exceeding the limit will provide a bad price but will
|
||||
still succeed.
|
||||
"""
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_engine(data.get("engine", None))
|
||||
self.balance_owner = data.get("balance_owner", None)
|
||||
self._adapter_contract = AdapterContract(ADAPTER_ADDRESS, self._engine)
|
||||
self._set_capabilities()
|
||||
if len(self.spot_prices) == 0:
|
||||
self._set_spot_prices()
|
||||
|
||||
def _set_engine(self, engine: Optional[SimulationEngine]):
|
||||
"""Set instance's simulation engine. If no engine given, make a default one.
|
||||
|
||||
If engine is already set, this is a noop.
|
||||
|
||||
The engine will have the specified adapter contract mocked, as well as the
|
||||
tokens used by the pool.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine
|
||||
Optional simulation engine instance.
|
||||
"""
|
||||
if self._engine is not None:
|
||||
return
|
||||
else:
|
||||
engine = create_engine([t.address for t in self.tokens], trace=self.trace)
|
||||
engine.init_account(
|
||||
address=ADAPTER_ADDRESS,
|
||||
account=AccountInfo(
|
||||
balance=0,
|
||||
nonce=0,
|
||||
code=get_contract_bytecode(self.adapter_contract_name),
|
||||
),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
for addr, bytecode in self.stateless_contracts.items():
|
||||
engine.init_account(
|
||||
address=addr,
|
||||
account=AccountInfo(balance=0, nonce=0, code=bytecode),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
self._engine = engine
|
||||
|
||||
"""Set the spot prices for this pool.
|
||||
|
||||
We currently require the price function capability for now.
|
||||
"""
|
||||
self._ensure_capability(Capability.PriceFunction)
|
||||
for t0, t1 in itertools.permutations(self.tokens, 2):
|
||||
sell_amount = t0.to_onchain_amount(
|
||||
self.get_sell_amount_limit(t0, t1) * Decimal("0.01")
|
||||
)
|
||||
frac = self._adapter_contract.price(
|
||||
cast(HexStr, self.id_),
|
||||
t0,
|
||||
t1,
|
||||
[sell_amount],
|
||||
block=self.block,
|
||||
overwrites=self.block_lasting_overwrites,
|
||||
)[0]
|
||||
if Capability.ScaledPrice in self.capabilities:
|
||||
self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
|
||||
else:
|
||||
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals)
|
||||
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
|
||||
|
||||
def _ensure_capability(self, capability: Capability):
|
||||
"""Ensures the protocol/adapter implement a certain capability."""
|
||||
if capability not in self.capabilities:
|
||||
raise NotImplemented(f"{capability} not available!")
|
||||
|
||||
def _set_capabilities(self):
|
||||
"""Sets capabilities of the pool."""
|
||||
capabilities = []
|
||||
for t0, t1 in itertools.permutations(self.tokens, 2):
|
||||
capabilities.append(
|
||||
self._adapter_contract.get_capabilities(cast(HexStr, self.id_), t0, t1)
|
||||
)
|
||||
max_capabilities = max(map(len, capabilities))
|
||||
self.capabilities = functools.reduce(set.intersection, capabilities)
|
||||
if len(self.capabilities) < max_capabilities:
|
||||
log.warning(
|
||||
f"Pool {self.id_} hash different capabilities depending on the token pair!"
|
||||
)
|
||||
|
||||
def get_amount_out(
|
||||
self: TPoolState,
|
||||
sell_token: EthereumToken,
|
||||
sell_amount: Decimal,
|
||||
buy_token: EthereumToken,
|
||||
slippage: Decimal = Decimal(0),
|
||||
create_new_pool: bool = True,
|
||||
) -> tuple[Decimal, int, TPoolState]:
|
||||
# if the pool has a hard limit and the sell amount exceeds that, simulate and
|
||||
# raise a partial trade
|
||||
if self.hard_sell_limit:
|
||||
sell_limit = self.get_sell_amount_limit(sell_token, buy_token)
|
||||
if sell_amount > sell_limit:
|
||||
partial_trade = self._get_amount_out(sell_token, sell_limit, buy_token)
|
||||
raise RecoverableSimulationException(
|
||||
"Sell amount exceeds sell limit",
|
||||
repr(self),
|
||||
partial_trade + (sell_limit,),
|
||||
)
|
||||
|
||||
return self._get_amount_out(sell_token, sell_amount, buy_token)
|
||||
|
||||
def _get_amount_out(
|
||||
self: TPoolState,
|
||||
sell_token: EthereumToken,
|
||||
sell_amount: Decimal,
|
||||
buy_token: EthereumToken,
|
||||
) -> tuple[Decimal, int, TPoolState]:
|
||||
trade, state_changes = self._adapter_contract.swap(
|
||||
cast(HexStr, self.id_),
|
||||
sell_token,
|
||||
buy_token,
|
||||
False,
|
||||
sell_token.to_onchain_amount(sell_amount),
|
||||
block=self.block,
|
||||
overwrites=self._get_overwrites(sell_token, buy_token),
|
||||
)
|
||||
new_state = self._duplicate()
|
||||
for address, state_update in state_changes.items():
|
||||
for slot, value in state_update.storage.items():
|
||||
new_state.block_lasting_overwrites[address][slot] = value
|
||||
|
||||
new_price = frac_to_decimal(trade.price)
|
||||
if new_price != Decimal(0):
|
||||
new_state.spot_prices = {
|
||||
(sell_token, buy_token): new_price,
|
||||
(buy_token, sell_token): Decimal(1) / new_price,
|
||||
}
|
||||
|
||||
buy_amount = buy_token.from_onchain_amount(trade.received_amount)
|
||||
|
||||
return buy_amount, trade.gas_used, new_state
|
||||
|
||||
def _get_overwrites(
|
||||
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
|
||||
) -> dict[Address, dict[int, int]]:
|
||||
"""Get an overwrites dictionary to use in a simulation.
|
||||
|
||||
The returned overwrites include block-lasting overwrites set on the instance
|
||||
level, and token-specific overwrites that depend on passed tokens.
|
||||
"""
|
||||
token_overwrites = self._get_token_overwrites(sell_token, buy_token, **kwargs)
|
||||
return _merge(self.block_lasting_overwrites, token_overwrites)
|
||||
|
||||
def _get_token_overwrites(
|
||||
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
|
||||
) -> dict[Address, dict[int, int]]:
|
||||
"""Creates overwrites for a token.
|
||||
|
||||
Funds external account with enough tokens to execute swaps. Also creates a
|
||||
corresponding approval to the adapter contract.
|
||||
|
||||
If the protocol reads its own token balances, the balances for the underlying
|
||||
pool contract will also be overwritten.
|
||||
"""
|
||||
res = []
|
||||
if Capability.TokenBalanceIndependent not in self.capabilities:
|
||||
res = [self._get_balance_overwrites()]
|
||||
|
||||
# avoids recursion if using this method with get_sell_amount_limit
|
||||
if max_amount is None:
|
||||
max_amount = sell_token.to_onchain_amount(
|
||||
self.get_sell_amount_limit(sell_token, buy_token)
|
||||
)
|
||||
overwrites = ERC20OverwriteFactory(sell_token)
|
||||
overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT)
|
||||
overwrites.set_allowance(
|
||||
allowance=max_amount, owner=EXTERNAL_ACCOUNT, spender=ADAPTER_ADDRESS
|
||||
)
|
||||
res.append(overwrites.get_protosim_overwrites())
|
||||
|
||||
# we need to merge the dictionaries because balance overwrites may target
|
||||
# the same token address.
|
||||
res = functools.reduce(_merge, res)
|
||||
return res
|
||||
|
||||
def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]:
|
||||
balance_overwrites = {}
|
||||
address = self.balance_owner or self.id_
|
||||
for t in self.tokens:
|
||||
overwrites = ERC20OverwriteFactory(t)
|
||||
overwrites.set_balance(
|
||||
t.to_onchain_amount(self.balances[t.address]), address
|
||||
)
|
||||
balance_overwrites.update(overwrites.get_protosim_overwrites())
|
||||
return balance_overwrites
|
||||
|
||||
def _duplicate(self: type["ThirdPartyPool"]) -> "ThirdPartyPool":
|
||||
"""Make a new instance identical to self that shares the same simulation engine.
|
||||
|
||||
Note that the new and current state become coupled in a way that they must
|
||||
simulate the same block. This is fine, see
|
||||
https://datarevenue.atlassian.net/browse/ROC-1301
|
||||
|
||||
Not naming this method _copy to not confuse with pydantic's .copy method.
|
||||
"""
|
||||
return type(self)(
|
||||
exchange=self.exchange,
|
||||
adapter_contract_name=self.adapter_contract_name,
|
||||
block=self.block,
|
||||
id_=self.id_,
|
||||
tokens=self.tokens,
|
||||
spot_prices=self.spot_prices.copy(),
|
||||
trading_fee=self.trading_fee,
|
||||
block_lasting_overwrites=deepcopy(self.block_lasting_overwrites),
|
||||
engine=self._engine,
|
||||
balances=self.balances,
|
||||
minimum_gas=self.minimum_gas,
|
||||
hard_sell_limit=self.hard_sell_limit,
|
||||
balance_owner=self.balance_owner,
|
||||
stateless_contracts=self.stateless_contracts,
|
||||
)
|
||||
|
||||
def get_sell_amount_limit(
|
||||
self, sell_token: EthereumToken, buy_token: EthereumToken
|
||||
) -> Decimal:
|
||||
"""
|
||||
Retrieves the sell amount of the given token.
|
||||
|
||||
For pools with more than 2 tokens, the sell limit is obtain for all possible buy token
|
||||
combinations and the minimum is returned.
|
||||
"""
|
||||
limit = self._adapter_contract.get_limits(
|
||||
cast(HexStr, self.id_),
|
||||
sell_token,
|
||||
buy_token,
|
||||
block=self.block,
|
||||
overwrites=self._get_overwrites(
|
||||
sell_token, buy_token, max_amount=MAX_BALANCE // 100
|
||||
),
|
||||
)[0]
|
||||
return sell_token.from_onchain_amount(limit)
|
||||
|
||||
def simulate_transition(
|
||||
self: TPoolState,
|
||||
sell_token: EthereumToken,
|
||||
buy_token: EthereumToken,
|
||||
target_price: Decimal,
|
||||
max_sell_amount: Decimal,
|
||||
) -> tuple[Decimal, Decimal, int, TPoolState]:
|
||||
pass
|
||||
|
||||
def update_inplace(self, new: "ThirdPartyPool"):
|
||||
"""
|
||||
Updates the current `ThirdPartyPool` in-place.
|
||||
|
||||
If the block attribute of the `new` state differs from the block attribute of
|
||||
the current state, the temporary storage of the simulation engine is cleared.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new
|
||||
The new state object to update from.
|
||||
|
||||
"""
|
||||
old_block = self.block
|
||||
super(ThirdPartyPool, self).update_inplace(new)
|
||||
self.block_lasting_overwrites = new.block_lasting_overwrites.copy()
|
||||
if new.block != old_block:
|
||||
self.clear_all_cache()
|
||||
|
||||
def clear_all_cache(self):
|
||||
self._engine.clear_temp_storage()
|
||||
self.block_lasting_overwrites = defaultdict(dict)
|
||||
self._set_spot_prices()
|
||||
|
||||
# def transition(self, event: ProtocolEvent) -> "ThirdPartyPool":
|
||||
# """Make a new pool state so that everything's initialised from scratch.
|
||||
#
|
||||
# Instead of interpreting the event and applying the changes in signals, we just
|
||||
# create a fresh instance of the pool. This way all dynamic parameters will be
|
||||
# set again using the on-chain data.
|
||||
# """
|
||||
# new = type(self)(
|
||||
# block=self.block,
|
||||
# id_=self.id_,
|
||||
# tokens=self.tokens,
|
||||
# balances=self.balances.copy(),
|
||||
# )
|
||||
# if isinstance(event, ERC20Transfer):
|
||||
# transition_balances_inplace(self.id_, event, new.balances)
|
||||
# return new
|
||||
#
|
||||
|
||||
|
||||
def _merge(a: dict, b: dict, path=None):
|
||||
"""
|
||||
Merges two dictionaries (a and b) deeply. This means it will traverse and combine
|
||||
their nested dictionaries too if present.
|
||||
|
||||
Parameters:
|
||||
a (dict): The first dictionary to merge.
|
||||
b (dict): The second dictionary to merge into the first one.
|
||||
path (list, optional): An internal parameter used during recursion
|
||||
to keep track of the ancestry of nested dictionaries.
|
||||
|
||||
Returns:
|
||||
a (dict): The merged dictionary which includes all key-value pairs from b
|
||||
added into a. If they have nested dictionaries with same keys, those are also merged.
|
||||
On key conflicts, preference is given to values from b.
|
||||
"""
|
||||
if path is None:
|
||||
path = []
|
||||
for key in b:
|
||||
if key in a:
|
||||
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
||||
_merge(a[key], b[key], path + [str(key)])
|
||||
else:
|
||||
a[key] = b[key]
|
||||
return a
|
||||
@@ -28,7 +28,9 @@ from tycho.tycho.constants import (
|
||||
)
|
||||
from tycho.tycho.decoders import ThirdPartyPoolTychoDecoder
|
||||
from tycho.tycho.exceptions import APIRequestError
|
||||
from tycho.tycho.models import Blockchain, EVMBlock, ThirdPartyPool, EthereumToken
|
||||
from tycho.tycho.models import Blockchain, EVMBlock, EthereumToken
|
||||
from tycho.tycho.pool_state import ThirdPartyPool
|
||||
from tycho.tycho.utils import create_engine, TychoDBSingleton
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
@@ -124,14 +126,9 @@ class TychoPoolStateStreamAdapter:
|
||||
self._decoder = decoder
|
||||
|
||||
# Create engine
|
||||
self._db = TychoDB(tycho_http_url=self.tycho_url)
|
||||
self._engine = SimulationEngine.new_with_tycho_db(db=self._db, trace=True)
|
||||
self._engine.init_account(
|
||||
address=EXTERNAL_ACCOUNT,
|
||||
account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
# TODO: This should be initialized outside the adapter?
|
||||
TychoDBSingleton.initialize(tycho_http_url=self.tycho_url)
|
||||
self._engine = create_engine([], state_block=None, trace=True)
|
||||
|
||||
# Loads tokens from Tycho
|
||||
self._tokens: dict[str, EthereumToken] = TokenLoader(
|
||||
|
||||
0
tycho/tycho/tycho_db.py
Normal file
0
tycho/tycho/tycho_db.py
Normal file
@@ -1,3 +1,343 @@
|
||||
import json
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Final, Any
|
||||
|
||||
import eth_abi
|
||||
from eth_typing import HexStr
|
||||
from hexbytes import HexBytes
|
||||
from protosim_py import SimulationEngine, TychoDB, AccountInfo
|
||||
from web3 import Web3
|
||||
|
||||
from tycho.tycho.constants import EXTERNAL_ACCOUNT, MAX_BALANCE
|
||||
from tycho.tycho.exceptions import OutOfGas
|
||||
from tycho.tycho.models import Address, EthereumToken
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def decode_tycho_exchange(exchange: str) -> (str, bool):
|
||||
# removes vm prefix if present, returns True if vm prefix was present (vm protocol) or False if native protocol
|
||||
return (exchange.split(":")[1], False) if "vm:" in exchange else (exchange, True)
|
||||
|
||||
|
||||
class TychoDBSingleton:
|
||||
"""
|
||||
A singleton wrapper around the TychoDB class.
|
||||
|
||||
This class ensures that there is only one instance of TychoDB throughout the lifetime of the program,
|
||||
avoiding the overhead of creating multiple instances.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, tycho_http_url: str):
|
||||
"""
|
||||
Initialize the TychoDB instance with the given URLs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tycho_http_url : str
|
||||
The URL of the Tycho HTTP server.
|
||||
|
||||
"""
|
||||
cls._instance = TychoDB(tycho_http_url=tycho_http_url)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> TychoDB:
|
||||
"""
|
||||
Retrieve the singleton instance of TychoDB.
|
||||
|
||||
If the TychoDB instance does not exist, it creates a new one.
|
||||
If it already exists, it returns the existing instance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TychoDB
|
||||
The singleton instance of TychoDB.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise ValueError(
|
||||
"TychoDB instance not initialized. Call initialize() first."
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def clear_instance(cls):
|
||||
cls._instance = None
|
||||
|
||||
|
||||
def create_engine(
|
||||
mocked_tokens: list[Address], trace: bool = False
|
||||
) -> SimulationEngine:
|
||||
"""Create a simulation engine with a mocked ERC20 contract at given addresses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mocked_tokens
|
||||
A list of addresses at which a mocked ERC20 contract should be inserted.
|
||||
|
||||
trace
|
||||
Whether to trace calls, only meant for debugging purposes, might print a lot of
|
||||
data to stdout.
|
||||
"""
|
||||
|
||||
db = TychoDBSingleton.get_instance()
|
||||
engine = SimulationEngine.new_with_tycho_db(db=db, trace=trace)
|
||||
|
||||
for t in mocked_tokens:
|
||||
info = AccountInfo(balance=0, nonce=0, code=get_contract_bytecode("ERC20.bin"))
|
||||
engine.init_account(
|
||||
address=t, account=info, mocked=True, permanent_storage=None
|
||||
)
|
||||
engine.init_account(
|
||||
address=EXTERNAL_ACCOUNT,
|
||||
account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
class ERC20OverwriteFactory:
|
||||
def __init__(self, token: EthereumToken):
|
||||
"""
|
||||
Initialize the ERC20OverwriteFactory.
|
||||
|
||||
Parameters:
|
||||
token: The token object.
|
||||
"""
|
||||
self._token = token
|
||||
self._overwrites = dict()
|
||||
self._balance_slot: Final[int] = 0
|
||||
self._allowance_slot: Final[int] = 1
|
||||
|
||||
def set_balance(self, balance: int, owner: Address):
|
||||
"""
|
||||
Set the balance for a given owner.
|
||||
|
||||
Parameters:
|
||||
balance: The balance value.
|
||||
owner: The owner's address.
|
||||
"""
|
||||
storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot)
|
||||
self._overwrites[storage_index] = balance
|
||||
log.log(
|
||||
5,
|
||||
f"Override balance: token={self._token.address} owner={owner}"
|
||||
f"value={balance} slot={storage_index}",
|
||||
)
|
||||
|
||||
def set_allowance(self, allowance: int, spender: Address, owner: Address):
|
||||
"""
|
||||
Set the allowance for a given spender and owner.
|
||||
|
||||
Parameters:
|
||||
allowance: The allowance value.
|
||||
spender: The spender's address.
|
||||
owner: The owner's address.
|
||||
"""
|
||||
storage_index = get_storage_slot_at_key(
|
||||
HexStr(spender),
|
||||
get_storage_slot_at_key(HexStr(owner), self._allowance_slot),
|
||||
)
|
||||
self._overwrites[storage_index] = allowance
|
||||
log.log(
|
||||
5,
|
||||
f"Override allowance: token={self._token.address} owner={owner}"
|
||||
f"spender={spender} value={allowance} slot={storage_index}",
|
||||
)
|
||||
|
||||
def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]:
|
||||
"""
|
||||
Get the overwrites dictionary of previously collected values.
|
||||
|
||||
Returns:
|
||||
dict[Address, dict]: A dictionary containing the token's address
|
||||
and the overwrites.
|
||||
"""
|
||||
# Protosim returns lowercase addresses in state updates returned from simulation
|
||||
|
||||
return {self._token.address.lower(): self._overwrites}
|
||||
|
||||
def get_geth_overwrites(self) -> dict[Address, dict[int, int]]:
|
||||
"""
|
||||
Get the overwrites dictionary of previously collected values.
|
||||
|
||||
Returns:
|
||||
dict[Address, dict]: A dictionary containing the token's address
|
||||
and the overwrites.
|
||||
"""
|
||||
formatted_overwrites = {
|
||||
HexBytes(key).hex(): "0x" + HexBytes(val).hex().lstrip("0x").zfill(64)
|
||||
for key, val in self._overwrites.items()
|
||||
}
|
||||
code = "0x" + get_contract_bytecode("ERC20.bin").hex()
|
||||
return {self._token.address: {"stateDiff": formatted_overwrites, "code": code}}
|
||||
|
||||
|
||||
def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int:
|
||||
"""Get storage slot index of a value stored at a certain key in a mapping
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key
|
||||
Key in a mapping. This function is meant to work with ethereum addresses
|
||||
and accepts only strings.
|
||||
mapping_slot
|
||||
Storage slot at which the mapping itself is stored. See the examples for more
|
||||
explanation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
slot
|
||||
An index of a storage slot where the value at the given key is stored.
|
||||
|
||||
Examples
|
||||
--------
|
||||
If a mapping is declared as a first variable in solidity code, its storage slot
|
||||
is 0 (e.g. ``balances`` in our mocked ERC20 contract). Here's how to compute
|
||||
a storage slot where balance of a given account is stored::
|
||||
|
||||
get_storage_slot_at_key("0xC63135E4bF73F637AF616DFd64cf701866BB2628", 0)
|
||||
|
||||
For nested mappings, we need to apply the function twice. An example of this is
|
||||
``allowances`` in ERC20. It is a mapping of form:
|
||||
``dict[owner, dict[spender, value]]``. In our mocked ERC20 contract, ``allowances``
|
||||
is a second variable, so it is stored at slot 1. Here's how to get a storage slot
|
||||
where an allowance of ``0xspender`` to spend ``0xowner``'s money is stored::
|
||||
|
||||
get_storage_slot_at_key("0xspender", get_storage_slot_at_key("0xowner", 1)))
|
||||
|
||||
See Also
|
||||
--------
|
||||
`Solidity Storage Layout documentation
|
||||
<https://docs.soliditylang.org/en/v0.8.13/internals/layout_in_storage.html#mappings-and-dynamic-arrays>`_
|
||||
"""
|
||||
key_bytes = bytes.fromhex(key[2:]).rjust(32, b"\0")
|
||||
mapping_slot_bytes = int.to_bytes(mapping_slot, 32, "big")
|
||||
slot_bytes = Web3.keccak(key_bytes + mapping_slot_bytes)
|
||||
return int.from_bytes(slot_bytes, "big")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_contract_bytecode(name: str) -> bytes:
|
||||
with open(Path(__file__).parent / "assets" / name, "rb") as fh:
|
||||
code = fh.read()
|
||||
return code
|
||||
|
||||
|
||||
def frac_to_decimal(frac: Fraction) -> Decimal:
|
||||
return Decimal(frac.numerator) / Decimal(frac.denominator)
|
||||
|
||||
|
||||
def load_abi(name_or_path: str) -> dict:
|
||||
if os.path.exists(abspath := os.path.abspath(name_or_path)):
|
||||
path = abspath
|
||||
else:
|
||||
path = f"{os.path.dirname(os.path.abspath(__file__))}/assets/{name_or_path}.abi"
|
||||
try:
|
||||
with open(os.path.abspath(path)) as f:
|
||||
abi: dict = json.load(f)
|
||||
except FileNotFoundError:
|
||||
search_dir = f"{os.path.dirname(os.path.abspath(__file__))}/assets/"
|
||||
|
||||
# List all files in search dir and subdirs suggest them to the user in an error message
|
||||
available_files = []
|
||||
for dirpath, dirnames, filenames in os.walk(search_dir):
|
||||
for filename in filenames:
|
||||
# Make paths relative to search_dir
|
||||
relative_path = os.path.relpath(
|
||||
os.path.join(dirpath, filename), search_dir
|
||||
)
|
||||
available_files.append(relative_path.replace(".abi", ""))
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"File {name_or_path} not found. "
|
||||
f"Did you mean one of these? {', '.join(available_files)}"
|
||||
)
|
||||
return abi
|
||||
|
||||
|
||||
# https://docs.soliditylang.org/en/latest/control-structures.html#panic-via-assert-and-error-via-require
|
||||
solidity_panic_codes = {
|
||||
0: "GenericCompilerPanic",
|
||||
1: "AssertionError",
|
||||
17: "ArithmeticOver/Underflow",
|
||||
18: "ZeroDivisionError",
|
||||
33: "UnkownEnumMember",
|
||||
34: "BadStorageByteArrayEncoding",
|
||||
51: "EmptyArray",
|
||||
0x32: "OutOfBounds",
|
||||
0x41: "OutOfMemory",
|
||||
0x51: "BadFunctionPointer",
|
||||
}
|
||||
|
||||
|
||||
def parse_solidity_error_message(data) -> str:
|
||||
data_bytes = HexBytes(data)
|
||||
error_string = f"Failed to decode: {data}"
|
||||
# data is encoded as Error(string)
|
||||
if data_bytes[:4] == HexBytes("0x08c379a0"):
|
||||
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
|
||||
return error_string
|
||||
elif data_bytes[:4] == HexBytes("0x4e487b71"):
|
||||
(error_code,) = eth_abi.decode(["uint256"], data_bytes[4:])
|
||||
return solidity_panic_codes.get(error_code, f"Panic({error_code})")
|
||||
# old solidity: revert 'some string' case
|
||||
try:
|
||||
(error_string,) = eth_abi.decode(["string"], data_bytes)
|
||||
return error_string
|
||||
except Exception:
|
||||
pass
|
||||
# some custom error maybe it is with string?
|
||||
try:
|
||||
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
|
||||
return error_string
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
(error_string,) = eth_abi.decode(["string"], data_bytes[4:])
|
||||
return error_string
|
||||
except Exception:
|
||||
pass
|
||||
return error_string
|
||||
|
||||
|
||||
def maybe_coerce_error(
|
||||
err: RuntimeError, pool_state: Any, gas_limit: int = None
|
||||
) -> Exception:
|
||||
details = err.args[0]
|
||||
# we got bytes as data, so this was a revert
|
||||
if details.data.startswith("0x"):
|
||||
err = RuntimeError(
|
||||
f"Revert! Reason: {parse_solidity_error_message(details.data)}"
|
||||
)
|
||||
# we have gas information, check if this likely an out of gas err.
|
||||
if gas_limit is not None and details.gas_used is not None:
|
||||
# if we used up 97% or more issue a OutOfGas error.
|
||||
usage = details.gas_used / gas_limit
|
||||
if usage >= 0.97:
|
||||
return OutOfGas(
|
||||
f"SimulationError: Likely out-of-gas. "
|
||||
f"Used: {usage * 100:.2f}% of gas limit. "
|
||||
f"Original error: {err}",
|
||||
repr(pool_state),
|
||||
)
|
||||
elif "OutOfGas" in details.data:
|
||||
if gas_limit is not None:
|
||||
usage = details.gas_used / gas_limit
|
||||
usage_msg = f"Used: {usage * 100:.2f}% of gas limit. "
|
||||
else:
|
||||
usage_msg = ""
|
||||
return OutOfGas(
|
||||
f"SimulationError: out-of-gas. {usage_msg}Original error: {details.data}",
|
||||
repr(pool_state),
|
||||
)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user