Start using external modules

This commit is contained in:
Thales
2024-08-05 19:58:10 -03:00
committed by tvinagre
parent 8ea02613a2
commit d893ab264c
9 changed files with 171 additions and 176 deletions

View File

@@ -6,29 +6,20 @@ def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run indexer within a specified range of blocks" description="Run indexer within a specified range of blocks"
) )
parser.add_argument("--package", type=str, help="Name of the package to test.")
parser.add_argument( parser.add_argument(
"--package", type=str, help="Name of the package to test." "--tycho-logs", action="store_true", help="Flag to activate logs from Tycho."
)
parser.add_argument(
"--tycho-logs",
action="store_true",
help="Flag to activate logs from Tycho.",
) )
parser.add_argument( parser.add_argument(
"--db-url", type=str, help="Postgres database URL for the Tycho indexer." "--db-url", type=str, help="Postgres database URL for the Tycho indexer."
) )
parser.add_argument( parser.add_argument(
"--vm-traces", "--vm-traces", action="store_true", help="Enable tracing during vm simulations."
action="store_true",
help="Enable tracing during vm simulations.",
) )
args = parser.parse_args() args = parser.parse_args()
test_runner = TestRunner( test_runner = TestRunner(
args.package, args.package, args.tycho_logs, db_url=args.db_url, vm_traces=args.vm_traces
args.tycho_logs,
db_url=args.db_url,
vm_traces=args.vm_traces,
) )
test_runner.run_tests() test_runner.run_tests()

View File

@@ -2,22 +2,31 @@ import itertools
import os import os
import shutil import shutil
import subprocess import subprocess
import traceback
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from decimal import Decimal from decimal import Decimal
from pathlib import Path from pathlib import Path
import traceback
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from tycho_client.dto import (
Chain,
ProtocolComponentsParams,
ProtocolStateParams,
ContractStateParams, ProtocolComponent, ResponseProtocolState, HexBytes,
)
from tycho_client.rpc_client import TychoRPCClient
from tycho_client.stream import TychoStream
from adapter_handler import AdapterContractHandler from .adapter_handler import AdapterContractHandler
from tycho_client.decoders import ThirdPartyPoolTychoDecoder from .evm import get_token_balance, get_block_header
from tycho_client.models import Blockchain, EVMBlock from .tycho import TychoRunner
from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter
from evm import get_token_balance, get_block_header
from tycho import TychoRunner, TychoRPCClient # from tycho_client.decoders import ThirdPartyPoolTychoDecoder
# from tycho_client.models import Blockchain, EVMBlock
# from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter
class TestResult: class TestResult:
@@ -49,7 +58,7 @@ class SimulationFailure(BaseModel):
class TestRunner: class TestRunner:
def __init__( def __init__(
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
): ):
self.repo_root = os.getcwd() self.repo_root = os.getcwd()
config_path = os.path.join( config_path = os.path.join(
@@ -64,7 +73,7 @@ class TestRunner:
self.tycho_rpc_client = TychoRPCClient() self.tycho_rpc_client = TychoRPCClient()
self.db_url = db_url self.db_url = db_url
self._vm_traces = vm_traces self._vm_traces = vm_traces
self._chain = Blockchain.ethereum self._chain = Chain.ethereum
def run_tests(self) -> None: def run_tests(self) -> None:
"""Run all tests specified in the configuration.""" """Run all tests specified in the configuration."""
@@ -96,22 +105,27 @@ class TestRunner:
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult: def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
"""Validate the current protocol state against the expected state.""" """Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components() protocol_components: list[ProtocolComponent] = self.tycho_rpc_client.get_protocol_components(
protocol_states = self.tycho_rpc_client.get_protocol_state() ProtocolComponentsParams(protocol_system="test_protocol")
components = { )
component["id"]: component protocol_states: list[ResponseProtocolState] = self.tycho_rpc_client.get_protocol_state(
for component in protocol_components["protocol_components"] ProtocolStateParams(protocol_system="test_protocol")
)
components_by_id = {
component.id: component
for component in protocol_components
} }
try: try:
for expected_component in expected_state.get("protocol_components", []): for expected_component in expected_state.get("protocol_components", []):
comp_id = expected_component["id"].lower() comp_id = expected_component["id"].lower()
if comp_id not in components: if comp_id not in components_by_id:
return TestResult.Failed( return TestResult.Failed(
f"'{comp_id}' not found in protocol components." f"'{comp_id}' not found in protocol components."
) )
component = components[comp_id] # TODO: Manipulate pydantic objects instead of dict
component = components_by_id[comp_id].dict()
for key, value in expected_component.items(): for key, value in expected_component.items():
if key not in ["tokens", "static_attributes", "creation_tx"]: if key not in ["tokens", "static_attributes", "creation_tx"]:
continue continue
@@ -121,7 +135,7 @@ class TestRunner:
) )
if isinstance(value, list): if isinstance(value, list):
if set(map(str.lower, value)) != set( if set(map(str.lower, value)) != set(
map(str.lower, component[key]) map(str.lower, component[key])
): ):
return TestResult.Failed( return TestResult.Failed(
f"List mismatch for key '{key}': {value} != {component[key]}" f"List mismatch for key '{key}': {value} != {component[key]}"
@@ -131,11 +145,10 @@ class TestRunner:
f"Value mismatch for key '{key}': {value} != {component[key]}" f"Value mismatch for key '{key}': {value} != {component[key]}"
) )
token_balances: dict[str, dict[str, int]] = defaultdict(dict) token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
for component in protocol_components["protocol_components"]: for component in protocol_components:
comp_id = component["id"].lower() comp_id = component.id.lower()
for token in component["tokens"]: for token in component.tokens:
token_lower = token.lower()
state = next( state = next(
( (
s s
@@ -145,11 +158,11 @@ class TestRunner:
None, None,
) )
if state: if state:
balance_hex = state["balances"].get(token_lower, "0x0") balance_hex = state["balances"].get(token, "0x0")
else: else:
balance_hex = "0x0" balance_hex = "0x0"
tycho_balance = int(balance_hex, 16) tycho_balance = int(balance_hex, 16)
token_balances[comp_id][token_lower] = tycho_balance token_balances[comp_id][token] = tycho_balance
if self.config["skip_balance_check"] is not True: if self.config["skip_balance_check"] is not True:
node_balance = get_token_balance(token, comp_id, stop_block) node_balance = get_token_balance(token, comp_id, stop_block)
@@ -158,17 +171,18 @@ class TestRunner:
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} " f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
f"from rpc call and {tycho_balance} from Substreams" f"from rpc call and {tycho_balance} from Substreams"
) )
contract_states = self.tycho_rpc_client.get_contract_state() contract_states = self.tycho_rpc_client.get_contract_state(
ContractStateParams()
)
filtered_components = { filtered_components = {
"protocol_components": [ "protocol_components": [
pc pc
for pc in protocol_components["protocol_components"] for pc in protocol_components
if pc["id"] if pc.id in [
in [ c["id"].lower()
c["id"].lower() for c in expected_state["protocol_components"]
for c in expected_state["protocol_components"] if c.get("skip_simulation", False) is False
if c.get("skip_simulation", False) is False ]
]
] ]
} }
simulation_failures = self.simulate_get_amount_out( simulation_failures = self.simulate_get_amount_out(
@@ -191,11 +205,11 @@ class TestRunner:
return TestResult.Failed(error_message) return TestResult.Failed(error_message)
def simulate_get_amount_out( def simulate_get_amount_out(
self, self,
block_number: int, block_number: int,
protocol_states: dict, protocol_states: ResponseProtocolState,
protocol_components: dict, protocol_components: list[ProtocolComponent],
contract_state: dict, contract_state: Contract,
) -> dict[str, list[SimulationFailure]]: ) -> dict[str, list[SimulationFailure]]:
protocol_type_names = self.config["protocol_type_names"] protocol_type_names = self.config["protocol_type_names"]
@@ -224,15 +238,22 @@ class TestRunner:
self.config["adapter_build_args"], self.config["adapter_build_args"],
) )
decoder = ThirdPartyPoolTychoDecoder( # decoder = ThirdPartyPoolTychoDecoder(
adapter_contract, 0, trace=self._vm_traces # adapter_contract, 0, trace=self._vm_traces
) # )
stream_adapter = TychoPoolStateStreamAdapter( # stream_adapter = TychoPoolStateStreamAdapter(
# tycho_url="0.0.0.0:4242",
# protocol=protocol,
# decoder=decoder,
# blockchain=self._chain,
# )
stream_adapter = TychoStream(
tycho_url="0.0.0.0:4242", tycho_url="0.0.0.0:4242",
protocol=protocol, exchanges=[protocol],
decoder=decoder, min_tvl=Decimal("0"),
blockchain=self._chain, blockchain=self._chain,
) )
snapshot_message = stream_adapter.build_snapshot_message( snapshot_message = stream_adapter.build_snapshot_message(
protocol_components, protocol_states, contract_state protocol_components, protocol_states, contract_state
) )
@@ -243,11 +264,11 @@ class TestRunner:
if not pool_state.balances: if not pool_state.balances:
raise ValueError(f"Missing balances for pool {pool_id}") raise ValueError(f"Missing balances for pool {pool_id}")
for sell_token, buy_token in itertools.permutations( for sell_token, buy_token in itertools.permutations(
pool_state.tokens, 2 pool_state.tokens, 2
): ):
# Try to sell 0.1% of the protocol balance # Try to sell 0.1% of the protocol balance
sell_amount = ( sell_amount = (
Decimal("0.001") * pool_state.balances[sell_token.address] Decimal("0.001") * pool_state.balances[sell_token.address]
) )
try: try:
amount_out, gas_used, _ = pool_state.get_amount_out( amount_out, gas_used, _ = pool_state.get_amount_out(

View File

@@ -39,51 +39,24 @@ def find_binary_file(file_name):
binary_path = find_binary_file("tycho-indexer") binary_path = find_binary_file("tycho-indexer")
class TychoRPCClient:
def __init__(self, rpc_url: str = "http://0.0.0.0:4242"):
self.rpc_url = rpc_url
def get_protocol_components(self) -> dict:
"""Retrieve protocol components from the RPC server."""
url = self.rpc_url + "/v1/ethereum/protocol_components"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {"protocol_system": "test_protocol"}
response = requests.post(url, headers=headers, json=data)
return response.json()
def get_protocol_state(self) -> dict:
"""Retrieve protocol state from the RPC server."""
url = self.rpc_url + "/v1/ethereum/protocol_state"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
return response.json()
def get_contract_state(self) -> dict:
"""Retrieve contract state from the RPC server."""
url = self.rpc_url + "/v1/ethereum/contract_state?include_balances=false"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
return response.json()
class TychoRunner: class TychoRunner:
def __init__(self, db_url: str, with_binary_logs: bool = False, initialized_accounts: list[str] = None): def __init__(
self,
db_url: str,
with_binary_logs: bool = False,
initialized_accounts: list[str] = None,
):
self.with_binary_logs = with_binary_logs self.with_binary_logs = with_binary_logs
self._db_url = db_url self._db_url = db_url
self._initialized_accounts = initialized_accounts or [] self._initialized_accounts = initialized_accounts or []
def run_tycho( def run_tycho(
self, self,
spkg_path: str, spkg_path: str,
start_block: int, start_block: int,
end_block: int, end_block: int,
protocol_type_names: list, protocol_type_names: list,
initialized_accounts: list, initialized_accounts: list,
) -> None: ) -> None:
"""Run the Tycho indexer with the specified SPKG and block range.""" """Run the Tycho indexer with the specified SPKG and block range."""
@@ -112,13 +85,18 @@ class TychoRunner:
str(end_block + 2), str(end_block + 2),
"--initialization-block", "--initialization-block",
str(start_block), str(start_block),
] + (["--initialized-accounts", ",".join(all_accounts)] if all_accounts else []), ]
+ (
["--initialized-accounts", ",".join(all_accounts)]
if all_accounts
else []
),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
bufsize=1, bufsize=1,
env=env, env=env,
) )
with process.stdout: with process.stdout:
for line in iter(process.stdout.readline, ""): for line in iter(process.stdout.readline, ""):
@@ -151,12 +129,7 @@ class TychoRunner:
env["RUST_LOG"] = "info" env["RUST_LOG"] = "info"
process = subprocess.Popen( process = subprocess.Popen(
[ [binary_path, "--database-url", self._db_url, "rpc"],
binary_path,
"--database-url",
self._db_url,
"rpc"
],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
@@ -206,7 +179,7 @@ class TychoRunner:
def empty_database(db_url: str) -> None: def empty_database(db_url: str) -> None:
"""Drop and recreate the Tycho indexer database.""" """Drop and recreate the Tycho indexer database."""
try: try:
conn = psycopg2.connect(db_url[:db_url.rfind('/')]) conn = psycopg2.connect(db_url[: db_url.rfind("/")])
conn.autocommit = True conn.autocommit = True
cursor = conn.cursor() cursor = conn.cursor()

View File

@@ -7,6 +7,6 @@ TYCHO_CLIENT_LOG_FOLDER = TYCHO_CLIENT_FOLDER / "logs"
EXTERNAL_ACCOUNT: Final[str] = "0xf847a638E44186F3287ee9F8cAF73FF4d4B80784" EXTERNAL_ACCOUNT: Final[str] = "0xf847a638E44186F3287ee9F8cAF73FF4d4B80784"
"""This is a dummy address used as a transaction sender""" """This is a dummy address used as a transaction sender"""
UINT256_MAX: Final[int] = 2 ** 256 - 1 UINT256_MAX: Final[int] = 2**256 - 1
MAX_BALANCE: Final[int] = UINT256_MAX // 2 MAX_BALANCE: Final[int] = UINT256_MAX // 2
"""0.5 of the maximal possible balance to avoid overflow errors""" """0.5 of the maximal possible balance to avoid overflow errors"""

View File

@@ -26,10 +26,10 @@ class ThirdPartyPoolTychoDecoder:
self.trace = trace self.trace = trace
def decode_snapshot( def decode_snapshot(
self, self,
snapshot: dict[str, Any], snapshot: dict[str, Any],
block: EVMBlock, block: EVMBlock,
tokens: dict[str, EthereumToken], tokens: dict[str, EthereumToken],
) -> tuple[dict[str, ThirdPartyPool], list[str]]: ) -> tuple[dict[str, ThirdPartyPool], list[str]]:
pools = {} pools = {}
failed_pools = [] failed_pools = []
@@ -45,7 +45,7 @@ class ThirdPartyPoolTychoDecoder:
return pools, failed_pools return pools, failed_pools
def decode_pool_state( def decode_pool_state(
self, snap: dict, block: EVMBlock, tokens: dict[str, EthereumToken] self, snap: dict, block: EVMBlock, tokens: dict[str, EthereumToken]
) -> ThirdPartyPool: ) -> ThirdPartyPool:
component = snap["component"] component = snap["component"]
exchange, _ = decode_tycho_exchange(component["protocol_system"]) exchange, _ = decode_tycho_exchange(component["protocol_system"])
@@ -85,20 +85,29 @@ class ThirdPartyPoolTychoDecoder:
while f"stateless_contract_addr_{index}" in static_attributes: while f"stateless_contract_addr_{index}" in static_attributes:
encoded_address = static_attributes[f"stateless_contract_addr_{index}"] encoded_address = static_attributes[f"stateless_contract_addr_{index}"]
decoded = bytes.fromhex( decoded = bytes.fromhex(
encoded_address[2:] if encoded_address.startswith('0x') else encoded_address).decode('utf-8') encoded_address[2:]
if encoded_address.startswith("0x")
else encoded_address
).decode("utf-8")
if decoded.startswith("call"): if decoded.startswith("call"):
address = ThirdPartyPoolTychoDecoder.get_address_from_call(block_number, decoded) address = ThirdPartyPoolTychoDecoder.get_address_from_call(
block_number, decoded
)
else: else:
address = decoded address = decoded
code = static_attributes.get(f"stateless_contract_code_{index}") or get_code_for_address(address) code = static_attributes.get(
f"stateless_contract_code_{index}"
) or get_code_for_address(address)
stateless_contracts[address] = code stateless_contracts[address] = code
index += 1 index += 1
index = 0 index = 0
while f"stateless_contract_addr_{index}" in attributes: while f"stateless_contract_addr_{index}" in attributes:
address = attributes[f"stateless_contract_addr_{index}"] address = attributes[f"stateless_contract_addr_{index}"]
code = attributes.get(f"stateless_contract_code_{index}") or get_code_for_address(address) code = attributes.get(
f"stateless_contract_code_{index}"
) or get_code_for_address(address)
stateless_contracts[address] = code stateless_contracts[address] = code
index += 1 index += 1
return { return {
@@ -118,15 +127,17 @@ class ThirdPartyPoolTychoDecoder:
permanent_storage=None, permanent_storage=None,
) )
selector = keccak(text=decoded.split(":")[-1])[:4] selector = keccak(text=decoded.split(":")[-1])[:4]
sim_result = engine.run_sim(SimulationParameters( sim_result = engine.run_sim(
data=bytearray(selector), SimulationParameters(
to=decoded.split(':')[1], data=bytearray(selector),
block_number=block_number, to=decoded.split(":")[1],
timestamp=int(time.time()), block_number=block_number,
overrides={}, timestamp=int(time.time()),
caller=EXTERNAL_ACCOUNT, overrides={},
value=0, caller=EXTERNAL_ACCOUNT,
)) value=0,
)
)
address = eth_abi.decode(["address"], bytearray(sim_result.result)) address = eth_abi.decode(["address"], bytearray(sim_result.result))
return address[0] return address[0]
@@ -143,10 +154,10 @@ class ThirdPartyPoolTychoDecoder:
@staticmethod @staticmethod
def apply_update( def apply_update(
pool: ThirdPartyPool, pool: ThirdPartyPool,
pool_update: dict[str, Any], pool_update: dict[str, Any],
balance_updates: dict[str, Any], balance_updates: dict[str, Any],
block: EVMBlock, block: EVMBlock,
) -> ThirdPartyPool: ) -> ThirdPartyPool:
# check for and apply optional state attributes # check for and apply optional state attributes
attributes = pool_update.get("updated_attributes") attributes = pool_update.get("updated_attributes")

View File

@@ -41,7 +41,7 @@ class EthereumToken(BaseModel):
log.warning(f"Expected variable of type Decimal. Got {type(amount)}.") log.warning(f"Expected variable of type Decimal. Got {type(amount)}.")
with localcontext(Context(rounding=ROUND_FLOOR, prec=256)): with localcontext(Context(rounding=ROUND_FLOOR, prec=256)):
amount = Decimal(str(amount)) * (10 ** self.decimals) amount = Decimal(str(amount)) * (10**self.decimals)
try: try:
amount = amount.quantize(Decimal("1.0")) amount = amount.quantize(Decimal("1.0"))
except InvalidOperation: except InvalidOperation:
@@ -51,7 +51,7 @@ class EthereumToken(BaseModel):
return int(amount) return int(amount)
def from_onchain_amount( def from_onchain_amount(
self, onchain_amount: Union[int, Fraction], quantize: bool = True self, onchain_amount: Union[int, Fraction], quantize: bool = True
) -> Decimal: ) -> Decimal:
"""Converts an Integer to a quantized decimal, by shifting left by the token's """Converts an Integer to a quantized decimal, by shifting left by the token's
maximum amount of decimals (e.g.: 1000000 becomes 1.000000 for a 6-decimal token maximum amount of decimals (e.g.: 1000000 becomes 1.000000 for a 6-decimal token
@@ -66,19 +66,19 @@ class EthereumToken(BaseModel):
with localcontext(Context(rounding=ROUND_FLOOR, prec=256)): with localcontext(Context(rounding=ROUND_FLOOR, prec=256)):
if isinstance(onchain_amount, Fraction): if isinstance(onchain_amount, Fraction):
return ( return (
Decimal(onchain_amount.numerator) Decimal(onchain_amount.numerator)
/ Decimal(onchain_amount.denominator) / Decimal(onchain_amount.denominator)
/ Decimal(10 ** self.decimals) / Decimal(10**self.decimals)
).quantize(Decimal(f"{1 / 10 ** self.decimals}")) ).quantize(Decimal(f"{1 / 10 ** self.decimals}"))
if quantize is True: if quantize is True:
try: try:
amount = ( amount = (
Decimal(str(onchain_amount)) / 10 ** self.decimals Decimal(str(onchain_amount)) / 10**self.decimals
).quantize(Decimal(f"{1 / 10 ** self.decimals}")) ).quantize(Decimal(f"{1 / 10 ** self.decimals}"))
except InvalidOperation: except InvalidOperation:
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals) amount = Decimal(str(onchain_amount)) / Decimal(10**self.decimals)
else: else:
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals) amount = Decimal(str(onchain_amount)) / Decimal(10**self.decimals)
return amount return amount
def __repr__(self): def __repr__(self):

View File

@@ -142,7 +142,7 @@ class ThirdPartyPool(BaseModel):
if Capability.ScaledPrice in self.capabilities: if Capability.ScaledPrice in self.capabilities:
self.spot_prices[(t0, t1)] = frac_to_decimal(frac) self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
else: else:
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals) scaled = frac * Fraction(10**t0.decimals, 10**t1.decimals)
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled) self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
def _ensure_capability(self, capability: Capability): def _ensure_capability(self, capability: Capability):
@@ -165,10 +165,10 @@ class ThirdPartyPool(BaseModel):
) )
def get_amount_out( def get_amount_out(
self: TPoolState, self: TPoolState,
sell_token: EthereumToken, sell_token: EthereumToken,
sell_amount: Decimal, sell_amount: Decimal,
buy_token: EthereumToken, buy_token: EthereumToken,
) -> tuple[Decimal, int, TPoolState]: ) -> tuple[Decimal, int, TPoolState]:
# if the pool has a hard limit and the sell amount exceeds that, simulate and # if the pool has a hard limit and the sell amount exceeds that, simulate and
# raise a partial trade # raise a partial trade
@@ -185,10 +185,10 @@ class ThirdPartyPool(BaseModel):
return self._get_amount_out(sell_token, sell_amount, buy_token) return self._get_amount_out(sell_token, sell_amount, buy_token)
def _get_amount_out( def _get_amount_out(
self: TPoolState, self: TPoolState,
sell_token: EthereumToken, sell_token: EthereumToken,
sell_amount: Decimal, sell_amount: Decimal,
buy_token: EthereumToken, buy_token: EthereumToken,
) -> tuple[Decimal, int, TPoolState]: ) -> tuple[Decimal, int, TPoolState]:
trade, state_changes = self._adapter_contract.swap( trade, state_changes = self._adapter_contract.swap(
cast(HexStr, self.id_), cast(HexStr, self.id_),
@@ -216,7 +216,7 @@ class ThirdPartyPool(BaseModel):
return buy_amount, trade.gas_used, new_state return buy_amount, trade.gas_used, new_state
def _get_overwrites( def _get_overwrites(
self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs self, sell_token: EthereumToken, buy_token: EthereumToken, **kwargs
) -> dict[Address, dict[int, int]]: ) -> dict[Address, dict[int, int]]:
"""Get an overwrites dictionary to use in a simulation. """Get an overwrites dictionary to use in a simulation.
@@ -227,7 +227,7 @@ class ThirdPartyPool(BaseModel):
return _merge(self.block_lasting_overwrites, token_overwrites) return _merge(self.block_lasting_overwrites, token_overwrites)
def _get_token_overwrites( def _get_token_overwrites(
self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None self, sell_token: EthereumToken, buy_token: EthereumToken, max_amount=None
) -> dict[Address, dict[int, int]]: ) -> dict[Address, dict[int, int]]:
"""Creates overwrites for a token. """Creates overwrites for a token.
@@ -295,7 +295,7 @@ class ThirdPartyPool(BaseModel):
) )
def get_sell_amount_limit( def get_sell_amount_limit(
self, sell_token: EthereumToken, buy_token: EthereumToken self, sell_token: EthereumToken, buy_token: EthereumToken
) -> Decimal: ) -> Decimal:
""" """
Retrieves the sell amount of the given token. Retrieves the sell amount of the given token.

View File

@@ -26,10 +26,10 @@ log = getLogger(__name__)
class TokenLoader: class TokenLoader:
def __init__( def __init__(
self, self,
tycho_url: str, tycho_url: str,
blockchain: Blockchain, blockchain: Blockchain,
min_token_quality: Optional[int] = 0, min_token_quality: Optional[int] = 0,
): ):
self.tycho_url = tycho_url self.tycho_url = tycho_url
self.blockchain = blockchain self.blockchain = blockchain
@@ -45,10 +45,10 @@ class TokenLoader:
start = time.monotonic() start = time.monotonic()
all_tokens = [] all_tokens = []
while data := self._get_all_with_pagination( while data := self._get_all_with_pagination(
url=url, url=url,
page=page, page=page,
limit=self._token_limit, limit=self._token_limit,
params={"min_quality": self.min_token_quality}, params={"min_quality": self.min_token_quality},
): ):
all_tokens.extend(data) all_tokens.extend(data)
page += 1 page += 1
@@ -73,10 +73,10 @@ class TokenLoader:
start = time.monotonic() start = time.monotonic()
all_tokens = [] all_tokens = []
while data := self._get_all_with_pagination( while data := self._get_all_with_pagination(
url=url, url=url,
page=page, page=page,
limit=self._token_limit, limit=self._token_limit,
params={"min_quality": self.min_token_quality, "addresses": addresses}, params={"min_quality": self.min_token_quality, "addresses": addresses},
): ):
all_tokens.extend(data) all_tokens.extend(data)
page += 1 page += 1
@@ -95,7 +95,7 @@ class TokenLoader:
@staticmethod @staticmethod
def _get_all_with_pagination( def _get_all_with_pagination(
url: str, params: Optional[Dict] = None, page: int = 0, limit: int = 50 url: str, params: Optional[Dict] = None, page: int = 0, limit: int = 50
) -> Dict: ) -> Dict:
if params is None: if params is None:
params = {} params = {}
@@ -122,14 +122,14 @@ class BlockProtocolChanges:
class TychoPoolStateStreamAdapter: class TychoPoolStateStreamAdapter:
def __init__( def __init__(
self, self,
tycho_url: str, tycho_url: str,
protocol: str, protocol: str,
decoder: ThirdPartyPoolTychoDecoder, decoder: ThirdPartyPoolTychoDecoder,
blockchain: Blockchain, blockchain: Blockchain,
min_tvl: Optional[Decimal] = 10, min_tvl: Optional[Decimal] = 10,
min_token_quality: Optional[int] = 0, min_token_quality: Optional[int] = 0,
include_state=True, include_state=True,
): ):
""" """
:param tycho_url: URL to connect to Tycho DB :param tycho_url: URL to connect to Tycho DB
@@ -181,7 +181,7 @@ class TychoPoolStateStreamAdapter:
log.debug(f"Starting tycho-client binary at {bin_path}. CMD: {cmd}") log.debug(f"Starting tycho-client binary at {bin_path}. CMD: {cmd}")
self.tycho_client = await asyncio.create_subprocess_exec( self.tycho_client = await asyncio.create_subprocess_exec(
str(bin_path), *cmd, stdout=PIPE, stderr=STDOUT, limit=2 ** 64 str(bin_path), *cmd, stdout=PIPE, stderr=STDOUT, limit=2**64
) )
@staticmethod @staticmethod
@@ -238,7 +238,7 @@ class TychoPoolStateStreamAdapter:
@staticmethod @staticmethod
def build_snapshot_message( def build_snapshot_message(
protocol_components: dict, protocol_states: dict, contract_states: dict protocol_components: dict, protocol_states: dict, contract_states: dict
) -> dict[str, ThirdPartyPool]: ) -> dict[str, ThirdPartyPool]:
vm_states = {state["address"]: state for state in contract_states["accounts"]} vm_states = {state["address"]: state for state in contract_states["accounts"]}
states = {} states = {}
@@ -269,7 +269,7 @@ class TychoPoolStateStreamAdapter:
return self.process_snapshot(block, state_msg["snapshot"]) return self.process_snapshot(block, state_msg["snapshot"])
def process_snapshot( def process_snapshot(
self, block: EVMBlock, state_msg: dict self, block: EVMBlock, state_msg: dict
) -> BlockProtocolChanges: ) -> BlockProtocolChanges:
start = time.monotonic() start = time.monotonic()
removed_pools = set() removed_pools = set()

View File

@@ -121,8 +121,7 @@ class ERC20OverwriteFactory:
""" """
self._overwrites[self._total_supply_slot] = supply self._overwrites[self._total_supply_slot] = supply
log.log( log.log(
5, 5, f"Override total supply: token={self._token.address} supply={supply}"
f"Override total supply: token={self._token.address} supply={supply}"
) )
def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]: def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]:
@@ -352,4 +351,4 @@ def get_code_for_address(address: str, connection_string: str = None):
return bytes.fromhex(code[2:]) return bytes.fromhex(code[2:])
except RuntimeError as e: except RuntimeError as e:
print(f"Error fetching code for address {address}: {e}") print(f"Error fetching code for address {address}: {e}")
return None return None