feat: update test suite to support DCI enabled protocols (#225)
* feat: update tycho-client and support testing DCI enabled protocols * refactor: improve python Typing * test: update test suite tycho-simulation dependency The updated version includes an account_balances fix. * refactor: use sets instead of lists * feat: update tycho-client test dependency
This commit is contained in:
@@ -2,5 +2,5 @@ psycopg2==2.9.9
|
||||
PyYAML==6.0.1
|
||||
Requests==2.32.2
|
||||
web3==5.31.3
|
||||
git+https://github.com/propeller-heads/tycho-indexer.git@0.60.0#subdirectory=tycho-client-py
|
||||
git+https://github.com/propeller-heads/tycho-simulation.git@0.105.0#subdirectory=tycho_simulation_py
|
||||
git+https://github.com/propeller-heads/tycho-indexer.git@0.74.0#subdirectory=tycho-client-py
|
||||
git+https://github.com/propeller-heads/tycho-simulation.git@0.118.0#subdirectory=tycho_simulation_py
|
||||
@@ -3,11 +3,10 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
import yaml
|
||||
from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
||||
@@ -24,6 +23,7 @@ from tycho_indexer_client.dto import (
|
||||
HexBytes,
|
||||
ResponseAccount,
|
||||
Snapshot,
|
||||
TracedEntryPointParams,
|
||||
)
|
||||
from tycho_indexer_client.rpc_client import TychoRPCClient
|
||||
|
||||
@@ -39,7 +39,9 @@ from utils import build_snapshot_message, token_factory
|
||||
|
||||
|
||||
class TestResult:
|
||||
def __init__(self, success: bool, step: str = None, message: str = None):
|
||||
def __init__(
|
||||
self, success: bool, step: Optional[str] = None, message: Optional[str] = None
|
||||
):
|
||||
self.success = success
|
||||
self.step = step
|
||||
self.message = message
|
||||
@@ -93,7 +95,7 @@ class TestRunner:
|
||||
print(f"Running {len(self.config.tests)} tests ...\n")
|
||||
print("--------------------------------\n")
|
||||
|
||||
failed_tests = []
|
||||
failed_tests: list[str] = []
|
||||
count = 1
|
||||
|
||||
for test in self.config.tests:
|
||||
@@ -141,18 +143,18 @@ class TestRunner:
|
||||
|
||||
def validate_state(
|
||||
self,
|
||||
expected_components: List[ProtocolComponentWithTestConfig],
|
||||
expected_components: list[ProtocolComponentWithTestConfig],
|
||||
stop_block: int,
|
||||
initialized_accounts: List[str],
|
||||
initialized_accounts: list[str],
|
||||
) -> TestResult:
|
||||
"""Validate the current protocol state against the expected state."""
|
||||
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
||||
ProtocolComponentsParams(protocol_system="test_protocol")
|
||||
)
|
||||
).protocol_components
|
||||
protocol_states = self.tycho_rpc_client.get_protocol_state(
|
||||
ProtocolStateParams(protocol_system="test_protocol")
|
||||
)
|
||||
components_by_id = {
|
||||
).states
|
||||
components_by_id: dict[str, ProtocolComponent] = {
|
||||
component.id: component for component in protocol_components
|
||||
}
|
||||
|
||||
@@ -214,34 +216,49 @@ class TestRunner:
|
||||
step = "Simulation validation"
|
||||
|
||||
# Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components.
|
||||
filtered_components = []
|
||||
simulation_components = [
|
||||
simulation_components: list[str] = [
|
||||
c.id for c in expected_components if c.skip_simulation is False
|
||||
]
|
||||
|
||||
related_contracts = set()
|
||||
related_contracts: set[str] = set()
|
||||
for account in self.config.initialized_accounts or []:
|
||||
related_contracts.add(HexBytes(account))
|
||||
related_contracts.add(account)
|
||||
for account in initialized_accounts or []:
|
||||
related_contracts.add(HexBytes(account))
|
||||
related_contracts.add(account)
|
||||
|
||||
# Filter out components that are not set to be used for the simulation
|
||||
component_related_contracts = set()
|
||||
# Collect all contracts that are related to the simulation components
|
||||
filtered_components: list[ProtocolComponent] = []
|
||||
component_related_contracts: set[str] = set()
|
||||
for component in protocol_components:
|
||||
# Filter out components that are not set to be used for the simulation
|
||||
if component.id in simulation_components:
|
||||
# Collect component contracts
|
||||
for a in component.contract_ids:
|
||||
component_related_contracts.add(a)
|
||||
component_related_contracts.add(a.hex())
|
||||
# Collect DCI detected contracts
|
||||
traces_results = self.tycho_rpc_client.get_traced_entry_points(
|
||||
TracedEntryPointParams(
|
||||
protocol_system="test_protocol",
|
||||
component_ids=[component.id],
|
||||
)
|
||||
).traced_entry_points.values()
|
||||
for traces in traces_results:
|
||||
for _, trace in traces:
|
||||
component_related_contracts.update(
|
||||
trace["accessed_slots"].keys()
|
||||
)
|
||||
filtered_components.append(component)
|
||||
|
||||
# Check if any of the initialized contracts are not listed as component contract dependencies
|
||||
unspecified_contracts = related_contracts - component_related_contracts
|
||||
unspecified_contracts: list[str] = [
|
||||
c for c in related_contracts if c not in component_related_contracts
|
||||
]
|
||||
|
||||
related_contracts.update(component_related_contracts)
|
||||
related_contracts = [a.hex() for a in related_contracts]
|
||||
|
||||
contract_states = self.tycho_rpc_client.get_contract_state(
|
||||
ContractStateParams(contract_ids=related_contracts)
|
||||
)
|
||||
ContractStateParams(contract_ids=list(related_contracts))
|
||||
).accounts
|
||||
if len(filtered_components):
|
||||
|
||||
if len(unspecified_contracts):
|
||||
@@ -254,16 +271,16 @@ class TestRunner:
|
||||
stop_block, protocol_states, filtered_components, contract_states
|
||||
)
|
||||
if len(simulation_failures):
|
||||
error_msgs = []
|
||||
error_msgs: list[str] = []
|
||||
for pool_id, failures in simulation_failures.items():
|
||||
failures_ = [
|
||||
failures_formatted: list[str] = [
|
||||
f"{f.sell_token} -> {f.buy_token}: {f.error}"
|
||||
for f in failures
|
||||
]
|
||||
error_msgs.append(
|
||||
f"Pool {pool_id} failed simulations: {', '.join(failures_)}"
|
||||
f"Pool {pool_id} failed simulations: {', '.join(failures_formatted)}"
|
||||
)
|
||||
return TestResult.Failed(step=step, message="/n".join(error_msgs))
|
||||
return TestResult.Failed(step=step, message="\n".join(error_msgs))
|
||||
print(f"\n✅ {step} passed.\n")
|
||||
else:
|
||||
print(f"\nℹ️ {step} skipped.\n")
|
||||
@@ -280,7 +297,6 @@ class TestRunner:
|
||||
contract_states: list[ResponseAccount],
|
||||
) -> dict[str, list[SimulationFailure]]:
|
||||
TychoDBSingleton.initialize()
|
||||
protocol_type_names = self.config.protocol_type_names
|
||||
|
||||
block_header = get_block_header(block_number)
|
||||
block: EVMBlock = EVMBlock(
|
||||
@@ -289,7 +305,7 @@ class TestRunner:
|
||||
hash_=block_header.hash.hex(),
|
||||
)
|
||||
|
||||
failed_simulations: dict[str, list[SimulationFailure]] = dict()
|
||||
failed_simulations: dict[str, list[SimulationFailure]] = {}
|
||||
|
||||
try:
|
||||
adapter_contract = self.adapter_contract_builder.find_contract(
|
||||
@@ -337,7 +353,7 @@ class TestRunner:
|
||||
# Try to sell 0.1% of the protocol balance
|
||||
try:
|
||||
sell_amount = (
|
||||
Decimal(prctg) * pool_state.balances[sell_token.address]
|
||||
Decimal(prctg) * pool_state.balances[sell_token.address]
|
||||
)
|
||||
amount_out, gas_used, _ = pool_state.get_amount_out(
|
||||
sell_token, sell_amount, buy_token
|
||||
@@ -365,7 +381,9 @@ class TestRunner:
|
||||
return failed_simulations
|
||||
|
||||
@staticmethod
|
||||
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
|
||||
def build_spkg(
|
||||
yaml_file_path: str, modify_func: Callable[[dict[str, Any]], None]
|
||||
) -> str:
|
||||
"""Build a Substreams package with modifications to the YAML file."""
|
||||
backup_file_path = f"{yaml_file_path}.backup"
|
||||
shutil.copy(yaml_file_path, backup_file_path)
|
||||
@@ -394,7 +412,7 @@ class TestRunner:
|
||||
return spkg_name
|
||||
|
||||
@staticmethod
|
||||
def update_initial_block(data: dict, start_block: int) -> None:
|
||||
def update_initial_block(data: dict[str, Any], start_block: int) -> None:
|
||||
"""Update the initial block for all modules in the configuration data."""
|
||||
for module in data["modules"]:
|
||||
module["initialBlock"] = start_block
|
||||
|
||||
@@ -32,7 +32,6 @@ def build_snapshot_message(
|
||||
for state in protocol_states:
|
||||
pool_id = state.component_id
|
||||
if pool_id not in states:
|
||||
log.warning(f"State for pool {pool_id} not found in components")
|
||||
continue
|
||||
states[pool_id]["state"] = state
|
||||
|
||||
@@ -62,7 +61,7 @@ def token_factory(rpc_client: TychoRPCClient) -> callable(HexBytes):
|
||||
if to_fetch:
|
||||
pagination = PaginationParams(page_size=len(to_fetch), page=0)
|
||||
params = TokensParams(token_addresses=to_fetch, pagination=pagination)
|
||||
tokens = _client.get_tokens(params)
|
||||
tokens = _client.get_tokens(params).tokens
|
||||
for token in tokens:
|
||||
address = to_checksum_address(token.address)
|
||||
eth_token = EthereumToken(
|
||||
|
||||
Reference in New Issue
Block a user