Start using external modules

This commit is contained in:
Thales
2024-08-05 19:58:10 -03:00
committed by tvinagre
parent 8ea02613a2
commit d893ab264c
9 changed files with 171 additions and 176 deletions

View File

@@ -7,6 +7,6 @@ TYCHO_CLIENT_LOG_FOLDER = TYCHO_CLIENT_FOLDER / "logs"
EXTERNAL_ACCOUNT: Final[str] = "0xf847a638E44186F3287ee9F8cAF73FF4d4B80784"
"""This is a dummy address used as a transaction sender"""
UINT256_MAX: Final[int] = 2 ** 256 - 1
UINT256_MAX: Final[int] = 2**256 - 1
MAX_BALANCE: Final[int] = UINT256_MAX // 2
"""0.5 of the maximal possible balance to avoid overflow errors"""

View File

@@ -26,10 +26,10 @@ class ThirdPartyPoolTychoDecoder:
self.trace = trace
def decode_snapshot(
self,
snapshot: dict[str, Any],
block: EVMBlock,
tokens: dict[str, EthereumToken],
self,
snapshot: dict[str, Any],
block: EVMBlock,
tokens: dict[str, EthereumToken],
) -> tuple[dict[str, ThirdPartyPool], list[str]]:
pools = {}
failed_pools = []
@@ -45,7 +45,7 @@ class ThirdPartyPoolTychoDecoder:
return pools, failed_pools
def decode_pool_state(
self, snap: dict, block: EVMBlock, tokens: dict[str, EthereumToken]
self, snap: dict, block: EVMBlock, tokens: dict[str, EthereumToken]
) -> ThirdPartyPool:
component = snap["component"]
exchange, _ = decode_tycho_exchange(component["protocol_system"])
@@ -85,20 +85,29 @@ class ThirdPartyPoolTychoDecoder:
while f"stateless_contract_addr_{index}" in static_attributes:
encoded_address = static_attributes[f"stateless_contract_addr_{index}"]
decoded = bytes.fromhex(
encoded_address[2:] if encoded_address.startswith('0x') else encoded_address).decode('utf-8')
encoded_address[2:]
if encoded_address.startswith("0x")
else encoded_address
).decode("utf-8")
if decoded.startswith("call"):
address = ThirdPartyPoolTychoDecoder.get_address_from_call(block_number, decoded)
address = ThirdPartyPoolTychoDecoder.get_address_from_call(
block_number, decoded
)
else:
address = decoded
code = static_attributes.get(f"stateless_contract_code_{index}") or get_code_for_address(address)
code = static_attributes.get(
f"stateless_contract_code_{index}"
) or get_code_for_address(address)
stateless_contracts[address] = code
index += 1
index = 0
while f"stateless_contract_addr_{index}" in attributes:
address = attributes[f"stateless_contract_addr_{index}"]
code = attributes.get(f"stateless_contract_code_{index}") or get_code_for_address(address)
code = attributes.get(
f"stateless_contract_code_{index}"
) or get_code_for_address(address)
stateless_contracts[address] = code
index += 1
return {
@@ -118,15 +127,17 @@ class ThirdPartyPoolTychoDecoder:
permanent_storage=None,
)
selector = keccak(text=decoded.split(":")[-1])[:4]
sim_result = engine.run_sim(SimulationParameters(
data=bytearray(selector),
to=decoded.split(':')[1],
block_number=block_number,
timestamp=int(time.time()),
overrides={},
caller=EXTERNAL_ACCOUNT,
value=0,
))
sim_result = engine.run_sim(
SimulationParameters(
data=bytearray(selector),
to=decoded.split(":")[1],
block_number=block_number,
timestamp=int(time.time()),
overrides={},
caller=EXTERNAL_ACCOUNT,
value=0,
)
)
address = eth_abi.decode(["address"], bytearray(sim_result.result))
return address[0]
@@ -143,10 +154,10 @@ class ThirdPartyPoolTychoDecoder:
@staticmethod
def apply_update(
pool: ThirdPartyPool,
pool_update: dict[str, Any],
balance_updates: dict[str, Any],
block: EVMBlock,
pool: ThirdPartyPool,
pool_update: dict[str, Any],
balance_updates: dict[str, Any],
block: EVMBlock,
) -> ThirdPartyPool:
# check for and apply optional state attributes
attributes = pool_update.get("updated_attributes")

View File

@@ -41,7 +41,7 @@ class EthereumToken(BaseModel):
log.warning(f"Expected variable of type Decimal. Got {type(amount)}.")
with localcontext(Context(rounding=ROUND_FLOOR, prec=256)):
amount = Decimal(str(amount)) * (10 ** self.decimals)
amount = Decimal(str(amount)) * (10**self.decimals)
try:
amount = amount.quantize(Decimal("1.0"))
except InvalidOperation:
@@ -51,7 +51,7 @@ class EthereumToken(BaseModel):
return int(amount)
def from_onchain_amount(
self, onchain_amount: Union[int, Fraction], quantize: bool = True
self, onchain_amount: Union[int, Fraction], quantize: bool = True
) -> Decimal:
"""Converts an Integer to a quantized decimal, by shifting left by the token's
maximum amount of decimals (e.g.: 1000000 becomes 1.000000 for a 6-decimal token
@@ -66,19 +66,19 @@ class EthereumToken(BaseModel):
with localcontext(Context(rounding=ROUND_FLOOR, prec=256)):
if isinstance(onchain_amount, Fraction):
return (
Decimal(onchain_amount.numerator)
/ Decimal(onchain_amount.denominator)
/ Decimal(10 ** self.decimals)
Decimal(onchain_amount.numerator)
/ Decimal(onchain_amount.denominator)
/ Decimal(10**self.decimals)
).quantize(Decimal(f"{1 / 10 ** self.decimals}"))
if quantize is True:
try:
amount = (
Decimal(str(onchain_amount)) / 10 ** self.decimals
Decimal(str(onchain_amount)) / 10**self.decimals
).quantize(Decimal(f"{1 / 10 ** self.decimals}"))
except InvalidOperation:
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals)
amount = Decimal(str(onchain_amount)) / Decimal(10**self.decimals)
else:
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals)
amount = Decimal(str(onchain_amount)) / Decimal(10**self.decimals)
return amount
def __repr__(self):

View File

@@ -142,7 +142,7 @@ class ThirdPartyPool(BaseModel):
if Capability.ScaledPrice in self.capabilities:
self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
else:
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals)
scaled = frac * Fraction(10**t0.decimals, 10**t1.decimals)
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
def _ensure_capability(self, capability: Capability):
@@ -165,10 +165,10 @@ class ThirdPartyPool(BaseModel):
)
def get_amount_out(
self: TPoolState,
sell_token: EthereumToken,
sell_amount: Decimal,
buy_token: EthereumToken,
self: TPoolState,
sell_token: EthereumToken,
sell_amount: Decimal,
buy_token: EthereumToken,
) -> tuple[Decimal, int, TPoolState]:
# if the pool has a hard limit and the sell amount exceeds that, simulate and
# raise a partial trade
@@ -185,10 +185,10 @@ class ThirdPartyPool(BaseModel):
return self._get_amount_out(sell_token, sell_amount, buy_token)
def _get_amount_out(
self: TPoolState,
sell_token: EthereumToken,
sell_amount: Decimal,
buy_token: EthereumToken,
self: TPoolState,
sell_token: EthereumToken,
sell_amount: Decimal,
buy_token: EthereumToken,
) -> tuple[Decimal, int, TPoolState]:
trade, state_changes = self._adapter_contract.swap(
cast(HexStr, self.id_),
@@ -216,7 +216,7 @@ class ThirdPartyPool(BaseModel):
return buy_amount, trade.gas_used, new_state
def _get_overwrites(
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
) -> dict[Address, dict[int, int]]:
"""Get an overwrites dictionary to use in a simulation.
@@ -227,7 +227,7 @@ class ThirdPartyPool(BaseModel):
return _merge(self.block_lasting_overwrites, token_overwrites)
def _get_token_overwrites(
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
) -> dict[Address, dict[int, int]]:
"""Creates overwrites for a token.
@@ -295,7 +295,7 @@ class ThirdPartyPool(BaseModel):
)
def get_sell_amount_limit(
self, sell_token: EthereumToken, buy_token: EthereumToken
self, sell_token: EthereumToken, buy_token: EthereumToken
) -> Decimal:
"""
Retrieves the sell amount of the given token.

View File

@@ -26,10 +26,10 @@ log = getLogger(__name__)
class TokenLoader:
def __init__(
self,
tycho_url: str,
blockchain: Blockchain,
min_token_quality: Optional[int] = 0,
self,
tycho_url: str,
blockchain: Blockchain,
min_token_quality: Optional[int] = 0,
):
self.tycho_url = tycho_url
self.blockchain = blockchain
@@ -45,10 +45,10 @@ class TokenLoader:
start = time.monotonic()
all_tokens = []
while data := self._get_all_with_pagination(
url=url,
page=page,
limit=self._token_limit,
params={"min_quality": self.min_token_quality},
url=url,
page=page,
limit=self._token_limit,
params={"min_quality": self.min_token_quality},
):
all_tokens.extend(data)
page += 1
@@ -73,10 +73,10 @@ class TokenLoader:
start = time.monotonic()
all_tokens = []
while data := self._get_all_with_pagination(
url=url,
page=page,
limit=self._token_limit,
params={"min_quality": self.min_token_quality, "addresses": addresses},
url=url,
page=page,
limit=self._token_limit,
params={"min_quality": self.min_token_quality, "addresses": addresses},
):
all_tokens.extend(data)
page += 1
@@ -95,7 +95,7 @@ class TokenLoader:
@staticmethod
def _get_all_with_pagination(
url: str, params: Optional[Dict] = None, page: int = 0, limit: int = 50
url: str, params: Optional[Dict] = None, page: int = 0, limit: int = 50
) -> Dict:
if params is None:
params = {}
@@ -122,14 +122,14 @@ class BlockProtocolChanges:
class TychoPoolStateStreamAdapter:
def __init__(
self,
tycho_url: str,
protocol: str,
decoder: ThirdPartyPoolTychoDecoder,
blockchain: Blockchain,
min_tvl: Optional[Decimal] = 10,
min_token_quality: Optional[int] = 0,
include_state=True,
self,
tycho_url: str,
protocol: str,
decoder: ThirdPartyPoolTychoDecoder,
blockchain: Blockchain,
min_tvl: Optional[Decimal] = 10,
min_token_quality: Optional[int] = 0,
include_state=True,
):
"""
:param tycho_url: URL to connect to Tycho DB
@@ -181,7 +181,7 @@ class TychoPoolStateStreamAdapter:
log.debug(f"Starting tycho-client binary at {bin_path}. CMD: {cmd}")
self.tycho_client = await asyncio.create_subprocess_exec(
str(bin_path), *cmd, stdout=PIPE, stderr=STDOUT, limit=2 ** 64
str(bin_path), *cmd, stdout=PIPE, stderr=STDOUT, limit=2**64
)
@staticmethod
@@ -238,7 +238,7 @@ class TychoPoolStateStreamAdapter:
@staticmethod
def build_snapshot_message(
protocol_components: dict, protocol_states: dict, contract_states: dict
protocol_components: dict, protocol_states: dict, contract_states: dict
) -> dict[str, ThirdPartyPool]:
vm_states = {state["address"]: state for state in contract_states["accounts"]}
states = {}
@@ -269,7 +269,7 @@ class TychoPoolStateStreamAdapter:
return self.process_snapshot(block, state_msg["snapshot"])
def process_snapshot(
self, block: EVMBlock, state_msg: dict
self, block: EVMBlock, state_msg: dict
) -> BlockProtocolChanges:
start = time.monotonic()
removed_pools = set()

View File

@@ -121,8 +121,7 @@ class ERC20OverwriteFactory:
"""
self._overwrites[self._total_supply_slot] = supply
log.log(
5,
f"Override total supply: token={self._token.address} supply={supply}"
5, f"Override total supply: token={self._token.address} supply={supply}"
)
def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]:
@@ -352,4 +351,4 @@ def get_code_for_address(address: str, connection_string: str = None):
return bytes.fromhex(code[2:])
except RuntimeError as e:
print(f"Error fetching code for address {address}: {e}")
return None
return None