from collections import deque from typing import Any, Dict, List, Optional, Tuple, Deque import jsonpatch from pydantic import BaseModel from sync.protocol import SnapshotMessage, PatchMessage class SyncEntry: def __init__(self, model: BaseModel, store_name: str, history_size: int = 50): self.model = model self.store_name = store_name self.seq = 0 self.last_snapshot = model.model_dump(mode="json") self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size) def compute_patch(self) -> Optional[List[Dict[str, Any]]]: current_state = self.model.model_dump(mode="json") patch = jsonpatch.make_patch(self.last_snapshot, current_state) if not patch.patch: return None return patch.patch def commit_patch(self, patch: List[Dict[str, Any]]): self.seq += 1 self.history.append((self.seq, patch)) self.last_snapshot = self.model.model_dump(mode="json") def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]: if since_seq == self.seq: return [] # Check if all patches from since_seq + 1 to self.seq are in history if not self.history or self.history[0][0] > since_seq + 1: return None result = [] for seq, patch in self.history: if seq > since_seq: result.append((seq, patch)) return result class SyncRegistry: def __init__(self): self.entries: Dict[str, SyncEntry] = {} self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar def register(self, model: BaseModel, store_name: str): self.entries[store_name] = SyncEntry(model, store_name) async def push_all(self): import logging logger = logging.getLogger(__name__) if not self.websocket: logger.warning("push_all: No websocket connected, cannot push updates") return logger.info(f"push_all: Processing {len(self.entries)} store entries") for entry in self.entries.values(): patch = entry.compute_patch() if patch: logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}") entry.commit_patch(patch) msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch) logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}") await self.websocket.send_json(msg.model_dump(mode="json")) logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'") else: logger.debug(f"push_all: No changes detected for store '{entry.store_name}'") async def sync_client(self, client_seqs: Dict[str, int]): if not self.websocket: return for store_name, entry in self.entries.items(): client_seq = client_seqs.get(store_name, -1) patches = entry.catchup_patches(client_seq) if patches is not None: # Replay patches for seq, patch in patches: msg = PatchMessage(store=store_name, seq=seq, patch=patch) await self.websocket.send_json(msg.model_dump(mode="json")) else: # Send full snapshot msg = SnapshotMessage( store=store_name, seq=entry.seq, state=entry.model.model_dump(mode="json") ) await self.websocket.send_json(msg.model_dump(mode="json")) async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]): import logging logger = logging.getLogger(__name__) logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}") entry = self.entries.get(store_name) if not entry: logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry") return logger.info(f"apply_client_patch: Current backend seq={entry.seq}") try: if client_base_seq == entry.seq: # No conflict logger.info("apply_client_patch: No conflict - applying patch directly") current_state = entry.model.model_dump(mode="json") logger.info(f"apply_client_patch: Current state before patch: {current_state}") try: new_state = jsonpatch.apply_patch(current_state, patch) logger.info(f"apply_client_patch: New state after patch: {new_state}") self._update_model(entry.model, new_state) # Verify the model was actually updated updated_state = entry.model.model_dump(mode="json") logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}") entry.commit_patch(patch) logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}") # Don't broadcast back to client - they already have this change # Broadcasting would cause an infinite loop logger.info("apply_client_patch: Not broadcasting back to originating client") except jsonpatch.JsonPatchConflict as e: logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.") # Send snapshot to force resync if self.websocket: msg = SnapshotMessage( store=entry.store_name, seq=entry.seq, state=entry.model.model_dump(mode="json") ) await self.websocket.send_json(msg.model_dump(mode="json")) elif client_base_seq < entry.seq: # Conflict! Frontend wins. # 1. Get backend patches since client_base_seq backend_patches = [] for seq, p in entry.history: if seq > client_base_seq: backend_patches.append(p) # 2. Apply frontend patch first to the state at client_base_seq # But we only have the current authoritative model. # "Apply the frontend patch first to the model (frontend wins)" # "Re-apply the backend deltas that do not overlap the frontend's changed paths on top" # Let's get the state as it was at client_base_seq if possible? # No, history only has patches. # Alternative: Apply frontend patch to current model. # Then re-apply backend patches, but discard parts that overlap. frontend_paths = {p['path'] for p in patch} current_state = entry.model.model_dump(mode="json") # Apply frontend patch try: new_state = jsonpatch.apply_patch(current_state, patch) except jsonpatch.JsonPatchConflict as e: logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.") # Send snapshot to force resync if self.websocket: msg = SnapshotMessage( store=entry.store_name, seq=entry.seq, state=entry.model.model_dump(mode="json") ) await self.websocket.send_json(msg.model_dump(mode="json")) return # Re-apply backend patches that don't overlap for b_patch in backend_patches: filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths] if filtered_b_patch: try: new_state = jsonpatch.apply_patch(new_state, filtered_b_patch) except jsonpatch.JsonPatchConflict as e: logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.") continue self._update_model(entry.model, new_state) # Commit the result as a single new patch # We need to compute what changed from last_snapshot to new_state final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch if final_patch: entry.commit_patch(final_patch) # Broadcast resolved state as snapshot to converge if self.websocket: msg = SnapshotMessage( store=entry.store_name, seq=entry.seq, state=entry.model.model_dump(mode="json") ) await self.websocket.send_json(msg.model_dump(mode="json")) except Exception as e: logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True) # Send snapshot to force resync if self.websocket: msg = SnapshotMessage( store=entry.store_name, seq=entry.seq, state=entry.model.model_dump(mode="json") ) await self.websocket.send_json(msg.model_dump(mode="json")) def _update_model(self, model: BaseModel, new_data: Dict[str, Any]): # Update model fields in-place to preserve references # This is important for dict fields that may be referenced elsewhere for field_name, field_info in model.model_fields.items(): if field_name in new_data: new_value = new_data[field_name] current_value = getattr(model, field_name) # For dict fields, update in-place instead of replacing if isinstance(current_value, dict) and isinstance(new_value, dict): self._deep_update_dict(current_value, new_value) else: # For other types, just set the new value setattr(model, field_name, new_value) def _deep_update_dict(self, target: dict, source: dict): """Deep update target dict with source dict, preserving nested dict references.""" # Remove keys that are in target but not in source keys_to_remove = set(target.keys()) - set(source.keys()) for key in keys_to_remove: del target[key] # Update or add keys from source for key, source_value in source.items(): if key in target: target_value = target[key] # If both are dicts, recursively update if isinstance(target_value, dict) and isinstance(source_value, dict): self._deep_update_dict(target_value, source_value) else: # Replace the value target[key] = source_value else: # Add new key target[key] = source_value