shape editing
This commit is contained in:
@@ -7,7 +7,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
@@ -65,10 +65,10 @@ class AgentExecutor:
|
||||
|
||||
# Create agent without a static system prompt
|
||||
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||
# Include all tool categories: sync, datasource, chart, indicator, and research
|
||||
# Include all tool categories: sync, datasource, chart, indicator, shape, and research
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS + RESEARCH_TOOLS,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
|
||||
@@ -30,6 +30,11 @@ def _get_chart_store_context() -> str:
|
||||
interval = chart_data.get("interval", "N/A")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
selected_shapes = chart_data.get("selected_shapes", [])
|
||||
|
||||
selected_info = ""
|
||||
if selected_shapes:
|
||||
selected_info = f"\n- **Selected Shapes**: {len(selected_shapes)} shape(s) selected (IDs: {', '.join(selected_shapes)})"
|
||||
|
||||
chart_context = f"""
|
||||
## Current Chart Context
|
||||
@@ -37,7 +42,7 @@ def _get_chart_store_context() -> str:
|
||||
The user is currently viewing a chart with the following settings:
|
||||
- **Symbol**: {symbol}
|
||||
- **Interval**: {interval}
|
||||
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}
|
||||
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}{selected_info}
|
||||
|
||||
This information is automatically available because you're connected via websocket.
|
||||
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.
|
||||
|
||||
@@ -5,6 +5,7 @@ This package provides tools for:
|
||||
- Data sources and market data (datasource_tools)
|
||||
- Chart data access and analysis (chart_tools)
|
||||
- Technical indicators (indicator_tools)
|
||||
- Shape/drawing management (shape_tools)
|
||||
"""
|
||||
|
||||
# Global registries that will be set by main.py
|
||||
@@ -37,6 +38,7 @@ from .datasource_tools import DATASOURCE_TOOLS
|
||||
from .chart_tools import CHART_TOOLS
|
||||
from .indicator_tools import INDICATOR_TOOLS
|
||||
from .research_tools import RESEARCH_TOOLS
|
||||
from .shape_tools import SHAPE_TOOLS
|
||||
|
||||
__all__ = [
|
||||
"set_registry",
|
||||
@@ -47,4 +49,5 @@ __all__ = [
|
||||
"CHART_TOOLS",
|
||||
"INDICATOR_TOOLS",
|
||||
"RESEARCH_TOOLS",
|
||||
"SHAPE_TOOLS",
|
||||
]
|
||||
|
||||
475
backend/src/agent/tools/shape_tools.py
Normal file
475
backend/src/agent/tools/shape_tools.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Shape/drawing tools for chart analysis."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map legacy/common shape type names to TradingView's native names
|
||||
SHAPE_TYPE_ALIASES: Dict[str, str] = {
|
||||
'trendline': 'trend_line',
|
||||
'fibonacci': 'fib_retracement',
|
||||
'fibonacci_extension': 'fib_trend_ext',
|
||||
'gann_fan': 'gannbox_fan',
|
||||
}
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_shape_store():
|
||||
"""Get the global ShapeStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "ShapeStore" in registry.entries:
|
||||
return registry.entries["ShapeStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def search_shapes(
|
||||
start_time: Optional[int] = None,
|
||||
end_time: Optional[int] = None,
|
||||
shape_type: Optional[str] = None,
|
||||
symbol: Optional[str] = None,
|
||||
shape_ids: Optional[List[str]] = None,
|
||||
original_ids: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for shapes/drawings using flexible filters.
|
||||
|
||||
This tool can search shapes by:
|
||||
- Time range (finds shapes that overlap the range)
|
||||
- Shape type (e.g., 'trendline', 'horizontal_line')
|
||||
- Symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
- Specific shape IDs (TradingView's assigned IDs)
|
||||
- Original IDs (the IDs you specified when creating shapes)
|
||||
|
||||
Args:
|
||||
start_time: Optional start of time range (Unix timestamp in seconds)
|
||||
end_time: Optional end of time range (Unix timestamp in seconds)
|
||||
shape_type: Optional filter by shape type (e.g., 'trend_line', 'horizontal_line', 'rectangle')
|
||||
symbol: Optional filter by symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
shape_ids: Optional list of specific shape IDs to retrieve (searches both id and original_id fields)
|
||||
original_ids: Optional list of original IDs to search for (the IDs you specified when creating)
|
||||
|
||||
Returns:
|
||||
List of matching shapes, each as a dictionary with:
|
||||
- id: Shape identifier (TradingView's assigned ID)
|
||||
- original_id: The ID you specified when creating the shape (if applicable)
|
||||
- type: Shape type
|
||||
- points: List of control points with time and price
|
||||
- color, line_width, line_style: Visual properties
|
||||
- properties: Additional shape-specific properties
|
||||
- symbol: Symbol the shape is drawn on
|
||||
- created_at, modified_at: Timestamps
|
||||
|
||||
Examples:
|
||||
# Find all shapes in the currently visible chart range
|
||||
shapes = search_shapes(
|
||||
start_time=chart_state.start_time,
|
||||
end_time=chart_state.end_time
|
||||
)
|
||||
|
||||
# Find only trendlines in a specific time range
|
||||
trendlines = search_shapes(
|
||||
start_time=1640000000,
|
||||
end_time=1650000000,
|
||||
shape_type='trend_line'
|
||||
)
|
||||
|
||||
# Find shapes for a specific symbol
|
||||
btc_shapes = search_shapes(
|
||||
start_time=1640000000,
|
||||
end_time=1650000000,
|
||||
symbol='BINANCE:BTC/USDT'
|
||||
)
|
||||
|
||||
# Get specific shapes by TradingView ID or original ID
|
||||
# This searches both the 'id' and 'original_id' fields
|
||||
selected = search_shapes(
|
||||
shape_ids=['trendline-1', 'support-42k', 'fib-retracement-1']
|
||||
)
|
||||
|
||||
# Get shapes by the original IDs you specified when creating them
|
||||
my_shapes = search_shapes(
|
||||
original_ids=['my-support-line', 'my-resistance-line']
|
||||
)
|
||||
|
||||
# Get all trendlines (no time filter)
|
||||
all_trendlines = search_shapes(shape_type='trend_line')
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
shapes_dict = shape_store.shapes
|
||||
matching_shapes = []
|
||||
|
||||
# If specific shape IDs are requested, search by both id and original_id
|
||||
if shape_ids:
|
||||
for requested_id in shape_ids:
|
||||
# First try direct ID lookup
|
||||
shape = shapes_dict.get(requested_id)
|
||||
if shape:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
else:
|
||||
# If not found by ID, search by original_id
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
if shape.get('original_id') == requested_id:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes by ID filter (requested {len(shape_ids)} IDs)"
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
return matching_shapes
|
||||
|
||||
# If specific original IDs are requested, search by original_id only
|
||||
if original_ids:
|
||||
for original_id in original_ids:
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
if shape.get('original_id') == original_id:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes by original_id filter (requested {len(original_ids)} IDs)"
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
return matching_shapes
|
||||
|
||||
# Otherwise, search all shapes with filters
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
# Filter by symbol if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
|
||||
# Filter by type if specified
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
|
||||
# Filter by time range if specified
|
||||
if start_time is not None and end_time is not None:
|
||||
# Check if any control point falls within the time range
|
||||
# or if the shape spans across the time range
|
||||
points = shape.get('points', [])
|
||||
if not points:
|
||||
continue
|
||||
|
||||
# Get min and max times from shape's control points
|
||||
shape_times = [point['time'] for point in points]
|
||||
shape_min_time = min(shape_times)
|
||||
shape_max_time = max(shape_times)
|
||||
|
||||
# Check for overlap: shape overlaps if its range intersects with query range
|
||||
if not (shape_max_time >= start_time and shape_min_time <= end_time):
|
||||
continue
|
||||
|
||||
matching_shapes.append(shape)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes"
|
||||
+ (f" in time range {start_time}-{end_time}" if start_time and end_time else "")
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
|
||||
return matching_shapes
|
||||
|
||||
|
||||
@tool
|
||||
async def create_or_update_shape(
|
||||
shape_id: str,
|
||||
shape_type: str,
|
||||
points: List[Dict[str, Any]],
|
||||
color: Optional[str] = None,
|
||||
line_width: Optional[int] = None,
|
||||
line_style: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new shape or update an existing shape on the chart.
|
||||
|
||||
This tool allows the agent to draw shapes on the user's chart or modify
|
||||
existing shapes. Shapes are synchronized to the frontend in real-time.
|
||||
|
||||
IMPORTANT - Shape ID Mapping:
|
||||
When you create a shape, TradingView will assign its own internal ID that differs
|
||||
from the shape_id you provide. The shape will be updated in the store with:
|
||||
- id: TradingView's assigned ID
|
||||
- original_id: The shape_id you provided
|
||||
|
||||
To find your shape later, use search_shapes() and filter by original_id field.
|
||||
|
||||
Example:
|
||||
# Create a shape
|
||||
await create_or_update_shape(shape_id='my-support', ...)
|
||||
|
||||
# Later, find it by original_id
|
||||
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
|
||||
my_shape = next((s for s in shapes if s.get('original_id') == 'my-support'), None)
|
||||
|
||||
Args:
|
||||
shape_id: Unique identifier for the shape (use existing ID to update, new ID to create)
|
||||
Note: TradingView will assign its own ID; your ID will be stored in original_id
|
||||
shape_type: Type of shape using TradingView's native names.
|
||||
|
||||
Single-point shapes (use 1 point):
|
||||
- 'horizontal_line': Horizontal support/resistance line
|
||||
- 'vertical_line': Vertical time marker
|
||||
- 'text': Text label
|
||||
- 'anchored_text': Anchored text annotation
|
||||
- 'anchored_note': Anchored note
|
||||
- 'note': Note annotation
|
||||
- 'emoji': Emoji marker
|
||||
- 'icon': Icon marker
|
||||
- 'sticker': Sticker marker
|
||||
- 'arrow_up': Upward arrow marker
|
||||
- 'arrow_down': Downward arrow marker
|
||||
- 'flag': Flag marker
|
||||
- 'long_position': Long position marker
|
||||
- 'short_position': Short position marker
|
||||
|
||||
Multi-point shapes (use 2+ points):
|
||||
- 'trend_line': Trendline (2 points)
|
||||
- 'rectangle': Rectangle (2 points: top-left, bottom-right)
|
||||
- 'fib_retracement': Fibonacci retracement (2 points)
|
||||
- 'fib_trend_ext': Fibonacci extension (3 points)
|
||||
- 'parallel_channel': Parallel channel (3 points)
|
||||
- 'arrow': Arrow (2 points)
|
||||
- 'circle': Circle/ellipse (2-3 points)
|
||||
- 'path': Free drawing path (3+ points)
|
||||
- 'pitchfork': Andrew's pitchfork (3 points)
|
||||
- 'gannbox_fan': Gann fan (2 points)
|
||||
- 'head_and_shoulders': Head and shoulders pattern (5 points)
|
||||
|
||||
points: List of control points, each with 'time' (Unix seconds) and 'price' fields
|
||||
color: Optional color (hex like '#FF0000' or name like 'red')
|
||||
line_width: Optional line width in pixels (default: 1)
|
||||
line_style: Optional line style: 'solid', 'dashed', 'dotted' (default: 'solid')
|
||||
properties: Optional dict of additional shape-specific properties
|
||||
symbol: Optional symbol to associate with the shape (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- shape: The complete shape object (initially with your ID, will be updated to TV ID)
|
||||
|
||||
Examples:
|
||||
# Draw a trendline between two points
|
||||
await create_or_update_shape(
|
||||
shape_id='my-trendline-1',
|
||||
shape_type='trend_line',
|
||||
points=[
|
||||
{'time': 1640000000, 'price': 45000.0},
|
||||
{'time': 1650000000, 'price': 50000.0}
|
||||
],
|
||||
color='#00FF00',
|
||||
line_width=2
|
||||
)
|
||||
|
||||
# Draw a horizontal support line
|
||||
await create_or_update_shape(
|
||||
shape_id='support-1',
|
||||
shape_type='horizontal_line',
|
||||
points=[{'time': 1640000000, 'price': 42000.0}],
|
||||
color='blue',
|
||||
line_style='dashed'
|
||||
)
|
||||
|
||||
# Find your shape after creation using original_id
|
||||
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
|
||||
my_shape = next((s for s in shapes if s.get('original_id') == 'support-1'), None)
|
||||
if my_shape:
|
||||
print(f"TradingView assigned ID: {my_shape['id']}")
|
||||
"""
|
||||
from schema.shape import Shape, ControlPoint
|
||||
import time as time_module
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
# Normalize shape type (handle legacy names)
|
||||
normalized_type = SHAPE_TYPE_ALIASES.get(shape_type, shape_type)
|
||||
if normalized_type != shape_type:
|
||||
logger.info(f"Normalized shape type '{shape_type}' -> '{normalized_type}'")
|
||||
|
||||
# Convert points to ControlPoint objects
|
||||
control_points = []
|
||||
for p in points:
|
||||
point_data = {
|
||||
'time': p['time'],
|
||||
'price': p['price']
|
||||
}
|
||||
# Only include channel if it's actually provided
|
||||
if 'channel' in p and p['channel'] is not None:
|
||||
point_data['channel'] = p['channel']
|
||||
control_points.append(ControlPoint(**point_data))
|
||||
|
||||
# Check if updating existing shape
|
||||
existing_shape = shape_store.shapes.get(shape_id)
|
||||
is_update = existing_shape is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for shape: {symbol}")
|
||||
|
||||
now = int(time_module.time())
|
||||
|
||||
# Create shape object
|
||||
shape = Shape(
|
||||
id=shape_id,
|
||||
type=normalized_type,
|
||||
points=control_points,
|
||||
color=color,
|
||||
line_width=line_width,
|
||||
line_style=line_style,
|
||||
properties=properties or {},
|
||||
symbol=symbol,
|
||||
created_at=existing_shape.get('created_at') if existing_shape else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
shape_store.shapes[shape_id] = shape.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} shape '{shape_id}' "
|
||||
f"of type '{shape_type}' with {len(points)} points"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"shape": shape.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_shape(shape_id: str) -> Dict[str, str]:
|
||||
"""Delete a shape from the chart.
|
||||
|
||||
Args:
|
||||
shape_id: ID of the shape to delete
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If shape doesn't exist
|
||||
|
||||
Example:
|
||||
await delete_shape('my-trendline-1')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
if shape_id not in shape_store.shapes:
|
||||
raise ValueError(f"Shape '{shape_id}' not found")
|
||||
|
||||
# Delete the shape
|
||||
del shape_store.shapes[shape_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Deleted shape '{shape_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Shape '{shape_id}' deleted"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_shape(shape_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific shape by ID.
|
||||
|
||||
Args:
|
||||
shape_id: ID of the shape to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the shape data
|
||||
|
||||
Raises:
|
||||
ValueError: If shape doesn't exist
|
||||
|
||||
Example:
|
||||
shape = get_shape('my-trendline-1')
|
||||
print(f"Shape type: {shape['type']}")
|
||||
print(f"Points: {shape['points']}")
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
shape = shape_store.shapes.get(shape_id)
|
||||
if not shape:
|
||||
raise ValueError(f"Shape '{shape_id}' not found")
|
||||
|
||||
return shape
|
||||
|
||||
|
||||
@tool
|
||||
def list_all_shapes() -> List[Dict[str, Any]]:
|
||||
"""List all shapes currently on the chart.
|
||||
|
||||
Returns:
|
||||
List of all shapes as dictionaries
|
||||
|
||||
Example:
|
||||
shapes = list_all_shapes()
|
||||
print(f"Total shapes: {len(shapes)}")
|
||||
for shape in shapes:
|
||||
print(f" - {shape['id']}: {shape['type']}")
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
return list(shape_store.shapes.values())
|
||||
|
||||
|
||||
SHAPE_TOOLS = [
|
||||
search_shapes,
|
||||
create_or_update_shape,
|
||||
delete_shape,
|
||||
get_shape,
|
||||
list_all_shapes
|
||||
]
|
||||
@@ -23,6 +23,7 @@ from agent.core import create_agent
|
||||
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from schema.shape import ShapeCollection
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
@@ -124,7 +125,7 @@ async def lifespan(app: FastAPI):
|
||||
chroma_db_path=config["memory"]["chroma_db"],
|
||||
embedding_model=config["memory"]["embedding_model"],
|
||||
context_docs_dir=config["agent"]["context_docs_dir"],
|
||||
base_dir=".." # Point to project root from backend/src
|
||||
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ lives
|
||||
)
|
||||
|
||||
await agent_executor.initialize()
|
||||
@@ -159,13 +160,19 @@ class OrderStore(BaseModel):
|
||||
class ChartStore(BaseModel):
|
||||
chart_state: ChartState = ChartState()
|
||||
|
||||
# ShapeStore model for synchronization
|
||||
class ShapeStore(BaseModel):
|
||||
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
|
||||
|
||||
# Initialize stores
|
||||
order_store = OrderStore()
|
||||
chart_store = ChartStore()
|
||||
shape_store = ShapeStore()
|
||||
|
||||
# Register with SyncRegistry
|
||||
registry.register(order_store, store_name="OrderStore")
|
||||
registry.register(chart_store, store_name="ChartStore")
|
||||
registry.register(shape_store, store_name="ShapeStore")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
@@ -361,11 +368,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
elif msg_type == "patch":
|
||||
patch_msg = PatchMessage(**message_json)
|
||||
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
|
||||
await registry.apply_client_patch(
|
||||
store_name=patch_msg.store,
|
||||
client_base_seq=patch_msg.seq,
|
||||
patch=patch_msg.patch
|
||||
)
|
||||
try:
|
||||
await registry.apply_client_patch(
|
||||
store_name=patch_msg.store,
|
||||
client_base_seq=patch_msg.seq,
|
||||
patch=patch_msg.patch
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True)
|
||||
elif msg_type == "agent_user_message":
|
||||
# Handle agent messages directly here
|
||||
print(f"[DEBUG] Raw message_json: {message_json}")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -23,3 +23,6 @@ class ChartState(BaseModel):
|
||||
# Optional: Chart interval/resolution
|
||||
# None when chart is not visible
|
||||
interval: Optional[str] = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D'), or None if no chart visible")
|
||||
|
||||
# Selected shapes/drawings on the chart
|
||||
selected_shapes: List[str] = Field(default_factory=list, description="Array of selected shape IDs")
|
||||
|
||||
44
backend/src/schema/shape.py
Normal file
44
backend/src/schema/shape.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ControlPoint(BaseModel):
|
||||
"""A control point for a drawing shape.
|
||||
|
||||
Control points define the position and properties of a shape.
|
||||
Different shapes have different numbers of control points.
|
||||
"""
|
||||
time: int = Field(..., description="Unix timestamp in seconds")
|
||||
price: float = Field(..., description="Price level")
|
||||
# Optional channel for multi-point shapes (e.g., parallel channels)
|
||||
channel: Optional[str] = Field(default=None, description="Channel identifier for multi-point shapes")
|
||||
|
||||
|
||||
class Shape(BaseModel):
|
||||
"""A TradingView drawing shape/study.
|
||||
|
||||
Represents any drawing the user creates on the chart (trendlines,
|
||||
horizontal lines, rectangles, Fibonacci retracements, etc.)
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for the shape")
|
||||
type: str = Field(..., description="Shape type (e.g., 'trendline', 'horizontal_line', 'rectangle', 'fibonacci')")
|
||||
points: List[ControlPoint] = Field(default_factory=list, description="Control points that define the shape")
|
||||
|
||||
# Visual properties
|
||||
color: Optional[str] = Field(default=None, description="Shape color (hex or color name)")
|
||||
line_width: Optional[int] = Field(default=1, description="Line width in pixels")
|
||||
line_style: Optional[str] = Field(default="solid", description="Line style: 'solid', 'dashed', 'dotted'")
|
||||
|
||||
# Shape-specific properties stored as flexible dict
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Additional shape-specific properties")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this shape is drawn on")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend/agent before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class ShapeCollection(BaseModel):
|
||||
"""Collection of all shapes/drawings on the chart."""
|
||||
shapes: Dict[str, Shape] = Field(default_factory=dict, description="Dictionary of shapes keyed by ID")
|
||||
@@ -105,67 +105,105 @@ class SyncRegistry:
|
||||
|
||||
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
|
||||
|
||||
if client_base_seq == entry.seq:
|
||||
# No conflict
|
||||
logger.info("apply_client_patch: No conflict - applying patch directly")
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
try:
|
||||
if client_base_seq == entry.seq:
|
||||
# No conflict
|
||||
logger.info("apply_client_patch: No conflict - applying patch directly")
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
# Broadcasting would cause an infinite loop
|
||||
logger.info("apply_client_patch: Not broadcasting back to originating client")
|
||||
|
||||
elif client_base_seq < entry.seq:
|
||||
# Conflict! Frontend wins.
|
||||
# 1. Get backend patches since client_base_seq
|
||||
backend_patches = []
|
||||
for seq, p in entry.history:
|
||||
if seq > client_base_seq:
|
||||
backend_patches.append(p)
|
||||
|
||||
# 2. Apply frontend patch first to the state at client_base_seq
|
||||
# But we only have the current authoritative model.
|
||||
# "Apply the frontend patch first to the model (frontend wins)"
|
||||
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
|
||||
|
||||
# Let's get the state as it was at client_base_seq if possible?
|
||||
# No, history only has patches.
|
||||
|
||||
# Alternative: Apply frontend patch to current model.
|
||||
# Then re-apply backend patches, but discard parts that overlap.
|
||||
|
||||
frontend_paths = {p['path'] for p in patch}
|
||||
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
# Apply frontend patch
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
|
||||
# Re-apply backend patches that don't overlap
|
||||
for b_patch in backend_patches:
|
||||
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
|
||||
if filtered_b_patch:
|
||||
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
|
||||
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Commit the result as a single new patch
|
||||
# We need to compute what changed from last_snapshot to new_state
|
||||
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
|
||||
if final_patch:
|
||||
entry.commit_patch(final_patch)
|
||||
# Broadcast resolved state as snapshot to converge
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
# Broadcasting would cause an infinite loop
|
||||
logger.info("apply_client_patch: Not broadcasting back to originating client")
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
elif client_base_seq < entry.seq:
|
||||
# Conflict! Frontend wins.
|
||||
# 1. Get backend patches since client_base_seq
|
||||
backend_patches = []
|
||||
for seq, p in entry.history:
|
||||
if seq > client_base_seq:
|
||||
backend_patches.append(p)
|
||||
|
||||
# 2. Apply frontend patch first to the state at client_base_seq
|
||||
# But we only have the current authoritative model.
|
||||
# "Apply the frontend patch first to the model (frontend wins)"
|
||||
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
|
||||
|
||||
# Let's get the state as it was at client_base_seq if possible?
|
||||
# No, history only has patches.
|
||||
|
||||
# Alternative: Apply frontend patch to current model.
|
||||
# Then re-apply backend patches, but discard parts that overlap.
|
||||
|
||||
frontend_paths = {p['path'] for p in patch}
|
||||
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
# Apply frontend patch
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
return
|
||||
|
||||
# Re-apply backend patches that don't overlap
|
||||
for b_patch in backend_patches:
|
||||
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
|
||||
if filtered_b_patch:
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
|
||||
continue
|
||||
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Commit the result as a single new patch
|
||||
# We need to compute what changed from last_snapshot to new_state
|
||||
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
|
||||
if final_patch:
|
||||
entry.commit_patch(final_patch)
|
||||
# Broadcast resolved state as snapshot to converge
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
||||
# Update model using model_validate for potentially nested models
|
||||
|
||||
Reference in New Issue
Block a user