feat: update test suite to support DCI enabled protocols (#225)

* feat: update tycho-client and support testing DCI enabled protocols

* refactor: improve python Typing

* test: update test suite tycho-simulation dependency

The updated version includes an account_balances fix.

* refactor: use sets instead of lists

* feat: update tycho-client test dependency
This commit is contained in:
Louise Poole
2025-06-30 11:44:33 +02:00
committed by GitHub
parent 08e5794de0
commit ef6c826a8a
3 changed files with 51 additions and 34 deletions

View File

@@ -2,5 +2,5 @@ psycopg2==2.9.9
PyYAML==6.0.1 PyYAML==6.0.1
Requests==2.32.2 Requests==2.32.2
web3==5.31.3 web3==5.31.3
git+https://github.com/propeller-heads/tycho-indexer.git@0.60.0#subdirectory=tycho-client-py git+https://github.com/propeller-heads/tycho-indexer.git@0.74.0#subdirectory=tycho-client-py
git+https://github.com/propeller-heads/tycho-simulation.git@0.105.0#subdirectory=tycho_simulation_py git+https://github.com/propeller-heads/tycho-simulation.git@0.118.0#subdirectory=tycho_simulation_py

View File

@@ -3,11 +3,10 @@ import os
import shutil import shutil
import subprocess import subprocess
import traceback import traceback
from collections import defaultdict
from datetime import datetime from datetime import datetime
from decimal import Decimal from decimal import Decimal
from pathlib import Path from pathlib import Path
from typing import List from typing import Optional, Callable, Any
import yaml import yaml
from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder
@@ -24,6 +23,7 @@ from tycho_indexer_client.dto import (
HexBytes, HexBytes,
ResponseAccount, ResponseAccount,
Snapshot, Snapshot,
TracedEntryPointParams,
) )
from tycho_indexer_client.rpc_client import TychoRPCClient from tycho_indexer_client.rpc_client import TychoRPCClient
@@ -39,7 +39,9 @@ from utils import build_snapshot_message, token_factory
class TestResult: class TestResult:
def __init__(self, success: bool, step: str = None, message: str = None): def __init__(
self, success: bool, step: Optional[str] = None, message: Optional[str] = None
):
self.success = success self.success = success
self.step = step self.step = step
self.message = message self.message = message
@@ -93,7 +95,7 @@ class TestRunner:
print(f"Running {len(self.config.tests)} tests ...\n") print(f"Running {len(self.config.tests)} tests ...\n")
print("--------------------------------\n") print("--------------------------------\n")
failed_tests = [] failed_tests: list[str] = []
count = 1 count = 1
for test in self.config.tests: for test in self.config.tests:
@@ -141,18 +143,18 @@ class TestRunner:
def validate_state( def validate_state(
self, self,
expected_components: List[ProtocolComponentWithTestConfig], expected_components: list[ProtocolComponentWithTestConfig],
stop_block: int, stop_block: int,
initialized_accounts: List[str], initialized_accounts: list[str],
) -> TestResult: ) -> TestResult:
"""Validate the current protocol state against the expected state.""" """Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components( protocol_components = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol") ProtocolComponentsParams(protocol_system="test_protocol")
) ).protocol_components
protocol_states = self.tycho_rpc_client.get_protocol_state( protocol_states = self.tycho_rpc_client.get_protocol_state(
ProtocolStateParams(protocol_system="test_protocol") ProtocolStateParams(protocol_system="test_protocol")
) ).states
components_by_id = { components_by_id: dict[str, ProtocolComponent] = {
component.id: component for component in protocol_components component.id: component for component in protocol_components
} }
@@ -214,34 +216,49 @@ class TestRunner:
step = "Simulation validation" step = "Simulation validation"
# Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components. # Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components.
filtered_components = [] simulation_components: list[str] = [
simulation_components = [
c.id for c in expected_components if c.skip_simulation is False c.id for c in expected_components if c.skip_simulation is False
] ]
related_contracts = set() related_contracts: set[str] = set()
for account in self.config.initialized_accounts or []: for account in self.config.initialized_accounts or []:
related_contracts.add(HexBytes(account)) related_contracts.add(account)
for account in initialized_accounts or []: for account in initialized_accounts or []:
related_contracts.add(HexBytes(account)) related_contracts.add(account)
# Filter out components that are not set to be used for the simulation # Collect all contracts that are related to the simulation components
component_related_contracts = set() filtered_components: list[ProtocolComponent] = []
component_related_contracts: set[str] = set()
for component in protocol_components: for component in protocol_components:
# Filter out components that are not set to be used for the simulation
if component.id in simulation_components: if component.id in simulation_components:
# Collect component contracts
for a in component.contract_ids: for a in component.contract_ids:
component_related_contracts.add(a) component_related_contracts.add(a.hex())
# Collect DCI detected contracts
traces_results = self.tycho_rpc_client.get_traced_entry_points(
TracedEntryPointParams(
protocol_system="test_protocol",
component_ids=[component.id],
)
).traced_entry_points.values()
for traces in traces_results:
for _, trace in traces:
component_related_contracts.update(
trace["accessed_slots"].keys()
)
filtered_components.append(component) filtered_components.append(component)
# Check if any of the initialized contracts are not listed as component contract dependencies # Check if any of the initialized contracts are not listed as component contract dependencies
unspecified_contracts = related_contracts - component_related_contracts unspecified_contracts: list[str] = [
c for c in related_contracts if c not in component_related_contracts
]
related_contracts.update(component_related_contracts) related_contracts.update(component_related_contracts)
related_contracts = [a.hex() for a in related_contracts]
contract_states = self.tycho_rpc_client.get_contract_state( contract_states = self.tycho_rpc_client.get_contract_state(
ContractStateParams(contract_ids=related_contracts) ContractStateParams(contract_ids=list(related_contracts))
) ).accounts
if len(filtered_components): if len(filtered_components):
if len(unspecified_contracts): if len(unspecified_contracts):
@@ -254,16 +271,16 @@ class TestRunner:
stop_block, protocol_states, filtered_components, contract_states stop_block, protocol_states, filtered_components, contract_states
) )
if len(simulation_failures): if len(simulation_failures):
error_msgs = [] error_msgs: list[str] = []
for pool_id, failures in simulation_failures.items(): for pool_id, failures in simulation_failures.items():
failures_ = [ failures_formatted: list[str] = [
f"{f.sell_token} -> {f.buy_token}: {f.error}" f"{f.sell_token} -> {f.buy_token}: {f.error}"
for f in failures for f in failures
] ]
error_msgs.append( error_msgs.append(
f"Pool {pool_id} failed simulations: {', '.join(failures_)}" f"Pool {pool_id} failed simulations: {', '.join(failures_formatted)}"
) )
return TestResult.Failed(step=step, message="/n".join(error_msgs)) return TestResult.Failed(step=step, message="\n".join(error_msgs))
print(f"\n{step} passed.\n") print(f"\n{step} passed.\n")
else: else:
print(f"\n {step} skipped.\n") print(f"\n {step} skipped.\n")
@@ -280,7 +297,6 @@ class TestRunner:
contract_states: list[ResponseAccount], contract_states: list[ResponseAccount],
) -> dict[str, list[SimulationFailure]]: ) -> dict[str, list[SimulationFailure]]:
TychoDBSingleton.initialize() TychoDBSingleton.initialize()
protocol_type_names = self.config.protocol_type_names
block_header = get_block_header(block_number) block_header = get_block_header(block_number)
block: EVMBlock = EVMBlock( block: EVMBlock = EVMBlock(
@@ -289,7 +305,7 @@ class TestRunner:
hash_=block_header.hash.hex(), hash_=block_header.hash.hex(),
) )
failed_simulations: dict[str, list[SimulationFailure]] = dict() failed_simulations: dict[str, list[SimulationFailure]] = {}
try: try:
adapter_contract = self.adapter_contract_builder.find_contract( adapter_contract = self.adapter_contract_builder.find_contract(
@@ -365,7 +381,9 @@ class TestRunner:
return failed_simulations return failed_simulations
@staticmethod @staticmethod
def build_spkg(yaml_file_path: str, modify_func: callable) -> str: def build_spkg(
yaml_file_path: str, modify_func: Callable[[dict[str, Any]], None]
) -> str:
"""Build a Substreams package with modifications to the YAML file.""" """Build a Substreams package with modifications to the YAML file."""
backup_file_path = f"{yaml_file_path}.backup" backup_file_path = f"{yaml_file_path}.backup"
shutil.copy(yaml_file_path, backup_file_path) shutil.copy(yaml_file_path, backup_file_path)
@@ -394,7 +412,7 @@ class TestRunner:
return spkg_name return spkg_name
@staticmethod @staticmethod
def update_initial_block(data: dict, start_block: int) -> None: def update_initial_block(data: dict[str, Any], start_block: int) -> None:
"""Update the initial block for all modules in the configuration data.""" """Update the initial block for all modules in the configuration data."""
for module in data["modules"]: for module in data["modules"]:
module["initialBlock"] = start_block module["initialBlock"] = start_block

View File

@@ -32,7 +32,6 @@ def build_snapshot_message(
for state in protocol_states: for state in protocol_states:
pool_id = state.component_id pool_id = state.component_id
if pool_id not in states: if pool_id not in states:
log.warning(f"State for pool {pool_id} not found in components")
continue continue
states[pool_id]["state"] = state states[pool_id]["state"] = state
@@ -62,7 +61,7 @@ def token_factory(rpc_client: TychoRPCClient) -> callable(HexBytes):
if to_fetch: if to_fetch:
pagination = PaginationParams(page_size=len(to_fetch), page=0) pagination = PaginationParams(page_size=len(to_fetch), page=0)
params = TokensParams(token_addresses=to_fetch, pagination=pagination) params = TokensParams(token_addresses=to_fetch, pagination=pagination)
tokens = _client.get_tokens(params) tokens = _client.get_tokens(params).tokens
for token in tokens: for token in tokens:
address = to_checksum_address(token.address) address = to_checksum_address(token.address)
eth_token = EthereumToken( eth_token = EthereumToken(