diff --git a/testing/requirements.txt b/testing/requirements.txt index 685f6d3..b57b8ed 100644 --- a/testing/requirements.txt +++ b/testing/requirements.txt @@ -2,5 +2,5 @@ psycopg2==2.9.9 PyYAML==6.0.1 Requests==2.32.2 web3==5.31.3 -git+https://github.com/propeller-heads/tycho-indexer.git@0.60.0#subdirectory=tycho-client-py -git+https://github.com/propeller-heads/tycho-simulation.git@0.105.0#subdirectory=tycho_simulation_py \ No newline at end of file +git+https://github.com/propeller-heads/tycho-indexer.git@0.74.0#subdirectory=tycho-client-py +git+https://github.com/propeller-heads/tycho-simulation.git@0.118.0#subdirectory=tycho_simulation_py \ No newline at end of file diff --git a/testing/src/runner/runner.py b/testing/src/runner/runner.py index 123bae6..3440d63 100644 --- a/testing/src/runner/runner.py +++ b/testing/src/runner/runner.py @@ -3,11 +3,10 @@ import os import shutil import subprocess import traceback -from collections import defaultdict from datetime import datetime from decimal import Decimal from pathlib import Path -from typing import List +from typing import Optional, Callable, Any import yaml from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder @@ -24,6 +23,7 @@ from tycho_indexer_client.dto import ( HexBytes, ResponseAccount, Snapshot, + TracedEntryPointParams, ) from tycho_indexer_client.rpc_client import TychoRPCClient @@ -39,7 +39,9 @@ from utils import build_snapshot_message, token_factory class TestResult: - def __init__(self, success: bool, step: str = None, message: str = None): + def __init__( + self, success: bool, step: Optional[str] = None, message: Optional[str] = None + ): self.success = success self.step = step self.message = message @@ -93,7 +95,7 @@ class TestRunner: print(f"Running {len(self.config.tests)} tests ...\n") print("--------------------------------\n") - failed_tests = [] + failed_tests: list[str] = [] count = 1 for test in self.config.tests: @@ -141,18 +143,18 @@ class TestRunner: def validate_state( self, - expected_components: List[ProtocolComponentWithTestConfig], + expected_components: list[ProtocolComponentWithTestConfig], stop_block: int, - initialized_accounts: List[str], + initialized_accounts: list[str], ) -> TestResult: """Validate the current protocol state against the expected state.""" protocol_components = self.tycho_rpc_client.get_protocol_components( ProtocolComponentsParams(protocol_system="test_protocol") - ) + ).protocol_components protocol_states = self.tycho_rpc_client.get_protocol_state( ProtocolStateParams(protocol_system="test_protocol") - ) - components_by_id = { + ).states + components_by_id: dict[str, ProtocolComponent] = { component.id: component for component in protocol_components } @@ -214,34 +216,49 @@ class TestRunner: step = "Simulation validation" # Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components. - filtered_components = [] - simulation_components = [ + simulation_components: list[str] = [ c.id for c in expected_components if c.skip_simulation is False ] - related_contracts = set() + related_contracts: set[str] = set() for account in self.config.initialized_accounts or []: - related_contracts.add(HexBytes(account)) + related_contracts.add(account) for account in initialized_accounts or []: - related_contracts.add(HexBytes(account)) + related_contracts.add(account) - # Filter out components that are not set to be used for the simulation - component_related_contracts = set() + # Collect all contracts that are related to the simulation components + filtered_components: list[ProtocolComponent] = [] + component_related_contracts: set[str] = set() for component in protocol_components: + # Filter out components that are not set to be used for the simulation if component.id in simulation_components: + # Collect component contracts for a in component.contract_ids: - component_related_contracts.add(a) + component_related_contracts.add(a.hex()) + # Collect DCI detected contracts + traces_results = self.tycho_rpc_client.get_traced_entry_points( + TracedEntryPointParams( + protocol_system="test_protocol", + component_ids=[component.id], + ) + ).traced_entry_points.values() + for traces in traces_results: + for _, trace in traces: + component_related_contracts.update( + trace["accessed_slots"].keys() + ) filtered_components.append(component) # Check if any of the initialized contracts are not listed as component contract dependencies - unspecified_contracts = related_contracts - component_related_contracts + unspecified_contracts: list[str] = [ + c for c in related_contracts if c not in component_related_contracts + ] related_contracts.update(component_related_contracts) - related_contracts = [a.hex() for a in related_contracts] contract_states = self.tycho_rpc_client.get_contract_state( - ContractStateParams(contract_ids=related_contracts) - ) + ContractStateParams(contract_ids=list(related_contracts)) + ).accounts if len(filtered_components): if len(unspecified_contracts): @@ -254,16 +271,16 @@ class TestRunner: stop_block, protocol_states, filtered_components, contract_states ) if len(simulation_failures): - error_msgs = [] + error_msgs: list[str] = [] for pool_id, failures in simulation_failures.items(): - failures_ = [ + failures_formatted: list[str] = [ f"{f.sell_token} -> {f.buy_token}: {f.error}" for f in failures ] error_msgs.append( - f"Pool {pool_id} failed simulations: {', '.join(failures_)}" + f"Pool {pool_id} failed simulations: {', '.join(failures_formatted)}" ) - return TestResult.Failed(step=step, message="/n".join(error_msgs)) + return TestResult.Failed(step=step, message="\n".join(error_msgs)) print(f"\n✅ {step} passed.\n") else: print(f"\nℹ️ {step} skipped.\n") @@ -280,7 +297,6 @@ class TestRunner: contract_states: list[ResponseAccount], ) -> dict[str, list[SimulationFailure]]: TychoDBSingleton.initialize() - protocol_type_names = self.config.protocol_type_names block_header = get_block_header(block_number) block: EVMBlock = EVMBlock( @@ -289,7 +305,7 @@ class TestRunner: hash_=block_header.hash.hex(), ) - failed_simulations: dict[str, list[SimulationFailure]] = dict() + failed_simulations: dict[str, list[SimulationFailure]] = {} try: adapter_contract = self.adapter_contract_builder.find_contract( @@ -337,7 +353,7 @@ class TestRunner: # Try to sell 0.1% of the protocol balance try: sell_amount = ( - Decimal(prctg) * pool_state.balances[sell_token.address] + Decimal(prctg) * pool_state.balances[sell_token.address] ) amount_out, gas_used, _ = pool_state.get_amount_out( sell_token, sell_amount, buy_token @@ -365,7 +381,9 @@ class TestRunner: return failed_simulations @staticmethod - def build_spkg(yaml_file_path: str, modify_func: callable) -> str: + def build_spkg( + yaml_file_path: str, modify_func: Callable[[dict[str, Any]], None] + ) -> str: """Build a Substreams package with modifications to the YAML file.""" backup_file_path = f"{yaml_file_path}.backup" shutil.copy(yaml_file_path, backup_file_path) @@ -394,7 +412,7 @@ class TestRunner: return spkg_name @staticmethod - def update_initial_block(data: dict, start_block: int) -> None: + def update_initial_block(data: dict[str, Any], start_block: int) -> None: """Update the initial block for all modules in the configuration data.""" for module in data["modules"]: module["initialBlock"] = start_block diff --git a/testing/src/runner/utils.py b/testing/src/runner/utils.py index b4e09ff..6408226 100644 --- a/testing/src/runner/utils.py +++ b/testing/src/runner/utils.py @@ -32,7 +32,6 @@ def build_snapshot_message( for state in protocol_states: pool_id = state.component_id if pool_id not in states: - log.warning(f"State for pool {pool_id} not found in components") continue states[pool_id]["state"] = state @@ -62,7 +61,7 @@ def token_factory(rpc_client: TychoRPCClient) -> callable(HexBytes): if to_fetch: pagination = PaginationParams(page_size=len(to_fetch), page=0) params = TokensParams(token_addresses=to_fetch, pagination=pagination) - tokens = _client.get_tokens(params) + tokens = _client.get_tokens(params).tokens for token in tokens: address = to_checksum_address(token.address) eth_token = EthereumToken(