From 1f9fe8d58390ec5ba9487a954d662b6cb3c520a2 Mon Sep 17 00:00:00 2001 From: Florian Pellissier <111426680+flopell@users.noreply.github.com> Date: Wed, 7 Aug 2024 23:53:34 +0200 Subject: [PATCH] refactor(substreams-testing): Use Pydantic validators, Hexbytes and improve description --- testing/src/runner/models.py | 41 ++++++++++++++++++++---------------- testing/src/runner/runner.py | 2 +- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/testing/src/runner/models.py b/testing/src/runner/models.py index dab75d4..aa79559 100644 --- a/testing/src/runner/models.py +++ b/testing/src/runner/models.py @@ -1,4 +1,5 @@ -from pydantic import BaseModel, Field +from hexbytes import HexBytes +from pydantic import BaseModel, Field, validator from typing import List, Dict, Optional from tycho_client.dto import ProtocolComponent @@ -8,21 +9,32 @@ class ProtocolComponentExpectation(BaseModel): """Represents a ProtocolComponent with its main attributes.""" id: str = Field(..., description="Identifier of the protocol component") - tokens: List[str] = Field( + tokens: List[HexBytes] = Field( ..., description="List of token addresses associated with the protocol component", ) - static_attributes: Optional[Dict[str, Optional[str]]] = Field( + static_attributes: Optional[Dict[str, HexBytes]] = Field( default_factory=dict, description="Static attributes of the protocol component" ) - creation_tx: str = Field( + creation_tx: HexBytes = Field( ..., description="Hash of the transaction that created the protocol component" ) - def __init__(self, **data): - super().__init__(**data) - self.id = self.id.lower() - self.tokens = sorted([t.lower() for t in self.tokens]) + @validator("id", pre=True, always=True) + def lower_id(cls, v): + return v.lower() + + @validator("tokens", pre=True, always=True) + def convert_tokens_to_hexbytes(cls, v): + return sorted(HexBytes(t.lower()) for t in v) + + @validator("static_attributes", pre=True, always=True) + def convert_static_attributes_to_hexbytes(cls, v): + return {k: HexBytes(v[k].lower()) for k in v} if v else {} + + @validator("creation_tx", pre=True, always=True) + def convert_creation_tx_to_hexbytes(cls, v): + return HexBytes(v.lower()) def compare(self, other: "ProtocolComponentExpectation") -> Optional[str]: """Compares the current instance with another ProtocolComponent instance and returns a message with the differences or None if there are no differences.""" @@ -40,14 +52,7 @@ class ProtocolComponentExpectation(BaseModel): @staticmethod def from_dto(dto: ProtocolComponent) -> "ProtocolComponentExpectation": - return ProtocolComponentExpectation( - id=dto.id, - tokens=[t.hex() for t in dto.tokens], - static_attributes={ - key: value.hex() for key, value in dto.static_attributes.items() - }, - creation_tx=dto.creation_tx.hex(), - ) + return ProtocolComponentExpectation(**dto.dict()) class ProtocolComponentWithTestConfig(ProtocolComponentExpectation): @@ -86,10 +91,10 @@ class IntegrationTestsConfig(BaseModel): ..., description="Name of the SwapAdapter contract for this protocol" ) adapter_build_signature: Optional[str] = Field( - None, description="Signatre of the SwapAdapter constructor for this protocol" + None, description="SwapAdapter's constructor signature" ) adapter_build_args: Optional[str] = Field( - None, description="Arguments for the SwapAdapter constructor for this protocol" + None, description="Arguments for the SwapAdapter constructor" ) initialized_accounts: Optional[List[str]] = Field( None, diff --git a/testing/src/runner/runner.py b/testing/src/runner/runner.py index 8740f12..c25292d 100644 --- a/testing/src/runner/runner.py +++ b/testing/src/runner/runner.py @@ -162,7 +162,7 @@ class TestRunner: tycho_balance = int(balance_hex) token_balances[comp_id][token] = tycho_balance - if self.config.skip_balance_check is not True: + if not self.config.skip_balance_check: node_balance = get_token_balance(token, comp_id, stop_block) if node_balance != tycho_balance: return TestResult.Failed(