refactor(substreams-testing): Use Pydantic to deserialize test_assets.yaml
This commit is contained in:
committed by
tvinagre
parent
09d266a810
commit
95efda0423
@@ -7,6 +7,7 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
||||
@@ -27,6 +28,11 @@ from tycho_client.dto import (
|
||||
)
|
||||
from tycho_client.rpc_client import TychoRPCClient
|
||||
|
||||
from models import (
|
||||
IntegrationTestsConfig,
|
||||
ProtocolComponentWithTestConfig,
|
||||
ProtocolComponentExpectation,
|
||||
)
|
||||
from adapter_handler import AdapterContractHandler
|
||||
from evm import get_token_balance, get_block_header
|
||||
from tycho import TychoRunner
|
||||
@@ -47,10 +53,10 @@ class TestResult:
|
||||
return cls(success=False, message=message)
|
||||
|
||||
|
||||
def load_config(yaml_path: str) -> dict:
|
||||
"""Load YAML configuration from a specified file path."""
|
||||
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
|
||||
with open(yaml_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
yaml_content = yaml.safe_load(file)
|
||||
return IntegrationTestsConfig(**yaml_content)
|
||||
|
||||
|
||||
class SimulationFailure(BaseModel):
|
||||
@@ -66,13 +72,13 @@ class TestRunner:
|
||||
):
|
||||
self.repo_root = os.getcwd()
|
||||
config_path = os.path.join(
|
||||
self.repo_root, "substreams", package, "test_assets.yaml"
|
||||
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
|
||||
)
|
||||
self.config = load_config(config_path)
|
||||
self.config: IntegrationTestsConfig = parse_config(config_path)
|
||||
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
|
||||
self.adapters_src = os.path.join(self.repo_root, "evm")
|
||||
self.tycho_runner = TychoRunner(
|
||||
db_url, with_binary_logs, self.config["initialized_accounts"]
|
||||
db_url, with_binary_logs, self.config.initialized_accounts
|
||||
)
|
||||
self.tycho_rpc_client = TychoRPCClient()
|
||||
self._token_factory_func = token_factory(self.tycho_rpc_client)
|
||||
@@ -83,32 +89,35 @@ class TestRunner:
|
||||
def run_tests(self) -> None:
|
||||
"""Run all tests specified in the configuration."""
|
||||
print(f"Running tests ...")
|
||||
for test in self.config["tests"]:
|
||||
for test in self.config.tests:
|
||||
self.tycho_runner.empty_database(self.db_url)
|
||||
|
||||
spkg_path = self.build_spkg(
|
||||
os.path.join(self.spkg_src, self.config["substreams_yaml_path"]),
|
||||
lambda data: self.update_initial_block(data, test["start_block"]),
|
||||
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
|
||||
lambda data: self.update_initial_block(data, test.start_block),
|
||||
)
|
||||
self.tycho_runner.run_tycho(
|
||||
spkg_path,
|
||||
test["start_block"],
|
||||
test["stop_block"],
|
||||
self.config["protocol_type_names"],
|
||||
test.get("initialized_accounts", []),
|
||||
test.start_block,
|
||||
test.stop_block,
|
||||
self.config.protocol_type_names,
|
||||
test.initialized_accounts or [],
|
||||
)
|
||||
|
||||
result = self.tycho_runner.run_with_rpc_server(
|
||||
self.validate_state, test["expected_state"], test["stop_block"]
|
||||
self.validate_state, test.expected_components, test.stop_block
|
||||
)
|
||||
|
||||
if result.success:
|
||||
print(f"✅ {test['name']} passed.")
|
||||
|
||||
print(f"✅ {test.name} passed.")
|
||||
else:
|
||||
print(f"❗️ {test['name']} failed: {result.message}")
|
||||
print(f"❗️ {test.name} failed: {result.message}")
|
||||
|
||||
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
|
||||
def validate_state(
|
||||
self,
|
||||
expected_components: List[ProtocolComponentWithTestConfig],
|
||||
stop_block: int,
|
||||
) -> TestResult:
|
||||
"""Validate the current protocol state against the expected state."""
|
||||
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
||||
ProtocolComponentsParams(protocol_system="test_protocol")
|
||||
@@ -121,43 +130,18 @@ class TestRunner:
|
||||
}
|
||||
|
||||
try:
|
||||
for expected_component in expected_state.get("protocol_components", []):
|
||||
comp_id = expected_component["id"].lower()
|
||||
for expected_component in expected_components:
|
||||
comp_id = expected_component.id.lower()
|
||||
if comp_id not in components_by_id:
|
||||
return TestResult.Failed(
|
||||
f"'{comp_id}' not found in protocol components."
|
||||
)
|
||||
|
||||
# TODO: Manipulate pydantic objects instead of dict
|
||||
component = components_by_id[comp_id].dict()
|
||||
for key, value in expected_component.items():
|
||||
if key not in ["tokens", "static_attributes", "creation_tx"]:
|
||||
continue
|
||||
if key not in component:
|
||||
return TestResult.Failed(
|
||||
f"Missing '{key}' in component '{comp_id}'."
|
||||
)
|
||||
if key == "tokens":
|
||||
if set(map(HexBytes, value)) != set(component[key]):
|
||||
return TestResult.Failed(
|
||||
f"Token mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
elif key == "creation_tx":
|
||||
if HexBytes(value) != component[key]:
|
||||
return TestResult.Failed(
|
||||
f"Creation tx mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
if set(map(str.lower, value)) != set(
|
||||
map(str.lower, component[key])
|
||||
):
|
||||
return TestResult.Failed(
|
||||
f"List mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
elif value is not None and value.lower() != component[key]:
|
||||
return TestResult.Failed(
|
||||
f"Value mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
diff = ProtocolComponentExpectation.from_dto(
|
||||
components_by_id[comp_id]
|
||||
).compare(expected_component.into_protocol_component())
|
||||
if diff is not None:
|
||||
return TestResult.Failed(diff)
|
||||
|
||||
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
|
||||
for component in protocol_components:
|
||||
@@ -178,7 +162,7 @@ class TestRunner:
|
||||
tycho_balance = int(balance_hex)
|
||||
token_balances[comp_id][token] = tycho_balance
|
||||
|
||||
if self.config["skip_balance_check"] is not True:
|
||||
if self.config.skip_balance_check is not True:
|
||||
node_balance = get_token_balance(token, comp_id, stop_block)
|
||||
if node_balance != tycho_balance:
|
||||
return TestResult.Failed(
|
||||
@@ -198,11 +182,7 @@ class TestRunner:
|
||||
pc
|
||||
for pc in protocol_components
|
||||
if pc.id
|
||||
in [
|
||||
c["id"].lower()
|
||||
for c in expected_state["protocol_components"]
|
||||
if c.get("skip_simulation", False) is False
|
||||
]
|
||||
in [c.id for c in expected_components if c.skip_simulation is False]
|
||||
]
|
||||
simulation_failures = self.simulate_get_amount_out(
|
||||
stop_block, protocol_states, filtered_components, contract_states
|
||||
@@ -231,7 +211,7 @@ class TestRunner:
|
||||
contract_states: list[ResponseAccount],
|
||||
) -> dict[str, list[SimulationFailure]]:
|
||||
TychoDBSingleton.initialize()
|
||||
protocol_type_names = self.config["protocol_type_names"]
|
||||
protocol_type_names = self.config.protocol_type_names
|
||||
|
||||
block_header = get_block_header(block_number)
|
||||
block: EVMBlock = EVMBlock(
|
||||
@@ -245,17 +225,17 @@ class TestRunner:
|
||||
adapter_contract = os.path.join(
|
||||
self.adapters_src,
|
||||
"out",
|
||||
f"{self.config['adapter_contract']}.sol",
|
||||
f"{self.config['adapter_contract']}.evm.runtime",
|
||||
f"{self.config.adapter_contract}.sol",
|
||||
f"{self.config.adapter_contract}.evm.runtime",
|
||||
)
|
||||
if not os.path.exists(adapter_contract):
|
||||
print("Adapter contract not found. Building it ...")
|
||||
|
||||
AdapterContractHandler.build_target(
|
||||
self.adapters_src,
|
||||
self.config["adapter_contract"],
|
||||
self.config["adapter_build_signature"],
|
||||
self.config["adapter_build_args"],
|
||||
self.config.adapter_contract,
|
||||
self.config.adapter_build_signature,
|
||||
self.config.adapter_build_args,
|
||||
)
|
||||
|
||||
decoder = ThirdPartyPoolTychoDecoder(
|
||||
|
||||
Reference in New Issue
Block a user