db state load fix; vault balance tracking with pubsub; db statedict saves strings not json

This commit is contained in:
Tim Olson
2023-10-30 18:45:46 -04:00
parent 064f1a4d82
commit 6af695d345
16 changed files with 103 additions and 47 deletions

View File

@@ -34,7 +34,7 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint('key') sa.PrimaryKeyConstraint('key')
) )
op.create_table('seriesdict', op.create_table('seriesdict',
sa.Column('value', postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column('value', sa.String(), nullable=False),
sa.Column('chain', sa.Integer(), nullable=False), sa.Column('chain', sa.Integer(), nullable=False),
sa.Column('series', sa.String(), nullable=False), sa.Column('series', sa.String(), nullable=False),
sa.Column('key', sa.String(), nullable=False), sa.Column('key', sa.String(), nullable=False),

View File

@@ -57,7 +57,7 @@ class DisjointFork:
def __init__(self, block: Block, root: Block): def __init__(self, block: Block, root: Block):
self.height = block.height self.height = block.height
self.hash = block.hash self.hash = block.hash
self.parent = root.hash self.parent = root
self.disjoint = True self.disjoint = True
def __contains__(self, item): def __contains__(self, item):
@@ -65,10 +65,10 @@ class DisjointFork:
return False # item is in the future return False # item is in the future
if item.height < self.parent.height: if item.height < self.parent.height:
return True # item is ancient return True # item is ancient
return item.hash in (self.hash, self.parent) return item.hash in (self.hash, self.parent.hash)
def __str__(self): def __str__(self):
return f'{self.height}_[{self.hash.hex()}->{self.parent.hex()}]' return f'{self.height}_[{self.hash.hex()}->{self.parent.hash.hex()}]'
current_fork = ContextVar[Optional[Fork]]('current_fork', default=None) current_fork = ContextVar[Optional[Fork]]('current_fork', default=None)

View File

@@ -23,7 +23,7 @@ class Token (ContractProxy, FixedDecimals):
@staticmethod @staticmethod
def get(name_or_address:str, *, chain_id=None) -> 'Token': def get(name_or_address:str, *, chain_id=None) -> 'Token':
try: try:
return tokens.get(name_or_address, default=NARG, chain_id=chain_id) # default=NARG will raise return tokens.get(name_or_address, default=NARG, chain_id=chain_id)
except KeyError: except KeyError:
try: try:
# noinspection PyTypeChecker # noinspection PyTypeChecker

View File

@@ -19,8 +19,7 @@ async def main():
logging.basicConfig(level=logging.INFO, stream=sys.stdout) logging.basicConfig(level=logging.INFO, stream=sys.stdout)
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
parse_args() parse_args()
current_chain.set(Blockchain.get(config.chain)) await blockchain.connect()
blockchain.connect()
redis_state = None redis_state = None
state = None state = None
if memcache: if memcache:

View File

@@ -1,13 +1,14 @@
from hexbytes import HexBytes from hexbytes import HexBytes
from web3 import WebsocketProviderV2, AsyncWeb3, AsyncHTTPProvider from web3 import WebsocketProviderV2, AsyncWeb3, AsyncHTTPProvider
from ..base.chain import current_chain
from ..contract import get_contract_data from ..contract import get_contract_data
from .. import current_w3 from .. import current_w3, Blockchain
from ..configuration import resolve_rpc_url from ..configuration import resolve_rpc_url
from ..configuration.resolve import resolve_ws_url from ..configuration.resolve import resolve_ws_url
def connect(rpc_url=None): async def connect(rpc_url=None):
""" """
connects to the rpc_url and configures the context connects to the rpc_url and configures the context
if you don't want to use ctx.account for this w3, either set ctx.account first or if you don't want to use ctx.account for this w3, either set ctx.account first or
@@ -15,6 +16,7 @@ def connect(rpc_url=None):
""" """
w3 = create_w3(rpc_url) w3 = create_w3(rpc_url)
current_w3.set(w3) current_w3.set(w3)
current_chain.set(Blockchain.get(await w3.eth.chain_id))
return w3 return w3

View File

@@ -1,3 +1,4 @@
import json
import logging import logging
from enum import Enum from enum import Enum
from typing import TypeVar, Generic, Iterable, Union, Any, Iterator, Callable from typing import TypeVar, Generic, Iterable, Union, Any, Iterator, Callable
@@ -26,7 +27,7 @@ class BlockData:
def __init__(self, data_type: DataType, series: Any, *, def __init__(self, data_type: DataType, series: Any, *,
series2str=None, series2key=None, # defaults to key2str and str2key series2str=None, series2key=None, # defaults to key2str and str2key
key2str=util_key2str, str2key=util_str2key, key2str=util_key2str, str2key=util_str2key,
value2str=lambda x:x, str2value=lambda x:x, # serialize/deserialize value to something JSON-able value2str=json.dumps, str2value=json.loads, # serialize/deserialize value to something JSON-able
**opts): **opts):
assert series not in BlockData.registry assert series not in BlockData.registry
BlockData.registry[series] = self BlockData.registry[series] = self
@@ -82,6 +83,18 @@ class BlockData:
fork = current_fork.get() fork = current_fork.get()
return state.iteritems(fork, series_key) return state.iteritems(fork, series_key)
@staticmethod
def iter_keys(series_key):
state = current_blockstate.get()
fork = current_fork.get()
return state.iterkeys(fork, series_key)
@staticmethod
def iter_values(series_key):
state = current_blockstate.get()
fork = current_fork.get()
return state.itervalues(fork, series_key)
@staticmethod @staticmethod
def by_opt(key): def by_opt(key):
yield from (s for s in BlockData.registry.values() if key in s.opts) yield from (s for s in BlockData.registry.values() if key in s.opts)
@@ -131,6 +144,12 @@ class BlockDict(Generic[K,V], BlockData):
def items(self) -> Iterable[tuple[K,V]]: def items(self) -> Iterable[tuple[K,V]]:
return self.iter_items(self.series) return self.iter_items(self.series)
def keys(self) -> Iterable[K]:
return self.iter_keys(self.series)
def values(self) -> Iterable[V]:
return self.iter_values(self.series)
def get(self, item: K, default: V = None) -> V: def get(self, item: K, default: V = None) -> V:
return self.getitem(item, default) return self.getitem(item, default)

View File

@@ -62,10 +62,11 @@ class DbState(SeriesCollection):
db.session.add(SeriesSet(**key)) db.session.add(SeriesSet(**key))
elif t == DataType.DICT: elif t == DataType.DICT:
found = db.session.get(SeriesDict, key) found = db.session.get(SeriesDict, key)
value = d.value2str(diff.value)
if found is None: if found is None:
db.session.add(SeriesDict(**key, value=d.value2str(diff.value))) db.session.add(SeriesDict(**key, value=value))
else: else:
found.value = diff.value found.value = value
else: else:
raise NotImplementedError raise NotImplementedError
db.kv[f'root_block|{root_block.chain}'] = [root_block.height, root_block.hash] db.kv[f'root_block|{root_block.chain}'] = [root_block.height, root_block.hash]
@@ -87,16 +88,23 @@ class DbState(SeriesCollection):
current_fork.set(None) # root fork current_fork.set(None) # root fork
for series, data in self.datas.items(): for series, data in self.datas.items():
if data.opts.get('db') != 'lazy': if data.opts.get('db') != 'lazy':
log.debug(f'loading series {series}')
t = data.type t = data.type
if t == DataType.SET: if t == DataType.SET:
# noinspection PyTypeChecker # noinspection PyTypeChecker
var: BlockSet = BlockData.registry[series] var: BlockSet = BlockData.registry[series]
for row in db.session.query(SeriesSet).where(SeriesSet.series == data.series2str(series)): for row in db.session.query(SeriesSet).where(SeriesSet.chain == chain_id, SeriesSet.series == data.series2str(series)):
var.add(data.str2key(row.key)) key = data.str2key(row.key)
log.debug(f'load {series} {key}')
var.add(key)
elif t == DataType.DICT: elif t == DataType.DICT:
# noinspection PyTypeChecker # noinspection PyTypeChecker
var: BlockDict = BlockData.registry[series] var: BlockDict = BlockData.registry[series]
for row in db.session.query(SeriesDict).where(SeriesDict.series == data.series2str(series)): for row in db.session.query(SeriesDict).where(SeriesDict.chain == chain_id, SeriesDict.series == data.series2str(series)):
var[data.str2key(row.key)] = data.str2value(row.value) key = data.str2key(row.key)
value = data.str2value(row.value)
log.debug(f'load {series} {key} {value}')
var[key] = value
completed_block.set(root_block) completed_block.set(root_block)
log.debug(f'loaded db state from block {root_block}')
return state return state

View File

@@ -129,7 +129,7 @@ class BlockState:
if diff.value is DELETE: if diff.value is DELETE:
break break
else: else:
if self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task if fork and self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task
raise ValueError(f'Cannot get value for a non-root fork {hexstr(fork.hash)}') raise ValueError(f'Cannot get value for a non-root fork {hexstr(fork.hash)}')
return diff.value return diff.value
return DELETE return DELETE
@@ -152,6 +152,22 @@ class BlockState:
yield k, diff.value yield k, diff.value
break break
def iterkeys(self, fork: Optional[Fork], series):
for k, difflist in self.diffs_by_series.get(series, {}).items():
for diff in reversed(difflist):
if diff.height <= self.root_block.height or fork is not None and diff in fork:
if diff.value is not DELETE:
yield k
break
def itervalues(self, fork: Optional[Fork], series):
for k, difflist in self.diffs_by_series.get(series, {}).items():
for diff in reversed(difflist):
if diff.height <= self.root_block.height or fork is not None and diff in fork:
if diff.value is not DELETE:
yield diff.value
break
def promote_root(self, new_root_fork: Fork): def promote_root(self, new_root_fork: Fork):
block = self.by_hash[new_root_fork.hash] block = self.by_hash[new_root_fork.hash]
diffs = self.collect_diffs(block) diffs = self.collect_diffs(block)

View File

@@ -10,8 +10,6 @@ from typing import Optional, Union
@dataclass @dataclass
class Config: class Config:
chain: Union[int,str] = 'Arbitrum'
rpc_url: str = 'http://localhost:8545' rpc_url: str = 'http://localhost:8545'
ws_url: str = 'ws://localhost:8545' ws_url: str = 'ws://localhost:8545'
rpc_urls: Optional[dict[str,str]] = field(default_factory=dict) rpc_urls: Optional[dict[str,str]] = field(default_factory=dict)

View File

@@ -1,13 +1,18 @@
from dexorder import dec from dexorder import dec
from dexorder.blockstate import BlockSet, BlockDict from dexorder.base.chain import current_chain
from dexorder.util import defaultdictk, hexstr from dexorder.blockstate import BlockDict
from dexorder.util import json, defaultdictk
# pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args) # pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args)
# if pub is True, then event is the current series name, room is the key, and args is [value] # if pub is True, then event is the current series name, room is the key, and args is [value]
# values of DELETE are serialized as nulls # values of DELETE are serialized as nulls
vault_owners: BlockDict[str,str] = BlockDict('v', db=True, redis=True) vault_owners: BlockDict[str, str] = BlockDict('v', db=True, redis=True)
vault_balances: dict[str, BlockDict[str,int]] = defaultdictk(lambda vault: BlockDict(f'vb|{vault}', db=True, redis=True, vault_balances: BlockDict[str, dict[str, int]] = BlockDict(
pub=lambda k,v: (vault_owners[vault], 'vb', (vault,k,v)))) f'vb', db=True, redis=True,
pool_prices: BlockDict[str,dec] = BlockDict('p', db=True, redis=True, value2str=lambda d:f'{d:f}', str2value=dec, value2str=lambda d: json.dumps({k: str(v) for k, v in d.items()}), # ints can be large so we need to stringify them in JSON
pub=lambda k,v: (f'p|{k}', 'p', (k,str(v)))) str2value=lambda s: {k: int(v) for k, v in json.loads(s).items()},
pub=lambda k, v: (f'{current_chain.get().chain_id}|{vault_owners[k]}', 'vb', (k,json.dumps({k2: str(v2) for k2, v2 in v.items()})))
)
pool_prices: BlockDict[str, dec] = BlockDict('p', db=True, redis=True, value2str=lambda d: f'{d:f}', str2value=dec,
pub=lambda k, v: (f'{current_chain.get().chain_id}|{k}', 'p', (k, str(v))))

View File

@@ -18,4 +18,4 @@ class SeriesSet (SeriesBase, Base):
pass pass
class SeriesDict (SeriesBase, Base): class SeriesDict (SeriesBase, Base):
value: Mapped[Json] value: Mapped[str]

View File

@@ -1,12 +0,0 @@
import logging
from sqlalchemy.orm import Mapped, mapped_column
from dexorder.database.column import Address
from dexorder.database.model import Base
log = logging.getLogger(__name__)
class VaultToken (Base):
vault:Mapped[Address] = mapped_column(primary_key=True)
token:Mapped[Address] = mapped_column(primary_key=True)

View File

@@ -3,9 +3,9 @@ from uuid import UUID
from web3.types import EventData from web3.types import EventData
from dexorder import current_pub, db, dec from dexorder import current_pub, db
from dexorder.base.chain import current_chain from dexorder.base.chain import current_chain
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request
from dexorder.transaction import handle_create_transactions, submit_transaction_request, handle_transaction_receipts, handle_send_transactions from dexorder.transaction import handle_create_transactions, submit_transaction_request, handle_transaction_receipts, handle_send_transactions
from dexorder.blockchain.uniswap import uniswap_price from dexorder.blockchain.uniswap import uniswap_price
from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract
@@ -141,13 +141,32 @@ def handle_transfer(transfer: EventData):
from_address = transfer['args']['from'] from_address = transfer['args']['from']
to_address = transfer['args']['to'] to_address = transfer['args']['to']
amount = int(transfer['args']['value']) amount = int(transfer['args']['value'])
log.debug(f'transfer {to_address}')
if to_address in vault_owners and to_address != from_address: if to_address in vault_owners and to_address != from_address:
log.debug(f'deposit {to_address} {amount}')
vault = to_address
token_address = transfer['address'] token_address = transfer['address']
vault_balances[to_address].add(token_address, amount, 0) def transfer_in(d):
result = dict(d)
result[token_address] = result.get(token_address, 0) + amount
return result
vault_balances.modify(vault, transfer_in, default={})
if from_address in vault_owners and to_address != from_address: if from_address in vault_owners and to_address != from_address:
log.debug(f'withdraw {to_address} {amount}')
vault = from_address
token_address = transfer['address'] token_address = transfer['address']
vault_balances[to_address].add(token_address, -amount, 0) def transfer_out(d):
result = dict(d)
result[token_address] = new_value = result.get(token_address, 0) - amount
if new_value < 0:
log.warning(f'Negative balance in vault {vault}:\n{d} - {token_address} : {amount}')
# value = await ContractProxy(from_address, 'ERC20').balanceOf(from_address)
return result
vault_balances.modify(vault, transfer_out, default={})
# todo check for negative balances.
if to_address not in vault_owners and from_address not in vault_owners:
vaults = vault_owners.keys()
log.debug(f'vaults: {list(vaults)}')
new_pool_prices: dict[str, int] = {} new_pool_prices: dict[str, int] = {}

View File

@@ -106,4 +106,5 @@ async def publish_all(pubs: list[tuple[str,str,list[Any]]]):
r: Pipeline r: Pipeline
io = Emitter(dict(client=r)) io = Emitter(dict(client=r))
for room, event, args in pubs: for room, event, args in pubs:
log.debug(f'publishing {room} {event} {args}')
io.To(room).Emit(event, *args) io.To(room).Emit(event, *args)

View File

@@ -164,8 +164,8 @@ class BlockStateRunner:
if fork.disjoint: if fork.disjoint:
# backfill batches # backfill batches
for callback, event, log_filter in self.events: for callback, event, log_filter in self.events:
if event is None: if log_filter is None:
batches.append(None) batches.append((None, callback, event, None))
else: else:
from_height = self.state.root_block.height + 1 from_height = self.state.root_block.height + 1
end_height = block.height end_height = block.height

View File

@@ -48,7 +48,8 @@ def key2str(key):
return _keystr1(key) if type(key) not in (list, tuple) else '|'.join(_keystr1(v) for v in key) return _keystr1(key) if type(key) not in (list, tuple) else '|'.join(_keystr1(v) for v in key)
def str2key(s,types=None): def str2key(s,types=None):
return tuple(s.split('|')) if types is None else tuple(t(v) for t,v in zip(types,s.split('|'))) key = tuple(s.split('|')) if types is None else tuple(t(v) for t,v in zip(types,s.split('|')))
return key[0] if len(key) == 1 else key
def topic(event_abi): def topic(event_abi):
event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')' event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')'