typing cleanup

This commit is contained in:
Tim
2024-01-25 18:20:23 -04:00
parent 96d54360b6
commit 5931dd5647
12 changed files with 33 additions and 41 deletions

View File

@@ -1,14 +1,18 @@
import logging import logging
import sys import sys
from asyncio import CancelledError from asyncio import CancelledError
from typing import Iterable, Union
from dexorder import blockchain, config from dexorder import blockchain, config
from dexorder.bin.executable import execute from dexorder.bin.executable import execute
from dexorder.blockstate import DiffItem
from dexorder.blockstate.blockdata import BlockData from dexorder.blockstate.blockdata import BlockData
from dexorder.blockstate.db_state import DbState from dexorder.blockstate.db_state import DbState
from dexorder.blockstate.diff import DiffEntryItem
from dexorder.configuration import parse_args from dexorder.configuration import parse_args
from dexorder.contract import get_contract_event from dexorder.contract import get_contract_event
from dexorder.database import db from dexorder.database import db
from dexorder.database.model import Block
from dexorder.event_handler import handle_uniswap_swap from dexorder.event_handler import handle_uniswap_swap
from dexorder.memcache.memcache_state import RedisState, publish_all from dexorder.memcache.memcache_state import RedisState, publish_all
from dexorder.memcache import memcache from dexorder.memcache import memcache
@@ -18,6 +22,10 @@ from dexorder.runner import BlockStateRunner
log = logging.getLogger('dexorder') log = logging.getLogger('dexorder')
def finalize_callback(block: Block, _diffs: Iterable[Union[DiffItem, DiffEntryItem]]):
log.info(f'backfill completed through block {block.height} {block.timestamp:%Y-%m-%d %H:%M:%S} {block.hash}')
# noinspection DuplicatedCode # noinspection DuplicatedCode
async def main(): async def main():
# noinspection DuplicatedCode # noinspection DuplicatedCode
@@ -47,16 +55,14 @@ async def main():
log.info(f'loaded state from db for root block {state.root_block}') log.info(f'loaded state from db for root block {state.root_block}')
runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None, timer_period=0) runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None, timer_period=0)
# noinspection PyTypeChecker
runner.add_event_trigger(handle_uniswap_swap, get_contract_event('IUniswapV3PoolEvents', 'Swap')) runner.add_event_trigger(handle_uniswap_swap, get_contract_event('IUniswapV3PoolEvents', 'Swap'))
# noinspection PyTypeChecker
runner.on_promotion.append(ohlc_finalize) runner.on_promotion.append(ohlc_finalize)
if db: if db:
# noinspection PyUnboundLocalVariable,PyTypeChecker # noinspection PyUnboundLocalVariable
runner.on_promotion.append(db_state.save) runner.on_promotion.append(db_state.save)
runner.on_promotion.append(finalize_callback)
if redis_state: if redis_state:
# noinspection PyTypeChecker
runner.on_head_update.append(redis_state.save) runner.on_head_update.append(redis_state.save)
try: try:

View File

@@ -97,15 +97,12 @@ async def main():
runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None) runner = BlockStateRunner(state, publish_all=publish_all if redis_state else None)
setup_logevent_triggers(runner) setup_logevent_triggers(runner)
if config.ohlc_dir: if config.ohlc_dir:
# noinspection PyTypeChecker
runner.on_promotion.append(ohlc_finalize) runner.on_promotion.append(ohlc_finalize)
if db: if db:
# noinspection PyTypeChecker
runner.on_state_init.append(init_order_triggers) runner.on_state_init.append(init_order_triggers)
# noinspection PyUnboundLocalVariable,PyTypeChecker # noinspection PyUnboundLocalVariable
runner.on_promotion.append(db_state.save) runner.on_promotion.append(db_state.save)
if redis_state: if redis_state:
# noinspection PyTypeChecker
runner.on_head_update.append(redis_state.save) runner.on_head_update.append(redis_state.save)
try: try:

View File

@@ -97,7 +97,6 @@ class BlockState:
return Fork([block.hash], height=block.height) return Fork([block.hash], height=block.height)
if block.height - self.ancestors[block.hash].height > 1: if block.height - self.ancestors[block.hash].height > 1:
# noinspection PyTypeChecker
return DisjointFork(block, self.root_block) return DisjointFork(block, self.root_block)
def ancestors(): def ancestors():

View File

@@ -80,7 +80,6 @@ def from_toml(filename):
def parse_args(args=None): def parse_args(args=None):
""" should be called from binaries to parse args as command-line config settings """ """ should be called from binaries to parse args as command-line config settings """
# noinspection PyTypeChecker
try: try:
config.merge_with(OmegaConf.from_cli(args)) # updates config in-place. THANK YOU OmegaConf! config.merge_with(OmegaConf.from_cli(args)) # updates config in-place. THANK YOU OmegaConf!
except OmegaConfBaseException as x: except OmegaConfBaseException as x:

View File

@@ -69,24 +69,22 @@ class ContractProxy:
def events(self): def events(self):
return self.contract.events return self.contract.events
def deploy(self, *args): # def deploy(self, *args):
""" # """
Calls the contract constructor transaction and waits to receive a transaction receipt. # Calls the contract constructor transaction and waits to receive a transaction receipt.
""" # """
tx: ContractTransaction = self.transact.constructor(*args) # tx: ContractTransaction = self.transact.constructor(*args)
receipt = tx.wait() # receipt = tx.wait()
self.address = receipt.contractAddress # self.address = receipt.contractAddress
self._contracts.clear() # self._contracts.clear()
return receipt # return receipt
@property @property
def transact(self): def transact(self):
# noinspection PyTypeChecker
return ContractProxy(self.address, self._interface_name, _contracts=self._contracts, _wrapper=transact_wrapper, abi=self._abi) return ContractProxy(self.address, self._interface_name, _contracts=self._contracts, _wrapper=transact_wrapper, abi=self._abi)
@property @property
def build(self): def build(self):
# noinspection PyTypeChecker
return ContractProxy(self.address, self._interface_name, _contracts=self._contracts, _wrapper=build_wrapper, abi=self._abi) return ContractProxy(self.address, self._interface_name, _contracts=self._contracts, _wrapper=build_wrapper, abi=self._abi)
def __getattr__(self, item): def __getattr__(self, item):

View File

@@ -75,7 +75,6 @@ class Db:
if engine is None: if engine is None:
raise RuntimeError('Cannot create session: no database engine set. Use dexorder.db.connect() first') raise RuntimeError('Cannot create session: no database engine set. Use dexorder.db.connect() first')
s = Session(engine, expire_on_commit=False) s = Session(engine, expire_on_commit=False)
# noinspection PyTypeChecker
_session.set(s) _session.set(s)
return s return s

View File

@@ -210,7 +210,6 @@ async def activate_time_triggers():
# log.debug(f'activating time triggers at {now}') # log.debug(f'activating time triggers at {now}')
# time triggers # time triggers
for tt in tuple(time_triggers): for tt in tuple(time_triggers):
# noinspection PyTypeChecker
await maywait(tt(now)) await maywait(tt(now))
@@ -220,16 +219,13 @@ async def activate_price_triggers():
for pool, price in new_pool_prices.items(): for pool, price in new_pool_prices.items():
pools_triggered.add(pool) pools_triggered.add(pool)
for pt in tuple(price_triggers[pool]): for pt in tuple(price_triggers[pool]):
# noinspection PyTypeChecker
await maywait(pt(price)) await maywait(pt(price))
for pool, triggers in new_price_triggers.items(): for pool, triggers in new_price_triggers.items():
if pool not in pools_triggered: if pool not in pools_triggered:
price = pool_prices[pool] price = pool_prices[pool]
for pt in triggers: for pt in triggers:
# noinspection PyTypeChecker
await maywait(pt(price)) await maywait(pt(price))
for t in tuple(unconstrained_price_triggers): for t in tuple(unconstrained_price_triggers):
# noinspection PyTypeChecker
await maywait(t(None)) await maywait(t(None))

View File

@@ -54,7 +54,7 @@ class RedisState (SeriesCollection):
sdels: dict[str,set[str]] = defaultdict(set) sdels: dict[str,set[str]] = defaultdict(set)
hsets: dict[str,dict[str,str]] = defaultdict(dict) hsets: dict[str,dict[str,str]] = defaultdict(dict)
hdels: dict[str,set[str]] = defaultdict(set) hdels: dict[str,set[str]] = defaultdict(set)
pubs: list[tuple[str,str,list[Any]]] = [] # series, key, value => room, event, value pubs: list[tuple[str,str,Any]] = [] # series, key, value => room, event, value
for diff in compress_diffs(diffs): for diff in compress_diffs(diffs):
try: try:
d = self.datas[diff.series] d = self.datas[diff.series]

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from asyncio import Queue from asyncio import Queue
from typing import Union, Any, Iterable from typing import Union, Any, Iterable, Callable
from web3.contract.contract import ContractEvents from web3.contract.contract import ContractEvents
from web3.exceptions import LogTopicError, MismatchedABI from web3.exceptions import LogTopicError, MismatchedABI
@@ -60,22 +60,22 @@ class BlockStateRunner:
self.state = state self.state = state
# items are (callback, event, log_filter). The callback is invoked with web3 EventData for every detected event # items are (callback, event, log_filter). The callback is invoked with web3 EventData for every detected event
self.events:list[tuple[Maywaitable[[EventData],None],ContractEvents,dict]] = [] self.events:list[tuple[Callable[[EventData],Maywaitable[None]],ContractEvents,dict]] = []
# these callbacks are invoked after every block and also every second if there wasnt a block # these callbacks are invoked after every block and also every second if there wasnt a block
self.postprocess_cbs:list[Maywaitable[[],None]] = [] self.postprocess_cbs:list[Callable[[],Maywaitable[None]]] = []
# onStateInit callbacks are invoked after the initial state is loaded or created # onStateInit callbacks are invoked after the initial state is loaded or created
self.on_state_init: list[Maywaitable[[],None]] = [] self.on_state_init: list[Callable[[],Maywaitable[None]]] = []
self.state_initialized = False self.state_initialized = False
# onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root # onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root
self.on_head_update: list[Maywaitable[[Block,list[DiffEntryItem]],None]] = [] self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],Maywaitable[None]]] = []
# onPromotion callbacks are invoked with a list of DiffItems used to advance the root state # onPromotion callbacks are invoked with a list of DiffItems used to advance the root state
self.on_promotion: list[Maywaitable[[Block,list[DiffEntryItem]],None]] = [] self.on_promotion: list[Callable[[Block,list[DiffEntryItem]],Maywaitable[None]]] = []
self.publish_all: Maywaitable[[Iterable[tuple[str,str,Any]]],None] = publish_all self.publish_all: Callable[[Iterable[tuple[str,str,Any]]],Maywaitable[None]] = publish_all
self.timer_period = timer_period self.timer_period = timer_period
@@ -85,7 +85,7 @@ class BlockStateRunner:
self.running = False self.running = False
def add_event_trigger(self, callback: Maywaitable[[EventData], None], event: ContractEvents = None, log_filter: Union[dict, str] = None): def add_event_trigger(self, callback: Callable[[EventData], Maywaitable[None]], event: ContractEvents = None, log_filter: Union[dict, str] = None):
""" """
if event is None, the callback is still invoked in the series of log handlers but with no argument instead of logs if event is None, the callback is still invoked in the series of log handlers but with no argument instead of logs
""" """

View File

@@ -17,14 +17,13 @@ def align_decimal(value, left_columns) -> str:
return ' ' * pad + s return ' ' * pad + s
def hexstr(value: bytes): def hexstr(value: Union[HexBytes, bytes, str]):
""" returns an 0x-prefixed hex string """ """ returns an 0x-prefixed hex string """
if type(value) is HexBytes: if type(value) is HexBytes:
return value.hex() return value.hex()
elif type(value) is bytes: elif type(value) is bytes:
return '0x' + value.hex() return '0x' + value.hex()
elif type(value) is str: elif type(value) is str:
# noinspection PyTypeChecker
return value if value.startswith('0x') else '0x' + value return value if value.startswith('0x') else '0x' + value
else: else:
raise ValueError raise ValueError

View File

@@ -12,9 +12,7 @@ async def async_yield():
Args = TypeVar('Args') Args = TypeVar('Args')
Return = TypeVar('Return') Return = TypeVar('Return')
class Maywaitable (Generic[Args, Return], Callable[[Args],Return], Awaitable[Return], ABC): Maywaitable = Union[Return, Awaitable[Return]]
pass
async def maywait(obj: Maywaitable): async def maywait(obj: Maywaitable):
if inspect.isawaitable(obj): if inspect.isawaitable(obj):

View File

@@ -1,8 +1,9 @@
import logging import logging
from typing import Never
log = logging.getLogger('dexorder') log = logging.getLogger('dexorder')
def fatal(message, exception=None): def fatal(message, exception=None) -> Never:
if exception is None and isinstance(message, (BaseException,RuntimeError)): if exception is None and isinstance(message, (BaseException,RuntimeError)):
exception = message exception = message
log.exception(message, exc_info=exception) log.exception(message, exc_info=exception)