Load Snapshot from RPC request

This commit is contained in:
Thales Lima
2024-07-18 04:54:02 +02:00
committed by tvinagre
parent 5e6c7d4647
commit 183868e536
8 changed files with 500 additions and 59 deletions

View File

@@ -1,7 +1,11 @@
import os
from web3 import Web3
native_aliases = ["0x0000000000000000000000000000000000000000","0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"]
native_aliases = [
"0x0000000000000000000000000000000000000000",
"0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
]
erc20_abi = [
{
@@ -13,6 +17,7 @@ erc20_abi = [
}
]
def get_token_balance(token_address, wallet_address, block_number):
rpc_url = os.getenv("RPC_URL")
@@ -23,14 +28,34 @@ def get_token_balance(token_address, wallet_address, block_number):
if not web3.isConnected():
raise ConnectionError("Failed to connect to the Ethereum node")
# Check if the token_address is a native token alias
if token_address.lower() in native_aliases:
balance = web3.eth.get_balance(Web3.toChecksumAddress(wallet_address), block_identifier=block_number)
else:
contract = web3.eth.contract(address=Web3.toChecksumAddress(token_address), abi=erc20_abi)
balance = contract.functions.balanceOf(Web3.toChecksumAddress(wallet_address)).call(
block_identifier=block_number
balance = web3.eth.get_balance(
Web3.toChecksumAddress(wallet_address), block_identifier=block_number
)
else:
contract = web3.eth.contract(
address=Web3.toChecksumAddress(token_address), abi=erc20_abi
)
balance = contract.functions.balanceOf(
Web3.toChecksumAddress(wallet_address)
).call(block_identifier=block_number)
return balance
def get_block_header(block_number):
rpc_url = os.getenv("RPC_URL")
if rpc_url is None:
raise EnvironmentError("RPC_URL environment variable not set")
web3 = Web3(Web3.HTTPProvider(rpc_url))
if not web3.isConnected():
raise ConnectionError("Failed to connect to the Ethereum node")
block = web3.eth.get_block(block_number)
return block

View File

@@ -1,12 +1,23 @@
import itertools
import itertools
import os
from pathlib import Path
import shutil
import subprocess
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from pathlib import Path
import yaml
from pydantic import BaseModel
from evm import get_token_balance
from evm import get_token_balance, get_block_header
from tycho import TychoRunner
from tycho_client.tycho.decoders import ThirdPartyPoolTychoDecoder
from tycho_client.tycho.models import Blockchain, EVMBlock
from tycho_client.tycho.tycho_adapter import (
TychoPoolStateStreamAdapter,
)
class TestResult:
@@ -29,12 +40,20 @@ def load_config(yaml_path: str) -> dict:
return yaml.safe_load(file)
class SimulationFailure(BaseModel):
pool_id: str
sell_token: str
buy_token: str
error: str
class TestRunner:
def __init__(self, config_path: str, with_binary_logs: bool, db_url: str):
self.config = load_config(config_path)
self.base_dir = os.path.dirname(config_path)
self.tycho_runner = TychoRunner(with_binary_logs)
self.db_url = db_url
self._chain = Blockchain.ethereum
def run_tests(self) -> None:
"""Run all tests specified in the configuration."""
@@ -58,12 +77,11 @@ class TestRunner:
if result.success:
print(f"{test['name']} passed.")
else:
print(f"❗️ {test['name']} failed: {result.message}")
self.tycho_runner.empty_database(
self.db_url
)
self.tycho_runner.empty_database(self.db_url)
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
"""Validate the current protocol state against the expected state."""
@@ -90,7 +108,7 @@ class TestRunner:
)
if isinstance(value, list):
if set(map(str.lower, value)) != set(
map(str.lower, component[key])
map(str.lower, component[key])
):
return TestResult.Failed(
f"List mismatch for key '{key}': {value} != {component[key]}"
@@ -100,25 +118,116 @@ class TestRunner:
f"Value mismatch for key '{key}': {value} != {component[key]}"
)
token_balances: dict[str, dict[str, int]] = defaultdict(dict)
for component in protocol_components["protocol_components"]:
comp_id = component["id"].lower()
for token in component["tokens"]:
token_lower = token.lower()
state = next((s for s in protocol_states["states"] if s["component_id"].lower() == comp_id), None)
state = next(
(
s
for s in protocol_states["states"]
if s["component_id"].lower() == comp_id
),
None,
)
if state:
balance_hex = state["balances"].get(token_lower, "0x0")
else:
balance_hex = "0x0"
tycho_balance = int(balance_hex, 16)
token_balances[comp_id][token_lower] = tycho_balance
node_balance = get_token_balance(token, comp_id, stop_block)
tycho_balance = int(balance_hex, 16)
if node_balance != tycho_balance:
return TestResult.Failed(
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} from rpc call and {tycho_balance} from Substreams")
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} from rpc call and {tycho_balance} from Substreams"
)
contract_states = self.tycho_runner.get_contract_state()
self.simulate_get_amount_out(
token_balances,
stop_block,
protocol_states,
protocol_components,
contract_states,
)
return TestResult.Passed()
except Exception as e:
return TestResult.Failed(str(e))
def simulate_get_amount_out(
self,
token_balances: dict[str, dict[str, int]],
block_number: int,
protocol_states: dict,
protocol_components: dict,
contract_state: dict,
) -> TestResult:
protocol_type_names = self.config["protocol_type_names"]
block_header = get_block_header(block_number)
block: EVMBlock = EVMBlock(
id=block_number,
ts=datetime.fromtimestamp(block_header.timestamp),
hash_=block_header.hash.hex(),
)
failed_simulations = dict[str, list[SimulationFailure]]
for protocol in protocol_type_names:
# TODO: Parametrize this
decoder = ThirdPartyPoolTychoDecoder(
"CurveSwapAdapter.evm.runtime", 0, False
)
stream_adapter = TychoPoolStateStreamAdapter(
tycho_url="0.0.0.0:4242",
protocol=protocol,
decoder=decoder,
blockchain=self._chain,
)
snapshot_message = stream_adapter.build_snapshot_message(
protocol_components, protocol_states, contract_state
)
decoded = stream_adapter.process_snapshot(block, snapshot_message)
for pool_state in decoded.pool_states.values():
pool_id = pool_state.id_
protocol_balances = token_balances.get(pool_id)
if not protocol_balances:
raise ValueError(f"Missing balances for pool {pool_id}")
for sell_token, buy_token in itertools.permutations(
pool_state.tokens, 2
):
try:
# Try to sell 0.1% of the protocol balance
sell_amount = Decimal("0.001") * sell_token.from_onchain_amount(
protocol_balances[sell_token.address]
)
amount_out, gas_used, _ = pool_state.get_amount_out(
sell_token, sell_amount, buy_token
)
# TODO: Should we validate this with an archive node or RPC reader?
print(
f"Amount out for {pool_id}: {sell_amount} {sell_token} -> {amount_out} {buy_token} - "
f"Gas used: {gas_used}"
)
except Exception as e:
print(
f"Error simulating get_amount_out for {pool_id}: {sell_token} -> {buy_token}. "
f"Error: {e}"
)
if pool_id not in failed_simulations:
failed_simulations[pool_id] = []
failed_simulations[pool_id].append(
SimulationFailure(
pool_id=pool_id,
sell_token=sell_token,
buy_token=buy_token,
error=str(e),
)
)
continue
@staticmethod
def build_spkg(yaml_file_path: str, modify_func: callable) -> str:
"""Build a Substreams package with modifications to the YAML file."""

View File

@@ -1,14 +1,14 @@
import os
import signal
import subprocess
import threading
import time
import requests
import subprocess
import os
import psycopg2
from psycopg2 import sql
from pathlib import Path
import psycopg2
import requests
from psycopg2 import sql
binary_path = Path(__file__).parent / "tycho-indexer"
@@ -148,6 +148,16 @@ class TychoRunner:
response = requests.post(url, headers=headers, json=data)
return response.json()
@staticmethod
def get_contract_state() -> dict:
"""Retrieve contract state from the RPC server."""
url = "http://0.0.0.0:4242/v1/ethereum/contract_state"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
return response.json()
@staticmethod
def empty_database(db_url: str) -> None:
"""Drop and recreate the Tycho indexer database."""

View File

@@ -0,0 +1,250 @@
[
{
"inputs": [
{
"internalType": "uint256",
"name": "limit",
"type": "uint256"
}
],
"name": "LimitExceeded",
"type": "error"
},
{
"inputs": [
{
"internalType": "string",
"name": "reason",
"type": "string"
}
],
"name": "NotImplemented",
"type": "error"
},
{
"inputs": [
{
"internalType": "string",
"name": "reason",
"type": "string"
}
],
"name": "Unavailable",
"type": "error"
},
{
"inputs": [
{
"internalType": "bytes32",
"name": "poolId",
"type": "bytes32"
},
{
"internalType": "contract IERC20",
"name": "sellToken",
"type": "address"
},
{
"internalType": "contract IERC20",
"name": "buyToken",
"type": "address"
}
],
"name": "getCapabilities",
"outputs": [
{
"internalType": "enum ISwapAdapterTypes.Capability[]",
"name": "capabilities",
"type": "uint8[]"
}
],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes32",
"name": "poolId",
"type": "bytes32"
},
{
"internalType": "contract IERC20",
"name": "sellToken",
"type": "address"
},
{
"internalType": "contract IERC20",
"name": "buyToken",
"type": "address"
}
],
"name": "getLimits",
"outputs": [
{
"internalType": "uint256[]",
"name": "limits",
"type": "uint256[]"
}
],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "offset",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "limit",
"type": "uint256"
}
],
"name": "getPoolIds",
"outputs": [
{
"internalType": "bytes32[]",
"name": "ids",
"type": "bytes32[]"
}
],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes32",
"name": "poolId",
"type": "bytes32"
}
],
"name": "getTokens",
"outputs": [
{
"internalType": "contract IERC20[]",
"name": "tokens",
"type": "address[]"
}
],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes32",
"name": "poolId",
"type": "bytes32"
},
{
"internalType": "contract IERC20",
"name": "sellToken",
"type": "address"
},
{
"internalType": "contract IERC20",
"name": "buyToken",
"type": "address"
},
{
"internalType": "uint256[]",
"name": "specifiedAmounts",
"type": "uint256[]"
}
],
"name": "price",
"outputs": [
{
"components": [
{
"internalType": "uint256",
"name": "numerator",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "denominator",
"type": "uint256"
}
],
"internalType": "struct ISwapAdapterTypes.Fraction[]",
"name": "prices",
"type": "tuple[]"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes32",
"name": "poolId",
"type": "bytes32"
},
{
"internalType": "contract IERC20",
"name": "sellToken",
"type": "address"
},
{
"internalType": "contract IERC20",
"name": "buyToken",
"type": "address"
},
{
"internalType": "enum ISwapAdapterTypes.OrderSide",
"name": "side",
"type": "uint8"
},
{
"internalType": "uint256",
"name": "specifiedAmount",
"type": "uint256"
}
],
"name": "swap",
"outputs": [
{
"components": [
{
"internalType": "uint256",
"name": "calculatedAmount",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "gasUsed",
"type": "uint256"
},
{
"components": [
{
"internalType": "uint256",
"name": "numerator",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "denominator",
"type": "uint256"
}
],
"internalType": "struct ISwapAdapterTypes.Fraction",
"name": "price",
"type": "tuple"
}
],
"internalType": "struct ISwapAdapterTypes.Trade",
"name": "trade",
"type": "tuple"
}
],
"stateMutability": "nonpayable",
"type": "function"
}
]

View File

@@ -62,7 +62,7 @@ class ThirdPartyPoolTychoDecoder:
adapter_contract_name=self.adapter_contract,
minimum_gas=self.minimum_gas,
hard_sell_limit=self.hard_limit,
trace=True,
trace=False,
**optional_attributes,
)

View File

@@ -5,7 +5,7 @@ from fractions import Fraction
from logging import getLogger
from typing import Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
Address = str
@@ -30,6 +30,7 @@ class EthereumToken(BaseModel):
address: str
decimals: int
gas: Union[int, list[int]] = 29000
_hash: int = PrivateAttr(default=None)
def to_onchain_amount(self, amount: Union[float, Decimal, str]) -> int:
"""Converts floating-point numerals to an integer, by shifting right by the
@@ -62,7 +63,7 @@ class EthereumToken(BaseModel):
Quantize is needed for UniswapV2.
"""
with localcontext(self._dec_context):
with localcontext(Context(rounding=ROUND_FLOOR, prec=256)):
if isinstance(onchain_amount, Fraction):
return (
Decimal(onchain_amount.numerator)
@@ -80,6 +81,22 @@ class EthereumToken(BaseModel):
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals)
return amount
def __repr__(self):
return self.symbol
def __str__(self):
return self.symbol
def __eq__(self, other) -> bool:
# this is faster than calling custom __hash__, due to cache check
return other.address == self.address
def __hash__(self) -> int:
if self._hash is None:
# caching the hash saves time during graph search
self._hash = hash(self.address)
return self._hash
class DatabaseType(Enum):
# Make call to the node each time it needs a storage (unless cached from a previous call).

View File

@@ -5,7 +5,7 @@ from copy import deepcopy
from decimal import Decimal
from fractions import Fraction
from logging import getLogger
from typing import Optional, cast, TypeVar, Annotated, DefaultDict
from typing import Optional, cast, TypeVar, Annotated
from eth_typing import HexStr
from protosim_py import SimulationEngine, AccountInfo
@@ -54,10 +54,10 @@ class ThirdPartyPool(BaseModel):
"""The contract address for where protocol balances are stored (i.e. a vault contract).
If given, balances will be overwritten here instead of on the pool contract during simulations."""
block_lasting_overwrites: DefaultDict[
block_lasting_overwrites: defaultdict[
Address,
Annotated[dict[int, int], Field(default_factory=lambda: defaultdict[dict])],
]
] = Field(default_factory=lambda: defaultdict(dict))
"""Storage overwrites that will be applied to all simulations. They will be cleared
when ``clear_all_cache`` is called, i.e. usually at each block. Hence the name."""
@@ -97,6 +97,18 @@ class ThirdPartyPool(BaseModel):
return
else:
engine = create_engine([t.address for t in self.tokens], trace=self.trace)
engine.init_account(
address="0x0000000000000000000000000000000000000000",
account=AccountInfo(balance=0, nonce=0),
mocked=False,
permanent_storage=None,
)
engine.init_account(
address="0x0000000000000000000000000000000000000004",
account=AccountInfo(balance=0, nonce=0),
mocked=False,
permanent_storage=None,
)
engine.init_account(
address=ADAPTER_ADDRESS,
account=AccountInfo(
@@ -116,6 +128,7 @@ class ThirdPartyPool(BaseModel):
)
self._engine = engine
def _set_spot_prices(self):
"""Set the spot prices for this pool.
We currently require the price function capability for now.

View File

@@ -3,7 +3,6 @@ import json
import platform
import time
from asyncio.subprocess import STDOUT, PIPE
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
@@ -14,11 +13,11 @@ from typing import Any, Optional, Dict
import requests
from protosim_py import AccountUpdate, AccountInfo, BlockHeader
from .pool_state import ThirdPartyPool
from .constants import TYCHO_CLIENT_LOG_FOLDER, TYCHO_CLIENT_FOLDER
from .decoders import ThirdPartyPoolTychoDecoder
from .exceptions import APIRequestError, TychoClientException
from .models import Blockchain, EVMBlock, EthereumToken, SynchronizerState, Address
from .pool_state import ThirdPartyPool
from .tycho_db import TychoDBSingleton
from .utils import create_engine
@@ -117,7 +116,6 @@ class BlockProtocolChanges:
pool_states: dict[Address, ThirdPartyPool]
"""All updated pools"""
removed_pools: set[Address]
sync_states: dict[str, SynchronizerState]
deserialization_time: float
"""The time it took to deserialize the pool states from the tycho feed message"""
@@ -153,7 +151,7 @@ class TychoPoolStateStreamAdapter:
# Create engine
# TODO: This should be initialized outside the adapter?
TychoDBSingleton.initialize(tycho_http_url=self.tycho_url)
self._engine = create_engine([], trace=True)
self._engine = create_engine([], trace=False)
# Loads tokens from Tycho
self._tokens: dict[str, EthereumToken] = TokenLoader(
@@ -162,10 +160,6 @@ class TychoPoolStateStreamAdapter:
min_token_quality=self.min_token_quality,
).get_tokens()
# TODO: Check if it's necessary
self.ignored_pools = []
self.vm_contracts = defaultdict(list)
async def start(self):
"""Start the tycho-client Rust binary through subprocess"""
# stdout=PIPE means that the output is piped directly to this Python process
@@ -240,23 +234,31 @@ class TychoPoolStateStreamAdapter:
error_msg += f" Tycho logs: {last_lines}"
log.exception(error_msg)
raise Exception("Tycho-client failed.")
return self._process_message(msg)
return self.process_tycho_message(msg)
def _process_message(self, msg) -> BlockProtocolChanges:
try:
sync_state = msg["sync_states"][self.protocol]
state_msg = msg["state_msgs"][self.protocol]
log.info(f"Received sync state for {self.protocol}: {sync_state}")
if not sync_state["status"] != SynchronizerState.ready.value:
raise ValueError("Tycho-indexer is not synced")
except KeyError:
raise ValueError("Invalid message received from tycho-client.")
@staticmethod
def build_snapshot_message(
protocol_components: dict, protocol_states: dict, contract_states: dict
) -> dict[str, ThirdPartyPool]:
vm_states = {state["address"]: state for state in contract_states["accounts"]}
states = {}
for component in protocol_components["protocol_components"]:
pool_id = component["id"]
states[pool_id] = {"component": component}
for state in protocol_states["states"]:
pool_id = state["component_id"]
if pool_id not in states:
log.warning(f"State for pool {pool_id} not found in components")
continue
states[pool_id]["state"] = state
snapshot = {"vm_storage": vm_states, "states": states}
start = time.monotonic()
return snapshot
removed_pools = set()
decoded_count = 0
failed_count = 0
def process_tycho_message(self, msg) -> BlockProtocolChanges:
self._validate_sync_states(msg)
state_msg = msg["state_msgs"][self.protocol]
block = EVMBlock(
id=msg["block"]["id"],
@@ -264,24 +266,30 @@ class TychoPoolStateStreamAdapter:
hash_=msg["block"]["hash"],
)
self._process_vm_storage(state_msg["snapshots"]["vm_storage"], block)
return self.process_snapshot(block, state_msg["snapshot"])
# decode new pools
def process_snapshot(
self, block: EVMBlock, state_msg: dict
) -> BlockProtocolChanges:
start = time.monotonic()
removed_pools = set()
decoded_count = 0
failed_count = 0
self._process_vm_storage(state_msg["vm_storage"], block)
# decode new components
decoded_pools, failed_pools = self._decoder.decode_snapshot(
state_msg["snapshots"]["states"], block, self._tokens
state_msg["states"], block, self._tokens
)
decoded_count += len(decoded_pools)
failed_count += len(failed_pools)
for addr, p in decoded_pools.items():
self.vm_contracts[addr].append(p.id_)
decoded_pools = {
p.id_: p for p in decoded_pools.values()
} # remap pools to their pool ids
deserialization_time = time.monotonic() - start
total = decoded_count + failed_count
log.debug(
f"Received {total} snapshots. n_decoded: {decoded_count}, n_failed: {failed_count}"
@@ -296,6 +304,15 @@ class TychoPoolStateStreamAdapter:
deserialization_time=round(deserialization_time, 3),
)
def _validate_sync_states(self, msg):
try:
sync_state = msg["sync_states"][self.protocol]
log.info(f"Received sync state for {self.protocol}: {sync_state}")
if not sync_state["status"] != SynchronizerState.ready.value:
raise ValueError("Tycho-indexer is not synced")
except KeyError:
raise ValueError("Invalid message received from tycho-client.")
def _process_vm_storage(self, storage: dict[str, Any], block: EVMBlock):
vm_updates = []
for storage_update in storage.values():
@@ -325,4 +342,4 @@ class TychoPoolStateStreamAdapter:
)
block_header = BlockHeader(block.id, block.hash_, int(block.ts.timestamp()))
self._db.update(vm_updates, block_header)
TychoDBSingleton.get_instance().update(vm_updates, block_header)