Rename adapter_handler to adapter_builder and move responsibilities

This commit is contained in:
Thales Lima
2024-08-08 22:55:18 +02:00
committed by tvinagre
parent c0382fefdf
commit 5eea31db40
2 changed files with 83 additions and 68 deletions

View File

@@ -1,15 +1,33 @@
import os
import subprocess import subprocess
from typing import Optional from typing import Optional
class AdapterContractHandler: class AdapterContractBuilder:
@staticmethod def __init__(self, src_path: str):
self.src_path = src_path
def find_contract(self, adapter_contract: str):
"""
Finds the contract file in the provided source path.
:param adapter_contract: The contract name to be found.
:return: The path to the contract file.
"""
contract_path = os.path.join(
self.src_path,
"out",
f"{adapter_contract}.sol",
f"{adapter_contract}.evm.runtime",
)
if not os.path.exists(contract_path):
raise FileNotFoundError(f"Contract {adapter_contract} not found.")
return contract_path
def build_target( def build_target(
src_path: str, self, adapter_contract: str, signature: Optional[str], args: Optional[str]
adapter_contract: str, ) -> str:
signature: Optional[str],
args: Optional[str],
):
""" """
Runs the buildRuntime Bash script in a subprocess with the provided arguments. Runs the buildRuntime Bash script in a subprocess with the provided arguments.
@@ -17,6 +35,8 @@ class AdapterContractHandler:
:param adapter_contract: The contract name to be passed to the script. :param adapter_contract: The contract name to be passed to the script.
:param signature: The constructor signature to be passed to the script. :param signature: The constructor signature to be passed to the script.
:param args: The constructor arguments to be passed to the script. :param args: The constructor arguments to be passed to the script.
:return: The path to the contract file.
""" """
script_path = "scripts/buildRuntime.sh" script_path = "scripts/buildRuntime.sh"
@@ -27,7 +47,7 @@ class AdapterContractHandler:
# Running the bash script with the provided arguments # Running the bash script with the provided arguments
result = subprocess.run( result = subprocess.run(
[script_path, "-c", adapter_contract, "-s", signature, "-a", args], [script_path, "-c", adapter_contract, "-s", signature, "-a", args],
cwd=src_path, cwd=self.src_path,
capture_output=True, capture_output=True,
text=True, text=True,
check=True, check=True,
@@ -38,6 +58,8 @@ class AdapterContractHandler:
if result.stderr: if result.stderr:
print("Errors:\n", result.stderr) print("Errors:\n", result.stderr)
return self.find_contract(adapter_contract)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
print("Error Output:\n", e.stderr) print("Error Output:\n", e.stderr)

View File

@@ -33,7 +33,7 @@ from models import (
ProtocolComponentWithTestConfig, ProtocolComponentWithTestConfig,
ProtocolComponentExpectation, ProtocolComponentExpectation,
) )
from adapter_handler import AdapterContractHandler from adapter_builder import AdapterContractBuilder
from evm import get_token_balance, get_block_header from evm import get_token_balance, get_block_header
from tycho import TychoRunner from tycho import TychoRunner
from utils import build_snapshot_message, token_factory from utils import build_snapshot_message, token_factory
@@ -76,7 +76,9 @@ class TestRunner:
) )
self.config: IntegrationTestsConfig = parse_config(config_path) self.config: IntegrationTestsConfig = parse_config(config_path)
self.spkg_src = os.path.join(self.repo_root, "substreams", package) self.spkg_src = os.path.join(self.repo_root, "substreams", package)
self.adapters_src = os.path.join(self.repo_root, "evm") self.adapter_contract_builder = AdapterContractBuilder(
os.path.join(self.repo_root, "evm")
)
self.tycho_runner = TychoRunner( self.tycho_runner = TychoRunner(
db_url, with_binary_logs, self.config.initialized_accounts db_url, with_binary_logs, self.config.initialized_accounts
) )
@@ -233,71 +235,62 @@ class TestRunner:
) )
failed_simulations: dict[str, list[SimulationFailure]] = dict() failed_simulations: dict[str, list[SimulationFailure]] = dict()
for _ in protocol_type_names:
adapter_contract = os.path.join( try:
self.adapters_src, adapter_contract = self.adapter_contract_builder.find_contract(
"out", self.config.adapter_contract
f"{self.config.adapter_contract}.sol",
f"{self.config.adapter_contract}.evm.runtime",
) )
if not os.path.exists(adapter_contract): except FileNotFoundError:
print("Adapter contract not found. Building it ...") adapter_contract = self.adapter_contract_builder.build_target(
self.config.adapter_contract,
AdapterContractHandler.build_target( self.config.adapter_build_signature,
self.adapters_src, self.config.adapter_build_args,
self.config.adapter_contract,
self.config.adapter_build_signature,
self.config.adapter_build_args,
)
decoder = ThirdPartyPoolTychoDecoder(
token_factory_func=self._token_factory_func,
adapter_contract=adapter_contract,
minimum_gas=0,
trace=self._vm_traces,
) )
snapshot_message: Snapshot = build_snapshot_message( decoder = ThirdPartyPoolTychoDecoder(
protocol_states, protocol_components, contract_states token_factory_func=self._token_factory_func,
) adapter_contract=adapter_contract,
minimum_gas=0,
trace=self._vm_traces,
)
decoded = decoder.decode_snapshot(snapshot_message, block) snapshot_message: Snapshot = build_snapshot_message(
protocol_states, protocol_components, contract_states
)
for pool_state in decoded.values(): decoded = decoder.decode_snapshot(snapshot_message, block)
pool_id = pool_state.id_
if not pool_state.balances: for pool_state in decoded.values():
raise ValueError(f"Missing balances for pool {pool_id}") pool_id = pool_state.id_
for sell_token, buy_token in itertools.permutations( if not pool_state.balances:
pool_state.tokens, 2 raise ValueError(f"Missing balances for pool {pool_id}")
): for sell_token, buy_token in itertools.permutations(pool_state.tokens, 2):
# Try to sell 0.1% of the protocol balance # Try to sell 0.1% of the protocol balance
sell_amount = ( sell_amount = Decimal("0.001") * pool_state.balances[sell_token.address]
Decimal("0.001") * pool_state.balances[sell_token.address] try:
amount_out, gas_used, _ = pool_state.get_amount_out(
sell_token, sell_amount, buy_token
) )
try: print(
amount_out, gas_used, _ = pool_state.get_amount_out( f"Amount out for {pool_id}: {sell_amount} {sell_token} -> {amount_out} {buy_token} - "
sell_token, sell_amount, buy_token f"Gas used: {gas_used}"
)
except Exception as e:
print(
f"Error simulating get_amount_out for {pool_id}: {sell_token} -> {buy_token}. "
f"Error: {e}"
)
if pool_id not in failed_simulations:
failed_simulations[pool_id] = []
failed_simulations[pool_id].append(
SimulationFailure(
pool_id=pool_id,
sell_token=str(sell_token),
buy_token=str(buy_token),
error=str(e),
) )
print( )
f"Amount out for {pool_id}: {sell_amount} {sell_token} -> {amount_out} {buy_token} - " continue
f"Gas used: {gas_used}"
)
except Exception as e:
print(
f"Error simulating get_amount_out for {pool_id}: {sell_token} -> {buy_token}. "
f"Error: {e}"
)
if pool_id not in failed_simulations:
failed_simulations[pool_id] = []
failed_simulations[pool_id].append(
SimulationFailure(
pool_id=pool_id,
sell_token=str(sell_token),
buy_token=str(buy_token),
error=str(e),
)
)
continue
return failed_simulations return failed_simulations
@staticmethod @staticmethod