blockstate touchups

This commit is contained in:
Tim Olson
2023-09-27 19:45:24 -04:00
parent 75c0233384
commit 2b72decf7b
2 changed files with 55 additions and 25 deletions

View File

@@ -38,12 +38,12 @@ class Fork(list[bytes]):
"""
@staticmethod
def cur():
return _path.get()
def cur() -> Optional['Fork']:
return _fork.get()
@staticmethod
def set_cur(value: Optional['Fork']):
_path.set(value)
_fork.set(value)
def __init__(self, ancestry, *, height: int):
super().__init__(ancestry)
@@ -93,10 +93,12 @@ class BlockState:
def __init__(self, root_block: Block):
self.root_block: Block = root_block
self.by_height: SortedList[tuple[int, Block]] = SortedList(key=lambda x: x[0])
self.by_height: SortedList[Block] = SortedList(key=lambda x: x.height)
self.by_hash: dict[bytes, Block] = {root_block.hash: root_block}
# diffs_by_series is the main data structure. leaf nodes are list of diffs sorted by blockheight
self.diffs_by_series: dict[Any, dict[Any, SortedList[DiffEntry]]] = defaultdict(lambda: defaultdict(lambda: SortedList(key=lambda x: x.height)))
self.diffs_by_block: dict[bytes, list[DiffItem]] = defaultdict(list)
# diffs_by_hash holds the diff items generated by each block
self.diffs_by_hash: dict[bytes, list[DiffItem]] = defaultdict(list)
self.ancestors: dict[bytes, Block] = {}
BlockState.by_chain[root_block.chain] = self
@@ -113,7 +115,7 @@ class BlockState:
return None
if block.hash not in self.by_hash:
self.by_hash[block.hash] = block
self.by_height.add((block.height, block))
self.by_height.add(block)
log.debug(f'new block state {block}')
parent = self.by_hash.get(block.parent)
if parent is None:
@@ -130,7 +132,7 @@ class BlockState:
return Fork(ancestors(), height=block.height)
def get(self, fork: Fork, series, key, default=NARG):
def get(self, fork: Optional[Fork], series, key, default=NARG):
series_diffs = self.diffs_by_series.get(series)
if series_diffs is None:
if default is NARG:
@@ -139,10 +141,12 @@ class BlockState:
return default
diffs: list[DiffEntry] = series_diffs.get(key, [])
for diff in reversed(diffs):
if diff.height <= self.root_block.height or fork[diff.height] == diff.hash:
if diff.height <= self.root_block.height or fork is not None and fork[diff.height] == diff.hash:
if diff.value is BlockState.DELETE:
break
else:
if fork[self.root_block.height] != self.root_block.hash: # todo move this assertion elsewhere so it runs once per task
raise RuntimeError
return diff.value
# value not found or was DELETE
if default is NARG:
@@ -154,7 +158,7 @@ class BlockState:
fork.height if fork is not None else self.root_block.height,
fork.hash if fork is not None else self.root_block.hash)
if fork is not None:
self.diffs_by_block[fork.hash].append(DiffItem(series, key, diff))
self.diffs_by_hash[fork.hash].append(DiffItem(series, key, diff))
self.diffs_by_series[series][key].add(diff)
def iteritems(self, fork: Optional[Fork], series):
@@ -167,29 +171,28 @@ class BlockState:
def promote_root(self, fork: Fork):
assert all(block in self.by_hash for block in fork)
block = self.by_hash[fork[0]]
block = self.by_hash[fork.hash]
diffs = self.collect_diffs(block)
# no application of diffs to the internal state is required, just clean up
updated_keys = set()
# walk the by_height list to delete any aged-out block data
while self.by_height and self.by_height[0][0] <= block.height:
height, dead = self.by_height.pop(0)
del self.by_hash[self.root_block.hash] # old root block
# in order to prune diffs_by_series, updated_keys remembers all the keys that were touched by any aged-out block
updated_keys = set()
while self.by_height and self.by_height[0].height <= block.height:
dead = self.by_height.pop(0)
if dead is not block:
try:
del self.by_hash[dead.hash]
except KeyError:
pass
block_diffs = self.diffs_by_block.get(dead.hash)
block_diffs = self.diffs_by_hash.get(dead.hash)
if block_diffs is not None:
updated_keys.update((s, k) for s, k, d in block_diffs)
del self.diffs_by_block[dead.hash]
updated_keys.update((d.series, d.key) for d in block_diffs)
del self.diffs_by_hash[dead.hash]
del self.ancestors[dead.hash]
# remove old series diffs that have been superceded by new diffs
# prune diffs_by_series by removing old series diffs that have been superceded by new diffs
for s, k in updated_keys:
difflist = self.diffs_by_series[s][k]
# remove old diffs on abandoned forks but keep old diffs on the root fork
@@ -207,9 +210,10 @@ class BlockState:
while len(difflist) >= 2 and difflist[1].height <= fork.height:
difflist.pop(0)
# if only one diff remains, and it's old, and it's a delete, then we can actually delete the diff list
if len(difflist) == 1 and difflist[0].value == BlockState.DELETE and difflist[0].height <= fork.height:
if not difflist or len(difflist) == 1 and difflist[0].value == BlockState.DELETE and difflist[0].height <= fork.height:
del self.diffs_by_series[s][k]
del self.by_hash[self.root_block.hash] # old root block
self.root_block = block
return diffs
@@ -220,7 +224,7 @@ class BlockState:
# first collect the exhaustive list of diffs along the ancestry path
diff_lists: list[list[DiffItem]] = []
while block.height > self.root_block.height:
diffs = self.diffs_by_block.get(block.hash)
diffs = self.diffs_by_hash.get(block.hash)
if diffs:
if series_key is not NARG:
diffs = [d for d in diffs if d.series == series_key]
@@ -241,7 +245,7 @@ class BlockState:
_blockstate = ContextVar[BlockState]('BlockState.cur')
_path = ContextVar[Fork]('fork.cur')
_fork = ContextVar[Optional[Fork]]('fork.cur', default=None)
T = TypeVar('T')
@@ -335,11 +339,15 @@ def _test():
Fork.set_cur(None) # Use None to set values on root
d['foo'] = 'bar'
d['test'] = 'failed'
Fork.set_cur(f11)
d['foo2'] = 'bar2'
del d['test']
Fork.set_cur(f11b)
d['foo2'] = 'bar2b'
del d['foo2']
d['foob'] = 'barb'
Fork.set_cur(f12)
d['test'] = 'ok'
@@ -352,6 +360,18 @@ def _test():
for i in state.collect_diffs(b12):
print(i)
print()
print('promoting b11')
state.promote_root(f11)
Fork.set_cur(f12)
dump()
print()
print('promoting b12')
state.promote_root(f12)
Fork.set_cur(f12)
dump()
if __name__ == '__main__':
_test()

View File

@@ -1,4 +1,5 @@
from contextvars import ContextVar
from typing import Optional
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
@@ -32,6 +33,15 @@ class Block(Base):
def set_latest(value: 'Block'):
_latest.set(value)
@staticmethod
def completed() -> Optional['Block']:
return _completed.get()
_cur = ContextVar[Block]('Block.cur')
_latest = ContextVar[Block]('Block.latest')
@staticmethod
def set_completed(value: Optional['Block']):
_completed.set(value)
_cur = ContextVar[Block]('Block.cur') # block for the current thread
_latest = ContextVar[Block]('Block.latest') # most recent discovered but may not be processed yet
_completed = ContextVar[Block]('Block.completed') # most recent fully-processed block