refactor(substreams-testing): Use Pydantic to deserialize test_assets.yaml
This commit is contained in:
committed by
tvinagre
parent
09d266a810
commit
95efda0423
@@ -14,7 +14,8 @@ The testing suite builds the `.spkg` for your Substreams module, indexes a speci
|
||||
|
||||
## Test Configuration
|
||||
|
||||
Tests are defined in a `yaml` file. A template can be found at `substreams/ethereum-template/test_assets.yaml`. The configuration file should include:
|
||||
Tests are defined in a `yaml` file. A template can be found at
|
||||
`substreams/ethereum-template/integration_test.tycho.yaml`. The configuration file should include:
|
||||
|
||||
- The target Substreams config file.
|
||||
- The expected protocol types.
|
||||
|
||||
104
testing/src/runner/models.py
Normal file
104
testing/src/runner/models.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from tycho_client.dto import ProtocolComponent
|
||||
|
||||
|
||||
class ProtocolComponentExpectation(BaseModel):
|
||||
"""Represents a ProtocolComponent with its main attributes."""
|
||||
|
||||
id: str = Field(..., description="Identifier of the protocol component")
|
||||
tokens: List[str] = Field(
|
||||
...,
|
||||
description="List of token addresses associated with the protocol component",
|
||||
)
|
||||
static_attributes: Optional[Dict[str, Optional[str]]] = Field(
|
||||
default_factory=dict, description="Static attributes of the protocol component"
|
||||
)
|
||||
creation_tx: str = Field(
|
||||
..., description="Hash of the transaction that created the protocol component"
|
||||
)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.id = self.id.lower()
|
||||
self.tokens = sorted([t.lower() for t in self.tokens])
|
||||
|
||||
def compare(self, other: "ProtocolComponentExpectation") -> Optional[str]:
|
||||
"""Compares the current instance with another ProtocolComponent instance and returns a message with the differences or None if there are no differences."""
|
||||
differences = []
|
||||
for field_name, field_value in self.__dict__.items():
|
||||
other_value = getattr(other, field_name, None)
|
||||
if field_value != other_value:
|
||||
differences.append(
|
||||
f"Field '{field_name}' mismatch: '{field_value}' != '{other_value}'"
|
||||
)
|
||||
if not differences:
|
||||
return None
|
||||
|
||||
return "\n".join(differences)
|
||||
|
||||
@staticmethod
|
||||
def from_dto(dto: ProtocolComponent) -> "ProtocolComponentExpectation":
|
||||
return ProtocolComponentExpectation(
|
||||
id=dto.id,
|
||||
tokens=[t.hex() for t in dto.tokens],
|
||||
static_attributes={
|
||||
key: value.hex() for key, value in dto.static_attributes.items()
|
||||
},
|
||||
creation_tx=dto.creation_tx.hex(),
|
||||
)
|
||||
|
||||
|
||||
class ProtocolComponentWithTestConfig(ProtocolComponentExpectation):
|
||||
"""Represents a ProtocolComponent with its main attributes and test configuration."""
|
||||
|
||||
skip_simulation: Optional[bool] = Field(
|
||||
False,
|
||||
description="Flag indicating whether to skip simulation for this component",
|
||||
)
|
||||
|
||||
def into_protocol_component(self) -> ProtocolComponentExpectation:
|
||||
return ProtocolComponentExpectation(**self.dict())
|
||||
|
||||
|
||||
class IntegrationTest(BaseModel):
|
||||
"""Configuration for an individual test."""
|
||||
|
||||
name: str = Field(..., description="Name of the test")
|
||||
start_block: int = Field(..., description="Starting block number for the test")
|
||||
stop_block: int = Field(..., description="Stopping block number for the test")
|
||||
initialized_accounts: Optional[List[str]] = Field(
|
||||
None, description="List of initialized account addresses"
|
||||
)
|
||||
expected_components: List[ProtocolComponentWithTestConfig] = Field(
|
||||
..., description="List of protocol components expected in the indexed state"
|
||||
)
|
||||
|
||||
|
||||
class IntegrationTestsConfig(BaseModel):
|
||||
"""Main integration test configuration."""
|
||||
|
||||
substreams_yaml_path: str = Field(
|
||||
"./substreams.yaml", description="Path of the Substreams YAML file"
|
||||
)
|
||||
adapter_contract: str = Field(
|
||||
..., description="Name of the SwapAdapter contract for this protocol"
|
||||
)
|
||||
adapter_build_signature: Optional[str] = Field(
|
||||
None, description="Signatre of the SwapAdapter constructor for this protocol"
|
||||
)
|
||||
adapter_build_args: Optional[str] = Field(
|
||||
None, description="Arguments for the SwapAdapter constructor for this protocol"
|
||||
)
|
||||
initialized_accounts: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of initialized account addresses. These accounts will be initialized for every tests",
|
||||
)
|
||||
skip_balance_check: bool = Field(
|
||||
..., description="Flag to skip balance check for all tests"
|
||||
)
|
||||
protocol_type_names: List[str] = Field(
|
||||
..., description="List of protocol type names for the tested protocol"
|
||||
)
|
||||
tests: List[IntegrationTest] = Field(..., description="List of integration tests")
|
||||
@@ -7,6 +7,7 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
||||
@@ -27,6 +28,11 @@ from tycho_client.dto import (
|
||||
)
|
||||
from tycho_client.rpc_client import TychoRPCClient
|
||||
|
||||
from models import (
|
||||
IntegrationTestsConfig,
|
||||
ProtocolComponentWithTestConfig,
|
||||
ProtocolComponentExpectation,
|
||||
)
|
||||
from adapter_handler import AdapterContractHandler
|
||||
from evm import get_token_balance, get_block_header
|
||||
from tycho import TychoRunner
|
||||
@@ -47,10 +53,10 @@ class TestResult:
|
||||
return cls(success=False, message=message)
|
||||
|
||||
|
||||
def load_config(yaml_path: str) -> dict:
|
||||
"""Load YAML configuration from a specified file path."""
|
||||
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
|
||||
with open(yaml_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
yaml_content = yaml.safe_load(file)
|
||||
return IntegrationTestsConfig(**yaml_content)
|
||||
|
||||
|
||||
class SimulationFailure(BaseModel):
|
||||
@@ -66,13 +72,13 @@ class TestRunner:
|
||||
):
|
||||
self.repo_root = os.getcwd()
|
||||
config_path = os.path.join(
|
||||
self.repo_root, "substreams", package, "test_assets.yaml"
|
||||
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
|
||||
)
|
||||
self.config = load_config(config_path)
|
||||
self.config: IntegrationTestsConfig = parse_config(config_path)
|
||||
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
|
||||
self.adapters_src = os.path.join(self.repo_root, "evm")
|
||||
self.tycho_runner = TychoRunner(
|
||||
db_url, with_binary_logs, self.config["initialized_accounts"]
|
||||
db_url, with_binary_logs, self.config.initialized_accounts
|
||||
)
|
||||
self.tycho_rpc_client = TychoRPCClient()
|
||||
self._token_factory_func = token_factory(self.tycho_rpc_client)
|
||||
@@ -83,32 +89,35 @@ class TestRunner:
|
||||
def run_tests(self) -> None:
|
||||
"""Run all tests specified in the configuration."""
|
||||
print(f"Running tests ...")
|
||||
for test in self.config["tests"]:
|
||||
for test in self.config.tests:
|
||||
self.tycho_runner.empty_database(self.db_url)
|
||||
|
||||
spkg_path = self.build_spkg(
|
||||
os.path.join(self.spkg_src, self.config["substreams_yaml_path"]),
|
||||
lambda data: self.update_initial_block(data, test["start_block"]),
|
||||
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
|
||||
lambda data: self.update_initial_block(data, test.start_block),
|
||||
)
|
||||
self.tycho_runner.run_tycho(
|
||||
spkg_path,
|
||||
test["start_block"],
|
||||
test["stop_block"],
|
||||
self.config["protocol_type_names"],
|
||||
test.get("initialized_accounts", []),
|
||||
test.start_block,
|
||||
test.stop_block,
|
||||
self.config.protocol_type_names,
|
||||
test.initialized_accounts or [],
|
||||
)
|
||||
|
||||
result = self.tycho_runner.run_with_rpc_server(
|
||||
self.validate_state, test["expected_state"], test["stop_block"]
|
||||
self.validate_state, test.expected_components, test.stop_block
|
||||
)
|
||||
|
||||
if result.success:
|
||||
print(f"✅ {test['name']} passed.")
|
||||
|
||||
print(f"✅ {test.name} passed.")
|
||||
else:
|
||||
print(f"❗️ {test['name']} failed: {result.message}")
|
||||
print(f"❗️ {test.name} failed: {result.message}")
|
||||
|
||||
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
|
||||
def validate_state(
|
||||
self,
|
||||
expected_components: List[ProtocolComponentWithTestConfig],
|
||||
stop_block: int,
|
||||
) -> TestResult:
|
||||
"""Validate the current protocol state against the expected state."""
|
||||
protocol_components = self.tycho_rpc_client.get_protocol_components(
|
||||
ProtocolComponentsParams(protocol_system="test_protocol")
|
||||
@@ -121,43 +130,18 @@ class TestRunner:
|
||||
}
|
||||
|
||||
try:
|
||||
for expected_component in expected_state.get("protocol_components", []):
|
||||
comp_id = expected_component["id"].lower()
|
||||
for expected_component in expected_components:
|
||||
comp_id = expected_component.id.lower()
|
||||
if comp_id not in components_by_id:
|
||||
return TestResult.Failed(
|
||||
f"'{comp_id}' not found in protocol components."
|
||||
)
|
||||
|
||||
# TODO: Manipulate pydantic objects instead of dict
|
||||
component = components_by_id[comp_id].dict()
|
||||
for key, value in expected_component.items():
|
||||
if key not in ["tokens", "static_attributes", "creation_tx"]:
|
||||
continue
|
||||
if key not in component:
|
||||
return TestResult.Failed(
|
||||
f"Missing '{key}' in component '{comp_id}'."
|
||||
)
|
||||
if key == "tokens":
|
||||
if set(map(HexBytes, value)) != set(component[key]):
|
||||
return TestResult.Failed(
|
||||
f"Token mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
elif key == "creation_tx":
|
||||
if HexBytes(value) != component[key]:
|
||||
return TestResult.Failed(
|
||||
f"Creation tx mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
if set(map(str.lower, value)) != set(
|
||||
map(str.lower, component[key])
|
||||
):
|
||||
return TestResult.Failed(
|
||||
f"List mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
elif value is not None and value.lower() != component[key]:
|
||||
return TestResult.Failed(
|
||||
f"Value mismatch for key '{key}': {value} != {component[key]}"
|
||||
)
|
||||
diff = ProtocolComponentExpectation.from_dto(
|
||||
components_by_id[comp_id]
|
||||
).compare(expected_component.into_protocol_component())
|
||||
if diff is not None:
|
||||
return TestResult.Failed(diff)
|
||||
|
||||
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
|
||||
for component in protocol_components:
|
||||
@@ -178,7 +162,7 @@ class TestRunner:
|
||||
tycho_balance = int(balance_hex)
|
||||
token_balances[comp_id][token] = tycho_balance
|
||||
|
||||
if self.config["skip_balance_check"] is not True:
|
||||
if self.config.skip_balance_check is not True:
|
||||
node_balance = get_token_balance(token, comp_id, stop_block)
|
||||
if node_balance != tycho_balance:
|
||||
return TestResult.Failed(
|
||||
@@ -198,11 +182,7 @@ class TestRunner:
|
||||
pc
|
||||
for pc in protocol_components
|
||||
if pc.id
|
||||
in [
|
||||
c["id"].lower()
|
||||
for c in expected_state["protocol_components"]
|
||||
if c.get("skip_simulation", False) is False
|
||||
]
|
||||
in [c.id for c in expected_components if c.skip_simulation is False]
|
||||
]
|
||||
simulation_failures = self.simulate_get_amount_out(
|
||||
stop_block, protocol_states, filtered_components, contract_states
|
||||
@@ -231,7 +211,7 @@ class TestRunner:
|
||||
contract_states: list[ResponseAccount],
|
||||
) -> dict[str, list[SimulationFailure]]:
|
||||
TychoDBSingleton.initialize()
|
||||
protocol_type_names = self.config["protocol_type_names"]
|
||||
protocol_type_names = self.config.protocol_type_names
|
||||
|
||||
block_header = get_block_header(block_number)
|
||||
block: EVMBlock = EVMBlock(
|
||||
@@ -245,17 +225,17 @@ class TestRunner:
|
||||
adapter_contract = os.path.join(
|
||||
self.adapters_src,
|
||||
"out",
|
||||
f"{self.config['adapter_contract']}.sol",
|
||||
f"{self.config['adapter_contract']}.evm.runtime",
|
||||
f"{self.config.adapter_contract}.sol",
|
||||
f"{self.config.adapter_contract}.evm.runtime",
|
||||
)
|
||||
if not os.path.exists(adapter_contract):
|
||||
print("Adapter contract not found. Building it ...")
|
||||
|
||||
AdapterContractHandler.build_target(
|
||||
self.adapters_src,
|
||||
self.config["adapter_contract"],
|
||||
self.config["adapter_build_signature"],
|
||||
self.config["adapter_build_args"],
|
||||
self.config.adapter_contract,
|
||||
self.config.adapter_build_signature,
|
||||
self.config.adapter_build_args,
|
||||
)
|
||||
|
||||
decoder = ThirdPartyPoolTychoDecoder(
|
||||
|
||||
@@ -82,11 +82,14 @@ class TychoRunner:
|
||||
"--stop-block",
|
||||
# +2 is to make up for the cache in the index side.
|
||||
str(end_block + 2),
|
||||
"--initialization-block",
|
||||
str(start_block),
|
||||
]
|
||||
+ (
|
||||
["--initialized-accounts", ",".join(all_accounts)]
|
||||
[
|
||||
"--initialized-accounts",
|
||||
",".join(all_accounts),
|
||||
"--initialization-block",
|
||||
str(start_block),
|
||||
]
|
||||
if all_accounts
|
||||
else []
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user