order trigger recovery; state load bugfixes

This commit is contained in:
Tim Olson
2023-11-10 15:43:33 -04:00
parent 6c9e28b00d
commit 82be8d94e8
9 changed files with 99 additions and 48 deletions

View File

@@ -7,6 +7,7 @@ from dexorder.bin.executable import execute
from dexorder.blockstate.blockdata import BlockData from dexorder.blockstate.blockdata import BlockData
from dexorder.blockstate.db_state import DbState from dexorder.blockstate.db_state import DbState
from dexorder.configuration import parse_args from dexorder.configuration import parse_args
from dexorder.event_handler import init_order_triggers
from dexorder.memcache.memcache_state import RedisState, publish_all from dexorder.memcache.memcache_state import RedisState, publish_all
from dexorder.memcache import memcache from dexorder.memcache import memcache
from dexorder.runner import BlockStateRunner from dexorder.runner import BlockStateRunner
@@ -30,12 +31,14 @@ async def main():
with db.session: with db.session:
state = db_state.load() state = db_state.load()
if state is not None: if state is not None:
if redis_state: if redis_state:
await redis_state.init(state) await redis_state.init(state)
log.info(f'loaded state from db for root block {state.root_block}') log.info(f'loaded state from db for root block {state.root_block}')
runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None) runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None)
if db: if db:
runner.on_state_init.append(init_order_triggers)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
runner.on_promotion.append(db_state.save) runner.on_promotion.append(db_state.save)
if redis_state: if redis_state:

View File

@@ -36,7 +36,7 @@ class DbState(SeriesCollection):
return None return None
fork = Fork([hexbytes(blockhash)], height=height) fork = Fork([hexbytes(blockhash)], height=height)
value = db.session.get(Entity, (chain_id, series, key)) value = db.session.get(Entity, (chain_id, series, key))
return fork, value return fork, var.str2value(value.value)
def save(self, root_block: Block, diffs: Iterable[Union[DiffItem,DiffEntryItem]] ): def save(self, root_block: Block, diffs: Iterable[Union[DiffItem,DiffEntryItem]] ):
chain_id = current_chain.get().chain_id chain_id = current_chain.get().chain_id

View File

@@ -4,8 +4,11 @@ from dexorder import dec
from dexorder.base.chain import current_chain from dexorder.base.chain import current_chain
from dexorder.blockstate import BlockDict from dexorder.blockstate import BlockDict
from dexorder.blockstate.blockdata import K, V from dexorder.blockstate.blockdata import K, V
from dexorder.uniswap import UniswapV3Pool
from dexorder.util import json from dexorder.util import json
log = logging.getLogger(__name__)
# pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args) # pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args)
# if pub is True, then event is the current series name, room is the key, and args is [value] # if pub is True, then event is the current series name, room is the key, and args is [value]
# values of DELETE are serialized as nulls # values of DELETE are serialized as nulls
@@ -40,3 +43,9 @@ def pub_pool_price(k,v):
new_pool_prices: dict[str, dec] = {} # tracks which prices were set during the current block. cleared every block. new_pool_prices: dict[str, dec] = {} # tracks which prices were set during the current block. cleared every block.
pool_prices: PoolPrices = PoolPrices('p', db=True, redis=True, pub=pub_pool_price, value2str=lambda d: f'{d:f}', str2value=dec) pool_prices: PoolPrices = PoolPrices('p', db=True, redis=True, pub=pub_pool_price, value2str=lambda d: f'{d:f}', str2value=dec)
async def ensure_pool_price(pool_addr):
if pool_addr not in pool_prices:
log.debug(f'querying price for pool {pool_addr}')
pool_prices[pool_addr] = await UniswapV3Pool(pool_addr).price()

View File

@@ -1,3 +1,4 @@
import asyncio
import functools import functools
import logging import logging
from uuid import UUID from uuid import UUID
@@ -8,16 +9,16 @@ from dexorder import current_pub, db, dec
from dexorder.base.chain import current_chain from dexorder.base.chain import current_chain
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey
from dexorder.transaction import create_transactions, submit_transaction_request, handle_transaction_receipts, send_transactions from dexorder.transaction import create_transactions, submit_transaction_request, handle_transaction_receipts, send_transactions
from dexorder.uniswap import UniswapV3Pool, uniswap_price from dexorder.uniswap import uniswap_price
from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract
from dexorder.contract import get_contract_event, ERC20 from dexorder.contract import get_contract_event, ERC20
from dexorder.data import pool_prices, vault_owners, vault_balances, new_pool_prices from dexorder.data import pool_prices, vault_owners, vault_balances, new_pool_prices
from dexorder.database.model.block import current_block from dexorder.database.model.block import current_block
from dexorder.database.model.transaction import TransactionJob from dexorder.database.model.transaction import TransactionJob
from dexorder.order.orderlib import SwapOrderState, SwapOrderStatus from dexorder.order.orderlib import SwapOrderStatus
from dexorder.order.orderstate import Order from dexorder.order.orderstate import Order
from dexorder.order.triggers import OrderTriggers, close_order_and_disable_triggers, price_triggers, time_triggers, \ from dexorder.order.triggers import OrderTriggers, price_triggers, time_triggers, \
unconstrained_price_triggers, execution_requests, inflight_execution_requests, TrancheStatus, active_tranches, new_price_triggers unconstrained_price_triggers, execution_requests, inflight_execution_requests, TrancheStatus, active_tranches, new_price_triggers, activate_order
from dexorder.util.async_util import maywait from dexorder.util.async_util import maywait
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -25,14 +26,6 @@ log = logging.getLogger(__name__)
LOG_ALL_EVENTS = True # for debug LOG_ALL_EVENTS = True # for debug
async def ensure_pool_price(pool_addr):
if pool_addr not in pool_prices:
log.debug(f'querying price for pool {pool_addr}')
pool_prices[pool_addr] = await UniswapV3Pool(pool_addr).price()
def dump_log(eventlog):
log.debug(f'\t{eventlog}')
def setup_logevent_triggers(runner): def setup_logevent_triggers(runner):
runner.events.clear() runner.events.clear()
@@ -80,6 +73,18 @@ def setup_logevent_triggers(runner):
runner.add_event_trigger(send_transactions) runner.add_event_trigger(send_transactions)
def dump_log(eventlog):
log.debug(f'\t{eventlog}')
async def init_order_triggers():
log.debug('activating orders')
# this is a state init callback, called only once after the state has been loaded from the db or created fresh
orders = [Order.of(key) for key in Order.open_orders]
futures = [activate_order(order) for order in orders]
await asyncio.gather(*futures, return_exceptions=True)
log.debug(f'activated {len(futures)} orders')
def init(): def init():
new_pool_prices.clear() new_pool_prices.clear()
new_price_triggers.clear() new_price_triggers.clear()
@@ -107,12 +112,9 @@ async def handle_order_placed(event: EventData):
log.debug(f'raw order status {obj}') log.debug(f'raw order status {obj}')
order_status = SwapOrderStatus.load(obj) order_status = SwapOrderStatus.load(obj)
order = Order.create(vault.address, index, order_status) order = Order.create(vault.address, index, order_status)
await ensure_pool_price(order.pool_address) await activate_order(order)
triggers = OrderTriggers(order) log.debug(f'new order {order_status}')
log.debug(f'created order {order_status}')
if triggers.closed:
log.warning(f'order {order.key} was immediately closed')
close_order_and_disable_triggers(order, SwapOrderState.Filled if order.remaining <= 0 else SwapOrderState.Expired)
def handle_swap_filled(event: EventData): def handle_swap_filled(event: EventData):
# event DexorderSwapFilled (uint64 orderIndex, uint8 trancheIndex, uint256 amountIn, uint256 amountOut); # event DexorderSwapFilled (uint64 orderIndex, uint8 trancheIndex, uint256 amountIn, uint256 amountOut);
@@ -132,7 +134,7 @@ def handle_swap_filled(event: EventData):
triggers = OrderTriggers.instances[order.key] triggers = OrderTriggers.instances[order.key]
triggers.fill(tranche_index, amount_in, amount_out) triggers.fill(tranche_index, amount_in, amount_out)
except KeyError: except KeyError:
log.warning(f'No order triggers for fill of {TrancheKey(*order.key,tranche_index)}') log.warning(f'No order triggers for fill of {TrancheKey(order.key.vault, order.key.order_index, tranche_index)}')
async def handle_order_completed(event: EventData): async def handle_order_completed(event: EventData):
# event DexorderCompleted (uint64 orderIndex); // todo remove? # event DexorderCompleted (uint64 orderIndex); // todo remove?
@@ -322,6 +324,10 @@ def finish_execution_request(req: TrancheExecutionRequest, error: str):
# todo dont keep trying # todo dont keep trying
else: else:
log.error(f'Unhandled execution error for transaction request {req} ERROR: "{error}"') log.error(f'Unhandled execution error for transaction request {req} ERROR: "{error}"')
er = execution_requests[tk] try:
if er.height < current_block.get().height: er = execution_requests[tk]
del execution_requests[tk] except KeyError:
pass
else:
if er.height < current_block.get().height:
del execution_requests[tk]

View File

@@ -3,11 +3,12 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Union
from dexorder import dec from dexorder import dec
from dexorder.uniswap import uniswapV3_pool_address, uniswap_price from dexorder.uniswap import uniswapV3_pool_address, uniswap_price
from dexorder.contract import abi_decoder, abi_encoder from dexorder.contract import abi_decoder, abi_encoder
from dexorder.util import hexbytes
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -158,8 +159,8 @@ class TimeConstraint (Constraint):
TYPES = ['uint8', 'uint32', 'uint8', 'uint32'] TYPES = ['uint8', 'uint32', 'uint8', 'uint32']
@staticmethod @staticmethod
def load(obj: bytes): def load(obj: Union[bytes|str]):
earliest_mode, earliest_time, latest_mode, latest_time = abi_decoder.decode(TimeConstraint.TYPES, obj) earliest_mode, earliest_time, latest_mode, latest_time = abi_decoder.decode(TimeConstraint.TYPES, hexbytes(obj))
return TimeConstraint(ConstraintMode.Time, Time(TimeMode(earliest_mode),earliest_time), Time(TimeMode(latest_mode),latest_time)) return TimeConstraint(ConstraintMode.Time, Time(TimeMode(earliest_mode),earliest_time), Time(TimeMode(latest_mode),latest_time))
def dump(self): def dump(self):
@@ -178,7 +179,7 @@ class LineConstraint (Constraint):
@staticmethod @staticmethod
def load(obj): def load(obj):
return LineConstraint(ConstraintMode.Line, *abi_decoder.decode(LineConstraint.TYPES, obj)) return LineConstraint(ConstraintMode.Line, *abi_decoder.decode(LineConstraint.TYPES, hexbytes(obj)))
def dump(self): def dump(self):
return self._dump(LineConstraint.TYPES, (self.isAbove, self.isRatio, self.time, self.valueSqrtX96, self.slopeSqrtX96)) return self._dump(LineConstraint.TYPES, (self.isAbove, self.isRatio, self.time, self.valueSqrtX96, self.slopeSqrtX96))

View File

@@ -20,8 +20,8 @@ class Filled:
filled_out: int filled_out: int
@staticmethod @staticmethod
def load(string): def load(obj: tuple[str,str]):
return Filled(*map(int,string[1:-1].split(','))) return Filled(*map(int,obj))
def dump(self): def dump(self):
return str(self.filled_in), str(self.filled_out) return str(self.filled_in), str(self.filled_out)
@@ -33,9 +33,9 @@ class OrderFilled:
tranche_filled: list[Filled] tranche_filled: list[Filled]
@staticmethod @staticmethod
def load(string): def load(obj):
f, tfs = json.loads(string) f, tfs = obj
return OrderFilled(Filled(*f), [Filled(*tf) for tf in tfs]) return OrderFilled(Filled.load(f), [Filled.load(tf) for tf in tfs])
def dump(self): def dump(self):
return [self.filled.dump(), [tf.dump() for tf in self.tranche_filled]] return [self.filled.dump(), [tf.dump() for tf in self.tranche_filled]]
@@ -47,6 +47,9 @@ class Order:
""" """
represents the canonical internal representation of an order. some members are immutable like the order spec, and some are represents the canonical internal representation of an order. some members are immutable like the order spec, and some are
represented in various blockstate structures. this class hides that complexity to provide a clean interface to orders. represented in various blockstate structures. this class hides that complexity to provide a clean interface to orders.
Orders are therefore just references to state data, and may be accessed using Order.of(key). If there is a new order placed
in the system, invoke Order.create(...) instead.
""" """
instances: dict[OrderKey, 'Order'] = {} instances: dict[OrderKey, 'Order'] = {}
@@ -62,7 +65,10 @@ class Order:
@staticmethod @staticmethod
def of(a, b=None) -> 'Order': def of(a, b=None) -> 'Order':
key = a if b is None else OrderKey(a, b) key = a if b is None else OrderKey(a, b)
return Order.instances[key] try:
return Order.instances[key]
except KeyError:
return Order(key)
@staticmethod @staticmethod
@@ -92,7 +98,7 @@ class Order:
self.status: SwapOrderStatus = Order.order_statuses[key].copy() self.status: SwapOrderStatus = Order.order_statuses[key].copy()
self.pool_address: str = self.status.order.pool_address self.pool_address: str = self.status.order.pool_address
self.tranche_keys = [TrancheKey(key.vault, key.order_index, i) for i in range(len(self.status.trancheFilledIn))] self.tranche_keys = [TrancheKey(key.vault, key.order_index, i) for i in range(len(self.status.trancheFilledIn))]
# various flattenings # flattenings of various static data
self.order = self.status.order self.order = self.status.order
self.amount = self.status.order.amount self.amount = self.status.order.amount
self.amount_is_input = self.status.order.amountIsInput self.amount_is_input = self.status.order.amountIsInput
@@ -207,7 +213,7 @@ class Order:
# the filled amount fields for active orders are maintained in the order_remainings and tranche_remainings series. # the filled amount fields for active orders are maintained in the order_remainings and tranche_remainings series.
order_statuses: BlockDict[OrderKey, SwapOrderStatus] = BlockDict( order_statuses: BlockDict[OrderKey, SwapOrderStatus] = BlockDict(
'o', db='lazy', redis=True, pub=pub_order_status, 'o', db='lazy', redis=True, pub=pub_order_status,
str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=SwapOrderStatus.load, str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=lambda s:SwapOrderStatus.load(json.loads(s)),
) )
# open orders = the set of unfilled, not-canceled orders # open orders = the set of unfilled, not-canceled orders
@@ -222,7 +228,7 @@ class Order:
# is removed from open_orders, the order_status directly contains the final fill values. # is removed from open_orders, the order_status directly contains the final fill values.
order_filled: BlockDict[OrderKey, OrderFilled] = BlockDict( order_filled: BlockDict[OrderKey, OrderFilled] = BlockDict(
'of', db=True, redis=True, pub=pub_order_fills, 'of', db=True, redis=True, pub=pub_order_fills,
str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=OrderFilled.load) str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=lambda s:OrderFilled.load(json.loads(s)))
# "active" means the order wants to be executed now. this is not BlockData because it's cleared every block # "active" means the order wants to be executed now. this is not BlockData because it's cleared every block

View File

@@ -10,6 +10,7 @@ from dexorder.util import defaultdictk
from .orderstate import Order from .orderstate import Order
from .. import dec from .. import dec
from ..base.order import OrderKey, TrancheKey, ExecutionRequest from ..base.order import OrderKey, TrancheKey, ExecutionRequest
from ..data import ensure_pool_price
from ..database.model.block import current_block from ..database.model.block import current_block
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -28,6 +29,18 @@ execution_requests:BlockDict[TrancheKey, ExecutionRequest] = BlockDict('e') # g
# todo should this really be blockdata? # todo should this really be blockdata?
inflight_execution_requests:BlockDict[TrancheKey, int] = BlockDict('ei') # value is block height when the request was sent inflight_execution_requests:BlockDict[TrancheKey, int] = BlockDict('ei') # value is block height when the request was sent
async def activate_order(order):
"""
Call this to enable triggers on an order which is already in the state.
"""
await ensure_pool_price(order.pool_address)
triggers = OrderTriggers(order)
if triggers.closed:
log.debug(f'order {order.key} was immediately closed')
close_order_and_disable_triggers(order, SwapOrderState.Filled if order.remaining <= 0 else SwapOrderState.Expired)
def intersect_ranges( a_low, a_high, b_low, b_high): def intersect_ranges( a_low, a_high, b_low, b_high):
low, high = max(a_low,b_low), min(a_high,b_high) low, high = max(a_low,b_low), min(a_high,b_high)
if high <= low: if high <= low:
@@ -46,6 +59,9 @@ class TrancheTrigger:
self.order = order self.order = order
self.tk = tranche_key self.tk = tranche_key
self.status = TrancheStatus.Early self.status = TrancheStatus.Early
self.time_constraint = None
self.line_constraints: list[LineConstraint] = []
start = self.order.status.start start = self.order.status.start
tranche = order.order.tranches[self.tk.tranche_index] tranche = order.order.tranches[self.tk.tranche_index]
@@ -57,25 +73,22 @@ class TrancheTrigger:
self.status = TrancheStatus.Filled self.status = TrancheStatus.Filled
return return
time_constraint = None # stored as a tuple of two ints for earliest and latest absolute timestamps
self.line_constraints: list[LineConstraint] = []
for c in tranche.constraints: for c in tranche.constraints:
if c.mode == ConstraintMode.Time: if c.mode == ConstraintMode.Time:
c: TimeConstraint c: TimeConstraint
earliest = c.earliest.timestamp(start) earliest = c.earliest.timestamp(start)
latest = c.latest.timestamp(start) latest = c.latest.timestamp(start)
time_constraint = (earliest, latest) if time_constraint is None else intersect_ranges(*time_constraint, earliest, latest) self.time_constraint = (earliest, latest) if self.time_constraint is None else intersect_ranges(*self.time_constraint, earliest, latest)
elif c.mode == ConstraintMode.Line: elif c.mode == ConstraintMode.Line:
c: LineConstraint c: LineConstraint
self.line_constraints.append(c) self.line_constraints.append(c)
else: else:
raise NotImplementedError raise NotImplementedError
self.time_constraint = time_constraint if self.time_constraint is None:
if time_constraint is None:
self.status = TrancheStatus.Pricing self.status = TrancheStatus.Pricing
else: else:
timestamp = current_block.get().timestamp timestamp = current_block.get().timestamp
earliest, latest = time_constraint earliest, latest = self.time_constraint
self.status = TrancheStatus.Early if timestamp < earliest else TrancheStatus.Expired if timestamp > latest else TrancheStatus.Pricing self.status = TrancheStatus.Early if timestamp < earliest else TrancheStatus.Expired if timestamp > latest else TrancheStatus.Pricing
self.enable_time_trigger() self.enable_time_trigger()
if self.status == TrancheStatus.Pricing: if self.status == TrancheStatus.Pricing:

View File

@@ -35,6 +35,10 @@ class BlockStateRunner:
# items are (callback, event, log_filter). The callback is invoked with web3 EventData for every detected event # items are (callback, event, log_filter). The callback is invoked with web3 EventData for every detected event
self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = [] self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = []
# onStateInit callbacks are invoked after the initial state is loaded or created
self.on_state_init: list[Callable[[],None]] = []
self.state_initialized = False
# onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root # onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root
self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],None]] = [] self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],None]] = []
@@ -156,8 +160,8 @@ class BlockStateRunner:
# initialize # initialize
self.state = BlockState(block) self.state = BlockState(block)
current_blockstate.set(self.state) current_blockstate.set(self.state)
log.info('Created new empty root state')
fork = Fork([block.hash], height=block.height) fork = Fork([block.hash], height=block.height)
log.info('Created new empty root state')
else: else:
fork = self.state.add_block(block) fork = self.state.add_block(block)
if fork is None: if fork is None:
@@ -194,12 +198,14 @@ class BlockStateRunner:
# set up for callbacks # set up for callbacks
current_block.set(block) current_block.set(block)
current_fork.set(fork) current_fork.set(fork) # this is set earlier
session = db.session session = db.session
session.begin() session.begin()
session.add(block) session.add(block)
pubs = [] pubs = []
current_pub.set(lambda room, evnt, *args: pubs.append((room, evnt, args))) current_pub.set(lambda room, evnt, *args: pubs.append((room, evnt, args))) # used by handle_vault_created
if not self.state_initialized:
await self.do_state_init_cbs()
# logevent callbacks # logevent callbacks
for future, callback, event, filter_args in batches: for future, callback, event, filter_args in batches:
if future is None: if future is None:
@@ -245,3 +251,10 @@ class BlockStateRunner:
finally: finally:
if session is not None: if session is not None:
session.close() session.close()
async def do_state_init_cbs(self):
if self.state_initialized:
return
for cb in self.on_state_init:
await maywait(cb())
self.state_initialized = True

View File

@@ -1,5 +1,5 @@
import re import re
from typing import Callable, TypeVar, Generic from typing import Callable, TypeVar, Generic, Union
from eth_utils import keccak from eth_utils import keccak
from hexbytes import HexBytes from hexbytes import HexBytes
@@ -30,9 +30,9 @@ def hexstr(value: bytes):
raise ValueError raise ValueError
def hexbytes(value: str): def hexbytes(value: Union[str|bytes]):
""" converts an optionally 0x-prefixed hex string into bytes """ """ converts an optionally 0x-prefixed hex string into bytes """
return bytes.fromhex(value[2:] if value.startswith('0x') else value) return value if type(value) is bytes else bytes.fromhex(value[2:] if value.startswith('0x') else value)
def hexint(value: str): def hexint(value: str):