diff --git a/requirements.txt b/requirements.txt index 1163dbc..62acf63 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ omegaconf~=2.3.0 web3==6.9.0 psycopg2-binary orjson~=3.9.7 -sortedcontainers +sortedcontainers~=2.4.0 hexbytes~=0.3.1 defaultlist~=1.0.0 diff --git a/src/dexorder/base/blockstate.py b/src/dexorder/base/blockstate.py index dd04680..f2a68bc 100644 --- a/src/dexorder/base/blockstate.py +++ b/src/dexorder/base/blockstate.py @@ -1,7 +1,9 @@ +import itertools import logging from collections import defaultdict from contextvars import ContextVar -from typing import Union, TypeVar, Generic, Any +from dataclasses import dataclass +from typing import Union, TypeVar, Generic, Any, Optional from sortedcontainers import SortedList @@ -11,6 +13,63 @@ from dexorder.database.model.block import Block log = logging.getLogger(__name__) +@dataclass +class DiffEntry: + value: Union[Any, 'BlockState.DELETE'] + height: int + hash: bytes + + +@dataclass +class DiffItem: + series: Any + key: Any + entry: DiffEntry + + def __str__(self): + return f'{self.entry.hash} {self.series}.{self.key}={"[DEL]" if self.entry.value is BlockState.DELETE else self.entry.value}' + + +class Fork(list[bytes]): + """ + A Fork is an ancestor path, stored as block hashes in reverse-chronological order from the "current" block at the start to ancestors at the end. The + getitem [] operator indexes by block height for positive values, while negative value are relative to the latest block, so [-1] is the latest + block and [-2] is its parent, etc. + """ + + @staticmethod + def cur(): + return _path.get() + + @staticmethod + def set_cur(value: Optional['Fork']): + _path.set(value) + + def __init__(self, ancestry, *, height: int): + super().__init__(ancestry) + self.height = height + + def __getitem__(self, height): + try: + return super().__getitem__(self.height - height if height >= 0 else -height + 1) + except IndexError: + return None + + @property + def hash(self): + return super().__getitem__(0) + + @property + def parent(self): + return super().__getitem__(1) + + def for_child(self, blockhash: bytes): + return Fork(self + [blockhash], height=self.height + 1) + + def __str__(self): + return f'{self.height}_[{"->".join(h.hex() for h in self)}]' + + class BlockState: DELETE = object() @@ -19,32 +78,31 @@ class BlockState: """ Since recent blocks can be part of temporary forks, we need to be able to undo certain operations if they were part of a reorg. Instead of implementing undo, we recover state via snapshot plus replay of recent diffs. When old blocks become low enough in the blockheight they may be considered canonical - at which point the deltas may be reliably incorporated into a new snapshot or rolling permanent collection. BlockState manages separate memory areas + at which point the deltas may be reliably incorporated into a rolling permanent collection. BlockState manages separate memory areas for every block, per-block state that defaults to its parent's state, up the ancestry tree to the root. State clients may read the state for their block, - applying any diffs from the root state to the target block. + by applying any diffs along the block's fork path to the root data. """ @staticmethod def cur() -> 'BlockState': - return _cur.get() + return _blockstate.get() @staticmethod def set_cur(value: 'BlockState'): - _cur.set(value) + _blockstate.set(value) - def __init__(self, root_block: Block, root_state: dict): + def __init__(self, root_block: Block): self.root_block: Block = root_block - self.root_state: dict = root_state self.by_height: SortedList[tuple[int, Block]] = SortedList(key=lambda x: x[0]) self.by_hash: dict[bytes, Block] = {root_block.hash: root_block} - self.diffs: dict[bytes, dict[Any, dict[Any, Union[Any, BlockState.DELETE]]]] = defaultdict(dict) # by series + self.diffs_by_series: dict[Any, dict[Any, SortedList[DiffEntry]]] = defaultdict(lambda: defaultdict(lambda: SortedList(key=lambda x: x.height))) + self.diffs_by_block: dict[bytes, list[DiffItem]] = defaultdict(list) self.ancestors: dict[bytes, Block] = {} BlockState.by_chain[root_block.chain] = self - def add_block(self, block: Block) -> Union[int, Block, None]: + def add_block(self, block: Block) -> Union[int, Fork, None]: """ - If block is the same age as root_height or older, it is ignored and None is returned. Otherwise, returns the found parent block if available - or else self.root_height. + If block is the same age as root_height or older, it is ignored and None is returned. Otherwise, returns a Fork leading to root. The ancestor block is set in the ancestors dictionary and any state updates to block are considered to have occured between the registered ancestor block and the given block. This could be an interval of many blocks, and the ancestor does not need to be the block's immediate parent. """ @@ -60,73 +118,131 @@ class BlockState: parent = self.by_hash.get(block.parent) if parent is None: self.ancestors[block.hash] = self.root_block - return self.root_block.height + return Fork(block.hash, height=block.height) else: self.ancestors[block.hash] = parent - return parent - def promote_root(self, block): - assert block.hash in self.by_hash + def ancestors(): + b = block + while b is not self.root_block: + yield b.hash + b = self.ancestors[b.hash] + + return Fork(ancestors(), height=block.height) + + def get(self, fork: Fork, series, key, default=NARG): + series_diffs = self.diffs_by_series.get(series) + if series_diffs is None: + if default is NARG: + raise ValueError('series') + else: + return default + diffs: list[DiffEntry] = series_diffs.get(key, []) + for diff in reversed(diffs): + if diff.height <= self.root_block.height or fork[diff.height] == diff.hash: + if diff.value is BlockState.DELETE: + break + else: + return diff.value + # value not found or was DELETE + if default is NARG: + raise KeyError((series, key)) + return default + + def set(self, fork: Optional[Fork], series, key, value): + diff = DiffEntry(value, + fork.height if fork is not None else self.root_block.height, + fork.hash if fork is not None else self.root_block.hash) + if fork is not None: + self.diffs_by_block[fork.hash].append(DiffItem(series, key, diff)) + self.diffs_by_series[series][key].add(diff) + + def iteritems(self, fork: Optional[Fork], series): + for k, difflist in self.diffs_by_series.get(series, {}).items(): + for diff in reversed(difflist): + if diff.height <= self.root_block.height or fork is not None and fork[diff.height] == diff.hash: + if diff.value is not BlockState.DELETE: + yield k, diff.value + break + + def promote_root(self, fork: Fork): + assert all(block in self.by_hash for block in fork) + block = self.by_hash[fork[0]] diffs = self.collect_diffs(block) - BlockState.apply_diffs(self.root_state, diffs) - del self.by_hash[self.root_block.hash] + + # no application of diffs to the internal state is required, just clean up + + updated_keys = set() + + # walk the by_height list to delete any aged-out block data while self.by_height and self.by_height[0][0] <= block.height: height, dead = self.by_height.pop(0) + del self.by_hash[self.root_block.hash] # old root block if dead is not block: try: del self.by_hash[dead.hash] except KeyError: pass - try: - del self.diffs[dead.hash] - except KeyError: - pass - try: - del self.ancestors[dead.hash] - except KeyError: - pass - self.root_block = block - - @staticmethod - def apply_diffs(obj, diffs): - for series_key, series in diffs.items(): - for key, value in series.items(): - if value is BlockState.DELETE: - try: - del obj[series_key][key] - except KeyError: - pass - else: - series_obj = obj.get(series_key) - if series_obj is None: - obj[series_key] = series_obj = {} - series_obj[key] = value - - - def collect_diffs(self, block, series_key=NARG): - diffs = {} - while block is not self.root_block: - block_diffs = self.diffs.get(block.hash) + block_diffs = self.diffs_by_block.get(dead.hash) if block_diffs is not None: - if series_key is NARG: - for s_key, series in block_diffs.items(): - series_diffs = diffs.get(s_key) - if series_diffs is None: - series_diffs = diffs[s_key] = {} - for k, v in series.items(): - series_diffs.setdefault(k, v) - else: - series = block_diffs.get(series_key) - if series is not None: - for k, v in series.items(): - diffs.setdefault(k, v) - block = self.ancestors[block.hash] + updated_keys.update((s, k) for s, k, d in block_diffs) + del self.diffs_by_block[dead.hash] + del self.ancestors[dead.hash] + + # remove old series diffs that have been superceded by new diffs + for s, k in updated_keys: + difflist = self.diffs_by_series[s][k] + # remove old diffs on abandoned forks but keep old diffs on the root fork + removals = None + for d in difflist: + if d.height <= fork.height and d.hash != fork[d.height]: + if removals is None: + removals = [d] + else: + removals.append(d) + if removals is not None: + for r in removals: + difflist.remove(r) + # while the second-oldest diff is still root-age, pop off the oldest diff + while len(difflist) >= 2 and difflist[1].height <= fork.height: + difflist.pop(0) + # if only one diff remains, and it's old, and it's a delete, then we can actually delete the diff list + if len(difflist) == 1 and difflist[0].value == BlockState.DELETE and difflist[0].height <= fork.height: + del self.diffs_by_series[s][k] + + self.root_block = block return diffs + def collect_diffs(self, block: Block, series_key=NARG) -> list[DiffItem]: + """ + returns a list of the latest DiffItem for each key change along the ancestor path from block to root + """ + # first collect the exhaustive list of diffs along the ancestry path + diff_lists: list[list[DiffItem]] = [] + while block.height > self.root_block.height: + diffs = self.diffs_by_block.get(block.hash) + if diffs: + if series_key is not NARG: + diffs = [d for d in diffs if d.series == series_key] + diff_lists.append(diffs) + block = self.ancestors[block.hash] -_cur = ContextVar[BlockState]('BlockState.cur') + # now keep only the latest values for keys that were set multiple times + sk = set() # seen keys + result: list[DiffItem] = [] + # iterate through all diffs in -reverse- chronological order keeping only the first item we see for each key + for i in itertools.chain(*(reversed(l) for l in diff_lists)): + k = i.series, i.key + if k not in sk: + sk.add(k) + result.append(i) + result.reverse() # forward chronological order + return result +_blockstate = ContextVar[BlockState]('BlockState.cur') +_path = ContextVar[Fork]('fork.cur') + T = TypeVar('T') @@ -154,47 +270,20 @@ class BlockDict(Generic[T]): def items(self): return BlockDict.iter_items(self.series_key) + def get(self, item, default=None): + return BlockDict.getitem(self.series_key, item, default) @staticmethod def setitem(series_key, item, value): state = BlockState.cur() - block = Block.cur() - if block.height > state.root_block.height: - diffs = state.diffs[block.hash] - series = diffs.get(series_key) - if series is None: - series = diffs[series_key] = {} - else: - series = state.root_state.get(series_key) - if series is None: - series = state.root_state[series_key] = {} - series[item] = value + fork = Fork.cur() + state.set(fork, series_key, item, value) @staticmethod - def getitem(series_key, item): + def getitem(series_key, item, default=NARG): state = BlockState.cur() - block = Block.cur() - while block.height > state.root_block.height: - diffs = state.diffs.get(block.hash) - if diffs is not None: - series = diffs.get(series_key) - if series is not None: - value = series.get(item, NARG) - if value is BlockState.DELETE: - raise KeyError - if value is not NARG: - return value - block = state.ancestors[block.hash] - if block is not state.root_block: - raise ValueError('Orphaned block is invalid',Block.cur().hash) - root_series = state.root_state.get(series_key) - if root_series is not None: - value = root_series.get(item, NARG) - if value is BlockState.DELETE: - raise KeyError - if value is not NARG: - return value - raise KeyError + fork = Fork.cur() + return state.get(fork, series_key, item, default) @staticmethod def delitem(series_key, item): @@ -205,16 +294,64 @@ class BlockDict(Generic[T]): try: BlockDict.getitem(series_key, item) return True - except KeyError: + except KeyError: # getitem with no default will raise on a missing item return False @staticmethod def iter_items(series_key): state = BlockState.cur() - block = Block.cur() - root = state.root_state.get(series_key,{}) - diffs = state.collect_diffs(block, series_key) - # first output recent changes in the diff obj - yield from ((k,v) for k,v in diffs.items() if v is not BlockState.DELETE) - # then all the items not diffed - yield from ((k,v) for k,v in root.items() if k not in diffs) + fork = Fork.cur() + return state.iteritems(fork, series_key) + + +def _test(): + + def B(height, hash:str, parent): + return Block(chain=1337, height=height, hash=hash.encode('utf8'), parent=None if parent is None else parent.hash, data=None) + + root_block = B(10, '#root', None ) + state = BlockState(root_block) + BlockState.set_cur(state) + b11 = B(11, '#b11', parent=root_block) + f11: Fork = state.add_block(b11) + print('f11',f11) + b11b = B(11, '#b11b', parent=root_block) + f11b: Fork = state.add_block(b11b) + print('f11b',f11b) + b12 = B(12, '#b12', parent=b11) + f12: Fork = state.add_block(b12) + print('f12',f12) + b13 = B(13, '#b13', parent=b12) + f13: Fork = state.add_block(b13) + + d = BlockDict('ser') + + def dump(): + print() + print(Fork.cur().hash if Fork.cur() is not None else 'root') + for k,v in d.items(): + print(f'{k} = {v}') + + Fork.set_cur(None) # Use None to set values on root + d['foo'] = 'bar' + d['test'] = 'failed' + Fork.set_cur(f11) + d['foo2'] = 'bar2' + del d['test'] + Fork.set_cur(f11b) + d['foo2'] = 'bar2b' + Fork.set_cur(f12) + d['test'] = 'ok' + + for f in (None, f11, f11b, f12): + Fork.set_cur(f) + dump() + + print() + print('all b12 diffs') + for i in state.collect_diffs(b12): + print(i) + + +if __name__ == '__main__': + _test() diff --git a/src/dexorder/blockchain/__init__.py b/src/dexorder/blockchain/__init__.py index 7017c84..0458fd4 100644 --- a/src/dexorder/blockchain/__init__.py +++ b/src/dexorder/blockchain/__init__.py @@ -1,4 +1,3 @@ -from .old_dispatch import OldDispatcher from .by_blockchain import ByBlockchainDict, ByBlockchainList, ByBlockchainCollection from .connection import connect from dexorder.base.chain import Ethereum, Polygon, Goerli, Mumbai, ArbitrumOne, BSC diff --git a/src/dexorder/database/model/block.py b/src/dexorder/database/model/block.py index f7aa75c..22d2320 100644 --- a/src/dexorder/database/model/block.py +++ b/src/dexorder/database/model/block.py @@ -14,7 +14,7 @@ class Block(Base): data: Mapped[dict] = mapped_column('data',JSONB) def __str__(self): - return f'{self.height}_{self.hash}' + return f'{self.height}_{self.hash.hex()}' @staticmethod def cur() -> 'Block': diff --git a/src/dexorder/util/__init__.py b/src/dexorder/util/__init__.py index 7c490f8..6e9235b 100644 --- a/src/dexorder/util/__init__.py +++ b/src/dexorder/util/__init__.py @@ -12,21 +12,27 @@ def align_decimal(value, left_columns) -> str: returns a string where the decimal point in value is aligned to have left_columns of characters before it """ s = str(value) - pad = max(left_columns - len(re.sub(r'[^0-9]*$','',s.split('.')[0])), 0) + pad = max(left_columns - len(re.sub(r'[^0-9]*$', '', s.split('.')[0])), 0) return ' ' * pad + s -def hexstr(value): + +def hexstr(value: bytes): """ returns an 0x-prefixed hex string """ if type(value) is HexBytes: return value.hex() elif type(value) is bytes: - return '0x'+bytes.hex() + return '0x' + value.hex() elif type(value) is str: return value if value.startswith('0x') else '0x' + value else: raise ValueError +def hexbytes(value: str): + """ converts an optionally 0x-prefixed hex string into bytes """ + return bytes.fromhex(value[2:] if value.startswith('0x') else value) + + def topic(event_abi): event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')' result = '0x' + keccak(text=event_name).hex() diff --git a/src/dexorder/util/json.py b/src/dexorder/util/json.py index 4eb239f..b452a20 100644 --- a/src/dexorder/util/json.py +++ b/src/dexorder/util/json.py @@ -1,16 +1,19 @@ from hexbytes import HexBytes from orjson import orjson -from web3.datastructures import ReadableAttributeDict +from web3.datastructures import AttributeDict def _serialize(v): # todo wrap json.dumps() - if isinstance(v,HexBytes): + if type(v) is HexBytes: return v.hex() - if isinstance(v,ReadableAttributeDict): + if type(v) is AttributeDict: return v.__dict__ raise ValueError(v) +def loads(s): + return orjson.loads(s) + def dumps(obj): return orjson.dumps(obj, default=_serialize)