OHLC's and datamain.py; update DB and package requirements.

This commit is contained in:
tim
2024-01-21 21:04:52 -04:00
parent 8b113563b3
commit c154f13f7c
20 changed files with 499 additions and 179 deletions

View File

@@ -1,10 +1,18 @@
# noinspection PyPackageRequirements
from contextvars import ContextVar
from datetime import datetime
from decimal import Decimal
from typing import Callable, Any
from web3 import AsyncWeb3
dec = Decimal
def now():
return datetime.utcnow() # we use naive datetimes that are always UTC
def timestamp():
return datetime.now().timestamp()
# NARG is used in argument defaults to mean "not specified" rather than "specified as None"
class _Token:

View File

@@ -1,6 +1,8 @@
import math
# noinspection PyPackageRequirements
from contextvars import ContextVar
from datetime import datetime
import dexorder
class Blockchain:
@@ -59,10 +61,10 @@ class BlockClock:
def set(self, timestamp):
self.timestamp = timestamp
self.adjustment = timestamp - datetime.now().timestamp()
self.adjustment = timestamp - dexorder.timestamp()
def now(self):
return math.ceil(datetime.now().timestamp() + self.adjustment)
def timestamp(self):
return math.ceil(dexorder.timestamp() + self.adjustment)
current_clock = ContextVar[BlockClock]('clock') # current estimated timestamp of the blockchain. will be different than current_block.get().timestamp when evaluating time triggers in-between blocks

View File

@@ -1,25 +1,72 @@
import json
import logging
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional
from dexorder import dec
from cachetools import LFUCache
from dexorder import dec, config
from dexorder.blockstate import BlockDict
log = logging.getLogger(__name__)
OHLC_PERIODS = [
timedelta(minutes=1), timedelta(minutes=3), timedelta(minutes=5), timedelta(minutes=10), timedelta(minutes=15), timedelta(minutes=30),
timedelta(hours=1), timedelta(hours=2), timedelta(hours=4), timedelta(hours=8), timedelta(hours=12),
timedelta(days=1), timedelta(days=2), timedelta(days=3), timedelta(days=7)
timedelta(days=1), timedelta(days=2), timedelta(days=3), timedelta(days=7)
]
OHLC_DATE_ROOT = datetime(2009, 1, 4) # Sunday before Bitcoin Genesis
# OHLC's are stored as [time, open, high, low, close] string values. If there was no data during the interval,
# then open, high, and low are None but the close value is carried over from the previous interval.
OHLC = list[str] # typedef
def opt_dec(v):
return None if v is None else dec(v)
def dt(v):
return v if isinstance(v, datetime) else datetime.fromisoformat(v)
@dataclass
class NativeOHLC:
@staticmethod
def from_ohlc(ohlc: OHLC) -> 'NativeOHLC':
return NativeOHLC(*[cast(value) for value, cast in zip(ohlc,(dt, opt_dec, opt_dec, opt_dec, dec))])
start: datetime
open: Optional[dec]
high: Optional[dec]
low: Optional[dec]
close: dec
@property
def ohlc(self) -> OHLC:
return [
self.start.isoformat(timespec='minutes'),
None if self.open is None else str(self.open),
None if self.high is None else str(self.high),
None if self.low is None else str(self.low),
str(self.close)
]
def ohlc_name(period: timedelta) -> str:
return f'{period//timedelta(minutes=1)}m' if period < timedelta(hours=1) \
else f'{period//timedelta(hours=1)}H' if period < timedelta(days=1) \
else f'{period//timedelta(days=7)}W' if period == timedelta(days=7) \
else f'{period//timedelta(days=1)}D'
return f'{period // timedelta(minutes=1)}m' if period < timedelta(hours=1) \
else f'{period // timedelta(hours=1)}H' if period < timedelta(days=1) \
else f'{period // timedelta(days=7)}W' if period == timedelta(days=7) \
else f'{period // timedelta(days=1)}D'
def period_from_name(name: str) -> timedelta:
value = int(name[:-1])
unit = name[-1:]
factor = {'m':timedelta(minutes=1), 'H':timedelta(hours=1), 'D':timedelta(days=1), 'W':timedelta(days=7)}[unit]
return value * factor
def ohlc_start_time(time, period: timedelta):
@@ -29,57 +76,140 @@ def ohlc_start_time(time, period: timedelta):
return OHLC_DATE_ROOT + timedelta(seconds=period_sec * period_count)
@dataclass(frozen=True)
class OHLC:
start_time: datetime # first datetime included in this range
period: timedelta # the interval covered by this range, starting from start_time
# if no swaps happen during the interval, heights are set to prev_ohlc.last_height
first_height: int = None # blockchain height of the first trade in this range.
last_height: int = None # last_height == first_height if there's zero or one trades during this interval
# if no swaps happen during the interval, prices are set to prev_ohlc.close
open: dec = None
high: dec = None
low: dec = None
close: dec = None
has_data: bool = False # True iff any trade has happened this period
def update(self, height: int, time: datetime, price: dec) -> list['OHLC',...:'OHLC']:
""" returns an ordered list of OHLC's that have been created/modified by the new price """
assert time >= self.start_time
result = []
cur = self
start = self.start_time
while True:
end = start + self.period
if time < end:
break
result.append(cur)
start = end
cur = OHLC(start, self.period, cur.last_height, cur.last_height, cur.close, cur.close, cur.close, cur.close)
if not cur.has_data:
cur = OHLC(cur.start_time, self.period, height, height, price, price, price, price, True)
def update_ohlc(prev: OHLC, period: timedelta, time: datetime, price: Optional[dec]) -> list[OHLC]:
"""
returns an ordered list of OHLC's that have been created/modified by the new time/price
if price is None, then bars are advanced based on the time but no new price is added to the series.
"""
cur = NativeOHLC.from_ohlc(prev)
assert time >= cur.start
result = []
# advance time and finalize any past OHLC's into the result array
while True:
end = cur.start + period
if time < end:
break
result.append(cur.ohlc)
cur = NativeOHLC(end, None, None, None, cur.close)
# if we are setting a price, update the current bar
if price is not None:
if cur.open is None:
cur.open = price
cur.high = price
cur.low = price
else:
cur = OHLC(cur.start_time, self.period, cur.first_height, height, cur.open, max(cur.high,price), min(cur.low,price), price, True)
result.append(cur)
return result
cur.high = max(cur.high, price)
cur.low = min(cur.low, price)
cur.close = price
result.append(cur.ohlc)
return result
# The most recent OHLC's are stored as block data. We store a list of at least the two latest bars, which provides clients with
# the latest finalized bar as well as the current open bar.
recent_ohlcs = BlockDict('ohlc', db=True, redis=True)
class OHLCRepository:
def __init__(self, base_dir: str):
def __init__(self, base_dir: str = None):
""" can't actually make more than one of these because there's a global recent_ohlcs BlockDict """
if base_dir is None:
base_dir = config.ohlc_dir
self.dir = base_dir
self.cache = LFUCache(len(OHLC_PERIODS) * 128) # enough for the top 128 pools
@staticmethod
def add_symbol(symbol: str, period: timedelta = None):
if period is not None:
if (symbol, period) not in recent_ohlcs:
recent_ohlcs[(symbol, period)] = [] # setting an empty value will initiate price capture
else:
for period in OHLC_PERIODS:
if (symbol, period) not in recent_ohlcs:
recent_ohlcs[(symbol, period)] = []
def update_all(self, symbol: str, time: datetime, price: dec, *, create: bool = False):
for period in OHLC_PERIODS:
self.update(symbol, period, time, price, create=create)
def update(self, symbol: str, period: timedelta, time: datetime, price: Optional[dec], *, create: bool = False) -> Optional[list[OHLC]]:
"""
if price is None, then bars are advanced based on the time but no new price is added to the series.
"""
key = (symbol, period)
bars: Optional[list[OHLC]] = recent_ohlcs.get(key)
if bars is None:
if create is False:
return # do not track symbols which have not been explicity set up
bars = [OHLC((ohlc_start_time(time, period).isoformat(timespec='minutes'), price, price, price, price))]
updated = update_ohlc(bars[-1], period, time, price)
if len(updated) == 1:
updated = [*bars[:-1], updated[0]] # return the previous finalized bars along with the updated current bar
recent_ohlcs.setitem(key, updated)
if len(updated) > 1:
self.save_all(symbol, period, updated[:-1])
return updated
def save_all(self, symbol: str, period: timedelta, ohlc_list: list[OHLC]) -> None:
for ohlc in ohlc_list:
self.save(symbol, period, ohlc)
def save(self, symbol: str, period: timedelta, ohlc: OHLC) -> None:
time = dt(ohlc[0])
chunk = self.get_chunk(symbol, period, time)
if not chunk:
chunk = [ohlc]
else:
start = datetime.fromisoformat(chunk[0][0])
index = (time - start) // period
if index == len(chunk):
assert datetime.fromisoformat(chunk[-1][0]) + period == time
chunk.append(ohlc)
else:
assert datetime.fromisoformat(chunk[index][0]) == time
chunk[index] = ohlc
self.save_chunk(symbol, period, chunk)
def get_chunk(self, symbol: str, period: timedelta, start_time: datetime) -> list[OHLC]:
path = self.chunk_path(symbol, period, start_time)
found = self.cache.get(path)
if found is None:
found = self.load_chunk(path)
if found is None:
found = []
self.cache[path] = found
return found
@staticmethod
def load_chunk(path: str) -> Optional[list[OHLC]]:
try:
with open(path, 'r') as file:
return json.load(file)
except FileNotFoundError:
return []
def save_chunk(self, symbol: str, period: timedelta, chunk: list[OHLC]):
if not chunk:
return
path = self.chunk_path(symbol, period, datetime.fromisoformat(chunk[0][0]))
try:
with open(path, 'w') as file:
json.dump(chunk, file)
return
except FileNotFoundError:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as file:
json.dump(chunk, file)
def chunk_path(self, symbol: str, period: timedelta, time: datetime) -> str:
start = ohlc_start_time(time, period)
name = ohlc_name(period)
return f'{self.dir}/{symbol}/{name}/' + (
f'{start.year}/{symbol}-{name}-{start:%Y%m%d}.json' if period < timedelta(hours=1) else # <1H data has a file per day
f'{start.year}/{symbol}-{name}-{start:%Y%m}.json' if period < timedelta(days=7) else # <1W data has a file per month
f'{symbol}-{name}.json' # long periods are a single file for all of history
f'{start.year}/{symbol}-{name}-{start:%Y%m%d}.json' if period < timedelta(hours=1) else # <1H data has a file per day
f'{start.year}/{symbol}-{name}-{start:%Y%m}.json' if period < timedelta(days=7) else # <1W data has a file per month
f'{symbol}-{name}.json' # long periods are a single file for all of history
)
if __name__ == '__main__':
r = OHLCRepository('')
for p in OHLC_PERIODS:
print(f'{ohlc_name(p)}\t{r.chunk_path("symbol",p, datetime.utcnow())}')
ohlcs = OHLCRepository()

View File

@@ -19,6 +19,7 @@ class SwapOrderState (Enum):
Filled = 5
class Exchange (Enum):
Unknown = -1
UniswapV2 = 0
UniswapV3 = 1

View File

@@ -0,0 +1,97 @@
import logging
import sys
from asyncio import CancelledError
from datetime import datetime
from async_lru import alru_cache
from web3.types import EventData
from dexorder import blockchain, config, dec, current_w3
from dexorder.base.ohlc import ohlcs
from dexorder.base.orderlib import Exchange
from dexorder.bin.executable import execute
from dexorder.blockstate.blockdata import BlockData
from dexorder.blockstate.db_state import DbState
from dexorder.configuration import parse_args
from dexorder.contract import get_contract_event
from dexorder.database import db
from dexorder.memcache.memcache_state import RedisState, publish_all
from dexorder.memcache import memcache
from dexorder.pools import uniswap_price, Pools
from dexorder.runner import BlockStateRunner
from dexorder.util import hexint
log = logging.getLogger('dexorder')
@alru_cache
async def get_block_timestamp(blockhash) -> int:
response = await current_w3.get().provider.make_request('eth_getBlockByHash', [blockhash, False])
raw = hexint(response['result']['timestamp'])
# noinspection PyTypeChecker
return raw if type(raw) is int else hexint(raw)
async def handle_uniswap_swap(swap: EventData):
try:
sqrt_price = swap['args']['sqrtPriceX96']
except KeyError:
return
addr = swap['address']
pool = await Pools.get(addr)
if pool is None:
return
if pool.exchange != Exchange.UniswapV3:
log.debug(f'Ignoring {pool.exchange} pool {addr}')
return
price: dec = await uniswap_price(pool, sqrt_price)
timestamp = await get_block_timestamp(swap['blockHash'])
dt = datetime.fromtimestamp(timestamp)
log.debug(f'pool {addr} {dt} {price}')
ohlcs.update_all(addr, dt, price, create=True)
async def main():
# noinspection DuplicatedCode
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
log.setLevel(logging.DEBUG)
parse_args()
await blockchain.connect()
redis_state = None
state = None
if memcache:
await memcache.connect()
redis_state = RedisState(BlockData.by_opt('redis'))
if db:
db.connect(url=config.datadb_url) # our main database is the data db
# noinspection DuplicatedCode
db_state = DbState(BlockData.by_opt('db'))
with db.session:
state = db_state.load()
if state is None:
log.info('no state in database')
else:
if redis_state:
await redis_state.init(state)
log.info(f'loaded state from db for root block {state.root_block}')
runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None, timer_period=0)
# noinspection PyTypeChecker
runner.add_event_trigger(handle_uniswap_swap, get_contract_event('IUniswapV3PoolEvents', 'Swap'))
if db:
# noinspection PyUnboundLocalVariable,PyTypeChecker
runner.on_promotion.append(db_state.save)
if redis_state:
# noinspection PyTypeChecker
runner.on_head_update.append(redis_state.save)
try:
await runner.run()
except CancelledError:
pass
log.info('exiting')
if __name__ == '__main__':
execute(main())

View File

@@ -84,7 +84,9 @@ async def main():
db_state = DbState(BlockData.by_opt('db'))
with db.session:
state = db_state.load()
if state is not None:
if state is None:
log.info('no state in database')
else:
if redis_state:
await redis_state.init(state)
log.info(f'loaded state from db for root block {state.root_block}')

View File

@@ -13,11 +13,14 @@ class Config:
ws_url: str = 'ws://localhost:8545'
rpc_urls: Optional[dict[str,str]] = field(default_factory=dict)
db_url: str = 'postgresql://dexorder:redroxed@localhost/dexorder'
datadb_url: str = 'postgresql://dexorder:redroxed@localhost/dexorderdata'
ohlc_dir: str = './ohlc'
dump_sql: bool = False
redis_url: str = 'redis://localhost:6379'
parallel_logevent_queries: bool = True
polling: float = 0 # seconds between queries for a new block. 0 disables polling and uses a websocket subscription on ws_url instead
backfill: int = 0 # if not 0, then runner will initialize an empty database by backfilling from the given block height
tokens: list['TokenConfig'] = field(default_factory=list)

View File

@@ -1,6 +1,9 @@
import logging
from dexorder import db
from eth_abi.exceptions import InsufficientDataBytes
from web3.exceptions import ContractLogicError, BadFunctionCallOutput
from dexorder import db, dec
from dexorder.contract import ERC20
log = logging.getLogger(__name__)
@@ -11,6 +14,13 @@ async def token_decimals(addr):
try:
return db.kv[key]
except KeyError:
decimals = await ERC20(addr).decimals()
try:
decimals = await ERC20(addr).decimals()
except (InsufficientDataBytes, ContractLogicError, BadFunctionCallOutput):
log.warning(f'token {addr} has no decimals()')
decimals = 0
except Exception:
log.debug(f'could not get token decimals for {addr}')
return None
db.kv[key] = decimals
return decimals

View File

@@ -42,8 +42,9 @@ class Kv:
class Db:
def __init__(self):
def __init__(self, db_url_config_key='db_url'):
self.kv = Kv()
self.db_url_config_key = db_url_config_key
def __bool__(self):
return bool(config.db_url)
@@ -79,24 +80,24 @@ class Db:
return s
# noinspection PyShadowingNames
@staticmethod
def connect(url=None, migrate=True, reconnect=False, dump_sql=None):
def connect(self, url=None, migrate=True, reconnect=False, dump_sql=None):
if _engine.get() is not None and not reconnect:
return
if url is None:
url = config.db_url
url = config[self.db_url_config_key]
if dump_sql is None:
dump_sql = config.dump_sql
engine = sqlalchemy.create_engine(url, echo=dump_sql, json_serializer=json.dumps, json_deserializer=json.loads)
if migrate:
migrate_database()
migrate_database(url)
with engine.connect() as connection:
connection.execute(sqlalchemy.text("SET TIME ZONE 'UTC'"))
result = connection.execute(sqlalchemy.text("select version_num from alembic_version"))
for row in result:
log.info(f'database revision {row[0]}')
log.info(f'{url} database revision {row[0]}')
_engine.set(engine)
return db
raise Exception('database version not found')
raise Exception(f'{url} database version not found')
db = Db()
datadb = Db('datadb_url')

View File

@@ -1,10 +1,15 @@
import subprocess
import sys
from traceback import print_exception
from alembic import command
from alembic.config import Config
def migrate_database():
completed = subprocess.run('alembic upgrade head', shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if completed.returncode != 0:
print(completed.stdout.decode(), file=sys.stderr)
def migrate_database(db_url):
alembic_config = Config("alembic.ini")
alembic_config.set_main_option('sqlalchemy.url', db_url)
try:
command.upgrade(alembic_config, "head")
except Exception as e:
print('FATAL: database migration failed!', file=sys.stderr)
print_exception(e, file=sys.stderr)
exit(1)

View File

@@ -8,15 +8,14 @@ from web3.types import EventData
from dexorder import current_pub, db, dec
from dexorder.base.chain import current_chain, current_clock
from dexorder.base.order import TrancheExecutionRequest, TrancheKey, ExecutionRequest, new_tranche_execution_request, OrderKey
from dexorder.transaction import create_transactions, submit_transaction_request, handle_transaction_receipts, send_transactions
from dexorder.pools import uniswap_price
from dexorder.contract.dexorder import get_factory_contract, vault_address, VaultContract, get_dexorder_contract
from dexorder.contract import get_contract_event, ERC20
from dexorder.transaction import submit_transaction_request
from dexorder.pools import uniswap_price, new_pool_prices, pool_prices, Pools
from dexorder.contract.dexorder import vault_address, VaultContract
from dexorder.contract import ERC20
from dexorder.data import vault_owners, vault_balances
from dexorder.pools import new_pool_prices, pool_prices, pool_decimals, Pools
from dexorder.database.model.block import current_block
from dexorder.database.model.transaction import TransactionJob
from dexorder.base.orderlib import SwapOrderStatus, SwapOrderState
from dexorder.base.orderlib import SwapOrderState
from dexorder.order.orderstate import Order
from dexorder.order.triggers import OrderTriggers, price_triggers, time_triggers, \
unconstrained_price_triggers, execution_requests, inflight_execution_requests, TrancheStatus, active_tranches, new_price_triggers, activate_order
@@ -187,7 +186,7 @@ def handle_vault_created(created: EventData):
async def activate_time_triggers():
now = current_clock.get().now()
now = current_clock.get().timestamp()
# log.debug(f'activating time triggers at {now}')
# time triggers
for tt in tuple(time_triggers):

View File

@@ -54,7 +54,7 @@ async def line_passes(lc: tuple[float,float], is_min: bool, price: dec) -> bool:
b, m = lc
if b == 0 and m == 0:
return True
limit = m * current_clock.get().now() + b
limit = m * current_clock.get().timestamp() + b
log.debug(f'line passes {limit} {"<" if is_min else ">"} {price}')
# todo ratios
# prices AT the limit get zero volume, so we only trigger on >, not >=
@@ -98,7 +98,7 @@ class TrancheTrigger:
if tranche_remaining == 0 or tranche_remaining < self.order.min_fill_amount: # min_fill_amount could be 0 (disabled) so we also check for the 0 case separately
self._status = TrancheStatus.Filled
return
timestamp = current_clock.get().now()
timestamp = current_clock.get().timestamp()
self._status = \
TrancheStatus.Pricing if self.time_constraint is None else \
TrancheStatus.Early if timestamp < self.time_constraint[0] else \

View File

@@ -2,14 +2,16 @@ import asyncio
import logging
from typing import Optional
from dexorder import dec, db
from web3.exceptions import ContractLogicError
from dexorder import dec, db, ADDRESS_0
from dexorder.base.chain import current_chain
from dexorder.base.orderlib import Exchange
from dexorder.blockstate import BlockDict
from dexorder.blockstate.blockdata import K, V
from dexorder.contract.decimals import token_decimals
from dexorder.database.model.pool import Pool
from dexorder.uniswap import UniswapV3Pool
from dexorder.uniswap import UniswapV3Pool, uniswapV3_pool_address
log = logging.getLogger(__name__)
@@ -26,11 +28,22 @@ class Pools:
found = db.session.get(Pool, key)
if not found:
# todo other exchanges
t0, t1, fee = await asyncio.gather(UniswapV3Pool(address).token0(), UniswapV3Pool(address).token1(), UniswapV3Pool(address).fee())
found = Pool(chain=chain, address=address, exchange=Exchange.UniswapV3, base=t0, quote=t1, fee=fee)
try:
v3 = UniswapV3Pool(address)
t0, t1, fee = await asyncio.gather(v3.token0(), v3.token1(), v3.fee())
if uniswapV3_pool_address(t0, t1, fee) == address: # VALIDATE don't just trust
log.debug(f'new UniswapV3 pool at {address}')
found = Pool(chain=chain, address=address, exchange=Exchange.UniswapV3, base=t0, quote=t1, fee=fee)
db.session.add(found)
else: # NOT a genuine Uniswap V3 pool if the address test doesn't pass
log.debug(f'new Unknown pool at {address}')
found = Pool(chain=chain, address=address, exchange=Exchange.Unknown, base=ADDRESS_0, quote=ADDRESS_0, fee=0)
except ContractLogicError:
log.debug(f'new Unknown pool at {address}')
found = Pool(chain=chain, address=address, exchange=Exchange.Unknown, base=ADDRESS_0, quote=ADDRESS_0, fee=0)
db.session.add(found)
Pools.instances[key] = found
return found
return None if found.exchange == Exchange.Unknown else found
class PoolPrices (BlockDict[str, dec]):
@@ -71,18 +84,15 @@ async def ensure_pool_price(pool):
_pool_decimals = {}
async def pool_decimals(pool):
if pool.exchange != Exchange.UniswapV3:
raise ValueError
found = _pool_decimals.get(pool)
if found is None:
key = f'pd|{pool.address}'
try:
found = db.kv[key]
log.debug('got decimals from db')
except KeyError:
found = _pool_decimals[pool] = await token_decimals(pool.base) - await token_decimals(pool.quote)
decimals0 = await token_decimals(pool.base)
decimals1 = await token_decimals(pool.quote)
decimals = decimals0 - decimals1
db.kv[key] = decimals
log.debug(f'pool decimals: {decimals0} - {decimals1}')
log.debug(f'pool decimals {pool.address} {found}')
found = _pool_decimals[pool] = db.kv[key] = decimals0 - decimals1
return found

View File

@@ -1,10 +1,11 @@
import asyncio
import logging
from asyncio import Queue
from typing import Callable, Union, Any, Iterable
from typing import Union, Any, Iterable
from web3.contract.contract import ContractEvents
from web3.exceptions import LogTopicError, MismatchedABI
from web3.types import EventData
# noinspection PyPackageRequirements
from websockets.exceptions import ConnectionClosedError
@@ -17,44 +18,74 @@ from dexorder.blockstate.diff import DiffEntryItem
from dexorder.database.model import Block
from dexorder.database.model.block import current_block, latest_block
from dexorder.util import hexstr, topic
from dexorder.util.async_util import maywait
from dexorder.util.async_util import maywait, Maywaitable
log = logging.getLogger(__name__)
# todo detect reorgs and generate correct onHeadUpdate set by unioning the changes along the two forks, not including their common ancestor deltas
class BlockStateRunner:
def __init__(self, state: BlockState = None, *, publish_all=None):
"""
NOTE: This doc is old and not strictly true but still has the basic idea
1. load root stateBlockchain
a. if no root, init from head
b. if root is old, batch forward by height
2. discover new heads
2b. find in-state parent block else use root
3. set the current fork = ancestor->head diff state
4. query blockchain eventlogs
5. process new vaults
6. process new orders and cancels
a. new pools
7. process Swap events and generate pool prices
8. process price horizons
9. process token movement
10. process swap triggers (zero constraint tranches)
11. process price tranche triggers
12. process horizon tranche triggers
13. filter by time tranche triggers
14. bundle execution requests and send tx. tx has require(block<deadline) todo execute deadlines
15. on tx confirmation, the block height of all executed trigger requests is set to the tx block
Most of these steps, the ones handling events, are set up in main.py so that datamain.py can also use Runner for its own purposes
"""
def __init__(self, state: BlockState = None, *, publish_all=None, timer_period: float = 1):
"""
If state is None, then it is initialized as empty using the first block seen as the root block. Then the second block begins log event handling.
"""
self.state = state
# items are (callback, event, log_filter). The callback is invoked with web3 EventData for every detected event
self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = []
self.events:list[tuple[Maywaitable[[EventData],None],ContractEvents,dict]] = []
# these callbacks are invoked after every block and also every second if there wasnt a block
self.postprocess_cbs:list[Callable[[],None]] = []
self.postprocess_cbs:list[Maywaitable[[],None]] = []
# onStateInit callbacks are invoked after the initial state is loaded or created
self.on_state_init: list[Callable[[],None]] = []
self.on_state_init: list[Maywaitable[[],None]] = []
self.state_initialized = False
# onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root
self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],None]] = []
self.on_head_update: list[Maywaitable[[Block,list[DiffEntryItem]],None]] = []
# onPromotion callbacks are invoked with a list of DiffItems used to advance the root state
self.on_promotion: list[Callable[[Block,list[DiffEntryItem]],None]] = []
self.on_promotion: list[Maywaitable[[Block,list[DiffEntryItem]],None]] = []
self.publish_all: Callable[[Iterable[tuple[str,str,Any]]],None] = publish_all
self.publish_all: Maywaitable[[Iterable[tuple[str,str,Any]]],None] = publish_all
self.timer_period = timer_period
self.queue: Queue = Queue()
self.max_height_seen = config.backfill
self.running = False
def add_event_trigger(self, callback: Callable[[dict], None], event: ContractEvents = None, log_filter: Union[dict, str] = None):
def add_event_trigger(self, callback: Maywaitable[[EventData], None], event: ContractEvents = None, log_filter: Union[dict, str] = None):
"""
if event is None, the callback is still invoked in the series of log handlers but with no argument instead of logs
"""
@@ -62,42 +93,18 @@ class BlockStateRunner:
log_filter = {'topics': [topic(event.abi)]}
self.events.append((callback, event, log_filter))
def add_postprocess_trigger(self, callback: Callable[[dict], None]):
def add_postprocess_trigger(self, callback: Maywaitable[[], None]):
# noinspection PyTypeChecker
self.postprocess_cbs.append(callback)
async def run(self):
# this run() process discovers new heads and puts them on a queue for the worker to process. the discovery is ether websockets or polling
if self.state:
self.max_height_seen = max(self.max_height_seen, self.state.root_block.height)
self.running = True
return await (self.run_polling() if config.polling > 0 else self.run_ws())
async def run_ws(self):
"""
NOTE: This doc is old and not strictly true but still the basic idea
1. load root stateBlockchain
a. if no root, init from head
b. if root is old, batch forward by height
2. discover new heads
2b. find in-state parent block else use root
3. fork = ancestor->head diff
4. query global log filter
5. process new vaults
6. process new orders and cancels
a. new pools
7. process Swap events and generate pool prices
8. process price horizons
9. process token movement
10. process swap triggers (zero constraint tranches)
11. process price tranche triggers
12. process horizon tranche triggers
13. filter by time tranche triggers
14. bundle execution requests and send tx. tx has require(block<deadline)
15. on tx confirmation, the block height of all executed trigger requests is set to the tx block
"""
self.running = True
w3ws = await create_w3_ws()
chain_id = await w3ws.eth.chain_id
chain = Blockchain.for_id(chain_id)
@@ -117,7 +124,7 @@ class BlockStateRunner:
async for message in w3ws.ws.listen_to_websocket():
head = message['result']
log.debug(f'detected new block {head["number"]} {hexstr(head["hash"])}')
await self.queue.put(head["hash"])
await self.add_head(head["hash"])
if not self.running:
break
await async_yield()
@@ -138,11 +145,9 @@ class BlockStateRunner:
"""
Hardhat websocket stops sending messages after about 5 minutes.
https://github.com/NomicFoundation/hardhat/issues/2053
So we must implement polling to work around their incompetence.
So we implement polling as a workaround.
"""
self.running = True
w3 = create_w3()
chain_id = await w3.eth.chain_id
chain = Blockchain.for_id(chain_id)
@@ -162,8 +167,8 @@ class BlockStateRunner:
head = block['hash']
if head != prev_blockhash:
prev_blockhash = head
await self.queue.put(block)
log.debug(f'polled new block {hexstr(head)}')
await self.add_head(block)
if not self.running:
break
await asyncio.sleep(config.polling)
@@ -180,6 +185,48 @@ class BlockStateRunner:
log.debug('runner run_polling() exiting')
async def add_head(self, head):
"""
head can either be a full block-data struct or simply a block hash. this method converts it to a Block
and pushes that Block onto the worker queue
"""
chain = current_chain.get()
w3 = current_w3.get()
try:
block_data = head
blockhash = block_data['hash']
parent = block_data['parentHash']
height = block_data['number']
except TypeError:
blockhash = head
response = await w3.provider.make_request('eth_getBlockByHash', [blockhash, False])
block_data:dict = response['result']
parent = bytes.fromhex(block_data['parentHash'][2:])
height = int(block_data['number'], 0)
head = Block(chain=chain.chain_id, height=height, hash=blockhash, parent=parent, data=block_data)
if self.state or config.backfill:
# backfill batches
start_height = self.max_height_seen
batch_height = start_height + chain.batch_size - 1
while batch_height < head.height:
# the backfill is larger than a single batch, so we push intermediate head blocks onto the queue
response = await w3.provider.make_request('eth_getBlockByNumber', [hex(batch_height), False])
block_data: dict = response['result']
blockhash = bytes.fromhex(block_data['hash'][2:])
parent = bytes.fromhex(block_data['parentHash'][2:])
height = int(block_data['number'], 0)
assert height == batch_height
block = Block(chain=chain.chain_id, height=height, hash=blockhash, parent=parent, data=block_data)
log.debug(f'enqueueing batch backfill from {start_height} through {batch_height}')
await self.queue.put(block) # add an intermediate block
self.max_height_seen = height
start_height += chain.batch_size
batch_height += chain.batch_size
await self.queue.put(head) # add the head block
self.max_height_seen = head.height
async def worker(self):
try:
log.debug(f'runner worker started')
@@ -190,7 +237,10 @@ class BlockStateRunner:
prev_head = None
while self.running:
try:
async with asyncio.timeout(1): # check running flag every second
if self.timer_period:
async with asyncio.timeout(self.timer_period):
head = await self.queue.get()
else:
head = await self.queue.get()
except TimeoutError:
# 1 second has passed without a new head. Run the postprocess callbacks to check for activated time-based triggers
@@ -202,34 +252,22 @@ class BlockStateRunner:
prev_head = head
except Exception as x:
log.exception(x)
except:
except Exception:
log.exception('exception in runner worker')
raise
finally:
log.debug('runner worker exiting')
async def handle_head(self, chain, blockhash, w3):
# check blockhash type and convert
try:
block_data = blockhash
blockhash = block_data['hash']
parent = block_data['parentHash']
height = block_data['number']
except TypeError:
response = await w3.provider.make_request('eth_getBlockByHash', [blockhash, False])
block_data:dict = response['result']
parent = bytes.fromhex(block_data['parentHash'][2:])
height = int(block_data['number'], 0)
log.debug(f'processing block {blockhash}')
chain_id = chain.chain_id
async def handle_head(self, chain, block, w3):
print(f'logger {log} {log.name} level {log.level} {logging.DEBUG} {logging.FATAL}')
log.debug(f'handle_head {block.height} {block.hash}')
session = None
batches = []
try:
if self.state is not None and blockhash in self.state.by_hash:
log.debug(f'block {blockhash} was already processed')
if self.state is not None and block.hash in self.state.by_hash:
log.debug(f'block {block.hash} was already processed')
return
assert block_data is not None
block = Block(chain=chain_id, height=height, hash=blockhash, parent=parent, data=block_data)
latest_block.set(block)
current_clock.get().set(block.timestamp)
if self.state is None:
@@ -250,19 +288,16 @@ class BlockStateRunner:
if log_filter is None:
batches.append((None, callback, event, None))
else:
from_height = self.state.root_block.height + 1
end_height = block.height
while from_height <= end_height:
to_height = min(end_height, from_height + chain.batch_size - 1)
lf = dict(log_filter)
lf['fromBlock'] = from_height
lf['toBlock'] = to_height
log.debug(f'batch backfill {from_height} - {to_height}')
get_logs = w3.eth.get_logs(lf)
if not config.parallel_logevent_queries:
get_logs = await get_logs
batches.append((get_logs, callback, event, lf))
from_height += chain.batch_size
from_height = fork.parent.height
to_height = fork.height
lf = dict(log_filter)
lf['fromBlock'] = from_height
lf['toBlock'] = to_height
log.debug(f'querying backfill {from_height} through {to_height}')
get_logs = w3.eth.get_logs(lf)
if not config.parallel_logevent_queries:
get_logs = await get_logs
batches.append((get_logs, callback, event, lf))
for callback in self.postprocess_cbs:
batches.append((None, callback, None, None))
else:
@@ -307,7 +342,10 @@ class BlockStateRunner:
# todo try/except for known retryable errors
await maywait(callback(parsed))
# todo check for reorg and generate a reorg diff list
# todo
# IMPORTANT! check for a reorg and generate a reorg diff list. the diff list we need is the union of the set of keys touched by either
# branch. Then we query all the values for those keys and apply that kv list to redis. This will make sure that any orphaned data that
# isn't updated by the new fork is still queried from the root state to overwrite any stale data from the abandoned branch.
diff_items = self.state.diffs_by_hash[block.hash]
for callback in self.on_head_update:
await maywait(callback(block, diff_items))
@@ -319,7 +357,7 @@ class BlockStateRunner:
diff_items = self.state.promote_root(new_root_fork)
for callback in self.on_promotion:
# todo try/except for known retryable errors
callback(self.state.root_block, diff_items)
await maywait(callback(self.state.root_block, diff_items))
if pubs and self.publish_all:
await maywait(self.publish_all(pubs))
@@ -327,8 +365,14 @@ class BlockStateRunner:
log.debug('rolling back session')
if session is not None:
session.rollback()
if blockhash is not None and self.state is not None:
self.state.delete_block(blockhash)
if block.hash is not None and self.state is not None:
self.state.delete_block(block.hash)
if config.parallel_logevent_queries:
for get_logs, *_ in batches:
try:
await get_logs
except Exception:
log.exception('exception while querying logs')
raise
else:
if session is not None:

View File

@@ -22,11 +22,6 @@ class UniswapV3Pool (ContractProxy):
def __init__(self, address: str = None):
super().__init__(address, 'IUniswapV3Pool')
async def price(self):
if not self.address:
raise ValueError
return await uniswap_price(self.address, (await self.slot0())[0])
def ordered_addresses(addr_a:str, addr_b:str):
return (addr_a, addr_b) if addr_a.lower() <= addr_b.lower() else (addr_b, addr_a)

View File

@@ -1,5 +1,7 @@
import asyncio
import inspect
from abc import ABC
from typing import Union, Callable, Awaitable, TypeVar, Generic
async def async_yield():
@@ -7,7 +9,14 @@ async def async_yield():
await asyncio.sleep(1e-9)
async def maywait(obj):
Args = TypeVar('Args')
Return = TypeVar('Return')
class Maywaitable (Generic[Args, Return], Callable[[Args],Return], Awaitable[Return], ABC):
pass
async def maywait(obj: Maywaitable):
if inspect.isawaitable(obj):
obj = await obj
return obj