indicators and plots

This commit is contained in:
2026-03-02 18:34:38 -04:00
parent 3b29096dab
commit 3ffce97b3e
43 changed files with 6690 additions and 878 deletions

View File

@@ -5,7 +5,7 @@ server_port: 8081
agent: agent:
model: "claude-sonnet-4-20250514" model: "claude-sonnet-4-20250514"
temperature: 0.7 temperature: 0.7
context_docs_dir: "doc" context_docs_dir: "memory"
# Local memory configuration (free & sophisticated!) # Local memory configuration (free & sophisticated!)
memory: memory:

View File

@@ -14,15 +14,7 @@ You are a **strategy authoring assistant**, not a strategy executor. You help us
## Your Capabilities ## Your Capabilities
### State Management ### State Management
You have read/write access to synchronized state stores: You have read/write access to synchronized state stores. Use your tools to read current state and update it as needed. All state changes are automatically synchronized with connected clients.
- **OrderStore**: Active swap orders and order configurations
- **ChartStore**: Current chart view state (symbol, time range, interval)
- `symbol`: Trading pair currently being viewed (e.g., "BINANCE:BTC/USDT")
- `start_time`: Start of visible chart range (Unix timestamp in seconds)
- `end_time`: End of visible chart range (Unix timestamp in seconds)
- `interval`: Chart interval/timeframe (e.g., "15", "60", "D")
- Use your tools to read current state and update it as needed
- All state changes are automatically synchronized with connected clients
### Strategy Authoring ### Strategy Authoring
- Help users express trading intent through conversation - Help users express trading intent through conversation
@@ -32,10 +24,9 @@ You have read/write access to synchronized state stores:
- Validate strategy logic for correctness and safety - Validate strategy logic for correctness and safety
### Data & Analysis ### Data & Analysis
- Access to market data through abstract feed specifications - Access market data through abstract feed specifications
- Can compute indicators and perform technical analysis - Compute indicators and perform technical analysis
- Understand OHLCV data, order books, and market microstructure - Understand OHLCV data, order books, and market microstructure
- Interpret unstructured data (news, sentiment, on-chain metrics)
## Communication Style ## Communication Style
@@ -48,7 +39,7 @@ You have read/write access to synchronized state stores:
## Key Principles ## Key Principles
1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime 1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime
2. **Local Execution**: The platform runs locally for security; you're design-time only 2. **Local Execution**: The platform runs locally for security; you are a design-time tool only
3. **Schema Validation**: All outputs must conform to platform schemas 3. **Schema Validation**: All outputs must conform to platform schemas
4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management 4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management
5. **Versioning**: Every strategy artifact is version-controlled with full auditability 5. **Versioning**: Every strategy artifact is version-controlled with full auditability
@@ -56,7 +47,6 @@ You have read/write access to synchronized state stores:
## Your Limitations ## Your Limitations
- You **DO NOT** execute trades directly - You **DO NOT** execute trades directly
- You **DO NOT** have access to live market data in real-time (users provide it)
- You **CANNOT** modify the order kernel or execution layer - You **CANNOT** modify the order kernel or execution layer
- You **SHOULD NOT** make assumptions about user risk tolerance without asking - You **SHOULD NOT** make assumptions about user risk tolerance without asking
- You **MUST NOT** provide trading or investment advice - You **MUST NOT** provide trading or investment advice
@@ -69,53 +59,93 @@ You have access to:
- Past strategy discussions and decisions - Past strategy discussions and decisions
- Relevant context retrieved automatically based on current conversation - Relevant context retrieved automatically based on current conversation
## Tools Available
### State Management Tools
- `list_sync_stores()`: See available state stores
- `read_sync_state(store_name)`: Read current state
- `write_sync_state(store_name, updates)`: Update state
- `get_store_schema(store_name)`: Inspect state structure
### Data Source Tools
- `list_data_sources()`: List available data sources (exchanges)
- `search_symbols(query, type, exchange, limit)`: Search for trading symbols
- `get_symbol_info(source_name, symbol)`: Get metadata for a symbol
- `get_historical_data(source_name, symbol, resolution, from_time, to_time, countback)`: Get historical bars
- **`get_chart_data(countback)`**: Get data for the chart the user is currently viewing
- This is the **preferred** way to access chart data when analyzing what the user is looking at
- Automatically reads ChartStore to determine symbol, timeframe, and visible range
- Returns OHLCV data plus any custom columns for the visible chart range
- **`analyze_chart_data(python_script, countback)`**: Execute Python analysis on current chart data
- Automatically fetches current chart data and converts to pandas DataFrame
- Execute custom Python scripts with access to pandas, numpy, matplotlib
- Captures matplotlib plots as base64 images for display to user
- Returns result DataFrames and any printed output
- **Use this for technical analysis, indicator calculations, statistical analysis, and visualization**
## Important Behavioral Rules ## Important Behavioral Rules
### Chart Context Awareness ### Chart Context Awareness
When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view: When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view:
1. **ALWAYS** first use `read_sync_state("ChartStore")` to see what they're viewing 1. **Chart info is automatically available** — The dynamic system prompt includes current chart state (symbol, interval, timeframe)
2. **NEVER** ask the user to upload an image or tell you what symbol they're looking at 2. **NEVER** ask the user to upload an image or tell you what symbol they're looking at
3. The user is viewing a live trading chart in the UI - you can access what they see via ChartStore 3. **Just use `execute_python()`** — It automatically loads the chart data from what they're viewing
4. After reading ChartStore, you can use `get_chart_data()` to get the actual candle data 4. Inside your Python script, `df` contains the data and `chart_context` has the metadata
5. For technical analysis questions, use `analyze_chart_data()` with Python scripts 5. Use `plot_ohlc(df)` to create beautiful candlestick charts
Examples of questions that require checking ChartStore first: This applies to questions like: "Can you see this chart?", "What are the swing highs and lows?", "Is this in an uptrend?", "What's the current price?", "Analyze this chart", "What am I looking at?"
- "Can you see this chart?"
- "What are the swing highs and lows?"
- "Is this in an uptrend?"
- "What's the current price?"
- "Analyze this chart"
- "What am I looking at?"
### Data Analysis Workflow ### Data Analysis Workflow
1. **Check ChartStore** → Know what the user is viewing 1. **Chart context is automatic** → Symbol, interval, and timeframe are in the dynamic system prompt
2. **Get data** with `get_chart_data()`Fetch the actual OHLCV bars 2. **Use `execute_python()`**This is your PRIMARY analysis tool
3. **Analyze** with `analyze_chart_data()` → Run Python analysis if needed - Automatically loads chart data into a pandas DataFrame `df`
4. **Respond** with insights based on the actual data - Pre-imports numpy (`np`), pandas (`pd`), matplotlib (`plt`), and talib
- Provides access to the indicator registry for computing indicators
- Use `plot_ohlc(df)` helper for beautiful candlestick charts
3. **Only use `get_chart_data()`** → For simple data inspection without analysis
### Python Analysis (`execute_python`) - Your Primary Tool
**ALWAYS use `execute_python()` when the user asks for:**
- Technical indicators (RSI, MACD, Bollinger Bands, moving averages, etc.)
- Chart visualizations or plots
- Statistical calculations or market analysis
- Pattern detection or trend analysis
- Any computational analysis of price data
**Why `execute_python()` is preferred:**
- Chart data (`df`) is automatically loaded from ChartStore (visible time range)
- Full pandas/numpy/talib stack pre-imported
- Use `plot_ohlc(df)` for instant professional candlestick charts
- Access to 150+ indicators via `indicator_registry`
- **Results include plots as image URLs** that are automatically displayed to the user
- Prints and return values are included in the response
**CRITICAL: Plots are automatically shown to the user**
When you create a matplotlib figure (via `plot_ohlc()` or `plt.figure()`), it is automatically:
1. Saved as a PNG image
2. Returned in the response as a URL (e.g., `/uploads/plot_abc123.png`)
3. **Displayed in the user's chat interface** - they see the image immediately
You MUST use `execute_python()` with `plot_ohlc()` or matplotlib whenever the user wants to see a chart or plot.
**IMPORTANT: Never use `get_historical_data()` for chart analysis**
- `get_historical_data()` requires manual timestamp calculation and is only for custom queries
- When analyzing what the user is viewing, ALWAYS use `execute_python()` which automatically loads the correct data
- The `df` DataFrame in `execute_python()` is pre-loaded with the exact time range the user is viewing
**Example workflows:**
```python
# Computing an indicator and plotting
execute_python("""
df['RSI'] = talib.RSI(df['close'], 14)
fig = plot_ohlc(df, title='Price with RSI')
df[['close', 'RSI']].tail(10)
""")
# Multi-indicator analysis
execute_python("""
df['SMA20'] = df['close'].rolling(20).mean()
df['BB_upper'] = df['close'].rolling(20).mean() + 2 * df['close'].rolling(20).std()
df['BB_lower'] = df['close'].rolling(20).mean() - 2 * df['close'].rolling(20).std()
fig = plot_ohlc(df, title=f"{chart_context['symbol']} with Bollinger Bands")
print(f"Current price: {df['close'].iloc[-1]:.2f}")
print(f"20-period SMA: {df['SMA20'].iloc[-1]:.2f}")
""")
```
**Only use `get_chart_data()` for:**
- Quick inspection of raw bar data
- When you just need the data structure without analysis
### Quick Reference: Common Tasks
| User Request | Tool to Use | Example |
|--------------|-------------|---------|
| "Show me RSI" | `execute_python()` | `df['RSI'] = talib.RSI(df['close'], 14); plot_ohlc(df)` |
| "What's the current price?" | `execute_python()` | `print(f"Current: {df['close'].iloc[-1]}")` |
| "Is this bullish?" | `execute_python()` | Compute SMAs, trend, and analyze |
| "Add Bollinger Bands" | `execute_python()` | Compute bands, use `plot_ohlc(df, title='BB')` |
| "Find swing highs" | `execute_python()` | Use pandas logic to detect patterns |
| "What indicators exist?" | `search_indicators()` | Search by category or query |
| "What chart am I viewing?" | N/A - automatic | Chart info is in dynamic system prompt |
| "Read other stores" | `read_sync_state(store_name)` | For TraderState, StrategyState, etc. |
## Working with Users ## Working with Users
@@ -124,51 +154,3 @@ Examples of questions that require checking ChartStore first:
3. **Validate**: Ensure strategy makes sense before generating code 3. **Validate**: Ensure strategy makes sense before generating code
4. **Test**: Encourage backtesting and paper trading first 4. **Test**: Encourage backtesting and paper trading first
5. **Monitor**: Help users interpret live strategy behavior 5. **Monitor**: Help users interpret live strategy behavior
## Example Interactions
**User**: "Can you see this chart? What are the swing highs and lows?"
**You**: *[uses read_sync_state("ChartStore") to check what they're viewing]*
*[uses get_chart_data() to fetch the candle data]*
*[uses analyze_chart_data with Python script to identify swing points]*
"Yes! You're viewing BINANCE:BTC/USDT on the 15-minute chart. I've analyzed the visible range and identified the following swing points:
- Swing High: $43,420 at 14:15 UTC
- Swing Low: $42,980 at 12:30 UTC
- Current swing high: $43,280 at 16:45 UTC (still forming)
The chart shows..."
**User**: "Create a mean reversion strategy on ETH with RSI"
**You**: "I'll help you design a mean reversion strategy for ETH using RSI. A few questions:
- What RSI levels indicate oversold/overbought? (e.g., 30/70)
- What timeframe? (e.g., 1h, 15m)
- Position sizing as % of balance?
- Stop loss and take profit levels?"
**User**: "What am I looking at?"
**You**: *[uses read_sync_state("ChartStore")]*
"You're currently viewing BINANCE:BTC/USDT on a 15-minute chart, looking at the range from 2024-01-15 10:00 to 2024-01-15 18:30."
**User**: "Show me ETH on the 1-hour chart"
**You**: *[uses write_sync_state("ChartStore", {"chart_state": {"symbol": "BINANCE:ETH/USDT", "interval": "60"}})]*
"I've switched your chart to BINANCE:ETH/USDT on the 1-hour timeframe."
**User**: "What's the current price?"
**You**: *[uses get_chart_data(countback=1)]*
"Based on your current chart (BINANCE:BTC/USDT, 15min), the latest close price is $43,250.50 as of 14:30 UTC."
**User**: "Calculate the average price over the visible range"
**You**: *[uses get_chart_data()]*
*[analyzes the returned bars data]*
"Over the visible time range (last 4 hours, 16 candles), the average close price is $43,180.25, with a high of $43,420 and low of $42,980."
**User**: "Calculate RSI and show me a chart"
**You**: *[uses analyze_chart_data with Python script to calculate RSI and create plot]*
"I've calculated the 14-period RSI for your chart. The current RSI is 58.3, indicating neutral momentum. Here's the chart showing price and RSI over the visible range." *[image displayed to user]*
**User**: "Is this in an uptrend?"
**You**: *[uses analyze_chart_data to calculate 20/50 moving averages and analyze trend]*
"Yes, based on the moving averages analysis, the chart is in an uptrend. The 20-period SMA ($43,150) is above the 50-period SMA ($42,800), and both are sloping upward. Price is currently trading above both averages."
---
Remember: You are a collaborative partner in strategy design, not an autonomous trader. Always prioritize safety, clarity, and user intent.

View File

@@ -4,6 +4,7 @@ pandas
numpy numpy
scipy scipy
matplotlib matplotlib
mplfinance
fastapi fastapi
uvicorn uvicorn
websockets websockets
@@ -11,6 +12,7 @@ jsonpatch
python-multipart python-multipart
ccxt>=4.0.0 ccxt>=4.0.0
pyyaml pyyaml
TA-Lib>=0.4.0
# LangChain agent dependencies # LangChain agent dependencies
langchain>=0.3.0 langchain>=0.3.0
@@ -19,6 +21,11 @@ langgraph-checkpoint-sqlite>=1.0.0
langchain-anthropic>=0.3.0 langchain-anthropic>=0.3.0
langchain-community>=0.3.0 langchain-community>=0.3.0
# Additional tools for research and web access
arxiv>=2.0.0
duckduckgo-search>=7.0.0
requests>=2.31.0
# Local memory system # Local memory system
chromadb>=0.4.0 chromadb>=0.4.0
sentence-transformers>=2.0.0 sentence-transformers>=2.0.0

View File

@@ -1,3 +1,4 @@
from agent.core import create_agent # Don't import at module level to avoid circular imports
# Users should import directly: from agent.core import create_agent
__all__ = ["create_agent"] __all__ = ["core", "tools"]

View File

@@ -7,7 +7,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS
from agent.memory import MemoryManager from agent.memory import MemoryManager
from agent.session import SessionManager from agent.session import SessionManager
from agent.prompts import build_system_prompt from agent.prompts import build_system_prompt
@@ -60,17 +60,15 @@ class AgentExecutor:
"""Initialize the agent system.""" """Initialize the agent system."""
await self.memory_manager.initialize() await self.memory_manager.initialize()
# Create agent with tools and LangGraph checkpointing # Create agent with tools and LangGraph checkpointer
checkpointer = self.memory_manager.get_checkpointer() checkpointer = self.memory_manager.get_checkpointer()
# Build initial system prompt with context # Create agent without a static system prompt
context = self.memory_manager.get_context_prompt() # We'll pass the dynamic system prompt via state_modifier at runtime
system_prompt = build_system_prompt(context, []) # Include all tool categories: sync, datasource, chart, indicator, and research
self.agent = create_react_agent( self.agent = create_react_agent(
self.llm, self.llm,
SYNC_TOOLS + DATASOURCE_TOOLS, SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS,
prompt=system_prompt,
checkpointer=checkpointer checkpointer=checkpointer
) )
@@ -101,26 +99,6 @@ class AgentExecutor:
except Exception as e: except Exception as e:
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}") logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage:
"""Build system message with context.
Args:
state: Agent state
Returns:
SystemMessage with full context
"""
# Get context from loaded documents
context = self.memory_manager.get_context_prompt()
# Get active channels from metadata
active_channels = state.get("metadata", {}).get("active_channels", [])
# Build system prompt
system_prompt = build_system_prompt(context, active_channels)
return SystemMessage(content=system_prompt)
async def execute( async def execute(
self, self,
session: UserSession, session: UserSession,
@@ -143,7 +121,12 @@ class AgentExecutor:
async with lock: async with lock:
try: try:
# Build message history # Build system prompt with current context
context = self.memory_manager.get_context_prompt()
system_prompt = build_system_prompt(context, session.active_channels)
# Build message history WITHOUT prepending system message
# The system prompt will be passed via state_modifier in the config
messages = [] messages = []
history = session.get_history(limit=10) history = session.get_history(limit=10)
logger.info(f"Building message history, {len(history)} messages in history") logger.info(f"Building message history, {len(history)} messages in history")
@@ -155,14 +138,18 @@ class AgentExecutor:
elif msg.role == "assistant": elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content)) messages.append(AIMessage(content=msg.content))
logger.info(f"Prepared {len(messages)} messages for agent") logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'") msg_type = type(msg).__name__
content_preview = msg.content[:100] if msg.content else 'EMPTY'
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
# Prepare config with metadata # Prepare config with metadata and dynamic system prompt
# Pass system_prompt via state_modifier to avoid multiple system messages
config = RunnableConfig( config = RunnableConfig(
configurable={ configurable={
"thread_id": session.session_id "thread_id": session.session_id,
"state_modifier": system_prompt # Dynamic system prompt injection
}, },
metadata={ metadata={
"session_id": session.session_id, "session_id": session.session_id,
@@ -178,6 +165,8 @@ class AgentExecutor:
event_count = 0 event_count = 0
chunk_count = 0 chunk_count = 0
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
async for event in self.agent.astream_events( async for event in self.agent.astream_events(
{"messages": messages}, {"messages": messages},
config=config, config=config,
@@ -199,7 +188,35 @@ class AgentExecutor:
elif event["event"] == "on_tool_end": elif event["event"] == "on_tool_end":
tool_name = event.get("name", "unknown") tool_name = event.get("name", "unknown")
tool_output = event.get("data", {}).get("output") tool_output = event.get("data", {}).get("output")
logger.info(f"Tool call completed: {tool_name} with output: {tool_output}")
# LangChain may wrap the output in a ToolMessage with content field
# Try to extract the actual content from the ToolMessage
actual_output = tool_output
if hasattr(tool_output, "content"):
actual_output = tool_output.content
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
# Extract plot_urls from execute_python tool results
if tool_name == "execute_python":
# Try to parse as JSON if it's a string
import json
if isinstance(actual_output, str):
try:
actual_output = json.loads(actual_output)
except (json.JSONDecodeError, ValueError):
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
if isinstance(actual_output, dict):
tool_plot_urls = actual_output.get("plot_urls", [])
if tool_plot_urls:
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
plot_urls.extend(tool_plot_urls)
# Yield metadata about plots immediately
yield {
"content": "",
"metadata": {"plot_urls": tool_plot_urls}
}
# Extract streaming tokens # Extract streaming tokens
elif event["event"] == "on_chat_model_stream": elif event["event"] == "on_chat_model_stream":

View File

@@ -1,7 +1,54 @@
from typing import List from typing import List, Dict, Any
from gateway.user_session import UserSession from gateway.user_session import UserSession
def _get_chart_store_context() -> str:
"""Get current ChartStore state for context injection.
Returns:
Formatted string with ChartStore contents, or empty string if unavailable
"""
try:
from agent.tools import _registry
if not _registry:
return ""
chart_store = _registry.entries.get("ChartStore")
if not chart_store:
return ""
chart_state = chart_store.model.model_dump(mode="json")
chart_data = chart_state.get("chart_state", {})
# Only include if there's actual chart data
if not chart_data or not chart_data.get("symbol"):
return ""
# Format the chart information
symbol = chart_data.get("symbol", "N/A")
interval = chart_data.get("interval", "N/A")
start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time")
chart_context = f"""
## Current Chart Context
The user is currently viewing a chart with the following settings:
- **Symbol**: {symbol}
- **Interval**: {interval}
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}
This information is automatically available because you're connected via websocket.
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.
"""
return chart_context
except Exception:
# Silently fail - chart context is optional enhancement
return ""
def build_system_prompt(context: str, active_channels: List[str]) -> str: def build_system_prompt(context: str, active_channels: List[str]) -> str:
"""Build the system prompt for the agent. """Build the system prompt for the agent.
@@ -17,6 +64,15 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str:
""" """
channels_str = ", ".join(active_channels) if active_channels else "none" channels_str = ", ".join(active_channels) if active_channels else "none"
# Check if user is connected via websocket - if so, inject chart context
# Note: We check for websocket by looking for "websocket" in channel IDs
# since WebSocketChannel uses channel_id like "websocket-{uuid}"
has_websocket = any("websocket" in channel_id.lower() for channel_id in active_channels)
chart_context = ""
if has_websocket:
chart_context = _get_chart_store_context()
# Context already includes system_prompt.md and other docs # Context already includes system_prompt.md and other docs
# Just add current session information # Just add current session information
prompt = f"""{context} prompt = f"""{context}
@@ -28,7 +84,7 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str:
Your responses will be sent to all active channels. Your responses are streamed back in real-time. Your responses will be sent to all active channels. Your responses are streamed back in real-time.
If the user sends a new message while you're responding, your current response will be interrupted If the user sends a new message while you're responding, your current response will be interrupted
and you'll be re-invoked with the updated context. and you'll be re-invoked with the updated context.
""" {chart_context}"""
return prompt return prompt

View File

@@ -1,662 +0,0 @@
from typing import Dict, Any, List, Optional
import io
import base64
import sys
import uuid
import logging
from pathlib import Path
from contextlib import redirect_stdout, redirect_stderr
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
# Global registry instance (will be set by main.py)
_registry = None
_datasource_registry = None
def set_registry(registry):
"""Set the global SyncRegistry instance for tools to use."""
global _registry
_registry = registry
def set_datasource_registry(datasource_registry):
"""Set the global DataSourceRegistry instance for tools to use."""
global _datasource_registry
_datasource_registry = datasource_registry
@tool
def list_sync_stores() -> List[str]:
"""List all available synchronization stores.
Returns:
List of store names that can be read/written
"""
if not _registry:
return []
return list(_registry.entries.keys())
@tool
def read_sync_state(store_name: str) -> Dict[str, Any]:
"""Read the current state of a synchronization store.
Args:
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
Returns:
Dictionary containing the current state of the store
Raises:
ValueError: If store_name doesn't exist
"""
if not _registry:
raise ValueError("SyncRegistry not initialized")
entry = _registry.entries.get(store_name)
if not entry:
available = list(_registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
return entry.model.model_dump(mode="json")
@tool
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
"""Update the state of a synchronization store.
This will apply the updates to the store and trigger synchronization
with all connected clients.
Args:
store_name: Name of the store to update
updates: Dictionary of field updates (field_name: new_value)
Returns:
Dictionary with status and updated fields
Raises:
ValueError: If store_name doesn't exist or updates are invalid
"""
if not _registry:
raise ValueError("SyncRegistry not initialized")
entry = _registry.entries.get(store_name)
if not entry:
available = list(_registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
try:
# Get current state
current_state = entry.model.model_dump(mode="json")
# Apply updates
new_state = {**current_state, **updates}
# Update the model
_registry._update_model(entry.model, new_state)
# Trigger sync
await _registry.push_all()
return {
"status": "success",
"store": store_name,
"updated_fields": list(updates.keys())
}
except Exception as e:
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
@tool
def get_store_schema(store_name: str) -> Dict[str, Any]:
"""Get the schema/structure of a synchronization store.
This shows what fields are available and their types.
Args:
store_name: Name of the store
Returns:
Dictionary describing the store's schema
Raises:
ValueError: If store_name doesn't exist
"""
if not _registry:
raise ValueError("SyncRegistry not initialized")
entry = _registry.entries.get(store_name)
if not entry:
available = list(_registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
# Get model schema
schema = entry.model.model_json_schema()
return {
"store_name": store_name,
"schema": schema
}
# DataSource tools
@tool
def list_data_sources() -> List[str]:
"""List all available data sources.
Returns:
List of data source names that can be queried for market data
"""
if not _datasource_registry:
return []
return _datasource_registry.list_sources()
@tool
async def search_symbols(
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> Dict[str, Any]:
"""Search for trading symbols across all data sources.
Automatically searches all available data sources and returns aggregated results.
Use this to find symbols before calling get_symbol_info or get_historical_data.
Args:
query: Search query (e.g., "BTC", "AAPL", "EUR")
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
limit: Maximum number of results per source (default: 30)
Returns:
Dictionary mapping source names to lists of matching symbols.
Each symbol includes: symbol, full_name, description, exchange, type.
Use the source name and symbol from results with get_symbol_info or get_historical_data.
Example response:
{
"demo": [
{
"symbol": "BTC/USDT",
"full_name": "Bitcoin / Tether USD",
"description": "Bitcoin perpetual futures",
"exchange": "demo",
"type": "crypto"
}
]
}
"""
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized")
# Always search all sources
results = await _datasource_registry.search_all(query, type, exchange, limit)
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
@tool
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
"""Get complete metadata for a trading symbol.
This retrieves full information about a symbol including:
- Description and type
- Supported time resolutions
- Available data columns (OHLCV, volume, funding rates, etc.)
- Trading session information
- Price scale and precision
Args:
source_name: Name of the data source (use list_data_sources to see available)
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
Returns:
Dictionary containing complete symbol metadata including column schema
Raises:
ValueError: If source_name or symbol is not found
"""
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized")
symbol_info = await _datasource_registry.resolve_symbol(source_name, symbol)
return symbol_info.model_dump()
@tool
async def get_historical_data(
source_name: str,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> Dict[str, Any]:
"""Get historical bar/candle data for a symbol.
Retrieves time-series data between the specified timestamps. The data
includes all columns defined for the symbol (OHLCV + any custom columns).
Args:
source_name: Name of the data source
symbol: Symbol identifier
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
from_time: Start time as Unix timestamp in seconds
to_time: End time as Unix timestamp in seconds
countback: Optional limit on number of bars to return
Returns:
Dictionary containing:
- symbol: The requested symbol
- resolution: The time resolution
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- nextTime: If present, indicates more data is available for pagination
Raises:
ValueError: If source, symbol, or resolution is invalid
Example:
# Get 1-hour BTC data for the last 24 hours
import time
to_time = int(time.time())
from_time = to_time - 86400 # 24 hours ago
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
"""
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized")
source = _datasource_registry.get(source_name)
if not source:
available = _datasource_registry.list_sources()
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
return result.model_dump()
async def _get_chart_data_impl(countback: Optional[int] = None):
"""Internal implementation for getting chart data.
This is a helper function that can be called by both get_chart_data tool
and analyze_chart_data tool.
Returns:
Tuple of (HistoryResult, chart_context dict, source_name)
"""
if not _registry:
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized - cannot query data")
# Read current chart state
chart_store = _registry.entries.get("ChartStore")
if not chart_store:
raise ValueError("ChartStore not found in registry")
chart_state = chart_store.model.model_dump(mode="json")
chart_data = chart_state.get("chart_state", {})
symbol = chart_data.get("symbol", "")
interval = chart_data.get("interval", "15")
start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time")
if not symbol:
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
# Parse the symbol to extract exchange/source and symbol name
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
if ":" not in symbol:
raise ValueError(
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
)
exchange_prefix, symbol_name = symbol.split(":", 1)
source_name = exchange_prefix.lower()
# Get the data source
source = _datasource_registry.get(source_name)
if not source:
available = _datasource_registry.list_sources()
raise ValueError(
f"Data source '{source_name}' not found. Available sources: {available}. "
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
)
# Determine time range - REQUIRE it to be set, no defaults
if start_time is None or end_time is None:
raise ValueError(
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
f"Wait for the chart to fully load before analyzing data."
)
from_time = int(start_time)
end_time = int(end_time)
logger.info(
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
f"countback={countback}"
)
logger.info(
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
f"resolution '{interval}'"
)
# Query the data source
result = await source.get_bars(
symbol=symbol_name,
resolution=interval,
from_time=from_time,
to_time=end_time,
countback=countback
)
logger.info(
f"Received {len(result.bars)} bars from data source. "
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
)
# Build chart context to return along with result
chart_context = {
"symbol": symbol,
"interval": interval,
"start_time": start_time,
"end_time": end_time
}
return result, chart_context, source_name
@tool
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
"""Get the candle/bar data for what the user is currently viewing on their chart.
This is a convenience tool that automatically:
1. Reads the ChartStore to see what chart the user is viewing
2. Parses the symbol to determine the data source (exchange prefix)
3. Queries the appropriate data source for that symbol's data
4. Returns the data for the visible time range and interval
This is the preferred way to access chart data when helping the user analyze
what they're looking at, since it automatically uses their current chart context.
Args:
countback: Optional limit on number of bars to return. If not specified,
returns all bars in the visible time range.
Returns:
Dictionary containing:
- chart_context: Current chart state (symbol, interval, time range)
- symbol: The trading pair being viewed
- resolution: The chart interval
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- source: Which data source was used
Raises:
ValueError: If ChartStore or DataSourceRegistry is not initialized,
or if the symbol format is invalid
Example:
# User is viewing BINANCE:BTC/USDT on 15min chart
data = get_chart_data()
# Returns BTC/USDT data from binance source at 15min resolution
# for the currently visible time range
"""
result, chart_context, source_name = await _get_chart_data_impl(countback)
# Return enriched result with chart context
response = result.model_dump()
response["chart_context"] = chart_context
response["source"] = source_name
return response
@tool
async def analyze_chart_data(python_script: str, countback: Optional[int] = None) -> Dict[str, Any]:
"""Analyze the current chart data using a Python script with pandas and matplotlib.
This tool:
1. Gets the current chart data (same as get_chart_data)
2. Converts it to a pandas DataFrame with columns: time, open, high, low, close, volume
3. Executes your Python script with access to the DataFrame as 'df'
4. Saves any matplotlib plots to disk and returns URLs to access them
5. Returns any final DataFrame result and plot URLs
The script has access to:
- `df`: pandas DataFrame with OHLCV data indexed by datetime
- `pandas` (as `pd`): For data manipulation
- `numpy` (as `np`): For numerical operations
- `matplotlib.pyplot` (as `plt`): For plotting (use plt.figure() for each plot)
All matplotlib figures are automatically saved to disk and accessible via URLs.
The last expression in the script (if it's a DataFrame) is returned as the result.
Args:
python_script: Python code to execute. The DataFrame is available as 'df'.
Can use pandas, numpy, matplotlib. Return a DataFrame to include it in results.
countback: Optional limit on number of bars to analyze
Returns:
Dictionary containing:
- chart_context: Current chart state (symbol, interval, time range)
- source: Data source used
- script_output: Any printed output from the script
- result_dataframe: If script returns a DataFrame, it's included here as dict
- plot_urls: List of URLs to saved plot images (one per plt.figure())
- error: Error message if script execution failed
Example scripts:
# Calculate 20-period SMA and plot
```python
df['SMA20'] = df['close'].rolling(20).mean()
plt.figure(figsize=(12, 6))
plt.plot(df.index, df['close'], label='Close')
plt.plot(df.index, df['SMA20'], label='SMA20')
plt.legend()
plt.title('Price with SMA')
df[['close', 'SMA20']].tail(10) # Return last 10 rows
```
# Calculate RSI
```python
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / loss
df['RSI'] = 100 - (100 / (1 + rs))
df[['close', 'RSI']].tail(20)
```
# Multiple plots
```python
# Price chart
plt.figure(figsize=(12, 4))
plt.plot(df['close'])
plt.title('Price')
# Volume chart
plt.figure(figsize=(12, 3))
plt.bar(df.index, df['volume'])
plt.title('Volume')
df.describe() # Return statistics
```
"""
if not _registry:
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized - cannot query data")
try:
# Import pandas and numpy here to allow lazy loading
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
except ImportError as e:
raise ValueError(
f"Required library not installed: {e}. "
"Please install pandas, numpy, and matplotlib: pip install pandas numpy matplotlib"
)
# Get chart data using the internal helper function
result, chart_context, source_name = await _get_chart_data_impl(countback)
# Build the same response format as get_chart_data
chart_data = result.model_dump()
chart_data["chart_context"] = chart_context
chart_data["source"] = source_name
# Convert bars to DataFrame
bars = chart_data.get('bars', [])
if not bars:
return {
"chart_context": chart_data.get('chart_context', {}),
"source": chart_data.get('source', ''),
"error": "No data available for the current chart"
}
# Build DataFrame
rows = []
for bar in bars:
row = {
'time': pd.to_datetime(bar['time'], unit='s'),
**bar['data'] # Includes open, high, low, close, volume, etc.
}
rows.append(row)
df = pd.DataFrame(rows)
df.set_index('time', inplace=True)
# Convert price columns to float for clean numeric operations
price_columns = ['open', 'high', 'low', 'close', 'volume']
for col in price_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
logger.info(
f"Created DataFrame with {len(df)} rows, columns: {df.columns.tolist()}, "
f"time range: {df.index.min()} to {df.index.max()}, "
f"dtypes: {df.dtypes.to_dict()}"
)
# Prepare execution environment
script_globals = {
'df': df,
'pd': pd,
'np': np,
'plt': plt,
}
# Capture stdout/stderr
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
result_df = None
error_msg = None
plot_urls = []
# Determine uploads directory (relative to this file)
uploads_dir = Path(__file__).parent.parent.parent / "data" / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
try:
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
# Execute the script
exec(python_script, script_globals)
# Check if the last line is an expression that returns a DataFrame
# We'll try to evaluate it separately
script_lines = python_script.strip().split('\n')
if script_lines:
last_line = script_lines[-1].strip()
# Only evaluate if it doesn't look like a statement
if last_line and not any(last_line.startswith(kw) for kw in ['if', 'for', 'while', 'def', 'class', 'import', 'from', 'with', 'try', 'return']):
try:
last_result = eval(last_line, script_globals)
if isinstance(last_result, pd.DataFrame):
result_df = last_result
except:
# If eval fails, that's okay - might not be an expression
pass
# Save all matplotlib figures to disk
for fig_num in plt.get_fignums():
fig = plt.figure(fig_num)
# Generate unique filename
plot_id = str(uuid.uuid4())
filename = f"plot_{plot_id}.png"
filepath = uploads_dir / filename
# Save figure to file
fig.savefig(filepath, format='png', bbox_inches='tight', dpi=100)
# Generate URL that can be accessed via the web server
plot_url = f"/uploads/{filename}"
plot_urls.append(plot_url)
plt.close(fig)
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
import traceback
error_msg += f"\n{traceback.format_exc()}"
# Build response
response = {
"chart_context": chart_data.get('chart_context', {}),
"source": chart_data.get('source', ''),
"script_output": stdout_capture.getvalue(),
}
if error_msg:
response["error"] = error_msg
response["stderr"] = stderr_capture.getvalue()
if result_df is not None:
# Convert DataFrame to dict for JSON serialization
response["result_dataframe"] = {
"columns": result_df.columns.tolist(),
"index": result_df.index.astype(str).tolist() if hasattr(result_df.index, 'astype') else result_df.index.tolist(),
"data": result_df.values.tolist(),
"shape": result_df.shape,
}
if plot_urls:
response["plot_urls"] = plot_urls
return response
# Export all tools
SYNC_TOOLS = [
list_sync_stores,
read_sync_state,
write_sync_state,
get_store_schema
]
DATASOURCE_TOOLS = [
list_data_sources,
search_symbols,
get_symbol_info,
get_historical_data,
get_chart_data,
analyze_chart_data
]

View File

@@ -0,0 +1,139 @@
# Chart Utilities - Standard OHLC Plotting
## Overview
The `chart_utils.py` module provides convenience functions for creating beautiful, professional OHLC candlestick charts with a consistent look and feel. This is designed to be used by the LLM in `analyze_chart_data` scripts, eliminating the need to write custom matplotlib code for every chart.
## Key Features
- **Beautiful by default**: Uses mplfinance with seaborn-inspired aesthetics
- **Consistent styling**: Professional color scheme (teal green up, coral red down)
- **Easy to use**: Simple function calls instead of complex matplotlib code
- **Customizable**: Supports all mplfinance options via kwargs
- **Volume integration**: Optional volume subplot
## Installation
The required package `mplfinance` has been added to `requirements.txt`:
```bash
pip install mplfinance
```
## Available Functions
### 1. `plot_ohlc(df, title=None, volume=True, figsize=(14, 8), **kwargs)`
Main function for creating standard OHLC candlestick charts.
**Parameters:**
- `df`: pandas DataFrame with DatetimeIndex and OHLCV columns
- `title`: Optional chart title
- `volume`: Whether to include volume subplot (default: True)
- `figsize`: Figure size in inches (default: (14, 8))
- `**kwargs`: Additional mplfinance.plot() arguments
**Example:**
```python
fig = plot_ohlc(df, title='BTC/USDT 15min', volume=True)
```
### 2. `add_indicators_to_plot(df, indicators, **plot_kwargs)`
Creates OHLC chart with technical indicators overlaid.
**Parameters:**
- `df`: DataFrame with OHLCV data and indicator columns
- `indicators`: Dict mapping indicator column names to display parameters
- `**plot_kwargs`: Additional arguments for plot_ohlc()
**Example:**
```python
df['SMA_20'] = df['close'].rolling(20).mean()
df['SMA_50'] = df['close'].rolling(50).mean()
fig = add_indicators_to_plot(
df,
indicators={
'SMA_20': {'color': 'blue', 'width': 1.5},
'SMA_50': {'color': 'red', 'width': 1.5}
},
title='Price with Moving Averages'
)
```
### 3. Preset Functions
- `plot_price_volume(df, title=None)` - Standard price + volume chart
- `plot_price_only(df, title=None)` - Candlesticks without volume
## Integration with analyze_chart_data
These functions are automatically available in the `analyze_chart_data` tool's script environment:
```python
# In an analyze_chart_data script:
# df is already provided
# Simple usage
fig = plot_ohlc(df, title='Price Action')
# With indicators
df['SMA'] = df['close'].rolling(20).mean()
fig = add_indicators_to_plot(
df,
indicators={'SMA': {'color': 'blue', 'width': 1.5}},
title='Price with SMA'
)
# Return data for the assistant
df[['close', 'SMA']].tail(10)
```
## Styling
The default style includes:
- **Up candles**: Teal green (#26a69a)
- **Down candles**: Coral red (#ef5350)
- **Background**: Light gray with white axes
- **Grid**: Subtle dashed lines with 30% alpha
- **Professional fonts**: Clean, readable sizes
## Why This Matters
**Before:**
```python
# LLM had to write this every time
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(df.index, df['close'], label='Close')
# ... lots more code for styling, colors, etc.
```
**After:**
```python
# LLM can now just do this
fig = plot_ohlc(df, title='BTC/USDT')
```
Benefits:
- ✅ Less code to generate → faster response
- ✅ Consistent appearance across all charts
- ✅ Professional look out of the box
- ✅ Easier to maintain and customize
- ✅ Better use of mplfinance's candlestick rendering
## Example Output
See `chart_utils_example.py` for runnable examples demonstrating:
1. Basic OHLC chart with volume
2. OHLC chart with multiple indicators
3. Price-only chart
4. Custom styling options
## File Locations
- **Main module**: `backend/src/agent/tools/chart_utils.py`
- **Integration**: `backend/src/agent/tools/chart_tools.py` (lines 306-328)
- **Examples**: `backend/src/agent/tools/chart_utils_example.py`
- **Dependency**: `backend/requirements.txt` (mplfinance added)

View File

@@ -0,0 +1,50 @@
"""Agent tools for trading operations.
This package provides tools for:
- Synchronization stores (sync_tools)
- Data sources and market data (datasource_tools)
- Chart data access and analysis (chart_tools)
- Technical indicators (indicator_tools)
"""
# Global registries that will be set by main.py
_registry = None
_datasource_registry = None
_indicator_registry = None
def set_registry(registry):
"""Set the global SyncRegistry instance for tools to use."""
global _registry
_registry = registry
def set_datasource_registry(datasource_registry):
"""Set the global DataSourceRegistry instance for tools to use."""
global _datasource_registry
_datasource_registry = datasource_registry
def set_indicator_registry(indicator_registry):
"""Set the global IndicatorRegistry instance for tools to use."""
global _indicator_registry
_indicator_registry = indicator_registry
# Import all tools from submodules
from .sync_tools import SYNC_TOOLS
from .datasource_tools import DATASOURCE_TOOLS
from .chart_tools import CHART_TOOLS
from .indicator_tools import INDICATOR_TOOLS
from .research_tools import RESEARCH_TOOLS
__all__ = [
"set_registry",
"set_datasource_registry",
"set_indicator_registry",
"SYNC_TOOLS",
"DATASOURCE_TOOLS",
"CHART_TOOLS",
"INDICATOR_TOOLS",
"RESEARCH_TOOLS",
]

View File

@@ -0,0 +1,371 @@
"""Chart data access and analysis tools."""
from typing import Dict, Any, Optional, Tuple
import io
import uuid
import logging
from pathlib import Path
from contextlib import redirect_stdout, redirect_stderr
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
def _get_registry():
"""Get the global registry instance."""
from . import _registry
return _registry
def _get_datasource_registry():
"""Get the global datasource registry instance."""
from . import _datasource_registry
return _datasource_registry
def _get_indicator_registry():
"""Get the global indicator registry instance."""
from . import _indicator_registry
return _indicator_registry
async def _get_chart_data_impl(countback: Optional[int] = None):
"""Internal implementation for getting chart data.
This is a helper function that can be called by both get_chart_data tool
and analyze_chart_data tool.
Returns:
Tuple of (HistoryResult, chart_context dict, source_name)
"""
registry = _get_registry()
datasource_registry = _get_datasource_registry()
if not registry:
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
if not datasource_registry:
raise ValueError("DataSourceRegistry not initialized - cannot query data")
# Read current chart state
chart_store = registry.entries.get("ChartStore")
if not chart_store:
raise ValueError("ChartStore not found in registry")
chart_state = chart_store.model.model_dump(mode="json")
chart_data = chart_state.get("chart_state", {})
symbol = chart_data.get("symbol", "")
interval = chart_data.get("interval", "15")
start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time")
if not symbol:
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
# Parse the symbol to extract exchange/source and symbol name
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
if ":" not in symbol:
raise ValueError(
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
)
exchange_prefix, symbol_name = symbol.split(":", 1)
source_name = exchange_prefix.lower()
# Get the data source
source = datasource_registry.get(source_name)
if not source:
available = datasource_registry.list_sources()
raise ValueError(
f"Data source '{source_name}' not found. Available sources: {available}. "
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
)
# Determine time range - REQUIRE it to be set, no defaults
if start_time is None or end_time is None:
raise ValueError(
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
f"Wait for the chart to fully load before analyzing data."
)
from_time = int(start_time)
end_time = int(end_time)
logger.info(
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
f"countback={countback}"
)
logger.info(
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
f"resolution '{interval}'"
)
# Query the data source
result = await source.get_bars(
symbol=symbol_name,
resolution=interval,
from_time=from_time,
to_time=end_time,
countback=countback
)
logger.info(
f"Received {len(result.bars)} bars from data source. "
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
)
# Build chart context to return along with result
chart_context = {
"symbol": symbol,
"interval": interval,
"start_time": start_time,
"end_time": end_time
}
return result, chart_context, source_name
@tool
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
"""Get the candle/bar data for what the user is currently viewing on their chart.
This is a convenience tool that automatically:
1. Reads the ChartStore to see what chart the user is viewing
2. Parses the symbol to determine the data source (exchange prefix)
3. Queries the appropriate data source for that symbol's data
4. Returns the data for the visible time range and interval
This is the preferred way to access chart data when helping the user analyze
what they're looking at, since it automatically uses their current chart context.
Args:
countback: Optional limit on number of bars to return. If not specified,
returns all bars in the visible time range.
Returns:
Dictionary containing:
- chart_context: Current chart state (symbol, interval, time range)
- symbol: The trading pair being viewed
- resolution: The chart interval
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- source: Which data source was used
Raises:
ValueError: If ChartStore or DataSourceRegistry is not initialized,
or if the symbol format is invalid
Example:
# User is viewing BINANCE:BTC/USDT on 15min chart
data = get_chart_data()
# Returns BTC/USDT data from binance source at 15min resolution
# for the currently visible time range
"""
result, chart_context, source_name = await _get_chart_data_impl(countback)
# Return enriched result with chart context
response = result.model_dump()
response["chart_context"] = chart_context
response["source"] = source_name
return response
@tool
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
"""Execute Python code for technical analysis with automatic chart data loading.
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
This is your go-to tool whenever the user asks about indicators, wants to see
a chart, or needs any computational analysis of market data.
Pre-loaded Environment:
- `pd` : pandas
- `np` : numpy
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
- `talib` : TA-Lib technical analysis library
- `indicator_registry`: 150+ registered indicators
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
Auto-loaded when user has a chart open:
- `df` : pandas DataFrame with DatetimeIndex and columns:
open, high, low, close, volume (OHLCV data ready to use)
- `chart_context` : dict with symbol, interval, start_time, end_time
The `plot_ohlc()` Helper:
Create professional candlestick charts instantly:
- `plot_ohlc(df)` - basic OHLC chart with volume
- `plot_ohlc(df, title='BTC 15min')` - with custom title
- `plot_ohlc(df, volume=False)` - price only, no volume
- Returns a matplotlib Figure that's automatically saved to plot_urls
Args:
code: Python code to execute
countback: Optional limit on number of bars to load (default: all visible bars)
Returns:
Dictionary with:
- script_output : printed output + last expression result
- result_dataframe : serialized DataFrame if last expression is a DataFrame
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
- chart_context : {symbol, interval, start_time, end_time} or None
- error : traceback if execution failed
Examples:
# RSI indicator with chart
execute_python(\"\"\"
df['RSI'] = talib.RSI(df['close'], 14)
fig = plot_ohlc(df, title='BTC/USDT with RSI')
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
df[['close', 'RSI']].tail(5)
\"\"\")
# Multiple indicators
execute_python(\"\"\"
df['SMA_20'] = df['close'].rolling(20).mean()
df['SMA_50'] = df['close'].rolling(50).mean()
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
current_price = df['close'].iloc[-1]
sma20 = df['SMA_20'].iloc[-1]
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
\"\"\")
# Pattern detection
execute_python(\"\"\"
# Find swing highs
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
swing_highs = df[df['swing_high']][['high']].tail(5)
fig = plot_ohlc(df, title='Swing High Detection')
print("Recent swing highs:")
print(swing_highs)
\"\"\")
"""
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
try:
import talib
except ImportError:
talib = None
logger.warning("TA-Lib not available in execute_python environment")
# --- Attempt to load chart data ---
df = None
chart_context = None
registry = _get_registry()
datasource_registry = _get_datasource_registry()
if registry and datasource_registry:
try:
result, chart_context, source_name = await _get_chart_data_impl(countback)
bars = result.bars
if bars:
rows = []
for bar in bars:
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
df = pd.DataFrame(rows).set_index('time')
for col in ['open', 'high', 'low', 'close', 'volume']:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
except Exception as e:
logger.info(f"execute_python: no chart data loaded ({e})")
# --- Import chart utilities ---
from .chart_utils import plot_ohlc
# --- Get indicator registry ---
indicator_registry = _get_indicator_registry()
# --- Build globals ---
script_globals: Dict[str, Any] = {
'pd': pd,
'np': np,
'plt': plt,
'talib': talib,
'indicator_registry': indicator_registry,
'df': df,
'chart_context': chart_context,
'plot_ohlc': plot_ohlc,
}
# --- Execute ---
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
stdout_capture = io.StringIO()
result_df = None
error_msg = None
plot_urls = []
try:
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
exec(code, script_globals)
# Capture last expression
lines = code.strip().splitlines()
if lines:
last = lines[-1].strip()
if last and not any(last.startswith(kw) for kw in (
'if', 'for', 'while', 'def', 'class', 'import',
'from', 'with', 'try', 'return', '#'
)):
try:
last_val = eval(last, script_globals)
if isinstance(last_val, pd.DataFrame):
result_df = last_val
elif last_val is not None:
stdout_capture.write(str(last_val))
except Exception:
pass
# Save plots
for fig_num in plt.get_fignums():
fig = plt.figure(fig_num)
filename = f"plot_{uuid.uuid4()}.png"
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
plot_urls.append(f"/uploads/{filename}")
plt.close(fig)
except Exception as e:
import traceback
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
# --- Build response ---
response: Dict[str, Any] = {
'script_output': stdout_capture.getvalue(),
'chart_context': chart_context,
'plot_urls': plot_urls,
}
if result_df is not None:
response['result_dataframe'] = {
'columns': result_df.columns.tolist(),
'index': result_df.index.astype(str).tolist(),
'data': result_df.values.tolist(),
'shape': result_df.shape,
}
if error_msg:
response['error'] = error_msg
return response
CHART_TOOLS = [
get_chart_data,
execute_python
]

View File

@@ -0,0 +1,224 @@
"""Chart plotting utilities for creating standard, beautiful OHLC charts."""
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Tuple
import logging
logger = logging.getLogger(__name__)
def plot_ohlc(
df: pd.DataFrame,
title: Optional[str] = None,
volume: bool = True,
figsize: Tuple[int, int] = (14, 8),
style: str = 'seaborn-v0_8-darkgrid',
**kwargs
) -> plt.Figure:
"""Create a beautiful standard OHLC candlestick chart.
This is a convenience function that generates a professional-looking candlestick
chart with consistent styling across all generated charts. It uses mplfinance
with seaborn aesthetics for a polished appearance.
Args:
df: pandas DataFrame with DatetimeIndex and columns: open, high, low, close, volume
title: Optional chart title. If None, uses symbol from chart context
volume: Whether to include volume subplot (default: True)
figsize: Figure size as (width, height) in inches (default: (14, 8))
style: Base matplotlib style to use (default: 'seaborn-v0_8-darkgrid')
**kwargs: Additional arguments to pass to mplfinance.plot()
Returns:
matplotlib.figure.Figure: The created figure object
Example:
```python
# Basic usage in analyze_chart_data script
fig = plot_ohlc(df, title='BTC/USDT 15min')
# Customize with additional indicators
fig = plot_ohlc(df, volume=True, title='Price Action')
# Add custom overlays after calling plot_ohlc
df['SMA20'] = df['close'].rolling(20).mean()
fig = plot_ohlc(df, title='With SMA')
# Note: For mplfinance overlays, use the mav or addplot parameters
```
Note:
The DataFrame must have a DatetimeIndex and the standard OHLCV columns.
Column names should be lowercase: open, high, low, close, volume
"""
try:
import mplfinance as mpf
except ImportError:
raise ImportError(
"mplfinance is required for plot_ohlc(). "
"Install it with: pip install mplfinance"
)
# Validate DataFrame structure
required_cols = ['open', 'high', 'low', 'close']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(
f"DataFrame missing required columns: {missing_cols}. "
f"Required: {required_cols}"
)
if not isinstance(df.index, pd.DatetimeIndex):
raise ValueError(
"DataFrame must have a DatetimeIndex. "
"Convert with: df.index = pd.to_datetime(df.index)"
)
# Ensure volume column exists for volume plot
if volume and 'volume' not in df.columns:
logger.warning("volume=True but 'volume' column not found in DataFrame. Disabling volume.")
volume = False
# Create custom style with seaborn aesthetics
# Using a professional color scheme: green for up candles, red for down candles
mc = mpf.make_marketcolors(
up='#26a69a', # Teal green (calmer than bright green)
down='#ef5350', # Coral red (softer than pure red)
edge='inherit', # Match candle color for edges
wick='inherit', # Match candle color for wicks
volume='in', # Volume bars colored by price direction
alpha=0.9 # Slight transparency for elegance
)
s = mpf.make_mpf_style(
base_mpf_style='charles', # Clean base style
marketcolors=mc,
rc={
'font.size': 10,
'axes.labelsize': 11,
'axes.titlesize': 12,
'xtick.labelsize': 9,
'ytick.labelsize': 9,
'legend.fontsize': 10,
'figure.facecolor': '#f0f0f0',
'axes.facecolor': '#ffffff',
'axes.grid': True,
'grid.alpha': 0.3,
'grid.linestyle': '--',
}
)
# Prepare plot parameters
plot_params = {
'type': 'candle',
'style': s,
'volume': volume,
'figsize': figsize,
'tight_layout': True,
'returnfig': True,
'warn_too_much_data': 1000, # Warn if > 1000 candles for performance
}
# Add title if provided
if title:
plot_params['title'] = title
# Merge any additional kwargs
plot_params.update(kwargs)
# Create the plot
logger.info(
f"Creating OHLC chart with {len(df)} candles, "
f"date range: {df.index.min()} to {df.index.max()}, "
f"volume: {volume}"
)
fig, axes = mpf.plot(df, **plot_params)
return fig
def add_indicators_to_plot(
df: pd.DataFrame,
indicators: dict,
**plot_kwargs
) -> plt.Figure:
"""Create an OHLC chart with technical indicators overlaid.
This extends plot_ohlc() to include common technical indicators using
mplfinance's addplot functionality for proper overlay on candlestick charts.
Args:
df: pandas DataFrame with OHLCV data and indicator columns
indicators: Dictionary mapping indicator names to parameters
Example: {
'SMA_20': {'color': 'blue', 'width': 1.5},
'EMA_50': {'color': 'orange', 'width': 1.5}
}
**plot_kwargs: Additional arguments for plot_ohlc()
Returns:
matplotlib.figure.Figure: The created figure object
Example:
```python
# Calculate indicators
df['SMA_20'] = df['close'].rolling(20).mean()
df['SMA_50'] = df['close'].rolling(50).mean()
# Plot with indicators
fig = add_indicators_to_plot(
df,
indicators={
'SMA_20': {'color': 'blue', 'width': 1.5, 'label': '20 SMA'},
'SMA_50': {'color': 'red', 'width': 1.5, 'label': '50 SMA'}
},
title='BTC/USDT with Moving Averages'
)
```
"""
try:
import mplfinance as mpf
except ImportError:
raise ImportError(
"mplfinance is required. Install it with: pip install mplfinance"
)
# Build addplot list for indicators
addplots = []
for indicator_col, params in indicators.items():
if indicator_col not in df.columns:
logger.warning(f"Indicator column '{indicator_col}' not found in DataFrame. Skipping.")
continue
color = params.get('color', 'blue')
width = params.get('width', 1.0)
panel = params.get('panel', 0) # 0 = main panel with candles
ylabel = params.get('ylabel', '')
addplots.append(
mpf.make_addplot(
df[indicator_col],
color=color,
width=width,
panel=panel,
ylabel=ylabel
)
)
# Pass addplot to plot_ohlc via kwargs
if addplots:
plot_kwargs['addplot'] = addplots
return plot_ohlc(df, **plot_kwargs)
# Convenience presets for common chart types
def plot_price_volume(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
"""Create a standard price + volume chart."""
return plot_ohlc(df, title=title, volume=True, figsize=(14, 8))
def plot_price_only(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
"""Create a price-only candlestick chart without volume."""
return plot_ohlc(df, title=title, volume=False, figsize=(14, 6))

View File

@@ -0,0 +1,154 @@
"""
Example usage of chart_utils.py plotting functions.
This demonstrates how the LLM can use the plot_ohlc() convenience function
in analyze_chart_data scripts to create beautiful, standard OHLC charts.
"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
def create_sample_data(days=30):
"""Create sample OHLCV data for testing."""
dates = pd.date_range(end=datetime.now(), periods=days * 24, freq='1H')
# Simulate price movement
np.random.seed(42)
close = 50000 + np.cumsum(np.random.randn(len(dates)) * 100)
data = {
'open': close + np.random.randn(len(dates)) * 50,
'high': close + np.abs(np.random.randn(len(dates))) * 100,
'low': close - np.abs(np.random.randn(len(dates))) * 100,
'close': close,
'volume': np.abs(np.random.randn(len(dates))) * 1000000
}
df = pd.DataFrame(data, index=dates)
# Ensure high is highest and low is lowest
df['high'] = df[['open', 'high', 'low', 'close']].max(axis=1)
df['low'] = df[['open', 'high', 'low', 'close']].min(axis=1)
return df
if __name__ == "__main__":
from chart_utils import plot_ohlc, add_indicators_to_plot, plot_price_volume
# Create sample data
df = create_sample_data(days=30)
print("=" * 60)
print("Example 1: Basic OHLC chart with volume")
print("=" * 60)
print("\nScript the LLM would generate:")
print("""
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
df.tail(5)
""")
# Execute it
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
print("\n✓ Chart created successfully!")
print(f" Figure size: {fig.get_size_inches()}")
print(f" Number of axes: {len(fig.axes)}")
print("\n" + "=" * 60)
print("Example 2: OHLC chart with indicators")
print("=" * 60)
print("\nScript the LLM would generate:")
print("""
# Calculate indicators
df['SMA_20'] = df['close'].rolling(20).mean()
df['SMA_50'] = df['close'].rolling(50).mean()
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
# Plot with indicators
fig = add_indicators_to_plot(
df,
indicators={
'SMA_20': {'color': 'blue', 'width': 1.5},
'SMA_50': {'color': 'red', 'width': 1.5},
'EMA_12': {'color': 'green', 'width': 1.0}
},
title='BTC/USDT with Moving Averages',
volume=True
)
df[['close', 'SMA_20', 'SMA_50', 'EMA_12']].tail(5)
""")
# Execute it
df['SMA_20'] = df['close'].rolling(20).mean()
df['SMA_50'] = df['close'].rolling(50).mean()
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
fig = add_indicators_to_plot(
df,
indicators={
'SMA_20': {'color': 'blue', 'width': 1.5},
'SMA_50': {'color': 'red', 'width': 1.5},
'EMA_12': {'color': 'green', 'width': 1.0}
},
title='BTC/USDT with Moving Averages',
volume=True
)
print("\n✓ Chart with indicators created successfully!")
print(f" Last close: ${df['close'].iloc[-1]:,.2f}")
print(f" SMA 20: ${df['SMA_20'].iloc[-1]:,.2f}")
print(f" SMA 50: ${df['SMA_50'].iloc[-1]:,.2f}")
print("\n" + "=" * 60)
print("Example 3: Price-only chart (no volume)")
print("=" * 60)
print("\nScript the LLM would generate:")
print("""
from chart_utils import plot_price_only
fig = plot_price_only(df, title='Clean Price Action')
""")
# Execute it
from chart_utils import plot_price_only
fig = plot_price_only(df, title='Clean Price Action')
print("\n✓ Price-only chart created successfully!")
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
print("""
The chart_utils module provides:
1. plot_ohlc() - Main function for beautiful candlestick charts
- Professional seaborn-inspired styling
- Consistent color scheme (teal up, coral down)
- Optional volume subplot
- Customizable figure size
2. add_indicators_to_plot() - OHLC charts with technical indicators
- Overlay multiple indicators
- Customizable colors and line widths
- Proper integration with mplfinance
3. Preset functions for common chart types:
- plot_price_volume() - Standard price + volume
- plot_price_only() - Candlesticks without volume
Benefits:
✓ Consistent look and feel across all charts
✓ Less code for the LLM to generate
✓ Professional appearance out of the box
✓ Easy to customize when needed
✓ Works seamlessly with analyze_chart_data tool
The LLM can now simply call plot_ohlc(df) instead of writing
custom matplotlib code for every chart request!
""")

View File

@@ -0,0 +1,158 @@
"""Data source and market data tools."""
from typing import Dict, Any, List, Optional
from langchain_core.tools import tool
def _get_datasource_registry():
"""Get the global datasource registry instance."""
from . import _datasource_registry
return _datasource_registry
@tool
def list_data_sources() -> List[str]:
"""List all available data sources.
Returns:
List of data source names that can be queried for market data
"""
registry = _get_datasource_registry()
if not registry:
return []
return registry.list_sources()
@tool
async def search_symbols(
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> Dict[str, Any]:
"""Search for trading symbols across all data sources.
Automatically searches all available data sources and returns aggregated results.
Use this to find symbols before calling get_symbol_info or get_historical_data.
Args:
query: Search query (e.g., "BTC", "AAPL", "EUR")
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
limit: Maximum number of results per source (default: 30)
Returns:
Dictionary mapping source names to lists of matching symbols.
Each symbol includes: symbol, full_name, description, exchange, type.
Use the source name and symbol from results with get_symbol_info or get_historical_data.
Example response:
{
"demo": [
{
"symbol": "BTC/USDT",
"full_name": "Bitcoin / Tether USD",
"description": "Bitcoin perpetual futures",
"exchange": "demo",
"type": "crypto"
}
]
}
"""
registry = _get_datasource_registry()
if not registry:
raise ValueError("DataSourceRegistry not initialized")
# Always search all sources
results = await registry.search_all(query, type, exchange, limit)
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
@tool
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
"""Get complete metadata for a trading symbol.
This retrieves full information about a symbol including:
- Description and type
- Supported time resolutions
- Available data columns (OHLCV, volume, funding rates, etc.)
- Trading session information
- Price scale and precision
Args:
source_name: Name of the data source (use list_data_sources to see available)
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
Returns:
Dictionary containing complete symbol metadata including column schema
Raises:
ValueError: If source_name or symbol is not found
"""
registry = _get_datasource_registry()
if not registry:
raise ValueError("DataSourceRegistry not initialized")
symbol_info = await registry.resolve_symbol(source_name, symbol)
return symbol_info.model_dump()
@tool
async def get_historical_data(
source_name: str,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> Dict[str, Any]:
"""Get historical bar/candle data for a symbol.
Retrieves time-series data between the specified timestamps. The data
includes all columns defined for the symbol (OHLCV + any custom columns).
Args:
source_name: Name of the data source
symbol: Symbol identifier
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
from_time: Start time as Unix timestamp in seconds
to_time: End time as Unix timestamp in seconds
countback: Optional limit on number of bars to return
Returns:
Dictionary containing:
- symbol: The requested symbol
- resolution: The time resolution
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- nextTime: If present, indicates more data is available for pagination
Raises:
ValueError: If source, symbol, or resolution is invalid
Example:
# Get 1-hour BTC data for the last 24 hours
import time
to_time = int(time.time())
from_time = to_time - 86400 # 24 hours ago
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
"""
registry = _get_datasource_registry()
if not registry:
raise ValueError("DataSourceRegistry not initialized")
source = registry.get(source_name)
if not source:
available = registry.list_sources()
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
return result.model_dump()
DATASOURCE_TOOLS = [
list_data_sources,
search_symbols,
get_symbol_info,
get_historical_data,
]

View File

@@ -0,0 +1,169 @@
"""Technical indicator tools."""
from typing import Dict, Any, List, Optional
from langchain_core.tools import tool
def _get_indicator_registry():
"""Get the global indicator registry instance."""
from . import _indicator_registry
return _indicator_registry
@tool
def list_indicators() -> List[str]:
"""List all available technical indicators.
Returns:
List of indicator names that can be used in analysis and strategies
"""
registry = _get_indicator_registry()
if not registry:
return []
return registry.list_indicators()
@tool
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
"""Get detailed information about a specific indicator.
Retrieves metadata including description, parameters, category, use cases,
input/output schemas, and references.
Args:
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
Returns:
Dictionary containing:
- name: Indicator name
- display_name: Human-readable name
- description: What the indicator computes and why it's useful
- category: Category (momentum, trend, volatility, volume, etc.)
- parameters: List of configurable parameters with types and defaults
- use_cases: Common trading scenarios where this indicator helps
- tags: Searchable tags
- input_schema: Required input columns (e.g., OHLCV requirements)
- output_schema: Columns this indicator produces
Raises:
ValueError: If indicator_name is not found
"""
registry = _get_indicator_registry()
if not registry:
raise ValueError("IndicatorRegistry not initialized")
metadata = registry.get_metadata(indicator_name)
if not metadata:
total_count = len(registry.list_indicators())
raise ValueError(
f"Indicator '{indicator_name}' not found. "
f"Total available: {total_count} indicators. "
f"Use search_indicators() to find indicators by name, category, or tag."
)
input_schema = registry.get_input_schema(indicator_name)
output_schema = registry.get_output_schema(indicator_name)
result = metadata.model_dump()
result["input_schema"] = input_schema.model_dump() if input_schema else None
result["output_schema"] = output_schema.model_dump() if output_schema else None
return result
@tool
def search_indicators(
query: Optional[str] = None,
category: Optional[str] = None,
tag: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Search for indicators by text query, category, or tag.
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
Use this to discover relevant indicators for your trading strategy or analysis.
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
Args:
query: Optional text search across names, descriptions, and use cases
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
Returns:
List of lightweight indicator summaries. Each contains:
- name: Indicator name (use with get_indicator_info() for full details)
- display_name: Human-readable name
- description: Brief one-line description
- category: Category (momentum, trend, volatility, etc.)
Example:
# Find all momentum indicators
results = search_indicators(category="momentum")
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
# Then get details on interesting ones
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
# Search for moving average indicators
search_indicators(query="moving average")
# Find all TA-Lib indicators
search_indicators(tag="talib")
"""
registry = _get_indicator_registry()
if not registry:
raise ValueError("IndicatorRegistry not initialized")
results = []
if query:
results = registry.search_by_text(query)
elif category:
results = registry.search_by_category(category)
elif tag:
results = registry.search_by_tag(tag)
else:
# Return all indicators if no filter
results = registry.get_all_metadata()
# Return lightweight summaries only
return [
{
"name": r.name,
"display_name": r.display_name,
"description": r.description,
"category": r.category
}
for r in results
]
@tool
def get_indicator_categories() -> Dict[str, int]:
"""Get all indicator categories and their counts.
Returns a summary of available indicator categories, useful for
exploring what types of indicators are available.
Returns:
Dictionary mapping category name to count of indicators in that category.
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
"""
registry = _get_indicator_registry()
if not registry:
raise ValueError("IndicatorRegistry not initialized")
categories: Dict[str, int] = {}
for metadata in registry.get_all_metadata():
category = metadata.category
categories[category] = categories.get(category, 0) + 1
return categories
INDICATOR_TOOLS = [
list_indicators,
get_indicator_info,
search_indicators,
get_indicator_categories
]

View File

@@ -0,0 +1,171 @@
"""Research and external data tools for trading analysis."""
from typing import Dict, Any, Optional
from langchain_core.tools import tool
from langchain_community.tools import (
ArxivQueryRun,
WikipediaQueryRun,
DuckDuckGoSearchRun
)
from langchain_community.utilities import (
ArxivAPIWrapper,
WikipediaAPIWrapper,
DuckDuckGoSearchAPIWrapper
)
@tool
def search_arxiv(query: str, max_results: int = 5) -> str:
"""Search arXiv for academic papers on quantitative finance, trading strategies, and machine learning.
Use this to find research papers on topics like:
- Market microstructure and order flow
- Algorithmic trading strategies
- Machine learning for finance
- Time series forecasting
- Risk management
- Portfolio optimization
Args:
query: Search query (e.g., "machine learning algorithmic trading", "deep learning stock prediction")
max_results: Maximum number of results to return (default: 5)
Returns:
Summary of papers including titles, authors, abstracts, and links
Example:
search_arxiv("reinforcement learning trading", max_results=3)
"""
arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=max_results))
return arxiv.run(query)
@tool
def search_wikipedia(query: str) -> str:
"""Search Wikipedia for information on finance, trading, and economics concepts.
Use this to get background information on:
- Financial instruments and markets
- Economic indicators
- Trading terminology
- Technical analysis concepts
- Historical market events
Args:
query: Search query (e.g., "Black-Scholes model", "technical analysis", "options trading")
Returns:
Wikipedia article summary with key information
Example:
search_wikipedia("Bollinger Bands")
"""
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
return wikipedia.run(query)
@tool
def search_web(query: str, max_results: int = 5) -> str:
"""Search the web for current information on markets, news, and trading.
Use this to find:
- Latest market news and analysis
- Company announcements and earnings
- Economic events and indicators
- Cryptocurrency updates
- Exchange status and updates
- Trading strategy discussions
Args:
query: Search query (e.g., "Bitcoin price news", "Fed interest rate decision")
max_results: Maximum number of results to return (default: 5)
Returns:
Search results with titles, snippets, and links
Example:
search_web("Ethereum merge update", max_results=3)
"""
# Lazy initialization to avoid hanging during import
search = DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper())
# Note: max_results parameter doesn't work properly with current wrapper
return search.run(query)
@tool
def http_get(url: str, params: Optional[Dict[str, str]] = None) -> str:
"""Make HTTP GET request to fetch data from APIs or web pages.
Use this to retrieve:
- Exchange API data (if public endpoints)
- Market data from external APIs
- Documentation and specifications
- News articles and blog posts
- JSON/XML data from web services
Args:
url: The URL to fetch
params: Optional query parameters as a dictionary
Returns:
Response text from the URL
Raises:
ValueError: If the request fails
Example:
http_get("https://api.coingecko.com/api/v3/simple/price",
params={"ids": "bitcoin", "vs_currencies": "usd"})
"""
import requests
try:
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ValueError(f"HTTP GET request failed: {str(e)}")
@tool
def http_post(url: str, data: Dict[str, Any]) -> str:
"""Make HTTP POST request to send data to APIs.
Use this to:
- Submit data to external APIs
- Trigger webhooks
- Post analysis results
- Interact with exchange APIs (if authenticated)
Args:
url: The URL to post to
data: Dictionary of data to send in the request body
Returns:
Response text from the server
Raises:
ValueError: If the request fails
Example:
http_post("https://webhook.site/xxx", {"message": "Trade executed"})
"""
import requests
import json
try:
response = requests.post(url, json=data, timeout=10)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ValueError(f"HTTP POST request failed: {str(e)}")
# Export tools list
RESEARCH_TOOLS = [
search_arxiv,
search_wikipedia,
search_web,
http_get,
http_post
]

View File

@@ -0,0 +1,138 @@
"""Synchronization store tools."""
from typing import Dict, Any, List
from langchain_core.tools import tool
def _get_registry():
"""Get the global registry instance."""
from . import _registry
return _registry
@tool
def list_sync_stores() -> List[str]:
"""List all available synchronization stores.
Returns:
List of store names that can be read/written
"""
registry = _get_registry()
if not registry:
return []
return list(registry.entries.keys())
@tool
def read_sync_state(store_name: str) -> Dict[str, Any]:
"""Read the current state of a synchronization store.
Args:
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
Returns:
Dictionary containing the current state of the store
Raises:
ValueError: If store_name doesn't exist
"""
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
entry = registry.entries.get(store_name)
if not entry:
available = list(registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
return entry.model.model_dump(mode="json")
@tool
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
"""Update the state of a synchronization store.
This will apply the updates to the store and trigger synchronization
with all connected clients.
Args:
store_name: Name of the store to update
updates: Dictionary of field updates (field_name: new_value)
Returns:
Dictionary with status and updated fields
Raises:
ValueError: If store_name doesn't exist or updates are invalid
"""
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
entry = registry.entries.get(store_name)
if not entry:
available = list(registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
try:
# Get current state
current_state = entry.model.model_dump(mode="json")
# Apply updates
new_state = {**current_state, **updates}
# Update the model
registry._update_model(entry.model, new_state)
# Trigger sync
await registry.push_all()
return {
"status": "success",
"store": store_name,
"updated_fields": list(updates.keys())
}
except Exception as e:
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
@tool
def get_store_schema(store_name: str) -> Dict[str, Any]:
"""Get the schema/structure of a synchronization store.
This shows what fields are available and their types.
Args:
store_name: Name of the store
Returns:
Dictionary describing the store's schema
Raises:
ValueError: If store_name doesn't exist
"""
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
entry = registry.entries.get(store_name)
if not entry:
available = list(registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
# Get model schema
schema = entry.model.model_json_schema()
return {
"store_name": store_name,
"schema": schema
}
SYNC_TOOLS = [
list_sync_stores,
read_sync_state,
write_sync_state,
get_store_schema
]

View File

@@ -6,9 +6,10 @@ the free CCXT library (not ccxt.pro), supporting both historical data and
polling-based subscriptions. polling-based subscriptions.
Numerical Precision: Numerical Precision:
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors - OHLCV data uses native floats for optimal DataFrame/analysis performance
- Account balances and order data should use Decimal (via _to_decimal method)
- CCXT returns numeric values as strings or floats depending on configuration - CCXT returns numeric values as strings or floats depending on configuration
- All financial values are converted to Decimal to maintain precision - Price data converted to float (_to_float), financial data to Decimal (_to_decimal)
Real-time Updates: Real-time Updates:
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support) - Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
@@ -72,6 +73,20 @@ class CCXTDataSource(DataSource):
exchange_class = getattr(ccxt, exchange_id) exchange_class = getattr(ccxt, exchange_id)
self.exchange = exchange_class(self._config) self.exchange = exchange_class(self._config)
# Configure CCXT to use Decimal mode for precise financial calculations
# This ensures all numeric values from the exchange use Decimal internally
# We then convert OHLCV to float for DataFrame performance, but keep
# Decimal precision for account balances, order sizes, etc.
from decimal import Decimal as PythonDecimal
self.exchange.number = PythonDecimal
# Log the precision mode being used by this exchange
precision_mode = getattr(self.exchange, 'precisionMode', 'UNKNOWN')
logger.info(
f"CCXT {exchange_id}: Configured with Decimal mode. "
f"Exchange precision mode: {precision_mode}"
)
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'): if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
self.exchange.set_sandbox_mode(True) self.exchange.set_sandbox_mode(True)
@@ -103,6 +118,33 @@ class CCXTDataSource(DataSource):
return Decimal(str(value)) return Decimal(str(value))
return None return None
@staticmethod
def _to_float(value: Union[str, int, float, Decimal, None]) -> Optional[float]:
"""
Convert a value to float for OHLCV data.
OHLCV data is used for charting and DataFrame analysis, where native
floats provide better performance and compatibility with pandas/numpy.
For financial precision (balances, order sizes), use _to_decimal() instead.
When CCXT is in Decimal mode (exchange.number = Decimal), it returns
Decimal objects. This method converts them to float for performance.
Handles CCXT's output in both modes:
- Decimal mode: receives Decimal objects
- Default mode: receives strings, floats, or ints
"""
if value is None:
return None
if isinstance(value, float):
return value
if isinstance(value, Decimal):
# CCXT in Decimal mode - convert to float for OHLCV
return float(value)
if isinstance(value, (str, int)):
return float(value)
return None
async def _ensure_markets_loaded(self): async def _ensure_markets_loaded(self):
"""Ensure markets are loaded from exchange""" """Ensure markets are loaded from exchange"""
if not self._markets_loaded: if not self._markets_loaded:
@@ -241,31 +283,31 @@ class CCXTDataSource(DataSource):
columns=[ columns=[
ColumnInfo( ColumnInfo(
name="open", name="open",
type="decimal", type="float",
description=f"Opening price in {quote}", description=f"Opening price in {quote}",
unit=quote, unit=quote,
), ),
ColumnInfo( ColumnInfo(
name="high", name="high",
type="decimal", type="float",
description=f"Highest price in {quote}", description=f"Highest price in {quote}",
unit=quote, unit=quote,
), ),
ColumnInfo( ColumnInfo(
name="low", name="low",
type="decimal", type="float",
description=f"Lowest price in {quote}", description=f"Lowest price in {quote}",
unit=quote, unit=quote,
), ),
ColumnInfo( ColumnInfo(
name="close", name="close",
type="decimal", type="float",
description=f"Closing price in {quote}", description=f"Closing price in {quote}",
unit=quote, unit=quote,
), ),
ColumnInfo( ColumnInfo(
name="volume", name="volume",
type="decimal", type="float",
description=f"Trading volume in {base}", description=f"Trading volume in {base}",
unit=base, unit=base,
), ),
@@ -370,7 +412,7 @@ class CCXTDataSource(DataSource):
all_ohlcv = all_ohlcv[:countback] all_ohlcv = all_ohlcv[:countback]
break break
# Convert to our Bar format with Decimal precision # Convert to our Bar format with float for OHLCV (used in DataFrames)
bars = [] bars = []
for candle in all_ohlcv: for candle in all_ohlcv:
timestamp_ms, open_price, high, low, close, volume = candle timestamp_ms, open_price, high, low, close, volume = candle
@@ -384,11 +426,11 @@ class CCXTDataSource(DataSource):
Bar( Bar(
time=timestamp, time=timestamp,
data={ data={
"open": self._to_decimal(open_price), "open": self._to_float(open_price),
"high": self._to_decimal(high), "high": self._to_float(high),
"low": self._to_decimal(low), "low": self._to_float(low),
"close": self._to_decimal(close), "close": self._to_float(close),
"volume": self._to_decimal(volume), "volume": self._to_float(volume),
}, },
) )
) )
@@ -476,14 +518,14 @@ class CCXTDataSource(DataSource):
if timestamp > last_timestamp: if timestamp > last_timestamp:
self._last_bars[subscription_id] = timestamp self._last_bars[subscription_id] = timestamp
# Convert to our format with Decimal precision # Convert to our format with float for OHLCV (used in DataFrames)
tick_data = { tick_data = {
"time": timestamp, "time": timestamp,
"open": self._to_decimal(open_price), "open": self._to_float(open_price),
"high": self._to_decimal(high), "high": self._to_float(high),
"low": self._to_decimal(low), "low": self._to_float(low),
"close": self._to_decimal(close), "close": self._to_float(close),
"volume": self._to_decimal(volume), "volume": self._to_float(volume),
} }
# Call the callback # Call the callback

View File

@@ -0,0 +1,179 @@
# Exchange Kernel API
A Kubernetes-style declarative API for managing orders across different exchanges.
## Architecture Overview
The Exchange Kernel maintains two separate views of order state:
1. **Desired State (Intent)**: What the strategy kernel wants
2. **Actual State (Reality)**: What currently exists on the exchange
A reconciliation loop continuously works to bring actual state into alignment with desired state, handling errors, retries, and edge cases automatically.
## Core Components
### Models (`models.py`)
- **OrderIntent**: Desired order state from strategy kernel
- **OrderState**: Actual current order state on exchange
- **Position**: Current position (spot, margin, perp, futures, options)
- **Asset**: Asset holdings with metadata
- **AccountState**: Complete account snapshot (balances, positions, margin)
- **AssetMetadata**: Asset type descriptions and trading parameters
### Events (`events.py`)
Order lifecycle events:
- `OrderSubmitted`, `OrderAccepted`, `OrderRejected`
- `OrderPartiallyFilled`, `OrderFilled`, `OrderCanceled`
- `OrderModified`, `OrderExpired`
Position events:
- `PositionOpened`, `PositionModified`, `PositionClosed`
Account events:
- `AccountBalanceUpdated`, `MarginCallWarning`
### Base Interface (`base.py`)
Abstract `ExchangeKernel` class defining:
**Command API**:
- `place_order()`, `place_order_group()` - Create order intents
- `cancel_order()`, `modify_order()` - Update intents
- `cancel_all_orders()` - Bulk cancellation
**Query API**:
- `get_order_intent()`, `get_order_state()` - Query single order
- `get_all_intents()`, `get_all_orders()` - Query all orders
- `get_positions()`, `get_account_state()` - Query positions/balances
- `get_symbol_metadata()`, `get_asset_metadata()` - Query market info
**Event API**:
- `subscribe_events()`, `unsubscribe_events()` - Event notifications
**Lifecycle**:
- `start()`, `stop()` - Kernel lifecycle
- `health_check()` - Connection status
- `force_reconciliation()` - Manual reconciliation trigger
### State Management (`state.py`)
- **IntentStateStore**: Storage for desired state (durable, survives restarts)
- **ActualStateStore**: Storage for actual exchange state (ephemeral cache)
- **ReconciliationEngine**: Framework for intent→reality reconciliation
- **InMemory implementations**: For testing/prototyping
## Standard Order Model
Defined in `schema/order_spec.py`:
```python
StandardOrder(
symbol_id="BTC/USD",
side=Side.BUY,
amount=1.0,
amount_type=AmountType.BASE, # or QUOTE for exact-out
limit_price=50000.0, # None for market orders
time_in_force=TimeInForce.GTC,
conditional_trigger=ConditionalTrigger(...), # Optional stop-loss/take-profit
conditional_mode=ConditionalOrderMode.UNIFIED_ADJUSTING,
reduce_only=False,
post_only=False,
iceberg_qty=None,
)
```
## Symbol Metadata
Markets describe their capabilities via `SymbolMetadata`:
- **AmountConstraints**: Min/max order size, step size
- **PriceConstraints**: Tick size, tick spacing mode (fixed/dynamic/continuous)
- **MarketCapabilities**:
- Supported sides (BUY, SELL)
- Supported amount types (BASE, QUOTE, or both)
- Market vs limit order support
- Time-in-force options (GTC, IOC, FOK, DAY, GTD)
- Conditional order support (stop-loss, take-profit, trailing stops)
- Advanced features (post-only, reduce-only, iceberg)
## Asset Types
Comprehensive asset type system supporting:
- **SPOT**: Cash markets
- **MARGIN**: Margin trading
- **PERP**: Perpetual futures
- **FUTURE**: Dated futures
- **OPTION**: Options contracts
- **SYNTHETIC**: Derived instruments
Each asset has metadata describing contract specs, settlement, margin requirements, etc.
## Usage Pattern
```python
# Create exchange kernel for specific exchange
kernel = SomeExchangeKernel(exchange_id="binance_main")
# Subscribe to events
kernel.subscribe_events(my_event_handler)
# Start kernel
await kernel.start()
# Place order (creates intent, kernel handles execution)
intent_id = await kernel.place_order(
StandardOrder(
symbol_id="BTC/USD",
side=Side.BUY,
amount=1.0,
amount_type=AmountType.BASE,
limit_price=50000.0,
)
)
# Query desired state
intent = await kernel.get_order_intent(intent_id)
# Query actual state
state = await kernel.get_order_state(intent_id)
# Modify order (updates intent, kernel reconciles)
await kernel.modify_order(intent_id, new_order)
# Cancel order
await kernel.cancel_order(intent_id)
# Query positions
positions = await kernel.get_positions()
# Query account state
account = await kernel.get_account_state()
```
## Implementation Status
**Complete**:
- Data models and type definitions
- Event definitions
- Abstract interface
- State store framework
- In-memory stores for testing
**TODO** (Exchange-specific implementations):
- Concrete ExchangeKernel implementations per exchange
- Reconciliation engine implementation
- Exchange API adapters
- Persistent state storage (database)
- Error handling and retry logic
- Monitoring and observability
## Next Steps
1. Create concrete implementations for specific exchanges (Binance, Uniswap, etc.)
2. Implement reconciliation engine with proper error handling
3. Add persistent storage backend for intents
4. Build integration tests
5. Add monitoring/metrics collection

View File

@@ -0,0 +1,75 @@
"""
Exchange Kernel API
The exchange kernel provides a Kubernetes-style declarative API for managing orders
across different exchanges. It maintains both desired state (intent) and actual state
(current orders on exchange) and reconciles them continuously.
Key concepts:
- OrderIntent: What the strategy kernel wants
- OrderState: What actually exists on the exchange
- Reconciliation: Bringing actual state into alignment with desired state
"""
from .base import ExchangeKernel
from .events import (
OrderEvent,
OrderSubmitted,
OrderAccepted,
OrderRejected,
OrderPartiallyFilled,
OrderFilled,
OrderCanceled,
OrderModified,
OrderExpired,
PositionEvent,
PositionOpened,
PositionModified,
PositionClosed,
AccountEvent,
AccountBalanceUpdated,
MarginCallWarning,
)
from .models import (
OrderIntent,
OrderState,
Position,
Asset,
AssetMetadata,
AccountState,
Balance,
)
from .state import IntentStateStore, ActualStateStore
__all__ = [
# Core interface
"ExchangeKernel",
# Events
"OrderEvent",
"OrderSubmitted",
"OrderAccepted",
"OrderRejected",
"OrderPartiallyFilled",
"OrderFilled",
"OrderCanceled",
"OrderModified",
"OrderExpired",
"PositionEvent",
"PositionOpened",
"PositionModified",
"PositionClosed",
"AccountEvent",
"AccountBalanceUpdated",
"MarginCallWarning",
# Models
"OrderIntent",
"OrderState",
"Position",
"Asset",
"AssetMetadata",
"AccountState",
"Balance",
# State management
"IntentStateStore",
"ActualStateStore",
]

View File

@@ -0,0 +1,361 @@
"""
Base interface for Exchange Kernels.
Defines the abstract API that all exchange kernel implementations must support.
Each exchange (or exchange type) will have its own kernel implementation.
"""
from abc import ABC, abstractmethod
from typing import Callable, Any
from .models import (
OrderIntent,
OrderState,
Position,
AccountState,
AssetMetadata,
)
from .events import BaseEvent
from ..schema.order_spec import (
StandardOrder,
StandardOrderGroup,
SymbolMetadata,
)
class ExchangeKernel(ABC):
"""
Abstract base class for exchange kernels.
An exchange kernel manages the lifecycle of orders on a specific exchange,
maintaining both desired state (intents from strategy kernel) and actual
state (current orders on exchange), and continuously reconciling them.
Think of it as a Kubernetes-style controller for trading orders.
"""
def __init__(self, exchange_id: str):
"""
Initialize the exchange kernel.
Args:
exchange_id: Unique identifier for this exchange instance
"""
self.exchange_id = exchange_id
# -------------------------------------------------------------------------
# Command API - Strategy kernel sends intents
# -------------------------------------------------------------------------
@abstractmethod
async def place_order(self, order: StandardOrder, metadata: dict[str, Any] | None = None) -> str:
"""
Place a single order on the exchange.
This creates an OrderIntent and begins the reconciliation process to
get the order onto the exchange.
Args:
order: The order specification
metadata: Optional strategy-specific metadata
Returns:
intent_id: Unique identifier for this order intent
Raises:
ValidationError: If order violates market constraints
ExchangeError: If exchange rejects the order
"""
pass
@abstractmethod
async def place_order_group(
self,
group: StandardOrderGroup,
metadata: dict[str, Any] | None = None
) -> list[str]:
"""
Place a group of orders with OCO (One-Cancels-Other) relationship.
Args:
group: Group of orders with OCO mode
metadata: Optional strategy-specific metadata
Returns:
intent_ids: List of intent IDs for each order in the group
Raises:
ValidationError: If any order violates market constraints
ExchangeError: If exchange rejects the group
"""
pass
@abstractmethod
async def cancel_order(self, intent_id: str) -> None:
"""
Cancel an order by intent ID.
Updates the intent to indicate cancellation is desired, and the
reconciliation loop will handle the actual exchange cancellation.
Args:
intent_id: Intent ID of the order to cancel
Raises:
NotFoundError: If intent_id doesn't exist
ExchangeError: If exchange rejects cancellation
"""
pass
@abstractmethod
async def modify_order(
self,
intent_id: str,
new_order: StandardOrder,
) -> None:
"""
Modify an existing order.
Updates the order intent, and the reconciliation loop will update
the exchange order (via modify API if available, or cancel+replace).
Args:
intent_id: Intent ID of the order to modify
new_order: New order specification
Raises:
NotFoundError: If intent_id doesn't exist
ValidationError: If new order violates market constraints
ExchangeError: If exchange rejects modification
"""
pass
@abstractmethod
async def cancel_all_orders(self, symbol_id: str | None = None) -> int:
"""
Cancel all orders, optionally filtered by symbol.
Args:
symbol_id: If provided, only cancel orders for this symbol
Returns:
count: Number of orders canceled
"""
pass
# -------------------------------------------------------------------------
# Query API - Read desired and actual state
# -------------------------------------------------------------------------
@abstractmethod
async def get_order_intent(self, intent_id: str) -> OrderIntent:
"""
Get the desired order state (what strategy kernel wants).
Args:
intent_id: Intent ID to query
Returns:
The order intent
Raises:
NotFoundError: If intent_id doesn't exist
"""
pass
@abstractmethod
async def get_order_state(self, intent_id: str) -> OrderState:
"""
Get the actual order state (what's currently on exchange).
Args:
intent_id: Intent ID to query
Returns:
The current order state
Raises:
NotFoundError: If intent_id doesn't exist
"""
pass
@abstractmethod
async def get_all_intents(self, symbol_id: str | None = None) -> list[OrderIntent]:
"""
Get all order intents, optionally filtered by symbol.
Args:
symbol_id: If provided, only return intents for this symbol
Returns:
List of order intents
"""
pass
@abstractmethod
async def get_all_orders(self, symbol_id: str | None = None) -> list[OrderState]:
"""
Get all actual order states, optionally filtered by symbol.
Args:
symbol_id: If provided, only return orders for this symbol
Returns:
List of order states
"""
pass
@abstractmethod
async def get_positions(self, symbol_id: str | None = None) -> list[Position]:
"""
Get current positions, optionally filtered by symbol.
Args:
symbol_id: If provided, only return positions for this symbol
Returns:
List of positions
"""
pass
@abstractmethod
async def get_account_state(self) -> AccountState:
"""
Get current account state (balances, margin, etc.).
Returns:
Current account state
"""
pass
@abstractmethod
async def get_symbol_metadata(self, symbol_id: str) -> SymbolMetadata:
"""
Get metadata for a symbol (constraints, capabilities, etc.).
Args:
symbol_id: Symbol to query
Returns:
Symbol metadata
Raises:
NotFoundError: If symbol doesn't exist on this exchange
"""
pass
@abstractmethod
async def get_asset_metadata(self, asset_id: str) -> AssetMetadata:
"""
Get metadata for an asset.
Args:
asset_id: Asset to query
Returns:
Asset metadata
Raises:
NotFoundError: If asset doesn't exist
"""
pass
@abstractmethod
async def list_symbols(self) -> list[str]:
"""
List all available symbols on this exchange.
Returns:
List of symbol IDs
"""
pass
# -------------------------------------------------------------------------
# Event Subscription API
# -------------------------------------------------------------------------
@abstractmethod
def subscribe_events(
self,
callback: Callable[[BaseEvent], None],
event_filter: dict[str, Any] | None = None,
) -> str:
"""
Subscribe to events from this exchange kernel.
Args:
callback: Function to call when events occur
event_filter: Optional filter criteria (event_type, symbol_id, etc.)
Returns:
subscription_id: Unique ID for this subscription (for unsubscribe)
"""
pass
@abstractmethod
def unsubscribe_events(self, subscription_id: str) -> None:
"""
Unsubscribe from events.
Args:
subscription_id: Subscription ID returned from subscribe_events
"""
pass
# -------------------------------------------------------------------------
# Lifecycle Management
# -------------------------------------------------------------------------
@abstractmethod
async def start(self) -> None:
"""
Start the exchange kernel.
Initializes connections, starts reconciliation loops, etc.
"""
pass
@abstractmethod
async def stop(self) -> None:
"""
Stop the exchange kernel.
Closes connections, stops reconciliation loops, etc.
Does NOT cancel open orders - call cancel_all_orders() first if desired.
"""
pass
@abstractmethod
async def health_check(self) -> dict[str, Any]:
"""
Check health status of the exchange kernel.
Returns:
Health status dict with connection state, latency, error counts, etc.
"""
pass
# -------------------------------------------------------------------------
# Reconciliation Control (advanced)
# -------------------------------------------------------------------------
@abstractmethod
async def force_reconciliation(self, intent_id: str | None = None) -> None:
"""
Force immediate reconciliation.
Args:
intent_id: If provided, only reconcile this specific intent.
If None, reconcile all intents.
"""
pass
@abstractmethod
def get_reconciliation_metrics(self) -> dict[str, Any]:
"""
Get metrics about the reconciliation process.
Returns:
Metrics dict with reconciliation lag, error rates, retry counts, etc.
"""
pass

View File

@@ -0,0 +1,250 @@
"""
Event definitions for the Exchange Kernel.
All events that can occur during the order lifecycle, position management,
and account updates.
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from ..schema.order_spec import Float, Uint64
# ---------------------------------------------------------------------------
# Base Event Classes
# ---------------------------------------------------------------------------
class EventType(StrEnum):
"""Types of events emitted by the exchange kernel"""
# Order lifecycle
ORDER_SUBMITTED = "ORDER_SUBMITTED"
ORDER_ACCEPTED = "ORDER_ACCEPTED"
ORDER_REJECTED = "ORDER_REJECTED"
ORDER_PARTIALLY_FILLED = "ORDER_PARTIALLY_FILLED"
ORDER_FILLED = "ORDER_FILLED"
ORDER_CANCELED = "ORDER_CANCELED"
ORDER_MODIFIED = "ORDER_MODIFIED"
ORDER_EXPIRED = "ORDER_EXPIRED"
# Position events
POSITION_OPENED = "POSITION_OPENED"
POSITION_MODIFIED = "POSITION_MODIFIED"
POSITION_CLOSED = "POSITION_CLOSED"
# Account events
ACCOUNT_BALANCE_UPDATED = "ACCOUNT_BALANCE_UPDATED"
MARGIN_CALL_WARNING = "MARGIN_CALL_WARNING"
# System events
RECONCILIATION_FAILED = "RECONCILIATION_FAILED"
CONNECTION_LOST = "CONNECTION_LOST"
CONNECTION_RESTORED = "CONNECTION_RESTORED"
class BaseEvent(BaseModel):
"""Base class for all exchange kernel events"""
model_config = {"extra": "forbid"}
event_type: EventType = Field(description="Type of event")
timestamp: Uint64 = Field(description="Event timestamp (Unix milliseconds)")
exchange: str = Field(description="Exchange identifier")
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional event data")
# ---------------------------------------------------------------------------
# Order Events
# ---------------------------------------------------------------------------
class OrderEvent(BaseEvent):
"""Base class for order-related events"""
intent_id: str = Field(description="Order intent ID")
order_id: str | None = Field(default=None, description="Exchange order ID (if assigned)")
symbol_id: str = Field(description="Symbol being traded")
class OrderSubmitted(OrderEvent):
"""Order has been submitted to the exchange"""
event_type: EventType = Field(default=EventType.ORDER_SUBMITTED)
client_order_id: str | None = Field(default=None, description="Client-assigned order ID")
class OrderAccepted(OrderEvent):
"""Order has been accepted by the exchange"""
event_type: EventType = Field(default=EventType.ORDER_ACCEPTED)
order_id: str = Field(description="Exchange-assigned order ID")
accepted_at: Uint64 = Field(description="Exchange acceptance timestamp")
class OrderRejected(OrderEvent):
"""Order was rejected by the exchange"""
event_type: EventType = Field(default=EventType.ORDER_REJECTED)
reason: str = Field(description="Rejection reason")
error_code: str | None = Field(default=None, description="Exchange error code")
class OrderPartiallyFilled(OrderEvent):
"""Order was partially filled"""
event_type: EventType = Field(default=EventType.ORDER_PARTIALLY_FILLED)
order_id: str = Field(description="Exchange order ID")
fill_price: Float = Field(description="Fill price for this execution")
fill_quantity: Float = Field(description="Quantity filled in this execution")
total_filled: Float = Field(description="Total quantity filled so far")
remaining_quantity: Float = Field(description="Remaining quantity to fill")
commission: Float = Field(default=0.0, description="Commission/fee for this fill")
commission_asset: str | None = Field(default=None, description="Asset used for commission")
trade_id: str | None = Field(default=None, description="Exchange trade ID")
class OrderFilled(OrderEvent):
"""Order was completely filled"""
event_type: EventType = Field(default=EventType.ORDER_FILLED)
order_id: str = Field(description="Exchange order ID")
average_fill_price: Float = Field(description="Average execution price")
total_quantity: Float = Field(description="Total quantity filled")
total_commission: Float = Field(default=0.0, description="Total commission/fees")
commission_asset: str | None = Field(default=None, description="Asset used for commission")
completed_at: Uint64 = Field(description="Completion timestamp")
class OrderCanceled(OrderEvent):
"""Order was canceled"""
event_type: EventType = Field(default=EventType.ORDER_CANCELED)
order_id: str = Field(description="Exchange order ID")
reason: str = Field(description="Cancellation reason")
filled_quantity: Float = Field(default=0.0, description="Quantity filled before cancellation")
canceled_at: Uint64 = Field(description="Cancellation timestamp")
class OrderModified(OrderEvent):
"""Order was modified (price, quantity, etc.)"""
event_type: EventType = Field(default=EventType.ORDER_MODIFIED)
order_id: str = Field(description="Exchange order ID")
old_price: Float | None = Field(default=None, description="Previous price")
new_price: Float | None = Field(default=None, description="New price")
old_quantity: Float | None = Field(default=None, description="Previous quantity")
new_quantity: Float | None = Field(default=None, description="New quantity")
modified_at: Uint64 = Field(description="Modification timestamp")
class OrderExpired(OrderEvent):
"""Order expired (GTD, DAY orders)"""
event_type: EventType = Field(default=EventType.ORDER_EXPIRED)
order_id: str = Field(description="Exchange order ID")
filled_quantity: Float = Field(default=0.0, description="Quantity filled before expiration")
expired_at: Uint64 = Field(description="Expiration timestamp")
# ---------------------------------------------------------------------------
# Position Events
# ---------------------------------------------------------------------------
class PositionEvent(BaseEvent):
"""Base class for position-related events"""
position_id: str = Field(description="Position identifier")
symbol_id: str = Field(description="Symbol identifier")
asset_id: str = Field(description="Asset identifier")
class PositionOpened(PositionEvent):
"""New position was opened"""
event_type: EventType = Field(default=EventType.POSITION_OPENED)
quantity: Float = Field(description="Position quantity")
entry_price: Float = Field(description="Entry price")
side: str = Field(description="LONG or SHORT")
leverage: Float | None = Field(default=None, description="Leverage")
class PositionModified(PositionEvent):
"""Existing position was modified (size change, etc.)"""
event_type: EventType = Field(default=EventType.POSITION_MODIFIED)
old_quantity: Float = Field(description="Previous quantity")
new_quantity: Float = Field(description="New quantity")
average_entry_price: Float = Field(description="Updated average entry price")
unrealized_pnl: Float | None = Field(default=None, description="Current unrealized P&L")
class PositionClosed(PositionEvent):
"""Position was closed"""
event_type: EventType = Field(default=EventType.POSITION_CLOSED)
exit_price: Float = Field(description="Exit price")
realized_pnl: Float = Field(description="Realized profit/loss")
closed_at: Uint64 = Field(description="Closure timestamp")
# ---------------------------------------------------------------------------
# Account Events
# ---------------------------------------------------------------------------
class AccountEvent(BaseEvent):
"""Base class for account-related events"""
account_id: str = Field(description="Account identifier")
class AccountBalanceUpdated(AccountEvent):
"""Account balance was updated"""
event_type: EventType = Field(default=EventType.ACCOUNT_BALANCE_UPDATED)
asset_id: str = Field(description="Asset that changed")
old_balance: Float = Field(description="Previous balance")
new_balance: Float = Field(description="New balance")
old_available: Float = Field(description="Previous available")
new_available: Float = Field(description="New available")
change_reason: str = Field(description="Why balance changed (TRADE, DEPOSIT, WITHDRAWAL, etc.)")
class MarginCallWarning(AccountEvent):
"""Margin level is approaching liquidation threshold"""
event_type: EventType = Field(default=EventType.MARGIN_CALL_WARNING)
margin_level: Float = Field(description="Current margin level")
liquidation_threshold: Float = Field(description="Liquidation threshold")
required_action: str = Field(description="Required action to avoid liquidation")
estimated_liquidation_price: Float | None = Field(
default=None,
description="Estimated liquidation price for positions"
)
# ---------------------------------------------------------------------------
# System Events
# ---------------------------------------------------------------------------
class ReconciliationFailed(BaseEvent):
"""Failed to reconcile intent with actual state"""
event_type: EventType = Field(default=EventType.RECONCILIATION_FAILED)
intent_id: str = Field(description="Order intent ID")
error_message: str = Field(description="Error details")
retry_count: int = Field(description="Number of retry attempts")
class ConnectionLost(BaseEvent):
"""Connection to exchange was lost"""
event_type: EventType = Field(default=EventType.CONNECTION_LOST)
reason: str = Field(description="Disconnection reason")
class ConnectionRestored(BaseEvent):
"""Connection to exchange was restored"""
event_type: EventType = Field(default=EventType.CONNECTION_RESTORED)
downtime_duration: int = Field(description="Duration of downtime in milliseconds")

View File

@@ -0,0 +1,194 @@
"""
Data models for the Exchange Kernel.
Defines order intents, order state, positions, assets, and account state.
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from ..schema.order_spec import (
StandardOrder,
StandardOrderStatus,
AssetType,
Float,
Uint64,
)
# ---------------------------------------------------------------------------
# Order Intent and State
# ---------------------------------------------------------------------------
class OrderIntent(BaseModel):
"""
Desired order state from the strategy kernel.
This represents what the strategy wants, not what currently exists.
The exchange kernel will work to reconcile actual state with this intent.
"""
model_config = {"extra": "forbid"}
intent_id: str = Field(description="Unique identifier for this intent (client-assigned)")
order: StandardOrder = Field(description="The desired order specification")
group_id: str | None = Field(default=None, description="Group ID for OCO relationships")
created_at: Uint64 = Field(description="When this intent was created")
updated_at: Uint64 = Field(description="When this intent was last modified")
metadata: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific metadata")
class ReconciliationStatus(StrEnum):
"""Status of reconciliation between intent and actual state"""
PENDING = "PENDING" # Not yet submitted to exchange
SUBMITTING = "SUBMITTING" # Currently being submitted
ACTIVE = "ACTIVE" # Successfully placed on exchange
RECONCILING = "RECONCILING" # Intent changed, updating exchange order
FAILED = "FAILED" # Failed to submit or reconcile
COMPLETED = "COMPLETED" # Order fully filled
CANCELED = "CANCELED" # Order canceled
class OrderState(BaseModel):
"""
Actual current state of an order on the exchange.
This represents reality - what the exchange reports about the order.
May differ from OrderIntent during reconciliation.
"""
model_config = {"extra": "forbid"}
intent_id: str = Field(description="Links back to the OrderIntent")
exchange_order_id: str = Field(description="Exchange-assigned order ID")
status: StandardOrderStatus = Field(description="Current order status from exchange")
reconciliation_status: ReconciliationStatus = Field(description="Reconciliation state")
last_sync_at: Uint64 = Field(description="Last time we synced with exchange")
error_message: str | None = Field(default=None, description="Error details if FAILED")
# ---------------------------------------------------------------------------
# Position and Asset Models
# ---------------------------------------------------------------------------
class AssetMetadata(BaseModel):
"""
Metadata describing an asset type.
Provides context for positions, balances, and trading.
"""
model_config = {"extra": "forbid"}
asset_id: str = Field(description="Unique asset identifier")
symbol: str = Field(description="Asset symbol (e.g., 'BTC', 'ETH', 'USD')")
asset_type: AssetType = Field(description="Type of asset")
name: str = Field(description="Full name")
# Contract specifications (for derivatives)
contract_size: Float | None = Field(default=None, description="Contract multiplier")
settlement_asset: str | None = Field(default=None, description="Settlement currency")
expiry_timestamp: Uint64 | None = Field(default=None, description="Expiration timestamp")
# Trading parameters
tick_size: Float | None = Field(default=None, description="Minimum price increment")
lot_size: Float | None = Field(default=None, description="Minimum quantity increment")
# Margin requirements (for leveraged products)
initial_margin_rate: Float | None = Field(default=None, description="Initial margin requirement")
maintenance_margin_rate: Float | None = Field(default=None, description="Maintenance margin requirement")
# Additional metadata
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific metadata")
class Asset(BaseModel):
"""
An asset holding (spot, margin, derivative position, etc.)
"""
model_config = {"extra": "forbid"}
asset_id: str = Field(description="References AssetMetadata")
quantity: Float = Field(description="Amount held (positive or negative for short positions)")
available: Float = Field(description="Amount available for trading (not locked in orders)")
locked: Float = Field(description="Amount locked in open orders")
# For derivative positions
entry_price: Float | None = Field(default=None, description="Average entry price")
mark_price: Float | None = Field(default=None, description="Current mark price")
liquidation_price: Float | None = Field(default=None, description="Estimated liquidation price")
unrealized_pnl: Float | None = Field(default=None, description="Unrealized profit/loss")
realized_pnl: Float | None = Field(default=None, description="Realized profit/loss")
# Margin info
margin_used: Float | None = Field(default=None, description="Margin allocated to this position")
updated_at: Uint64 = Field(description="Last update timestamp")
class Position(BaseModel):
"""
A trading position (spot, margin, perpetual, futures, etc.)
Tracks both the asset holdings and associated metadata.
"""
model_config = {"extra": "forbid"}
position_id: str = Field(description="Unique position identifier")
symbol_id: str = Field(description="Trading symbol")
asset: Asset = Field(description="Asset holding details")
metadata: AssetMetadata = Field(description="Asset metadata")
# Position-level info
leverage: Float | None = Field(default=None, description="Current leverage")
side: str | None = Field(default=None, description="LONG or SHORT (for derivatives)")
updated_at: Uint64 = Field(description="Last update timestamp")
class Balance(BaseModel):
"""Account balance for a single currency/asset"""
model_config = {"extra": "forbid"}
asset_id: str = Field(description="Asset identifier")
total: Float = Field(description="Total balance")
available: Float = Field(description="Available for trading")
locked: Float = Field(description="Locked in orders/positions")
# For margin accounts
borrowed: Float = Field(default=0.0, description="Borrowed amount (margin)")
interest: Float = Field(default=0.0, description="Accrued interest")
updated_at: Uint64 = Field(description="Last update timestamp")
class AccountState(BaseModel):
"""
Complete account state including balances, positions, and margin info.
"""
model_config = {"extra": "forbid"}
account_id: str = Field(description="Account identifier")
exchange: str = Field(description="Exchange identifier")
balances: list[Balance] = Field(default_factory=list, description="All asset balances")
positions: list[Position] = Field(default_factory=list, description="All open positions")
# Margin account info
total_equity: Float | None = Field(default=None, description="Total account equity")
total_margin_used: Float | None = Field(default=None, description="Total margin in use")
total_available_margin: Float | None = Field(default=None, description="Available margin")
margin_level: Float | None = Field(default=None, description="Margin level (equity/margin_used)")
# Risk metrics
total_unrealized_pnl: Float | None = Field(default=None, description="Total unrealized P&L")
total_realized_pnl: Float | None = Field(default=None, description="Total realized P&L")
updated_at: Uint64 = Field(description="Last update timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific data")

View File

@@ -0,0 +1,472 @@
"""
State management for the Exchange Kernel.
Implements the storage and reconciliation logic for desired vs actual state.
This is the "Kubernetes for orders" concept - maintaining intent and continuously
reconciling reality to match intent.
"""
from abc import ABC, abstractmethod
from typing import Any
from collections import defaultdict
from .models import OrderIntent, OrderState, ReconciliationStatus
from ..schema.order_spec import Uint64
# ---------------------------------------------------------------------------
# Intent State Store - Desired State
# ---------------------------------------------------------------------------
class IntentStateStore(ABC):
"""
Storage for order intents (desired state).
This represents what the strategy kernel wants. Intents are durable and
persist across restarts. The reconciliation loop continuously works to
make actual state match these intents.
"""
@abstractmethod
async def create_intent(self, intent: OrderIntent) -> None:
"""
Store a new order intent.
Args:
intent: The order intent to store
Raises:
AlreadyExistsError: If intent_id already exists
"""
pass
@abstractmethod
async def get_intent(self, intent_id: str) -> OrderIntent:
"""
Retrieve an order intent.
Args:
intent_id: Intent ID to retrieve
Returns:
The order intent
Raises:
NotFoundError: If intent_id doesn't exist
"""
pass
@abstractmethod
async def update_intent(self, intent: OrderIntent) -> None:
"""
Update an existing order intent.
Args:
intent: Updated intent (intent_id must match existing)
Raises:
NotFoundError: If intent_id doesn't exist
"""
pass
@abstractmethod
async def delete_intent(self, intent_id: str) -> None:
"""
Delete an order intent.
Args:
intent_id: Intent ID to delete
Raises:
NotFoundError: If intent_id doesn't exist
"""
pass
@abstractmethod
async def list_intents(
self,
symbol_id: str | None = None,
group_id: str | None = None,
) -> list[OrderIntent]:
"""
List all order intents, optionally filtered.
Args:
symbol_id: Filter by symbol
group_id: Filter by OCO group
Returns:
List of matching intents
"""
pass
@abstractmethod
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
"""
Get all intents in an OCO group.
Args:
group_id: Group ID to query
Returns:
List of intents in the group
"""
pass
# ---------------------------------------------------------------------------
# Actual State Store - Current Reality
# ---------------------------------------------------------------------------
class ActualStateStore(ABC):
"""
Storage for actual order state (reality on exchange).
This represents what actually exists on the exchange right now.
Updated frequently from exchange feeds and order status queries.
"""
@abstractmethod
async def create_order_state(self, state: OrderState) -> None:
"""
Store a new order state.
Args:
state: The order state to store
Raises:
AlreadyExistsError: If order state for this intent_id already exists
"""
pass
@abstractmethod
async def get_order_state(self, intent_id: str) -> OrderState:
"""
Retrieve order state for an intent.
Args:
intent_id: Intent ID to query
Returns:
The current order state
Raises:
NotFoundError: If no state exists for this intent
"""
pass
@abstractmethod
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
"""
Retrieve order state by exchange order ID.
Useful for processing exchange callbacks that only provide exchange_order_id.
Args:
exchange_order_id: Exchange's order ID
Returns:
The order state
Raises:
NotFoundError: If no state exists for this exchange order ID
"""
pass
@abstractmethod
async def update_order_state(self, state: OrderState) -> None:
"""
Update an existing order state.
Args:
state: Updated state (intent_id must match existing)
Raises:
NotFoundError: If state doesn't exist
"""
pass
@abstractmethod
async def delete_order_state(self, intent_id: str) -> None:
"""
Delete an order state.
Args:
intent_id: Intent ID whose state to delete
Raises:
NotFoundError: If state doesn't exist
"""
pass
@abstractmethod
async def list_order_states(
self,
symbol_id: str | None = None,
reconciliation_status: ReconciliationStatus | None = None,
) -> list[OrderState]:
"""
List all order states, optionally filtered.
Args:
symbol_id: Filter by symbol
reconciliation_status: Filter by reconciliation status
Returns:
List of matching order states
"""
pass
@abstractmethod
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
"""
Find orders that haven't been synced recently.
Used to identify orders that need status updates from exchange.
Args:
max_age_seconds: Maximum age since last sync
Returns:
List of order states that need refresh
"""
pass
# ---------------------------------------------------------------------------
# In-Memory Implementations (for testing/prototyping)
# ---------------------------------------------------------------------------
class InMemoryIntentStore(IntentStateStore):
"""Simple in-memory implementation of IntentStateStore"""
def __init__(self):
self._intents: dict[str, OrderIntent] = {}
self._by_symbol: dict[str, set[str]] = defaultdict(set)
self._by_group: dict[str, set[str]] = defaultdict(set)
async def create_intent(self, intent: OrderIntent) -> None:
if intent.intent_id in self._intents:
raise ValueError(f"Intent {intent.intent_id} already exists")
self._intents[intent.intent_id] = intent
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
if intent.group_id:
self._by_group[intent.group_id].add(intent.intent_id)
async def get_intent(self, intent_id: str) -> OrderIntent:
if intent_id not in self._intents:
raise KeyError(f"Intent {intent_id} not found")
return self._intents[intent_id]
async def update_intent(self, intent: OrderIntent) -> None:
if intent.intent_id not in self._intents:
raise KeyError(f"Intent {intent.intent_id} not found")
old_intent = self._intents[intent.intent_id]
# Update indices if symbol or group changed
if old_intent.order.symbol_id != intent.order.symbol_id:
self._by_symbol[old_intent.order.symbol_id].discard(intent.intent_id)
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
if old_intent.group_id != intent.group_id:
if old_intent.group_id:
self._by_group[old_intent.group_id].discard(intent.intent_id)
if intent.group_id:
self._by_group[intent.group_id].add(intent.intent_id)
self._intents[intent.intent_id] = intent
async def delete_intent(self, intent_id: str) -> None:
if intent_id not in self._intents:
raise KeyError(f"Intent {intent_id} not found")
intent = self._intents[intent_id]
self._by_symbol[intent.order.symbol_id].discard(intent_id)
if intent.group_id:
self._by_group[intent.group_id].discard(intent_id)
del self._intents[intent_id]
async def list_intents(
self,
symbol_id: str | None = None,
group_id: str | None = None,
) -> list[OrderIntent]:
if symbol_id and group_id:
# Intersection of both filters
symbol_ids = self._by_symbol.get(symbol_id, set())
group_ids = self._by_group.get(group_id, set())
intent_ids = symbol_ids & group_ids
elif symbol_id:
intent_ids = self._by_symbol.get(symbol_id, set())
elif group_id:
intent_ids = self._by_group.get(group_id, set())
else:
intent_ids = self._intents.keys()
return [self._intents[iid] for iid in intent_ids]
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
intent_ids = self._by_group.get(group_id, set())
return [self._intents[iid] for iid in intent_ids]
class InMemoryActualStateStore(ActualStateStore):
"""Simple in-memory implementation of ActualStateStore"""
def __init__(self):
self._states: dict[str, OrderState] = {}
self._by_exchange_id: dict[str, str] = {} # exchange_order_id -> intent_id
self._by_symbol: dict[str, set[str]] = defaultdict(set)
async def create_order_state(self, state: OrderState) -> None:
if state.intent_id in self._states:
raise ValueError(f"Order state for intent {state.intent_id} already exists")
self._states[state.intent_id] = state
self._by_exchange_id[state.exchange_order_id] = state.intent_id
self._by_symbol[state.status.order.symbol_id].add(state.intent_id)
async def get_order_state(self, intent_id: str) -> OrderState:
if intent_id not in self._states:
raise KeyError(f"Order state for intent {intent_id} not found")
return self._states[intent_id]
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
if exchange_order_id not in self._by_exchange_id:
raise KeyError(f"Order state for exchange order {exchange_order_id} not found")
intent_id = self._by_exchange_id[exchange_order_id]
return self._states[intent_id]
async def update_order_state(self, state: OrderState) -> None:
if state.intent_id not in self._states:
raise KeyError(f"Order state for intent {state.intent_id} not found")
old_state = self._states[state.intent_id]
# Update exchange_id index if it changed
if old_state.exchange_order_id != state.exchange_order_id:
del self._by_exchange_id[old_state.exchange_order_id]
self._by_exchange_id[state.exchange_order_id] = state.intent_id
# Update symbol index if it changed
old_symbol = old_state.status.order.symbol_id
new_symbol = state.status.order.symbol_id
if old_symbol != new_symbol:
self._by_symbol[old_symbol].discard(state.intent_id)
self._by_symbol[new_symbol].add(state.intent_id)
self._states[state.intent_id] = state
async def delete_order_state(self, intent_id: str) -> None:
if intent_id not in self._states:
raise KeyError(f"Order state for intent {intent_id} not found")
state = self._states[intent_id]
del self._by_exchange_id[state.exchange_order_id]
self._by_symbol[state.status.order.symbol_id].discard(intent_id)
del self._states[intent_id]
async def list_order_states(
self,
symbol_id: str | None = None,
reconciliation_status: ReconciliationStatus | None = None,
) -> list[OrderState]:
if symbol_id:
intent_ids = self._by_symbol.get(symbol_id, set())
states = [self._states[iid] for iid in intent_ids]
else:
states = list(self._states.values())
if reconciliation_status:
states = [s for s in states if s.reconciliation_status == reconciliation_status]
return states
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
import time
current_time = int(time.time())
threshold = current_time - max_age_seconds
return [
state
for state in self._states.values()
if state.last_sync_at < threshold
]
# ---------------------------------------------------------------------------
# Reconciliation Engine (framework only, no implementation)
# ---------------------------------------------------------------------------
class ReconciliationEngine:
"""
Reconciliation engine that continuously works to make actual state match intent.
This is the heart of the "Kubernetes for orders" concept. It:
1. Compares desired state (intents) with actual state (exchange orders)
2. Computes necessary actions (place, modify, cancel)
3. Executes those actions via the exchange API
4. Handles retries, errors, and edge cases
This is a framework class - concrete implementations will be exchange-specific.
"""
def __init__(
self,
intent_store: IntentStateStore,
actual_store: ActualStateStore,
):
"""
Initialize the reconciliation engine.
Args:
intent_store: Store for desired state
actual_store: Store for actual state
"""
self.intent_store = intent_store
self.actual_store = actual_store
self._running = False
async def start(self) -> None:
"""Start the reconciliation loop"""
self._running = True
# Implementation would start async reconciliation loop here
pass
async def stop(self) -> None:
"""Stop the reconciliation loop"""
self._running = False
# Implementation would stop reconciliation loop here
pass
async def reconcile_intent(self, intent_id: str) -> None:
"""
Reconcile a specific intent.
Compares the intent with actual state and takes necessary actions.
Args:
intent_id: Intent to reconcile
"""
# Framework only - concrete implementation needed
pass
async def reconcile_all(self) -> None:
"""
Reconcile all intents.
Full reconciliation pass over all orders.
"""
# Framework only - concrete implementation needed
pass
def get_metrics(self) -> dict[str, Any]:
"""
Get reconciliation metrics.
Returns:
Metrics about reconciliation performance, errors, etc.
"""
return {
"running": self._running,
"reconciliation_lag_ms": 0, # Framework only
"pending_reconciliations": 0, # Framework only
"error_count": 0, # Framework only
"retry_count": 0, # Framework only
}

View File

@@ -94,6 +94,11 @@ class Gateway:
logger.info(f"Session is busy, interrupting existing task") logger.info(f"Session is busy, interrupting existing task")
await session.interrupt() await session.interrupt()
# Check if this is a stop interrupt (empty message)
if not message.content.strip() and not message.attachments:
logger.info("Received stop interrupt (empty message), not starting new agent round")
return
# Add user message to history # Add user message to history
session.add_message("user", message.content, message.channel_id) session.add_message("user", message.content, message.channel_id)
logger.info(f"User message added to history, history length: {len(session.get_history())}") logger.info(f"User message added to history, history length: {len(session.get_history())}")
@@ -134,33 +139,55 @@ class Gateway:
# Stream chunks back to active channels # Stream chunks back to active channels
full_response = "" full_response = ""
chunk_count = 0 chunk_count = 0
async for chunk in response_stream: accumulated_metadata = {}
chunk_count += 1
full_response += chunk
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
# Send chunk to all active channels async for chunk in response_stream:
agent_msg = AgentMessage( # Handle dict response with metadata (from agent executor)
session_id=session.session_id, if isinstance(chunk, dict):
target_channels=session.active_channels, content = chunk.get("content", "")
content=chunk, metadata = chunk.get("metadata", {})
stream_chunk=True, # Accumulate metadata (e.g., plot_urls)
done=False for key, value in metadata.items():
) if key == "plot_urls" and value:
await self._send_to_channels(agent_msg) # Append to existing plot_urls
if "plot_urls" not in accumulated_metadata:
accumulated_metadata["plot_urls"] = []
accumulated_metadata["plot_urls"].extend(value)
logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}")
else:
accumulated_metadata[key] = value
chunk = content
# Only send non-empty chunks
if chunk:
chunk_count += 1
full_response += chunk
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
# Send chunk to all active channels with accumulated metadata
agent_msg = AgentMessage(
session_id=session.session_id,
target_channels=session.active_channels,
content=chunk,
stream_chunk=True,
done=False,
metadata=accumulated_metadata.copy()
)
await self._send_to_channels(agent_msg)
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}") logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
# Send final done message # Send final done message with all accumulated metadata
agent_msg = AgentMessage( agent_msg = AgentMessage(
session_id=session.session_id, session_id=session.session_id,
target_channels=session.active_channels, target_channels=session.active_channels,
content="", content="",
stream_chunk=True, stream_chunk=True,
done=True done=True,
metadata=accumulated_metadata
) )
await self._send_to_channels(agent_msg) await self._send_to_channels(agent_msg)
logger.info("Sent final done message to channels") logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}")
# Add to history # Add to history
session.add_message("assistant", full_response) session.add_message("assistant", full_response)

View File

@@ -0,0 +1,172 @@
"""
Composable Indicator System.
Provides a framework for building DAGs of data transformation pipelines
that process time-series data incrementally. Indicators can consume
DataSources or other Indicators as inputs, composing into arbitrarily
complex processing graphs.
Key Components:
---------------
Indicator (base.py):
Abstract base class for all indicator implementations.
Declares input/output schemas and implements synchronous compute().
IndicatorRegistry (registry.py):
Central catalog of available indicators with rich metadata
for AI agent discovery and tool generation.
Pipeline (pipeline.py):
Execution engine that builds DAGs, resolves dependencies,
and orchestrates incremental data flow through indicator chains.
Schema Types (schema.py):
Type definitions for input/output schemas, computation context,
and metadata for AI-native documentation.
Usage Example:
--------------
from indicator import Indicator, IndicatorRegistry, Pipeline
from indicator.schema import (
InputSchema, OutputSchema, ComputeContext, ComputeResult,
IndicatorMetadata, IndicatorParameter
)
# Define an indicator
class SimpleMovingAverage(Indicator):
@classmethod
def get_metadata(cls):
return IndicatorMetadata(
name="SMA",
display_name="Simple Moving Average",
description="Arithmetic mean of prices over N periods",
category="trend",
parameters=[
IndicatorParameter(
name="period",
type="int",
description="Number of periods to average",
default=20,
min_value=1
)
],
tags=["moving-average", "trend-following"]
)
@classmethod
def get_input_schema(cls):
return InputSchema(
required_columns=[
ColumnInfo(name="close", type="float", description="Closing price")
]
)
@classmethod
def get_output_schema(cls, **params):
return OutputSchema(
columns=[
ColumnInfo(
name="sma",
type="float",
description=f"Simple moving average over {params.get('period', 20)} periods"
)
]
)
def compute(self, context: ComputeContext) -> ComputeResult:
period = self.params["period"]
closes = context.get_column("close")
times = context.get_times()
sma_values = []
for i in range(len(closes)):
if i < period - 1:
sma_values.append(None)
else:
window = closes[i - period + 1 : i + 1]
sma_values.append(sum(window) / period)
return ComputeResult(
data=[
{"time": times[i], "sma": sma_values[i]}
for i in range(len(times))
]
)
# Register the indicator
registry = IndicatorRegistry()
registry.register(SimpleMovingAverage)
# Create a pipeline
pipeline = Pipeline(datasource_registry)
pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D")
sma_indicator = registry.create_instance("SMA", "sma_20", period=20)
pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"])
# Execute
results = pipeline.execute(datasource_data={"price_data": price_bars})
sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma
Design Philosophy:
------------------
1. **Schema-based composition**: Indicators declare inputs/outputs via schemas,
enabling automatic validation and flexible composition.
2. **Synchronous execution**: All computation is synchronous for simplicity.
Async handling happens at the event/strategy layer.
3. **Incremental updates**: Indicators receive context about what changed,
allowing optimized recomputation of only affected values.
4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs
make indicators discoverable and usable by AI agents.
5. **Generic data flow**: Indicators work with any data source that matches
their input schema, not specific DataSource instances.
6. **Event-driven**: Designed to react to DataSource updates and propagate
changes through the DAG efficiently.
"""
from .base import DataSourceAdapter, Indicator
from .pipeline import Pipeline, PipelineNode
from .registry import IndicatorRegistry
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
IndicatorParameter,
InputSchema,
OutputSchema,
)
from .talib_adapter import (
TALibIndicator,
register_all_talib_indicators,
is_talib_available,
get_talib_version,
)
__all__ = [
# Core classes
"Indicator",
"IndicatorRegistry",
"Pipeline",
"PipelineNode",
"DataSourceAdapter",
# Schema types
"InputSchema",
"OutputSchema",
"ComputeContext",
"ComputeResult",
"IndicatorMetadata",
"IndicatorParameter",
# TA-Lib integration
"TALibIndicator",
"register_all_talib_indicators",
"is_talib_available",
"get_talib_version",
]

View File

@@ -0,0 +1,230 @@
"""
Abstract Indicator interface.
Provides the base class for all technical indicators and derived data transformations.
Indicators compose into DAGs, processing data incrementally as updates arrive.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
InputSchema,
OutputSchema,
)
class Indicator(ABC):
"""
Abstract base class for all indicators.
Indicators are composable transformation nodes that:
- Declare input schema (columns they need)
- Declare output schema (columns they produce)
- Compute outputs synchronously from inputs
- Support incremental updates (process only what changed)
- Provide rich metadata for AI agent discovery
Indicators are stateless at the instance level - all state is managed
by the pipeline execution engine. This allows the same indicator class
to be reused with different parameters.
"""
def __init__(self, instance_name: str, **params):
"""
Initialize an indicator instance.
Args:
instance_name: Unique name for this instance (used for output column prefixing)
**params: Configuration parameters (validated against metadata.parameters)
"""
self.instance_name = instance_name
self.params = params
self._validate_params()
@classmethod
@abstractmethod
def get_metadata(cls) -> IndicatorMetadata:
"""
Get metadata for this indicator class.
Called by the registry for AI agent discovery and documentation.
Should return comprehensive information about the indicator's purpose,
parameters, and use cases.
Returns:
IndicatorMetadata describing this indicator class
"""
pass
@classmethod
@abstractmethod
def get_input_schema(cls) -> InputSchema:
"""
Get the input schema required by this indicator.
Declares what columns must be present in the input data.
The pipeline will match this against available data sources.
Returns:
InputSchema describing required and optional input columns
"""
pass
@classmethod
@abstractmethod
def get_output_schema(cls, **params) -> OutputSchema:
"""
Get the output schema produced by this indicator.
Output column names will be automatically prefixed with the instance name
by the pipeline engine.
Args:
**params: Configuration parameters (may affect output schema)
Returns:
OutputSchema describing the columns this indicator produces
"""
pass
@abstractmethod
def compute(self, context: ComputeContext) -> ComputeResult:
"""
Compute indicator values from input data.
This method is called synchronously by the pipeline engine whenever
input data changes. Implementations should:
1. Extract needed columns from context.data
2. Perform calculations
3. Return results with proper time alignment
For incremental updates (context.is_incremental == True):
- context.data contains only new/updated rows
- Implementations MAY optimize by computing only these rows
- OR implementations MAY recompute everything (simpler but slower)
Args:
context: Input data and update metadata
Returns:
ComputeResult with calculated indicator values
Raises:
ValueError: If input data doesn't match expected schema
"""
pass
def _validate_params(self) -> None:
"""
Validate that provided parameters match the metadata specification.
Raises:
ValueError: If required parameters are missing or invalid
"""
metadata = self.get_metadata()
# Check for required parameters
for param_def in metadata.parameters:
if param_def.required and param_def.name not in self.params:
raise ValueError(
f"Indicator '{metadata.name}' requires parameter '{param_def.name}'"
)
# Validate parameter types and ranges
for name, value in self.params.items():
# Find parameter definition
param_def = next(
(p for p in metadata.parameters if p.name == name),
None
)
if param_def is None:
raise ValueError(
f"Unknown parameter '{name}' for indicator '{metadata.name}'"
)
# Type checking
if param_def.type == "int" and not isinstance(value, int):
raise ValueError(
f"Parameter '{name}' must be int, got {type(value).__name__}"
)
elif param_def.type == "float" and not isinstance(value, (int, float)):
raise ValueError(
f"Parameter '{name}' must be float, got {type(value).__name__}"
)
elif param_def.type == "bool" and not isinstance(value, bool):
raise ValueError(
f"Parameter '{name}' must be bool, got {type(value).__name__}"
)
elif param_def.type == "string" and not isinstance(value, str):
raise ValueError(
f"Parameter '{name}' must be string, got {type(value).__name__}"
)
# Range checking for numeric types
if param_def.type in ("int", "float"):
if param_def.min_value is not None and value < param_def.min_value:
raise ValueError(
f"Parameter '{name}' must be >= {param_def.min_value}, got {value}"
)
if param_def.max_value is not None and value > param_def.max_value:
raise ValueError(
f"Parameter '{name}' must be <= {param_def.max_value}, got {value}"
)
def get_output_columns(self) -> List[str]:
"""
Get the output column names with instance name prefix.
Returns:
List of prefixed output column names
"""
output_schema = self.get_output_schema(**self.params)
prefixed = output_schema.with_prefix(self.instance_name)
return [col.name for col in prefixed.columns if col.name != output_schema.time_column]
def __repr__(self) -> str:
return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})"
class DataSourceAdapter:
"""
Adapter to make a DataSource look like an Indicator for pipeline composition.
This allows DataSources to be inputs to indicators in a unified way.
"""
def __init__(self, datasource_id: str, symbol: str, resolution: str):
"""
Create a DataSource adapter.
Args:
datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo')
symbol: Symbol to query (e.g., 'BTC/USD')
resolution: Time resolution (e.g., '1', '5', '1D')
"""
self.datasource_id = datasource_id
self.symbol = symbol
self.resolution = resolution
self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_")
def get_output_columns(self) -> List[str]:
"""
Get the columns provided by this datasource.
Note: This requires runtime resolution - the pipeline engine
will need to query the actual DataSource to get the schema.
Returns:
List of column names (placeholder - needs runtime resolution)
"""
# This will be resolved at runtime by the pipeline engine
return []
def __repr__(self) -> str:
return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')"

View File

@@ -0,0 +1,439 @@
"""
Pipeline execution engine for composable indicators.
Manages DAG construction, dependency resolution, incremental updates,
and efficient data flow through indicator chains.
"""
import logging
from collections import defaultdict, deque
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from datasource.base import DataSource
from datasource.schema import ColumnInfo
from .base import DataSourceAdapter, Indicator
from .schema import ComputeContext, ComputeResult
logger = logging.getLogger(__name__)
class PipelineNode:
"""
A node in the pipeline DAG.
Can be either a DataSource adapter or an Indicator instance.
"""
def __init__(
self,
node_id: str,
node: Union[DataSourceAdapter, Indicator],
dependencies: List[str]
):
"""
Create a pipeline node.
Args:
node_id: Unique identifier for this node
node: The DataSourceAdapter or Indicator instance
dependencies: List of node_ids this node depends on
"""
self.node_id = node_id
self.node = node
self.dependencies = dependencies
self.output_columns: List[str] = []
self.cached_data: List[Dict[str, Any]] = []
def is_datasource(self) -> bool:
"""Check if this node is a DataSource adapter."""
return isinstance(self.node, DataSourceAdapter)
def is_indicator(self) -> bool:
"""Check if this node is an Indicator."""
return isinstance(self.node, Indicator)
def __repr__(self) -> str:
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
class Pipeline:
"""
Execution engine for indicator DAGs.
Manages:
- DAG construction and validation
- Topological sorting for execution order
- Data flow and caching
- Incremental updates (only recompute what changed)
- Schema validation
"""
def __init__(self, datasource_registry):
"""
Initialize a pipeline.
Args:
datasource_registry: DataSourceRegistry for resolving data sources
"""
self.datasource_registry = datasource_registry
self.nodes: Dict[str, PipelineNode] = {}
self.execution_order: List[str] = []
self._dirty_nodes: Set[str] = set()
def add_datasource(
self,
node_id: str,
datasource_name: str,
symbol: str,
resolution: str
) -> None:
"""
Add a DataSource to the pipeline.
Args:
node_id: Unique identifier for this node
datasource_name: Name of the datasource in the registry
symbol: Symbol to query
resolution: Time resolution
Raises:
ValueError: If node_id already exists or datasource not found
"""
if node_id in self.nodes:
raise ValueError(f"Node '{node_id}' already exists in pipeline")
datasource = self.datasource_registry.get(datasource_name)
if not datasource:
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
node = PipelineNode(node_id, adapter, dependencies=[])
self.nodes[node_id] = node
self._invalidate_execution_order()
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
def add_indicator(
self,
node_id: str,
indicator: Indicator,
input_node_ids: List[str]
) -> None:
"""
Add an Indicator to the pipeline.
Args:
node_id: Unique identifier for this node
indicator: Indicator instance
input_node_ids: List of node IDs providing input data
Raises:
ValueError: If node_id already exists, dependencies not found, or schema mismatch
"""
if node_id in self.nodes:
raise ValueError(f"Node '{node_id}' already exists in pipeline")
# Validate dependencies exist
for dep_id in input_node_ids:
if dep_id not in self.nodes:
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
# TODO: Validate input schema matches available columns from dependencies
# This requires merging output schemas from all input nodes
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
self.nodes[node_id] = node
self._invalidate_execution_order()
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
def remove_node(self, node_id: str) -> None:
"""
Remove a node from the pipeline.
Args:
node_id: Node to remove
Raises:
ValueError: If other nodes depend on this node
"""
if node_id not in self.nodes:
return
# Check for dependent nodes
dependents = [
n.node_id for n in self.nodes.values()
if node_id in n.dependencies
]
if dependents:
raise ValueError(
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
)
del self.nodes[node_id]
self._invalidate_execution_order()
logger.info(f"Removed node '{node_id}' from pipeline")
def _invalidate_execution_order(self) -> None:
"""Mark execution order as needing recomputation."""
self.execution_order = []
def _compute_execution_order(self) -> List[str]:
"""
Compute topological sort of the DAG.
Returns:
List of node IDs in execution order
Raises:
ValueError: If DAG contains cycles
"""
if self.execution_order:
return self.execution_order
# Kahn's algorithm for topological sort
in_degree = {node_id: 0 for node_id in self.nodes}
for node in self.nodes.values():
for dep in node.dependencies:
in_degree[node.node_id] += 1
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
result = []
while queue:
node_id = queue.popleft()
result.append(node_id)
# Find all nodes that depend on this one
for other_node in self.nodes.values():
if node_id in other_node.dependencies:
in_degree[other_node.node_id] -= 1
if in_degree[other_node.node_id] == 0:
queue.append(other_node.node_id)
if len(result) != len(self.nodes):
raise ValueError("Pipeline contains cycles")
self.execution_order = result
logger.debug(f"Computed execution order: {result}")
return result
def execute(
self,
datasource_data: Dict[str, List[Dict[str, Any]]],
incremental: bool = False,
updated_from_time: Optional[int] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
Execute the pipeline.
Args:
datasource_data: Mapping of DataSource node_id to input data
incremental: Whether this is an incremental update
updated_from_time: Timestamp of earliest updated row (for incremental)
Returns:
Dictionary mapping node_id to output data (all nodes)
Raises:
ValueError: If required datasource data is missing
"""
execution_order = self._compute_execution_order()
results: Dict[str, List[Dict[str, Any]]] = {}
logger.info(
f"Executing pipeline with {len(execution_order)} nodes "
f"(incremental={incremental})"
)
for node_id in execution_order:
node = self.nodes[node_id]
if node.is_datasource():
# DataSource node - get data from input
if node_id not in datasource_data:
raise ValueError(
f"DataSource node '{node_id}' has no input data"
)
results[node_id] = datasource_data[node_id]
node.cached_data = results[node_id]
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
elif node.is_indicator():
# Indicator node - compute from dependencies
indicator = node.node
# Merge input data from all dependencies
input_data = self._merge_dependency_data(node.dependencies, results)
# Create compute context
context = ComputeContext(
data=input_data,
is_incremental=incremental,
updated_from_time=updated_from_time
)
# Execute indicator
logger.debug(
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
)
compute_result = indicator.compute(context)
# Merge result with input data (adding prefixed columns)
output_data = compute_result.merge_with_prefix(
indicator.instance_name,
input_data
)
results[node_id] = output_data
node.cached_data = output_data
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
return results
def _merge_dependency_data(
self,
dependency_ids: List[str],
results: Dict[str, List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""
Merge data from multiple dependency nodes.
Data is merged by time, with later dependencies overwriting earlier ones
for conflicting column names.
Args:
dependency_ids: List of node IDs to merge
results: Current execution results
Returns:
Merged data rows
"""
if not dependency_ids:
return []
if len(dependency_ids) == 1:
return results[dependency_ids[0]]
# Build time-indexed data from first dependency
merged: Dict[int, Dict[str, Any]] = {}
for row in results[dependency_ids[0]]:
merged[row["time"]] = row.copy()
# Merge in additional dependencies
for dep_id in dependency_ids[1:]:
for row in results[dep_id]:
time_key = row["time"]
if time_key in merged:
# Merge columns (later dependencies win)
merged[time_key].update(row)
else:
# New timestamp
merged[time_key] = row.copy()
# Sort by time and return
sorted_times = sorted(merged.keys())
return [merged[t] for t in sorted_times]
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
"""
Get cached output data for a specific node.
Args:
node_id: Node identifier
Returns:
Cached data or None if not available
"""
node = self.nodes.get(node_id)
return node.cached_data if node else None
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
"""
Get the output schema for a specific node.
Args:
node_id: Node identifier
Returns:
List of ColumnInfo describing output columns
Raises:
ValueError: If node not found
"""
node = self.nodes.get(node_id)
if not node:
raise ValueError(f"Node '{node_id}' not found")
if node.is_datasource():
# Would need to query the actual datasource at runtime
# For now, return empty - this requires integration with DataSource
return []
elif node.is_indicator():
indicator = node.node
output_schema = indicator.get_output_schema(**indicator.params)
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
return prefixed_schema.columns
return []
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
"""
Validate the entire pipeline for correctness.
Checks:
- No cycles (already checked in execution order)
- All dependencies exist (already checked in add_indicator)
- Input schemas match output schemas (TODO)
Returns:
Tuple of (is_valid, error_message)
"""
try:
self._compute_execution_order()
return True, None
except ValueError as e:
return False, str(e)
def get_node_count(self) -> int:
"""Get the number of nodes in the pipeline."""
return len(self.nodes)
def get_indicator_count(self) -> int:
"""Get the number of indicator nodes in the pipeline."""
return sum(1 for node in self.nodes.values() if node.is_indicator())
def get_datasource_count(self) -> int:
"""Get the number of datasource nodes in the pipeline."""
return sum(1 for node in self.nodes.values() if node.is_datasource())
def describe(self) -> Dict[str, Any]:
"""
Get a detailed description of the pipeline structure.
Returns:
Dictionary with pipeline metadata and structure
"""
return {
"node_count": self.get_node_count(),
"datasource_count": self.get_datasource_count(),
"indicator_count": self.get_indicator_count(),
"nodes": [
{
"id": node.node_id,
"type": "datasource" if node.is_datasource() else "indicator",
"node": str(node.node),
"dependencies": node.dependencies,
"cached_rows": len(node.cached_data)
}
for node in self.nodes.values()
],
"execution_order": self.execution_order or self._compute_execution_order(),
"is_valid": self.validate_pipeline()[0]
}

View File

@@ -0,0 +1,349 @@
"""
Indicator registry for managing and discovering indicators.
Provides AI agents with a queryable catalog of available indicators,
their capabilities, and metadata.
"""
from typing import Dict, List, Optional, Type
from .base import Indicator
from .schema import IndicatorMetadata, InputSchema, OutputSchema
class IndicatorRegistry:
"""
Central registry for indicator classes.
Enables:
- Registration of indicator implementations
- Discovery by name, category, or tags
- Schema validation
- AI agent tool generation
"""
def __init__(self):
self._indicators: Dict[str, Type[Indicator]] = {}
def register(self, indicator_class: Type[Indicator]) -> None:
"""
Register an indicator class.
Args:
indicator_class: Indicator class to register
Raises:
ValueError: If an indicator with this name is already registered
"""
metadata = indicator_class.get_metadata()
if metadata.name in self._indicators:
raise ValueError(
f"Indicator '{metadata.name}' is already registered"
)
self._indicators[metadata.name] = indicator_class
def unregister(self, name: str) -> None:
"""
Unregister an indicator class.
Args:
name: Indicator class name
"""
self._indicators.pop(name, None)
def get(self, name: str) -> Optional[Type[Indicator]]:
"""
Get an indicator class by name.
Args:
name: Indicator class name
Returns:
Indicator class or None if not found
"""
return self._indicators.get(name)
def list_indicators(self) -> List[str]:
"""
Get names of all registered indicators.
Returns:
List of indicator class names
"""
return list(self._indicators.keys())
def get_metadata(self, name: str) -> Optional[IndicatorMetadata]:
"""
Get metadata for a specific indicator.
Args:
name: Indicator class name
Returns:
IndicatorMetadata or None if not found
"""
indicator_class = self.get(name)
if indicator_class:
return indicator_class.get_metadata()
return None
def get_all_metadata(self) -> List[IndicatorMetadata]:
"""
Get metadata for all registered indicators.
Useful for AI agent tool generation and discovery.
Returns:
List of IndicatorMetadata for all registered indicators
"""
return [cls.get_metadata() for cls in self._indicators.values()]
def search_by_category(self, category: str) -> List[IndicatorMetadata]:
"""
Find indicators by category.
Args:
category: Category name (e.g., 'momentum', 'trend', 'volatility')
Returns:
List of matching indicator metadata
"""
results = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
if metadata.category.lower() == category.lower():
results.append(metadata)
return results
def search_by_tag(self, tag: str) -> List[IndicatorMetadata]:
"""
Find indicators by tag.
Args:
tag: Tag to search for (case-insensitive)
Returns:
List of matching indicator metadata
"""
tag_lower = tag.lower()
results = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
if any(t.lower() == tag_lower for t in metadata.tags):
results.append(metadata)
return results
def search_by_text(self, query: str) -> List[IndicatorMetadata]:
"""
Full-text search across indicator names, descriptions, and use cases.
Args:
query: Search query (case-insensitive)
Returns:
List of matching indicator metadata, ranked by relevance
"""
query_lower = query.lower()
results = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
score = 0
# Check name (highest weight)
if query_lower in metadata.name.lower():
score += 10
if query_lower in metadata.display_name.lower():
score += 8
# Check description
if query_lower in metadata.description.lower():
score += 5
# Check use cases
for use_case in metadata.use_cases:
if query_lower in use_case.lower():
score += 3
# Check tags
for tag in metadata.tags:
if query_lower in tag.lower():
score += 2
if score > 0:
results.append((score, metadata))
# Sort by score descending
results.sort(key=lambda x: x[0], reverse=True)
return [metadata for _, metadata in results]
def find_compatible_indicators(
self,
available_columns: List[str],
column_types: Dict[str, str]
) -> List[IndicatorMetadata]:
"""
Find indicators that can be computed from available columns.
Args:
available_columns: List of column names available
column_types: Mapping of column name to type
Returns:
List of indicators whose input schema is satisfied
"""
from datasource.schema import ColumnInfo
# Build ColumnInfo list from available data
available_schema = [
ColumnInfo(
name=name,
type=column_types.get(name, "float"),
description=f"Column {name}"
)
for name in available_columns
]
results = []
for indicator_class in self._indicators.values():
input_schema = indicator_class.get_input_schema()
if input_schema.matches(available_schema):
results.append(indicator_class.get_metadata())
return results
def validate_indicator_chain(
self,
indicator_chain: List[tuple[str, Dict]]
) -> tuple[bool, Optional[str]]:
"""
Validate that a chain of indicators can be connected.
Args:
indicator_chain: List of (indicator_name, params) tuples in execution order
Returns:
Tuple of (is_valid, error_message)
"""
if not indicator_chain:
return True, None
# For now, just check that all indicators exist
# More sophisticated DAG validation happens in the pipeline engine
for indicator_name, params in indicator_chain:
if indicator_name not in self._indicators:
return False, f"Indicator '{indicator_name}' not found in registry"
return True, None
def get_input_schema(self, name: str) -> Optional[InputSchema]:
"""
Get input schema for a specific indicator.
Args:
name: Indicator class name
Returns:
InputSchema or None if not found
"""
indicator_class = self.get(name)
if indicator_class:
return indicator_class.get_input_schema()
return None
def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]:
"""
Get output schema for a specific indicator with given parameters.
Args:
name: Indicator class name
**params: Indicator parameters
Returns:
OutputSchema or None if not found
"""
indicator_class = self.get(name)
if indicator_class:
return indicator_class.get_output_schema(**params)
return None
def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]:
"""
Create an indicator instance with validation.
Args:
name: Indicator class name
instance_name: Unique instance name (for output column prefixing)
**params: Indicator configuration parameters
Returns:
Indicator instance or None if class not found
Raises:
ValueError: If parameters are invalid
"""
indicator_class = self.get(name)
if not indicator_class:
return None
return indicator_class(instance_name=instance_name, **params)
def generate_ai_tool_spec(self) -> Dict:
"""
Generate a JSON specification for AI agent tools.
Creates a structured representation of all indicators that can be
used to build agent tools for indicator selection and composition.
Returns:
Dict suitable for AI agent tool registration
"""
tools = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
# Build parameter spec
parameters = {
"type": "object",
"properties": {},
"required": []
}
for param in metadata.parameters:
param_spec = {
"type": param.type,
"description": param.description
}
if param.default is not None:
param_spec["default"] = param.default
if param.min_value is not None:
param_spec["minimum"] = param.min_value
if param.max_value is not None:
param_spec["maximum"] = param.max_value
parameters["properties"][param.name] = param_spec
if param.required:
parameters["required"].append(param.name)
tool = {
"name": f"indicator_{metadata.name.lower()}",
"description": f"{metadata.display_name}: {metadata.description}",
"category": metadata.category,
"use_cases": metadata.use_cases,
"tags": metadata.tags,
"parameters": parameters,
"input_schema": indicator_class.get_input_schema().model_dump(),
"output_schema": indicator_class.get_output_schema().model_dump()
}
tools.append(tool)
return {
"indicator_tools": tools,
"total_count": len(tools)
}

View File

@@ -0,0 +1,269 @@
"""
Data models for the Indicator system.
Defines schemas for input/output specifications, computation context,
and metadata for AI agent discovery.
"""
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from datasource.schema import ColumnInfo
class InputSchema(BaseModel):
"""
Declares the required input columns for an Indicator.
Indicators match against any data source (DataSource or other Indicator)
that provides columns satisfying this schema.
"""
model_config = {"extra": "forbid"}
required_columns: List[ColumnInfo] = Field(
description="Columns that must be present in the input data"
)
optional_columns: List[ColumnInfo] = Field(
default_factory=list,
description="Columns that may be used if present but are not required"
)
time_column: str = Field(
default="time",
description="Name of the timestamp column (must be present)"
)
def matches(self, available_columns: List[ColumnInfo]) -> bool:
"""
Check if available columns satisfy this input schema.
Args:
available_columns: Columns provided by a data source
Returns:
True if all required columns are present with compatible types
"""
available_map = {col.name: col for col in available_columns}
# Check time column exists
if self.time_column not in available_map:
return False
# Check all required columns exist with compatible types
for required in self.required_columns:
if required.name not in available_map:
return False
available = available_map[required.name]
if available.type != required.type:
return False
return True
def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]:
"""
Get list of missing required column names.
Args:
available_columns: Columns provided by a data source
Returns:
List of missing column names
"""
available_names = {col.name for col in available_columns}
missing = []
if self.time_column not in available_names:
missing.append(self.time_column)
for required in self.required_columns:
if required.name not in available_names:
missing.append(required.name)
return missing
class OutputSchema(BaseModel):
"""
Declares the output columns produced by an Indicator.
Column names will be automatically prefixed with the indicator instance name
to avoid collisions in the pipeline.
"""
model_config = {"extra": "forbid"}
columns: List[ColumnInfo] = Field(
description="Output columns produced by this indicator"
)
time_column: str = Field(
default="time",
description="Name of the timestamp column (passed through from input)"
)
def with_prefix(self, prefix: str) -> "OutputSchema":
"""
Create a new OutputSchema with all column names prefixed.
Args:
prefix: Prefix to add (e.g., indicator instance name)
Returns:
New OutputSchema with prefixed column names
"""
prefixed_columns = [
ColumnInfo(
name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name,
type=col.type,
description=col.description,
unit=col.unit,
nullable=col.nullable
)
for col in self.columns
]
return OutputSchema(
columns=prefixed_columns,
time_column=self.time_column
)
class IndicatorParameter(BaseModel):
"""
Metadata for a configurable indicator parameter.
Used for AI agent discovery and dynamic indicator instantiation.
"""
model_config = {"extra": "forbid"}
name: str = Field(description="Parameter name")
type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type")
description: str = Field(description="Human and LLM-readable description")
default: Optional[Any] = Field(default=None, description="Default value if not specified")
required: bool = Field(default=False, description="Whether this parameter is required")
min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)")
max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)")
class IndicatorMetadata(BaseModel):
"""
Rich metadata for an Indicator class.
Enables AI agents to discover, understand, and instantiate indicators.
"""
model_config = {"extra": "forbid"}
name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')")
display_name: str = Field(description="Human-readable display name")
description: str = Field(description="Detailed description of what this indicator computes and why it's useful")
category: str = Field(
description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')"
)
parameters: List[IndicatorParameter] = Field(
default_factory=list,
description="Configurable parameters for this indicator"
)
use_cases: List[str] = Field(
default_factory=list,
description="Common use cases and trading scenarios where this indicator is helpful"
)
references: List[str] = Field(
default_factory=list,
description="URLs or citations for indicator methodology"
)
tags: List[str] = Field(
default_factory=list,
description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')"
)
class ComputeContext(BaseModel):
"""
Context passed to an Indicator's compute() method.
Contains the input data and metadata about what changed (for incremental updates).
"""
model_config = {"extra": "forbid"}
data: List[Dict[str, Any]] = Field(
description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}"
)
is_incremental: bool = Field(
default=False,
description="True if this is an incremental update (only new/changed rows), False for full recompute"
)
updated_from_time: Optional[int] = Field(
default=None,
description="Unix timestamp (ms) of the earliest updated row (for incremental updates)"
)
def get_column(self, name: str) -> List[Any]:
"""
Extract a single column as a list of values.
Args:
name: Column name
Returns:
List of values in time order
"""
return [row.get(name) for row in self.data]
def get_times(self) -> List[int]:
"""
Get the time column as a list.
Returns:
List of timestamps in order
"""
return [row["time"] for row in self.data]
class ComputeResult(BaseModel):
"""
Result from an Indicator's compute() method.
Contains the computed output data with proper column naming.
"""
model_config = {"extra": "forbid"}
data: List[Dict[str, Any]] = Field(
description="Output data rows (time-ordered). Must include time column."
)
is_partial: bool = Field(
default=False,
description="True if this result only contains updates (for incremental computation)"
)
def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Merge this result into existing data with column name prefixing.
Args:
prefix: Prefix to add to all column names except time
existing_data: Existing data to merge with (matched by time)
Returns:
Merged data with prefixed columns added
"""
# Build a time index for new data
time_index = {row["time"]: row for row in self.data}
# Merge into existing data
result = []
for existing_row in existing_data:
row_time = existing_row["time"]
merged_row = existing_row.copy()
if row_time in time_index:
new_row = time_index[row_time]
for key, value in new_row.items():
if key != "time":
merged_row[f"{prefix}_{key}"] = value
result.append(merged_row)
return result

View File

@@ -0,0 +1,436 @@
"""
TA-Lib indicator adapter.
Provides automatic registration of all TA-Lib technical indicators
as composable Indicator instances.
Installation Requirements:
--------------------------
TA-Lib requires both the C library and Python wrapper:
1. Install TA-Lib C library:
- Ubuntu/Debian: sudo apt-get install libta-lib-dev
- macOS: brew install ta-lib
- From source: https://ta-lib.org/install.html
2. Install Python wrapper (already in requirements.txt):
pip install TA-Lib
Usage:
------
from indicator.talib_adapter import register_all_talib_indicators
# Auto-register all TA-Lib indicators
registry = IndicatorRegistry()
register_all_talib_indicators(registry)
# Now you can use any TA-Lib indicator
sma = registry.create_instance("SMA", "sma_20", period=20)
rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14)
"""
import logging
from typing import Any, Dict, List, Optional
import numpy as np
try:
import talib
from talib import abstract
TALIB_AVAILABLE = True
except ImportError:
TALIB_AVAILABLE = False
talib = None
abstract = None
from datasource.schema import ColumnInfo
from .base import Indicator
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
IndicatorParameter,
InputSchema,
OutputSchema,
)
logger = logging.getLogger(__name__)
# Mapping of TA-Lib parameter types to our schema types
TALIB_TYPE_MAP = {
"double": "float",
"double[]": "float",
"int": "int",
"str": "string",
}
# Categorization of TA-Lib functions
TALIB_CATEGORIES = {
"overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3",
"BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"],
"momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX",
"ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI",
"STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC",
"BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"],
"volume": ["AD", "ADOSC", "OBV"],
"volatility": ["ATR", "NATR", "TRANGE"],
"price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"],
"cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"],
"pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE",
"CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY",
"CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU",
"CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI",
"CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR",
"CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER",
"CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE",
"CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK",
"CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM",
"CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW",
"CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK",
"CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES",
"CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN",
"CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR",
"CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"],
"statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT",
"LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"],
"math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX",
"MULT", "SUB", "SUM"],
}
def _get_function_category(func_name: str) -> str:
"""Determine the category of a TA-Lib function."""
for category, functions in TALIB_CATEGORIES.items():
if func_name in functions:
return category
return "other"
class TALibIndicator(Indicator):
"""
Generic adapter for TA-Lib technical indicators.
Wraps any TA-Lib function to work within the composable indicator framework.
Handles parameter mapping, input validation, and output formatting.
"""
# Class variable to store the TA-Lib function name
talib_function_name: str = None
def __init__(self, instance_name: str, **params):
"""
Initialize a TA-Lib indicator.
Args:
instance_name: Unique name for this instance
**params: TA-Lib function parameters
"""
if not TALIB_AVAILABLE:
raise ImportError(
"TA-Lib is not installed. Please install the TA-Lib C library "
"and Python wrapper. See indicator/talib_adapter.py for instructions."
)
super().__init__(instance_name, **params)
self._talib_func = abstract.Function(self.talib_function_name)
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
"""Get metadata from TA-Lib function info."""
if not TALIB_AVAILABLE:
raise ImportError("TA-Lib is not installed")
func = abstract.Function(cls.talib_function_name)
info = func.info
# Build parameters list from TA-Lib function info
parameters = []
for param_name, param_info in info.get("parameters", {}).items():
# Handle case where param_info is a simple value (int/float) instead of a dict
if isinstance(param_info, dict):
param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float")
default_value = param_info.get("default_value")
else:
# param_info is a simple value (default), infer type from the value
if isinstance(param_info, int):
param_type = "int"
elif isinstance(param_info, float):
param_type = "float"
else:
param_type = "float" # Default to float
default_value = param_info
parameters.append(
IndicatorParameter(
name=param_name,
type=param_type,
description=f"TA-Lib parameter: {param_name}",
default=default_value,
required=False
)
)
# Get function group/category
category = _get_function_category(cls.talib_function_name)
# Build display name (split camelCase or handle CDL prefix)
display_name = cls.talib_function_name
if display_name.startswith("CDL"):
display_name = display_name[3:] # Remove CDL prefix for patterns
return IndicatorMetadata(
name=cls.talib_function_name,
display_name=display_name,
description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"),
category=category,
parameters=parameters,
use_cases=[f"Technical analysis using {cls.talib_function_name}"],
references=["https://ta-lib.org/function.html"],
tags=["talib", category, cls.talib_function_name.lower()]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
"""
Get input schema from TA-Lib function requirements.
Most TA-Lib functions use OHLCV data, but some use subsets.
"""
if not TALIB_AVAILABLE:
raise ImportError("TA-Lib is not installed")
func = abstract.Function(cls.talib_function_name)
info = func.info
input_names = info.get("input_names", {})
required_columns = []
# Map TA-Lib input names to our schema
if "prices" in input_names:
price_inputs = input_names["prices"]
if "open" in price_inputs:
required_columns.append(
ColumnInfo(name="open", type="float", description="Opening price")
)
if "high" in price_inputs:
required_columns.append(
ColumnInfo(name="high", type="float", description="High price")
)
if "low" in price_inputs:
required_columns.append(
ColumnInfo(name="low", type="float", description="Low price")
)
if "close" in price_inputs:
required_columns.append(
ColumnInfo(name="close", type="float", description="Closing price")
)
if "volume" in price_inputs:
required_columns.append(
ColumnInfo(name="volume", type="float", description="Trading volume")
)
# Handle functions that take generic price arrays
if "price" in input_names:
required_columns.append(
ColumnInfo(name="close", type="float", description="Price (typically close)")
)
# If no specific inputs found, assume close price
if not required_columns:
required_columns.append(
ColumnInfo(name="close", type="float", description="Closing price")
)
return InputSchema(required_columns=required_columns)
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
"""Get output schema from TA-Lib function outputs."""
if not TALIB_AVAILABLE:
raise ImportError("TA-Lib is not installed")
func = abstract.Function(cls.talib_function_name)
info = func.info
output_names = info.get("output_names", [])
columns = []
# Most TA-Lib functions output one or more float arrays
if isinstance(output_names, list):
for output_name in output_names:
columns.append(
ColumnInfo(
name=output_name.lower(),
type="float",
description=f"{cls.talib_function_name} output: {output_name}",
nullable=True # TA-Lib often has NaN for initial periods
)
)
else:
# Single output, use function name
columns.append(
ColumnInfo(
name=cls.talib_function_name.lower(),
type="float",
description=f"{cls.talib_function_name} indicator value",
nullable=True
)
)
return OutputSchema(columns=columns)
def compute(self, context: ComputeContext) -> ComputeResult:
"""Compute indicator using TA-Lib."""
# Extract input columns
input_data = {}
# Get the function's expected inputs
info = self._talib_func.info
input_names = info.get("input_names", {})
# Prepare input arrays
if "prices" in input_names:
price_inputs = input_names["prices"]
for price_type in price_inputs:
column_data = context.get_column(price_type)
# Convert to numpy array, replacing None with NaN
input_data[price_type] = np.array(
[float(v) if v is not None else np.nan for v in column_data]
)
elif "price" in input_names:
# Generic price input, use close
column_data = context.get_column("close")
input_data["price"] = np.array(
[float(v) if v is not None else np.nan for v in column_data]
)
else:
# Default to close if no inputs specified
column_data = context.get_column("close")
input_data["close"] = np.array(
[float(v) if v is not None else np.nan for v in column_data]
)
# Set parameters on the function
self._talib_func.parameters = self.params
# Execute TA-Lib function
try:
output = self._talib_func(input_data)
except Exception as e:
logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}")
raise ValueError(f"TA-Lib computation failed: {e}")
# Format output
times = context.get_times()
output_names = info.get("output_names", [])
# Handle single vs multiple outputs
if isinstance(output, np.ndarray):
# Single output
output_name = output_names[0].lower() if output_names else self.talib_function_name.lower()
result_data = [
{
"time": times[i],
output_name: float(output[i]) if not np.isnan(output[i]) else None
}
for i in range(len(times))
]
elif isinstance(output, tuple):
# Multiple outputs
result_data = []
for i in range(len(times)):
row = {"time": times[i]}
for j, output_array in enumerate(output):
output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}"
row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None
result_data.append(row)
else:
raise ValueError(f"Unexpected TA-Lib output type: {type(output)}")
return ComputeResult(
data=result_data,
is_partial=context.is_incremental
)
def create_talib_indicator_class(func_name: str) -> type:
"""
Dynamically create an Indicator class for a TA-Lib function.
Args:
func_name: TA-Lib function name (e.g., 'SMA', 'RSI')
Returns:
Indicator class for this function
"""
return type(
f"TALib_{func_name}",
(TALibIndicator,),
{"talib_function_name": func_name}
)
def register_all_talib_indicators(registry) -> int:
"""
Auto-register all available TA-Lib indicators with the registry.
Args:
registry: IndicatorRegistry instance
Returns:
Number of indicators registered
Raises:
ImportError: If TA-Lib is not installed
"""
if not TALIB_AVAILABLE:
logger.warning(
"TA-Lib is not installed. Skipping TA-Lib indicator registration. "
"Install TA-Lib C library and Python wrapper to enable TA-Lib indicators."
)
return 0
# Get all TA-Lib functions
func_groups = talib.get_function_groups()
all_functions = []
for group, functions in func_groups.items():
all_functions.extend(functions)
# Remove duplicates
all_functions = sorted(set(all_functions))
registered_count = 0
for func_name in all_functions:
try:
# Create indicator class for this function
indicator_class = create_talib_indicator_class(func_name)
# Register with the registry
registry.register(indicator_class)
registered_count += 1
except Exception as e:
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
continue
logger.info(f"Registered {registered_count} TA-Lib indicators")
return registered_count
def get_talib_version() -> Optional[str]:
"""
Get the installed TA-Lib version.
Returns:
Version string or None if not installed
"""
if TALIB_AVAILABLE:
return talib.__version__
return None
def is_talib_available() -> bool:
"""Check if TA-Lib is available."""
return TALIB_AVAILABLE

View File

@@ -20,13 +20,14 @@ from gateway.hub import Gateway
from gateway.channels.websocket import WebSocketChannel from gateway.channels.websocket import WebSocketChannel
from gateway.protocol import WebSocketAgentUserMessage from gateway.protocol import WebSocketAgentUserMessage
from agent.core import create_agent from agent.core import create_agent
from agent.tools import set_registry, set_datasource_registry from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
from schema.order_spec import SwapOrder from schema.order_spec import SwapOrder
from schema.chart_state import ChartState from schema.chart_state import ChartState
from datasource.registry import DataSourceRegistry from datasource.registry import DataSourceRegistry
from datasource.subscription_manager import SubscriptionManager from datasource.subscription_manager import SubscriptionManager
from datasource.websocket_handler import DatafeedWebSocketHandler from datasource.websocket_handler import DatafeedWebSocketHandler
from secrets_manager import SecretsStore, InvalidMasterPassword from secrets_manager import SecretsStore, InvalidMasterPassword
from indicator import IndicatorRegistry, register_all_talib_indicators
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -53,6 +54,9 @@ agent_executor = None
datasource_registry = DataSourceRegistry() datasource_registry = DataSourceRegistry()
subscription_manager = SubscriptionManager() subscription_manager = SubscriptionManager()
# Indicator infrastructure
indicator_registry = IndicatorRegistry()
# Global secrets store # Global secrets store
secrets_store = SecretsStore() secrets_store = SecretsStore()
@@ -80,6 +84,14 @@ async def lifespan(app: FastAPI):
logger.warning(f"CCXT not available: {e}. Only demo source will be available.") logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0") logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
# Initialize indicator registry with all TA-Lib indicators
try:
indicator_count = register_all_talib_indicators(indicator_registry)
logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators")
except Exception as e:
logger.warning(f"Failed to register TA-Lib indicators: {e}")
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
# Get API keys from secrets store if unlocked, otherwise fall back to environment # Get API keys from secrets store if unlocked, otherwise fall back to environment
anthropic_api_key = None anthropic_api_key = None
@@ -101,6 +113,7 @@ async def lifespan(app: FastAPI):
# Set the registries for agent tools # Set the registries for agent tools
set_registry(registry) set_registry(registry)
set_datasource_registry(datasource_registry) set_datasource_registry(datasource_registry)
set_indicator_registry(indicator_registry)
# Create and initialize agent # Create and initialize agent
agent_executor = create_agent( agent_executor = create_agent(

View File

@@ -40,6 +40,58 @@ class Exchange(StrEnum):
UNISWAP_V3 = "UniswapV3" UNISWAP_V3 = "UniswapV3"
class Side(StrEnum):
"""Order side: buy or sell"""
BUY = "BUY"
SELL = "SELL"
class AmountType(StrEnum):
"""Whether the order amount refers to base or quote currency"""
BASE = "BASE" # Amount is in base currency (e.g., BTC in BTC/USD)
QUOTE = "QUOTE" # Amount is in quote currency (e.g., USD in BTC/USD)
class TimeInForce(StrEnum):
"""Order lifetime specification"""
GTC = "GTC" # Good Till Cancel
IOC = "IOC" # Immediate or Cancel
FOK = "FOK" # Fill or Kill
DAY = "DAY" # Good for trading day
GTD = "GTD" # Good Till Date
class ConditionalOrderMode(StrEnum):
"""How conditional orders behave on partial fills"""
NEW_PER_FILL = "NEW_PER_FILL" # Create new conditional order per each fill
UNIFIED_ADJUSTING = "UNIFIED_ADJUSTING" # Single conditional order that adjusts amount
class TriggerType(StrEnum):
"""Type of conditional trigger"""
STOP_LOSS = "STOP_LOSS"
TAKE_PROFIT = "TAKE_PROFIT"
STOP_LIMIT = "STOP_LIMIT"
TRAILING_STOP = "TRAILING_STOP"
class TickSpacingMode(StrEnum):
"""How price tick spacing is determined"""
FIXED = "FIXED" # Fixed tick size
DYNAMIC = "DYNAMIC" # Tick size varies by price level
CONTINUOUS = "CONTINUOUS" # No tick restrictions
class AssetType(StrEnum):
"""Type of tradeable asset"""
SPOT = "SPOT" # Spot/cash market
MARGIN = "MARGIN" # Margin trading
PERP = "PERP" # Perpetual futures
FUTURE = "FUTURE" # Dated futures
OPTION = "OPTION" # Options
SYNTHETIC = "SYNTHETIC" # Synthetic/derived instruments
class OcoMode(StrEnum): class OcoMode(StrEnum):
NO_OCO = "NO_OCO" NO_OCO = "NO_OCO"
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL" CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
@@ -96,6 +148,126 @@ class TrancheStatus(BaseModel):
endTime: Uint32 = Field(description="Concrete end timestamp") endTime: Uint32 = Field(description="Concrete end timestamp")
# ---------------------------------------------------------------------------
# Standard Order Models
# ---------------------------------------------------------------------------
class ConditionalTrigger(BaseModel):
"""Conditional order trigger (stop-loss, take-profit, etc.)"""
model_config = {"extra": "forbid"}
trigger_type: TriggerType
trigger_price: Float = Field(description="Price at which conditional order activates")
trailing_delta: Float | None = Field(default=None, description="For trailing stops: delta from peak/trough")
class AmountConstraints(BaseModel):
"""Constraints on order amounts for a symbol"""
model_config = {"extra": "forbid"}
min_amount: Float = Field(description="Minimum order amount")
max_amount: Float = Field(description="Maximum order amount")
step_size: Float = Field(description="Amount increment granularity")
class PriceConstraints(BaseModel):
"""Constraints on order pricing for a symbol"""
model_config = {"extra": "forbid"}
tick_spacing_mode: TickSpacingMode
tick_size: Float | None = Field(default=None, description="Fixed tick size (if FIXED mode)")
min_price: Float | None = Field(default=None, description="Minimum allowed price")
max_price: Float | None = Field(default=None, description="Maximum allowed price")
class MarketCapabilities(BaseModel):
"""Describes what order features a market supports"""
model_config = {"extra": "forbid"}
supported_sides: list[Side] = Field(description="Supported order sides (usually both)")
supported_amount_types: list[AmountType] = Field(description="Whether BASE, QUOTE, or both amounts are supported")
supports_market_orders: bool = Field(description="Whether market orders are supported")
supports_limit_orders: bool = Field(description="Whether limit orders are supported")
supported_time_in_force: list[TimeInForce] = Field(description="Supported order lifetimes")
supports_conditional_orders: bool = Field(description="Whether stop-loss/take-profit are supported")
supported_trigger_types: list[TriggerType] = Field(default_factory=list, description="Supported trigger types")
supports_post_only: bool = Field(default=False, description="Whether post-only orders are supported")
supports_reduce_only: bool = Field(default=False, description="Whether reduce-only orders are supported")
supports_iceberg: bool = Field(default=False, description="Whether iceberg orders are supported")
market_order_amount_type: AmountType | None = Field(
default=None,
description="Required amount type for market orders (some DEXs require exact-in)"
)
class SymbolMetadata(BaseModel):
"""Complete metadata describing a tradeable symbol/market"""
model_config = {"extra": "forbid"}
symbol_id: str = Field(description="Unique symbol identifier")
base_asset: str = Field(description="Base asset (e.g., 'BTC')")
quote_asset: str = Field(description="Quote asset (e.g., 'USD')")
asset_type: AssetType = Field(description="Type of market")
exchange: str = Field(description="Exchange identifier")
amount_constraints: AmountConstraints
price_constraints: PriceConstraints
capabilities: MarketCapabilities
contract_size: Float | None = Field(default=None, description="For futures/options: contract multiplier")
settlement_asset: str | None = Field(default=None, description="For derivatives: settlement currency")
expiry_timestamp: Uint64 | None = Field(default=None, description="For dated futures/options: expiration")
class StandardOrder(BaseModel):
"""Standard order specification for exchange kernels"""
model_config = {"extra": "forbid"}
symbol_id: str = Field(description="Symbol to trade")
side: Side = Field(description="Buy or sell")
amount: Float = Field(description="Order amount")
amount_type: AmountType = Field(description="Whether amount is BASE or QUOTE currency")
limit_price: Float | None = Field(default=None, description="Limit price (None = market order)")
time_in_force: TimeInForce = Field(default=TimeInForce.GTC, description="Order lifetime")
good_till_date: Uint64 | None = Field(default=None, description="Expiry timestamp for GTD orders")
conditional_trigger: ConditionalTrigger | None = Field(
default=None,
description="Stop-loss/take-profit trigger"
)
conditional_mode: ConditionalOrderMode | None = Field(
default=None,
description="How conditional orders behave on partial fills"
)
reduce_only: bool = Field(default=False, description="Only reduce existing position")
post_only: bool = Field(default=False, description="Only make, never take")
iceberg_qty: Float | None = Field(default=None, description="Visible amount for iceberg orders")
client_order_id: str | None = Field(default=None, description="Client-specified order ID")
class StandardOrderStatus(BaseModel):
"""Current status of a standard order"""
model_config = {"extra": "forbid"}
order: StandardOrder
order_id: str = Field(description="Exchange-assigned order ID")
status: str = Field(description="Order status: NEW, PARTIALLY_FILLED, FILLED, CANCELED, REJECTED, EXPIRED")
filled_amount: Float = Field(description="Amount filled so far")
average_fill_price: Float = Field(description="Average execution price")
created_at: Uint64 = Field(description="Order creation timestamp")
updated_at: Uint64 = Field(description="Last update timestamp")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Order models # Order models
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -117,7 +289,22 @@ class SwapOrder(BaseModel):
tranches: list[Tranche] = Field(min_length=1) tranches: list[Tranche] = Field(min_length=1)
class StandardOrderGroup(BaseModel):
"""Group of orders with OCO (One-Cancels-Other) relationship"""
model_config = {"extra": "forbid"}
mode: OcoMode
orders: list[StandardOrder] = Field(min_length=1)
# ---------------------------------------------------------------------------
# Legacy swap order models (kept for backward compatibility)
# ---------------------------------------------------------------------------
class OcoGroup(BaseModel): class OcoGroup(BaseModel):
"""DEPRECATED: Use StandardOrderGroup instead"""
model_config = {"extra": "forbid"} model_config = {"extra": "forbid"}
mode: OcoMode mode: OcoMode

View File

@@ -0,0 +1,40 @@
"""
Encrypted secrets management with master password protection.
This module provides secure storage for sensitive configuration like API keys,
using Argon2id for password-based key derivation and Fernet (AES-256) for encryption.
Basic usage:
from secrets_manager import SecretsStore
# First time setup
store = SecretsStore()
store.initialize("my-master-password")
store.set("ANTHROPIC_API_KEY", "sk-ant-...")
# Later usage
store = SecretsStore()
store.unlock("my-master-password")
api_key = store.get("ANTHROPIC_API_KEY")
Command-line interface:
python -m secrets_manager.cli init
python -m secrets_manager.cli set KEY VALUE
python -m secrets_manager.cli get KEY
python -m secrets_manager.cli list
python -m secrets_manager.cli change-password
"""
from .store import (
SecretsStore,
SecretsStoreError,
SecretsStoreLocked,
InvalidMasterPassword,
)
__all__ = [
"SecretsStore",
"SecretsStoreError",
"SecretsStoreLocked",
"InvalidMasterPassword",
]

View File

@@ -0,0 +1,374 @@
#!/usr/bin/env python3
"""
Command-line interface for managing the encrypted secrets store.
Usage:
python -m secrets.cli init # Initialize new secrets store
python -m secrets.cli set KEY VALUE # Set a secret
python -m secrets.cli get KEY # Get a secret
python -m secrets.cli delete KEY # Delete a secret
python -m secrets.cli list # List all secret keys
python -m secrets.cli change-password # Change master password
python -m secrets.cli export FILE # Export encrypted backup
python -m secrets.cli import FILE # Import encrypted backup
python -m secrets.cli migrate-from-env # Migrate secrets from .env file
"""
import sys
import argparse
import getpass
from pathlib import Path
from .store import SecretsStore, SecretsStoreError, InvalidMasterPassword
def get_password(prompt: str = "Master password: ", confirm: bool = False) -> str:
"""
Securely get password from user.
Args:
prompt: Password prompt
confirm: If True, ask for confirmation
Returns:
Password string
"""
password = getpass.getpass(prompt)
if confirm:
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
print("Error: Passwords do not match", file=sys.stderr)
sys.exit(1)
return password
def cmd_init(args):
"""Initialize a new secrets store."""
store = SecretsStore()
if store.is_initialized:
print("Error: Secrets store is already initialized", file=sys.stderr)
print(f"Location: {store.secrets_file}", file=sys.stderr)
sys.exit(1)
password = get_password("Create master password: ", confirm=True)
if len(password) < 8:
print("Error: Password must be at least 8 characters", file=sys.stderr)
sys.exit(1)
store.initialize(password)
print(f"Secrets store initialized at {store.secrets_file}")
def cmd_set(args):
"""Set a secret value."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
password = get_password()
try:
store.unlock(password)
except InvalidMasterPassword:
print("Error: Invalid master password", file=sys.stderr)
sys.exit(1)
store.set(args.key, args.value)
print(f"✓ Secret '{args.key}' saved")
def cmd_get(args):
"""Get a secret value."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
password = get_password()
try:
store.unlock(password)
except InvalidMasterPassword:
print("Error: Invalid master password", file=sys.stderr)
sys.exit(1)
value = store.get(args.key)
if value is None:
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
sys.exit(1)
# Print to stdout (can be captured)
print(value)
def cmd_delete(args):
"""Delete a secret."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
password = get_password()
try:
store.unlock(password)
except InvalidMasterPassword:
print("Error: Invalid master password", file=sys.stderr)
sys.exit(1)
if store.delete(args.key):
print(f"✓ Secret '{args.key}' deleted")
else:
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
sys.exit(1)
def cmd_list(args):
"""List all secret keys."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
password = get_password()
try:
store.unlock(password)
except InvalidMasterPassword:
print("Error: Invalid master password", file=sys.stderr)
sys.exit(1)
keys = store.list_keys()
if not keys:
print("No secrets stored")
else:
print(f"Stored secrets ({len(keys)}):")
for key in sorted(keys):
# Show key and value length for verification
value = store.get(key)
value_str = str(value)
value_preview = value_str[:50] + "..." if len(value_str) > 50 else value_str
print(f" {key}: {value_preview}")
def cmd_change_password(args):
"""Change the master password."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
current_password = get_password("Current master password: ")
new_password = get_password("New master password: ", confirm=True)
if len(new_password) < 8:
print("Error: Password must be at least 8 characters", file=sys.stderr)
sys.exit(1)
try:
store.change_master_password(current_password, new_password)
except InvalidMasterPassword:
print("Error: Invalid current password", file=sys.stderr)
sys.exit(1)
def cmd_export(args):
"""Export encrypted secrets to a backup file."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
output_path = Path(args.file)
if output_path.exists() and not args.force:
print(f"Error: File {output_path} already exists. Use --force to overwrite.", file=sys.stderr)
sys.exit(1)
try:
store.export_encrypted(output_path)
except SecretsStoreError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
def cmd_import(args):
"""Import encrypted secrets from a backup file."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
input_path = Path(args.file)
if not input_path.exists():
print(f"Error: File {input_path} does not exist", file=sys.stderr)
sys.exit(1)
password = get_password()
try:
store.import_encrypted(input_path, password)
except InvalidMasterPassword:
print("Error: Invalid master password or incompatible backup", file=sys.stderr)
sys.exit(1)
except SecretsStoreError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
def cmd_migrate_from_env(args):
"""Migrate secrets from .env file to encrypted store."""
store = SecretsStore()
if not store.is_initialized:
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
sys.exit(1)
# Look for .env file
backend_root = Path(__file__).parent.parent.parent
env_file = backend_root / ".env"
if not env_file.exists():
print(f"Error: .env file not found at {env_file}", file=sys.stderr)
sys.exit(1)
password = get_password()
try:
store.unlock(password)
except InvalidMasterPassword:
print("Error: Invalid master password", file=sys.stderr)
sys.exit(1)
# Parse .env file (simple parser - doesn't handle all edge cases)
migrated = 0
skipped = 0
with open(env_file) as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith('#'):
continue
# Parse KEY=VALUE format
if '=' not in line:
print(f"Warning: Skipping invalid line {line_num}: {line}", file=sys.stderr)
skipped += 1
continue
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
# Remove quotes if present
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith("'") and value.endswith("'"):
value = value[1:-1]
# Check if key already exists
existing = store.get(key)
if existing is not None:
print(f"Warning: Secret '{key}' already exists, skipping", file=sys.stderr)
skipped += 1
continue
store.set(key, value)
print(f"✓ Migrated: {key}")
migrated += 1
print(f"\nMigration complete: {migrated} secrets migrated, {skipped} skipped")
if not args.keep_env:
# Ask for confirmation before deleting .env
confirm = input(f"\nDelete {env_file}? [y/N]: ").strip().lower()
if confirm == 'y':
env_file.unlink()
print(f"✓ Deleted {env_file}")
else:
print(f"Kept {env_file} (consider deleting it manually)")
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(
description="Manage encrypted secrets store",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest='command', help='Command to run')
subparsers.required = True
# init
parser_init = subparsers.add_parser('init', help='Initialize new secrets store')
parser_init.set_defaults(func=cmd_init)
# set
parser_set = subparsers.add_parser('set', help='Set a secret value')
parser_set.add_argument('key', help='Secret key name')
parser_set.add_argument('value', help='Secret value')
parser_set.set_defaults(func=cmd_set)
# get
parser_get = subparsers.add_parser('get', help='Get a secret value')
parser_get.add_argument('key', help='Secret key name')
parser_get.set_defaults(func=cmd_get)
# delete
parser_delete = subparsers.add_parser('delete', help='Delete a secret')
parser_delete.add_argument('key', help='Secret key name')
parser_delete.set_defaults(func=cmd_delete)
# list
parser_list = subparsers.add_parser('list', help='List all secret keys')
parser_list.set_defaults(func=cmd_list)
# change-password
parser_change = subparsers.add_parser('change-password', help='Change master password')
parser_change.set_defaults(func=cmd_change_password)
# export
parser_export = subparsers.add_parser('export', help='Export encrypted backup')
parser_export.add_argument('file', help='Output file path')
parser_export.add_argument('--force', action='store_true', help='Overwrite existing file')
parser_export.set_defaults(func=cmd_export)
# import
parser_import = subparsers.add_parser('import', help='Import encrypted backup')
parser_import.add_argument('file', help='Input file path')
parser_import.set_defaults(func=cmd_import)
# migrate-from-env
parser_migrate = subparsers.add_parser('migrate-from-env', help='Migrate from .env file')
parser_migrate.add_argument('--keep-env', action='store_true', help='Keep .env file after migration')
parser_migrate.set_defaults(func=cmd_migrate_from_env)
args = parser.parse_args()
try:
args.func(args)
except KeyboardInterrupt:
print("\nAborted", file=sys.stderr)
sys.exit(130)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,144 @@
"""
Cryptographic utilities for secrets management.
Uses Argon2id for password-based key derivation and Fernet for encryption.
"""
import os
import secrets as secrets_module
from typing import Tuple
from argon2 import PasswordHasher
from argon2.low_level import hash_secret_raw, Type
from cryptography.fernet import Fernet
import base64
# Argon2id parameters (OWASP recommended for password-based KDF)
# These provide strong defense against GPU/ASIC attacks
ARGON2_TIME_COST = 3 # iterations
ARGON2_MEMORY_COST = 65536 # 64 MB
ARGON2_PARALLELISM = 4 # threads
ARGON2_HASH_LENGTH = 32 # bytes (256 bits for Fernet key)
ARGON2_SALT_LENGTH = 16 # bytes (128 bits)
def generate_salt() -> bytes:
"""Generate a cryptographically secure random salt."""
return secrets_module.token_bytes(ARGON2_SALT_LENGTH)
def derive_key_from_password(password: str, salt: bytes) -> bytes:
"""
Derive an encryption key from a password using Argon2id.
Args:
password: The master password
salt: The salt (must be consistent for the same password to work)
Returns:
32-byte key suitable for Fernet encryption
"""
password_bytes = password.encode('utf-8')
# Use Argon2id (hybrid mode - best of Argon2i and Argon2d)
raw_hash = hash_secret_raw(
secret=password_bytes,
salt=salt,
time_cost=ARGON2_TIME_COST,
memory_cost=ARGON2_MEMORY_COST,
parallelism=ARGON2_PARALLELISM,
hash_len=ARGON2_HASH_LENGTH,
type=Type.ID # Argon2id
)
return raw_hash
def create_fernet(key: bytes) -> Fernet:
"""
Create a Fernet cipher instance from a raw key.
Args:
key: 32-byte raw key from Argon2id
Returns:
Fernet instance for encryption/decryption
"""
# Fernet requires a URL-safe base64-encoded 32-byte key
fernet_key = base64.urlsafe_b64encode(key)
return Fernet(fernet_key)
def encrypt_data(data: bytes, key: bytes) -> bytes:
"""
Encrypt data using Fernet (AES-256-CBC).
Args:
data: Raw bytes to encrypt
key: 32-byte encryption key
Returns:
Encrypted data (includes IV and auth tag)
"""
fernet = create_fernet(key)
return fernet.encrypt(data)
def decrypt_data(encrypted_data: bytes, key: bytes) -> bytes:
"""
Decrypt data using Fernet.
Args:
encrypted_data: Encrypted bytes from encrypt_data
key: 32-byte encryption key (must match encryption key)
Returns:
Decrypted raw bytes
Raises:
cryptography.fernet.InvalidToken: If decryption fails (wrong key/corrupted data)
"""
fernet = create_fernet(key)
return fernet.decrypt(encrypted_data)
def create_verification_hash(password: str, salt: bytes) -> str:
"""
Create a verification hash to check if a password is correct.
This is NOT for storing the password - it's for verifying the password
unlocks the correct key without trying to decrypt the entire secrets file.
Args:
password: The master password
salt: The salt used for key derivation
Returns:
Base64-encoded hash for verification
"""
# Derive key and hash it again for verification
key = derive_key_from_password(password, salt)
# Simple hash of the key for verification (not security critical since
# the key itself is already derived from Argon2id)
verification = base64.b64encode(key[:16]).decode('ascii')
return verification
def verify_password(password: str, salt: bytes, verification_hash: str) -> bool:
"""
Verify a password against a verification hash.
Args:
password: Password to verify
salt: Salt used for key derivation
verification_hash: Expected verification hash
Returns:
True if password is correct, False otherwise
"""
computed_hash = create_verification_hash(password, salt)
# Constant-time comparison to prevent timing attacks
return secrets_module.compare_digest(computed_hash, verification_hash)

View File

@@ -0,0 +1,406 @@
"""
Encrypted secrets store with master password protection.
The secrets are stored in an encrypted file, with the encryption key derived
from a master password using Argon2id. The master password can be changed
without re-encrypting all secrets.
"""
import json
import os
import stat
from pathlib import Path
from typing import Dict, Optional, Any
from cryptography.fernet import InvalidToken
from .crypto import (
generate_salt,
derive_key_from_password,
encrypt_data,
decrypt_data,
create_verification_hash,
verify_password,
)
class SecretsStoreError(Exception):
"""Base exception for secrets store errors."""
pass
class SecretsStoreLocked(SecretsStoreError):
"""Raised when trying to access secrets while store is locked."""
pass
class InvalidMasterPassword(SecretsStoreError):
"""Raised when master password is incorrect."""
pass
class SecretsStore:
"""
Encrypted secrets store with master password protection.
Usage:
# Initialize (first time)
store = SecretsStore()
store.initialize("my-secure-password")
# Unlock
store = SecretsStore()
store.unlock("my-secure-password")
# Access secrets
api_key = store.get("ANTHROPIC_API_KEY")
store.set("NEW_SECRET", "secret-value")
# Change master password
store.change_master_password("my-secure-password", "new-password")
# Lock when done
store.lock()
"""
def __init__(self, data_dir: Optional[Path] = None):
"""
Initialize secrets store.
Args:
data_dir: Directory for secrets files (defaults to backend/data)
"""
if data_dir is None:
# Default to backend/data
backend_root = Path(__file__).parent.parent.parent
data_dir = backend_root / "data"
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.master_key_file = self.data_dir / ".master.key"
self.secrets_file = self.data_dir / "secrets.enc"
# Runtime state
self._encryption_key: Optional[bytes] = None
self._secrets: Optional[Dict[str, Any]] = None
@property
def is_initialized(self) -> bool:
"""Check if the secrets store has been initialized."""
return self.master_key_file.exists()
@property
def is_unlocked(self) -> bool:
"""Check if the secrets store is currently unlocked."""
return self._encryption_key is not None
def initialize(self, master_password: str) -> None:
"""
Initialize the secrets store with a master password.
This should only be called once when setting up the store.
Args:
master_password: The master password to protect the secrets
Raises:
SecretsStoreError: If store is already initialized
"""
if self.is_initialized:
raise SecretsStoreError(
"Secrets store is already initialized. "
"Use unlock() to access it or change_master_password() to change the password."
)
# Generate a new random salt
salt = generate_salt()
# Derive encryption key
encryption_key = derive_key_from_password(master_password, salt)
# Create verification hash
verification_hash = create_verification_hash(master_password, salt)
# Store salt and verification hash
master_key_data = {
"salt": salt.hex(),
"verification": verification_hash,
}
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
# Set restrictive permissions (owner read/write only)
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
# Initialize empty secrets
self._encryption_key = encryption_key
self._secrets = {}
self._save_secrets()
print(f"✓ Secrets store initialized at {self.secrets_file}")
def unlock(self, master_password: str) -> None:
"""
Unlock the secrets store with the master password.
Args:
master_password: The master password
Raises:
SecretsStoreError: If store is not initialized
InvalidMasterPassword: If password is incorrect
"""
if not self.is_initialized:
raise SecretsStoreError(
"Secrets store is not initialized. Call initialize() first."
)
# Load salt and verification hash
master_key_data = json.loads(self.master_key_file.read_text())
salt = bytes.fromhex(master_key_data["salt"])
verification_hash = master_key_data["verification"]
# Verify password
if not verify_password(master_password, salt, verification_hash):
raise InvalidMasterPassword("Invalid master password")
# Derive encryption key
encryption_key = derive_key_from_password(master_password, salt)
# Load and decrypt secrets
if self.secrets_file.exists():
try:
encrypted_data = self.secrets_file.read_bytes()
decrypted_data = decrypt_data(encrypted_data, encryption_key)
self._secrets = json.loads(decrypted_data.decode('utf-8'))
except InvalidToken:
raise InvalidMasterPassword("Failed to decrypt secrets (invalid password)")
except json.JSONDecodeError as e:
raise SecretsStoreError(f"Corrupted secrets file: {e}")
else:
# No secrets file yet (fresh initialization)
self._secrets = {}
self._encryption_key = encryption_key
print(f"✓ Secrets store unlocked ({len(self._secrets)} secrets)")
def lock(self) -> None:
"""Lock the secrets store (clear decrypted data from memory)."""
self._encryption_key = None
self._secrets = None
def get(self, key: str, default: Any = None) -> Any:
"""
Get a secret value.
Args:
key: Secret key name
default: Default value if key doesn't exist
Returns:
Secret value or default
Raises:
SecretsStoreLocked: If store is locked
"""
if not self.is_unlocked:
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
return self._secrets.get(key, default)
def set(self, key: str, value: Any) -> None:
"""
Set a secret value.
Args:
key: Secret key name
value: Secret value (must be JSON-serializable)
Raises:
SecretsStoreLocked: If store is locked
"""
if not self.is_unlocked:
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
self._secrets[key] = value
self._save_secrets()
def delete(self, key: str) -> bool:
"""
Delete a secret.
Args:
key: Secret key name
Returns:
True if secret existed and was deleted, False otherwise
Raises:
SecretsStoreLocked: If store is locked
"""
if not self.is_unlocked:
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
if key in self._secrets:
del self._secrets[key]
self._save_secrets()
return True
return False
def list_keys(self) -> list[str]:
"""
List all secret keys.
Returns:
List of secret keys
Raises:
SecretsStoreLocked: If store is locked
"""
if not self.is_unlocked:
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
return list(self._secrets.keys())
def change_master_password(self, current_password: str, new_password: str) -> None:
"""
Change the master password.
This re-encrypts the secrets with a new key derived from the new password.
Args:
current_password: Current master password
new_password: New master password
Raises:
InvalidMasterPassword: If current password is incorrect
"""
# ALWAYS verify current password before changing
# Load salt and verification hash
if not self.is_initialized:
raise SecretsStoreError(
"Secrets store is not initialized. Call initialize() first."
)
master_key_data = json.loads(self.master_key_file.read_text())
salt = bytes.fromhex(master_key_data["salt"])
verification_hash = master_key_data["verification"]
# Verify current password is correct
if not verify_password(current_password, salt, verification_hash):
raise InvalidMasterPassword("Invalid current password")
# Unlock if needed to access secrets
was_unlocked = self.is_unlocked
if not was_unlocked:
# Store is locked, so unlock with current password
# (we already verified it above, so this will succeed)
encryption_key = derive_key_from_password(current_password, salt)
# Load and decrypt secrets
if self.secrets_file.exists():
encrypted_data = self.secrets_file.read_bytes()
decrypted_data = decrypt_data(encrypted_data, encryption_key)
self._secrets = json.loads(decrypted_data.decode('utf-8'))
else:
self._secrets = {}
self._encryption_key = encryption_key
# Generate new salt
new_salt = generate_salt()
# Derive new encryption key
new_encryption_key = derive_key_from_password(new_password, new_salt)
# Create new verification hash
new_verification_hash = create_verification_hash(new_password, new_salt)
# Update master key file
master_key_data = {
"salt": new_salt.hex(),
"verification": new_verification_hash,
}
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
# Re-encrypt secrets with new key
old_key = self._encryption_key
self._encryption_key = new_encryption_key
self._save_secrets()
print("✓ Master password changed successfully")
# Lock if it wasn't unlocked before
if not was_unlocked:
self.lock()
def _save_secrets(self) -> None:
"""Save secrets to encrypted file."""
if not self.is_unlocked:
raise SecretsStoreLocked("Cannot save while locked")
# Serialize secrets to JSON
secrets_json = json.dumps(self._secrets, indent=2)
secrets_bytes = secrets_json.encode('utf-8')
# Encrypt
encrypted_data = encrypt_data(secrets_bytes, self._encryption_key)
# Write to file
self.secrets_file.write_bytes(encrypted_data)
# Set restrictive permissions
os.chmod(self.secrets_file, stat.S_IRUSR | stat.S_IWUSR)
def export_encrypted(self, output_path: Path) -> None:
"""
Export encrypted secrets to a file (for backup).
Args:
output_path: Path to export file
Raises:
SecretsStoreError: If secrets file doesn't exist
"""
if not self.secrets_file.exists():
raise SecretsStoreError("No secrets to export")
import shutil
shutil.copy2(self.secrets_file, output_path)
print(f"✓ Encrypted secrets exported to {output_path}")
def import_encrypted(self, input_path: Path, master_password: str) -> None:
"""
Import encrypted secrets from a file.
This will verify the password can decrypt the import before replacing
the current secrets.
Args:
input_path: Path to import file
master_password: Master password for the current store
Raises:
InvalidMasterPassword: If password doesn't work with import
"""
if not self.is_unlocked:
self.unlock(master_password)
# Try to decrypt the imported file with current key
try:
encrypted_data = Path(input_path).read_bytes()
decrypted_data = decrypt_data(encrypted_data, self._encryption_key)
imported_secrets = json.loads(decrypted_data.decode('utf-8'))
except InvalidToken:
raise InvalidMasterPassword(
"Cannot decrypt imported secrets with current master password"
)
except json.JSONDecodeError as e:
raise SecretsStoreError(f"Corrupted import file: {e}")
# Replace secrets
self._secrets = imported_secrets
self._save_secrets()
print(f"✓ Imported {len(self._secrets)} secrets from {input_path}")

View File

@@ -2,8 +2,10 @@
Test script for CCXT DataSource adapter (Free Version). Test script for CCXT DataSource adapter (Free Version).
This demonstrates how to use the free CCXT adapter (not ccxt.pro) with various This demonstrates how to use the free CCXT adapter (not ccxt.pro) with various
exchanges. It uses polling instead of WebSocket for real-time updates and exchanges. It uses polling instead of WebSocket for real-time updates.
verifies that Decimal precision is maintained throughout.
CCXT is configured to use Decimal mode internally for precision, but OHLCV data
is converted to float for optimal DataFrame/analysis performance.
""" """
import asyncio import asyncio
@@ -73,10 +75,10 @@ async def test_binance_datasource():
print(f" Close: {latest.data['close']} (type: {type(latest.data['close']).__name__})") print(f" Close: {latest.data['close']} (type: {type(latest.data['close']).__name__})")
print(f" Volume: {latest.data['volume']} (type: {type(latest.data['volume']).__name__})") print(f" Volume: {latest.data['volume']} (type: {type(latest.data['volume']).__name__})")
# Verify Decimal precision # Verify OHLCV uses float (converted from Decimal for DataFrame performance)
assert isinstance(latest.data['close'], Decimal), "Price should be Decimal type!" assert isinstance(latest.data['close'], float), "OHLCV price should be float type!"
assert isinstance(latest.data['volume'], Decimal), "Volume should be Decimal type!" assert isinstance(latest.data['volume'], float), "OHLCV volume should be float type!"
print(f"Numerical precision verified: using Decimal types") print(f"OHLCV data type verified: using native float (CCXT uses Decimal internally)")
# Test 5: Polling subscription (brief test) # Test 5: Polling subscription (brief test)
print("\n5. Testing polling-based subscription...") print("\n5. Testing polling-based subscription...")
@@ -87,7 +89,7 @@ async def test_binance_datasource():
tick_count[0] += 1 tick_count[0] += 1
if tick_count[0] == 1: if tick_count[0] == 1:
print(f" Received tick: close={data['close']} (type: {type(data['close']).__name__})") print(f" Received tick: close={data['close']} (type: {type(data['close']).__name__})")
assert isinstance(data['close'], Decimal), "Polled data should use Decimal!" assert isinstance(data['close'], float), "Polled OHLCV data should use float!"
subscription_id = await binance.subscribe_bars( subscription_id = await binance.subscribe_bars(
symbol="BTC/USDT", symbol="BTC/USDT",

View File

@@ -1,3 +1,27 @@
FROM python:3.14-alpine FROM python:3.14-alpine
COPY python/src /app/src # Install TA-Lib C library and build dependencies
RUN apk add --no-cache --virtual .build-deps \
gcc \
g++ \
make \
musl-dev \
wget \
&& apk add --no-cache \
ta-lib \
&& rm -rf /var/cache/apk/*
# Set working directory
WORKDIR /app
# Copy requirements first for better caching
COPY backend/requirements.txt /app/requirements.txt
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Clean up build dependencies
RUN apk del .build-deps
# Copy application code
COPY backend/src /app/src

View File

@@ -13,6 +13,7 @@ import { wsManager } from './composables/useWebSocket'
const isAuthenticated = ref(false) const isAuthenticated = ref(false)
const needsConfirmation = ref(false) const needsConfirmation = ref(false)
const authError = ref<string>() const authError = ref<string>()
const isDragging = ref(false)
let stateSyncCleanup: (() => void) | null = null let stateSyncCleanup: (() => void) | null = null
// Check if we need password confirmation on first load // Check if we need password confirmation on first load
@@ -58,6 +59,21 @@ const handleAuthenticate = async (
} }
} }
onMounted(() => {
// Listen for splitter drag events
document.addEventListener('mousedown', (e) => {
// Check if the mousedown is on a splitter gutter
const target = e.target as HTMLElement
if (target.closest('.p-splitter-gutter')) {
isDragging.value = true
}
})
document.addEventListener('mouseup', () => {
isDragging.value = false
})
})
onBeforeUnmount(() => { onBeforeUnmount(() => {
if (stateSyncCleanup) { if (stateSyncCleanup) {
stateSyncCleanup() stateSyncCleanup()
@@ -82,6 +98,8 @@ onBeforeUnmount(() => {
<ChatPanel /> <ChatPanel />
</SplitterPanel> </SplitterPanel>
</Splitter> </Splitter>
<!-- Transparent overlay to prevent iframe from capturing mouse events during drag -->
<div v-if="isDragging" class="drag-overlay"></div>
</div> </div>
</template> </template>
@@ -90,19 +108,24 @@ onBeforeUnmount(() => {
width: 100vw !important; width: 100vw !important;
height: 100vh !important; height: 100vh !important;
overflow: hidden; overflow: hidden;
background: var(--p-surface-0); background: var(--p-surface-0) !important;
}
.app-container.dark {
background: var(--p-surface-0) !important;
} }
.main-splitter { .main-splitter {
height: 100vh !important; height: 100vh !important;
background: var(--p-surface-0) !important;
} }
.main-splitter :deep(.p-splitter-gutter) { .main-splitter :deep(.p-splitter-gutter) {
background: var(--p-surface-100); background: var(--p-surface-0) !important;
} }
.main-splitter :deep(.p-splitter-gutter-handle) { .main-splitter :deep(.p-splitter-gutter-handle) {
background: var(--p-primary-color); background: var(--p-surface-400) !important;
} }
.chart-panel, .chart-panel,
@@ -119,4 +142,15 @@ onBeforeUnmount(() => {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
} }
.drag-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
z-index: 9999;
cursor: col-resize;
background: transparent;
}
</style> </style>

View File

@@ -26,3 +26,16 @@ html, body, #app {
overflow: hidden; overflow: hidden;
background-color: var(--p-surface-0) !important; background-color: var(--p-surface-0) !important;
} }
.dark {
background-color: var(--p-surface-0) !important;
color: var(--p-surface-900) !important;
}
/* Ensure dark background for main containers */
.app-container,
.main-splitter,
.p-splitter,
.p-splitter-panel {
background-color: var(--p-surface-0) !important;
}

View File

@@ -191,9 +191,14 @@ onBeforeUnmount(() => {
flex-direction: column; flex-direction: column;
overflow: hidden; overflow: hidden;
border: none; border: none;
border-radius: 0 !important;
background: var(--p-surface-0); background: var(--p-surface-0);
} }
.chart-card :deep(.p-card) {
border-radius: 0 !important;
}
.chart-card :deep(.p-card-body) { .chart-card :deep(.p-card-body) {
flex: 1; flex: 1;
display: flex; display: flex;

View File

@@ -2,6 +2,7 @@
import { ref, onMounted, onUnmounted, computed } from 'vue' import { ref, onMounted, onUnmounted, computed } from 'vue'
import { register } from 'vue-advanced-chat' import { register } from 'vue-advanced-chat'
import Badge from 'primevue/badge' import Badge from 'primevue/badge'
import Button from 'primevue/button'
import { wsManager } from '../composables/useWebSocket' import { wsManager } from '../composables/useWebSocket'
import type { WebSocketMessage } from '../composables/useWebSocket' import type { WebSocketMessage } from '../composables/useWebSocket'
@@ -17,7 +18,7 @@ const messages = ref<any[]>([])
const messagesLoaded = ref(false) const messagesLoaded = ref(false)
const isConnected = wsManager.isConnected const isConnected = wsManager.isConnected
// Reactive rooms that update based on WebSocket connection // Reactive rooms that update based on WebSocket connection and agent processing state
const rooms = computed(() => [{ const rooms = computed(() => [{
roomId: SESSION_ID, roomId: SESSION_ID,
roomName: 'AI Agent', roomName: 'AI Agent',
@@ -26,23 +27,29 @@ const rooms = computed(() => [{
{ _id: CURRENT_USER_ID, username: 'You' }, { _id: CURRENT_USER_ID, username: 'You' },
{ _id: AGENT_ID, username: 'AI Agent', status: { state: isConnected.value ? 'online' : 'offline' } } { _id: AGENT_ID, username: 'AI Agent', status: { state: isConnected.value ? 'online' : 'offline' } }
], ],
unreadCount: 0 unreadCount: 0,
typingUsers: isAgentProcessing.value ? [AGENT_ID] : []
}]) }])
// Streaming state // Streaming state
let currentStreamingMessageId: string | null = null let currentStreamingMessageId: string | null = null
let streamingBuffer = '' let streamingBuffer = ''
const isAgentProcessing = ref(false)
// Generate message ID // Generate message ID
const generateMessageId = () => `msg-${Date.now()}-${Math.random().toString(36).substr(2, 9)}` const generateMessageId = () => `msg-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`
// Handle WebSocket messages // Handle WebSocket messages
const handleMessage = (data: WebSocketMessage) => { const handleMessage = (data: WebSocketMessage) => {
console.log('[ChatPanel] Received message:', data)
if (data.type === 'agent_chunk') { if (data.type === 'agent_chunk') {
console.log('[ChatPanel] Processing agent_chunk, content:', data.content, 'done:', data.done)
const timestamp = new Date().toTimeString().split(' ')[0].slice(0, 5) const timestamp = new Date().toTimeString().split(' ')[0].slice(0, 5)
if (!currentStreamingMessageId) { if (!currentStreamingMessageId) {
console.log('[ChatPanel] Starting new streaming message')
// Start new streaming message // Start new streaming message
isAgentProcessing.value = true
currentStreamingMessageId = generateMessageId() currentStreamingMessageId = generateMessageId()
streamingBuffer = data.content streamingBuffer = data.content
@@ -54,7 +61,8 @@ const handleMessage = (data: WebSocketMessage) => {
date: new Date().toLocaleDateString(), date: new Date().toLocaleDateString(),
saved: false, saved: false,
distributed: false, distributed: false,
seen: false seen: false,
files: []
}] }]
} else { } else {
// Update existing streaming message // Update existing streaming message
@@ -62,10 +70,24 @@ const handleMessage = (data: WebSocketMessage) => {
const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId) const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId)
if (msgIndex !== -1) { if (msgIndex !== -1) {
messages.value[msgIndex] = { const updatedMessage: any = {
...messages.value[msgIndex], ...messages.value[msgIndex],
content: streamingBuffer content: streamingBuffer
} }
// Add plot images if present in metadata
if (data.metadata && data.metadata.plot_urls && Array.isArray(data.metadata.plot_urls)) {
const plotFiles = data.metadata.plot_urls.map((url: string, idx: number) => ({
name: `plot_${idx + 1}.png`,
size: 0,
type: 'png',
url: `${BACKEND_URL}${url}`,
preview: `${BACKEND_URL}${url}`
}))
updatedMessage.files = plotFiles
}
messages.value[msgIndex] = updatedMessage
messages.value = [...messages.value] messages.value = [...messages.value]
} }
} }
@@ -74,21 +96,49 @@ const handleMessage = (data: WebSocketMessage) => {
// Mark message as complete // Mark message as complete
const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId) const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId)
if (msgIndex !== -1) { if (msgIndex !== -1) {
messages.value[msgIndex] = { const finalMessage: any = {
...messages.value[msgIndex], ...messages.value[msgIndex],
saved: true, saved: true,
distributed: true, distributed: true,
seen: true seen: true
} }
// Ensure plot images are included in final message
if (data.metadata && data.metadata.plot_urls && Array.isArray(data.metadata.plot_urls)) {
const plotFiles = data.metadata.plot_urls.map((url: string, idx: number) => ({
name: `plot_${idx + 1}.png`,
size: 0,
type: 'png',
url: `${BACKEND_URL}${url}`,
preview: `${BACKEND_URL}${url}`
}))
finalMessage.files = plotFiles
}
messages.value[msgIndex] = finalMessage
messages.value = [...messages.value] messages.value = [...messages.value]
} }
currentStreamingMessageId = null currentStreamingMessageId = null
streamingBuffer = '' streamingBuffer = ''
isAgentProcessing.value = false
} }
} }
} }
// Stop agent processing
const stopAgent = () => {
// Send empty message to trigger interrupt without new agent round
const wsMessage = {
type: 'agent_user_message',
session_id: SESSION_ID,
content: '',
attachments: []
}
wsManager.send(wsMessage)
isAgentProcessing.value = false
}
// Send message handler // Send message handler
const sendMessage = async (event: any) => { const sendMessage = async (event: any) => {
// Extract data from CustomEvent.detail[0] // Extract data from CustomEvent.detail[0]
@@ -191,39 +241,39 @@ const openFile = ({ file }: any) => {
} }
// Theme configuration for dark mode // Theme configuration for dark mode
const chatTheme = 'light' const chatTheme = 'dark'
// Styles to match PrimeVue theme // Styles to match PrimeVue theme
const chatStyles = computed(() => JSON.stringify({ const chatStyles = computed(() => JSON.stringify({
general: { general: {
color: 'var(--p-surface-900)', color: '#cdd6e8',
colorSpinner: 'var(--p-primary-color)', colorSpinner: '#00d4aa',
borderStyle: '1px solid var(--p-surface-200)' borderStyle: '1px solid #263452'
}, },
container: { container: {
background: 'var(--p-surface-0)' background: '#0a0e1a'
}, },
header: { header: {
background: 'var(--p-surface-50)', background: '#0f1629',
colorRoomName: 'var(--p-surface-900)', colorRoomName: '#cdd6e8',
colorRoomInfo: 'var(--p-surface-700)' colorRoomInfo: '#8892a4'
}, },
footer: { footer: {
background: 'var(--p-surface-50)', background: '#0f1629',
borderStyleInput: '1px solid var(--p-surface-300)', borderStyleInput: '1px solid #263452',
backgroundInput: 'var(--p-surface-200)', backgroundInput: '#161e35',
colorInput: 'var(--p-surface-900)', colorInput: '#cdd6e8',
colorPlaceholder: 'var(--p-surface-400)', colorPlaceholder: '#8892a4',
colorIcons: 'var(--p-surface-400)' colorIcons: '#8892a4'
}, },
content: { content: {
background: 'var(--p-surface-0)' background: '#0a0e1a'
}, },
message: { message: {
background: 'var(--p-surface-100)', background: '#161e35',
backgroundMe: 'var(--p-primary-color)', backgroundMe: '#00d4aa',
color: 'var(--p-surface-900)', color: '#cdd6e8',
colorMe: 'var(--p-primary-contrast-color)' colorMe: '#0a0e1a'
} }
})) }))
@@ -231,6 +281,14 @@ onMounted(() => {
wsManager.addHandler(handleMessage) wsManager.addHandler(handleMessage)
// Mark messages as loaded after initialization // Mark messages as loaded after initialization
messagesLoaded.value = true messagesLoaded.value = true
// Focus on the chat input when component mounts
setTimeout(() => {
const chatInput = document.querySelector('.vac-textarea') as HTMLTextAreaElement
if (chatInput) {
chatInput.focus()
}
}, 300)
}) })
onUnmounted(() => { onUnmounted(() => {
@@ -251,7 +309,7 @@ onUnmounted(() => {
--> -->
<vue-advanced-chat <vue-advanced-chat
height="calc(100vh - 60px)" height="100vh"
:current-user-id="CURRENT_USER_ID" :current-user-id="CURRENT_USER_ID"
:rooms="JSON.stringify(rooms)" :rooms="JSON.stringify(rooms)"
:messages="JSON.stringify(messages)" :messages="JSON.stringify(messages)"
@@ -267,10 +325,22 @@ onUnmounted(() => {
:show-emojis="true" :show-emojis="true"
:show-reaction-emojis="false" :show-reaction-emojis="false"
:accepted-files="'image/*,video/*,application/pdf'" :accepted-files="'image/*,video/*,application/pdf'"
:message-images="true"
@send-message="sendMessage" @send-message="sendMessage"
@fetch-messages="fetchMessages" @fetch-messages="fetchMessages"
@open-file="openFile" @open-file="openFile"
/> />
<!-- Stop button overlay -->
<div v-if="isAgentProcessing" class="stop-button-container">
<Button
icon="pi pi-stop-circle"
label="Stop"
severity="danger"
@click="stopAgent"
class="stop-button"
/>
</div>
</div> </div>
</template> </template>
@@ -279,8 +349,9 @@ onUnmounted(() => {
height: 100% !important; height: 100% !important;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
background: var(--p-surface-0); background: var(--p-surface-0) !important;
overflow: hidden; overflow: hidden;
position: relative;
} }
.chat-container :deep(.vac-container) { .chat-container :deep(.vac-container) {
@@ -306,4 +377,25 @@ onUnmounted(() => {
font-weight: 600; font-weight: 600;
color: var(--p-surface-900); color: var(--p-surface-900);
} }
.stop-button-container {
position: absolute;
bottom: 80px;
right: 20px;
z-index: 1000;
}
.stop-button {
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
animation: pulse 2s infinite;
}
@keyframes pulse {
0%, 100% {
opacity: 1;
}
50% {
opacity: 0.8;
}
}
</style> </style>

View File

@@ -1,5 +1,5 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed } from 'vue' import { ref, computed, onMounted } from 'vue'
import Card from 'primevue/card' import Card from 'primevue/card'
import InputText from 'primevue/inputtext' import InputText from 'primevue/inputtext'
import Password from 'primevue/password' import Password from 'primevue/password'
@@ -66,6 +66,14 @@ const togglePasswordChange = () => {
newPassword.value = '' newPassword.value = ''
confirmNewPassword.value = '' confirmNewPassword.value = ''
} }
onMounted(() => {
// Focus on the password input when component mounts
const passwordInput = document.querySelector('#password input') as HTMLInputElement
if (passwordInput) {
passwordInput.focus()
}
})
</script> </script>
<template> <template>
@@ -81,13 +89,13 @@ const togglePasswordChange = () => {
<template #content> <template #content>
<div class="login-content"> <div class="login-content">
<p v-if="needsConfirmation" class="welcome-message"> <p v-if="needsConfirmation" class="welcome-message">
This is your first time connecting. Please create a master password to secure your workspace. This is your first time connecting. Please create a password to secure your workspace.
</p> </p>
<p v-else-if="isChangingPassword" class="welcome-message"> <p v-else-if="isChangingPassword" class="welcome-message">
Enter your current password and choose a new one. Enter your current password and choose a new one.
</p> </p>
<p v-else class="welcome-message"> <p v-else class="welcome-message">
Enter your master password to connect. Enter your password to connect.
</p> </p>
<Message v-if="errorMessage" severity="error" :closable="false"> <Message v-if="errorMessage" severity="error" :closable="false">