shape editing

This commit is contained in:
2026-03-02 22:49:45 -04:00
parent f4da40706c
commit bf7af2b426
18 changed files with 2236 additions and 209 deletions

View File

@@ -7,7 +7,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS
from agent.memory import MemoryManager
from agent.session import SessionManager
from agent.prompts import build_system_prompt
@@ -65,10 +65,10 @@ class AgentExecutor:
# Create agent without a static system prompt
# We'll pass the dynamic system prompt via state_modifier at runtime
# Include all tool categories: sync, datasource, chart, indicator, and research
# Include all tool categories: sync, datasource, chart, indicator, shape, and research
self.agent = create_react_agent(
self.llm,
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS,
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS + RESEARCH_TOOLS,
checkpointer=checkpointer
)

View File

@@ -30,6 +30,11 @@ def _get_chart_store_context() -> str:
interval = chart_data.get("interval", "N/A")
start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time")
selected_shapes = chart_data.get("selected_shapes", [])
selected_info = ""
if selected_shapes:
selected_info = f"\n- **Selected Shapes**: {len(selected_shapes)} shape(s) selected (IDs: {', '.join(selected_shapes)})"
chart_context = f"""
## Current Chart Context
@@ -37,7 +42,7 @@ def _get_chart_store_context() -> str:
The user is currently viewing a chart with the following settings:
- **Symbol**: {symbol}
- **Interval**: {interval}
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}{selected_info}
This information is automatically available because you're connected via websocket.
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.

View File

@@ -5,6 +5,7 @@ This package provides tools for:
- Data sources and market data (datasource_tools)
- Chart data access and analysis (chart_tools)
- Technical indicators (indicator_tools)
- Shape/drawing management (shape_tools)
"""
# Global registries that will be set by main.py
@@ -37,6 +38,7 @@ from .datasource_tools import DATASOURCE_TOOLS
from .chart_tools import CHART_TOOLS
from .indicator_tools import INDICATOR_TOOLS
from .research_tools import RESEARCH_TOOLS
from .shape_tools import SHAPE_TOOLS
__all__ = [
"set_registry",
@@ -47,4 +49,5 @@ __all__ = [
"CHART_TOOLS",
"INDICATOR_TOOLS",
"RESEARCH_TOOLS",
"SHAPE_TOOLS",
]

View File

@@ -0,0 +1,475 @@
"""Shape/drawing tools for chart analysis."""
from typing import Dict, Any, List, Optional
from langchain_core.tools import tool
import logging
logger = logging.getLogger(__name__)
# Map legacy/common shape type names to TradingView's native names
SHAPE_TYPE_ALIASES: Dict[str, str] = {
'trendline': 'trend_line',
'fibonacci': 'fib_retracement',
'fibonacci_extension': 'fib_trend_ext',
'gann_fan': 'gannbox_fan',
}
def _get_registry():
"""Get the global registry instance."""
from . import _registry
return _registry
def _get_shape_store():
"""Get the global ShapeStore instance."""
registry = _get_registry()
if registry and "ShapeStore" in registry.entries:
return registry.entries["ShapeStore"].model
return None
@tool
def search_shapes(
start_time: Optional[int] = None,
end_time: Optional[int] = None,
shape_type: Optional[str] = None,
symbol: Optional[str] = None,
shape_ids: Optional[List[str]] = None,
original_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Search for shapes/drawings using flexible filters.
This tool can search shapes by:
- Time range (finds shapes that overlap the range)
- Shape type (e.g., 'trendline', 'horizontal_line')
- Symbol (e.g., 'BINANCE:BTC/USDT')
- Specific shape IDs (TradingView's assigned IDs)
- Original IDs (the IDs you specified when creating shapes)
Args:
start_time: Optional start of time range (Unix timestamp in seconds)
end_time: Optional end of time range (Unix timestamp in seconds)
shape_type: Optional filter by shape type (e.g., 'trend_line', 'horizontal_line', 'rectangle')
symbol: Optional filter by symbol (e.g., 'BINANCE:BTC/USDT')
shape_ids: Optional list of specific shape IDs to retrieve (searches both id and original_id fields)
original_ids: Optional list of original IDs to search for (the IDs you specified when creating)
Returns:
List of matching shapes, each as a dictionary with:
- id: Shape identifier (TradingView's assigned ID)
- original_id: The ID you specified when creating the shape (if applicable)
- type: Shape type
- points: List of control points with time and price
- color, line_width, line_style: Visual properties
- properties: Additional shape-specific properties
- symbol: Symbol the shape is drawn on
- created_at, modified_at: Timestamps
Examples:
# Find all shapes in the currently visible chart range
shapes = search_shapes(
start_time=chart_state.start_time,
end_time=chart_state.end_time
)
# Find only trendlines in a specific time range
trendlines = search_shapes(
start_time=1640000000,
end_time=1650000000,
shape_type='trend_line'
)
# Find shapes for a specific symbol
btc_shapes = search_shapes(
start_time=1640000000,
end_time=1650000000,
symbol='BINANCE:BTC/USDT'
)
# Get specific shapes by TradingView ID or original ID
# This searches both the 'id' and 'original_id' fields
selected = search_shapes(
shape_ids=['trendline-1', 'support-42k', 'fib-retracement-1']
)
# Get shapes by the original IDs you specified when creating them
my_shapes = search_shapes(
original_ids=['my-support-line', 'my-resistance-line']
)
# Get all trendlines (no time filter)
all_trendlines = search_shapes(shape_type='trend_line')
"""
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
shapes_dict = shape_store.shapes
matching_shapes = []
# If specific shape IDs are requested, search by both id and original_id
if shape_ids:
for requested_id in shape_ids:
# First try direct ID lookup
shape = shapes_dict.get(requested_id)
if shape:
# Still apply other filters if specified
if symbol and shape.get('symbol') != symbol:
continue
if shape_type and shape.get('type') != shape_type:
continue
matching_shapes.append(shape)
else:
# If not found by ID, search by original_id
for shape_id, shape in shapes_dict.items():
if shape.get('original_id') == requested_id:
# Still apply other filters if specified
if symbol and shape.get('symbol') != symbol:
continue
if shape_type and shape.get('type') != shape_type:
continue
matching_shapes.append(shape)
break
logger.info(
f"Found {len(matching_shapes)} shapes by ID filter (requested {len(shape_ids)} IDs)"
+ (f" for type '{shape_type}'" if shape_type else "")
+ (f" on symbol '{symbol}'" if symbol else "")
)
return matching_shapes
# If specific original IDs are requested, search by original_id only
if original_ids:
for original_id in original_ids:
for shape_id, shape in shapes_dict.items():
if shape.get('original_id') == original_id:
# Still apply other filters if specified
if symbol and shape.get('symbol') != symbol:
continue
if shape_type and shape.get('type') != shape_type:
continue
matching_shapes.append(shape)
break
logger.info(
f"Found {len(matching_shapes)} shapes by original_id filter (requested {len(original_ids)} IDs)"
+ (f" for type '{shape_type}'" if shape_type else "")
+ (f" on symbol '{symbol}'" if symbol else "")
)
return matching_shapes
# Otherwise, search all shapes with filters
for shape_id, shape in shapes_dict.items():
# Filter by symbol if specified
if symbol and shape.get('symbol') != symbol:
continue
# Filter by type if specified
if shape_type and shape.get('type') != shape_type:
continue
# Filter by time range if specified
if start_time is not None and end_time is not None:
# Check if any control point falls within the time range
# or if the shape spans across the time range
points = shape.get('points', [])
if not points:
continue
# Get min and max times from shape's control points
shape_times = [point['time'] for point in points]
shape_min_time = min(shape_times)
shape_max_time = max(shape_times)
# Check for overlap: shape overlaps if its range intersects with query range
if not (shape_max_time >= start_time and shape_min_time <= end_time):
continue
matching_shapes.append(shape)
logger.info(
f"Found {len(matching_shapes)} shapes"
+ (f" in time range {start_time}-{end_time}" if start_time and end_time else "")
+ (f" for type '{shape_type}'" if shape_type else "")
+ (f" on symbol '{symbol}'" if symbol else "")
)
return matching_shapes
@tool
async def create_or_update_shape(
shape_id: str,
shape_type: str,
points: List[Dict[str, Any]],
color: Optional[str] = None,
line_width: Optional[int] = None,
line_style: Optional[str] = None,
properties: Optional[Dict[str, Any]] = None,
symbol: Optional[str] = None
) -> Dict[str, Any]:
"""Create a new shape or update an existing shape on the chart.
This tool allows the agent to draw shapes on the user's chart or modify
existing shapes. Shapes are synchronized to the frontend in real-time.
IMPORTANT - Shape ID Mapping:
When you create a shape, TradingView will assign its own internal ID that differs
from the shape_id you provide. The shape will be updated in the store with:
- id: TradingView's assigned ID
- original_id: The shape_id you provided
To find your shape later, use search_shapes() and filter by original_id field.
Example:
# Create a shape
await create_or_update_shape(shape_id='my-support', ...)
# Later, find it by original_id
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
my_shape = next((s for s in shapes if s.get('original_id') == 'my-support'), None)
Args:
shape_id: Unique identifier for the shape (use existing ID to update, new ID to create)
Note: TradingView will assign its own ID; your ID will be stored in original_id
shape_type: Type of shape using TradingView's native names.
Single-point shapes (use 1 point):
- 'horizontal_line': Horizontal support/resistance line
- 'vertical_line': Vertical time marker
- 'text': Text label
- 'anchored_text': Anchored text annotation
- 'anchored_note': Anchored note
- 'note': Note annotation
- 'emoji': Emoji marker
- 'icon': Icon marker
- 'sticker': Sticker marker
- 'arrow_up': Upward arrow marker
- 'arrow_down': Downward arrow marker
- 'flag': Flag marker
- 'long_position': Long position marker
- 'short_position': Short position marker
Multi-point shapes (use 2+ points):
- 'trend_line': Trendline (2 points)
- 'rectangle': Rectangle (2 points: top-left, bottom-right)
- 'fib_retracement': Fibonacci retracement (2 points)
- 'fib_trend_ext': Fibonacci extension (3 points)
- 'parallel_channel': Parallel channel (3 points)
- 'arrow': Arrow (2 points)
- 'circle': Circle/ellipse (2-3 points)
- 'path': Free drawing path (3+ points)
- 'pitchfork': Andrew's pitchfork (3 points)
- 'gannbox_fan': Gann fan (2 points)
- 'head_and_shoulders': Head and shoulders pattern (5 points)
points: List of control points, each with 'time' (Unix seconds) and 'price' fields
color: Optional color (hex like '#FF0000' or name like 'red')
line_width: Optional line width in pixels (default: 1)
line_style: Optional line style: 'solid', 'dashed', 'dotted' (default: 'solid')
properties: Optional dict of additional shape-specific properties
symbol: Optional symbol to associate with the shape (defaults to current chart symbol)
Returns:
Dictionary with:
- status: 'created' or 'updated'
- shape: The complete shape object (initially with your ID, will be updated to TV ID)
Examples:
# Draw a trendline between two points
await create_or_update_shape(
shape_id='my-trendline-1',
shape_type='trend_line',
points=[
{'time': 1640000000, 'price': 45000.0},
{'time': 1650000000, 'price': 50000.0}
],
color='#00FF00',
line_width=2
)
# Draw a horizontal support line
await create_or_update_shape(
shape_id='support-1',
shape_type='horizontal_line',
points=[{'time': 1640000000, 'price': 42000.0}],
color='blue',
line_style='dashed'
)
# Find your shape after creation using original_id
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
my_shape = next((s for s in shapes if s.get('original_id') == 'support-1'), None)
if my_shape:
print(f"TradingView assigned ID: {my_shape['id']}")
"""
from schema.shape import Shape, ControlPoint
import time as time_module
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
# Normalize shape type (handle legacy names)
normalized_type = SHAPE_TYPE_ALIASES.get(shape_type, shape_type)
if normalized_type != shape_type:
logger.info(f"Normalized shape type '{shape_type}' -> '{normalized_type}'")
# Convert points to ControlPoint objects
control_points = []
for p in points:
point_data = {
'time': p['time'],
'price': p['price']
}
# Only include channel if it's actually provided
if 'channel' in p and p['channel'] is not None:
point_data['channel'] = p['channel']
control_points.append(ControlPoint(**point_data))
# Check if updating existing shape
existing_shape = shape_store.shapes.get(shape_id)
is_update = existing_shape is not None
# If symbol is not provided, try to get it from ChartStore
if symbol is None and "ChartStore" in registry.entries:
chart_store = registry.entries["ChartStore"].model
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
symbol = chart_store.chart_state.symbol
logger.info(f"Using current chart symbol for shape: {symbol}")
now = int(time_module.time())
# Create shape object
shape = Shape(
id=shape_id,
type=normalized_type,
points=control_points,
color=color,
line_width=line_width,
line_style=line_style,
properties=properties or {},
symbol=symbol,
created_at=existing_shape.get('created_at') if existing_shape else now,
modified_at=now
)
# Update the store
shape_store.shapes[shape_id] = shape.model_dump(mode="json")
# Trigger sync
await registry.push_all()
logger.info(
f"{'Updated' if is_update else 'Created'} shape '{shape_id}' "
f"of type '{shape_type}' with {len(points)} points"
)
return {
"status": "updated" if is_update else "created",
"shape": shape.model_dump(mode="json")
}
@tool
async def delete_shape(shape_id: str) -> Dict[str, str]:
"""Delete a shape from the chart.
Args:
shape_id: ID of the shape to delete
Returns:
Dictionary with status message
Raises:
ValueError: If shape doesn't exist
Example:
await delete_shape('my-trendline-1')
"""
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
if shape_id not in shape_store.shapes:
raise ValueError(f"Shape '{shape_id}' not found")
# Delete the shape
del shape_store.shapes[shape_id]
# Trigger sync
await registry.push_all()
logger.info(f"Deleted shape '{shape_id}'")
return {
"status": "success",
"message": f"Shape '{shape_id}' deleted"
}
@tool
def get_shape(shape_id: str) -> Dict[str, Any]:
"""Get details of a specific shape by ID.
Args:
shape_id: ID of the shape to retrieve
Returns:
Dictionary containing the shape data
Raises:
ValueError: If shape doesn't exist
Example:
shape = get_shape('my-trendline-1')
print(f"Shape type: {shape['type']}")
print(f"Points: {shape['points']}")
"""
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
shape = shape_store.shapes.get(shape_id)
if not shape:
raise ValueError(f"Shape '{shape_id}' not found")
return shape
@tool
def list_all_shapes() -> List[Dict[str, Any]]:
"""List all shapes currently on the chart.
Returns:
List of all shapes as dictionaries
Example:
shapes = list_all_shapes()
print(f"Total shapes: {len(shapes)}")
for shape in shapes:
print(f" - {shape['id']}: {shape['type']}")
"""
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
return list(shape_store.shapes.values())
SHAPE_TOOLS = [
search_shapes,
create_or_update_shape,
delete_shape,
get_shape,
list_all_shapes
]

View File

@@ -23,6 +23,7 @@ from agent.core import create_agent
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
from schema.order_spec import SwapOrder
from schema.chart_state import ChartState
from schema.shape import ShapeCollection
from datasource.registry import DataSourceRegistry
from datasource.subscription_manager import SubscriptionManager
from datasource.websocket_handler import DatafeedWebSocketHandler
@@ -124,7 +125,7 @@ async def lifespan(app: FastAPI):
chroma_db_path=config["memory"]["chroma_db"],
embedding_model=config["memory"]["embedding_model"],
context_docs_dir=config["agent"]["context_docs_dir"],
base_dir=".." # Point to project root from backend/src
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ lives
)
await agent_executor.initialize()
@@ -159,13 +160,19 @@ class OrderStore(BaseModel):
class ChartStore(BaseModel):
chart_state: ChartState = ChartState()
# ShapeStore model for synchronization
class ShapeStore(BaseModel):
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
# Initialize stores
order_store = OrderStore()
chart_store = ChartStore()
shape_store = ShapeStore()
# Register with SyncRegistry
registry.register(order_store, store_name="OrderStore")
registry.register(chart_store, store_name="ChartStore")
registry.register(shape_store, store_name="ShapeStore")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
@@ -361,11 +368,14 @@ async def websocket_endpoint(websocket: WebSocket):
elif msg_type == "patch":
patch_msg = PatchMessage(**message_json)
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
await registry.apply_client_patch(
store_name=patch_msg.store,
client_base_seq=patch_msg.seq,
patch=patch_msg.patch
)
try:
await registry.apply_client_patch(
store_name=patch_msg.store,
client_base_seq=patch_msg.seq,
patch=patch_msg.patch
)
except Exception as e:
logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True)
elif msg_type == "agent_user_message":
# Handle agent messages directly here
print(f"[DEBUG] Raw message_json: {message_json}")

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List
from pydantic import BaseModel, Field
@@ -23,3 +23,6 @@ class ChartState(BaseModel):
# Optional: Chart interval/resolution
# None when chart is not visible
interval: Optional[str] = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D'), or None if no chart visible")
# Selected shapes/drawings on the chart
selected_shapes: List[str] = Field(default_factory=list, description="Array of selected shape IDs")

View File

@@ -0,0 +1,44 @@
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
class ControlPoint(BaseModel):
"""A control point for a drawing shape.
Control points define the position and properties of a shape.
Different shapes have different numbers of control points.
"""
time: int = Field(..., description="Unix timestamp in seconds")
price: float = Field(..., description="Price level")
# Optional channel for multi-point shapes (e.g., parallel channels)
channel: Optional[str] = Field(default=None, description="Channel identifier for multi-point shapes")
class Shape(BaseModel):
"""A TradingView drawing shape/study.
Represents any drawing the user creates on the chart (trendlines,
horizontal lines, rectangles, Fibonacci retracements, etc.)
"""
id: str = Field(..., description="Unique identifier for the shape")
type: str = Field(..., description="Shape type (e.g., 'trendline', 'horizontal_line', 'rectangle', 'fibonacci')")
points: List[ControlPoint] = Field(default_factory=list, description="Control points that define the shape")
# Visual properties
color: Optional[str] = Field(default=None, description="Shape color (hex or color name)")
line_width: Optional[int] = Field(default=1, description="Line width in pixels")
line_style: Optional[str] = Field(default="solid", description="Line style: 'solid', 'dashed', 'dotted'")
# Shape-specific properties stored as flexible dict
properties: Dict[str, Any] = Field(default_factory=dict, description="Additional shape-specific properties")
# Metadata
symbol: Optional[str] = Field(default=None, description="Symbol this shape is drawn on")
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
original_id: Optional[str] = Field(default=None, description="Original ID from backend/agent before TradingView assigns its own ID")
class ShapeCollection(BaseModel):
"""Collection of all shapes/drawings on the chart."""
shapes: Dict[str, Shape] = Field(default_factory=dict, description="Dictionary of shapes keyed by ID")

View File

@@ -105,67 +105,105 @@ class SyncRegistry:
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
if client_base_seq == entry.seq:
# No conflict
logger.info("apply_client_patch: No conflict - applying patch directly")
current_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
new_state = jsonpatch.apply_patch(current_state, patch)
logger.info(f"apply_client_patch: New state after patch: {new_state}")
self._update_model(entry.model, new_state)
try:
if client_base_seq == entry.seq:
# No conflict
logger.info("apply_client_patch: No conflict - applying patch directly")
current_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
try:
new_state = jsonpatch.apply_patch(current_state, patch)
logger.info(f"apply_client_patch: New state after patch: {new_state}")
self._update_model(entry.model, new_state)
entry.commit_patch(patch)
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
# Don't broadcast back to client - they already have this change
# Broadcasting would cause an infinite loop
logger.info("apply_client_patch: Not broadcasting back to originating client")
elif client_base_seq < entry.seq:
# Conflict! Frontend wins.
# 1. Get backend patches since client_base_seq
backend_patches = []
for seq, p in entry.history:
if seq > client_base_seq:
backend_patches.append(p)
# 2. Apply frontend patch first to the state at client_base_seq
# But we only have the current authoritative model.
# "Apply the frontend patch first to the model (frontend wins)"
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
# Let's get the state as it was at client_base_seq if possible?
# No, history only has patches.
# Alternative: Apply frontend patch to current model.
# Then re-apply backend patches, but discard parts that overlap.
frontend_paths = {p['path'] for p in patch}
current_state = entry.model.model_dump(mode="json")
# Apply frontend patch
new_state = jsonpatch.apply_patch(current_state, patch)
# Re-apply backend patches that don't overlap
for b_patch in backend_patches:
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
if filtered_b_patch:
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
self._update_model(entry.model, new_state)
# Commit the result as a single new patch
# We need to compute what changed from last_snapshot to new_state
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
if final_patch:
entry.commit_patch(final_patch)
# Broadcast resolved state as snapshot to converge
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
entry.commit_patch(patch)
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
# Don't broadcast back to client - they already have this change
# Broadcasting would cause an infinite loop
logger.info("apply_client_patch: Not broadcasting back to originating client")
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
elif client_base_seq < entry.seq:
# Conflict! Frontend wins.
# 1. Get backend patches since client_base_seq
backend_patches = []
for seq, p in entry.history:
if seq > client_base_seq:
backend_patches.append(p)
# 2. Apply frontend patch first to the state at client_base_seq
# But we only have the current authoritative model.
# "Apply the frontend patch first to the model (frontend wins)"
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
# Let's get the state as it was at client_base_seq if possible?
# No, history only has patches.
# Alternative: Apply frontend patch to current model.
# Then re-apply backend patches, but discard parts that overlap.
frontend_paths = {p['path'] for p in patch}
current_state = entry.model.model_dump(mode="json")
# Apply frontend patch
try:
new_state = jsonpatch.apply_patch(current_state, patch)
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
return
# Re-apply backend patches that don't overlap
for b_patch in backend_patches:
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
if filtered_b_patch:
try:
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
continue
self._update_model(entry.model, new_state)
# Commit the result as a single new patch
# We need to compute what changed from last_snapshot to new_state
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
if final_patch:
entry.commit_patch(final_patch)
# Broadcast resolved state as snapshot to converge
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
except Exception as e:
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
# Update model using model_validate for potentially nested models