refactor(substreams-testing): Use Pydantic to deserialize test_assets.yaml

This commit is contained in:
Florian Pellissier
2024-08-06 13:32:58 +02:00
committed by tvinagre
parent 09d266a810
commit 95efda0423
10 changed files with 365 additions and 233 deletions

View File

@@ -7,6 +7,7 @@ from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import List
import yaml
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
@@ -27,6 +28,11 @@ from tycho_client.dto import (
)
from tycho_client.rpc_client import TychoRPCClient
from models import (
IntegrationTestsConfig,
ProtocolComponentWithTestConfig,
ProtocolComponentExpectation,
)
from adapter_handler import AdapterContractHandler
from evm import get_token_balance, get_block_header
from tycho import TychoRunner
@@ -47,10 +53,10 @@ class TestResult:
return cls(success=False, message=message)
def load_config(yaml_path: str) -> dict:
"""Load YAML configuration from a specified file path."""
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
with open(yaml_path, "r") as file:
return yaml.safe_load(file)
yaml_content = yaml.safe_load(file)
return IntegrationTestsConfig(**yaml_content)
class SimulationFailure(BaseModel):
@@ -66,13 +72,13 @@ class TestRunner:
):
self.repo_root = os.getcwd()
config_path = os.path.join(
self.repo_root, "substreams", package, "test_assets.yaml"
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
)
self.config = load_config(config_path)
self.config: IntegrationTestsConfig = parse_config(config_path)
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
self.adapters_src = os.path.join(self.repo_root, "evm")
self.tycho_runner = TychoRunner(
db_url, with_binary_logs, self.config["initialized_accounts"]
db_url, with_binary_logs, self.config.initialized_accounts
)
self.tycho_rpc_client = TychoRPCClient()
self._token_factory_func = token_factory(self.tycho_rpc_client)
@@ -83,32 +89,35 @@ class TestRunner:
def run_tests(self) -> None:
"""Run all tests specified in the configuration."""
print(f"Running tests ...")
for test in self.config["tests"]:
for test in self.config.tests:
self.tycho_runner.empty_database(self.db_url)
spkg_path = self.build_spkg(
os.path.join(self.spkg_src, self.config["substreams_yaml_path"]),
lambda data: self.update_initial_block(data, test["start_block"]),
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
lambda data: self.update_initial_block(data, test.start_block),
)
self.tycho_runner.run_tycho(
spkg_path,
test["start_block"],
test["stop_block"],
self.config["protocol_type_names"],
test.get("initialized_accounts", []),
test.start_block,
test.stop_block,
self.config.protocol_type_names,
test.initialized_accounts or [],
)
result = self.tycho_runner.run_with_rpc_server(
self.validate_state, test["expected_state"], test["stop_block"]
self.validate_state, test.expected_components, test.stop_block
)
if result.success:
print(f"{test['name']} passed.")
print(f"{test.name} passed.")
else:
print(f"❗️ {test['name']} failed: {result.message}")
print(f"❗️ {test.name} failed: {result.message}")
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
def validate_state(
self,
expected_components: List[ProtocolComponentWithTestConfig],
stop_block: int,
) -> TestResult:
"""Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol")
@@ -121,43 +130,18 @@ class TestRunner:
}
try:
for expected_component in expected_state.get("protocol_components", []):
comp_id = expected_component["id"].lower()
for expected_component in expected_components:
comp_id = expected_component.id.lower()
if comp_id not in components_by_id:
return TestResult.Failed(
f"'{comp_id}' not found in protocol components."
)
# TODO: Manipulate pydantic objects instead of dict
component = components_by_id[comp_id].dict()
for key, value in expected_component.items():
if key not in ["tokens", "static_attributes", "creation_tx"]:
continue
if key not in component:
return TestResult.Failed(
f"Missing '{key}' in component '{comp_id}'."
)
if key == "tokens":
if set(map(HexBytes, value)) != set(component[key]):
return TestResult.Failed(
f"Token mismatch for key '{key}': {value} != {component[key]}"
)
elif key == "creation_tx":
if HexBytes(value) != component[key]:
return TestResult.Failed(
f"Creation tx mismatch for key '{key}': {value} != {component[key]}"
)
elif isinstance(value, list):
if set(map(str.lower, value)) != set(
map(str.lower, component[key])
):
return TestResult.Failed(
f"List mismatch for key '{key}': {value} != {component[key]}"
)
elif value is not None and value.lower() != component[key]:
return TestResult.Failed(
f"Value mismatch for key '{key}': {value} != {component[key]}"
)
diff = ProtocolComponentExpectation.from_dto(
components_by_id[comp_id]
).compare(expected_component.into_protocol_component())
if diff is not None:
return TestResult.Failed(diff)
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
for component in protocol_components:
@@ -178,7 +162,7 @@ class TestRunner:
tycho_balance = int(balance_hex)
token_balances[comp_id][token] = tycho_balance
if self.config["skip_balance_check"] is not True:
if self.config.skip_balance_check is not True:
node_balance = get_token_balance(token, comp_id, stop_block)
if node_balance != tycho_balance:
return TestResult.Failed(
@@ -198,11 +182,7 @@ class TestRunner:
pc
for pc in protocol_components
if pc.id
in [
c["id"].lower()
for c in expected_state["protocol_components"]
if c.get("skip_simulation", False) is False
]
in [c.id for c in expected_components if c.skip_simulation is False]
]
simulation_failures = self.simulate_get_amount_out(
stop_block, protocol_states, filtered_components, contract_states
@@ -231,7 +211,7 @@ class TestRunner:
contract_states: list[ResponseAccount],
) -> dict[str, list[SimulationFailure]]:
TychoDBSingleton.initialize()
protocol_type_names = self.config["protocol_type_names"]
protocol_type_names = self.config.protocol_type_names
block_header = get_block_header(block_number)
block: EVMBlock = EVMBlock(
@@ -245,17 +225,17 @@ class TestRunner:
adapter_contract = os.path.join(
self.adapters_src,
"out",
f"{self.config['adapter_contract']}.sol",
f"{self.config['adapter_contract']}.evm.runtime",
f"{self.config.adapter_contract}.sol",
f"{self.config.adapter_contract}.evm.runtime",
)
if not os.path.exists(adapter_contract):
print("Adapter contract not found. Building it ...")
AdapterContractHandler.build_target(
self.adapters_src,
self.config["adapter_contract"],
self.config["adapter_build_signature"],
self.config["adapter_build_args"],
self.config.adapter_contract,
self.config.adapter_build_signature,
self.config.adapter_build_args,
)
decoder = ThirdPartyPoolTychoDecoder(