Make tycho_client a python package, small bugfixes
This commit is contained in:
355
testing/tycho-client/tycho_client/pool_state.py
Normal file
355
testing/tycho-client/tycho_client/pool_state.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import functools
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from logging import getLogger
|
||||
from typing import Optional, cast, TypeVar, Annotated
|
||||
|
||||
from eth_typing import HexStr
|
||||
from protosim_py import SimulationEngine, AccountInfo
|
||||
from pydantic import BaseModel, PrivateAttr, Field
|
||||
|
||||
from .adapter_contract import AdapterContract
|
||||
from .constants import MAX_BALANCE, EXTERNAL_ACCOUNT
|
||||
from .exceptions import RecoverableSimulationException
|
||||
from .models import EVMBlock, Capability, Address, EthereumToken
|
||||
from .utils import (
|
||||
create_engine,
|
||||
get_contract_bytecode,
|
||||
frac_to_decimal,
|
||||
ERC20OverwriteFactory,
|
||||
)
|
||||
|
||||
ADAPTER_ADDRESS = "0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270"
|
||||
|
||||
log = getLogger(__name__)
|
||||
TPoolState = TypeVar("TPoolState", bound="ThirdPartyPool")
|
||||
|
||||
|
||||
class ThirdPartyPool(BaseModel):
|
||||
id_: str
|
||||
tokens: tuple[EthereumToken, ...]
|
||||
balances: dict[Address, Decimal]
|
||||
block: EVMBlock
|
||||
spot_prices: dict[tuple[EthereumToken, EthereumToken], Decimal]
|
||||
trading_fee: Decimal
|
||||
exchange: str
|
||||
minimum_gas: int
|
||||
|
||||
_engine: SimulationEngine = PrivateAttr(default=None)
|
||||
|
||||
adapter_contract_name: str
|
||||
"""The adapters contract name. Used to look up the byte code for the adapter."""
|
||||
_adapter_contract: AdapterContract = PrivateAttr(default=None)
|
||||
|
||||
stateless_contracts: dict[str, bytes] = {}
|
||||
"""The address to bytecode map of all stateless contracts used by the protocol for simulations."""
|
||||
|
||||
capabilities: set[Capability] = Field(default_factory=lambda: {Capability.SellSide})
|
||||
"""The supported capabilities of this pool."""
|
||||
|
||||
balance_owner: Optional[str] = None
|
||||
"""The contract address for where protocol balances are stored (i.e. a vault contract).
|
||||
If given, balances will be overwritten here instead of on the pool contract during simulations."""
|
||||
|
||||
block_lasting_overwrites: defaultdict[
|
||||
Address,
|
||||
Annotated[dict[int, int], Field(default_factory=lambda: defaultdict[dict])],
|
||||
] = Field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
"""Storage overwrites that will be applied to all simulations. They will be cleared
|
||||
when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name."""
|
||||
|
||||
trace: bool = False
|
||||
|
||||
hard_sell_limit: bool = False
|
||||
"""
|
||||
Whether the pool will revert if you attempt to sell more than the limit. Defaults to
|
||||
False where it is assumed that exceeding the limit will provide a bad price but will
|
||||
still succeed.
|
||||
"""
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_engine(data.get("engine", None))
|
||||
self.balance_owner = data.get("balance_owner", None)
|
||||
self._adapter_contract = AdapterContract(ADAPTER_ADDRESS, self._engine)
|
||||
self._set_capabilities()
|
||||
if len(self.spot_prices) == 0:
|
||||
self._set_spot_prices()
|
||||
|
||||
def _set_engine(self, engine: Optional[SimulationEngine]):
|
||||
"""Set instance's simulation engine. If no engine given, make a default one.
|
||||
|
||||
If engine is already set, this is a noop.
|
||||
|
||||
The engine will have the specified adapter contract mocked, as well as the
|
||||
tokens used by the pool.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine
|
||||
Optional simulation engine instance.
|
||||
"""
|
||||
if self._engine is not None:
|
||||
return
|
||||
else:
|
||||
engine = create_engine([t.address for t in self.tokens], trace=self.trace)
|
||||
engine.init_account(
|
||||
address="0x0000000000000000000000000000000000000000",
|
||||
account=AccountInfo(balance=0, nonce=0),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
engine.init_account(
|
||||
address="0x0000000000000000000000000000000000000004",
|
||||
account=AccountInfo(balance=0, nonce=0),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
engine.init_account(
|
||||
address=ADAPTER_ADDRESS,
|
||||
account=AccountInfo(
|
||||
balance=0,
|
||||
nonce=0,
|
||||
code=get_contract_bytecode(self.adapter_contract_name),
|
||||
),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
for addr, bytecode in self.stateless_contracts.items():
|
||||
engine.init_account(
|
||||
address=addr,
|
||||
account=AccountInfo(balance=0, nonce=0, code=bytecode),
|
||||
mocked=False,
|
||||
permanent_storage=None,
|
||||
)
|
||||
self._engine = engine
|
||||
|
||||
def _set_spot_prices(self):
|
||||
"""Set the spot prices for this pool.
|
||||
|
||||
We currently require the price function capability for now.
|
||||
"""
|
||||
self._ensure_capability(Capability.PriceFunction)
|
||||
for t0, t1 in itertools.permutations(self.tokens, 2):
|
||||
sell_amount = t0.to_onchain_amount(
|
||||
self.get_sell_amount_limit(t0, t1) * Decimal("0.01")
|
||||
)
|
||||
frac = self._adapter_contract.price(
|
||||
cast(HexStr, self.id_),
|
||||
t0,
|
||||
t1,
|
||||
[sell_amount],
|
||||
block=self.block,
|
||||
overwrites=self.block_lasting_overwrites,
|
||||
)[0]
|
||||
if Capability.ScaledPrice in self.capabilities:
|
||||
self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
|
||||
else:
|
||||
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals)
|
||||
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
|
||||
|
||||
def _ensure_capability(self, capability: Capability):
|
||||
"""Ensures the protocol/adapter implement a certain capability."""
|
||||
if capability not in self.capabilities:
|
||||
raise NotImplemented(f"{capability} not available!")
|
||||
|
||||
def _set_capabilities(self):
|
||||
"""Sets capabilities of the pool."""
|
||||
capabilities = []
|
||||
for t0, t1 in itertools.permutations(self.tokens, 2):
|
||||
capabilities.append(
|
||||
self._adapter_contract.get_capabilities(cast(HexStr, self.id_), t0, t1)
|
||||
)
|
||||
max_capabilities = max(map(len, capabilities))
|
||||
self.capabilities = functools.reduce(set.intersection, capabilities)
|
||||
if len(self.capabilities) < max_capabilities:
|
||||
log.warning(
|
||||
f"Pool {self.id_} hash different capabilities depending on the token pair!"
|
||||
)
|
||||
|
||||
def get_amount_out(
|
||||
self: TPoolState,
|
||||
sell_token: EthereumToken,
|
||||
sell_amount: Decimal,
|
||||
buy_token: EthereumToken,
|
||||
) -> tuple[Decimal, int, TPoolState]:
|
||||
# if the pool has a hard limit and the sell amount exceeds that, simulate and
|
||||
# raise a partial trade
|
||||
if self.hard_sell_limit:
|
||||
sell_limit = self.get_sell_amount_limit(sell_token, buy_token)
|
||||
if sell_amount > sell_limit:
|
||||
partial_trade = self._get_amount_out(sell_token, sell_limit, buy_token)
|
||||
raise RecoverableSimulationException(
|
||||
"Sell amount exceeds sell limit",
|
||||
repr(self),
|
||||
partial_trade + (sell_limit,),
|
||||
)
|
||||
|
||||
return self._get_amount_out(sell_token, sell_amount, buy_token)
|
||||
|
||||
def _get_amount_out(
|
||||
self: TPoolState,
|
||||
sell_token: EthereumToken,
|
||||
sell_amount: Decimal,
|
||||
buy_token: EthereumToken,
|
||||
) -> tuple[Decimal, int, TPoolState]:
|
||||
trade, state_changes = self._adapter_contract.swap(
|
||||
cast(HexStr, self.id_),
|
||||
sell_token,
|
||||
buy_token,
|
||||
False,
|
||||
sell_token.to_onchain_amount(sell_amount),
|
||||
block=self.block,
|
||||
overwrites=self._get_overwrites(sell_token, buy_token),
|
||||
)
|
||||
new_state = self._duplicate()
|
||||
for address, state_update in state_changes.items():
|
||||
for slot, value in state_update.storage.items():
|
||||
new_state.block_lasting_overwrites[address][slot] = value
|
||||
|
||||
new_price = frac_to_decimal(trade.price)
|
||||
if new_price != Decimal(0):
|
||||
new_state.spot_prices = {
|
||||
(sell_token, buy_token): new_price,
|
||||
(buy_token, sell_token): Decimal(1) / new_price,
|
||||
}
|
||||
|
||||
buy_amount = buy_token.from_onchain_amount(trade.received_amount)
|
||||
|
||||
return buy_amount, trade.gas_used, new_state
|
||||
|
||||
def _get_overwrites(
|
||||
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
|
||||
) -> dict[Address, dict[int, int]]:
|
||||
"""Get an overwrites dictionary to use in a simulation.
|
||||
|
||||
The returned overwrites include block-lasting overwrites set on the instance
|
||||
level, and token-specific overwrites that depend on passed tokens.
|
||||
"""
|
||||
token_overwrites = self._get_token_overwrites(sell_token, buy_token, **kwargs)
|
||||
return _merge(self.block_lasting_overwrites, token_overwrites)
|
||||
|
||||
def _get_token_overwrites(
|
||||
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
|
||||
) -> dict[Address, dict[int, int]]:
|
||||
"""Creates overwrites for a token.
|
||||
|
||||
Funds external account with enough tokens to execute swaps. Also creates a
|
||||
corresponding approval to the adapter contract.
|
||||
|
||||
If the protocol reads its own token balances, the balances for the underlying
|
||||
pool contract will also be overwritten.
|
||||
"""
|
||||
res = []
|
||||
if Capability.TokenBalanceIndependent not in self.capabilities:
|
||||
res = [self._get_balance_overwrites()]
|
||||
|
||||
# avoids recursion if using this method with get_sell_amount_limit
|
||||
if max_amount is None:
|
||||
max_amount = sell_token.to_onchain_amount(
|
||||
self.get_sell_amount_limit(sell_token, buy_token)
|
||||
)
|
||||
overwrites = ERC20OverwriteFactory(sell_token)
|
||||
overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT)
|
||||
overwrites.set_allowance(
|
||||
allowance=max_amount, owner=EXTERNAL_ACCOUNT, spender=ADAPTER_ADDRESS
|
||||
)
|
||||
res.append(overwrites.get_protosim_overwrites())
|
||||
|
||||
# we need to merge the dictionaries because balance overwrites may target
|
||||
# the same token address.
|
||||
res = functools.reduce(_merge, res)
|
||||
return res
|
||||
|
||||
def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]:
|
||||
balance_overwrites = {}
|
||||
address = self.balance_owner or self.id_
|
||||
for t in self.tokens:
|
||||
overwrites = ERC20OverwriteFactory(t)
|
||||
overwrites.set_balance(
|
||||
t.to_onchain_amount(self.balances[t.address]), address
|
||||
)
|
||||
balance_overwrites.update(overwrites.get_protosim_overwrites())
|
||||
return balance_overwrites
|
||||
|
||||
def _duplicate(self: type["ThirdPartyPool"]) -> "ThirdPartyPool":
|
||||
"""Make a new instance identical to self that shares the same simulation engine.
|
||||
|
||||
Note that the new and current state become coupled in a way that they must
|
||||
simulate the same block. This is fine, see
|
||||
https://datarevenue.atlassian.net/browse/ROC-1301
|
||||
|
||||
Not naming this method _copy to not confuse with Pydantic's .copy method.
|
||||
"""
|
||||
return type(self)(
|
||||
exchange=self.exchange,
|
||||
adapter_contract_name=self.adapter_contract_name,
|
||||
block=self.block,
|
||||
id_=self.id_,
|
||||
tokens=self.tokens,
|
||||
spot_prices=self.spot_prices.copy(),
|
||||
trading_fee=self.trading_fee,
|
||||
block_lasting_overwrites=deepcopy(self.block_lasting_overwrites),
|
||||
engine=self._engine,
|
||||
balances=self.balances,
|
||||
minimum_gas=self.minimum_gas,
|
||||
hard_sell_limit=self.hard_sell_limit,
|
||||
balance_owner=self.balance_owner,
|
||||
stateless_contracts=self.stateless_contracts,
|
||||
)
|
||||
|
||||
def get_sell_amount_limit(
|
||||
self, sell_token: EthereumToken, buy_token: EthereumToken
|
||||
) -> Decimal:
|
||||
"""
|
||||
Retrieves the sell amount of the given token.
|
||||
|
||||
For pools with more than 2 tokens, the sell limit is obtain for all possible buy token
|
||||
combinations and the minimum is returned.
|
||||
"""
|
||||
limit = self._adapter_contract.get_limits(
|
||||
cast(HexStr, self.id_),
|
||||
sell_token,
|
||||
buy_token,
|
||||
block=self.block,
|
||||
overwrites=self._get_overwrites(
|
||||
sell_token, buy_token, max_amount=MAX_BALANCE // 100
|
||||
),
|
||||
)[0]
|
||||
return sell_token.from_onchain_amount(limit)
|
||||
|
||||
def clear_all_cache(self):
|
||||
self._engine.clear_temp_storage()
|
||||
self.block_lasting_overwrites = defaultdict(dict)
|
||||
self._set_spot_prices()
|
||||
|
||||
|
||||
def _merge(a: dict, b: dict, path=None):
|
||||
"""
|
||||
Merges two dictionaries (a and b) deeply. This means it will traverse and combine
|
||||
their nested dictionaries too if present.
|
||||
|
||||
Parameters:
|
||||
a (dict): The first dictionary to merge.
|
||||
b (dict): The second dictionary to merge into the first one.
|
||||
path (list, optional): An internal parameter used during recursion
|
||||
to keep track of the ancestry of nested dictionaries.
|
||||
|
||||
Returns:
|
||||
a (dict): The merged dictionary which includes all key-value pairs from `b`
|
||||
added into `a`. If they have nested dictionaries with same keys, those are also merged.
|
||||
On key conflicts, preference is given to values from b.
|
||||
"""
|
||||
if path is None:
|
||||
path = []
|
||||
for key in b:
|
||||
if key in a:
|
||||
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
||||
_merge(a[key], b[key], path + [str(key)])
|
||||
else:
|
||||
a[key] = b[key]
|
||||
return a
|
||||
Reference in New Issue
Block a user