Start using external modules

This commit is contained in:
Thales
2024-08-05 19:58:10 -03:00
committed by tvinagre
parent 8ea02613a2
commit d893ab264c
9 changed files with 171 additions and 176 deletions

View File

@@ -6,29 +6,20 @@ def main() -> None:
parser = argparse.ArgumentParser(
description="Run indexer within a specified range of blocks"
)
parser.add_argument("--package", type=str, help="Name of the package to test.")
parser.add_argument(
"--package", type=str, help="Name of the package to test."
)
parser.add_argument(
"--tycho-logs",
action="store_true",
help="Flag to activate logs from Tycho.",
"--tycho-logs", action="store_true", help="Flag to activate logs from Tycho."
)
parser.add_argument(
"--db-url", type=str, help="Postgres database URL for the Tycho indexer."
)
parser.add_argument(
"--vm-traces",
action="store_true",
help="Enable tracing during vm simulations.",
"--vm-traces", action="store_true", help="Enable tracing during vm simulations."
)
args = parser.parse_args()
test_runner = TestRunner(
args.package,
args.tycho_logs,
db_url=args.db_url,
vm_traces=args.vm_traces,
args.package, args.tycho_logs, db_url=args.db_url, vm_traces=args.vm_traces
)
test_runner.run_tests()

View File

@@ -2,22 +2,31 @@ import itertools
import os
import shutil
import subprocess
import traceback
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from pathlib import Path
import traceback
import yaml
from pydantic import BaseModel
from tycho_client.dto import (
Chain,
ProtocolComponentsParams,
ProtocolStateParams,
ContractStateParams, ProtocolComponent, ResponseProtocolState, HexBytes,
)
from tycho_client.rpc_client import TychoRPCClient
from tycho_client.stream import TychoStream
from adapter_handler import AdapterContractHandler
from tycho_client.decoders import ThirdPartyPoolTychoDecoder
from tycho_client.models import Blockchain, EVMBlock
from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter
from .adapter_handler import AdapterContractHandler
from .evm import get_token_balance, get_block_header
from .tycho import TychoRunner
from evm import get_token_balance, get_block_header
from tycho import TychoRunner, TychoRPCClient
# from tycho_client.decoders import ThirdPartyPoolTychoDecoder
# from tycho_client.models import Blockchain, EVMBlock
# from tycho_client.tycho_adapter import TychoPoolStateStreamAdapter
class TestResult:
@@ -64,7 +73,7 @@ class TestRunner:
self.tycho_rpc_client = TychoRPCClient()
self.db_url = db_url
self._vm_traces = vm_traces
self._chain = Blockchain.ethereum
self._chain = Chain.ethereum
def run_tests(self) -> None:
"""Run all tests specified in the configuration."""
@@ -96,22 +105,27 @@ class TestRunner:
def validate_state(self, expected_state: dict, stop_block: int) -> TestResult:
"""Validate the current protocol state against the expected state."""
protocol_components = self.tycho_rpc_client.get_protocol_components()
protocol_states = self.tycho_rpc_client.get_protocol_state()
components = {
component["id"]: component
for component in protocol_components["protocol_components"]
protocol_components: list[ProtocolComponent] = self.tycho_rpc_client.get_protocol_components(
ProtocolComponentsParams(protocol_system="test_protocol")
)
protocol_states: list[ResponseProtocolState] = self.tycho_rpc_client.get_protocol_state(
ProtocolStateParams(protocol_system="test_protocol")
)
components_by_id = {
component.id: component
for component in protocol_components
}
try:
for expected_component in expected_state.get("protocol_components", []):
comp_id = expected_component["id"].lower()
if comp_id not in components:
if comp_id not in components_by_id:
return TestResult.Failed(
f"'{comp_id}' not found in protocol components."
)
component = components[comp_id]
# TODO: Manipulate pydantic objects instead of dict
component = components_by_id[comp_id].dict()
for key, value in expected_component.items():
if key not in ["tokens", "static_attributes", "creation_tx"]:
continue
@@ -131,11 +145,10 @@ class TestRunner:
f"Value mismatch for key '{key}': {value} != {component[key]}"
)
token_balances: dict[str, dict[str, int]] = defaultdict(dict)
for component in protocol_components["protocol_components"]:
comp_id = component["id"].lower()
for token in component["tokens"]:
token_lower = token.lower()
token_balances: dict[str, dict[HexBytes, int]] = defaultdict(dict)
for component in protocol_components:
comp_id = component.id.lower()
for token in component.tokens:
state = next(
(
s
@@ -145,11 +158,11 @@ class TestRunner:
None,
)
if state:
balance_hex = state["balances"].get(token_lower, "0x0")
balance_hex = state["balances"].get(token, "0x0")
else:
balance_hex = "0x0"
tycho_balance = int(balance_hex, 16)
token_balances[comp_id][token_lower] = tycho_balance
token_balances[comp_id][token] = tycho_balance
if self.config["skip_balance_check"] is not True:
node_balance = get_token_balance(token, comp_id, stop_block)
@@ -158,13 +171,14 @@ class TestRunner:
f"Balance mismatch for {comp_id}:{token} at block {stop_block}: got {node_balance} "
f"from rpc call and {tycho_balance} from Substreams"
)
contract_states = self.tycho_rpc_client.get_contract_state()
contract_states = self.tycho_rpc_client.get_contract_state(
ContractStateParams()
)
filtered_components = {
"protocol_components": [
pc
for pc in protocol_components["protocol_components"]
if pc["id"]
in [
for pc in protocol_components
if pc.id in [
c["id"].lower()
for c in expected_state["protocol_components"]
if c.get("skip_simulation", False) is False
@@ -193,9 +207,9 @@ class TestRunner:
def simulate_get_amount_out(
self,
block_number: int,
protocol_states: dict,
protocol_components: dict,
contract_state: dict,
protocol_states: ResponseProtocolState,
protocol_components: list[ProtocolComponent],
contract_state: Contract,
) -> dict[str, list[SimulationFailure]]:
protocol_type_names = self.config["protocol_type_names"]
@@ -224,15 +238,22 @@ class TestRunner:
self.config["adapter_build_args"],
)
decoder = ThirdPartyPoolTychoDecoder(
adapter_contract, 0, trace=self._vm_traces
)
stream_adapter = TychoPoolStateStreamAdapter(
# decoder = ThirdPartyPoolTychoDecoder(
# adapter_contract, 0, trace=self._vm_traces
# )
# stream_adapter = TychoPoolStateStreamAdapter(
# tycho_url="0.0.0.0:4242",
# protocol=protocol,
# decoder=decoder,
# blockchain=self._chain,
# )
stream_adapter = TychoStream(
tycho_url="0.0.0.0:4242",
protocol=protocol,
decoder=decoder,
exchanges=[protocol],
min_tvl=Decimal("0"),
blockchain=self._chain,
)
snapshot_message = stream_adapter.build_snapshot_message(
protocol_components, protocol_states, contract_state
)

View File

@@ -39,40 +39,13 @@ def find_binary_file(file_name):
binary_path = find_binary_file("tycho-indexer")
class TychoRPCClient:
def __init__(self, rpc_url: str = "http://0.0.0.0:4242"):
self.rpc_url = rpc_url
def get_protocol_components(self) -> dict:
"""Retrieve protocol components from the RPC server."""
url = self.rpc_url + "/v1/ethereum/protocol_components"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {"protocol_system": "test_protocol"}
response = requests.post(url, headers=headers, json=data)
return response.json()
def get_protocol_state(self) -> dict:
"""Retrieve protocol state from the RPC server."""
url = self.rpc_url + "/v1/ethereum/protocol_state"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
return response.json()
def get_contract_state(self) -> dict:
"""Retrieve contract state from the RPC server."""
url = self.rpc_url + "/v1/ethereum/contract_state?include_balances=false"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
return response.json()
class TychoRunner:
def __init__(self, db_url: str, with_binary_logs: bool = False, initialized_accounts: list[str] = None):
def __init__(
self,
db_url: str,
with_binary_logs: bool = False,
initialized_accounts: list[str] = None,
):
self.with_binary_logs = with_binary_logs
self._db_url = db_url
self._initialized_accounts = initialized_accounts or []
@@ -112,7 +85,12 @@ class TychoRunner:
str(end_block + 2),
"--initialization-block",
str(start_block),
] + (["--initialized-accounts", ",".join(all_accounts)] if all_accounts else []),
]
+ (
["--initialized-accounts", ",".join(all_accounts)]
if all_accounts
else []
),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
@@ -151,12 +129,7 @@ class TychoRunner:
env["RUST_LOG"] = "info"
process = subprocess.Popen(
[
binary_path,
"--database-url",
self._db_url,
"rpc"
],
[binary_path, "--database-url", self._db_url, "rpc"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
@@ -206,7 +179,7 @@ class TychoRunner:
def empty_database(db_url: str) -> None:
"""Drop and recreate the Tycho indexer database."""
try:
conn = psycopg2.connect(db_url[:db_url.rfind('/')])
conn = psycopg2.connect(db_url[: db_url.rfind("/")])
conn.autocommit = True
cursor = conn.cursor()

View File

@@ -7,6 +7,6 @@ TYCHO_CLIENT_LOG_FOLDER = TYCHO_CLIENT_FOLDER / "logs"
EXTERNAL_ACCOUNT: Final[str] = "0xf847a638E44186F3287ee9F8cAF73FF4d4B80784"
"""This is a dummy address used as a transaction sender"""
UINT256_MAX: Final[int] = 2 ** 256 - 1
UINT256_MAX: Final[int] = 2**256 - 1
MAX_BALANCE: Final[int] = UINT256_MAX // 2
"""0.5 of the maximal possible balance to avoid overflow errors"""

View File

@@ -85,20 +85,29 @@ class ThirdPartyPoolTychoDecoder:
while f"stateless_contract_addr_{index}" in static_attributes:
encoded_address = static_attributes[f"stateless_contract_addr_{index}"]
decoded = bytes.fromhex(
encoded_address[2:] if encoded_address.startswith('0x') else encoded_address).decode('utf-8')
encoded_address[2:]
if encoded_address.startswith("0x")
else encoded_address
).decode("utf-8")
if decoded.startswith("call"):
address = ThirdPartyPoolTychoDecoder.get_address_from_call(block_number, decoded)
address = ThirdPartyPoolTychoDecoder.get_address_from_call(
block_number, decoded
)
else:
address = decoded
code = static_attributes.get(f"stateless_contract_code_{index}") or get_code_for_address(address)
code = static_attributes.get(
f"stateless_contract_code_{index}"
) or get_code_for_address(address)
stateless_contracts[address] = code
index += 1
index = 0
while f"stateless_contract_addr_{index}" in attributes:
address = attributes[f"stateless_contract_addr_{index}"]
code = attributes.get(f"stateless_contract_code_{index}") or get_code_for_address(address)
code = attributes.get(
f"stateless_contract_code_{index}"
) or get_code_for_address(address)
stateless_contracts[address] = code
index += 1
return {
@@ -118,15 +127,17 @@ class ThirdPartyPoolTychoDecoder:
permanent_storage=None,
)
selector = keccak(text=decoded.split(":")[-1])[:4]
sim_result = engine.run_sim(SimulationParameters(
sim_result = engine.run_sim(
SimulationParameters(
data=bytearray(selector),
to=decoded.split(':')[1],
to=decoded.split(":")[1],
block_number=block_number,
timestamp=int(time.time()),
overrides={},
caller=EXTERNAL_ACCOUNT,
value=0,
))
)
)
address = eth_abi.decode(["address"], bytearray(sim_result.result))
return address[0]

View File

@@ -41,7 +41,7 @@ class EthereumToken(BaseModel):
log.warning(f"Expected variable of type Decimal. Got {type(amount)}.")
with localcontext(Context(rounding=ROUND_FLOOR, prec=256)):
amount = Decimal(str(amount)) * (10 ** self.decimals)
amount = Decimal(str(amount)) * (10**self.decimals)
try:
amount = amount.quantize(Decimal("1.0"))
except InvalidOperation:
@@ -68,17 +68,17 @@ class EthereumToken(BaseModel):
return (
Decimal(onchain_amount.numerator)
/ Decimal(onchain_amount.denominator)
/ Decimal(10 ** self.decimals)
/ Decimal(10**self.decimals)
).quantize(Decimal(f"{1 / 10 ** self.decimals}"))
if quantize is True:
try:
amount = (
Decimal(str(onchain_amount)) / 10 ** self.decimals
Decimal(str(onchain_amount)) / 10**self.decimals
).quantize(Decimal(f"{1 / 10 ** self.decimals}"))
except InvalidOperation:
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals)
amount = Decimal(str(onchain_amount)) / Decimal(10**self.decimals)
else:
amount = Decimal(str(onchain_amount)) / Decimal(10 ** self.decimals)
amount = Decimal(str(onchain_amount)) / Decimal(10**self.decimals)
return amount
def __repr__(self):

View File

@@ -142,7 +142,7 @@ class ThirdPartyPool(BaseModel):
if Capability.ScaledPrice in self.capabilities:
self.spot_prices[(t0, t1)] = frac_to_decimal(frac)
else:
scaled = frac * Fraction(10 ** t0.decimals, 10 ** t1.decimals)
scaled = frac * Fraction(10**t0.decimals, 10**t1.decimals)
self.spot_prices[(t0, t1)] = frac_to_decimal(scaled)
def _ensure_capability(self, capability: Capability):

View File

@@ -181,7 +181,7 @@ class TychoPoolStateStreamAdapter:
log.debug(f"Starting tycho-client binary at {bin_path}. CMD: {cmd}")
self.tycho_client = await asyncio.create_subprocess_exec(
str(bin_path), *cmd, stdout=PIPE, stderr=STDOUT, limit=2 ** 64
str(bin_path), *cmd, stdout=PIPE, stderr=STDOUT, limit=2**64
)
@staticmethod

View File

@@ -121,8 +121,7 @@ class ERC20OverwriteFactory:
"""
self._overwrites[self._total_supply_slot] = supply
log.log(
5,
f"Override total supply: token={self._token.address} supply={supply}"
5, f"Override total supply: token={self._token.address} supply={supply}"
)
def get_protosim_overwrites(self) -> dict[Address, dict[int, int]]: