diff --git a/src/dexorder/bin/main.py b/src/dexorder/bin/main.py index 1933516..d757109 100644 --- a/src/dexorder/bin/main.py +++ b/src/dexorder/bin/main.py @@ -117,7 +117,7 @@ async def main(): if redis_state: # load initial state log.info('initializing redis with root state') - await redis_state.save(state.root_fork, state.diffs_by_branch[state.root_branch.id]) + await redis_state.init(state, state.root_fork) await initialize_accounting_runner() diff --git a/src/dexorder/memcache/__init__.py b/src/dexorder/memcache/__init__.py index 9f13988..56bc6c3 100644 --- a/src/dexorder/memcache/__init__.py +++ b/src/dexorder/memcache/__init__.py @@ -4,45 +4,18 @@ from contextvars import ContextVar import redis.asyncio as redis_async from redis.asyncio import Redis -from redis.asyncio.client import Pipeline from dexorder import config log = logging.getLogger(__name__) -class FlushingPipeline: - def __init__(self, redis: Redis): - self.redis = redis - self.pipe: Pipeline = redis.pipeline() - self.full_pipes: list[Pipeline] = [] - self.count = 0 - self.flush_at = 10_000 - - def __getattr__(self, item): - if item in ('sadd', 'srem', 'hset', 'hdel', 'json'): - self.count += 1 - if self.count >= self.flush_at: - self.full_pipes.append(self.pipe) - self.pipe = self.redis.pipeline() - self.count = 0 - return getattr(self.pipe, item) - - async def execute(self): - for pipe in self.full_pipes: - await pipe.execute() - await self.pipe.execute() - self.pipe = None - self.full_pipes.clear() - self.count = 0 - - class Memcache: @staticmethod @asynccontextmanager - async def batch(): + async def batch(transaction=True): old_redis: Redis = current_redis.get() - pipe = FlushingPipeline(old_redis) + pipe = old_redis.pipeline(transaction=transaction) # noinspection PyTypeChecker current_redis.set(pipe) try: diff --git a/src/dexorder/memcache/memcache_state.py b/src/dexorder/memcache/memcache_state.py index 9aa0bdb..b3b2faf 100644 --- a/src/dexorder/memcache/memcache_state.py +++ b/src/dexorder/memcache/memcache_state.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from typing import Iterable, Union, Reversible, Any -from redis.asyncio.client import Pipeline, Redis +from redis.asyncio.client import Pipeline from socket_io_emitter import Emitter from dexorder import DELETE @@ -40,11 +40,12 @@ class RedisState (SeriesCollection): for series in self.datas.keys(): for k, v in state.iteritems(fork, series): diffs.append(DiffItem(series, k, v)) - await self.save(fork, diffs) + # todo tim fix pubs + await self.save(fork, diffs, use_transaction=True, skip_pubs=True) # use_transaction=False if the data is too big # noinspection PyAsyncCall - async def save(self, fork: Fork, diffs: Reversible[Union[DiffItem, DiffEntryItem]]): + async def save(self, fork: Fork, diffs: Reversible[Union[DiffItem, DiffEntryItem]], *, use_transaction=True, skip_pubs=False): # the diffs must be already compressed such that there is only one action per key chain = current_chain.get() chain_id = chain.id @@ -91,7 +92,9 @@ class RedisState (SeriesCollection): hsets[series][key] = value else: raise NotImplementedError - async with memcache.batch() as r: + async with memcache.batch(use_transaction) as r: + # Redis pipelines fill up before our state can be sent, so we cannot do this atomically. + # However, sending many individual calls is super slow, so we r: Pipeline for series, keys in sadds.items(): r.sadd(series, *keys) @@ -106,7 +109,7 @@ class RedisState (SeriesCollection): r.json(json_encoder).set(block_series,'$',[fork.height, headstr]) pubs.append((str(chain_id), 'head', [fork.height, headstr])) # separate batch for pubs - if pubs: + if pubs and not skip_pubs: await publish_all(pubs)