From d0c248fcb68d80dd7d023efd0420bfc360a9afa0 Mon Sep 17 00:00:00 2001 From: Thales Lima Date: Tue, 6 Aug 2024 06:03:29 +0200 Subject: [PATCH] Add build_snapshot_message method --- testing/src/runner/runner.py | 95 ++++++++++++++++++------------------ testing/src/runner/utils.py | 34 +++++++++++++ 2 files changed, 82 insertions(+), 47 deletions(-) create mode 100644 testing/src/runner/utils.py diff --git a/testing/src/runner/runner.py b/testing/src/runner/runner.py index a313d2b..40f7688 100644 --- a/testing/src/runner/runner.py +++ b/testing/src/runner/runner.py @@ -9,12 +9,19 @@ from decimal import Decimal from pathlib import Path import yaml +from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder +from protosim_py.models import EVMBlock from pydantic import BaseModel from tycho_client.dto import ( Chain, ProtocolComponentsParams, ProtocolStateParams, - ContractStateParams, ProtocolComponent, ResponseProtocolState, HexBytes, + ContractStateParams, + ProtocolComponent, + ResponseProtocolState, + HexBytes, + ResponseAccount, + Snapshot, ) from tycho_client.rpc_client import TychoRPCClient from tycho_client.stream import TychoStream @@ -22,11 +29,7 @@ from tycho_client.stream import TychoStream from .adapter_handler import AdapterContractHandler from .evm import get_token_balance, get_block_header from .tycho import TychoRunner - - -# from tycho_client.decoders import ThirdPartyPoolTychoDecoder -# from tycho_client.models import Blockchain, EVMBlock -# from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter +from .utils import build_snapshot_message class TestResult: @@ -58,7 +61,7 @@ class SimulationFailure(BaseModel): class TestRunner: def __init__( - self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool + self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool ): self.repo_root = os.getcwd() config_path = os.path.join( @@ -105,15 +108,18 @@ class TestRunner: def validate_state(self, expected_state: dict, stop_block: int) -> TestResult: """Validate the current protocol state against the expected state.""" - protocol_components: list[ProtocolComponent] = self.tycho_rpc_client.get_protocol_components( + protocol_components: list[ + ProtocolComponent + ] = self.tycho_rpc_client.get_protocol_components( ProtocolComponentsParams(protocol_system="test_protocol") ) - protocol_states: list[ResponseProtocolState] = self.tycho_rpc_client.get_protocol_state( + protocol_states: list[ + ResponseProtocolState + ] = self.tycho_rpc_client.get_protocol_state( ProtocolStateParams(protocol_system="test_protocol") ) components_by_id = { - component.id: component - for component in protocol_components + component.id: component for component in protocol_components } try: @@ -135,7 +141,7 @@ class TestRunner: ) if isinstance(value, list): if set(map(str.lower, value)) != set( - map(str.lower, component[key]) + map(str.lower, component[key]) ): return TestResult.Failed( f"List mismatch for key '{key}': {value} != {component[key]}" @@ -152,8 +158,8 @@ class TestRunner: state = next( ( s - for s in protocol_states["states"] - if s["component_id"].lower() == comp_id + for s in protocol_states + if s.component_id.lower() == comp_id ), None, ) @@ -171,20 +177,19 @@ class TestRunner: f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} " f"from rpc call and {tycho_balance} from Substreams" ) - contract_states = self.tycho_rpc_client.get_contract_state( - ContractStateParams() - ) - filtered_components = { - "protocol_components": [ - pc - for pc in protocol_components - if pc.id in [ - c["id"].lower() - for c in expected_state["protocol_components"] - if c.get("skip_simulation", False) is False - ] + contract_states: list[ + ResponseAccount + ] = self.tycho_rpc_client.get_contract_state(ContractStateParams()) + filtered_components = [ + pc + for pc in protocol_components + if pc.id + in [ + c["id"].lower() + for c in expected_state["protocol_components"] + if c.get("skip_simulation", False) is False ] - } + ] simulation_failures = self.simulate_get_amount_out( stop_block, protocol_states, filtered_components, contract_states ) @@ -205,11 +210,11 @@ class TestRunner: return TestResult.Failed(error_message) def simulate_get_amount_out( - self, - block_number: int, - protocol_states: ResponseProtocolState, - protocol_components: list[ProtocolComponent], - contract_state: Contract, + self, + block_number: int, + protocol_states: list[ResponseProtocolState], + protocol_components: list[ProtocolComponent], + contract_states: list[ResponseAccount], ) -> dict[str, list[SimulationFailure]]: protocol_type_names = self.config["protocol_type_names"] @@ -238,15 +243,10 @@ class TestRunner: self.config["adapter_build_args"], ) - # decoder = ThirdPartyPoolTychoDecoder( - # adapter_contract, 0, trace=self._vm_traces - # ) - # stream_adapter = TychoPoolStateStreamAdapter( - # tycho_url="0.0.0.0:4242", - # protocol=protocol, - # decoder=decoder, - # blockchain=self._chain, - # ) + decoder = ThirdPartyPoolTychoDecoder( + adapter_contract=adapter_contract, minimum_gas=0, trace=self._vm_traces + ) + stream_adapter = TychoStream( tycho_url="0.0.0.0:4242", exchanges=[protocol], @@ -254,21 +254,22 @@ class TestRunner: blockchain=self._chain, ) - snapshot_message = stream_adapter.build_snapshot_message( - protocol_components, protocol_states, contract_state + snapshot_message: Snapshot = build_snapshot_message( + protocol_states, protocol_components, contract_states ) - decoded = stream_adapter.process_snapshot(block, snapshot_message) - for pool_state in decoded.pool_states.values(): + decoded = decoder.decode_snapshot(snapshot_message, block) + + for pool_state in decoded.values(): pool_id = pool_state.id_ if not pool_state.balances: raise ValueError(f"Missing balances for pool {pool_id}") for sell_token, buy_token in itertools.permutations( - pool_state.tokens, 2 + pool_state.tokens, 2 ): # Try to sell 0.1% of the protocol balance sell_amount = ( - Decimal("0.001") * pool_state.balances[sell_token.address] + Decimal("0.001") * pool_state.balances[sell_token.address] ) try: amount_out, gas_used, _ = pool_state.get_amount_out( diff --git a/testing/src/runner/utils.py b/testing/src/runner/utils.py new file mode 100644 index 0000000..78cc9fc --- /dev/null +++ b/testing/src/runner/utils.py @@ -0,0 +1,34 @@ +from logging import getLogger + +from protosim_py.evm.pool_state import ThirdPartyPool +from tycho_client.dto import ( + ResponseProtocolState, + ProtocolComponent, + ResponseAccount, + ComponentWithState, + Snapshot, +) + +log = getLogger(__name__) + + +def build_snapshot_message( + protocol_states: list[ResponseProtocolState], + protocol_components: list[ProtocolComponent], + account_states: list[ResponseAccount], +) -> Snapshot: + vm_storage = {state.address: state for state in account_states} + + states = {} + for component in protocol_components: + pool_id = component.id + states[pool_id] = {"component": component} + for state in protocol_states: + pool_id = state.component_id + if pool_id not in states: + log.warning(f"State for pool {pool_id} not found in components") + continue + states[pool_id]["state"] = state + + states = {id_: ComponentWithState(**state) for id_, state in states.items()} + return Snapshot(states=states, vm_storage=vm_storage)