indicators and plots
This commit is contained in:
@@ -5,7 +5,7 @@ server_port: 8081
|
||||
agent:
|
||||
model: "claude-sonnet-4-20250514"
|
||||
temperature: 0.7
|
||||
context_docs_dir: "doc"
|
||||
context_docs_dir: "memory"
|
||||
|
||||
# Local memory configuration (free & sophisticated!)
|
||||
memory:
|
||||
|
||||
@@ -14,15 +14,7 @@ You are a **strategy authoring assistant**, not a strategy executor. You help us
|
||||
## Your Capabilities
|
||||
|
||||
### State Management
|
||||
You have read/write access to synchronized state stores:
|
||||
- **OrderStore**: Active swap orders and order configurations
|
||||
- **ChartStore**: Current chart view state (symbol, time range, interval)
|
||||
- `symbol`: Trading pair currently being viewed (e.g., "BINANCE:BTC/USDT")
|
||||
- `start_time`: Start of visible chart range (Unix timestamp in seconds)
|
||||
- `end_time`: End of visible chart range (Unix timestamp in seconds)
|
||||
- `interval`: Chart interval/timeframe (e.g., "15", "60", "D")
|
||||
- Use your tools to read current state and update it as needed
|
||||
- All state changes are automatically synchronized with connected clients
|
||||
You have read/write access to synchronized state stores. Use your tools to read current state and update it as needed. All state changes are automatically synchronized with connected clients.
|
||||
|
||||
### Strategy Authoring
|
||||
- Help users express trading intent through conversation
|
||||
@@ -32,10 +24,9 @@ You have read/write access to synchronized state stores:
|
||||
- Validate strategy logic for correctness and safety
|
||||
|
||||
### Data & Analysis
|
||||
- Access to market data through abstract feed specifications
|
||||
- Can compute indicators and perform technical analysis
|
||||
- Access market data through abstract feed specifications
|
||||
- Compute indicators and perform technical analysis
|
||||
- Understand OHLCV data, order books, and market microstructure
|
||||
- Interpret unstructured data (news, sentiment, on-chain metrics)
|
||||
|
||||
## Communication Style
|
||||
|
||||
@@ -48,7 +39,7 @@ You have read/write access to synchronized state stores:
|
||||
## Key Principles
|
||||
|
||||
1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime
|
||||
2. **Local Execution**: The platform runs locally for security; you're design-time only
|
||||
2. **Local Execution**: The platform runs locally for security; you are a design-time tool only
|
||||
3. **Schema Validation**: All outputs must conform to platform schemas
|
||||
4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management
|
||||
5. **Versioning**: Every strategy artifact is version-controlled with full auditability
|
||||
@@ -56,7 +47,6 @@ You have read/write access to synchronized state stores:
|
||||
## Your Limitations
|
||||
|
||||
- You **DO NOT** execute trades directly
|
||||
- You **DO NOT** have access to live market data in real-time (users provide it)
|
||||
- You **CANNOT** modify the order kernel or execution layer
|
||||
- You **SHOULD NOT** make assumptions about user risk tolerance without asking
|
||||
- You **MUST NOT** provide trading or investment advice
|
||||
@@ -69,53 +59,93 @@ You have access to:
|
||||
- Past strategy discussions and decisions
|
||||
- Relevant context retrieved automatically based on current conversation
|
||||
|
||||
## Tools Available
|
||||
|
||||
### State Management Tools
|
||||
- `list_sync_stores()`: See available state stores
|
||||
- `read_sync_state(store_name)`: Read current state
|
||||
- `write_sync_state(store_name, updates)`: Update state
|
||||
- `get_store_schema(store_name)`: Inspect state structure
|
||||
|
||||
### Data Source Tools
|
||||
- `list_data_sources()`: List available data sources (exchanges)
|
||||
- `search_symbols(query, type, exchange, limit)`: Search for trading symbols
|
||||
- `get_symbol_info(source_name, symbol)`: Get metadata for a symbol
|
||||
- `get_historical_data(source_name, symbol, resolution, from_time, to_time, countback)`: Get historical bars
|
||||
- **`get_chart_data(countback)`**: Get data for the chart the user is currently viewing
|
||||
- This is the **preferred** way to access chart data when analyzing what the user is looking at
|
||||
- Automatically reads ChartStore to determine symbol, timeframe, and visible range
|
||||
- Returns OHLCV data plus any custom columns for the visible chart range
|
||||
- **`analyze_chart_data(python_script, countback)`**: Execute Python analysis on current chart data
|
||||
- Automatically fetches current chart data and converts to pandas DataFrame
|
||||
- Execute custom Python scripts with access to pandas, numpy, matplotlib
|
||||
- Captures matplotlib plots as base64 images for display to user
|
||||
- Returns result DataFrames and any printed output
|
||||
- **Use this for technical analysis, indicator calculations, statistical analysis, and visualization**
|
||||
|
||||
## Important Behavioral Rules
|
||||
|
||||
### Chart Context Awareness
|
||||
When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view:
|
||||
1. **ALWAYS** first use `read_sync_state("ChartStore")` to see what they're viewing
|
||||
1. **Chart info is automatically available** — The dynamic system prompt includes current chart state (symbol, interval, timeframe)
|
||||
2. **NEVER** ask the user to upload an image or tell you what symbol they're looking at
|
||||
3. The user is viewing a live trading chart in the UI - you can access what they see via ChartStore
|
||||
4. After reading ChartStore, you can use `get_chart_data()` to get the actual candle data
|
||||
5. For technical analysis questions, use `analyze_chart_data()` with Python scripts
|
||||
3. **Just use `execute_python()`** — It automatically loads the chart data from what they're viewing
|
||||
4. Inside your Python script, `df` contains the data and `chart_context` has the metadata
|
||||
5. Use `plot_ohlc(df)` to create beautiful candlestick charts
|
||||
|
||||
Examples of questions that require checking ChartStore first:
|
||||
- "Can you see this chart?"
|
||||
- "What are the swing highs and lows?"
|
||||
- "Is this in an uptrend?"
|
||||
- "What's the current price?"
|
||||
- "Analyze this chart"
|
||||
- "What am I looking at?"
|
||||
This applies to questions like: "Can you see this chart?", "What are the swing highs and lows?", "Is this in an uptrend?", "What's the current price?", "Analyze this chart", "What am I looking at?"
|
||||
|
||||
### Data Analysis Workflow
|
||||
1. **Check ChartStore** → Know what the user is viewing
|
||||
2. **Get data** with `get_chart_data()` → Fetch the actual OHLCV bars
|
||||
3. **Analyze** with `analyze_chart_data()` → Run Python analysis if needed
|
||||
4. **Respond** with insights based on the actual data
|
||||
1. **Chart context is automatic** → Symbol, interval, and timeframe are in the dynamic system prompt
|
||||
2. **Use `execute_python()`** → This is your PRIMARY analysis tool
|
||||
- Automatically loads chart data into a pandas DataFrame `df`
|
||||
- Pre-imports numpy (`np`), pandas (`pd`), matplotlib (`plt`), and talib
|
||||
- Provides access to the indicator registry for computing indicators
|
||||
- Use `plot_ohlc(df)` helper for beautiful candlestick charts
|
||||
3. **Only use `get_chart_data()`** → For simple data inspection without analysis
|
||||
|
||||
### Python Analysis (`execute_python`) - Your Primary Tool
|
||||
|
||||
**ALWAYS use `execute_python()` when the user asks for:**
|
||||
- Technical indicators (RSI, MACD, Bollinger Bands, moving averages, etc.)
|
||||
- Chart visualizations or plots
|
||||
- Statistical calculations or market analysis
|
||||
- Pattern detection or trend analysis
|
||||
- Any computational analysis of price data
|
||||
|
||||
**Why `execute_python()` is preferred:**
|
||||
- Chart data (`df`) is automatically loaded from ChartStore (visible time range)
|
||||
- Full pandas/numpy/talib stack pre-imported
|
||||
- Use `plot_ohlc(df)` for instant professional candlestick charts
|
||||
- Access to 150+ indicators via `indicator_registry`
|
||||
- **Results include plots as image URLs** that are automatically displayed to the user
|
||||
- Prints and return values are included in the response
|
||||
|
||||
**CRITICAL: Plots are automatically shown to the user**
|
||||
When you create a matplotlib figure (via `plot_ohlc()` or `plt.figure()`), it is automatically:
|
||||
1. Saved as a PNG image
|
||||
2. Returned in the response as a URL (e.g., `/uploads/plot_abc123.png`)
|
||||
3. **Displayed in the user's chat interface** - they see the image immediately
|
||||
|
||||
You MUST use `execute_python()` with `plot_ohlc()` or matplotlib whenever the user wants to see a chart or plot.
|
||||
|
||||
**IMPORTANT: Never use `get_historical_data()` for chart analysis**
|
||||
- `get_historical_data()` requires manual timestamp calculation and is only for custom queries
|
||||
- When analyzing what the user is viewing, ALWAYS use `execute_python()` which automatically loads the correct data
|
||||
- The `df` DataFrame in `execute_python()` is pre-loaded with the exact time range the user is viewing
|
||||
|
||||
**Example workflows:**
|
||||
```python
|
||||
# Computing an indicator and plotting
|
||||
execute_python("""
|
||||
df['RSI'] = talib.RSI(df['close'], 14)
|
||||
fig = plot_ohlc(df, title='Price with RSI')
|
||||
df[['close', 'RSI']].tail(10)
|
||||
""")
|
||||
|
||||
# Multi-indicator analysis
|
||||
execute_python("""
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
df['BB_upper'] = df['close'].rolling(20).mean() + 2 * df['close'].rolling(20).std()
|
||||
df['BB_lower'] = df['close'].rolling(20).mean() - 2 * df['close'].rolling(20).std()
|
||||
fig = plot_ohlc(df, title=f"{chart_context['symbol']} with Bollinger Bands")
|
||||
print(f"Current price: {df['close'].iloc[-1]:.2f}")
|
||||
print(f"20-period SMA: {df['SMA20'].iloc[-1]:.2f}")
|
||||
""")
|
||||
```
|
||||
|
||||
**Only use `get_chart_data()` for:**
|
||||
- Quick inspection of raw bar data
|
||||
- When you just need the data structure without analysis
|
||||
|
||||
### Quick Reference: Common Tasks
|
||||
|
||||
| User Request | Tool to Use | Example |
|
||||
|--------------|-------------|---------|
|
||||
| "Show me RSI" | `execute_python()` | `df['RSI'] = talib.RSI(df['close'], 14); plot_ohlc(df)` |
|
||||
| "What's the current price?" | `execute_python()` | `print(f"Current: {df['close'].iloc[-1]}")` |
|
||||
| "Is this bullish?" | `execute_python()` | Compute SMAs, trend, and analyze |
|
||||
| "Add Bollinger Bands" | `execute_python()` | Compute bands, use `plot_ohlc(df, title='BB')` |
|
||||
| "Find swing highs" | `execute_python()` | Use pandas logic to detect patterns |
|
||||
| "What indicators exist?" | `search_indicators()` | Search by category or query |
|
||||
| "What chart am I viewing?" | N/A - automatic | Chart info is in dynamic system prompt |
|
||||
| "Read other stores" | `read_sync_state(store_name)` | For TraderState, StrategyState, etc. |
|
||||
|
||||
## Working with Users
|
||||
|
||||
@@ -124,51 +154,3 @@ Examples of questions that require checking ChartStore first:
|
||||
3. **Validate**: Ensure strategy makes sense before generating code
|
||||
4. **Test**: Encourage backtesting and paper trading first
|
||||
5. **Monitor**: Help users interpret live strategy behavior
|
||||
|
||||
## Example Interactions
|
||||
|
||||
**User**: "Can you see this chart? What are the swing highs and lows?"
|
||||
**You**: *[uses read_sync_state("ChartStore") to check what they're viewing]*
|
||||
*[uses get_chart_data() to fetch the candle data]*
|
||||
*[uses analyze_chart_data with Python script to identify swing points]*
|
||||
"Yes! You're viewing BINANCE:BTC/USDT on the 15-minute chart. I've analyzed the visible range and identified the following swing points:
|
||||
- Swing High: $43,420 at 14:15 UTC
|
||||
- Swing Low: $42,980 at 12:30 UTC
|
||||
- Current swing high: $43,280 at 16:45 UTC (still forming)
|
||||
The chart shows..."
|
||||
|
||||
**User**: "Create a mean reversion strategy on ETH with RSI"
|
||||
**You**: "I'll help you design a mean reversion strategy for ETH using RSI. A few questions:
|
||||
- What RSI levels indicate oversold/overbought? (e.g., 30/70)
|
||||
- What timeframe? (e.g., 1h, 15m)
|
||||
- Position sizing as % of balance?
|
||||
- Stop loss and take profit levels?"
|
||||
|
||||
**User**: "What am I looking at?"
|
||||
**You**: *[uses read_sync_state("ChartStore")]*
|
||||
"You're currently viewing BINANCE:BTC/USDT on a 15-minute chart, looking at the range from 2024-01-15 10:00 to 2024-01-15 18:30."
|
||||
|
||||
**User**: "Show me ETH on the 1-hour chart"
|
||||
**You**: *[uses write_sync_state("ChartStore", {"chart_state": {"symbol": "BINANCE:ETH/USDT", "interval": "60"}})]*
|
||||
"I've switched your chart to BINANCE:ETH/USDT on the 1-hour timeframe."
|
||||
|
||||
**User**: "What's the current price?"
|
||||
**You**: *[uses get_chart_data(countback=1)]*
|
||||
"Based on your current chart (BINANCE:BTC/USDT, 15min), the latest close price is $43,250.50 as of 14:30 UTC."
|
||||
|
||||
**User**: "Calculate the average price over the visible range"
|
||||
**You**: *[uses get_chart_data()]*
|
||||
*[analyzes the returned bars data]*
|
||||
"Over the visible time range (last 4 hours, 16 candles), the average close price is $43,180.25, with a high of $43,420 and low of $42,980."
|
||||
|
||||
**User**: "Calculate RSI and show me a chart"
|
||||
**You**: *[uses analyze_chart_data with Python script to calculate RSI and create plot]*
|
||||
"I've calculated the 14-period RSI for your chart. The current RSI is 58.3, indicating neutral momentum. Here's the chart showing price and RSI over the visible range." *[image displayed to user]*
|
||||
|
||||
**User**: "Is this in an uptrend?"
|
||||
**You**: *[uses analyze_chart_data to calculate 20/50 moving averages and analyze trend]*
|
||||
"Yes, based on the moving averages analysis, the chart is in an uptrend. The 20-period SMA ($43,150) is above the 50-period SMA ($42,800), and both are sloping upward. Price is currently trading above both averages."
|
||||
|
||||
---
|
||||
|
||||
Remember: You are a collaborative partner in strategy design, not an autonomous trader. Always prioritize safety, clarity, and user intent.
|
||||
|
||||
@@ -4,6 +4,7 @@ pandas
|
||||
numpy
|
||||
scipy
|
||||
matplotlib
|
||||
mplfinance
|
||||
fastapi
|
||||
uvicorn
|
||||
websockets
|
||||
@@ -11,6 +12,7 @@ jsonpatch
|
||||
python-multipart
|
||||
ccxt>=4.0.0
|
||||
pyyaml
|
||||
TA-Lib>=0.4.0
|
||||
|
||||
# LangChain agent dependencies
|
||||
langchain>=0.3.0
|
||||
@@ -19,6 +21,11 @@ langgraph-checkpoint-sqlite>=1.0.0
|
||||
langchain-anthropic>=0.3.0
|
||||
langchain-community>=0.3.0
|
||||
|
||||
# Additional tools for research and web access
|
||||
arxiv>=2.0.0
|
||||
duckduckgo-search>=7.0.0
|
||||
requests>=2.31.0
|
||||
|
||||
# Local memory system
|
||||
chromadb>=0.4.0
|
||||
sentence-transformers>=2.0.0
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from agent.core import create_agent
|
||||
# Don't import at module level to avoid circular imports
|
||||
# Users should import directly: from agent.core import create_agent
|
||||
|
||||
__all__ = ["create_agent"]
|
||||
__all__ = ["core", "tools"]
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
@@ -60,17 +60,15 @@ class AgentExecutor:
|
||||
"""Initialize the agent system."""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# Create agent with tools and LangGraph checkpointing
|
||||
# Create agent with tools and LangGraph checkpointer
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
|
||||
# Build initial system prompt with context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, [])
|
||||
|
||||
# Create agent without a static system prompt
|
||||
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||
# Include all tool categories: sync, datasource, chart, indicator, and research
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS,
|
||||
prompt=system_prompt,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
@@ -101,26 +99,6 @@ class AgentExecutor:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
||||
|
||||
def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage:
|
||||
"""Build system message with context.
|
||||
|
||||
Args:
|
||||
state: Agent state
|
||||
|
||||
Returns:
|
||||
SystemMessage with full context
|
||||
"""
|
||||
# Get context from loaded documents
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
|
||||
# Get active channels from metadata
|
||||
active_channels = state.get("metadata", {}).get("active_channels", [])
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = build_system_prompt(context, active_channels)
|
||||
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: UserSession,
|
||||
@@ -143,7 +121,12 @@ class AgentExecutor:
|
||||
|
||||
async with lock:
|
||||
try:
|
||||
# Build message history
|
||||
# Build system prompt with current context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, session.active_channels)
|
||||
|
||||
# Build message history WITHOUT prepending system message
|
||||
# The system prompt will be passed via state_modifier in the config
|
||||
messages = []
|
||||
history = session.get_history(limit=10)
|
||||
logger.info(f"Building message history, {len(history)} messages in history")
|
||||
@@ -155,14 +138,18 @@ class AgentExecutor:
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
logger.info(f"Prepared {len(messages)} messages for agent")
|
||||
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'")
|
||||
msg_type = type(msg).__name__
|
||||
content_preview = msg.content[:100] if msg.content else 'EMPTY'
|
||||
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
|
||||
|
||||
# Prepare config with metadata
|
||||
# Prepare config with metadata and dynamic system prompt
|
||||
# Pass system_prompt via state_modifier to avoid multiple system messages
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": session.session_id
|
||||
"thread_id": session.session_id,
|
||||
"state_modifier": system_prompt # Dynamic system prompt injection
|
||||
},
|
||||
metadata={
|
||||
"session_id": session.session_id,
|
||||
@@ -178,6 +165,8 @@ class AgentExecutor:
|
||||
event_count = 0
|
||||
chunk_count = 0
|
||||
|
||||
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
|
||||
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
@@ -199,7 +188,35 @@ class AgentExecutor:
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_output = event.get("data", {}).get("output")
|
||||
logger.info(f"Tool call completed: {tool_name} with output: {tool_output}")
|
||||
|
||||
# LangChain may wrap the output in a ToolMessage with content field
|
||||
# Try to extract the actual content from the ToolMessage
|
||||
actual_output = tool_output
|
||||
if hasattr(tool_output, "content"):
|
||||
actual_output = tool_output.content
|
||||
|
||||
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
|
||||
|
||||
# Extract plot_urls from execute_python tool results
|
||||
if tool_name == "execute_python":
|
||||
# Try to parse as JSON if it's a string
|
||||
import json
|
||||
if isinstance(actual_output, str):
|
||||
try:
|
||||
actual_output = json.loads(actual_output)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
|
||||
|
||||
if isinstance(actual_output, dict):
|
||||
tool_plot_urls = actual_output.get("plot_urls", [])
|
||||
if tool_plot_urls:
|
||||
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
|
||||
plot_urls.extend(tool_plot_urls)
|
||||
# Yield metadata about plots immediately
|
||||
yield {
|
||||
"content": "",
|
||||
"metadata": {"plot_urls": tool_plot_urls}
|
||||
}
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
|
||||
@@ -1,7 +1,54 @@
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
from gateway.user_session import UserSession
|
||||
|
||||
|
||||
def _get_chart_store_context() -> str:
|
||||
"""Get current ChartStore state for context injection.
|
||||
|
||||
Returns:
|
||||
Formatted string with ChartStore contents, or empty string if unavailable
|
||||
"""
|
||||
try:
|
||||
from agent.tools import _registry
|
||||
|
||||
if not _registry:
|
||||
return ""
|
||||
|
||||
chart_store = _registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
return ""
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
# Only include if there's actual chart data
|
||||
if not chart_data or not chart_data.get("symbol"):
|
||||
return ""
|
||||
|
||||
# Format the chart information
|
||||
symbol = chart_data.get("symbol", "N/A")
|
||||
interval = chart_data.get("interval", "N/A")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
chart_context = f"""
|
||||
## Current Chart Context
|
||||
|
||||
The user is currently viewing a chart with the following settings:
|
||||
- **Symbol**: {symbol}
|
||||
- **Interval**: {interval}
|
||||
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}
|
||||
|
||||
This information is automatically available because you're connected via websocket.
|
||||
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.
|
||||
"""
|
||||
return chart_context
|
||||
|
||||
except Exception:
|
||||
# Silently fail - chart context is optional enhancement
|
||||
return ""
|
||||
|
||||
|
||||
def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||
"""Build the system prompt for the agent.
|
||||
|
||||
@@ -17,6 +64,15 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||
"""
|
||||
channels_str = ", ".join(active_channels) if active_channels else "none"
|
||||
|
||||
# Check if user is connected via websocket - if so, inject chart context
|
||||
# Note: We check for websocket by looking for "websocket" in channel IDs
|
||||
# since WebSocketChannel uses channel_id like "websocket-{uuid}"
|
||||
has_websocket = any("websocket" in channel_id.lower() for channel_id in active_channels)
|
||||
|
||||
chart_context = ""
|
||||
if has_websocket:
|
||||
chart_context = _get_chart_store_context()
|
||||
|
||||
# Context already includes system_prompt.md and other docs
|
||||
# Just add current session information
|
||||
prompt = f"""{context}
|
||||
@@ -28,7 +84,7 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
|
||||
If the user sends a new message while you're responding, your current response will be interrupted
|
||||
and you'll be re-invoked with the updated context.
|
||||
"""
|
||||
{chart_context}"""
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
@@ -1,662 +0,0 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
import io
|
||||
import base64
|
||||
import sys
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global registry instance (will be set by main.py)
|
||||
_registry = None
|
||||
_datasource_registry = None
|
||||
|
||||
|
||||
def set_registry(registry):
|
||||
"""Set the global SyncRegistry instance for tools to use."""
|
||||
global _registry
|
||||
_registry = registry
|
||||
|
||||
|
||||
def set_datasource_registry(datasource_registry):
|
||||
"""Set the global DataSourceRegistry instance for tools to use."""
|
||||
global _datasource_registry
|
||||
_datasource_registry = datasource_registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_sync_stores() -> List[str]:
|
||||
"""List all available synchronization stores.
|
||||
|
||||
Returns:
|
||||
List of store names that can be read/written
|
||||
"""
|
||||
if not _registry:
|
||||
return []
|
||||
return list(_registry.entries.keys())
|
||||
|
||||
|
||||
@tool
|
||||
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
||||
"""Read the current state of a synchronization store.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current state of the store
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = _registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(_registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
return entry.model.model_dump(mode="json")
|
||||
|
||||
|
||||
@tool
|
||||
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Update the state of a synchronization store.
|
||||
|
||||
This will apply the updates to the store and trigger synchronization
|
||||
with all connected clients.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to update
|
||||
updates: Dictionary of field updates (field_name: new_value)
|
||||
|
||||
Returns:
|
||||
Dictionary with status and updated fields
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist or updates are invalid
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = _registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(_registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
try:
|
||||
# Get current state
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
|
||||
# Apply updates
|
||||
new_state = {**current_state, **updates}
|
||||
|
||||
# Update the model
|
||||
_registry._update_model(entry.model, new_state)
|
||||
|
||||
# Trigger sync
|
||||
await _registry.push_all()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"store": store_name,
|
||||
"updated_fields": list(updates.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
||||
"""Get the schema/structure of a synchronization store.
|
||||
|
||||
This shows what fields are available and their types.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store
|
||||
|
||||
Returns:
|
||||
Dictionary describing the store's schema
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = _registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(_registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
# Get model schema
|
||||
schema = entry.model.model_json_schema()
|
||||
|
||||
return {
|
||||
"store_name": store_name,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
# DataSource tools
|
||||
|
||||
@tool
|
||||
def list_data_sources() -> List[str]:
|
||||
"""List all available data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names that can be queried for market data
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
return []
|
||||
return _datasource_registry.list_sources()
|
||||
|
||||
|
||||
@tool
|
||||
async def search_symbols(
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for trading symbols across all data sources.
|
||||
|
||||
Automatically searches all available data sources and returns aggregated results.
|
||||
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
||||
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
||||
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
||||
limit: Maximum number of results per source (default: 30)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping source names to lists of matching symbols.
|
||||
Each symbol includes: symbol, full_name, description, exchange, type.
|
||||
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
||||
|
||||
Example response:
|
||||
{
|
||||
"demo": [
|
||||
{
|
||||
"symbol": "BTC/USDT",
|
||||
"full_name": "Bitcoin / Tether USD",
|
||||
"description": "Bitcoin perpetual futures",
|
||||
"exchange": "demo",
|
||||
"type": "crypto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
# Always search all sources
|
||||
results = await _datasource_registry.search_all(query, type, exchange, limit)
|
||||
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
||||
|
||||
|
||||
@tool
|
||||
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get complete metadata for a trading symbol.
|
||||
|
||||
This retrieves full information about a symbol including:
|
||||
- Description and type
|
||||
- Supported time resolutions
|
||||
- Available data columns (OHLCV, volume, funding rates, etc.)
|
||||
- Trading session information
|
||||
- Price scale and precision
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source (use list_data_sources to see available)
|
||||
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing complete symbol metadata including column schema
|
||||
|
||||
Raises:
|
||||
ValueError: If source_name or symbol is not found
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
symbol_info = await _datasource_registry.resolve_symbol(source_name, symbol)
|
||||
return symbol_info.model_dump()
|
||||
|
||||
|
||||
@tool
|
||||
async def get_historical_data(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get historical bar/candle data for a symbol.
|
||||
|
||||
Retrieves time-series data between the specified timestamps. The data
|
||||
includes all columns defined for the symbol (OHLCV + any custom columns).
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
||||
from_time: Start time as Unix timestamp in seconds
|
||||
to_time: End time as Unix timestamp in seconds
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- symbol: The requested symbol
|
||||
- resolution: The time resolution
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- nextTime: If present, indicates more data is available for pagination
|
||||
|
||||
Raises:
|
||||
ValueError: If source, symbol, or resolution is invalid
|
||||
|
||||
Example:
|
||||
# Get 1-hour BTC data for the last 24 hours
|
||||
import time
|
||||
to_time = int(time.time())
|
||||
from_time = to_time - 86400 # 24 hours ago
|
||||
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
source = _datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = _datasource_registry.list_sources()
|
||||
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
||||
|
||||
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||
"""Internal implementation for getting chart data.
|
||||
|
||||
This is a helper function that can be called by both get_chart_data tool
|
||||
and analyze_chart_data tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
# Read current chart state
|
||||
chart_store = _registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
raise ValueError("ChartStore not found in registry")
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
symbol = chart_data.get("symbol", "")
|
||||
interval = chart_data.get("interval", "15")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
if not symbol:
|
||||
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
|
||||
|
||||
# Parse the symbol to extract exchange/source and symbol name
|
||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
if ":" not in symbol:
|
||||
raise ValueError(
|
||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||
)
|
||||
|
||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
|
||||
# Get the data source
|
||||
source = _datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = _datasource_registry.list_sources()
|
||||
raise ValueError(
|
||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||
)
|
||||
|
||||
# Determine time range - REQUIRE it to be set, no defaults
|
||||
if start_time is None or end_time is None:
|
||||
raise ValueError(
|
||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||
f"Wait for the chart to fully load before analyzing data."
|
||||
)
|
||||
|
||||
from_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
logger.info(
|
||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||
f"countback={countback}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||
f"resolution '{interval}'"
|
||||
)
|
||||
|
||||
# Query the data source
|
||||
result = await source.get_bars(
|
||||
symbol=symbol_name,
|
||||
resolution=interval,
|
||||
from_time=from_time,
|
||||
to_time=end_time,
|
||||
countback=countback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received {len(result.bars)} bars from data source. "
|
||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Build chart context to return along with result
|
||||
chart_context = {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
}
|
||||
|
||||
return result, chart_context, source_name
|
||||
|
||||
|
||||
@tool
|
||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||
|
||||
This is a convenience tool that automatically:
|
||||
1. Reads the ChartStore to see what chart the user is viewing
|
||||
2. Parses the symbol to determine the data source (exchange prefix)
|
||||
3. Queries the appropriate data source for that symbol's data
|
||||
4. Returns the data for the visible time range and interval
|
||||
|
||||
This is the preferred way to access chart data when helping the user analyze
|
||||
what they're looking at, since it automatically uses their current chart context.
|
||||
|
||||
Args:
|
||||
countback: Optional limit on number of bars to return. If not specified,
|
||||
returns all bars in the visible time range.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- symbol: The trading pair being viewed
|
||||
- resolution: The chart interval
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- source: Which data source was used
|
||||
|
||||
Raises:
|
||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||
or if the symbol format is invalid
|
||||
|
||||
Example:
|
||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||
data = get_chart_data()
|
||||
# Returns BTC/USDT data from binance source at 15min resolution
|
||||
# for the currently visible time range
|
||||
"""
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Return enriched result with chart context
|
||||
response = result.model_dump()
|
||||
response["chart_context"] = chart_context
|
||||
response["source"] = source_name
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@tool
|
||||
async def analyze_chart_data(python_script: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Analyze the current chart data using a Python script with pandas and matplotlib.
|
||||
|
||||
This tool:
|
||||
1. Gets the current chart data (same as get_chart_data)
|
||||
2. Converts it to a pandas DataFrame with columns: time, open, high, low, close, volume
|
||||
3. Executes your Python script with access to the DataFrame as 'df'
|
||||
4. Saves any matplotlib plots to disk and returns URLs to access them
|
||||
5. Returns any final DataFrame result and plot URLs
|
||||
|
||||
The script has access to:
|
||||
- `df`: pandas DataFrame with OHLCV data indexed by datetime
|
||||
- `pandas` (as `pd`): For data manipulation
|
||||
- `numpy` (as `np`): For numerical operations
|
||||
- `matplotlib.pyplot` (as `plt`): For plotting (use plt.figure() for each plot)
|
||||
|
||||
All matplotlib figures are automatically saved to disk and accessible via URLs.
|
||||
The last expression in the script (if it's a DataFrame) is returned as the result.
|
||||
|
||||
Args:
|
||||
python_script: Python code to execute. The DataFrame is available as 'df'.
|
||||
Can use pandas, numpy, matplotlib. Return a DataFrame to include it in results.
|
||||
countback: Optional limit on number of bars to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- source: Data source used
|
||||
- script_output: Any printed output from the script
|
||||
- result_dataframe: If script returns a DataFrame, it's included here as dict
|
||||
- plot_urls: List of URLs to saved plot images (one per plt.figure())
|
||||
- error: Error message if script execution failed
|
||||
|
||||
Example scripts:
|
||||
# Calculate 20-period SMA and plot
|
||||
```python
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(df.index, df['close'], label='Close')
|
||||
plt.plot(df.index, df['SMA20'], label='SMA20')
|
||||
plt.legend()
|
||||
plt.title('Price with SMA')
|
||||
df[['close', 'SMA20']].tail(10) # Return last 10 rows
|
||||
```
|
||||
|
||||
# Calculate RSI
|
||||
```python
|
||||
delta = df['close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
||||
rs = gain / loss
|
||||
df['RSI'] = 100 - (100 / (1 + rs))
|
||||
df[['close', 'RSI']].tail(20)
|
||||
```
|
||||
|
||||
# Multiple plots
|
||||
```python
|
||||
# Price chart
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(df['close'])
|
||||
plt.title('Price')
|
||||
|
||||
# Volume chart
|
||||
plt.figure(figsize=(12, 3))
|
||||
plt.bar(df.index, df['volume'])
|
||||
plt.title('Volume')
|
||||
|
||||
df.describe() # Return statistics
|
||||
```
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
try:
|
||||
# Import pandas and numpy here to allow lazy loading
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"Required library not installed: {e}. "
|
||||
"Please install pandas, numpy, and matplotlib: pip install pandas numpy matplotlib"
|
||||
)
|
||||
|
||||
# Get chart data using the internal helper function
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Build the same response format as get_chart_data
|
||||
chart_data = result.model_dump()
|
||||
chart_data["chart_context"] = chart_context
|
||||
chart_data["source"] = source_name
|
||||
|
||||
# Convert bars to DataFrame
|
||||
bars = chart_data.get('bars', [])
|
||||
if not bars:
|
||||
return {
|
||||
"chart_context": chart_data.get('chart_context', {}),
|
||||
"source": chart_data.get('source', ''),
|
||||
"error": "No data available for the current chart"
|
||||
}
|
||||
|
||||
# Build DataFrame
|
||||
rows = []
|
||||
for bar in bars:
|
||||
row = {
|
||||
'time': pd.to_datetime(bar['time'], unit='s'),
|
||||
**bar['data'] # Includes open, high, low, close, volume, etc.
|
||||
}
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
df.set_index('time', inplace=True)
|
||||
|
||||
# Convert price columns to float for clean numeric operations
|
||||
price_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in price_columns:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
|
||||
logger.info(
|
||||
f"Created DataFrame with {len(df)} rows, columns: {df.columns.tolist()}, "
|
||||
f"time range: {df.index.min()} to {df.index.max()}, "
|
||||
f"dtypes: {df.dtypes.to_dict()}"
|
||||
)
|
||||
|
||||
# Prepare execution environment
|
||||
script_globals = {
|
||||
'df': df,
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'plt': plt,
|
||||
}
|
||||
|
||||
# Capture stdout/stderr
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
result_df = None
|
||||
error_msg = None
|
||||
plot_urls = []
|
||||
|
||||
# Determine uploads directory (relative to this file)
|
||||
uploads_dir = Path(__file__).parent.parent.parent / "data" / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
|
||||
# Execute the script
|
||||
exec(python_script, script_globals)
|
||||
|
||||
# Check if the last line is an expression that returns a DataFrame
|
||||
# We'll try to evaluate it separately
|
||||
script_lines = python_script.strip().split('\n')
|
||||
if script_lines:
|
||||
last_line = script_lines[-1].strip()
|
||||
# Only evaluate if it doesn't look like a statement
|
||||
if last_line and not any(last_line.startswith(kw) for kw in ['if', 'for', 'while', 'def', 'class', 'import', 'from', 'with', 'try', 'return']):
|
||||
try:
|
||||
last_result = eval(last_line, script_globals)
|
||||
if isinstance(last_result, pd.DataFrame):
|
||||
result_df = last_result
|
||||
except:
|
||||
# If eval fails, that's okay - might not be an expression
|
||||
pass
|
||||
|
||||
# Save all matplotlib figures to disk
|
||||
for fig_num in plt.get_fignums():
|
||||
fig = plt.figure(fig_num)
|
||||
|
||||
# Generate unique filename
|
||||
plot_id = str(uuid.uuid4())
|
||||
filename = f"plot_{plot_id}.png"
|
||||
filepath = uploads_dir / filename
|
||||
|
||||
# Save figure to file
|
||||
fig.savefig(filepath, format='png', bbox_inches='tight', dpi=100)
|
||||
|
||||
# Generate URL that can be accessed via the web server
|
||||
plot_url = f"/uploads/{filename}"
|
||||
plot_urls.append(plot_url)
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
import traceback
|
||||
error_msg += f"\n{traceback.format_exc()}"
|
||||
|
||||
# Build response
|
||||
response = {
|
||||
"chart_context": chart_data.get('chart_context', {}),
|
||||
"source": chart_data.get('source', ''),
|
||||
"script_output": stdout_capture.getvalue(),
|
||||
}
|
||||
|
||||
if error_msg:
|
||||
response["error"] = error_msg
|
||||
response["stderr"] = stderr_capture.getvalue()
|
||||
|
||||
if result_df is not None:
|
||||
# Convert DataFrame to dict for JSON serialization
|
||||
response["result_dataframe"] = {
|
||||
"columns": result_df.columns.tolist(),
|
||||
"index": result_df.index.astype(str).tolist() if hasattr(result_df.index, 'astype') else result_df.index.tolist(),
|
||||
"data": result_df.values.tolist(),
|
||||
"shape": result_df.shape,
|
||||
}
|
||||
|
||||
if plot_urls:
|
||||
response["plot_urls"] = plot_urls
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Export all tools
|
||||
SYNC_TOOLS = [
|
||||
list_sync_stores,
|
||||
read_sync_state,
|
||||
write_sync_state,
|
||||
get_store_schema
|
||||
]
|
||||
|
||||
DATASOURCE_TOOLS = [
|
||||
list_data_sources,
|
||||
search_symbols,
|
||||
get_symbol_info,
|
||||
get_historical_data,
|
||||
get_chart_data,
|
||||
analyze_chart_data
|
||||
]
|
||||
139
backend/src/agent/tools/CHART_UTILS_README.md
Normal file
139
backend/src/agent/tools/CHART_UTILS_README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# Chart Utilities - Standard OHLC Plotting
|
||||
|
||||
## Overview
|
||||
|
||||
The `chart_utils.py` module provides convenience functions for creating beautiful, professional OHLC candlestick charts with a consistent look and feel. This is designed to be used by the LLM in `analyze_chart_data` scripts, eliminating the need to write custom matplotlib code for every chart.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Beautiful by default**: Uses mplfinance with seaborn-inspired aesthetics
|
||||
- **Consistent styling**: Professional color scheme (teal green up, coral red down)
|
||||
- **Easy to use**: Simple function calls instead of complex matplotlib code
|
||||
- **Customizable**: Supports all mplfinance options via kwargs
|
||||
- **Volume integration**: Optional volume subplot
|
||||
|
||||
## Installation
|
||||
|
||||
The required package `mplfinance` has been added to `requirements.txt`:
|
||||
|
||||
```bash
|
||||
pip install mplfinance
|
||||
```
|
||||
|
||||
## Available Functions
|
||||
|
||||
### 1. `plot_ohlc(df, title=None, volume=True, figsize=(14, 8), **kwargs)`
|
||||
|
||||
Main function for creating standard OHLC candlestick charts.
|
||||
|
||||
**Parameters:**
|
||||
- `df`: pandas DataFrame with DatetimeIndex and OHLCV columns
|
||||
- `title`: Optional chart title
|
||||
- `volume`: Whether to include volume subplot (default: True)
|
||||
- `figsize`: Figure size in inches (default: (14, 8))
|
||||
- `**kwargs`: Additional mplfinance.plot() arguments
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
fig = plot_ohlc(df, title='BTC/USDT 15min', volume=True)
|
||||
```
|
||||
|
||||
### 2. `add_indicators_to_plot(df, indicators, **plot_kwargs)`
|
||||
|
||||
Creates OHLC chart with technical indicators overlaid.
|
||||
|
||||
**Parameters:**
|
||||
- `df`: DataFrame with OHLCV data and indicator columns
|
||||
- `indicators`: Dict mapping indicator column names to display parameters
|
||||
- `**plot_kwargs`: Additional arguments for plot_ohlc()
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5}
|
||||
},
|
||||
title='Price with Moving Averages'
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Preset Functions
|
||||
|
||||
- `plot_price_volume(df, title=None)` - Standard price + volume chart
|
||||
- `plot_price_only(df, title=None)` - Candlesticks without volume
|
||||
|
||||
## Integration with analyze_chart_data
|
||||
|
||||
These functions are automatically available in the `analyze_chart_data` tool's script environment:
|
||||
|
||||
```python
|
||||
# In an analyze_chart_data script:
|
||||
# df is already provided
|
||||
|
||||
# Simple usage
|
||||
fig = plot_ohlc(df, title='Price Action')
|
||||
|
||||
# With indicators
|
||||
df['SMA'] = df['close'].rolling(20).mean()
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={'SMA': {'color': 'blue', 'width': 1.5}},
|
||||
title='Price with SMA'
|
||||
)
|
||||
|
||||
# Return data for the assistant
|
||||
df[['close', 'SMA']].tail(10)
|
||||
```
|
||||
|
||||
## Styling
|
||||
|
||||
The default style includes:
|
||||
- **Up candles**: Teal green (#26a69a)
|
||||
- **Down candles**: Coral red (#ef5350)
|
||||
- **Background**: Light gray with white axes
|
||||
- **Grid**: Subtle dashed lines with 30% alpha
|
||||
- **Professional fonts**: Clean, readable sizes
|
||||
|
||||
## Why This Matters
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
# LLM had to write this every time
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
ax.plot(df.index, df['close'], label='Close')
|
||||
# ... lots more code for styling, colors, etc.
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
# LLM can now just do this
|
||||
fig = plot_ohlc(df, title='BTC/USDT')
|
||||
```
|
||||
|
||||
Benefits:
|
||||
- ✅ Less code to generate → faster response
|
||||
- ✅ Consistent appearance across all charts
|
||||
- ✅ Professional look out of the box
|
||||
- ✅ Easier to maintain and customize
|
||||
- ✅ Better use of mplfinance's candlestick rendering
|
||||
|
||||
## Example Output
|
||||
|
||||
See `chart_utils_example.py` for runnable examples demonstrating:
|
||||
1. Basic OHLC chart with volume
|
||||
2. OHLC chart with multiple indicators
|
||||
3. Price-only chart
|
||||
4. Custom styling options
|
||||
|
||||
## File Locations
|
||||
|
||||
- **Main module**: `backend/src/agent/tools/chart_utils.py`
|
||||
- **Integration**: `backend/src/agent/tools/chart_tools.py` (lines 306-328)
|
||||
- **Examples**: `backend/src/agent/tools/chart_utils_example.py`
|
||||
- **Dependency**: `backend/requirements.txt` (mplfinance added)
|
||||
50
backend/src/agent/tools/__init__.py
Normal file
50
backend/src/agent/tools/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Agent tools for trading operations.
|
||||
|
||||
This package provides tools for:
|
||||
- Synchronization stores (sync_tools)
|
||||
- Data sources and market data (datasource_tools)
|
||||
- Chart data access and analysis (chart_tools)
|
||||
- Technical indicators (indicator_tools)
|
||||
"""
|
||||
|
||||
# Global registries that will be set by main.py
|
||||
_registry = None
|
||||
_datasource_registry = None
|
||||
_indicator_registry = None
|
||||
|
||||
|
||||
def set_registry(registry):
|
||||
"""Set the global SyncRegistry instance for tools to use."""
|
||||
global _registry
|
||||
_registry = registry
|
||||
|
||||
|
||||
def set_datasource_registry(datasource_registry):
|
||||
"""Set the global DataSourceRegistry instance for tools to use."""
|
||||
global _datasource_registry
|
||||
_datasource_registry = datasource_registry
|
||||
|
||||
|
||||
def set_indicator_registry(indicator_registry):
|
||||
"""Set the global IndicatorRegistry instance for tools to use."""
|
||||
global _indicator_registry
|
||||
_indicator_registry = indicator_registry
|
||||
|
||||
|
||||
# Import all tools from submodules
|
||||
from .sync_tools import SYNC_TOOLS
|
||||
from .datasource_tools import DATASOURCE_TOOLS
|
||||
from .chart_tools import CHART_TOOLS
|
||||
from .indicator_tools import INDICATOR_TOOLS
|
||||
from .research_tools import RESEARCH_TOOLS
|
||||
|
||||
__all__ = [
|
||||
"set_registry",
|
||||
"set_datasource_registry",
|
||||
"set_indicator_registry",
|
||||
"SYNC_TOOLS",
|
||||
"DATASOURCE_TOOLS",
|
||||
"CHART_TOOLS",
|
||||
"INDICATOR_TOOLS",
|
||||
"RESEARCH_TOOLS",
|
||||
]
|
||||
371
backend/src/agent/tools/chart_tools.py
Normal file
371
backend/src/agent/tools/chart_tools.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Chart data access and analysis tools."""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
import io
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||
"""Internal implementation for getting chart data.
|
||||
|
||||
This is a helper function that can be called by both get_chart_data tool
|
||||
and analyze_chart_data tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||
"""
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
# Read current chart state
|
||||
chart_store = registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
raise ValueError("ChartStore not found in registry")
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
symbol = chart_data.get("symbol", "")
|
||||
interval = chart_data.get("interval", "15")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
if not symbol:
|
||||
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
|
||||
|
||||
# Parse the symbol to extract exchange/source and symbol name
|
||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
if ":" not in symbol:
|
||||
raise ValueError(
|
||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||
)
|
||||
|
||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
|
||||
# Get the data source
|
||||
source = datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = datasource_registry.list_sources()
|
||||
raise ValueError(
|
||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||
)
|
||||
|
||||
# Determine time range - REQUIRE it to be set, no defaults
|
||||
if start_time is None or end_time is None:
|
||||
raise ValueError(
|
||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||
f"Wait for the chart to fully load before analyzing data."
|
||||
)
|
||||
|
||||
from_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
logger.info(
|
||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||
f"countback={countback}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||
f"resolution '{interval}'"
|
||||
)
|
||||
|
||||
# Query the data source
|
||||
result = await source.get_bars(
|
||||
symbol=symbol_name,
|
||||
resolution=interval,
|
||||
from_time=from_time,
|
||||
to_time=end_time,
|
||||
countback=countback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received {len(result.bars)} bars from data source. "
|
||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Build chart context to return along with result
|
||||
chart_context = {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
}
|
||||
|
||||
return result, chart_context, source_name
|
||||
|
||||
|
||||
@tool
|
||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||
|
||||
This is a convenience tool that automatically:
|
||||
1. Reads the ChartStore to see what chart the user is viewing
|
||||
2. Parses the symbol to determine the data source (exchange prefix)
|
||||
3. Queries the appropriate data source for that symbol's data
|
||||
4. Returns the data for the visible time range and interval
|
||||
|
||||
This is the preferred way to access chart data when helping the user analyze
|
||||
what they're looking at, since it automatically uses their current chart context.
|
||||
|
||||
Args:
|
||||
countback: Optional limit on number of bars to return. If not specified,
|
||||
returns all bars in the visible time range.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- symbol: The trading pair being viewed
|
||||
- resolution: The chart interval
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- source: Which data source was used
|
||||
|
||||
Raises:
|
||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||
or if the symbol format is invalid
|
||||
|
||||
Example:
|
||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||
data = get_chart_data()
|
||||
# Returns BTC/USDT data from binance source at 15min resolution
|
||||
# for the currently visible time range
|
||||
"""
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Return enriched result with chart context
|
||||
response = result.model_dump()
|
||||
response["chart_context"] = chart_context
|
||||
response["source"] = source_name
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Execute Python code for technical analysis with automatic chart data loading.
|
||||
|
||||
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
|
||||
|
||||
This is your go-to tool whenever the user asks about indicators, wants to see
|
||||
a chart, or needs any computational analysis of market data.
|
||||
|
||||
Pre-loaded Environment:
|
||||
- `pd` : pandas
|
||||
- `np` : numpy
|
||||
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
|
||||
- `talib` : TA-Lib technical analysis library
|
||||
- `indicator_registry`: 150+ registered indicators
|
||||
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
|
||||
|
||||
Auto-loaded when user has a chart open:
|
||||
- `df` : pandas DataFrame with DatetimeIndex and columns:
|
||||
open, high, low, close, volume (OHLCV data ready to use)
|
||||
- `chart_context` : dict with symbol, interval, start_time, end_time
|
||||
|
||||
The `plot_ohlc()` Helper:
|
||||
Create professional candlestick charts instantly:
|
||||
- `plot_ohlc(df)` - basic OHLC chart with volume
|
||||
- `plot_ohlc(df, title='BTC 15min')` - with custom title
|
||||
- `plot_ohlc(df, volume=False)` - price only, no volume
|
||||
- Returns a matplotlib Figure that's automatically saved to plot_urls
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
countback: Optional limit on number of bars to load (default: all visible bars)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- script_output : printed output + last expression result
|
||||
- result_dataframe : serialized DataFrame if last expression is a DataFrame
|
||||
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
|
||||
- chart_context : {symbol, interval, start_time, end_time} or None
|
||||
- error : traceback if execution failed
|
||||
|
||||
Examples:
|
||||
# RSI indicator with chart
|
||||
execute_python(\"\"\"
|
||||
df['RSI'] = talib.RSI(df['close'], 14)
|
||||
fig = plot_ohlc(df, title='BTC/USDT with RSI')
|
||||
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
|
||||
df[['close', 'RSI']].tail(5)
|
||||
\"\"\")
|
||||
|
||||
# Multiple indicators
|
||||
execute_python(\"\"\"
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
|
||||
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
|
||||
|
||||
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
|
||||
|
||||
current_price = df['close'].iloc[-1]
|
||||
sma20 = df['SMA_20'].iloc[-1]
|
||||
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
|
||||
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
|
||||
\"\"\")
|
||||
|
||||
# Pattern detection
|
||||
execute_python(\"\"\"
|
||||
# Find swing highs
|
||||
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
|
||||
swing_highs = df[df['swing_high']][['high']].tail(5)
|
||||
|
||||
fig = plot_ohlc(df, title='Swing High Detection')
|
||||
print("Recent swing highs:")
|
||||
print(swing_highs)
|
||||
\"\"\")
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import talib
|
||||
except ImportError:
|
||||
talib = None
|
||||
logger.warning("TA-Lib not available in execute_python environment")
|
||||
|
||||
# --- Attempt to load chart data ---
|
||||
df = None
|
||||
chart_context = None
|
||||
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if registry and datasource_registry:
|
||||
try:
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
bars = result.bars
|
||||
if bars:
|
||||
rows = []
|
||||
for bar in bars:
|
||||
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
|
||||
df = pd.DataFrame(rows).set_index('time')
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
|
||||
except Exception as e:
|
||||
logger.info(f"execute_python: no chart data loaded ({e})")
|
||||
|
||||
# --- Import chart utilities ---
|
||||
from .chart_utils import plot_ohlc
|
||||
|
||||
# --- Get indicator registry ---
|
||||
indicator_registry = _get_indicator_registry()
|
||||
|
||||
# --- Build globals ---
|
||||
script_globals: Dict[str, Any] = {
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'plt': plt,
|
||||
'talib': talib,
|
||||
'indicator_registry': indicator_registry,
|
||||
'df': df,
|
||||
'chart_context': chart_context,
|
||||
'plot_ohlc': plot_ohlc,
|
||||
}
|
||||
|
||||
# --- Execute ---
|
||||
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stdout_capture = io.StringIO()
|
||||
result_df = None
|
||||
error_msg = None
|
||||
plot_urls = []
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
|
||||
exec(code, script_globals)
|
||||
|
||||
# Capture last expression
|
||||
lines = code.strip().splitlines()
|
||||
if lines:
|
||||
last = lines[-1].strip()
|
||||
if last and not any(last.startswith(kw) for kw in (
|
||||
'if', 'for', 'while', 'def', 'class', 'import',
|
||||
'from', 'with', 'try', 'return', '#'
|
||||
)):
|
||||
try:
|
||||
last_val = eval(last, script_globals)
|
||||
if isinstance(last_val, pd.DataFrame):
|
||||
result_df = last_val
|
||||
elif last_val is not None:
|
||||
stdout_capture.write(str(last_val))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save plots
|
||||
for fig_num in plt.get_fignums():
|
||||
fig = plt.figure(fig_num)
|
||||
filename = f"plot_{uuid.uuid4()}.png"
|
||||
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
|
||||
plot_urls.append(f"/uploads/{filename}")
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
||||
|
||||
# --- Build response ---
|
||||
response: Dict[str, Any] = {
|
||||
'script_output': stdout_capture.getvalue(),
|
||||
'chart_context': chart_context,
|
||||
'plot_urls': plot_urls,
|
||||
}
|
||||
if result_df is not None:
|
||||
response['result_dataframe'] = {
|
||||
'columns': result_df.columns.tolist(),
|
||||
'index': result_df.index.astype(str).tolist(),
|
||||
'data': result_df.values.tolist(),
|
||||
'shape': result_df.shape,
|
||||
}
|
||||
if error_msg:
|
||||
response['error'] = error_msg
|
||||
|
||||
return response
|
||||
|
||||
|
||||
CHART_TOOLS = [
|
||||
get_chart_data,
|
||||
execute_python
|
||||
]
|
||||
224
backend/src/agent/tools/chart_utils.py
Normal file
224
backend/src/agent/tools/chart_utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Chart plotting utilities for creating standard, beautiful OHLC charts."""
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def plot_ohlc(
|
||||
df: pd.DataFrame,
|
||||
title: Optional[str] = None,
|
||||
volume: bool = True,
|
||||
figsize: Tuple[int, int] = (14, 8),
|
||||
style: str = 'seaborn-v0_8-darkgrid',
|
||||
**kwargs
|
||||
) -> plt.Figure:
|
||||
"""Create a beautiful standard OHLC candlestick chart.
|
||||
|
||||
This is a convenience function that generates a professional-looking candlestick
|
||||
chart with consistent styling across all generated charts. It uses mplfinance
|
||||
with seaborn aesthetics for a polished appearance.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame with DatetimeIndex and columns: open, high, low, close, volume
|
||||
title: Optional chart title. If None, uses symbol from chart context
|
||||
volume: Whether to include volume subplot (default: True)
|
||||
figsize: Figure size as (width, height) in inches (default: (14, 8))
|
||||
style: Base matplotlib style to use (default: 'seaborn-v0_8-darkgrid')
|
||||
**kwargs: Additional arguments to pass to mplfinance.plot()
|
||||
|
||||
Returns:
|
||||
matplotlib.figure.Figure: The created figure object
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Basic usage in analyze_chart_data script
|
||||
fig = plot_ohlc(df, title='BTC/USDT 15min')
|
||||
|
||||
# Customize with additional indicators
|
||||
fig = plot_ohlc(df, volume=True, title='Price Action')
|
||||
|
||||
# Add custom overlays after calling plot_ohlc
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
fig = plot_ohlc(df, title='With SMA')
|
||||
# Note: For mplfinance overlays, use the mav or addplot parameters
|
||||
```
|
||||
|
||||
Note:
|
||||
The DataFrame must have a DatetimeIndex and the standard OHLCV columns.
|
||||
Column names should be lowercase: open, high, low, close, volume
|
||||
"""
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mplfinance is required for plot_ohlc(). "
|
||||
"Install it with: pip install mplfinance"
|
||||
)
|
||||
|
||||
# Validate DataFrame structure
|
||||
required_cols = ['open', 'high', 'low', 'close']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(
|
||||
f"DataFrame missing required columns: {missing_cols}. "
|
||||
f"Required: {required_cols}"
|
||||
)
|
||||
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
raise ValueError(
|
||||
"DataFrame must have a DatetimeIndex. "
|
||||
"Convert with: df.index = pd.to_datetime(df.index)"
|
||||
)
|
||||
|
||||
# Ensure volume column exists for volume plot
|
||||
if volume and 'volume' not in df.columns:
|
||||
logger.warning("volume=True but 'volume' column not found in DataFrame. Disabling volume.")
|
||||
volume = False
|
||||
|
||||
# Create custom style with seaborn aesthetics
|
||||
# Using a professional color scheme: green for up candles, red for down candles
|
||||
mc = mpf.make_marketcolors(
|
||||
up='#26a69a', # Teal green (calmer than bright green)
|
||||
down='#ef5350', # Coral red (softer than pure red)
|
||||
edge='inherit', # Match candle color for edges
|
||||
wick='inherit', # Match candle color for wicks
|
||||
volume='in', # Volume bars colored by price direction
|
||||
alpha=0.9 # Slight transparency for elegance
|
||||
)
|
||||
|
||||
s = mpf.make_mpf_style(
|
||||
base_mpf_style='charles', # Clean base style
|
||||
marketcolors=mc,
|
||||
rc={
|
||||
'font.size': 10,
|
||||
'axes.labelsize': 11,
|
||||
'axes.titlesize': 12,
|
||||
'xtick.labelsize': 9,
|
||||
'ytick.labelsize': 9,
|
||||
'legend.fontsize': 10,
|
||||
'figure.facecolor': '#f0f0f0',
|
||||
'axes.facecolor': '#ffffff',
|
||||
'axes.grid': True,
|
||||
'grid.alpha': 0.3,
|
||||
'grid.linestyle': '--',
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare plot parameters
|
||||
plot_params = {
|
||||
'type': 'candle',
|
||||
'style': s,
|
||||
'volume': volume,
|
||||
'figsize': figsize,
|
||||
'tight_layout': True,
|
||||
'returnfig': True,
|
||||
'warn_too_much_data': 1000, # Warn if > 1000 candles for performance
|
||||
}
|
||||
|
||||
# Add title if provided
|
||||
if title:
|
||||
plot_params['title'] = title
|
||||
|
||||
# Merge any additional kwargs
|
||||
plot_params.update(kwargs)
|
||||
|
||||
# Create the plot
|
||||
logger.info(
|
||||
f"Creating OHLC chart with {len(df)} candles, "
|
||||
f"date range: {df.index.min()} to {df.index.max()}, "
|
||||
f"volume: {volume}"
|
||||
)
|
||||
|
||||
fig, axes = mpf.plot(df, **plot_params)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def add_indicators_to_plot(
|
||||
df: pd.DataFrame,
|
||||
indicators: dict,
|
||||
**plot_kwargs
|
||||
) -> plt.Figure:
|
||||
"""Create an OHLC chart with technical indicators overlaid.
|
||||
|
||||
This extends plot_ohlc() to include common technical indicators using
|
||||
mplfinance's addplot functionality for proper overlay on candlestick charts.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame with OHLCV data and indicator columns
|
||||
indicators: Dictionary mapping indicator names to parameters
|
||||
Example: {
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'EMA_50': {'color': 'orange', 'width': 1.5}
|
||||
}
|
||||
**plot_kwargs: Additional arguments for plot_ohlc()
|
||||
|
||||
Returns:
|
||||
matplotlib.figure.Figure: The created figure object
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Calculate indicators
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
|
||||
# Plot with indicators
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5, 'label': '20 SMA'},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5, 'label': '50 SMA'}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages'
|
||||
)
|
||||
```
|
||||
"""
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mplfinance is required. Install it with: pip install mplfinance"
|
||||
)
|
||||
|
||||
# Build addplot list for indicators
|
||||
addplots = []
|
||||
for indicator_col, params in indicators.items():
|
||||
if indicator_col not in df.columns:
|
||||
logger.warning(f"Indicator column '{indicator_col}' not found in DataFrame. Skipping.")
|
||||
continue
|
||||
|
||||
color = params.get('color', 'blue')
|
||||
width = params.get('width', 1.0)
|
||||
panel = params.get('panel', 0) # 0 = main panel with candles
|
||||
ylabel = params.get('ylabel', '')
|
||||
|
||||
addplots.append(
|
||||
mpf.make_addplot(
|
||||
df[indicator_col],
|
||||
color=color,
|
||||
width=width,
|
||||
panel=panel,
|
||||
ylabel=ylabel
|
||||
)
|
||||
)
|
||||
|
||||
# Pass addplot to plot_ohlc via kwargs
|
||||
if addplots:
|
||||
plot_kwargs['addplot'] = addplots
|
||||
|
||||
return plot_ohlc(df, **plot_kwargs)
|
||||
|
||||
|
||||
# Convenience presets for common chart types
|
||||
def plot_price_volume(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||
"""Create a standard price + volume chart."""
|
||||
return plot_ohlc(df, title=title, volume=True, figsize=(14, 8))
|
||||
|
||||
|
||||
def plot_price_only(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||
"""Create a price-only candlestick chart without volume."""
|
||||
return plot_ohlc(df, title=title, volume=False, figsize=(14, 6))
|
||||
154
backend/src/agent/tools/chart_utils_example.py
Normal file
154
backend/src/agent/tools/chart_utils_example.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Example usage of chart_utils.py plotting functions.
|
||||
|
||||
This demonstrates how the LLM can use the plot_ohlc() convenience function
|
||||
in analyze_chart_data scripts to create beautiful, standard OHLC charts.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def create_sample_data(days=30):
|
||||
"""Create sample OHLCV data for testing."""
|
||||
dates = pd.date_range(end=datetime.now(), periods=days * 24, freq='1H')
|
||||
|
||||
# Simulate price movement
|
||||
np.random.seed(42)
|
||||
close = 50000 + np.cumsum(np.random.randn(len(dates)) * 100)
|
||||
|
||||
data = {
|
||||
'open': close + np.random.randn(len(dates)) * 50,
|
||||
'high': close + np.abs(np.random.randn(len(dates))) * 100,
|
||||
'low': close - np.abs(np.random.randn(len(dates))) * 100,
|
||||
'close': close,
|
||||
'volume': np.abs(np.random.randn(len(dates))) * 1000000
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
|
||||
# Ensure high is highest and low is lowest
|
||||
df['high'] = df[['open', 'high', 'low', 'close']].max(axis=1)
|
||||
df['low'] = df[['open', 'high', 'low', 'close']].min(axis=1)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from chart_utils import plot_ohlc, add_indicators_to_plot, plot_price_volume
|
||||
|
||||
# Create sample data
|
||||
df = create_sample_data(days=30)
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 1: Basic OHLC chart with volume")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||
df.tail(5)
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||
print("\n✓ Chart created successfully!")
|
||||
print(f" Figure size: {fig.get_size_inches()}")
|
||||
print(f" Number of axes: {len(fig.axes)}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 2: OHLC chart with indicators")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
# Calculate indicators
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
|
||||
# Plot with indicators
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages',
|
||||
volume=True
|
||||
)
|
||||
|
||||
df[['close', 'SMA_20', 'SMA_50', 'EMA_12']].tail(5)
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages',
|
||||
volume=True
|
||||
)
|
||||
|
||||
print("\n✓ Chart with indicators created successfully!")
|
||||
print(f" Last close: ${df['close'].iloc[-1]:,.2f}")
|
||||
print(f" SMA 20: ${df['SMA_20'].iloc[-1]:,.2f}")
|
||||
print(f" SMA 50: ${df['SMA_50'].iloc[-1]:,.2f}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 3: Price-only chart (no volume)")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
from chart_utils import plot_price_only
|
||||
|
||||
fig = plot_price_only(df, title='Clean Price Action')
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
from chart_utils import plot_price_only
|
||||
fig = plot_price_only(df, title='Clean Price Action')
|
||||
|
||||
print("\n✓ Price-only chart created successfully!")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
print("""
|
||||
The chart_utils module provides:
|
||||
|
||||
1. plot_ohlc() - Main function for beautiful candlestick charts
|
||||
- Professional seaborn-inspired styling
|
||||
- Consistent color scheme (teal up, coral down)
|
||||
- Optional volume subplot
|
||||
- Customizable figure size
|
||||
|
||||
2. add_indicators_to_plot() - OHLC charts with technical indicators
|
||||
- Overlay multiple indicators
|
||||
- Customizable colors and line widths
|
||||
- Proper integration with mplfinance
|
||||
|
||||
3. Preset functions for common chart types:
|
||||
- plot_price_volume() - Standard price + volume
|
||||
- plot_price_only() - Candlesticks without volume
|
||||
|
||||
Benefits:
|
||||
✓ Consistent look and feel across all charts
|
||||
✓ Less code for the LLM to generate
|
||||
✓ Professional appearance out of the box
|
||||
✓ Easy to customize when needed
|
||||
✓ Works seamlessly with analyze_chart_data tool
|
||||
|
||||
The LLM can now simply call plot_ohlc(df) instead of writing
|
||||
custom matplotlib code for every chart request!
|
||||
""")
|
||||
158
backend/src/agent/tools/datasource_tools.py
Normal file
158
backend/src/agent/tools/datasource_tools.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Data source and market data tools."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_data_sources() -> List[str]:
|
||||
"""List all available data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names that can be queried for market data
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_sources()
|
||||
|
||||
|
||||
@tool
|
||||
async def search_symbols(
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for trading symbols across all data sources.
|
||||
|
||||
Automatically searches all available data sources and returns aggregated results.
|
||||
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
||||
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
||||
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
||||
limit: Maximum number of results per source (default: 30)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping source names to lists of matching symbols.
|
||||
Each symbol includes: symbol, full_name, description, exchange, type.
|
||||
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
||||
|
||||
Example response:
|
||||
{
|
||||
"demo": [
|
||||
{
|
||||
"symbol": "BTC/USDT",
|
||||
"full_name": "Bitcoin / Tether USD",
|
||||
"description": "Bitcoin perpetual futures",
|
||||
"exchange": "demo",
|
||||
"type": "crypto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
# Always search all sources
|
||||
results = await registry.search_all(query, type, exchange, limit)
|
||||
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
||||
|
||||
|
||||
@tool
|
||||
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get complete metadata for a trading symbol.
|
||||
|
||||
This retrieves full information about a symbol including:
|
||||
- Description and type
|
||||
- Supported time resolutions
|
||||
- Available data columns (OHLCV, volume, funding rates, etc.)
|
||||
- Trading session information
|
||||
- Price scale and precision
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source (use list_data_sources to see available)
|
||||
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing complete symbol metadata including column schema
|
||||
|
||||
Raises:
|
||||
ValueError: If source_name or symbol is not found
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
symbol_info = await registry.resolve_symbol(source_name, symbol)
|
||||
return symbol_info.model_dump()
|
||||
|
||||
|
||||
@tool
|
||||
async def get_historical_data(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get historical bar/candle data for a symbol.
|
||||
|
||||
Retrieves time-series data between the specified timestamps. The data
|
||||
includes all columns defined for the symbol (OHLCV + any custom columns).
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
||||
from_time: Start time as Unix timestamp in seconds
|
||||
to_time: End time as Unix timestamp in seconds
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- symbol: The requested symbol
|
||||
- resolution: The time resolution
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- nextTime: If present, indicates more data is available for pagination
|
||||
|
||||
Raises:
|
||||
ValueError: If source, symbol, or resolution is invalid
|
||||
|
||||
Example:
|
||||
# Get 1-hour BTC data for the last 24 hours
|
||||
import time
|
||||
to_time = int(time.time())
|
||||
from_time = to_time - 86400 # 24 hours ago
|
||||
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
source = registry.get(source_name)
|
||||
if not source:
|
||||
available = registry.list_sources()
|
||||
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
||||
|
||||
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
DATASOURCE_TOOLS = [
|
||||
list_data_sources,
|
||||
search_symbols,
|
||||
get_symbol_info,
|
||||
get_historical_data,
|
||||
]
|
||||
169
backend/src/agent/tools/indicator_tools.py
Normal file
169
backend/src/agent/tools/indicator_tools.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Technical indicator tools."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_indicators() -> List[str]:
|
||||
"""List all available technical indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator names that can be used in analysis and strategies
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_indicators()
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific indicator.
|
||||
|
||||
Retrieves metadata including description, parameters, category, use cases,
|
||||
input/output schemas, and references.
|
||||
|
||||
Args:
|
||||
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- name: Indicator name
|
||||
- display_name: Human-readable name
|
||||
- description: What the indicator computes and why it's useful
|
||||
- category: Category (momentum, trend, volatility, volume, etc.)
|
||||
- parameters: List of configurable parameters with types and defaults
|
||||
- use_cases: Common trading scenarios where this indicator helps
|
||||
- tags: Searchable tags
|
||||
- input_schema: Required input columns (e.g., OHLCV requirements)
|
||||
- output_schema: Columns this indicator produces
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator_name is not found
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = registry.get_metadata(indicator_name)
|
||||
if not metadata:
|
||||
total_count = len(registry.list_indicators())
|
||||
raise ValueError(
|
||||
f"Indicator '{indicator_name}' not found. "
|
||||
f"Total available: {total_count} indicators. "
|
||||
f"Use search_indicators() to find indicators by name, category, or tag."
|
||||
)
|
||||
|
||||
input_schema = registry.get_input_schema(indicator_name)
|
||||
output_schema = registry.get_output_schema(indicator_name)
|
||||
|
||||
result = metadata.model_dump()
|
||||
result["input_schema"] = input_schema.model_dump() if input_schema else None
|
||||
result["output_schema"] = output_schema.model_dump() if output_schema else None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def search_indicators(
|
||||
query: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tag: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for indicators by text query, category, or tag.
|
||||
|
||||
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
|
||||
|
||||
Use this to discover relevant indicators for your trading strategy or analysis.
|
||||
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
|
||||
|
||||
Args:
|
||||
query: Optional text search across names, descriptions, and use cases
|
||||
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
|
||||
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
|
||||
|
||||
Returns:
|
||||
List of lightweight indicator summaries. Each contains:
|
||||
- name: Indicator name (use with get_indicator_info() for full details)
|
||||
- display_name: Human-readable name
|
||||
- description: Brief one-line description
|
||||
- category: Category (momentum, trend, volatility, etc.)
|
||||
|
||||
Example:
|
||||
# Find all momentum indicators
|
||||
results = search_indicators(category="momentum")
|
||||
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
|
||||
|
||||
# Then get details on interesting ones
|
||||
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
|
||||
|
||||
# Search for moving average indicators
|
||||
search_indicators(query="moving average")
|
||||
|
||||
# Find all TA-Lib indicators
|
||||
search_indicators(tag="talib")
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
results = []
|
||||
|
||||
if query:
|
||||
results = registry.search_by_text(query)
|
||||
elif category:
|
||||
results = registry.search_by_category(category)
|
||||
elif tag:
|
||||
results = registry.search_by_tag(tag)
|
||||
else:
|
||||
# Return all indicators if no filter
|
||||
results = registry.get_all_metadata()
|
||||
|
||||
# Return lightweight summaries only
|
||||
return [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"description": r.description,
|
||||
"category": r.category
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_categories() -> Dict[str, int]:
|
||||
"""Get all indicator categories and their counts.
|
||||
|
||||
Returns a summary of available indicator categories, useful for
|
||||
exploring what types of indicators are available.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category name to count of indicators in that category.
|
||||
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
categories: Dict[str, int] = {}
|
||||
for metadata in registry.get_all_metadata():
|
||||
category = metadata.category
|
||||
categories[category] = categories.get(category, 0) + 1
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
INDICATOR_TOOLS = [
|
||||
list_indicators,
|
||||
get_indicator_info,
|
||||
search_indicators,
|
||||
get_indicator_categories
|
||||
]
|
||||
171
backend/src/agent/tools/research_tools.py
Normal file
171
backend/src/agent/tools/research_tools.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Research and external data tools for trading analysis."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.tools import tool
|
||||
from langchain_community.tools import (
|
||||
ArxivQueryRun,
|
||||
WikipediaQueryRun,
|
||||
DuckDuckGoSearchRun
|
||||
)
|
||||
from langchain_community.utilities import (
|
||||
ArxivAPIWrapper,
|
||||
WikipediaAPIWrapper,
|
||||
DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def search_arxiv(query: str, max_results: int = 5) -> str:
|
||||
"""Search arXiv for academic papers on quantitative finance, trading strategies, and machine learning.
|
||||
|
||||
Use this to find research papers on topics like:
|
||||
- Market microstructure and order flow
|
||||
- Algorithmic trading strategies
|
||||
- Machine learning for finance
|
||||
- Time series forecasting
|
||||
- Risk management
|
||||
- Portfolio optimization
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "machine learning algorithmic trading", "deep learning stock prediction")
|
||||
max_results: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Summary of papers including titles, authors, abstracts, and links
|
||||
|
||||
Example:
|
||||
search_arxiv("reinforcement learning trading", max_results=3)
|
||||
"""
|
||||
arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=max_results))
|
||||
return arxiv.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information on finance, trading, and economics concepts.
|
||||
|
||||
Use this to get background information on:
|
||||
- Financial instruments and markets
|
||||
- Economic indicators
|
||||
- Trading terminology
|
||||
- Technical analysis concepts
|
||||
- Historical market events
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "Black-Scholes model", "technical analysis", "options trading")
|
||||
|
||||
Returns:
|
||||
Wikipedia article summary with key information
|
||||
|
||||
Example:
|
||||
search_wikipedia("Bollinger Bands")
|
||||
"""
|
||||
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||
return wikipedia.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def search_web(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for current information on markets, news, and trading.
|
||||
|
||||
Use this to find:
|
||||
- Latest market news and analysis
|
||||
- Company announcements and earnings
|
||||
- Economic events and indicators
|
||||
- Cryptocurrency updates
|
||||
- Exchange status and updates
|
||||
- Trading strategy discussions
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "Bitcoin price news", "Fed interest rate decision")
|
||||
max_results: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Search results with titles, snippets, and links
|
||||
|
||||
Example:
|
||||
search_web("Ethereum merge update", max_results=3)
|
||||
"""
|
||||
# Lazy initialization to avoid hanging during import
|
||||
search = DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper())
|
||||
# Note: max_results parameter doesn't work properly with current wrapper
|
||||
return search.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def http_get(url: str, params: Optional[Dict[str, str]] = None) -> str:
|
||||
"""Make HTTP GET request to fetch data from APIs or web pages.
|
||||
|
||||
Use this to retrieve:
|
||||
- Exchange API data (if public endpoints)
|
||||
- Market data from external APIs
|
||||
- Documentation and specifications
|
||||
- News articles and blog posts
|
||||
- JSON/XML data from web services
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
params: Optional query parameters as a dictionary
|
||||
|
||||
Returns:
|
||||
Response text from the URL
|
||||
|
||||
Raises:
|
||||
ValueError: If the request fails
|
||||
|
||||
Example:
|
||||
http_get("https://api.coingecko.com/api/v3/simple/price",
|
||||
params={"ids": "bitcoin", "vs_currencies": "usd"})
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"HTTP GET request failed: {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def http_post(url: str, data: Dict[str, Any]) -> str:
|
||||
"""Make HTTP POST request to send data to APIs.
|
||||
|
||||
Use this to:
|
||||
- Submit data to external APIs
|
||||
- Trigger webhooks
|
||||
- Post analysis results
|
||||
- Interact with exchange APIs (if authenticated)
|
||||
|
||||
Args:
|
||||
url: The URL to post to
|
||||
data: Dictionary of data to send in the request body
|
||||
|
||||
Returns:
|
||||
Response text from the server
|
||||
|
||||
Raises:
|
||||
ValueError: If the request fails
|
||||
|
||||
Example:
|
||||
http_post("https://webhook.site/xxx", {"message": "Trade executed"})
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"HTTP POST request failed: {str(e)}")
|
||||
|
||||
|
||||
# Export tools list
|
||||
RESEARCH_TOOLS = [
|
||||
search_arxiv,
|
||||
search_wikipedia,
|
||||
search_web,
|
||||
http_get,
|
||||
http_post
|
||||
]
|
||||
138
backend/src/agent/tools/sync_tools.py
Normal file
138
backend/src/agent/tools/sync_tools.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Synchronization store tools."""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_sync_stores() -> List[str]:
|
||||
"""List all available synchronization stores.
|
||||
|
||||
Returns:
|
||||
List of store names that can be read/written
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return list(registry.entries.keys())
|
||||
|
||||
|
||||
@tool
|
||||
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
||||
"""Read the current state of a synchronization store.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current state of the store
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
return entry.model.model_dump(mode="json")
|
||||
|
||||
|
||||
@tool
|
||||
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Update the state of a synchronization store.
|
||||
|
||||
This will apply the updates to the store and trigger synchronization
|
||||
with all connected clients.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to update
|
||||
updates: Dictionary of field updates (field_name: new_value)
|
||||
|
||||
Returns:
|
||||
Dictionary with status and updated fields
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist or updates are invalid
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
try:
|
||||
# Get current state
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
|
||||
# Apply updates
|
||||
new_state = {**current_state, **updates}
|
||||
|
||||
# Update the model
|
||||
registry._update_model(entry.model, new_state)
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"store": store_name,
|
||||
"updated_fields": list(updates.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
||||
"""Get the schema/structure of a synchronization store.
|
||||
|
||||
This shows what fields are available and their types.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store
|
||||
|
||||
Returns:
|
||||
Dictionary describing the store's schema
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
# Get model schema
|
||||
schema = entry.model.model_json_schema()
|
||||
|
||||
return {
|
||||
"store_name": store_name,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
SYNC_TOOLS = [
|
||||
list_sync_stores,
|
||||
read_sync_state,
|
||||
write_sync_state,
|
||||
get_store_schema
|
||||
]
|
||||
@@ -6,9 +6,10 @@ the free CCXT library (not ccxt.pro), supporting both historical data and
|
||||
polling-based subscriptions.
|
||||
|
||||
Numerical Precision:
|
||||
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
|
||||
- OHLCV data uses native floats for optimal DataFrame/analysis performance
|
||||
- Account balances and order data should use Decimal (via _to_decimal method)
|
||||
- CCXT returns numeric values as strings or floats depending on configuration
|
||||
- All financial values are converted to Decimal to maintain precision
|
||||
- Price data converted to float (_to_float), financial data to Decimal (_to_decimal)
|
||||
|
||||
Real-time Updates:
|
||||
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
||||
@@ -72,6 +73,20 @@ class CCXTDataSource(DataSource):
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchange = exchange_class(self._config)
|
||||
|
||||
# Configure CCXT to use Decimal mode for precise financial calculations
|
||||
# This ensures all numeric values from the exchange use Decimal internally
|
||||
# We then convert OHLCV to float for DataFrame performance, but keep
|
||||
# Decimal precision for account balances, order sizes, etc.
|
||||
from decimal import Decimal as PythonDecimal
|
||||
self.exchange.number = PythonDecimal
|
||||
|
||||
# Log the precision mode being used by this exchange
|
||||
precision_mode = getattr(self.exchange, 'precisionMode', 'UNKNOWN')
|
||||
logger.info(
|
||||
f"CCXT {exchange_id}: Configured with Decimal mode. "
|
||||
f"Exchange precision mode: {precision_mode}"
|
||||
)
|
||||
|
||||
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
||||
self.exchange.set_sandbox_mode(True)
|
||||
|
||||
@@ -103,6 +118,33 @@ class CCXTDataSource(DataSource):
|
||||
return Decimal(str(value))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _to_float(value: Union[str, int, float, Decimal, None]) -> Optional[float]:
|
||||
"""
|
||||
Convert a value to float for OHLCV data.
|
||||
|
||||
OHLCV data is used for charting and DataFrame analysis, where native
|
||||
floats provide better performance and compatibility with pandas/numpy.
|
||||
For financial precision (balances, order sizes), use _to_decimal() instead.
|
||||
|
||||
When CCXT is in Decimal mode (exchange.number = Decimal), it returns
|
||||
Decimal objects. This method converts them to float for performance.
|
||||
|
||||
Handles CCXT's output in both modes:
|
||||
- Decimal mode: receives Decimal objects
|
||||
- Default mode: receives strings, floats, or ints
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
# CCXT in Decimal mode - convert to float for OHLCV
|
||||
return float(value)
|
||||
if isinstance(value, (str, int)):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
async def _ensure_markets_loaded(self):
|
||||
"""Ensure markets are loaded from exchange"""
|
||||
if not self._markets_loaded:
|
||||
@@ -241,31 +283,31 @@ class CCXTDataSource(DataSource):
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
@@ -370,7 +412,7 @@ class CCXTDataSource(DataSource):
|
||||
all_ohlcv = all_ohlcv[:countback]
|
||||
break
|
||||
|
||||
# Convert to our Bar format with Decimal precision
|
||||
# Convert to our Bar format with float for OHLCV (used in DataFrames)
|
||||
bars = []
|
||||
for candle in all_ohlcv:
|
||||
timestamp_ms, open_price, high, low, close, volume = candle
|
||||
@@ -384,11 +426,11 @@ class CCXTDataSource(DataSource):
|
||||
Bar(
|
||||
time=timestamp,
|
||||
data={
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
"open": self._to_float(open_price),
|
||||
"high": self._to_float(high),
|
||||
"low": self._to_float(low),
|
||||
"close": self._to_float(close),
|
||||
"volume": self._to_float(volume),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -476,14 +518,14 @@ class CCXTDataSource(DataSource):
|
||||
if timestamp > last_timestamp:
|
||||
self._last_bars[subscription_id] = timestamp
|
||||
|
||||
# Convert to our format with Decimal precision
|
||||
# Convert to our format with float for OHLCV (used in DataFrames)
|
||||
tick_data = {
|
||||
"time": timestamp,
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
"open": self._to_float(open_price),
|
||||
"high": self._to_float(high),
|
||||
"low": self._to_float(low),
|
||||
"close": self._to_float(close),
|
||||
"volume": self._to_float(volume),
|
||||
}
|
||||
|
||||
# Call the callback
|
||||
|
||||
179
backend/src/exchange_kernel/README.md
Normal file
179
backend/src/exchange_kernel/README.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# Exchange Kernel API
|
||||
|
||||
A Kubernetes-style declarative API for managing orders across different exchanges.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The Exchange Kernel maintains two separate views of order state:
|
||||
|
||||
1. **Desired State (Intent)**: What the strategy kernel wants
|
||||
2. **Actual State (Reality)**: What currently exists on the exchange
|
||||
|
||||
A reconciliation loop continuously works to bring actual state into alignment with desired state, handling errors, retries, and edge cases automatically.
|
||||
|
||||
## Core Components
|
||||
|
||||
### Models (`models.py`)
|
||||
|
||||
- **OrderIntent**: Desired order state from strategy kernel
|
||||
- **OrderState**: Actual current order state on exchange
|
||||
- **Position**: Current position (spot, margin, perp, futures, options)
|
||||
- **Asset**: Asset holdings with metadata
|
||||
- **AccountState**: Complete account snapshot (balances, positions, margin)
|
||||
- **AssetMetadata**: Asset type descriptions and trading parameters
|
||||
|
||||
### Events (`events.py`)
|
||||
|
||||
Order lifecycle events:
|
||||
- `OrderSubmitted`, `OrderAccepted`, `OrderRejected`
|
||||
- `OrderPartiallyFilled`, `OrderFilled`, `OrderCanceled`
|
||||
- `OrderModified`, `OrderExpired`
|
||||
|
||||
Position events:
|
||||
- `PositionOpened`, `PositionModified`, `PositionClosed`
|
||||
|
||||
Account events:
|
||||
- `AccountBalanceUpdated`, `MarginCallWarning`
|
||||
|
||||
### Base Interface (`base.py`)
|
||||
|
||||
Abstract `ExchangeKernel` class defining:
|
||||
|
||||
**Command API**:
|
||||
- `place_order()`, `place_order_group()` - Create order intents
|
||||
- `cancel_order()`, `modify_order()` - Update intents
|
||||
- `cancel_all_orders()` - Bulk cancellation
|
||||
|
||||
**Query API**:
|
||||
- `get_order_intent()`, `get_order_state()` - Query single order
|
||||
- `get_all_intents()`, `get_all_orders()` - Query all orders
|
||||
- `get_positions()`, `get_account_state()` - Query positions/balances
|
||||
- `get_symbol_metadata()`, `get_asset_metadata()` - Query market info
|
||||
|
||||
**Event API**:
|
||||
- `subscribe_events()`, `unsubscribe_events()` - Event notifications
|
||||
|
||||
**Lifecycle**:
|
||||
- `start()`, `stop()` - Kernel lifecycle
|
||||
- `health_check()` - Connection status
|
||||
- `force_reconciliation()` - Manual reconciliation trigger
|
||||
|
||||
### State Management (`state.py`)
|
||||
|
||||
- **IntentStateStore**: Storage for desired state (durable, survives restarts)
|
||||
- **ActualStateStore**: Storage for actual exchange state (ephemeral cache)
|
||||
- **ReconciliationEngine**: Framework for intent→reality reconciliation
|
||||
- **InMemory implementations**: For testing/prototyping
|
||||
|
||||
## Standard Order Model
|
||||
|
||||
Defined in `schema/order_spec.py`:
|
||||
|
||||
```python
|
||||
StandardOrder(
|
||||
symbol_id="BTC/USD",
|
||||
side=Side.BUY,
|
||||
amount=1.0,
|
||||
amount_type=AmountType.BASE, # or QUOTE for exact-out
|
||||
limit_price=50000.0, # None for market orders
|
||||
time_in_force=TimeInForce.GTC,
|
||||
conditional_trigger=ConditionalTrigger(...), # Optional stop-loss/take-profit
|
||||
conditional_mode=ConditionalOrderMode.UNIFIED_ADJUSTING,
|
||||
reduce_only=False,
|
||||
post_only=False,
|
||||
iceberg_qty=None,
|
||||
)
|
||||
```
|
||||
|
||||
## Symbol Metadata
|
||||
|
||||
Markets describe their capabilities via `SymbolMetadata`:
|
||||
|
||||
- **AmountConstraints**: Min/max order size, step size
|
||||
- **PriceConstraints**: Tick size, tick spacing mode (fixed/dynamic/continuous)
|
||||
- **MarketCapabilities**:
|
||||
- Supported sides (BUY, SELL)
|
||||
- Supported amount types (BASE, QUOTE, or both)
|
||||
- Market vs limit order support
|
||||
- Time-in-force options (GTC, IOC, FOK, DAY, GTD)
|
||||
- Conditional order support (stop-loss, take-profit, trailing stops)
|
||||
- Advanced features (post-only, reduce-only, iceberg)
|
||||
|
||||
## Asset Types
|
||||
|
||||
Comprehensive asset type system supporting:
|
||||
- **SPOT**: Cash markets
|
||||
- **MARGIN**: Margin trading
|
||||
- **PERP**: Perpetual futures
|
||||
- **FUTURE**: Dated futures
|
||||
- **OPTION**: Options contracts
|
||||
- **SYNTHETIC**: Derived instruments
|
||||
|
||||
Each asset has metadata describing contract specs, settlement, margin requirements, etc.
|
||||
|
||||
## Usage Pattern
|
||||
|
||||
```python
|
||||
# Create exchange kernel for specific exchange
|
||||
kernel = SomeExchangeKernel(exchange_id="binance_main")
|
||||
|
||||
# Subscribe to events
|
||||
kernel.subscribe_events(my_event_handler)
|
||||
|
||||
# Start kernel
|
||||
await kernel.start()
|
||||
|
||||
# Place order (creates intent, kernel handles execution)
|
||||
intent_id = await kernel.place_order(
|
||||
StandardOrder(
|
||||
symbol_id="BTC/USD",
|
||||
side=Side.BUY,
|
||||
amount=1.0,
|
||||
amount_type=AmountType.BASE,
|
||||
limit_price=50000.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Query desired state
|
||||
intent = await kernel.get_order_intent(intent_id)
|
||||
|
||||
# Query actual state
|
||||
state = await kernel.get_order_state(intent_id)
|
||||
|
||||
# Modify order (updates intent, kernel reconciles)
|
||||
await kernel.modify_order(intent_id, new_order)
|
||||
|
||||
# Cancel order
|
||||
await kernel.cancel_order(intent_id)
|
||||
|
||||
# Query positions
|
||||
positions = await kernel.get_positions()
|
||||
|
||||
# Query account state
|
||||
account = await kernel.get_account_state()
|
||||
```
|
||||
|
||||
## Implementation Status
|
||||
|
||||
✅ **Complete**:
|
||||
- Data models and type definitions
|
||||
- Event definitions
|
||||
- Abstract interface
|
||||
- State store framework
|
||||
- In-memory stores for testing
|
||||
|
||||
⏳ **TODO** (Exchange-specific implementations):
|
||||
- Concrete ExchangeKernel implementations per exchange
|
||||
- Reconciliation engine implementation
|
||||
- Exchange API adapters
|
||||
- Persistent state storage (database)
|
||||
- Error handling and retry logic
|
||||
- Monitoring and observability
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Create concrete implementations for specific exchanges (Binance, Uniswap, etc.)
|
||||
2. Implement reconciliation engine with proper error handling
|
||||
3. Add persistent storage backend for intents
|
||||
4. Build integration tests
|
||||
5. Add monitoring/metrics collection
|
||||
75
backend/src/exchange_kernel/__init__.py
Normal file
75
backend/src/exchange_kernel/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Exchange Kernel API
|
||||
|
||||
The exchange kernel provides a Kubernetes-style declarative API for managing orders
|
||||
across different exchanges. It maintains both desired state (intent) and actual state
|
||||
(current orders on exchange) and reconciles them continuously.
|
||||
|
||||
Key concepts:
|
||||
- OrderIntent: What the strategy kernel wants
|
||||
- OrderState: What actually exists on the exchange
|
||||
- Reconciliation: Bringing actual state into alignment with desired state
|
||||
"""
|
||||
|
||||
from .base import ExchangeKernel
|
||||
from .events import (
|
||||
OrderEvent,
|
||||
OrderSubmitted,
|
||||
OrderAccepted,
|
||||
OrderRejected,
|
||||
OrderPartiallyFilled,
|
||||
OrderFilled,
|
||||
OrderCanceled,
|
||||
OrderModified,
|
||||
OrderExpired,
|
||||
PositionEvent,
|
||||
PositionOpened,
|
||||
PositionModified,
|
||||
PositionClosed,
|
||||
AccountEvent,
|
||||
AccountBalanceUpdated,
|
||||
MarginCallWarning,
|
||||
)
|
||||
from .models import (
|
||||
OrderIntent,
|
||||
OrderState,
|
||||
Position,
|
||||
Asset,
|
||||
AssetMetadata,
|
||||
AccountState,
|
||||
Balance,
|
||||
)
|
||||
from .state import IntentStateStore, ActualStateStore
|
||||
|
||||
__all__ = [
|
||||
# Core interface
|
||||
"ExchangeKernel",
|
||||
# Events
|
||||
"OrderEvent",
|
||||
"OrderSubmitted",
|
||||
"OrderAccepted",
|
||||
"OrderRejected",
|
||||
"OrderPartiallyFilled",
|
||||
"OrderFilled",
|
||||
"OrderCanceled",
|
||||
"OrderModified",
|
||||
"OrderExpired",
|
||||
"PositionEvent",
|
||||
"PositionOpened",
|
||||
"PositionModified",
|
||||
"PositionClosed",
|
||||
"AccountEvent",
|
||||
"AccountBalanceUpdated",
|
||||
"MarginCallWarning",
|
||||
# Models
|
||||
"OrderIntent",
|
||||
"OrderState",
|
||||
"Position",
|
||||
"Asset",
|
||||
"AssetMetadata",
|
||||
"AccountState",
|
||||
"Balance",
|
||||
# State management
|
||||
"IntentStateStore",
|
||||
"ActualStateStore",
|
||||
]
|
||||
361
backend/src/exchange_kernel/base.py
Normal file
361
backend/src/exchange_kernel/base.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Base interface for Exchange Kernels.
|
||||
|
||||
Defines the abstract API that all exchange kernel implementations must support.
|
||||
Each exchange (or exchange type) will have its own kernel implementation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any
|
||||
|
||||
from .models import (
|
||||
OrderIntent,
|
||||
OrderState,
|
||||
Position,
|
||||
AccountState,
|
||||
AssetMetadata,
|
||||
)
|
||||
from .events import BaseEvent
|
||||
from ..schema.order_spec import (
|
||||
StandardOrder,
|
||||
StandardOrderGroup,
|
||||
SymbolMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ExchangeKernel(ABC):
|
||||
"""
|
||||
Abstract base class for exchange kernels.
|
||||
|
||||
An exchange kernel manages the lifecycle of orders on a specific exchange,
|
||||
maintaining both desired state (intents from strategy kernel) and actual
|
||||
state (current orders on exchange), and continuously reconciling them.
|
||||
|
||||
Think of it as a Kubernetes-style controller for trading orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_id: str):
|
||||
"""
|
||||
Initialize the exchange kernel.
|
||||
|
||||
Args:
|
||||
exchange_id: Unique identifier for this exchange instance
|
||||
"""
|
||||
self.exchange_id = exchange_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Command API - Strategy kernel sends intents
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def place_order(self, order: StandardOrder, metadata: dict[str, Any] | None = None) -> str:
|
||||
"""
|
||||
Place a single order on the exchange.
|
||||
|
||||
This creates an OrderIntent and begins the reconciliation process to
|
||||
get the order onto the exchange.
|
||||
|
||||
Args:
|
||||
order: The order specification
|
||||
metadata: Optional strategy-specific metadata
|
||||
|
||||
Returns:
|
||||
intent_id: Unique identifier for this order intent
|
||||
|
||||
Raises:
|
||||
ValidationError: If order violates market constraints
|
||||
ExchangeError: If exchange rejects the order
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def place_order_group(
|
||||
self,
|
||||
group: StandardOrderGroup,
|
||||
metadata: dict[str, Any] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Place a group of orders with OCO (One-Cancels-Other) relationship.
|
||||
|
||||
Args:
|
||||
group: Group of orders with OCO mode
|
||||
metadata: Optional strategy-specific metadata
|
||||
|
||||
Returns:
|
||||
intent_ids: List of intent IDs for each order in the group
|
||||
|
||||
Raises:
|
||||
ValidationError: If any order violates market constraints
|
||||
ExchangeError: If exchange rejects the group
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_order(self, intent_id: str) -> None:
|
||||
"""
|
||||
Cancel an order by intent ID.
|
||||
|
||||
Updates the intent to indicate cancellation is desired, and the
|
||||
reconciliation loop will handle the actual exchange cancellation.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID of the order to cancel
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
ExchangeError: If exchange rejects cancellation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def modify_order(
|
||||
self,
|
||||
intent_id: str,
|
||||
new_order: StandardOrder,
|
||||
) -> None:
|
||||
"""
|
||||
Modify an existing order.
|
||||
|
||||
Updates the order intent, and the reconciliation loop will update
|
||||
the exchange order (via modify API if available, or cancel+replace).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID of the order to modify
|
||||
new_order: New order specification
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
ValidationError: If new order violates market constraints
|
||||
ExchangeError: If exchange rejects modification
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_all_orders(self, symbol_id: str | None = None) -> int:
|
||||
"""
|
||||
Cancel all orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only cancel orders for this symbol
|
||||
|
||||
Returns:
|
||||
count: Number of orders canceled
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Query API - Read desired and actual state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_intent(self, intent_id: str) -> OrderIntent:
|
||||
"""
|
||||
Get the desired order state (what strategy kernel wants).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The order intent
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
"""
|
||||
Get the actual order state (what's currently on exchange).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The current order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_intents(self, symbol_id: str | None = None) -> list[OrderIntent]:
|
||||
"""
|
||||
Get all order intents, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return intents for this symbol
|
||||
|
||||
Returns:
|
||||
List of order intents
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_orders(self, symbol_id: str | None = None) -> list[OrderState]:
|
||||
"""
|
||||
Get all actual order states, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return orders for this symbol
|
||||
|
||||
Returns:
|
||||
List of order states
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_positions(self, symbol_id: str | None = None) -> list[Position]:
|
||||
"""
|
||||
Get current positions, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return positions for this symbol
|
||||
|
||||
Returns:
|
||||
List of positions
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_account_state(self) -> AccountState:
|
||||
"""
|
||||
Get current account state (balances, margin, etc.).
|
||||
|
||||
Returns:
|
||||
Current account state
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_symbol_metadata(self, symbol_id: str) -> SymbolMetadata:
|
||||
"""
|
||||
Get metadata for a symbol (constraints, capabilities, etc.).
|
||||
|
||||
Args:
|
||||
symbol_id: Symbol to query
|
||||
|
||||
Returns:
|
||||
Symbol metadata
|
||||
|
||||
Raises:
|
||||
NotFoundError: If symbol doesn't exist on this exchange
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_asset_metadata(self, asset_id: str) -> AssetMetadata:
|
||||
"""
|
||||
Get metadata for an asset.
|
||||
|
||||
Args:
|
||||
asset_id: Asset to query
|
||||
|
||||
Returns:
|
||||
Asset metadata
|
||||
|
||||
Raises:
|
||||
NotFoundError: If asset doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_symbols(self) -> list[str]:
|
||||
"""
|
||||
List all available symbols on this exchange.
|
||||
|
||||
Returns:
|
||||
List of symbol IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Event Subscription API
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def subscribe_events(
|
||||
self,
|
||||
callback: Callable[[BaseEvent], None],
|
||||
event_filter: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to events from this exchange kernel.
|
||||
|
||||
Args:
|
||||
callback: Function to call when events occur
|
||||
event_filter: Optional filter criteria (event_type, symbol_id, etc.)
|
||||
|
||||
Returns:
|
||||
subscription_id: Unique ID for this subscription (for unsubscribe)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unsubscribe_events(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe from events.
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription ID returned from subscribe_events
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lifecycle Management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the exchange kernel.
|
||||
|
||||
Initializes connections, starts reconciliation loops, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the exchange kernel.
|
||||
|
||||
Closes connections, stops reconciliation loops, etc.
|
||||
Does NOT cancel open orders - call cancel_all_orders() first if desired.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""
|
||||
Check health status of the exchange kernel.
|
||||
|
||||
Returns:
|
||||
Health status dict with connection state, latency, error counts, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Reconciliation Control (advanced)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def force_reconciliation(self, intent_id: str | None = None) -> None:
|
||||
"""
|
||||
Force immediate reconciliation.
|
||||
|
||||
Args:
|
||||
intent_id: If provided, only reconcile this specific intent.
|
||||
If None, reconcile all intents.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_reconciliation_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get metrics about the reconciliation process.
|
||||
|
||||
Returns:
|
||||
Metrics dict with reconciliation lag, error rates, retry counts, etc.
|
||||
"""
|
||||
pass
|
||||
250
backend/src/exchange_kernel/events.py
Normal file
250
backend/src/exchange_kernel/events.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Event definitions for the Exchange Kernel.
|
||||
|
||||
All events that can occur during the order lifecycle, position management,
|
||||
and account updates.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..schema.order_spec import Float, Uint64
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base Event Classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EventType(StrEnum):
|
||||
"""Types of events emitted by the exchange kernel"""
|
||||
# Order lifecycle
|
||||
ORDER_SUBMITTED = "ORDER_SUBMITTED"
|
||||
ORDER_ACCEPTED = "ORDER_ACCEPTED"
|
||||
ORDER_REJECTED = "ORDER_REJECTED"
|
||||
ORDER_PARTIALLY_FILLED = "ORDER_PARTIALLY_FILLED"
|
||||
ORDER_FILLED = "ORDER_FILLED"
|
||||
ORDER_CANCELED = "ORDER_CANCELED"
|
||||
ORDER_MODIFIED = "ORDER_MODIFIED"
|
||||
ORDER_EXPIRED = "ORDER_EXPIRED"
|
||||
|
||||
# Position events
|
||||
POSITION_OPENED = "POSITION_OPENED"
|
||||
POSITION_MODIFIED = "POSITION_MODIFIED"
|
||||
POSITION_CLOSED = "POSITION_CLOSED"
|
||||
|
||||
# Account events
|
||||
ACCOUNT_BALANCE_UPDATED = "ACCOUNT_BALANCE_UPDATED"
|
||||
MARGIN_CALL_WARNING = "MARGIN_CALL_WARNING"
|
||||
|
||||
# System events
|
||||
RECONCILIATION_FAILED = "RECONCILIATION_FAILED"
|
||||
CONNECTION_LOST = "CONNECTION_LOST"
|
||||
CONNECTION_RESTORED = "CONNECTION_RESTORED"
|
||||
|
||||
|
||||
class BaseEvent(BaseModel):
|
||||
"""Base class for all exchange kernel events"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
event_type: EventType = Field(description="Type of event")
|
||||
timestamp: Uint64 = Field(description="Event timestamp (Unix milliseconds)")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional event data")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OrderEvent(BaseEvent):
|
||||
"""Base class for order-related events"""
|
||||
|
||||
intent_id: str = Field(description="Order intent ID")
|
||||
order_id: str | None = Field(default=None, description="Exchange order ID (if assigned)")
|
||||
symbol_id: str = Field(description="Symbol being traded")
|
||||
|
||||
|
||||
class OrderSubmitted(OrderEvent):
|
||||
"""Order has been submitted to the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_SUBMITTED)
|
||||
client_order_id: str | None = Field(default=None, description="Client-assigned order ID")
|
||||
|
||||
|
||||
class OrderAccepted(OrderEvent):
|
||||
"""Order has been accepted by the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_ACCEPTED)
|
||||
order_id: str = Field(description="Exchange-assigned order ID")
|
||||
accepted_at: Uint64 = Field(description="Exchange acceptance timestamp")
|
||||
|
||||
|
||||
class OrderRejected(OrderEvent):
|
||||
"""Order was rejected by the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_REJECTED)
|
||||
reason: str = Field(description="Rejection reason")
|
||||
error_code: str | None = Field(default=None, description="Exchange error code")
|
||||
|
||||
|
||||
class OrderPartiallyFilled(OrderEvent):
|
||||
"""Order was partially filled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_PARTIALLY_FILLED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
fill_price: Float = Field(description="Fill price for this execution")
|
||||
fill_quantity: Float = Field(description="Quantity filled in this execution")
|
||||
total_filled: Float = Field(description="Total quantity filled so far")
|
||||
remaining_quantity: Float = Field(description="Remaining quantity to fill")
|
||||
commission: Float = Field(default=0.0, description="Commission/fee for this fill")
|
||||
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||
trade_id: str | None = Field(default=None, description="Exchange trade ID")
|
||||
|
||||
|
||||
class OrderFilled(OrderEvent):
|
||||
"""Order was completely filled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_FILLED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
average_fill_price: Float = Field(description="Average execution price")
|
||||
total_quantity: Float = Field(description="Total quantity filled")
|
||||
total_commission: Float = Field(default=0.0, description="Total commission/fees")
|
||||
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||
completed_at: Uint64 = Field(description="Completion timestamp")
|
||||
|
||||
|
||||
class OrderCanceled(OrderEvent):
|
||||
"""Order was canceled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_CANCELED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
reason: str = Field(description="Cancellation reason")
|
||||
filled_quantity: Float = Field(default=0.0, description="Quantity filled before cancellation")
|
||||
canceled_at: Uint64 = Field(description="Cancellation timestamp")
|
||||
|
||||
|
||||
class OrderModified(OrderEvent):
|
||||
"""Order was modified (price, quantity, etc.)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_MODIFIED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
old_price: Float | None = Field(default=None, description="Previous price")
|
||||
new_price: Float | None = Field(default=None, description="New price")
|
||||
old_quantity: Float | None = Field(default=None, description="Previous quantity")
|
||||
new_quantity: Float | None = Field(default=None, description="New quantity")
|
||||
modified_at: Uint64 = Field(description="Modification timestamp")
|
||||
|
||||
|
||||
class OrderExpired(OrderEvent):
|
||||
"""Order expired (GTD, DAY orders)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_EXPIRED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
filled_quantity: Float = Field(default=0.0, description="Quantity filled before expiration")
|
||||
expired_at: Uint64 = Field(description="Expiration timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PositionEvent(BaseEvent):
|
||||
"""Base class for position-related events"""
|
||||
|
||||
position_id: str = Field(description="Position identifier")
|
||||
symbol_id: str = Field(description="Symbol identifier")
|
||||
asset_id: str = Field(description="Asset identifier")
|
||||
|
||||
|
||||
class PositionOpened(PositionEvent):
|
||||
"""New position was opened"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_OPENED)
|
||||
quantity: Float = Field(description="Position quantity")
|
||||
entry_price: Float = Field(description="Entry price")
|
||||
side: str = Field(description="LONG or SHORT")
|
||||
leverage: Float | None = Field(default=None, description="Leverage")
|
||||
|
||||
|
||||
class PositionModified(PositionEvent):
|
||||
"""Existing position was modified (size change, etc.)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_MODIFIED)
|
||||
old_quantity: Float = Field(description="Previous quantity")
|
||||
new_quantity: Float = Field(description="New quantity")
|
||||
average_entry_price: Float = Field(description="Updated average entry price")
|
||||
unrealized_pnl: Float | None = Field(default=None, description="Current unrealized P&L")
|
||||
|
||||
|
||||
class PositionClosed(PositionEvent):
|
||||
"""Position was closed"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_CLOSED)
|
||||
exit_price: Float = Field(description="Exit price")
|
||||
realized_pnl: Float = Field(description="Realized profit/loss")
|
||||
closed_at: Uint64 = Field(description="Closure timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AccountEvent(BaseEvent):
|
||||
"""Base class for account-related events"""
|
||||
|
||||
account_id: str = Field(description="Account identifier")
|
||||
|
||||
|
||||
class AccountBalanceUpdated(AccountEvent):
|
||||
"""Account balance was updated"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ACCOUNT_BALANCE_UPDATED)
|
||||
asset_id: str = Field(description="Asset that changed")
|
||||
old_balance: Float = Field(description="Previous balance")
|
||||
new_balance: Float = Field(description="New balance")
|
||||
old_available: Float = Field(description="Previous available")
|
||||
new_available: Float = Field(description="New available")
|
||||
change_reason: str = Field(description="Why balance changed (TRADE, DEPOSIT, WITHDRAWAL, etc.)")
|
||||
|
||||
|
||||
class MarginCallWarning(AccountEvent):
|
||||
"""Margin level is approaching liquidation threshold"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.MARGIN_CALL_WARNING)
|
||||
margin_level: Float = Field(description="Current margin level")
|
||||
liquidation_threshold: Float = Field(description="Liquidation threshold")
|
||||
required_action: str = Field(description="Required action to avoid liquidation")
|
||||
estimated_liquidation_price: Float | None = Field(
|
||||
default=None,
|
||||
description="Estimated liquidation price for positions"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReconciliationFailed(BaseEvent):
|
||||
"""Failed to reconcile intent with actual state"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.RECONCILIATION_FAILED)
|
||||
intent_id: str = Field(description="Order intent ID")
|
||||
error_message: str = Field(description="Error details")
|
||||
retry_count: int = Field(description="Number of retry attempts")
|
||||
|
||||
|
||||
class ConnectionLost(BaseEvent):
|
||||
"""Connection to exchange was lost"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.CONNECTION_LOST)
|
||||
reason: str = Field(description="Disconnection reason")
|
||||
|
||||
|
||||
class ConnectionRestored(BaseEvent):
|
||||
"""Connection to exchange was restored"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.CONNECTION_RESTORED)
|
||||
downtime_duration: int = Field(description="Duration of downtime in milliseconds")
|
||||
194
backend/src/exchange_kernel/models.py
Normal file
194
backend/src/exchange_kernel/models.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Data models for the Exchange Kernel.
|
||||
|
||||
Defines order intents, order state, positions, assets, and account state.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..schema.order_spec import (
|
||||
StandardOrder,
|
||||
StandardOrderStatus,
|
||||
AssetType,
|
||||
Float,
|
||||
Uint64,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order Intent and State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OrderIntent(BaseModel):
|
||||
"""
|
||||
Desired order state from the strategy kernel.
|
||||
|
||||
This represents what the strategy wants, not what currently exists.
|
||||
The exchange kernel will work to reconcile actual state with this intent.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intent_id: str = Field(description="Unique identifier for this intent (client-assigned)")
|
||||
order: StandardOrder = Field(description="The desired order specification")
|
||||
group_id: str | None = Field(default=None, description="Group ID for OCO relationships")
|
||||
created_at: Uint64 = Field(description="When this intent was created")
|
||||
updated_at: Uint64 = Field(description="When this intent was last modified")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific metadata")
|
||||
|
||||
|
||||
class ReconciliationStatus(StrEnum):
|
||||
"""Status of reconciliation between intent and actual state"""
|
||||
PENDING = "PENDING" # Not yet submitted to exchange
|
||||
SUBMITTING = "SUBMITTING" # Currently being submitted
|
||||
ACTIVE = "ACTIVE" # Successfully placed on exchange
|
||||
RECONCILING = "RECONCILING" # Intent changed, updating exchange order
|
||||
FAILED = "FAILED" # Failed to submit or reconcile
|
||||
COMPLETED = "COMPLETED" # Order fully filled
|
||||
CANCELED = "CANCELED" # Order canceled
|
||||
|
||||
|
||||
class OrderState(BaseModel):
|
||||
"""
|
||||
Actual current state of an order on the exchange.
|
||||
|
||||
This represents reality - what the exchange reports about the order.
|
||||
May differ from OrderIntent during reconciliation.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intent_id: str = Field(description="Links back to the OrderIntent")
|
||||
exchange_order_id: str = Field(description="Exchange-assigned order ID")
|
||||
status: StandardOrderStatus = Field(description="Current order status from exchange")
|
||||
reconciliation_status: ReconciliationStatus = Field(description="Reconciliation state")
|
||||
last_sync_at: Uint64 = Field(description="Last time we synced with exchange")
|
||||
error_message: str | None = Field(default=None, description="Error details if FAILED")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position and Asset Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AssetMetadata(BaseModel):
|
||||
"""
|
||||
Metadata describing an asset type.
|
||||
|
||||
Provides context for positions, balances, and trading.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="Unique asset identifier")
|
||||
symbol: str = Field(description="Asset symbol (e.g., 'BTC', 'ETH', 'USD')")
|
||||
asset_type: AssetType = Field(description="Type of asset")
|
||||
name: str = Field(description="Full name")
|
||||
|
||||
# Contract specifications (for derivatives)
|
||||
contract_size: Float | None = Field(default=None, description="Contract multiplier")
|
||||
settlement_asset: str | None = Field(default=None, description="Settlement currency")
|
||||
expiry_timestamp: Uint64 | None = Field(default=None, description="Expiration timestamp")
|
||||
|
||||
# Trading parameters
|
||||
tick_size: Float | None = Field(default=None, description="Minimum price increment")
|
||||
lot_size: Float | None = Field(default=None, description="Minimum quantity increment")
|
||||
|
||||
# Margin requirements (for leveraged products)
|
||||
initial_margin_rate: Float | None = Field(default=None, description="Initial margin requirement")
|
||||
maintenance_margin_rate: Float | None = Field(default=None, description="Maintenance margin requirement")
|
||||
|
||||
# Additional metadata
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific metadata")
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
"""
|
||||
An asset holding (spot, margin, derivative position, etc.)
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="References AssetMetadata")
|
||||
quantity: Float = Field(description="Amount held (positive or negative for short positions)")
|
||||
available: Float = Field(description="Amount available for trading (not locked in orders)")
|
||||
locked: Float = Field(description="Amount locked in open orders")
|
||||
|
||||
# For derivative positions
|
||||
entry_price: Float | None = Field(default=None, description="Average entry price")
|
||||
mark_price: Float | None = Field(default=None, description="Current mark price")
|
||||
liquidation_price: Float | None = Field(default=None, description="Estimated liquidation price")
|
||||
unrealized_pnl: Float | None = Field(default=None, description="Unrealized profit/loss")
|
||||
realized_pnl: Float | None = Field(default=None, description="Realized profit/loss")
|
||||
|
||||
# Margin info
|
||||
margin_used: Float | None = Field(default=None, description="Margin allocated to this position")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""
|
||||
A trading position (spot, margin, perpetual, futures, etc.)
|
||||
|
||||
Tracks both the asset holdings and associated metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
position_id: str = Field(description="Unique position identifier")
|
||||
symbol_id: str = Field(description="Trading symbol")
|
||||
asset: Asset = Field(description="Asset holding details")
|
||||
metadata: AssetMetadata = Field(description="Asset metadata")
|
||||
|
||||
# Position-level info
|
||||
leverage: Float | None = Field(default=None, description="Current leverage")
|
||||
side: str | None = Field(default=None, description="LONG or SHORT (for derivatives)")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class Balance(BaseModel):
|
||||
"""Account balance for a single currency/asset"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="Asset identifier")
|
||||
total: Float = Field(description="Total balance")
|
||||
available: Float = Field(description="Available for trading")
|
||||
locked: Float = Field(description="Locked in orders/positions")
|
||||
|
||||
# For margin accounts
|
||||
borrowed: Float = Field(default=0.0, description="Borrowed amount (margin)")
|
||||
interest: Float = Field(default=0.0, description="Accrued interest")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class AccountState(BaseModel):
|
||||
"""
|
||||
Complete account state including balances, positions, and margin info.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
account_id: str = Field(description="Account identifier")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
|
||||
balances: list[Balance] = Field(default_factory=list, description="All asset balances")
|
||||
positions: list[Position] = Field(default_factory=list, description="All open positions")
|
||||
|
||||
# Margin account info
|
||||
total_equity: Float | None = Field(default=None, description="Total account equity")
|
||||
total_margin_used: Float | None = Field(default=None, description="Total margin in use")
|
||||
total_available_margin: Float | None = Field(default=None, description="Available margin")
|
||||
margin_level: Float | None = Field(default=None, description="Margin level (equity/margin_used)")
|
||||
|
||||
# Risk metrics
|
||||
total_unrealized_pnl: Float | None = Field(default=None, description="Total unrealized P&L")
|
||||
total_realized_pnl: Float | None = Field(default=None, description="Total realized P&L")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific data")
|
||||
472
backend/src/exchange_kernel/state.py
Normal file
472
backend/src/exchange_kernel/state.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
State management for the Exchange Kernel.
|
||||
|
||||
Implements the storage and reconciliation logic for desired vs actual state.
|
||||
This is the "Kubernetes for orders" concept - maintaining intent and continuously
|
||||
reconciling reality to match intent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
from .models import OrderIntent, OrderState, ReconciliationStatus
|
||||
from ..schema.order_spec import Uint64
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Intent State Store - Desired State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class IntentStateStore(ABC):
|
||||
"""
|
||||
Storage for order intents (desired state).
|
||||
|
||||
This represents what the strategy kernel wants. Intents are durable and
|
||||
persist across restarts. The reconciliation loop continuously works to
|
||||
make actual state match these intents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_intent(self, intent: OrderIntent) -> None:
|
||||
"""
|
||||
Store a new order intent.
|
||||
|
||||
Args:
|
||||
intent: The order intent to store
|
||||
|
||||
Raises:
|
||||
AlreadyExistsError: If intent_id already exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||
"""
|
||||
Retrieve an order intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to retrieve
|
||||
|
||||
Returns:
|
||||
The order intent
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_intent(self, intent: OrderIntent) -> None:
|
||||
"""
|
||||
Update an existing order intent.
|
||||
|
||||
Args:
|
||||
intent: Updated intent (intent_id must match existing)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_intent(self, intent_id: str) -> None:
|
||||
"""
|
||||
Delete an order intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to delete
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_intents(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> list[OrderIntent]:
|
||||
"""
|
||||
List all order intents, optionally filtered.
|
||||
|
||||
Args:
|
||||
symbol_id: Filter by symbol
|
||||
group_id: Filter by OCO group
|
||||
|
||||
Returns:
|
||||
List of matching intents
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||
"""
|
||||
Get all intents in an OCO group.
|
||||
|
||||
Args:
|
||||
group_id: Group ID to query
|
||||
|
||||
Returns:
|
||||
List of intents in the group
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Actual State Store - Current Reality
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ActualStateStore(ABC):
|
||||
"""
|
||||
Storage for actual order state (reality on exchange).
|
||||
|
||||
This represents what actually exists on the exchange right now.
|
||||
Updated frequently from exchange feeds and order status queries.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_order_state(self, state: OrderState) -> None:
|
||||
"""
|
||||
Store a new order state.
|
||||
|
||||
Args:
|
||||
state: The order state to store
|
||||
|
||||
Raises:
|
||||
AlreadyExistsError: If order state for this intent_id already exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
"""
|
||||
Retrieve order state for an intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The current order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no state exists for this intent
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||
"""
|
||||
Retrieve order state by exchange order ID.
|
||||
|
||||
Useful for processing exchange callbacks that only provide exchange_order_id.
|
||||
|
||||
Args:
|
||||
exchange_order_id: Exchange's order ID
|
||||
|
||||
Returns:
|
||||
The order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no state exists for this exchange order ID
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_order_state(self, state: OrderState) -> None:
|
||||
"""
|
||||
Update an existing order state.
|
||||
|
||||
Args:
|
||||
state: Updated state (intent_id must match existing)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If state doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_order_state(self, intent_id: str) -> None:
|
||||
"""
|
||||
Delete an order state.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID whose state to delete
|
||||
|
||||
Raises:
|
||||
NotFoundError: If state doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_order_states(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
reconciliation_status: ReconciliationStatus | None = None,
|
||||
) -> list[OrderState]:
|
||||
"""
|
||||
List all order states, optionally filtered.
|
||||
|
||||
Args:
|
||||
symbol_id: Filter by symbol
|
||||
reconciliation_status: Filter by reconciliation status
|
||||
|
||||
Returns:
|
||||
List of matching order states
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||
"""
|
||||
Find orders that haven't been synced recently.
|
||||
|
||||
Used to identify orders that need status updates from exchange.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age since last sync
|
||||
|
||||
Returns:
|
||||
List of order states that need refresh
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-Memory Implementations (for testing/prototyping)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InMemoryIntentStore(IntentStateStore):
|
||||
"""Simple in-memory implementation of IntentStateStore"""
|
||||
|
||||
def __init__(self):
|
||||
self._intents: dict[str, OrderIntent] = {}
|
||||
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||
self._by_group: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
async def create_intent(self, intent: OrderIntent) -> None:
|
||||
if intent.intent_id in self._intents:
|
||||
raise ValueError(f"Intent {intent.intent_id} already exists")
|
||||
self._intents[intent.intent_id] = intent
|
||||
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].add(intent.intent_id)
|
||||
|
||||
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||
if intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent_id} not found")
|
||||
return self._intents[intent_id]
|
||||
|
||||
async def update_intent(self, intent: OrderIntent) -> None:
|
||||
if intent.intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent.intent_id} not found")
|
||||
old_intent = self._intents[intent.intent_id]
|
||||
|
||||
# Update indices if symbol or group changed
|
||||
if old_intent.order.symbol_id != intent.order.symbol_id:
|
||||
self._by_symbol[old_intent.order.symbol_id].discard(intent.intent_id)
|
||||
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||
|
||||
if old_intent.group_id != intent.group_id:
|
||||
if old_intent.group_id:
|
||||
self._by_group[old_intent.group_id].discard(intent.intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].add(intent.intent_id)
|
||||
|
||||
self._intents[intent.intent_id] = intent
|
||||
|
||||
async def delete_intent(self, intent_id: str) -> None:
|
||||
if intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent_id} not found")
|
||||
intent = self._intents[intent_id]
|
||||
self._by_symbol[intent.order.symbol_id].discard(intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].discard(intent_id)
|
||||
del self._intents[intent_id]
|
||||
|
||||
async def list_intents(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> list[OrderIntent]:
|
||||
if symbol_id and group_id:
|
||||
# Intersection of both filters
|
||||
symbol_ids = self._by_symbol.get(symbol_id, set())
|
||||
group_ids = self._by_group.get(group_id, set())
|
||||
intent_ids = symbol_ids & group_ids
|
||||
elif symbol_id:
|
||||
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||
elif group_id:
|
||||
intent_ids = self._by_group.get(group_id, set())
|
||||
else:
|
||||
intent_ids = self._intents.keys()
|
||||
|
||||
return [self._intents[iid] for iid in intent_ids]
|
||||
|
||||
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||
intent_ids = self._by_group.get(group_id, set())
|
||||
return [self._intents[iid] for iid in intent_ids]
|
||||
|
||||
|
||||
class InMemoryActualStateStore(ActualStateStore):
|
||||
"""Simple in-memory implementation of ActualStateStore"""
|
||||
|
||||
def __init__(self):
|
||||
self._states: dict[str, OrderState] = {}
|
||||
self._by_exchange_id: dict[str, str] = {} # exchange_order_id -> intent_id
|
||||
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
async def create_order_state(self, state: OrderState) -> None:
|
||||
if state.intent_id in self._states:
|
||||
raise ValueError(f"Order state for intent {state.intent_id} already exists")
|
||||
self._states[state.intent_id] = state
|
||||
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||
self._by_symbol[state.status.order.symbol_id].add(state.intent_id)
|
||||
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
if intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||
return self._states[intent_id]
|
||||
|
||||
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||
if exchange_order_id not in self._by_exchange_id:
|
||||
raise KeyError(f"Order state for exchange order {exchange_order_id} not found")
|
||||
intent_id = self._by_exchange_id[exchange_order_id]
|
||||
return self._states[intent_id]
|
||||
|
||||
async def update_order_state(self, state: OrderState) -> None:
|
||||
if state.intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {state.intent_id} not found")
|
||||
old_state = self._states[state.intent_id]
|
||||
|
||||
# Update exchange_id index if it changed
|
||||
if old_state.exchange_order_id != state.exchange_order_id:
|
||||
del self._by_exchange_id[old_state.exchange_order_id]
|
||||
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||
|
||||
# Update symbol index if it changed
|
||||
old_symbol = old_state.status.order.symbol_id
|
||||
new_symbol = state.status.order.symbol_id
|
||||
if old_symbol != new_symbol:
|
||||
self._by_symbol[old_symbol].discard(state.intent_id)
|
||||
self._by_symbol[new_symbol].add(state.intent_id)
|
||||
|
||||
self._states[state.intent_id] = state
|
||||
|
||||
async def delete_order_state(self, intent_id: str) -> None:
|
||||
if intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||
state = self._states[intent_id]
|
||||
del self._by_exchange_id[state.exchange_order_id]
|
||||
self._by_symbol[state.status.order.symbol_id].discard(intent_id)
|
||||
del self._states[intent_id]
|
||||
|
||||
async def list_order_states(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
reconciliation_status: ReconciliationStatus | None = None,
|
||||
) -> list[OrderState]:
|
||||
if symbol_id:
|
||||
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||
states = [self._states[iid] for iid in intent_ids]
|
||||
else:
|
||||
states = list(self._states.values())
|
||||
|
||||
if reconciliation_status:
|
||||
states = [s for s in states if s.reconciliation_status == reconciliation_status]
|
||||
|
||||
return states
|
||||
|
||||
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||
import time
|
||||
current_time = int(time.time())
|
||||
threshold = current_time - max_age_seconds
|
||||
|
||||
return [
|
||||
state
|
||||
for state in self._states.values()
|
||||
if state.last_sync_at < threshold
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconciliation Engine (framework only, no implementation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReconciliationEngine:
|
||||
"""
|
||||
Reconciliation engine that continuously works to make actual state match intent.
|
||||
|
||||
This is the heart of the "Kubernetes for orders" concept. It:
|
||||
1. Compares desired state (intents) with actual state (exchange orders)
|
||||
2. Computes necessary actions (place, modify, cancel)
|
||||
3. Executes those actions via the exchange API
|
||||
4. Handles retries, errors, and edge cases
|
||||
|
||||
This is a framework class - concrete implementations will be exchange-specific.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_store: IntentStateStore,
|
||||
actual_store: ActualStateStore,
|
||||
):
|
||||
"""
|
||||
Initialize the reconciliation engine.
|
||||
|
||||
Args:
|
||||
intent_store: Store for desired state
|
||||
actual_store: Store for actual state
|
||||
"""
|
||||
self.intent_store = intent_store
|
||||
self.actual_store = actual_store
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the reconciliation loop"""
|
||||
self._running = True
|
||||
# Implementation would start async reconciliation loop here
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the reconciliation loop"""
|
||||
self._running = False
|
||||
# Implementation would stop reconciliation loop here
|
||||
pass
|
||||
|
||||
async def reconcile_intent(self, intent_id: str) -> None:
|
||||
"""
|
||||
Reconcile a specific intent.
|
||||
|
||||
Compares the intent with actual state and takes necessary actions.
|
||||
|
||||
Args:
|
||||
intent_id: Intent to reconcile
|
||||
"""
|
||||
# Framework only - concrete implementation needed
|
||||
pass
|
||||
|
||||
async def reconcile_all(self) -> None:
|
||||
"""
|
||||
Reconcile all intents.
|
||||
|
||||
Full reconciliation pass over all orders.
|
||||
"""
|
||||
# Framework only - concrete implementation needed
|
||||
pass
|
||||
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get reconciliation metrics.
|
||||
|
||||
Returns:
|
||||
Metrics about reconciliation performance, errors, etc.
|
||||
"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"reconciliation_lag_ms": 0, # Framework only
|
||||
"pending_reconciliations": 0, # Framework only
|
||||
"error_count": 0, # Framework only
|
||||
"retry_count": 0, # Framework only
|
||||
}
|
||||
@@ -94,6 +94,11 @@ class Gateway:
|
||||
logger.info(f"Session is busy, interrupting existing task")
|
||||
await session.interrupt()
|
||||
|
||||
# Check if this is a stop interrupt (empty message)
|
||||
if not message.content.strip() and not message.attachments:
|
||||
logger.info("Received stop interrupt (empty message), not starting new agent round")
|
||||
return
|
||||
|
||||
# Add user message to history
|
||||
session.add_message("user", message.content, message.channel_id)
|
||||
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
||||
@@ -134,33 +139,55 @@ class Gateway:
|
||||
# Stream chunks back to active channels
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response_stream:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
accumulated_metadata = {}
|
||||
|
||||
# Send chunk to all active channels
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
async for chunk in response_stream:
|
||||
# Handle dict response with metadata (from agent executor)
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
metadata = chunk.get("metadata", {})
|
||||
# Accumulate metadata (e.g., plot_urls)
|
||||
for key, value in metadata.items():
|
||||
if key == "plot_urls" and value:
|
||||
# Append to existing plot_urls
|
||||
if "plot_urls" not in accumulated_metadata:
|
||||
accumulated_metadata["plot_urls"] = []
|
||||
accumulated_metadata["plot_urls"].extend(value)
|
||||
logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}")
|
||||
else:
|
||||
accumulated_metadata[key] = value
|
||||
chunk = content
|
||||
|
||||
# Only send non-empty chunks
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
|
||||
# Send chunk to all active channels with accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False,
|
||||
metadata=accumulated_metadata.copy()
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
|
||||
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
||||
|
||||
# Send final done message
|
||||
# Send final done message with all accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="",
|
||||
stream_chunk=True,
|
||||
done=True
|
||||
done=True,
|
||||
metadata=accumulated_metadata
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
logger.info("Sent final done message to channels")
|
||||
logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}")
|
||||
|
||||
# Add to history
|
||||
session.add_message("assistant", full_response)
|
||||
|
||||
172
backend/src/indicator/__init__.py
Normal file
172
backend/src/indicator/__init__.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Composable Indicator System.
|
||||
|
||||
Provides a framework for building DAGs of data transformation pipelines
|
||||
that process time-series data incrementally. Indicators can consume
|
||||
DataSources or other Indicators as inputs, composing into arbitrarily
|
||||
complex processing graphs.
|
||||
|
||||
Key Components:
|
||||
---------------
|
||||
|
||||
Indicator (base.py):
|
||||
Abstract base class for all indicator implementations.
|
||||
Declares input/output schemas and implements synchronous compute().
|
||||
|
||||
IndicatorRegistry (registry.py):
|
||||
Central catalog of available indicators with rich metadata
|
||||
for AI agent discovery and tool generation.
|
||||
|
||||
Pipeline (pipeline.py):
|
||||
Execution engine that builds DAGs, resolves dependencies,
|
||||
and orchestrates incremental data flow through indicator chains.
|
||||
|
||||
Schema Types (schema.py):
|
||||
Type definitions for input/output schemas, computation context,
|
||||
and metadata for AI-native documentation.
|
||||
|
||||
Usage Example:
|
||||
--------------
|
||||
|
||||
from indicator import Indicator, IndicatorRegistry, Pipeline
|
||||
from indicator.schema import (
|
||||
InputSchema, OutputSchema, ComputeContext, ComputeResult,
|
||||
IndicatorMetadata, IndicatorParameter
|
||||
)
|
||||
|
||||
# Define an indicator
|
||||
class SimpleMovingAverage(Indicator):
|
||||
@classmethod
|
||||
def get_metadata(cls):
|
||||
return IndicatorMetadata(
|
||||
name="SMA",
|
||||
display_name="Simple Moving Average",
|
||||
description="Arithmetic mean of prices over N periods",
|
||||
category="trend",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="period",
|
||||
type="int",
|
||||
description="Number of periods to average",
|
||||
default=20,
|
||||
min_value=1
|
||||
)
|
||||
],
|
||||
tags=["moving-average", "trend-following"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls):
|
||||
return InputSchema(
|
||||
required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params):
|
||||
return OutputSchema(
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="sma",
|
||||
type="float",
|
||||
description=f"Simple moving average over {params.get('period', 20)} periods"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
period = self.params["period"]
|
||||
closes = context.get_column("close")
|
||||
times = context.get_times()
|
||||
|
||||
sma_values = []
|
||||
for i in range(len(closes)):
|
||||
if i < period - 1:
|
||||
sma_values.append(None)
|
||||
else:
|
||||
window = closes[i - period + 1 : i + 1]
|
||||
sma_values.append(sum(window) / period)
|
||||
|
||||
return ComputeResult(
|
||||
data=[
|
||||
{"time": times[i], "sma": sma_values[i]}
|
||||
for i in range(len(times))
|
||||
]
|
||||
)
|
||||
|
||||
# Register the indicator
|
||||
registry = IndicatorRegistry()
|
||||
registry.register(SimpleMovingAverage)
|
||||
|
||||
# Create a pipeline
|
||||
pipeline = Pipeline(datasource_registry)
|
||||
pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D")
|
||||
|
||||
sma_indicator = registry.create_instance("SMA", "sma_20", period=20)
|
||||
pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"])
|
||||
|
||||
# Execute
|
||||
results = pipeline.execute(datasource_data={"price_data": price_bars})
|
||||
sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma
|
||||
|
||||
Design Philosophy:
|
||||
------------------
|
||||
|
||||
1. **Schema-based composition**: Indicators declare inputs/outputs via schemas,
|
||||
enabling automatic validation and flexible composition.
|
||||
|
||||
2. **Synchronous execution**: All computation is synchronous for simplicity.
|
||||
Async handling happens at the event/strategy layer.
|
||||
|
||||
3. **Incremental updates**: Indicators receive context about what changed,
|
||||
allowing optimized recomputation of only affected values.
|
||||
|
||||
4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs
|
||||
make indicators discoverable and usable by AI agents.
|
||||
|
||||
5. **Generic data flow**: Indicators work with any data source that matches
|
||||
their input schema, not specific DataSource instances.
|
||||
|
||||
6. **Event-driven**: Designed to react to DataSource updates and propagate
|
||||
changes through the DAG efficiently.
|
||||
"""
|
||||
|
||||
from .base import DataSourceAdapter, Indicator
|
||||
from .pipeline import Pipeline, PipelineNode
|
||||
from .registry import IndicatorRegistry
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
from .talib_adapter import (
|
||||
TALibIndicator,
|
||||
register_all_talib_indicators,
|
||||
is_talib_available,
|
||||
get_talib_version,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"Indicator",
|
||||
"IndicatorRegistry",
|
||||
"Pipeline",
|
||||
"PipelineNode",
|
||||
"DataSourceAdapter",
|
||||
# Schema types
|
||||
"InputSchema",
|
||||
"OutputSchema",
|
||||
"ComputeContext",
|
||||
"ComputeResult",
|
||||
"IndicatorMetadata",
|
||||
"IndicatorParameter",
|
||||
# TA-Lib integration
|
||||
"TALibIndicator",
|
||||
"register_all_talib_indicators",
|
||||
"is_talib_available",
|
||||
"get_talib_version",
|
||||
]
|
||||
230
backend/src/indicator/base.py
Normal file
230
backend/src/indicator/base.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Abstract Indicator interface.
|
||||
|
||||
Provides the base class for all technical indicators and derived data transformations.
|
||||
Indicators compose into DAGs, processing data incrementally as updates arrive.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
|
||||
class Indicator(ABC):
|
||||
"""
|
||||
Abstract base class for all indicators.
|
||||
|
||||
Indicators are composable transformation nodes that:
|
||||
- Declare input schema (columns they need)
|
||||
- Declare output schema (columns they produce)
|
||||
- Compute outputs synchronously from inputs
|
||||
- Support incremental updates (process only what changed)
|
||||
- Provide rich metadata for AI agent discovery
|
||||
|
||||
Indicators are stateless at the instance level - all state is managed
|
||||
by the pipeline execution engine. This allows the same indicator class
|
||||
to be reused with different parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_name: str, **params):
|
||||
"""
|
||||
Initialize an indicator instance.
|
||||
|
||||
Args:
|
||||
instance_name: Unique name for this instance (used for output column prefixing)
|
||||
**params: Configuration parameters (validated against metadata.parameters)
|
||||
"""
|
||||
self.instance_name = instance_name
|
||||
self.params = params
|
||||
self._validate_params()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
"""
|
||||
Get metadata for this indicator class.
|
||||
|
||||
Called by the registry for AI agent discovery and documentation.
|
||||
Should return comprehensive information about the indicator's purpose,
|
||||
parameters, and use cases.
|
||||
|
||||
Returns:
|
||||
IndicatorMetadata describing this indicator class
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
"""
|
||||
Get the input schema required by this indicator.
|
||||
|
||||
Declares what columns must be present in the input data.
|
||||
The pipeline will match this against available data sources.
|
||||
|
||||
Returns:
|
||||
InputSchema describing required and optional input columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
"""
|
||||
Get the output schema produced by this indicator.
|
||||
|
||||
Output column names will be automatically prefixed with the instance name
|
||||
by the pipeline engine.
|
||||
|
||||
Args:
|
||||
**params: Configuration parameters (may affect output schema)
|
||||
|
||||
Returns:
|
||||
OutputSchema describing the columns this indicator produces
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
"""
|
||||
Compute indicator values from input data.
|
||||
|
||||
This method is called synchronously by the pipeline engine whenever
|
||||
input data changes. Implementations should:
|
||||
|
||||
1. Extract needed columns from context.data
|
||||
2. Perform calculations
|
||||
3. Return results with proper time alignment
|
||||
|
||||
For incremental updates (context.is_incremental == True):
|
||||
- context.data contains only new/updated rows
|
||||
- Implementations MAY optimize by computing only these rows
|
||||
- OR implementations MAY recompute everything (simpler but slower)
|
||||
|
||||
Args:
|
||||
context: Input data and update metadata
|
||||
|
||||
Returns:
|
||||
ComputeResult with calculated indicator values
|
||||
|
||||
Raises:
|
||||
ValueError: If input data doesn't match expected schema
|
||||
"""
|
||||
pass
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
"""
|
||||
Validate that provided parameters match the metadata specification.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing or invalid
|
||||
"""
|
||||
metadata = self.get_metadata()
|
||||
|
||||
# Check for required parameters
|
||||
for param_def in metadata.parameters:
|
||||
if param_def.required and param_def.name not in self.params:
|
||||
raise ValueError(
|
||||
f"Indicator '{metadata.name}' requires parameter '{param_def.name}'"
|
||||
)
|
||||
|
||||
# Validate parameter types and ranges
|
||||
for name, value in self.params.items():
|
||||
# Find parameter definition
|
||||
param_def = next(
|
||||
(p for p in metadata.parameters if p.name == name),
|
||||
None
|
||||
)
|
||||
|
||||
if param_def is None:
|
||||
raise ValueError(
|
||||
f"Unknown parameter '{name}' for indicator '{metadata.name}'"
|
||||
)
|
||||
|
||||
# Type checking
|
||||
if param_def.type == "int" and not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be int, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "float" and not isinstance(value, (int, float)):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be float, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "bool" and not isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be bool, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "string" and not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be string, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
# Range checking for numeric types
|
||||
if param_def.type in ("int", "float"):
|
||||
if param_def.min_value is not None and value < param_def.min_value:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be >= {param_def.min_value}, got {value}"
|
||||
)
|
||||
if param_def.max_value is not None and value > param_def.max_value:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be <= {param_def.max_value}, got {value}"
|
||||
)
|
||||
|
||||
def get_output_columns(self) -> List[str]:
|
||||
"""
|
||||
Get the output column names with instance name prefix.
|
||||
|
||||
Returns:
|
||||
List of prefixed output column names
|
||||
"""
|
||||
output_schema = self.get_output_schema(**self.params)
|
||||
prefixed = output_schema.with_prefix(self.instance_name)
|
||||
return [col.name for col in prefixed.columns if col.name != output_schema.time_column]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})"
|
||||
|
||||
|
||||
class DataSourceAdapter:
|
||||
"""
|
||||
Adapter to make a DataSource look like an Indicator for pipeline composition.
|
||||
|
||||
This allows DataSources to be inputs to indicators in a unified way.
|
||||
"""
|
||||
|
||||
def __init__(self, datasource_id: str, symbol: str, resolution: str):
|
||||
"""
|
||||
Create a DataSource adapter.
|
||||
|
||||
Args:
|
||||
datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo')
|
||||
symbol: Symbol to query (e.g., 'BTC/USD')
|
||||
resolution: Time resolution (e.g., '1', '5', '1D')
|
||||
"""
|
||||
self.datasource_id = datasource_id
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_")
|
||||
|
||||
def get_output_columns(self) -> List[str]:
|
||||
"""
|
||||
Get the columns provided by this datasource.
|
||||
|
||||
Note: This requires runtime resolution - the pipeline engine
|
||||
will need to query the actual DataSource to get the schema.
|
||||
|
||||
Returns:
|
||||
List of column names (placeholder - needs runtime resolution)
|
||||
"""
|
||||
# This will be resolved at runtime by the pipeline engine
|
||||
return []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')"
|
||||
439
backend/src/indicator/pipeline.py
Normal file
439
backend/src/indicator/pipeline.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Pipeline execution engine for composable indicators.
|
||||
|
||||
Manages DAG construction, dependency resolution, incremental updates,
|
||||
and efficient data flow through indicator chains.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from datasource.base import DataSource
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
from .base import DataSourceAdapter, Indicator
|
||||
from .schema import ComputeContext, ComputeResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineNode:
|
||||
"""
|
||||
A node in the pipeline DAG.
|
||||
|
||||
Can be either a DataSource adapter or an Indicator instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
node: Union[DataSourceAdapter, Indicator],
|
||||
dependencies: List[str]
|
||||
):
|
||||
"""
|
||||
Create a pipeline node.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
node: The DataSourceAdapter or Indicator instance
|
||||
dependencies: List of node_ids this node depends on
|
||||
"""
|
||||
self.node_id = node_id
|
||||
self.node = node
|
||||
self.dependencies = dependencies
|
||||
self.output_columns: List[str] = []
|
||||
self.cached_data: List[Dict[str, Any]] = []
|
||||
|
||||
def is_datasource(self) -> bool:
|
||||
"""Check if this node is a DataSource adapter."""
|
||||
return isinstance(self.node, DataSourceAdapter)
|
||||
|
||||
def is_indicator(self) -> bool:
|
||||
"""Check if this node is an Indicator."""
|
||||
return isinstance(self.node, Indicator)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Execution engine for indicator DAGs.
|
||||
|
||||
Manages:
|
||||
- DAG construction and validation
|
||||
- Topological sorting for execution order
|
||||
- Data flow and caching
|
||||
- Incremental updates (only recompute what changed)
|
||||
- Schema validation
|
||||
"""
|
||||
|
||||
def __init__(self, datasource_registry):
|
||||
"""
|
||||
Initialize a pipeline.
|
||||
|
||||
Args:
|
||||
datasource_registry: DataSourceRegistry for resolving data sources
|
||||
"""
|
||||
self.datasource_registry = datasource_registry
|
||||
self.nodes: Dict[str, PipelineNode] = {}
|
||||
self.execution_order: List[str] = []
|
||||
self._dirty_nodes: Set[str] = set()
|
||||
|
||||
def add_datasource(
|
||||
self,
|
||||
node_id: str,
|
||||
datasource_name: str,
|
||||
symbol: str,
|
||||
resolution: str
|
||||
) -> None:
|
||||
"""
|
||||
Add a DataSource to the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
datasource_name: Name of the datasource in the registry
|
||||
symbol: Symbol to query
|
||||
resolution: Time resolution
|
||||
|
||||
Raises:
|
||||
ValueError: If node_id already exists or datasource not found
|
||||
"""
|
||||
if node_id in self.nodes:
|
||||
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||
|
||||
datasource = self.datasource_registry.get(datasource_name)
|
||||
if not datasource:
|
||||
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
|
||||
|
||||
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
|
||||
node = PipelineNode(node_id, adapter, dependencies=[])
|
||||
|
||||
self.nodes[node_id] = node
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
|
||||
|
||||
def add_indicator(
|
||||
self,
|
||||
node_id: str,
|
||||
indicator: Indicator,
|
||||
input_node_ids: List[str]
|
||||
) -> None:
|
||||
"""
|
||||
Add an Indicator to the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
indicator: Indicator instance
|
||||
input_node_ids: List of node IDs providing input data
|
||||
|
||||
Raises:
|
||||
ValueError: If node_id already exists, dependencies not found, or schema mismatch
|
||||
"""
|
||||
if node_id in self.nodes:
|
||||
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||
|
||||
# Validate dependencies exist
|
||||
for dep_id in input_node_ids:
|
||||
if dep_id not in self.nodes:
|
||||
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
|
||||
|
||||
# TODO: Validate input schema matches available columns from dependencies
|
||||
# This requires merging output schemas from all input nodes
|
||||
|
||||
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
|
||||
self.nodes[node_id] = node
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
|
||||
|
||||
def remove_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Remove a node from the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Node to remove
|
||||
|
||||
Raises:
|
||||
ValueError: If other nodes depend on this node
|
||||
"""
|
||||
if node_id not in self.nodes:
|
||||
return
|
||||
|
||||
# Check for dependent nodes
|
||||
dependents = [
|
||||
n.node_id for n in self.nodes.values()
|
||||
if node_id in n.dependencies
|
||||
]
|
||||
|
||||
if dependents:
|
||||
raise ValueError(
|
||||
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
|
||||
)
|
||||
|
||||
del self.nodes[node_id]
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Removed node '{node_id}' from pipeline")
|
||||
|
||||
def _invalidate_execution_order(self) -> None:
|
||||
"""Mark execution order as needing recomputation."""
|
||||
self.execution_order = []
|
||||
|
||||
def _compute_execution_order(self) -> List[str]:
|
||||
"""
|
||||
Compute topological sort of the DAG.
|
||||
|
||||
Returns:
|
||||
List of node IDs in execution order
|
||||
|
||||
Raises:
|
||||
ValueError: If DAG contains cycles
|
||||
"""
|
||||
if self.execution_order:
|
||||
return self.execution_order
|
||||
|
||||
# Kahn's algorithm for topological sort
|
||||
in_degree = {node_id: 0 for node_id in self.nodes}
|
||||
for node in self.nodes.values():
|
||||
for dep in node.dependencies:
|
||||
in_degree[node.node_id] += 1
|
||||
|
||||
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
node_id = queue.popleft()
|
||||
result.append(node_id)
|
||||
|
||||
# Find all nodes that depend on this one
|
||||
for other_node in self.nodes.values():
|
||||
if node_id in other_node.dependencies:
|
||||
in_degree[other_node.node_id] -= 1
|
||||
if in_degree[other_node.node_id] == 0:
|
||||
queue.append(other_node.node_id)
|
||||
|
||||
if len(result) != len(self.nodes):
|
||||
raise ValueError("Pipeline contains cycles")
|
||||
|
||||
self.execution_order = result
|
||||
logger.debug(f"Computed execution order: {result}")
|
||||
return result
|
||||
|
||||
def execute(
|
||||
self,
|
||||
datasource_data: Dict[str, List[Dict[str, Any]]],
|
||||
incremental: bool = False,
|
||||
updated_from_time: Optional[int] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Execute the pipeline.
|
||||
|
||||
Args:
|
||||
datasource_data: Mapping of DataSource node_id to input data
|
||||
incremental: Whether this is an incremental update
|
||||
updated_from_time: Timestamp of earliest updated row (for incremental)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node_id to output data (all nodes)
|
||||
|
||||
Raises:
|
||||
ValueError: If required datasource data is missing
|
||||
"""
|
||||
execution_order = self._compute_execution_order()
|
||||
results: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
logger.info(
|
||||
f"Executing pipeline with {len(execution_order)} nodes "
|
||||
f"(incremental={incremental})"
|
||||
)
|
||||
|
||||
for node_id in execution_order:
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if node.is_datasource():
|
||||
# DataSource node - get data from input
|
||||
if node_id not in datasource_data:
|
||||
raise ValueError(
|
||||
f"DataSource node '{node_id}' has no input data"
|
||||
)
|
||||
results[node_id] = datasource_data[node_id]
|
||||
node.cached_data = results[node_id]
|
||||
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
|
||||
|
||||
elif node.is_indicator():
|
||||
# Indicator node - compute from dependencies
|
||||
indicator = node.node
|
||||
|
||||
# Merge input data from all dependencies
|
||||
input_data = self._merge_dependency_data(node.dependencies, results)
|
||||
|
||||
# Create compute context
|
||||
context = ComputeContext(
|
||||
data=input_data,
|
||||
is_incremental=incremental,
|
||||
updated_from_time=updated_from_time
|
||||
)
|
||||
|
||||
# Execute indicator
|
||||
logger.debug(
|
||||
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
|
||||
)
|
||||
compute_result = indicator.compute(context)
|
||||
|
||||
# Merge result with input data (adding prefixed columns)
|
||||
output_data = compute_result.merge_with_prefix(
|
||||
indicator.instance_name,
|
||||
input_data
|
||||
)
|
||||
|
||||
results[node_id] = output_data
|
||||
node.cached_data = output_data
|
||||
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
|
||||
|
||||
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
|
||||
return results
|
||||
|
||||
def _merge_dependency_data(
|
||||
self,
|
||||
dependency_ids: List[str],
|
||||
results: Dict[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge data from multiple dependency nodes.
|
||||
|
||||
Data is merged by time, with later dependencies overwriting earlier ones
|
||||
for conflicting column names.
|
||||
|
||||
Args:
|
||||
dependency_ids: List of node IDs to merge
|
||||
results: Current execution results
|
||||
|
||||
Returns:
|
||||
Merged data rows
|
||||
"""
|
||||
if not dependency_ids:
|
||||
return []
|
||||
|
||||
if len(dependency_ids) == 1:
|
||||
return results[dependency_ids[0]]
|
||||
|
||||
# Build time-indexed data from first dependency
|
||||
merged: Dict[int, Dict[str, Any]] = {}
|
||||
for row in results[dependency_ids[0]]:
|
||||
merged[row["time"]] = row.copy()
|
||||
|
||||
# Merge in additional dependencies
|
||||
for dep_id in dependency_ids[1:]:
|
||||
for row in results[dep_id]:
|
||||
time_key = row["time"]
|
||||
if time_key in merged:
|
||||
# Merge columns (later dependencies win)
|
||||
merged[time_key].update(row)
|
||||
else:
|
||||
# New timestamp
|
||||
merged[time_key] = row.copy()
|
||||
|
||||
# Sort by time and return
|
||||
sorted_times = sorted(merged.keys())
|
||||
return [merged[t] for t in sorted_times]
|
||||
|
||||
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached output data for a specific node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
Cached data or None if not available
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
return node.cached_data if node else None
|
||||
|
||||
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
|
||||
"""
|
||||
Get the output schema for a specific node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
List of ColumnInfo describing output columns
|
||||
|
||||
Raises:
|
||||
ValueError: If node not found
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
if not node:
|
||||
raise ValueError(f"Node '{node_id}' not found")
|
||||
|
||||
if node.is_datasource():
|
||||
# Would need to query the actual datasource at runtime
|
||||
# For now, return empty - this requires integration with DataSource
|
||||
return []
|
||||
|
||||
elif node.is_indicator():
|
||||
indicator = node.node
|
||||
output_schema = indicator.get_output_schema(**indicator.params)
|
||||
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
|
||||
return prefixed_schema.columns
|
||||
|
||||
return []
|
||||
|
||||
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate the entire pipeline for correctness.
|
||||
|
||||
Checks:
|
||||
- No cycles (already checked in execution order)
|
||||
- All dependencies exist (already checked in add_indicator)
|
||||
- Input schemas match output schemas (TODO)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
self._compute_execution_order()
|
||||
return True, None
|
||||
except ValueError as e:
|
||||
return False, str(e)
|
||||
|
||||
def get_node_count(self) -> int:
|
||||
"""Get the number of nodes in the pipeline."""
|
||||
return len(self.nodes)
|
||||
|
||||
def get_indicator_count(self) -> int:
|
||||
"""Get the number of indicator nodes in the pipeline."""
|
||||
return sum(1 for node in self.nodes.values() if node.is_indicator())
|
||||
|
||||
def get_datasource_count(self) -> int:
|
||||
"""Get the number of datasource nodes in the pipeline."""
|
||||
return sum(1 for node in self.nodes.values() if node.is_datasource())
|
||||
|
||||
def describe(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a detailed description of the pipeline structure.
|
||||
|
||||
Returns:
|
||||
Dictionary with pipeline metadata and structure
|
||||
"""
|
||||
return {
|
||||
"node_count": self.get_node_count(),
|
||||
"datasource_count": self.get_datasource_count(),
|
||||
"indicator_count": self.get_indicator_count(),
|
||||
"nodes": [
|
||||
{
|
||||
"id": node.node_id,
|
||||
"type": "datasource" if node.is_datasource() else "indicator",
|
||||
"node": str(node.node),
|
||||
"dependencies": node.dependencies,
|
||||
"cached_rows": len(node.cached_data)
|
||||
}
|
||||
for node in self.nodes.values()
|
||||
],
|
||||
"execution_order": self.execution_order or self._compute_execution_order(),
|
||||
"is_valid": self.validate_pipeline()[0]
|
||||
}
|
||||
349
backend/src/indicator/registry.py
Normal file
349
backend/src/indicator/registry.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Indicator registry for managing and discovering indicators.
|
||||
|
||||
Provides AI agents with a queryable catalog of available indicators,
|
||||
their capabilities, and metadata.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from .base import Indicator
|
||||
from .schema import IndicatorMetadata, InputSchema, OutputSchema
|
||||
|
||||
|
||||
class IndicatorRegistry:
|
||||
"""
|
||||
Central registry for indicator classes.
|
||||
|
||||
Enables:
|
||||
- Registration of indicator implementations
|
||||
- Discovery by name, category, or tags
|
||||
- Schema validation
|
||||
- AI agent tool generation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._indicators: Dict[str, Type[Indicator]] = {}
|
||||
|
||||
def register(self, indicator_class: Type[Indicator]) -> None:
|
||||
"""
|
||||
Register an indicator class.
|
||||
|
||||
Args:
|
||||
indicator_class: Indicator class to register
|
||||
|
||||
Raises:
|
||||
ValueError: If an indicator with this name is already registered
|
||||
"""
|
||||
metadata = indicator_class.get_metadata()
|
||||
|
||||
if metadata.name in self._indicators:
|
||||
raise ValueError(
|
||||
f"Indicator '{metadata.name}' is already registered"
|
||||
)
|
||||
|
||||
self._indicators[metadata.name] = indicator_class
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""
|
||||
Unregister an indicator class.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
"""
|
||||
self._indicators.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Optional[Type[Indicator]]:
|
||||
"""
|
||||
Get an indicator class by name.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
Indicator class or None if not found
|
||||
"""
|
||||
return self._indicators.get(name)
|
||||
|
||||
def list_indicators(self) -> List[str]:
|
||||
"""
|
||||
Get names of all registered indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator class names
|
||||
"""
|
||||
return list(self._indicators.keys())
|
||||
|
||||
def get_metadata(self, name: str) -> Optional[IndicatorMetadata]:
|
||||
"""
|
||||
Get metadata for a specific indicator.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
IndicatorMetadata or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_metadata()
|
||||
return None
|
||||
|
||||
def get_all_metadata(self) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Get metadata for all registered indicators.
|
||||
|
||||
Useful for AI agent tool generation and discovery.
|
||||
|
||||
Returns:
|
||||
List of IndicatorMetadata for all registered indicators
|
||||
"""
|
||||
return [cls.get_metadata() for cls in self._indicators.values()]
|
||||
|
||||
def search_by_category(self, category: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators by category.
|
||||
|
||||
Args:
|
||||
category: Category name (e.g., 'momentum', 'trend', 'volatility')
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata
|
||||
"""
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
if metadata.category.lower() == category.lower():
|
||||
results.append(metadata)
|
||||
return results
|
||||
|
||||
def search_by_tag(self, tag: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators by tag.
|
||||
|
||||
Args:
|
||||
tag: Tag to search for (case-insensitive)
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata
|
||||
"""
|
||||
tag_lower = tag.lower()
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
if any(t.lower() == tag_lower for t in metadata.tags):
|
||||
results.append(metadata)
|
||||
return results
|
||||
|
||||
def search_by_text(self, query: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Full-text search across indicator names, descriptions, and use cases.
|
||||
|
||||
Args:
|
||||
query: Search query (case-insensitive)
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata, ranked by relevance
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
score = 0
|
||||
|
||||
# Check name (highest weight)
|
||||
if query_lower in metadata.name.lower():
|
||||
score += 10
|
||||
if query_lower in metadata.display_name.lower():
|
||||
score += 8
|
||||
|
||||
# Check description
|
||||
if query_lower in metadata.description.lower():
|
||||
score += 5
|
||||
|
||||
# Check use cases
|
||||
for use_case in metadata.use_cases:
|
||||
if query_lower in use_case.lower():
|
||||
score += 3
|
||||
|
||||
# Check tags
|
||||
for tag in metadata.tags:
|
||||
if query_lower in tag.lower():
|
||||
score += 2
|
||||
|
||||
if score > 0:
|
||||
results.append((score, metadata))
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
return [metadata for _, metadata in results]
|
||||
|
||||
def find_compatible_indicators(
|
||||
self,
|
||||
available_columns: List[str],
|
||||
column_types: Dict[str, str]
|
||||
) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators that can be computed from available columns.
|
||||
|
||||
Args:
|
||||
available_columns: List of column names available
|
||||
column_types: Mapping of column name to type
|
||||
|
||||
Returns:
|
||||
List of indicators whose input schema is satisfied
|
||||
"""
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
# Build ColumnInfo list from available data
|
||||
available_schema = [
|
||||
ColumnInfo(
|
||||
name=name,
|
||||
type=column_types.get(name, "float"),
|
||||
description=f"Column {name}"
|
||||
)
|
||||
for name in available_columns
|
||||
]
|
||||
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
input_schema = indicator_class.get_input_schema()
|
||||
if input_schema.matches(available_schema):
|
||||
results.append(indicator_class.get_metadata())
|
||||
|
||||
return results
|
||||
|
||||
def validate_indicator_chain(
|
||||
self,
|
||||
indicator_chain: List[tuple[str, Dict]]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate that a chain of indicators can be connected.
|
||||
|
||||
Args:
|
||||
indicator_chain: List of (indicator_name, params) tuples in execution order
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not indicator_chain:
|
||||
return True, None
|
||||
|
||||
# For now, just check that all indicators exist
|
||||
# More sophisticated DAG validation happens in the pipeline engine
|
||||
for indicator_name, params in indicator_chain:
|
||||
if indicator_name not in self._indicators:
|
||||
return False, f"Indicator '{indicator_name}' not found in registry"
|
||||
|
||||
return True, None
|
||||
|
||||
def get_input_schema(self, name: str) -> Optional[InputSchema]:
|
||||
"""
|
||||
Get input schema for a specific indicator.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
InputSchema or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_input_schema()
|
||||
return None
|
||||
|
||||
def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]:
|
||||
"""
|
||||
Get output schema for a specific indicator with given parameters.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
**params: Indicator parameters
|
||||
|
||||
Returns:
|
||||
OutputSchema or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_output_schema(**params)
|
||||
return None
|
||||
|
||||
def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]:
|
||||
"""
|
||||
Create an indicator instance with validation.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
instance_name: Unique instance name (for output column prefixing)
|
||||
**params: Indicator configuration parameters
|
||||
|
||||
Returns:
|
||||
Indicator instance or None if class not found
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if not indicator_class:
|
||||
return None
|
||||
|
||||
return indicator_class(instance_name=instance_name, **params)
|
||||
|
||||
def generate_ai_tool_spec(self) -> Dict:
|
||||
"""
|
||||
Generate a JSON specification for AI agent tools.
|
||||
|
||||
Creates a structured representation of all indicators that can be
|
||||
used to build agent tools for indicator selection and composition.
|
||||
|
||||
Returns:
|
||||
Dict suitable for AI agent tool registration
|
||||
"""
|
||||
tools = []
|
||||
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
|
||||
# Build parameter spec
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
for param in metadata.parameters:
|
||||
param_spec = {
|
||||
"type": param.type,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
param_spec["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
param_spec["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
param_spec["maximum"] = param.max_value
|
||||
|
||||
parameters["properties"][param.name] = param_spec
|
||||
|
||||
if param.required:
|
||||
parameters["required"].append(param.name)
|
||||
|
||||
tool = {
|
||||
"name": f"indicator_{metadata.name.lower()}",
|
||||
"description": f"{metadata.display_name}: {metadata.description}",
|
||||
"category": metadata.category,
|
||||
"use_cases": metadata.use_cases,
|
||||
"tags": metadata.tags,
|
||||
"parameters": parameters,
|
||||
"input_schema": indicator_class.get_input_schema().model_dump(),
|
||||
"output_schema": indicator_class.get_output_schema().model_dump()
|
||||
}
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return {
|
||||
"indicator_tools": tools,
|
||||
"total_count": len(tools)
|
||||
}
|
||||
269
backend/src/indicator/schema.py
Normal file
269
backend/src/indicator/schema.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Data models for the Indicator system.
|
||||
|
||||
Defines schemas for input/output specifications, computation context,
|
||||
and metadata for AI agent discovery.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
"""
|
||||
Declares the required input columns for an Indicator.
|
||||
|
||||
Indicators match against any data source (DataSource or other Indicator)
|
||||
that provides columns satisfying this schema.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
required_columns: List[ColumnInfo] = Field(
|
||||
description="Columns that must be present in the input data"
|
||||
)
|
||||
optional_columns: List[ColumnInfo] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that may be used if present but are not required"
|
||||
)
|
||||
time_column: str = Field(
|
||||
default="time",
|
||||
description="Name of the timestamp column (must be present)"
|
||||
)
|
||||
|
||||
def matches(self, available_columns: List[ColumnInfo]) -> bool:
|
||||
"""
|
||||
Check if available columns satisfy this input schema.
|
||||
|
||||
Args:
|
||||
available_columns: Columns provided by a data source
|
||||
|
||||
Returns:
|
||||
True if all required columns are present with compatible types
|
||||
"""
|
||||
available_map = {col.name: col for col in available_columns}
|
||||
|
||||
# Check time column exists
|
||||
if self.time_column not in available_map:
|
||||
return False
|
||||
|
||||
# Check all required columns exist with compatible types
|
||||
for required in self.required_columns:
|
||||
if required.name not in available_map:
|
||||
return False
|
||||
available = available_map[required.name]
|
||||
if available.type != required.type:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]:
|
||||
"""
|
||||
Get list of missing required column names.
|
||||
|
||||
Args:
|
||||
available_columns: Columns provided by a data source
|
||||
|
||||
Returns:
|
||||
List of missing column names
|
||||
"""
|
||||
available_names = {col.name for col in available_columns}
|
||||
missing = []
|
||||
|
||||
if self.time_column not in available_names:
|
||||
missing.append(self.time_column)
|
||||
|
||||
for required in self.required_columns:
|
||||
if required.name not in available_names:
|
||||
missing.append(required.name)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
"""
|
||||
Declares the output columns produced by an Indicator.
|
||||
|
||||
Column names will be automatically prefixed with the indicator instance name
|
||||
to avoid collisions in the pipeline.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
columns: List[ColumnInfo] = Field(
|
||||
description="Output columns produced by this indicator"
|
||||
)
|
||||
time_column: str = Field(
|
||||
default="time",
|
||||
description="Name of the timestamp column (passed through from input)"
|
||||
)
|
||||
|
||||
def with_prefix(self, prefix: str) -> "OutputSchema":
|
||||
"""
|
||||
Create a new OutputSchema with all column names prefixed.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to add (e.g., indicator instance name)
|
||||
|
||||
Returns:
|
||||
New OutputSchema with prefixed column names
|
||||
"""
|
||||
prefixed_columns = [
|
||||
ColumnInfo(
|
||||
name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name,
|
||||
type=col.type,
|
||||
description=col.description,
|
||||
unit=col.unit,
|
||||
nullable=col.nullable
|
||||
)
|
||||
for col in self.columns
|
||||
]
|
||||
return OutputSchema(
|
||||
columns=prefixed_columns,
|
||||
time_column=self.time_column
|
||||
)
|
||||
|
||||
|
||||
class IndicatorParameter(BaseModel):
|
||||
"""
|
||||
Metadata for a configurable indicator parameter.
|
||||
|
||||
Used for AI agent discovery and dynamic indicator instantiation.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Parameter name")
|
||||
type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type")
|
||||
description: str = Field(description="Human and LLM-readable description")
|
||||
default: Optional[Any] = Field(default=None, description="Default value if not specified")
|
||||
required: bool = Field(default=False, description="Whether this parameter is required")
|
||||
min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)")
|
||||
max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)")
|
||||
|
||||
|
||||
class IndicatorMetadata(BaseModel):
|
||||
"""
|
||||
Rich metadata for an Indicator class.
|
||||
|
||||
Enables AI agents to discover, understand, and instantiate indicators.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')")
|
||||
display_name: str = Field(description="Human-readable display name")
|
||||
description: str = Field(description="Detailed description of what this indicator computes and why it's useful")
|
||||
category: str = Field(
|
||||
description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')"
|
||||
)
|
||||
parameters: List[IndicatorParameter] = Field(
|
||||
default_factory=list,
|
||||
description="Configurable parameters for this indicator"
|
||||
)
|
||||
use_cases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Common use cases and trading scenarios where this indicator is helpful"
|
||||
)
|
||||
references: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="URLs or citations for indicator methodology"
|
||||
)
|
||||
tags: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')"
|
||||
)
|
||||
|
||||
|
||||
class ComputeContext(BaseModel):
|
||||
"""
|
||||
Context passed to an Indicator's compute() method.
|
||||
|
||||
Contains the input data and metadata about what changed (for incremental updates).
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}"
|
||||
)
|
||||
is_incremental: bool = Field(
|
||||
default=False,
|
||||
description="True if this is an incremental update (only new/changed rows), False for full recompute"
|
||||
)
|
||||
updated_from_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Unix timestamp (ms) of the earliest updated row (for incremental updates)"
|
||||
)
|
||||
|
||||
def get_column(self, name: str) -> List[Any]:
|
||||
"""
|
||||
Extract a single column as a list of values.
|
||||
|
||||
Args:
|
||||
name: Column name
|
||||
|
||||
Returns:
|
||||
List of values in time order
|
||||
"""
|
||||
return [row.get(name) for row in self.data]
|
||||
|
||||
def get_times(self) -> List[int]:
|
||||
"""
|
||||
Get the time column as a list.
|
||||
|
||||
Returns:
|
||||
List of timestamps in order
|
||||
"""
|
||||
return [row["time"] for row in self.data]
|
||||
|
||||
|
||||
class ComputeResult(BaseModel):
|
||||
"""
|
||||
Result from an Indicator's compute() method.
|
||||
|
||||
Contains the computed output data with proper column naming.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
description="Output data rows (time-ordered). Must include time column."
|
||||
)
|
||||
is_partial: bool = Field(
|
||||
default=False,
|
||||
description="True if this result only contains updates (for incremental computation)"
|
||||
)
|
||||
|
||||
def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge this result into existing data with column name prefixing.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to add to all column names except time
|
||||
existing_data: Existing data to merge with (matched by time)
|
||||
|
||||
Returns:
|
||||
Merged data with prefixed columns added
|
||||
"""
|
||||
# Build a time index for new data
|
||||
time_index = {row["time"]: row for row in self.data}
|
||||
|
||||
# Merge into existing data
|
||||
result = []
|
||||
for existing_row in existing_data:
|
||||
row_time = existing_row["time"]
|
||||
merged_row = existing_row.copy()
|
||||
|
||||
if row_time in time_index:
|
||||
new_row = time_index[row_time]
|
||||
for key, value in new_row.items():
|
||||
if key != "time":
|
||||
merged_row[f"{prefix}_{key}"] = value
|
||||
|
||||
result.append(merged_row)
|
||||
|
||||
return result
|
||||
436
backend/src/indicator/talib_adapter.py
Normal file
436
backend/src/indicator/talib_adapter.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
TA-Lib indicator adapter.
|
||||
|
||||
Provides automatic registration of all TA-Lib technical indicators
|
||||
as composable Indicator instances.
|
||||
|
||||
Installation Requirements:
|
||||
--------------------------
|
||||
TA-Lib requires both the C library and Python wrapper:
|
||||
|
||||
1. Install TA-Lib C library:
|
||||
- Ubuntu/Debian: sudo apt-get install libta-lib-dev
|
||||
- macOS: brew install ta-lib
|
||||
- From source: https://ta-lib.org/install.html
|
||||
|
||||
2. Install Python wrapper (already in requirements.txt):
|
||||
pip install TA-Lib
|
||||
|
||||
Usage:
|
||||
------
|
||||
from indicator.talib_adapter import register_all_talib_indicators
|
||||
|
||||
# Auto-register all TA-Lib indicators
|
||||
registry = IndicatorRegistry()
|
||||
register_all_talib_indicators(registry)
|
||||
|
||||
# Now you can use any TA-Lib indicator
|
||||
sma = registry.create_instance("SMA", "sma_20", period=20)
|
||||
rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import talib
|
||||
from talib import abstract
|
||||
TALIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
TALIB_AVAILABLE = False
|
||||
talib = None
|
||||
abstract = None
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib parameter types to our schema types
|
||||
TALIB_TYPE_MAP = {
|
||||
"double": "float",
|
||||
"double[]": "float",
|
||||
"int": "int",
|
||||
"str": "string",
|
||||
}
|
||||
|
||||
# Categorization of TA-Lib functions
|
||||
TALIB_CATEGORIES = {
|
||||
"overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3",
|
||||
"BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"],
|
||||
"momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX",
|
||||
"ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI",
|
||||
"STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC",
|
||||
"BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"],
|
||||
"volume": ["AD", "ADOSC", "OBV"],
|
||||
"volatility": ["ATR", "NATR", "TRANGE"],
|
||||
"price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"],
|
||||
"cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"],
|
||||
"pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE",
|
||||
"CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY",
|
||||
"CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU",
|
||||
"CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI",
|
||||
"CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR",
|
||||
"CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER",
|
||||
"CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE",
|
||||
"CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK",
|
||||
"CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM",
|
||||
"CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW",
|
||||
"CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK",
|
||||
"CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES",
|
||||
"CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN",
|
||||
"CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR",
|
||||
"CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"],
|
||||
"statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT",
|
||||
"LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"],
|
||||
"math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX",
|
||||
"MULT", "SUB", "SUM"],
|
||||
}
|
||||
|
||||
|
||||
def _get_function_category(func_name: str) -> str:
|
||||
"""Determine the category of a TA-Lib function."""
|
||||
for category, functions in TALIB_CATEGORIES.items():
|
||||
if func_name in functions:
|
||||
return category
|
||||
return "other"
|
||||
|
||||
|
||||
class TALibIndicator(Indicator):
|
||||
"""
|
||||
Generic adapter for TA-Lib technical indicators.
|
||||
|
||||
Wraps any TA-Lib function to work within the composable indicator framework.
|
||||
Handles parameter mapping, input validation, and output formatting.
|
||||
"""
|
||||
|
||||
# Class variable to store the TA-Lib function name
|
||||
talib_function_name: str = None
|
||||
|
||||
def __init__(self, instance_name: str, **params):
|
||||
"""
|
||||
Initialize a TA-Lib indicator.
|
||||
|
||||
Args:
|
||||
instance_name: Unique name for this instance
|
||||
**params: TA-Lib function parameters
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError(
|
||||
"TA-Lib is not installed. Please install the TA-Lib C library "
|
||||
"and Python wrapper. See indicator/talib_adapter.py for instructions."
|
||||
)
|
||||
|
||||
super().__init__(instance_name, **params)
|
||||
self._talib_func = abstract.Function(self.talib_function_name)
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
"""Get metadata from TA-Lib function info."""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
|
||||
# Build parameters list from TA-Lib function info
|
||||
parameters = []
|
||||
for param_name, param_info in info.get("parameters", {}).items():
|
||||
# Handle case where param_info is a simple value (int/float) instead of a dict
|
||||
if isinstance(param_info, dict):
|
||||
param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float")
|
||||
default_value = param_info.get("default_value")
|
||||
else:
|
||||
# param_info is a simple value (default), infer type from the value
|
||||
if isinstance(param_info, int):
|
||||
param_type = "int"
|
||||
elif isinstance(param_info, float):
|
||||
param_type = "float"
|
||||
else:
|
||||
param_type = "float" # Default to float
|
||||
default_value = param_info
|
||||
|
||||
parameters.append(
|
||||
IndicatorParameter(
|
||||
name=param_name,
|
||||
type=param_type,
|
||||
description=f"TA-Lib parameter: {param_name}",
|
||||
default=default_value,
|
||||
required=False
|
||||
)
|
||||
)
|
||||
|
||||
# Get function group/category
|
||||
category = _get_function_category(cls.talib_function_name)
|
||||
|
||||
# Build display name (split camelCase or handle CDL prefix)
|
||||
display_name = cls.talib_function_name
|
||||
if display_name.startswith("CDL"):
|
||||
display_name = display_name[3:] # Remove CDL prefix for patterns
|
||||
|
||||
return IndicatorMetadata(
|
||||
name=cls.talib_function_name,
|
||||
display_name=display_name,
|
||||
description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"),
|
||||
category=category,
|
||||
parameters=parameters,
|
||||
use_cases=[f"Technical analysis using {cls.talib_function_name}"],
|
||||
references=["https://ta-lib.org/function.html"],
|
||||
tags=["talib", category, cls.talib_function_name.lower()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
"""
|
||||
Get input schema from TA-Lib function requirements.
|
||||
|
||||
Most TA-Lib functions use OHLCV data, but some use subsets.
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
input_names = info.get("input_names", {})
|
||||
|
||||
required_columns = []
|
||||
|
||||
# Map TA-Lib input names to our schema
|
||||
if "prices" in input_names:
|
||||
price_inputs = input_names["prices"]
|
||||
if "open" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="open", type="float", description="Opening price")
|
||||
)
|
||||
if "high" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="high", type="float", description="High price")
|
||||
)
|
||||
if "low" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="low", type="float", description="Low price")
|
||||
)
|
||||
if "close" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
)
|
||||
if "volume" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="volume", type="float", description="Trading volume")
|
||||
)
|
||||
|
||||
# Handle functions that take generic price arrays
|
||||
if "price" in input_names:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Price (typically close)")
|
||||
)
|
||||
|
||||
# If no specific inputs found, assume close price
|
||||
if not required_columns:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
)
|
||||
|
||||
return InputSchema(required_columns=required_columns)
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
"""Get output schema from TA-Lib function outputs."""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
output_names = info.get("output_names", [])
|
||||
|
||||
columns = []
|
||||
|
||||
# Most TA-Lib functions output one or more float arrays
|
||||
if isinstance(output_names, list):
|
||||
for output_name in output_names:
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=output_name.lower(),
|
||||
type="float",
|
||||
description=f"{cls.talib_function_name} output: {output_name}",
|
||||
nullable=True # TA-Lib often has NaN for initial periods
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Single output, use function name
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=cls.talib_function_name.lower(),
|
||||
type="float",
|
||||
description=f"{cls.talib_function_name} indicator value",
|
||||
nullable=True
|
||||
)
|
||||
)
|
||||
|
||||
return OutputSchema(columns=columns)
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
"""Compute indicator using TA-Lib."""
|
||||
# Extract input columns
|
||||
input_data = {}
|
||||
|
||||
# Get the function's expected inputs
|
||||
info = self._talib_func.info
|
||||
input_names = info.get("input_names", {})
|
||||
|
||||
# Prepare input arrays
|
||||
if "prices" in input_names:
|
||||
price_inputs = input_names["prices"]
|
||||
for price_type in price_inputs:
|
||||
column_data = context.get_column(price_type)
|
||||
# Convert to numpy array, replacing None with NaN
|
||||
input_data[price_type] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
elif "price" in input_names:
|
||||
# Generic price input, use close
|
||||
column_data = context.get_column("close")
|
||||
input_data["price"] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
else:
|
||||
# Default to close if no inputs specified
|
||||
column_data = context.get_column("close")
|
||||
input_data["close"] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
|
||||
# Set parameters on the function
|
||||
self._talib_func.parameters = self.params
|
||||
|
||||
# Execute TA-Lib function
|
||||
try:
|
||||
output = self._talib_func(input_data)
|
||||
except Exception as e:
|
||||
logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}")
|
||||
raise ValueError(f"TA-Lib computation failed: {e}")
|
||||
|
||||
# Format output
|
||||
times = context.get_times()
|
||||
output_names = info.get("output_names", [])
|
||||
|
||||
# Handle single vs multiple outputs
|
||||
if isinstance(output, np.ndarray):
|
||||
# Single output
|
||||
output_name = output_names[0].lower() if output_names else self.talib_function_name.lower()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
output_name: float(output[i]) if not np.isnan(output[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
elif isinstance(output, tuple):
|
||||
# Multiple outputs
|
||||
result_data = []
|
||||
for i in range(len(times)):
|
||||
row = {"time": times[i]}
|
||||
for j, output_array in enumerate(output):
|
||||
output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}"
|
||||
row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None
|
||||
result_data.append(row)
|
||||
else:
|
||||
raise ValueError(f"Unexpected TA-Lib output type: {type(output)}")
|
||||
|
||||
return ComputeResult(
|
||||
data=result_data,
|
||||
is_partial=context.is_incremental
|
||||
)
|
||||
|
||||
|
||||
def create_talib_indicator_class(func_name: str) -> type:
|
||||
"""
|
||||
Dynamically create an Indicator class for a TA-Lib function.
|
||||
|
||||
Args:
|
||||
func_name: TA-Lib function name (e.g., 'SMA', 'RSI')
|
||||
|
||||
Returns:
|
||||
Indicator class for this function
|
||||
"""
|
||||
return type(
|
||||
f"TALib_{func_name}",
|
||||
(TALibIndicator,),
|
||||
{"talib_function_name": func_name}
|
||||
)
|
||||
|
||||
|
||||
def register_all_talib_indicators(registry) -> int:
|
||||
"""
|
||||
Auto-register all available TA-Lib indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
|
||||
Raises:
|
||||
ImportError: If TA-Lib is not installed
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
logger.warning(
|
||||
"TA-Lib is not installed. Skipping TA-Lib indicator registration. "
|
||||
"Install TA-Lib C library and Python wrapper to enable TA-Lib indicators."
|
||||
)
|
||||
return 0
|
||||
|
||||
# Get all TA-Lib functions
|
||||
func_groups = talib.get_function_groups()
|
||||
all_functions = []
|
||||
for group, functions in func_groups.items():
|
||||
all_functions.extend(functions)
|
||||
|
||||
# Remove duplicates
|
||||
all_functions = sorted(set(all_functions))
|
||||
|
||||
registered_count = 0
|
||||
for func_name in all_functions:
|
||||
try:
|
||||
# Create indicator class for this function
|
||||
indicator_class = create_talib_indicator_class(func_name)
|
||||
|
||||
# Register with the registry
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators")
|
||||
return registered_count
|
||||
|
||||
|
||||
def get_talib_version() -> Optional[str]:
|
||||
"""
|
||||
Get the installed TA-Lib version.
|
||||
|
||||
Returns:
|
||||
Version string or None if not installed
|
||||
"""
|
||||
if TALIB_AVAILABLE:
|
||||
return talib.__version__
|
||||
return None
|
||||
|
||||
|
||||
def is_talib_available() -> bool:
|
||||
"""Check if TA-Lib is available."""
|
||||
return TALIB_AVAILABLE
|
||||
@@ -20,13 +20,14 @@ from gateway.hub import Gateway
|
||||
from gateway.channels.websocket import WebSocketChannel
|
||||
from gateway.protocol import WebSocketAgentUserMessage
|
||||
from agent.core import create_agent
|
||||
from agent.tools import set_registry, set_datasource_registry
|
||||
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
from secrets_manager import SecretsStore, InvalidMasterPassword
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -53,6 +54,9 @@ agent_executor = None
|
||||
datasource_registry = DataSourceRegistry()
|
||||
subscription_manager = SubscriptionManager()
|
||||
|
||||
# Indicator infrastructure
|
||||
indicator_registry = IndicatorRegistry()
|
||||
|
||||
# Global secrets store
|
||||
secrets_store = SecretsStore()
|
||||
|
||||
@@ -80,6 +84,14 @@ async def lifespan(app: FastAPI):
|
||||
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
||||
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
||||
|
||||
# Initialize indicator registry with all TA-Lib indicators
|
||||
try:
|
||||
indicator_count = register_all_talib_indicators(indicator_registry)
|
||||
logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
||||
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
||||
|
||||
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
||||
anthropic_api_key = None
|
||||
|
||||
@@ -101,6 +113,7 @@ async def lifespan(app: FastAPI):
|
||||
# Set the registries for agent tools
|
||||
set_registry(registry)
|
||||
set_datasource_registry(datasource_registry)
|
||||
set_indicator_registry(indicator_registry)
|
||||
|
||||
# Create and initialize agent
|
||||
agent_executor = create_agent(
|
||||
|
||||
@@ -40,6 +40,58 @@ class Exchange(StrEnum):
|
||||
UNISWAP_V3 = "UniswapV3"
|
||||
|
||||
|
||||
class Side(StrEnum):
|
||||
"""Order side: buy or sell"""
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
|
||||
class AmountType(StrEnum):
|
||||
"""Whether the order amount refers to base or quote currency"""
|
||||
BASE = "BASE" # Amount is in base currency (e.g., BTC in BTC/USD)
|
||||
QUOTE = "QUOTE" # Amount is in quote currency (e.g., USD in BTC/USD)
|
||||
|
||||
|
||||
class TimeInForce(StrEnum):
|
||||
"""Order lifetime specification"""
|
||||
GTC = "GTC" # Good Till Cancel
|
||||
IOC = "IOC" # Immediate or Cancel
|
||||
FOK = "FOK" # Fill or Kill
|
||||
DAY = "DAY" # Good for trading day
|
||||
GTD = "GTD" # Good Till Date
|
||||
|
||||
|
||||
class ConditionalOrderMode(StrEnum):
|
||||
"""How conditional orders behave on partial fills"""
|
||||
NEW_PER_FILL = "NEW_PER_FILL" # Create new conditional order per each fill
|
||||
UNIFIED_ADJUSTING = "UNIFIED_ADJUSTING" # Single conditional order that adjusts amount
|
||||
|
||||
|
||||
class TriggerType(StrEnum):
|
||||
"""Type of conditional trigger"""
|
||||
STOP_LOSS = "STOP_LOSS"
|
||||
TAKE_PROFIT = "TAKE_PROFIT"
|
||||
STOP_LIMIT = "STOP_LIMIT"
|
||||
TRAILING_STOP = "TRAILING_STOP"
|
||||
|
||||
|
||||
class TickSpacingMode(StrEnum):
|
||||
"""How price tick spacing is determined"""
|
||||
FIXED = "FIXED" # Fixed tick size
|
||||
DYNAMIC = "DYNAMIC" # Tick size varies by price level
|
||||
CONTINUOUS = "CONTINUOUS" # No tick restrictions
|
||||
|
||||
|
||||
class AssetType(StrEnum):
|
||||
"""Type of tradeable asset"""
|
||||
SPOT = "SPOT" # Spot/cash market
|
||||
MARGIN = "MARGIN" # Margin trading
|
||||
PERP = "PERP" # Perpetual futures
|
||||
FUTURE = "FUTURE" # Dated futures
|
||||
OPTION = "OPTION" # Options
|
||||
SYNTHETIC = "SYNTHETIC" # Synthetic/derived instruments
|
||||
|
||||
|
||||
class OcoMode(StrEnum):
|
||||
NO_OCO = "NO_OCO"
|
||||
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
|
||||
@@ -96,6 +148,126 @@ class TrancheStatus(BaseModel):
|
||||
endTime: Uint32 = Field(description="Concrete end timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standard Order Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConditionalTrigger(BaseModel):
|
||||
"""Conditional order trigger (stop-loss, take-profit, etc.)"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
trigger_type: TriggerType
|
||||
trigger_price: Float = Field(description="Price at which conditional order activates")
|
||||
trailing_delta: Float | None = Field(default=None, description="For trailing stops: delta from peak/trough")
|
||||
|
||||
|
||||
class AmountConstraints(BaseModel):
|
||||
"""Constraints on order amounts for a symbol"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
min_amount: Float = Field(description="Minimum order amount")
|
||||
max_amount: Float = Field(description="Maximum order amount")
|
||||
step_size: Float = Field(description="Amount increment granularity")
|
||||
|
||||
|
||||
class PriceConstraints(BaseModel):
|
||||
"""Constraints on order pricing for a symbol"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
tick_spacing_mode: TickSpacingMode
|
||||
tick_size: Float | None = Field(default=None, description="Fixed tick size (if FIXED mode)")
|
||||
min_price: Float | None = Field(default=None, description="Minimum allowed price")
|
||||
max_price: Float | None = Field(default=None, description="Maximum allowed price")
|
||||
|
||||
|
||||
class MarketCapabilities(BaseModel):
|
||||
"""Describes what order features a market supports"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
supported_sides: list[Side] = Field(description="Supported order sides (usually both)")
|
||||
supported_amount_types: list[AmountType] = Field(description="Whether BASE, QUOTE, or both amounts are supported")
|
||||
supports_market_orders: bool = Field(description="Whether market orders are supported")
|
||||
supports_limit_orders: bool = Field(description="Whether limit orders are supported")
|
||||
supported_time_in_force: list[TimeInForce] = Field(description="Supported order lifetimes")
|
||||
supports_conditional_orders: bool = Field(description="Whether stop-loss/take-profit are supported")
|
||||
supported_trigger_types: list[TriggerType] = Field(default_factory=list, description="Supported trigger types")
|
||||
supports_post_only: bool = Field(default=False, description="Whether post-only orders are supported")
|
||||
supports_reduce_only: bool = Field(default=False, description="Whether reduce-only orders are supported")
|
||||
supports_iceberg: bool = Field(default=False, description="Whether iceberg orders are supported")
|
||||
market_order_amount_type: AmountType | None = Field(
|
||||
default=None,
|
||||
description="Required amount type for market orders (some DEXs require exact-in)"
|
||||
)
|
||||
|
||||
|
||||
class SymbolMetadata(BaseModel):
|
||||
"""Complete metadata describing a tradeable symbol/market"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol_id: str = Field(description="Unique symbol identifier")
|
||||
base_asset: str = Field(description="Base asset (e.g., 'BTC')")
|
||||
quote_asset: str = Field(description="Quote asset (e.g., 'USD')")
|
||||
asset_type: AssetType = Field(description="Type of market")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
|
||||
amount_constraints: AmountConstraints
|
||||
price_constraints: PriceConstraints
|
||||
capabilities: MarketCapabilities
|
||||
|
||||
contract_size: Float | None = Field(default=None, description="For futures/options: contract multiplier")
|
||||
settlement_asset: str | None = Field(default=None, description="For derivatives: settlement currency")
|
||||
expiry_timestamp: Uint64 | None = Field(default=None, description="For dated futures/options: expiration")
|
||||
|
||||
|
||||
class StandardOrder(BaseModel):
|
||||
"""Standard order specification for exchange kernels"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol_id: str = Field(description="Symbol to trade")
|
||||
side: Side = Field(description="Buy or sell")
|
||||
amount: Float = Field(description="Order amount")
|
||||
amount_type: AmountType = Field(description="Whether amount is BASE or QUOTE currency")
|
||||
|
||||
limit_price: Float | None = Field(default=None, description="Limit price (None = market order)")
|
||||
time_in_force: TimeInForce = Field(default=TimeInForce.GTC, description="Order lifetime")
|
||||
good_till_date: Uint64 | None = Field(default=None, description="Expiry timestamp for GTD orders")
|
||||
|
||||
conditional_trigger: ConditionalTrigger | None = Field(
|
||||
default=None,
|
||||
description="Stop-loss/take-profit trigger"
|
||||
)
|
||||
conditional_mode: ConditionalOrderMode | None = Field(
|
||||
default=None,
|
||||
description="How conditional orders behave on partial fills"
|
||||
)
|
||||
|
||||
reduce_only: bool = Field(default=False, description="Only reduce existing position")
|
||||
post_only: bool = Field(default=False, description="Only make, never take")
|
||||
iceberg_qty: Float | None = Field(default=None, description="Visible amount for iceberg orders")
|
||||
|
||||
client_order_id: str | None = Field(default=None, description="Client-specified order ID")
|
||||
|
||||
|
||||
class StandardOrderStatus(BaseModel):
|
||||
"""Current status of a standard order"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
order: StandardOrder
|
||||
order_id: str = Field(description="Exchange-assigned order ID")
|
||||
status: str = Field(description="Order status: NEW, PARTIALLY_FILLED, FILLED, CANCELED, REJECTED, EXPIRED")
|
||||
filled_amount: Float = Field(description="Amount filled so far")
|
||||
average_fill_price: Float = Field(description="Average execution price")
|
||||
created_at: Uint64 = Field(description="Order creation timestamp")
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -117,7 +289,22 @@ class SwapOrder(BaseModel):
|
||||
tranches: list[Tranche] = Field(min_length=1)
|
||||
|
||||
|
||||
class StandardOrderGroup(BaseModel):
|
||||
"""Group of orders with OCO (One-Cancels-Other) relationship"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
orders: list[StandardOrder] = Field(min_length=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy swap order models (kept for backward compatibility)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OcoGroup(BaseModel):
|
||||
"""DEPRECATED: Use StandardOrderGroup instead"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
|
||||
40
backend/src/secrets_manager/__init__.py
Normal file
40
backend/src/secrets_manager/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Encrypted secrets management with master password protection.
|
||||
|
||||
This module provides secure storage for sensitive configuration like API keys,
|
||||
using Argon2id for password-based key derivation and Fernet (AES-256) for encryption.
|
||||
|
||||
Basic usage:
|
||||
from secrets_manager import SecretsStore
|
||||
|
||||
# First time setup
|
||||
store = SecretsStore()
|
||||
store.initialize("my-master-password")
|
||||
store.set("ANTHROPIC_API_KEY", "sk-ant-...")
|
||||
|
||||
# Later usage
|
||||
store = SecretsStore()
|
||||
store.unlock("my-master-password")
|
||||
api_key = store.get("ANTHROPIC_API_KEY")
|
||||
|
||||
Command-line interface:
|
||||
python -m secrets_manager.cli init
|
||||
python -m secrets_manager.cli set KEY VALUE
|
||||
python -m secrets_manager.cli get KEY
|
||||
python -m secrets_manager.cli list
|
||||
python -m secrets_manager.cli change-password
|
||||
"""
|
||||
|
||||
from .store import (
|
||||
SecretsStore,
|
||||
SecretsStoreError,
|
||||
SecretsStoreLocked,
|
||||
InvalidMasterPassword,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SecretsStore",
|
||||
"SecretsStoreError",
|
||||
"SecretsStoreLocked",
|
||||
"InvalidMasterPassword",
|
||||
]
|
||||
374
backend/src/secrets_manager/cli.py
Normal file
374
backend/src/secrets_manager/cli.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Command-line interface for managing the encrypted secrets store.
|
||||
|
||||
Usage:
|
||||
python -m secrets.cli init # Initialize new secrets store
|
||||
python -m secrets.cli set KEY VALUE # Set a secret
|
||||
python -m secrets.cli get KEY # Get a secret
|
||||
python -m secrets.cli delete KEY # Delete a secret
|
||||
python -m secrets.cli list # List all secret keys
|
||||
python -m secrets.cli change-password # Change master password
|
||||
python -m secrets.cli export FILE # Export encrypted backup
|
||||
python -m secrets.cli import FILE # Import encrypted backup
|
||||
python -m secrets.cli migrate-from-env # Migrate secrets from .env file
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
import getpass
|
||||
from pathlib import Path
|
||||
|
||||
from .store import SecretsStore, SecretsStoreError, InvalidMasterPassword
|
||||
|
||||
|
||||
def get_password(prompt: str = "Master password: ", confirm: bool = False) -> str:
|
||||
"""
|
||||
Securely get password from user.
|
||||
|
||||
Args:
|
||||
prompt: Password prompt
|
||||
confirm: If True, ask for confirmation
|
||||
|
||||
Returns:
|
||||
Password string
|
||||
"""
|
||||
password = getpass.getpass(prompt)
|
||||
|
||||
if confirm:
|
||||
confirm_password = getpass.getpass("Confirm password: ")
|
||||
if password != confirm_password:
|
||||
print("Error: Passwords do not match", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return password
|
||||
|
||||
|
||||
def cmd_init(args):
|
||||
"""Initialize a new secrets store."""
|
||||
store = SecretsStore()
|
||||
|
||||
if store.is_initialized:
|
||||
print("Error: Secrets store is already initialized", file=sys.stderr)
|
||||
print(f"Location: {store.secrets_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password("Create master password: ", confirm=True)
|
||||
|
||||
if len(password) < 8:
|
||||
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
store.initialize(password)
|
||||
print(f"Secrets store initialized at {store.secrets_file}")
|
||||
|
||||
|
||||
def cmd_set(args):
|
||||
"""Set a secret value."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
store.set(args.key, args.value)
|
||||
print(f"✓ Secret '{args.key}' saved")
|
||||
|
||||
|
||||
def cmd_get(args):
|
||||
"""Get a secret value."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
value = store.get(args.key)
|
||||
if value is None:
|
||||
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Print to stdout (can be captured)
|
||||
print(value)
|
||||
|
||||
|
||||
def cmd_delete(args):
|
||||
"""Delete a secret."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if store.delete(args.key):
|
||||
print(f"✓ Secret '{args.key}' deleted")
|
||||
else:
|
||||
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_list(args):
|
||||
"""List all secret keys."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
keys = store.list_keys()
|
||||
|
||||
if not keys:
|
||||
print("No secrets stored")
|
||||
else:
|
||||
print(f"Stored secrets ({len(keys)}):")
|
||||
for key in sorted(keys):
|
||||
# Show key and value length for verification
|
||||
value = store.get(key)
|
||||
value_str = str(value)
|
||||
value_preview = value_str[:50] + "..." if len(value_str) > 50 else value_str
|
||||
print(f" {key}: {value_preview}")
|
||||
|
||||
|
||||
def cmd_change_password(args):
|
||||
"""Change the master password."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
current_password = get_password("Current master password: ")
|
||||
new_password = get_password("New master password: ", confirm=True)
|
||||
|
||||
if len(new_password) < 8:
|
||||
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
store.change_master_password(current_password, new_password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid current password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_export(args):
|
||||
"""Export encrypted secrets to a backup file."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
output_path = Path(args.file)
|
||||
|
||||
if output_path.exists() and not args.force:
|
||||
print(f"Error: File {output_path} already exists. Use --force to overwrite.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
store.export_encrypted(output_path)
|
||||
except SecretsStoreError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_import(args):
|
||||
"""Import encrypted secrets from a backup file."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
input_path = Path(args.file)
|
||||
|
||||
if not input_path.exists():
|
||||
print(f"Error: File {input_path} does not exist", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.import_encrypted(input_path, password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password or incompatible backup", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except SecretsStoreError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_migrate_from_env(args):
|
||||
"""Migrate secrets from .env file to encrypted store."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Look for .env file
|
||||
backend_root = Path(__file__).parent.parent.parent
|
||||
env_file = backend_root / ".env"
|
||||
|
||||
if not env_file.exists():
|
||||
print(f"Error: .env file not found at {env_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse .env file (simple parser - doesn't handle all edge cases)
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
|
||||
with open(env_file) as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Parse KEY=VALUE format
|
||||
if '=' not in line:
|
||||
print(f"Warning: Skipping invalid line {line_num}: {line}", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
key, value = line.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Remove quotes if present
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
|
||||
# Check if key already exists
|
||||
existing = store.get(key)
|
||||
if existing is not None:
|
||||
print(f"Warning: Secret '{key}' already exists, skipping", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
store.set(key, value)
|
||||
print(f"✓ Migrated: {key}")
|
||||
migrated += 1
|
||||
|
||||
print(f"\nMigration complete: {migrated} secrets migrated, {skipped} skipped")
|
||||
|
||||
if not args.keep_env:
|
||||
# Ask for confirmation before deleting .env
|
||||
confirm = input(f"\nDelete {env_file}? [y/N]: ").strip().lower()
|
||||
if confirm == 'y':
|
||||
env_file.unlink()
|
||||
print(f"✓ Deleted {env_file}")
|
||||
else:
|
||||
print(f"Kept {env_file} (consider deleting it manually)")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Manage encrypted secrets store",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||
subparsers.required = True
|
||||
|
||||
# init
|
||||
parser_init = subparsers.add_parser('init', help='Initialize new secrets store')
|
||||
parser_init.set_defaults(func=cmd_init)
|
||||
|
||||
# set
|
||||
parser_set = subparsers.add_parser('set', help='Set a secret value')
|
||||
parser_set.add_argument('key', help='Secret key name')
|
||||
parser_set.add_argument('value', help='Secret value')
|
||||
parser_set.set_defaults(func=cmd_set)
|
||||
|
||||
# get
|
||||
parser_get = subparsers.add_parser('get', help='Get a secret value')
|
||||
parser_get.add_argument('key', help='Secret key name')
|
||||
parser_get.set_defaults(func=cmd_get)
|
||||
|
||||
# delete
|
||||
parser_delete = subparsers.add_parser('delete', help='Delete a secret')
|
||||
parser_delete.add_argument('key', help='Secret key name')
|
||||
parser_delete.set_defaults(func=cmd_delete)
|
||||
|
||||
# list
|
||||
parser_list = subparsers.add_parser('list', help='List all secret keys')
|
||||
parser_list.set_defaults(func=cmd_list)
|
||||
|
||||
# change-password
|
||||
parser_change = subparsers.add_parser('change-password', help='Change master password')
|
||||
parser_change.set_defaults(func=cmd_change_password)
|
||||
|
||||
# export
|
||||
parser_export = subparsers.add_parser('export', help='Export encrypted backup')
|
||||
parser_export.add_argument('file', help='Output file path')
|
||||
parser_export.add_argument('--force', action='store_true', help='Overwrite existing file')
|
||||
parser_export.set_defaults(func=cmd_export)
|
||||
|
||||
# import
|
||||
parser_import = subparsers.add_parser('import', help='Import encrypted backup')
|
||||
parser_import.add_argument('file', help='Input file path')
|
||||
parser_import.set_defaults(func=cmd_import)
|
||||
|
||||
# migrate-from-env
|
||||
parser_migrate = subparsers.add_parser('migrate-from-env', help='Migrate from .env file')
|
||||
parser_migrate.add_argument('--keep-env', action='store_true', help='Keep .env file after migration')
|
||||
parser_migrate.set_defaults(func=cmd_migrate_from_env)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
args.func(args)
|
||||
except KeyboardInterrupt:
|
||||
print("\nAborted", file=sys.stderr)
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
144
backend/src/secrets_manager/crypto.py
Normal file
144
backend/src/secrets_manager/crypto.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Cryptographic utilities for secrets management.
|
||||
|
||||
Uses Argon2id for password-based key derivation and Fernet for encryption.
|
||||
"""
|
||||
import os
|
||||
import secrets as secrets_module
|
||||
from typing import Tuple
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.low_level import hash_secret_raw, Type
|
||||
from cryptography.fernet import Fernet
|
||||
import base64
|
||||
|
||||
|
||||
# Argon2id parameters (OWASP recommended for password-based KDF)
|
||||
# These provide strong defense against GPU/ASIC attacks
|
||||
ARGON2_TIME_COST = 3 # iterations
|
||||
ARGON2_MEMORY_COST = 65536 # 64 MB
|
||||
ARGON2_PARALLELISM = 4 # threads
|
||||
ARGON2_HASH_LENGTH = 32 # bytes (256 bits for Fernet key)
|
||||
ARGON2_SALT_LENGTH = 16 # bytes (128 bits)
|
||||
|
||||
|
||||
def generate_salt() -> bytes:
|
||||
"""Generate a cryptographically secure random salt."""
|
||||
return secrets_module.token_bytes(ARGON2_SALT_LENGTH)
|
||||
|
||||
|
||||
def derive_key_from_password(password: str, salt: bytes) -> bytes:
|
||||
"""
|
||||
Derive an encryption key from a password using Argon2id.
|
||||
|
||||
Args:
|
||||
password: The master password
|
||||
salt: The salt (must be consistent for the same password to work)
|
||||
|
||||
Returns:
|
||||
32-byte key suitable for Fernet encryption
|
||||
"""
|
||||
password_bytes = password.encode('utf-8')
|
||||
|
||||
# Use Argon2id (hybrid mode - best of Argon2i and Argon2d)
|
||||
raw_hash = hash_secret_raw(
|
||||
secret=password_bytes,
|
||||
salt=salt,
|
||||
time_cost=ARGON2_TIME_COST,
|
||||
memory_cost=ARGON2_MEMORY_COST,
|
||||
parallelism=ARGON2_PARALLELISM,
|
||||
hash_len=ARGON2_HASH_LENGTH,
|
||||
type=Type.ID # Argon2id
|
||||
)
|
||||
|
||||
return raw_hash
|
||||
|
||||
|
||||
def create_fernet(key: bytes) -> Fernet:
|
||||
"""
|
||||
Create a Fernet cipher instance from a raw key.
|
||||
|
||||
Args:
|
||||
key: 32-byte raw key from Argon2id
|
||||
|
||||
Returns:
|
||||
Fernet instance for encryption/decryption
|
||||
"""
|
||||
# Fernet requires a URL-safe base64-encoded 32-byte key
|
||||
fernet_key = base64.urlsafe_b64encode(key)
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def encrypt_data(data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt data using Fernet (AES-256-CBC).
|
||||
|
||||
Args:
|
||||
data: Raw bytes to encrypt
|
||||
key: 32-byte encryption key
|
||||
|
||||
Returns:
|
||||
Encrypted data (includes IV and auth tag)
|
||||
"""
|
||||
fernet = create_fernet(key)
|
||||
return fernet.encrypt(data)
|
||||
|
||||
|
||||
def decrypt_data(encrypted_data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt data using Fernet.
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted bytes from encrypt_data
|
||||
key: 32-byte encryption key (must match encryption key)
|
||||
|
||||
Returns:
|
||||
Decrypted raw bytes
|
||||
|
||||
Raises:
|
||||
cryptography.fernet.InvalidToken: If decryption fails (wrong key/corrupted data)
|
||||
"""
|
||||
fernet = create_fernet(key)
|
||||
return fernet.decrypt(encrypted_data)
|
||||
|
||||
|
||||
def create_verification_hash(password: str, salt: bytes) -> str:
|
||||
"""
|
||||
Create a verification hash to check if a password is correct.
|
||||
|
||||
This is NOT for storing the password - it's for verifying the password
|
||||
unlocks the correct key without trying to decrypt the entire secrets file.
|
||||
|
||||
Args:
|
||||
password: The master password
|
||||
salt: The salt used for key derivation
|
||||
|
||||
Returns:
|
||||
Base64-encoded hash for verification
|
||||
"""
|
||||
# Derive key and hash it again for verification
|
||||
key = derive_key_from_password(password, salt)
|
||||
|
||||
# Simple hash of the key for verification (not security critical since
|
||||
# the key itself is already derived from Argon2id)
|
||||
verification = base64.b64encode(key[:16]).decode('ascii')
|
||||
|
||||
return verification
|
||||
|
||||
|
||||
def verify_password(password: str, salt: bytes, verification_hash: str) -> bool:
|
||||
"""
|
||||
Verify a password against a verification hash.
|
||||
|
||||
Args:
|
||||
password: Password to verify
|
||||
salt: Salt used for key derivation
|
||||
verification_hash: Expected verification hash
|
||||
|
||||
Returns:
|
||||
True if password is correct, False otherwise
|
||||
"""
|
||||
computed_hash = create_verification_hash(password, salt)
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
return secrets_module.compare_digest(computed_hash, verification_hash)
|
||||
406
backend/src/secrets_manager/store.py
Normal file
406
backend/src/secrets_manager/store.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Encrypted secrets store with master password protection.
|
||||
|
||||
The secrets are stored in an encrypted file, with the encryption key derived
|
||||
from a master password using Argon2id. The master password can be changed
|
||||
without re-encrypting all secrets.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from .crypto import (
|
||||
generate_salt,
|
||||
derive_key_from_password,
|
||||
encrypt_data,
|
||||
decrypt_data,
|
||||
create_verification_hash,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
class SecretsStoreError(Exception):
|
||||
"""Base exception for secrets store errors."""
|
||||
pass
|
||||
|
||||
|
||||
class SecretsStoreLocked(SecretsStoreError):
|
||||
"""Raised when trying to access secrets while store is locked."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMasterPassword(SecretsStoreError):
|
||||
"""Raised when master password is incorrect."""
|
||||
pass
|
||||
|
||||
|
||||
class SecretsStore:
|
||||
"""
|
||||
Encrypted secrets store with master password protection.
|
||||
|
||||
Usage:
|
||||
# Initialize (first time)
|
||||
store = SecretsStore()
|
||||
store.initialize("my-secure-password")
|
||||
|
||||
# Unlock
|
||||
store = SecretsStore()
|
||||
store.unlock("my-secure-password")
|
||||
|
||||
# Access secrets
|
||||
api_key = store.get("ANTHROPIC_API_KEY")
|
||||
store.set("NEW_SECRET", "secret-value")
|
||||
|
||||
# Change master password
|
||||
store.change_master_password("my-secure-password", "new-password")
|
||||
|
||||
# Lock when done
|
||||
store.lock()
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize secrets store.
|
||||
|
||||
Args:
|
||||
data_dir: Directory for secrets files (defaults to backend/data)
|
||||
"""
|
||||
if data_dir is None:
|
||||
# Default to backend/data
|
||||
backend_root = Path(__file__).parent.parent.parent
|
||||
data_dir = backend_root / "data"
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.master_key_file = self.data_dir / ".master.key"
|
||||
self.secrets_file = self.data_dir / "secrets.enc"
|
||||
|
||||
# Runtime state
|
||||
self._encryption_key: Optional[bytes] = None
|
||||
self._secrets: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the secrets store has been initialized."""
|
||||
return self.master_key_file.exists()
|
||||
|
||||
@property
|
||||
def is_unlocked(self) -> bool:
|
||||
"""Check if the secrets store is currently unlocked."""
|
||||
return self._encryption_key is not None
|
||||
|
||||
def initialize(self, master_password: str) -> None:
|
||||
"""
|
||||
Initialize the secrets store with a master password.
|
||||
|
||||
This should only be called once when setting up the store.
|
||||
|
||||
Args:
|
||||
master_password: The master password to protect the secrets
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If store is already initialized
|
||||
"""
|
||||
if self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is already initialized. "
|
||||
"Use unlock() to access it or change_master_password() to change the password."
|
||||
)
|
||||
|
||||
# Generate a new random salt
|
||||
salt = generate_salt()
|
||||
|
||||
# Derive encryption key
|
||||
encryption_key = derive_key_from_password(master_password, salt)
|
||||
|
||||
# Create verification hash
|
||||
verification_hash = create_verification_hash(master_password, salt)
|
||||
|
||||
# Store salt and verification hash
|
||||
master_key_data = {
|
||||
"salt": salt.hex(),
|
||||
"verification": verification_hash,
|
||||
}
|
||||
|
||||
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||
|
||||
# Set restrictive permissions (owner read/write only)
|
||||
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
# Initialize empty secrets
|
||||
self._encryption_key = encryption_key
|
||||
self._secrets = {}
|
||||
self._save_secrets()
|
||||
|
||||
print(f"✓ Secrets store initialized at {self.secrets_file}")
|
||||
|
||||
def unlock(self, master_password: str) -> None:
|
||||
"""
|
||||
Unlock the secrets store with the master password.
|
||||
|
||||
Args:
|
||||
master_password: The master password
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If store is not initialized
|
||||
InvalidMasterPassword: If password is incorrect
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
# Load salt and verification hash
|
||||
master_key_data = json.loads(self.master_key_file.read_text())
|
||||
salt = bytes.fromhex(master_key_data["salt"])
|
||||
verification_hash = master_key_data["verification"]
|
||||
|
||||
# Verify password
|
||||
if not verify_password(master_password, salt, verification_hash):
|
||||
raise InvalidMasterPassword("Invalid master password")
|
||||
|
||||
# Derive encryption key
|
||||
encryption_key = derive_key_from_password(master_password, salt)
|
||||
|
||||
# Load and decrypt secrets
|
||||
if self.secrets_file.exists():
|
||||
try:
|
||||
encrypted_data = self.secrets_file.read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
except InvalidToken:
|
||||
raise InvalidMasterPassword("Failed to decrypt secrets (invalid password)")
|
||||
except json.JSONDecodeError as e:
|
||||
raise SecretsStoreError(f"Corrupted secrets file: {e}")
|
||||
else:
|
||||
# No secrets file yet (fresh initialization)
|
||||
self._secrets = {}
|
||||
|
||||
self._encryption_key = encryption_key
|
||||
print(f"✓ Secrets store unlocked ({len(self._secrets)} secrets)")
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock the secrets store (clear decrypted data from memory)."""
|
||||
self._encryption_key = None
|
||||
self._secrets = None
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a secret value.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
default: Default value if key doesn't exist
|
||||
|
||||
Returns:
|
||||
Secret value or default
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
return self._secrets.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Set a secret value.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
value: Secret value (must be JSON-serializable)
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
self._secrets[key] = value
|
||||
self._save_secrets()
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a secret.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
|
||||
Returns:
|
||||
True if secret existed and was deleted, False otherwise
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
if key in self._secrets:
|
||||
del self._secrets[key]
|
||||
self._save_secrets()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_keys(self) -> list[str]:
|
||||
"""
|
||||
List all secret keys.
|
||||
|
||||
Returns:
|
||||
List of secret keys
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
return list(self._secrets.keys())
|
||||
|
||||
def change_master_password(self, current_password: str, new_password: str) -> None:
|
||||
"""
|
||||
Change the master password.
|
||||
|
||||
This re-encrypts the secrets with a new key derived from the new password.
|
||||
|
||||
Args:
|
||||
current_password: Current master password
|
||||
new_password: New master password
|
||||
|
||||
Raises:
|
||||
InvalidMasterPassword: If current password is incorrect
|
||||
"""
|
||||
# ALWAYS verify current password before changing
|
||||
# Load salt and verification hash
|
||||
if not self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
master_key_data = json.loads(self.master_key_file.read_text())
|
||||
salt = bytes.fromhex(master_key_data["salt"])
|
||||
verification_hash = master_key_data["verification"]
|
||||
|
||||
# Verify current password is correct
|
||||
if not verify_password(current_password, salt, verification_hash):
|
||||
raise InvalidMasterPassword("Invalid current password")
|
||||
|
||||
# Unlock if needed to access secrets
|
||||
was_unlocked = self.is_unlocked
|
||||
if not was_unlocked:
|
||||
# Store is locked, so unlock with current password
|
||||
# (we already verified it above, so this will succeed)
|
||||
encryption_key = derive_key_from_password(current_password, salt)
|
||||
|
||||
# Load and decrypt secrets
|
||||
if self.secrets_file.exists():
|
||||
encrypted_data = self.secrets_file.read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
else:
|
||||
self._secrets = {}
|
||||
|
||||
self._encryption_key = encryption_key
|
||||
|
||||
# Generate new salt
|
||||
new_salt = generate_salt()
|
||||
|
||||
# Derive new encryption key
|
||||
new_encryption_key = derive_key_from_password(new_password, new_salt)
|
||||
|
||||
# Create new verification hash
|
||||
new_verification_hash = create_verification_hash(new_password, new_salt)
|
||||
|
||||
# Update master key file
|
||||
master_key_data = {
|
||||
"salt": new_salt.hex(),
|
||||
"verification": new_verification_hash,
|
||||
}
|
||||
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
# Re-encrypt secrets with new key
|
||||
old_key = self._encryption_key
|
||||
self._encryption_key = new_encryption_key
|
||||
self._save_secrets()
|
||||
|
||||
print("✓ Master password changed successfully")
|
||||
|
||||
# Lock if it wasn't unlocked before
|
||||
if not was_unlocked:
|
||||
self.lock()
|
||||
|
||||
def _save_secrets(self) -> None:
|
||||
"""Save secrets to encrypted file."""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Cannot save while locked")
|
||||
|
||||
# Serialize secrets to JSON
|
||||
secrets_json = json.dumps(self._secrets, indent=2)
|
||||
secrets_bytes = secrets_json.encode('utf-8')
|
||||
|
||||
# Encrypt
|
||||
encrypted_data = encrypt_data(secrets_bytes, self._encryption_key)
|
||||
|
||||
# Write to file
|
||||
self.secrets_file.write_bytes(encrypted_data)
|
||||
|
||||
# Set restrictive permissions
|
||||
os.chmod(self.secrets_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
def export_encrypted(self, output_path: Path) -> None:
|
||||
"""
|
||||
Export encrypted secrets to a file (for backup).
|
||||
|
||||
Args:
|
||||
output_path: Path to export file
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If secrets file doesn't exist
|
||||
"""
|
||||
if not self.secrets_file.exists():
|
||||
raise SecretsStoreError("No secrets to export")
|
||||
|
||||
import shutil
|
||||
shutil.copy2(self.secrets_file, output_path)
|
||||
print(f"✓ Encrypted secrets exported to {output_path}")
|
||||
|
||||
def import_encrypted(self, input_path: Path, master_password: str) -> None:
|
||||
"""
|
||||
Import encrypted secrets from a file.
|
||||
|
||||
This will verify the password can decrypt the import before replacing
|
||||
the current secrets.
|
||||
|
||||
Args:
|
||||
input_path: Path to import file
|
||||
master_password: Master password for the current store
|
||||
|
||||
Raises:
|
||||
InvalidMasterPassword: If password doesn't work with import
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
self.unlock(master_password)
|
||||
|
||||
# Try to decrypt the imported file with current key
|
||||
try:
|
||||
encrypted_data = Path(input_path).read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, self._encryption_key)
|
||||
imported_secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
except InvalidToken:
|
||||
raise InvalidMasterPassword(
|
||||
"Cannot decrypt imported secrets with current master password"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise SecretsStoreError(f"Corrupted import file: {e}")
|
||||
|
||||
# Replace secrets
|
||||
self._secrets = imported_secrets
|
||||
self._save_secrets()
|
||||
|
||||
print(f"✓ Imported {len(self._secrets)} secrets from {input_path}")
|
||||
@@ -2,8 +2,10 @@
|
||||
Test script for CCXT DataSource adapter (Free Version).
|
||||
|
||||
This demonstrates how to use the free CCXT adapter (not ccxt.pro) with various
|
||||
exchanges. It uses polling instead of WebSocket for real-time updates and
|
||||
verifies that Decimal precision is maintained throughout.
|
||||
exchanges. It uses polling instead of WebSocket for real-time updates.
|
||||
|
||||
CCXT is configured to use Decimal mode internally for precision, but OHLCV data
|
||||
is converted to float for optimal DataFrame/analysis performance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -73,10 +75,10 @@ async def test_binance_datasource():
|
||||
print(f" Close: {latest.data['close']} (type: {type(latest.data['close']).__name__})")
|
||||
print(f" Volume: {latest.data['volume']} (type: {type(latest.data['volume']).__name__})")
|
||||
|
||||
# Verify Decimal precision
|
||||
assert isinstance(latest.data['close'], Decimal), "Price should be Decimal type!"
|
||||
assert isinstance(latest.data['volume'], Decimal), "Volume should be Decimal type!"
|
||||
print(f" ✓ Numerical precision verified: using Decimal types")
|
||||
# Verify OHLCV uses float (converted from Decimal for DataFrame performance)
|
||||
assert isinstance(latest.data['close'], float), "OHLCV price should be float type!"
|
||||
assert isinstance(latest.data['volume'], float), "OHLCV volume should be float type!"
|
||||
print(f" ✓ OHLCV data type verified: using native float (CCXT uses Decimal internally)")
|
||||
|
||||
# Test 5: Polling subscription (brief test)
|
||||
print("\n5. Testing polling-based subscription...")
|
||||
@@ -87,7 +89,7 @@ async def test_binance_datasource():
|
||||
tick_count[0] += 1
|
||||
if tick_count[0] == 1:
|
||||
print(f" Received tick: close={data['close']} (type: {type(data['close']).__name__})")
|
||||
assert isinstance(data['close'], Decimal), "Polled data should use Decimal!"
|
||||
assert isinstance(data['close'], float), "Polled OHLCV data should use float!"
|
||||
|
||||
subscription_id = await binance.subscribe_bars(
|
||||
symbol="BTC/USDT",
|
||||
|
||||
Reference in New Issue
Block a user