refactor into TriggerRunner

This commit is contained in:
Tim Olson
2023-09-19 16:05:04 -04:00
parent 68647364cd
commit 0ff6a1ae0b
23 changed files with 379 additions and 171 deletions

View File

@@ -17,8 +17,8 @@ if config.config_file_name is not None:
# DEXORDER SETUP
from sys import path
path.append('src')
import dexorder.db.model
target_metadata = dexorder.db.model.Base.metadata
import dexorder.database.model
target_metadata = dexorder.database.model.Base.metadata
config.set_main_option('sqlalchemy.url', dexorder.config.db_url)
# other values from the config, defined by the needs of env.py,

View File

@@ -9,7 +9,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import dexorder.db
import dexorder.database
${imports if imports else ""}
# revision identifiers, used by Alembic.

View File

@@ -17,5 +17,5 @@ from .util import async_yield
from .base.fixed import Fixed2, FixedDecimals, Dec18
from .configuration import config
from .base.account import Account # must come before context
from .base.context import ctx
from .base.token import Token, tokens
from .database import db

View File

@@ -1,14 +1,14 @@
import logging
from collections import defaultdict
from contextvars import ContextVar
from logging import Logger
from typing import Union, TypeVar, Generic, Any
from sortedcontainers import SortedList
from dexorder import NARG
from dexorder.db.model.block import Block
from dexorder.database.model.block import Block
log = Logger('dexorder.blockstate')
log = logging.getLogger(__name__)
class BlockState:

View File

@@ -1,3 +1,6 @@
from typing import Union
from defaultlist import defaultlist
from eth_utils import keccak
from dexorder.base.blockstate import BlockDict
@@ -7,18 +10,17 @@ class EventManager:
def __init__(self):
self.all_topics = set()
self.triggers:dict[str,BlockDict] = {}
self.rings = defaultlist(list)
def add_handler(self, topic: str, callback):
if not topic.startswith('0x'):
topic = '0x'+keccak(text=topic).hex().lower()
def add_handler(self, topic: Union[bytes,str], callback):
if type(topic) is str:
topic = bytes.fromhex(topic[2:]) if topic.startswith('0x') else keccak(text=topic)
triggers = self.triggers.get(topic)
if triggers is None:
triggers = self.triggers[topic] = BlockDict(topic)
triggers.add(callback)
self.all_topics.add(topic)
def handle_logs(self, logs):
for log in logs:
for callback, _ in self.triggers.get(log.topics[0].hex(), []).items():
callback(log)
def publish_topic(self, topic, data):
for callback, _ in self.triggers.get(topic, {}).items():
callback(data)

View File

@@ -4,11 +4,11 @@ from decimal import Decimal
from sqlalchemy.orm import Mapped
from web3 import Web3
from dexorder import config, ctx, Blockchain, NARG, FixedDecimals, ADDRESS_0
from dexorder import config, Blockchain, NARG, FixedDecimals, ADDRESS_0
from dexorder.blockchain import ByBlockchainDict
from dexorder.base.chain import Polygon, ArbitrumOne, Ethereum
from dexorder.contract import ContractProxy, abis
import dexorder.db.column as col
import dexorder.database.column as col
class Token (ContractProxy, FixedDecimals):

View File

@@ -41,7 +41,8 @@ def execute(main:Coroutine, shutdown=None, parse_args=True):
loop.run_until_complete(task)
x = task.exception()
if x is not None:
print_exception(x)
if x.__class__ not in ignorable_exceptions:
print_exception(x)
for t in asyncio.all_tasks():
t.cancel()
else:

View File

@@ -1,122 +1,17 @@
import logging
from asyncio import CancelledError
from hexbytes import HexBytes
from web3 import AsyncWeb3, WebsocketProviderV2, AsyncHTTPProvider
from web3.types import FilterParams
from dexorder import config, Blockchain
from dexorder.base.blockstate import BlockState, BlockDict
from dexorder.base.event_manager import EventManager
from dexorder.bin.executable import execute
from dexorder.configuration import resolve_rpc_url
from dexorder.db.model import Block
from dexorder.trigger_runner import TriggerRunner
log = logging.getLogger('dexorder')
ROOT_AGE = 10 # todo set per chain
wallets = BlockDict('wallets')
def handle_transfer(event):
to_address = event.topics[2].hex()
wallets.add(to_address)
def setup_triggers(event_manager: EventManager):
event_manager.add_handler('Transfer(address,address,uint256)', handle_transfer)
async def main():
"""
1. load root stateBlockchain
a. if no root, init from head
b. if root is old, batch forward by height
2. discover new heads
2b. find in-memory ancestor else use root
3. context = ancestor->head diff
4. query global log filter
5. process new vaults
6. process new orders and cancels
a. new pools
7. process Swap events and generate pool prices
8. process price horizons
9. process token movement
10. process swap triggers (zero constraint tranches)
11. process price tranche triggers
12. process horizon tranche triggers
13. filter by time tranche triggers
14. bundle execution requests and send tx. tx has require(block<deadline)
15. on tx confirmation, the block height of all executed trigger requests is set to the tx block
"""
# db.connect()
# blockchain.connect()
ws_provider = WebsocketProviderV2(resolve_rpc_url(config.ws_url))
w3ws = AsyncWeb3.persistent_websocket(ws_provider)
http_provider = AsyncHTTPProvider(resolve_rpc_url(config.rpc_url))
w3 = AsyncWeb3(http_provider)
# w3.middleware_onion.remove('attrdict')
try:
chain_id = await w3ws.eth.chain_id
Blockchain.set_cur(Blockchain.for_id(chain_id))
event_manager = EventManager()
# todo load root
state = None
async with w3ws as w3ws:
await w3ws.eth.subscribe('newHeads')
while True:
async for head in w3ws.listen_to_websocket():
log.debug('head', head)
block_data = await w3.eth.get_block(head.hash.hex(), True)
block = Block(chain=chain_id,height=block_data.number,hash=block_data.hash,parent=block_data.parentHash,data=block_data)
block.set_latest(block)
block.set_cur(block)
if state is None:
state = BlockState(block,{})
BlockState.set_cur(state)
setup_triggers(event_manager)
log.info('Created new empty root state')
else:
ancestor = BlockState.cur().add_block(block)
if ancestor is None:
log.debug(f'discarded late-arriving head {block}')
elif type(ancestor) is int:
# todo backfill batches
log.error(f'backfill unimplemented for range {ancestor} to {block}')
else:
logs_filter = FilterParams(topics=list(event_manager.all_topics), blockhash=HexBytes(block.hash).hex())
log.debug(f'get logs {logs_filter}')
logs = await w3.eth.get_logs(logs_filter)
if logs:
log.debug('handle logs')
event_manager.handle_logs(logs)
# check for root promotion
if block.height - state.root_block.height > ROOT_AGE:
b = block
try:
for _ in range(1,ROOT_AGE):
# we walk backwards ROOT_AGE and promote what's there
b = state.by_hash[b.parent]
except KeyError:
pass
else:
log.debug(f'promoting root {b}')
state.promote_root(b)
log.debug('wallets: '+' '.join(k for k,_ in wallets.items()))
except CancelledError:
pass
finally:
if ws_provider.is_connected():
await ws_provider.disconnect()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
log = logging.getLogger('dexorder')
log.setLevel(logging.DEBUG)
execute(main())
execute(TriggerRunner().run())
log.info('exiting')

View File

@@ -1,6 +1,6 @@
from typing import Generic, TypeVar, Any, Iterator
from dexorder import ctx, NARG
from dexorder import NARG
_T = TypeVar('_T')

View File

@@ -1,9 +1,24 @@
from web3 import HTTPProvider, Web3
from web3.middleware import geth_poa_middleware, simple_cache_middleware
from contextvars import ContextVar
from hexbytes import HexBytes
from web3 import WebsocketProviderV2, AsyncWeb3, AsyncHTTPProvider
from dexorder import ctx
from dexorder.blockchain.util import get_contract_data
from ..configuration import resolve_rpc_url
from ..configuration.resolve import resolve_ws_url
_w3 = ContextVar('w3')
class W3:
@staticmethod
def cur() -> AsyncWeb3:
return _w3.get()
@staticmethod
def set_cur(value:AsyncWeb3):
_w3.set(value)
def connect(rpc_url=None):
@@ -13,11 +28,28 @@ def connect(rpc_url=None):
use create_w3() and set w3.eth.default_account separately
"""
w3 = create_w3(rpc_url)
ctx.w3 = w3
W3.set_cur(w3)
return w3
def create_w3(rpc_url=None):
# todo create a proxy w3 that rotates among rpc urls
# self.w3s = tuple(create_w3(url) for url in rpc_url_or_tag)
# chain_id = self.w3s[0].eth.chain_id
# assert all(w3.eth.chain_id == chain_id for w3 in self.w3s) # all rpc urls must be the same blockchain
# self.w3iter = itertools.cycle(self.w3s)
url = resolve_rpc_url(rpc_url)
w3 = AsyncWeb3(AsyncHTTPProvider(url))
# w3.middleware_onion.inject(geth_poa_middleware, layer=0) # todo is this line needed?
# w3.middleware_onion.add(simple_cache_middleware)
w3.middleware_onion.remove('attrdict')
w3.middleware_onion.add(clean_input_async, 'clean_input')
w3.eth.Contract = _make_contract(w3.eth)
return w3
def create_w3_ws(ws_url=None):
"""
this constructs a Web3 object but does NOT attach it to the context. consider using connect(...) instead
this does *not* attach any signer to the w3. make sure to inject the proper middleware with Account.attach(w3)
@@ -27,15 +59,44 @@ def create_w3(rpc_url=None):
# chain_id = self.w3s[0].eth.chain_id
# assert all(w3.eth.chain_id == chain_id for w3 in self.w3s) # all rpc urls must be the same blockchain
# self.w3iter = itertools.cycle(self.w3s)
url = resolve_rpc_url(rpc_url)
w3 = Web3(HTTPProvider(url))
w3.middleware_onion.inject(geth_poa_middleware, layer=0)
w3.middleware_onion.add(simple_cache_middleware)
ws_provider = WebsocketProviderV2(resolve_ws_url(ws_url))
w3 = AsyncWeb3.persistent_websocket(ws_provider)
w3.middleware_onion.remove('attrdict')
# w3.middleware_onion.add(clean_input, 'clean_input')
w3.eth.Contract = _make_contract(w3.eth)
return w3
def _clean(obj):
if type(obj) is HexBytes:
return bytes(obj)
elif type(obj) is list:
return [_clean(v) for v in obj]
elif type(obj) is dict:
return {k:_clean(v) for k,v in obj.items()}
return obj
def _make_clean_input_middleware(make_request,_w3):
def _clean_input(method, params):
# do pre-processing here
# perform the RPC request, getting the response
response = make_request(method, params)
# do post-processing here
response = _clean(response)
# finally return the response
return response
return _clean_input
async def clean_input_async(make_request, w3):
# do one-time setup operations here
return _make_clean_input_middleware(make_request, w3)
def clean_input(make_request, w3):
# do one-time setup operations here
return _make_clean_input_middleware(make_request, w3)
def _make_contract(w3_eth):
def f(address, abi_or_name): # if abi, then it must already be in native object format, not a string
if type(abi_or_name) is str:

View File

@@ -11,3 +11,15 @@ def resolve_rpc_url(rpc_url=None):
except KeyError:
pass
return rpc_url
def resolve_ws_url(ws_url=None):
if ws_url is None:
ws_url = config.ws_url
if ws_url == 'test':
return 'ws://localhost:8545'
try:
return config.rpc_urls[ws_url] # look up aliases
except KeyError:
pass
return ws_url

View File

@@ -0,0 +1,67 @@
from contextvars import ContextVar
import sqlalchemy
from sqlalchemy import Engine
from sqlalchemy.orm import Session, SessionTransaction
from .migrate import migrate_database
from .. import config
_engine = ContextVar[Engine]('engine', default=None)
_session = ContextVar[Session]('session', default=None)
class Db:
def transaction(self) -> SessionTransaction:
"""
this type of block should be at the top-level of any group of db operations. it will automatically commit
and close the session at the end of the scope
```
with db.transaction():
do_database_stuff()
```
if you want to do manual commits, use:
```
with db.session:
do_database_stuff()
```
"""
return self.session.begin()
@property
def session(self) -> Session:
s = _session.get()
if s is None:
engine = _engine.get()
if engine is None:
raise RuntimeError('Cannot create session: no database engine set. Use dexorder.db.connect() first')
s = Session(engine, expire_on_commit=False)
# noinspection PyTypeChecker
_session.set(s)
return s
# noinspection PyShadowingNames
@staticmethod
def connect(url=None, migrate=True, reconnect=False, dump_sql=None):
if _engine.get() is not None and not reconnect:
return
if url is None:
url = config.db_url
if dump_sql is None:
dump_sql = config.dump_sql
engine = sqlalchemy.create_engine(url, echo=dump_sql)
if migrate:
migrate_database()
with engine.connect() as connection:
connection.execute(sqlalchemy.text("SET TIME ZONE 'UTC'"))
result = connection.execute(sqlalchemy.text("select version_num from alembic_version"))
for row in result:
print(f'database revision {row[0]}')
_engine.set(engine)
return db
raise Exception('database version not found')
db = Db()

View File

@@ -1,4 +1,6 @@
from hexbytes import HexBytes
from sqlalchemy import SMALLINT, INTEGER, BIGINT
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.orm import mapped_column
from typing_extensions import Annotated
@@ -75,6 +77,8 @@ Int256 = Annotated[int, mapped_column(t.IntBits(256, True))]
Address = Annotated[str, mapped_column(t.Address())]
Bytes = Annotated[bytes, mapped_column(BYTEA)]
BlockCol = Annotated[int, mapped_column(BIGINT)]
Blockchain = Annotated[NativeBlockchain, mapped_column(t.Blockchain)]

View File

@@ -1,7 +1,5 @@
from sqlalchemy.orm import DeclarativeBase, declared_attr
from dexorder import ctx
# add Base as the -last- class inherited on classes which should get tables
class Base(DeclarativeBase):
@@ -9,8 +7,3 @@ class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return cls.__name__.lower()
@classmethod
def get(cls, **kwargs):
return ctx.session.get(cls, kwargs)

View File

@@ -3,7 +3,7 @@ from contextvars import ContextVar
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from dexorder.db.model import Base
from dexorder.database.model import Base
class Block(Base):
@@ -11,10 +11,10 @@ class Block(Base):
height: Mapped[int] = mapped_column(primary_key=True) # timescaledb index
hash: Mapped[bytes] = mapped_column(primary_key=True)
parent: Mapped[bytes]
data: Mapped[dict] = mapped_column(JSONB)
data: Mapped[dict] = mapped_column('data',JSONB)
def __str__(self):
return f'{self.height}_{self.hash.hex()}'
return f'{self.height}_{self.hash}'
@staticmethod
def cur() -> 'Block':

View File

@@ -0,0 +1,12 @@
import logging
from sqlalchemy.orm import Mapped, mapped_column
from dexorder.database.column import Address
from dexorder.database.model import Base
log = logging.getLogger(__name__)
class VaultToken (Base):
vault:Mapped[Address] = mapped_column(primary_key=True)
token:Mapped[Address] = mapped_column(primary_key=True)

View File

@@ -1,25 +0,0 @@
import sqlalchemy
from .migrate import migrate_database
from .. import config, ctx
# noinspection PyShadowingNames
def connect(url=None, migrate=True, reconnect=False, dump_sql=None):
if ctx.engine is not None and not reconnect:
return
if url is None:
url = config.db_url
if dump_sql is None:
dump_sql = config.dump_sql
engine = sqlalchemy.create_engine(url, echo=dump_sql)
if migrate:
migrate_database()
with engine.connect() as connection:
connection.execute(sqlalchemy.text("SET TIME ZONE 'UTC'"))
result = connection.execute(sqlalchemy.text("select version_num from alembic_version"))
for row in result:
print(f'database revision {row[0]}')
ctx.engine = engine
return
raise Exception('database version not found')

View File

@@ -0,0 +1,166 @@
import asyncio
import logging
from typing import Callable, Union
from web3 import AsyncWeb3
from web3.contract.contract import ContractEvents
from web3.exceptions import LogTopicError
from dexorder import Blockchain, db, blockchain
from dexorder.base.blockstate import BlockState, BlockDict
from dexorder.blockchain.connection import create_w3_ws, W3
from dexorder.blockchain.util import get_contract_data
from dexorder.database.model import Block
from dexorder.database.model.vault_tokens import VaultToken
from dexorder.util import hexstr, topic
log = logging.getLogger(__name__)
vault_addresses = BlockDict('v')
underfunded_vaults = BlockDict('ufv')
active_orders = BlockDict('a')
pool_prices = BlockDict('p')
wallets = BlockDict('wallets') # todo remove debug
class TriggerRunner:
def __init__(self):
self.root_age = 10 # todo set per chain
self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = []
async def run(self):
"""
1. load root stateBlockchain
a. if no root, init from head
b. if root is old, batch forward by height
2. discover new heads
2b. find in-memory ancestor else use root
3. context = ancestor->head diff
4. query global log filter
5. process new vaults
6. process new orders and cancels
a. new pools
7. process Swap events and generate pool prices
8. process price horizons
9. process token movement
10. process swap triggers (zero constraint tranches)
11. process price tranche triggers
12. process horizon tranche triggers
13. filter by time tranche triggers
14. bundle execution requests and send tx. tx has require(block<deadline)
15. on tx confirmation, the block height of all executed trigger requests is set to the tx block
"""
db.connect()
w3 = blockchain.connect()
w3ws = create_w3_ws()
chain_id = await w3ws.eth.chain_id
Blockchain.set_cur(Blockchain.for_id(chain_id))
# todo load root
state = None
async with w3ws as w3ws:
await w3ws.eth.subscribe('newHeads')
while True:
async for head in w3ws.listen_to_websocket():
session = None
try:
log.debug('head', head)
# block_data = await w3.eth.get_block(head['hash'], True)
block_data = (await w3.provider.make_request('eth_getBlockByHash',[hexstr(head['hash']),False]))['result']
block = Block(chain=chain_id, height=int(block_data['number'],0),
hash=bytes.fromhex(block_data['hash'][2:]), parent=bytes.fromhex(block_data['parentHash'][2:]), data=block_data)
block.set_latest(block)
block.set_cur(block)
if state is None:
state = BlockState(block, {})
BlockState.set_cur(state)
self.setup_triggers(w3)
log.info('Created new empty root state')
else:
ancestor = BlockState.cur().add_block(block)
if ancestor is None:
log.debug(f'discarded late-arriving head {block}')
elif type(ancestor) is int:
# todo backfill batches
log.error(f'backfill unimplemented for range {ancestor} to {block}')
else:
futures = []
for callback, event, log_filter in self.events:
log_filter['blockhash'] = w3.to_hex(block.hash)
futures.append(w3.eth.get_logs(log_filter))
results = await asyncio.gather(*futures)
if session is None:
session = db.session
session.begin()
session.add(block)
for result, (callback,event,filter_args) in zip(results,self.events):
for log_event in result:
callback(log_event)
# check for root promotion
if block.height - state.root_block.height > self.root_age:
b = block
try:
for _ in range(1, self.root_age):
# we walk backwards self.root_age and promote what's there
b = state.by_hash[b.parent]
except KeyError:
pass
else:
log.debug(f'promoting root {b}')
state.promote_root(b)
except:
if session is not None:
session.rollback()
raise
else:
if session is not None:
session.commit()
def handle_transfer(self, event):
w3 = W3.cur()
try:
transfer = w3.eth.contract(abi=get_contract_data('ERC20')['abi']).events.Transfer().process_log(event)
except LogTopicError:
return
to_address = transfer['args']['to']
print('transfer', to_address)
if to_address in vault_addresses:
# todo publish event to vault watchers
db.session.add(VaultToken(vault=to_address, token=event.address))
if to_address in underfunded_vaults:
# todo flag underfunded vault (check token type?)
pass
BlockDict('wallets').add(to_address)
def handle_swap(self, event):
w3 = W3.cur()
try:
swap = w3.eth.contract(abi=get_contract_data('IUniswapV3PoolEvents')['abi']).events.Swap().process_log(event)
except LogTopicError:
return
try:
sqrt_price = swap['args']['sqrtPriceX96']
except KeyError:
return
addr = event['address']
price = sqrt_price * sqrt_price / 2**(96*2)
print(f'pool {addr} {price}')
# pool_prices[addr] =
def add_event_trigger(self, callback:Callable[[dict],None], event: ContractEvents, log_filter: Union[dict,str]=None):
if log_filter is None:
log_filter = {'topics':[topic(event.abi)]}
self.events.append((callback, event, log_filter))
def setup_triggers(self, w3: AsyncWeb3):
transfer = w3.eth.contract(abi=get_contract_data('ERC20')['abi']).events.Transfer()
self.add_event_trigger(self.handle_transfer, transfer)
swap = w3.eth.contract(abi=get_contract_data('IUniswapV3PoolEvents')['abi']).events.Swap()
self.add_event_trigger(self.handle_swap, swap)

View File

@@ -1,5 +1,8 @@
import re
from eth_utils import keccak
from hexbytes import HexBytes
from .async_yield import async_yield
from .tick_math import nearest_available_ticks, round_tick, spans_tick, spans_range
@@ -12,3 +15,20 @@ def align_decimal(value, left_columns) -> str:
pad = max(left_columns - len(re.sub(r'[^0-9]*$','',s.split('.')[0])), 0)
return ' ' * pad + s
def hexstr(value):
""" returns an 0x-prefixed hex string """
if type(value) is HexBytes:
return value.hex()
elif type(value) is bytes:
return '0x'+bytes.hex()
elif type(value) is str:
return value if value.startswith('0x') else '0x' + value
else:
raise ValueError
def topic(event_abi):
event_name = f'{event_abi["name"]}(' + ','.join(i['type'] for i in event_abi['inputs']) + ')'
result = '0x' + keccak(text=event_name).hex()
print(event_name, result)
return result

View File

@@ -1,5 +1,5 @@
from dexorder.base.blockstate import BlockState, BlockDict
from dexorder.db.model.block import Block
from dexorder.database.model.block import Block
block_10 = Block(chain=1, height=10, hash=bytes.fromhex('10'), parent=bytes.fromhex('09'), data=None)
block_11a = Block(chain=1, height=11, hash=bytes.fromhex('1a'), parent=block_10.hash, data=None)