From 3ffce97b3eaf17e5e57180674c9f5696f533a0b9 Mon Sep 17 00:00:00 2001 From: Tim Olson Date: Mon, 2 Mar 2026 18:34:38 -0400 Subject: [PATCH] indicators and plots --- backend/config.yaml | 2 +- backend/memory/system_prompt.md | 184 +++-- backend/requirements.txt | 7 + backend/src/agent/__init__.py | 5 +- backend/src/agent/core.py | 85 ++- backend/src/agent/prompts.py | 60 +- backend/src/agent/tools.py | 662 ------------------ backend/src/agent/tools/CHART_UTILS_README.md | 139 ++++ backend/src/agent/tools/__init__.py | 50 ++ backend/src/agent/tools/chart_tools.py | 371 ++++++++++ backend/src/agent/tools/chart_utils.py | 224 ++++++ .../src/agent/tools/chart_utils_example.py | 154 ++++ backend/src/agent/tools/datasource_tools.py | 158 +++++ backend/src/agent/tools/indicator_tools.py | 169 +++++ backend/src/agent/tools/research_tools.py | 171 +++++ backend/src/agent/tools/sync_tools.py | 138 ++++ .../src/datasource/adapters/ccxt_adapter.py | 80 ++- backend/src/exchange_kernel/README.md | 179 +++++ backend/src/exchange_kernel/__init__.py | 75 ++ backend/src/exchange_kernel/base.py | 361 ++++++++++ backend/src/exchange_kernel/events.py | 250 +++++++ backend/src/exchange_kernel/models.py | 194 +++++ backend/src/exchange_kernel/state.py | 472 +++++++++++++ backend/src/gateway/hub.py | 59 +- backend/src/indicator/__init__.py | 172 +++++ backend/src/indicator/base.py | 230 ++++++ backend/src/indicator/pipeline.py | 439 ++++++++++++ backend/src/indicator/registry.py | 349 +++++++++ backend/src/indicator/schema.py | 269 +++++++ backend/src/indicator/talib_adapter.py | 436 ++++++++++++ backend/src/main.py | 15 +- backend/src/schema/order_spec.py | 187 +++++ backend/src/secrets_manager/__init__.py | 40 ++ backend/src/secrets_manager/cli.py | 374 ++++++++++ backend/src/secrets_manager/crypto.py | 144 ++++ backend/src/secrets_manager/store.py | 406 +++++++++++ backend/tests/test_ccxt_datasource.py | 16 +- deploy/Dockerfile-backend | 26 +- web/src/App.vue | 40 +- web/src/assets/theme.css | 13 + web/src/components/ChartView.vue | 5 + web/src/components/ChatPanel.vue | 144 +++- web/src/components/LoginScreen.vue | 14 +- 43 files changed, 6690 insertions(+), 878 deletions(-) delete mode 100644 backend/src/agent/tools.py create mode 100644 backend/src/agent/tools/CHART_UTILS_README.md create mode 100644 backend/src/agent/tools/__init__.py create mode 100644 backend/src/agent/tools/chart_tools.py create mode 100644 backend/src/agent/tools/chart_utils.py create mode 100644 backend/src/agent/tools/chart_utils_example.py create mode 100644 backend/src/agent/tools/datasource_tools.py create mode 100644 backend/src/agent/tools/indicator_tools.py create mode 100644 backend/src/agent/tools/research_tools.py create mode 100644 backend/src/agent/tools/sync_tools.py create mode 100644 backend/src/exchange_kernel/README.md create mode 100644 backend/src/exchange_kernel/__init__.py create mode 100644 backend/src/exchange_kernel/base.py create mode 100644 backend/src/exchange_kernel/events.py create mode 100644 backend/src/exchange_kernel/models.py create mode 100644 backend/src/exchange_kernel/state.py create mode 100644 backend/src/indicator/__init__.py create mode 100644 backend/src/indicator/base.py create mode 100644 backend/src/indicator/pipeline.py create mode 100644 backend/src/indicator/registry.py create mode 100644 backend/src/indicator/schema.py create mode 100644 backend/src/indicator/talib_adapter.py create mode 100644 backend/src/secrets_manager/__init__.py create mode 100644 backend/src/secrets_manager/cli.py create mode 100644 backend/src/secrets_manager/crypto.py create mode 100644 backend/src/secrets_manager/store.py diff --git a/backend/config.yaml b/backend/config.yaml index 6dec1c6..183d7e7 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -5,7 +5,7 @@ server_port: 8081 agent: model: "claude-sonnet-4-20250514" temperature: 0.7 - context_docs_dir: "doc" + context_docs_dir: "memory" # Local memory configuration (free & sophisticated!) memory: diff --git a/backend/memory/system_prompt.md b/backend/memory/system_prompt.md index f657f64..3e8abc3 100644 --- a/backend/memory/system_prompt.md +++ b/backend/memory/system_prompt.md @@ -14,15 +14,7 @@ You are a **strategy authoring assistant**, not a strategy executor. You help us ## Your Capabilities ### State Management -You have read/write access to synchronized state stores: -- **OrderStore**: Active swap orders and order configurations -- **ChartStore**: Current chart view state (symbol, time range, interval) - - `symbol`: Trading pair currently being viewed (e.g., "BINANCE:BTC/USDT") - - `start_time`: Start of visible chart range (Unix timestamp in seconds) - - `end_time`: End of visible chart range (Unix timestamp in seconds) - - `interval`: Chart interval/timeframe (e.g., "15", "60", "D") -- Use your tools to read current state and update it as needed -- All state changes are automatically synchronized with connected clients +You have read/write access to synchronized state stores. Use your tools to read current state and update it as needed. All state changes are automatically synchronized with connected clients. ### Strategy Authoring - Help users express trading intent through conversation @@ -32,10 +24,9 @@ You have read/write access to synchronized state stores: - Validate strategy logic for correctness and safety ### Data & Analysis -- Access to market data through abstract feed specifications -- Can compute indicators and perform technical analysis +- Access market data through abstract feed specifications +- Compute indicators and perform technical analysis - Understand OHLCV data, order books, and market microstructure -- Interpret unstructured data (news, sentiment, on-chain metrics) ## Communication Style @@ -48,7 +39,7 @@ You have read/write access to synchronized state stores: ## Key Principles 1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime -2. **Local Execution**: The platform runs locally for security; you're design-time only +2. **Local Execution**: The platform runs locally for security; you are a design-time tool only 3. **Schema Validation**: All outputs must conform to platform schemas 4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management 5. **Versioning**: Every strategy artifact is version-controlled with full auditability @@ -56,7 +47,6 @@ You have read/write access to synchronized state stores: ## Your Limitations - You **DO NOT** execute trades directly -- You **DO NOT** have access to live market data in real-time (users provide it) - You **CANNOT** modify the order kernel or execution layer - You **SHOULD NOT** make assumptions about user risk tolerance without asking - You **MUST NOT** provide trading or investment advice @@ -69,53 +59,93 @@ You have access to: - Past strategy discussions and decisions - Relevant context retrieved automatically based on current conversation -## Tools Available - -### State Management Tools -- `list_sync_stores()`: See available state stores -- `read_sync_state(store_name)`: Read current state -- `write_sync_state(store_name, updates)`: Update state -- `get_store_schema(store_name)`: Inspect state structure - -### Data Source Tools -- `list_data_sources()`: List available data sources (exchanges) -- `search_symbols(query, type, exchange, limit)`: Search for trading symbols -- `get_symbol_info(source_name, symbol)`: Get metadata for a symbol -- `get_historical_data(source_name, symbol, resolution, from_time, to_time, countback)`: Get historical bars -- **`get_chart_data(countback)`**: Get data for the chart the user is currently viewing - - This is the **preferred** way to access chart data when analyzing what the user is looking at - - Automatically reads ChartStore to determine symbol, timeframe, and visible range - - Returns OHLCV data plus any custom columns for the visible chart range -- **`analyze_chart_data(python_script, countback)`**: Execute Python analysis on current chart data - - Automatically fetches current chart data and converts to pandas DataFrame - - Execute custom Python scripts with access to pandas, numpy, matplotlib - - Captures matplotlib plots as base64 images for display to user - - Returns result DataFrames and any printed output - - **Use this for technical analysis, indicator calculations, statistical analysis, and visualization** - ## Important Behavioral Rules ### Chart Context Awareness When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view: -1. **ALWAYS** first use `read_sync_state("ChartStore")` to see what they're viewing +1. **Chart info is automatically available** — The dynamic system prompt includes current chart state (symbol, interval, timeframe) 2. **NEVER** ask the user to upload an image or tell you what symbol they're looking at -3. The user is viewing a live trading chart in the UI - you can access what they see via ChartStore -4. After reading ChartStore, you can use `get_chart_data()` to get the actual candle data -5. For technical analysis questions, use `analyze_chart_data()` with Python scripts +3. **Just use `execute_python()`** — It automatically loads the chart data from what they're viewing +4. Inside your Python script, `df` contains the data and `chart_context` has the metadata +5. Use `plot_ohlc(df)` to create beautiful candlestick charts -Examples of questions that require checking ChartStore first: -- "Can you see this chart?" -- "What are the swing highs and lows?" -- "Is this in an uptrend?" -- "What's the current price?" -- "Analyze this chart" -- "What am I looking at?" +This applies to questions like: "Can you see this chart?", "What are the swing highs and lows?", "Is this in an uptrend?", "What's the current price?", "Analyze this chart", "What am I looking at?" ### Data Analysis Workflow -1. **Check ChartStore** → Know what the user is viewing -2. **Get data** with `get_chart_data()` → Fetch the actual OHLCV bars -3. **Analyze** with `analyze_chart_data()` → Run Python analysis if needed -4. **Respond** with insights based on the actual data +1. **Chart context is automatic** → Symbol, interval, and timeframe are in the dynamic system prompt +2. **Use `execute_python()`** → This is your PRIMARY analysis tool + - Automatically loads chart data into a pandas DataFrame `df` + - Pre-imports numpy (`np`), pandas (`pd`), matplotlib (`plt`), and talib + - Provides access to the indicator registry for computing indicators + - Use `plot_ohlc(df)` helper for beautiful candlestick charts +3. **Only use `get_chart_data()`** → For simple data inspection without analysis + +### Python Analysis (`execute_python`) - Your Primary Tool + +**ALWAYS use `execute_python()` when the user asks for:** +- Technical indicators (RSI, MACD, Bollinger Bands, moving averages, etc.) +- Chart visualizations or plots +- Statistical calculations or market analysis +- Pattern detection or trend analysis +- Any computational analysis of price data + +**Why `execute_python()` is preferred:** +- Chart data (`df`) is automatically loaded from ChartStore (visible time range) +- Full pandas/numpy/talib stack pre-imported +- Use `plot_ohlc(df)` for instant professional candlestick charts +- Access to 150+ indicators via `indicator_registry` +- **Results include plots as image URLs** that are automatically displayed to the user +- Prints and return values are included in the response + +**CRITICAL: Plots are automatically shown to the user** +When you create a matplotlib figure (via `plot_ohlc()` or `plt.figure()`), it is automatically: +1. Saved as a PNG image +2. Returned in the response as a URL (e.g., `/uploads/plot_abc123.png`) +3. **Displayed in the user's chat interface** - they see the image immediately + +You MUST use `execute_python()` with `plot_ohlc()` or matplotlib whenever the user wants to see a chart or plot. + +**IMPORTANT: Never use `get_historical_data()` for chart analysis** +- `get_historical_data()` requires manual timestamp calculation and is only for custom queries +- When analyzing what the user is viewing, ALWAYS use `execute_python()` which automatically loads the correct data +- The `df` DataFrame in `execute_python()` is pre-loaded with the exact time range the user is viewing + +**Example workflows:** +```python +# Computing an indicator and plotting +execute_python(""" +df['RSI'] = talib.RSI(df['close'], 14) +fig = plot_ohlc(df, title='Price with RSI') +df[['close', 'RSI']].tail(10) +""") + +# Multi-indicator analysis +execute_python(""" +df['SMA20'] = df['close'].rolling(20).mean() +df['BB_upper'] = df['close'].rolling(20).mean() + 2 * df['close'].rolling(20).std() +df['BB_lower'] = df['close'].rolling(20).mean() - 2 * df['close'].rolling(20).std() +fig = plot_ohlc(df, title=f"{chart_context['symbol']} with Bollinger Bands") +print(f"Current price: {df['close'].iloc[-1]:.2f}") +print(f"20-period SMA: {df['SMA20'].iloc[-1]:.2f}") +""") +``` + +**Only use `get_chart_data()` for:** +- Quick inspection of raw bar data +- When you just need the data structure without analysis + +### Quick Reference: Common Tasks + +| User Request | Tool to Use | Example | +|--------------|-------------|---------| +| "Show me RSI" | `execute_python()` | `df['RSI'] = talib.RSI(df['close'], 14); plot_ohlc(df)` | +| "What's the current price?" | `execute_python()` | `print(f"Current: {df['close'].iloc[-1]}")` | +| "Is this bullish?" | `execute_python()` | Compute SMAs, trend, and analyze | +| "Add Bollinger Bands" | `execute_python()` | Compute bands, use `plot_ohlc(df, title='BB')` | +| "Find swing highs" | `execute_python()` | Use pandas logic to detect patterns | +| "What indicators exist?" | `search_indicators()` | Search by category or query | +| "What chart am I viewing?" | N/A - automatic | Chart info is in dynamic system prompt | +| "Read other stores" | `read_sync_state(store_name)` | For TraderState, StrategyState, etc. | ## Working with Users @@ -124,51 +154,3 @@ Examples of questions that require checking ChartStore first: 3. **Validate**: Ensure strategy makes sense before generating code 4. **Test**: Encourage backtesting and paper trading first 5. **Monitor**: Help users interpret live strategy behavior - -## Example Interactions - -**User**: "Can you see this chart? What are the swing highs and lows?" -**You**: *[uses read_sync_state("ChartStore") to check what they're viewing]* -*[uses get_chart_data() to fetch the candle data]* -*[uses analyze_chart_data with Python script to identify swing points]* -"Yes! You're viewing BINANCE:BTC/USDT on the 15-minute chart. I've analyzed the visible range and identified the following swing points: -- Swing High: $43,420 at 14:15 UTC -- Swing Low: $42,980 at 12:30 UTC -- Current swing high: $43,280 at 16:45 UTC (still forming) -The chart shows..." - -**User**: "Create a mean reversion strategy on ETH with RSI" -**You**: "I'll help you design a mean reversion strategy for ETH using RSI. A few questions: -- What RSI levels indicate oversold/overbought? (e.g., 30/70) -- What timeframe? (e.g., 1h, 15m) -- Position sizing as % of balance? -- Stop loss and take profit levels?" - -**User**: "What am I looking at?" -**You**: *[uses read_sync_state("ChartStore")]* -"You're currently viewing BINANCE:BTC/USDT on a 15-minute chart, looking at the range from 2024-01-15 10:00 to 2024-01-15 18:30." - -**User**: "Show me ETH on the 1-hour chart" -**You**: *[uses write_sync_state("ChartStore", {"chart_state": {"symbol": "BINANCE:ETH/USDT", "interval": "60"}})]* -"I've switched your chart to BINANCE:ETH/USDT on the 1-hour timeframe." - -**User**: "What's the current price?" -**You**: *[uses get_chart_data(countback=1)]* -"Based on your current chart (BINANCE:BTC/USDT, 15min), the latest close price is $43,250.50 as of 14:30 UTC." - -**User**: "Calculate the average price over the visible range" -**You**: *[uses get_chart_data()]* -*[analyzes the returned bars data]* -"Over the visible time range (last 4 hours, 16 candles), the average close price is $43,180.25, with a high of $43,420 and low of $42,980." - -**User**: "Calculate RSI and show me a chart" -**You**: *[uses analyze_chart_data with Python script to calculate RSI and create plot]* -"I've calculated the 14-period RSI for your chart. The current RSI is 58.3, indicating neutral momentum. Here's the chart showing price and RSI over the visible range." *[image displayed to user]* - -**User**: "Is this in an uptrend?" -**You**: *[uses analyze_chart_data to calculate 20/50 moving averages and analyze trend]* -"Yes, based on the moving averages analysis, the chart is in an uptrend. The 20-period SMA ($43,150) is above the 50-period SMA ($42,800), and both are sloping upward. Price is currently trading above both averages." - ---- - -Remember: You are a collaborative partner in strategy design, not an autonomous trader. Always prioritize safety, clarity, and user intent. diff --git a/backend/requirements.txt b/backend/requirements.txt index 022d483..09e95f7 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,6 +4,7 @@ pandas numpy scipy matplotlib +mplfinance fastapi uvicorn websockets @@ -11,6 +12,7 @@ jsonpatch python-multipart ccxt>=4.0.0 pyyaml +TA-Lib>=0.4.0 # LangChain agent dependencies langchain>=0.3.0 @@ -19,6 +21,11 @@ langgraph-checkpoint-sqlite>=1.0.0 langchain-anthropic>=0.3.0 langchain-community>=0.3.0 +# Additional tools for research and web access +arxiv>=2.0.0 +duckduckgo-search>=7.0.0 +requests>=2.31.0 + # Local memory system chromadb>=0.4.0 sentence-transformers>=2.0.0 diff --git a/backend/src/agent/__init__.py b/backend/src/agent/__init__.py index 34f2a59..eca6d20 100644 --- a/backend/src/agent/__init__.py +++ b/backend/src/agent/__init__.py @@ -1,3 +1,4 @@ -from agent.core import create_agent +# Don't import at module level to avoid circular imports +# Users should import directly: from agent.core import create_agent -__all__ = ["create_agent"] +__all__ = ["core", "tools"] diff --git a/backend/src/agent/core.py b/backend/src/agent/core.py index e14f2da..ca15b9b 100644 --- a/backend/src/agent/core.py +++ b/backend/src/agent/core.py @@ -7,7 +7,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langchain_core.runnables import RunnableConfig from langgraph.prebuilt import create_react_agent -from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS +from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS from agent.memory import MemoryManager from agent.session import SessionManager from agent.prompts import build_system_prompt @@ -60,17 +60,15 @@ class AgentExecutor: """Initialize the agent system.""" await self.memory_manager.initialize() - # Create agent with tools and LangGraph checkpointing + # Create agent with tools and LangGraph checkpointer checkpointer = self.memory_manager.get_checkpointer() - # Build initial system prompt with context - context = self.memory_manager.get_context_prompt() - system_prompt = build_system_prompt(context, []) - + # Create agent without a static system prompt + # We'll pass the dynamic system prompt via state_modifier at runtime + # Include all tool categories: sync, datasource, chart, indicator, and research self.agent = create_react_agent( self.llm, - SYNC_TOOLS + DATASOURCE_TOOLS, - prompt=system_prompt, + SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS, checkpointer=checkpointer ) @@ -101,26 +99,6 @@ class AgentExecutor: except Exception as e: logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}") - def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage: - """Build system message with context. - - Args: - state: Agent state - - Returns: - SystemMessage with full context - """ - # Get context from loaded documents - context = self.memory_manager.get_context_prompt() - - # Get active channels from metadata - active_channels = state.get("metadata", {}).get("active_channels", []) - - # Build system prompt - system_prompt = build_system_prompt(context, active_channels) - - return SystemMessage(content=system_prompt) - async def execute( self, session: UserSession, @@ -143,7 +121,12 @@ class AgentExecutor: async with lock: try: - # Build message history + # Build system prompt with current context + context = self.memory_manager.get_context_prompt() + system_prompt = build_system_prompt(context, session.active_channels) + + # Build message history WITHOUT prepending system message + # The system prompt will be passed via state_modifier in the config messages = [] history = session.get_history(limit=10) logger.info(f"Building message history, {len(history)} messages in history") @@ -155,14 +138,18 @@ class AgentExecutor: elif msg.role == "assistant": messages.append(AIMessage(content=msg.content)) - logger.info(f"Prepared {len(messages)} messages for agent") + logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)") for i, msg in enumerate(messages): - logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'") + msg_type = type(msg).__name__ + content_preview = msg.content[:100] if msg.content else 'EMPTY' + logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'") - # Prepare config with metadata + # Prepare config with metadata and dynamic system prompt + # Pass system_prompt via state_modifier to avoid multiple system messages config = RunnableConfig( configurable={ - "thread_id": session.session_id + "thread_id": session.session_id, + "state_modifier": system_prompt # Dynamic system prompt injection }, metadata={ "session_id": session.session_id, @@ -178,6 +165,8 @@ class AgentExecutor: event_count = 0 chunk_count = 0 + plot_urls = [] # Accumulate plot URLs from execute_python tool calls + async for event in self.agent.astream_events( {"messages": messages}, config=config, @@ -199,7 +188,35 @@ class AgentExecutor: elif event["event"] == "on_tool_end": tool_name = event.get("name", "unknown") tool_output = event.get("data", {}).get("output") - logger.info(f"Tool call completed: {tool_name} with output: {tool_output}") + + # LangChain may wrap the output in a ToolMessage with content field + # Try to extract the actual content from the ToolMessage + actual_output = tool_output + if hasattr(tool_output, "content"): + actual_output = tool_output.content + + logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}") + + # Extract plot_urls from execute_python tool results + if tool_name == "execute_python": + # Try to parse as JSON if it's a string + import json + if isinstance(actual_output, str): + try: + actual_output = json.loads(actual_output) + except (json.JSONDecodeError, ValueError): + logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}") + + if isinstance(actual_output, dict): + tool_plot_urls = actual_output.get("plot_urls", []) + if tool_plot_urls: + logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}") + plot_urls.extend(tool_plot_urls) + # Yield metadata about plots immediately + yield { + "content": "", + "metadata": {"plot_urls": tool_plot_urls} + } # Extract streaming tokens elif event["event"] == "on_chat_model_stream": diff --git a/backend/src/agent/prompts.py b/backend/src/agent/prompts.py index 29a2ed8..3821bd6 100644 --- a/backend/src/agent/prompts.py +++ b/backend/src/agent/prompts.py @@ -1,7 +1,54 @@ -from typing import List +from typing import List, Dict, Any from gateway.user_session import UserSession +def _get_chart_store_context() -> str: + """Get current ChartStore state for context injection. + + Returns: + Formatted string with ChartStore contents, or empty string if unavailable + """ + try: + from agent.tools import _registry + + if not _registry: + return "" + + chart_store = _registry.entries.get("ChartStore") + if not chart_store: + return "" + + chart_state = chart_store.model.model_dump(mode="json") + chart_data = chart_state.get("chart_state", {}) + + # Only include if there's actual chart data + if not chart_data or not chart_data.get("symbol"): + return "" + + # Format the chart information + symbol = chart_data.get("symbol", "N/A") + interval = chart_data.get("interval", "N/A") + start_time = chart_data.get("start_time") + end_time = chart_data.get("end_time") + + chart_context = f""" +## Current Chart Context + +The user is currently viewing a chart with the following settings: +- **Symbol**: {symbol} +- **Interval**: {interval} +- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"} + +This information is automatically available because you're connected via websocket. +When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean. +""" + return chart_context + + except Exception: + # Silently fail - chart context is optional enhancement + return "" + + def build_system_prompt(context: str, active_channels: List[str]) -> str: """Build the system prompt for the agent. @@ -17,6 +64,15 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str: """ channels_str = ", ".join(active_channels) if active_channels else "none" + # Check if user is connected via websocket - if so, inject chart context + # Note: We check for websocket by looking for "websocket" in channel IDs + # since WebSocketChannel uses channel_id like "websocket-{uuid}" + has_websocket = any("websocket" in channel_id.lower() for channel_id in active_channels) + + chart_context = "" + if has_websocket: + chart_context = _get_chart_store_context() + # Context already includes system_prompt.md and other docs # Just add current session information prompt = f"""{context} @@ -28,7 +84,7 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str: Your responses will be sent to all active channels. Your responses are streamed back in real-time. If the user sends a new message while you're responding, your current response will be interrupted and you'll be re-invoked with the updated context. -""" +{chart_context}""" return prompt diff --git a/backend/src/agent/tools.py b/backend/src/agent/tools.py deleted file mode 100644 index 5603c40..0000000 --- a/backend/src/agent/tools.py +++ /dev/null @@ -1,662 +0,0 @@ -from typing import Dict, Any, List, Optional -import io -import base64 -import sys -import uuid -import logging -from pathlib import Path -from contextlib import redirect_stdout, redirect_stderr -from langchain_core.tools import tool - -logger = logging.getLogger(__name__) - -# Global registry instance (will be set by main.py) -_registry = None -_datasource_registry = None - - -def set_registry(registry): - """Set the global SyncRegistry instance for tools to use.""" - global _registry - _registry = registry - - -def set_datasource_registry(datasource_registry): - """Set the global DataSourceRegistry instance for tools to use.""" - global _datasource_registry - _datasource_registry = datasource_registry - - -@tool -def list_sync_stores() -> List[str]: - """List all available synchronization stores. - - Returns: - List of store names that can be read/written - """ - if not _registry: - return [] - return list(_registry.entries.keys()) - - -@tool -def read_sync_state(store_name: str) -> Dict[str, Any]: - """Read the current state of a synchronization store. - - Args: - store_name: Name of the store to read (e.g., "TraderState", "StrategyState") - - Returns: - Dictionary containing the current state of the store - - Raises: - ValueError: If store_name doesn't exist - """ - if not _registry: - raise ValueError("SyncRegistry not initialized") - - entry = _registry.entries.get(store_name) - if not entry: - available = list(_registry.entries.keys()) - raise ValueError(f"Store '{store_name}' not found. Available stores: {available}") - - return entry.model.model_dump(mode="json") - - -@tool -async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]: - """Update the state of a synchronization store. - - This will apply the updates to the store and trigger synchronization - with all connected clients. - - Args: - store_name: Name of the store to update - updates: Dictionary of field updates (field_name: new_value) - - Returns: - Dictionary with status and updated fields - - Raises: - ValueError: If store_name doesn't exist or updates are invalid - """ - if not _registry: - raise ValueError("SyncRegistry not initialized") - - entry = _registry.entries.get(store_name) - if not entry: - available = list(_registry.entries.keys()) - raise ValueError(f"Store '{store_name}' not found. Available stores: {available}") - - try: - # Get current state - current_state = entry.model.model_dump(mode="json") - - # Apply updates - new_state = {**current_state, **updates} - - # Update the model - _registry._update_model(entry.model, new_state) - - # Trigger sync - await _registry.push_all() - - return { - "status": "success", - "store": store_name, - "updated_fields": list(updates.keys()) - } - - except Exception as e: - raise ValueError(f"Failed to update store '{store_name}': {str(e)}") - - -@tool -def get_store_schema(store_name: str) -> Dict[str, Any]: - """Get the schema/structure of a synchronization store. - - This shows what fields are available and their types. - - Args: - store_name: Name of the store - - Returns: - Dictionary describing the store's schema - - Raises: - ValueError: If store_name doesn't exist - """ - if not _registry: - raise ValueError("SyncRegistry not initialized") - - entry = _registry.entries.get(store_name) - if not entry: - available = list(_registry.entries.keys()) - raise ValueError(f"Store '{store_name}' not found. Available stores: {available}") - - # Get model schema - schema = entry.model.model_json_schema() - - return { - "store_name": store_name, - "schema": schema - } - - -# DataSource tools - -@tool -def list_data_sources() -> List[str]: - """List all available data sources. - - Returns: - List of data source names that can be queried for market data - """ - if not _datasource_registry: - return [] - return _datasource_registry.list_sources() - - -@tool -async def search_symbols( - query: str, - type: Optional[str] = None, - exchange: Optional[str] = None, - limit: int = 30, -) -> Dict[str, Any]: - """Search for trading symbols across all data sources. - - Automatically searches all available data sources and returns aggregated results. - Use this to find symbols before calling get_symbol_info or get_historical_data. - - Args: - query: Search query (e.g., "BTC", "AAPL", "EUR") - type: Optional filter by instrument type (e.g., "crypto", "stock", "forex") - exchange: Optional filter by exchange (e.g., "binance", "nasdaq") - limit: Maximum number of results per source (default: 30) - - Returns: - Dictionary mapping source names to lists of matching symbols. - Each symbol includes: symbol, full_name, description, exchange, type. - Use the source name and symbol from results with get_symbol_info or get_historical_data. - - Example response: - { - "demo": [ - { - "symbol": "BTC/USDT", - "full_name": "Bitcoin / Tether USD", - "description": "Bitcoin perpetual futures", - "exchange": "demo", - "type": "crypto" - } - ] - } - """ - if not _datasource_registry: - raise ValueError("DataSourceRegistry not initialized") - - # Always search all sources - results = await _datasource_registry.search_all(query, type, exchange, limit) - return {name: [r.model_dump() for r in matches] for name, matches in results.items()} - - -@tool -async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]: - """Get complete metadata for a trading symbol. - - This retrieves full information about a symbol including: - - Description and type - - Supported time resolutions - - Available data columns (OHLCV, volume, funding rates, etc.) - - Trading session information - - Price scale and precision - - Args: - source_name: Name of the data source (use list_data_sources to see available) - symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD") - - Returns: - Dictionary containing complete symbol metadata including column schema - - Raises: - ValueError: If source_name or symbol is not found - """ - if not _datasource_registry: - raise ValueError("DataSourceRegistry not initialized") - - symbol_info = await _datasource_registry.resolve_symbol(source_name, symbol) - return symbol_info.model_dump() - - -@tool -async def get_historical_data( - source_name: str, - symbol: str, - resolution: str, - from_time: int, - to_time: int, - countback: Optional[int] = None, -) -> Dict[str, Any]: - """Get historical bar/candle data for a symbol. - - Retrieves time-series data between the specified timestamps. The data - includes all columns defined for the symbol (OHLCV + any custom columns). - - Args: - source_name: Name of the data source - symbol: Symbol identifier - resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day) - from_time: Start time as Unix timestamp in seconds - to_time: End time as Unix timestamp in seconds - countback: Optional limit on number of bars to return - - Returns: - Dictionary containing: - - symbol: The requested symbol - - resolution: The time resolution - - bars: List of bar data with 'time' and 'data' fields - - columns: Schema describing available data columns - - nextTime: If present, indicates more data is available for pagination - - Raises: - ValueError: If source, symbol, or resolution is invalid - - Example: - # Get 1-hour BTC data for the last 24 hours - import time - to_time = int(time.time()) - from_time = to_time - 86400 # 24 hours ago - data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time) - """ - if not _datasource_registry: - raise ValueError("DataSourceRegistry not initialized") - - source = _datasource_registry.get(source_name) - if not source: - available = _datasource_registry.list_sources() - raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}") - - result = await source.get_bars(symbol, resolution, from_time, to_time, countback) - return result.model_dump() - - -async def _get_chart_data_impl(countback: Optional[int] = None): - """Internal implementation for getting chart data. - - This is a helper function that can be called by both get_chart_data tool - and analyze_chart_data tool. - - Returns: - Tuple of (HistoryResult, chart_context dict, source_name) - """ - if not _registry: - raise ValueError("SyncRegistry not initialized - cannot read ChartStore") - - if not _datasource_registry: - raise ValueError("DataSourceRegistry not initialized - cannot query data") - - # Read current chart state - chart_store = _registry.entries.get("ChartStore") - if not chart_store: - raise ValueError("ChartStore not found in registry") - - chart_state = chart_store.model.model_dump(mode="json") - chart_data = chart_state.get("chart_state", {}) - - symbol = chart_data.get("symbol", "") - interval = chart_data.get("interval", "15") - start_time = chart_data.get("start_time") - end_time = chart_data.get("end_time") - - if not symbol: - raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet") - - # Parse the symbol to extract exchange/source and symbol name - # Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD") - if ":" not in symbol: - raise ValueError( - f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' " - f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')" - ) - - exchange_prefix, symbol_name = symbol.split(":", 1) - source_name = exchange_prefix.lower() - - # Get the data source - source = _datasource_registry.get(source_name) - if not source: - available = _datasource_registry.list_sources() - raise ValueError( - f"Data source '{source_name}' not found. Available sources: {available}. " - f"Make sure the exchange in the symbol '{symbol}' matches an available source." - ) - - # Determine time range - REQUIRE it to be set, no defaults - if start_time is None or end_time is None: - raise ValueError( - f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. " - f"The user needs to load the chart first, or the frontend may not be sending the visible range. " - f"Wait for the chart to fully load before analyzing data." - ) - - from_time = int(start_time) - end_time = int(end_time) - logger.info( - f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, " - f"countback={countback}" - ) - - logger.info( - f"Querying data source '{source_name}' for symbol '{symbol_name}', " - f"resolution '{interval}'" - ) - - # Query the data source - result = await source.get_bars( - symbol=symbol_name, - resolution=interval, - from_time=from_time, - to_time=end_time, - countback=countback - ) - - logger.info( - f"Received {len(result.bars)} bars from data source. " - f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, " - f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}" - ) - - # Build chart context to return along with result - chart_context = { - "symbol": symbol, - "interval": interval, - "start_time": start_time, - "end_time": end_time - } - - return result, chart_context, source_name - - -@tool -async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]: - """Get the candle/bar data for what the user is currently viewing on their chart. - - This is a convenience tool that automatically: - 1. Reads the ChartStore to see what chart the user is viewing - 2. Parses the symbol to determine the data source (exchange prefix) - 3. Queries the appropriate data source for that symbol's data - 4. Returns the data for the visible time range and interval - - This is the preferred way to access chart data when helping the user analyze - what they're looking at, since it automatically uses their current chart context. - - Args: - countback: Optional limit on number of bars to return. If not specified, - returns all bars in the visible time range. - - Returns: - Dictionary containing: - - chart_context: Current chart state (symbol, interval, time range) - - symbol: The trading pair being viewed - - resolution: The chart interval - - bars: List of bar data with 'time' and 'data' fields - - columns: Schema describing available data columns - - source: Which data source was used - - Raises: - ValueError: If ChartStore or DataSourceRegistry is not initialized, - or if the symbol format is invalid - - Example: - # User is viewing BINANCE:BTC/USDT on 15min chart - data = get_chart_data() - # Returns BTC/USDT data from binance source at 15min resolution - # for the currently visible time range - """ - result, chart_context, source_name = await _get_chart_data_impl(countback) - - # Return enriched result with chart context - response = result.model_dump() - response["chart_context"] = chart_context - response["source"] = source_name - - return response - - -@tool -async def analyze_chart_data(python_script: str, countback: Optional[int] = None) -> Dict[str, Any]: - """Analyze the current chart data using a Python script with pandas and matplotlib. - - This tool: - 1. Gets the current chart data (same as get_chart_data) - 2. Converts it to a pandas DataFrame with columns: time, open, high, low, close, volume - 3. Executes your Python script with access to the DataFrame as 'df' - 4. Saves any matplotlib plots to disk and returns URLs to access them - 5. Returns any final DataFrame result and plot URLs - - The script has access to: - - `df`: pandas DataFrame with OHLCV data indexed by datetime - - `pandas` (as `pd`): For data manipulation - - `numpy` (as `np`): For numerical operations - - `matplotlib.pyplot` (as `plt`): For plotting (use plt.figure() for each plot) - - All matplotlib figures are automatically saved to disk and accessible via URLs. - The last expression in the script (if it's a DataFrame) is returned as the result. - - Args: - python_script: Python code to execute. The DataFrame is available as 'df'. - Can use pandas, numpy, matplotlib. Return a DataFrame to include it in results. - countback: Optional limit on number of bars to analyze - - Returns: - Dictionary containing: - - chart_context: Current chart state (symbol, interval, time range) - - source: Data source used - - script_output: Any printed output from the script - - result_dataframe: If script returns a DataFrame, it's included here as dict - - plot_urls: List of URLs to saved plot images (one per plt.figure()) - - error: Error message if script execution failed - - Example scripts: - # Calculate 20-period SMA and plot - ```python - df['SMA20'] = df['close'].rolling(20).mean() - plt.figure(figsize=(12, 6)) - plt.plot(df.index, df['close'], label='Close') - plt.plot(df.index, df['SMA20'], label='SMA20') - plt.legend() - plt.title('Price with SMA') - df[['close', 'SMA20']].tail(10) # Return last 10 rows - ``` - - # Calculate RSI - ```python - delta = df['close'].diff() - gain = (delta.where(delta > 0, 0)).rolling(14).mean() - loss = (-delta.where(delta < 0, 0)).rolling(14).mean() - rs = gain / loss - df['RSI'] = 100 - (100 / (1 + rs)) - df[['close', 'RSI']].tail(20) - ``` - - # Multiple plots - ```python - # Price chart - plt.figure(figsize=(12, 4)) - plt.plot(df['close']) - plt.title('Price') - - # Volume chart - plt.figure(figsize=(12, 3)) - plt.bar(df.index, df['volume']) - plt.title('Volume') - - df.describe() # Return statistics - ``` - """ - if not _registry: - raise ValueError("SyncRegistry not initialized - cannot read ChartStore") - - if not _datasource_registry: - raise ValueError("DataSourceRegistry not initialized - cannot query data") - - try: - # Import pandas and numpy here to allow lazy loading - import pandas as pd - import numpy as np - import matplotlib - matplotlib.use('Agg') # Non-interactive backend - import matplotlib.pyplot as plt - except ImportError as e: - raise ValueError( - f"Required library not installed: {e}. " - "Please install pandas, numpy, and matplotlib: pip install pandas numpy matplotlib" - ) - - # Get chart data using the internal helper function - result, chart_context, source_name = await _get_chart_data_impl(countback) - - # Build the same response format as get_chart_data - chart_data = result.model_dump() - chart_data["chart_context"] = chart_context - chart_data["source"] = source_name - - # Convert bars to DataFrame - bars = chart_data.get('bars', []) - if not bars: - return { - "chart_context": chart_data.get('chart_context', {}), - "source": chart_data.get('source', ''), - "error": "No data available for the current chart" - } - - # Build DataFrame - rows = [] - for bar in bars: - row = { - 'time': pd.to_datetime(bar['time'], unit='s'), - **bar['data'] # Includes open, high, low, close, volume, etc. - } - rows.append(row) - - df = pd.DataFrame(rows) - df.set_index('time', inplace=True) - - # Convert price columns to float for clean numeric operations - price_columns = ['open', 'high', 'low', 'close', 'volume'] - for col in price_columns: - if col in df.columns: - df[col] = pd.to_numeric(df[col], errors='coerce') - - logger.info( - f"Created DataFrame with {len(df)} rows, columns: {df.columns.tolist()}, " - f"time range: {df.index.min()} to {df.index.max()}, " - f"dtypes: {df.dtypes.to_dict()}" - ) - - # Prepare execution environment - script_globals = { - 'df': df, - 'pd': pd, - 'np': np, - 'plt': plt, - } - - # Capture stdout/stderr - stdout_capture = io.StringIO() - stderr_capture = io.StringIO() - - result_df = None - error_msg = None - plot_urls = [] - - # Determine uploads directory (relative to this file) - uploads_dir = Path(__file__).parent.parent.parent / "data" / "uploads" - uploads_dir.mkdir(parents=True, exist_ok=True) - - try: - with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture): - # Execute the script - exec(python_script, script_globals) - - # Check if the last line is an expression that returns a DataFrame - # We'll try to evaluate it separately - script_lines = python_script.strip().split('\n') - if script_lines: - last_line = script_lines[-1].strip() - # Only evaluate if it doesn't look like a statement - if last_line and not any(last_line.startswith(kw) for kw in ['if', 'for', 'while', 'def', 'class', 'import', 'from', 'with', 'try', 'return']): - try: - last_result = eval(last_line, script_globals) - if isinstance(last_result, pd.DataFrame): - result_df = last_result - except: - # If eval fails, that's okay - might not be an expression - pass - - # Save all matplotlib figures to disk - for fig_num in plt.get_fignums(): - fig = plt.figure(fig_num) - - # Generate unique filename - plot_id = str(uuid.uuid4()) - filename = f"plot_{plot_id}.png" - filepath = uploads_dir / filename - - # Save figure to file - fig.savefig(filepath, format='png', bbox_inches='tight', dpi=100) - - # Generate URL that can be accessed via the web server - plot_url = f"/uploads/{filename}" - plot_urls.append(plot_url) - - plt.close(fig) - - except Exception as e: - error_msg = f"{type(e).__name__}: {str(e)}" - import traceback - error_msg += f"\n{traceback.format_exc()}" - - # Build response - response = { - "chart_context": chart_data.get('chart_context', {}), - "source": chart_data.get('source', ''), - "script_output": stdout_capture.getvalue(), - } - - if error_msg: - response["error"] = error_msg - response["stderr"] = stderr_capture.getvalue() - - if result_df is not None: - # Convert DataFrame to dict for JSON serialization - response["result_dataframe"] = { - "columns": result_df.columns.tolist(), - "index": result_df.index.astype(str).tolist() if hasattr(result_df.index, 'astype') else result_df.index.tolist(), - "data": result_df.values.tolist(), - "shape": result_df.shape, - } - - if plot_urls: - response["plot_urls"] = plot_urls - - return response - - -# Export all tools -SYNC_TOOLS = [ - list_sync_stores, - read_sync_state, - write_sync_state, - get_store_schema -] - -DATASOURCE_TOOLS = [ - list_data_sources, - search_symbols, - get_symbol_info, - get_historical_data, - get_chart_data, - analyze_chart_data -] diff --git a/backend/src/agent/tools/CHART_UTILS_README.md b/backend/src/agent/tools/CHART_UTILS_README.md new file mode 100644 index 0000000..6853f87 --- /dev/null +++ b/backend/src/agent/tools/CHART_UTILS_README.md @@ -0,0 +1,139 @@ +# Chart Utilities - Standard OHLC Plotting + +## Overview + +The `chart_utils.py` module provides convenience functions for creating beautiful, professional OHLC candlestick charts with a consistent look and feel. This is designed to be used by the LLM in `analyze_chart_data` scripts, eliminating the need to write custom matplotlib code for every chart. + +## Key Features + +- **Beautiful by default**: Uses mplfinance with seaborn-inspired aesthetics +- **Consistent styling**: Professional color scheme (teal green up, coral red down) +- **Easy to use**: Simple function calls instead of complex matplotlib code +- **Customizable**: Supports all mplfinance options via kwargs +- **Volume integration**: Optional volume subplot + +## Installation + +The required package `mplfinance` has been added to `requirements.txt`: + +```bash +pip install mplfinance +``` + +## Available Functions + +### 1. `plot_ohlc(df, title=None, volume=True, figsize=(14, 8), **kwargs)` + +Main function for creating standard OHLC candlestick charts. + +**Parameters:** +- `df`: pandas DataFrame with DatetimeIndex and OHLCV columns +- `title`: Optional chart title +- `volume`: Whether to include volume subplot (default: True) +- `figsize`: Figure size in inches (default: (14, 8)) +- `**kwargs`: Additional mplfinance.plot() arguments + +**Example:** +```python +fig = plot_ohlc(df, title='BTC/USDT 15min', volume=True) +``` + +### 2. `add_indicators_to_plot(df, indicators, **plot_kwargs)` + +Creates OHLC chart with technical indicators overlaid. + +**Parameters:** +- `df`: DataFrame with OHLCV data and indicator columns +- `indicators`: Dict mapping indicator column names to display parameters +- `**plot_kwargs`: Additional arguments for plot_ohlc() + +**Example:** +```python +df['SMA_20'] = df['close'].rolling(20).mean() +df['SMA_50'] = df['close'].rolling(50).mean() + +fig = add_indicators_to_plot( + df, + indicators={ + 'SMA_20': {'color': 'blue', 'width': 1.5}, + 'SMA_50': {'color': 'red', 'width': 1.5} + }, + title='Price with Moving Averages' +) +``` + +### 3. Preset Functions + +- `plot_price_volume(df, title=None)` - Standard price + volume chart +- `plot_price_only(df, title=None)` - Candlesticks without volume + +## Integration with analyze_chart_data + +These functions are automatically available in the `analyze_chart_data` tool's script environment: + +```python +# In an analyze_chart_data script: +# df is already provided + +# Simple usage +fig = plot_ohlc(df, title='Price Action') + +# With indicators +df['SMA'] = df['close'].rolling(20).mean() +fig = add_indicators_to_plot( + df, + indicators={'SMA': {'color': 'blue', 'width': 1.5}}, + title='Price with SMA' +) + +# Return data for the assistant +df[['close', 'SMA']].tail(10) +``` + +## Styling + +The default style includes: +- **Up candles**: Teal green (#26a69a) +- **Down candles**: Coral red (#ef5350) +- **Background**: Light gray with white axes +- **Grid**: Subtle dashed lines with 30% alpha +- **Professional fonts**: Clean, readable sizes + +## Why This Matters + +**Before:** +```python +# LLM had to write this every time +import matplotlib.pyplot as plt +fig, ax = plt.subplots(figsize=(12, 6)) +ax.plot(df.index, df['close'], label='Close') +# ... lots more code for styling, colors, etc. +``` + +**After:** +```python +# LLM can now just do this +fig = plot_ohlc(df, title='BTC/USDT') +``` + +Benefits: +- ✅ Less code to generate → faster response +- ✅ Consistent appearance across all charts +- ✅ Professional look out of the box +- ✅ Easier to maintain and customize +- ✅ Better use of mplfinance's candlestick rendering + +## Example Output + +See `chart_utils_example.py` for runnable examples demonstrating: +1. Basic OHLC chart with volume +2. OHLC chart with multiple indicators +3. Price-only chart +4. Custom styling options + +## File Locations + +- **Main module**: `backend/src/agent/tools/chart_utils.py` +- **Integration**: `backend/src/agent/tools/chart_tools.py` (lines 306-328) +- **Examples**: `backend/src/agent/tools/chart_utils_example.py` +- **Dependency**: `backend/requirements.txt` (mplfinance added) diff --git a/backend/src/agent/tools/__init__.py b/backend/src/agent/tools/__init__.py new file mode 100644 index 0000000..13f3877 --- /dev/null +++ b/backend/src/agent/tools/__init__.py @@ -0,0 +1,50 @@ +"""Agent tools for trading operations. + +This package provides tools for: +- Synchronization stores (sync_tools) +- Data sources and market data (datasource_tools) +- Chart data access and analysis (chart_tools) +- Technical indicators (indicator_tools) +""" + +# Global registries that will be set by main.py +_registry = None +_datasource_registry = None +_indicator_registry = None + + +def set_registry(registry): + """Set the global SyncRegistry instance for tools to use.""" + global _registry + _registry = registry + + +def set_datasource_registry(datasource_registry): + """Set the global DataSourceRegistry instance for tools to use.""" + global _datasource_registry + _datasource_registry = datasource_registry + + +def set_indicator_registry(indicator_registry): + """Set the global IndicatorRegistry instance for tools to use.""" + global _indicator_registry + _indicator_registry = indicator_registry + + +# Import all tools from submodules +from .sync_tools import SYNC_TOOLS +from .datasource_tools import DATASOURCE_TOOLS +from .chart_tools import CHART_TOOLS +from .indicator_tools import INDICATOR_TOOLS +from .research_tools import RESEARCH_TOOLS + +__all__ = [ + "set_registry", + "set_datasource_registry", + "set_indicator_registry", + "SYNC_TOOLS", + "DATASOURCE_TOOLS", + "CHART_TOOLS", + "INDICATOR_TOOLS", + "RESEARCH_TOOLS", +] diff --git a/backend/src/agent/tools/chart_tools.py b/backend/src/agent/tools/chart_tools.py new file mode 100644 index 0000000..311593b --- /dev/null +++ b/backend/src/agent/tools/chart_tools.py @@ -0,0 +1,371 @@ +"""Chart data access and analysis tools.""" + +from typing import Dict, Any, Optional, Tuple +import io +import uuid +import logging +from pathlib import Path +from contextlib import redirect_stdout, redirect_stderr +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + + +def _get_registry(): + """Get the global registry instance.""" + from . import _registry + return _registry + + +def _get_datasource_registry(): + """Get the global datasource registry instance.""" + from . import _datasource_registry + return _datasource_registry + + +def _get_indicator_registry(): + """Get the global indicator registry instance.""" + from . import _indicator_registry + return _indicator_registry + + +async def _get_chart_data_impl(countback: Optional[int] = None): + """Internal implementation for getting chart data. + + This is a helper function that can be called by both get_chart_data tool + and analyze_chart_data tool. + + Returns: + Tuple of (HistoryResult, chart_context dict, source_name) + """ + registry = _get_registry() + datasource_registry = _get_datasource_registry() + + if not registry: + raise ValueError("SyncRegistry not initialized - cannot read ChartStore") + + if not datasource_registry: + raise ValueError("DataSourceRegistry not initialized - cannot query data") + + # Read current chart state + chart_store = registry.entries.get("ChartStore") + if not chart_store: + raise ValueError("ChartStore not found in registry") + + chart_state = chart_store.model.model_dump(mode="json") + chart_data = chart_state.get("chart_state", {}) + + symbol = chart_data.get("symbol", "") + interval = chart_data.get("interval", "15") + start_time = chart_data.get("start_time") + end_time = chart_data.get("end_time") + + if not symbol: + raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet") + + # Parse the symbol to extract exchange/source and symbol name + # Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD") + if ":" not in symbol: + raise ValueError( + f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' " + f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')" + ) + + exchange_prefix, symbol_name = symbol.split(":", 1) + source_name = exchange_prefix.lower() + + # Get the data source + source = datasource_registry.get(source_name) + if not source: + available = datasource_registry.list_sources() + raise ValueError( + f"Data source '{source_name}' not found. Available sources: {available}. " + f"Make sure the exchange in the symbol '{symbol}' matches an available source." + ) + + # Determine time range - REQUIRE it to be set, no defaults + if start_time is None or end_time is None: + raise ValueError( + f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. " + f"The user needs to load the chart first, or the frontend may not be sending the visible range. " + f"Wait for the chart to fully load before analyzing data." + ) + + from_time = int(start_time) + end_time = int(end_time) + logger.info( + f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, " + f"countback={countback}" + ) + + logger.info( + f"Querying data source '{source_name}' for symbol '{symbol_name}', " + f"resolution '{interval}'" + ) + + # Query the data source + result = await source.get_bars( + symbol=symbol_name, + resolution=interval, + from_time=from_time, + to_time=end_time, + countback=countback + ) + + logger.info( + f"Received {len(result.bars)} bars from data source. " + f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, " + f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}" + ) + + # Build chart context to return along with result + chart_context = { + "symbol": symbol, + "interval": interval, + "start_time": start_time, + "end_time": end_time + } + + return result, chart_context, source_name + + +@tool +async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]: + """Get the candle/bar data for what the user is currently viewing on their chart. + + This is a convenience tool that automatically: + 1. Reads the ChartStore to see what chart the user is viewing + 2. Parses the symbol to determine the data source (exchange prefix) + 3. Queries the appropriate data source for that symbol's data + 4. Returns the data for the visible time range and interval + + This is the preferred way to access chart data when helping the user analyze + what they're looking at, since it automatically uses their current chart context. + + Args: + countback: Optional limit on number of bars to return. If not specified, + returns all bars in the visible time range. + + Returns: + Dictionary containing: + - chart_context: Current chart state (symbol, interval, time range) + - symbol: The trading pair being viewed + - resolution: The chart interval + - bars: List of bar data with 'time' and 'data' fields + - columns: Schema describing available data columns + - source: Which data source was used + + Raises: + ValueError: If ChartStore or DataSourceRegistry is not initialized, + or if the symbol format is invalid + + Example: + # User is viewing BINANCE:BTC/USDT on 15min chart + data = get_chart_data() + # Returns BTC/USDT data from binance source at 15min resolution + # for the currently visible time range + """ + result, chart_context, source_name = await _get_chart_data_impl(countback) + + # Return enriched result with chart context + response = result.model_dump() + response["chart_context"] = chart_context + response["source"] = source_name + + return response + + +@tool +async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]: + """Execute Python code for technical analysis with automatic chart data loading. + + **PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.** + + This is your go-to tool whenever the user asks about indicators, wants to see + a chart, or needs any computational analysis of market data. + + Pre-loaded Environment: + - `pd` : pandas + - `np` : numpy + - `plt` : matplotlib.pyplot (figures auto-saved to plot_urls) + - `talib` : TA-Lib technical analysis library + - `indicator_registry`: 150+ registered indicators + - `plot_ohlc(df)` : Helper function for beautiful candlestick charts + + Auto-loaded when user has a chart open: + - `df` : pandas DataFrame with DatetimeIndex and columns: + open, high, low, close, volume (OHLCV data ready to use) + - `chart_context` : dict with symbol, interval, start_time, end_time + + The `plot_ohlc()` Helper: + Create professional candlestick charts instantly: + - `plot_ohlc(df)` - basic OHLC chart with volume + - `plot_ohlc(df, title='BTC 15min')` - with custom title + - `plot_ohlc(df, volume=False)` - price only, no volume + - Returns a matplotlib Figure that's automatically saved to plot_urls + + Args: + code: Python code to execute + countback: Optional limit on number of bars to load (default: all visible bars) + + Returns: + Dictionary with: + - script_output : printed output + last expression result + - result_dataframe : serialized DataFrame if last expression is a DataFrame + - plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"]) + - chart_context : {symbol, interval, start_time, end_time} or None + - error : traceback if execution failed + + Examples: + # RSI indicator with chart + execute_python(\"\"\" + df['RSI'] = talib.RSI(df['close'], 14) + fig = plot_ohlc(df, title='BTC/USDT with RSI') + print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}") + df[['close', 'RSI']].tail(5) + \"\"\") + + # Multiple indicators + execute_python(\"\"\" + df['SMA_20'] = df['close'].rolling(20).mean() + df['SMA_50'] = df['close'].rolling(50).mean() + df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std() + df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std() + + fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands") + + current_price = df['close'].iloc[-1] + sma20 = df['SMA_20'].iloc[-1] + print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}") + df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10) + \"\"\") + + # Pattern detection + execute_python(\"\"\" + # Find swing highs + df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1)) + swing_highs = df[df['swing_high']][['high']].tail(5) + + fig = plot_ohlc(df, title='Swing High Detection') + print("Recent swing highs:") + print(swing_highs) + \"\"\") + """ + import pandas as pd + import numpy as np + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + + try: + import talib + except ImportError: + talib = None + logger.warning("TA-Lib not available in execute_python environment") + + # --- Attempt to load chart data --- + df = None + chart_context = None + + registry = _get_registry() + datasource_registry = _get_datasource_registry() + + if registry and datasource_registry: + try: + result, chart_context, source_name = await _get_chart_data_impl(countback) + bars = result.bars + if bars: + rows = [] + for bar in bars: + rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data}) + df = pd.DataFrame(rows).set_index('time') + for col in ['open', 'high', 'low', 'close', 'volume']: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors='coerce') + logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}") + except Exception as e: + logger.info(f"execute_python: no chart data loaded ({e})") + + # --- Import chart utilities --- + from .chart_utils import plot_ohlc + + # --- Get indicator registry --- + indicator_registry = _get_indicator_registry() + + # --- Build globals --- + script_globals: Dict[str, Any] = { + 'pd': pd, + 'np': np, + 'plt': plt, + 'talib': talib, + 'indicator_registry': indicator_registry, + 'df': df, + 'chart_context': chart_context, + 'plot_ohlc': plot_ohlc, + } + + # --- Execute --- + uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads" + uploads_dir.mkdir(parents=True, exist_ok=True) + + stdout_capture = io.StringIO() + result_df = None + error_msg = None + plot_urls = [] + + try: + with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture): + exec(code, script_globals) + + # Capture last expression + lines = code.strip().splitlines() + if lines: + last = lines[-1].strip() + if last and not any(last.startswith(kw) for kw in ( + 'if', 'for', 'while', 'def', 'class', 'import', + 'from', 'with', 'try', 'return', '#' + )): + try: + last_val = eval(last, script_globals) + if isinstance(last_val, pd.DataFrame): + result_df = last_val + elif last_val is not None: + stdout_capture.write(str(last_val)) + except Exception: + pass + + # Save plots + for fig_num in plt.get_fignums(): + fig = plt.figure(fig_num) + filename = f"plot_{uuid.uuid4()}.png" + fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100) + plot_urls.append(f"/uploads/{filename}") + plt.close(fig) + + except Exception as e: + import traceback + error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" + + # --- Build response --- + response: Dict[str, Any] = { + 'script_output': stdout_capture.getvalue(), + 'chart_context': chart_context, + 'plot_urls': plot_urls, + } + if result_df is not None: + response['result_dataframe'] = { + 'columns': result_df.columns.tolist(), + 'index': result_df.index.astype(str).tolist(), + 'data': result_df.values.tolist(), + 'shape': result_df.shape, + } + if error_msg: + response['error'] = error_msg + + return response + + +CHART_TOOLS = [ + get_chart_data, + execute_python +] diff --git a/backend/src/agent/tools/chart_utils.py b/backend/src/agent/tools/chart_utils.py new file mode 100644 index 0000000..1dafbc8 --- /dev/null +++ b/backend/src/agent/tools/chart_utils.py @@ -0,0 +1,224 @@ +"""Chart plotting utilities for creating standard, beautiful OHLC charts.""" + +import pandas as pd +import matplotlib.pyplot as plt +from typing import Optional, Tuple +import logging + +logger = logging.getLogger(__name__) + + +def plot_ohlc( + df: pd.DataFrame, + title: Optional[str] = None, + volume: bool = True, + figsize: Tuple[int, int] = (14, 8), + style: str = 'seaborn-v0_8-darkgrid', + **kwargs +) -> plt.Figure: + """Create a beautiful standard OHLC candlestick chart. + + This is a convenience function that generates a professional-looking candlestick + chart with consistent styling across all generated charts. It uses mplfinance + with seaborn aesthetics for a polished appearance. + + Args: + df: pandas DataFrame with DatetimeIndex and columns: open, high, low, close, volume + title: Optional chart title. If None, uses symbol from chart context + volume: Whether to include volume subplot (default: True) + figsize: Figure size as (width, height) in inches (default: (14, 8)) + style: Base matplotlib style to use (default: 'seaborn-v0_8-darkgrid') + **kwargs: Additional arguments to pass to mplfinance.plot() + + Returns: + matplotlib.figure.Figure: The created figure object + + Example: + ```python + # Basic usage in analyze_chart_data script + fig = plot_ohlc(df, title='BTC/USDT 15min') + + # Customize with additional indicators + fig = plot_ohlc(df, volume=True, title='Price Action') + + # Add custom overlays after calling plot_ohlc + df['SMA20'] = df['close'].rolling(20).mean() + fig = plot_ohlc(df, title='With SMA') + # Note: For mplfinance overlays, use the mav or addplot parameters + ``` + + Note: + The DataFrame must have a DatetimeIndex and the standard OHLCV columns. + Column names should be lowercase: open, high, low, close, volume + """ + try: + import mplfinance as mpf + except ImportError: + raise ImportError( + "mplfinance is required for plot_ohlc(). " + "Install it with: pip install mplfinance" + ) + + # Validate DataFrame structure + required_cols = ['open', 'high', 'low', 'close'] + missing_cols = [col for col in required_cols if col not in df.columns] + if missing_cols: + raise ValueError( + f"DataFrame missing required columns: {missing_cols}. " + f"Required: {required_cols}" + ) + + if not isinstance(df.index, pd.DatetimeIndex): + raise ValueError( + "DataFrame must have a DatetimeIndex. " + "Convert with: df.index = pd.to_datetime(df.index)" + ) + + # Ensure volume column exists for volume plot + if volume and 'volume' not in df.columns: + logger.warning("volume=True but 'volume' column not found in DataFrame. Disabling volume.") + volume = False + + # Create custom style with seaborn aesthetics + # Using a professional color scheme: green for up candles, red for down candles + mc = mpf.make_marketcolors( + up='#26a69a', # Teal green (calmer than bright green) + down='#ef5350', # Coral red (softer than pure red) + edge='inherit', # Match candle color for edges + wick='inherit', # Match candle color for wicks + volume='in', # Volume bars colored by price direction + alpha=0.9 # Slight transparency for elegance + ) + + s = mpf.make_mpf_style( + base_mpf_style='charles', # Clean base style + marketcolors=mc, + rc={ + 'font.size': 10, + 'axes.labelsize': 11, + 'axes.titlesize': 12, + 'xtick.labelsize': 9, + 'ytick.labelsize': 9, + 'legend.fontsize': 10, + 'figure.facecolor': '#f0f0f0', + 'axes.facecolor': '#ffffff', + 'axes.grid': True, + 'grid.alpha': 0.3, + 'grid.linestyle': '--', + } + ) + + # Prepare plot parameters + plot_params = { + 'type': 'candle', + 'style': s, + 'volume': volume, + 'figsize': figsize, + 'tight_layout': True, + 'returnfig': True, + 'warn_too_much_data': 1000, # Warn if > 1000 candles for performance + } + + # Add title if provided + if title: + plot_params['title'] = title + + # Merge any additional kwargs + plot_params.update(kwargs) + + # Create the plot + logger.info( + f"Creating OHLC chart with {len(df)} candles, " + f"date range: {df.index.min()} to {df.index.max()}, " + f"volume: {volume}" + ) + + fig, axes = mpf.plot(df, **plot_params) + + return fig + + +def add_indicators_to_plot( + df: pd.DataFrame, + indicators: dict, + **plot_kwargs +) -> plt.Figure: + """Create an OHLC chart with technical indicators overlaid. + + This extends plot_ohlc() to include common technical indicators using + mplfinance's addplot functionality for proper overlay on candlestick charts. + + Args: + df: pandas DataFrame with OHLCV data and indicator columns + indicators: Dictionary mapping indicator names to parameters + Example: { + 'SMA_20': {'color': 'blue', 'width': 1.5}, + 'EMA_50': {'color': 'orange', 'width': 1.5} + } + **plot_kwargs: Additional arguments for plot_ohlc() + + Returns: + matplotlib.figure.Figure: The created figure object + + Example: + ```python + # Calculate indicators + df['SMA_20'] = df['close'].rolling(20).mean() + df['SMA_50'] = df['close'].rolling(50).mean() + + # Plot with indicators + fig = add_indicators_to_plot( + df, + indicators={ + 'SMA_20': {'color': 'blue', 'width': 1.5, 'label': '20 SMA'}, + 'SMA_50': {'color': 'red', 'width': 1.5, 'label': '50 SMA'} + }, + title='BTC/USDT with Moving Averages' + ) + ``` + """ + try: + import mplfinance as mpf + except ImportError: + raise ImportError( + "mplfinance is required. Install it with: pip install mplfinance" + ) + + # Build addplot list for indicators + addplots = [] + for indicator_col, params in indicators.items(): + if indicator_col not in df.columns: + logger.warning(f"Indicator column '{indicator_col}' not found in DataFrame. Skipping.") + continue + + color = params.get('color', 'blue') + width = params.get('width', 1.0) + panel = params.get('panel', 0) # 0 = main panel with candles + ylabel = params.get('ylabel', '') + + addplots.append( + mpf.make_addplot( + df[indicator_col], + color=color, + width=width, + panel=panel, + ylabel=ylabel + ) + ) + + # Pass addplot to plot_ohlc via kwargs + if addplots: + plot_kwargs['addplot'] = addplots + + return plot_ohlc(df, **plot_kwargs) + + +# Convenience presets for common chart types +def plot_price_volume(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure: + """Create a standard price + volume chart.""" + return plot_ohlc(df, title=title, volume=True, figsize=(14, 8)) + + +def plot_price_only(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure: + """Create a price-only candlestick chart without volume.""" + return plot_ohlc(df, title=title, volume=False, figsize=(14, 6)) diff --git a/backend/src/agent/tools/chart_utils_example.py b/backend/src/agent/tools/chart_utils_example.py new file mode 100644 index 0000000..d05776e --- /dev/null +++ b/backend/src/agent/tools/chart_utils_example.py @@ -0,0 +1,154 @@ +""" +Example usage of chart_utils.py plotting functions. + +This demonstrates how the LLM can use the plot_ohlc() convenience function +in analyze_chart_data scripts to create beautiful, standard OHLC charts. +""" + +import pandas as pd +import numpy as np +from datetime import datetime, timedelta + + +def create_sample_data(days=30): + """Create sample OHLCV data for testing.""" + dates = pd.date_range(end=datetime.now(), periods=days * 24, freq='1H') + + # Simulate price movement + np.random.seed(42) + close = 50000 + np.cumsum(np.random.randn(len(dates)) * 100) + + data = { + 'open': close + np.random.randn(len(dates)) * 50, + 'high': close + np.abs(np.random.randn(len(dates))) * 100, + 'low': close - np.abs(np.random.randn(len(dates))) * 100, + 'close': close, + 'volume': np.abs(np.random.randn(len(dates))) * 1000000 + } + + df = pd.DataFrame(data, index=dates) + + # Ensure high is highest and low is lowest + df['high'] = df[['open', 'high', 'low', 'close']].max(axis=1) + df['low'] = df[['open', 'high', 'low', 'close']].min(axis=1) + + return df + + +if __name__ == "__main__": + from chart_utils import plot_ohlc, add_indicators_to_plot, plot_price_volume + + # Create sample data + df = create_sample_data(days=30) + + print("=" * 60) + print("Example 1: Basic OHLC chart with volume") + print("=" * 60) + print("\nScript the LLM would generate:") + print(""" +fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True) +df.tail(5) + """) + + # Execute it + fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True) + print("\n✓ Chart created successfully!") + print(f" Figure size: {fig.get_size_inches()}") + print(f" Number of axes: {len(fig.axes)}") + + + print("\n" + "=" * 60) + print("Example 2: OHLC chart with indicators") + print("=" * 60) + print("\nScript the LLM would generate:") + print(""" +# Calculate indicators +df['SMA_20'] = df['close'].rolling(20).mean() +df['SMA_50'] = df['close'].rolling(50).mean() +df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean() + +# Plot with indicators +fig = add_indicators_to_plot( + df, + indicators={ + 'SMA_20': {'color': 'blue', 'width': 1.5}, + 'SMA_50': {'color': 'red', 'width': 1.5}, + 'EMA_12': {'color': 'green', 'width': 1.0} + }, + title='BTC/USDT with Moving Averages', + volume=True +) + +df[['close', 'SMA_20', 'SMA_50', 'EMA_12']].tail(5) + """) + + # Execute it + df['SMA_20'] = df['close'].rolling(20).mean() + df['SMA_50'] = df['close'].rolling(50).mean() + df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean() + + fig = add_indicators_to_plot( + df, + indicators={ + 'SMA_20': {'color': 'blue', 'width': 1.5}, + 'SMA_50': {'color': 'red', 'width': 1.5}, + 'EMA_12': {'color': 'green', 'width': 1.0} + }, + title='BTC/USDT with Moving Averages', + volume=True + ) + + print("\n✓ Chart with indicators created successfully!") + print(f" Last close: ${df['close'].iloc[-1]:,.2f}") + print(f" SMA 20: ${df['SMA_20'].iloc[-1]:,.2f}") + print(f" SMA 50: ${df['SMA_50'].iloc[-1]:,.2f}") + + + print("\n" + "=" * 60) + print("Example 3: Price-only chart (no volume)") + print("=" * 60) + print("\nScript the LLM would generate:") + print(""" +from chart_utils import plot_price_only + +fig = plot_price_only(df, title='Clean Price Action') + """) + + # Execute it + from chart_utils import plot_price_only + fig = plot_price_only(df, title='Clean Price Action') + + print("\n✓ Price-only chart created successfully!") + + + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + print(""" +The chart_utils module provides: + +1. plot_ohlc() - Main function for beautiful candlestick charts + - Professional seaborn-inspired styling + - Consistent color scheme (teal up, coral down) + - Optional volume subplot + - Customizable figure size + +2. add_indicators_to_plot() - OHLC charts with technical indicators + - Overlay multiple indicators + - Customizable colors and line widths + - Proper integration with mplfinance + +3. Preset functions for common chart types: + - plot_price_volume() - Standard price + volume + - plot_price_only() - Candlesticks without volume + +Benefits: +✓ Consistent look and feel across all charts +✓ Less code for the LLM to generate +✓ Professional appearance out of the box +✓ Easy to customize when needed +✓ Works seamlessly with analyze_chart_data tool + +The LLM can now simply call plot_ohlc(df) instead of writing +custom matplotlib code for every chart request! + """) diff --git a/backend/src/agent/tools/datasource_tools.py b/backend/src/agent/tools/datasource_tools.py new file mode 100644 index 0000000..f4fa361 --- /dev/null +++ b/backend/src/agent/tools/datasource_tools.py @@ -0,0 +1,158 @@ +"""Data source and market data tools.""" + +from typing import Dict, Any, List, Optional +from langchain_core.tools import tool + + +def _get_datasource_registry(): + """Get the global datasource registry instance.""" + from . import _datasource_registry + return _datasource_registry + + +@tool +def list_data_sources() -> List[str]: + """List all available data sources. + + Returns: + List of data source names that can be queried for market data + """ + registry = _get_datasource_registry() + if not registry: + return [] + return registry.list_sources() + + +@tool +async def search_symbols( + query: str, + type: Optional[str] = None, + exchange: Optional[str] = None, + limit: int = 30, +) -> Dict[str, Any]: + """Search for trading symbols across all data sources. + + Automatically searches all available data sources and returns aggregated results. + Use this to find symbols before calling get_symbol_info or get_historical_data. + + Args: + query: Search query (e.g., "BTC", "AAPL", "EUR") + type: Optional filter by instrument type (e.g., "crypto", "stock", "forex") + exchange: Optional filter by exchange (e.g., "binance", "nasdaq") + limit: Maximum number of results per source (default: 30) + + Returns: + Dictionary mapping source names to lists of matching symbols. + Each symbol includes: symbol, full_name, description, exchange, type. + Use the source name and symbol from results with get_symbol_info or get_historical_data. + + Example response: + { + "demo": [ + { + "symbol": "BTC/USDT", + "full_name": "Bitcoin / Tether USD", + "description": "Bitcoin perpetual futures", + "exchange": "demo", + "type": "crypto" + } + ] + } + """ + registry = _get_datasource_registry() + if not registry: + raise ValueError("DataSourceRegistry not initialized") + + # Always search all sources + results = await registry.search_all(query, type, exchange, limit) + return {name: [r.model_dump() for r in matches] for name, matches in results.items()} + + +@tool +async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]: + """Get complete metadata for a trading symbol. + + This retrieves full information about a symbol including: + - Description and type + - Supported time resolutions + - Available data columns (OHLCV, volume, funding rates, etc.) + - Trading session information + - Price scale and precision + + Args: + source_name: Name of the data source (use list_data_sources to see available) + symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD") + + Returns: + Dictionary containing complete symbol metadata including column schema + + Raises: + ValueError: If source_name or symbol is not found + """ + registry = _get_datasource_registry() + if not registry: + raise ValueError("DataSourceRegistry not initialized") + + symbol_info = await registry.resolve_symbol(source_name, symbol) + return symbol_info.model_dump() + + +@tool +async def get_historical_data( + source_name: str, + symbol: str, + resolution: str, + from_time: int, + to_time: int, + countback: Optional[int] = None, +) -> Dict[str, Any]: + """Get historical bar/candle data for a symbol. + + Retrieves time-series data between the specified timestamps. The data + includes all columns defined for the symbol (OHLCV + any custom columns). + + Args: + source_name: Name of the data source + symbol: Symbol identifier + resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day) + from_time: Start time as Unix timestamp in seconds + to_time: End time as Unix timestamp in seconds + countback: Optional limit on number of bars to return + + Returns: + Dictionary containing: + - symbol: The requested symbol + - resolution: The time resolution + - bars: List of bar data with 'time' and 'data' fields + - columns: Schema describing available data columns + - nextTime: If present, indicates more data is available for pagination + + Raises: + ValueError: If source, symbol, or resolution is invalid + + Example: + # Get 1-hour BTC data for the last 24 hours + import time + to_time = int(time.time()) + from_time = to_time - 86400 # 24 hours ago + data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time) + """ + registry = _get_datasource_registry() + if not registry: + raise ValueError("DataSourceRegistry not initialized") + + source = registry.get(source_name) + if not source: + available = registry.list_sources() + raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}") + + result = await source.get_bars(symbol, resolution, from_time, to_time, countback) + return result.model_dump() + + +DATASOURCE_TOOLS = [ + list_data_sources, + search_symbols, + get_symbol_info, + get_historical_data, +] diff --git a/backend/src/agent/tools/indicator_tools.py b/backend/src/agent/tools/indicator_tools.py new file mode 100644 index 0000000..939876c --- /dev/null +++ b/backend/src/agent/tools/indicator_tools.py @@ -0,0 +1,169 @@ +"""Technical indicator tools.""" + +from typing import Dict, Any, List, Optional +from langchain_core.tools import tool + + +def _get_indicator_registry(): + """Get the global indicator registry instance.""" + from . import _indicator_registry + return _indicator_registry + + +@tool +def list_indicators() -> List[str]: + """List all available technical indicators. + + Returns: + List of indicator names that can be used in analysis and strategies + """ + registry = _get_indicator_registry() + if not registry: + return [] + return registry.list_indicators() + + +@tool +def get_indicator_info(indicator_name: str) -> Dict[str, Any]: + """Get detailed information about a specific indicator. + + Retrieves metadata including description, parameters, category, use cases, + input/output schemas, and references. + + Args: + indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD") + + Returns: + Dictionary containing: + - name: Indicator name + - display_name: Human-readable name + - description: What the indicator computes and why it's useful + - category: Category (momentum, trend, volatility, volume, etc.) + - parameters: List of configurable parameters with types and defaults + - use_cases: Common trading scenarios where this indicator helps + - tags: Searchable tags + - input_schema: Required input columns (e.g., OHLCV requirements) + - output_schema: Columns this indicator produces + + Raises: + ValueError: If indicator_name is not found + """ + registry = _get_indicator_registry() + if not registry: + raise ValueError("IndicatorRegistry not initialized") + + metadata = registry.get_metadata(indicator_name) + if not metadata: + total_count = len(registry.list_indicators()) + raise ValueError( + f"Indicator '{indicator_name}' not found. " + f"Total available: {total_count} indicators. " + f"Use search_indicators() to find indicators by name, category, or tag." + ) + + input_schema = registry.get_input_schema(indicator_name) + output_schema = registry.get_output_schema(indicator_name) + + result = metadata.model_dump() + result["input_schema"] = input_schema.model_dump() if input_schema else None + result["output_schema"] = output_schema.model_dump() if output_schema else None + + return result + + +@tool +def search_indicators( + query: Optional[str] = None, + category: Optional[str] = None, + tag: Optional[str] = None +) -> List[Dict[str, Any]]: + """Search for indicators by text query, category, or tag. + + Returns lightweight summaries - use get_indicator_info() for full details on specific indicators. + + Use this to discover relevant indicators for your trading strategy or analysis. + Can filter by category (momentum, trend, volatility, etc.) or search by keywords. + + Args: + query: Optional text search across names, descriptions, and use cases + category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.) + tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib") + + Returns: + List of lightweight indicator summaries. Each contains: + - name: Indicator name (use with get_indicator_info() for full details) + - display_name: Human-readable name + - description: Brief one-line description + - category: Category (momentum, trend, volatility, etc.) + + Example: + # Find all momentum indicators + results = search_indicators(category="momentum") + # Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...] + + # Then get details on interesting ones + rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases + + # Search for moving average indicators + search_indicators(query="moving average") + + # Find all TA-Lib indicators + search_indicators(tag="talib") + """ + registry = _get_indicator_registry() + if not registry: + raise ValueError("IndicatorRegistry not initialized") + + results = [] + + if query: + results = registry.search_by_text(query) + elif category: + results = registry.search_by_category(category) + elif tag: + results = registry.search_by_tag(tag) + else: + # Return all indicators if no filter + results = registry.get_all_metadata() + + # Return lightweight summaries only + return [ + { + "name": r.name, + "display_name": r.display_name, + "description": r.description, + "category": r.category + } + for r in results + ] + + +@tool +def get_indicator_categories() -> Dict[str, int]: + """Get all indicator categories and their counts. + + Returns a summary of available indicator categories, useful for + exploring what types of indicators are available. + + Returns: + Dictionary mapping category name to count of indicators in that category. + Example: {"momentum": 25, "trend": 15, "volatility": 8, ...} + """ + registry = _get_indicator_registry() + if not registry: + raise ValueError("IndicatorRegistry not initialized") + + categories: Dict[str, int] = {} + for metadata in registry.get_all_metadata(): + category = metadata.category + categories[category] = categories.get(category, 0) + 1 + + return categories + + +INDICATOR_TOOLS = [ + list_indicators, + get_indicator_info, + search_indicators, + get_indicator_categories +] diff --git a/backend/src/agent/tools/research_tools.py b/backend/src/agent/tools/research_tools.py new file mode 100644 index 0000000..52b971b --- /dev/null +++ b/backend/src/agent/tools/research_tools.py @@ -0,0 +1,171 @@ +"""Research and external data tools for trading analysis.""" + +from typing import Dict, Any, Optional +from langchain_core.tools import tool +from langchain_community.tools import ( + ArxivQueryRun, + WikipediaQueryRun, + DuckDuckGoSearchRun +) +from langchain_community.utilities import ( + ArxivAPIWrapper, + WikipediaAPIWrapper, + DuckDuckGoSearchAPIWrapper +) + + +@tool +def search_arxiv(query: str, max_results: int = 5) -> str: + """Search arXiv for academic papers on quantitative finance, trading strategies, and machine learning. + + Use this to find research papers on topics like: + - Market microstructure and order flow + - Algorithmic trading strategies + - Machine learning for finance + - Time series forecasting + - Risk management + - Portfolio optimization + + Args: + query: Search query (e.g., "machine learning algorithmic trading", "deep learning stock prediction") + max_results: Maximum number of results to return (default: 5) + + Returns: + Summary of papers including titles, authors, abstracts, and links + + Example: + search_arxiv("reinforcement learning trading", max_results=3) + """ + arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=max_results)) + return arxiv.run(query) + + +@tool +def search_wikipedia(query: str) -> str: + """Search Wikipedia for information on finance, trading, and economics concepts. + + Use this to get background information on: + - Financial instruments and markets + - Economic indicators + - Trading terminology + - Technical analysis concepts + - Historical market events + + Args: + query: Search query (e.g., "Black-Scholes model", "technical analysis", "options trading") + + Returns: + Wikipedia article summary with key information + + Example: + search_wikipedia("Bollinger Bands") + """ + wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) + return wikipedia.run(query) + + +@tool +def search_web(query: str, max_results: int = 5) -> str: + """Search the web for current information on markets, news, and trading. + + Use this to find: + - Latest market news and analysis + - Company announcements and earnings + - Economic events and indicators + - Cryptocurrency updates + - Exchange status and updates + - Trading strategy discussions + + Args: + query: Search query (e.g., "Bitcoin price news", "Fed interest rate decision") + max_results: Maximum number of results to return (default: 5) + + Returns: + Search results with titles, snippets, and links + + Example: + search_web("Ethereum merge update", max_results=3) + """ + # Lazy initialization to avoid hanging during import + search = DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper()) + # Note: max_results parameter doesn't work properly with current wrapper + return search.run(query) + + +@tool +def http_get(url: str, params: Optional[Dict[str, str]] = None) -> str: + """Make HTTP GET request to fetch data from APIs or web pages. + + Use this to retrieve: + - Exchange API data (if public endpoints) + - Market data from external APIs + - Documentation and specifications + - News articles and blog posts + - JSON/XML data from web services + + Args: + url: The URL to fetch + params: Optional query parameters as a dictionary + + Returns: + Response text from the URL + + Raises: + ValueError: If the request fails + + Example: + http_get("https://api.coingecko.com/api/v3/simple/price", + params={"ids": "bitcoin", "vs_currencies": "usd"}) + """ + import requests + + try: + response = requests.get(url, params=params, timeout=10) + response.raise_for_status() + return response.text + except requests.RequestException as e: + raise ValueError(f"HTTP GET request failed: {str(e)}") + + +@tool +def http_post(url: str, data: Dict[str, Any]) -> str: + """Make HTTP POST request to send data to APIs. + + Use this to: + - Submit data to external APIs + - Trigger webhooks + - Post analysis results + - Interact with exchange APIs (if authenticated) + + Args: + url: The URL to post to + data: Dictionary of data to send in the request body + + Returns: + Response text from the server + + Raises: + ValueError: If the request fails + + Example: + http_post("https://webhook.site/xxx", {"message": "Trade executed"}) + """ + import requests + import json + + try: + response = requests.post(url, json=data, timeout=10) + response.raise_for_status() + return response.text + except requests.RequestException as e: + raise ValueError(f"HTTP POST request failed: {str(e)}") + + +# Export tools list +RESEARCH_TOOLS = [ + search_arxiv, + search_wikipedia, + search_web, + http_get, + http_post +] diff --git a/backend/src/agent/tools/sync_tools.py b/backend/src/agent/tools/sync_tools.py new file mode 100644 index 0000000..1829f31 --- /dev/null +++ b/backend/src/agent/tools/sync_tools.py @@ -0,0 +1,138 @@ +"""Synchronization store tools.""" + +from typing import Dict, Any, List +from langchain_core.tools import tool + + +def _get_registry(): + """Get the global registry instance.""" + from . import _registry + return _registry + + +@tool +def list_sync_stores() -> List[str]: + """List all available synchronization stores. + + Returns: + List of store names that can be read/written + """ + registry = _get_registry() + if not registry: + return [] + return list(registry.entries.keys()) + + +@tool +def read_sync_state(store_name: str) -> Dict[str, Any]: + """Read the current state of a synchronization store. + + Args: + store_name: Name of the store to read (e.g., "TraderState", "StrategyState") + + Returns: + Dictionary containing the current state of the store + + Raises: + ValueError: If store_name doesn't exist + """ + registry = _get_registry() + if not registry: + raise ValueError("SyncRegistry not initialized") + + entry = registry.entries.get(store_name) + if not entry: + available = list(registry.entries.keys()) + raise ValueError(f"Store '{store_name}' not found. Available stores: {available}") + + return entry.model.model_dump(mode="json") + + +@tool +async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]: + """Update the state of a synchronization store. + + This will apply the updates to the store and trigger synchronization + with all connected clients. + + Args: + store_name: Name of the store to update + updates: Dictionary of field updates (field_name: new_value) + + Returns: + Dictionary with status and updated fields + + Raises: + ValueError: If store_name doesn't exist or updates are invalid + """ + registry = _get_registry() + if not registry: + raise ValueError("SyncRegistry not initialized") + + entry = registry.entries.get(store_name) + if not entry: + available = list(registry.entries.keys()) + raise ValueError(f"Store '{store_name}' not found. Available stores: {available}") + + try: + # Get current state + current_state = entry.model.model_dump(mode="json") + + # Apply updates + new_state = {**current_state, **updates} + + # Update the model + registry._update_model(entry.model, new_state) + + # Trigger sync + await registry.push_all() + + return { + "status": "success", + "store": store_name, + "updated_fields": list(updates.keys()) + } + + except Exception as e: + raise ValueError(f"Failed to update store '{store_name}': {str(e)}") + + +@tool +def get_store_schema(store_name: str) -> Dict[str, Any]: + """Get the schema/structure of a synchronization store. + + This shows what fields are available and their types. + + Args: + store_name: Name of the store + + Returns: + Dictionary describing the store's schema + + Raises: + ValueError: If store_name doesn't exist + """ + registry = _get_registry() + if not registry: + raise ValueError("SyncRegistry not initialized") + + entry = registry.entries.get(store_name) + if not entry: + available = list(registry.entries.keys()) + raise ValueError(f"Store '{store_name}' not found. Available stores: {available}") + + # Get model schema + schema = entry.model.model_json_schema() + + return { + "store_name": store_name, + "schema": schema + } + + +SYNC_TOOLS = [ + list_sync_stores, + read_sync_state, + write_sync_state, + get_store_schema +] diff --git a/backend/src/datasource/adapters/ccxt_adapter.py b/backend/src/datasource/adapters/ccxt_adapter.py index 7f89f81..0d813c3 100644 --- a/backend/src/datasource/adapters/ccxt_adapter.py +++ b/backend/src/datasource/adapters/ccxt_adapter.py @@ -6,9 +6,10 @@ the free CCXT library (not ccxt.pro), supporting both historical data and polling-based subscriptions. Numerical Precision: -- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors +- OHLCV data uses native floats for optimal DataFrame/analysis performance +- Account balances and order data should use Decimal (via _to_decimal method) - CCXT returns numeric values as strings or floats depending on configuration -- All financial values are converted to Decimal to maintain precision +- Price data converted to float (_to_float), financial data to Decimal (_to_decimal) Real-time Updates: - Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support) @@ -72,6 +73,20 @@ class CCXTDataSource(DataSource): exchange_class = getattr(ccxt, exchange_id) self.exchange = exchange_class(self._config) + # Configure CCXT to use Decimal mode for precise financial calculations + # This ensures all numeric values from the exchange use Decimal internally + # We then convert OHLCV to float for DataFrame performance, but keep + # Decimal precision for account balances, order sizes, etc. + from decimal import Decimal as PythonDecimal + self.exchange.number = PythonDecimal + + # Log the precision mode being used by this exchange + precision_mode = getattr(self.exchange, 'precisionMode', 'UNKNOWN') + logger.info( + f"CCXT {exchange_id}: Configured with Decimal mode. " + f"Exchange precision mode: {precision_mode}" + ) + if sandbox and hasattr(self.exchange, 'set_sandbox_mode'): self.exchange.set_sandbox_mode(True) @@ -103,6 +118,33 @@ class CCXTDataSource(DataSource): return Decimal(str(value)) return None + @staticmethod + def _to_float(value: Union[str, int, float, Decimal, None]) -> Optional[float]: + """ + Convert a value to float for OHLCV data. + + OHLCV data is used for charting and DataFrame analysis, where native + floats provide better performance and compatibility with pandas/numpy. + For financial precision (balances, order sizes), use _to_decimal() instead. + + When CCXT is in Decimal mode (exchange.number = Decimal), it returns + Decimal objects. This method converts them to float for performance. + + Handles CCXT's output in both modes: + - Decimal mode: receives Decimal objects + - Default mode: receives strings, floats, or ints + """ + if value is None: + return None + if isinstance(value, float): + return value + if isinstance(value, Decimal): + # CCXT in Decimal mode - convert to float for OHLCV + return float(value) + if isinstance(value, (str, int)): + return float(value) + return None + async def _ensure_markets_loaded(self): """Ensure markets are loaded from exchange""" if not self._markets_loaded: @@ -241,31 +283,31 @@ class CCXTDataSource(DataSource): columns=[ ColumnInfo( name="open", - type="decimal", + type="float", description=f"Opening price in {quote}", unit=quote, ), ColumnInfo( name="high", - type="decimal", + type="float", description=f"Highest price in {quote}", unit=quote, ), ColumnInfo( name="low", - type="decimal", + type="float", description=f"Lowest price in {quote}", unit=quote, ), ColumnInfo( name="close", - type="decimal", + type="float", description=f"Closing price in {quote}", unit=quote, ), ColumnInfo( name="volume", - type="decimal", + type="float", description=f"Trading volume in {base}", unit=base, ), @@ -370,7 +412,7 @@ class CCXTDataSource(DataSource): all_ohlcv = all_ohlcv[:countback] break - # Convert to our Bar format with Decimal precision + # Convert to our Bar format with float for OHLCV (used in DataFrames) bars = [] for candle in all_ohlcv: timestamp_ms, open_price, high, low, close, volume = candle @@ -384,11 +426,11 @@ class CCXTDataSource(DataSource): Bar( time=timestamp, data={ - "open": self._to_decimal(open_price), - "high": self._to_decimal(high), - "low": self._to_decimal(low), - "close": self._to_decimal(close), - "volume": self._to_decimal(volume), + "open": self._to_float(open_price), + "high": self._to_float(high), + "low": self._to_float(low), + "close": self._to_float(close), + "volume": self._to_float(volume), }, ) ) @@ -476,14 +518,14 @@ class CCXTDataSource(DataSource): if timestamp > last_timestamp: self._last_bars[subscription_id] = timestamp - # Convert to our format with Decimal precision + # Convert to our format with float for OHLCV (used in DataFrames) tick_data = { "time": timestamp, - "open": self._to_decimal(open_price), - "high": self._to_decimal(high), - "low": self._to_decimal(low), - "close": self._to_decimal(close), - "volume": self._to_decimal(volume), + "open": self._to_float(open_price), + "high": self._to_float(high), + "low": self._to_float(low), + "close": self._to_float(close), + "volume": self._to_float(volume), } # Call the callback diff --git a/backend/src/exchange_kernel/README.md b/backend/src/exchange_kernel/README.md new file mode 100644 index 0000000..c120683 --- /dev/null +++ b/backend/src/exchange_kernel/README.md @@ -0,0 +1,179 @@ +# Exchange Kernel API + +A Kubernetes-style declarative API for managing orders across different exchanges. + +## Architecture Overview + +The Exchange Kernel maintains two separate views of order state: + +1. **Desired State (Intent)**: What the strategy kernel wants +2. **Actual State (Reality)**: What currently exists on the exchange + +A reconciliation loop continuously works to bring actual state into alignment with desired state, handling errors, retries, and edge cases automatically. + +## Core Components + +### Models (`models.py`) + +- **OrderIntent**: Desired order state from strategy kernel +- **OrderState**: Actual current order state on exchange +- **Position**: Current position (spot, margin, perp, futures, options) +- **Asset**: Asset holdings with metadata +- **AccountState**: Complete account snapshot (balances, positions, margin) +- **AssetMetadata**: Asset type descriptions and trading parameters + +### Events (`events.py`) + +Order lifecycle events: +- `OrderSubmitted`, `OrderAccepted`, `OrderRejected` +- `OrderPartiallyFilled`, `OrderFilled`, `OrderCanceled` +- `OrderModified`, `OrderExpired` + +Position events: +- `PositionOpened`, `PositionModified`, `PositionClosed` + +Account events: +- `AccountBalanceUpdated`, `MarginCallWarning` + +### Base Interface (`base.py`) + +Abstract `ExchangeKernel` class defining: + +**Command API**: +- `place_order()`, `place_order_group()` - Create order intents +- `cancel_order()`, `modify_order()` - Update intents +- `cancel_all_orders()` - Bulk cancellation + +**Query API**: +- `get_order_intent()`, `get_order_state()` - Query single order +- `get_all_intents()`, `get_all_orders()` - Query all orders +- `get_positions()`, `get_account_state()` - Query positions/balances +- `get_symbol_metadata()`, `get_asset_metadata()` - Query market info + +**Event API**: +- `subscribe_events()`, `unsubscribe_events()` - Event notifications + +**Lifecycle**: +- `start()`, `stop()` - Kernel lifecycle +- `health_check()` - Connection status +- `force_reconciliation()` - Manual reconciliation trigger + +### State Management (`state.py`) + +- **IntentStateStore**: Storage for desired state (durable, survives restarts) +- **ActualStateStore**: Storage for actual exchange state (ephemeral cache) +- **ReconciliationEngine**: Framework for intent→reality reconciliation +- **InMemory implementations**: For testing/prototyping + +## Standard Order Model + +Defined in `schema/order_spec.py`: + +```python +StandardOrder( + symbol_id="BTC/USD", + side=Side.BUY, + amount=1.0, + amount_type=AmountType.BASE, # or QUOTE for exact-out + limit_price=50000.0, # None for market orders + time_in_force=TimeInForce.GTC, + conditional_trigger=ConditionalTrigger(...), # Optional stop-loss/take-profit + conditional_mode=ConditionalOrderMode.UNIFIED_ADJUSTING, + reduce_only=False, + post_only=False, + iceberg_qty=None, +) +``` + +## Symbol Metadata + +Markets describe their capabilities via `SymbolMetadata`: + +- **AmountConstraints**: Min/max order size, step size +- **PriceConstraints**: Tick size, tick spacing mode (fixed/dynamic/continuous) +- **MarketCapabilities**: + - Supported sides (BUY, SELL) + - Supported amount types (BASE, QUOTE, or both) + - Market vs limit order support + - Time-in-force options (GTC, IOC, FOK, DAY, GTD) + - Conditional order support (stop-loss, take-profit, trailing stops) + - Advanced features (post-only, reduce-only, iceberg) + +## Asset Types + +Comprehensive asset type system supporting: +- **SPOT**: Cash markets +- **MARGIN**: Margin trading +- **PERP**: Perpetual futures +- **FUTURE**: Dated futures +- **OPTION**: Options contracts +- **SYNTHETIC**: Derived instruments + +Each asset has metadata describing contract specs, settlement, margin requirements, etc. + +## Usage Pattern + +```python +# Create exchange kernel for specific exchange +kernel = SomeExchangeKernel(exchange_id="binance_main") + +# Subscribe to events +kernel.subscribe_events(my_event_handler) + +# Start kernel +await kernel.start() + +# Place order (creates intent, kernel handles execution) +intent_id = await kernel.place_order( + StandardOrder( + symbol_id="BTC/USD", + side=Side.BUY, + amount=1.0, + amount_type=AmountType.BASE, + limit_price=50000.0, + ) +) + +# Query desired state +intent = await kernel.get_order_intent(intent_id) + +# Query actual state +state = await kernel.get_order_state(intent_id) + +# Modify order (updates intent, kernel reconciles) +await kernel.modify_order(intent_id, new_order) + +# Cancel order +await kernel.cancel_order(intent_id) + +# Query positions +positions = await kernel.get_positions() + +# Query account state +account = await kernel.get_account_state() +``` + +## Implementation Status + +✅ **Complete**: +- Data models and type definitions +- Event definitions +- Abstract interface +- State store framework +- In-memory stores for testing + +⏳ **TODO** (Exchange-specific implementations): +- Concrete ExchangeKernel implementations per exchange +- Reconciliation engine implementation +- Exchange API adapters +- Persistent state storage (database) +- Error handling and retry logic +- Monitoring and observability + +## Next Steps + +1. Create concrete implementations for specific exchanges (Binance, Uniswap, etc.) +2. Implement reconciliation engine with proper error handling +3. Add persistent storage backend for intents +4. Build integration tests +5. Add monitoring/metrics collection diff --git a/backend/src/exchange_kernel/__init__.py b/backend/src/exchange_kernel/__init__.py new file mode 100644 index 0000000..70d9152 --- /dev/null +++ b/backend/src/exchange_kernel/__init__.py @@ -0,0 +1,75 @@ +""" +Exchange Kernel API + +The exchange kernel provides a Kubernetes-style declarative API for managing orders +across different exchanges. It maintains both desired state (intent) and actual state +(current orders on exchange) and reconciles them continuously. + +Key concepts: +- OrderIntent: What the strategy kernel wants +- OrderState: What actually exists on the exchange +- Reconciliation: Bringing actual state into alignment with desired state +""" + +from .base import ExchangeKernel +from .events import ( + OrderEvent, + OrderSubmitted, + OrderAccepted, + OrderRejected, + OrderPartiallyFilled, + OrderFilled, + OrderCanceled, + OrderModified, + OrderExpired, + PositionEvent, + PositionOpened, + PositionModified, + PositionClosed, + AccountEvent, + AccountBalanceUpdated, + MarginCallWarning, +) +from .models import ( + OrderIntent, + OrderState, + Position, + Asset, + AssetMetadata, + AccountState, + Balance, +) +from .state import IntentStateStore, ActualStateStore + +__all__ = [ + # Core interface + "ExchangeKernel", + # Events + "OrderEvent", + "OrderSubmitted", + "OrderAccepted", + "OrderRejected", + "OrderPartiallyFilled", + "OrderFilled", + "OrderCanceled", + "OrderModified", + "OrderExpired", + "PositionEvent", + "PositionOpened", + "PositionModified", + "PositionClosed", + "AccountEvent", + "AccountBalanceUpdated", + "MarginCallWarning", + # Models + "OrderIntent", + "OrderState", + "Position", + "Asset", + "AssetMetadata", + "AccountState", + "Balance", + # State management + "IntentStateStore", + "ActualStateStore", +] diff --git a/backend/src/exchange_kernel/base.py b/backend/src/exchange_kernel/base.py new file mode 100644 index 0000000..2d814dd --- /dev/null +++ b/backend/src/exchange_kernel/base.py @@ -0,0 +1,361 @@ +""" +Base interface for Exchange Kernels. + +Defines the abstract API that all exchange kernel implementations must support. +Each exchange (or exchange type) will have its own kernel implementation. +""" + +from abc import ABC, abstractmethod +from typing import Callable, Any + +from .models import ( + OrderIntent, + OrderState, + Position, + AccountState, + AssetMetadata, +) +from .events import BaseEvent +from ..schema.order_spec import ( + StandardOrder, + StandardOrderGroup, + SymbolMetadata, +) + + +class ExchangeKernel(ABC): + """ + Abstract base class for exchange kernels. + + An exchange kernel manages the lifecycle of orders on a specific exchange, + maintaining both desired state (intents from strategy kernel) and actual + state (current orders on exchange), and continuously reconciling them. + + Think of it as a Kubernetes-style controller for trading orders. + """ + + def __init__(self, exchange_id: str): + """ + Initialize the exchange kernel. + + Args: + exchange_id: Unique identifier for this exchange instance + """ + self.exchange_id = exchange_id + + # ------------------------------------------------------------------------- + # Command API - Strategy kernel sends intents + # ------------------------------------------------------------------------- + + @abstractmethod + async def place_order(self, order: StandardOrder, metadata: dict[str, Any] | None = None) -> str: + """ + Place a single order on the exchange. + + This creates an OrderIntent and begins the reconciliation process to + get the order onto the exchange. + + Args: + order: The order specification + metadata: Optional strategy-specific metadata + + Returns: + intent_id: Unique identifier for this order intent + + Raises: + ValidationError: If order violates market constraints + ExchangeError: If exchange rejects the order + """ + pass + + @abstractmethod + async def place_order_group( + self, + group: StandardOrderGroup, + metadata: dict[str, Any] | None = None + ) -> list[str]: + """ + Place a group of orders with OCO (One-Cancels-Other) relationship. + + Args: + group: Group of orders with OCO mode + metadata: Optional strategy-specific metadata + + Returns: + intent_ids: List of intent IDs for each order in the group + + Raises: + ValidationError: If any order violates market constraints + ExchangeError: If exchange rejects the group + """ + pass + + @abstractmethod + async def cancel_order(self, intent_id: str) -> None: + """ + Cancel an order by intent ID. + + Updates the intent to indicate cancellation is desired, and the + reconciliation loop will handle the actual exchange cancellation. + + Args: + intent_id: Intent ID of the order to cancel + + Raises: + NotFoundError: If intent_id doesn't exist + ExchangeError: If exchange rejects cancellation + """ + pass + + @abstractmethod + async def modify_order( + self, + intent_id: str, + new_order: StandardOrder, + ) -> None: + """ + Modify an existing order. + + Updates the order intent, and the reconciliation loop will update + the exchange order (via modify API if available, or cancel+replace). + + Args: + intent_id: Intent ID of the order to modify + new_order: New order specification + + Raises: + NotFoundError: If intent_id doesn't exist + ValidationError: If new order violates market constraints + ExchangeError: If exchange rejects modification + """ + pass + + @abstractmethod + async def cancel_all_orders(self, symbol_id: str | None = None) -> int: + """ + Cancel all orders, optionally filtered by symbol. + + Args: + symbol_id: If provided, only cancel orders for this symbol + + Returns: + count: Number of orders canceled + """ + pass + + # ------------------------------------------------------------------------- + # Query API - Read desired and actual state + # ------------------------------------------------------------------------- + + @abstractmethod + async def get_order_intent(self, intent_id: str) -> OrderIntent: + """ + Get the desired order state (what strategy kernel wants). + + Args: + intent_id: Intent ID to query + + Returns: + The order intent + + Raises: + NotFoundError: If intent_id doesn't exist + """ + pass + + @abstractmethod + async def get_order_state(self, intent_id: str) -> OrderState: + """ + Get the actual order state (what's currently on exchange). + + Args: + intent_id: Intent ID to query + + Returns: + The current order state + + Raises: + NotFoundError: If intent_id doesn't exist + """ + pass + + @abstractmethod + async def get_all_intents(self, symbol_id: str | None = None) -> list[OrderIntent]: + """ + Get all order intents, optionally filtered by symbol. + + Args: + symbol_id: If provided, only return intents for this symbol + + Returns: + List of order intents + """ + pass + + @abstractmethod + async def get_all_orders(self, symbol_id: str | None = None) -> list[OrderState]: + """ + Get all actual order states, optionally filtered by symbol. + + Args: + symbol_id: If provided, only return orders for this symbol + + Returns: + List of order states + """ + pass + + @abstractmethod + async def get_positions(self, symbol_id: str | None = None) -> list[Position]: + """ + Get current positions, optionally filtered by symbol. + + Args: + symbol_id: If provided, only return positions for this symbol + + Returns: + List of positions + """ + pass + + @abstractmethod + async def get_account_state(self) -> AccountState: + """ + Get current account state (balances, margin, etc.). + + Returns: + Current account state + """ + pass + + @abstractmethod + async def get_symbol_metadata(self, symbol_id: str) -> SymbolMetadata: + """ + Get metadata for a symbol (constraints, capabilities, etc.). + + Args: + symbol_id: Symbol to query + + Returns: + Symbol metadata + + Raises: + NotFoundError: If symbol doesn't exist on this exchange + """ + pass + + @abstractmethod + async def get_asset_metadata(self, asset_id: str) -> AssetMetadata: + """ + Get metadata for an asset. + + Args: + asset_id: Asset to query + + Returns: + Asset metadata + + Raises: + NotFoundError: If asset doesn't exist + """ + pass + + @abstractmethod + async def list_symbols(self) -> list[str]: + """ + List all available symbols on this exchange. + + Returns: + List of symbol IDs + """ + pass + + # ------------------------------------------------------------------------- + # Event Subscription API + # ------------------------------------------------------------------------- + + @abstractmethod + def subscribe_events( + self, + callback: Callable[[BaseEvent], None], + event_filter: dict[str, Any] | None = None, + ) -> str: + """ + Subscribe to events from this exchange kernel. + + Args: + callback: Function to call when events occur + event_filter: Optional filter criteria (event_type, symbol_id, etc.) + + Returns: + subscription_id: Unique ID for this subscription (for unsubscribe) + """ + pass + + @abstractmethod + def unsubscribe_events(self, subscription_id: str) -> None: + """ + Unsubscribe from events. + + Args: + subscription_id: Subscription ID returned from subscribe_events + """ + pass + + # ------------------------------------------------------------------------- + # Lifecycle Management + # ------------------------------------------------------------------------- + + @abstractmethod + async def start(self) -> None: + """ + Start the exchange kernel. + + Initializes connections, starts reconciliation loops, etc. + """ + pass + + @abstractmethod + async def stop(self) -> None: + """ + Stop the exchange kernel. + + Closes connections, stops reconciliation loops, etc. + Does NOT cancel open orders - call cancel_all_orders() first if desired. + """ + pass + + @abstractmethod + async def health_check(self) -> dict[str, Any]: + """ + Check health status of the exchange kernel. + + Returns: + Health status dict with connection state, latency, error counts, etc. + """ + pass + + # ------------------------------------------------------------------------- + # Reconciliation Control (advanced) + # ------------------------------------------------------------------------- + + @abstractmethod + async def force_reconciliation(self, intent_id: str | None = None) -> None: + """ + Force immediate reconciliation. + + Args: + intent_id: If provided, only reconcile this specific intent. + If None, reconcile all intents. + """ + pass + + @abstractmethod + def get_reconciliation_metrics(self) -> dict[str, Any]: + """ + Get metrics about the reconciliation process. + + Returns: + Metrics dict with reconciliation lag, error rates, retry counts, etc. + """ + pass diff --git a/backend/src/exchange_kernel/events.py b/backend/src/exchange_kernel/events.py new file mode 100644 index 0000000..4c561bb --- /dev/null +++ b/backend/src/exchange_kernel/events.py @@ -0,0 +1,250 @@ +""" +Event definitions for the Exchange Kernel. + +All events that can occur during the order lifecycle, position management, +and account updates. +""" + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from ..schema.order_spec import Float, Uint64 + + +# --------------------------------------------------------------------------- +# Base Event Classes +# --------------------------------------------------------------------------- + +class EventType(StrEnum): + """Types of events emitted by the exchange kernel""" + # Order lifecycle + ORDER_SUBMITTED = "ORDER_SUBMITTED" + ORDER_ACCEPTED = "ORDER_ACCEPTED" + ORDER_REJECTED = "ORDER_REJECTED" + ORDER_PARTIALLY_FILLED = "ORDER_PARTIALLY_FILLED" + ORDER_FILLED = "ORDER_FILLED" + ORDER_CANCELED = "ORDER_CANCELED" + ORDER_MODIFIED = "ORDER_MODIFIED" + ORDER_EXPIRED = "ORDER_EXPIRED" + + # Position events + POSITION_OPENED = "POSITION_OPENED" + POSITION_MODIFIED = "POSITION_MODIFIED" + POSITION_CLOSED = "POSITION_CLOSED" + + # Account events + ACCOUNT_BALANCE_UPDATED = "ACCOUNT_BALANCE_UPDATED" + MARGIN_CALL_WARNING = "MARGIN_CALL_WARNING" + + # System events + RECONCILIATION_FAILED = "RECONCILIATION_FAILED" + CONNECTION_LOST = "CONNECTION_LOST" + CONNECTION_RESTORED = "CONNECTION_RESTORED" + + +class BaseEvent(BaseModel): + """Base class for all exchange kernel events""" + + model_config = {"extra": "forbid"} + + event_type: EventType = Field(description="Type of event") + timestamp: Uint64 = Field(description="Event timestamp (Unix milliseconds)") + exchange: str = Field(description="Exchange identifier") + metadata: dict[str, Any] = Field(default_factory=dict, description="Additional event data") + + +# --------------------------------------------------------------------------- +# Order Events +# --------------------------------------------------------------------------- + +class OrderEvent(BaseEvent): + """Base class for order-related events""" + + intent_id: str = Field(description="Order intent ID") + order_id: str | None = Field(default=None, description="Exchange order ID (if assigned)") + symbol_id: str = Field(description="Symbol being traded") + + +class OrderSubmitted(OrderEvent): + """Order has been submitted to the exchange""" + + event_type: EventType = Field(default=EventType.ORDER_SUBMITTED) + client_order_id: str | None = Field(default=None, description="Client-assigned order ID") + + +class OrderAccepted(OrderEvent): + """Order has been accepted by the exchange""" + + event_type: EventType = Field(default=EventType.ORDER_ACCEPTED) + order_id: str = Field(description="Exchange-assigned order ID") + accepted_at: Uint64 = Field(description="Exchange acceptance timestamp") + + +class OrderRejected(OrderEvent): + """Order was rejected by the exchange""" + + event_type: EventType = Field(default=EventType.ORDER_REJECTED) + reason: str = Field(description="Rejection reason") + error_code: str | None = Field(default=None, description="Exchange error code") + + +class OrderPartiallyFilled(OrderEvent): + """Order was partially filled""" + + event_type: EventType = Field(default=EventType.ORDER_PARTIALLY_FILLED) + order_id: str = Field(description="Exchange order ID") + fill_price: Float = Field(description="Fill price for this execution") + fill_quantity: Float = Field(description="Quantity filled in this execution") + total_filled: Float = Field(description="Total quantity filled so far") + remaining_quantity: Float = Field(description="Remaining quantity to fill") + commission: Float = Field(default=0.0, description="Commission/fee for this fill") + commission_asset: str | None = Field(default=None, description="Asset used for commission") + trade_id: str | None = Field(default=None, description="Exchange trade ID") + + +class OrderFilled(OrderEvent): + """Order was completely filled""" + + event_type: EventType = Field(default=EventType.ORDER_FILLED) + order_id: str = Field(description="Exchange order ID") + average_fill_price: Float = Field(description="Average execution price") + total_quantity: Float = Field(description="Total quantity filled") + total_commission: Float = Field(default=0.0, description="Total commission/fees") + commission_asset: str | None = Field(default=None, description="Asset used for commission") + completed_at: Uint64 = Field(description="Completion timestamp") + + +class OrderCanceled(OrderEvent): + """Order was canceled""" + + event_type: EventType = Field(default=EventType.ORDER_CANCELED) + order_id: str = Field(description="Exchange order ID") + reason: str = Field(description="Cancellation reason") + filled_quantity: Float = Field(default=0.0, description="Quantity filled before cancellation") + canceled_at: Uint64 = Field(description="Cancellation timestamp") + + +class OrderModified(OrderEvent): + """Order was modified (price, quantity, etc.)""" + + event_type: EventType = Field(default=EventType.ORDER_MODIFIED) + order_id: str = Field(description="Exchange order ID") + old_price: Float | None = Field(default=None, description="Previous price") + new_price: Float | None = Field(default=None, description="New price") + old_quantity: Float | None = Field(default=None, description="Previous quantity") + new_quantity: Float | None = Field(default=None, description="New quantity") + modified_at: Uint64 = Field(description="Modification timestamp") + + +class OrderExpired(OrderEvent): + """Order expired (GTD, DAY orders)""" + + event_type: EventType = Field(default=EventType.ORDER_EXPIRED) + order_id: str = Field(description="Exchange order ID") + filled_quantity: Float = Field(default=0.0, description="Quantity filled before expiration") + expired_at: Uint64 = Field(description="Expiration timestamp") + + +# --------------------------------------------------------------------------- +# Position Events +# --------------------------------------------------------------------------- + +class PositionEvent(BaseEvent): + """Base class for position-related events""" + + position_id: str = Field(description="Position identifier") + symbol_id: str = Field(description="Symbol identifier") + asset_id: str = Field(description="Asset identifier") + + +class PositionOpened(PositionEvent): + """New position was opened""" + + event_type: EventType = Field(default=EventType.POSITION_OPENED) + quantity: Float = Field(description="Position quantity") + entry_price: Float = Field(description="Entry price") + side: str = Field(description="LONG or SHORT") + leverage: Float | None = Field(default=None, description="Leverage") + + +class PositionModified(PositionEvent): + """Existing position was modified (size change, etc.)""" + + event_type: EventType = Field(default=EventType.POSITION_MODIFIED) + old_quantity: Float = Field(description="Previous quantity") + new_quantity: Float = Field(description="New quantity") + average_entry_price: Float = Field(description="Updated average entry price") + unrealized_pnl: Float | None = Field(default=None, description="Current unrealized P&L") + + +class PositionClosed(PositionEvent): + """Position was closed""" + + event_type: EventType = Field(default=EventType.POSITION_CLOSED) + exit_price: Float = Field(description="Exit price") + realized_pnl: Float = Field(description="Realized profit/loss") + closed_at: Uint64 = Field(description="Closure timestamp") + + +# --------------------------------------------------------------------------- +# Account Events +# --------------------------------------------------------------------------- + +class AccountEvent(BaseEvent): + """Base class for account-related events""" + + account_id: str = Field(description="Account identifier") + + +class AccountBalanceUpdated(AccountEvent): + """Account balance was updated""" + + event_type: EventType = Field(default=EventType.ACCOUNT_BALANCE_UPDATED) + asset_id: str = Field(description="Asset that changed") + old_balance: Float = Field(description="Previous balance") + new_balance: Float = Field(description="New balance") + old_available: Float = Field(description="Previous available") + new_available: Float = Field(description="New available") + change_reason: str = Field(description="Why balance changed (TRADE, DEPOSIT, WITHDRAWAL, etc.)") + + +class MarginCallWarning(AccountEvent): + """Margin level is approaching liquidation threshold""" + + event_type: EventType = Field(default=EventType.MARGIN_CALL_WARNING) + margin_level: Float = Field(description="Current margin level") + liquidation_threshold: Float = Field(description="Liquidation threshold") + required_action: str = Field(description="Required action to avoid liquidation") + estimated_liquidation_price: Float | None = Field( + default=None, + description="Estimated liquidation price for positions" + ) + + +# --------------------------------------------------------------------------- +# System Events +# --------------------------------------------------------------------------- + +class ReconciliationFailed(BaseEvent): + """Failed to reconcile intent with actual state""" + + event_type: EventType = Field(default=EventType.RECONCILIATION_FAILED) + intent_id: str = Field(description="Order intent ID") + error_message: str = Field(description="Error details") + retry_count: int = Field(description="Number of retry attempts") + + +class ConnectionLost(BaseEvent): + """Connection to exchange was lost""" + + event_type: EventType = Field(default=EventType.CONNECTION_LOST) + reason: str = Field(description="Disconnection reason") + + +class ConnectionRestored(BaseEvent): + """Connection to exchange was restored""" + + event_type: EventType = Field(default=EventType.CONNECTION_RESTORED) + downtime_duration: int = Field(description="Duration of downtime in milliseconds") diff --git a/backend/src/exchange_kernel/models.py b/backend/src/exchange_kernel/models.py new file mode 100644 index 0000000..c4a4c30 --- /dev/null +++ b/backend/src/exchange_kernel/models.py @@ -0,0 +1,194 @@ +""" +Data models for the Exchange Kernel. + +Defines order intents, order state, positions, assets, and account state. +""" + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from ..schema.order_spec import ( + StandardOrder, + StandardOrderStatus, + AssetType, + Float, + Uint64, +) + + +# --------------------------------------------------------------------------- +# Order Intent and State +# --------------------------------------------------------------------------- + +class OrderIntent(BaseModel): + """ + Desired order state from the strategy kernel. + + This represents what the strategy wants, not what currently exists. + The exchange kernel will work to reconcile actual state with this intent. + """ + + model_config = {"extra": "forbid"} + + intent_id: str = Field(description="Unique identifier for this intent (client-assigned)") + order: StandardOrder = Field(description="The desired order specification") + group_id: str | None = Field(default=None, description="Group ID for OCO relationships") + created_at: Uint64 = Field(description="When this intent was created") + updated_at: Uint64 = Field(description="When this intent was last modified") + metadata: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific metadata") + + +class ReconciliationStatus(StrEnum): + """Status of reconciliation between intent and actual state""" + PENDING = "PENDING" # Not yet submitted to exchange + SUBMITTING = "SUBMITTING" # Currently being submitted + ACTIVE = "ACTIVE" # Successfully placed on exchange + RECONCILING = "RECONCILING" # Intent changed, updating exchange order + FAILED = "FAILED" # Failed to submit or reconcile + COMPLETED = "COMPLETED" # Order fully filled + CANCELED = "CANCELED" # Order canceled + + +class OrderState(BaseModel): + """ + Actual current state of an order on the exchange. + + This represents reality - what the exchange reports about the order. + May differ from OrderIntent during reconciliation. + """ + + model_config = {"extra": "forbid"} + + intent_id: str = Field(description="Links back to the OrderIntent") + exchange_order_id: str = Field(description="Exchange-assigned order ID") + status: StandardOrderStatus = Field(description="Current order status from exchange") + reconciliation_status: ReconciliationStatus = Field(description="Reconciliation state") + last_sync_at: Uint64 = Field(description="Last time we synced with exchange") + error_message: str | None = Field(default=None, description="Error details if FAILED") + + +# --------------------------------------------------------------------------- +# Position and Asset Models +# --------------------------------------------------------------------------- + +class AssetMetadata(BaseModel): + """ + Metadata describing an asset type. + + Provides context for positions, balances, and trading. + """ + + model_config = {"extra": "forbid"} + + asset_id: str = Field(description="Unique asset identifier") + symbol: str = Field(description="Asset symbol (e.g., 'BTC', 'ETH', 'USD')") + asset_type: AssetType = Field(description="Type of asset") + name: str = Field(description="Full name") + + # Contract specifications (for derivatives) + contract_size: Float | None = Field(default=None, description="Contract multiplier") + settlement_asset: str | None = Field(default=None, description="Settlement currency") + expiry_timestamp: Uint64 | None = Field(default=None, description="Expiration timestamp") + + # Trading parameters + tick_size: Float | None = Field(default=None, description="Minimum price increment") + lot_size: Float | None = Field(default=None, description="Minimum quantity increment") + + # Margin requirements (for leveraged products) + initial_margin_rate: Float | None = Field(default=None, description="Initial margin requirement") + maintenance_margin_rate: Float | None = Field(default=None, description="Maintenance margin requirement") + + # Additional metadata + metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific metadata") + + +class Asset(BaseModel): + """ + An asset holding (spot, margin, derivative position, etc.) + """ + + model_config = {"extra": "forbid"} + + asset_id: str = Field(description="References AssetMetadata") + quantity: Float = Field(description="Amount held (positive or negative for short positions)") + available: Float = Field(description="Amount available for trading (not locked in orders)") + locked: Float = Field(description="Amount locked in open orders") + + # For derivative positions + entry_price: Float | None = Field(default=None, description="Average entry price") + mark_price: Float | None = Field(default=None, description="Current mark price") + liquidation_price: Float | None = Field(default=None, description="Estimated liquidation price") + unrealized_pnl: Float | None = Field(default=None, description="Unrealized profit/loss") + realized_pnl: Float | None = Field(default=None, description="Realized profit/loss") + + # Margin info + margin_used: Float | None = Field(default=None, description="Margin allocated to this position") + + updated_at: Uint64 = Field(description="Last update timestamp") + + +class Position(BaseModel): + """ + A trading position (spot, margin, perpetual, futures, etc.) + + Tracks both the asset holdings and associated metadata. + """ + + model_config = {"extra": "forbid"} + + position_id: str = Field(description="Unique position identifier") + symbol_id: str = Field(description="Trading symbol") + asset: Asset = Field(description="Asset holding details") + metadata: AssetMetadata = Field(description="Asset metadata") + + # Position-level info + leverage: Float | None = Field(default=None, description="Current leverage") + side: str | None = Field(default=None, description="LONG or SHORT (for derivatives)") + + updated_at: Uint64 = Field(description="Last update timestamp") + + +class Balance(BaseModel): + """Account balance for a single currency/asset""" + + model_config = {"extra": "forbid"} + + asset_id: str = Field(description="Asset identifier") + total: Float = Field(description="Total balance") + available: Float = Field(description="Available for trading") + locked: Float = Field(description="Locked in orders/positions") + + # For margin accounts + borrowed: Float = Field(default=0.0, description="Borrowed amount (margin)") + interest: Float = Field(default=0.0, description="Accrued interest") + + updated_at: Uint64 = Field(description="Last update timestamp") + + +class AccountState(BaseModel): + """ + Complete account state including balances, positions, and margin info. + """ + + model_config = {"extra": "forbid"} + + account_id: str = Field(description="Account identifier") + exchange: str = Field(description="Exchange identifier") + + balances: list[Balance] = Field(default_factory=list, description="All asset balances") + positions: list[Position] = Field(default_factory=list, description="All open positions") + + # Margin account info + total_equity: Float | None = Field(default=None, description="Total account equity") + total_margin_used: Float | None = Field(default=None, description="Total margin in use") + total_available_margin: Float | None = Field(default=None, description="Available margin") + margin_level: Float | None = Field(default=None, description="Margin level (equity/margin_used)") + + # Risk metrics + total_unrealized_pnl: Float | None = Field(default=None, description="Total unrealized P&L") + total_realized_pnl: Float | None = Field(default=None, description="Total realized P&L") + + updated_at: Uint64 = Field(description="Last update timestamp") + metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific data") diff --git a/backend/src/exchange_kernel/state.py b/backend/src/exchange_kernel/state.py new file mode 100644 index 0000000..355e6bd --- /dev/null +++ b/backend/src/exchange_kernel/state.py @@ -0,0 +1,472 @@ +""" +State management for the Exchange Kernel. + +Implements the storage and reconciliation logic for desired vs actual state. +This is the "Kubernetes for orders" concept - maintaining intent and continuously +reconciling reality to match intent. +""" + +from abc import ABC, abstractmethod +from typing import Any +from collections import defaultdict + +from .models import OrderIntent, OrderState, ReconciliationStatus +from ..schema.order_spec import Uint64 + + +# --------------------------------------------------------------------------- +# Intent State Store - Desired State +# --------------------------------------------------------------------------- + +class IntentStateStore(ABC): + """ + Storage for order intents (desired state). + + This represents what the strategy kernel wants. Intents are durable and + persist across restarts. The reconciliation loop continuously works to + make actual state match these intents. + """ + + @abstractmethod + async def create_intent(self, intent: OrderIntent) -> None: + """ + Store a new order intent. + + Args: + intent: The order intent to store + + Raises: + AlreadyExistsError: If intent_id already exists + """ + pass + + @abstractmethod + async def get_intent(self, intent_id: str) -> OrderIntent: + """ + Retrieve an order intent. + + Args: + intent_id: Intent ID to retrieve + + Returns: + The order intent + + Raises: + NotFoundError: If intent_id doesn't exist + """ + pass + + @abstractmethod + async def update_intent(self, intent: OrderIntent) -> None: + """ + Update an existing order intent. + + Args: + intent: Updated intent (intent_id must match existing) + + Raises: + NotFoundError: If intent_id doesn't exist + """ + pass + + @abstractmethod + async def delete_intent(self, intent_id: str) -> None: + """ + Delete an order intent. + + Args: + intent_id: Intent ID to delete + + Raises: + NotFoundError: If intent_id doesn't exist + """ + pass + + @abstractmethod + async def list_intents( + self, + symbol_id: str | None = None, + group_id: str | None = None, + ) -> list[OrderIntent]: + """ + List all order intents, optionally filtered. + + Args: + symbol_id: Filter by symbol + group_id: Filter by OCO group + + Returns: + List of matching intents + """ + pass + + @abstractmethod + async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]: + """ + Get all intents in an OCO group. + + Args: + group_id: Group ID to query + + Returns: + List of intents in the group + """ + pass + + +# --------------------------------------------------------------------------- +# Actual State Store - Current Reality +# --------------------------------------------------------------------------- + +class ActualStateStore(ABC): + """ + Storage for actual order state (reality on exchange). + + This represents what actually exists on the exchange right now. + Updated frequently from exchange feeds and order status queries. + """ + + @abstractmethod + async def create_order_state(self, state: OrderState) -> None: + """ + Store a new order state. + + Args: + state: The order state to store + + Raises: + AlreadyExistsError: If order state for this intent_id already exists + """ + pass + + @abstractmethod + async def get_order_state(self, intent_id: str) -> OrderState: + """ + Retrieve order state for an intent. + + Args: + intent_id: Intent ID to query + + Returns: + The current order state + + Raises: + NotFoundError: If no state exists for this intent + """ + pass + + @abstractmethod + async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState: + """ + Retrieve order state by exchange order ID. + + Useful for processing exchange callbacks that only provide exchange_order_id. + + Args: + exchange_order_id: Exchange's order ID + + Returns: + The order state + + Raises: + NotFoundError: If no state exists for this exchange order ID + """ + pass + + @abstractmethod + async def update_order_state(self, state: OrderState) -> None: + """ + Update an existing order state. + + Args: + state: Updated state (intent_id must match existing) + + Raises: + NotFoundError: If state doesn't exist + """ + pass + + @abstractmethod + async def delete_order_state(self, intent_id: str) -> None: + """ + Delete an order state. + + Args: + intent_id: Intent ID whose state to delete + + Raises: + NotFoundError: If state doesn't exist + """ + pass + + @abstractmethod + async def list_order_states( + self, + symbol_id: str | None = None, + reconciliation_status: ReconciliationStatus | None = None, + ) -> list[OrderState]: + """ + List all order states, optionally filtered. + + Args: + symbol_id: Filter by symbol + reconciliation_status: Filter by reconciliation status + + Returns: + List of matching order states + """ + pass + + @abstractmethod + async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]: + """ + Find orders that haven't been synced recently. + + Used to identify orders that need status updates from exchange. + + Args: + max_age_seconds: Maximum age since last sync + + Returns: + List of order states that need refresh + """ + pass + + +# --------------------------------------------------------------------------- +# In-Memory Implementations (for testing/prototyping) +# --------------------------------------------------------------------------- + +class InMemoryIntentStore(IntentStateStore): + """Simple in-memory implementation of IntentStateStore""" + + def __init__(self): + self._intents: dict[str, OrderIntent] = {} + self._by_symbol: dict[str, set[str]] = defaultdict(set) + self._by_group: dict[str, set[str]] = defaultdict(set) + + async def create_intent(self, intent: OrderIntent) -> None: + if intent.intent_id in self._intents: + raise ValueError(f"Intent {intent.intent_id} already exists") + self._intents[intent.intent_id] = intent + self._by_symbol[intent.order.symbol_id].add(intent.intent_id) + if intent.group_id: + self._by_group[intent.group_id].add(intent.intent_id) + + async def get_intent(self, intent_id: str) -> OrderIntent: + if intent_id not in self._intents: + raise KeyError(f"Intent {intent_id} not found") + return self._intents[intent_id] + + async def update_intent(self, intent: OrderIntent) -> None: + if intent.intent_id not in self._intents: + raise KeyError(f"Intent {intent.intent_id} not found") + old_intent = self._intents[intent.intent_id] + + # Update indices if symbol or group changed + if old_intent.order.symbol_id != intent.order.symbol_id: + self._by_symbol[old_intent.order.symbol_id].discard(intent.intent_id) + self._by_symbol[intent.order.symbol_id].add(intent.intent_id) + + if old_intent.group_id != intent.group_id: + if old_intent.group_id: + self._by_group[old_intent.group_id].discard(intent.intent_id) + if intent.group_id: + self._by_group[intent.group_id].add(intent.intent_id) + + self._intents[intent.intent_id] = intent + + async def delete_intent(self, intent_id: str) -> None: + if intent_id not in self._intents: + raise KeyError(f"Intent {intent_id} not found") + intent = self._intents[intent_id] + self._by_symbol[intent.order.symbol_id].discard(intent_id) + if intent.group_id: + self._by_group[intent.group_id].discard(intent_id) + del self._intents[intent_id] + + async def list_intents( + self, + symbol_id: str | None = None, + group_id: str | None = None, + ) -> list[OrderIntent]: + if symbol_id and group_id: + # Intersection of both filters + symbol_ids = self._by_symbol.get(symbol_id, set()) + group_ids = self._by_group.get(group_id, set()) + intent_ids = symbol_ids & group_ids + elif symbol_id: + intent_ids = self._by_symbol.get(symbol_id, set()) + elif group_id: + intent_ids = self._by_group.get(group_id, set()) + else: + intent_ids = self._intents.keys() + + return [self._intents[iid] for iid in intent_ids] + + async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]: + intent_ids = self._by_group.get(group_id, set()) + return [self._intents[iid] for iid in intent_ids] + + +class InMemoryActualStateStore(ActualStateStore): + """Simple in-memory implementation of ActualStateStore""" + + def __init__(self): + self._states: dict[str, OrderState] = {} + self._by_exchange_id: dict[str, str] = {} # exchange_order_id -> intent_id + self._by_symbol: dict[str, set[str]] = defaultdict(set) + + async def create_order_state(self, state: OrderState) -> None: + if state.intent_id in self._states: + raise ValueError(f"Order state for intent {state.intent_id} already exists") + self._states[state.intent_id] = state + self._by_exchange_id[state.exchange_order_id] = state.intent_id + self._by_symbol[state.status.order.symbol_id].add(state.intent_id) + + async def get_order_state(self, intent_id: str) -> OrderState: + if intent_id not in self._states: + raise KeyError(f"Order state for intent {intent_id} not found") + return self._states[intent_id] + + async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState: + if exchange_order_id not in self._by_exchange_id: + raise KeyError(f"Order state for exchange order {exchange_order_id} not found") + intent_id = self._by_exchange_id[exchange_order_id] + return self._states[intent_id] + + async def update_order_state(self, state: OrderState) -> None: + if state.intent_id not in self._states: + raise KeyError(f"Order state for intent {state.intent_id} not found") + old_state = self._states[state.intent_id] + + # Update exchange_id index if it changed + if old_state.exchange_order_id != state.exchange_order_id: + del self._by_exchange_id[old_state.exchange_order_id] + self._by_exchange_id[state.exchange_order_id] = state.intent_id + + # Update symbol index if it changed + old_symbol = old_state.status.order.symbol_id + new_symbol = state.status.order.symbol_id + if old_symbol != new_symbol: + self._by_symbol[old_symbol].discard(state.intent_id) + self._by_symbol[new_symbol].add(state.intent_id) + + self._states[state.intent_id] = state + + async def delete_order_state(self, intent_id: str) -> None: + if intent_id not in self._states: + raise KeyError(f"Order state for intent {intent_id} not found") + state = self._states[intent_id] + del self._by_exchange_id[state.exchange_order_id] + self._by_symbol[state.status.order.symbol_id].discard(intent_id) + del self._states[intent_id] + + async def list_order_states( + self, + symbol_id: str | None = None, + reconciliation_status: ReconciliationStatus | None = None, + ) -> list[OrderState]: + if symbol_id: + intent_ids = self._by_symbol.get(symbol_id, set()) + states = [self._states[iid] for iid in intent_ids] + else: + states = list(self._states.values()) + + if reconciliation_status: + states = [s for s in states if s.reconciliation_status == reconciliation_status] + + return states + + async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]: + import time + current_time = int(time.time()) + threshold = current_time - max_age_seconds + + return [ + state + for state in self._states.values() + if state.last_sync_at < threshold + ] + + +# --------------------------------------------------------------------------- +# Reconciliation Engine (framework only, no implementation) +# --------------------------------------------------------------------------- + +class ReconciliationEngine: + """ + Reconciliation engine that continuously works to make actual state match intent. + + This is the heart of the "Kubernetes for orders" concept. It: + 1. Compares desired state (intents) with actual state (exchange orders) + 2. Computes necessary actions (place, modify, cancel) + 3. Executes those actions via the exchange API + 4. Handles retries, errors, and edge cases + + This is a framework class - concrete implementations will be exchange-specific. + """ + + def __init__( + self, + intent_store: IntentStateStore, + actual_store: ActualStateStore, + ): + """ + Initialize the reconciliation engine. + + Args: + intent_store: Store for desired state + actual_store: Store for actual state + """ + self.intent_store = intent_store + self.actual_store = actual_store + self._running = False + + async def start(self) -> None: + """Start the reconciliation loop""" + self._running = True + # Implementation would start async reconciliation loop here + pass + + async def stop(self) -> None: + """Stop the reconciliation loop""" + self._running = False + # Implementation would stop reconciliation loop here + pass + + async def reconcile_intent(self, intent_id: str) -> None: + """ + Reconcile a specific intent. + + Compares the intent with actual state and takes necessary actions. + + Args: + intent_id: Intent to reconcile + """ + # Framework only - concrete implementation needed + pass + + async def reconcile_all(self) -> None: + """ + Reconcile all intents. + + Full reconciliation pass over all orders. + """ + # Framework only - concrete implementation needed + pass + + def get_metrics(self) -> dict[str, Any]: + """ + Get reconciliation metrics. + + Returns: + Metrics about reconciliation performance, errors, etc. + """ + return { + "running": self._running, + "reconciliation_lag_ms": 0, # Framework only + "pending_reconciliations": 0, # Framework only + "error_count": 0, # Framework only + "retry_count": 0, # Framework only + } diff --git a/backend/src/gateway/hub.py b/backend/src/gateway/hub.py index 92152af..e3db48d 100644 --- a/backend/src/gateway/hub.py +++ b/backend/src/gateway/hub.py @@ -94,6 +94,11 @@ class Gateway: logger.info(f"Session is busy, interrupting existing task") await session.interrupt() + # Check if this is a stop interrupt (empty message) + if not message.content.strip() and not message.attachments: + logger.info("Received stop interrupt (empty message), not starting new agent round") + return + # Add user message to history session.add_message("user", message.content, message.channel_id) logger.info(f"User message added to history, history length: {len(session.get_history())}") @@ -134,33 +139,55 @@ class Gateway: # Stream chunks back to active channels full_response = "" chunk_count = 0 - async for chunk in response_stream: - chunk_count += 1 - full_response += chunk - logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}") + accumulated_metadata = {} - # Send chunk to all active channels - agent_msg = AgentMessage( - session_id=session.session_id, - target_channels=session.active_channels, - content=chunk, - stream_chunk=True, - done=False - ) - await self._send_to_channels(agent_msg) + async for chunk in response_stream: + # Handle dict response with metadata (from agent executor) + if isinstance(chunk, dict): + content = chunk.get("content", "") + metadata = chunk.get("metadata", {}) + # Accumulate metadata (e.g., plot_urls) + for key, value in metadata.items(): + if key == "plot_urls" and value: + # Append to existing plot_urls + if "plot_urls" not in accumulated_metadata: + accumulated_metadata["plot_urls"] = [] + accumulated_metadata["plot_urls"].extend(value) + logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}") + else: + accumulated_metadata[key] = value + chunk = content + + # Only send non-empty chunks + if chunk: + chunk_count += 1 + full_response += chunk + logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}") + + # Send chunk to all active channels with accumulated metadata + agent_msg = AgentMessage( + session_id=session.session_id, + target_channels=session.active_channels, + content=chunk, + stream_chunk=True, + done=False, + metadata=accumulated_metadata.copy() + ) + await self._send_to_channels(agent_msg) logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}") - # Send final done message + # Send final done message with all accumulated metadata agent_msg = AgentMessage( session_id=session.session_id, target_channels=session.active_channels, content="", stream_chunk=True, - done=True + done=True, + metadata=accumulated_metadata ) await self._send_to_channels(agent_msg) - logger.info("Sent final done message to channels") + logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}") # Add to history session.add_message("assistant", full_response) diff --git a/backend/src/indicator/__init__.py b/backend/src/indicator/__init__.py new file mode 100644 index 0000000..57a5fd3 --- /dev/null +++ b/backend/src/indicator/__init__.py @@ -0,0 +1,172 @@ +""" +Composable Indicator System. + +Provides a framework for building DAGs of data transformation pipelines +that process time-series data incrementally. Indicators can consume +DataSources or other Indicators as inputs, composing into arbitrarily +complex processing graphs. + +Key Components: +--------------- + +Indicator (base.py): + Abstract base class for all indicator implementations. + Declares input/output schemas and implements synchronous compute(). + +IndicatorRegistry (registry.py): + Central catalog of available indicators with rich metadata + for AI agent discovery and tool generation. + +Pipeline (pipeline.py): + Execution engine that builds DAGs, resolves dependencies, + and orchestrates incremental data flow through indicator chains. + +Schema Types (schema.py): + Type definitions for input/output schemas, computation context, + and metadata for AI-native documentation. + +Usage Example: +-------------- + + from indicator import Indicator, IndicatorRegistry, Pipeline + from indicator.schema import ( + InputSchema, OutputSchema, ComputeContext, ComputeResult, + IndicatorMetadata, IndicatorParameter + ) + + # Define an indicator + class SimpleMovingAverage(Indicator): + @classmethod + def get_metadata(cls): + return IndicatorMetadata( + name="SMA", + display_name="Simple Moving Average", + description="Arithmetic mean of prices over N periods", + category="trend", + parameters=[ + IndicatorParameter( + name="period", + type="int", + description="Number of periods to average", + default=20, + min_value=1 + ) + ], + tags=["moving-average", "trend-following"] + ) + + @classmethod + def get_input_schema(cls): + return InputSchema( + required_columns=[ + ColumnInfo(name="close", type="float", description="Closing price") + ] + ) + + @classmethod + def get_output_schema(cls, **params): + return OutputSchema( + columns=[ + ColumnInfo( + name="sma", + type="float", + description=f"Simple moving average over {params.get('period', 20)} periods" + ) + ] + ) + + def compute(self, context: ComputeContext) -> ComputeResult: + period = self.params["period"] + closes = context.get_column("close") + times = context.get_times() + + sma_values = [] + for i in range(len(closes)): + if i < period - 1: + sma_values.append(None) + else: + window = closes[i - period + 1 : i + 1] + sma_values.append(sum(window) / period) + + return ComputeResult( + data=[ + {"time": times[i], "sma": sma_values[i]} + for i in range(len(times)) + ] + ) + + # Register the indicator + registry = IndicatorRegistry() + registry.register(SimpleMovingAverage) + + # Create a pipeline + pipeline = Pipeline(datasource_registry) + pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D") + + sma_indicator = registry.create_instance("SMA", "sma_20", period=20) + pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"]) + + # Execute + results = pipeline.execute(datasource_data={"price_data": price_bars}) + sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma + +Design Philosophy: +------------------ + +1. **Schema-based composition**: Indicators declare inputs/outputs via schemas, + enabling automatic validation and flexible composition. + +2. **Synchronous execution**: All computation is synchronous for simplicity. + Async handling happens at the event/strategy layer. + +3. **Incremental updates**: Indicators receive context about what changed, + allowing optimized recomputation of only affected values. + +4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs + make indicators discoverable and usable by AI agents. + +5. **Generic data flow**: Indicators work with any data source that matches + their input schema, not specific DataSource instances. + +6. **Event-driven**: Designed to react to DataSource updates and propagate + changes through the DAG efficiently. +""" + +from .base import DataSourceAdapter, Indicator +from .pipeline import Pipeline, PipelineNode +from .registry import IndicatorRegistry +from .schema import ( + ComputeContext, + ComputeResult, + IndicatorMetadata, + IndicatorParameter, + InputSchema, + OutputSchema, +) +from .talib_adapter import ( + TALibIndicator, + register_all_talib_indicators, + is_talib_available, + get_talib_version, +) + +__all__ = [ + # Core classes + "Indicator", + "IndicatorRegistry", + "Pipeline", + "PipelineNode", + "DataSourceAdapter", + # Schema types + "InputSchema", + "OutputSchema", + "ComputeContext", + "ComputeResult", + "IndicatorMetadata", + "IndicatorParameter", + # TA-Lib integration + "TALibIndicator", + "register_all_talib_indicators", + "is_talib_available", + "get_talib_version", +] diff --git a/backend/src/indicator/base.py b/backend/src/indicator/base.py new file mode 100644 index 0000000..b7bed28 --- /dev/null +++ b/backend/src/indicator/base.py @@ -0,0 +1,230 @@ +""" +Abstract Indicator interface. + +Provides the base class for all technical indicators and derived data transformations. +Indicators compose into DAGs, processing data incrementally as updates arrive. +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from .schema import ( + ComputeContext, + ComputeResult, + IndicatorMetadata, + InputSchema, + OutputSchema, +) + + +class Indicator(ABC): + """ + Abstract base class for all indicators. + + Indicators are composable transformation nodes that: + - Declare input schema (columns they need) + - Declare output schema (columns they produce) + - Compute outputs synchronously from inputs + - Support incremental updates (process only what changed) + - Provide rich metadata for AI agent discovery + + Indicators are stateless at the instance level - all state is managed + by the pipeline execution engine. This allows the same indicator class + to be reused with different parameters. + """ + + def __init__(self, instance_name: str, **params): + """ + Initialize an indicator instance. + + Args: + instance_name: Unique name for this instance (used for output column prefixing) + **params: Configuration parameters (validated against metadata.parameters) + """ + self.instance_name = instance_name + self.params = params + self._validate_params() + + @classmethod + @abstractmethod + def get_metadata(cls) -> IndicatorMetadata: + """ + Get metadata for this indicator class. + + Called by the registry for AI agent discovery and documentation. + Should return comprehensive information about the indicator's purpose, + parameters, and use cases. + + Returns: + IndicatorMetadata describing this indicator class + """ + pass + + @classmethod + @abstractmethod + def get_input_schema(cls) -> InputSchema: + """ + Get the input schema required by this indicator. + + Declares what columns must be present in the input data. + The pipeline will match this against available data sources. + + Returns: + InputSchema describing required and optional input columns + """ + pass + + @classmethod + @abstractmethod + def get_output_schema(cls, **params) -> OutputSchema: + """ + Get the output schema produced by this indicator. + + Output column names will be automatically prefixed with the instance name + by the pipeline engine. + + Args: + **params: Configuration parameters (may affect output schema) + + Returns: + OutputSchema describing the columns this indicator produces + """ + pass + + @abstractmethod + def compute(self, context: ComputeContext) -> ComputeResult: + """ + Compute indicator values from input data. + + This method is called synchronously by the pipeline engine whenever + input data changes. Implementations should: + + 1. Extract needed columns from context.data + 2. Perform calculations + 3. Return results with proper time alignment + + For incremental updates (context.is_incremental == True): + - context.data contains only new/updated rows + - Implementations MAY optimize by computing only these rows + - OR implementations MAY recompute everything (simpler but slower) + + Args: + context: Input data and update metadata + + Returns: + ComputeResult with calculated indicator values + + Raises: + ValueError: If input data doesn't match expected schema + """ + pass + + def _validate_params(self) -> None: + """ + Validate that provided parameters match the metadata specification. + + Raises: + ValueError: If required parameters are missing or invalid + """ + metadata = self.get_metadata() + + # Check for required parameters + for param_def in metadata.parameters: + if param_def.required and param_def.name not in self.params: + raise ValueError( + f"Indicator '{metadata.name}' requires parameter '{param_def.name}'" + ) + + # Validate parameter types and ranges + for name, value in self.params.items(): + # Find parameter definition + param_def = next( + (p for p in metadata.parameters if p.name == name), + None + ) + + if param_def is None: + raise ValueError( + f"Unknown parameter '{name}' for indicator '{metadata.name}'" + ) + + # Type checking + if param_def.type == "int" and not isinstance(value, int): + raise ValueError( + f"Parameter '{name}' must be int, got {type(value).__name__}" + ) + elif param_def.type == "float" and not isinstance(value, (int, float)): + raise ValueError( + f"Parameter '{name}' must be float, got {type(value).__name__}" + ) + elif param_def.type == "bool" and not isinstance(value, bool): + raise ValueError( + f"Parameter '{name}' must be bool, got {type(value).__name__}" + ) + elif param_def.type == "string" and not isinstance(value, str): + raise ValueError( + f"Parameter '{name}' must be string, got {type(value).__name__}" + ) + + # Range checking for numeric types + if param_def.type in ("int", "float"): + if param_def.min_value is not None and value < param_def.min_value: + raise ValueError( + f"Parameter '{name}' must be >= {param_def.min_value}, got {value}" + ) + if param_def.max_value is not None and value > param_def.max_value: + raise ValueError( + f"Parameter '{name}' must be <= {param_def.max_value}, got {value}" + ) + + def get_output_columns(self) -> List[str]: + """ + Get the output column names with instance name prefix. + + Returns: + List of prefixed output column names + """ + output_schema = self.get_output_schema(**self.params) + prefixed = output_schema.with_prefix(self.instance_name) + return [col.name for col in prefixed.columns if col.name != output_schema.time_column] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})" + + +class DataSourceAdapter: + """ + Adapter to make a DataSource look like an Indicator for pipeline composition. + + This allows DataSources to be inputs to indicators in a unified way. + """ + + def __init__(self, datasource_id: str, symbol: str, resolution: str): + """ + Create a DataSource adapter. + + Args: + datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo') + symbol: Symbol to query (e.g., 'BTC/USD') + resolution: Time resolution (e.g., '1', '5', '1D') + """ + self.datasource_id = datasource_id + self.symbol = symbol + self.resolution = resolution + self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_") + + def get_output_columns(self) -> List[str]: + """ + Get the columns provided by this datasource. + + Note: This requires runtime resolution - the pipeline engine + will need to query the actual DataSource to get the schema. + + Returns: + List of column names (placeholder - needs runtime resolution) + """ + # This will be resolved at runtime by the pipeline engine + return [] + + def __repr__(self) -> str: + return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')" diff --git a/backend/src/indicator/pipeline.py b/backend/src/indicator/pipeline.py new file mode 100644 index 0000000..3f6adcd --- /dev/null +++ b/backend/src/indicator/pipeline.py @@ -0,0 +1,439 @@ +""" +Pipeline execution engine for composable indicators. + +Manages DAG construction, dependency resolution, incremental updates, +and efficient data flow through indicator chains. +""" + +import logging +from collections import defaultdict, deque +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from datasource.base import DataSource +from datasource.schema import ColumnInfo + +from .base import DataSourceAdapter, Indicator +from .schema import ComputeContext, ComputeResult + +logger = logging.getLogger(__name__) + + +class PipelineNode: + """ + A node in the pipeline DAG. + + Can be either a DataSource adapter or an Indicator instance. + """ + + def __init__( + self, + node_id: str, + node: Union[DataSourceAdapter, Indicator], + dependencies: List[str] + ): + """ + Create a pipeline node. + + Args: + node_id: Unique identifier for this node + node: The DataSourceAdapter or Indicator instance + dependencies: List of node_ids this node depends on + """ + self.node_id = node_id + self.node = node + self.dependencies = dependencies + self.output_columns: List[str] = [] + self.cached_data: List[Dict[str, Any]] = [] + + def is_datasource(self) -> bool: + """Check if this node is a DataSource adapter.""" + return isinstance(self.node, DataSourceAdapter) + + def is_indicator(self) -> bool: + """Check if this node is an Indicator.""" + return isinstance(self.node, Indicator) + + def __repr__(self) -> str: + return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})" + + +class Pipeline: + """ + Execution engine for indicator DAGs. + + Manages: + - DAG construction and validation + - Topological sorting for execution order + - Data flow and caching + - Incremental updates (only recompute what changed) + - Schema validation + """ + + def __init__(self, datasource_registry): + """ + Initialize a pipeline. + + Args: + datasource_registry: DataSourceRegistry for resolving data sources + """ + self.datasource_registry = datasource_registry + self.nodes: Dict[str, PipelineNode] = {} + self.execution_order: List[str] = [] + self._dirty_nodes: Set[str] = set() + + def add_datasource( + self, + node_id: str, + datasource_name: str, + symbol: str, + resolution: str + ) -> None: + """ + Add a DataSource to the pipeline. + + Args: + node_id: Unique identifier for this node + datasource_name: Name of the datasource in the registry + symbol: Symbol to query + resolution: Time resolution + + Raises: + ValueError: If node_id already exists or datasource not found + """ + if node_id in self.nodes: + raise ValueError(f"Node '{node_id}' already exists in pipeline") + + datasource = self.datasource_registry.get(datasource_name) + if not datasource: + raise ValueError(f"DataSource '{datasource_name}' not found in registry") + + adapter = DataSourceAdapter(datasource_name, symbol, resolution) + node = PipelineNode(node_id, adapter, dependencies=[]) + + self.nodes[node_id] = node + self._invalidate_execution_order() + + logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}") + + def add_indicator( + self, + node_id: str, + indicator: Indicator, + input_node_ids: List[str] + ) -> None: + """ + Add an Indicator to the pipeline. + + Args: + node_id: Unique identifier for this node + indicator: Indicator instance + input_node_ids: List of node IDs providing input data + + Raises: + ValueError: If node_id already exists, dependencies not found, or schema mismatch + """ + if node_id in self.nodes: + raise ValueError(f"Node '{node_id}' already exists in pipeline") + + # Validate dependencies exist + for dep_id in input_node_ids: + if dep_id not in self.nodes: + raise ValueError(f"Dependency node '{dep_id}' not found in pipeline") + + # TODO: Validate input schema matches available columns from dependencies + # This requires merging output schemas from all input nodes + + node = PipelineNode(node_id, indicator, dependencies=input_node_ids) + self.nodes[node_id] = node + self._invalidate_execution_order() + + logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}") + + def remove_node(self, node_id: str) -> None: + """ + Remove a node from the pipeline. + + Args: + node_id: Node to remove + + Raises: + ValueError: If other nodes depend on this node + """ + if node_id not in self.nodes: + return + + # Check for dependent nodes + dependents = [ + n.node_id for n in self.nodes.values() + if node_id in n.dependencies + ] + + if dependents: + raise ValueError( + f"Cannot remove node '{node_id}': nodes {dependents} depend on it" + ) + + del self.nodes[node_id] + self._invalidate_execution_order() + + logger.info(f"Removed node '{node_id}' from pipeline") + + def _invalidate_execution_order(self) -> None: + """Mark execution order as needing recomputation.""" + self.execution_order = [] + + def _compute_execution_order(self) -> List[str]: + """ + Compute topological sort of the DAG. + + Returns: + List of node IDs in execution order + + Raises: + ValueError: If DAG contains cycles + """ + if self.execution_order: + return self.execution_order + + # Kahn's algorithm for topological sort + in_degree = {node_id: 0 for node_id in self.nodes} + for node in self.nodes.values(): + for dep in node.dependencies: + in_degree[node.node_id] += 1 + + queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) + result = [] + + while queue: + node_id = queue.popleft() + result.append(node_id) + + # Find all nodes that depend on this one + for other_node in self.nodes.values(): + if node_id in other_node.dependencies: + in_degree[other_node.node_id] -= 1 + if in_degree[other_node.node_id] == 0: + queue.append(other_node.node_id) + + if len(result) != len(self.nodes): + raise ValueError("Pipeline contains cycles") + + self.execution_order = result + logger.debug(f"Computed execution order: {result}") + return result + + def execute( + self, + datasource_data: Dict[str, List[Dict[str, Any]]], + incremental: bool = False, + updated_from_time: Optional[int] = None + ) -> Dict[str, List[Dict[str, Any]]]: + """ + Execute the pipeline. + + Args: + datasource_data: Mapping of DataSource node_id to input data + incremental: Whether this is an incremental update + updated_from_time: Timestamp of earliest updated row (for incremental) + + Returns: + Dictionary mapping node_id to output data (all nodes) + + Raises: + ValueError: If required datasource data is missing + """ + execution_order = self._compute_execution_order() + results: Dict[str, List[Dict[str, Any]]] = {} + + logger.info( + f"Executing pipeline with {len(execution_order)} nodes " + f"(incremental={incremental})" + ) + + for node_id in execution_order: + node = self.nodes[node_id] + + if node.is_datasource(): + # DataSource node - get data from input + if node_id not in datasource_data: + raise ValueError( + f"DataSource node '{node_id}' has no input data" + ) + results[node_id] = datasource_data[node_id] + node.cached_data = results[node_id] + logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows") + + elif node.is_indicator(): + # Indicator node - compute from dependencies + indicator = node.node + + # Merge input data from all dependencies + input_data = self._merge_dependency_data(node.dependencies, results) + + # Create compute context + context = ComputeContext( + data=input_data, + is_incremental=incremental, + updated_from_time=updated_from_time + ) + + # Execute indicator + logger.debug( + f"Computing indicator '{node_id}' with {len(input_data)} input rows" + ) + compute_result = indicator.compute(context) + + # Merge result with input data (adding prefixed columns) + output_data = compute_result.merge_with_prefix( + indicator.instance_name, + input_data + ) + + results[node_id] = output_data + node.cached_data = output_data + logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows") + + logger.info(f"Pipeline execution complete: {len(results)} nodes processed") + return results + + def _merge_dependency_data( + self, + dependency_ids: List[str], + results: Dict[str, List[Dict[str, Any]]] + ) -> List[Dict[str, Any]]: + """ + Merge data from multiple dependency nodes. + + Data is merged by time, with later dependencies overwriting earlier ones + for conflicting column names. + + Args: + dependency_ids: List of node IDs to merge + results: Current execution results + + Returns: + Merged data rows + """ + if not dependency_ids: + return [] + + if len(dependency_ids) == 1: + return results[dependency_ids[0]] + + # Build time-indexed data from first dependency + merged: Dict[int, Dict[str, Any]] = {} + for row in results[dependency_ids[0]]: + merged[row["time"]] = row.copy() + + # Merge in additional dependencies + for dep_id in dependency_ids[1:]: + for row in results[dep_id]: + time_key = row["time"] + if time_key in merged: + # Merge columns (later dependencies win) + merged[time_key].update(row) + else: + # New timestamp + merged[time_key] = row.copy() + + # Sort by time and return + sorted_times = sorted(merged.keys()) + return [merged[t] for t in sorted_times] + + def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]: + """ + Get cached output data for a specific node. + + Args: + node_id: Node identifier + + Returns: + Cached data or None if not available + """ + node = self.nodes.get(node_id) + return node.cached_data if node else None + + def get_output_schema(self, node_id: str) -> List[ColumnInfo]: + """ + Get the output schema for a specific node. + + Args: + node_id: Node identifier + + Returns: + List of ColumnInfo describing output columns + + Raises: + ValueError: If node not found + """ + node = self.nodes.get(node_id) + if not node: + raise ValueError(f"Node '{node_id}' not found") + + if node.is_datasource(): + # Would need to query the actual datasource at runtime + # For now, return empty - this requires integration with DataSource + return [] + + elif node.is_indicator(): + indicator = node.node + output_schema = indicator.get_output_schema(**indicator.params) + prefixed_schema = output_schema.with_prefix(indicator.instance_name) + return prefixed_schema.columns + + return [] + + def validate_pipeline(self) -> Tuple[bool, Optional[str]]: + """ + Validate the entire pipeline for correctness. + + Checks: + - No cycles (already checked in execution order) + - All dependencies exist (already checked in add_indicator) + - Input schemas match output schemas (TODO) + + Returns: + Tuple of (is_valid, error_message) + """ + try: + self._compute_execution_order() + return True, None + except ValueError as e: + return False, str(e) + + def get_node_count(self) -> int: + """Get the number of nodes in the pipeline.""" + return len(self.nodes) + + def get_indicator_count(self) -> int: + """Get the number of indicator nodes in the pipeline.""" + return sum(1 for node in self.nodes.values() if node.is_indicator()) + + def get_datasource_count(self) -> int: + """Get the number of datasource nodes in the pipeline.""" + return sum(1 for node in self.nodes.values() if node.is_datasource()) + + def describe(self) -> Dict[str, Any]: + """ + Get a detailed description of the pipeline structure. + + Returns: + Dictionary with pipeline metadata and structure + """ + return { + "node_count": self.get_node_count(), + "datasource_count": self.get_datasource_count(), + "indicator_count": self.get_indicator_count(), + "nodes": [ + { + "id": node.node_id, + "type": "datasource" if node.is_datasource() else "indicator", + "node": str(node.node), + "dependencies": node.dependencies, + "cached_rows": len(node.cached_data) + } + for node in self.nodes.values() + ], + "execution_order": self.execution_order or self._compute_execution_order(), + "is_valid": self.validate_pipeline()[0] + } diff --git a/backend/src/indicator/registry.py b/backend/src/indicator/registry.py new file mode 100644 index 0000000..74f8aa6 --- /dev/null +++ b/backend/src/indicator/registry.py @@ -0,0 +1,349 @@ +""" +Indicator registry for managing and discovering indicators. + +Provides AI agents with a queryable catalog of available indicators, +their capabilities, and metadata. +""" + +from typing import Dict, List, Optional, Type + +from .base import Indicator +from .schema import IndicatorMetadata, InputSchema, OutputSchema + + +class IndicatorRegistry: + """ + Central registry for indicator classes. + + Enables: + - Registration of indicator implementations + - Discovery by name, category, or tags + - Schema validation + - AI agent tool generation + """ + + def __init__(self): + self._indicators: Dict[str, Type[Indicator]] = {} + + def register(self, indicator_class: Type[Indicator]) -> None: + """ + Register an indicator class. + + Args: + indicator_class: Indicator class to register + + Raises: + ValueError: If an indicator with this name is already registered + """ + metadata = indicator_class.get_metadata() + + if metadata.name in self._indicators: + raise ValueError( + f"Indicator '{metadata.name}' is already registered" + ) + + self._indicators[metadata.name] = indicator_class + + def unregister(self, name: str) -> None: + """ + Unregister an indicator class. + + Args: + name: Indicator class name + """ + self._indicators.pop(name, None) + + def get(self, name: str) -> Optional[Type[Indicator]]: + """ + Get an indicator class by name. + + Args: + name: Indicator class name + + Returns: + Indicator class or None if not found + """ + return self._indicators.get(name) + + def list_indicators(self) -> List[str]: + """ + Get names of all registered indicators. + + Returns: + List of indicator class names + """ + return list(self._indicators.keys()) + + def get_metadata(self, name: str) -> Optional[IndicatorMetadata]: + """ + Get metadata for a specific indicator. + + Args: + name: Indicator class name + + Returns: + IndicatorMetadata or None if not found + """ + indicator_class = self.get(name) + if indicator_class: + return indicator_class.get_metadata() + return None + + def get_all_metadata(self) -> List[IndicatorMetadata]: + """ + Get metadata for all registered indicators. + + Useful for AI agent tool generation and discovery. + + Returns: + List of IndicatorMetadata for all registered indicators + """ + return [cls.get_metadata() for cls in self._indicators.values()] + + def search_by_category(self, category: str) -> List[IndicatorMetadata]: + """ + Find indicators by category. + + Args: + category: Category name (e.g., 'momentum', 'trend', 'volatility') + + Returns: + List of matching indicator metadata + """ + results = [] + for indicator_class in self._indicators.values(): + metadata = indicator_class.get_metadata() + if metadata.category.lower() == category.lower(): + results.append(metadata) + return results + + def search_by_tag(self, tag: str) -> List[IndicatorMetadata]: + """ + Find indicators by tag. + + Args: + tag: Tag to search for (case-insensitive) + + Returns: + List of matching indicator metadata + """ + tag_lower = tag.lower() + results = [] + for indicator_class in self._indicators.values(): + metadata = indicator_class.get_metadata() + if any(t.lower() == tag_lower for t in metadata.tags): + results.append(metadata) + return results + + def search_by_text(self, query: str) -> List[IndicatorMetadata]: + """ + Full-text search across indicator names, descriptions, and use cases. + + Args: + query: Search query (case-insensitive) + + Returns: + List of matching indicator metadata, ranked by relevance + """ + query_lower = query.lower() + results = [] + + for indicator_class in self._indicators.values(): + metadata = indicator_class.get_metadata() + score = 0 + + # Check name (highest weight) + if query_lower in metadata.name.lower(): + score += 10 + if query_lower in metadata.display_name.lower(): + score += 8 + + # Check description + if query_lower in metadata.description.lower(): + score += 5 + + # Check use cases + for use_case in metadata.use_cases: + if query_lower in use_case.lower(): + score += 3 + + # Check tags + for tag in metadata.tags: + if query_lower in tag.lower(): + score += 2 + + if score > 0: + results.append((score, metadata)) + + # Sort by score descending + results.sort(key=lambda x: x[0], reverse=True) + return [metadata for _, metadata in results] + + def find_compatible_indicators( + self, + available_columns: List[str], + column_types: Dict[str, str] + ) -> List[IndicatorMetadata]: + """ + Find indicators that can be computed from available columns. + + Args: + available_columns: List of column names available + column_types: Mapping of column name to type + + Returns: + List of indicators whose input schema is satisfied + """ + from datasource.schema import ColumnInfo + + # Build ColumnInfo list from available data + available_schema = [ + ColumnInfo( + name=name, + type=column_types.get(name, "float"), + description=f"Column {name}" + ) + for name in available_columns + ] + + results = [] + for indicator_class in self._indicators.values(): + input_schema = indicator_class.get_input_schema() + if input_schema.matches(available_schema): + results.append(indicator_class.get_metadata()) + + return results + + def validate_indicator_chain( + self, + indicator_chain: List[tuple[str, Dict]] + ) -> tuple[bool, Optional[str]]: + """ + Validate that a chain of indicators can be connected. + + Args: + indicator_chain: List of (indicator_name, params) tuples in execution order + + Returns: + Tuple of (is_valid, error_message) + """ + if not indicator_chain: + return True, None + + # For now, just check that all indicators exist + # More sophisticated DAG validation happens in the pipeline engine + for indicator_name, params in indicator_chain: + if indicator_name not in self._indicators: + return False, f"Indicator '{indicator_name}' not found in registry" + + return True, None + + def get_input_schema(self, name: str) -> Optional[InputSchema]: + """ + Get input schema for a specific indicator. + + Args: + name: Indicator class name + + Returns: + InputSchema or None if not found + """ + indicator_class = self.get(name) + if indicator_class: + return indicator_class.get_input_schema() + return None + + def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]: + """ + Get output schema for a specific indicator with given parameters. + + Args: + name: Indicator class name + **params: Indicator parameters + + Returns: + OutputSchema or None if not found + """ + indicator_class = self.get(name) + if indicator_class: + return indicator_class.get_output_schema(**params) + return None + + def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]: + """ + Create an indicator instance with validation. + + Args: + name: Indicator class name + instance_name: Unique instance name (for output column prefixing) + **params: Indicator configuration parameters + + Returns: + Indicator instance or None if class not found + + Raises: + ValueError: If parameters are invalid + """ + indicator_class = self.get(name) + if not indicator_class: + return None + + return indicator_class(instance_name=instance_name, **params) + + def generate_ai_tool_spec(self) -> Dict: + """ + Generate a JSON specification for AI agent tools. + + Creates a structured representation of all indicators that can be + used to build agent tools for indicator selection and composition. + + Returns: + Dict suitable for AI agent tool registration + """ + tools = [] + + for indicator_class in self._indicators.values(): + metadata = indicator_class.get_metadata() + + # Build parameter spec + parameters = { + "type": "object", + "properties": {}, + "required": [] + } + + for param in metadata.parameters: + param_spec = { + "type": param.type, + "description": param.description + } + + if param.default is not None: + param_spec["default"] = param.default + if param.min_value is not None: + param_spec["minimum"] = param.min_value + if param.max_value is not None: + param_spec["maximum"] = param.max_value + + parameters["properties"][param.name] = param_spec + + if param.required: + parameters["required"].append(param.name) + + tool = { + "name": f"indicator_{metadata.name.lower()}", + "description": f"{metadata.display_name}: {metadata.description}", + "category": metadata.category, + "use_cases": metadata.use_cases, + "tags": metadata.tags, + "parameters": parameters, + "input_schema": indicator_class.get_input_schema().model_dump(), + "output_schema": indicator_class.get_output_schema().model_dump() + } + + tools.append(tool) + + return { + "indicator_tools": tools, + "total_count": len(tools) + } diff --git a/backend/src/indicator/schema.py b/backend/src/indicator/schema.py new file mode 100644 index 0000000..1f558f2 --- /dev/null +++ b/backend/src/indicator/schema.py @@ -0,0 +1,269 @@ +""" +Data models for the Indicator system. + +Defines schemas for input/output specifications, computation context, +and metadata for AI agent discovery. +""" + +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + +from datasource.schema import ColumnInfo + + +class InputSchema(BaseModel): + """ + Declares the required input columns for an Indicator. + + Indicators match against any data source (DataSource or other Indicator) + that provides columns satisfying this schema. + """ + + model_config = {"extra": "forbid"} + + required_columns: List[ColumnInfo] = Field( + description="Columns that must be present in the input data" + ) + optional_columns: List[ColumnInfo] = Field( + default_factory=list, + description="Columns that may be used if present but are not required" + ) + time_column: str = Field( + default="time", + description="Name of the timestamp column (must be present)" + ) + + def matches(self, available_columns: List[ColumnInfo]) -> bool: + """ + Check if available columns satisfy this input schema. + + Args: + available_columns: Columns provided by a data source + + Returns: + True if all required columns are present with compatible types + """ + available_map = {col.name: col for col in available_columns} + + # Check time column exists + if self.time_column not in available_map: + return False + + # Check all required columns exist with compatible types + for required in self.required_columns: + if required.name not in available_map: + return False + available = available_map[required.name] + if available.type != required.type: + return False + + return True + + def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]: + """ + Get list of missing required column names. + + Args: + available_columns: Columns provided by a data source + + Returns: + List of missing column names + """ + available_names = {col.name for col in available_columns} + missing = [] + + if self.time_column not in available_names: + missing.append(self.time_column) + + for required in self.required_columns: + if required.name not in available_names: + missing.append(required.name) + + return missing + + +class OutputSchema(BaseModel): + """ + Declares the output columns produced by an Indicator. + + Column names will be automatically prefixed with the indicator instance name + to avoid collisions in the pipeline. + """ + + model_config = {"extra": "forbid"} + + columns: List[ColumnInfo] = Field( + description="Output columns produced by this indicator" + ) + time_column: str = Field( + default="time", + description="Name of the timestamp column (passed through from input)" + ) + + def with_prefix(self, prefix: str) -> "OutputSchema": + """ + Create a new OutputSchema with all column names prefixed. + + Args: + prefix: Prefix to add (e.g., indicator instance name) + + Returns: + New OutputSchema with prefixed column names + """ + prefixed_columns = [ + ColumnInfo( + name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name, + type=col.type, + description=col.description, + unit=col.unit, + nullable=col.nullable + ) + for col in self.columns + ] + return OutputSchema( + columns=prefixed_columns, + time_column=self.time_column + ) + + +class IndicatorParameter(BaseModel): + """ + Metadata for a configurable indicator parameter. + + Used for AI agent discovery and dynamic indicator instantiation. + """ + + model_config = {"extra": "forbid"} + + name: str = Field(description="Parameter name") + type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type") + description: str = Field(description="Human and LLM-readable description") + default: Optional[Any] = Field(default=None, description="Default value if not specified") + required: bool = Field(default=False, description="Whether this parameter is required") + min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)") + max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)") + + +class IndicatorMetadata(BaseModel): + """ + Rich metadata for an Indicator class. + + Enables AI agents to discover, understand, and instantiate indicators. + """ + + model_config = {"extra": "forbid"} + + name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')") + display_name: str = Field(description="Human-readable display name") + description: str = Field(description="Detailed description of what this indicator computes and why it's useful") + category: str = Field( + description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')" + ) + parameters: List[IndicatorParameter] = Field( + default_factory=list, + description="Configurable parameters for this indicator" + ) + use_cases: List[str] = Field( + default_factory=list, + description="Common use cases and trading scenarios where this indicator is helpful" + ) + references: List[str] = Field( + default_factory=list, + description="URLs or citations for indicator methodology" + ) + tags: List[str] = Field( + default_factory=list, + description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')" + ) + + +class ComputeContext(BaseModel): + """ + Context passed to an Indicator's compute() method. + + Contains the input data and metadata about what changed (for incremental updates). + """ + + model_config = {"extra": "forbid"} + + data: List[Dict[str, Any]] = Field( + description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}" + ) + is_incremental: bool = Field( + default=False, + description="True if this is an incremental update (only new/changed rows), False for full recompute" + ) + updated_from_time: Optional[int] = Field( + default=None, + description="Unix timestamp (ms) of the earliest updated row (for incremental updates)" + ) + + def get_column(self, name: str) -> List[Any]: + """ + Extract a single column as a list of values. + + Args: + name: Column name + + Returns: + List of values in time order + """ + return [row.get(name) for row in self.data] + + def get_times(self) -> List[int]: + """ + Get the time column as a list. + + Returns: + List of timestamps in order + """ + return [row["time"] for row in self.data] + + +class ComputeResult(BaseModel): + """ + Result from an Indicator's compute() method. + + Contains the computed output data with proper column naming. + """ + + model_config = {"extra": "forbid"} + + data: List[Dict[str, Any]] = Field( + description="Output data rows (time-ordered). Must include time column." + ) + is_partial: bool = Field( + default=False, + description="True if this result only contains updates (for incremental computation)" + ) + + def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merge this result into existing data with column name prefixing. + + Args: + prefix: Prefix to add to all column names except time + existing_data: Existing data to merge with (matched by time) + + Returns: + Merged data with prefixed columns added + """ + # Build a time index for new data + time_index = {row["time"]: row for row in self.data} + + # Merge into existing data + result = [] + for existing_row in existing_data: + row_time = existing_row["time"] + merged_row = existing_row.copy() + + if row_time in time_index: + new_row = time_index[row_time] + for key, value in new_row.items(): + if key != "time": + merged_row[f"{prefix}_{key}"] = value + + result.append(merged_row) + + return result diff --git a/backend/src/indicator/talib_adapter.py b/backend/src/indicator/talib_adapter.py new file mode 100644 index 0000000..61a4238 --- /dev/null +++ b/backend/src/indicator/talib_adapter.py @@ -0,0 +1,436 @@ +""" +TA-Lib indicator adapter. + +Provides automatic registration of all TA-Lib technical indicators +as composable Indicator instances. + +Installation Requirements: +-------------------------- +TA-Lib requires both the C library and Python wrapper: + +1. Install TA-Lib C library: + - Ubuntu/Debian: sudo apt-get install libta-lib-dev + - macOS: brew install ta-lib + - From source: https://ta-lib.org/install.html + +2. Install Python wrapper (already in requirements.txt): + pip install TA-Lib + +Usage: +------ + from indicator.talib_adapter import register_all_talib_indicators + + # Auto-register all TA-Lib indicators + registry = IndicatorRegistry() + register_all_talib_indicators(registry) + + # Now you can use any TA-Lib indicator + sma = registry.create_instance("SMA", "sma_20", period=20) + rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14) +""" + +import logging +from typing import Any, Dict, List, Optional + +import numpy as np + +try: + import talib + from talib import abstract + TALIB_AVAILABLE = True +except ImportError: + TALIB_AVAILABLE = False + talib = None + abstract = None + +from datasource.schema import ColumnInfo + +from .base import Indicator +from .schema import ( + ComputeContext, + ComputeResult, + IndicatorMetadata, + IndicatorParameter, + InputSchema, + OutputSchema, +) + +logger = logging.getLogger(__name__) + + +# Mapping of TA-Lib parameter types to our schema types +TALIB_TYPE_MAP = { + "double": "float", + "double[]": "float", + "int": "int", + "str": "string", +} + +# Categorization of TA-Lib functions +TALIB_CATEGORIES = { + "overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3", + "BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"], + "momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX", + "ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI", + "STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC", + "BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"], + "volume": ["AD", "ADOSC", "OBV"], + "volatility": ["ATR", "NATR", "TRANGE"], + "price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"], + "cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"], + "pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE", + "CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY", + "CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU", + "CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI", + "CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR", + "CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER", + "CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE", + "CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK", + "CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM", + "CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW", + "CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK", + "CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES", + "CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN", + "CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR", + "CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"], + "statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT", + "LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"], + "math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX", + "MULT", "SUB", "SUM"], +} + + +def _get_function_category(func_name: str) -> str: + """Determine the category of a TA-Lib function.""" + for category, functions in TALIB_CATEGORIES.items(): + if func_name in functions: + return category + return "other" + + +class TALibIndicator(Indicator): + """ + Generic adapter for TA-Lib technical indicators. + + Wraps any TA-Lib function to work within the composable indicator framework. + Handles parameter mapping, input validation, and output formatting. + """ + + # Class variable to store the TA-Lib function name + talib_function_name: str = None + + def __init__(self, instance_name: str, **params): + """ + Initialize a TA-Lib indicator. + + Args: + instance_name: Unique name for this instance + **params: TA-Lib function parameters + """ + if not TALIB_AVAILABLE: + raise ImportError( + "TA-Lib is not installed. Please install the TA-Lib C library " + "and Python wrapper. See indicator/talib_adapter.py for instructions." + ) + + super().__init__(instance_name, **params) + self._talib_func = abstract.Function(self.talib_function_name) + + @classmethod + def get_metadata(cls) -> IndicatorMetadata: + """Get metadata from TA-Lib function info.""" + if not TALIB_AVAILABLE: + raise ImportError("TA-Lib is not installed") + + func = abstract.Function(cls.talib_function_name) + info = func.info + + # Build parameters list from TA-Lib function info + parameters = [] + for param_name, param_info in info.get("parameters", {}).items(): + # Handle case where param_info is a simple value (int/float) instead of a dict + if isinstance(param_info, dict): + param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float") + default_value = param_info.get("default_value") + else: + # param_info is a simple value (default), infer type from the value + if isinstance(param_info, int): + param_type = "int" + elif isinstance(param_info, float): + param_type = "float" + else: + param_type = "float" # Default to float + default_value = param_info + + parameters.append( + IndicatorParameter( + name=param_name, + type=param_type, + description=f"TA-Lib parameter: {param_name}", + default=default_value, + required=False + ) + ) + + # Get function group/category + category = _get_function_category(cls.talib_function_name) + + # Build display name (split camelCase or handle CDL prefix) + display_name = cls.talib_function_name + if display_name.startswith("CDL"): + display_name = display_name[3:] # Remove CDL prefix for patterns + + return IndicatorMetadata( + name=cls.talib_function_name, + display_name=display_name, + description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"), + category=category, + parameters=parameters, + use_cases=[f"Technical analysis using {cls.talib_function_name}"], + references=["https://ta-lib.org/function.html"], + tags=["talib", category, cls.talib_function_name.lower()] + ) + + @classmethod + def get_input_schema(cls) -> InputSchema: + """ + Get input schema from TA-Lib function requirements. + + Most TA-Lib functions use OHLCV data, but some use subsets. + """ + if not TALIB_AVAILABLE: + raise ImportError("TA-Lib is not installed") + + func = abstract.Function(cls.talib_function_name) + info = func.info + input_names = info.get("input_names", {}) + + required_columns = [] + + # Map TA-Lib input names to our schema + if "prices" in input_names: + price_inputs = input_names["prices"] + if "open" in price_inputs: + required_columns.append( + ColumnInfo(name="open", type="float", description="Opening price") + ) + if "high" in price_inputs: + required_columns.append( + ColumnInfo(name="high", type="float", description="High price") + ) + if "low" in price_inputs: + required_columns.append( + ColumnInfo(name="low", type="float", description="Low price") + ) + if "close" in price_inputs: + required_columns.append( + ColumnInfo(name="close", type="float", description="Closing price") + ) + if "volume" in price_inputs: + required_columns.append( + ColumnInfo(name="volume", type="float", description="Trading volume") + ) + + # Handle functions that take generic price arrays + if "price" in input_names: + required_columns.append( + ColumnInfo(name="close", type="float", description="Price (typically close)") + ) + + # If no specific inputs found, assume close price + if not required_columns: + required_columns.append( + ColumnInfo(name="close", type="float", description="Closing price") + ) + + return InputSchema(required_columns=required_columns) + + @classmethod + def get_output_schema(cls, **params) -> OutputSchema: + """Get output schema from TA-Lib function outputs.""" + if not TALIB_AVAILABLE: + raise ImportError("TA-Lib is not installed") + + func = abstract.Function(cls.talib_function_name) + info = func.info + output_names = info.get("output_names", []) + + columns = [] + + # Most TA-Lib functions output one or more float arrays + if isinstance(output_names, list): + for output_name in output_names: + columns.append( + ColumnInfo( + name=output_name.lower(), + type="float", + description=f"{cls.talib_function_name} output: {output_name}", + nullable=True # TA-Lib often has NaN for initial periods + ) + ) + else: + # Single output, use function name + columns.append( + ColumnInfo( + name=cls.talib_function_name.lower(), + type="float", + description=f"{cls.talib_function_name} indicator value", + nullable=True + ) + ) + + return OutputSchema(columns=columns) + + def compute(self, context: ComputeContext) -> ComputeResult: + """Compute indicator using TA-Lib.""" + # Extract input columns + input_data = {} + + # Get the function's expected inputs + info = self._talib_func.info + input_names = info.get("input_names", {}) + + # Prepare input arrays + if "prices" in input_names: + price_inputs = input_names["prices"] + for price_type in price_inputs: + column_data = context.get_column(price_type) + # Convert to numpy array, replacing None with NaN + input_data[price_type] = np.array( + [float(v) if v is not None else np.nan for v in column_data] + ) + elif "price" in input_names: + # Generic price input, use close + column_data = context.get_column("close") + input_data["price"] = np.array( + [float(v) if v is not None else np.nan for v in column_data] + ) + else: + # Default to close if no inputs specified + column_data = context.get_column("close") + input_data["close"] = np.array( + [float(v) if v is not None else np.nan for v in column_data] + ) + + # Set parameters on the function + self._talib_func.parameters = self.params + + # Execute TA-Lib function + try: + output = self._talib_func(input_data) + except Exception as e: + logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}") + raise ValueError(f"TA-Lib computation failed: {e}") + + # Format output + times = context.get_times() + output_names = info.get("output_names", []) + + # Handle single vs multiple outputs + if isinstance(output, np.ndarray): + # Single output + output_name = output_names[0].lower() if output_names else self.talib_function_name.lower() + result_data = [ + { + "time": times[i], + output_name: float(output[i]) if not np.isnan(output[i]) else None + } + for i in range(len(times)) + ] + elif isinstance(output, tuple): + # Multiple outputs + result_data = [] + for i in range(len(times)): + row = {"time": times[i]} + for j, output_array in enumerate(output): + output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}" + row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None + result_data.append(row) + else: + raise ValueError(f"Unexpected TA-Lib output type: {type(output)}") + + return ComputeResult( + data=result_data, + is_partial=context.is_incremental + ) + + +def create_talib_indicator_class(func_name: str) -> type: + """ + Dynamically create an Indicator class for a TA-Lib function. + + Args: + func_name: TA-Lib function name (e.g., 'SMA', 'RSI') + + Returns: + Indicator class for this function + """ + return type( + f"TALib_{func_name}", + (TALibIndicator,), + {"talib_function_name": func_name} + ) + + +def register_all_talib_indicators(registry) -> int: + """ + Auto-register all available TA-Lib indicators with the registry. + + Args: + registry: IndicatorRegistry instance + + Returns: + Number of indicators registered + + Raises: + ImportError: If TA-Lib is not installed + """ + if not TALIB_AVAILABLE: + logger.warning( + "TA-Lib is not installed. Skipping TA-Lib indicator registration. " + "Install TA-Lib C library and Python wrapper to enable TA-Lib indicators." + ) + return 0 + + # Get all TA-Lib functions + func_groups = talib.get_function_groups() + all_functions = [] + for group, functions in func_groups.items(): + all_functions.extend(functions) + + # Remove duplicates + all_functions = sorted(set(all_functions)) + + registered_count = 0 + for func_name in all_functions: + try: + # Create indicator class for this function + indicator_class = create_talib_indicator_class(func_name) + + # Register with the registry + registry.register(indicator_class) + registered_count += 1 + + except Exception as e: + logger.warning(f"Failed to register TA-Lib function {func_name}: {e}") + continue + + logger.info(f"Registered {registered_count} TA-Lib indicators") + return registered_count + + +def get_talib_version() -> Optional[str]: + """ + Get the installed TA-Lib version. + + Returns: + Version string or None if not installed + """ + if TALIB_AVAILABLE: + return talib.__version__ + return None + + +def is_talib_available() -> bool: + """Check if TA-Lib is available.""" + return TALIB_AVAILABLE diff --git a/backend/src/main.py b/backend/src/main.py index 1cd868f..01317f0 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -20,13 +20,14 @@ from gateway.hub import Gateway from gateway.channels.websocket import WebSocketChannel from gateway.protocol import WebSocketAgentUserMessage from agent.core import create_agent -from agent.tools import set_registry, set_datasource_registry +from agent.tools import set_registry, set_datasource_registry, set_indicator_registry from schema.order_spec import SwapOrder from schema.chart_state import ChartState from datasource.registry import DataSourceRegistry from datasource.subscription_manager import SubscriptionManager from datasource.websocket_handler import DatafeedWebSocketHandler from secrets_manager import SecretsStore, InvalidMasterPassword +from indicator import IndicatorRegistry, register_all_talib_indicators # Configure logging logging.basicConfig( @@ -53,6 +54,9 @@ agent_executor = None datasource_registry = DataSourceRegistry() subscription_manager = SubscriptionManager() +# Indicator infrastructure +indicator_registry = IndicatorRegistry() + # Global secrets store secrets_store = SecretsStore() @@ -80,6 +84,14 @@ async def lifespan(app: FastAPI): logger.warning(f"CCXT not available: {e}. Only demo source will be available.") logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0") + # Initialize indicator registry with all TA-Lib indicators + try: + indicator_count = register_all_talib_indicators(indicator_registry) + logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators") + except Exception as e: + logger.warning(f"Failed to register TA-Lib indicators: {e}") + logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.") + # Get API keys from secrets store if unlocked, otherwise fall back to environment anthropic_api_key = None @@ -101,6 +113,7 @@ async def lifespan(app: FastAPI): # Set the registries for agent tools set_registry(registry) set_datasource_registry(datasource_registry) + set_indicator_registry(indicator_registry) # Create and initialize agent agent_executor = create_agent( diff --git a/backend/src/schema/order_spec.py b/backend/src/schema/order_spec.py index b13be5e..78dd563 100644 --- a/backend/src/schema/order_spec.py +++ b/backend/src/schema/order_spec.py @@ -40,6 +40,58 @@ class Exchange(StrEnum): UNISWAP_V3 = "UniswapV3" +class Side(StrEnum): + """Order side: buy or sell""" + BUY = "BUY" + SELL = "SELL" + + +class AmountType(StrEnum): + """Whether the order amount refers to base or quote currency""" + BASE = "BASE" # Amount is in base currency (e.g., BTC in BTC/USD) + QUOTE = "QUOTE" # Amount is in quote currency (e.g., USD in BTC/USD) + + +class TimeInForce(StrEnum): + """Order lifetime specification""" + GTC = "GTC" # Good Till Cancel + IOC = "IOC" # Immediate or Cancel + FOK = "FOK" # Fill or Kill + DAY = "DAY" # Good for trading day + GTD = "GTD" # Good Till Date + + +class ConditionalOrderMode(StrEnum): + """How conditional orders behave on partial fills""" + NEW_PER_FILL = "NEW_PER_FILL" # Create new conditional order per each fill + UNIFIED_ADJUSTING = "UNIFIED_ADJUSTING" # Single conditional order that adjusts amount + + +class TriggerType(StrEnum): + """Type of conditional trigger""" + STOP_LOSS = "STOP_LOSS" + TAKE_PROFIT = "TAKE_PROFIT" + STOP_LIMIT = "STOP_LIMIT" + TRAILING_STOP = "TRAILING_STOP" + + +class TickSpacingMode(StrEnum): + """How price tick spacing is determined""" + FIXED = "FIXED" # Fixed tick size + DYNAMIC = "DYNAMIC" # Tick size varies by price level + CONTINUOUS = "CONTINUOUS" # No tick restrictions + + +class AssetType(StrEnum): + """Type of tradeable asset""" + SPOT = "SPOT" # Spot/cash market + MARGIN = "MARGIN" # Margin trading + PERP = "PERP" # Perpetual futures + FUTURE = "FUTURE" # Dated futures + OPTION = "OPTION" # Options + SYNTHETIC = "SYNTHETIC" # Synthetic/derived instruments + + class OcoMode(StrEnum): NO_OCO = "NO_OCO" CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL" @@ -96,6 +148,126 @@ class TrancheStatus(BaseModel): endTime: Uint32 = Field(description="Concrete end timestamp") +# --------------------------------------------------------------------------- +# Standard Order Models +# --------------------------------------------------------------------------- + +class ConditionalTrigger(BaseModel): + """Conditional order trigger (stop-loss, take-profit, etc.)""" + + model_config = {"extra": "forbid"} + + trigger_type: TriggerType + trigger_price: Float = Field(description="Price at which conditional order activates") + trailing_delta: Float | None = Field(default=None, description="For trailing stops: delta from peak/trough") + + +class AmountConstraints(BaseModel): + """Constraints on order amounts for a symbol""" + + model_config = {"extra": "forbid"} + + min_amount: Float = Field(description="Minimum order amount") + max_amount: Float = Field(description="Maximum order amount") + step_size: Float = Field(description="Amount increment granularity") + + +class PriceConstraints(BaseModel): + """Constraints on order pricing for a symbol""" + + model_config = {"extra": "forbid"} + + tick_spacing_mode: TickSpacingMode + tick_size: Float | None = Field(default=None, description="Fixed tick size (if FIXED mode)") + min_price: Float | None = Field(default=None, description="Minimum allowed price") + max_price: Float | None = Field(default=None, description="Maximum allowed price") + + +class MarketCapabilities(BaseModel): + """Describes what order features a market supports""" + + model_config = {"extra": "forbid"} + + supported_sides: list[Side] = Field(description="Supported order sides (usually both)") + supported_amount_types: list[AmountType] = Field(description="Whether BASE, QUOTE, or both amounts are supported") + supports_market_orders: bool = Field(description="Whether market orders are supported") + supports_limit_orders: bool = Field(description="Whether limit orders are supported") + supported_time_in_force: list[TimeInForce] = Field(description="Supported order lifetimes") + supports_conditional_orders: bool = Field(description="Whether stop-loss/take-profit are supported") + supported_trigger_types: list[TriggerType] = Field(default_factory=list, description="Supported trigger types") + supports_post_only: bool = Field(default=False, description="Whether post-only orders are supported") + supports_reduce_only: bool = Field(default=False, description="Whether reduce-only orders are supported") + supports_iceberg: bool = Field(default=False, description="Whether iceberg orders are supported") + market_order_amount_type: AmountType | None = Field( + default=None, + description="Required amount type for market orders (some DEXs require exact-in)" + ) + + +class SymbolMetadata(BaseModel): + """Complete metadata describing a tradeable symbol/market""" + + model_config = {"extra": "forbid"} + + symbol_id: str = Field(description="Unique symbol identifier") + base_asset: str = Field(description="Base asset (e.g., 'BTC')") + quote_asset: str = Field(description="Quote asset (e.g., 'USD')") + asset_type: AssetType = Field(description="Type of market") + exchange: str = Field(description="Exchange identifier") + + amount_constraints: AmountConstraints + price_constraints: PriceConstraints + capabilities: MarketCapabilities + + contract_size: Float | None = Field(default=None, description="For futures/options: contract multiplier") + settlement_asset: str | None = Field(default=None, description="For derivatives: settlement currency") + expiry_timestamp: Uint64 | None = Field(default=None, description="For dated futures/options: expiration") + + +class StandardOrder(BaseModel): + """Standard order specification for exchange kernels""" + + model_config = {"extra": "forbid"} + + symbol_id: str = Field(description="Symbol to trade") + side: Side = Field(description="Buy or sell") + amount: Float = Field(description="Order amount") + amount_type: AmountType = Field(description="Whether amount is BASE or QUOTE currency") + + limit_price: Float | None = Field(default=None, description="Limit price (None = market order)") + time_in_force: TimeInForce = Field(default=TimeInForce.GTC, description="Order lifetime") + good_till_date: Uint64 | None = Field(default=None, description="Expiry timestamp for GTD orders") + + conditional_trigger: ConditionalTrigger | None = Field( + default=None, + description="Stop-loss/take-profit trigger" + ) + conditional_mode: ConditionalOrderMode | None = Field( + default=None, + description="How conditional orders behave on partial fills" + ) + + reduce_only: bool = Field(default=False, description="Only reduce existing position") + post_only: bool = Field(default=False, description="Only make, never take") + iceberg_qty: Float | None = Field(default=None, description="Visible amount for iceberg orders") + + client_order_id: str | None = Field(default=None, description="Client-specified order ID") + + +class StandardOrderStatus(BaseModel): + """Current status of a standard order""" + + model_config = {"extra": "forbid"} + + order: StandardOrder + order_id: str = Field(description="Exchange-assigned order ID") + status: str = Field(description="Order status: NEW, PARTIALLY_FILLED, FILLED, CANCELED, REJECTED, EXPIRED") + filled_amount: Float = Field(description="Amount filled so far") + average_fill_price: Float = Field(description="Average execution price") + created_at: Uint64 = Field(description="Order creation timestamp") + updated_at: Uint64 = Field(description="Last update timestamp") + + # --------------------------------------------------------------------------- # Order models # --------------------------------------------------------------------------- @@ -117,7 +289,22 @@ class SwapOrder(BaseModel): tranches: list[Tranche] = Field(min_length=1) +class StandardOrderGroup(BaseModel): + """Group of orders with OCO (One-Cancels-Other) relationship""" + + model_config = {"extra": "forbid"} + + mode: OcoMode + orders: list[StandardOrder] = Field(min_length=1) + + +# --------------------------------------------------------------------------- +# Legacy swap order models (kept for backward compatibility) +# --------------------------------------------------------------------------- + class OcoGroup(BaseModel): + """DEPRECATED: Use StandardOrderGroup instead""" + model_config = {"extra": "forbid"} mode: OcoMode diff --git a/backend/src/secrets_manager/__init__.py b/backend/src/secrets_manager/__init__.py new file mode 100644 index 0000000..bc52d33 --- /dev/null +++ b/backend/src/secrets_manager/__init__.py @@ -0,0 +1,40 @@ +""" +Encrypted secrets management with master password protection. + +This module provides secure storage for sensitive configuration like API keys, +using Argon2id for password-based key derivation and Fernet (AES-256) for encryption. + +Basic usage: + from secrets_manager import SecretsStore + + # First time setup + store = SecretsStore() + store.initialize("my-master-password") + store.set("ANTHROPIC_API_KEY", "sk-ant-...") + + # Later usage + store = SecretsStore() + store.unlock("my-master-password") + api_key = store.get("ANTHROPIC_API_KEY") + +Command-line interface: + python -m secrets_manager.cli init + python -m secrets_manager.cli set KEY VALUE + python -m secrets_manager.cli get KEY + python -m secrets_manager.cli list + python -m secrets_manager.cli change-password +""" + +from .store import ( + SecretsStore, + SecretsStoreError, + SecretsStoreLocked, + InvalidMasterPassword, +) + +__all__ = [ + "SecretsStore", + "SecretsStoreError", + "SecretsStoreLocked", + "InvalidMasterPassword", +] diff --git a/backend/src/secrets_manager/cli.py b/backend/src/secrets_manager/cli.py new file mode 100644 index 0000000..03738bc --- /dev/null +++ b/backend/src/secrets_manager/cli.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +Command-line interface for managing the encrypted secrets store. + +Usage: + python -m secrets.cli init # Initialize new secrets store + python -m secrets.cli set KEY VALUE # Set a secret + python -m secrets.cli get KEY # Get a secret + python -m secrets.cli delete KEY # Delete a secret + python -m secrets.cli list # List all secret keys + python -m secrets.cli change-password # Change master password + python -m secrets.cli export FILE # Export encrypted backup + python -m secrets.cli import FILE # Import encrypted backup + python -m secrets.cli migrate-from-env # Migrate secrets from .env file +""" +import sys +import argparse +import getpass +from pathlib import Path + +from .store import SecretsStore, SecretsStoreError, InvalidMasterPassword + + +def get_password(prompt: str = "Master password: ", confirm: bool = False) -> str: + """ + Securely get password from user. + + Args: + prompt: Password prompt + confirm: If True, ask for confirmation + + Returns: + Password string + """ + password = getpass.getpass(prompt) + + if confirm: + confirm_password = getpass.getpass("Confirm password: ") + if password != confirm_password: + print("Error: Passwords do not match", file=sys.stderr) + sys.exit(1) + + return password + + +def cmd_init(args): + """Initialize a new secrets store.""" + store = SecretsStore() + + if store.is_initialized: + print("Error: Secrets store is already initialized", file=sys.stderr) + print(f"Location: {store.secrets_file}", file=sys.stderr) + sys.exit(1) + + password = get_password("Create master password: ", confirm=True) + + if len(password) < 8: + print("Error: Password must be at least 8 characters", file=sys.stderr) + sys.exit(1) + + store.initialize(password) + print(f"Secrets store initialized at {store.secrets_file}") + + +def cmd_set(args): + """Set a secret value.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + password = get_password() + + try: + store.unlock(password) + except InvalidMasterPassword: + print("Error: Invalid master password", file=sys.stderr) + sys.exit(1) + + store.set(args.key, args.value) + print(f"✓ Secret '{args.key}' saved") + + +def cmd_get(args): + """Get a secret value.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + password = get_password() + + try: + store.unlock(password) + except InvalidMasterPassword: + print("Error: Invalid master password", file=sys.stderr) + sys.exit(1) + + value = store.get(args.key) + if value is None: + print(f"Error: Secret '{args.key}' not found", file=sys.stderr) + sys.exit(1) + + # Print to stdout (can be captured) + print(value) + + +def cmd_delete(args): + """Delete a secret.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + password = get_password() + + try: + store.unlock(password) + except InvalidMasterPassword: + print("Error: Invalid master password", file=sys.stderr) + sys.exit(1) + + if store.delete(args.key): + print(f"✓ Secret '{args.key}' deleted") + else: + print(f"Error: Secret '{args.key}' not found", file=sys.stderr) + sys.exit(1) + + +def cmd_list(args): + """List all secret keys.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + password = get_password() + + try: + store.unlock(password) + except InvalidMasterPassword: + print("Error: Invalid master password", file=sys.stderr) + sys.exit(1) + + keys = store.list_keys() + + if not keys: + print("No secrets stored") + else: + print(f"Stored secrets ({len(keys)}):") + for key in sorted(keys): + # Show key and value length for verification + value = store.get(key) + value_str = str(value) + value_preview = value_str[:50] + "..." if len(value_str) > 50 else value_str + print(f" {key}: {value_preview}") + + +def cmd_change_password(args): + """Change the master password.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + current_password = get_password("Current master password: ") + new_password = get_password("New master password: ", confirm=True) + + if len(new_password) < 8: + print("Error: Password must be at least 8 characters", file=sys.stderr) + sys.exit(1) + + try: + store.change_master_password(current_password, new_password) + except InvalidMasterPassword: + print("Error: Invalid current password", file=sys.stderr) + sys.exit(1) + + +def cmd_export(args): + """Export encrypted secrets to a backup file.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + output_path = Path(args.file) + + if output_path.exists() and not args.force: + print(f"Error: File {output_path} already exists. Use --force to overwrite.", file=sys.stderr) + sys.exit(1) + + try: + store.export_encrypted(output_path) + except SecretsStoreError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +def cmd_import(args): + """Import encrypted secrets from a backup file.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + input_path = Path(args.file) + + if not input_path.exists(): + print(f"Error: File {input_path} does not exist", file=sys.stderr) + sys.exit(1) + + password = get_password() + + try: + store.import_encrypted(input_path, password) + except InvalidMasterPassword: + print("Error: Invalid master password or incompatible backup", file=sys.stderr) + sys.exit(1) + except SecretsStoreError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +def cmd_migrate_from_env(args): + """Migrate secrets from .env file to encrypted store.""" + store = SecretsStore() + + if not store.is_initialized: + print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr) + sys.exit(1) + + # Look for .env file + backend_root = Path(__file__).parent.parent.parent + env_file = backend_root / ".env" + + if not env_file.exists(): + print(f"Error: .env file not found at {env_file}", file=sys.stderr) + sys.exit(1) + + password = get_password() + + try: + store.unlock(password) + except InvalidMasterPassword: + print("Error: Invalid master password", file=sys.stderr) + sys.exit(1) + + # Parse .env file (simple parser - doesn't handle all edge cases) + migrated = 0 + skipped = 0 + + with open(env_file) as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + + # Skip empty lines and comments + if not line or line.startswith('#'): + continue + + # Parse KEY=VALUE format + if '=' not in line: + print(f"Warning: Skipping invalid line {line_num}: {line}", file=sys.stderr) + skipped += 1 + continue + + key, value = line.split('=', 1) + key = key.strip() + value = value.strip() + + # Remove quotes if present + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.startswith("'") and value.endswith("'"): + value = value[1:-1] + + # Check if key already exists + existing = store.get(key) + if existing is not None: + print(f"Warning: Secret '{key}' already exists, skipping", file=sys.stderr) + skipped += 1 + continue + + store.set(key, value) + print(f"✓ Migrated: {key}") + migrated += 1 + + print(f"\nMigration complete: {migrated} secrets migrated, {skipped} skipped") + + if not args.keep_env: + # Ask for confirmation before deleting .env + confirm = input(f"\nDelete {env_file}? [y/N]: ").strip().lower() + if confirm == 'y': + env_file.unlink() + print(f"✓ Deleted {env_file}") + else: + print(f"Kept {env_file} (consider deleting it manually)") + + +def main(): + """Main CLI entry point.""" + parser = argparse.ArgumentParser( + description="Manage encrypted secrets store", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + subparsers = parser.add_subparsers(dest='command', help='Command to run') + subparsers.required = True + + # init + parser_init = subparsers.add_parser('init', help='Initialize new secrets store') + parser_init.set_defaults(func=cmd_init) + + # set + parser_set = subparsers.add_parser('set', help='Set a secret value') + parser_set.add_argument('key', help='Secret key name') + parser_set.add_argument('value', help='Secret value') + parser_set.set_defaults(func=cmd_set) + + # get + parser_get = subparsers.add_parser('get', help='Get a secret value') + parser_get.add_argument('key', help='Secret key name') + parser_get.set_defaults(func=cmd_get) + + # delete + parser_delete = subparsers.add_parser('delete', help='Delete a secret') + parser_delete.add_argument('key', help='Secret key name') + parser_delete.set_defaults(func=cmd_delete) + + # list + parser_list = subparsers.add_parser('list', help='List all secret keys') + parser_list.set_defaults(func=cmd_list) + + # change-password + parser_change = subparsers.add_parser('change-password', help='Change master password') + parser_change.set_defaults(func=cmd_change_password) + + # export + parser_export = subparsers.add_parser('export', help='Export encrypted backup') + parser_export.add_argument('file', help='Output file path') + parser_export.add_argument('--force', action='store_true', help='Overwrite existing file') + parser_export.set_defaults(func=cmd_export) + + # import + parser_import = subparsers.add_parser('import', help='Import encrypted backup') + parser_import.add_argument('file', help='Input file path') + parser_import.set_defaults(func=cmd_import) + + # migrate-from-env + parser_migrate = subparsers.add_parser('migrate-from-env', help='Migrate from .env file') + parser_migrate.add_argument('--keep-env', action='store_true', help='Keep .env file after migration') + parser_migrate.set_defaults(func=cmd_migrate_from_env) + + args = parser.parse_args() + + try: + args.func(args) + except KeyboardInterrupt: + print("\nAborted", file=sys.stderr) + sys.exit(130) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/backend/src/secrets_manager/crypto.py b/backend/src/secrets_manager/crypto.py new file mode 100644 index 0000000..45cd5fc --- /dev/null +++ b/backend/src/secrets_manager/crypto.py @@ -0,0 +1,144 @@ +""" +Cryptographic utilities for secrets management. + +Uses Argon2id for password-based key derivation and Fernet for encryption. +""" +import os +import secrets as secrets_module +from typing import Tuple + +from argon2 import PasswordHasher +from argon2.low_level import hash_secret_raw, Type +from cryptography.fernet import Fernet +import base64 + + +# Argon2id parameters (OWASP recommended for password-based KDF) +# These provide strong defense against GPU/ASIC attacks +ARGON2_TIME_COST = 3 # iterations +ARGON2_MEMORY_COST = 65536 # 64 MB +ARGON2_PARALLELISM = 4 # threads +ARGON2_HASH_LENGTH = 32 # bytes (256 bits for Fernet key) +ARGON2_SALT_LENGTH = 16 # bytes (128 bits) + + +def generate_salt() -> bytes: + """Generate a cryptographically secure random salt.""" + return secrets_module.token_bytes(ARGON2_SALT_LENGTH) + + +def derive_key_from_password(password: str, salt: bytes) -> bytes: + """ + Derive an encryption key from a password using Argon2id. + + Args: + password: The master password + salt: The salt (must be consistent for the same password to work) + + Returns: + 32-byte key suitable for Fernet encryption + """ + password_bytes = password.encode('utf-8') + + # Use Argon2id (hybrid mode - best of Argon2i and Argon2d) + raw_hash = hash_secret_raw( + secret=password_bytes, + salt=salt, + time_cost=ARGON2_TIME_COST, + memory_cost=ARGON2_MEMORY_COST, + parallelism=ARGON2_PARALLELISM, + hash_len=ARGON2_HASH_LENGTH, + type=Type.ID # Argon2id + ) + + return raw_hash + + +def create_fernet(key: bytes) -> Fernet: + """ + Create a Fernet cipher instance from a raw key. + + Args: + key: 32-byte raw key from Argon2id + + Returns: + Fernet instance for encryption/decryption + """ + # Fernet requires a URL-safe base64-encoded 32-byte key + fernet_key = base64.urlsafe_b64encode(key) + return Fernet(fernet_key) + + +def encrypt_data(data: bytes, key: bytes) -> bytes: + """ + Encrypt data using Fernet (AES-256-CBC). + + Args: + data: Raw bytes to encrypt + key: 32-byte encryption key + + Returns: + Encrypted data (includes IV and auth tag) + """ + fernet = create_fernet(key) + return fernet.encrypt(data) + + +def decrypt_data(encrypted_data: bytes, key: bytes) -> bytes: + """ + Decrypt data using Fernet. + + Args: + encrypted_data: Encrypted bytes from encrypt_data + key: 32-byte encryption key (must match encryption key) + + Returns: + Decrypted raw bytes + + Raises: + cryptography.fernet.InvalidToken: If decryption fails (wrong key/corrupted data) + """ + fernet = create_fernet(key) + return fernet.decrypt(encrypted_data) + + +def create_verification_hash(password: str, salt: bytes) -> str: + """ + Create a verification hash to check if a password is correct. + + This is NOT for storing the password - it's for verifying the password + unlocks the correct key without trying to decrypt the entire secrets file. + + Args: + password: The master password + salt: The salt used for key derivation + + Returns: + Base64-encoded hash for verification + """ + # Derive key and hash it again for verification + key = derive_key_from_password(password, salt) + + # Simple hash of the key for verification (not security critical since + # the key itself is already derived from Argon2id) + verification = base64.b64encode(key[:16]).decode('ascii') + + return verification + + +def verify_password(password: str, salt: bytes, verification_hash: str) -> bool: + """ + Verify a password against a verification hash. + + Args: + password: Password to verify + salt: Salt used for key derivation + verification_hash: Expected verification hash + + Returns: + True if password is correct, False otherwise + """ + computed_hash = create_verification_hash(password, salt) + + # Constant-time comparison to prevent timing attacks + return secrets_module.compare_digest(computed_hash, verification_hash) diff --git a/backend/src/secrets_manager/store.py b/backend/src/secrets_manager/store.py new file mode 100644 index 0000000..bc2f734 --- /dev/null +++ b/backend/src/secrets_manager/store.py @@ -0,0 +1,406 @@ +""" +Encrypted secrets store with master password protection. + +The secrets are stored in an encrypted file, with the encryption key derived +from a master password using Argon2id. The master password can be changed +without re-encrypting all secrets. +""" +import json +import os +import stat +from pathlib import Path +from typing import Dict, Optional, Any + +from cryptography.fernet import InvalidToken + +from .crypto import ( + generate_salt, + derive_key_from_password, + encrypt_data, + decrypt_data, + create_verification_hash, + verify_password, +) + + +class SecretsStoreError(Exception): + """Base exception for secrets store errors.""" + pass + + +class SecretsStoreLocked(SecretsStoreError): + """Raised when trying to access secrets while store is locked.""" + pass + + +class InvalidMasterPassword(SecretsStoreError): + """Raised when master password is incorrect.""" + pass + + +class SecretsStore: + """ + Encrypted secrets store with master password protection. + + Usage: + # Initialize (first time) + store = SecretsStore() + store.initialize("my-secure-password") + + # Unlock + store = SecretsStore() + store.unlock("my-secure-password") + + # Access secrets + api_key = store.get("ANTHROPIC_API_KEY") + store.set("NEW_SECRET", "secret-value") + + # Change master password + store.change_master_password("my-secure-password", "new-password") + + # Lock when done + store.lock() + """ + + def __init__(self, data_dir: Optional[Path] = None): + """ + Initialize secrets store. + + Args: + data_dir: Directory for secrets files (defaults to backend/data) + """ + if data_dir is None: + # Default to backend/data + backend_root = Path(__file__).parent.parent.parent + data_dir = backend_root / "data" + + self.data_dir = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + self.master_key_file = self.data_dir / ".master.key" + self.secrets_file = self.data_dir / "secrets.enc" + + # Runtime state + self._encryption_key: Optional[bytes] = None + self._secrets: Optional[Dict[str, Any]] = None + + @property + def is_initialized(self) -> bool: + """Check if the secrets store has been initialized.""" + return self.master_key_file.exists() + + @property + def is_unlocked(self) -> bool: + """Check if the secrets store is currently unlocked.""" + return self._encryption_key is not None + + def initialize(self, master_password: str) -> None: + """ + Initialize the secrets store with a master password. + + This should only be called once when setting up the store. + + Args: + master_password: The master password to protect the secrets + + Raises: + SecretsStoreError: If store is already initialized + """ + if self.is_initialized: + raise SecretsStoreError( + "Secrets store is already initialized. " + "Use unlock() to access it or change_master_password() to change the password." + ) + + # Generate a new random salt + salt = generate_salt() + + # Derive encryption key + encryption_key = derive_key_from_password(master_password, salt) + + # Create verification hash + verification_hash = create_verification_hash(master_password, salt) + + # Store salt and verification hash + master_key_data = { + "salt": salt.hex(), + "verification": verification_hash, + } + + self.master_key_file.write_text(json.dumps(master_key_data, indent=2)) + + # Set restrictive permissions (owner read/write only) + os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR) + + # Initialize empty secrets + self._encryption_key = encryption_key + self._secrets = {} + self._save_secrets() + + print(f"✓ Secrets store initialized at {self.secrets_file}") + + def unlock(self, master_password: str) -> None: + """ + Unlock the secrets store with the master password. + + Args: + master_password: The master password + + Raises: + SecretsStoreError: If store is not initialized + InvalidMasterPassword: If password is incorrect + """ + if not self.is_initialized: + raise SecretsStoreError( + "Secrets store is not initialized. Call initialize() first." + ) + + # Load salt and verification hash + master_key_data = json.loads(self.master_key_file.read_text()) + salt = bytes.fromhex(master_key_data["salt"]) + verification_hash = master_key_data["verification"] + + # Verify password + if not verify_password(master_password, salt, verification_hash): + raise InvalidMasterPassword("Invalid master password") + + # Derive encryption key + encryption_key = derive_key_from_password(master_password, salt) + + # Load and decrypt secrets + if self.secrets_file.exists(): + try: + encrypted_data = self.secrets_file.read_bytes() + decrypted_data = decrypt_data(encrypted_data, encryption_key) + self._secrets = json.loads(decrypted_data.decode('utf-8')) + except InvalidToken: + raise InvalidMasterPassword("Failed to decrypt secrets (invalid password)") + except json.JSONDecodeError as e: + raise SecretsStoreError(f"Corrupted secrets file: {e}") + else: + # No secrets file yet (fresh initialization) + self._secrets = {} + + self._encryption_key = encryption_key + print(f"✓ Secrets store unlocked ({len(self._secrets)} secrets)") + + def lock(self) -> None: + """Lock the secrets store (clear decrypted data from memory).""" + self._encryption_key = None + self._secrets = None + + def get(self, key: str, default: Any = None) -> Any: + """ + Get a secret value. + + Args: + key: Secret key name + default: Default value if key doesn't exist + + Returns: + Secret value or default + + Raises: + SecretsStoreLocked: If store is locked + """ + if not self.is_unlocked: + raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.") + + return self._secrets.get(key, default) + + def set(self, key: str, value: Any) -> None: + """ + Set a secret value. + + Args: + key: Secret key name + value: Secret value (must be JSON-serializable) + + Raises: + SecretsStoreLocked: If store is locked + """ + if not self.is_unlocked: + raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.") + + self._secrets[key] = value + self._save_secrets() + + def delete(self, key: str) -> bool: + """ + Delete a secret. + + Args: + key: Secret key name + + Returns: + True if secret existed and was deleted, False otherwise + + Raises: + SecretsStoreLocked: If store is locked + """ + if not self.is_unlocked: + raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.") + + if key in self._secrets: + del self._secrets[key] + self._save_secrets() + return True + return False + + def list_keys(self) -> list[str]: + """ + List all secret keys. + + Returns: + List of secret keys + + Raises: + SecretsStoreLocked: If store is locked + """ + if not self.is_unlocked: + raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.") + + return list(self._secrets.keys()) + + def change_master_password(self, current_password: str, new_password: str) -> None: + """ + Change the master password. + + This re-encrypts the secrets with a new key derived from the new password. + + Args: + current_password: Current master password + new_password: New master password + + Raises: + InvalidMasterPassword: If current password is incorrect + """ + # ALWAYS verify current password before changing + # Load salt and verification hash + if not self.is_initialized: + raise SecretsStoreError( + "Secrets store is not initialized. Call initialize() first." + ) + + master_key_data = json.loads(self.master_key_file.read_text()) + salt = bytes.fromhex(master_key_data["salt"]) + verification_hash = master_key_data["verification"] + + # Verify current password is correct + if not verify_password(current_password, salt, verification_hash): + raise InvalidMasterPassword("Invalid current password") + + # Unlock if needed to access secrets + was_unlocked = self.is_unlocked + if not was_unlocked: + # Store is locked, so unlock with current password + # (we already verified it above, so this will succeed) + encryption_key = derive_key_from_password(current_password, salt) + + # Load and decrypt secrets + if self.secrets_file.exists(): + encrypted_data = self.secrets_file.read_bytes() + decrypted_data = decrypt_data(encrypted_data, encryption_key) + self._secrets = json.loads(decrypted_data.decode('utf-8')) + else: + self._secrets = {} + + self._encryption_key = encryption_key + + # Generate new salt + new_salt = generate_salt() + + # Derive new encryption key + new_encryption_key = derive_key_from_password(new_password, new_salt) + + # Create new verification hash + new_verification_hash = create_verification_hash(new_password, new_salt) + + # Update master key file + master_key_data = { + "salt": new_salt.hex(), + "verification": new_verification_hash, + } + self.master_key_file.write_text(json.dumps(master_key_data, indent=2)) + os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR) + + # Re-encrypt secrets with new key + old_key = self._encryption_key + self._encryption_key = new_encryption_key + self._save_secrets() + + print("✓ Master password changed successfully") + + # Lock if it wasn't unlocked before + if not was_unlocked: + self.lock() + + def _save_secrets(self) -> None: + """Save secrets to encrypted file.""" + if not self.is_unlocked: + raise SecretsStoreLocked("Cannot save while locked") + + # Serialize secrets to JSON + secrets_json = json.dumps(self._secrets, indent=2) + secrets_bytes = secrets_json.encode('utf-8') + + # Encrypt + encrypted_data = encrypt_data(secrets_bytes, self._encryption_key) + + # Write to file + self.secrets_file.write_bytes(encrypted_data) + + # Set restrictive permissions + os.chmod(self.secrets_file, stat.S_IRUSR | stat.S_IWUSR) + + def export_encrypted(self, output_path: Path) -> None: + """ + Export encrypted secrets to a file (for backup). + + Args: + output_path: Path to export file + + Raises: + SecretsStoreError: If secrets file doesn't exist + """ + if not self.secrets_file.exists(): + raise SecretsStoreError("No secrets to export") + + import shutil + shutil.copy2(self.secrets_file, output_path) + print(f"✓ Encrypted secrets exported to {output_path}") + + def import_encrypted(self, input_path: Path, master_password: str) -> None: + """ + Import encrypted secrets from a file. + + This will verify the password can decrypt the import before replacing + the current secrets. + + Args: + input_path: Path to import file + master_password: Master password for the current store + + Raises: + InvalidMasterPassword: If password doesn't work with import + """ + if not self.is_unlocked: + self.unlock(master_password) + + # Try to decrypt the imported file with current key + try: + encrypted_data = Path(input_path).read_bytes() + decrypted_data = decrypt_data(encrypted_data, self._encryption_key) + imported_secrets = json.loads(decrypted_data.decode('utf-8')) + except InvalidToken: + raise InvalidMasterPassword( + "Cannot decrypt imported secrets with current master password" + ) + except json.JSONDecodeError as e: + raise SecretsStoreError(f"Corrupted import file: {e}") + + # Replace secrets + self._secrets = imported_secrets + self._save_secrets() + + print(f"✓ Imported {len(self._secrets)} secrets from {input_path}") diff --git a/backend/tests/test_ccxt_datasource.py b/backend/tests/test_ccxt_datasource.py index ffaf335..95f532a 100644 --- a/backend/tests/test_ccxt_datasource.py +++ b/backend/tests/test_ccxt_datasource.py @@ -2,8 +2,10 @@ Test script for CCXT DataSource adapter (Free Version). This demonstrates how to use the free CCXT adapter (not ccxt.pro) with various -exchanges. It uses polling instead of WebSocket for real-time updates and -verifies that Decimal precision is maintained throughout. +exchanges. It uses polling instead of WebSocket for real-time updates. + +CCXT is configured to use Decimal mode internally for precision, but OHLCV data +is converted to float for optimal DataFrame/analysis performance. """ import asyncio @@ -73,10 +75,10 @@ async def test_binance_datasource(): print(f" Close: {latest.data['close']} (type: {type(latest.data['close']).__name__})") print(f" Volume: {latest.data['volume']} (type: {type(latest.data['volume']).__name__})") - # Verify Decimal precision - assert isinstance(latest.data['close'], Decimal), "Price should be Decimal type!" - assert isinstance(latest.data['volume'], Decimal), "Volume should be Decimal type!" - print(f" ✓ Numerical precision verified: using Decimal types") + # Verify OHLCV uses float (converted from Decimal for DataFrame performance) + assert isinstance(latest.data['close'], float), "OHLCV price should be float type!" + assert isinstance(latest.data['volume'], float), "OHLCV volume should be float type!" + print(f" ✓ OHLCV data type verified: using native float (CCXT uses Decimal internally)") # Test 5: Polling subscription (brief test) print("\n5. Testing polling-based subscription...") @@ -87,7 +89,7 @@ async def test_binance_datasource(): tick_count[0] += 1 if tick_count[0] == 1: print(f" Received tick: close={data['close']} (type: {type(data['close']).__name__})") - assert isinstance(data['close'], Decimal), "Polled data should use Decimal!" + assert isinstance(data['close'], float), "Polled OHLCV data should use float!" subscription_id = await binance.subscribe_bars( symbol="BTC/USDT", diff --git a/deploy/Dockerfile-backend b/deploy/Dockerfile-backend index 2275c3d..950ee0b 100644 --- a/deploy/Dockerfile-backend +++ b/deploy/Dockerfile-backend @@ -1,3 +1,27 @@ FROM python:3.14-alpine -COPY python/src /app/src +# Install TA-Lib C library and build dependencies +RUN apk add --no-cache --virtual .build-deps \ + gcc \ + g++ \ + make \ + musl-dev \ + wget \ + && apk add --no-cache \ + ta-lib \ + && rm -rf /var/cache/apk/* + +# Set working directory +WORKDIR /app + +# Copy requirements first for better caching +COPY backend/requirements.txt /app/requirements.txt + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Clean up build dependencies +RUN apk del .build-deps + +# Copy application code +COPY backend/src /app/src diff --git a/web/src/App.vue b/web/src/App.vue index 56aa29c..81c32d0 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -13,6 +13,7 @@ import { wsManager } from './composables/useWebSocket' const isAuthenticated = ref(false) const needsConfirmation = ref(false) const authError = ref() +const isDragging = ref(false) let stateSyncCleanup: (() => void) | null = null // Check if we need password confirmation on first load @@ -58,6 +59,21 @@ const handleAuthenticate = async ( } } +onMounted(() => { + // Listen for splitter drag events + document.addEventListener('mousedown', (e) => { + // Check if the mousedown is on a splitter gutter + const target = e.target as HTMLElement + if (target.closest('.p-splitter-gutter')) { + isDragging.value = true + } + }) + + document.addEventListener('mouseup', () => { + isDragging.value = false + }) +}) + onBeforeUnmount(() => { if (stateSyncCleanup) { stateSyncCleanup() @@ -82,6 +98,8 @@ onBeforeUnmount(() => { + +
@@ -90,19 +108,24 @@ onBeforeUnmount(() => { width: 100vw !important; height: 100vh !important; overflow: hidden; - background: var(--p-surface-0); + background: var(--p-surface-0) !important; +} + +.app-container.dark { + background: var(--p-surface-0) !important; } .main-splitter { height: 100vh !important; + background: var(--p-surface-0) !important; } .main-splitter :deep(.p-splitter-gutter) { - background: var(--p-surface-100); + background: var(--p-surface-0) !important; } .main-splitter :deep(.p-splitter-gutter-handle) { - background: var(--p-primary-color); + background: var(--p-surface-400) !important; } .chart-panel, @@ -119,4 +142,15 @@ onBeforeUnmount(() => { display: flex; flex-direction: column; } + +.drag-overlay { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + z-index: 9999; + cursor: col-resize; + background: transparent; +} diff --git a/web/src/assets/theme.css b/web/src/assets/theme.css index 3476bb3..bd12e43 100644 --- a/web/src/assets/theme.css +++ b/web/src/assets/theme.css @@ -26,3 +26,16 @@ html, body, #app { overflow: hidden; background-color: var(--p-surface-0) !important; } + +.dark { + background-color: var(--p-surface-0) !important; + color: var(--p-surface-900) !important; +} + +/* Ensure dark background for main containers */ +.app-container, +.main-splitter, +.p-splitter, +.p-splitter-panel { + background-color: var(--p-surface-0) !important; +} diff --git a/web/src/components/ChartView.vue b/web/src/components/ChartView.vue index fe55cfb..e8386b4 100644 --- a/web/src/components/ChartView.vue +++ b/web/src/components/ChartView.vue @@ -191,9 +191,14 @@ onBeforeUnmount(() => { flex-direction: column; overflow: hidden; border: none; + border-radius: 0 !important; background: var(--p-surface-0); } +.chart-card :deep(.p-card) { + border-radius: 0 !important; +} + .chart-card :deep(.p-card-body) { flex: 1; display: flex; diff --git a/web/src/components/ChatPanel.vue b/web/src/components/ChatPanel.vue index ab4772d..c450631 100644 --- a/web/src/components/ChatPanel.vue +++ b/web/src/components/ChatPanel.vue @@ -2,6 +2,7 @@ import { ref, onMounted, onUnmounted, computed } from 'vue' import { register } from 'vue-advanced-chat' import Badge from 'primevue/badge' +import Button from 'primevue/button' import { wsManager } from '../composables/useWebSocket' import type { WebSocketMessage } from '../composables/useWebSocket' @@ -17,7 +18,7 @@ const messages = ref([]) const messagesLoaded = ref(false) const isConnected = wsManager.isConnected -// Reactive rooms that update based on WebSocket connection +// Reactive rooms that update based on WebSocket connection and agent processing state const rooms = computed(() => [{ roomId: SESSION_ID, roomName: 'AI Agent', @@ -26,23 +27,29 @@ const rooms = computed(() => [{ { _id: CURRENT_USER_ID, username: 'You' }, { _id: AGENT_ID, username: 'AI Agent', status: { state: isConnected.value ? 'online' : 'offline' } } ], - unreadCount: 0 + unreadCount: 0, + typingUsers: isAgentProcessing.value ? [AGENT_ID] : [] }]) // Streaming state let currentStreamingMessageId: string | null = null let streamingBuffer = '' +const isAgentProcessing = ref(false) // Generate message ID const generateMessageId = () => `msg-${Date.now()}-${Math.random().toString(36).substr(2, 9)}` // Handle WebSocket messages const handleMessage = (data: WebSocketMessage) => { + console.log('[ChatPanel] Received message:', data) if (data.type === 'agent_chunk') { + console.log('[ChatPanel] Processing agent_chunk, content:', data.content, 'done:', data.done) const timestamp = new Date().toTimeString().split(' ')[0].slice(0, 5) if (!currentStreamingMessageId) { + console.log('[ChatPanel] Starting new streaming message') // Start new streaming message + isAgentProcessing.value = true currentStreamingMessageId = generateMessageId() streamingBuffer = data.content @@ -54,7 +61,8 @@ const handleMessage = (data: WebSocketMessage) => { date: new Date().toLocaleDateString(), saved: false, distributed: false, - seen: false + seen: false, + files: [] }] } else { // Update existing streaming message @@ -62,10 +70,24 @@ const handleMessage = (data: WebSocketMessage) => { const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId) if (msgIndex !== -1) { - messages.value[msgIndex] = { + const updatedMessage: any = { ...messages.value[msgIndex], content: streamingBuffer } + + // Add plot images if present in metadata + if (data.metadata && data.metadata.plot_urls && Array.isArray(data.metadata.plot_urls)) { + const plotFiles = data.metadata.plot_urls.map((url: string, idx: number) => ({ + name: `plot_${idx + 1}.png`, + size: 0, + type: 'png', + url: `${BACKEND_URL}${url}`, + preview: `${BACKEND_URL}${url}` + })) + updatedMessage.files = plotFiles + } + + messages.value[msgIndex] = updatedMessage messages.value = [...messages.value] } } @@ -74,21 +96,49 @@ const handleMessage = (data: WebSocketMessage) => { // Mark message as complete const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId) if (msgIndex !== -1) { - messages.value[msgIndex] = { + const finalMessage: any = { ...messages.value[msgIndex], saved: true, distributed: true, seen: true } + + // Ensure plot images are included in final message + if (data.metadata && data.metadata.plot_urls && Array.isArray(data.metadata.plot_urls)) { + const plotFiles = data.metadata.plot_urls.map((url: string, idx: number) => ({ + name: `plot_${idx + 1}.png`, + size: 0, + type: 'png', + url: `${BACKEND_URL}${url}`, + preview: `${BACKEND_URL}${url}` + })) + finalMessage.files = plotFiles + } + + messages.value[msgIndex] = finalMessage messages.value = [...messages.value] } currentStreamingMessageId = null streamingBuffer = '' + isAgentProcessing.value = false } } } +// Stop agent processing +const stopAgent = () => { + // Send empty message to trigger interrupt without new agent round + const wsMessage = { + type: 'agent_user_message', + session_id: SESSION_ID, + content: '', + attachments: [] + } + wsManager.send(wsMessage) + isAgentProcessing.value = false +} + // Send message handler const sendMessage = async (event: any) => { // Extract data from CustomEvent.detail[0] @@ -191,39 +241,39 @@ const openFile = ({ file }: any) => { } // Theme configuration for dark mode -const chatTheme = 'light' +const chatTheme = 'dark' // Styles to match PrimeVue theme const chatStyles = computed(() => JSON.stringify({ general: { - color: 'var(--p-surface-900)', - colorSpinner: 'var(--p-primary-color)', - borderStyle: '1px solid var(--p-surface-200)' + color: '#cdd6e8', + colorSpinner: '#00d4aa', + borderStyle: '1px solid #263452' }, container: { - background: 'var(--p-surface-0)' + background: '#0a0e1a' }, header: { - background: 'var(--p-surface-50)', - colorRoomName: 'var(--p-surface-900)', - colorRoomInfo: 'var(--p-surface-700)' + background: '#0f1629', + colorRoomName: '#cdd6e8', + colorRoomInfo: '#8892a4' }, footer: { - background: 'var(--p-surface-50)', - borderStyleInput: '1px solid var(--p-surface-300)', - backgroundInput: 'var(--p-surface-200)', - colorInput: 'var(--p-surface-900)', - colorPlaceholder: 'var(--p-surface-400)', - colorIcons: 'var(--p-surface-400)' + background: '#0f1629', + borderStyleInput: '1px solid #263452', + backgroundInput: '#161e35', + colorInput: '#cdd6e8', + colorPlaceholder: '#8892a4', + colorIcons: '#8892a4' }, content: { - background: 'var(--p-surface-0)' + background: '#0a0e1a' }, message: { - background: 'var(--p-surface-100)', - backgroundMe: 'var(--p-primary-color)', - color: 'var(--p-surface-900)', - colorMe: 'var(--p-primary-contrast-color)' + background: '#161e35', + backgroundMe: '#00d4aa', + color: '#cdd6e8', + colorMe: '#0a0e1a' } })) @@ -231,6 +281,14 @@ onMounted(() => { wsManager.addHandler(handleMessage) // Mark messages as loaded after initialization messagesLoaded.value = true + + // Focus on the chat input when component mounts + setTimeout(() => { + const chatInput = document.querySelector('.vac-textarea') as HTMLTextAreaElement + if (chatInput) { + chatInput.focus() + } + }, 300) }) onUnmounted(() => { @@ -251,7 +309,7 @@ onUnmounted(() => { --> { :show-emojis="true" :show-reaction-emojis="false" :accepted-files="'image/*,video/*,application/pdf'" + :message-images="true" @send-message="sendMessage" @fetch-messages="fetchMessages" @open-file="openFile" /> + + +
+
@@ -279,8 +349,9 @@ onUnmounted(() => { height: 100% !important; display: flex; flex-direction: column; - background: var(--p-surface-0); + background: var(--p-surface-0) !important; overflow: hidden; + position: relative; } .chat-container :deep(.vac-container) { @@ -306,4 +377,25 @@ onUnmounted(() => { font-weight: 600; color: var(--p-surface-900); } + +.stop-button-container { + position: absolute; + bottom: 80px; + right: 20px; + z-index: 1000; +} + +.stop-button { + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); + animation: pulse 2s infinite; +} + +@keyframes pulse { + 0%, 100% { + opacity: 1; + } + 50% { + opacity: 0.8; + } +} diff --git a/web/src/components/LoginScreen.vue b/web/src/components/LoginScreen.vue index 54c1c19..180a2ea 100644 --- a/web/src/components/LoginScreen.vue +++ b/web/src/components/LoginScreen.vue @@ -1,5 +1,5 @@