Files
backend/src/dexorder/runner.py
2024-03-14 00:47:16 -04:00

404 lines
18 KiB
Python

import asyncio
import logging
from asyncio import Queue
from typing import Any, Iterable, Callable
from web3.exceptions import LogTopicError, MismatchedABI
# noinspection PyPackageRequirements
from websockets.exceptions import ConnectionClosedError
from dexorder import Blockchain, db, current_pub, async_yield, current_w3, config, NARG
from dexorder.base.chain import current_chain, current_clock, BlockClock
from dexorder.base.fork import current_fork, Fork, DisjointFork
from dexorder.blockchain.connection import create_w3_ws, create_w3
from dexorder.blockstate import BlockState, current_blockstate
from dexorder.blockstate.diff import DiffEntryItem
from dexorder.database.model import Block
from dexorder.database.model.block import current_block, latest_block
from dexorder.progressor import BlockProgressor
from dexorder.util import hexstr
from dexorder.util.async_util import maywait, Maywaitable
from dexorder.util.shutdown import fatal
log = logging.getLogger(__name__)
class Retry (Exception): ...
# todo detect reorgs and generate correct onHeadUpdate set by unioning the changes along the two forks, not including their common ancestor deltas
class BlockStateRunner(BlockProgressor):
"""
NOTE: This doc is old and not strictly true but still has the basic idea
1. load root stateBlockchain
a. if no root, init from head
b. if root is old, batch forward by height
2. discover new heads
2b. find in-state parent block else use root
3. set the current fork = ancestor->head diff state
4. query blockchain eventlogs
5. process new vaults
6. process new orders and cancels
a. new pools
7. process Swap events and generate pool prices
8. process price horizons
9. process token movement
10. process swap triggers (zero constraint tranches)
11. process price tranche triggers
12. process horizon tranche triggers
13. filter by time tranche triggers
14. bundle execution requests and send tx. tx has require(block<deadline) todo execute deadlines
15. on tx confirmation, the block height of all executed trigger requests is set to the tx block
Most of these steps, the ones handling events, are set up in main.py so that datamain.py can also use Runner for its own purposes
"""
def __init__(self, state: BlockState = None, *, publish_all=None, timer_period: float = 1):
"""
If state is None, then it is initialized as empty using the first block seen as the root block. Then the second block begins log event handling.
"""
super().__init__()
self.state = state
# onStateInit callbacks are invoked after the initial state is loaded or created
self.on_state_init: list[Callable[[],Maywaitable[None]]] = []
self.state_initialized = False
# onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root
self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],Maywaitable[None]]] = []
# onPromotion callbacks are invoked with a list of DiffItems used to advance the root state
self.on_promotion: list[Callable[[Block,list[DiffEntryItem]],Maywaitable[None]]] = []
self.publish_all: Callable[[Iterable[tuple[str,str,Any]]],Maywaitable[None]] = publish_all
self.timer_period = timer_period
self.queue: Queue = Queue()
self.max_height_seen = config.backfill
self.running = False
async def run(self):
# this run() process discovers new heads and puts them on a queue for the worker to process. the discovery is ether websockets or polling
if self.state:
self.max_height_seen = max(self.max_height_seen, self.state.root_block.height)
self.running = True
return await (self.run_polling() if config.polling > 0 or not config.ws_url else self.run_ws())
async def run_ws(self):
w3ws = await create_w3_ws()
chain_id = await w3ws.eth.chain_id
chain = Blockchain.for_id(chain_id)
current_chain.set(chain)
# this run() process discovers new heads and puts them on a queue for the worker to process
_worker_task = asyncio.create_task(self.worker())
while self.running:
try:
async with w3ws as w3ws:
log.debug('connecting to ws provider')
await w3ws.provider.connect()
subscription = await w3ws.eth.subscribe('newHeads') # the return value of this call is not consistent between anvil/hardhat/rpc. do not use it.
log.debug(f'subscribed to newHeads {subscription}')
while self.running:
async for message in w3ws.ws.process_subscriptions():
head = message['result']
log.debug(f'detected new block {head["number"]} {hexstr(head["hash"])}')
await self.add_head(head["hash"])
if not self.running:
break
await async_yield()
except (ConnectionClosedError, TimeoutError, asyncio.TimeoutError) as e:
log.debug(f'runner timeout {e}')
finally:
# noinspection PyBroadException
try:
# noinspection PyUnresolvedReferences
await w3ws.provider.disconnect()
except Exception:
pass
log.debug('yield')
log.debug('runner run_ws() exiting')
async def run_polling(self):
"""
Hardhat websocket stops sending messages after about 5 minutes.
https://github.com/NomicFoundation/hardhat/issues/2053
So we implement polling as a workaround.
"""
w3 = await create_w3()
chain_id = await w3.eth.chain_id
chain = Blockchain.for_id(chain_id)
current_chain.set(chain)
_worker_task = asyncio.create_task(self.worker())
prev_blockhash = None
while self.running:
try:
# polling mode is used primarily because Hardhat fails to deliver newHeads events after about an hour
# unfortunately, hardhat also stops responding to eth_getBlockByHash. so instead, we use the standard (stupid)
# 'latest' polling for blocks, and we push the entire block to the queue since apparently this is the only
# rpc call Hardhat seems to consistently support. The worker must then detect the type of object pushed to the
# work queue and either use the block directly or query for the block if the queue object is a hashcode.
block = await w3.eth.get_block('latest')
head = block['hash']
if head != prev_blockhash:
prev_blockhash = head
log.debug(f'polled new block {hexstr(head)}')
await self.add_head(block)
if not self.running:
break
await asyncio.sleep(config.polling)
except (ConnectionClosedError, TimeoutError, asyncio.TimeoutError) as e:
log.debug(f'runner timeout {e}')
finally:
# noinspection PyBroadException
try:
# noinspection PyUnresolvedReferences
await w3.provider.disconnect()
except Exception:
pass
await async_yield()
log.debug('runner run_polling() exiting')
async def add_head(self, head):
"""
head can either be a full block-data struct or simply a block hash. this method converts it to a Block
and pushes that Block onto the worker queue
"""
chain = current_chain.get()
w3 = current_w3.get()
try:
block_data = head
blockhash = block_data['hash']
parent = block_data['parentHash']
height = block_data['number']
except TypeError:
blockhash = head
response = await w3.provider.make_request('eth_getBlockByHash', [blockhash, False])
block_data:dict = response['result']
parent = bytes.fromhex(block_data['parentHash'][2:])
height = int(block_data['number'], 0)
head = Block(chain=chain.chain_id, height=height, hash=blockhash, parent=parent, data=block_data)
latest_block.set(head)
if self.state or config.backfill:
# backfill batches
start_height = self.max_height_seen
batch_size = config.batch_size if config.batch_size is not None else chain.batch_size
batch_height = start_height + batch_size - 1
while batch_height < head.height:
# the backfill is larger than a single batch, so we push intermediate head blocks onto the queue
response = await w3.provider.make_request('eth_getBlockByNumber', [hex(batch_height), False])
block_data: dict = response['result']
blockhash = bytes.fromhex(block_data['hash'][2:])
parent = bytes.fromhex(block_data['parentHash'][2:])
height = int(block_data['number'], 0)
assert height == batch_height
block = Block(chain=chain.chain_id, height=height, hash=blockhash, parent=parent, data=block_data)
log.debug(f'enqueueing batch backfill from {start_height} through {batch_height}')
await self.queue.put(block) # add an intermediate block
self.max_height_seen = height
start_height += chain.batch_size
batch_height += chain.batch_size
if self.queue.qsize() > 2:
await asyncio.sleep(1)
else:
await async_yield()
await self.queue.put(head) # add the head block
self.max_height_seen = head.height
async def worker(self):
try:
log.debug(f'runner worker started')
w3 = current_w3.get()
chain = current_chain.get()
assert chain.chain_id == await w3.eth.chain_id
current_clock.set(BlockClock())
prev_head = None
while self.running:
try:
if self.timer_period:
async with asyncio.timeout(self.timer_period):
head = await self.queue.get()
else:
head = await self.queue.get()
except TimeoutError:
# 1 second has passed without a new head. Run the postprocess callbacks to check for activated time-based triggers
if prev_head is not None:
await self.handle_time_tick(prev_head)
else:
try:
await self.handle_head(chain, head, w3)
prev_head = head
except Retry:
pass
except Exception as x:
log.exception(x)
except Exception:
log.exception('exception in runner worker')
raise
finally:
log.debug('runner worker exiting')
async def handle_head(self, chain, block, w3):
log.debug(f'handle_head {block.height} {hexstr(block.hash)}')
session = None
batches = []
try:
if self.state is not None and block.hash in self.state.by_hash:
log.debug(f'block {block.hash} was already processed')
return
if self.state is None:
# initialize
self.state = BlockState(block)
current_blockstate.set(self.state)
fork: Fork = Fork([block.hash], height=block.height)
log.info('Created new empty root state')
else:
fork = self.state.add_block(block)
if fork is None:
log.debug(f'discarded late-arriving head {block}')
else:
batches: list
from_height = self.state.by_hash[fork.parent].height if fork.parent is not None else fork.height
to_height = fork.height
if fork.disjoint:
batches = await self.get_backfill_batches(from_height, to_height, w3)
else:
# event callbacks are triggered in the order in which they're registered. the events passed to
# each callback are in block transaction order
for callback, event, log_filter in self.events:
if log_filter is None:
batches.append((None, callback, event, None))
else:
# todo use head['logsBloom'] to skip unnecessary log queries
lf = dict(log_filter)
lf['blockHash'] = hexstr(block.hash)
get_logs = w3.eth.get_logs(lf)
if not config.parallel_logevent_queries:
get_logs = await get_logs
batches.append((get_logs, callback, event, log_filter))
for callback in self.postprocess_cbs:
batches.append((None, callback, None, None))
# set up for callbacks
current_block.set(block)
current_fork.set(fork)
session = db.session
session.begin()
session.add(block)
pubs = []
current_pub.set(lambda room, evnt, *args: pubs.append((room, evnt, args))) # used by handle_vault_created
if not self.state_initialized:
await self.do_state_init_cbs()
await self.invoke_callbacks(batches)
# todo
# IMPORTANT! check for a reorg and generate a reorg diff list. the diff list we need is the union of the set of keys touched by either
# branch. Then we query all the values for those keys and apply that kv list to redis. This will make sure that any orphaned data that
# isn't updated by the new fork is still queried from the root state to overwrite any stale data from the abandoned branch.
diff_items = self.state.diffs_by_hash[block.hash]
for callback in self.on_head_update:
# noinspection PyCallingNonCallable
await maywait(callback(block, diff_items))
# check for root promotion
confirm_offset = (config.confirms if config.confirms is not None else chain.confirms) - 1
promotion_height = latest_block.get().height - confirm_offset
new_root_fork = None
if fork.disjoint:
fork: DisjointFork
# individually check the fork's head and ancestor
if fork.height <= promotion_height:
new_root_fork = fork
else:
state = current_blockstate.get()
parent_block = fork.root
if parent_block.height <= promotion_height:
new_root_fork = state.fork(parent_block)
else:
fork: Fork
# non-disjoint, contiguous fork
if fork.height <= promotion_height:
new_root_fork = fork
else:
new_root_fork = fork.for_height(promotion_height)
if new_root_fork:
log.debug(f'promoting root {new_root_fork.height} {hexstr(new_root_fork.hash)}')
diff_items = self.state.promote_root(new_root_fork)
for callback in self.on_promotion:
# todo try/except for known retryable errors
# noinspection PyCallingNonCallable
await maywait(callback(self.state.root_block, diff_items))
# publish messages
if pubs and self.publish_all:
# noinspection PyCallingNonCallable
await maywait(self.publish_all(pubs))
except: # legitimately catch EVERYTHING because we re-raise
log.debug('rolling back session')
if session is not None:
session.rollback()
if block.hash is not None and self.state is not None:
self.state.delete_block(block.hash)
if config.parallel_logevent_queries:
for get_logs, *_ in batches:
if get_logs is not None:
# noinspection PyBroadException
try:
await get_logs
except Exception:
log.exception('exception while querying logs')
raise
else:
if session is not None:
session.commit()
log.info(f'completed block {block}')
finally:
if session is not None:
session.close()
async def handle_time_tick(self, block):
if current_blockstate.get() is None:
return
fork = self.state.fork(block)
current_block.set(block)
current_fork.set(fork)
session = db.session
session.begin()
try:
for callback in self.postprocess_cbs:
# noinspection PyCallingNonCallable
await maywait(callback())
except:
session.rollback()
raise
else:
session.commit()
finally:
if session is not None:
session.close()
async def do_state_init_cbs(self):
if self.state_initialized:
return
for cb in self.on_state_init:
# noinspection PyCallingNonCallable
await maywait(cb())
self.state_initialized = True