From d49f142fe320d45646356ea018060320c422221c Mon Sep 17 00:00:00 2001 From: tim Date: Tue, 1 Apr 2025 10:54:25 -0400 Subject: [PATCH] redis pipeline autoflush after 10000 entries --- src/dexorder/memcache/__init__.py | 29 ++++++++++++++++++++++++- src/dexorder/memcache/memcache_state.py | 27 ++++++++++++----------- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/dexorder/memcache/__init__.py b/src/dexorder/memcache/__init__.py index 27d7839..9f13988 100644 --- a/src/dexorder/memcache/__init__.py +++ b/src/dexorder/memcache/__init__.py @@ -11,12 +11,39 @@ from dexorder import config log = logging.getLogger(__name__) +class FlushingPipeline: + def __init__(self, redis: Redis): + self.redis = redis + self.pipe: Pipeline = redis.pipeline() + self.full_pipes: list[Pipeline] = [] + self.count = 0 + self.flush_at = 10_000 + + def __getattr__(self, item): + if item in ('sadd', 'srem', 'hset', 'hdel', 'json'): + self.count += 1 + if self.count >= self.flush_at: + self.full_pipes.append(self.pipe) + self.pipe = self.redis.pipeline() + self.count = 0 + return getattr(self.pipe, item) + + async def execute(self): + for pipe in self.full_pipes: + await pipe.execute() + await self.pipe.execute() + self.pipe = None + self.full_pipes.clear() + self.count = 0 + + class Memcache: @staticmethod @asynccontextmanager async def batch(): old_redis: Redis = current_redis.get() - pipe: Pipeline = old_redis.pipeline() + pipe = FlushingPipeline(old_redis) + # noinspection PyTypeChecker current_redis.set(pipe) try: yield pipe diff --git a/src/dexorder/memcache/memcache_state.py b/src/dexorder/memcache/memcache_state.py index 596d9ff..9aa0bdb 100644 --- a/src/dexorder/memcache/memcache_state.py +++ b/src/dexorder/memcache/memcache_state.py @@ -91,19 +91,20 @@ class RedisState (SeriesCollection): hsets[series][key] = value else: raise NotImplementedError - r: Redis = current_redis.get() - for series, keys in sadds.items(): - await r.sadd(series, *keys) - for series, keys in sdels.items(): - await r.srem(series, *keys) - for series, kvs in hsets.items(): - await r.hset(series, mapping=kvs) - for series, keys in hdels.items(): - await r.hdel(series, *keys) - block_series = f'{chain_id}|head' - headstr = hexstr(fork.head) - await r.json(json_encoder).set(block_series,'$',[fork.height, headstr]) - pubs.append((str(chain_id), 'head', [fork.height, headstr])) + async with memcache.batch() as r: + r: Pipeline + for series, keys in sadds.items(): + r.sadd(series, *keys) + for series, keys in sdels.items(): + r.srem(series, *keys) + for series, kvs in hsets.items(): + r.hset(series, mapping=kvs) + for series, keys in hdels.items(): + r.hdel(series, *keys) + block_series = f'{chain_id}|head' + headstr = hexstr(fork.head) + r.json(json_encoder).set(block_series,'$',[fork.height, headstr]) + pubs.append((str(chain_id), 'head', [fork.height, headstr])) # separate batch for pubs if pubs: await publish_all(pubs)