diff --git a/requirements.txt b/requirements.txt index 62acf63..5ed5b4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ orjson~=3.9.7 sortedcontainers~=2.4.0 hexbytes~=0.3.1 defaultlist~=1.0.0 +redis[hiredis] diff --git a/src/dexorder/bin/main.py b/src/dexorder/bin/main.py index 53437b7..08cbac3 100644 --- a/src/dexorder/bin/main.py +++ b/src/dexorder/bin/main.py @@ -1,32 +1,53 @@ import logging +from asyncio import CancelledError from dexorder import db, config, Blockchain from dexorder.base.chain import current_chain from dexorder.bin.executable import execute +from dexorder.blockstate import DiffItem, DiffEntry from dexorder.blockstate.blockdata import BlockData from dexorder.blockstate.db_state import DbState from dexorder.configuration import parse_args +from dexorder.memcache.memcache_state import RedisState +from dexorder.memcache import memcache from dexorder.runner import BlockStateRunner log = logging.getLogger('dexorder') -if __name__ == '__main__': +async def main(): logging.basicConfig(level=logging.INFO) - log = logging.getLogger('dexorder') log.setLevel(logging.DEBUG) parse_args() current_chain.set(Blockchain.get(config.chain)) + redis_state = None state = None + if memcache: + await memcache.connect() + redis_state = RedisState(BlockData.by_tag['redis']) if db: db.connect() db_state = DbState(BlockData.by_tag['db']) with db.session: state = db_state.load() - log.info(f'loaded state from db for root block {state.root_block}') + if state is not None: + if redis_state: + await redis_state.init(state) + log.info(f'loaded state from db for root block {state.root_block}') + runner = BlockStateRunner(state) if db: # noinspection PyUnboundLocalVariable runner.on_promotion.append(db_state.save) - execute(runner.run()) # single task + if redis_state: + # noinspection PyTypeChecker + runner.on_head_update.append(redis_state.save) + try: + await runner.run() + except CancelledError: + pass log.info('exiting') + + +if __name__ == '__main__': + execute(main()) diff --git a/src/dexorder/blockstate/__init__.py b/src/dexorder/blockstate/__init__.py index 1dafa73..ac23173 100644 --- a/src/dexorder/blockstate/__init__.py +++ b/src/dexorder/blockstate/__init__.py @@ -1,70 +1,70 @@ from .diff import DiffEntry, DiffItem, DELETE from .state import BlockState, current_blockstate -from .blockdata import BlockDict, BlockSet +from .blockdata import DataType, BlockDict, BlockSet -def _test(): - - def B(height, hash:str, parent): - return Block(chain=1337, height=height, hash=hash.encode('utf8'), parent=None if parent is None else parent.hash, data=None) - - root_block = B(10, '#root', None ) - state = BlockState(root_block) - current_blockstate.set(state) - b11 = B(11, '#b11', parent=root_block) - f11: Fork = state.add_block(b11) - print('f11',f11) - b11b = B(11, '#b11b', parent=root_block) - f11b: Fork = state.add_block(b11b) - print('f11b',f11b) - b12 = B(12, '#b12', parent=b11) - f12: Fork = state.add_block(b12) - print('f12',f12) - - d = BlockDict('ser') - - def dump(): - print() - print(current_fork.get().hash if current_fork.get() is not None else 'root') - for k,v in d.items(): - print(f'{k} = {v}') - - current_fork.set(None) # Use None to set values on root - d['foo'] = 'bar' - d['test'] = 'failed' - - current_fork.set(f11) - d['foo2'] = 'bar2' - del d['test'] - - current_fork.set(f11b) - del d['foo2'] - d['foob'] = 'barb' - - current_fork.set(f12) - d['test'] = 'ok' - - for f in (None, f11, f11b, f12): - current_fork.set(f) - dump() - - print() - print('all b12 diffs') - for i in state.collect_diffs(b12): - print(i) - - print() - print('promoting b11') - state.promote_root(f11) - current_fork.set(f12) - dump() - - print() - print('promoting b12') - state.promote_root(f12) - current_fork.set(f12) - dump() - - -if __name__ == '__main__': - _test() +# def _test(): +# +# def B(height, hash:str, parent): +# return Block(chain=1337, height=height, hash=hash.encode('utf8'), parent=None if parent is None else parent.hash, data=None) +# +# root_block = B(10, '#root', None ) +# state = BlockState(root_block) +# current_blockstate.set(state) +# b11 = B(11, '#b11', parent=root_block) +# f11: Fork = state.add_block(b11) +# print('f11',f11) +# b11b = B(11, '#b11b', parent=root_block) +# f11b: Fork = state.add_block(b11b) +# print('f11b',f11b) +# b12 = B(12, '#b12', parent=b11) +# f12: Fork = state.add_block(b12) +# print('f12',f12) +# +# d = BlockDict('ser') +# +# def dump(): +# print() +# print(current_fork.get().hash if current_fork.get() is not None else 'root') +# for k,v in d.items(): +# print(f'{k} = {v}') +# +# current_fork.set(None) # Use None to set values on root +# d['foo'] = 'bar' +# d['test'] = 'failed' +# +# current_fork.set(f11) +# d['foo2'] = 'bar2' +# del d['test'] +# +# current_fork.set(f11b) +# del d['foo2'] +# d['foob'] = 'barb' +# +# current_fork.set(f12) +# d['test'] = 'ok' +# +# for f in (None, f11, f11b, f12): +# current_fork.set(f) +# dump() +# +# print() +# print('all b12 diffs') +# for i in state.collect_diffs(b12): +# print(i) +# +# print() +# print('promoting b11') +# state.promote_root(f11) +# current_fork.set(f12) +# dump() +# +# print() +# print('promoting b12') +# state.promote_root(f12) +# current_fork.set(f12) +# dump() +# +# +# if __name__ == '__main__': +# _test() diff --git a/src/dexorder/blockstate/blockdata.py b/src/dexorder/blockstate/blockdata.py index 24cace1..b9f22d5 100644 --- a/src/dexorder/blockstate/blockdata.py +++ b/src/dexorder/blockstate/blockdata.py @@ -1,7 +1,7 @@ import logging from collections import defaultdict from enum import Enum -from typing import TypeVar, Generic, Iterable +from typing import TypeVar, Generic, Iterable, Union from dexorder import NARG from dexorder.base.fork import current_fork @@ -12,17 +12,18 @@ log = logging.getLogger(__name__) T = TypeVar('T') -class BlockData: - class Type (Enum): - SCALAR:int = 0 - SET:int = 1 - LIST:int = 2 - DICT:int = 3 +class DataType(Enum): + SCALAR: int = 0 + SET: int = 1 + LIST: int = 2 + DICT: int = 3 + +class BlockData: registry: dict[str,'BlockData'] = {} # series name and instance by_tag: dict[str, list['BlockData']] = defaultdict(list) - def __init__(self, series:str, data_type: Type, **tags): + def __init__(self, series:str, data_type: DataType, **tags): assert series not in BlockData.registry BlockData.registry[series] = self self.series = series @@ -60,7 +61,7 @@ class BlockData: class BlockSet(Generic[T], Iterable[T], BlockData): def __init__(self, series: str, **tags): - super().__init__(series, BlockData.Type.SET, **tags) + super().__init__(series, DataType.SET, **tags) self.series = series def add(self, item): @@ -80,7 +81,7 @@ class BlockSet(Generic[T], Iterable[T], BlockData): class BlockDict(Generic[T], BlockData): def __init__(self, series: str, **tags): - super().__init__(series, BlockData.Type.DICT, **tags) + super().__init__(series, DataType.DICT, **tags) def __setitem__(self, item, value): self.setitem(item, value) @@ -99,3 +100,11 @@ class BlockDict(Generic[T], BlockData): def get(self, item, default=None): return self.getitem(item, default) + + +class SeriesCollection: + def __init__(self, series_or_datavars: Iterable[Union[str,BlockData]]): + self.types = { + (d:=BlockData.registry[x] if type(x) is str else x).series:d.type + for x in series_or_datavars + } diff --git a/src/dexorder/blockstate/db_state.py b/src/dexorder/blockstate/db_state.py index 23149b3..385d468 100644 --- a/src/dexorder/blockstate/db_state.py +++ b/src/dexorder/blockstate/db_state.py @@ -1,8 +1,9 @@ import logging from typing import Iterable, Optional, Union -from . import DiffItem, BlockSet, BlockDict, DELETE, BlockState, current_blockstate -from .blockdata import BlockData +from . import DiffItem, BlockSet, BlockDict, DELETE, BlockState, current_blockstate, DataType +from .blockdata import BlockData, SeriesCollection +from .diff import DiffEntryItem from .. import db from ..base.chain import current_chain from ..base.fork import current_fork @@ -13,14 +14,9 @@ from ..util import keystr, strkey, hexbytes log = logging.getLogger(__name__) -class DbState: - def __init__(self, series_or_datavars: Iterable[Union[str,BlockData]]): - self.types = { - (d:=BlockData.registry[x] if type(x) is str else x).series:d.type - for x in series_or_datavars - } +class DbState(SeriesCollection): - def save(self, root_block: Block, diffs: Iterable[DiffItem] ): + def save(self, root_block: Block, diffs: Iterable[Union[DiffItem,DiffEntryItem]] ): chain_id = current_chain.get().chain_id for diff in diffs: try: @@ -30,21 +26,21 @@ class DbState: diffseries = keystr(diff.series) diffkey = keystr(diff.key) key = dict(chain=chain_id, series=diffseries, key=diffkey) - if diff.entry.value is DELETE: - Entity = SeriesSet if t == BlockData.Type.SET else SeriesDict if t == BlockData.Type.DICT else None + if diff.value is DELETE: + Entity = SeriesSet if t == DataType.SET else SeriesDict if t == DataType.DICT else None db.session.query(Entity).filter(Entity.chain==chain_id, Entity.series==diffseries, Entity.key==diffkey).delete() else: # upsert - if t == BlockData.Type.SET: + if t == DataType.SET: found = db.session.get(SeriesSet, key) if found is None: db.session.add(SeriesSet(**key)) - elif t == BlockData.Type.DICT: + elif t == DataType.DICT: found = db.session.get(SeriesDict, key) if found is None: - db.session.add(SeriesDict(**key, value=diff.entry.value)) + db.session.add(SeriesDict(**key, value=diff.value)) else: - found.value = diff.entry.value + found.value = diff.value else: raise NotImplementedError db.kv[f'root_block.{root_block.chain}'] = [root_block.height, root_block.hash] @@ -65,12 +61,12 @@ class DbState: current_blockstate.set(state) current_fork.set(None) # root fork for series, t in self.types.items(): - if t == BlockData.Type.SET: + if t == DataType.SET: # noinspection PyTypeChecker var: BlockSet = BlockData.registry[series] for row in db.session.query(SeriesSet).where(SeriesSet.series==keystr(series)): var.add(strkey(row.key)) - elif t == BlockData.Type.DICT: + elif t == DataType.DICT: # noinspection PyTypeChecker var: BlockDict = BlockData.registry[series] for row in db.session.query(SeriesDict).where(SeriesDict.series==keystr(series)): diff --git a/src/dexorder/blockstate/diff.py b/src/dexorder/blockstate/diff.py index 280902b..05fb48f 100644 --- a/src/dexorder/blockstate/diff.py +++ b/src/dexorder/blockstate/diff.py @@ -14,9 +14,22 @@ class DiffEntry: @dataclass class DiffItem: + series: Any + key: Any + value: Any + + def __str__(self): + return f'{self.series}.{self.key}={"[DEL]" if self.value is DELETE else self.value}' + +@dataclass +class DiffEntryItem: series: Any key: Any entry: DiffEntry + @property + def value(self): + return self.entry.value + def __str__(self): return f'{self.entry.hash.hex()} {self.series}.{self.key}={"[DEL]" if self.entry.value is DELETE else self.entry.value}' diff --git a/src/dexorder/blockstate/state.py b/src/dexorder/blockstate/state.py index 381e85b..d044b0b 100644 --- a/src/dexorder/blockstate/state.py +++ b/src/dexorder/blockstate/state.py @@ -2,7 +2,7 @@ import itertools import logging from collections import defaultdict from contextvars import ContextVar -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Sequence, Reversible from sortedcontainers import SortedList @@ -10,11 +10,25 @@ from dexorder import NARG from dexorder.base.fork import Fork, DisjointFork from dexorder.database.model import Block from dexorder.util import hexstr -from .diff import DiffEntry, DiffItem, DELETE +from .diff import DiffEntry, DiffItem, DELETE, DiffEntryItem log = logging.getLogger(__name__) +def compress_diffs(difflist: Reversible): + """ diff items must be in chronological order """ + sk = set() # seen keys + result: list = [] + # iterate through all diffs in -reverse- chronological order keeping only the first item we see for each key + for i in reversed(difflist): + k = i.series, i.key + if k not in sk: + sk.add(k) + result.append(i) + result.reverse() # forward chronological order + return result + + class BlockState: by_chain: dict[int, 'BlockState'] = {} @@ -34,7 +48,7 @@ class BlockState: # diffs_by_series is the main data structure. leaf nodes are list of diffs sorted by blockheight self.diffs_by_series: dict[Any, dict[Any, SortedList[DiffEntry]]] = defaultdict(lambda: defaultdict(lambda: SortedList(key=lambda x: x.height))) # diffs_by_hash holds the diff items generated by each block - self.diffs_by_hash: dict[bytes, list[DiffItem]] = defaultdict(list) + self.diffs_by_hash: dict[bytes, list[DiffEntryItem]] = defaultdict(list) self.ancestors: dict[bytes, Block] = {} BlockState.by_chain[root_block.chain] = self @@ -128,7 +142,7 @@ class BlockState: fork.height if fork is not None else self.root_block.height, fork.hash if fork is not None else self.root_block.hash) if fork is not None: - self.diffs_by_hash[fork.hash].append(DiffItem(series, key, diff)) + self.diffs_by_hash[fork.hash].append(DiffEntryItem(series, key, diff)) diffs.add(diff) def iteritems(self, fork: Optional[Fork], series): @@ -186,12 +200,12 @@ class BlockState: log.debug(f'promoted root {self.root_block}') return diffs - def collect_diffs(self, block: Block, series_key=NARG) -> list[DiffItem]: + def collect_diffs(self, block: Block, series_key=NARG) -> list[DiffEntryItem]: """ returns a list of the latest DiffItem for each key change along the ancestor path from block to root """ # first collect the exhaustive list of diffs along the ancestry path - diff_lists: list[list[DiffItem]] = [] + diff_lists: list[list[DiffEntryItem]] = [] while block.height > self.root_block.height: diffs = self.diffs_by_hash.get(block.hash) if diffs: @@ -199,18 +213,8 @@ class BlockState: diffs = [d for d in diffs if d.series == series_key] diff_lists.append(diffs) block = self.ancestors[block.hash] - - # now keep only the latest values for keys that were set multiple times - sk = set() # seen keys - result: list[DiffItem] = [] - # iterate through all diffs in -reverse- chronological order keeping only the first item we see for each key - for i in itertools.chain(*(reversed(l) for l in diff_lists)): - k = i.series, i.key - if k not in sk: - sk.add(k) - result.append(i) - result.reverse() # forward chronological order - return result + difflist = list(itertools.chain(*reversed(diff_lists))) + return compress_diffs(difflist) current_blockstate = ContextVar[BlockState]('current_blockstate') diff --git a/src/dexorder/configuration/schema.py b/src/dexorder/configuration/schema.py index 7aee0a4..3195c78 100644 --- a/src/dexorder/configuration/schema.py +++ b/src/dexorder/configuration/schema.py @@ -3,38 +3,27 @@ from typing import Optional, Union # SCHEMA NOTES: -# - avoid using int keys since (1) they are hard to decipher by a human and (2) the Python TOML parser mistypes int keys +# - avoid using int keys since (a) they are hard to decipher by a human and (b) the Python TOML parser mistypes int keys # as strings in certain situations -# - do not nest structured types more than one level deep. it confuses the config's typing system +# - do not nest structured types more than one level deep. it confuses the config's typing system https://github.com/omry/omegaconf/issues/1058 @dataclass class Config: - db_url: str = 'postgresql://dexorder:redroxed@localhost/dexorder' - dump_sql: bool = False chain: Union[int,str] = 'Arbitrum' - # rpc_url may also reference the aliases from the foundry.toml's rpc_endpoints section rpc_url: str = 'http://localhost:8545' ws_url: str = 'ws://localhost:8545' rpc_urls: Optional[dict[str,str]] = field(default_factory=dict) + db_url: str = 'postgresql://dexorder:redroxed@localhost/dexorder' + dump_sql: bool = False + redis_url: str = 'redis://localhost:6379' + + tokens: list['TokenConfig'] = field(default_factory=list) account: Optional[str] = None # may be a private key or an account alias accounts: Optional[dict[str,str]] = field(default_factory=dict) # account aliases min_gas: str = '0' - tokens: list['TokenConfig'] = field(default_factory=list) - dexorders: list['DexorderConfig'] = field(default_factory=list) - pools: list['PoolConfig'] = field(default_factory=list) - query_helpers: dict[str,str] = field(default_factory=dict) - - # Dispatcher - polling_interval: float = 0.2 - backoff_factor: float = 1.5 - max_interval: float = 10 - - # positive numbers are absolute block numbers and negative numbers are relative to the latest block - backfill: int = 0 - @dataclass class TokenConfig: @@ -44,30 +33,3 @@ class TokenConfig: chain: str address: str abi: Optional[str] = None - - -@dataclass -class PoolConfig: - chain: str - address: str - token_a: str - token_b: str - fee: int - enabled: bool = False - - -@dataclass -class DexorderConfig: - chain: str - address: str - pool: str - owner: str - name: Optional[str] = None - width: Optional[int] = None # in bps aka ticks - width_above: Optional[int] = None # defaults to width - width_below: Optional[int] = None # defaults to width - offset: Optional[int] = None # in bps aka ticks - offset_above: Optional[int] = None # defaults to offset - offset_below: Optional[int] = None # defaults to offset - ema: Optional[int] = None - diff --git a/src/dexorder/memcache/__init__.py b/src/dexorder/memcache/__init__.py new file mode 100644 index 0000000..c1d419e --- /dev/null +++ b/src/dexorder/memcache/__init__.py @@ -0,0 +1,45 @@ +import logging +from contextlib import asynccontextmanager +from contextvars import ContextVar + +import redis.asyncio as redis +from redis.asyncio import Redis +from redis.asyncio.client import Pipeline + +from dexorder import config + +log = logging.getLogger(__name__) + + +class Memcache: + @staticmethod + @asynccontextmanager + async def batch(): + old_redis: Redis = current_redis.get() + pipe: Pipeline = old_redis.pipeline() + current_redis.set(pipe) + try: + yield pipe + await pipe.execute() + finally: + current_redis.set(old_redis) + + + @staticmethod + async def connect(redis_url=None): + if redis_url is None: + redis_url = config.redis_url + r = await redis.from_url(redis_url, decode_responses=True, protocol=3) + current_redis.set(r) + return r + + + @staticmethod + def __bool__(): + return bool(config.redis_url) + + +memcache = Memcache() + +current_redis = ContextVar[Redis]('current_redis') + diff --git a/src/dexorder/memcache/memcache_state.py b/src/dexorder/memcache/memcache_state.py new file mode 100644 index 0000000..4eb77e0 --- /dev/null +++ b/src/dexorder/memcache/memcache_state.py @@ -0,0 +1,84 @@ +import logging +from collections import defaultdict +from typing import Iterable, Union, Reversible + +from redis.asyncio.client import Pipeline + +from dexorder.base.chain import current_chain +from dexorder.base.fork import current_fork +from dexorder.blockstate import DiffItem, DataType, DELETE, BlockState +from dexorder.blockstate.blockdata import SeriesCollection, BlockData +from dexorder.blockstate.diff import DiffEntryItem +from dexorder.blockstate.state import compress_diffs +from dexorder.database.model import Block +from dexorder.memcache import current_redis, memcache +from dexorder.util import keystr +from dexorder.util.json import json_encoder + +log = logging.getLogger(__name__) + + +class RedisState (SeriesCollection): + + def __init__(self, series_or_datavars: Iterable[Union[str, BlockData]]): + super().__init__(series_or_datavars) + self.exists:set[str] = set() + + async def clear(self): + r = current_redis.get() + await r.delete(f'{current_chain.get().chain_id}|latest_block', *self.types.keys()) + + + async def init(self, state: BlockState): + fork = current_fork.get() + await self.clear() + diffs = [] + for series, t in self.types.items(): + for k, v in state.iteritems(fork, series): + diffs.append(DiffItem(series, k, v)) + await self.save(state.root_block, diffs) + + + # noinspection PyAsyncCall + async def save(self, block: Block, diffs: Reversible[Union[DiffItem, DiffEntryItem]] ): + # the diffs must be already compressed such that there is only one action per key + chain = current_chain.get() + assert block.chain == chain.chain_id + chain_id = chain.chain_id + sadds: dict[str,set[str]] = defaultdict(set) + sdels: dict[str,set[str]] = defaultdict(set) + hsets: dict[str,dict[str,str]] = defaultdict(dict) + hdels: dict[str,set[str]] = defaultdict(set) + for diff in compress_diffs(diffs): + try: + t = self.types[diff.series] + except KeyError: + continue + series = f'{chain_id}|{keystr(diff.series)}' + key = keystr(diff.key) + if diff.value is DELETE: + if t == DataType.SET: + sdels[series].add(key) + elif t == DataType.DICT: + hdels[series].add(key) + else: + raise NotImplementedError + else: + if t == DataType.SET: + sadds[series].add(key) + elif t == DataType.DICT: + hsets[series][key] = keystr(diff.value) + else: + raise NotImplementedError + async with memcache.batch() as r: + r: Pipeline + for series, keys in sadds.items(): + r.sadd(series, *keys) + for series, keys in sdels.items(): + r.srem(series, *keys) + for series, kvs in hsets.items(): + r.hset(series, mapping=kvs) + for series, keys in hdels.items(): + r.hdel(series, *keys) + r.json(json_encoder).set(f'{current_chain.get().chain_id}|latest_block','$',block.data) + diff --git a/src/dexorder/runner.py b/src/dexorder/runner.py index 9b7ea4e..dd37516 100644 --- a/src/dexorder/runner.py +++ b/src/dexorder/runner.py @@ -12,10 +12,13 @@ from dexorder.base.fork import Fork, current_fork from dexorder.blockchain.connection import create_w3_ws from dexorder.blockchain.util import get_contract_data from dexorder.blockstate import DiffItem, BlockState, current_blockstate +from dexorder.blockstate.diff import DiffEntryItem +from dexorder.blockstate.state import compress_diffs from dexorder.data import pool_prices, vault_tokens, underfunded_vaults, vault_addresses from dexorder.database.model import Block from dexorder.database.model.block import current_block, latest_block from dexorder.util import hexstr, topic +from dexorder.util.async_util import maywait log = logging.getLogger(__name__) @@ -33,10 +36,10 @@ class BlockStateRunner: self.events:list[tuple[Callable[[dict],None],ContractEvents,dict]] = [] # onHeadUpdate callbacks are invoked with a list of DiffItems used to update the head state from either the previous head or the root - self.on_head_update: list[Callable[[Block,list[DiffItem]],None]] = [] + self.on_head_update: list[Callable[[Block,list[DiffEntryItem]],None]] = [] # onPromotion callbacks are invoked with a list of DiffItems used to advance the root state - self.on_promotion: list[Callable[[Block,list[DiffItem]],None]] = [] + self.on_promotion: list[Callable[[Block,list[DiffEntryItem]],None]] = [] async def run(self): @@ -139,7 +142,7 @@ class BlockStateRunner: # todo check for reorg and generate a reorg diff list diff_items = state.diffs_by_hash[block.hash] for callback in self.on_head_update: - callback(block, diff_items) + await maywait(callback(block, diff_items)) # check for root promotion promotion_height = fork.height - chain.confirms diff --git a/src/dexorder/util/__init__.py b/src/dexorder/util/__init__.py index d01b67e..15c02a5 100644 --- a/src/dexorder/util/__init__.py +++ b/src/dexorder/util/__init__.py @@ -3,7 +3,7 @@ import re from eth_utils import keccak from hexbytes import HexBytes -from .async_yield import async_yield +from .async_util import async_yield from .tick_math import nearest_available_ticks, round_tick, spans_tick, spans_range diff --git a/src/dexorder/util/async_yield.py b/src/dexorder/util/async_util.py similarity index 56% rename from src/dexorder/util/async_yield.py rename to src/dexorder/util/async_util.py index 9bbe0ae..2d53ee7 100644 --- a/src/dexorder/util/async_yield.py +++ b/src/dexorder/util/async_util.py @@ -1,5 +1,13 @@ import asyncio +import inspect + async def async_yield(): # a value of exactly 0 doesn't seem to work as well, so we set 1 nanosecond await asyncio.sleep(1e-9) + + +async def maywait(obj): + if inspect.isawaitable(obj): + obj = await obj + return obj diff --git a/src/dexorder/util/json.py b/src/dexorder/util/json.py index 2b46334..9a688b9 100644 --- a/src/dexorder/util/json.py +++ b/src/dexorder/util/json.py @@ -1,4 +1,6 @@ from decimal import Decimal +from json import JSONEncoder +from typing import Any from hexbytes import HexBytes from orjson import orjson @@ -22,3 +24,11 @@ def loads(s): def dumps(obj): return orjson.dumps(obj, default=_serialize, option=orjson.OPT_PASSTHROUGH_SUBCLASS).decode('utf8') + + +class JsonEncoder (JSONEncoder): + def default(self, o: Any) -> Any: + return _serialize(o) + + +json_encoder = JsonEncoder()