redis initial state push fix

This commit is contained in:
tim
2025-04-01 13:52:49 -04:00
parent 52b406ba17
commit 0bb670b356
3 changed files with 11 additions and 35 deletions

View File

@@ -117,7 +117,7 @@ async def main():
if redis_state: if redis_state:
# load initial state # load initial state
log.info('initializing redis with root state') log.info('initializing redis with root state')
await redis_state.save(state.root_fork, state.diffs_by_branch[state.root_branch.id]) await redis_state.init(state, state.root_fork)
await initialize_accounting_runner() await initialize_accounting_runner()

View File

@@ -4,45 +4,18 @@ from contextvars import ContextVar
import redis.asyncio as redis_async import redis.asyncio as redis_async
from redis.asyncio import Redis from redis.asyncio import Redis
from redis.asyncio.client import Pipeline
from dexorder import config from dexorder import config
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class FlushingPipeline:
def __init__(self, redis: Redis):
self.redis = redis
self.pipe: Pipeline = redis.pipeline()
self.full_pipes: list[Pipeline] = []
self.count = 0
self.flush_at = 10_000
def __getattr__(self, item):
if item in ('sadd', 'srem', 'hset', 'hdel', 'json'):
self.count += 1
if self.count >= self.flush_at:
self.full_pipes.append(self.pipe)
self.pipe = self.redis.pipeline()
self.count = 0
return getattr(self.pipe, item)
async def execute(self):
for pipe in self.full_pipes:
await pipe.execute()
await self.pipe.execute()
self.pipe = None
self.full_pipes.clear()
self.count = 0
class Memcache: class Memcache:
@staticmethod @staticmethod
@asynccontextmanager @asynccontextmanager
async def batch(): async def batch(transaction=True):
old_redis: Redis = current_redis.get() old_redis: Redis = current_redis.get()
pipe = FlushingPipeline(old_redis) pipe = old_redis.pipeline(transaction=transaction)
# noinspection PyTypeChecker # noinspection PyTypeChecker
current_redis.set(pipe) current_redis.set(pipe)
try: try:

View File

@@ -2,7 +2,7 @@ import logging
from collections import defaultdict from collections import defaultdict
from typing import Iterable, Union, Reversible, Any from typing import Iterable, Union, Reversible, Any
from redis.asyncio.client import Pipeline, Redis from redis.asyncio.client import Pipeline
from socket_io_emitter import Emitter from socket_io_emitter import Emitter
from dexorder import DELETE from dexorder import DELETE
@@ -40,11 +40,12 @@ class RedisState (SeriesCollection):
for series in self.datas.keys(): for series in self.datas.keys():
for k, v in state.iteritems(fork, series): for k, v in state.iteritems(fork, series):
diffs.append(DiffItem(series, k, v)) diffs.append(DiffItem(series, k, v))
await self.save(fork, diffs) # todo tim fix pubs
await self.save(fork, diffs, use_transaction=True, skip_pubs=True) # use_transaction=False if the data is too big
# noinspection PyAsyncCall # noinspection PyAsyncCall
async def save(self, fork: Fork, diffs: Reversible[Union[DiffItem, DiffEntryItem]]): async def save(self, fork: Fork, diffs: Reversible[Union[DiffItem, DiffEntryItem]], *, use_transaction=True, skip_pubs=False):
# the diffs must be already compressed such that there is only one action per key # the diffs must be already compressed such that there is only one action per key
chain = current_chain.get() chain = current_chain.get()
chain_id = chain.id chain_id = chain.id
@@ -91,7 +92,9 @@ class RedisState (SeriesCollection):
hsets[series][key] = value hsets[series][key] = value
else: else:
raise NotImplementedError raise NotImplementedError
async with memcache.batch() as r: async with memcache.batch(use_transaction) as r:
# Redis pipelines fill up before our state can be sent, so we cannot do this atomically.
# However, sending many individual calls is super slow, so we
r: Pipeline r: Pipeline
for series, keys in sadds.items(): for series, keys in sadds.items():
r.sadd(series, *keys) r.sadd(series, *keys)
@@ -106,7 +109,7 @@ class RedisState (SeriesCollection):
r.json(json_encoder).set(block_series,'$',[fork.height, headstr]) r.json(json_encoder).set(block_series,'$',[fork.height, headstr])
pubs.append((str(chain_id), 'head', [fork.height, headstr])) pubs.append((str(chain_id), 'head', [fork.height, headstr]))
# separate batch for pubs # separate batch for pubs
if pubs: if pubs and not skip_pubs:
await publish_all(pubs) await publish_all(pubs)