Add build_snapshot_message method

This commit is contained in:
Thales Lima
2024-08-06 06:03:29 +02:00
committed by tvinagre
parent d893ab264c
commit d0c248fcb6
2 changed files with 82 additions and 47 deletions

View File

@@ -9,12 +9,19 @@ from decimal import Decimal
from pathlib import Path
import yaml
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
from protosim_py.models import EVMBlock
from pydantic import BaseModel
from tycho_client.dto import (
Chain,
ProtocolComponentsParams,
ProtocolStateParams,
ContractStateParams, ProtocolComponent, ResponseProtocolState, HexBytes,
ContractStateParams,
ProtocolComponent,
ResponseProtocolState,
HexBytes,
ResponseAccount,
Snapshot,
)
from tycho_client.rpc_client import TychoRPCClient
from tycho_client.stream import TychoStream
@@ -22,11 +29,7 @@ from tycho_client.stream import TychoStream
from .adapter_handler import AdapterContractHandler
from .evm import get_token_balance, get_block_header
from .tycho import TychoRunner
# from tycho_client.decoders import ThirdPartyPoolTychoDecoder
# from tycho_client.models import Blockchain, EVMBlock
# from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter
from .utils import build_snapshot_message
class TestResult:
@@ -58,7 +61,7 @@ class SimulationFailure(BaseModel):
class TestRunner:
def __init__(
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
):
self.repo_root = os.getcwd()
config_path = os.path.join(
@@ -105,15 +108,18 @@ class TestRunner:
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
"""Validate the current protocol state against the expected state."""
protocol_components: list[ProtocolComponent] = self.tycho_rpc_client.get_protocol_components(
protocol_components: list[
ProtocolComponent
] = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol")
)
protocol_states: list[ResponseProtocolState] = self.tycho_rpc_client.get_protocol_state(
protocol_states: list[
ResponseProtocolState
] = self.tycho_rpc_client.get_protocol_state(
ProtocolStateParams(protocol_system="test_protocol")
)
components_by_id = {
component.id: component
for component in protocol_components
component.id: component for component in protocol_components
}
try:
@@ -135,7 +141,7 @@ class TestRunner:
)
if isinstance(value, list):
if set(map(str.lower, value)) != set(
map(str.lower, component[key])
map(str.lower, component[key])
):
return TestResult.Failed(
f"List mismatch for key '{key}': {value} != {component[key]}"
@@ -152,8 +158,8 @@ class TestRunner:
state = next(
(
s
for s in protocol_states["states"]
if s["component_id"].lower() == comp_id
for s in protocol_states
if s.component_id.lower() == comp_id
),
None,
)
@@ -171,20 +177,19 @@ class TestRunner:
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
f"from rpc call and {tycho_balance} from Substreams"
)
contract_states = self.tycho_rpc_client.get_contract_state(
ContractStateParams()
)
filtered_components = {
"protocol_components": [
pc
for pc in protocol_components
if pc.id in [
c["id"].lower()
for c in expected_state["protocol_components"]
if c.get("skip_simulation", False) is False
]
contract_states: list[
ResponseAccount
] = self.tycho_rpc_client.get_contract_state(ContractStateParams())
filtered_components = [
pc
for pc in protocol_components
if pc.id
in [
c["id"].lower()
for c in expected_state["protocol_components"]
if c.get("skip_simulation", False) is False
]
}
]
simulation_failures = self.simulate_get_amount_out(
stop_block, protocol_states, filtered_components, contract_states
)
@@ -205,11 +210,11 @@ class TestRunner:
return TestResult.Failed(error_message)
def simulate_get_amount_out(
self,
block_number: int,
protocol_states: ResponseProtocolState,
protocol_components: list[ProtocolComponent],
contract_state: Contract,
self,
block_number: int,
protocol_states: list[ResponseProtocolState],
protocol_components: list[ProtocolComponent],
contract_states: list[ResponseAccount],
) -> dict[str, list[SimulationFailure]]:
protocol_type_names = self.config["protocol_type_names"]
@@ -238,15 +243,10 @@ class TestRunner:
self.config["adapter_build_args"],
)
# decoder = ThirdPartyPoolTychoDecoder(
# adapter_contract, 0, trace=self._vm_traces
# )
# stream_adapter = TychoPoolStateStreamAdapter(
# tycho_url="0.0.0.0:4242",
# protocol=protocol,
# decoder=decoder,
# blockchain=self._chain,
# )
decoder = ThirdPartyPoolTychoDecoder(
adapter_contract=adapter_contract, minimum_gas=0, trace=self._vm_traces
)
stream_adapter = TychoStream(
tycho_url="0.0.0.0:4242",
exchanges=[protocol],
@@ -254,21 +254,22 @@ class TestRunner:
blockchain=self._chain,
)
snapshot_message = stream_adapter.build_snapshot_message(
protocol_components, protocol_states, contract_state
snapshot_message: Snapshot = build_snapshot_message(
protocol_states, protocol_components, contract_states
)
decoded = stream_adapter.process_snapshot(block, snapshot_message)
for pool_state in decoded.pool_states.values():
decoded = decoder.decode_snapshot(snapshot_message, block)
for pool_state in decoded.values():
pool_id = pool_state.id_
if not pool_state.balances:
raise ValueError(f"Missing balances for pool {pool_id}")
for sell_token, buy_token in itertools.permutations(
pool_state.tokens, 2
pool_state.tokens, 2
):
# Try to sell 0.1% of the protocol balance
sell_amount = (
Decimal("0.001") * pool_state.balances[sell_token.address]
Decimal("0.001") * pool_state.balances[sell_token.address]
)
try:
amount_out, gas_used, _ = pool_state.get_amount_out(

View File

@@ -0,0 +1,34 @@
from logging import getLogger
from protosim_py.evm.pool_state import ThirdPartyPool
from tycho_client.dto import (
ResponseProtocolState,
ProtocolComponent,
ResponseAccount,
ComponentWithState,
Snapshot,
)
log = getLogger(__name__)
def build_snapshot_message(
protocol_states: list[ResponseProtocolState],
protocol_components: list[ProtocolComponent],
account_states: list[ResponseAccount],
) -> Snapshot:
vm_storage = {state.address: state for state in account_states}
states = {}
for component in protocol_components:
pool_id = component.id
states[pool_id] = {"component": component}
for state in protocol_states:
pool_id = state.component_id
if pool_id not in states:
log.warning(f"State for pool {pool_id} not found in components")
continue
states[pool_id]["state"] = state
states = {id_: ComponentWithState(**state) for id_, state in states.items()}
return Snapshot(states=states, vm_storage=vm_storage)