Files
ai/backend/src/sync/registry.py
2026-03-04 03:28:09 -04:00

247 lines
11 KiB
Python

from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Deque
import jsonpatch
from pydantic import BaseModel
from sync.protocol import SnapshotMessage, PatchMessage
class SyncEntry:
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
self.model = model
self.store_name = store_name
self.seq = 0
self.last_snapshot = model.model_dump(mode="json")
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
current_state = self.model.model_dump(mode="json")
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
if not patch.patch:
return None
return patch.patch
def commit_patch(self, patch: List[Dict[str, Any]]):
self.seq += 1
self.history.append((self.seq, patch))
self.last_snapshot = self.model.model_dump(mode="json")
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
if since_seq == self.seq:
return []
# Check if all patches from since_seq + 1 to self.seq are in history
if not self.history or self.history[0][0] > since_seq + 1:
return None
result = []
for seq, patch in self.history:
if seq > since_seq:
result.append((seq, patch))
return result
class SyncRegistry:
def __init__(self):
self.entries: Dict[str, SyncEntry] = {}
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
def register(self, model: BaseModel, store_name: str):
self.entries[store_name] = SyncEntry(model, store_name)
async def push_all(self):
import logging
logger = logging.getLogger(__name__)
if not self.websocket:
logger.warning("push_all: No websocket connected, cannot push updates")
return
logger.info(f"push_all: Processing {len(self.entries)} store entries")
for entry in self.entries.values():
patch = entry.compute_patch()
if patch:
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
entry.commit_patch(patch)
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
await self.websocket.send_json(msg.model_dump(mode="json"))
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
else:
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
async def sync_client(self, client_seqs: Dict[str, int]):
if not self.websocket:
return
for store_name, entry in self.entries.items():
client_seq = client_seqs.get(store_name, -1)
patches = entry.catchup_patches(client_seq)
if patches is not None:
# Replay patches
for seq, patch in patches:
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
await self.websocket.send_json(msg.model_dump(mode="json"))
else:
# Send full snapshot
msg = SnapshotMessage(
store=store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
import logging
logger = logging.getLogger(__name__)
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
entry = self.entries.get(store_name)
if not entry:
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
return
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
try:
if client_base_seq == entry.seq:
# No conflict
logger.info("apply_client_patch: No conflict - applying patch directly")
current_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
try:
new_state = jsonpatch.apply_patch(current_state, patch)
logger.info(f"apply_client_patch: New state after patch: {new_state}")
self._update_model(entry.model, new_state)
# Verify the model was actually updated
updated_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
entry.commit_patch(patch)
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
# Don't broadcast back to client - they already have this change
# Broadcasting would cause an infinite loop
logger.info("apply_client_patch: Not broadcasting back to originating client")
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
elif client_base_seq < entry.seq:
# Conflict! Frontend wins.
# 1. Get backend patches since client_base_seq
backend_patches = []
for seq, p in entry.history:
if seq > client_base_seq:
backend_patches.append(p)
# 2. Apply frontend patch first to the state at client_base_seq
# But we only have the current authoritative model.
# "Apply the frontend patch first to the model (frontend wins)"
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
# Let's get the state as it was at client_base_seq if possible?
# No, history only has patches.
# Alternative: Apply frontend patch to current model.
# Then re-apply backend patches, but discard parts that overlap.
frontend_paths = {p['path'] for p in patch}
current_state = entry.model.model_dump(mode="json")
# Apply frontend patch
try:
new_state = jsonpatch.apply_patch(current_state, patch)
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
return
# Re-apply backend patches that don't overlap
for b_patch in backend_patches:
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
if filtered_b_patch:
try:
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
continue
self._update_model(entry.model, new_state)
# Commit the result as a single new patch
# We need to compute what changed from last_snapshot to new_state
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
if final_patch:
entry.commit_patch(final_patch)
# Broadcast resolved state as snapshot to converge
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
except Exception as e:
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
# Update model fields in-place to preserve references
# This is important for dict fields that may be referenced elsewhere
for field_name, field_info in model.model_fields.items():
if field_name in new_data:
new_value = new_data[field_name]
current_value = getattr(model, field_name)
# For dict fields, update in-place instead of replacing
if isinstance(current_value, dict) and isinstance(new_value, dict):
self._deep_update_dict(current_value, new_value)
else:
# For other types, just set the new value
setattr(model, field_name, new_value)
def _deep_update_dict(self, target: dict, source: dict):
"""Deep update target dict with source dict, preserving nested dict references."""
# Remove keys that are in target but not in source
keys_to_remove = set(target.keys()) - set(source.keys())
for key in keys_to_remove:
del target[key]
# Update or add keys from source
for key, source_value in source.items():
if key in target:
target_value = target[key]
# If both are dicts, recursively update
if isinstance(target_value, dict) and isinstance(source_value, dict):
self._deep_update_dict(target_value, source_value)
else:
# Replace the value
target[key] = source_value
else:
# Add new key
target[key] = source_value