* refactor(substreams): add `rust-toolchain` to the templates and modules. This freezes the wasm output for each module, allowing us to have full control over each packages and when we break the modules hashes. * fix(testing): allow no `initialized_accounts` in tests --------- Co-authored-by: zizou <111426680+flopell@users.noreply.github.com>
389 lines
15 KiB
Python
389 lines
15 KiB
Python
import itertools
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import traceback
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
from decimal import Decimal
|
||
from pathlib import Path
|
||
from typing import List
|
||
|
||
import yaml
|
||
from tycho_simulation_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
||
from tycho_simulation_py.evm.storage import TychoDBSingleton
|
||
from tycho_simulation_py.models import EVMBlock
|
||
from pydantic import BaseModel
|
||
from tycho_indexer_client.dto import (
|
||
Chain,
|
||
ProtocolComponentsParams,
|
||
ProtocolStateParams,
|
||
ContractStateParams,
|
||
ProtocolComponent,
|
||
ResponseProtocolState,
|
||
HexBytes,
|
||
ResponseAccount,
|
||
Snapshot,
|
||
)
|
||
from tycho_indexer_client.rpc_client import TychoRPCClient
|
||
|
||
from models import (
|
||
IntegrationTestsConfig,
|
||
ProtocolComponentWithTestConfig,
|
||
ProtocolComponentExpectation,
|
||
)
|
||
from adapter_builder import AdapterContractBuilder
|
||
from evm import get_token_balance, get_block_header
|
||
from tycho import TychoRunner
|
||
from utils import build_snapshot_message, token_factory
|
||
|
||
|
||
class TestResult:
|
||
def __init__(self, success: bool, step: str = None, message: str = None):
|
||
self.success = success
|
||
self.step = step
|
||
self.message = message
|
||
|
||
@classmethod
|
||
def Passed(cls):
|
||
return cls(success=True)
|
||
|
||
@classmethod
|
||
def Failed(cls, step: str, message: str):
|
||
return cls(success=False, step=step, message=message)
|
||
|
||
|
||
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
|
||
with open(yaml_path, "r") as file:
|
||
yaml_content = yaml.safe_load(file)
|
||
return IntegrationTestsConfig(**yaml_content)
|
||
|
||
|
||
class SimulationFailure(BaseModel):
|
||
pool_id: str
|
||
sell_token: str
|
||
buy_token: str
|
||
error: str
|
||
|
||
|
||
class TestRunner:
|
||
def __init__(
|
||
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
|
||
):
|
||
self.repo_root = os.getcwd()
|
||
config_path = os.path.join(
|
||
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
|
||
)
|
||
self.config: IntegrationTestsConfig = parse_config(config_path)
|
||
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
|
||
self.adapter_contract_builder = AdapterContractBuilder(
|
||
os.path.join(self.repo_root, "evm")
|
||
)
|
||
self.tycho_runner = TychoRunner(
|
||
db_url, with_binary_logs, self.config.initialized_accounts
|
||
)
|
||
self.tycho_rpc_client = TychoRPCClient()
|
||
self._token_factory_func = token_factory(self.tycho_rpc_client)
|
||
self.db_url = db_url
|
||
self._vm_traces = vm_traces
|
||
self._chain = Chain.ethereum
|
||
|
||
def run_tests(self) -> None:
|
||
"""Run all tests specified in the configuration."""
|
||
print(f"Running tests ...")
|
||
|
||
failed_tests = []
|
||
|
||
for test in self.config.tests:
|
||
self.tycho_runner.empty_database(self.db_url)
|
||
|
||
spkg_path = self.build_spkg(
|
||
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
|
||
lambda data: self.update_initial_block(data, test.start_block),
|
||
)
|
||
self.tycho_runner.run_tycho(
|
||
spkg_path,
|
||
test.start_block,
|
||
test.stop_block,
|
||
self.config.protocol_type_names,
|
||
test.initialized_accounts or [],
|
||
)
|
||
|
||
result: TestResult = self.tycho_runner.run_with_rpc_server(
|
||
self.validate_state,
|
||
test.expected_components,
|
||
test.stop_block,
|
||
test.initialized_accounts or [],
|
||
)
|
||
|
||
if result.success:
|
||
print(f"\n✅ {test.name} passed.\n")
|
||
else:
|
||
print(f"\n❗️ {test.name} failed on {result.step}: {result.message}\n")
|
||
|
||
print(
|
||
"\nTest finished! \n"
|
||
f"Passed: {len(self.config.tests) - len(failed_tests)}/{len(self.config.tests)}\n"
|
||
)
|
||
if failed_tests:
|
||
print("Failed tests:")
|
||
for failed_test in failed_tests:
|
||
print(failed_test)
|
||
|
||
def validate_state(
|
||
self,
|
||
expected_components: List[ProtocolComponentWithTestConfig],
|
||
stop_block: int,
|
||
initialized_accounts: List[str],
|
||
) -> TestResult:
|
||
"""Validate the current protocol state against the expected state."""
|
||
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
||
ProtocolComponentsParams(protocol_system="test_protocol")
|
||
)
|
||
protocol_states = self.tycho_rpc_client.get_protocol_state(
|
||
ProtocolStateParams(protocol_system="test_protocol")
|
||
)
|
||
components_by_id = {
|
||
component.id: component for component in protocol_components
|
||
}
|
||
|
||
step = "Protocol component validation"
|
||
try:
|
||
# Step 1: Validate the protocol components
|
||
for expected_component in expected_components:
|
||
comp_id = expected_component.id.lower()
|
||
if comp_id not in components_by_id:
|
||
return TestResult.Failed(
|
||
step=step,
|
||
message=f"'{comp_id}' not found in protocol components. "
|
||
f"Available components: {set(components_by_id.keys())}",
|
||
)
|
||
|
||
diff = ProtocolComponentExpectation(
|
||
**components_by_id[comp_id].dict()
|
||
).compare(ProtocolComponentExpectation(**expected_component.dict()))
|
||
if diff is not None:
|
||
return TestResult.Failed(step=step, message=diff)
|
||
|
||
print(f"\n✅ {step} passed.\n")
|
||
|
||
step = "Token balance validation"
|
||
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
|
||
for component in protocol_components:
|
||
comp_id = component.id.lower()
|
||
for token in component.tokens:
|
||
state = next(
|
||
(
|
||
s
|
||
for s in protocol_states
|
||
if s.component_id.lower() == comp_id
|
||
),
|
||
None,
|
||
)
|
||
if state:
|
||
balance_hex = state.balances.get(token, HexBytes("0x00"))
|
||
else:
|
||
balance_hex = HexBytes("0x00")
|
||
tycho_balance = int(balance_hex)
|
||
token_balances[comp_id][token] = tycho_balance
|
||
|
||
if not self.config.skip_balance_check:
|
||
node_balance = get_token_balance(token, comp_id, stop_block)
|
||
if node_balance != tycho_balance:
|
||
return TestResult.Failed(
|
||
step=step,
|
||
message=f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
|
||
f"from rpc call and {tycho_balance} from Substreams",
|
||
)
|
||
|
||
if not self.config.skip_balance_check:
|
||
print(f"\n✅ {step} passed.\n")
|
||
else:
|
||
print(f"\nℹ️ {step} skipped")
|
||
|
||
step = "Simulation validation"
|
||
|
||
# Loads from Tycho-Indexer the state of all the contracts that are related to the protocol components.
|
||
filtered_components = []
|
||
simulation_components = [
|
||
c.id for c in expected_components if c.skip_simulation is False
|
||
]
|
||
|
||
related_contracts = set()
|
||
for account in self.config.initialized_accounts or []:
|
||
related_contracts.add(HexBytes(account))
|
||
for account in initialized_accounts or []:
|
||
related_contracts.add(HexBytes(account))
|
||
|
||
# Filter out components that are not set to be used for the simulation
|
||
component_related_contracts = set()
|
||
for component in protocol_components:
|
||
if component.id in simulation_components:
|
||
for a in component.contract_ids:
|
||
component_related_contracts.add(a)
|
||
filtered_components.append(component)
|
||
|
||
# Check if any of the initialized contracts are not listed as component contract dependencies
|
||
unspecified_contracts = related_contracts - component_related_contracts
|
||
if len(unspecified_contracts):
|
||
print(
|
||
f"⚠️ The following initialized contracts are not listed as component contract dependencies: {unspecified_contracts}. "
|
||
f"Please ensure that, if they are required for this component's simulation, they are specified under the Protocol Component's contract field."
|
||
)
|
||
|
||
related_contracts.update(component_related_contracts)
|
||
related_contracts = [a.hex() for a in related_contracts]
|
||
|
||
contract_states = self.tycho_rpc_client.get_contract_state(
|
||
ContractStateParams(contract_ids=related_contracts)
|
||
)
|
||
if len(filtered_components):
|
||
simulation_failures = self.simulate_get_amount_out(
|
||
stop_block, protocol_states, filtered_components, contract_states
|
||
)
|
||
if len(simulation_failures):
|
||
error_msgs = []
|
||
for pool_id, failures in simulation_failures.items():
|
||
failures_ = [
|
||
f"{f.sell_token} -> {f.buy_token}: {f.error}"
|
||
for f in failures
|
||
]
|
||
error_msgs.append(
|
||
f"Pool {pool_id} failed simulations: {', '.join(failures_)}"
|
||
)
|
||
return TestResult.Failed(step=step, message="/n".join(error_msgs))
|
||
print(f"\n✅ {step} passed.\n")
|
||
else:
|
||
print(f"\nℹ️ {step} skipped")
|
||
return TestResult.Passed()
|
||
except Exception as e:
|
||
error_message = f"An error occurred: {str(e)}\n" + traceback.format_exc()
|
||
return TestResult.Failed(step=step, message=error_message)
|
||
|
||
def simulate_get_amount_out(
|
||
self,
|
||
block_number: int,
|
||
protocol_states: list[ResponseProtocolState],
|
||
protocol_components: list[ProtocolComponent],
|
||
contract_states: list[ResponseAccount],
|
||
) -> dict[str, list[SimulationFailure]]:
|
||
TychoDBSingleton.initialize()
|
||
protocol_type_names = self.config.protocol_type_names
|
||
|
||
block_header = get_block_header(block_number)
|
||
block: EVMBlock = EVMBlock(
|
||
id=block_number,
|
||
ts=datetime.fromtimestamp(block_header.timestamp),
|
||
hash_=block_header.hash.hex(),
|
||
)
|
||
|
||
failed_simulations: dict[str, list[SimulationFailure]] = dict()
|
||
|
||
try:
|
||
adapter_contract = self.adapter_contract_builder.find_contract(
|
||
self.config.adapter_contract
|
||
)
|
||
except FileNotFoundError:
|
||
adapter_contract = self.adapter_contract_builder.build_target(
|
||
self.config.adapter_contract,
|
||
self.config.adapter_build_signature,
|
||
self.config.adapter_build_args,
|
||
)
|
||
|
||
TychoDBSingleton.clear_instance()
|
||
|
||
decoder = ThirdPartyPoolTychoDecoder(
|
||
token_factory_func=self._token_factory_func,
|
||
adapter_contract=adapter_contract,
|
||
minimum_gas=0,
|
||
trace=self._vm_traces,
|
||
)
|
||
|
||
snapshot_message: Snapshot = build_snapshot_message(
|
||
protocol_states, protocol_components, contract_states
|
||
)
|
||
|
||
decoded = decoder.decode_snapshot(snapshot_message, block)
|
||
|
||
for component in protocol_components:
|
||
if component.id not in decoded:
|
||
failed_simulations[component.id] = [
|
||
SimulationFailure(
|
||
pool_id=component.id,
|
||
sell_token=component.tokens[0].hex(),
|
||
buy_token=component.tokens[1].hex(),
|
||
error="Pool not found in decoded state.",
|
||
)
|
||
]
|
||
|
||
for pool_state in decoded.values():
|
||
pool_id = pool_state.id_
|
||
if not pool_state.balances:
|
||
raise ValueError(f"Missing balances for pool {pool_id}")
|
||
for sell_token, buy_token in itertools.permutations(pool_state.tokens, 2):
|
||
for prctg in ["0.001", "0.01", "0.1"]:
|
||
# Try to sell 0.1% of the protocol balance
|
||
sell_amount = (
|
||
Decimal(prctg) * pool_state.balances[sell_token.address]
|
||
)
|
||
try:
|
||
amount_out, gas_used, _ = pool_state.get_amount_out(
|
||
sell_token, sell_amount, buy_token
|
||
)
|
||
print(
|
||
f"Amount out for {pool_id}: {sell_amount} {sell_token} -> {amount_out} {buy_token} - "
|
||
f"Gas used: {gas_used}"
|
||
)
|
||
except Exception as e:
|
||
print(
|
||
f"Error simulating get_amount_out for {pool_id}: {sell_token} -> {buy_token} at block {block_number}. "
|
||
f"Error: {e}"
|
||
)
|
||
if pool_id not in failed_simulations:
|
||
failed_simulations[pool_id] = []
|
||
failed_simulations[pool_id].append(
|
||
SimulationFailure(
|
||
pool_id=pool_id,
|
||
sell_token=str(sell_token),
|
||
buy_token=str(buy_token),
|
||
error=str(e),
|
||
)
|
||
)
|
||
continue
|
||
return failed_simulations
|
||
|
||
@staticmethod
|
||
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
|
||
"""Build a Substreams package with modifications to the YAML file."""
|
||
backup_file_path = f"{yaml_file_path}.backup"
|
||
shutil.copy(yaml_file_path, backup_file_path)
|
||
|
||
with open(yaml_file_path, "r") as file:
|
||
data = yaml.safe_load(file)
|
||
|
||
modify_func(data)
|
||
spkg_name = f"{yaml_file_path.rsplit('/', 1)[0]}/{data['package']['name'].replace('_', '-')}-{data['package']['version']}.spkg"
|
||
|
||
with open(yaml_file_path, "w") as file:
|
||
yaml.dump(data, file, default_flow_style=False)
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
["substreams", "pack", yaml_file_path], capture_output=True, text=True
|
||
)
|
||
if result.returncode != 0:
|
||
print("Substreams pack command failed:", result.stderr)
|
||
except Exception as e:
|
||
print(f"Error running substreams pack command: {e}")
|
||
|
||
shutil.copy(backup_file_path, yaml_file_path)
|
||
Path(backup_file_path).unlink()
|
||
|
||
return spkg_name
|
||
|
||
@staticmethod
|
||
def update_initial_block(data: dict, start_block: int) -> None:
|
||
"""Update the initial block for all modules in the configuration data."""
|
||
for module in data["modules"]:
|
||
module["initialBlock"] = start_block
|