db state load fix; vault balance tracking with pubsub; db statedict saves strings not json

This commit is contained in:
Tim Olson
2023-10-30 18:45:46 -04:00
parent 064f1a4d82
commit 6af695d345
16 changed files with 103 additions and 47 deletions

View File

@@ -34,7 +34,7 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint('key')
)
op.create_table('seriesdict',
sa.Column('value', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('value', sa.String(), nullable=False),
sa.Column('chain', sa.Integer(), nullable=False),
sa.Column('series', sa.String(), nullable=False),
sa.Column('key', sa.String(), nullable=False),

View File

@@ -57,7 +57,7 @@ class DisjointFork:
def __init__(self, block: Block, root: Block):
self.height = block.height
self.hash = block.hash
self.parent = root.hash
self.parent = root
self.disjoint = True
def __contains__(self, item):
@@ -65,10 +65,10 @@ class DisjointFork:
return False # item is in the future
if item.height < self.parent.height:
return True # item is ancient
return item.hash in (self.hash, self.parent)
return item.hash in (self.hash, self.parent.hash)
def __str__(self):
return f'{self.height}_[{self.hash.hex()}->{self.parent.hex()}]'
return f'{self.height}_[{self.hash.hex()}->{self.parent.hash.hex()}]'
current_fork = ContextVar[Optional[Fork]]('current_fork', default=None)

View File

@@ -23,7 +23,7 @@ class Token (ContractProxy, FixedDecimals):
@staticmethod
def get(name_or_address:str, *, chain_id=None) -> 'Token':
try:
return tokens.get(name_or_address, default=NARG, chain_id=chain_id) # default=NARG will raise
return tokens.get(name_or_address, default=NARG, chain_id=chain_id)
except KeyError:
try:
# noinspection PyTypeChecker

View File

@@ -19,8 +19,7 @@ async def main():
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
log.setLevel(logging.DEBUG)
parse_args()
current_chain.set(Blockchain.get(config.chain))
blockchain.connect()
await blockchain.connect()
redis_state = None
state = None
if memcache:

View File

@@ -1,13 +1,14 @@
from hexbytes import HexBytes
from web3 import WebsocketProviderV2, AsyncWeb3, AsyncHTTPProvider
from ..base.chain import current_chain
from ..contract import get_contract_data
from .. import current_w3
from .. import current_w3, Blockchain
from ..configuration import resolve_rpc_url
from ..configuration.resolve import resolve_ws_url
def connect(rpc_url=None):
async def connect(rpc_url=None):
"""
connects to the rpc_url and configures the context
if you don't want to use ctx.account for this w3, either set ctx.account first or
@@ -15,6 +16,7 @@ def connect(rpc_url=None):
"""
w3 = create_w3(rpc_url)
current_w3.set(w3)
current_chain.set(Blockchain.get(await w3.eth.chain_id))
return w3

View File

@@ -1,3 +1,4 @@
import json
import logging
from enum import Enum
from typing import TypeVar, Generic, Iterable, Union, Any, Iterator, Callable
@@ -26,7 +27,7 @@ class BlockData:
def __init__(self, data_type: DataType, series: Any, *,
series2str=None, series2key=None, # defaults to key2str and str2key
key2str=util_key2str, str2key=util_str2key,
value2str=lambda x:x, str2value=lambda x:x, # serialize/deserialize value to something JSON-able
value2str=json.dumps, str2value=json.loads, # serialize/deserialize value to something JSON-able
**opts):
assert series not in BlockData.registry
BlockData.registry[series] = self
@@ -82,6 +83,18 @@ class BlockData:
fork = current_fork.get()
return state.iteritems(fork, series_key)
@staticmethod
def iter_keys(series_key):
state = current_blockstate.get()
fork = current_fork.get()
return state.iterkeys(fork, series_key)
@staticmethod
def iter_values(series_key):
state = current_blockstate.get()
fork = current_fork.get()
return state.itervalues(fork, series_key)
@staticmethod
def by_opt(key):
yield from (s for s in BlockData.registry.values() if key in s.opts)
@@ -131,6 +144,12 @@ class BlockDict(Generic[K,V], BlockData):
def items(self) -> Iterable[tuple[K,V]]:
return self.iter_items(self.series)
def keys(self) -> Iterable[K]:
return self.iter_keys(self.series)
def values(self) -> Iterable[V]:
return self.iter_values(self.series)
def get(self, item: K, default: V = None) -> V:
return self.getitem(item, default)

View File

@@ -62,10 +62,11 @@ class DbState(SeriesCollection):
db.session.add(SeriesSet(**key))
elif t == DataType.DICT:
found = db.session.get(SeriesDict, key)
value = d.value2str(diff.value)
if found is None:
db.session.add(SeriesDict(**key, value=d.value2str(diff.value)))
db.session.add(SeriesDict(**key, value=value))
else:
found.value = diff.value
found.value = value
else:
raise NotImplementedError
db.kv[f'root_block|{root_block.chain}'] = [root_block.height, root_block.hash]
@@ -87,16 +88,23 @@ class DbState(SeriesCollection):
current_fork.set(None) # root fork
for series, data in self.datas.items():
if data.opts.get('db') != 'lazy':
log.debug(f'loading series {series}')
t = data.type
if t == DataType.SET:
# noinspection PyTypeChecker
var: BlockSet = BlockData.registry[series]
for row in db.session.query(SeriesSet).where(SeriesSet.series == data.series2str(series)):
var.add(data.str2key(row.key))
for row in db.session.query(SeriesSet).where(SeriesSet.chain == chain_id, SeriesSet.series == data.series2str(series)):
key = data.str2key(row.key)
log.debug(f'load {series} {key}')
var.add(key)
elif t == DataType.DICT:
# noinspection PyTypeChecker
var: BlockDict = BlockData.registry[series]
for row in db.session.query(SeriesDict).where(SeriesDict.series == data.series2str(series)):
var[data.str2key(row.key)] = data.str2value(row.value)
for row in db.session.query(SeriesDict).where(SeriesDict.chain == chain_id, SeriesDict.series == data.series2str(series)):
key = data.str2key(row.key)
value = data.str2value(row.value)
log.debug(f'load {series} {key} {value}')
var[key] = value
completed_block.set(root_block)
log.debug(f'loaded db state from block {root_block}')
return state

View File

@@ -129,7 +129,7 @@ class BlockState:
if diff.value is DELETE:
break
else:
if self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task
if fork and self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task
raise ValueError(f'Cannot get value for a non-root fork {hexstr(fork.hash)}')
return diff.value
return DELETE
@@ -152,6 +152,22 @@ class BlockState:
yield k, diff.value
break
def iterkeys(self, fork: Optional[Fork], series):
for k, difflist in self.diffs_by_series.get(series, {}).items():
for diff in reversed(difflist):
if diff.height <= self.root_block.height or fork is not None and diff in fork:
if diff.value is not DELETE:
yield k
break
def itervalues(self, fork: Optional[Fork], series):
for k, difflist in self.diffs_by_series.get(series, {}).items():
for diff in reversed(difflist):
if diff.height <= self.root_block.height or fork is not None and diff in fork:
if diff.value is not DELETE:
yield diff.value
break
def promote_root(self, new_root_fork: Fork):
block = self.by_hash[new_root_fork.hash]
diffs = self.collect_diffs(block)

View File

@@ -10,8 +10,6 @@ from typing import Optional, Union
@dataclass
class Config:
chain: Union[int,str] = 'Arbitrum'
rpc_url: str = 'http://localhost:8545'
ws_url: str = 'ws://localhost:8545'
rpc_urls: Optional[dict[str,str]] = field(default_factory=dict)

View File

@@ -1,13 +1,18 @@
from dexorder import dec
from dexorder.blockstate import BlockSet, BlockDict
from dexorder.util import defaultdictk, hexstr
from dexorder.base.chain import current_chain
from dexorder.blockstate import BlockDict
from dexorder.util import json, defaultdictk
# pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args)
# if pub is True, then event is the current series name, room is the key, and args is [value]
# values of DELETE are serialized as nulls
vault_owners: BlockDict[str,str] = BlockDict('v', db=True, redis=True)
vault_balances: dict[str, BlockDict[str,int]] = defaultdictk(lambda vault: BlockDict(f'vb|{vault}', db=True, redis=True,
pub=lambda k,v: (vault_owners[vault], 'vb', (vault,k,v))))
pool_prices: BlockDict[str,dec] = BlockDict('p', db=True, redis=True, value2str=lambda d:f'{d:f}', str2value=dec,
pub=lambda k,v: (f'p|{k}', 'p', (k,str(v))))
vault_owners: BlockDict[str, str] = BlockDict('v', db=True, redis=True)
vault_balances: BlockDict[str, dict[str, int]] = BlockDict(
f'vb', db=True, redis=True,
value2str=lambda d: json.dumps({k: str(v) for k, v in d.items()}), # ints can be large so we need to stringify them in JSON
str2value=lambda s: {k: int(v) for k, v in json.loads(s).items()},
pub=lambda k, v: (f'{current_chain.get().chain_id}|{vault_owners[k]}', 'vb', (k,json.dumps({k2: str(v2) for k2, v2 in v.items()})))
)
pool_prices: BlockDict[str, dec] = BlockDict('p', db=True, redis=True, value2str=lambda d: f'{d:f}', str2value=dec,
pub=lambda k, v: (f'{current_chain.get().chain_id}|{k}', 'p', (k, str(v))))

View File

@@ -18,4 +18,4 @@ class SeriesSet (SeriesBase, Base):
pass
class SeriesDict (SeriesBase, Base):
value: Mapped[Json]
value: Mapped[str]

View File

@@ -1,12 +0,0 @@
import logging
from sqlalchemy.orm import Mapped, mapped_column
from dexorder.database.column import Address
from dexorder.database.model import Base
log = logging.getLogger(__name__)
class VaultToken (Base):
vault:Mapped[Address] = mapped_column(primary_key=True)
token:Mapped[Address] = mapped_column(primary_key=True)

View File

@@ -3,9 +3,9 @@ from uuid import UUID
from web3.types import EventData
from dexorder import current_pub, db, dec
from dexorder import current_pub, db
from dexorder.base.chain import current_chain
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request
from dexorder.transaction import handle_create_transactions, submit_transaction_request, handle_transaction_receipts, handle_send_transactions
from dexorder.blockchain.uniswap import uniswap_price
from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract
@@ -141,13 +141,32 @@ def handle_transfer(transfer: EventData):
from_address = transfer['args']['from']
to_address = transfer['args']['to']
amount = int(transfer['args']['value'])
log.debug(f'transfer {to_address}')
if to_address in vault_owners and to_address != from_address:
log.debug(f'deposit {to_address} {amount}')
vault = to_address
token_address = transfer['address']
vault_balances[to_address].add(token_address, amount, 0)
def transfer_in(d):
result = dict(d)
result[token_address] = result.get(token_address, 0) + amount
return result
vault_balances.modify(vault, transfer_in, default={})
if from_address in vault_owners and to_address != from_address:
log.debug(f'withdraw {to_address} {amount}')
vault = from_address
token_address = transfer['address']
vault_balances[to_address].add(token_address, -amount, 0)
def transfer_out(d):
result = dict(d)
result[token_address] = new_value = result.get(token_address, 0) - amount
if new_value < 0:
log.warning(f'Negative balance in vault {vault}:\n{d} - {token_address} : {amount}')
# value = await ContractProxy(from_address, 'ERC20').balanceOf(from_address)
return result
vault_balances.modify(vault, transfer_out, default={})
# todo check for negative balances.
if to_address not in vault_owners and from_address not in vault_owners:
vaults = vault_owners.keys()
log.debug(f'vaults: {list(vaults)}')
new_pool_prices: dict[str, int] = {}

View File

@@ -106,4 +106,5 @@ async def publish_all(pubs: list[tuple[str,str,list[Any]]]):
r: Pipeline
io = Emitter(dict(client=r))
for room, event, args in pubs:
log.debug(f'publishing {room} {event} {args}')
io.To(room).Emit(event, *args)

View File

@@ -164,8 +164,8 @@ class BlockStateRunner:
if fork.disjoint:
# backfill batches
for callback, event, log_filter in self.events:
if event is None:
batches.append(None)
if log_filter is None:
batches.append((None, callback, event, None))
else:
from_height = self.state.root_block.height + 1
end_height = block.height

View File

@@ -48,7 +48,8 @@ def key2str(key):
return _keystr1(key) if type(key) not in (list, tuple) else '|'.join(_keystr1(v) for v in key)
def str2key(s,types=None):
return tuple(s.split('|')) if types is None else tuple(t(v) for t,v in zip(types,s.split('|')))
key = tuple(s.split('|')) if types is None else tuple(t(v) for t,v in zip(types,s.split('|')))
return key[0] if len(key) == 1 else key
def topic(event_abi):
event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')'