Files
tycho-protocol-sdk/testing/src/runner/runner.py
2024-08-08 17:58:20 -03:00

325 lines
12 KiB
Python

import itertools
import os
import shutil
import subprocess
import traceback
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import List
import yaml
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
from protosim_py.evm.storage import TychoDBSingleton
from protosim_py.models import EVMBlock
from pydantic import BaseModel
from tycho_client.dto import (
Chain,
ProtocolComponentsParams,
ProtocolStateParams,
ContractStateParams,
ProtocolComponent,
ResponseProtocolState,
HexBytes,
ResponseAccount,
Snapshot,
ContractId,
)
from tycho_client.rpc_client import TychoRPCClient
from models import (
IntegrationTestsConfig,
ProtocolComponentWithTestConfig,
ProtocolComponentExpectation,
)
from adapter_handler import AdapterContractHandler
from evm import get_token_balance, get_block_header
from tycho import TychoRunner
from utils import build_snapshot_message, token_factory
class TestResult:
def __init__(self, success: bool, message: str = None):
self.success = success
self.message = message
@classmethod
def Passed(cls):
return cls(success=True)
@classmethod
def Failed(cls, message: str):
return cls(success=False, message=message)
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
with open(yaml_path, "r") as file:
yaml_content = yaml.safe_load(file)
return IntegrationTestsConfig(**yaml_content)
class SimulationFailure(BaseModel):
pool_id: str
sell_token: str
buy_token: str
error: str
class TestRunner:
def __init__(
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
):
self.repo_root = os.getcwd()
config_path = os.path.join(
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
)
self.config: IntegrationTestsConfig = parse_config(config_path)
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
self.adapters_src = os.path.join(self.repo_root, "evm")
self.tycho_runner = TychoRunner(
db_url, with_binary_logs, self.config.initialized_accounts
)
self.tycho_rpc_client = TychoRPCClient()
self._token_factory_func = token_factory(self.tycho_rpc_client)
self.db_url = db_url
self._vm_traces = vm_traces
self._chain = Chain.ethereum
def run_tests(self) -> None:
"""Run all tests specified in the configuration."""
print(f"Running tests ...")
for test in self.config.tests:
self.tycho_runner.empty_database(self.db_url)
spkg_path = self.build_spkg(
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
lambda data: self.update_initial_block(data, test.start_block),
)
self.tycho_runner.run_tycho(
spkg_path,
test.start_block,
test.stop_block,
self.config.protocol_type_names,
test.initialized_accounts or [],
)
result = self.tycho_runner.run_with_rpc_server(
self.validate_state, test.expected_components, test.stop_block
)
if result.success:
print(f"{test.name} passed.")
else:
print(f"❗️ {test.name} failed: {result.message}")
def validate_state(
self,
expected_components: List[ProtocolComponentWithTestConfig],
stop_block: int,
) -> TestResult:
"""Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol")
)
protocol_states = self.tycho_rpc_client.get_protocol_state(
ProtocolStateParams(protocol_system="test_protocol")
)
components_by_id = {
component.id: component for component in protocol_components
}
try:
for expected_component in expected_components:
comp_id = expected_component.id.lower()
if comp_id not in components_by_id:
return TestResult.Failed(
f"'{comp_id}' not found in protocol components."
)
diff = ProtocolComponentExpectation(
**components_by_id[comp_id].dict()
).compare(ProtocolComponentExpectation(**expected_component.dict()))
if diff is not None:
return TestResult.Failed(diff)
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
for component in protocol_components:
comp_id = component.id.lower()
for token in component.tokens:
state = next(
(
s
for s in protocol_states
if s.component_id.lower() == comp_id
),
None,
)
if state:
balance_hex = state.balances.get(token, HexBytes("0x00"))
else:
balance_hex = HexBytes("0x00")
tycho_balance = int(balance_hex)
token_balances[comp_id][token] = tycho_balance
if not self.config.skip_balance_check:
node_balance = get_token_balance(token, comp_id, stop_block)
if node_balance != tycho_balance:
return TestResult.Failed(
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
f"from rpc call and {tycho_balance} from Substreams"
)
contract_states = self.tycho_rpc_client.get_contract_state(
ContractStateParams(
contract_ids=[
ContractId(chain=self._chain, address=a)
for component in protocol_components
for a in component.contract_ids
]
)
)
filtered_components = [
pc
for pc in protocol_components
if pc.id
in [c.id for c in expected_components if c.skip_simulation is False]
]
simulation_failures = self.simulate_get_amount_out(
stop_block, protocol_states, filtered_components, contract_states
)
if len(simulation_failures):
error_msgs = []
for pool_id, failures in simulation_failures.items():
failures_ = [
f"{f.sell_token} -> {f.buy_token}: {f.error}" for f in failures
]
error_msgs.append(
f"Pool {pool_id} failed simulations: {', '.join(failures_)}"
)
raise ValueError(". ".join(error_msgs))
return TestResult.Passed()
except Exception as e:
error_message = f"An error occurred: {str(e)}\n" + traceback.format_exc()
return TestResult.Failed(error_message)
def simulate_get_amount_out(
self,
block_number: int,
protocol_states: list[ResponseProtocolState],
protocol_components: list[ProtocolComponent],
contract_states: list[ResponseAccount],
) -> dict[str, list[SimulationFailure]]:
TychoDBSingleton.initialize()
protocol_type_names = self.config.protocol_type_names
block_header = get_block_header(block_number)
block: EVMBlock = EVMBlock(
id=block_number,
ts=datetime.fromtimestamp(block_header.timestamp),
hash_=block_header.hash.hex(),
)
failed_simulations: dict[str, list[SimulationFailure]] = dict()
for _ in protocol_type_names:
adapter_contract = os.path.join(
self.adapters_src,
"out",
f"{self.config.adapter_contract}.sol",
f"{self.config.adapter_contract}.evm.runtime",
)
if not os.path.exists(adapter_contract):
print("Adapter contract not found. Building it ...")
AdapterContractHandler.build_target(
self.adapters_src,
self.config.adapter_contract,
self.config.adapter_build_signature,
self.config.adapter_build_args,
)
decoder = ThirdPartyPoolTychoDecoder(
token_factory_func=self._token_factory_func,
adapter_contract=adapter_contract,
minimum_gas=0,
trace=self._vm_traces,
)
snapshot_message: Snapshot = build_snapshot_message(
protocol_states, protocol_components, contract_states
)
decoded = decoder.decode_snapshot(snapshot_message, block)
for pool_state in decoded.values():
pool_id = pool_state.id_
if not pool_state.balances:
raise ValueError(f"Missing balances for pool {pool_id}")
for sell_token, buy_token in itertools.permutations(
pool_state.tokens, 2
):
# Try to sell 0.1% of the protocol balance
sell_amount = (
Decimal("0.001") * pool_state.balances[sell_token.address]
)
try:
amount_out, gas_used, _ = pool_state.get_amount_out(
sell_token, sell_amount, buy_token
)
print(
f"Amount out for {pool_id}: {sell_amount} {sell_token} -> {amount_out} {buy_token} - "
f"Gas used: {gas_used}"
)
except Exception as e:
print(
f"Error simulating get_amount_out for {pool_id}: {sell_token} -> {buy_token}. "
f"Error: {e}"
)
if pool_id not in failed_simulations:
failed_simulations[pool_id] = []
failed_simulations[pool_id].append(
SimulationFailure(
pool_id=pool_id,
sell_token=str(sell_token),
buy_token=str(buy_token),
error=str(e),
)
)
continue
return failed_simulations
@staticmethod
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
"""Build a Substreams package with modifications to the YAML file."""
backup_file_path = f"{yaml_file_path}.backup"
shutil.copy(yaml_file_path, backup_file_path)
with open(yaml_file_path, "r") as file:
data = yaml.safe_load(file)
modify_func(data)
spkg_name = f"{yaml_file_path.rsplit('/', 1)[0]}/{data['package']['name'].replace('_', '-', 1)}-{data['package']['version']}.spkg"
with open(yaml_file_path, "w") as file:
yaml.dump(data, file, default_flow_style=False)
try:
result = subprocess.run(
["substreams", "pack", yaml_file_path], capture_output=True, text=True
)
if result.returncode != 0:
print("Substreams pack command failed:", result.stderr)
except Exception as e:
print(f"Error running substreams pack command: {e}")
shutil.copy(backup_file_path, yaml_file_path)
Path(backup_file_path).unlink()
return spkg_name
@staticmethod
def update_initial_block(data: dict, start_block: int) -> None:
"""Update the initial block for all modules in the configuration data."""
for module in data["modules"]:
module["initialBlock"] = start_block