reworked blockdict and triggerrunner

This commit is contained in:
Tim Olson
2023-09-28 18:29:39 -04:00
parent 2b72decf7b
commit f22841e93e
7 changed files with 280 additions and 210 deletions

View File

@@ -1,3 +1,5 @@
from decimal import Decimal as dec
# NARG is used in argument defaults to mean "not specified" rather than "specified as None"
class _NARG:
def __bool__(self): return False

View File

@@ -3,12 +3,14 @@ import logging
from collections import defaultdict
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Union, TypeVar, Generic, Any, Optional
from enum import Enum
from typing import Union, TypeVar, Generic, Any, Optional, Iterable
from sortedcontainers import SortedList
from dexorder import NARG
from dexorder.database.model.block import Block
from dexorder.util import hexstr
log = logging.getLogger(__name__)
@@ -30,44 +32,62 @@ class DiffItem:
return f'{self.entry.hash} {self.series}.{self.key}={"[DEL]" if self.entry.value is BlockState.DELETE else self.entry.value}'
class Fork(list[bytes]):
class Fork:
"""
A Fork is an ancestor path, stored as block hashes in reverse-chronological order from the "current" block at the start to ancestors at the end. The
getitem [] operator indexes by block height for positive values, while negative value are relative to the latest block, so [-1] is the latest
block and [-2] is its parent, etc.
"""
@staticmethod
def cur() -> Optional['Fork']:
return _fork.get()
@staticmethod
def set_cur(value: Optional['Fork']):
_fork.set(value)
def __init__(self, ancestry, *, height: int):
super().__init__(ancestry)
def __init__(self, ancestry: Iterable[bytes], *, height: int):
self.ancestry = list(ancestry)
self.height = height
self.disjoint = False
def __getitem__(self, height):
def __contains__(self, item):
index = self.height - item.height
if index < 0:
return False
try:
return super().__getitem__(self.height - height if height >= 0 else -height + 1)
return self.ancestry[index] == item.hash
except IndexError:
return None
return False
@property
def hash(self):
return super().__getitem__(0)
return self.ancestry[0]
@property
def parent(self):
return super().__getitem__(1)
return self.ancestry[1]
def for_child(self, blockhash: bytes):
return Fork(self + [blockhash], height=self.height + 1)
def for_height(self, height):
""" returns a new Fork object for an older block along this fork. used for root promotion. """
assert( self.height - len(self.ancestry) < height <= self.height)
return Fork(self.ancestry[self.height-height:], height=height)
def __str__(self):
return f'{self.height}_[{"->".join(h.hex() for h in self)}]'
return f'{self.height}_[{"->".join(h.hex() for h in self.ancestry)}]'
current_fork = ContextVar[Optional[Fork]]('current_fork', default=None)
class DisjointFork:
"""
duck type of Fork for blocks that connect directly to root with a parent gap in-between
"""
def __init__(self, block: Block, root: Block):
self.height = block.height
self.hash = block.hash
self.parent = root.hash
self.disjoint = True
def __contains__(self, item):
return item.hash in (self.hash, self.parent)
def __str__(self):
return f'{self.height}_[{self.hash.hex()}->{self.parent.hex()}]'
class BlockState:
@@ -83,14 +103,6 @@ class BlockState:
by applying any diffs along the block's fork path to the root data.
"""
@staticmethod
def cur() -> 'BlockState':
return _blockstate.get()
@staticmethod
def set_cur(value: 'BlockState'):
_blockstate.set(value)
def __init__(self, root_block: Block):
self.root_block: Block = root_block
self.by_height: SortedList[Block] = SortedList(key=lambda x: x.height)
@@ -102,7 +114,7 @@ class BlockState:
self.ancestors: dict[bytes, Block] = {}
BlockState.by_chain[root_block.chain] = self
def add_block(self, block: Block) -> Union[int, Fork, None]:
def add_block(self, block: Block) -> Optional[Fork]:
"""
If block is the same age as root_height or older, it is ignored and None is returned. Otherwise, returns a Fork leading to root.
The ancestor block is set in the ancestors dictionary and any state updates to block are considered to have occured between the registered ancestor
@@ -115,63 +127,96 @@ class BlockState:
return None
if block.hash not in self.by_hash:
self.by_hash[block.hash] = block
parent = self.by_hash.get(block.parent)
self.ancestors[block.hash] = parent or self.root_block
self.by_height.add(block)
log.debug(f'new block state {block}')
parent = self.by_hash.get(block.parent)
if parent is None:
self.ancestors[block.hash] = self.root_block
return Fork(block.hash, height=block.height)
else:
self.ancestors[block.hash] = parent
return self.fork(block)
def ancestors():
b = block
while b is not self.root_block:
yield b.hash
b = self.ancestors[b.hash]
return Fork(ancestors(), height=block.height)
def delete_block(self, block: Union[Block,Fork,bytes]):
""" if there was an error during block processing, we need to remove the incomplete block data """
try:
block = block.hash
except AttributeError:
pass
try:
del self.by_hash[block]
except KeyError:
pass
try:
del self.diffs_by_hash[block]
except KeyError:
pass
try:
del self.ancestors[block]
except KeyError:
pass
def fork(self, block: Block):
if block.height - self.ancestors[block.hash].height > 1:
# noinspection PyTypeChecker
return DisjointFork(block, self.root_block)
def ancestors():
bh = block.hash
while True:
yield bh
if bh == self.root_block.hash:
return
bh = self.ancestors[bh].hash
return Fork(ancestors(), height=block.height)
def get(self, fork: Optional[Fork], series, key, default=NARG):
series_diffs = self.diffs_by_series.get(series)
if series_diffs is None:
if default is NARG:
raise ValueError('series')
raise KeyError((series,key))
else:
return default
diffs: list[DiffEntry] = series_diffs.get(key, [])
for diff in reversed(diffs):
if diff.height <= self.root_block.height or fork is not None and fork[diff.height] == diff.hash:
if diff.value is BlockState.DELETE:
break
else:
if fork[self.root_block.height] != self.root_block.hash: # todo move this assertion elsewhere so it runs once per task
raise RuntimeError
return diff.value
value = self._get_from_diffs(fork, diffs)
if value is not BlockState.DELETE:
return value
# value not found or was DELETE
if default is NARG:
raise KeyError((series, key))
return default
def set(self, fork: Optional[Fork], series, key, value):
diff = DiffEntry(value,
fork.height if fork is not None else self.root_block.height,
fork.hash if fork is not None else self.root_block.hash)
if fork is not None:
self.diffs_by_hash[fork.hash].append(DiffItem(series, key, diff))
self.diffs_by_series[series][key].add(diff)
def _get_from_diffs(self, fork, diffs):
for diff in reversed(diffs):
if diff.height <= self.root_block.height or fork is not None and diff in fork:
if diff.value is BlockState.DELETE:
break
else:
if self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task
raise ValueError(f'Cannot get value for a non-root fork {hexstr(fork.hash)}')
return diff.value
return BlockState.DELETE
def set(self, fork: Optional[Fork], series, key, value, overwrite=True):
diffs = self.diffs_by_series[series][key]
if overwrite or self._get_from_diffs(fork, diffs) != value:
diff = DiffEntry(value,
fork.height if fork is not None else self.root_block.height,
fork.hash if fork is not None else self.root_block.hash)
if fork is not None:
self.diffs_by_hash[fork.hash].append(DiffItem(series, key, diff))
diffs.add(diff)
def iteritems(self, fork: Optional[Fork], series):
for k, difflist in self.diffs_by_series.get(series, {}).items():
for diff in reversed(difflist):
if diff.height <= self.root_block.height or fork is not None and fork[diff.height] == diff.hash:
if diff.height <= self.root_block.height or fork is not None and diff in fork:
if diff.value is not BlockState.DELETE:
yield k, diff.value
break
def promote_root(self, fork: Fork):
assert all(block in self.by_hash for block in fork)
block = self.by_hash[fork.hash]
def promote_root(self, new_root_fork: Fork):
block = self.by_hash[new_root_fork.hash]
diffs = self.collect_diffs(block)
# no application of diffs to the internal state is required, just clean up
@@ -198,7 +243,7 @@ class BlockState:
# remove old diffs on abandoned forks but keep old diffs on the root fork
removals = None
for d in difflist:
if d.height <= fork.height and d.hash != fork[d.height]:
if d.height <= new_root_fork.height and d not in new_root_fork:
if removals is None:
removals = [d]
else:
@@ -207,14 +252,15 @@ class BlockState:
for r in removals:
difflist.remove(r)
# while the second-oldest diff is still root-age, pop off the oldest diff
while len(difflist) >= 2 and difflist[1].height <= fork.height:
while len(difflist) >= 2 and difflist[1].height <= new_root_fork.height:
difflist.pop(0)
# if only one diff remains, and it's old, and it's a delete, then we can actually delete the diff list
if not difflist or len(difflist) == 1 and difflist[0].value == BlockState.DELETE and difflist[0].height <= fork.height:
if not difflist or len(difflist) == 1 and difflist[0].value == BlockState.DELETE and difflist[0].height <= new_root_fork.height:
del self.diffs_by_series[s][k]
del self.by_hash[self.root_block.hash] # old root block
self.root_block = block
log.debug(f'promoted root {self.root_block}')
return diffs
def collect_diffs(self, block: Block, series_key=NARG) -> list[DiffItem]:
@@ -244,70 +290,96 @@ class BlockState:
return result
_blockstate = ContextVar[BlockState]('BlockState.cur')
_fork = ContextVar[Optional[Fork]]('fork.cur', default=None)
current_blockstate = ContextVar[BlockState]('current_blockstate')
T = TypeVar('T')
class BlockDict(Generic[T]):
class BlockData:
class Type (Enum):
SCALAR:int = 0
SET:int = 1
LIST:int = 2
DICT:int = 3
def __init__(self, series_key):
self.series_key = series_key
registry: dict[str,'BlockData'] = {} # series name and instance
def __setitem__(self, item, value):
BlockDict.setitem(self.series_key, item, value)
def __init__(self, series:str, data_type: Type):
assert series not in BlockData.registry
BlockData.registry[series] = self
self.series = series
self.type = data_type
def __getitem__(self, item):
return BlockDict.getitem(self.series_key, item)
def setitem(self, item, value, overwrite=True):
state = current_blockstate.get()
fork = current_fork.get()
state.set(fork, self.series, item, value, overwrite)
def __delitem__(self, item):
BlockDict.delitem(self.series_key, item)
def getitem(self, item, default=NARG):
state = current_blockstate.get()
fork = current_fork.get()
return state.get(fork, self.series, item, default)
def __contains__(self, item):
return BlockDict.contains(self.series_key, item)
def delitem(self, item, overwrite=True):
self.setitem(item, BlockState.DELETE, overwrite)
def add(self, item):
""" set-like semantics. the item key is added with a value of None. """
BlockDict.setitem(self.series_key, item, None)
def items(self):
return BlockDict.iter_items(self.series_key)
def get(self, item, default=None):
return BlockDict.getitem(self.series_key, item, default)
@staticmethod
def setitem(series_key, item, value):
state = BlockState.cur()
fork = Fork.cur()
state.set(fork, series_key, item, value)
@staticmethod
def getitem(series_key, item, default=NARG):
state = BlockState.cur()
fork = Fork.cur()
return state.get(fork, series_key, item, default)
@staticmethod
def delitem(series_key, item):
BlockDict.setitem(series_key, item, BlockState.DELETE)
@staticmethod
def contains(series_key, item):
def contains(self, item):
try:
BlockDict.getitem(series_key, item)
self.getitem(item)
return True
except KeyError: # getitem with no default will raise on a missing item
return False
@staticmethod
def iter_items(series_key):
state = BlockState.cur()
fork = Fork.cur()
state = current_blockstate.get()
fork = current_fork.get()
return state.iteritems(fork, series_key)
class BlockSet(Generic[T], Iterable[T], BlockData):
def __init__(self, series: str):
super().__init__(series, BlockData.Type.SET)
self.series = series
def add(self, item):
""" set-like semantics. the item key is added with a value of None. """
self.setitem(item, None, overwrite=False)
def __delitem__(self, item):
self.delitem(item, overwrite=False)
def __contains__(self, item):
return self.contains(item)
def __iter__(self):
yield from (k for k,v in self.iter_items(self.series))
class BlockDict(Generic[T], BlockData):
def __init__(self, series: str):
super().__init__(series, BlockData.Type.DICT)
def __setitem__(self, item, value):
self.setitem(item, value)
def __getitem__(self, item):
return self.getitem(item)
def __delitem__(self, item):
self.delitem(item)
def __contains__(self, item):
return self.contains(item)
def items(self):
return self.iter_items(self.series)
def get(self, item, default=None):
return self.getitem(item, default)
def _test():
def B(height, hash:str, parent):
@@ -315,7 +387,7 @@ def _test():
root_block = B(10, '#root', None )
state = BlockState(root_block)
BlockState.set_cur(state)
current_blockstate.set(state)
b11 = B(11, '#b11', parent=root_block)
f11: Fork = state.add_block(b11)
print('f11',f11)
@@ -325,34 +397,32 @@ def _test():
b12 = B(12, '#b12', parent=b11)
f12: Fork = state.add_block(b12)
print('f12',f12)
b13 = B(13, '#b13', parent=b12)
f13: Fork = state.add_block(b13)
d = BlockDict('ser')
def dump():
print()
print(Fork.cur().hash if Fork.cur() is not None else 'root')
print(current_fork.get().hash if current_fork.get() is not None else 'root')
for k,v in d.items():
print(f'{k} = {v}')
Fork.set_cur(None) # Use None to set values on root
current_fork.set(None) # Use None to set values on root
d['foo'] = 'bar'
d['test'] = 'failed'
Fork.set_cur(f11)
current_fork.set(f11)
d['foo2'] = 'bar2'
del d['test']
Fork.set_cur(f11b)
current_fork.set(f11b)
del d['foo2']
d['foob'] = 'barb'
Fork.set_cur(f12)
current_fork.set(f12)
d['test'] = 'ok'
for f in (None, f11, f11b, f12):
Fork.set_cur(f)
current_fork.set(f)
dump()
print()
@@ -363,13 +433,13 @@ def _test():
print()
print('promoting b11')
state.promote_root(f11)
Fork.set_cur(f12)
current_fork.set(f12)
dump()
print()
print('promoting b12')
state.promote_root(f12)
Fork.set_cur(f12)
current_fork.set(f12)
dump()

View File

@@ -1,5 +1,4 @@
import logging
from asyncio import CancelledError
from dexorder.bin.executable import execute
from dexorder.trigger_runner import TriggerRunner

8
src/dexorder/data.py Normal file
View File

@@ -0,0 +1,8 @@
from dexorder.base.blockstate import BlockSet, BlockDict
vault_addresses = BlockSet('v')
vault_tokens = BlockDict('vt')
underfunded_vaults = BlockSet('uv')
active_orders = BlockSet('a')
pool_prices = BlockDict('p')

View File

@@ -1,3 +1,4 @@
import logging
from contextvars import ContextVar
import sqlalchemy
@@ -7,6 +8,7 @@ from sqlalchemy.orm import Session, SessionTransaction
from .migrate import migrate_database
from .. import config
log = logging.getLogger(__name__)
_engine = ContextVar[Engine]('engine', default=None)
_session = ContextVar[Session]('session', default=None)
@@ -59,7 +61,7 @@ class Db:
connection.execute(sqlalchemy.text("SET TIME ZONE 'UTC'"))
result = connection.execute(sqlalchemy.text("select version_num from alembic_version"))
for row in result:
print(f'database revision {row[0]}')
log.info(f'database revision {row[0]}')
_engine.set(engine)
return db
raise Exception('database version not found')

View File

@@ -17,31 +17,7 @@ class Block(Base):
def __str__(self):
return f'{self.height}_{self.hash.hex()}'
@staticmethod
def cur() -> 'Block':
return _cur.get()
@staticmethod
def set_cur(value: 'Block'):
_cur.set(value)
@staticmethod
def latest() -> 'Block':
return _latest.get()
@staticmethod
def set_latest(value: 'Block'):
_latest.set(value)
@staticmethod
def completed() -> Optional['Block']:
return _completed.get()
@staticmethod
def set_completed(value: Optional['Block']):
_completed.set(value)
_cur = ContextVar[Block]('Block.cur') # block for the current thread
_latest = ContextVar[Block]('Block.latest') # most recent discovered but may not be processed yet
_completed = ContextVar[Block]('Block.completed') # most recent fully-processed block
current_block = ContextVar[Block]('Block.cur') # block for the current thread
latest_block = ContextVar[Block]('Block.latest') # most recent discovered but may not be processed yet
completed_block = ContextVar[Block]('Block.completed') # most recent fully-processed block

View File

@@ -1,35 +1,36 @@
import asyncio
import logging
from typing import Callable, Union
from web3 import AsyncWeb3
from web3.contract.contract import ContractEvents
from web3.exceptions import LogTopicError
from web3.types import EventData
from dexorder import Blockchain, db, blockchain
from dexorder.base.blockstate import BlockState, BlockDict
from dexorder.blockchain.connection import create_w3_ws, W3
from dexorder import Blockchain, db, blockchain, NARG, dec
from dexorder.base.blockstate import BlockState, BlockDict, Fork, DiffItem, BlockSet, current_blockstate, current_fork
from dexorder.blockchain.connection import create_w3_ws
from dexorder.blockchain.util import get_contract_data
from dexorder.data import pool_prices, vault_tokens, underfunded_vaults, vault_addresses
from dexorder.database.model import Block
from dexorder.database.model.vault_tokens import VaultToken
from dexorder.database.model.block import current_block, latest_block
from dexorder.util import hexstr, topic
log = logging.getLogger(__name__)
vault_addresses = BlockDict('v')
underfunded_vaults = BlockDict('ufv')
active_orders = BlockDict('a')
pool_prices = BlockDict('p')
wallets = BlockDict('wallets') # todo remove debug
# todo detect reorgs and generate correct onHeadUpdate set by unioning the changes along the two forks, not including their common ancestor deltas
class TriggerRunner:
def __init__(self):
self.root_age = 10 # todo set per chain
self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = []
# onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root
self.on_head_update: list[Callable[[Block,list[DiffItem]],None]] = []
# onPromotion callbacks are invoked with a list of DiffItems used to advance the root state
self.on_promotion: list[Callable[[Block,list[DiffItem]],None]] = []
async def run(self):
"""
1. load root stateBlockchain
@@ -67,91 +68,103 @@ class TriggerRunner:
while True:
async for head in w3ws.listen_to_websocket():
session = None
fork = None
try:
log.debug('head', head)
log.debug(f'head {head["hash"]}')
# block_data = await w3.eth.get_block(head['hash'], True)
block_data = (await w3.provider.make_request('eth_getBlockByHash',[hexstr(head['hash']),False]))['result']
block = Block(chain=chain_id, height=int(block_data['number'],0),
hash=bytes.fromhex(block_data['hash'][2:]), parent=bytes.fromhex(block_data['parentHash'][2:]), data=block_data)
block.set_latest(block)
block.set_cur(block)
latest_block.set(block)
fork = NARG
if state is None:
state = BlockState(block, {})
BlockState.set_cur(state)
# initialize
state = BlockState(block)
current_blockstate.set(state)
self.setup_triggers(w3)
log.info('Created new empty root state')
else:
ancestor = BlockState.cur().add_block(block)
if ancestor is None:
fork = state.add_block(block)
if fork is None:
log.debug(f'discarded late-arriving head {block}')
elif type(ancestor) is int:
# todo backfill batches
log.error(f'backfill unimplemented for range {ancestor} to {block}')
else:
futures = []
for callback, event, log_filter in self.events:
log_filter['blockhash'] = w3.to_hex(block.hash)
futures.append(w3.eth.get_logs(log_filter))
results = await asyncio.gather(*futures)
if session is None:
session = db.session
session.begin()
session.add(block)
for result, (callback,event,filter_args) in zip(results,self.events):
for log_event in result:
callback(log_event)
# check for root promotion
if block.height - state.root_block.height > self.root_age:
b = block
try:
for _ in range(1, self.root_age):
# we walk backwards self.root_age and promote what's there
b = state.by_hash[b.parent]
except KeyError:
pass
if fork.disjoint:
# todo backfill batches
from_height = state.root_block.height + 1
log.error(f'backfill unimplemented for range {from_height} to {block}')
exit(1)
else:
log.debug(f'promoting root {b}')
state.promote_root(b)
# event callbacks are triggered in the order in which they're registered. the events passed to
# each callback are in block transaction order
for callback, event, log_filter in self.events:
log_filter['blockhash'] = w3.to_hex(block.hash)
futures.append(w3.eth.get_logs(log_filter))
# set up for callbacks
current_block.set(block)
current_fork.set(fork)
session = db.session # todo move session creation to here?
session.begin()
session.add(block)
# callbacks
for future, (callback,event,filter_args) in zip(futures,self.events):
for log_event in await future:
try:
parsed = event.process_log(log_event)
except LogTopicError:
pass
else:
# todo try/except for known retryable errors
callback(parsed)
# todo check for reorg and generate a reorg diff list
diff_items = state.diffs_by_hash[block.hash]
for callback in self.on_head_update:
callback(block, diff_items)
# check for root promotion
promotion_height = fork.height - self.root_age
if not fork.disjoint and promotion_height > state.root_block.height:
diff_items = state.promote_root(fork.for_height(promotion_height))
for callback in self.on_promotion:
# todo try/except for known retryable errors
callback(state.root_block, diff_items)
except:
if session is not None:
session.rollback()
if fork is not None:
state.delete_block(fork)
raise
else:
if session is not None:
session.commit()
def handle_transfer(self, event):
w3 = W3.cur()
try:
transfer = w3.eth.contract(abi=get_contract_data('ERC20')['abi']).events.Transfer().process_log(event)
except LogTopicError:
return
@staticmethod
def handle_transfer(transfer: EventData):
to_address = transfer['args']['to']
print('transfer', to_address)
if to_address in vault_addresses:
# todo publish event to vault watchers
db.session.add(VaultToken(vault=to_address, token=event.address))
token_address = transfer['address']
vault_tokens.add(token_address)
if to_address in underfunded_vaults:
# todo flag underfunded vault (check token type?)
pass
BlockDict('wallets').add(to_address)
def handle_swap(self, event):
w3 = W3.cur()
try:
swap = w3.eth.contract(abi=get_contract_data('IUniswapV3PoolEvents')['abi']).events.Swap().process_log(event)
except LogTopicError:
return
@staticmethod
def handle_swap(swap: EventData):
try:
sqrt_price = swap['args']['sqrtPriceX96']
except KeyError:
return
addr = event['address']
price = sqrt_price * sqrt_price / 2**(96*2)
addr = swap['address']
d = dec(sqrt_price)
price = d*d / dec(2**(96*2))
print(f'pool {addr} {price}')
# pool_prices[addr] =
pool_prices[addr] = price
def add_event_trigger(self, callback:Callable[[dict],None], event: ContractEvents, log_filter: Union[dict,str]=None):
if log_filter is None: