block cache fixes; refactor transaction signing by account

This commit is contained in:
Tim
2024-07-16 21:40:25 -04:00
parent 43a2515a6d
commit e7ec80fdd8
6 changed files with 130 additions and 127 deletions

View File

@@ -13,9 +13,14 @@ from dexorder import NARG, config, current_w3
# call but must instead use a factory :( # call but must instead use a factory :(
class Account (LocalAccount): class Account (LocalAccount):
@staticmethod
def get_named(account_name: str) -> Optional['Account']:
account = config.accounts.get(account_name)
return Account.get(account) if account else Account.get()
@staticmethod @staticmethod
# noinspection PyInitNewSignature # noinspection PyInitNewSignature
def get(account:[Union,str]=NARG) -> Optional[LocalAccount]: def get(account:[Union,str]=NARG) -> Optional['Account']:
if account is NARG: if account is NARG:
account = config.account account = config.account
if type(account) is not str: if type(account) is not str:

View File

@@ -20,8 +20,6 @@ from datetime import timedelta
from dexorder import config, blockchain, current_w3, now, ADDRESS_0 from dexorder import config, blockchain, current_w3, now, ADDRESS_0
from dexorder.bin.executable import execute from dexorder.bin.executable import execute
from dexorder.blockchain.connection import create_w3 from dexorder.blockchain.connection import create_w3
from dexorder.blockstate import current_blockstate
from dexorder.blockstate.state import FinalizedBlockState
from dexorder.contract import get_deployment_address, ContractProxy, ERC20 from dexorder.contract import get_deployment_address, ContractProxy, ERC20
from dexorder.metadata import generate_metadata, init_generating_metadata from dexorder.metadata import generate_metadata, init_generating_metadata
from dexorder.pools import get_pool from dexorder.pools import get_pool
@@ -148,14 +146,14 @@ async def main():
log.debug(f'Mirroring tokens') log.debug(f'Mirroring tokens')
txs = [] txs = []
for t in tokens: for t in tokens:
info = await get_token_info(t)
try: try:
info = await get_token_info(t)
# anvil had trouble estimating the gas, so we hardcode it. # anvil had trouble estimating the gas, so we hardcode it.
tx = await mirrorenv.transact.mirrorToken(info, gas=1_000_000) tx = await mirrorenv.transact.mirrorToken(info, gas=1_000_000)
txs.append(tx.wait())
except Exception: except Exception:
log.exception(f'Failed to mirror token {t}') log.exception(f'Failed to mirror token {t}')
exit(1) exit(1)
txs.append(tx.wait())
results = await asyncio.gather(*txs) results = await asyncio.gather(*txs)
if any(result['status'] != 1 for result in results): if any(result['status'] != 1 for result in results):
log.error('Mirroring a token reverted.') log.error('Mirroring a token reverted.')
@@ -196,17 +194,17 @@ async def main():
while True: while True:
wake_up = now() + delay wake_up = now() + delay
# log.debug(f'querying {pool}') # log.debug(f'querying {pool}')
price = await get_pool_price(pool) try:
if price != last_prices.get(pool): price = await get_pool_price(pool)
try: if price != last_prices.get(pool):
# anvil had trouble estimating the gas, so we hardcode it. # anvil had trouble estimating the gas, so we hardcode it.
tx = await mirrorenv.transact.updatePool(pool, price, gas=1_000_000) # this is a B.S. gas number tx = await mirrorenv.transact.updatePool(pool, price, gas=1_000_000) # this is a B.S. gas number
await tx.wait() await tx.wait()
last_prices[pool] = price last_prices[pool] = price
log.debug(f'Mirrored {pool} {price}') log.debug(f'Mirrored {pool} {price}')
except Exception: except Exception:
log.debug(f'Could not update {pool}') log.debug(f'Could not update {pool}')
continue continue
try: try:
pool = next(pool_iter) pool = next(pool_iter)
except StopIteration: except StopIteration:

View File

@@ -7,14 +7,13 @@ Use `await fetch_block()` to force an RPC query for the Block, adding that block
""" """
import logging import logging
from contextvars import ContextVar from contextvars import ContextVar
from typing import Union, Optional from typing import Union, Optional, Awaitable
from cachetools import LRUCache from cachetools import LRUCache
from dexorder import current_w3, NARG, config from dexorder import current_w3, config
from dexorder.base.block import Block, BlockInfo from dexorder.base.block import Block, BlockInfo
from dexorder.base.chain import current_chain from dexorder.base.chain import current_chain
from dexorder.util.async_dict import AsyncDict
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -29,28 +28,17 @@ async def get_block_timestamp(blockid: Union[bytes,int], block_number: int = Non
return block.timestamp return block.timestamp
async def _cache_fetch(key: tuple[int, Union[int,bytes]], default: Union[Block, NARG]) -> Optional[Block]: async def _fetch(key: tuple[int, Union[int,bytes]]) -> Optional[Block]:
assert default is NARG
# try LRU cache first
try:
return _lru[key]
except KeyError:
pass
# fetch from RPC # fetch from RPC
chain_id, blockid = key chain_id, blockid = key
# log.debug(f'block cache miss; fetching {chain_id} {blockid}') # log.debug(f'block cache miss; fetching {chain_id} {blockid}')
if type(blockid) is int: if type(blockid) is int:
result = await fetch_block_by_number(blockid, chain_id=chain_id) return await fetch_block_by_number(blockid, chain_id=chain_id)
else: else:
result = await fetch_block(blockid, chain_id=chain_id) return await fetch_block(blockid, chain_id=chain_id)
if result is None:
# log.debug(f'Could not lookup block {blockid}')
return None # do not cache
_lru[key] = result
return result
_lru = LRUCache[tuple[int, bytes], Block](maxsize=128) _lru = LRUCache[tuple[int, bytes], Block](maxsize=128)
_cache = AsyncDict[tuple[int, bytes], Block](fetch=_cache_fetch) _fetches:dict[tuple[int, bytes], Awaitable[Block]] = {}
def cache_block(block: Block): def cache_block(block: Block):
@@ -60,7 +48,21 @@ def cache_block(block: Block):
async def get_block(blockhash, *, chain_id=None) -> Block: async def get_block(blockhash, *, chain_id=None) -> Block:
if chain_id is None: if chain_id is None:
chain_id = current_chain.get().id chain_id = current_chain.get().id
return await _cache.get((chain_id, blockhash)) key = chain_id, blockhash
# try LRU cache first
try:
return _lru[key]
except KeyError:
pass
# check if another thread is already fetching
fetch = _fetches.get(key)
if fetch is not None:
return await fetch
# otherwise initiate our own fetch
fetch = _fetches[key] = _fetch(key)
result = await fetch
del _fetches[key]
return result
async def fetch_block_by_number(height: int, *, chain_id=None) -> Block: async def fetch_block_by_number(height: int, *, chain_id=None) -> Block:

View File

@@ -4,9 +4,9 @@ from typing import Optional
import eth_account import eth_account
from web3.exceptions import Web3Exception from web3.exceptions import Web3Exception
from web3.types import TxReceipt from web3.types import TxReceipt, TxData
from dexorder import current_w3 from dexorder import current_w3, Account
from dexorder.base.account import current_account from dexorder.base.account import current_account
from dexorder.blockstate.fork import current_fork from dexorder.blockstate.fork import current_fork
from dexorder.util import hexstr from dexorder.util import hexstr
@@ -15,11 +15,17 @@ log = logging.getLogger(__name__)
class ContractTransaction: class ContractTransaction:
def __init__(self, id_bytes: bytes, rawtx: Optional[bytes] = None): def __init__(self, tx: TxData):
self.id_bytes = id_bytes # This is the standard RPC transaction dictionary
self.id = hexstr(self.id_bytes) self.tx = tx
self.data = rawtx
self.receipt: Optional[TxReceipt] = None # These three fields are populated only after signing
self.id_bytes: Optional[bytes] = None
self.id: Optional[str] = None
self.data: Optional[bytes] = None
# This field is populated only after the transaction has been mined
self.receipt: Optional[TxReceipt] = None # todo could be multiple receipts for different branches!
def __repr__(self): def __repr__(self):
# todo this is from an old status system # todo this is from an old status system
@@ -31,6 +37,14 @@ class ContractTransaction:
self.receipt = await current_w3.get().eth.wait_for_transaction_receipt(self.id) self.receipt = await current_w3.get().eth.wait_for_transaction_receipt(self.id)
return self.receipt return self.receipt
async def sign(self, account: Account):
self.tx['from'] = account.address
self.tx['nonce'] = await account.next_nonce()
signed = eth_account.Account.sign_transaction(self.tx, private_key=account.key)
self.data = signed['rawTransaction']
self.id_bytes = signed['hash']
self.id = hexstr(self.id_bytes)
class DeployTransaction (ContractTransaction): class DeployTransaction (ContractTransaction):
def __init__(self, contract: 'ContractProxy', id_bytes: bytes): def __init__(self, contract: 'ContractProxy', id_bytes: bytes):
@@ -62,27 +76,25 @@ def call_wrapper(addr, name, func):
def transact_wrapper(addr, name, func): def transact_wrapper(addr, name, func):
async def f(*args, **kwargs): async def f(*args, **kwargs):
try: try:
tx_id = await func(*args).transact(kwargs) tx = await func(*args).build_transaction(kwargs)
ct = ContractTransaction(tx)
account = Account.get()
if account is None:
raise ValueError(f'No account to sign transaction {addr}.{name}()')
await ct.sign(account)
tx_id = await current_w3.get().eth.send_raw_transaction(ct.data)
assert tx_id == ct.id_bytes
return ct
except Web3Exception as e: except Web3Exception as e:
e.args += addr, name e.args += addr, name
raise e raise e
return ContractTransaction(tx_id)
return f return f
def build_wrapper(addr, name, func): def build_wrapper(addr, name, func):
async def f(*args, **kwargs): async def f(*args, **kwargs):
try:
account = current_account.get()
except LookupError:
account = None
if account is None:
raise RuntimeError(f'Cannot invoke transaction {addr}.{name}() without setting an Account.')
tx = await func(*args).build_transaction(kwargs) tx = await func(*args).build_transaction(kwargs)
tx['from'] = account.address return ContractTransaction(tx)
tx['nonce'] = await account.next_nonce()
signed = eth_account.Account.sign_transaction(tx, private_key=account.key)
return ContractTransaction(signed['hash'], signed['rawTransaction'])
return f return f

View File

@@ -1,11 +1,12 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Optional
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import select from sqlalchemy import select
from web3.exceptions import TransactionNotFound from web3.exceptions import TransactionNotFound
from dexorder import db, current_w3 from dexorder import db, current_w3, Account
from dexorder.base import TransactionReceiptDict from dexorder.base import TransactionReceiptDict
from dexorder.base.chain import current_chain from dexorder.base.chain import current_chain
from dexorder.base.order import TransactionRequest from dexorder.base.order import TransactionRequest
@@ -14,6 +15,7 @@ from dexorder.blockstate.diff import DiffEntryItem
from dexorder.blockstate.fork import current_fork, Fork from dexorder.blockstate.fork import current_fork, Fork
from dexorder.contract.contract_proxy import ContractTransaction from dexorder.contract.contract_proxy import ContractTransaction
from dexorder.database.model.transaction import TransactionJob, TransactionJobState from dexorder.database.model.transaction import TransactionJob, TransactionJobState
from dexorder.util.shutdown import fatal
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -26,6 +28,7 @@ class TransactionHandler:
return TransactionHandler.instances[tag] return TransactionHandler.instances[tag]
def __init__(self, tag): def __init__(self, tag):
self.tag = tag
TransactionHandler.instances[tag] = self TransactionHandler.instances[tag] = self
@abstractmethod @abstractmethod
@@ -44,50 +47,45 @@ def submit_transaction_request(tr: TransactionRequest):
async def create_and_send_transactions(): async def create_and_send_transactions():
""" called by the Runner after the events have all been processed and the db committed """ """ called by the Runner after the events have all been processed and the db committed """
await create_transactions()
await send_transactions()
async def create_transactions():
for job in db.session.query(TransactionJob).filter( for job in db.session.query(TransactionJob).filter(
TransactionJob.chain == current_chain.get(), TransactionJob.chain == current_chain.get(),
TransactionJob.state == TransactionJobState.Requested TransactionJob.state == TransactionJobState.Requested
): ):
await create_transaction(job) log.info(f'building transaction request for {job.request.__class__.__name__} {job.id}')
try:
handler = TransactionHandler.of(job.request.type)
async def create_transaction(job: TransactionJob): except KeyError:
log.info(f'building transaction request for {job.request.__class__.__name__} {job.id}') # todo remove bad request?
try: log.warning('ignoring transaction request with bad type '
handler = TransactionHandler.of(job.request.type) f'"{job.request.type}": ' + ",".join(TransactionHandler.instances.keys()))
except KeyError: else:
# todo remove bad request? ctx: ContractTransaction = await handler.build_transaction(job.id, job.request)
log.warning(f'ignoring transaction request with bad type "{job.request.type}": {",".join(TransactionHandler.instances.keys())}') if ctx is None:
else: log.warning(f'unable to send transaction for job {job.id}')
ctx: ContractTransaction = await handler.build_transaction(job.id, job.request) return
if ctx is None: w3 = current_w3.get()
log.warning(f'unable to send transaction for job {job.id}') account = Account.get_named(handler.tag)
return if account is None:
job.state = TransactionJobState.Signed # todo lazy signing account = Account.get()
job.tx_id = ctx.id_bytes if account is None:
job.tx_data = ctx.data log.error(f'No account available for transaction request type "{handler.tag}"')
db.session.add(job) continue
log.info(f'servicing transaction request {job.request.__class__.__name__} {job.id} with tx {ctx.id}') await ctx.sign(account)
job.state = TransactionJobState.Signed
# todo sign-and-send should be a single phase. if the send fails due to lack of wallet gas, or because gas price went up suddenly, job.tx_id = ctx.id_bytes
# we need to re-sign a new message with updated gas. so do not store signed messages but keep the unsigned state around until it job.tx_data = ctx.data
# is signed and sent db.session.add(job)
async def send_transactions(): log.info(f'servicing transaction request {job.request.__class__.__name__} {job.id} with tx {ctx.id}')
w3 = current_w3.get() try:
for job in db.session.query(TransactionJob).filter( sent = await w3.eth.send_raw_transaction(job.tx_data)
TransactionJob.chain == current_chain.get(), except:
TransactionJob.state == TransactionJobState.Signed log.exception(f'Failure sending transaction for job {job.id}')
): # todo pager
log.debug(f'sending transaction for job {job.id}') # todo send state unknown!
sent = await w3.eth.send_raw_transaction(job.tx_data) else:
assert sent == job.tx_id assert sent == job.tx_id
job.state = TransactionJobState.Sent job.state = TransactionJobState.Sent
db.session.add(job) db.session.add(job)
async def handle_transaction_receipts(): async def handle_transaction_receipts():

View File

@@ -2,7 +2,7 @@ import asyncio
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from asyncio import Event from asyncio import Event
from typing import TypeVar, Generic, Awaitable, Callable, Optional from typing import TypeVar, Generic, Awaitable, Callable, Optional, Any
from dexorder import NARG from dexorder import NARG
@@ -12,46 +12,47 @@ K = TypeVar('K')
V = TypeVar('V') V = TypeVar('V')
class _Query (Generic[V]): ###
def __init__ (self): ### NOT TESTED AND NOT USED
self.event = Event() ###
self.result: V = NARG
self.exception: Optional[Exception] = None
def __bool__(self):
return self.result is not NARG
class AsyncDict (Generic[K,V]): class AsyncDict (Generic[K,V]):
""" """
Implements per-key locks around accessing dictionary values. Implements per-key locks around accessing dictionary values.
Either supply fetch and store functions in the constructor, or override those methods in a subclass. Either supply fetch and store functions in the constructor, or override those methods in a subclass.
fetch(key,default) takes two arguments and when a key is missing, it may either return the default value explicitly
or raise KeyError, in which case the call wrapper will return the default value.
""" """
def __init__(self, def __init__(self,
fetch: Callable[[K,V], Awaitable[V]] = None, fetch: Callable[[K,V], Awaitable[V]] = None,
store: Callable[[K,V], Awaitable[V]] = None, store: Callable[[K,V], Awaitable[Any]] = None,
): ):
self._queries: dict[K,_Query[V]] = {} self._queries: dict[K, tuple[bool,Awaitable]] = {} # bool indicates if it's a write (True) or a read (False)
if fetch is not None: if fetch is not None:
self.fetch = fetch self.fetch = fetch
if store is not None: if store is not None:
self.store = store self.store = store
async def get(self, key: K, default: V = NARG) -> V: async def get(self, key: K, default: V = NARG) -> V:
query = self._queries.get(key) found = self._queries.get(key)
if query is None: if found is not None:
return await self._query(key, self.fetch(key, default)) write, query = found
else: result = await query
await query.event.wait() if not write:
if query.exception is not None: return result
raise query.exception # either there was no query or it was a write query that's over
return query.result query = self.fetch(key, default)
self._queries[key] = False, query
return await query
async def set(self, key: K, value: V): async def set(self, key: K, value: V):
query = self._queries.get(K) found = self._queries.get(key)
if query is not None: if found is not None:
await query.event.wait() write, query = found
await self._query(key, self.store(key, value)) await query
query = self.store(key, value)
self._queries[key] = True, query
await query
# noinspection PyMethodMayBeStatic,PyUnusedLocal # noinspection PyMethodMayBeStatic,PyUnusedLocal
@abstractmethod @abstractmethod
@@ -65,16 +66,3 @@ class AsyncDict (Generic[K,V]):
Must return the value that was just set. Must return the value that was just set.
""" """
raise NotImplementedError raise NotImplementedError
async def _query(self, key: K, coro: Awaitable[V]) -> V:
assert key not in self._queries
query = _Query()
self._queries[key] = query
try:
query.result = await coro
except Exception as e:
query.exception = e
finally:
del self._queries[key]
query.event.set()
return query.result