Add build_snapshot_message method
This commit is contained in:
@@ -9,12 +9,19 @@ from decimal import Decimal
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from protosim_py.evm.decoders import ThirdPartyPoolTychoDecoder
|
||||||
|
from protosim_py.models import EVMBlock
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tycho_client.dto import (
|
from tycho_client.dto import (
|
||||||
Chain,
|
Chain,
|
||||||
ProtocolComponentsParams,
|
ProtocolComponentsParams,
|
||||||
ProtocolStateParams,
|
ProtocolStateParams,
|
||||||
ContractStateParams, ProtocolComponent, ResponseProtocolState, HexBytes,
|
ContractStateParams,
|
||||||
|
ProtocolComponent,
|
||||||
|
ResponseProtocolState,
|
||||||
|
HexBytes,
|
||||||
|
ResponseAccount,
|
||||||
|
Snapshot,
|
||||||
)
|
)
|
||||||
from tycho_client.rpc_client import TychoRPCClient
|
from tycho_client.rpc_client import TychoRPCClient
|
||||||
from tycho_client.stream import TychoStream
|
from tycho_client.stream import TychoStream
|
||||||
@@ -22,11 +29,7 @@ from tycho_client.stream import TychoStream
|
|||||||
from .adapter_handler import AdapterContractHandler
|
from .adapter_handler import AdapterContractHandler
|
||||||
from .evm import get_token_balance, get_block_header
|
from .evm import get_token_balance, get_block_header
|
||||||
from .tycho import TychoRunner
|
from .tycho import TychoRunner
|
||||||
|
from .utils import build_snapshot_message
|
||||||
|
|
||||||
# from tycho_client.decoders import ThirdPartyPoolTychoDecoder
|
|
||||||
# from tycho_client.models import Blockchain, EVMBlock
|
|
||||||
# from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter
|
|
||||||
|
|
||||||
|
|
||||||
class TestResult:
|
class TestResult:
|
||||||
@@ -58,7 +61,7 @@ class SimulationFailure(BaseModel):
|
|||||||
|
|
||||||
class TestRunner:
|
class TestRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
|
self, package: str, with_binary_logs: bool, db_url: str, vm_traces: bool
|
||||||
):
|
):
|
||||||
self.repo_root = os.getcwd()
|
self.repo_root = os.getcwd()
|
||||||
config_path = os.path.join(
|
config_path = os.path.join(
|
||||||
@@ -105,15 +108,18 @@ class TestRunner:
|
|||||||
|
|
||||||
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
|
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
|
||||||
"""Validate the current protocol state against the expected state."""
|
"""Validate the current protocol state against the expected state."""
|
||||||
protocol_components: list[ProtocolComponent] = self.tycho_rpc_client.get_protocol_components(
|
protocol_components: list[
|
||||||
|
ProtocolComponent
|
||||||
|
] = self.tycho_rpc_client.get_protocol_components(
|
||||||
ProtocolComponentsParams(protocol_system="test_protocol")
|
ProtocolComponentsParams(protocol_system="test_protocol")
|
||||||
)
|
)
|
||||||
protocol_states: list[ResponseProtocolState] = self.tycho_rpc_client.get_protocol_state(
|
protocol_states: list[
|
||||||
|
ResponseProtocolState
|
||||||
|
] = self.tycho_rpc_client.get_protocol_state(
|
||||||
ProtocolStateParams(protocol_system="test_protocol")
|
ProtocolStateParams(protocol_system="test_protocol")
|
||||||
)
|
)
|
||||||
components_by_id = {
|
components_by_id = {
|
||||||
component.id: component
|
component.id: component for component in protocol_components
|
||||||
for component in protocol_components
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -135,7 +141,7 @@ class TestRunner:
|
|||||||
)
|
)
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if set(map(str.lower, value)) != set(
|
if set(map(str.lower, value)) != set(
|
||||||
map(str.lower, component[key])
|
map(str.lower, component[key])
|
||||||
):
|
):
|
||||||
return TestResult.Failed(
|
return TestResult.Failed(
|
||||||
f"List mismatch for key '{key}': {value} != {component[key]}"
|
f"List mismatch for key '{key}': {value} != {component[key]}"
|
||||||
@@ -152,8 +158,8 @@ class TestRunner:
|
|||||||
state = next(
|
state = next(
|
||||||
(
|
(
|
||||||
s
|
s
|
||||||
for s in protocol_states["states"]
|
for s in protocol_states
|
||||||
if s["component_id"].lower() == comp_id
|
if s.component_id.lower() == comp_id
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -171,20 +177,19 @@ class TestRunner:
|
|||||||
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
|
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
|
||||||
f"from rpc call and {tycho_balance} from Substreams"
|
f"from rpc call and {tycho_balance} from Substreams"
|
||||||
)
|
)
|
||||||
contract_states = self.tycho_rpc_client.get_contract_state(
|
contract_states: list[
|
||||||
ContractStateParams()
|
ResponseAccount
|
||||||
)
|
] = self.tycho_rpc_client.get_contract_state(ContractStateParams())
|
||||||
filtered_components = {
|
filtered_components = [
|
||||||
"protocol_components": [
|
pc
|
||||||
pc
|
for pc in protocol_components
|
||||||
for pc in protocol_components
|
if pc.id
|
||||||
if pc.id in [
|
in [
|
||||||
c["id"].lower()
|
c["id"].lower()
|
||||||
for c in expected_state["protocol_components"]
|
for c in expected_state["protocol_components"]
|
||||||
if c.get("skip_simulation", False) is False
|
if c.get("skip_simulation", False) is False
|
||||||
]
|
|
||||||
]
|
]
|
||||||
}
|
]
|
||||||
simulation_failures = self.simulate_get_amount_out(
|
simulation_failures = self.simulate_get_amount_out(
|
||||||
stop_block, protocol_states, filtered_components, contract_states
|
stop_block, protocol_states, filtered_components, contract_states
|
||||||
)
|
)
|
||||||
@@ -205,11 +210,11 @@ class TestRunner:
|
|||||||
return TestResult.Failed(error_message)
|
return TestResult.Failed(error_message)
|
||||||
|
|
||||||
def simulate_get_amount_out(
|
def simulate_get_amount_out(
|
||||||
self,
|
self,
|
||||||
block_number: int,
|
block_number: int,
|
||||||
protocol_states: ResponseProtocolState,
|
protocol_states: list[ResponseProtocolState],
|
||||||
protocol_components: list[ProtocolComponent],
|
protocol_components: list[ProtocolComponent],
|
||||||
contract_state: Contract,
|
contract_states: list[ResponseAccount],
|
||||||
) -> dict[str, list[SimulationFailure]]:
|
) -> dict[str, list[SimulationFailure]]:
|
||||||
protocol_type_names = self.config["protocol_type_names"]
|
protocol_type_names = self.config["protocol_type_names"]
|
||||||
|
|
||||||
@@ -238,15 +243,10 @@ class TestRunner:
|
|||||||
self.config["adapter_build_args"],
|
self.config["adapter_build_args"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder = ThirdPartyPoolTychoDecoder(
|
decoder = ThirdPartyPoolTychoDecoder(
|
||||||
# adapter_contract, 0, trace=self._vm_traces
|
adapter_contract=adapter_contract, minimum_gas=0, trace=self._vm_traces
|
||||||
# )
|
)
|
||||||
# stream_adapter = TychoPoolStateStreamAdapter(
|
|
||||||
# tycho_url="0.0.0.0:4242",
|
|
||||||
# protocol=protocol,
|
|
||||||
# decoder=decoder,
|
|
||||||
# blockchain=self._chain,
|
|
||||||
# )
|
|
||||||
stream_adapter = TychoStream(
|
stream_adapter = TychoStream(
|
||||||
tycho_url="0.0.0.0:4242",
|
tycho_url="0.0.0.0:4242",
|
||||||
exchanges=[protocol],
|
exchanges=[protocol],
|
||||||
@@ -254,21 +254,22 @@ class TestRunner:
|
|||||||
blockchain=self._chain,
|
blockchain=self._chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
snapshot_message = stream_adapter.build_snapshot_message(
|
snapshot_message: Snapshot = build_snapshot_message(
|
||||||
protocol_components, protocol_states, contract_state
|
protocol_states, protocol_components, contract_states
|
||||||
)
|
)
|
||||||
decoded = stream_adapter.process_snapshot(block, snapshot_message)
|
|
||||||
|
|
||||||
for pool_state in decoded.pool_states.values():
|
decoded = decoder.decode_snapshot(snapshot_message, block)
|
||||||
|
|
||||||
|
for pool_state in decoded.values():
|
||||||
pool_id = pool_state.id_
|
pool_id = pool_state.id_
|
||||||
if not pool_state.balances:
|
if not pool_state.balances:
|
||||||
raise ValueError(f"Missing balances for pool {pool_id}")
|
raise ValueError(f"Missing balances for pool {pool_id}")
|
||||||
for sell_token, buy_token in itertools.permutations(
|
for sell_token, buy_token in itertools.permutations(
|
||||||
pool_state.tokens, 2
|
pool_state.tokens, 2
|
||||||
):
|
):
|
||||||
# Try to sell 0.1% of the protocol balance
|
# Try to sell 0.1% of the protocol balance
|
||||||
sell_amount = (
|
sell_amount = (
|
||||||
Decimal("0.001") * pool_state.balances[sell_token.address]
|
Decimal("0.001") * pool_state.balances[sell_token.address]
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
amount_out, gas_used, _ = pool_state.get_amount_out(
|
amount_out, gas_used, _ = pool_state.get_amount_out(
|
||||||
|
|||||||
34
testing/src/runner/utils.py
Normal file
34
testing/src/runner/utils.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from protosim_py.evm.pool_state import ThirdPartyPool
|
||||||
|
from tycho_client.dto import (
|
||||||
|
ResponseProtocolState,
|
||||||
|
ProtocolComponent,
|
||||||
|
ResponseAccount,
|
||||||
|
ComponentWithState,
|
||||||
|
Snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def build_snapshot_message(
|
||||||
|
protocol_states: list[ResponseProtocolState],
|
||||||
|
protocol_components: list[ProtocolComponent],
|
||||||
|
account_states: list[ResponseAccount],
|
||||||
|
) -> Snapshot:
|
||||||
|
vm_storage = {state.address: state for state in account_states}
|
||||||
|
|
||||||
|
states = {}
|
||||||
|
for component in protocol_components:
|
||||||
|
pool_id = component.id
|
||||||
|
states[pool_id] = {"component": component}
|
||||||
|
for state in protocol_states:
|
||||||
|
pool_id = state.component_id
|
||||||
|
if pool_id not in states:
|
||||||
|
log.warning(f"State for pool {pool_id} not found in components")
|
||||||
|
continue
|
||||||
|
states[pool_id]["state"] = state
|
||||||
|
|
||||||
|
states = {id_: ComponentWithState(**state) for id_, state in states.items()}
|
||||||
|
return Snapshot(states=states, vm_storage=vm_storage)
|
||||||
Reference in New Issue
Block a user