feat: update test suite to support DCI enabled protocols (#225)

* feat: update tycho-client and support testing DCI enabled protocols

* refactor: improve python Typing

* test: update test suite tycho-simulation dependency

The updated version includes an account_balances fix.

* refactor: use sets instead of lists

* feat: update tycho-client test dependency
This commit is contained in:
Louise Poole
2025-06-30 11:44:33 +02:00
committed by GitHub
parent 08e5794de0
commit ef6c826a8a
3 changed files with 51 additions and 34 deletions

View File

@@ -2,5 +2,5 @@ psycopg2==2.9.9
PyYAML==6.0.1
Requests==2.32.2
web3==5.31.3
git+https://github.com/propeller-heads/tycho-indexer.git@0.60.0#subdirectory=tycho-client-py
git+https://github.com/propeller-heads/tycho-simulation.git@0.105.0#subdirectory=tycho_simulation_py
git+https://github.com/propeller-heads/tycho-indexer.git@0.74.0#subdirectory=tycho-client-py
git+https://github.com/propeller-heads/tycho-simulation.git@0.118.0#subdirectory=tycho_simulation_py

View File

@@ -3,11 +3,10 @@ import os
import shutil
import subprocess
import traceback
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import List
from typing import Optional, Callable, Any
import yaml
from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder
@@ -24,6 +23,7 @@ from tycho_indexer_client.dto import (
HexBytes,
ResponseAccount,
Snapshot,
TracedEntryPointParams,
)
from tycho_indexer_client.rpc_client import TychoRPCClient
@@ -39,7 +39,9 @@ from utils import build_snapshot_message, token_factory
class TestResult:
def __init__(self, success: bool, step: str = None, message: str = None):
def __init__(
self, success: bool, step: Optional[str] = None, message: Optional[str] = None
):
self.success = success
self.step = step
self.message = message
@@ -93,7 +95,7 @@ class TestRunner:
print(f"Running {len(self.config.tests)} tests ...\n")
print("--------------------------------\n")
failed_tests = []
failed_tests: list[str] = []
count = 1
for test in self.config.tests:
@@ -141,18 +143,18 @@ class TestRunner:
def validate_state(
self,
expected_components: List[ProtocolComponentWithTestConfig],
expected_components: list[ProtocolComponentWithTestConfig],
stop_block: int,
initialized_accounts: List[str],
initialized_accounts: list[str],
) -> TestResult:
"""Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol")
)
).protocol_components
protocol_states = self.tycho_rpc_client.get_protocol_state(
ProtocolStateParams(protocol_system="test_protocol")
)
components_by_id = {
).states
components_by_id: dict[str, ProtocolComponent] = {
component.id: component for component in protocol_components
}
@@ -214,34 +216,49 @@ class TestRunner:
step = "Simulation validation"
# Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components.
filtered_components = []
simulation_components = [
simulation_components: list[str] = [
c.id for c in expected_components if c.skip_simulation is False
]
related_contracts = set()
related_contracts: set[str] = set()
for account in self.config.initialized_accounts or []:
related_contracts.add(HexBytes(account))
related_contracts.add(account)
for account in initialized_accounts or []:
related_contracts.add(HexBytes(account))
related_contracts.add(account)
# Filter out components that are not set to be used for the simulation
component_related_contracts = set()
# Collect all contracts that are related to the simulation components
filtered_components: list[ProtocolComponent] = []
component_related_contracts: set[str] = set()
for component in protocol_components:
# Filter out components that are not set to be used for the simulation
if component.id in simulation_components:
# Collect component contracts
for a in component.contract_ids:
component_related_contracts.add(a)
component_related_contracts.add(a.hex())
# Collect DCI detected contracts
traces_results = self.tycho_rpc_client.get_traced_entry_points(
TracedEntryPointParams(
protocol_system="test_protocol",
component_ids=[component.id],
)
).traced_entry_points.values()
for traces in traces_results:
for _, trace in traces:
component_related_contracts.update(
trace["accessed_slots"].keys()
)
filtered_components.append(component)
# Check if any of the initialized contracts are not listed as component contract dependencies
unspecified_contracts = related_contracts - component_related_contracts
unspecified_contracts: list[str] = [
c for c in related_contracts if c not in component_related_contracts
]
related_contracts.update(component_related_contracts)
related_contracts = [a.hex() for a in related_contracts]
contract_states = self.tycho_rpc_client.get_contract_state(
ContractStateParams(contract_ids=related_contracts)
)
ContractStateParams(contract_ids=list(related_contracts))
).accounts
if len(filtered_components):
if len(unspecified_contracts):
@@ -254,16 +271,16 @@ class TestRunner:
stop_block, protocol_states, filtered_components, contract_states
)
if len(simulation_failures):
error_msgs = []
error_msgs: list[str] = []
for pool_id, failures in simulation_failures.items():
failures_ = [
failures_formatted: list[str] = [
f"{f.sell_token} -> {f.buy_token}: {f.error}"
for f in failures
]
error_msgs.append(
f"Pool {pool_id} failed simulations: {', '.join(failures_)}"
f"Pool {pool_id} failed simulations: {', '.join(failures_formatted)}"
)
return TestResult.Failed(step=step, message="/n".join(error_msgs))
return TestResult.Failed(step=step, message="\n".join(error_msgs))
print(f"\n{step} passed.\n")
else:
print(f"\n {step} skipped.\n")
@@ -280,7 +297,6 @@ class TestRunner:
contract_states: list[ResponseAccount],
) -> dict[str, list[SimulationFailure]]:
TychoDBSingleton.initialize()
protocol_type_names = self.config.protocol_type_names
block_header = get_block_header(block_number)
block: EVMBlock = EVMBlock(
@@ -289,7 +305,7 @@ class TestRunner:
hash_=block_header.hash.hex(),
)
failed_simulations: dict[str, list[SimulationFailure]] = dict()
failed_simulations: dict[str, list[SimulationFailure]] = {}
try:
adapter_contract = self.adapter_contract_builder.find_contract(
@@ -337,7 +353,7 @@ class TestRunner:
# Try to sell 0.1% of the protocol balance
try:
sell_amount = (
Decimal(prctg) * pool_state.balances[sell_token.address]
Decimal(prctg) * pool_state.balances[sell_token.address]
)
amount_out, gas_used, _ = pool_state.get_amount_out(
sell_token, sell_amount, buy_token
@@ -365,7 +381,9 @@ class TestRunner:
return failed_simulations
@staticmethod
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
def build_spkg(
yaml_file_path: str, modify_func: Callable[[dict[str, Any]], None]
) -> str:
"""Build a Substreams package with modifications to the YAML file."""
backup_file_path = f"{yaml_file_path}.backup"
shutil.copy(yaml_file_path, backup_file_path)
@@ -394,7 +412,7 @@ class TestRunner:
return spkg_name
@staticmethod
def update_initial_block(data: dict, start_block: int) -> None:
def update_initial_block(data: dict[str, Any], start_block: int) -> None:
"""Update the initial block for all modules in the configuration data."""
for module in data["modules"]:
module["initialBlock"] = start_block

View File

@@ -32,7 +32,6 @@ def build_snapshot_message(
for state in protocol_states:
pool_id = state.component_id
if pool_id not in states:
log.warning(f"State for pool {pool_id} not found in components")
continue
states[pool_id]["state"] = state
@@ -62,7 +61,7 @@ def token_factory(rpc_client: TychoRPCClient) -> callable(HexBytes):
if to_fetch:
pagination = PaginationParams(page_size=len(to_fetch), page=0)
params = TokensParams(token_addresses=to_fetch, pagination=pagination)
tokens = _client.get_tokens(params)
tokens = _client.get_tokens(params).tokens
for token in tokens:
address = to_checksum_address(token.address)
eth_token = EthereumToken(