feat: update test suite to support DCI enabled protocols (#225)
* feat: update tycho-client and support testing DCI enabled protocols * refactor: improve python Typing * test: update test suite tycho-simulation dependency The updated version includes an account_balances fix. * refactor: use sets instead of lists * feat: update tycho-client test dependency
This commit is contained in:
@@ -2,5 +2,5 @@ psycopg2==2.9.9
|
|||||||
PyYAML==6.0.1
|
PyYAML==6.0.1
|
||||||
Requests==2.32.2
|
Requests==2.32.2
|
||||||
web3==5.31.3
|
web3==5.31.3
|
||||||
git+https://github.com/propeller-heads/tycho-indexer.git@0.60.0#subdirectory=tycho-client-py
|
git+https://github.com/propeller-heads/tycho-indexer.git@0.74.0#subdirectory=tycho-client-py
|
||||||
git+https://github.com/propeller-heads/tycho-simulation.git@0.105.0#subdirectory=tycho_simulation_py
|
git+https://github.com/propeller-heads/tycho-simulation.git@0.118.0#subdirectory=tycho_simulation_py
|
||||||
@@ -3,11 +3,10 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Optional, Callable, Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
||||||
@@ -24,6 +23,7 @@ from tycho_indexer_client.dto import (
|
|||||||
HexBytes,
|
HexBytes,
|
||||||
ResponseAccount,
|
ResponseAccount,
|
||||||
Snapshot,
|
Snapshot,
|
||||||
|
TracedEntryPointParams,
|
||||||
)
|
)
|
||||||
from tycho_indexer_client.rpc_client import TychoRPCClient
|
from tycho_indexer_client.rpc_client import TychoRPCClient
|
||||||
|
|
||||||
@@ -39,7 +39,9 @@ from utils import build_snapshot_message, token_factory
|
|||||||
|
|
||||||
|
|
||||||
class TestResult:
|
class TestResult:
|
||||||
def __init__(self, success: bool, step: str = None, message: str = None):
|
def __init__(
|
||||||
|
self, success: bool, step: Optional[str] = None, message: Optional[str] = None
|
||||||
|
):
|
||||||
self.success = success
|
self.success = success
|
||||||
self.step = step
|
self.step = step
|
||||||
self.message = message
|
self.message = message
|
||||||
@@ -93,7 +95,7 @@ class TestRunner:
|
|||||||
print(f"Running {len(self.config.tests)} tests ...\n")
|
print(f"Running {len(self.config.tests)} tests ...\n")
|
||||||
print("--------------------------------\n")
|
print("--------------------------------\n")
|
||||||
|
|
||||||
failed_tests = []
|
failed_tests: list[str] = []
|
||||||
count = 1
|
count = 1
|
||||||
|
|
||||||
for test in self.config.tests:
|
for test in self.config.tests:
|
||||||
@@ -141,18 +143,18 @@ class TestRunner:
|
|||||||
|
|
||||||
def validate_state(
|
def validate_state(
|
||||||
self,
|
self,
|
||||||
expected_components: List[ProtocolComponentWithTestConfig],
|
expected_components: list[ProtocolComponentWithTestConfig],
|
||||||
stop_block: int,
|
stop_block: int,
|
||||||
initialized_accounts: List[str],
|
initialized_accounts: list[str],
|
||||||
) -> TestResult:
|
) -> TestResult:
|
||||||
"""Validate the current protocol state against the expected state."""
|
"""Validate the current protocol state against the expected state."""
|
||||||
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
||||||
ProtocolComponentsParams(protocol_system="test_protocol")
|
ProtocolComponentsParams(protocol_system="test_protocol")
|
||||||
)
|
).protocol_components
|
||||||
protocol_states = self.tycho_rpc_client.get_protocol_state(
|
protocol_states = self.tycho_rpc_client.get_protocol_state(
|
||||||
ProtocolStateParams(protocol_system="test_protocol")
|
ProtocolStateParams(protocol_system="test_protocol")
|
||||||
)
|
).states
|
||||||
components_by_id = {
|
components_by_id: dict[str, ProtocolComponent] = {
|
||||||
component.id: component for component in protocol_components
|
component.id: component for component in protocol_components
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,34 +216,49 @@ class TestRunner:
|
|||||||
step = "Simulation validation"
|
step = "Simulation validation"
|
||||||
|
|
||||||
# Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components.
|
# Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components.
|
||||||
filtered_components = []
|
simulation_components: list[str] = [
|
||||||
simulation_components = [
|
|
||||||
c.id for c in expected_components if c.skip_simulation is False
|
c.id for c in expected_components if c.skip_simulation is False
|
||||||
]
|
]
|
||||||
|
|
||||||
related_contracts = set()
|
related_contracts: set[str] = set()
|
||||||
for account in self.config.initialized_accounts or []:
|
for account in self.config.initialized_accounts or []:
|
||||||
related_contracts.add(HexBytes(account))
|
related_contracts.add(account)
|
||||||
for account in initialized_accounts or []:
|
for account in initialized_accounts or []:
|
||||||
related_contracts.add(HexBytes(account))
|
related_contracts.add(account)
|
||||||
|
|
||||||
# Filter out components that are not set to be used for the simulation
|
# Collect all contracts that are related to the simulation components
|
||||||
component_related_contracts = set()
|
filtered_components: list[ProtocolComponent] = []
|
||||||
|
component_related_contracts: set[str] = set()
|
||||||
for component in protocol_components:
|
for component in protocol_components:
|
||||||
|
# Filter out components that are not set to be used for the simulation
|
||||||
if component.id in simulation_components:
|
if component.id in simulation_components:
|
||||||
|
# Collect component contracts
|
||||||
for a in component.contract_ids:
|
for a in component.contract_ids:
|
||||||
component_related_contracts.add(a)
|
component_related_contracts.add(a.hex())
|
||||||
|
# Collect DCI detected contracts
|
||||||
|
traces_results = self.tycho_rpc_client.get_traced_entry_points(
|
||||||
|
TracedEntryPointParams(
|
||||||
|
protocol_system="test_protocol",
|
||||||
|
component_ids=[component.id],
|
||||||
|
)
|
||||||
|
).traced_entry_points.values()
|
||||||
|
for traces in traces_results:
|
||||||
|
for _, trace in traces:
|
||||||
|
component_related_contracts.update(
|
||||||
|
trace["accessed_slots"].keys()
|
||||||
|
)
|
||||||
filtered_components.append(component)
|
filtered_components.append(component)
|
||||||
|
|
||||||
# Check if any of the initialized contracts are not listed as component contract dependencies
|
# Check if any of the initialized contracts are not listed as component contract dependencies
|
||||||
unspecified_contracts = related_contracts - component_related_contracts
|
unspecified_contracts: list[str] = [
|
||||||
|
c for c in related_contracts if c not in component_related_contracts
|
||||||
|
]
|
||||||
|
|
||||||
related_contracts.update(component_related_contracts)
|
related_contracts.update(component_related_contracts)
|
||||||
related_contracts = [a.hex() for a in related_contracts]
|
|
||||||
|
|
||||||
contract_states = self.tycho_rpc_client.get_contract_state(
|
contract_states = self.tycho_rpc_client.get_contract_state(
|
||||||
ContractStateParams(contract_ids=related_contracts)
|
ContractStateParams(contract_ids=list(related_contracts))
|
||||||
)
|
).accounts
|
||||||
if len(filtered_components):
|
if len(filtered_components):
|
||||||
|
|
||||||
if len(unspecified_contracts):
|
if len(unspecified_contracts):
|
||||||
@@ -254,16 +271,16 @@ class TestRunner:
|
|||||||
stop_block, protocol_states, filtered_components, contract_states
|
stop_block, protocol_states, filtered_components, contract_states
|
||||||
)
|
)
|
||||||
if len(simulation_failures):
|
if len(simulation_failures):
|
||||||
error_msgs = []
|
error_msgs: list[str] = []
|
||||||
for pool_id, failures in simulation_failures.items():
|
for pool_id, failures in simulation_failures.items():
|
||||||
failures_ = [
|
failures_formatted: list[str] = [
|
||||||
f"{f.sell_token} -> {f.buy_token}: {f.error}"
|
f"{f.sell_token} -> {f.buy_token}: {f.error}"
|
||||||
for f in failures
|
for f in failures
|
||||||
]
|
]
|
||||||
error_msgs.append(
|
error_msgs.append(
|
||||||
f"Pool {pool_id} failed simulations: {', '.join(failures_)}"
|
f"Pool {pool_id} failed simulations: {', '.join(failures_formatted)}"
|
||||||
)
|
)
|
||||||
return TestResult.Failed(step=step, message="/n".join(error_msgs))
|
return TestResult.Failed(step=step, message="\n".join(error_msgs))
|
||||||
print(f"\n✅ {step} passed.\n")
|
print(f"\n✅ {step} passed.\n")
|
||||||
else:
|
else:
|
||||||
print(f"\nℹ️ {step} skipped.\n")
|
print(f"\nℹ️ {step} skipped.\n")
|
||||||
@@ -280,7 +297,6 @@ class TestRunner:
|
|||||||
contract_states: list[ResponseAccount],
|
contract_states: list[ResponseAccount],
|
||||||
) -> dict[str, list[SimulationFailure]]:
|
) -> dict[str, list[SimulationFailure]]:
|
||||||
TychoDBSingleton.initialize()
|
TychoDBSingleton.initialize()
|
||||||
protocol_type_names = self.config.protocol_type_names
|
|
||||||
|
|
||||||
block_header = get_block_header(block_number)
|
block_header = get_block_header(block_number)
|
||||||
block: EVMBlock = EVMBlock(
|
block: EVMBlock = EVMBlock(
|
||||||
@@ -289,7 +305,7 @@ class TestRunner:
|
|||||||
hash_=block_header.hash.hex(),
|
hash_=block_header.hash.hex(),
|
||||||
)
|
)
|
||||||
|
|
||||||
failed_simulations: dict[str, list[SimulationFailure]] = dict()
|
failed_simulations: dict[str, list[SimulationFailure]] = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
adapter_contract = self.adapter_contract_builder.find_contract(
|
adapter_contract = self.adapter_contract_builder.find_contract(
|
||||||
@@ -365,7 +381,9 @@ class TestRunner:
|
|||||||
return failed_simulations
|
return failed_simulations
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
|
def build_spkg(
|
||||||
|
yaml_file_path: str, modify_func: Callable[[dict[str, Any]], None]
|
||||||
|
) -> str:
|
||||||
"""Build a Substreams package with modifications to the YAML file."""
|
"""Build a Substreams package with modifications to the YAML file."""
|
||||||
backup_file_path = f"{yaml_file_path}.backup"
|
backup_file_path = f"{yaml_file_path}.backup"
|
||||||
shutil.copy(yaml_file_path, backup_file_path)
|
shutil.copy(yaml_file_path, backup_file_path)
|
||||||
@@ -394,7 +412,7 @@ class TestRunner:
|
|||||||
return spkg_name
|
return spkg_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_initial_block(data: dict, start_block: int) -> None:
|
def update_initial_block(data: dict[str, Any], start_block: int) -> None:
|
||||||
"""Update the initial block for all modules in the configuration data."""
|
"""Update the initial block for all modules in the configuration data."""
|
||||||
for module in data["modules"]:
|
for module in data["modules"]:
|
||||||
module["initialBlock"] = start_block
|
module["initialBlock"] = start_block
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ def build_snapshot_message(
|
|||||||
for state in protocol_states:
|
for state in protocol_states:
|
||||||
pool_id = state.component_id
|
pool_id = state.component_id
|
||||||
if pool_id not in states:
|
if pool_id not in states:
|
||||||
log.warning(f"State for pool {pool_id} not found in components")
|
|
||||||
continue
|
continue
|
||||||
states[pool_id]["state"] = state
|
states[pool_id]["state"] = state
|
||||||
|
|
||||||
@@ -62,7 +61,7 @@ def token_factory(rpc_client: TychoRPCClient) -> callable(HexBytes):
|
|||||||
if to_fetch:
|
if to_fetch:
|
||||||
pagination = PaginationParams(page_size=len(to_fetch), page=0)
|
pagination = PaginationParams(page_size=len(to_fetch), page=0)
|
||||||
params = TokensParams(token_addresses=to_fetch, pagination=pagination)
|
params = TokensParams(token_addresses=to_fetch, pagination=pagination)
|
||||||
tokens = _client.get_tokens(params)
|
tokens = _client.get_tokens(params).tokens
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
address = to_checksum_address(token.address)
|
address = to_checksum_address(token.address)
|
||||||
eth_token = EthereumToken(
|
eth_token = EthereumToken(
|
||||||
|
|||||||
Reference in New Issue
Block a user