refactor(substreams-testing): Use Pydantic validators, Hexbytes and improve description

This commit is contained in:
Florian Pellissier
2024-08-07 23:53:34 +02:00
committed by tvinagre
parent 95efda0423
commit 1f9fe8d583
2 changed files with 24 additions and 19 deletions

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from hexbytes import HexBytes
from pydantic import BaseModel, Field, validator
from typing import List, Dict, Optional
from tycho_client.dto import ProtocolComponent
@@ -8,21 +9,32 @@ class ProtocolComponentExpectation(BaseModel):
"""Represents a ProtocolComponent with its main attributes."""
id: str = Field(..., description="Identifier of the protocol component")
tokens: List[str] = Field(
tokens: List[HexBytes] = Field(
...,
description="List of token addresses associated with the protocol component",
)
static_attributes: Optional[Dict[str, Optional[str]]] = Field(
static_attributes: Optional[Dict[str, HexBytes]] = Field(
default_factory=dict, description="Static attributes of the protocol component"
)
creation_tx: str = Field(
creation_tx: HexBytes = Field(
..., description="Hash of the transaction that created the protocol component"
)
def __init__(self, **data):
super().__init__(**data)
self.id = self.id.lower()
self.tokens = sorted([t.lower() for t in self.tokens])
@validator("id", pre=True, always=True)
def lower_id(cls, v):
return v.lower()
@validator("tokens", pre=True, always=True)
def convert_tokens_to_hexbytes(cls, v):
return sorted(HexBytes(t.lower()) for t in v)
@validator("static_attributes", pre=True, always=True)
def convert_static_attributes_to_hexbytes(cls, v):
return {k: HexBytes(v[k].lower()) for k in v} if v else {}
@validator("creation_tx", pre=True, always=True)
def convert_creation_tx_to_hexbytes(cls, v):
return HexBytes(v.lower())
def compare(self, other: "ProtocolComponentExpectation") -> Optional[str]:
"""Compares the current instance with another ProtocolComponent instance and returns a message with the differences or None if there are no differences."""
@@ -40,14 +52,7 @@ class ProtocolComponentExpectation(BaseModel):
@staticmethod
def from_dto(dto: ProtocolComponent) -> "ProtocolComponentExpectation":
return ProtocolComponentExpectation(
id=dto.id,
tokens=[t.hex() for t in dto.tokens],
static_attributes={
key: value.hex() for key, value in dto.static_attributes.items()
},
creation_tx=dto.creation_tx.hex(),
)
return ProtocolComponentExpectation(**dto.dict())
class ProtocolComponentWithTestConfig(ProtocolComponentExpectation):
@@ -86,10 +91,10 @@ class IntegrationTestsConfig(BaseModel):
..., description="Name of the SwapAdapter contract for this protocol"
)
adapter_build_signature: Optional[str] = Field(
None, description="Signatre of the SwapAdapter constructor for this protocol"
None, description="SwapAdapter's constructor signature"
)
adapter_build_args: Optional[str] = Field(
None, description="Arguments for the SwapAdapter constructor for this protocol"
None, description="Arguments for the SwapAdapter constructor"
)
initialized_accounts: Optional[List[str]] = Field(
None,