order trigger recovery; state load bugfixes

This commit is contained in:
Tim Olson
2023-11-10 15:43:33 -04:00
parent 6c9e28b00d
commit 82be8d94e8
9 changed files with 99 additions and 48 deletions

View File

@@ -7,6 +7,7 @@ from dexorder.bin.executable import execute
from dexorder.blockstate.blockdata import BlockData
from dexorder.blockstate.db_state import DbState
from dexorder.configuration import parse_args
from dexorder.event_handler import init_order_triggers
from dexorder.memcache.memcache_state import RedisState, publish_all
from dexorder.memcache import memcache
from dexorder.runner import BlockStateRunner
@@ -30,12 +31,14 @@ async def main():
with db.session:
state = db_state.load()
if state is not None:
if redis_state:
await redis_state.init(state)
log.info(f'loaded state from db for root block {state.root_block}')
runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None)
if db:
runner.on_state_init.append(init_order_triggers)
# noinspection PyUnboundLocalVariable
runner.on_promotion.append(db_state.save)
if redis_state:

View File

@@ -36,7 +36,7 @@ class DbState(SeriesCollection):
return None
fork = Fork([hexbytes(blockhash)], height=height)
value = db.session.get(Entity, (chain_id, series, key))
return fork, value
return fork, var.str2value(value.value)
def save(self, root_block: Block, diffs: Iterable[Union[DiffItem,DiffEntryItem]] ):
chain_id = current_chain.get().chain_id

View File

@@ -4,8 +4,11 @@ from dexorder import dec
from dexorder.base.chain import current_chain
from dexorder.blockstate import BlockDict
from dexorder.blockstate.blockdata import K, V
from dexorder.uniswap import UniswapV3Pool
from dexorder.util import json
log = logging.getLogger(__name__)
# pub=... publishes to a channel for web clients to consume. argument is (key,value) and return must be (event,room,args)
# if pub is True, then event is the current series name, room is the key, and args is [value]
# values of DELETE are serialized as nulls
@@ -40,3 +43,9 @@ def pub_pool_price(k,v):
new_pool_prices: dict[str, dec] = {} # tracks which prices were set during the current block. cleared every block.
pool_prices: PoolPrices = PoolPrices('p', db=True, redis=True, pub=pub_pool_price, value2str=lambda d: f'{d:f}', str2value=dec)
async def ensure_pool_price(pool_addr):
if pool_addr not in pool_prices:
log.debug(f'querying price for pool {pool_addr}')
pool_prices[pool_addr] = await UniswapV3Pool(pool_addr).price()

View File

@@ -1,3 +1,4 @@
import asyncio
import functools
import logging
from uuid import UUID
@@ -8,16 +9,16 @@ from dexorder import current_pub, db, dec
from dexorder.base.chain import current_chain
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey
from dexorder.transaction import create_transactions, submit_transaction_request, handle_transaction_receipts, send_transactions
from dexorder.uniswap import UniswapV3Pool, uniswap_price
from dexorder.uniswap import uniswap_price
from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract
from dexorder.contract import get_contract_event, ERC20
from dexorder.data import pool_prices, vault_owners, vault_balances, new_pool_prices
from dexorder.database.model.block import current_block
from dexorder.database.model.transaction import TransactionJob
from dexorder.order.orderlib import SwapOrderState, SwapOrderStatus
from dexorder.order.orderlib import SwapOrderStatus
from dexorder.order.orderstate import Order
from dexorder.order.triggers import OrderTriggers, close_order_and_disable_triggers, price_triggers, time_triggers, \
unconstrained_price_triggers, execution_requests, inflight_execution_requests, TrancheStatus, active_tranches, new_price_triggers
from dexorder.order.triggers import OrderTriggers, price_triggers, time_triggers, \
unconstrained_price_triggers, execution_requests, inflight_execution_requests, TrancheStatus, active_tranches, new_price_triggers, activate_order
from dexorder.util.async_util import maywait
log = logging.getLogger(__name__)
@@ -25,14 +26,6 @@ log = logging.getLogger(__name__)
LOG_ALL_EVENTS = True # for debug
async def ensure_pool_price(pool_addr):
if pool_addr not in pool_prices:
log.debug(f'querying price for pool {pool_addr}')
pool_prices[pool_addr] = await UniswapV3Pool(pool_addr).price()
def dump_log(eventlog):
log.debug(f'\t{eventlog}')
def setup_logevent_triggers(runner):
runner.events.clear()
@@ -80,6 +73,18 @@ def setup_logevent_triggers(runner):
runner.add_event_trigger(send_transactions)
def dump_log(eventlog):
log.debug(f'\t{eventlog}')
async def init_order_triggers():
log.debug('activating orders')
# this is a state init callback, called only once after the state has been loaded from the db or created fresh
orders = [Order.of(key) for key in Order.open_orders]
futures = [activate_order(order) for order in orders]
await asyncio.gather(*futures, return_exceptions=True)
log.debug(f'activated {len(futures)} orders')
def init():
new_pool_prices.clear()
new_price_triggers.clear()
@@ -107,12 +112,9 @@ async def handle_order_placed(event: EventData):
log.debug(f'raw order status {obj}')
order_status = SwapOrderStatus.load(obj)
order = Order.create(vault.address, index, order_status)
await ensure_pool_price(order.pool_address)
triggers = OrderTriggers(order)
log.debug(f'created order {order_status}')
if triggers.closed:
log.warning(f'order {order.key} was immediately closed')
close_order_and_disable_triggers(order, SwapOrderState.Filled if order.remaining <= 0 else SwapOrderState.Expired)
await activate_order(order)
log.debug(f'new order {order_status}')
def handle_swap_filled(event: EventData):
# event DexorderSwapFilled (uint64 orderIndex, uint8 trancheIndex, uint256 amountIn, uint256 amountOut);
@@ -132,7 +134,7 @@ def handle_swap_filled(event: EventData):
triggers = OrderTriggers.instances[order.key]
triggers.fill(tranche_index, amount_in, amount_out)
except KeyError:
log.warning(f'No order triggers for fill of {TrancheKey(*order.key,tranche_index)}')
log.warning(f'No order triggers for fill of {TrancheKey(order.key.vault, order.key.order_index, tranche_index)}')
async def handle_order_completed(event: EventData):
# event DexorderCompleted (uint64 orderIndex); // todo remove?
@@ -322,6 +324,10 @@ def finish_execution_request(req: TrancheExecutionRequest, error: str):
# todo dont keep trying
else:
log.error(f'Unhandled execution error for transaction request {req} ERROR: "{error}"')
er = execution_requests[tk]
if er.height < current_block.get().height:
del execution_requests[tk]
try:
er = execution_requests[tk]
except KeyError:
pass
else:
if er.height < current_block.get().height:
del execution_requests[tk]

View File

@@ -3,11 +3,12 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Union
from dexorder import dec
from dexorder.uniswap import uniswapV3_pool_address, uniswap_price
from dexorder.contract import abi_decoder, abi_encoder
from dexorder.util import hexbytes
log = logging.getLogger(__name__)
@@ -158,8 +159,8 @@ class TimeConstraint (Constraint):
TYPES = ['uint8', 'uint32', 'uint8', 'uint32']
@staticmethod
def load(obj: bytes):
earliest_mode, earliest_time, latest_mode, latest_time = abi_decoder.decode(TimeConstraint.TYPES, obj)
def load(obj: Union[bytes|str]):
earliest_mode, earliest_time, latest_mode, latest_time = abi_decoder.decode(TimeConstraint.TYPES, hexbytes(obj))
return TimeConstraint(ConstraintMode.Time, Time(TimeMode(earliest_mode),earliest_time), Time(TimeMode(latest_mode),latest_time))
def dump(self):
@@ -178,7 +179,7 @@ class LineConstraint (Constraint):
@staticmethod
def load(obj):
return LineConstraint(ConstraintMode.Line, *abi_decoder.decode(LineConstraint.TYPES, obj))
return LineConstraint(ConstraintMode.Line, *abi_decoder.decode(LineConstraint.TYPES, hexbytes(obj)))
def dump(self):
return self._dump(LineConstraint.TYPES, (self.isAbove, self.isRatio, self.time, self.valueSqrtX96, self.slopeSqrtX96))

View File

@@ -20,8 +20,8 @@ class Filled:
filled_out: int
@staticmethod
def load(string):
return Filled(*map(int,string[1:-1].split(',')))
def load(obj: tuple[str,str]):
return Filled(*map(int,obj))
def dump(self):
return str(self.filled_in), str(self.filled_out)
@@ -33,9 +33,9 @@ class OrderFilled:
tranche_filled: list[Filled]
@staticmethod
def load(string):
f, tfs = json.loads(string)
return OrderFilled(Filled(*f), [Filled(*tf) for tf in tfs])
def load(obj):
f, tfs = obj
return OrderFilled(Filled.load(f), [Filled.load(tf) for tf in tfs])
def dump(self):
return [self.filled.dump(), [tf.dump() for tf in self.tranche_filled]]
@@ -47,6 +47,9 @@ class Order:
"""
represents the canonical internal representation of an order. some members are immutable like the order spec, and some are
represented in various blockstate structures. this class hides that complexity to provide a clean interface to orders.
Orders are therefore just references to state data, and may be accessed using Order.of(key). If there is a new order placed
in the system, invoke Order.create(...) instead.
"""
instances: dict[OrderKey, 'Order'] = {}
@@ -62,7 +65,10 @@ class Order:
@staticmethod
def of(a, b=None) -> 'Order':
key = a if b is None else OrderKey(a, b)
return Order.instances[key]
try:
return Order.instances[key]
except KeyError:
return Order(key)
@staticmethod
@@ -92,7 +98,7 @@ class Order:
self.status: SwapOrderStatus = Order.order_statuses[key].copy()
self.pool_address: str = self.status.order.pool_address
self.tranche_keys = [TrancheKey(key.vault, key.order_index, i) for i in range(len(self.status.trancheFilledIn))]
# various flattenings
# flattenings of various static data
self.order = self.status.order
self.amount = self.status.order.amount
self.amount_is_input = self.status.order.amountIsInput
@@ -207,7 +213,7 @@ class Order:
# the filled amount fields for active orders are maintained in the order_remainings and tranche_remainings series.
order_statuses: BlockDict[OrderKey, SwapOrderStatus] = BlockDict(
'o', db='lazy', redis=True, pub=pub_order_status,
str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=SwapOrderStatus.load,
str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=lambda s:SwapOrderStatus.load(json.loads(s)),
)
# open orders = the set of unfilled, not-canceled orders
@@ -222,7 +228,7 @@ class Order:
# is removed from open_orders, the order_status directly contains the final fill values.
order_filled: BlockDict[OrderKey, OrderFilled] = BlockDict(
'of', db=True, redis=True, pub=pub_order_fills,
str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=OrderFilled.load)
str2key=OrderKey.str2key, value2str=lambda v: json.dumps(v.dump()), str2value=lambda s:OrderFilled.load(json.loads(s)))
# "active" means the order wants to be executed now. this is not BlockData because it's cleared every block

View File

@@ -10,6 +10,7 @@ from dexorder.util import defaultdictk
from .orderstate import Order
from .. import dec
from ..base.order import OrderKey, TrancheKey, ExecutionRequest
from ..data import ensure_pool_price
from ..database.model.block import current_block
log = logging.getLogger(__name__)
@@ -28,6 +29,18 @@ execution_requests:BlockDict[TrancheKey, ExecutionRequest] = BlockDict('e') # g
# todo should this really be blockdata?
inflight_execution_requests:BlockDict[TrancheKey, int] = BlockDict('ei') # value is block height when the request was sent
async def activate_order(order):
"""
Call this to enable triggers on an order which is already in the state.
"""
await ensure_pool_price(order.pool_address)
triggers = OrderTriggers(order)
if triggers.closed:
log.debug(f'order {order.key} was immediately closed')
close_order_and_disable_triggers(order, SwapOrderState.Filled if order.remaining <= 0 else SwapOrderState.Expired)
def intersect_ranges( a_low, a_high, b_low, b_high):
low, high = max(a_low,b_low), min(a_high,b_high)
if high <= low:
@@ -46,6 +59,9 @@ class TrancheTrigger:
self.order = order
self.tk = tranche_key
self.status = TrancheStatus.Early
self.time_constraint = None
self.line_constraints: list[LineConstraint] = []
start = self.order.status.start
tranche = order.order.tranches[self.tk.tranche_index]
@@ -57,25 +73,22 @@ class TrancheTrigger:
self.status = TrancheStatus.Filled
return
time_constraint = None # stored as a tuple of two ints for earliest and latest absolute timestamps
self.line_constraints: list[LineConstraint] = []
for c in tranche.constraints:
if c.mode == ConstraintMode.Time:
c: TimeConstraint
earliest = c.earliest.timestamp(start)
latest = c.latest.timestamp(start)
time_constraint = (earliest, latest) if time_constraint is None else intersect_ranges(*time_constraint, earliest, latest)
self.time_constraint = (earliest, latest) if self.time_constraint is None else intersect_ranges(*self.time_constraint, earliest, latest)
elif c.mode == ConstraintMode.Line:
c: LineConstraint
self.line_constraints.append(c)
else:
raise NotImplementedError
self.time_constraint = time_constraint
if time_constraint is None:
if self.time_constraint is None:
self.status = TrancheStatus.Pricing
else:
timestamp = current_block.get().timestamp
earliest, latest = time_constraint
earliest, latest = self.time_constraint
self.status = TrancheStatus.Early if timestamp < earliest else TrancheStatus.Expired if timestamp > latest else TrancheStatus.Pricing
self.enable_time_trigger()
if self.status == TrancheStatus.Pricing:

View File

@@ -35,6 +35,10 @@ class BlockStateRunner:
# items are (callback, event, log_filter). The callback is invoked with web3 EventData for every detected event
self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = []
# onStateInit callbacks are invoked after the initial state is loaded or created
self.on_state_init: list[Callable[[],None]] = []
self.state_initialized = False
# onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root
self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],None]] = []
@@ -156,8 +160,8 @@ class BlockStateRunner:
# initialize
self.state = BlockState(block)
current_blockstate.set(self.state)
log.info('Created new empty root state')
fork = Fork([block.hash], height=block.height)
log.info('Created new empty root state')
else:
fork = self.state.add_block(block)
if fork is None:
@@ -194,12 +198,14 @@ class BlockStateRunner:
# set up for callbacks
current_block.set(block)
current_fork.set(fork)
current_fork.set(fork) # this is set earlier
session = db.session
session.begin()
session.add(block)
pubs = []
current_pub.set(lambda room, evnt, *args: pubs.append((room, evnt, args)))
current_pub.set(lambda room, evnt, *args: pubs.append((room, evnt, args))) # used by handle_vault_created
if not self.state_initialized:
await self.do_state_init_cbs()
# logevent callbacks
for future, callback, event, filter_args in batches:
if future is None:
@@ -245,3 +251,10 @@ class BlockStateRunner:
finally:
if session is not None:
session.close()
async def do_state_init_cbs(self):
if self.state_initialized:
return
for cb in self.on_state_init:
await maywait(cb())
self.state_initialized = True

View File

@@ -1,5 +1,5 @@
import re
from typing import Callable, TypeVar, Generic
from typing import Callable, TypeVar, Generic, Union
from eth_utils import keccak
from hexbytes import HexBytes
@@ -30,9 +30,9 @@ def hexstr(value: bytes):
raise ValueError
def hexbytes(value: str):
def hexbytes(value: Union[str|bytes]):
""" converts an optionally 0x-prefixed hex string into bytes """
return bytes.fromhex(value[2:] if value.startswith('0x') else value)
return value if type(value) is bytes else bytes.fromhex(value[2:] if value.startswith('0x') else value)
def hexint(value: str):