247 lines
11 KiB
Python
247 lines
11 KiB
Python
from collections import deque
|
|
from typing import Any, Dict, List, Optional, Tuple, Deque
|
|
|
|
import jsonpatch
|
|
from pydantic import BaseModel
|
|
|
|
from sync.protocol import SnapshotMessage, PatchMessage
|
|
|
|
|
|
class SyncEntry:
|
|
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
|
|
self.model = model
|
|
self.store_name = store_name
|
|
self.seq = 0
|
|
self.last_snapshot = model.model_dump(mode="json")
|
|
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
|
|
|
|
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
|
|
current_state = self.model.model_dump(mode="json")
|
|
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
|
|
if not patch.patch:
|
|
return None
|
|
return patch.patch
|
|
|
|
def commit_patch(self, patch: List[Dict[str, Any]]):
|
|
self.seq += 1
|
|
self.history.append((self.seq, patch))
|
|
self.last_snapshot = self.model.model_dump(mode="json")
|
|
|
|
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
|
|
if since_seq == self.seq:
|
|
return []
|
|
|
|
# Check if all patches from since_seq + 1 to self.seq are in history
|
|
if not self.history or self.history[0][0] > since_seq + 1:
|
|
return None
|
|
|
|
result = []
|
|
for seq, patch in self.history:
|
|
if seq > since_seq:
|
|
result.append((seq, patch))
|
|
return result
|
|
|
|
class SyncRegistry:
|
|
def __init__(self):
|
|
self.entries: Dict[str, SyncEntry] = {}
|
|
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
|
|
|
|
def register(self, model: BaseModel, store_name: str):
|
|
self.entries[store_name] = SyncEntry(model, store_name)
|
|
|
|
async def push_all(self):
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if not self.websocket:
|
|
logger.warning("push_all: No websocket connected, cannot push updates")
|
|
return
|
|
|
|
logger.info(f"push_all: Processing {len(self.entries)} store entries")
|
|
for entry in self.entries.values():
|
|
patch = entry.compute_patch()
|
|
if patch:
|
|
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
|
|
entry.commit_patch(patch)
|
|
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
|
|
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
|
|
else:
|
|
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
|
|
|
|
async def sync_client(self, client_seqs: Dict[str, int]):
|
|
if not self.websocket:
|
|
return
|
|
|
|
for store_name, entry in self.entries.items():
|
|
client_seq = client_seqs.get(store_name, -1)
|
|
patches = entry.catchup_patches(client_seq)
|
|
|
|
if patches is not None:
|
|
# Replay patches
|
|
for seq, patch in patches:
|
|
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
else:
|
|
# Send full snapshot
|
|
msg = SnapshotMessage(
|
|
store=store_name,
|
|
seq=entry.seq,
|
|
state=entry.model.model_dump(mode="json")
|
|
)
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
|
|
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
|
|
|
|
entry = self.entries.get(store_name)
|
|
if not entry:
|
|
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
|
|
return
|
|
|
|
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
|
|
|
|
try:
|
|
if client_base_seq == entry.seq:
|
|
# No conflict
|
|
logger.info("apply_client_patch: No conflict - applying patch directly")
|
|
current_state = entry.model.model_dump(mode="json")
|
|
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
|
|
try:
|
|
new_state = jsonpatch.apply_patch(current_state, patch)
|
|
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
|
self._update_model(entry.model, new_state)
|
|
|
|
# Verify the model was actually updated
|
|
updated_state = entry.model.model_dump(mode="json")
|
|
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
|
|
|
|
entry.commit_patch(patch)
|
|
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
|
# Don't broadcast back to client - they already have this change
|
|
# Broadcasting would cause an infinite loop
|
|
logger.info("apply_client_patch: Not broadcasting back to originating client")
|
|
except jsonpatch.JsonPatchConflict as e:
|
|
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
|
|
# Send snapshot to force resync
|
|
if self.websocket:
|
|
msg = SnapshotMessage(
|
|
store=entry.store_name,
|
|
seq=entry.seq,
|
|
state=entry.model.model_dump(mode="json")
|
|
)
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
|
|
elif client_base_seq < entry.seq:
|
|
# Conflict! Frontend wins.
|
|
# 1. Get backend patches since client_base_seq
|
|
backend_patches = []
|
|
for seq, p in entry.history:
|
|
if seq > client_base_seq:
|
|
backend_patches.append(p)
|
|
|
|
# 2. Apply frontend patch first to the state at client_base_seq
|
|
# But we only have the current authoritative model.
|
|
# "Apply the frontend patch first to the model (frontend wins)"
|
|
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
|
|
|
|
# Let's get the state as it was at client_base_seq if possible?
|
|
# No, history only has patches.
|
|
|
|
# Alternative: Apply frontend patch to current model.
|
|
# Then re-apply backend patches, but discard parts that overlap.
|
|
|
|
frontend_paths = {p['path'] for p in patch}
|
|
|
|
current_state = entry.model.model_dump(mode="json")
|
|
# Apply frontend patch
|
|
try:
|
|
new_state = jsonpatch.apply_patch(current_state, patch)
|
|
except jsonpatch.JsonPatchConflict as e:
|
|
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
|
|
# Send snapshot to force resync
|
|
if self.websocket:
|
|
msg = SnapshotMessage(
|
|
store=entry.store_name,
|
|
seq=entry.seq,
|
|
state=entry.model.model_dump(mode="json")
|
|
)
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
return
|
|
|
|
# Re-apply backend patches that don't overlap
|
|
for b_patch in backend_patches:
|
|
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
|
|
if filtered_b_patch:
|
|
try:
|
|
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
|
|
except jsonpatch.JsonPatchConflict as e:
|
|
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
|
|
continue
|
|
|
|
self._update_model(entry.model, new_state)
|
|
|
|
# Commit the result as a single new patch
|
|
# We need to compute what changed from last_snapshot to new_state
|
|
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
|
|
if final_patch:
|
|
entry.commit_patch(final_patch)
|
|
# Broadcast resolved state as snapshot to converge
|
|
if self.websocket:
|
|
msg = SnapshotMessage(
|
|
store=entry.store_name,
|
|
seq=entry.seq,
|
|
state=entry.model.model_dump(mode="json")
|
|
)
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
except Exception as e:
|
|
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
|
|
# Send snapshot to force resync
|
|
if self.websocket:
|
|
msg = SnapshotMessage(
|
|
store=entry.store_name,
|
|
seq=entry.seq,
|
|
state=entry.model.model_dump(mode="json")
|
|
)
|
|
await self.websocket.send_json(msg.model_dump(mode="json"))
|
|
|
|
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
|
# Update model fields in-place to preserve references
|
|
# This is important for dict fields that may be referenced elsewhere
|
|
for field_name, field_info in model.model_fields.items():
|
|
if field_name in new_data:
|
|
new_value = new_data[field_name]
|
|
current_value = getattr(model, field_name)
|
|
|
|
# For dict fields, update in-place instead of replacing
|
|
if isinstance(current_value, dict) and isinstance(new_value, dict):
|
|
self._deep_update_dict(current_value, new_value)
|
|
else:
|
|
# For other types, just set the new value
|
|
setattr(model, field_name, new_value)
|
|
|
|
def _deep_update_dict(self, target: dict, source: dict):
|
|
"""Deep update target dict with source dict, preserving nested dict references."""
|
|
# Remove keys that are in target but not in source
|
|
keys_to_remove = set(target.keys()) - set(source.keys())
|
|
for key in keys_to_remove:
|
|
del target[key]
|
|
|
|
# Update or add keys from source
|
|
for key, source_value in source.items():
|
|
if key in target:
|
|
target_value = target[key]
|
|
# If both are dicts, recursively update
|
|
if isinstance(target_value, dict) and isinstance(source_value, dict):
|
|
self._deep_update_dict(target_value, source_value)
|
|
else:
|
|
# Replace the value
|
|
target[key] = source_value
|
|
else:
|
|
# Add new key
|
|
target[key] = source_value
|