refactor(substreams-testing): Use Pydantic to deserialize test_assets.yaml

This commit is contained in:
Florian Pellissier
2024-08-06 13:32:58 +02:00
committed by tvinagre
parent 09d266a810
commit 95efda0423
10 changed files with 365 additions and 233 deletions

View File

@@ -0,0 +1,104 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
from tycho_client.dto import ProtocolComponent
class ProtocolComponentExpectation(BaseModel):
"""Represents a ProtocolComponent with its main attributes."""
id: str = Field(..., description="Identifier of the protocol component")
tokens: List[str] = Field(
...,
description="List of token addresses associated with the protocol component",
)
static_attributes: Optional[Dict[str, Optional[str]]] = Field(
default_factory=dict, description="Static attributes of the protocol component"
)
creation_tx: str = Field(
..., description="Hash of the transaction that created the protocol component"
)
def __init__(self, **data):
super().__init__(**data)
self.id = self.id.lower()
self.tokens = sorted([t.lower() for t in self.tokens])
def compare(self, other: "ProtocolComponentExpectation") -> Optional[str]:
"""Compares the current instance with another ProtocolComponent instance and returns a message with the differences or None if there are no differences."""
differences = []
for field_name, field_value in self.__dict__.items():
other_value = getattr(other, field_name, None)
if field_value != other_value:
differences.append(
f"Field '{field_name}' mismatch: '{field_value}' != '{other_value}'"
)
if not differences:
return None
return "\n".join(differences)
@staticmethod
def from_dto(dto: ProtocolComponent) -> "ProtocolComponentExpectation":
return ProtocolComponentExpectation(
id=dto.id,
tokens=[t.hex() for t in dto.tokens],
static_attributes={
key: value.hex() for key, value in dto.static_attributes.items()
},
creation_tx=dto.creation_tx.hex(),
)
class ProtocolComponentWithTestConfig(ProtocolComponentExpectation):
"""Represents a ProtocolComponent with its main attributes and test configuration."""
skip_simulation: Optional[bool] = Field(
False,
description="Flag indicating whether to skip simulation for this component",
)
def into_protocol_component(self) -> ProtocolComponentExpectation:
return ProtocolComponentExpectation(**self.dict())
class IntegrationTest(BaseModel):
"""Configuration for an individual test."""
name: str = Field(..., description="Name of the test")
start_block: int = Field(..., description="Starting block number for the test")
stop_block: int = Field(..., description="Stopping block number for the test")
initialized_accounts: Optional[List[str]] = Field(
None, description="List of initialized account addresses"
)
expected_components: List[ProtocolComponentWithTestConfig] = Field(
..., description="List of protocol components expected in the indexed state"
)
class IntegrationTestsConfig(BaseModel):
"""Main integration test configuration."""
substreams_yaml_path: str = Field(
"./substreams.yaml", description="Path of the Substreams YAML file"
)
adapter_contract: str = Field(
..., description="Name of the SwapAdapter contract for this protocol"
)
adapter_build_signature: Optional[str] = Field(
None, description="Signatre of the SwapAdapter constructor for this protocol"
)
adapter_build_args: Optional[str] = Field(
None, description="Arguments for the SwapAdapter constructor for this protocol"
)
initialized_accounts: Optional[List[str]] = Field(
None,
description="List of initialized account addresses. These accounts will be initialized for every tests",
)
skip_balance_check: bool = Field(
..., description="Flag to skip balance check for all tests"
)
protocol_type_names: List[str] = Field(
..., description="List of protocol type names for the tested protocol"
)
tests: List[IntegrationTest] = Field(..., description="List of integration tests")

View File

@@ -7,6 +7,7 @@ from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import List
import yaml
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
@@ -27,6 +28,11 @@ from tycho_client.dto import (
)
from tycho_client.rpc_client import TychoRPCClient
from models import (
IntegrationTestsConfig,
ProtocolComponentWithTestConfig,
ProtocolComponentExpectation,
)
from adapter_handler import AdapterContractHandler
from evm import get_token_balance, get_block_header
from tycho import TychoRunner
@@ -47,10 +53,10 @@ class TestResult:
return cls(success=False, message=message)
def load_config(yaml_path: str) -> dict:
"""Load YAML configuration from a specified file path."""
def parse_config(yaml_path: str) -> IntegrationTestsConfig:
with open(yaml_path, "r") as file:
return yaml.safe_load(file)
yaml_content = yaml.safe_load(file)
return IntegrationTestsConfig(**yaml_content)
class SimulationFailure(BaseModel):
@@ -66,13 +72,13 @@ class TestRunner:
):
self.repo_root = os.getcwd()
config_path = os.path.join(
self.repo_root, "substreams", package, "test_assets.yaml"
self.repo_root, "substreams", package, "integration_test.tycho.yaml"
)
self.config = load_config(config_path)
self.config: IntegrationTestsConfig = parse_config(config_path)
self.spkg_src = os.path.join(self.repo_root, "substreams", package)
self.adapters_src = os.path.join(self.repo_root, "evm")
self.tycho_runner = TychoRunner(
db_url, with_binary_logs, self.config["initialized_accounts"]
db_url, with_binary_logs, self.config.initialized_accounts
)
self.tycho_rpc_client = TychoRPCClient()
self._token_factory_func = token_factory(self.tycho_rpc_client)
@@ -83,32 +89,35 @@ class TestRunner:
def run_tests(self) -> None:
"""Run all tests specified in the configuration."""
print(f"Running tests ...")
for test in self.config["tests"]:
for test in self.config.tests:
self.tycho_runner.empty_database(self.db_url)
spkg_path = self.build_spkg(
os.path.join(self.spkg_src, self.config["substreams_yaml_path"]),
lambda data: self.update_initial_block(data, test["start_block"]),
os.path.join(self.spkg_src, self.config.substreams_yaml_path),
lambda data: self.update_initial_block(data, test.start_block),
)
self.tycho_runner.run_tycho(
spkg_path,
test["start_block"],
test["stop_block"],
self.config["protocol_type_names"],
test.get("initialized_accounts", []),
test.start_block,
test.stop_block,
self.config.protocol_type_names,
test.initialized_accounts or [],
)
result = self.tycho_runner.run_with_rpc_server(
self.validate_state, test["expected_state"], test["stop_block"]
self.validate_state, test.expected_components, test.stop_block
)
if result.success:
print(f"{test['name']} passed.")
print(f"{test.name} passed.")
else:
print(f"❗️ {test['name']} failed: {result.message}")
print(f"❗️ {test.name} failed: {result.message}")
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
def validate_state(
self,
expected_components: List[ProtocolComponentWithTestConfig],
stop_block: int,
) -> TestResult:
"""Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol")
@@ -121,43 +130,18 @@ class TestRunner:
}
try:
for expected_component in expected_state.get("protocol_components", []):
comp_id = expected_component["id"].lower()
for expected_component in expected_components:
comp_id = expected_component.id.lower()
if comp_id not in components_by_id:
return TestResult.Failed(
f"'{comp_id}' not found in protocol components."
)
# TODO: Manipulate pydantic objects instead of dict
component = components_by_id[comp_id].dict()
for key, value in expected_component.items():
if key not in ["tokens", "static_attributes", "creation_tx"]:
continue
if key not in component:
return TestResult.Failed(
f"Missing '{key}' in component '{comp_id}'."
)
if key == "tokens":
if set(map(HexBytes, value)) != set(component[key]):
return TestResult.Failed(
f"Token mismatch for key '{key}': {value} != {component[key]}"
)
elif key == "creation_tx":
if HexBytes(value) != component[key]:
return TestResult.Failed(
f"Creation tx mismatch for key '{key}': {value} != {component[key]}"
)
elif isinstance(value, list):
if set(map(str.lower, value)) != set(
map(str.lower, component[key])
):
return TestResult.Failed(
f"List mismatch for key '{key}': {value} != {component[key]}"
)
elif value is not None and value.lower() != component[key]:
return TestResult.Failed(
f"Value mismatch for key '{key}': {value} != {component[key]}"
)
diff = ProtocolComponentExpectation.from_dto(
components_by_id[comp_id]
).compare(expected_component.into_protocol_component())
if diff is not None:
return TestResult.Failed(diff)
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
for component in protocol_components:
@@ -178,7 +162,7 @@ class TestRunner:
tycho_balance = int(balance_hex)
token_balances[comp_id][token] = tycho_balance
if self.config["skip_balance_check"] is not True:
if self.config.skip_balance_check is not True:
node_balance = get_token_balance(token, comp_id, stop_block)
if node_balance != tycho_balance:
return TestResult.Failed(
@@ -198,11 +182,7 @@ class TestRunner:
pc
for pc in protocol_components
if pc.id
in [
c["id"].lower()
for c in expected_state["protocol_components"]
if c.get("skip_simulation", False) is False
]
in [c.id for c in expected_components if c.skip_simulation is False]
]
simulation_failures = self.simulate_get_amount_out(
stop_block, protocol_states, filtered_components, contract_states
@@ -231,7 +211,7 @@ class TestRunner:
contract_states: list[ResponseAccount],
) -> dict[str, list[SimulationFailure]]:
TychoDBSingleton.initialize()
protocol_type_names = self.config["protocol_type_names"]
protocol_type_names = self.config.protocol_type_names
block_header = get_block_header(block_number)
block: EVMBlock = EVMBlock(
@@ -245,17 +225,17 @@ class TestRunner:
adapter_contract = os.path.join(
self.adapters_src,
"out",
f"{self.config['adapter_contract']}.sol",
f"{self.config['adapter_contract']}.evm.runtime",
f"{self.config.adapter_contract}.sol",
f"{self.config.adapter_contract}.evm.runtime",
)
if not os.path.exists(adapter_contract):
print("Adapter contract not found. Building it ...")
AdapterContractHandler.build_target(
self.adapters_src,
self.config["adapter_contract"],
self.config["adapter_build_signature"],
self.config["adapter_build_args"],
self.config.adapter_contract,
self.config.adapter_build_signature,
self.config.adapter_build_args,
)
decoder = ThirdPartyPoolTychoDecoder(

View File

@@ -82,11 +82,14 @@ class TychoRunner:
"--stop-block",
# +2 is to make up for the cache in the index side.
str(end_block + 2),
"--initialization-block",
str(start_block),
]
+ (
["--initialized-accounts", ",".join(all_accounts)]
[
"--initialized-accounts",
",".join(all_accounts),
"--initialization-block",
str(start_block),
]
if all_accounts
else []
),