db state load fix; vault balance tracking with pubsub; db statedict saves strings not json
This commit is contained in:
@@ -34,7 +34,7 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint('key')
|
||||
)
|
||||
op.create_table('seriesdict',
|
||||
sa.Column('value', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('value', sa.String(), nullable=False),
|
||||
sa.Column('chain', sa.Integer(), nullable=False),
|
||||
sa.Column('series', sa.String(), nullable=False),
|
||||
sa.Column('key', sa.String(), nullable=False),
|
||||
|
||||
@@ -57,7 +57,7 @@ class DisjointFork:
|
||||
def __init__(self, block: Block, root: Block):
|
||||
self.height = block.height
|
||||
self.hash = block.hash
|
||||
self.parent = root.hash
|
||||
self.parent = root
|
||||
self.disjoint = True
|
||||
|
||||
def __contains__(self, item):
|
||||
@@ -65,10 +65,10 @@ class DisjointFork:
|
||||
return False # item is in the future
|
||||
if item.height < self.parent.height:
|
||||
return True # item is ancient
|
||||
return item.hash in (self.hash, self.parent)
|
||||
return item.hash in (self.hash, self.parent.hash)
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.height}_[{self.hash.hex()}->{self.parent.hex()}]'
|
||||
return f'{self.height}_[{self.hash.hex()}->{self.parent.hash.hex()}]'
|
||||
|
||||
|
||||
current_fork = ContextVar[Optional[Fork]]('current_fork', default=None)
|
||||
|
||||
@@ -23,7 +23,7 @@ class Token (ContractProxy, FixedDecimals):
|
||||
@staticmethod
|
||||
def get(name_or_address:str, *, chain_id=None) -> 'Token':
|
||||
try:
|
||||
return tokens.get(name_or_address, default=NARG, chain_id=chain_id) # default=NARG will raise
|
||||
return tokens.get(name_or_address, default=NARG, chain_id=chain_id)
|
||||
except KeyError:
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
|
||||
@@ -19,8 +19,7 @@ async def main():
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
log.setLevel(logging.DEBUG)
|
||||
parse_args()
|
||||
current_chain.set(Blockchain.get(config.chain))
|
||||
blockchain.connect()
|
||||
await blockchain.connect()
|
||||
redis_state = None
|
||||
state = None
|
||||
if memcache:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from hexbytes import HexBytes
|
||||
from web3 import WebsocketProviderV2, AsyncWeb3, AsyncHTTPProvider
|
||||
|
||||
from ..base.chain import current_chain
|
||||
from ..contract import get_contract_data
|
||||
from .. import current_w3
|
||||
from .. import current_w3, Blockchain
|
||||
from ..configuration import resolve_rpc_url
|
||||
from ..configuration.resolve import resolve_ws_url
|
||||
|
||||
|
||||
def connect(rpc_url=None):
|
||||
async def connect(rpc_url=None):
|
||||
"""
|
||||
connects to the rpc_url and configures the context
|
||||
if you don't want to use ctx.account for this w3, either set ctx.account first or
|
||||
@@ -15,6 +16,7 @@ def connect(rpc_url=None):
|
||||
"""
|
||||
w3 = create_w3(rpc_url)
|
||||
current_w3.set(w3)
|
||||
current_chain.set(Blockchain.get(await w3.eth.chain_id))
|
||||
return w3
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TypeVar, Generic, Iterable, Union, Any, Iterator, Callable
|
||||
@@ -26,7 +27,7 @@ class BlockData:
|
||||
def __init__(self, data_type: DataType, series: Any, *,
|
||||
series2str=None, series2key=None, # defaults to key2str and str2key
|
||||
key2str=util_key2str, str2key=util_str2key,
|
||||
value2str=lambda x:x, str2value=lambda x:x, # serialize/deserialize value to something JSON-able
|
||||
value2str=json.dumps, str2value=json.loads, # serialize/deserialize value to something JSON-able
|
||||
**opts):
|
||||
assert series not in BlockData.registry
|
||||
BlockData.registry[series] = self
|
||||
@@ -82,6 +83,18 @@ class BlockData:
|
||||
fork = current_fork.get()
|
||||
return state.iteritems(fork, series_key)
|
||||
|
||||
@staticmethod
|
||||
def iter_keys(series_key):
|
||||
state = current_blockstate.get()
|
||||
fork = current_fork.get()
|
||||
return state.iterkeys(fork, series_key)
|
||||
|
||||
@staticmethod
|
||||
def iter_values(series_key):
|
||||
state = current_blockstate.get()
|
||||
fork = current_fork.get()
|
||||
return state.itervalues(fork, series_key)
|
||||
|
||||
@staticmethod
|
||||
def by_opt(key):
|
||||
yield from (s for s in BlockData.registry.values() if key in s.opts)
|
||||
@@ -131,6 +144,12 @@ class BlockDict(Generic[K,V], BlockData):
|
||||
def items(self) -> Iterable[tuple[K,V]]:
|
||||
return self.iter_items(self.series)
|
||||
|
||||
def keys(self) -> Iterable[K]:
|
||||
return self.iter_keys(self.series)
|
||||
|
||||
def values(self) -> Iterable[V]:
|
||||
return self.iter_values(self.series)
|
||||
|
||||
def get(self, item: K, default: V = None) -> V:
|
||||
return self.getitem(item, default)
|
||||
|
||||
|
||||
@@ -62,10 +62,11 @@ class DbState(SeriesCollection):
|
||||
db.session.add(SeriesSet(**key))
|
||||
elif t == DataType.DICT:
|
||||
found = db.session.get(SeriesDict, key)
|
||||
value = d.value2str(diff.value)
|
||||
if found is None:
|
||||
db.session.add(SeriesDict(**key, value=d.value2str(diff.value)))
|
||||
db.session.add(SeriesDict(**key, value=value))
|
||||
else:
|
||||
found.value = diff.value
|
||||
found.value = value
|
||||
else:
|
||||
raise NotImplementedError
|
||||
db.kv[f'root_block|{root_block.chain}'] = [root_block.height, root_block.hash]
|
||||
@@ -87,16 +88,23 @@ class DbState(SeriesCollection):
|
||||
current_fork.set(None) # root fork
|
||||
for series, data in self.datas.items():
|
||||
if data.opts.get('db') != 'lazy':
|
||||
log.debug(f'loading series {series}')
|
||||
t = data.type
|
||||
if t == DataType.SET:
|
||||
# noinspection PyTypeChecker
|
||||
var: BlockSet = BlockData.registry[series]
|
||||
for row in db.session.query(SeriesSet).where(SeriesSet.series == data.series2str(series)):
|
||||
var.add(data.str2key(row.key))
|
||||
for row in db.session.query(SeriesSet).where(SeriesSet.chain == chain_id, SeriesSet.series == data.series2str(series)):
|
||||
key = data.str2key(row.key)
|
||||
log.debug(f'load {series} {key}')
|
||||
var.add(key)
|
||||
elif t == DataType.DICT:
|
||||
# noinspection PyTypeChecker
|
||||
var: BlockDict = BlockData.registry[series]
|
||||
for row in db.session.query(SeriesDict).where(SeriesDict.series == data.series2str(series)):
|
||||
var[data.str2key(row.key)] = data.str2value(row.value)
|
||||
for row in db.session.query(SeriesDict).where(SeriesDict.chain == chain_id, SeriesDict.series == data.series2str(series)):
|
||||
key = data.str2key(row.key)
|
||||
value = data.str2value(row.value)
|
||||
log.debug(f'load {series} {key} {value}')
|
||||
var[key] = value
|
||||
completed_block.set(root_block)
|
||||
log.debug(f'loaded db state from block {root_block}')
|
||||
return state
|
||||
|
||||
@@ -129,7 +129,7 @@ class BlockState:
|
||||
if diff.value is DELETE:
|
||||
break
|
||||
else:
|
||||
if self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task
|
||||
if fork and self.root_block not in fork: # todo move this assertion elsewhere so it runs once per task
|
||||
raise ValueError(f'Cannot get value for a non-root fork {hexstr(fork.hash)}')
|
||||
return diff.value
|
||||
return DELETE
|
||||
@@ -152,6 +152,22 @@ class BlockState:
|
||||
yield k, diff.value
|
||||
break
|
||||
|
||||
def iterkeys(self, fork: Optional[Fork], series):
|
||||
for k, difflist in self.diffs_by_series.get(series, {}).items():
|
||||
for diff in reversed(difflist):
|
||||
if diff.height <= self.root_block.height or fork is not None and diff in fork:
|
||||
if diff.value is not DELETE:
|
||||
yield k
|
||||
break
|
||||
|
||||
def itervalues(self, fork: Optional[Fork], series):
|
||||
for k, difflist in self.diffs_by_series.get(series, {}).items():
|
||||
for diff in reversed(difflist):
|
||||
if diff.height <= self.root_block.height or fork is not None and diff in fork:
|
||||
if diff.value is not DELETE:
|
||||
yield diff.value
|
||||
break
|
||||
|
||||
def promote_root(self, new_root_fork: Fork):
|
||||
block = self.by_hash[new_root_fork.hash]
|
||||
diffs = self.collect_diffs(block)
|
||||
|
||||
@@ -10,8 +10,6 @@ from typing import Optional, Union
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
chain: Union[int,str] = 'Arbitrum'
|
||||
|
||||
rpc_url: str = 'http://localhost:8545'
|
||||
ws_url: str = 'ws://localhost:8545'
|
||||
rpc_urls: Optional[dict[str,str]] = field(default_factory=dict)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from dexorder import dec
|
||||
from dexorder.blockstate import BlockSet, BlockDict
|
||||
from dexorder.util import defaultdictk, hexstr
|
||||
from dexorder.base.chain import current_chain
|
||||
from dexorder.blockstate import BlockDict
|
||||
from dexorder.util import json, defaultdictk
|
||||
|
||||
# pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args)
|
||||
# if pub is True, then event is the current series name, room is the key, and args is [value]
|
||||
# values of DELETE are serialized as nulls
|
||||
|
||||
vault_owners: BlockDict[str,str] = BlockDict('v', db=True, redis=True)
|
||||
vault_balances: dict[str, BlockDict[str,int]] = defaultdictk(lambda vault: BlockDict(f'vb|{vault}', db=True, redis=True,
|
||||
pub=lambda k,v: (vault_owners[vault], 'vb', (vault,k,v))))
|
||||
pool_prices: BlockDict[str,dec] = BlockDict('p', db=True, redis=True, value2str=lambda d:f'{d:f}', str2value=dec,
|
||||
pub=lambda k,v: (f'p|{k}', 'p', (k,str(v))))
|
||||
vault_owners: BlockDict[str, str] = BlockDict('v', db=True, redis=True)
|
||||
vault_balances: BlockDict[str, dict[str, int]] = BlockDict(
|
||||
f'vb', db=True, redis=True,
|
||||
value2str=lambda d: json.dumps({k: str(v) for k, v in d.items()}), # ints can be large so we need to stringify them in JSON
|
||||
str2value=lambda s: {k: int(v) for k, v in json.loads(s).items()},
|
||||
pub=lambda k, v: (f'{current_chain.get().chain_id}|{vault_owners[k]}', 'vb', (k,json.dumps({k2: str(v2) for k2, v2 in v.items()})))
|
||||
)
|
||||
pool_prices: BlockDict[str, dec] = BlockDict('p', db=True, redis=True, value2str=lambda d: f'{d:f}', str2value=dec,
|
||||
pub=lambda k, v: (f'{current_chain.get().chain_id}|{k}', 'p', (k, str(v))))
|
||||
|
||||
@@ -18,4 +18,4 @@ class SeriesSet (SeriesBase, Base):
|
||||
pass
|
||||
|
||||
class SeriesDict (SeriesBase, Base):
|
||||
value: Mapped[Json]
|
||||
value: Mapped[str]
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from dexorder.database.column import Address
|
||||
from dexorder.database.model import Base
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
class VaultToken (Base):
|
||||
vault:Mapped[Address] = mapped_column(primary_key=True)
|
||||
token:Mapped[Address] = mapped_column(primary_key=True)
|
||||
@@ -3,9 +3,9 @@ from uuid import UUID
|
||||
|
||||
from web3.types import EventData
|
||||
|
||||
from dexorder import current_pub, db, dec
|
||||
from dexorder import current_pub, db
|
||||
from dexorder.base.chain import current_chain
|
||||
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey
|
||||
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request
|
||||
from dexorder.transaction import handle_create_transactions, submit_transaction_request, handle_transaction_receipts, handle_send_transactions
|
||||
from dexorder.blockchain.uniswap import uniswap_price
|
||||
from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract
|
||||
@@ -141,13 +141,32 @@ def handle_transfer(transfer: EventData):
|
||||
from_address = transfer['args']['from']
|
||||
to_address = transfer['args']['to']
|
||||
amount = int(transfer['args']['value'])
|
||||
log.debug(f'transfer {to_address}')
|
||||
if to_address in vault_owners and to_address != from_address:
|
||||
log.debug(f'deposit {to_address} {amount}')
|
||||
vault = to_address
|
||||
token_address = transfer['address']
|
||||
vault_balances[to_address].add(token_address, amount, 0)
|
||||
def transfer_in(d):
|
||||
result = dict(d)
|
||||
result[token_address] = result.get(token_address, 0) + amount
|
||||
return result
|
||||
vault_balances.modify(vault, transfer_in, default={})
|
||||
if from_address in vault_owners and to_address != from_address:
|
||||
log.debug(f'withdraw {to_address} {amount}')
|
||||
vault = from_address
|
||||
token_address = transfer['address']
|
||||
vault_balances[to_address].add(token_address, -amount, 0)
|
||||
def transfer_out(d):
|
||||
result = dict(d)
|
||||
result[token_address] = new_value = result.get(token_address, 0) - amount
|
||||
if new_value < 0:
|
||||
log.warning(f'Negative balance in vault {vault}:\n{d} - {token_address} : {amount}')
|
||||
# value = await ContractProxy(from_address, 'ERC20').balanceOf(from_address)
|
||||
return result
|
||||
vault_balances.modify(vault, transfer_out, default={})
|
||||
# todo check for negative balances.
|
||||
|
||||
if to_address not in vault_owners and from_address not in vault_owners:
|
||||
vaults = vault_owners.keys()
|
||||
log.debug(f'vaults: {list(vaults)}')
|
||||
|
||||
|
||||
new_pool_prices: dict[str, int] = {}
|
||||
|
||||
@@ -106,4 +106,5 @@ async def publish_all(pubs: list[tuple[str,str,list[Any]]]):
|
||||
r: Pipeline
|
||||
io = Emitter(dict(client=r))
|
||||
for room, event, args in pubs:
|
||||
log.debug(f'publishing {room} {event} {args}')
|
||||
io.To(room).Emit(event, *args)
|
||||
|
||||
@@ -164,8 +164,8 @@ class BlockStateRunner:
|
||||
if fork.disjoint:
|
||||
# backfill batches
|
||||
for callback, event, log_filter in self.events:
|
||||
if event is None:
|
||||
batches.append(None)
|
||||
if log_filter is None:
|
||||
batches.append((None, callback, event, None))
|
||||
else:
|
||||
from_height = self.state.root_block.height + 1
|
||||
end_height = block.height
|
||||
|
||||
@@ -48,7 +48,8 @@ def key2str(key):
|
||||
return _keystr1(key) if type(key) not in (list, tuple) else '|'.join(_keystr1(v) for v in key)
|
||||
|
||||
def str2key(s,types=None):
|
||||
return tuple(s.split('|')) if types is None else tuple(t(v) for t,v in zip(types,s.split('|')))
|
||||
key = tuple(s.split('|')) if types is None else tuple(t(v) for t,v in zip(types,s.split('|')))
|
||||
return key[0] if len(key) == 1 else key
|
||||
|
||||
def topic(event_abi):
|
||||
event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')'
|
||||
|
||||
Reference in New Issue
Block a user