refactor(substreams-testing): Use Pydantic validators, Hexbytes and improve description

This commit is contained in:
Florian Pellissier
2024-08-07 23:53:34 +02:00
committed by tvinagre
parent 95efda0423
commit 1f9fe8d583
2 changed files with 24 additions and 19 deletions

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field from hexbytes import HexBytes
from pydantic import BaseModel, Field, validator
from typing import List, Dict, Optional from typing import List, Dict, Optional
from tycho_client.dto import ProtocolComponent from tycho_client.dto import ProtocolComponent
@@ -8,21 +9,32 @@ class ProtocolComponentExpectation(BaseModel):
"""Represents a ProtocolComponent with its main attributes.""" """Represents a ProtocolComponent with its main attributes."""
id: str = Field(..., description="Identifier of the protocol component") id: str = Field(..., description="Identifier of the protocol component")
tokens: List[str] = Field( tokens: List[HexBytes] = Field(
..., ...,
description="List of token addresses associated with the protocol component", description="List of token addresses associated with the protocol component",
) )
static_attributes: Optional[Dict[str, Optional[str]]] = Field( static_attributes: Optional[Dict[str, HexBytes]] = Field(
default_factory=dict, description="Static attributes of the protocol component" default_factory=dict, description="Static attributes of the protocol component"
) )
creation_tx: str = Field( creation_tx: HexBytes = Field(
..., description="Hash of the transaction that created the protocol component" ..., description="Hash of the transaction that created the protocol component"
) )
def __init__(self, **data): @validator("id", pre=True, always=True)
super().__init__(**data) def lower_id(cls, v):
self.id = self.id.lower() return v.lower()
self.tokens = sorted([t.lower() for t in self.tokens])
@validator("tokens", pre=True, always=True)
def convert_tokens_to_hexbytes(cls, v):
return sorted(HexBytes(t.lower()) for t in v)
@validator("static_attributes", pre=True, always=True)
def convert_static_attributes_to_hexbytes(cls, v):
return {k: HexBytes(v[k].lower()) for k in v} if v else {}
@validator("creation_tx", pre=True, always=True)
def convert_creation_tx_to_hexbytes(cls, v):
return HexBytes(v.lower())
def compare(self, other: "ProtocolComponentExpectation") -> Optional[str]: def compare(self, other: "ProtocolComponentExpectation") -> Optional[str]:
"""Compares the current instance with another ProtocolComponent instance and returns a message with the differences or None if there are no differences.""" """Compares the current instance with another ProtocolComponent instance and returns a message with the differences or None if there are no differences."""
@@ -40,14 +52,7 @@ class ProtocolComponentExpectation(BaseModel):
@staticmethod @staticmethod
def from_dto(dto: ProtocolComponent) -> "ProtocolComponentExpectation": def from_dto(dto: ProtocolComponent) -> "ProtocolComponentExpectation":
return ProtocolComponentExpectation( return ProtocolComponentExpectation(**dto.dict())
id=dto.id,
tokens=[t.hex() for t in dto.tokens],
static_attributes={
key: value.hex() for key, value in dto.static_attributes.items()
},
creation_tx=dto.creation_tx.hex(),
)
class ProtocolComponentWithTestConfig(ProtocolComponentExpectation): class ProtocolComponentWithTestConfig(ProtocolComponentExpectation):
@@ -86,10 +91,10 @@ class IntegrationTestsConfig(BaseModel):
..., description="Name of the SwapAdapter contract for this protocol" ..., description="Name of the SwapAdapter contract for this protocol"
) )
adapter_build_signature: Optional[str] = Field( adapter_build_signature: Optional[str] = Field(
None, description="Signatre of the SwapAdapter constructor for this protocol" None, description="SwapAdapter's constructor signature"
) )
adapter_build_args: Optional[str] = Field( adapter_build_args: Optional[str] = Field(
None, description="Arguments for the SwapAdapter constructor for this protocol" None, description="Arguments for the SwapAdapter constructor"
) )
initialized_accounts: Optional[List[str]] = Field( initialized_accounts: Optional[List[str]] = Field(
None, None,

View File

@@ -162,7 +162,7 @@ class TestRunner:
tycho_balance = int(balance_hex) tycho_balance = int(balance_hex)
token_balances[comp_id][token] = tycho_balance token_balances[comp_id][token] = tycho_balance
if self.config.skip_balance_check is not True: if not self.config.skip_balance_check:
node_balance = get_token_balance(token, comp_id, stop_block) node_balance = get_token_balance(token, comp_id, stop_block)
if node_balance != tycho_balance: if node_balance != tycho_balance:
return TestResult.Failed( return TestResult.Failed(