325 lines
12 KiB
Python
325 lines
12 KiB
Python
import itertools
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import traceback
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import yaml
|
|
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
|
from protosim_py.evm.storage import TychoDBSingleton
|
|
from protosim_py.models import EVMBlock
|
|
from pydantic import BaseModel
|
|
from tycho_client.dto import (
|
|
Chain,
|
|
ProtocolComponentsParams,
|
|
ProtocolStateParams,
|
|
ContractStateParams,
|
|
ProtocolComponent,
|
|
ResponseProtocolState,
|
|
HexBytes,
|
|
ResponseAccount,
|
|
Snapshot,
|
|
ContractId,
|
|
)
|
|
from tycho_client.rpc_client import TychoRPCClient
|
|
|
|
from models import (
|
|
IntegrationTestsConfig,
|
|
ProtocolComponentWithTestConfig,
|
|
ProtocolComponentExpectation,
|
|
)
|
|
from adapter_handler import AdapterContractHandler
|
|
from evm import get_token_balance, get_block_header
|
|
from tycho import TychoRunner
|
|
from utils import build_snapshot_message, token_factory
|
|
|
|
|
|
class TestResult:
|
|
def __init__(self, success: bool, message: str = None):
|
|
self.success = success
|
|
self.message = message
|
|
|
|
@classmethod
|
|
def Passed(cls):
|
|
return cls(success=True)
|
|
|
|
@classmethod
|
|
def Failed(cls, message: str):
|
|
return cls(success=False, message=message)
|
|
|
|
|
|
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
|
|
with open(yaml_path, "r") as file:
|
|
yaml_content = yaml.safe_load(file)
|
|
return IntegrationTestsConfig(**yaml_content)
|
|
|
|
|
|
class SimulationFailure(BaseModel):
|
|
pool_id: str
|
|
sell_token: str
|
|
buy_token: str
|
|
error: str
|
|
|
|
|
|
class TestRunner:
|
|
def __init__(
|
|
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
|
|
):
|
|
self.repo_root = os.getcwd()
|
|
config_path = os.path.join(
|
|
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
|
|
)
|
|
self.config: IntegrationTestsConfig = parse_config(config_path)
|
|
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
|
|
self.adapters_src = os.path.join(self.repo_root, "evm")
|
|
self.tycho_runner = TychoRunner(
|
|
db_url, with_binary_logs, self.config.initialized_accounts
|
|
)
|
|
self.tycho_rpc_client = TychoRPCClient()
|
|
self._token_factory_func = token_factory(self.tycho_rpc_client)
|
|
self.db_url = db_url
|
|
self._vm_traces = vm_traces
|
|
self._chain = Chain.ethereum
|
|
|
|
def run_tests(self) -> None:
|
|
"""Run all tests specified in the configuration."""
|
|
print(f"Running tests ...")
|
|
for test in self.config.tests:
|
|
self.tycho_runner.empty_database(self.db_url)
|
|
|
|
spkg_path = self.build_spkg(
|
|
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
|
|
lambda data: self.update_initial_block(data, test.start_block),
|
|
)
|
|
self.tycho_runner.run_tycho(
|
|
spkg_path,
|
|
test.start_block,
|
|
test.stop_block,
|
|
self.config.protocol_type_names,
|
|
test.initialized_accounts or [],
|
|
)
|
|
|
|
result = self.tycho_runner.run_with_rpc_server(
|
|
self.validate_state, test.expected_components, test.stop_block
|
|
)
|
|
|
|
if result.success:
|
|
print(f"✅ {test.name} passed.")
|
|
else:
|
|
print(f"❗️ {test.name} failed: {result.message}")
|
|
|
|
def validate_state(
|
|
self,
|
|
expected_components: List[ProtocolComponentWithTestConfig],
|
|
stop_block: int,
|
|
) -> TestResult:
|
|
"""Validate the current protocol state against the expected state."""
|
|
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
|
ProtocolComponentsParams(protocol_system="test_protocol")
|
|
)
|
|
protocol_states = self.tycho_rpc_client.get_protocol_state(
|
|
ProtocolStateParams(protocol_system="test_protocol")
|
|
)
|
|
components_by_id = {
|
|
component.id: component for component in protocol_components
|
|
}
|
|
|
|
try:
|
|
for expected_component in expected_components:
|
|
comp_id = expected_component.id.lower()
|
|
if comp_id not in components_by_id:
|
|
return TestResult.Failed(
|
|
f"'{comp_id}' not found in protocol components."
|
|
)
|
|
|
|
diff = ProtocolComponentExpectation(
|
|
**components_by_id[comp_id].dict()
|
|
).compare(ProtocolComponentExpectation(**expected_component.dict()))
|
|
if diff is not None:
|
|
return TestResult.Failed(diff)
|
|
|
|
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
|
|
for component in protocol_components:
|
|
comp_id = component.id.lower()
|
|
for token in component.tokens:
|
|
state = next(
|
|
(
|
|
s
|
|
for s in protocol_states
|
|
if s.component_id.lower() == comp_id
|
|
),
|
|
None,
|
|
)
|
|
if state:
|
|
balance_hex = state.balances.get(token, HexBytes("0x00"))
|
|
else:
|
|
balance_hex = HexBytes("0x00")
|
|
tycho_balance = int(balance_hex)
|
|
token_balances[comp_id][token] = tycho_balance
|
|
|
|
if not self.config.skip_balance_check:
|
|
node_balance = get_token_balance(token, comp_id, stop_block)
|
|
if node_balance != tycho_balance:
|
|
return TestResult.Failed(
|
|
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
|
|
f"from rpc call and {tycho_balance} from Substreams"
|
|
)
|
|
contract_states = self.tycho_rpc_client.get_contract_state(
|
|
ContractStateParams(
|
|
contract_ids=[
|
|
ContractId(chain=self._chain, address=a)
|
|
for component in protocol_components
|
|
for a in component.contract_ids
|
|
]
|
|
)
|
|
)
|
|
filtered_components = [
|
|
pc
|
|
for pc in protocol_components
|
|
if pc.id
|
|
in [c.id for c in expected_components if c.skip_simulation is False]
|
|
]
|
|
simulation_failures = self.simulate_get_amount_out(
|
|
stop_block, protocol_states, filtered_components, contract_states
|
|
)
|
|
if len(simulation_failures):
|
|
error_msgs = []
|
|
for pool_id, failures in simulation_failures.items():
|
|
failures_ = [
|
|
f"{f.sell_token} -> {f.buy_token}: {f.error}" for f in failures
|
|
]
|
|
error_msgs.append(
|
|
f"Pool {pool_id} failed simulations: {', '.join(failures_)}"
|
|
)
|
|
raise ValueError(". ".join(error_msgs))
|
|
|
|
return TestResult.Passed()
|
|
except Exception as e:
|
|
error_message = f"An error occurred: {str(e)}\n" + traceback.format_exc()
|
|
return TestResult.Failed(error_message)
|
|
|
|
def simulate_get_amount_out(
|
|
self,
|
|
block_number: int,
|
|
protocol_states: list[ResponseProtocolState],
|
|
protocol_components: list[ProtocolComponent],
|
|
contract_states: list[ResponseAccount],
|
|
) -> dict[str, list[SimulationFailure]]:
|
|
TychoDBSingleton.initialize()
|
|
protocol_type_names = self.config.protocol_type_names
|
|
|
|
block_header = get_block_header(block_number)
|
|
block: EVMBlock = EVMBlock(
|
|
id=block_number,
|
|
ts=datetime.fromtimestamp(block_header.timestamp),
|
|
hash_=block_header.hash.hex(),
|
|
)
|
|
|
|
failed_simulations: dict[str, list[SimulationFailure]] = dict()
|
|
for _ in protocol_type_names:
|
|
adapter_contract = os.path.join(
|
|
self.adapters_src,
|
|
"out",
|
|
f"{self.config.adapter_contract}.sol",
|
|
f"{self.config.adapter_contract}.evm.runtime",
|
|
)
|
|
if not os.path.exists(adapter_contract):
|
|
print("Adapter contract not found. Building it ...")
|
|
|
|
AdapterContractHandler.build_target(
|
|
self.adapters_src,
|
|
self.config.adapter_contract,
|
|
self.config.adapter_build_signature,
|
|
self.config.adapter_build_args,
|
|
)
|
|
|
|
decoder = ThirdPartyPoolTychoDecoder(
|
|
token_factory_func=self._token_factory_func,
|
|
adapter_contract=adapter_contract,
|
|
minimum_gas=0,
|
|
trace=self._vm_traces,
|
|
)
|
|
|
|
snapshot_message: Snapshot = build_snapshot_message(
|
|
protocol_states, protocol_components, contract_states
|
|
)
|
|
|
|
decoded = decoder.decode_snapshot(snapshot_message, block)
|
|
|
|
for pool_state in decoded.values():
|
|
pool_id = pool_state.id_
|
|
if not pool_state.balances:
|
|
raise ValueError(f"Missing balances for pool {pool_id}")
|
|
for sell_token, buy_token in itertools.permutations(
|
|
pool_state.tokens, 2
|
|
):
|
|
# Try to sell 0.1% of the protocol balance
|
|
sell_amount = (
|
|
Decimal("0.001") * pool_state.balances[sell_token.address]
|
|
)
|
|
try:
|
|
amount_out, gas_used, _ = pool_state.get_amount_out(
|
|
sell_token, sell_amount, buy_token
|
|
)
|
|
print(
|
|
f"Amount out for {pool_id}: {sell_amount} {sell_token} -> {amount_out} {buy_token} - "
|
|
f"Gas used: {gas_used}"
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error simulating get_amount_out for {pool_id}: {sell_token} -> {buy_token}. "
|
|
f"Error: {e}"
|
|
)
|
|
if pool_id not in failed_simulations:
|
|
failed_simulations[pool_id] = []
|
|
failed_simulations[pool_id].append(
|
|
SimulationFailure(
|
|
pool_id=pool_id,
|
|
sell_token=str(sell_token),
|
|
buy_token=str(buy_token),
|
|
error=str(e),
|
|
)
|
|
)
|
|
continue
|
|
return failed_simulations
|
|
|
|
@staticmethod
|
|
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
|
|
"""Build a Substreams package with modifications to the YAML file."""
|
|
backup_file_path = f"{yaml_file_path}.backup"
|
|
shutil.copy(yaml_file_path, backup_file_path)
|
|
|
|
with open(yaml_file_path, "r") as file:
|
|
data = yaml.safe_load(file)
|
|
|
|
modify_func(data)
|
|
spkg_name = f"{yaml_file_path.rsplit('/', 1)[0]}/{data['package']['name'].replace('_', '-', 1)}-{data['package']['version']}.spkg"
|
|
|
|
with open(yaml_file_path, "w") as file:
|
|
yaml.dump(data, file, default_flow_style=False)
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
["substreams", "pack", yaml_file_path], capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
print("Substreams pack command failed:", result.stderr)
|
|
except Exception as e:
|
|
print(f"Error running substreams pack command: {e}")
|
|
|
|
shutil.copy(backup_file_path, yaml_file_path)
|
|
Path(backup_file_path).unlink()
|
|
|
|
return spkg_name
|
|
|
|
@staticmethod
|
|
def update_initial_block(data: dict, start_block: int) -> None:
|
|
"""Update the initial block for all modules in the configuration data."""
|
|
for module in data["modules"]:
|
|
module["initialBlock"] = start_block
|