402 lines
15 KiB
Python
402 lines
15 KiB
Python
import functools
|
|
import itertools
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from decimal import Decimal
|
|
from fractions import Fraction
|
|
from logging import getLogger
|
|
from typing import Optional, cast, TypeVar, Annotated
|
|
|
|
from eth_typing import HexStr
|
|
from protosim_py import SimulationEngine, AccountInfo
|
|
from pydantic import BaseModel, PrivateAttr, Field
|
|
|
|
from .adapter_contract import AdapterContract
|
|
from .constants import MAX_BALANCE, EXTERNAL_ACCOUNT
|
|
from .exceptions import RecoverableSimulationException
|
|
from .models import EVMBlock, Capability, Address, EthereumToken
|
|
from .utils import (
|
|
create_engine,
|
|
get_contract_bytecode,
|
|
frac_to_decimal,
|
|
ERC20OverwriteFactory,
|
|
)
|
|
|
|
ADAPTER_ADDRESS = "0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270"
|
|
|
|
log = getLogger(__name__)
|
|
TPoolState = TypeVar("TPoolState", bound="ThirdPartyPool")
|
|
|
|
|
|
class ThirdPartyPool(BaseModel):
|
|
id_: str
|
|
tokens: tuple[EthereumToken, ...]
|
|
balances: dict[Address, Decimal]
|
|
block: EVMBlock
|
|
spot_prices: dict[tuple[EthereumToken, EthereumToken], Decimal]
|
|
trading_fee: Decimal
|
|
exchange: str
|
|
minimum_gas: int
|
|
|
|
_engine: SimulationEngine = PrivateAttr(default=None)
|
|
|
|
adapter_contract_name: str
|
|
"""The adapters contract name. Used to look up the byte code for the adapter."""
|
|
_adapter_contract: AdapterContract = PrivateAttr(default=None)
|
|
|
|
stateless_contracts: dict[str, bytes] = {}
|
|
"""The address to bytecode map of all stateless contracts used by the protocol for simulations."""
|
|
|
|
capabilities: set[Capability] = Field(default_factory=lambda: {Capability.SellSide})
|
|
"""The supported capabilities of this pool."""
|
|
|
|
balance_owner: Optional[str] = None
|
|
"""The contract address for where protocol balances are stored (i.e. a vault contract).
|
|
If given, balances will be overwritten here instead of on the pool contract during simulations."""
|
|
|
|
block_lasting_overwrites: defaultdict[
|
|
Address,
|
|
Annotated[dict[int, int], Field(default_factory=lambda: defaultdict[dict])],
|
|
] = Field(default_factory=lambda: defaultdict(dict))
|
|
|
|
"""Storage overwrites that will be applied to all simulations. They will be cleared
|
|
when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name."""
|
|
|
|
trace: bool = False
|
|
|
|
hard_sell_limit: bool = False
|
|
"""
|
|
Whether the pool will revert if you attempt to sell more than the limit. Defaults to
|
|
False where it is assumed that exceeding the limit will provide a bad price but will
|
|
still succeed.
|
|
"""
|
|
|
|
def __init__(self, **data):
|
|
super().__init__(**data)
|
|
self._set_engine(data.get("engine", None))
|
|
self.balance_owner = data.get("balance_owner", None)
|
|
self._adapter_contract = AdapterContract(ADAPTER_ADDRESS, self._engine)
|
|
self._set_capabilities()
|
|
if len(self.spot_prices) == 0:
|
|
self._set_spot_prices()
|
|
|
|
def _set_engine(self, engine: Optional[SimulationEngine]):
|
|
"""Set instance's simulation engine. If no engine given, make a default one.
|
|
|
|
If engine is already set, this is a noop.
|
|
|
|
The engine will have the specified adapter contract mocked, as well as the
|
|
tokens used by the pool.
|
|
|
|
Parameters
|
|
----------
|
|
engine
|
|
Optional simulation engine instance.
|
|
"""
|
|
if self._engine is not None:
|
|
return
|
|
else:
|
|
engine = create_engine([t.address for t in self.tokens], trace=self.trace)
|
|
engine.init_account(
|
|
address="0x0000000000000000000000000000000000000000",
|
|
account=AccountInfo(balance=0, nonce=0),
|
|
mocked=False,
|
|
permanent_storage=None,
|
|
)
|
|
engine.init_account(
|
|
address="0x0000000000000000000000000000000000000004",
|
|
account=AccountInfo(balance=0, nonce=0),
|
|
mocked=False,
|
|
permanent_storage=None,
|
|
)
|
|
engine.init_account(
|
|
address=ADAPTER_ADDRESS,
|
|
account=AccountInfo(
|
|
balance=0,
|
|
nonce=0,
|
|
code=get_contract_bytecode(self.adapter_contract_name),
|
|
),
|
|
mocked=False,
|
|
permanent_storage=None,
|
|
)
|
|
for addr, bytecode in self.stateless_contracts.items():
|
|
engine.init_account(
|
|
address=addr,
|
|
account=AccountInfo(balance=0, nonce=0, code=bytecode),
|
|
mocked=False,
|
|
permanent_storage=None,
|
|
)
|
|
self._engine = engine
|
|
|
|
def _set_spot_prices(self):
|
|
"""Set the spot prices for this pool.
|
|
|
|
We currently require the price function capability for now.
|
|
"""
|
|
self._ensure_capability(Capability.PriceFunction)
|
|
for t0, t1 in itertools.permutations(self.tokens, 2):
|
|
sell_amount = t0.to_onchain_amount(
|
|
self.get_sell_amount_limit(t0, t1) * Decimal("0.01")
|
|
)
|
|
frac = self._adapter_contract.price(
|
|
cast(HexStr, self.id_),
|
|
t0,
|
|
t1,
|
|
[sell_amount],
|
|
block=self.block,
|
|
overwrites=self.block_lasting_overwrites,
|
|
)[0]
|
|
if Capability.ScaledPrice in self.capabilities:
|
|
self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
|
|
else:
|
|
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals)
|
|
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
|
|
|
|
def _ensure_capability(self, capability: Capability):
|
|
"""Ensures the protocol/adapter implement a certain capability."""
|
|
if capability not in self.capabilities:
|
|
raise NotImplemented(f"{capability} not available!")
|
|
|
|
def _set_capabilities(self):
|
|
"""Sets capabilities of the pool."""
|
|
capabilities = []
|
|
for t0, t1 in itertools.permutations(self.tokens, 2):
|
|
capabilities.append(
|
|
self._adapter_contract.get_capabilities(cast(HexStr, self.id_), t0, t1)
|
|
)
|
|
max_capabilities = max(map(len, capabilities))
|
|
self.capabilities = functools.reduce(set.intersection, capabilities)
|
|
if len(self.capabilities) < max_capabilities:
|
|
log.warning(
|
|
f"Pool {self.id_} hash different capabilities depending on the token pair!"
|
|
)
|
|
|
|
def get_amount_out(
|
|
self: TPoolState,
|
|
sell_token: EthereumToken,
|
|
sell_amount: Decimal,
|
|
buy_token: EthereumToken,
|
|
) -> tuple[Decimal, int, TPoolState]:
|
|
# if the pool has a hard limit and the sell amount exceeds that, simulate and
|
|
# raise a partial trade
|
|
if self.hard_sell_limit:
|
|
sell_limit = self.get_sell_amount_limit(sell_token, buy_token)
|
|
if sell_amount > sell_limit:
|
|
partial_trade = self._get_amount_out(sell_token, sell_limit, buy_token)
|
|
raise RecoverableSimulationException(
|
|
"Sell amount exceeds sell limit",
|
|
repr(self),
|
|
partial_trade + (sell_limit,),
|
|
)
|
|
|
|
return self._get_amount_out(sell_token, sell_amount, buy_token)
|
|
|
|
def _get_amount_out(
|
|
self: TPoolState,
|
|
sell_token: EthereumToken,
|
|
sell_amount: Decimal,
|
|
buy_token: EthereumToken,
|
|
) -> tuple[Decimal, int, TPoolState]:
|
|
trade, state_changes = self._adapter_contract.swap(
|
|
cast(HexStr, self.id_),
|
|
sell_token,
|
|
buy_token,
|
|
False,
|
|
sell_token.to_onchain_amount(sell_amount),
|
|
block=self.block,
|
|
overwrites=self._get_overwrites(sell_token, buy_token),
|
|
)
|
|
new_state = self._duplicate()
|
|
for address, state_update in state_changes.items():
|
|
for slot, value in state_update.storage.items():
|
|
new_state.block_lasting_overwrites[address][slot] = value
|
|
|
|
new_price = frac_to_decimal(trade.price)
|
|
if new_price != Decimal(0):
|
|
new_state.spot_prices = {
|
|
(sell_token, buy_token): new_price,
|
|
(buy_token, sell_token): Decimal(1) / new_price,
|
|
}
|
|
|
|
buy_amount = buy_token.from_onchain_amount(trade.received_amount)
|
|
|
|
return buy_amount, trade.gas_used, new_state
|
|
|
|
def _get_overwrites(
|
|
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
|
|
) -> dict[Address, dict[int, int]]:
|
|
"""Get an overwrites dictionary to use in a simulation.
|
|
|
|
The returned overwrites include block-lasting overwrites set on the instance
|
|
level, and token-specific overwrites that depend on passed tokens.
|
|
"""
|
|
token_overwrites = self._get_token_overwrites(sell_token, buy_token, **kwargs)
|
|
return _merge(self.block_lasting_overwrites, token_overwrites)
|
|
|
|
def _get_token_overwrites(
|
|
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
|
|
) -> dict[Address, dict[int, int]]:
|
|
"""Creates overwrites for a token.
|
|
|
|
Funds external account with enough tokens to execute swaps. Also creates a
|
|
corresponding approval to the adapter contract.
|
|
|
|
If the protocol reads its own token balances, the balances for the underlying
|
|
pool contract will also be overwritten.
|
|
"""
|
|
res = []
|
|
if Capability.TokenBalanceIndependent not in self.capabilities:
|
|
res = [self._get_balance_overwrites()]
|
|
|
|
# avoids recursion if using this method with get_sell_amount_limit
|
|
if max_amount is None:
|
|
max_amount = sell_token.to_onchain_amount(
|
|
self.get_sell_amount_limit(sell_token, buy_token)
|
|
)
|
|
overwrites = ERC20OverwriteFactory(sell_token)
|
|
overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT)
|
|
overwrites.set_allowance(
|
|
allowance=max_amount, owner=EXTERNAL_ACCOUNT, spender=ADAPTER_ADDRESS
|
|
)
|
|
res.append(overwrites.get_protosim_overwrites())
|
|
|
|
# we need to merge the dictionaries because balance overwrites may target
|
|
# the same token address.
|
|
res = functools.reduce(_merge, res)
|
|
return res
|
|
|
|
def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]:
|
|
balance_overwrites = {}
|
|
address = self.balance_owner or self.id_
|
|
for t in self.tokens:
|
|
overwrites = ERC20OverwriteFactory(t)
|
|
overwrites.set_balance(
|
|
t.to_onchain_amount(self.balances[t.address]), address
|
|
)
|
|
balance_overwrites.update(overwrites.get_protosim_overwrites())
|
|
return balance_overwrites
|
|
|
|
def _duplicate(self: type["ThirdPartyPool"]) -> "ThirdPartyPool":
|
|
"""Make a new instance identical to self that shares the same simulation engine.
|
|
|
|
Note that the new and current state become coupled in a way that they must
|
|
simulate the same block. This is fine, see
|
|
https://datarevenue.atlassian.net/browse/ROC-1301
|
|
|
|
Not naming this method _copy to not confuse with Pydantic's .copy method.
|
|
"""
|
|
return type(self)(
|
|
exchange=self.exchange,
|
|
adapter_contract_name=self.adapter_contract_name,
|
|
block=self.block,
|
|
id_=self.id_,
|
|
tokens=self.tokens,
|
|
spot_prices=self.spot_prices.copy(),
|
|
trading_fee=self.trading_fee,
|
|
block_lasting_overwrites=deepcopy(self.block_lasting_overwrites),
|
|
engine=self._engine,
|
|
balances=self.balances,
|
|
minimum_gas=self.minimum_gas,
|
|
hard_sell_limit=self.hard_sell_limit,
|
|
balance_owner=self.balance_owner,
|
|
stateless_contracts=self.stateless_contracts,
|
|
)
|
|
|
|
def get_sell_amount_limit(
|
|
self, sell_token: EthereumToken, buy_token: EthereumToken
|
|
) -> Decimal:
|
|
"""
|
|
Retrieves the sell amount of the given token.
|
|
|
|
For pools with more than 2 tokens, the sell limit is obtain for all possible buy token
|
|
combinations and the minimum is returned.
|
|
"""
|
|
limit = self._adapter_contract.get_limits(
|
|
cast(HexStr, self.id_),
|
|
sell_token,
|
|
buy_token,
|
|
block=self.block,
|
|
overwrites=self._get_overwrites(
|
|
sell_token, buy_token, max_amount=MAX_BALANCE // 100
|
|
),
|
|
)[0]
|
|
return sell_token.from_onchain_amount(limit)
|
|
|
|
def simulate_transition(
|
|
self: TPoolState,
|
|
sell_token: EthereumToken,
|
|
buy_token: EthereumToken,
|
|
target_price: Decimal,
|
|
max_sell_amount: Decimal,
|
|
) -> tuple[Decimal, Decimal, int, TPoolState]:
|
|
pass
|
|
|
|
def update_inplace(self, new: "ThirdPartyPool"):
|
|
"""
|
|
Updates the current `ThirdPartyPool` in-place.
|
|
|
|
If the block attribute of the `new` state differs from the block attribute of
|
|
the current state, the temporary storage of the simulation engine is cleared.
|
|
|
|
Parameters
|
|
----------
|
|
new
|
|
The new state object to update from.
|
|
|
|
"""
|
|
old_block = self.block
|
|
super(ThirdPartyPool, self).update_inplace(new)
|
|
self.block_lasting_overwrites = new.block_lasting_overwrites.copy()
|
|
if new.block != old_block:
|
|
self.clear_all_cache()
|
|
|
|
def clear_all_cache(self):
|
|
self._engine.clear_temp_storage()
|
|
self.block_lasting_overwrites = defaultdict(dict)
|
|
self._set_spot_prices()
|
|
|
|
# def transition(self, event: ProtocolEvent) -> "ThirdPartyPool":
|
|
# """Make a new pool state so that everything's initialised from scratch.
|
|
#
|
|
# Instead of interpreting the event and applying the changes in signals, we just
|
|
# create a fresh instance of the pool. This way all dynamic parameters will be
|
|
# set again using the on-chain data.
|
|
# """
|
|
# new = type(self)(
|
|
# block=self.block,
|
|
# id_=self.id_,
|
|
# tokens=self.tokens,
|
|
# balances=self.balances.copy(),
|
|
# )
|
|
# if isinstance(event, ERC20Transfer):
|
|
# transition_balances_inplace(self.id_, event, new.balances)
|
|
# return new
|
|
#
|
|
|
|
|
|
def _merge(a: dict, b: dict, path=None):
|
|
"""
|
|
Merges two dictionaries (a and b) deeply. This means it will traverse and combine
|
|
their nested dictionaries too if present.
|
|
|
|
Parameters:
|
|
a (dict): The first dictionary to merge.
|
|
b (dict): The second dictionary to merge into the first one.
|
|
path (list, optional): An internal parameter used during recursion
|
|
to keep track of the ancestry of nested dictionaries.
|
|
|
|
Returns:
|
|
a (dict): The merged dictionary which includes all key-value pairs from `b`
|
|
added into `a`. If they have nested dictionaries with same keys, those are also merged.
|
|
On key conflicts, preference is given to values from b.
|
|
"""
|
|
if path is None:
|
|
path = []
|
|
for key in b:
|
|
if key in a:
|
|
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
|
_merge(a[key], b[key], path + [str(key)])
|
|
else:
|
|
a[key] = b[key]
|
|
return a
|