blockstate rework

This commit is contained in:
Tim Olson
2023-09-27 17:36:58 -04:00
parent 79cd65a289
commit 75c0233384
6 changed files with 257 additions and 112 deletions

View File

@@ -4,6 +4,6 @@ omegaconf~=2.3.0
web3==6.9.0
psycopg2-binary
orjson~=3.9.7
sortedcontainers
sortedcontainers~=2.4.0
hexbytes~=0.3.1
defaultlist~=1.0.0

View File

@@ -1,7 +1,9 @@
import itertools
import logging
from collections import defaultdict
from contextvars import ContextVar
from typing import Union, TypeVar, Generic, Any
from dataclasses import dataclass
from typing import Union, TypeVar, Generic, Any, Optional
from sortedcontainers import SortedList
@@ -11,6 +13,63 @@ from dexorder.database.model.block import Block
log = logging.getLogger(__name__)
@dataclass
class DiffEntry:
value: Union[Any, 'BlockState.DELETE']
height: int
hash: bytes
@dataclass
class DiffItem:
series: Any
key: Any
entry: DiffEntry
def __str__(self):
return f'{self.entry.hash} {self.series}.{self.key}={"[DEL]" if self.entry.value is BlockState.DELETE else self.entry.value}'
class Fork(list[bytes]):
"""
A Fork is an ancestor path, stored as block hashes in reverse-chronological order from the "current" block at the start to ancestors at the end. The
getitem [] operator indexes by block height for positive values, while negative value are relative to the latest block, so [-1] is the latest
block and [-2] is its parent, etc.
"""
@staticmethod
def cur():
return _path.get()
@staticmethod
def set_cur(value: Optional['Fork']):
_path.set(value)
def __init__(self, ancestry, *, height: int):
super().__init__(ancestry)
self.height = height
def __getitem__(self, height):
try:
return super().__getitem__(self.height - height if height >= 0 else -height + 1)
except IndexError:
return None
@property
def hash(self):
return super().__getitem__(0)
@property
def parent(self):
return super().__getitem__(1)
def for_child(self, blockhash: bytes):
return Fork(self + [blockhash], height=self.height + 1)
def __str__(self):
return f'{self.height}_[{"->".join(h.hex() for h in self)}]'
class BlockState:
DELETE = object()
@@ -19,32 +78,31 @@ class BlockState:
"""
Since recent blocks can be part of temporary forks, we need to be able to undo certain operations if they were part of a reorg. Instead of implementing
undo, we recover state via snapshot plus replay of recent diffs. When old blocks become low enough in the blockheight they may be considered canonical
at which point the deltas may be reliably incorporated into a new snapshot or rolling permanent collection. BlockState manages separate memory areas
at which point the deltas may be reliably incorporated into a rolling permanent collection. BlockState manages separate memory areas
for every block, per-block state that defaults to its parent's state, up the ancestry tree to the root. State clients may read the state for their block,
applying any diffs from the root state to the target block.
by applying any diffs along the block's fork path to the root data.
"""
@staticmethod
def cur() -> 'BlockState':
return _cur.get()
return _blockstate.get()
@staticmethod
def set_cur(value: 'BlockState'):
_cur.set(value)
_blockstate.set(value)
def __init__(self, root_block: Block, root_state: dict):
def __init__(self, root_block: Block):
self.root_block: Block = root_block
self.root_state: dict = root_state
self.by_height: SortedList[tuple[int, Block]] = SortedList(key=lambda x: x[0])
self.by_hash: dict[bytes, Block] = {root_block.hash: root_block}
self.diffs: dict[bytes, dict[Any, dict[Any, Union[Any, BlockState.DELETE]]]] = defaultdict(dict) # by series
self.diffs_by_series: dict[Any, dict[Any, SortedList[DiffEntry]]] = defaultdict(lambda: defaultdict(lambda: SortedList(key=lambda x: x.height)))
self.diffs_by_block: dict[bytes, list[DiffItem]] = defaultdict(list)
self.ancestors: dict[bytes, Block] = {}
BlockState.by_chain[root_block.chain] = self
def add_block(self, block: Block) -> Union[int, Block, None]:
def add_block(self, block: Block) -> Union[int, Fork, None]:
"""
If block is the same age as root_height or older, it is ignored and None is returned. Otherwise, returns the found parent block if available
or else self.root_height.
If block is the same age as root_height or older, it is ignored and None is returned. Otherwise, returns a Fork leading to root.
The ancestor block is set in the ancestors dictionary and any state updates to block are considered to have occured between the registered ancestor
block and the given block. This could be an interval of many blocks, and the ancestor does not need to be the block's immediate parent.
"""
@@ -60,73 +118,131 @@ class BlockState:
parent = self.by_hash.get(block.parent)
if parent is None:
self.ancestors[block.hash] = self.root_block
return self.root_block.height
return Fork(block.hash, height=block.height)
else:
self.ancestors[block.hash] = parent
return parent
def promote_root(self, block):
assert block.hash in self.by_hash
def ancestors():
b = block
while b is not self.root_block:
yield b.hash
b = self.ancestors[b.hash]
return Fork(ancestors(), height=block.height)
def get(self, fork: Fork, series, key, default=NARG):
series_diffs = self.diffs_by_series.get(series)
if series_diffs is None:
if default is NARG:
raise ValueError('series')
else:
return default
diffs: list[DiffEntry] = series_diffs.get(key, [])
for diff in reversed(diffs):
if diff.height <= self.root_block.height or fork[diff.height] == diff.hash:
if diff.value is BlockState.DELETE:
break
else:
return diff.value
# value not found or was DELETE
if default is NARG:
raise KeyError((series, key))
return default
def set(self, fork: Optional[Fork], series, key, value):
diff = DiffEntry(value,
fork.height if fork is not None else self.root_block.height,
fork.hash if fork is not None else self.root_block.hash)
if fork is not None:
self.diffs_by_block[fork.hash].append(DiffItem(series, key, diff))
self.diffs_by_series[series][key].add(diff)
def iteritems(self, fork: Optional[Fork], series):
for k, difflist in self.diffs_by_series.get(series, {}).items():
for diff in reversed(difflist):
if diff.height <= self.root_block.height or fork is not None and fork[diff.height] == diff.hash:
if diff.value is not BlockState.DELETE:
yield k, diff.value
break
def promote_root(self, fork: Fork):
assert all(block in self.by_hash for block in fork)
block = self.by_hash[fork[0]]
diffs = self.collect_diffs(block)
BlockState.apply_diffs(self.root_state, diffs)
del self.by_hash[self.root_block.hash]
# no application of diffs to the internal state is required, just clean up
updated_keys = set()
# walk the by_height list to delete any aged-out block data
while self.by_height and self.by_height[0][0] <= block.height:
height, dead = self.by_height.pop(0)
del self.by_hash[self.root_block.hash] # old root block
if dead is not block:
try:
del self.by_hash[dead.hash]
except KeyError:
pass
try:
del self.diffs[dead.hash]
except KeyError:
pass
try:
del self.ancestors[dead.hash]
except KeyError:
pass
self.root_block = block
@staticmethod
def apply_diffs(obj, diffs):
for series_key, series in diffs.items():
for key, value in series.items():
if value is BlockState.DELETE:
try:
del obj[series_key][key]
except KeyError:
pass
else:
series_obj = obj.get(series_key)
if series_obj is None:
obj[series_key] = series_obj = {}
series_obj[key] = value
def collect_diffs(self, block, series_key=NARG):
diffs = {}
while block is not self.root_block:
block_diffs = self.diffs.get(block.hash)
block_diffs = self.diffs_by_block.get(dead.hash)
if block_diffs is not None:
if series_key is NARG:
for s_key, series in block_diffs.items():
series_diffs = diffs.get(s_key)
if series_diffs is None:
series_diffs = diffs[s_key] = {}
for k, v in series.items():
series_diffs.setdefault(k, v)
else:
series = block_diffs.get(series_key)
if series is not None:
for k, v in series.items():
diffs.setdefault(k, v)
block = self.ancestors[block.hash]
updated_keys.update((s, k) for s, k, d in block_diffs)
del self.diffs_by_block[dead.hash]
del self.ancestors[dead.hash]
# remove old series diffs that have been superceded by new diffs
for s, k in updated_keys:
difflist = self.diffs_by_series[s][k]
# remove old diffs on abandoned forks but keep old diffs on the root fork
removals = None
for d in difflist:
if d.height <= fork.height and d.hash != fork[d.height]:
if removals is None:
removals = [d]
else:
removals.append(d)
if removals is not None:
for r in removals:
difflist.remove(r)
# while the second-oldest diff is still root-age, pop off the oldest diff
while len(difflist) >= 2 and difflist[1].height <= fork.height:
difflist.pop(0)
# if only one diff remains, and it's old, and it's a delete, then we can actually delete the diff list
if len(difflist) == 1 and difflist[0].value == BlockState.DELETE and difflist[0].height <= fork.height:
del self.diffs_by_series[s][k]
self.root_block = block
return diffs
def collect_diffs(self, block: Block, series_key=NARG) -> list[DiffItem]:
"""
returns a list of the latest DiffItem for each key change along the ancestor path from block to root
"""
# first collect the exhaustive list of diffs along the ancestry path
diff_lists: list[list[DiffItem]] = []
while block.height > self.root_block.height:
diffs = self.diffs_by_block.get(block.hash)
if diffs:
if series_key is not NARG:
diffs = [d for d in diffs if d.series == series_key]
diff_lists.append(diffs)
block = self.ancestors[block.hash]
_cur = ContextVar[BlockState]('BlockState.cur')
# now keep only the latest values for keys that were set multiple times
sk = set() # seen keys
result: list[DiffItem] = []
# iterate through all diffs in -reverse- chronological order keeping only the first item we see for each key
for i in itertools.chain(*(reversed(l) for l in diff_lists)):
k = i.series, i.key
if k not in sk:
sk.add(k)
result.append(i)
result.reverse() # forward chronological order
return result
_blockstate = ContextVar[BlockState]('BlockState.cur')
_path = ContextVar[Fork]('fork.cur')
T = TypeVar('T')
@@ -154,47 +270,20 @@ class BlockDict(Generic[T]):
def items(self):
return BlockDict.iter_items(self.series_key)
def get(self, item, default=None):
return BlockDict.getitem(self.series_key, item, default)
@staticmethod
def setitem(series_key, item, value):
state = BlockState.cur()
block = Block.cur()
if block.height > state.root_block.height:
diffs = state.diffs[block.hash]
series = diffs.get(series_key)
if series is None:
series = diffs[series_key] = {}
else:
series = state.root_state.get(series_key)
if series is None:
series = state.root_state[series_key] = {}
series[item] = value
fork = Fork.cur()
state.set(fork, series_key, item, value)
@staticmethod
def getitem(series_key, item):
def getitem(series_key, item, default=NARG):
state = BlockState.cur()
block = Block.cur()
while block.height > state.root_block.height:
diffs = state.diffs.get(block.hash)
if diffs is not None:
series = diffs.get(series_key)
if series is not None:
value = series.get(item, NARG)
if value is BlockState.DELETE:
raise KeyError
if value is not NARG:
return value
block = state.ancestors[block.hash]
if block is not state.root_block:
raise ValueError('Orphaned block is invalid',Block.cur().hash)
root_series = state.root_state.get(series_key)
if root_series is not None:
value = root_series.get(item, NARG)
if value is BlockState.DELETE:
raise KeyError
if value is not NARG:
return value
raise KeyError
fork = Fork.cur()
return state.get(fork, series_key, item, default)
@staticmethod
def delitem(series_key, item):
@@ -205,16 +294,64 @@ class BlockDict(Generic[T]):
try:
BlockDict.getitem(series_key, item)
return True
except KeyError:
except KeyError: # getitem with no default will raise on a missing item
return False
@staticmethod
def iter_items(series_key):
state = BlockState.cur()
block = Block.cur()
root = state.root_state.get(series_key,{})
diffs = state.collect_diffs(block, series_key)
# first output recent changes in the diff obj
yield from ((k,v) for k,v in diffs.items() if v is not BlockState.DELETE)
# then all the items not diffed
yield from ((k,v) for k,v in root.items() if k not in diffs)
fork = Fork.cur()
return state.iteritems(fork, series_key)
def _test():
def B(height, hash:str, parent):
return Block(chain=1337, height=height, hash=hash.encode('utf8'), parent=None if parent is None else parent.hash, data=None)
root_block = B(10, '#root', None )
state = BlockState(root_block)
BlockState.set_cur(state)
b11 = B(11, '#b11', parent=root_block)
f11: Fork = state.add_block(b11)
print('f11',f11)
b11b = B(11, '#b11b', parent=root_block)
f11b: Fork = state.add_block(b11b)
print('f11b',f11b)
b12 = B(12, '#b12', parent=b11)
f12: Fork = state.add_block(b12)
print('f12',f12)
b13 = B(13, '#b13', parent=b12)
f13: Fork = state.add_block(b13)
d = BlockDict('ser')
def dump():
print()
print(Fork.cur().hash if Fork.cur() is not None else 'root')
for k,v in d.items():
print(f'{k} = {v}')
Fork.set_cur(None) # Use None to set values on root
d['foo'] = 'bar'
d['test'] = 'failed'
Fork.set_cur(f11)
d['foo2'] = 'bar2'
del d['test']
Fork.set_cur(f11b)
d['foo2'] = 'bar2b'
Fork.set_cur(f12)
d['test'] = 'ok'
for f in (None, f11, f11b, f12):
Fork.set_cur(f)
dump()
print()
print('all b12 diffs')
for i in state.collect_diffs(b12):
print(i)
if __name__ == '__main__':
_test()

View File

@@ -1,4 +1,3 @@
from .old_dispatch import OldDispatcher
from .by_blockchain import ByBlockchainDict, ByBlockchainList, ByBlockchainCollection
from .connection import connect
from dexorder.base.chain import Ethereum, Polygon, Goerli, Mumbai, ArbitrumOne, BSC

View File

@@ -14,7 +14,7 @@ class Block(Base):
data: Mapped[dict] = mapped_column('data',JSONB)
def __str__(self):
return f'{self.height}_{self.hash}'
return f'{self.height}_{self.hash.hex()}'
@staticmethod
def cur() -> 'Block':

View File

@@ -12,21 +12,27 @@ def align_decimal(value, left_columns) -> str:
returns a string where the decimal point in value is aligned to have left_columns of characters before it
"""
s = str(value)
pad = max(left_columns - len(re.sub(r'[^0-9]*$','',s.split('.')[0])), 0)
pad = max(left_columns - len(re.sub(r'[^0-9]*$', '', s.split('.')[0])), 0)
return ' ' * pad + s
def hexstr(value):
def hexstr(value: bytes):
""" returns an 0x-prefixed hex string """
if type(value) is HexBytes:
return value.hex()
elif type(value) is bytes:
return '0x'+bytes.hex()
return '0x' + value.hex()
elif type(value) is str:
return value if value.startswith('0x') else '0x' + value
else:
raise ValueError
def hexbytes(value: str):
""" converts an optionally 0x-prefixed hex string into bytes """
return bytes.fromhex(value[2:] if value.startswith('0x') else value)
def topic(event_abi):
event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')'
result = '0x' + keccak(text=event_name).hex()

View File

@@ -1,16 +1,19 @@
from hexbytes import HexBytes
from orjson import orjson
from web3.datastructures import ReadableAttributeDict
from web3.datastructures import AttributeDict
def _serialize(v):
# todo wrap json.dumps()
if isinstance(v,HexBytes):
if type(v) is HexBytes:
return v.hex()
if isinstance(v,ReadableAttributeDict):
if type(v) is AttributeDict:
return v.__dict__
raise ValueError(v)
def loads(s):
return orjson.loads(s)
def dumps(obj):
return orjson.dumps(obj, default=_serialize)