diff --git a/tycho/tycho/adapter_contract.py b/tycho/tycho/adapter_contract.py new file mode 100644 index 0000000..bcf2d74 --- /dev/null +++ b/tycho/tycho/adapter_contract.py @@ -0,0 +1,211 @@ +import logging +import time +from decimal import Decimal +from fractions import Fraction +from typing import Any, Union, NamedTuple + +import eth_abi +from eth_abi.exceptions import DecodingError +from eth_typing import HexStr +from eth_utils import keccak +from eth_utils.abi import collapse_if_tuple +from hexbytes import HexBytes +from protosim_py import ( + SimulationEngine, + SimulationParameters, + SimulationResult, + StateUpdate, +) + +from tycho.tycho.constants import EXTERNAL_ACCOUNT +from tycho.tycho.models import Address, EthereumToken, EVMBlock, Capability +from tycho.tycho.utils import load_abi, maybe_coerce_error + +log = logging.getLogger(__name__) + +TStateOverwrites = dict[Address, dict[int, int]] + + +class Trade(NamedTuple): + """ + Trade represents a simple trading operation with fields: + received_amount: Amount received from the trade + gas_used: Amount of gas used in the transaction + price: Price at which the trade was executed + """ + + received_amount: float + gas_used: float + price: float + + +class ProtoSimResponse: + def __init__(self, return_value: Any, simulation_result: "SimulationResult"): + self.return_value = return_value + self.simulation_result = simulation_result + + +class ProtoSimContract: + def __init__(self, address: Address, abi_name: str, engine: SimulationEngine): + self.abi = load_abi(abi_name) + self.address = address + self.engine = engine + self._default_tx_env = dict( + caller=EXTERNAL_ACCOUNT, to=self.address, value=0, overrides={} + ) + functions = [f for f in self.abi if f["type"] == "function"] + self._functions = {f["name"]: f for f in functions} + if len(self._functions) != len(functions): + raise ValueError( + f"ProtoSimContract does not support overloaded function names! " + f"Encountered while loading {abi_name}." + ) + + def _encode_input(self, fname: str, args: list) -> bytearray: + func = self._functions[fname] + types = [collapse_if_tuple(t) for t in func["inputs"]] + selector = keccak(text=f"{fname}({','.join(types)})")[:4] + return bytearray(selector + eth_abi.encode(types, args)) + + def _decode_output(self, fname: str, encoded: list[int]) -> Any: + func = self._functions[fname] + types = [collapse_if_tuple(t) for t in func["outputs"]] + return eth_abi.decode(types, bytearray(encoded)) + + def call( + self, + fname: str, + *args: list[Union[int, str, bool, bytes]], + block_number, + timestamp: int = None, + overrides: TStateOverwrites = None, + caller: Address = EXTERNAL_ACCOUNT, + value: int = 0, + ) -> ProtoSimResponse: + call_data = self._encode_input(fname, *args) + params = SimulationParameters( + data=call_data, + to=self.address, + block_number=block_number, + timestamp=timestamp or int(time.time()), + overrides=overrides or {}, + caller=caller, + value=value, + ) + sim_result = self._simulate(params) + try: + output = self._decode_output(fname, sim_result.result) + except DecodingError: + log.warning("Failed to decode output") + output = None + return ProtoSimResponse(output, sim_result) + + def _simulate(self, params: SimulationParameters) -> "SimulationResult": + """Run simulation and handle errors. + + It catches a RuntimeError: + + - if it's ``Execution reverted``, re-raises a RuntimeError + with a Tenderly link added + - if it's ``Out of gas``, re-raises a RecoverableSimulationException + - otherwise it just re-raises the original error. + """ + try: + simulation_result = self.engine.run_sim(params) + return simulation_result + except RuntimeError as err: + try: + coerced_err = maybe_coerce_error(err, self, params.gas_limit) + except Exception: + log.exception("Couldn't coerce error. Re-raising the original one.") + raise err + msg = str(coerced_err) + if "Revert!" in msg: + raise type(coerced_err)(msg, repr(self)) from err + else: + raise coerced_err + + +class AdapterContract(ProtoSimContract): + """ + The AdapterContract provides an interface to interact with the protocols implemented + by third parties using the `propeller-protocol-lib`. + """ + + def __init__(self, address: Address, engine: SimulationEngine): + super().__init__(address, "ISwapAdapter", engine) + + def price( + self, + pair_id: HexStr, + sell_token: EthereumToken, + buy_token: EthereumToken, + amounts: list[int], + block: EVMBlock, + overwrites: TStateOverwrites = None, + ) -> list[Fraction]: + args = [HexBytes(pair_id), sell_token.address, buy_token.address, amounts] + res = self.call( + "price", + args, + block_number=block.id, + timestamp=int(block.ts.timestamp()), + overrides=overwrites, + ) + return list(map(lambda x: Fraction(*x), res.return_value[0])) + + def swap( + self, + pair_id: HexStr, + sell_token: EthereumToken, + buy_token: EthereumToken, + is_buy: bool, + amount: Decimal, + block: EVMBlock, + overwrites: TStateOverwrites = None, + ) -> tuple[Trade, dict[str, StateUpdate]]: + args = [ + HexBytes(pair_id), + sell_token.address, + buy_token.address, + int(is_buy), + amount, + ] + res = self.call( + "swap", + args, + block_number=block.id, + timestamp=int(block.ts.timestamp()), + overrides=overwrites, + ) + amount, gas, price = res.return_value[0] + return Trade(amount, gas, Fraction(*price)), res.simulation_result.state_updates + + def get_limits( + self, + pair_id: HexStr, + sell_token: EthereumToken, + buy_token: EthereumToken, + block: EVMBlock, + overwrites: TStateOverwrites = None, + ) -> tuple[int, int]: + args = [HexBytes(pair_id), sell_token.address, buy_token.address] + res = self.call( + "getLimits", + args, + block_number=block.id, + timestamp=int(block.ts.timestamp()), + overrides=overwrites, + ) + return res.return_value[0] + + def get_capabilities( + self, pair_id: HexStr, sell_token: EthereumToken, buy_token: EthereumToken + ) -> set[Capability]: + args = [HexBytes(pair_id), sell_token.address, buy_token.address] + res = self.call("getCapabilities", args, block_number=1) + return set(map(Capability, res.return_value[0])) + + def min_gas_usage(self) -> int: + res = self.call("minGasUsage", [], block_number=1) + return res.return_value[0] diff --git a/tycho/tycho/decoders.py b/tycho/tycho/decoders.py index 66646f3..625ebf8 100644 --- a/tycho/tycho/decoders.py +++ b/tycho/tycho/decoders.py @@ -1,8 +1,12 @@ +from decimal import Decimal from logging import getLogger from typing import Any +from protosim_py import SimulationEngine + from tycho.tycho.exceptions import TychoDecodeError -from tycho.tycho.models import EVMBlock, ThirdPartyPool, EthereumToken +from tycho.tycho.models import EVMBlock, EthereumToken, DatabaseType +from tycho.tycho.pool_state import ThirdPartyPool from tycho.tycho.utils import decode_tycho_exchange log = getLogger(__name__) @@ -60,7 +64,6 @@ class ThirdPartyPoolTychoDecoder: adapter_contract_name=self.adapter_contract, minimum_gas=self.minimum_gas, hard_sell_limit=self.hard_limit, - db_type=DatabaseType.tycho, trace=True, **optional_attributes, ) diff --git a/tycho/tycho/exceptions.py b/tycho/tycho/exceptions.py index 8d88057..6a4b9a4 100644 --- a/tycho/tycho/exceptions.py +++ b/tycho/tycho/exceptions.py @@ -1,3 +1,6 @@ +from decimal import Decimal + + class TychoDecodeError(Exception): def __init__(self, msg: str, pool_id: str): super().__init__(msg) @@ -6,3 +9,47 @@ class TychoDecodeError(Exception): class APIRequestError(Exception): pass + + +class TradeSimulationException(Exception): + def __init__(self, message, pool_id: str): + self.pool_id = pool_id + super().__init__(message) + + +class RecoverableSimulationException(TradeSimulationException): + """Marks that the simulation could not fully fulfill the requested order. + + Provides a partial trade that is valid but does not fully fulfill the conditions + requested. + + Parameters + ---------- + message + Error message + pool_id + ID of a pool that caused the error + partial_trade + A tuple of (bought_amount, gas_used, new_pool_state, sold_amount) + """ + + def __init__( + self, + message, + pool_id: str, + partial_trade: tuple[Decimal, int, "ThirdPartyPool", Decimal] = None, + ): + super().__init__(message, pool_id) + self.partial_trade = partial_trade + + +class OutOfGas(RecoverableSimulationException): + """This exception indicates that the underlying VM **likely** ran out of gas. + + It is not easy to judge whether it was really due to out of gas, as the details + of the SC being called might be hiding this. E.g. out of gas may happen while + calling an external contract, which might show as the external call failing, although + it was due to a lack of gas. + """ + + pass diff --git a/tycho/tycho/models.py b/tycho/tycho/models.py index 4b7529d..94c74f0 100644 --- a/tycho/tycho/models.py +++ b/tycho/tycho/models.py @@ -1,9 +1,11 @@ import datetime -from enum import Enum +from enum import Enum, IntEnum, auto from typing import Union from pydantic import BaseModel, Field +Address = str + class Blockchain(Enum): ethereum = "ethereum" @@ -18,12 +20,26 @@ class EVMBlock(BaseModel): hash_: str -class ThirdPartyPool: - pass - - class EthereumToken(BaseModel): symbol: str address: str decimals: int gas: Union[int, list[int]] = 29000 + + +class DatabaseType(Enum): + # Make call to the node each time it needs a storage (unless cached from a previous call). + rpc_reader = "rpc_reader" + # Connect to Tycho and cache the whole state of a target contract, the state is continuously updated by Tycho. + # To use this we need Tycho to be configured to index the target contract state. + tycho = "tycho" + + +class Capability(IntEnum): + SellSide = auto() + BuySide = auto() + PriceFunction = auto() + FeeOnTransfer = auto() + ConstantPrice = auto() + TokenBalanceIndependent = auto() + ScaledPrice = auto() diff --git a/tycho/tycho/pool_state.py b/tycho/tycho/pool_state.py new file mode 100644 index 0000000..710ce72 --- /dev/null +++ b/tycho/tycho/pool_state.py @@ -0,0 +1,394 @@ +import functools +import itertools +from collections import defaultdict +from copy import deepcopy +from decimal import Decimal +from fractions import Fraction +from logging import getLogger +from typing import Optional, cast, TypeVar + +from protosim_py import SimulationEngine, AccountInfo +from pydantic import BaseModel, PrivateAttr, Field + +from tycho.tycho.adapter_contract import AdapterContract +from tycho.tycho.constants import MAX_BALANCE, EXTERNAL_ACCOUNT +from tycho.tycho.exceptions import RecoverableSimulationException +from tycho.tycho.models import ( + EVMBlock, + DatabaseType, + Capability, + Address, + EthereumToken, +) +from tycho.tycho.utils import ( + create_engine, + get_contract_bytecode, + frac_to_decimal, + ERC20OverwriteFactory, +) +from eth_typing import HexStr + +ADAPTER_ADDRESS = "0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270" + +log = getLogger(__name__) +TPoolState = TypeVar("TPoolState", bound="ThirdPartyPool") + + +class ThirdPartyPool(BaseModel): + id_: str + tokens: tuple[EthereumToken, ...] + balances: dict[Address, Decimal] + block: EVMBlock + spot_prices: dict[tuple[EthereumToken, EthereumToken], Decimal] + trading_fee: Decimal + exchange: str + minimum_gas: int + + _engine: SimulationEngine = PrivateAttr(default=None) + + adapter_contract_name: str + """The adapters contract name. Used to look up the byte code for the adapter.""" + _adapter_contract: AdapterContract = PrivateAttr(default=None) + + stateless_contracts: dict[str, bytes] = {} + """The address to bytecode map of all stateless contracts used by the protocol for simulations.""" + + capabilities: set[Capability] = Field(default_factory=lambda: {Capability.SellSide}) + """The supported capabilities of this pool.""" + + balance_owner: Optional[str] = None + """The contract address for where protocol balances are stored (i.e. a vault contract). + If given, balances will be overwritten here instead of on the pool contract during simulations.""" + + block_lasting_overwrites: defaultdict[Address, dict[int, int]] = Field( + default_factory=lambda: defaultdict(dict) + ) + """Storage overwrites that will be applied to all simulations. They will be cleared + when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name.""" + + trace: bool = False + + hard_sell_limit: bool = False + """ + Whether the pool will revert if you attempt to sell more than the limit. Defaults to + False where it is assumed that exceeding the limit will provide a bad price but will + still succeed. + """ + + def __init__(self, **data): + super().__init__(**data) + self._set_engine(data.get("engine", None)) + self.balance_owner = data.get("balance_owner", None) + self._adapter_contract = AdapterContract(ADAPTER_ADDRESS, self._engine) + self._set_capabilities() + if len(self.spot_prices) == 0: + self._set_spot_prices() + + def _set_engine(self, engine: Optional[SimulationEngine]): + """Set instance's simulation engine. If no engine given, make a default one. + + If engine is already set, this is a noop. + + The engine will have the specified adapter contract mocked, as well as the + tokens used by the pool. + + Parameters + ---------- + engine + Optional simulation engine instance. + """ + if self._engine is not None: + return + else: + engine = create_engine([t.address for t in self.tokens], trace=self.trace) + engine.init_account( + address=ADAPTER_ADDRESS, + account=AccountInfo( + balance=0, + nonce=0, + code=get_contract_bytecode(self.adapter_contract_name), + ), + mocked=False, + permanent_storage=None, + ) + for addr, bytecode in self.stateless_contracts.items(): + engine.init_account( + address=addr, + account=AccountInfo(balance=0, nonce=0, code=bytecode), + mocked=False, + permanent_storage=None, + ) + self._engine = engine + + """Set the spot prices for this pool. + + We currently require the price function capability for now. + """ + self._ensure_capability(Capability.PriceFunction) + for t0, t1 in itertools.permutations(self.tokens, 2): + sell_amount = t0.to_onchain_amount( + self.get_sell_amount_limit(t0, t1) * Decimal("0.01") + ) + frac = self._adapter_contract.price( + cast(HexStr, self.id_), + t0, + t1, + [sell_amount], + block=self.block, + overwrites=self.block_lasting_overwrites, + )[0] + if Capability.ScaledPrice in self.capabilities: + self.spot_prices[(t0, t1)] = frac_to_decimal(frac) + else: + scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals) + self.spot_prices[(t0, t1)] = frac_to_decimal(scaled) + + def _ensure_capability(self, capability: Capability): + """Ensures the protocol/adapter implement a certain capability.""" + if capability not in self.capabilities: + raise NotImplemented(f"{capability} not available!") + + def _set_capabilities(self): + """Sets capabilities of the pool.""" + capabilities = [] + for t0, t1 in itertools.permutations(self.tokens, 2): + capabilities.append( + self._adapter_contract.get_capabilities(cast(HexStr, self.id_), t0, t1) + ) + max_capabilities = max(map(len, capabilities)) + self.capabilities = functools.reduce(set.intersection, capabilities) + if len(self.capabilities) < max_capabilities: + log.warning( + f"Pool {self.id_} hash different capabilities depending on the token pair!" + ) + + def get_amount_out( + self: TPoolState, + sell_token: EthereumToken, + sell_amount: Decimal, + buy_token: EthereumToken, + slippage: Decimal = Decimal(0), + create_new_pool: bool = True, + ) -> tuple[Decimal, int, TPoolState]: + # if the pool has a hard limit and the sell amount exceeds that, simulate and + # raise a partial trade + if self.hard_sell_limit: + sell_limit = self.get_sell_amount_limit(sell_token, buy_token) + if sell_amount > sell_limit: + partial_trade = self._get_amount_out(sell_token, sell_limit, buy_token) + raise RecoverableSimulationException( + "Sell amount exceeds sell limit", + repr(self), + partial_trade + (sell_limit,), + ) + + return self._get_amount_out(sell_token, sell_amount, buy_token) + + def _get_amount_out( + self: TPoolState, + sell_token: EthereumToken, + sell_amount: Decimal, + buy_token: EthereumToken, + ) -> tuple[Decimal, int, TPoolState]: + trade, state_changes = self._adapter_contract.swap( + cast(HexStr, self.id_), + sell_token, + buy_token, + False, + sell_token.to_onchain_amount(sell_amount), + block=self.block, + overwrites=self._get_overwrites(sell_token, buy_token), + ) + new_state = self._duplicate() + for address, state_update in state_changes.items(): + for slot, value in state_update.storage.items(): + new_state.block_lasting_overwrites[address][slot] = value + + new_price = frac_to_decimal(trade.price) + if new_price != Decimal(0): + new_state.spot_prices = { + (sell_token, buy_token): new_price, + (buy_token, sell_token): Decimal(1) / new_price, + } + + buy_amount = buy_token.from_onchain_amount(trade.received_amount) + + return buy_amount, trade.gas_used, new_state + + def _get_overwrites( + self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs + ) -> dict[Address, dict[int, int]]: + """Get an overwrites dictionary to use in a simulation. + + The returned overwrites include block-lasting overwrites set on the instance + level, and token-specific overwrites that depend on passed tokens. + """ + token_overwrites = self._get_token_overwrites(sell_token, buy_token, **kwargs) + return _merge(self.block_lasting_overwrites, token_overwrites) + + def _get_token_overwrites( + self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None + ) -> dict[Address, dict[int, int]]: + """Creates overwrites for a token. + + Funds external account with enough tokens to execute swaps. Also creates a + corresponding approval to the adapter contract. + + If the protocol reads its own token balances, the balances for the underlying + pool contract will also be overwritten. + """ + res = [] + if Capability.TokenBalanceIndependent not in self.capabilities: + res = [self._get_balance_overwrites()] + + # avoids recursion if using this method with get_sell_amount_limit + if max_amount is None: + max_amount = sell_token.to_onchain_amount( + self.get_sell_amount_limit(sell_token, buy_token) + ) + overwrites = ERC20OverwriteFactory(sell_token) + overwrites.set_balance(max_amount, EXTERNAL_ACCOUNT) + overwrites.set_allowance( + allowance=max_amount, owner=EXTERNAL_ACCOUNT, spender=ADAPTER_ADDRESS + ) + res.append(overwrites.get_protosim_overwrites()) + + # we need to merge the dictionaries because balance overwrites may target + # the same token address. + res = functools.reduce(_merge, res) + return res + + def _get_balance_overwrites(self) -> dict[Address, dict[int, int]]: + balance_overwrites = {} + address = self.balance_owner or self.id_ + for t in self.tokens: + overwrites = ERC20OverwriteFactory(t) + overwrites.set_balance( + t.to_onchain_amount(self.balances[t.address]), address + ) + balance_overwrites.update(overwrites.get_protosim_overwrites()) + return balance_overwrites + + def _duplicate(self: type["ThirdPartyPool"]) -> "ThirdPartyPool": + """Make a new instance identical to self that shares the same simulation engine. + + Note that the new and current state become coupled in a way that they must + simulate the same block. This is fine, see + https://datarevenue.atlassian.net/browse/ROC-1301 + + Not naming this method _copy to not confuse with pydantic's .copy method. + """ + return type(self)( + exchange=self.exchange, + adapter_contract_name=self.adapter_contract_name, + block=self.block, + id_=self.id_, + tokens=self.tokens, + spot_prices=self.spot_prices.copy(), + trading_fee=self.trading_fee, + block_lasting_overwrites=deepcopy(self.block_lasting_overwrites), + engine=self._engine, + balances=self.balances, + minimum_gas=self.minimum_gas, + hard_sell_limit=self.hard_sell_limit, + balance_owner=self.balance_owner, + stateless_contracts=self.stateless_contracts, + ) + + def get_sell_amount_limit( + self, sell_token: EthereumToken, buy_token: EthereumToken + ) -> Decimal: + """ + Retrieves the sell amount of the given token. + + For pools with more than 2 tokens, the sell limit is obtain for all possible buy token + combinations and the minimum is returned. + """ + limit = self._adapter_contract.get_limits( + cast(HexStr, self.id_), + sell_token, + buy_token, + block=self.block, + overwrites=self._get_overwrites( + sell_token, buy_token, max_amount=MAX_BALANCE // 100 + ), + )[0] + return sell_token.from_onchain_amount(limit) + + def simulate_transition( + self: TPoolState, + sell_token: EthereumToken, + buy_token: EthereumToken, + target_price: Decimal, + max_sell_amount: Decimal, + ) -> tuple[Decimal, Decimal, int, TPoolState]: + pass + + def update_inplace(self, new: "ThirdPartyPool"): + """ + Updates the current `ThirdPartyPool` in-place. + + If the block attribute of the `new` state differs from the block attribute of + the current state, the temporary storage of the simulation engine is cleared. + + Parameters + ---------- + new + The new state object to update from. + + """ + old_block = self.block + super(ThirdPartyPool, self).update_inplace(new) + self.block_lasting_overwrites = new.block_lasting_overwrites.copy() + if new.block != old_block: + self.clear_all_cache() + + def clear_all_cache(self): + self._engine.clear_temp_storage() + self.block_lasting_overwrites = defaultdict(dict) + self._set_spot_prices() + + # def transition(self, event: ProtocolEvent) -> "ThirdPartyPool": + # """Make a new pool state so that everything's initialised from scratch. + # + # Instead of interpreting the event and applying the changes in signals, we just + # create a fresh instance of the pool. This way all dynamic parameters will be + # set again using the on-chain data. + # """ + # new = type(self)( + # block=self.block, + # id_=self.id_, + # tokens=self.tokens, + # balances=self.balances.copy(), + # ) + # if isinstance(event, ERC20Transfer): + # transition_balances_inplace(self.id_, event, new.balances) + # return new + # + + +def _merge(a: dict, b: dict, path=None): + """ + Merges two dictionaries (a and b) deeply. This means it will traverse and combine + their nested dictionaries too if present. + + Parameters: + a (dict): The first dictionary to merge. + b (dict): The second dictionary to merge into the first one. + path (list, optional): An internal parameter used during recursion + to keep track of the ancestry of nested dictionaries. + + Returns: + a (dict): The merged dictionary which includes all key-value pairs from b + added into a. If they have nested dictionaries with same keys, those are also merged. + On key conflicts, preference is given to values from b. + """ + if path is None: + path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + _merge(a[key], b[key], path + [str(key)]) + else: + a[key] = b[key] + return a diff --git a/tycho/tycho/tycho_adapter.py b/tycho/tycho/tycho_adapter.py index ff4319e..36d197d 100644 --- a/tycho/tycho/tycho_adapter.py +++ b/tycho/tycho/tycho_adapter.py @@ -28,7 +28,9 @@ from tycho.tycho.constants import ( ) from tycho.tycho.decoders import ThirdPartyPoolTychoDecoder from tycho.tycho.exceptions import APIRequestError -from tycho.tycho.models import Blockchain, EVMBlock, ThirdPartyPool, EthereumToken +from tycho.tycho.models import Blockchain, EVMBlock, EthereumToken +from tycho.tycho.pool_state import ThirdPartyPool +from tycho.tycho.utils import create_engine, TychoDBSingleton log = getLogger(__name__) @@ -124,14 +126,9 @@ class TychoPoolStateStreamAdapter: self._decoder = decoder # Create engine - self._db = TychoDB(tycho_http_url=self.tycho_url) - self._engine = SimulationEngine.new_with_tycho_db(db=self._db, trace=True) - self._engine.init_account( - address=EXTERNAL_ACCOUNT, - account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None), - mocked=False, - permanent_storage=None, - ) + # TODO: This should be initialized outside the adapter? + TychoDBSingleton.initialize(tycho_http_url=self.tycho_url) + self._engine = create_engine([], state_block=None, trace=True) # Loads tokens from Tycho self._tokens: dict[str, EthereumToken] = TokenLoader( diff --git a/tycho/tycho/tycho_db.py b/tycho/tycho/tycho_db.py new file mode 100644 index 0000000..e69de29 diff --git a/tycho/tycho/utils.py b/tycho/tycho/utils.py index 72b4ff5..002389a 100644 --- a/tycho/tycho/utils.py +++ b/tycho/tycho/utils.py @@ -1,3 +1,343 @@ +import json +import os +from decimal import Decimal +from fractions import Fraction +from functools import lru_cache +from logging import getLogger +from pathlib import Path +from typing import Final, Any + +import eth_abi +from eth_typing import HexStr +from hexbytes import HexBytes +from protosim_py import SimulationEngine, TychoDB, AccountInfo +from web3 import Web3 + +from tycho.tycho.constants import EXTERNAL_ACCOUNT, MAX_BALANCE +from tycho.tycho.exceptions import OutOfGas +from tycho.tycho.models import Address, EthereumToken + +log = getLogger(__name__) + + def decode_tycho_exchange(exchange: str) -> (str, bool): # removes vm prefix if present, returns True if vm prefix was present (vm protocol) or False if native protocol return (exchange.split(":")[1], False) if "vm:" in exchange else (exchange, True) + + +class TychoDBSingleton: + """ + A singleton wrapper around the TychoDB class. + + This class ensures that there is only one instance of TychoDB throughout the lifetime of the program, + avoiding the overhead of creating multiple instances. + """ + + _instance = None + + @classmethod + def initialize(cls, tycho_http_url: str): + """ + Initialize the TychoDB instance with the given URLs. + + Parameters + ---------- + tycho_http_url : str + The URL of the Tycho HTTP server. + + """ + cls._instance = TychoDB(tycho_http_url=tycho_http_url) + + @classmethod + def get_instance(cls) -> TychoDB: + """ + Retrieve the singleton instance of TychoDB. + + If the TychoDB instance does not exist, it creates a new one. + If it already exists, it returns the existing instance. + + Returns + ------- + TychoDB + The singleton instance of TychoDB. + """ + if cls._instance is None: + raise ValueError( + "TychoDB instance not initialized. Call initialize() first." + ) + return cls._instance + + @classmethod + def clear_instance(cls): + cls._instance = None + + +def create_engine( + mocked_tokens: list[Address], trace: bool = False +) -> SimulationEngine: + """Create a simulation engine with a mocked ERC20 contract at given addresses. + + Parameters + ---------- + mocked_tokens + A list of addresses at which a mocked ERC20 contract should be inserted. + + trace + Whether to trace calls, only meant for debugging purposes, might print a lot of + data to stdout. + """ + + db = TychoDBSingleton.get_instance() + engine = SimulationEngine.new_with_tycho_db(db=db, trace=trace) + + for t in mocked_tokens: + info = AccountInfo(balance=0, nonce=0, code=get_contract_bytecode("ERC20.bin")) + engine.init_account( + address=t, account=info, mocked=True, permanent_storage=None + ) + engine.init_account( + address=EXTERNAL_ACCOUNT, + account=AccountInfo(balance=MAX_BALANCE, nonce=0, code=None), + mocked=False, + permanent_storage=None, + ) + + return engine + + +class ERC20OverwriteFactory: + def __init__(self, token: EthereumToken): + """ + Initialize the ERC20OverwriteFactory. + + Parameters: + token: The token object. + """ + self._token = token + self._overwrites = dict() + self._balance_slot: Final[int] = 0 + self._allowance_slot: Final[int] = 1 + + def set_balance(self, balance: int, owner: Address): + """ + Set the balance for a given owner. + + Parameters: + balance: The balance value. + owner: The owner's address. + """ + storage_index = get_storage_slot_at_key(HexStr(owner), self._balance_slot) + self._overwrites[storage_index] = balance + log.log( + 5, + f"Override balance: token={self._token.address} owner={owner}" + f"value={balance} slot={storage_index}", + ) + + def set_allowance(self, allowance: int, spender: Address, owner: Address): + """ + Set the allowance for a given spender and owner. + + Parameters: + allowance: The allowance value. + spender: The spender's address. + owner: The owner's address. + """ + storage_index = get_storage_slot_at_key( + HexStr(spender), + get_storage_slot_at_key(HexStr(owner), self._allowance_slot), + ) + self._overwrites[storage_index] = allowance + log.log( + 5, + f"Override allowance: token={self._token.address} owner={owner}" + f"spender={spender} value={allowance} slot={storage_index}", + ) + + def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]: + """ + Get the overwrites dictionary of previously collected values. + + Returns: + dict[Address, dict]: A dictionary containing the token's address + and the overwrites. + """ + # Protosim returns lowercase addresses in state updates returned from simulation + + return {self._token.address.lower(): self._overwrites} + + def get_geth_overwrites(self) -> dict[Address, dict[int, int]]: + """ + Get the overwrites dictionary of previously collected values. + + Returns: + dict[Address, dict]: A dictionary containing the token's address + and the overwrites. + """ + formatted_overwrites = { + HexBytes(key).hex(): "0x" + HexBytes(val).hex().lstrip("0x").zfill(64) + for key, val in self._overwrites.items() + } + code = "0x" + get_contract_bytecode("ERC20.bin").hex() + return {self._token.address: {"stateDiff": formatted_overwrites, "code": code}} + + +def get_storage_slot_at_key(key: Address, mapping_slot: int) -> int: + """Get storage slot index of a value stored at a certain key in a mapping + + Parameters + ---------- + key + Key in a mapping. This function is meant to work with ethereum addresses + and accepts only strings. + mapping_slot + Storage slot at which the mapping itself is stored. See the examples for more + explanation. + + Returns + ------- + slot + An index of a storage slot where the value at the given key is stored. + + Examples + -------- + If a mapping is declared as a first variable in solidity code, its storage slot + is 0 (e.g. ``balances`` in our mocked ERC20 contract). Here's how to compute + a storage slot where balance of a given account is stored:: + + get_storage_slot_at_key("0xC63135E4bF73F637AF616DFd64cf701866BB2628", 0) + + For nested mappings, we need to apply the function twice. An example of this is + ``allowances`` in ERC20. It is a mapping of form: + ``dict[owner, dict[spender, value]]``. In our mocked ERC20 contract, ``allowances`` + is a second variable, so it is stored at slot 1. Here's how to get a storage slot + where an allowance of ``0xspender`` to spend ``0xowner``'s money is stored:: + + get_storage_slot_at_key("0xspender", get_storage_slot_at_key("0xowner", 1))) + + See Also + -------- + `Solidity Storage Layout documentation + `_ + """ + key_bytes = bytes.fromhex(key[2:]).rjust(32, b"\0") + mapping_slot_bytes = int.to_bytes(mapping_slot, 32, "big") + slot_bytes = Web3.keccak(key_bytes + mapping_slot_bytes) + return int.from_bytes(slot_bytes, "big") + + +@lru_cache +def get_contract_bytecode(name: str) -> bytes: + with open(Path(__file__).parent / "assets" / name, "rb") as fh: + code = fh.read() + return code + + +def frac_to_decimal(frac: Fraction) -> Decimal: + return Decimal(frac.numerator) / Decimal(frac.denominator) + + +def load_abi(name_or_path: str) -> dict: + if os.path.exists(abspath := os.path.abspath(name_or_path)): + path = abspath + else: + path = f"{os.path.dirname(os.path.abspath(__file__))}/assets/{name_or_path}.abi" + try: + with open(os.path.abspath(path)) as f: + abi: dict = json.load(f) + except FileNotFoundError: + search_dir = f"{os.path.dirname(os.path.abspath(__file__))}/assets/" + + # List all files in search dir and subdirs suggest them to the user in an error message + available_files = [] + for dirpath, dirnames, filenames in os.walk(search_dir): + for filename in filenames: + # Make paths relative to search_dir + relative_path = os.path.relpath( + os.path.join(dirpath, filename), search_dir + ) + available_files.append(relative_path.replace(".abi", "")) + + raise FileNotFoundError( + f"File {name_or_path} not found. " + f"Did you mean one of these? {', '.join(available_files)}" + ) + return abi + + +# https://docs.soliditylang.org/en/latest/control-structures.html#panic-via-assert-and-error-via-require +solidity_panic_codes = { + 0: "GenericCompilerPanic", + 1: "AssertionError", + 17: "ArithmeticOver/Underflow", + 18: "ZeroDivisionError", + 33: "UnkownEnumMember", + 34: "BadStorageByteArrayEncoding", + 51: "EmptyArray", + 0x32: "OutOfBounds", + 0x41: "OutOfMemory", + 0x51: "BadFunctionPointer", +} + + +def parse_solidity_error_message(data) -> str: + data_bytes = HexBytes(data) + error_string = f"Failed to decode: {data}" + # data is encoded as Error(string) + if data_bytes[:4] == HexBytes("0x08c379a0"): + (error_string,) = eth_abi.decode(["string"], data_bytes[4:]) + return error_string + elif data_bytes[:4] == HexBytes("0x4e487b71"): + (error_code,) = eth_abi.decode(["uint256"], data_bytes[4:]) + return solidity_panic_codes.get(error_code, f"Panic({error_code})") + # old solidity: revert 'some string' case + try: + (error_string,) = eth_abi.decode(["string"], data_bytes) + return error_string + except Exception: + pass + # some custom error maybe it is with string? + try: + (error_string,) = eth_abi.decode(["string"], data_bytes[4:]) + return error_string + except Exception: + pass + try: + (error_string,) = eth_abi.decode(["string"], data_bytes[4:]) + return error_string + except Exception: + pass + return error_string + + +def maybe_coerce_error( + err: RuntimeError, pool_state: Any, gas_limit: int = None +) -> Exception: + details = err.args[0] + # we got bytes as data, so this was a revert + if details.data.startswith("0x"): + err = RuntimeError( + f"Revert! Reason: {parse_solidity_error_message(details.data)}" + ) + # we have gas information, check if this likely an out of gas err. + if gas_limit is not None and details.gas_used is not None: + # if we used up 97% or more issue a OutOfGas error. + usage = details.gas_used / gas_limit + if usage >= 0.97: + return OutOfGas( + f"SimulationError: Likely out-of-gas. " + f"Used: {usage * 100:.2f}% of gas limit. " + f"Original error: {err}", + repr(pool_state), + ) + elif "OutOfGas" in details.data: + if gas_limit is not None: + usage = details.gas_used / gas_limit + usage_msg = f"Used: {usage * 100:.2f}% of gas limit. " + else: + usage_msg = "" + return OutOfGas( + f"SimulationError: out-of-gas. {usage_msg}Original error: {details.data}", + repr(pool_state), + ) + return err