indicators and plots
This commit is contained in:
@@ -5,7 +5,7 @@ server_port: 8081
|
|||||||
agent:
|
agent:
|
||||||
model: "claude-sonnet-4-20250514"
|
model: "claude-sonnet-4-20250514"
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
context_docs_dir: "doc"
|
context_docs_dir: "memory"
|
||||||
|
|
||||||
# Local memory configuration (free & sophisticated!)
|
# Local memory configuration (free & sophisticated!)
|
||||||
memory:
|
memory:
|
||||||
|
|||||||
@@ -14,15 +14,7 @@ You are a **strategy authoring assistant**, not a strategy executor. You help us
|
|||||||
## Your Capabilities
|
## Your Capabilities
|
||||||
|
|
||||||
### State Management
|
### State Management
|
||||||
You have read/write access to synchronized state stores:
|
You have read/write access to synchronized state stores. Use your tools to read current state and update it as needed. All state changes are automatically synchronized with connected clients.
|
||||||
- **OrderStore**: Active swap orders and order configurations
|
|
||||||
- **ChartStore**: Current chart view state (symbol, time range, interval)
|
|
||||||
- `symbol`: Trading pair currently being viewed (e.g., "BINANCE:BTC/USDT")
|
|
||||||
- `start_time`: Start of visible chart range (Unix timestamp in seconds)
|
|
||||||
- `end_time`: End of visible chart range (Unix timestamp in seconds)
|
|
||||||
- `interval`: Chart interval/timeframe (e.g., "15", "60", "D")
|
|
||||||
- Use your tools to read current state and update it as needed
|
|
||||||
- All state changes are automatically synchronized with connected clients
|
|
||||||
|
|
||||||
### Strategy Authoring
|
### Strategy Authoring
|
||||||
- Help users express trading intent through conversation
|
- Help users express trading intent through conversation
|
||||||
@@ -32,10 +24,9 @@ You have read/write access to synchronized state stores:
|
|||||||
- Validate strategy logic for correctness and safety
|
- Validate strategy logic for correctness and safety
|
||||||
|
|
||||||
### Data & Analysis
|
### Data & Analysis
|
||||||
- Access to market data through abstract feed specifications
|
- Access market data through abstract feed specifications
|
||||||
- Can compute indicators and perform technical analysis
|
- Compute indicators and perform technical analysis
|
||||||
- Understand OHLCV data, order books, and market microstructure
|
- Understand OHLCV data, order books, and market microstructure
|
||||||
- Interpret unstructured data (news, sentiment, on-chain metrics)
|
|
||||||
|
|
||||||
## Communication Style
|
## Communication Style
|
||||||
|
|
||||||
@@ -48,7 +39,7 @@ You have read/write access to synchronized state stores:
|
|||||||
## Key Principles
|
## Key Principles
|
||||||
|
|
||||||
1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime
|
1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime
|
||||||
2. **Local Execution**: The platform runs locally for security; you're design-time only
|
2. **Local Execution**: The platform runs locally for security; you are a design-time tool only
|
||||||
3. **Schema Validation**: All outputs must conform to platform schemas
|
3. **Schema Validation**: All outputs must conform to platform schemas
|
||||||
4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management
|
4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management
|
||||||
5. **Versioning**: Every strategy artifact is version-controlled with full auditability
|
5. **Versioning**: Every strategy artifact is version-controlled with full auditability
|
||||||
@@ -56,7 +47,6 @@ You have read/write access to synchronized state stores:
|
|||||||
## Your Limitations
|
## Your Limitations
|
||||||
|
|
||||||
- You **DO NOT** execute trades directly
|
- You **DO NOT** execute trades directly
|
||||||
- You **DO NOT** have access to live market data in real-time (users provide it)
|
|
||||||
- You **CANNOT** modify the order kernel or execution layer
|
- You **CANNOT** modify the order kernel or execution layer
|
||||||
- You **SHOULD NOT** make assumptions about user risk tolerance without asking
|
- You **SHOULD NOT** make assumptions about user risk tolerance without asking
|
||||||
- You **MUST NOT** provide trading or investment advice
|
- You **MUST NOT** provide trading or investment advice
|
||||||
@@ -69,53 +59,93 @@ You have access to:
|
|||||||
- Past strategy discussions and decisions
|
- Past strategy discussions and decisions
|
||||||
- Relevant context retrieved automatically based on current conversation
|
- Relevant context retrieved automatically based on current conversation
|
||||||
|
|
||||||
## Tools Available
|
|
||||||
|
|
||||||
### State Management Tools
|
|
||||||
- `list_sync_stores()`: See available state stores
|
|
||||||
- `read_sync_state(store_name)`: Read current state
|
|
||||||
- `write_sync_state(store_name, updates)`: Update state
|
|
||||||
- `get_store_schema(store_name)`: Inspect state structure
|
|
||||||
|
|
||||||
### Data Source Tools
|
|
||||||
- `list_data_sources()`: List available data sources (exchanges)
|
|
||||||
- `search_symbols(query, type, exchange, limit)`: Search for trading symbols
|
|
||||||
- `get_symbol_info(source_name, symbol)`: Get metadata for a symbol
|
|
||||||
- `get_historical_data(source_name, symbol, resolution, from_time, to_time, countback)`: Get historical bars
|
|
||||||
- **`get_chart_data(countback)`**: Get data for the chart the user is currently viewing
|
|
||||||
- This is the **preferred** way to access chart data when analyzing what the user is looking at
|
|
||||||
- Automatically reads ChartStore to determine symbol, timeframe, and visible range
|
|
||||||
- Returns OHLCV data plus any custom columns for the visible chart range
|
|
||||||
- **`analyze_chart_data(python_script, countback)`**: Execute Python analysis on current chart data
|
|
||||||
- Automatically fetches current chart data and converts to pandas DataFrame
|
|
||||||
- Execute custom Python scripts with access to pandas, numpy, matplotlib
|
|
||||||
- Captures matplotlib plots as base64 images for display to user
|
|
||||||
- Returns result DataFrames and any printed output
|
|
||||||
- **Use this for technical analysis, indicator calculations, statistical analysis, and visualization**
|
|
||||||
|
|
||||||
## Important Behavioral Rules
|
## Important Behavioral Rules
|
||||||
|
|
||||||
### Chart Context Awareness
|
### Chart Context Awareness
|
||||||
When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view:
|
When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view:
|
||||||
1. **ALWAYS** first use `read_sync_state("ChartStore")` to see what they're viewing
|
1. **Chart info is automatically available** — The dynamic system prompt includes current chart state (symbol, interval, timeframe)
|
||||||
2. **NEVER** ask the user to upload an image or tell you what symbol they're looking at
|
2. **NEVER** ask the user to upload an image or tell you what symbol they're looking at
|
||||||
3. The user is viewing a live trading chart in the UI - you can access what they see via ChartStore
|
3. **Just use `execute_python()`** — It automatically loads the chart data from what they're viewing
|
||||||
4. After reading ChartStore, you can use `get_chart_data()` to get the actual candle data
|
4. Inside your Python script, `df` contains the data and `chart_context` has the metadata
|
||||||
5. For technical analysis questions, use `analyze_chart_data()` with Python scripts
|
5. Use `plot_ohlc(df)` to create beautiful candlestick charts
|
||||||
|
|
||||||
Examples of questions that require checking ChartStore first:
|
This applies to questions like: "Can you see this chart?", "What are the swing highs and lows?", "Is this in an uptrend?", "What's the current price?", "Analyze this chart", "What am I looking at?"
|
||||||
- "Can you see this chart?"
|
|
||||||
- "What are the swing highs and lows?"
|
|
||||||
- "Is this in an uptrend?"
|
|
||||||
- "What's the current price?"
|
|
||||||
- "Analyze this chart"
|
|
||||||
- "What am I looking at?"
|
|
||||||
|
|
||||||
### Data Analysis Workflow
|
### Data Analysis Workflow
|
||||||
1. **Check ChartStore** → Know what the user is viewing
|
1. **Chart context is automatic** → Symbol, interval, and timeframe are in the dynamic system prompt
|
||||||
2. **Get data** with `get_chart_data()` → Fetch the actual OHLCV bars
|
2. **Use `execute_python()`** → This is your PRIMARY analysis tool
|
||||||
3. **Analyze** with `analyze_chart_data()` → Run Python analysis if needed
|
- Automatically loads chart data into a pandas DataFrame `df`
|
||||||
4. **Respond** with insights based on the actual data
|
- Pre-imports numpy (`np`), pandas (`pd`), matplotlib (`plt`), and talib
|
||||||
|
- Provides access to the indicator registry for computing indicators
|
||||||
|
- Use `plot_ohlc(df)` helper for beautiful candlestick charts
|
||||||
|
3. **Only use `get_chart_data()`** → For simple data inspection without analysis
|
||||||
|
|
||||||
|
### Python Analysis (`execute_python`) - Your Primary Tool
|
||||||
|
|
||||||
|
**ALWAYS use `execute_python()` when the user asks for:**
|
||||||
|
- Technical indicators (RSI, MACD, Bollinger Bands, moving averages, etc.)
|
||||||
|
- Chart visualizations or plots
|
||||||
|
- Statistical calculations or market analysis
|
||||||
|
- Pattern detection or trend analysis
|
||||||
|
- Any computational analysis of price data
|
||||||
|
|
||||||
|
**Why `execute_python()` is preferred:**
|
||||||
|
- Chart data (`df`) is automatically loaded from ChartStore (visible time range)
|
||||||
|
- Full pandas/numpy/talib stack pre-imported
|
||||||
|
- Use `plot_ohlc(df)` for instant professional candlestick charts
|
||||||
|
- Access to 150+ indicators via `indicator_registry`
|
||||||
|
- **Results include plots as image URLs** that are automatically displayed to the user
|
||||||
|
- Prints and return values are included in the response
|
||||||
|
|
||||||
|
**CRITICAL: Plots are automatically shown to the user**
|
||||||
|
When you create a matplotlib figure (via `plot_ohlc()` or `plt.figure()`), it is automatically:
|
||||||
|
1. Saved as a PNG image
|
||||||
|
2. Returned in the response as a URL (e.g., `/uploads/plot_abc123.png`)
|
||||||
|
3. **Displayed in the user's chat interface** - they see the image immediately
|
||||||
|
|
||||||
|
You MUST use `execute_python()` with `plot_ohlc()` or matplotlib whenever the user wants to see a chart or plot.
|
||||||
|
|
||||||
|
**IMPORTANT: Never use `get_historical_data()` for chart analysis**
|
||||||
|
- `get_historical_data()` requires manual timestamp calculation and is only for custom queries
|
||||||
|
- When analyzing what the user is viewing, ALWAYS use `execute_python()` which automatically loads the correct data
|
||||||
|
- The `df` DataFrame in `execute_python()` is pre-loaded with the exact time range the user is viewing
|
||||||
|
|
||||||
|
**Example workflows:**
|
||||||
|
```python
|
||||||
|
# Computing an indicator and plotting
|
||||||
|
execute_python("""
|
||||||
|
df['RSI'] = talib.RSI(df['close'], 14)
|
||||||
|
fig = plot_ohlc(df, title='Price with RSI')
|
||||||
|
df[['close', 'RSI']].tail(10)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Multi-indicator analysis
|
||||||
|
execute_python("""
|
||||||
|
df['SMA20'] = df['close'].rolling(20).mean()
|
||||||
|
df['BB_upper'] = df['close'].rolling(20).mean() + 2 * df['close'].rolling(20).std()
|
||||||
|
df['BB_lower'] = df['close'].rolling(20).mean() - 2 * df['close'].rolling(20).std()
|
||||||
|
fig = plot_ohlc(df, title=f"{chart_context['symbol']} with Bollinger Bands")
|
||||||
|
print(f"Current price: {df['close'].iloc[-1]:.2f}")
|
||||||
|
print(f"20-period SMA: {df['SMA20'].iloc[-1]:.2f}")
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Only use `get_chart_data()` for:**
|
||||||
|
- Quick inspection of raw bar data
|
||||||
|
- When you just need the data structure without analysis
|
||||||
|
|
||||||
|
### Quick Reference: Common Tasks
|
||||||
|
|
||||||
|
| User Request | Tool to Use | Example |
|
||||||
|
|--------------|-------------|---------|
|
||||||
|
| "Show me RSI" | `execute_python()` | `df['RSI'] = talib.RSI(df['close'], 14); plot_ohlc(df)` |
|
||||||
|
| "What's the current price?" | `execute_python()` | `print(f"Current: {df['close'].iloc[-1]}")` |
|
||||||
|
| "Is this bullish?" | `execute_python()` | Compute SMAs, trend, and analyze |
|
||||||
|
| "Add Bollinger Bands" | `execute_python()` | Compute bands, use `plot_ohlc(df, title='BB')` |
|
||||||
|
| "Find swing highs" | `execute_python()` | Use pandas logic to detect patterns |
|
||||||
|
| "What indicators exist?" | `search_indicators()` | Search by category or query |
|
||||||
|
| "What chart am I viewing?" | N/A - automatic | Chart info is in dynamic system prompt |
|
||||||
|
| "Read other stores" | `read_sync_state(store_name)` | For TraderState, StrategyState, etc. |
|
||||||
|
|
||||||
## Working with Users
|
## Working with Users
|
||||||
|
|
||||||
@@ -124,51 +154,3 @@ Examples of questions that require checking ChartStore first:
|
|||||||
3. **Validate**: Ensure strategy makes sense before generating code
|
3. **Validate**: Ensure strategy makes sense before generating code
|
||||||
4. **Test**: Encourage backtesting and paper trading first
|
4. **Test**: Encourage backtesting and paper trading first
|
||||||
5. **Monitor**: Help users interpret live strategy behavior
|
5. **Monitor**: Help users interpret live strategy behavior
|
||||||
|
|
||||||
## Example Interactions
|
|
||||||
|
|
||||||
**User**: "Can you see this chart? What are the swing highs and lows?"
|
|
||||||
**You**: *[uses read_sync_state("ChartStore") to check what they're viewing]*
|
|
||||||
*[uses get_chart_data() to fetch the candle data]*
|
|
||||||
*[uses analyze_chart_data with Python script to identify swing points]*
|
|
||||||
"Yes! You're viewing BINANCE:BTC/USDT on the 15-minute chart. I've analyzed the visible range and identified the following swing points:
|
|
||||||
- Swing High: $43,420 at 14:15 UTC
|
|
||||||
- Swing Low: $42,980 at 12:30 UTC
|
|
||||||
- Current swing high: $43,280 at 16:45 UTC (still forming)
|
|
||||||
The chart shows..."
|
|
||||||
|
|
||||||
**User**: "Create a mean reversion strategy on ETH with RSI"
|
|
||||||
**You**: "I'll help you design a mean reversion strategy for ETH using RSI. A few questions:
|
|
||||||
- What RSI levels indicate oversold/overbought? (e.g., 30/70)
|
|
||||||
- What timeframe? (e.g., 1h, 15m)
|
|
||||||
- Position sizing as % of balance?
|
|
||||||
- Stop loss and take profit levels?"
|
|
||||||
|
|
||||||
**User**: "What am I looking at?"
|
|
||||||
**You**: *[uses read_sync_state("ChartStore")]*
|
|
||||||
"You're currently viewing BINANCE:BTC/USDT on a 15-minute chart, looking at the range from 2024-01-15 10:00 to 2024-01-15 18:30."
|
|
||||||
|
|
||||||
**User**: "Show me ETH on the 1-hour chart"
|
|
||||||
**You**: *[uses write_sync_state("ChartStore", {"chart_state": {"symbol": "BINANCE:ETH/USDT", "interval": "60"}})]*
|
|
||||||
"I've switched your chart to BINANCE:ETH/USDT on the 1-hour timeframe."
|
|
||||||
|
|
||||||
**User**: "What's the current price?"
|
|
||||||
**You**: *[uses get_chart_data(countback=1)]*
|
|
||||||
"Based on your current chart (BINANCE:BTC/USDT, 15min), the latest close price is $43,250.50 as of 14:30 UTC."
|
|
||||||
|
|
||||||
**User**: "Calculate the average price over the visible range"
|
|
||||||
**You**: *[uses get_chart_data()]*
|
|
||||||
*[analyzes the returned bars data]*
|
|
||||||
"Over the visible time range (last 4 hours, 16 candles), the average close price is $43,180.25, with a high of $43,420 and low of $42,980."
|
|
||||||
|
|
||||||
**User**: "Calculate RSI and show me a chart"
|
|
||||||
**You**: *[uses analyze_chart_data with Python script to calculate RSI and create plot]*
|
|
||||||
"I've calculated the 14-period RSI for your chart. The current RSI is 58.3, indicating neutral momentum. Here's the chart showing price and RSI over the visible range." *[image displayed to user]*
|
|
||||||
|
|
||||||
**User**: "Is this in an uptrend?"
|
|
||||||
**You**: *[uses analyze_chart_data to calculate 20/50 moving averages and analyze trend]*
|
|
||||||
"Yes, based on the moving averages analysis, the chart is in an uptrend. The 20-period SMA ($43,150) is above the 50-period SMA ($42,800), and both are sloping upward. Price is currently trading above both averages."
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Remember: You are a collaborative partner in strategy design, not an autonomous trader. Always prioritize safety, clarity, and user intent.
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pandas
|
|||||||
numpy
|
numpy
|
||||||
scipy
|
scipy
|
||||||
matplotlib
|
matplotlib
|
||||||
|
mplfinance
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
websockets
|
websockets
|
||||||
@@ -11,6 +12,7 @@ jsonpatch
|
|||||||
python-multipart
|
python-multipart
|
||||||
ccxt>=4.0.0
|
ccxt>=4.0.0
|
||||||
pyyaml
|
pyyaml
|
||||||
|
TA-Lib>=0.4.0
|
||||||
|
|
||||||
# LangChain agent dependencies
|
# LangChain agent dependencies
|
||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
@@ -19,6 +21,11 @@ langgraph-checkpoint-sqlite>=1.0.0
|
|||||||
langchain-anthropic>=0.3.0
|
langchain-anthropic>=0.3.0
|
||||||
langchain-community>=0.3.0
|
langchain-community>=0.3.0
|
||||||
|
|
||||||
|
# Additional tools for research and web access
|
||||||
|
arxiv>=2.0.0
|
||||||
|
duckduckgo-search>=7.0.0
|
||||||
|
requests>=2.31.0
|
||||||
|
|
||||||
# Local memory system
|
# Local memory system
|
||||||
chromadb>=0.4.0
|
chromadb>=0.4.0
|
||||||
sentence-transformers>=2.0.0
|
sentence-transformers>=2.0.0
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
from agent.core import create_agent
|
# Don't import at module level to avoid circular imports
|
||||||
|
# Users should import directly: from agent.core import create_agent
|
||||||
|
|
||||||
__all__ = ["create_agent"]
|
__all__ = ["core", "tools"]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS
|
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS
|
||||||
from agent.memory import MemoryManager
|
from agent.memory import MemoryManager
|
||||||
from agent.session import SessionManager
|
from agent.session import SessionManager
|
||||||
from agent.prompts import build_system_prompt
|
from agent.prompts import build_system_prompt
|
||||||
@@ -60,17 +60,15 @@ class AgentExecutor:
|
|||||||
"""Initialize the agent system."""
|
"""Initialize the agent system."""
|
||||||
await self.memory_manager.initialize()
|
await self.memory_manager.initialize()
|
||||||
|
|
||||||
# Create agent with tools and LangGraph checkpointing
|
# Create agent with tools and LangGraph checkpointer
|
||||||
checkpointer = self.memory_manager.get_checkpointer()
|
checkpointer = self.memory_manager.get_checkpointer()
|
||||||
|
|
||||||
# Build initial system prompt with context
|
# Create agent without a static system prompt
|
||||||
context = self.memory_manager.get_context_prompt()
|
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||||
system_prompt = build_system_prompt(context, [])
|
# Include all tool categories: sync, datasource, chart, indicator, and research
|
||||||
|
|
||||||
self.agent = create_react_agent(
|
self.agent = create_react_agent(
|
||||||
self.llm,
|
self.llm,
|
||||||
SYNC_TOOLS + DATASOURCE_TOOLS,
|
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS,
|
||||||
prompt=system_prompt,
|
|
||||||
checkpointer=checkpointer
|
checkpointer=checkpointer
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,26 +99,6 @@ class AgentExecutor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
||||||
|
|
||||||
def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage:
|
|
||||||
"""Build system message with context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: Agent state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SystemMessage with full context
|
|
||||||
"""
|
|
||||||
# Get context from loaded documents
|
|
||||||
context = self.memory_manager.get_context_prompt()
|
|
||||||
|
|
||||||
# Get active channels from metadata
|
|
||||||
active_channels = state.get("metadata", {}).get("active_channels", [])
|
|
||||||
|
|
||||||
# Build system prompt
|
|
||||||
system_prompt = build_system_prompt(context, active_channels)
|
|
||||||
|
|
||||||
return SystemMessage(content=system_prompt)
|
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
session: UserSession,
|
session: UserSession,
|
||||||
@@ -143,7 +121,12 @@ class AgentExecutor:
|
|||||||
|
|
||||||
async with lock:
|
async with lock:
|
||||||
try:
|
try:
|
||||||
# Build message history
|
# Build system prompt with current context
|
||||||
|
context = self.memory_manager.get_context_prompt()
|
||||||
|
system_prompt = build_system_prompt(context, session.active_channels)
|
||||||
|
|
||||||
|
# Build message history WITHOUT prepending system message
|
||||||
|
# The system prompt will be passed via state_modifier in the config
|
||||||
messages = []
|
messages = []
|
||||||
history = session.get_history(limit=10)
|
history = session.get_history(limit=10)
|
||||||
logger.info(f"Building message history, {len(history)} messages in history")
|
logger.info(f"Building message history, {len(history)} messages in history")
|
||||||
@@ -155,14 +138,18 @@ class AgentExecutor:
|
|||||||
elif msg.role == "assistant":
|
elif msg.role == "assistant":
|
||||||
messages.append(AIMessage(content=msg.content))
|
messages.append(AIMessage(content=msg.content))
|
||||||
|
|
||||||
logger.info(f"Prepared {len(messages)} messages for agent")
|
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'")
|
msg_type = type(msg).__name__
|
||||||
|
content_preview = msg.content[:100] if msg.content else 'EMPTY'
|
||||||
|
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
|
||||||
|
|
||||||
# Prepare config with metadata
|
# Prepare config with metadata and dynamic system prompt
|
||||||
|
# Pass system_prompt via state_modifier to avoid multiple system messages
|
||||||
config = RunnableConfig(
|
config = RunnableConfig(
|
||||||
configurable={
|
configurable={
|
||||||
"thread_id": session.session_id
|
"thread_id": session.session_id,
|
||||||
|
"state_modifier": system_prompt # Dynamic system prompt injection
|
||||||
},
|
},
|
||||||
metadata={
|
metadata={
|
||||||
"session_id": session.session_id,
|
"session_id": session.session_id,
|
||||||
@@ -178,6 +165,8 @@ class AgentExecutor:
|
|||||||
event_count = 0
|
event_count = 0
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
|
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
|
||||||
|
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
config=config,
|
config=config,
|
||||||
@@ -199,7 +188,35 @@ class AgentExecutor:
|
|||||||
elif event["event"] == "on_tool_end":
|
elif event["event"] == "on_tool_end":
|
||||||
tool_name = event.get("name", "unknown")
|
tool_name = event.get("name", "unknown")
|
||||||
tool_output = event.get("data", {}).get("output")
|
tool_output = event.get("data", {}).get("output")
|
||||||
logger.info(f"Tool call completed: {tool_name} with output: {tool_output}")
|
|
||||||
|
# LangChain may wrap the output in a ToolMessage with content field
|
||||||
|
# Try to extract the actual content from the ToolMessage
|
||||||
|
actual_output = tool_output
|
||||||
|
if hasattr(tool_output, "content"):
|
||||||
|
actual_output = tool_output.content
|
||||||
|
|
||||||
|
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
|
||||||
|
|
||||||
|
# Extract plot_urls from execute_python tool results
|
||||||
|
if tool_name == "execute_python":
|
||||||
|
# Try to parse as JSON if it's a string
|
||||||
|
import json
|
||||||
|
if isinstance(actual_output, str):
|
||||||
|
try:
|
||||||
|
actual_output = json.loads(actual_output)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
|
||||||
|
|
||||||
|
if isinstance(actual_output, dict):
|
||||||
|
tool_plot_urls = actual_output.get("plot_urls", [])
|
||||||
|
if tool_plot_urls:
|
||||||
|
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
|
||||||
|
plot_urls.extend(tool_plot_urls)
|
||||||
|
# Yield metadata about plots immediately
|
||||||
|
yield {
|
||||||
|
"content": "",
|
||||||
|
"metadata": {"plot_urls": tool_plot_urls}
|
||||||
|
}
|
||||||
|
|
||||||
# Extract streaming tokens
|
# Extract streaming tokens
|
||||||
elif event["event"] == "on_chat_model_stream":
|
elif event["event"] == "on_chat_model_stream":
|
||||||
|
|||||||
@@ -1,7 +1,54 @@
|
|||||||
from typing import List
|
from typing import List, Dict, Any
|
||||||
from gateway.user_session import UserSession
|
from gateway.user_session import UserSession
|
||||||
|
|
||||||
|
|
||||||
|
def _get_chart_store_context() -> str:
|
||||||
|
"""Get current ChartStore state for context injection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with ChartStore contents, or empty string if unavailable
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from agent.tools import _registry
|
||||||
|
|
||||||
|
if not _registry:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chart_store = _registry.entries.get("ChartStore")
|
||||||
|
if not chart_store:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chart_state = chart_store.model.model_dump(mode="json")
|
||||||
|
chart_data = chart_state.get("chart_state", {})
|
||||||
|
|
||||||
|
# Only include if there's actual chart data
|
||||||
|
if not chart_data or not chart_data.get("symbol"):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Format the chart information
|
||||||
|
symbol = chart_data.get("symbol", "N/A")
|
||||||
|
interval = chart_data.get("interval", "N/A")
|
||||||
|
start_time = chart_data.get("start_time")
|
||||||
|
end_time = chart_data.get("end_time")
|
||||||
|
|
||||||
|
chart_context = f"""
|
||||||
|
## Current Chart Context
|
||||||
|
|
||||||
|
The user is currently viewing a chart with the following settings:
|
||||||
|
- **Symbol**: {symbol}
|
||||||
|
- **Interval**: {interval}
|
||||||
|
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}
|
||||||
|
|
||||||
|
This information is automatically available because you're connected via websocket.
|
||||||
|
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.
|
||||||
|
"""
|
||||||
|
return chart_context
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Silently fail - chart context is optional enhancement
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||||
"""Build the system prompt for the agent.
|
"""Build the system prompt for the agent.
|
||||||
|
|
||||||
@@ -17,6 +64,15 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
|||||||
"""
|
"""
|
||||||
channels_str = ", ".join(active_channels) if active_channels else "none"
|
channels_str = ", ".join(active_channels) if active_channels else "none"
|
||||||
|
|
||||||
|
# Check if user is connected via websocket - if so, inject chart context
|
||||||
|
# Note: We check for websocket by looking for "websocket" in channel IDs
|
||||||
|
# since WebSocketChannel uses channel_id like "websocket-{uuid}"
|
||||||
|
has_websocket = any("websocket" in channel_id.lower() for channel_id in active_channels)
|
||||||
|
|
||||||
|
chart_context = ""
|
||||||
|
if has_websocket:
|
||||||
|
chart_context = _get_chart_store_context()
|
||||||
|
|
||||||
# Context already includes system_prompt.md and other docs
|
# Context already includes system_prompt.md and other docs
|
||||||
# Just add current session information
|
# Just add current session information
|
||||||
prompt = f"""{context}
|
prompt = f"""{context}
|
||||||
@@ -28,7 +84,7 @@ def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
|||||||
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
|
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
|
||||||
If the user sends a new message while you're responding, your current response will be interrupted
|
If the user sends a new message while you're responding, your current response will be interrupted
|
||||||
and you'll be re-invoked with the updated context.
|
and you'll be re-invoked with the updated context.
|
||||||
"""
|
{chart_context}"""
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,662 +0,0 @@
|
|||||||
from typing import Dict, Any, List, Optional
|
|
||||||
import io
|
|
||||||
import base64
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import redirect_stdout, redirect_stderr
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Global registry instance (will be set by main.py)
|
|
||||||
_registry = None
|
|
||||||
_datasource_registry = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_registry(registry):
|
|
||||||
"""Set the global SyncRegistry instance for tools to use."""
|
|
||||||
global _registry
|
|
||||||
_registry = registry
|
|
||||||
|
|
||||||
|
|
||||||
def set_datasource_registry(datasource_registry):
|
|
||||||
"""Set the global DataSourceRegistry instance for tools to use."""
|
|
||||||
global _datasource_registry
|
|
||||||
_datasource_registry = datasource_registry
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def list_sync_stores() -> List[str]:
|
|
||||||
"""List all available synchronization stores.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of store names that can be read/written
|
|
||||||
"""
|
|
||||||
if not _registry:
|
|
||||||
return []
|
|
||||||
return list(_registry.entries.keys())
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
|
||||||
"""Read the current state of a synchronization store.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing the current state of the store
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If store_name doesn't exist
|
|
||||||
"""
|
|
||||||
if not _registry:
|
|
||||||
raise ValueError("SyncRegistry not initialized")
|
|
||||||
|
|
||||||
entry = _registry.entries.get(store_name)
|
|
||||||
if not entry:
|
|
||||||
available = list(_registry.entries.keys())
|
|
||||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
|
||||||
|
|
||||||
return entry.model.model_dump(mode="json")
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
|
||||||
"""Update the state of a synchronization store.
|
|
||||||
|
|
||||||
This will apply the updates to the store and trigger synchronization
|
|
||||||
with all connected clients.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_name: Name of the store to update
|
|
||||||
updates: Dictionary of field updates (field_name: new_value)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with status and updated fields
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If store_name doesn't exist or updates are invalid
|
|
||||||
"""
|
|
||||||
if not _registry:
|
|
||||||
raise ValueError("SyncRegistry not initialized")
|
|
||||||
|
|
||||||
entry = _registry.entries.get(store_name)
|
|
||||||
if not entry:
|
|
||||||
available = list(_registry.entries.keys())
|
|
||||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get current state
|
|
||||||
current_state = entry.model.model_dump(mode="json")
|
|
||||||
|
|
||||||
# Apply updates
|
|
||||||
new_state = {**current_state, **updates}
|
|
||||||
|
|
||||||
# Update the model
|
|
||||||
_registry._update_model(entry.model, new_state)
|
|
||||||
|
|
||||||
# Trigger sync
|
|
||||||
await _registry.push_all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"store": store_name,
|
|
||||||
"updated_fields": list(updates.keys())
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
|
||||||
"""Get the schema/structure of a synchronization store.
|
|
||||||
|
|
||||||
This shows what fields are available and their types.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_name: Name of the store
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary describing the store's schema
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If store_name doesn't exist
|
|
||||||
"""
|
|
||||||
if not _registry:
|
|
||||||
raise ValueError("SyncRegistry not initialized")
|
|
||||||
|
|
||||||
entry = _registry.entries.get(store_name)
|
|
||||||
if not entry:
|
|
||||||
available = list(_registry.entries.keys())
|
|
||||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
|
||||||
|
|
||||||
# Get model schema
|
|
||||||
schema = entry.model.model_json_schema()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"store_name": store_name,
|
|
||||||
"schema": schema
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# DataSource tools
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def list_data_sources() -> List[str]:
|
|
||||||
"""List all available data sources.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of data source names that can be queried for market data
|
|
||||||
"""
|
|
||||||
if not _datasource_registry:
|
|
||||||
return []
|
|
||||||
return _datasource_registry.list_sources()
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def search_symbols(
|
|
||||||
query: str,
|
|
||||||
type: Optional[str] = None,
|
|
||||||
exchange: Optional[str] = None,
|
|
||||||
limit: int = 30,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Search for trading symbols across all data sources.
|
|
||||||
|
|
||||||
Automatically searches all available data sources and returns aggregated results.
|
|
||||||
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
|
||||||
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
|
||||||
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
|
||||||
limit: Maximum number of results per source (default: 30)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping source names to lists of matching symbols.
|
|
||||||
Each symbol includes: symbol, full_name, description, exchange, type.
|
|
||||||
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
|
||||||
|
|
||||||
Example response:
|
|
||||||
{
|
|
||||||
"demo": [
|
|
||||||
{
|
|
||||||
"symbol": "BTC/USDT",
|
|
||||||
"full_name": "Bitcoin / Tether USD",
|
|
||||||
"description": "Bitcoin perpetual futures",
|
|
||||||
"exchange": "demo",
|
|
||||||
"type": "crypto"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if not _datasource_registry:
|
|
||||||
raise ValueError("DataSourceRegistry not initialized")
|
|
||||||
|
|
||||||
# Always search all sources
|
|
||||||
results = await _datasource_registry.search_all(query, type, exchange, limit)
|
|
||||||
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
|
||||||
"""Get complete metadata for a trading symbol.
|
|
||||||
|
|
||||||
This retrieves full information about a symbol including:
|
|
||||||
- Description and type
|
|
||||||
- Supported time resolutions
|
|
||||||
- Available data columns (OHLCV, volume, funding rates, etc.)
|
|
||||||
- Trading session information
|
|
||||||
- Price scale and precision
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_name: Name of the data source (use list_data_sources to see available)
|
|
||||||
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing complete symbol metadata including column schema
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If source_name or symbol is not found
|
|
||||||
"""
|
|
||||||
if not _datasource_registry:
|
|
||||||
raise ValueError("DataSourceRegistry not initialized")
|
|
||||||
|
|
||||||
symbol_info = await _datasource_registry.resolve_symbol(source_name, symbol)
|
|
||||||
return symbol_info.model_dump()
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_historical_data(
|
|
||||||
source_name: str,
|
|
||||||
symbol: str,
|
|
||||||
resolution: str,
|
|
||||||
from_time: int,
|
|
||||||
to_time: int,
|
|
||||||
countback: Optional[int] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Get historical bar/candle data for a symbol.
|
|
||||||
|
|
||||||
Retrieves time-series data between the specified timestamps. The data
|
|
||||||
includes all columns defined for the symbol (OHLCV + any custom columns).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_name: Name of the data source
|
|
||||||
symbol: Symbol identifier
|
|
||||||
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
|
||||||
from_time: Start time as Unix timestamp in seconds
|
|
||||||
to_time: End time as Unix timestamp in seconds
|
|
||||||
countback: Optional limit on number of bars to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing:
|
|
||||||
- symbol: The requested symbol
|
|
||||||
- resolution: The time resolution
|
|
||||||
- bars: List of bar data with 'time' and 'data' fields
|
|
||||||
- columns: Schema describing available data columns
|
|
||||||
- nextTime: If present, indicates more data is available for pagination
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If source, symbol, or resolution is invalid
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Get 1-hour BTC data for the last 24 hours
|
|
||||||
import time
|
|
||||||
to_time = int(time.time())
|
|
||||||
from_time = to_time - 86400 # 24 hours ago
|
|
||||||
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
|
||||||
"""
|
|
||||||
if not _datasource_registry:
|
|
||||||
raise ValueError("DataSourceRegistry not initialized")
|
|
||||||
|
|
||||||
source = _datasource_registry.get(source_name)
|
|
||||||
if not source:
|
|
||||||
available = _datasource_registry.list_sources()
|
|
||||||
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
|
||||||
|
|
||||||
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
|
||||||
return result.model_dump()
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
|
||||||
"""Internal implementation for getting chart data.
|
|
||||||
|
|
||||||
This is a helper function that can be called by both get_chart_data tool
|
|
||||||
and analyze_chart_data tool.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
|
||||||
"""
|
|
||||||
if not _registry:
|
|
||||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
|
||||||
|
|
||||||
if not _datasource_registry:
|
|
||||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
|
||||||
|
|
||||||
# Read current chart state
|
|
||||||
chart_store = _registry.entries.get("ChartStore")
|
|
||||||
if not chart_store:
|
|
||||||
raise ValueError("ChartStore not found in registry")
|
|
||||||
|
|
||||||
chart_state = chart_store.model.model_dump(mode="json")
|
|
||||||
chart_data = chart_state.get("chart_state", {})
|
|
||||||
|
|
||||||
symbol = chart_data.get("symbol", "")
|
|
||||||
interval = chart_data.get("interval", "15")
|
|
||||||
start_time = chart_data.get("start_time")
|
|
||||||
end_time = chart_data.get("end_time")
|
|
||||||
|
|
||||||
if not symbol:
|
|
||||||
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
|
|
||||||
|
|
||||||
# Parse the symbol to extract exchange/source and symbol name
|
|
||||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
|
||||||
if ":" not in symbol:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
|
||||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
|
||||||
)
|
|
||||||
|
|
||||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
|
||||||
source_name = exchange_prefix.lower()
|
|
||||||
|
|
||||||
# Get the data source
|
|
||||||
source = _datasource_registry.get(source_name)
|
|
||||||
if not source:
|
|
||||||
available = _datasource_registry.list_sources()
|
|
||||||
raise ValueError(
|
|
||||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
|
||||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine time range - REQUIRE it to be set, no defaults
|
|
||||||
if start_time is None or end_time is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
|
||||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
|
||||||
f"Wait for the chart to fully load before analyzing data."
|
|
||||||
)
|
|
||||||
|
|
||||||
from_time = int(start_time)
|
|
||||||
end_time = int(end_time)
|
|
||||||
logger.info(
|
|
||||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
|
||||||
f"countback={countback}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
|
||||||
f"resolution '{interval}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query the data source
|
|
||||||
result = await source.get_bars(
|
|
||||||
symbol=symbol_name,
|
|
||||||
resolution=interval,
|
|
||||||
from_time=from_time,
|
|
||||||
to_time=end_time,
|
|
||||||
countback=countback
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Received {len(result.bars)} bars from data source. "
|
|
||||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
|
||||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build chart context to return along with result
|
|
||||||
chart_context = {
|
|
||||||
"symbol": symbol,
|
|
||||||
"interval": interval,
|
|
||||||
"start_time": start_time,
|
|
||||||
"end_time": end_time
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, chart_context, source_name
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
|
||||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
|
||||||
|
|
||||||
This is a convenience tool that automatically:
|
|
||||||
1. Reads the ChartStore to see what chart the user is viewing
|
|
||||||
2. Parses the symbol to determine the data source (exchange prefix)
|
|
||||||
3. Queries the appropriate data source for that symbol's data
|
|
||||||
4. Returns the data for the visible time range and interval
|
|
||||||
|
|
||||||
This is the preferred way to access chart data when helping the user analyze
|
|
||||||
what they're looking at, since it automatically uses their current chart context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
countback: Optional limit on number of bars to return. If not specified,
|
|
||||||
returns all bars in the visible time range.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing:
|
|
||||||
- chart_context: Current chart state (symbol, interval, time range)
|
|
||||||
- symbol: The trading pair being viewed
|
|
||||||
- resolution: The chart interval
|
|
||||||
- bars: List of bar data with 'time' and 'data' fields
|
|
||||||
- columns: Schema describing available data columns
|
|
||||||
- source: Which data source was used
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
|
||||||
or if the symbol format is invalid
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
|
||||||
data = get_chart_data()
|
|
||||||
# Returns BTC/USDT data from binance source at 15min resolution
|
|
||||||
# for the currently visible time range
|
|
||||||
"""
|
|
||||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
|
||||||
|
|
||||||
# Return enriched result with chart context
|
|
||||||
response = result.model_dump()
|
|
||||||
response["chart_context"] = chart_context
|
|
||||||
response["source"] = source_name
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def analyze_chart_data(python_script: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
|
||||||
"""Analyze the current chart data using a Python script with pandas and matplotlib.
|
|
||||||
|
|
||||||
This tool:
|
|
||||||
1. Gets the current chart data (same as get_chart_data)
|
|
||||||
2. Converts it to a pandas DataFrame with columns: time, open, high, low, close, volume
|
|
||||||
3. Executes your Python script with access to the DataFrame as 'df'
|
|
||||||
4. Saves any matplotlib plots to disk and returns URLs to access them
|
|
||||||
5. Returns any final DataFrame result and plot URLs
|
|
||||||
|
|
||||||
The script has access to:
|
|
||||||
- `df`: pandas DataFrame with OHLCV data indexed by datetime
|
|
||||||
- `pandas` (as `pd`): For data manipulation
|
|
||||||
- `numpy` (as `np`): For numerical operations
|
|
||||||
- `matplotlib.pyplot` (as `plt`): For plotting (use plt.figure() for each plot)
|
|
||||||
|
|
||||||
All matplotlib figures are automatically saved to disk and accessible via URLs.
|
|
||||||
The last expression in the script (if it's a DataFrame) is returned as the result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
python_script: Python code to execute. The DataFrame is available as 'df'.
|
|
||||||
Can use pandas, numpy, matplotlib. Return a DataFrame to include it in results.
|
|
||||||
countback: Optional limit on number of bars to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing:
|
|
||||||
- chart_context: Current chart state (symbol, interval, time range)
|
|
||||||
- source: Data source used
|
|
||||||
- script_output: Any printed output from the script
|
|
||||||
- result_dataframe: If script returns a DataFrame, it's included here as dict
|
|
||||||
- plot_urls: List of URLs to saved plot images (one per plt.figure())
|
|
||||||
- error: Error message if script execution failed
|
|
||||||
|
|
||||||
Example scripts:
|
|
||||||
# Calculate 20-period SMA and plot
|
|
||||||
```python
|
|
||||||
df['SMA20'] = df['close'].rolling(20).mean()
|
|
||||||
plt.figure(figsize=(12, 6))
|
|
||||||
plt.plot(df.index, df['close'], label='Close')
|
|
||||||
plt.plot(df.index, df['SMA20'], label='SMA20')
|
|
||||||
plt.legend()
|
|
||||||
plt.title('Price with SMA')
|
|
||||||
df[['close', 'SMA20']].tail(10) # Return last 10 rows
|
|
||||||
```
|
|
||||||
|
|
||||||
# Calculate RSI
|
|
||||||
```python
|
|
||||||
delta = df['close'].diff()
|
|
||||||
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
|
||||||
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
|
||||||
rs = gain / loss
|
|
||||||
df['RSI'] = 100 - (100 / (1 + rs))
|
|
||||||
df[['close', 'RSI']].tail(20)
|
|
||||||
```
|
|
||||||
|
|
||||||
# Multiple plots
|
|
||||||
```python
|
|
||||||
# Price chart
|
|
||||||
plt.figure(figsize=(12, 4))
|
|
||||||
plt.plot(df['close'])
|
|
||||||
plt.title('Price')
|
|
||||||
|
|
||||||
# Volume chart
|
|
||||||
plt.figure(figsize=(12, 3))
|
|
||||||
plt.bar(df.index, df['volume'])
|
|
||||||
plt.title('Volume')
|
|
||||||
|
|
||||||
df.describe() # Return statistics
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not _registry:
|
|
||||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
|
||||||
|
|
||||||
if not _datasource_registry:
|
|
||||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Import pandas and numpy here to allow lazy loading
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('Agg') # Non-interactive backend
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
except ImportError as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Required library not installed: {e}. "
|
|
||||||
"Please install pandas, numpy, and matplotlib: pip install pandas numpy matplotlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get chart data using the internal helper function
|
|
||||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
|
||||||
|
|
||||||
# Build the same response format as get_chart_data
|
|
||||||
chart_data = result.model_dump()
|
|
||||||
chart_data["chart_context"] = chart_context
|
|
||||||
chart_data["source"] = source_name
|
|
||||||
|
|
||||||
# Convert bars to DataFrame
|
|
||||||
bars = chart_data.get('bars', [])
|
|
||||||
if not bars:
|
|
||||||
return {
|
|
||||||
"chart_context": chart_data.get('chart_context', {}),
|
|
||||||
"source": chart_data.get('source', ''),
|
|
||||||
"error": "No data available for the current chart"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Build DataFrame
|
|
||||||
rows = []
|
|
||||||
for bar in bars:
|
|
||||||
row = {
|
|
||||||
'time': pd.to_datetime(bar['time'], unit='s'),
|
|
||||||
**bar['data'] # Includes open, high, low, close, volume, etc.
|
|
||||||
}
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
df = pd.DataFrame(rows)
|
|
||||||
df.set_index('time', inplace=True)
|
|
||||||
|
|
||||||
# Convert price columns to float for clean numeric operations
|
|
||||||
price_columns = ['open', 'high', 'low', 'close', 'volume']
|
|
||||||
for col in price_columns:
|
|
||||||
if col in df.columns:
|
|
||||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created DataFrame with {len(df)} rows, columns: {df.columns.tolist()}, "
|
|
||||||
f"time range: {df.index.min()} to {df.index.max()}, "
|
|
||||||
f"dtypes: {df.dtypes.to_dict()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare execution environment
|
|
||||||
script_globals = {
|
|
||||||
'df': df,
|
|
||||||
'pd': pd,
|
|
||||||
'np': np,
|
|
||||||
'plt': plt,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Capture stdout/stderr
|
|
||||||
stdout_capture = io.StringIO()
|
|
||||||
stderr_capture = io.StringIO()
|
|
||||||
|
|
||||||
result_df = None
|
|
||||||
error_msg = None
|
|
||||||
plot_urls = []
|
|
||||||
|
|
||||||
# Determine uploads directory (relative to this file)
|
|
||||||
uploads_dir = Path(__file__).parent.parent.parent / "data" / "uploads"
|
|
||||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
|
|
||||||
# Execute the script
|
|
||||||
exec(python_script, script_globals)
|
|
||||||
|
|
||||||
# Check if the last line is an expression that returns a DataFrame
|
|
||||||
# We'll try to evaluate it separately
|
|
||||||
script_lines = python_script.strip().split('\n')
|
|
||||||
if script_lines:
|
|
||||||
last_line = script_lines[-1].strip()
|
|
||||||
# Only evaluate if it doesn't look like a statement
|
|
||||||
if last_line and not any(last_line.startswith(kw) for kw in ['if', 'for', 'while', 'def', 'class', 'import', 'from', 'with', 'try', 'return']):
|
|
||||||
try:
|
|
||||||
last_result = eval(last_line, script_globals)
|
|
||||||
if isinstance(last_result, pd.DataFrame):
|
|
||||||
result_df = last_result
|
|
||||||
except:
|
|
||||||
# If eval fails, that's okay - might not be an expression
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Save all matplotlib figures to disk
|
|
||||||
for fig_num in plt.get_fignums():
|
|
||||||
fig = plt.figure(fig_num)
|
|
||||||
|
|
||||||
# Generate unique filename
|
|
||||||
plot_id = str(uuid.uuid4())
|
|
||||||
filename = f"plot_{plot_id}.png"
|
|
||||||
filepath = uploads_dir / filename
|
|
||||||
|
|
||||||
# Save figure to file
|
|
||||||
fig.savefig(filepath, format='png', bbox_inches='tight', dpi=100)
|
|
||||||
|
|
||||||
# Generate URL that can be accessed via the web server
|
|
||||||
plot_url = f"/uploads/{filename}"
|
|
||||||
plot_urls.append(plot_url)
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
|
||||||
import traceback
|
|
||||||
error_msg += f"\n{traceback.format_exc()}"
|
|
||||||
|
|
||||||
# Build response
|
|
||||||
response = {
|
|
||||||
"chart_context": chart_data.get('chart_context', {}),
|
|
||||||
"source": chart_data.get('source', ''),
|
|
||||||
"script_output": stdout_capture.getvalue(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if error_msg:
|
|
||||||
response["error"] = error_msg
|
|
||||||
response["stderr"] = stderr_capture.getvalue()
|
|
||||||
|
|
||||||
if result_df is not None:
|
|
||||||
# Convert DataFrame to dict for JSON serialization
|
|
||||||
response["result_dataframe"] = {
|
|
||||||
"columns": result_df.columns.tolist(),
|
|
||||||
"index": result_df.index.astype(str).tolist() if hasattr(result_df.index, 'astype') else result_df.index.tolist(),
|
|
||||||
"data": result_df.values.tolist(),
|
|
||||||
"shape": result_df.shape,
|
|
||||||
}
|
|
||||||
|
|
||||||
if plot_urls:
|
|
||||||
response["plot_urls"] = plot_urls
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
# Export all tools
|
|
||||||
SYNC_TOOLS = [
|
|
||||||
list_sync_stores,
|
|
||||||
read_sync_state,
|
|
||||||
write_sync_state,
|
|
||||||
get_store_schema
|
|
||||||
]
|
|
||||||
|
|
||||||
DATASOURCE_TOOLS = [
|
|
||||||
list_data_sources,
|
|
||||||
search_symbols,
|
|
||||||
get_symbol_info,
|
|
||||||
get_historical_data,
|
|
||||||
get_chart_data,
|
|
||||||
analyze_chart_data
|
|
||||||
]
|
|
||||||
139
backend/src/agent/tools/CHART_UTILS_README.md
Normal file
139
backend/src/agent/tools/CHART_UTILS_README.md
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# Chart Utilities - Standard OHLC Plotting
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `chart_utils.py` module provides convenience functions for creating beautiful, professional OHLC candlestick charts with a consistent look and feel. This is designed to be used by the LLM in `analyze_chart_data` scripts, eliminating the need to write custom matplotlib code for every chart.
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- **Beautiful by default**: Uses mplfinance with seaborn-inspired aesthetics
|
||||||
|
- **Consistent styling**: Professional color scheme (teal green up, coral red down)
|
||||||
|
- **Easy to use**: Simple function calls instead of complex matplotlib code
|
||||||
|
- **Customizable**: Supports all mplfinance options via kwargs
|
||||||
|
- **Volume integration**: Optional volume subplot
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The required package `mplfinance` has been added to `requirements.txt`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mplfinance
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Functions
|
||||||
|
|
||||||
|
### 1. `plot_ohlc(df, title=None, volume=True, figsize=(14, 8), **kwargs)`
|
||||||
|
|
||||||
|
Main function for creating standard OHLC candlestick charts.
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `df`: pandas DataFrame with DatetimeIndex and OHLCV columns
|
||||||
|
- `title`: Optional chart title
|
||||||
|
- `volume`: Whether to include volume subplot (default: True)
|
||||||
|
- `figsize`: Figure size in inches (default: (14, 8))
|
||||||
|
- `**kwargs`: Additional mplfinance.plot() arguments
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
fig = plot_ohlc(df, title='BTC/USDT 15min', volume=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. `add_indicators_to_plot(df, indicators, **plot_kwargs)`
|
||||||
|
|
||||||
|
Creates OHLC chart with technical indicators overlaid.
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `df`: DataFrame with OHLCV data and indicator columns
|
||||||
|
- `indicators`: Dict mapping indicator column names to display parameters
|
||||||
|
- `**plot_kwargs`: Additional arguments for plot_ohlc()
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||||
|
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||||
|
|
||||||
|
fig = add_indicators_to_plot(
|
||||||
|
df,
|
||||||
|
indicators={
|
||||||
|
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||||
|
'SMA_50': {'color': 'red', 'width': 1.5}
|
||||||
|
},
|
||||||
|
title='Price with Moving Averages'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Preset Functions
|
||||||
|
|
||||||
|
- `plot_price_volume(df, title=None)` - Standard price + volume chart
|
||||||
|
- `plot_price_only(df, title=None)` - Candlesticks without volume
|
||||||
|
|
||||||
|
## Integration with analyze_chart_data
|
||||||
|
|
||||||
|
These functions are automatically available in the `analyze_chart_data` tool's script environment:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In an analyze_chart_data script:
|
||||||
|
# df is already provided
|
||||||
|
|
||||||
|
# Simple usage
|
||||||
|
fig = plot_ohlc(df, title='Price Action')
|
||||||
|
|
||||||
|
# With indicators
|
||||||
|
df['SMA'] = df['close'].rolling(20).mean()
|
||||||
|
fig = add_indicators_to_plot(
|
||||||
|
df,
|
||||||
|
indicators={'SMA': {'color': 'blue', 'width': 1.5}},
|
||||||
|
title='Price with SMA'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return data for the assistant
|
||||||
|
df[['close', 'SMA']].tail(10)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Styling
|
||||||
|
|
||||||
|
The default style includes:
|
||||||
|
- **Up candles**: Teal green (#26a69a)
|
||||||
|
- **Down candles**: Coral red (#ef5350)
|
||||||
|
- **Background**: Light gray with white axes
|
||||||
|
- **Grid**: Subtle dashed lines with 30% alpha
|
||||||
|
- **Professional fonts**: Clean, readable sizes
|
||||||
|
|
||||||
|
## Why This Matters
|
||||||
|
|
||||||
|
**Before:**
|
||||||
|
```python
|
||||||
|
# LLM had to write this every time
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 6))
|
||||||
|
ax.plot(df.index, df['close'], label='Close')
|
||||||
|
# ... lots more code for styling, colors, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
**After:**
|
||||||
|
```python
|
||||||
|
# LLM can now just do this
|
||||||
|
fig = plot_ohlc(df, title='BTC/USDT')
|
||||||
|
```
|
||||||
|
|
||||||
|
Benefits:
|
||||||
|
- ✅ Less code to generate → faster response
|
||||||
|
- ✅ Consistent appearance across all charts
|
||||||
|
- ✅ Professional look out of the box
|
||||||
|
- ✅ Easier to maintain and customize
|
||||||
|
- ✅ Better use of mplfinance's candlestick rendering
|
||||||
|
|
||||||
|
## Example Output
|
||||||
|
|
||||||
|
See `chart_utils_example.py` for runnable examples demonstrating:
|
||||||
|
1. Basic OHLC chart with volume
|
||||||
|
2. OHLC chart with multiple indicators
|
||||||
|
3. Price-only chart
|
||||||
|
4. Custom styling options
|
||||||
|
|
||||||
|
## File Locations
|
||||||
|
|
||||||
|
- **Main module**: `backend/src/agent/tools/chart_utils.py`
|
||||||
|
- **Integration**: `backend/src/agent/tools/chart_tools.py` (lines 306-328)
|
||||||
|
- **Examples**: `backend/src/agent/tools/chart_utils_example.py`
|
||||||
|
- **Dependency**: `backend/requirements.txt` (mplfinance added)
|
||||||
50
backend/src/agent/tools/__init__.py
Normal file
50
backend/src/agent/tools/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Agent tools for trading operations.
|
||||||
|
|
||||||
|
This package provides tools for:
|
||||||
|
- Synchronization stores (sync_tools)
|
||||||
|
- Data sources and market data (datasource_tools)
|
||||||
|
- Chart data access and analysis (chart_tools)
|
||||||
|
- Technical indicators (indicator_tools)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Global registries that will be set by main.py
|
||||||
|
_registry = None
|
||||||
|
_datasource_registry = None
|
||||||
|
_indicator_registry = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_registry(registry):
|
||||||
|
"""Set the global SyncRegistry instance for tools to use."""
|
||||||
|
global _registry
|
||||||
|
_registry = registry
|
||||||
|
|
||||||
|
|
||||||
|
def set_datasource_registry(datasource_registry):
|
||||||
|
"""Set the global DataSourceRegistry instance for tools to use."""
|
||||||
|
global _datasource_registry
|
||||||
|
_datasource_registry = datasource_registry
|
||||||
|
|
||||||
|
|
||||||
|
def set_indicator_registry(indicator_registry):
|
||||||
|
"""Set the global IndicatorRegistry instance for tools to use."""
|
||||||
|
global _indicator_registry
|
||||||
|
_indicator_registry = indicator_registry
|
||||||
|
|
||||||
|
|
||||||
|
# Import all tools from submodules
|
||||||
|
from .sync_tools import SYNC_TOOLS
|
||||||
|
from .datasource_tools import DATASOURCE_TOOLS
|
||||||
|
from .chart_tools import CHART_TOOLS
|
||||||
|
from .indicator_tools import INDICATOR_TOOLS
|
||||||
|
from .research_tools import RESEARCH_TOOLS
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"set_registry",
|
||||||
|
"set_datasource_registry",
|
||||||
|
"set_indicator_registry",
|
||||||
|
"SYNC_TOOLS",
|
||||||
|
"DATASOURCE_TOOLS",
|
||||||
|
"CHART_TOOLS",
|
||||||
|
"INDICATOR_TOOLS",
|
||||||
|
"RESEARCH_TOOLS",
|
||||||
|
]
|
||||||
371
backend/src/agent/tools/chart_tools.py
Normal file
371
backend/src/agent/tools/chart_tools.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
"""Chart data access and analysis tools."""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from contextlib import redirect_stdout, redirect_stderr
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_registry():
|
||||||
|
"""Get the global registry instance."""
|
||||||
|
from . import _registry
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def _get_datasource_registry():
|
||||||
|
"""Get the global datasource registry instance."""
|
||||||
|
from . import _datasource_registry
|
||||||
|
return _datasource_registry
|
||||||
|
|
||||||
|
|
||||||
|
def _get_indicator_registry():
|
||||||
|
"""Get the global indicator registry instance."""
|
||||||
|
from . import _indicator_registry
|
||||||
|
return _indicator_registry
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||||
|
"""Internal implementation for getting chart data.
|
||||||
|
|
||||||
|
This is a helper function that can be called by both get_chart_data tool
|
||||||
|
and analyze_chart_data tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||||
|
"""
|
||||||
|
registry = _get_registry()
|
||||||
|
datasource_registry = _get_datasource_registry()
|
||||||
|
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||||
|
|
||||||
|
if not datasource_registry:
|
||||||
|
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||||
|
|
||||||
|
# Read current chart state
|
||||||
|
chart_store = registry.entries.get("ChartStore")
|
||||||
|
if not chart_store:
|
||||||
|
raise ValueError("ChartStore not found in registry")
|
||||||
|
|
||||||
|
chart_state = chart_store.model.model_dump(mode="json")
|
||||||
|
chart_data = chart_state.get("chart_state", {})
|
||||||
|
|
||||||
|
symbol = chart_data.get("symbol", "")
|
||||||
|
interval = chart_data.get("interval", "15")
|
||||||
|
start_time = chart_data.get("start_time")
|
||||||
|
end_time = chart_data.get("end_time")
|
||||||
|
|
||||||
|
if not symbol:
|
||||||
|
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
|
||||||
|
|
||||||
|
# Parse the symbol to extract exchange/source and symbol name
|
||||||
|
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||||
|
if ":" not in symbol:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||||
|
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||||
|
)
|
||||||
|
|
||||||
|
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||||
|
source_name = exchange_prefix.lower()
|
||||||
|
|
||||||
|
# Get the data source
|
||||||
|
source = datasource_registry.get(source_name)
|
||||||
|
if not source:
|
||||||
|
available = datasource_registry.list_sources()
|
||||||
|
raise ValueError(
|
||||||
|
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||||
|
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine time range - REQUIRE it to be set, no defaults
|
||||||
|
if start_time is None or end_time is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||||
|
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||||
|
f"Wait for the chart to fully load before analyzing data."
|
||||||
|
)
|
||||||
|
|
||||||
|
from_time = int(start_time)
|
||||||
|
end_time = int(end_time)
|
||||||
|
logger.info(
|
||||||
|
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||||
|
f"countback={countback}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||||
|
f"resolution '{interval}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query the data source
|
||||||
|
result = await source.get_bars(
|
||||||
|
symbol=symbol_name,
|
||||||
|
resolution=interval,
|
||||||
|
from_time=from_time,
|
||||||
|
to_time=end_time,
|
||||||
|
countback=countback
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received {len(result.bars)} bars from data source. "
|
||||||
|
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||||
|
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build chart context to return along with result
|
||||||
|
chart_context = {
|
||||||
|
"symbol": symbol,
|
||||||
|
"interval": interval,
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, chart_context, source_name
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||||
|
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||||
|
|
||||||
|
This is a convenience tool that automatically:
|
||||||
|
1. Reads the ChartStore to see what chart the user is viewing
|
||||||
|
2. Parses the symbol to determine the data source (exchange prefix)
|
||||||
|
3. Queries the appropriate data source for that symbol's data
|
||||||
|
4. Returns the data for the visible time range and interval
|
||||||
|
|
||||||
|
This is the preferred way to access chart data when helping the user analyze
|
||||||
|
what they're looking at, since it automatically uses their current chart context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
countback: Optional limit on number of bars to return. If not specified,
|
||||||
|
returns all bars in the visible time range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- chart_context: Current chart state (symbol, interval, time range)
|
||||||
|
- symbol: The trading pair being viewed
|
||||||
|
- resolution: The chart interval
|
||||||
|
- bars: List of bar data with 'time' and 'data' fields
|
||||||
|
- columns: Schema describing available data columns
|
||||||
|
- source: Which data source was used
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||||
|
or if the symbol format is invalid
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||||
|
data = get_chart_data()
|
||||||
|
# Returns BTC/USDT data from binance source at 15min resolution
|
||||||
|
# for the currently visible time range
|
||||||
|
"""
|
||||||
|
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||||
|
|
||||||
|
# Return enriched result with chart context
|
||||||
|
response = result.model_dump()
|
||||||
|
response["chart_context"] = chart_context
|
||||||
|
response["source"] = source_name
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||||
|
"""Execute Python code for technical analysis with automatic chart data loading.
|
||||||
|
|
||||||
|
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
|
||||||
|
|
||||||
|
This is your go-to tool whenever the user asks about indicators, wants to see
|
||||||
|
a chart, or needs any computational analysis of market data.
|
||||||
|
|
||||||
|
Pre-loaded Environment:
|
||||||
|
- `pd` : pandas
|
||||||
|
- `np` : numpy
|
||||||
|
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
|
||||||
|
- `talib` : TA-Lib technical analysis library
|
||||||
|
- `indicator_registry`: 150+ registered indicators
|
||||||
|
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
|
||||||
|
|
||||||
|
Auto-loaded when user has a chart open:
|
||||||
|
- `df` : pandas DataFrame with DatetimeIndex and columns:
|
||||||
|
open, high, low, close, volume (OHLCV data ready to use)
|
||||||
|
- `chart_context` : dict with symbol, interval, start_time, end_time
|
||||||
|
|
||||||
|
The `plot_ohlc()` Helper:
|
||||||
|
Create professional candlestick charts instantly:
|
||||||
|
- `plot_ohlc(df)` - basic OHLC chart with volume
|
||||||
|
- `plot_ohlc(df, title='BTC 15min')` - with custom title
|
||||||
|
- `plot_ohlc(df, volume=False)` - price only, no volume
|
||||||
|
- Returns a matplotlib Figure that's automatically saved to plot_urls
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Python code to execute
|
||||||
|
countback: Optional limit on number of bars to load (default: all visible bars)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with:
|
||||||
|
- script_output : printed output + last expression result
|
||||||
|
- result_dataframe : serialized DataFrame if last expression is a DataFrame
|
||||||
|
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
|
||||||
|
- chart_context : {symbol, interval, start_time, end_time} or None
|
||||||
|
- error : traceback if execution failed
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# RSI indicator with chart
|
||||||
|
execute_python(\"\"\"
|
||||||
|
df['RSI'] = talib.RSI(df['close'], 14)
|
||||||
|
fig = plot_ohlc(df, title='BTC/USDT with RSI')
|
||||||
|
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
|
||||||
|
df[['close', 'RSI']].tail(5)
|
||||||
|
\"\"\")
|
||||||
|
|
||||||
|
# Multiple indicators
|
||||||
|
execute_python(\"\"\"
|
||||||
|
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||||
|
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||||
|
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
|
||||||
|
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
|
||||||
|
|
||||||
|
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
|
||||||
|
|
||||||
|
current_price = df['close'].iloc[-1]
|
||||||
|
sma20 = df['SMA_20'].iloc[-1]
|
||||||
|
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
|
||||||
|
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
|
||||||
|
\"\"\")
|
||||||
|
|
||||||
|
# Pattern detection
|
||||||
|
execute_python(\"\"\"
|
||||||
|
# Find swing highs
|
||||||
|
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
|
||||||
|
swing_highs = df[df['swing_high']][['high']].tail(5)
|
||||||
|
|
||||||
|
fig = plot_ohlc(df, title='Swing High Detection')
|
||||||
|
print("Recent swing highs:")
|
||||||
|
print(swing_highs)
|
||||||
|
\"\"\")
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
try:
|
||||||
|
import talib
|
||||||
|
except ImportError:
|
||||||
|
talib = None
|
||||||
|
logger.warning("TA-Lib not available in execute_python environment")
|
||||||
|
|
||||||
|
# --- Attempt to load chart data ---
|
||||||
|
df = None
|
||||||
|
chart_context = None
|
||||||
|
|
||||||
|
registry = _get_registry()
|
||||||
|
datasource_registry = _get_datasource_registry()
|
||||||
|
|
||||||
|
if registry and datasource_registry:
|
||||||
|
try:
|
||||||
|
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||||
|
bars = result.bars
|
||||||
|
if bars:
|
||||||
|
rows = []
|
||||||
|
for bar in bars:
|
||||||
|
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
|
||||||
|
df = pd.DataFrame(rows).set_index('time')
|
||||||
|
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||||
|
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"execute_python: no chart data loaded ({e})")
|
||||||
|
|
||||||
|
# --- Import chart utilities ---
|
||||||
|
from .chart_utils import plot_ohlc
|
||||||
|
|
||||||
|
# --- Get indicator registry ---
|
||||||
|
indicator_registry = _get_indicator_registry()
|
||||||
|
|
||||||
|
# --- Build globals ---
|
||||||
|
script_globals: Dict[str, Any] = {
|
||||||
|
'pd': pd,
|
||||||
|
'np': np,
|
||||||
|
'plt': plt,
|
||||||
|
'talib': talib,
|
||||||
|
'indicator_registry': indicator_registry,
|
||||||
|
'df': df,
|
||||||
|
'chart_context': chart_context,
|
||||||
|
'plot_ohlc': plot_ohlc,
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Execute ---
|
||||||
|
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
|
||||||
|
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
stdout_capture = io.StringIO()
|
||||||
|
result_df = None
|
||||||
|
error_msg = None
|
||||||
|
plot_urls = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
|
||||||
|
exec(code, script_globals)
|
||||||
|
|
||||||
|
# Capture last expression
|
||||||
|
lines = code.strip().splitlines()
|
||||||
|
if lines:
|
||||||
|
last = lines[-1].strip()
|
||||||
|
if last and not any(last.startswith(kw) for kw in (
|
||||||
|
'if', 'for', 'while', 'def', 'class', 'import',
|
||||||
|
'from', 'with', 'try', 'return', '#'
|
||||||
|
)):
|
||||||
|
try:
|
||||||
|
last_val = eval(last, script_globals)
|
||||||
|
if isinstance(last_val, pd.DataFrame):
|
||||||
|
result_df = last_val
|
||||||
|
elif last_val is not None:
|
||||||
|
stdout_capture.write(str(last_val))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Save plots
|
||||||
|
for fig_num in plt.get_fignums():
|
||||||
|
fig = plt.figure(fig_num)
|
||||||
|
filename = f"plot_{uuid.uuid4()}.png"
|
||||||
|
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
|
||||||
|
plot_urls.append(f"/uploads/{filename}")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
||||||
|
|
||||||
|
# --- Build response ---
|
||||||
|
response: Dict[str, Any] = {
|
||||||
|
'script_output': stdout_capture.getvalue(),
|
||||||
|
'chart_context': chart_context,
|
||||||
|
'plot_urls': plot_urls,
|
||||||
|
}
|
||||||
|
if result_df is not None:
|
||||||
|
response['result_dataframe'] = {
|
||||||
|
'columns': result_df.columns.tolist(),
|
||||||
|
'index': result_df.index.astype(str).tolist(),
|
||||||
|
'data': result_df.values.tolist(),
|
||||||
|
'shape': result_df.shape,
|
||||||
|
}
|
||||||
|
if error_msg:
|
||||||
|
response['error'] = error_msg
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
CHART_TOOLS = [
|
||||||
|
get_chart_data,
|
||||||
|
execute_python
|
||||||
|
]
|
||||||
224
backend/src/agent/tools/chart_utils.py
Normal file
224
backend/src/agent/tools/chart_utils.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
"""Chart plotting utilities for creating standard, beautiful OHLC charts."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_ohlc(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
volume: bool = True,
|
||||||
|
figsize: Tuple[int, int] = (14, 8),
|
||||||
|
style: str = 'seaborn-v0_8-darkgrid',
|
||||||
|
**kwargs
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""Create a beautiful standard OHLC candlestick chart.
|
||||||
|
|
||||||
|
This is a convenience function that generates a professional-looking candlestick
|
||||||
|
chart with consistent styling across all generated charts. It uses mplfinance
|
||||||
|
with seaborn aesthetics for a polished appearance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: pandas DataFrame with DatetimeIndex and columns: open, high, low, close, volume
|
||||||
|
title: Optional chart title. If None, uses symbol from chart context
|
||||||
|
volume: Whether to include volume subplot (default: True)
|
||||||
|
figsize: Figure size as (width, height) in inches (default: (14, 8))
|
||||||
|
style: Base matplotlib style to use (default: 'seaborn-v0_8-darkgrid')
|
||||||
|
**kwargs: Additional arguments to pass to mplfinance.plot()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.figure.Figure: The created figure object
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Basic usage in analyze_chart_data script
|
||||||
|
fig = plot_ohlc(df, title='BTC/USDT 15min')
|
||||||
|
|
||||||
|
# Customize with additional indicators
|
||||||
|
fig = plot_ohlc(df, volume=True, title='Price Action')
|
||||||
|
|
||||||
|
# Add custom overlays after calling plot_ohlc
|
||||||
|
df['SMA20'] = df['close'].rolling(20).mean()
|
||||||
|
fig = plot_ohlc(df, title='With SMA')
|
||||||
|
# Note: For mplfinance overlays, use the mav or addplot parameters
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The DataFrame must have a DatetimeIndex and the standard OHLCV columns.
|
||||||
|
Column names should be lowercase: open, high, low, close, volume
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import mplfinance as mpf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"mplfinance is required for plot_ohlc(). "
|
||||||
|
"Install it with: pip install mplfinance"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate DataFrame structure
|
||||||
|
required_cols = ['open', 'high', 'low', 'close']
|
||||||
|
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||||
|
if missing_cols:
|
||||||
|
raise ValueError(
|
||||||
|
f"DataFrame missing required columns: {missing_cols}. "
|
||||||
|
f"Required: {required_cols}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(df.index, pd.DatetimeIndex):
|
||||||
|
raise ValueError(
|
||||||
|
"DataFrame must have a DatetimeIndex. "
|
||||||
|
"Convert with: df.index = pd.to_datetime(df.index)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure volume column exists for volume plot
|
||||||
|
if volume and 'volume' not in df.columns:
|
||||||
|
logger.warning("volume=True but 'volume' column not found in DataFrame. Disabling volume.")
|
||||||
|
volume = False
|
||||||
|
|
||||||
|
# Create custom style with seaborn aesthetics
|
||||||
|
# Using a professional color scheme: green for up candles, red for down candles
|
||||||
|
mc = mpf.make_marketcolors(
|
||||||
|
up='#26a69a', # Teal green (calmer than bright green)
|
||||||
|
down='#ef5350', # Coral red (softer than pure red)
|
||||||
|
edge='inherit', # Match candle color for edges
|
||||||
|
wick='inherit', # Match candle color for wicks
|
||||||
|
volume='in', # Volume bars colored by price direction
|
||||||
|
alpha=0.9 # Slight transparency for elegance
|
||||||
|
)
|
||||||
|
|
||||||
|
s = mpf.make_mpf_style(
|
||||||
|
base_mpf_style='charles', # Clean base style
|
||||||
|
marketcolors=mc,
|
||||||
|
rc={
|
||||||
|
'font.size': 10,
|
||||||
|
'axes.labelsize': 11,
|
||||||
|
'axes.titlesize': 12,
|
||||||
|
'xtick.labelsize': 9,
|
||||||
|
'ytick.labelsize': 9,
|
||||||
|
'legend.fontsize': 10,
|
||||||
|
'figure.facecolor': '#f0f0f0',
|
||||||
|
'axes.facecolor': '#ffffff',
|
||||||
|
'axes.grid': True,
|
||||||
|
'grid.alpha': 0.3,
|
||||||
|
'grid.linestyle': '--',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare plot parameters
|
||||||
|
plot_params = {
|
||||||
|
'type': 'candle',
|
||||||
|
'style': s,
|
||||||
|
'volume': volume,
|
||||||
|
'figsize': figsize,
|
||||||
|
'tight_layout': True,
|
||||||
|
'returnfig': True,
|
||||||
|
'warn_too_much_data': 1000, # Warn if > 1000 candles for performance
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add title if provided
|
||||||
|
if title:
|
||||||
|
plot_params['title'] = title
|
||||||
|
|
||||||
|
# Merge any additional kwargs
|
||||||
|
plot_params.update(kwargs)
|
||||||
|
|
||||||
|
# Create the plot
|
||||||
|
logger.info(
|
||||||
|
f"Creating OHLC chart with {len(df)} candles, "
|
||||||
|
f"date range: {df.index.min()} to {df.index.max()}, "
|
||||||
|
f"volume: {volume}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fig, axes = mpf.plot(df, **plot_params)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def add_indicators_to_plot(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
indicators: dict,
|
||||||
|
**plot_kwargs
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""Create an OHLC chart with technical indicators overlaid.
|
||||||
|
|
||||||
|
This extends plot_ohlc() to include common technical indicators using
|
||||||
|
mplfinance's addplot functionality for proper overlay on candlestick charts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: pandas DataFrame with OHLCV data and indicator columns
|
||||||
|
indicators: Dictionary mapping indicator names to parameters
|
||||||
|
Example: {
|
||||||
|
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||||
|
'EMA_50': {'color': 'orange', 'width': 1.5}
|
||||||
|
}
|
||||||
|
**plot_kwargs: Additional arguments for plot_ohlc()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.figure.Figure: The created figure object
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Calculate indicators
|
||||||
|
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||||
|
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||||
|
|
||||||
|
# Plot with indicators
|
||||||
|
fig = add_indicators_to_plot(
|
||||||
|
df,
|
||||||
|
indicators={
|
||||||
|
'SMA_20': {'color': 'blue', 'width': 1.5, 'label': '20 SMA'},
|
||||||
|
'SMA_50': {'color': 'red', 'width': 1.5, 'label': '50 SMA'}
|
||||||
|
},
|
||||||
|
title='BTC/USDT with Moving Averages'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import mplfinance as mpf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"mplfinance is required. Install it with: pip install mplfinance"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build addplot list for indicators
|
||||||
|
addplots = []
|
||||||
|
for indicator_col, params in indicators.items():
|
||||||
|
if indicator_col not in df.columns:
|
||||||
|
logger.warning(f"Indicator column '{indicator_col}' not found in DataFrame. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
color = params.get('color', 'blue')
|
||||||
|
width = params.get('width', 1.0)
|
||||||
|
panel = params.get('panel', 0) # 0 = main panel with candles
|
||||||
|
ylabel = params.get('ylabel', '')
|
||||||
|
|
||||||
|
addplots.append(
|
||||||
|
mpf.make_addplot(
|
||||||
|
df[indicator_col],
|
||||||
|
color=color,
|
||||||
|
width=width,
|
||||||
|
panel=panel,
|
||||||
|
ylabel=ylabel
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass addplot to plot_ohlc via kwargs
|
||||||
|
if addplots:
|
||||||
|
plot_kwargs['addplot'] = addplots
|
||||||
|
|
||||||
|
return plot_ohlc(df, **plot_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience presets for common chart types
|
||||||
|
def plot_price_volume(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||||
|
"""Create a standard price + volume chart."""
|
||||||
|
return plot_ohlc(df, title=title, volume=True, figsize=(14, 8))
|
||||||
|
|
||||||
|
|
||||||
|
def plot_price_only(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||||
|
"""Create a price-only candlestick chart without volume."""
|
||||||
|
return plot_ohlc(df, title=title, volume=False, figsize=(14, 6))
|
||||||
154
backend/src/agent/tools/chart_utils_example.py
Normal file
154
backend/src/agent/tools/chart_utils_example.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Example usage of chart_utils.py plotting functions.
|
||||||
|
|
||||||
|
This demonstrates how the LLM can use the plot_ohlc() convenience function
|
||||||
|
in analyze_chart_data scripts to create beautiful, standard OHLC charts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
def create_sample_data(days=30):
|
||||||
|
"""Create sample OHLCV data for testing."""
|
||||||
|
dates = pd.date_range(end=datetime.now(), periods=days * 24, freq='1H')
|
||||||
|
|
||||||
|
# Simulate price movement
|
||||||
|
np.random.seed(42)
|
||||||
|
close = 50000 + np.cumsum(np.random.randn(len(dates)) * 100)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'open': close + np.random.randn(len(dates)) * 50,
|
||||||
|
'high': close + np.abs(np.random.randn(len(dates))) * 100,
|
||||||
|
'low': close - np.abs(np.random.randn(len(dates))) * 100,
|
||||||
|
'close': close,
|
||||||
|
'volume': np.abs(np.random.randn(len(dates))) * 1000000
|
||||||
|
}
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
|
||||||
|
# Ensure high is highest and low is lowest
|
||||||
|
df['high'] = df[['open', 'high', 'low', 'close']].max(axis=1)
|
||||||
|
df['low'] = df[['open', 'high', 'low', 'close']].min(axis=1)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from chart_utils import plot_ohlc, add_indicators_to_plot, plot_price_volume
|
||||||
|
|
||||||
|
# Create sample data
|
||||||
|
df = create_sample_data(days=30)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Example 1: Basic OHLC chart with volume")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nScript the LLM would generate:")
|
||||||
|
print("""
|
||||||
|
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||||
|
df.tail(5)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Execute it
|
||||||
|
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||||
|
print("\n✓ Chart created successfully!")
|
||||||
|
print(f" Figure size: {fig.get_size_inches()}")
|
||||||
|
print(f" Number of axes: {len(fig.axes)}")
|
||||||
|
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 2: OHLC chart with indicators")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nScript the LLM would generate:")
|
||||||
|
print("""
|
||||||
|
# Calculate indicators
|
||||||
|
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||||
|
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||||
|
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||||
|
|
||||||
|
# Plot with indicators
|
||||||
|
fig = add_indicators_to_plot(
|
||||||
|
df,
|
||||||
|
indicators={
|
||||||
|
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||||
|
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||||
|
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||||
|
},
|
||||||
|
title='BTC/USDT with Moving Averages',
|
||||||
|
volume=True
|
||||||
|
)
|
||||||
|
|
||||||
|
df[['close', 'SMA_20', 'SMA_50', 'EMA_12']].tail(5)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Execute it
|
||||||
|
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||||
|
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||||
|
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||||
|
|
||||||
|
fig = add_indicators_to_plot(
|
||||||
|
df,
|
||||||
|
indicators={
|
||||||
|
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||||
|
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||||
|
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||||
|
},
|
||||||
|
title='BTC/USDT with Moving Averages',
|
||||||
|
volume=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✓ Chart with indicators created successfully!")
|
||||||
|
print(f" Last close: ${df['close'].iloc[-1]:,.2f}")
|
||||||
|
print(f" SMA 20: ${df['SMA_20'].iloc[-1]:,.2f}")
|
||||||
|
print(f" SMA 50: ${df['SMA_50'].iloc[-1]:,.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 3: Price-only chart (no volume)")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nScript the LLM would generate:")
|
||||||
|
print("""
|
||||||
|
from chart_utils import plot_price_only
|
||||||
|
|
||||||
|
fig = plot_price_only(df, title='Clean Price Action')
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Execute it
|
||||||
|
from chart_utils import plot_price_only
|
||||||
|
fig = plot_price_only(df, title='Clean Price Action')
|
||||||
|
|
||||||
|
print("\n✓ Price-only chart created successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Summary")
|
||||||
|
print("=" * 60)
|
||||||
|
print("""
|
||||||
|
The chart_utils module provides:
|
||||||
|
|
||||||
|
1. plot_ohlc() - Main function for beautiful candlestick charts
|
||||||
|
- Professional seaborn-inspired styling
|
||||||
|
- Consistent color scheme (teal up, coral down)
|
||||||
|
- Optional volume subplot
|
||||||
|
- Customizable figure size
|
||||||
|
|
||||||
|
2. add_indicators_to_plot() - OHLC charts with technical indicators
|
||||||
|
- Overlay multiple indicators
|
||||||
|
- Customizable colors and line widths
|
||||||
|
- Proper integration with mplfinance
|
||||||
|
|
||||||
|
3. Preset functions for common chart types:
|
||||||
|
- plot_price_volume() - Standard price + volume
|
||||||
|
- plot_price_only() - Candlesticks without volume
|
||||||
|
|
||||||
|
Benefits:
|
||||||
|
✓ Consistent look and feel across all charts
|
||||||
|
✓ Less code for the LLM to generate
|
||||||
|
✓ Professional appearance out of the box
|
||||||
|
✓ Easy to customize when needed
|
||||||
|
✓ Works seamlessly with analyze_chart_data tool
|
||||||
|
|
||||||
|
The LLM can now simply call plot_ohlc(df) instead of writing
|
||||||
|
custom matplotlib code for every chart request!
|
||||||
|
""")
|
||||||
158
backend/src/agent/tools/datasource_tools.py
Normal file
158
backend/src/agent/tools/datasource_tools.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Data source and market data tools."""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
def _get_datasource_registry():
|
||||||
|
"""Get the global datasource registry instance."""
|
||||||
|
from . import _datasource_registry
|
||||||
|
return _datasource_registry
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def list_data_sources() -> List[str]:
|
||||||
|
"""List all available data sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of data source names that can be queried for market data
|
||||||
|
"""
|
||||||
|
registry = _get_datasource_registry()
|
||||||
|
if not registry:
|
||||||
|
return []
|
||||||
|
return registry.list_sources()
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def search_symbols(
|
||||||
|
query: str,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
exchange: Optional[str] = None,
|
||||||
|
limit: int = 30,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Search for trading symbols across all data sources.
|
||||||
|
|
||||||
|
Automatically searches all available data sources and returns aggregated results.
|
||||||
|
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
||||||
|
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
||||||
|
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
||||||
|
limit: Maximum number of results per source (default: 30)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping source names to lists of matching symbols.
|
||||||
|
Each symbol includes: symbol, full_name, description, exchange, type.
|
||||||
|
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
||||||
|
|
||||||
|
Example response:
|
||||||
|
{
|
||||||
|
"demo": [
|
||||||
|
{
|
||||||
|
"symbol": "BTC/USDT",
|
||||||
|
"full_name": "Bitcoin / Tether USD",
|
||||||
|
"description": "Bitcoin perpetual futures",
|
||||||
|
"exchange": "demo",
|
||||||
|
"type": "crypto"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
registry = _get_datasource_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("DataSourceRegistry not initialized")
|
||||||
|
|
||||||
|
# Always search all sources
|
||||||
|
results = await registry.search_all(query, type, exchange, limit)
|
||||||
|
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
||||||
|
"""Get complete metadata for a trading symbol.
|
||||||
|
|
||||||
|
This retrieves full information about a symbol including:
|
||||||
|
- Description and type
|
||||||
|
- Supported time resolutions
|
||||||
|
- Available data columns (OHLCV, volume, funding rates, etc.)
|
||||||
|
- Trading session information
|
||||||
|
- Price scale and precision
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_name: Name of the data source (use list_data_sources to see available)
|
||||||
|
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing complete symbol metadata including column schema
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If source_name or symbol is not found
|
||||||
|
"""
|
||||||
|
registry = _get_datasource_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("DataSourceRegistry not initialized")
|
||||||
|
|
||||||
|
symbol_info = await registry.resolve_symbol(source_name, symbol)
|
||||||
|
return symbol_info.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_historical_data(
|
||||||
|
source_name: str,
|
||||||
|
symbol: str,
|
||||||
|
resolution: str,
|
||||||
|
from_time: int,
|
||||||
|
to_time: int,
|
||||||
|
countback: Optional[int] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get historical bar/candle data for a symbol.
|
||||||
|
|
||||||
|
Retrieves time-series data between the specified timestamps. The data
|
||||||
|
includes all columns defined for the symbol (OHLCV + any custom columns).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_name: Name of the data source
|
||||||
|
symbol: Symbol identifier
|
||||||
|
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
||||||
|
from_time: Start time as Unix timestamp in seconds
|
||||||
|
to_time: End time as Unix timestamp in seconds
|
||||||
|
countback: Optional limit on number of bars to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- symbol: The requested symbol
|
||||||
|
- resolution: The time resolution
|
||||||
|
- bars: List of bar data with 'time' and 'data' fields
|
||||||
|
- columns: Schema describing available data columns
|
||||||
|
- nextTime: If present, indicates more data is available for pagination
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If source, symbol, or resolution is invalid
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Get 1-hour BTC data for the last 24 hours
|
||||||
|
import time
|
||||||
|
to_time = int(time.time())
|
||||||
|
from_time = to_time - 86400 # 24 hours ago
|
||||||
|
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
||||||
|
"""
|
||||||
|
registry = _get_datasource_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("DataSourceRegistry not initialized")
|
||||||
|
|
||||||
|
source = registry.get(source_name)
|
||||||
|
if not source:
|
||||||
|
available = registry.list_sources()
|
||||||
|
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
||||||
|
|
||||||
|
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
DATASOURCE_TOOLS = [
|
||||||
|
list_data_sources,
|
||||||
|
search_symbols,
|
||||||
|
get_symbol_info,
|
||||||
|
get_historical_data,
|
||||||
|
]
|
||||||
169
backend/src/agent/tools/indicator_tools.py
Normal file
169
backend/src/agent/tools/indicator_tools.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Technical indicator tools."""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
def _get_indicator_registry():
|
||||||
|
"""Get the global indicator registry instance."""
|
||||||
|
from . import _indicator_registry
|
||||||
|
return _indicator_registry
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def list_indicators() -> List[str]:
|
||||||
|
"""List all available technical indicators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of indicator names that can be used in analysis and strategies
|
||||||
|
"""
|
||||||
|
registry = _get_indicator_registry()
|
||||||
|
if not registry:
|
||||||
|
return []
|
||||||
|
return registry.list_indicators()
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
|
||||||
|
"""Get detailed information about a specific indicator.
|
||||||
|
|
||||||
|
Retrieves metadata including description, parameters, category, use cases,
|
||||||
|
input/output schemas, and references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- name: Indicator name
|
||||||
|
- display_name: Human-readable name
|
||||||
|
- description: What the indicator computes and why it's useful
|
||||||
|
- category: Category (momentum, trend, volatility, volume, etc.)
|
||||||
|
- parameters: List of configurable parameters with types and defaults
|
||||||
|
- use_cases: Common trading scenarios where this indicator helps
|
||||||
|
- tags: Searchable tags
|
||||||
|
- input_schema: Required input columns (e.g., OHLCV requirements)
|
||||||
|
- output_schema: Columns this indicator produces
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If indicator_name is not found
|
||||||
|
"""
|
||||||
|
registry = _get_indicator_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("IndicatorRegistry not initialized")
|
||||||
|
|
||||||
|
metadata = registry.get_metadata(indicator_name)
|
||||||
|
if not metadata:
|
||||||
|
total_count = len(registry.list_indicators())
|
||||||
|
raise ValueError(
|
||||||
|
f"Indicator '{indicator_name}' not found. "
|
||||||
|
f"Total available: {total_count} indicators. "
|
||||||
|
f"Use search_indicators() to find indicators by name, category, or tag."
|
||||||
|
)
|
||||||
|
|
||||||
|
input_schema = registry.get_input_schema(indicator_name)
|
||||||
|
output_schema = registry.get_output_schema(indicator_name)
|
||||||
|
|
||||||
|
result = metadata.model_dump()
|
||||||
|
result["input_schema"] = input_schema.model_dump() if input_schema else None
|
||||||
|
result["output_schema"] = output_schema.model_dump() if output_schema else None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_indicators(
|
||||||
|
query: Optional[str] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
tag: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search for indicators by text query, category, or tag.
|
||||||
|
|
||||||
|
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
|
||||||
|
|
||||||
|
Use this to discover relevant indicators for your trading strategy or analysis.
|
||||||
|
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Optional text search across names, descriptions, and use cases
|
||||||
|
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
|
||||||
|
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of lightweight indicator summaries. Each contains:
|
||||||
|
- name: Indicator name (use with get_indicator_info() for full details)
|
||||||
|
- display_name: Human-readable name
|
||||||
|
- description: Brief one-line description
|
||||||
|
- category: Category (momentum, trend, volatility, etc.)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Find all momentum indicators
|
||||||
|
results = search_indicators(category="momentum")
|
||||||
|
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
|
||||||
|
|
||||||
|
# Then get details on interesting ones
|
||||||
|
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
|
||||||
|
|
||||||
|
# Search for moving average indicators
|
||||||
|
search_indicators(query="moving average")
|
||||||
|
|
||||||
|
# Find all TA-Lib indicators
|
||||||
|
search_indicators(tag="talib")
|
||||||
|
"""
|
||||||
|
registry = _get_indicator_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("IndicatorRegistry not initialized")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if query:
|
||||||
|
results = registry.search_by_text(query)
|
||||||
|
elif category:
|
||||||
|
results = registry.search_by_category(category)
|
||||||
|
elif tag:
|
||||||
|
results = registry.search_by_tag(tag)
|
||||||
|
else:
|
||||||
|
# Return all indicators if no filter
|
||||||
|
results = registry.get_all_metadata()
|
||||||
|
|
||||||
|
# Return lightweight summaries only
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": r.name,
|
||||||
|
"display_name": r.display_name,
|
||||||
|
"description": r.description,
|
||||||
|
"category": r.category
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_indicator_categories() -> Dict[str, int]:
|
||||||
|
"""Get all indicator categories and their counts.
|
||||||
|
|
||||||
|
Returns a summary of available indicator categories, useful for
|
||||||
|
exploring what types of indicators are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping category name to count of indicators in that category.
|
||||||
|
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
|
||||||
|
"""
|
||||||
|
registry = _get_indicator_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("IndicatorRegistry not initialized")
|
||||||
|
|
||||||
|
categories: Dict[str, int] = {}
|
||||||
|
for metadata in registry.get_all_metadata():
|
||||||
|
category = metadata.category
|
||||||
|
categories[category] = categories.get(category, 0) + 1
|
||||||
|
|
||||||
|
return categories
|
||||||
|
|
||||||
|
|
||||||
|
INDICATOR_TOOLS = [
|
||||||
|
list_indicators,
|
||||||
|
get_indicator_info,
|
||||||
|
search_indicators,
|
||||||
|
get_indicator_categories
|
||||||
|
]
|
||||||
171
backend/src/agent/tools/research_tools.py
Normal file
171
backend/src/agent/tools/research_tools.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Research and external data tools for trading analysis."""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_community.tools import (
|
||||||
|
ArxivQueryRun,
|
||||||
|
WikipediaQueryRun,
|
||||||
|
DuckDuckGoSearchRun
|
||||||
|
)
|
||||||
|
from langchain_community.utilities import (
|
||||||
|
ArxivAPIWrapper,
|
||||||
|
WikipediaAPIWrapper,
|
||||||
|
DuckDuckGoSearchAPIWrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_arxiv(query: str, max_results: int = 5) -> str:
|
||||||
|
"""Search arXiv for academic papers on quantitative finance, trading strategies, and machine learning.
|
||||||
|
|
||||||
|
Use this to find research papers on topics like:
|
||||||
|
- Market microstructure and order flow
|
||||||
|
- Algorithmic trading strategies
|
||||||
|
- Machine learning for finance
|
||||||
|
- Time series forecasting
|
||||||
|
- Risk management
|
||||||
|
- Portfolio optimization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (e.g., "machine learning algorithmic trading", "deep learning stock prediction")
|
||||||
|
max_results: Maximum number of results to return (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary of papers including titles, authors, abstracts, and links
|
||||||
|
|
||||||
|
Example:
|
||||||
|
search_arxiv("reinforcement learning trading", max_results=3)
|
||||||
|
"""
|
||||||
|
arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=max_results))
|
||||||
|
return arxiv.run(query)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_wikipedia(query: str) -> str:
|
||||||
|
"""Search Wikipedia for information on finance, trading, and economics concepts.
|
||||||
|
|
||||||
|
Use this to get background information on:
|
||||||
|
- Financial instruments and markets
|
||||||
|
- Economic indicators
|
||||||
|
- Trading terminology
|
||||||
|
- Technical analysis concepts
|
||||||
|
- Historical market events
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (e.g., "Black-Scholes model", "technical analysis", "options trading")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wikipedia article summary with key information
|
||||||
|
|
||||||
|
Example:
|
||||||
|
search_wikipedia("Bollinger Bands")
|
||||||
|
"""
|
||||||
|
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||||
|
return wikipedia.run(query)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_web(query: str, max_results: int = 5) -> str:
|
||||||
|
"""Search the web for current information on markets, news, and trading.
|
||||||
|
|
||||||
|
Use this to find:
|
||||||
|
- Latest market news and analysis
|
||||||
|
- Company announcements and earnings
|
||||||
|
- Economic events and indicators
|
||||||
|
- Cryptocurrency updates
|
||||||
|
- Exchange status and updates
|
||||||
|
- Trading strategy discussions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (e.g., "Bitcoin price news", "Fed interest rate decision")
|
||||||
|
max_results: Maximum number of results to return (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Search results with titles, snippets, and links
|
||||||
|
|
||||||
|
Example:
|
||||||
|
search_web("Ethereum merge update", max_results=3)
|
||||||
|
"""
|
||||||
|
# Lazy initialization to avoid hanging during import
|
||||||
|
search = DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper())
|
||||||
|
# Note: max_results parameter doesn't work properly with current wrapper
|
||||||
|
return search.run(query)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def http_get(url: str, params: Optional[Dict[str, str]] = None) -> str:
|
||||||
|
"""Make HTTP GET request to fetch data from APIs or web pages.
|
||||||
|
|
||||||
|
Use this to retrieve:
|
||||||
|
- Exchange API data (if public endpoints)
|
||||||
|
- Market data from external APIs
|
||||||
|
- Documentation and specifications
|
||||||
|
- News articles and blog posts
|
||||||
|
- JSON/XML data from web services
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to fetch
|
||||||
|
params: Optional query parameters as a dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response text from the URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the request fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
http_get("https://api.coingecko.com/api/v3/simple/price",
|
||||||
|
params={"ids": "bitcoin", "vs_currencies": "usd"})
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, params=params, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise ValueError(f"HTTP GET request failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def http_post(url: str, data: Dict[str, Any]) -> str:
|
||||||
|
"""Make HTTP POST request to send data to APIs.
|
||||||
|
|
||||||
|
Use this to:
|
||||||
|
- Submit data to external APIs
|
||||||
|
- Trigger webhooks
|
||||||
|
- Post analysis results
|
||||||
|
- Interact with exchange APIs (if authenticated)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to post to
|
||||||
|
data: Dictionary of data to send in the request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response text from the server
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the request fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
http_post("https://webhook.site/xxx", {"message": "Trade executed"})
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, json=data, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise ValueError(f"HTTP POST request failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# Export tools list
|
||||||
|
RESEARCH_TOOLS = [
|
||||||
|
search_arxiv,
|
||||||
|
search_wikipedia,
|
||||||
|
search_web,
|
||||||
|
http_get,
|
||||||
|
http_post
|
||||||
|
]
|
||||||
138
backend/src/agent/tools/sync_tools.py
Normal file
138
backend/src/agent/tools/sync_tools.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Synchronization store tools."""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
def _get_registry():
|
||||||
|
"""Get the global registry instance."""
|
||||||
|
from . import _registry
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def list_sync_stores() -> List[str]:
|
||||||
|
"""List all available synchronization stores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of store names that can be read/written
|
||||||
|
"""
|
||||||
|
registry = _get_registry()
|
||||||
|
if not registry:
|
||||||
|
return []
|
||||||
|
return list(registry.entries.keys())
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
||||||
|
"""Read the current state of a synchronization store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the current state of the store
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If store_name doesn't exist
|
||||||
|
"""
|
||||||
|
registry = _get_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("SyncRegistry not initialized")
|
||||||
|
|
||||||
|
entry = registry.entries.get(store_name)
|
||||||
|
if not entry:
|
||||||
|
available = list(registry.entries.keys())
|
||||||
|
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||||
|
|
||||||
|
return entry.model.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Update the state of a synchronization store.
|
||||||
|
|
||||||
|
This will apply the updates to the store and trigger synchronization
|
||||||
|
with all connected clients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_name: Name of the store to update
|
||||||
|
updates: Dictionary of field updates (field_name: new_value)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and updated fields
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If store_name doesn't exist or updates are invalid
|
||||||
|
"""
|
||||||
|
registry = _get_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("SyncRegistry not initialized")
|
||||||
|
|
||||||
|
entry = registry.entries.get(store_name)
|
||||||
|
if not entry:
|
||||||
|
available = list(registry.entries.keys())
|
||||||
|
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get current state
|
||||||
|
current_state = entry.model.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Apply updates
|
||||||
|
new_state = {**current_state, **updates}
|
||||||
|
|
||||||
|
# Update the model
|
||||||
|
registry._update_model(entry.model, new_state)
|
||||||
|
|
||||||
|
# Trigger sync
|
||||||
|
await registry.push_all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"store": store_name,
|
||||||
|
"updated_fields": list(updates.keys())
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
||||||
|
"""Get the schema/structure of a synchronization store.
|
||||||
|
|
||||||
|
This shows what fields are available and their types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_name: Name of the store
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary describing the store's schema
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If store_name doesn't exist
|
||||||
|
"""
|
||||||
|
registry = _get_registry()
|
||||||
|
if not registry:
|
||||||
|
raise ValueError("SyncRegistry not initialized")
|
||||||
|
|
||||||
|
entry = registry.entries.get(store_name)
|
||||||
|
if not entry:
|
||||||
|
available = list(registry.entries.keys())
|
||||||
|
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||||
|
|
||||||
|
# Get model schema
|
||||||
|
schema = entry.model.model_json_schema()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"store_name": store_name,
|
||||||
|
"schema": schema
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SYNC_TOOLS = [
|
||||||
|
list_sync_stores,
|
||||||
|
read_sync_state,
|
||||||
|
write_sync_state,
|
||||||
|
get_store_schema
|
||||||
|
]
|
||||||
@@ -6,9 +6,10 @@ the free CCXT library (not ccxt.pro), supporting both historical data and
|
|||||||
polling-based subscriptions.
|
polling-based subscriptions.
|
||||||
|
|
||||||
Numerical Precision:
|
Numerical Precision:
|
||||||
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
|
- OHLCV data uses native floats for optimal DataFrame/analysis performance
|
||||||
|
- Account balances and order data should use Decimal (via _to_decimal method)
|
||||||
- CCXT returns numeric values as strings or floats depending on configuration
|
- CCXT returns numeric values as strings or floats depending on configuration
|
||||||
- All financial values are converted to Decimal to maintain precision
|
- Price data converted to float (_to_float), financial data to Decimal (_to_decimal)
|
||||||
|
|
||||||
Real-time Updates:
|
Real-time Updates:
|
||||||
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
||||||
@@ -72,6 +73,20 @@ class CCXTDataSource(DataSource):
|
|||||||
exchange_class = getattr(ccxt, exchange_id)
|
exchange_class = getattr(ccxt, exchange_id)
|
||||||
self.exchange = exchange_class(self._config)
|
self.exchange = exchange_class(self._config)
|
||||||
|
|
||||||
|
# Configure CCXT to use Decimal mode for precise financial calculations
|
||||||
|
# This ensures all numeric values from the exchange use Decimal internally
|
||||||
|
# We then convert OHLCV to float for DataFrame performance, but keep
|
||||||
|
# Decimal precision for account balances, order sizes, etc.
|
||||||
|
from decimal import Decimal as PythonDecimal
|
||||||
|
self.exchange.number = PythonDecimal
|
||||||
|
|
||||||
|
# Log the precision mode being used by this exchange
|
||||||
|
precision_mode = getattr(self.exchange, 'precisionMode', 'UNKNOWN')
|
||||||
|
logger.info(
|
||||||
|
f"CCXT {exchange_id}: Configured with Decimal mode. "
|
||||||
|
f"Exchange precision mode: {precision_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
||||||
self.exchange.set_sandbox_mode(True)
|
self.exchange.set_sandbox_mode(True)
|
||||||
|
|
||||||
@@ -103,6 +118,33 @@ class CCXTDataSource(DataSource):
|
|||||||
return Decimal(str(value))
|
return Decimal(str(value))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_float(value: Union[str, int, float, Decimal, None]) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Convert a value to float for OHLCV data.
|
||||||
|
|
||||||
|
OHLCV data is used for charting and DataFrame analysis, where native
|
||||||
|
floats provide better performance and compatibility with pandas/numpy.
|
||||||
|
For financial precision (balances, order sizes), use _to_decimal() instead.
|
||||||
|
|
||||||
|
When CCXT is in Decimal mode (exchange.number = Decimal), it returns
|
||||||
|
Decimal objects. This method converts them to float for performance.
|
||||||
|
|
||||||
|
Handles CCXT's output in both modes:
|
||||||
|
- Decimal mode: receives Decimal objects
|
||||||
|
- Default mode: receives strings, floats, or ints
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, float):
|
||||||
|
return value
|
||||||
|
if isinstance(value, Decimal):
|
||||||
|
# CCXT in Decimal mode - convert to float for OHLCV
|
||||||
|
return float(value)
|
||||||
|
if isinstance(value, (str, int)):
|
||||||
|
return float(value)
|
||||||
|
return None
|
||||||
|
|
||||||
async def _ensure_markets_loaded(self):
|
async def _ensure_markets_loaded(self):
|
||||||
"""Ensure markets are loaded from exchange"""
|
"""Ensure markets are loaded from exchange"""
|
||||||
if not self._markets_loaded:
|
if not self._markets_loaded:
|
||||||
@@ -241,31 +283,31 @@ class CCXTDataSource(DataSource):
|
|||||||
columns=[
|
columns=[
|
||||||
ColumnInfo(
|
ColumnInfo(
|
||||||
name="open",
|
name="open",
|
||||||
type="decimal",
|
type="float",
|
||||||
description=f"Opening price in {quote}",
|
description=f"Opening price in {quote}",
|
||||||
unit=quote,
|
unit=quote,
|
||||||
),
|
),
|
||||||
ColumnInfo(
|
ColumnInfo(
|
||||||
name="high",
|
name="high",
|
||||||
type="decimal",
|
type="float",
|
||||||
description=f"Highest price in {quote}",
|
description=f"Highest price in {quote}",
|
||||||
unit=quote,
|
unit=quote,
|
||||||
),
|
),
|
||||||
ColumnInfo(
|
ColumnInfo(
|
||||||
name="low",
|
name="low",
|
||||||
type="decimal",
|
type="float",
|
||||||
description=f"Lowest price in {quote}",
|
description=f"Lowest price in {quote}",
|
||||||
unit=quote,
|
unit=quote,
|
||||||
),
|
),
|
||||||
ColumnInfo(
|
ColumnInfo(
|
||||||
name="close",
|
name="close",
|
||||||
type="decimal",
|
type="float",
|
||||||
description=f"Closing price in {quote}",
|
description=f"Closing price in {quote}",
|
||||||
unit=quote,
|
unit=quote,
|
||||||
),
|
),
|
||||||
ColumnInfo(
|
ColumnInfo(
|
||||||
name="volume",
|
name="volume",
|
||||||
type="decimal",
|
type="float",
|
||||||
description=f"Trading volume in {base}",
|
description=f"Trading volume in {base}",
|
||||||
unit=base,
|
unit=base,
|
||||||
),
|
),
|
||||||
@@ -370,7 +412,7 @@ class CCXTDataSource(DataSource):
|
|||||||
all_ohlcv = all_ohlcv[:countback]
|
all_ohlcv = all_ohlcv[:countback]
|
||||||
break
|
break
|
||||||
|
|
||||||
# Convert to our Bar format with Decimal precision
|
# Convert to our Bar format with float for OHLCV (used in DataFrames)
|
||||||
bars = []
|
bars = []
|
||||||
for candle in all_ohlcv:
|
for candle in all_ohlcv:
|
||||||
timestamp_ms, open_price, high, low, close, volume = candle
|
timestamp_ms, open_price, high, low, close, volume = candle
|
||||||
@@ -384,11 +426,11 @@ class CCXTDataSource(DataSource):
|
|||||||
Bar(
|
Bar(
|
||||||
time=timestamp,
|
time=timestamp,
|
||||||
data={
|
data={
|
||||||
"open": self._to_decimal(open_price),
|
"open": self._to_float(open_price),
|
||||||
"high": self._to_decimal(high),
|
"high": self._to_float(high),
|
||||||
"low": self._to_decimal(low),
|
"low": self._to_float(low),
|
||||||
"close": self._to_decimal(close),
|
"close": self._to_float(close),
|
||||||
"volume": self._to_decimal(volume),
|
"volume": self._to_float(volume),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -476,14 +518,14 @@ class CCXTDataSource(DataSource):
|
|||||||
if timestamp > last_timestamp:
|
if timestamp > last_timestamp:
|
||||||
self._last_bars[subscription_id] = timestamp
|
self._last_bars[subscription_id] = timestamp
|
||||||
|
|
||||||
# Convert to our format with Decimal precision
|
# Convert to our format with float for OHLCV (used in DataFrames)
|
||||||
tick_data = {
|
tick_data = {
|
||||||
"time": timestamp,
|
"time": timestamp,
|
||||||
"open": self._to_decimal(open_price),
|
"open": self._to_float(open_price),
|
||||||
"high": self._to_decimal(high),
|
"high": self._to_float(high),
|
||||||
"low": self._to_decimal(low),
|
"low": self._to_float(low),
|
||||||
"close": self._to_decimal(close),
|
"close": self._to_float(close),
|
||||||
"volume": self._to_decimal(volume),
|
"volume": self._to_float(volume),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Call the callback
|
# Call the callback
|
||||||
|
|||||||
179
backend/src/exchange_kernel/README.md
Normal file
179
backend/src/exchange_kernel/README.md
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
# Exchange Kernel API
|
||||||
|
|
||||||
|
A Kubernetes-style declarative API for managing orders across different exchanges.
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
The Exchange Kernel maintains two separate views of order state:
|
||||||
|
|
||||||
|
1. **Desired State (Intent)**: What the strategy kernel wants
|
||||||
|
2. **Actual State (Reality)**: What currently exists on the exchange
|
||||||
|
|
||||||
|
A reconciliation loop continuously works to bring actual state into alignment with desired state, handling errors, retries, and edge cases automatically.
|
||||||
|
|
||||||
|
## Core Components
|
||||||
|
|
||||||
|
### Models (`models.py`)
|
||||||
|
|
||||||
|
- **OrderIntent**: Desired order state from strategy kernel
|
||||||
|
- **OrderState**: Actual current order state on exchange
|
||||||
|
- **Position**: Current position (spot, margin, perp, futures, options)
|
||||||
|
- **Asset**: Asset holdings with metadata
|
||||||
|
- **AccountState**: Complete account snapshot (balances, positions, margin)
|
||||||
|
- **AssetMetadata**: Asset type descriptions and trading parameters
|
||||||
|
|
||||||
|
### Events (`events.py`)
|
||||||
|
|
||||||
|
Order lifecycle events:
|
||||||
|
- `OrderSubmitted`, `OrderAccepted`, `OrderRejected`
|
||||||
|
- `OrderPartiallyFilled`, `OrderFilled`, `OrderCanceled`
|
||||||
|
- `OrderModified`, `OrderExpired`
|
||||||
|
|
||||||
|
Position events:
|
||||||
|
- `PositionOpened`, `PositionModified`, `PositionClosed`
|
||||||
|
|
||||||
|
Account events:
|
||||||
|
- `AccountBalanceUpdated`, `MarginCallWarning`
|
||||||
|
|
||||||
|
### Base Interface (`base.py`)
|
||||||
|
|
||||||
|
Abstract `ExchangeKernel` class defining:
|
||||||
|
|
||||||
|
**Command API**:
|
||||||
|
- `place_order()`, `place_order_group()` - Create order intents
|
||||||
|
- `cancel_order()`, `modify_order()` - Update intents
|
||||||
|
- `cancel_all_orders()` - Bulk cancellation
|
||||||
|
|
||||||
|
**Query API**:
|
||||||
|
- `get_order_intent()`, `get_order_state()` - Query single order
|
||||||
|
- `get_all_intents()`, `get_all_orders()` - Query all orders
|
||||||
|
- `get_positions()`, `get_account_state()` - Query positions/balances
|
||||||
|
- `get_symbol_metadata()`, `get_asset_metadata()` - Query market info
|
||||||
|
|
||||||
|
**Event API**:
|
||||||
|
- `subscribe_events()`, `unsubscribe_events()` - Event notifications
|
||||||
|
|
||||||
|
**Lifecycle**:
|
||||||
|
- `start()`, `stop()` - Kernel lifecycle
|
||||||
|
- `health_check()` - Connection status
|
||||||
|
- `force_reconciliation()` - Manual reconciliation trigger
|
||||||
|
|
||||||
|
### State Management (`state.py`)
|
||||||
|
|
||||||
|
- **IntentStateStore**: Storage for desired state (durable, survives restarts)
|
||||||
|
- **ActualStateStore**: Storage for actual exchange state (ephemeral cache)
|
||||||
|
- **ReconciliationEngine**: Framework for intent→reality reconciliation
|
||||||
|
- **InMemory implementations**: For testing/prototyping
|
||||||
|
|
||||||
|
## Standard Order Model
|
||||||
|
|
||||||
|
Defined in `schema/order_spec.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
StandardOrder(
|
||||||
|
symbol_id="BTC/USD",
|
||||||
|
side=Side.BUY,
|
||||||
|
amount=1.0,
|
||||||
|
amount_type=AmountType.BASE, # or QUOTE for exact-out
|
||||||
|
limit_price=50000.0, # None for market orders
|
||||||
|
time_in_force=TimeInForce.GTC,
|
||||||
|
conditional_trigger=ConditionalTrigger(...), # Optional stop-loss/take-profit
|
||||||
|
conditional_mode=ConditionalOrderMode.UNIFIED_ADJUSTING,
|
||||||
|
reduce_only=False,
|
||||||
|
post_only=False,
|
||||||
|
iceberg_qty=None,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Symbol Metadata
|
||||||
|
|
||||||
|
Markets describe their capabilities via `SymbolMetadata`:
|
||||||
|
|
||||||
|
- **AmountConstraints**: Min/max order size, step size
|
||||||
|
- **PriceConstraints**: Tick size, tick spacing mode (fixed/dynamic/continuous)
|
||||||
|
- **MarketCapabilities**:
|
||||||
|
- Supported sides (BUY, SELL)
|
||||||
|
- Supported amount types (BASE, QUOTE, or both)
|
||||||
|
- Market vs limit order support
|
||||||
|
- Time-in-force options (GTC, IOC, FOK, DAY, GTD)
|
||||||
|
- Conditional order support (stop-loss, take-profit, trailing stops)
|
||||||
|
- Advanced features (post-only, reduce-only, iceberg)
|
||||||
|
|
||||||
|
## Asset Types
|
||||||
|
|
||||||
|
Comprehensive asset type system supporting:
|
||||||
|
- **SPOT**: Cash markets
|
||||||
|
- **MARGIN**: Margin trading
|
||||||
|
- **PERP**: Perpetual futures
|
||||||
|
- **FUTURE**: Dated futures
|
||||||
|
- **OPTION**: Options contracts
|
||||||
|
- **SYNTHETIC**: Derived instruments
|
||||||
|
|
||||||
|
Each asset has metadata describing contract specs, settlement, margin requirements, etc.
|
||||||
|
|
||||||
|
## Usage Pattern
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create exchange kernel for specific exchange
|
||||||
|
kernel = SomeExchangeKernel(exchange_id="binance_main")
|
||||||
|
|
||||||
|
# Subscribe to events
|
||||||
|
kernel.subscribe_events(my_event_handler)
|
||||||
|
|
||||||
|
# Start kernel
|
||||||
|
await kernel.start()
|
||||||
|
|
||||||
|
# Place order (creates intent, kernel handles execution)
|
||||||
|
intent_id = await kernel.place_order(
|
||||||
|
StandardOrder(
|
||||||
|
symbol_id="BTC/USD",
|
||||||
|
side=Side.BUY,
|
||||||
|
amount=1.0,
|
||||||
|
amount_type=AmountType.BASE,
|
||||||
|
limit_price=50000.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query desired state
|
||||||
|
intent = await kernel.get_order_intent(intent_id)
|
||||||
|
|
||||||
|
# Query actual state
|
||||||
|
state = await kernel.get_order_state(intent_id)
|
||||||
|
|
||||||
|
# Modify order (updates intent, kernel reconciles)
|
||||||
|
await kernel.modify_order(intent_id, new_order)
|
||||||
|
|
||||||
|
# Cancel order
|
||||||
|
await kernel.cancel_order(intent_id)
|
||||||
|
|
||||||
|
# Query positions
|
||||||
|
positions = await kernel.get_positions()
|
||||||
|
|
||||||
|
# Query account state
|
||||||
|
account = await kernel.get_account_state()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Status
|
||||||
|
|
||||||
|
✅ **Complete**:
|
||||||
|
- Data models and type definitions
|
||||||
|
- Event definitions
|
||||||
|
- Abstract interface
|
||||||
|
- State store framework
|
||||||
|
- In-memory stores for testing
|
||||||
|
|
||||||
|
⏳ **TODO** (Exchange-specific implementations):
|
||||||
|
- Concrete ExchangeKernel implementations per exchange
|
||||||
|
- Reconciliation engine implementation
|
||||||
|
- Exchange API adapters
|
||||||
|
- Persistent state storage (database)
|
||||||
|
- Error handling and retry logic
|
||||||
|
- Monitoring and observability
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Create concrete implementations for specific exchanges (Binance, Uniswap, etc.)
|
||||||
|
2. Implement reconciliation engine with proper error handling
|
||||||
|
3. Add persistent storage backend for intents
|
||||||
|
4. Build integration tests
|
||||||
|
5. Add monitoring/metrics collection
|
||||||
75
backend/src/exchange_kernel/__init__.py
Normal file
75
backend/src/exchange_kernel/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
Exchange Kernel API
|
||||||
|
|
||||||
|
The exchange kernel provides a Kubernetes-style declarative API for managing orders
|
||||||
|
across different exchanges. It maintains both desired state (intent) and actual state
|
||||||
|
(current orders on exchange) and reconciles them continuously.
|
||||||
|
|
||||||
|
Key concepts:
|
||||||
|
- OrderIntent: What the strategy kernel wants
|
||||||
|
- OrderState: What actually exists on the exchange
|
||||||
|
- Reconciliation: Bringing actual state into alignment with desired state
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import ExchangeKernel
|
||||||
|
from .events import (
|
||||||
|
OrderEvent,
|
||||||
|
OrderSubmitted,
|
||||||
|
OrderAccepted,
|
||||||
|
OrderRejected,
|
||||||
|
OrderPartiallyFilled,
|
||||||
|
OrderFilled,
|
||||||
|
OrderCanceled,
|
||||||
|
OrderModified,
|
||||||
|
OrderExpired,
|
||||||
|
PositionEvent,
|
||||||
|
PositionOpened,
|
||||||
|
PositionModified,
|
||||||
|
PositionClosed,
|
||||||
|
AccountEvent,
|
||||||
|
AccountBalanceUpdated,
|
||||||
|
MarginCallWarning,
|
||||||
|
)
|
||||||
|
from .models import (
|
||||||
|
OrderIntent,
|
||||||
|
OrderState,
|
||||||
|
Position,
|
||||||
|
Asset,
|
||||||
|
AssetMetadata,
|
||||||
|
AccountState,
|
||||||
|
Balance,
|
||||||
|
)
|
||||||
|
from .state import IntentStateStore, ActualStateStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Core interface
|
||||||
|
"ExchangeKernel",
|
||||||
|
# Events
|
||||||
|
"OrderEvent",
|
||||||
|
"OrderSubmitted",
|
||||||
|
"OrderAccepted",
|
||||||
|
"OrderRejected",
|
||||||
|
"OrderPartiallyFilled",
|
||||||
|
"OrderFilled",
|
||||||
|
"OrderCanceled",
|
||||||
|
"OrderModified",
|
||||||
|
"OrderExpired",
|
||||||
|
"PositionEvent",
|
||||||
|
"PositionOpened",
|
||||||
|
"PositionModified",
|
||||||
|
"PositionClosed",
|
||||||
|
"AccountEvent",
|
||||||
|
"AccountBalanceUpdated",
|
||||||
|
"MarginCallWarning",
|
||||||
|
# Models
|
||||||
|
"OrderIntent",
|
||||||
|
"OrderState",
|
||||||
|
"Position",
|
||||||
|
"Asset",
|
||||||
|
"AssetMetadata",
|
||||||
|
"AccountState",
|
||||||
|
"Balance",
|
||||||
|
# State management
|
||||||
|
"IntentStateStore",
|
||||||
|
"ActualStateStore",
|
||||||
|
]
|
||||||
361
backend/src/exchange_kernel/base.py
Normal file
361
backend/src/exchange_kernel/base.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
"""
|
||||||
|
Base interface for Exchange Kernels.
|
||||||
|
|
||||||
|
Defines the abstract API that all exchange kernel implementations must support.
|
||||||
|
Each exchange (or exchange type) will have its own kernel implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
OrderIntent,
|
||||||
|
OrderState,
|
||||||
|
Position,
|
||||||
|
AccountState,
|
||||||
|
AssetMetadata,
|
||||||
|
)
|
||||||
|
from .events import BaseEvent
|
||||||
|
from ..schema.order_spec import (
|
||||||
|
StandardOrder,
|
||||||
|
StandardOrderGroup,
|
||||||
|
SymbolMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExchangeKernel(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for exchange kernels.
|
||||||
|
|
||||||
|
An exchange kernel manages the lifecycle of orders on a specific exchange,
|
||||||
|
maintaining both desired state (intents from strategy kernel) and actual
|
||||||
|
state (current orders on exchange), and continuously reconciling them.
|
||||||
|
|
||||||
|
Think of it as a Kubernetes-style controller for trading orders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, exchange_id: str):
|
||||||
|
"""
|
||||||
|
Initialize the exchange kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_id: Unique identifier for this exchange instance
|
||||||
|
"""
|
||||||
|
self.exchange_id = exchange_id
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Command API - Strategy kernel sends intents
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def place_order(self, order: StandardOrder, metadata: dict[str, Any] | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Place a single order on the exchange.
|
||||||
|
|
||||||
|
This creates an OrderIntent and begins the reconciliation process to
|
||||||
|
get the order onto the exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order: The order specification
|
||||||
|
metadata: Optional strategy-specific metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
intent_id: Unique identifier for this order intent
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If order violates market constraints
|
||||||
|
ExchangeError: If exchange rejects the order
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def place_order_group(
|
||||||
|
self,
|
||||||
|
group: StandardOrderGroup,
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Place a group of orders with OCO (One-Cancels-Other) relationship.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: Group of orders with OCO mode
|
||||||
|
metadata: Optional strategy-specific metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
intent_ids: List of intent IDs for each order in the group
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If any order violates market constraints
|
||||||
|
ExchangeError: If exchange rejects the group
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def cancel_order(self, intent_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Cancel an order by intent ID.
|
||||||
|
|
||||||
|
Updates the intent to indicate cancellation is desired, and the
|
||||||
|
reconciliation loop will handle the actual exchange cancellation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID of the order to cancel
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
ExchangeError: If exchange rejects cancellation
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def modify_order(
|
||||||
|
self,
|
||||||
|
intent_id: str,
|
||||||
|
new_order: StandardOrder,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Modify an existing order.
|
||||||
|
|
||||||
|
Updates the order intent, and the reconciliation loop will update
|
||||||
|
the exchange order (via modify API if available, or cancel+replace).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID of the order to modify
|
||||||
|
new_order: New order specification
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
ValidationError: If new order violates market constraints
|
||||||
|
ExchangeError: If exchange rejects modification
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def cancel_all_orders(self, symbol_id: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Cancel all orders, optionally filtered by symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: If provided, only cancel orders for this symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
count: Number of orders canceled
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Query API - Read desired and actual state
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_order_intent(self, intent_id: str) -> OrderIntent:
|
||||||
|
"""
|
||||||
|
Get the desired order state (what strategy kernel wants).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The order intent
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||||
|
"""
|
||||||
|
Get the actual order state (what's currently on exchange).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current order state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_intents(self, symbol_id: str | None = None) -> list[OrderIntent]:
|
||||||
|
"""
|
||||||
|
Get all order intents, optionally filtered by symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: If provided, only return intents for this symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of order intents
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_orders(self, symbol_id: str | None = None) -> list[OrderState]:
|
||||||
|
"""
|
||||||
|
Get all actual order states, optionally filtered by symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: If provided, only return orders for this symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of order states
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_positions(self, symbol_id: str | None = None) -> list[Position]:
|
||||||
|
"""
|
||||||
|
Get current positions, optionally filtered by symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: If provided, only return positions for this symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of positions
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_account_state(self) -> AccountState:
|
||||||
|
"""
|
||||||
|
Get current account state (balances, margin, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current account state
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_symbol_metadata(self, symbol_id: str) -> SymbolMetadata:
|
||||||
|
"""
|
||||||
|
Get metadata for a symbol (constraints, capabilities, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: Symbol to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If symbol doesn't exist on this exchange
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_asset_metadata(self, asset_id: str) -> AssetMetadata:
|
||||||
|
"""
|
||||||
|
Get metadata for an asset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asset_id: Asset to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Asset metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If asset doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_symbols(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all available symbols on this exchange.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of symbol IDs
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Event Subscription API
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def subscribe_events(
|
||||||
|
self,
|
||||||
|
callback: Callable[[BaseEvent], None],
|
||||||
|
event_filter: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Subscribe to events from this exchange kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function to call when events occur
|
||||||
|
event_filter: Optional filter criteria (event_type, symbol_id, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
subscription_id: Unique ID for this subscription (for unsubscribe)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unsubscribe_events(self, subscription_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Unsubscribe from events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Subscription ID returned from subscribe_events
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Lifecycle Management
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""
|
||||||
|
Start the exchange kernel.
|
||||||
|
|
||||||
|
Initializes connections, starts reconciliation loops, etc.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""
|
||||||
|
Stop the exchange kernel.
|
||||||
|
|
||||||
|
Closes connections, stops reconciliation loops, etc.
|
||||||
|
Does NOT cancel open orders - call cancel_all_orders() first if desired.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def health_check(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Check health status of the exchange kernel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Health status dict with connection state, latency, error counts, etc.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Reconciliation Control (advanced)
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def force_reconciliation(self, intent_id: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Force immediate reconciliation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: If provided, only reconcile this specific intent.
|
||||||
|
If None, reconcile all intents.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_reconciliation_metrics(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get metrics about the reconciliation process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Metrics dict with reconciliation lag, error rates, retry counts, etc.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
250
backend/src/exchange_kernel/events.py
Normal file
250
backend/src/exchange_kernel/events.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""
|
||||||
|
Event definitions for the Exchange Kernel.
|
||||||
|
|
||||||
|
All events that can occur during the order lifecycle, position management,
|
||||||
|
and account updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..schema.order_spec import Float, Uint64
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base Event Classes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class EventType(StrEnum):
|
||||||
|
"""Types of events emitted by the exchange kernel"""
|
||||||
|
# Order lifecycle
|
||||||
|
ORDER_SUBMITTED = "ORDER_SUBMITTED"
|
||||||
|
ORDER_ACCEPTED = "ORDER_ACCEPTED"
|
||||||
|
ORDER_REJECTED = "ORDER_REJECTED"
|
||||||
|
ORDER_PARTIALLY_FILLED = "ORDER_PARTIALLY_FILLED"
|
||||||
|
ORDER_FILLED = "ORDER_FILLED"
|
||||||
|
ORDER_CANCELED = "ORDER_CANCELED"
|
||||||
|
ORDER_MODIFIED = "ORDER_MODIFIED"
|
||||||
|
ORDER_EXPIRED = "ORDER_EXPIRED"
|
||||||
|
|
||||||
|
# Position events
|
||||||
|
POSITION_OPENED = "POSITION_OPENED"
|
||||||
|
POSITION_MODIFIED = "POSITION_MODIFIED"
|
||||||
|
POSITION_CLOSED = "POSITION_CLOSED"
|
||||||
|
|
||||||
|
# Account events
|
||||||
|
ACCOUNT_BALANCE_UPDATED = "ACCOUNT_BALANCE_UPDATED"
|
||||||
|
MARGIN_CALL_WARNING = "MARGIN_CALL_WARNING"
|
||||||
|
|
||||||
|
# System events
|
||||||
|
RECONCILIATION_FAILED = "RECONCILIATION_FAILED"
|
||||||
|
CONNECTION_LOST = "CONNECTION_LOST"
|
||||||
|
CONNECTION_RESTORED = "CONNECTION_RESTORED"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEvent(BaseModel):
|
||||||
|
"""Base class for all exchange kernel events"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
event_type: EventType = Field(description="Type of event")
|
||||||
|
timestamp: Uint64 = Field(description="Event timestamp (Unix milliseconds)")
|
||||||
|
exchange: str = Field(description="Exchange identifier")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional event data")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Order Events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class OrderEvent(BaseEvent):
|
||||||
|
"""Base class for order-related events"""
|
||||||
|
|
||||||
|
intent_id: str = Field(description="Order intent ID")
|
||||||
|
order_id: str | None = Field(default=None, description="Exchange order ID (if assigned)")
|
||||||
|
symbol_id: str = Field(description="Symbol being traded")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderSubmitted(OrderEvent):
|
||||||
|
"""Order has been submitted to the exchange"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_SUBMITTED)
|
||||||
|
client_order_id: str | None = Field(default=None, description="Client-assigned order ID")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderAccepted(OrderEvent):
|
||||||
|
"""Order has been accepted by the exchange"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_ACCEPTED)
|
||||||
|
order_id: str = Field(description="Exchange-assigned order ID")
|
||||||
|
accepted_at: Uint64 = Field(description="Exchange acceptance timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderRejected(OrderEvent):
|
||||||
|
"""Order was rejected by the exchange"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_REJECTED)
|
||||||
|
reason: str = Field(description="Rejection reason")
|
||||||
|
error_code: str | None = Field(default=None, description="Exchange error code")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderPartiallyFilled(OrderEvent):
|
||||||
|
"""Order was partially filled"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_PARTIALLY_FILLED)
|
||||||
|
order_id: str = Field(description="Exchange order ID")
|
||||||
|
fill_price: Float = Field(description="Fill price for this execution")
|
||||||
|
fill_quantity: Float = Field(description="Quantity filled in this execution")
|
||||||
|
total_filled: Float = Field(description="Total quantity filled so far")
|
||||||
|
remaining_quantity: Float = Field(description="Remaining quantity to fill")
|
||||||
|
commission: Float = Field(default=0.0, description="Commission/fee for this fill")
|
||||||
|
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||||
|
trade_id: str | None = Field(default=None, description="Exchange trade ID")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderFilled(OrderEvent):
|
||||||
|
"""Order was completely filled"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_FILLED)
|
||||||
|
order_id: str = Field(description="Exchange order ID")
|
||||||
|
average_fill_price: Float = Field(description="Average execution price")
|
||||||
|
total_quantity: Float = Field(description="Total quantity filled")
|
||||||
|
total_commission: Float = Field(default=0.0, description="Total commission/fees")
|
||||||
|
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||||
|
completed_at: Uint64 = Field(description="Completion timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderCanceled(OrderEvent):
|
||||||
|
"""Order was canceled"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_CANCELED)
|
||||||
|
order_id: str = Field(description="Exchange order ID")
|
||||||
|
reason: str = Field(description="Cancellation reason")
|
||||||
|
filled_quantity: Float = Field(default=0.0, description="Quantity filled before cancellation")
|
||||||
|
canceled_at: Uint64 = Field(description="Cancellation timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderModified(OrderEvent):
|
||||||
|
"""Order was modified (price, quantity, etc.)"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_MODIFIED)
|
||||||
|
order_id: str = Field(description="Exchange order ID")
|
||||||
|
old_price: Float | None = Field(default=None, description="Previous price")
|
||||||
|
new_price: Float | None = Field(default=None, description="New price")
|
||||||
|
old_quantity: Float | None = Field(default=None, description="Previous quantity")
|
||||||
|
new_quantity: Float | None = Field(default=None, description="New quantity")
|
||||||
|
modified_at: Uint64 = Field(description="Modification timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class OrderExpired(OrderEvent):
|
||||||
|
"""Order expired (GTD, DAY orders)"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ORDER_EXPIRED)
|
||||||
|
order_id: str = Field(description="Exchange order ID")
|
||||||
|
filled_quantity: Float = Field(default=0.0, description="Quantity filled before expiration")
|
||||||
|
expired_at: Uint64 = Field(description="Expiration timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Position Events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class PositionEvent(BaseEvent):
|
||||||
|
"""Base class for position-related events"""
|
||||||
|
|
||||||
|
position_id: str = Field(description="Position identifier")
|
||||||
|
symbol_id: str = Field(description="Symbol identifier")
|
||||||
|
asset_id: str = Field(description="Asset identifier")
|
||||||
|
|
||||||
|
|
||||||
|
class PositionOpened(PositionEvent):
|
||||||
|
"""New position was opened"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.POSITION_OPENED)
|
||||||
|
quantity: Float = Field(description="Position quantity")
|
||||||
|
entry_price: Float = Field(description="Entry price")
|
||||||
|
side: str = Field(description="LONG or SHORT")
|
||||||
|
leverage: Float | None = Field(default=None, description="Leverage")
|
||||||
|
|
||||||
|
|
||||||
|
class PositionModified(PositionEvent):
|
||||||
|
"""Existing position was modified (size change, etc.)"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.POSITION_MODIFIED)
|
||||||
|
old_quantity: Float = Field(description="Previous quantity")
|
||||||
|
new_quantity: Float = Field(description="New quantity")
|
||||||
|
average_entry_price: Float = Field(description="Updated average entry price")
|
||||||
|
unrealized_pnl: Float | None = Field(default=None, description="Current unrealized P&L")
|
||||||
|
|
||||||
|
|
||||||
|
class PositionClosed(PositionEvent):
|
||||||
|
"""Position was closed"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.POSITION_CLOSED)
|
||||||
|
exit_price: Float = Field(description="Exit price")
|
||||||
|
realized_pnl: Float = Field(description="Realized profit/loss")
|
||||||
|
closed_at: Uint64 = Field(description="Closure timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Account Events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AccountEvent(BaseEvent):
|
||||||
|
"""Base class for account-related events"""
|
||||||
|
|
||||||
|
account_id: str = Field(description="Account identifier")
|
||||||
|
|
||||||
|
|
||||||
|
class AccountBalanceUpdated(AccountEvent):
|
||||||
|
"""Account balance was updated"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.ACCOUNT_BALANCE_UPDATED)
|
||||||
|
asset_id: str = Field(description="Asset that changed")
|
||||||
|
old_balance: Float = Field(description="Previous balance")
|
||||||
|
new_balance: Float = Field(description="New balance")
|
||||||
|
old_available: Float = Field(description="Previous available")
|
||||||
|
new_available: Float = Field(description="New available")
|
||||||
|
change_reason: str = Field(description="Why balance changed (TRADE, DEPOSIT, WITHDRAWAL, etc.)")
|
||||||
|
|
||||||
|
|
||||||
|
class MarginCallWarning(AccountEvent):
|
||||||
|
"""Margin level is approaching liquidation threshold"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.MARGIN_CALL_WARNING)
|
||||||
|
margin_level: Float = Field(description="Current margin level")
|
||||||
|
liquidation_threshold: Float = Field(description="Liquidation threshold")
|
||||||
|
required_action: str = Field(description="Required action to avoid liquidation")
|
||||||
|
estimated_liquidation_price: Float | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Estimated liquidation price for positions"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# System Events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ReconciliationFailed(BaseEvent):
|
||||||
|
"""Failed to reconcile intent with actual state"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.RECONCILIATION_FAILED)
|
||||||
|
intent_id: str = Field(description="Order intent ID")
|
||||||
|
error_message: str = Field(description="Error details")
|
||||||
|
retry_count: int = Field(description="Number of retry attempts")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionLost(BaseEvent):
|
||||||
|
"""Connection to exchange was lost"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.CONNECTION_LOST)
|
||||||
|
reason: str = Field(description="Disconnection reason")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionRestored(BaseEvent):
|
||||||
|
"""Connection to exchange was restored"""
|
||||||
|
|
||||||
|
event_type: EventType = Field(default=EventType.CONNECTION_RESTORED)
|
||||||
|
downtime_duration: int = Field(description="Duration of downtime in milliseconds")
|
||||||
194
backend/src/exchange_kernel/models.py
Normal file
194
backend/src/exchange_kernel/models.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
Data models for the Exchange Kernel.
|
||||||
|
|
||||||
|
Defines order intents, order state, positions, assets, and account state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..schema.order_spec import (
|
||||||
|
StandardOrder,
|
||||||
|
StandardOrderStatus,
|
||||||
|
AssetType,
|
||||||
|
Float,
|
||||||
|
Uint64,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Order Intent and State
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class OrderIntent(BaseModel):
|
||||||
|
"""
|
||||||
|
Desired order state from the strategy kernel.
|
||||||
|
|
||||||
|
This represents what the strategy wants, not what currently exists.
|
||||||
|
The exchange kernel will work to reconcile actual state with this intent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
intent_id: str = Field(description="Unique identifier for this intent (client-assigned)")
|
||||||
|
order: StandardOrder = Field(description="The desired order specification")
|
||||||
|
group_id: str | None = Field(default=None, description="Group ID for OCO relationships")
|
||||||
|
created_at: Uint64 = Field(description="When this intent was created")
|
||||||
|
updated_at: Uint64 = Field(description="When this intent was last modified")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class ReconciliationStatus(StrEnum):
|
||||||
|
"""Status of reconciliation between intent and actual state"""
|
||||||
|
PENDING = "PENDING" # Not yet submitted to exchange
|
||||||
|
SUBMITTING = "SUBMITTING" # Currently being submitted
|
||||||
|
ACTIVE = "ACTIVE" # Successfully placed on exchange
|
||||||
|
RECONCILING = "RECONCILING" # Intent changed, updating exchange order
|
||||||
|
FAILED = "FAILED" # Failed to submit or reconcile
|
||||||
|
COMPLETED = "COMPLETED" # Order fully filled
|
||||||
|
CANCELED = "CANCELED" # Order canceled
|
||||||
|
|
||||||
|
|
||||||
|
class OrderState(BaseModel):
|
||||||
|
"""
|
||||||
|
Actual current state of an order on the exchange.
|
||||||
|
|
||||||
|
This represents reality - what the exchange reports about the order.
|
||||||
|
May differ from OrderIntent during reconciliation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
intent_id: str = Field(description="Links back to the OrderIntent")
|
||||||
|
exchange_order_id: str = Field(description="Exchange-assigned order ID")
|
||||||
|
status: StandardOrderStatus = Field(description="Current order status from exchange")
|
||||||
|
reconciliation_status: ReconciliationStatus = Field(description="Reconciliation state")
|
||||||
|
last_sync_at: Uint64 = Field(description="Last time we synced with exchange")
|
||||||
|
error_message: str | None = Field(default=None, description="Error details if FAILED")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Position and Asset Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AssetMetadata(BaseModel):
|
||||||
|
"""
|
||||||
|
Metadata describing an asset type.
|
||||||
|
|
||||||
|
Provides context for positions, balances, and trading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
asset_id: str = Field(description="Unique asset identifier")
|
||||||
|
symbol: str = Field(description="Asset symbol (e.g., 'BTC', 'ETH', 'USD')")
|
||||||
|
asset_type: AssetType = Field(description="Type of asset")
|
||||||
|
name: str = Field(description="Full name")
|
||||||
|
|
||||||
|
# Contract specifications (for derivatives)
|
||||||
|
contract_size: Float | None = Field(default=None, description="Contract multiplier")
|
||||||
|
settlement_asset: str | None = Field(default=None, description="Settlement currency")
|
||||||
|
expiry_timestamp: Uint64 | None = Field(default=None, description="Expiration timestamp")
|
||||||
|
|
||||||
|
# Trading parameters
|
||||||
|
tick_size: Float | None = Field(default=None, description="Minimum price increment")
|
||||||
|
lot_size: Float | None = Field(default=None, description="Minimum quantity increment")
|
||||||
|
|
||||||
|
# Margin requirements (for leveraged products)
|
||||||
|
initial_margin_rate: Float | None = Field(default=None, description="Initial margin requirement")
|
||||||
|
maintenance_margin_rate: Float | None = Field(default=None, description="Maintenance margin requirement")
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class Asset(BaseModel):
|
||||||
|
"""
|
||||||
|
An asset holding (spot, margin, derivative position, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
asset_id: str = Field(description="References AssetMetadata")
|
||||||
|
quantity: Float = Field(description="Amount held (positive or negative for short positions)")
|
||||||
|
available: Float = Field(description="Amount available for trading (not locked in orders)")
|
||||||
|
locked: Float = Field(description="Amount locked in open orders")
|
||||||
|
|
||||||
|
# For derivative positions
|
||||||
|
entry_price: Float | None = Field(default=None, description="Average entry price")
|
||||||
|
mark_price: Float | None = Field(default=None, description="Current mark price")
|
||||||
|
liquidation_price: Float | None = Field(default=None, description="Estimated liquidation price")
|
||||||
|
unrealized_pnl: Float | None = Field(default=None, description="Unrealized profit/loss")
|
||||||
|
realized_pnl: Float | None = Field(default=None, description="Realized profit/loss")
|
||||||
|
|
||||||
|
# Margin info
|
||||||
|
margin_used: Float | None = Field(default=None, description="Margin allocated to this position")
|
||||||
|
|
||||||
|
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class Position(BaseModel):
|
||||||
|
"""
|
||||||
|
A trading position (spot, margin, perpetual, futures, etc.)
|
||||||
|
|
||||||
|
Tracks both the asset holdings and associated metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
position_id: str = Field(description="Unique position identifier")
|
||||||
|
symbol_id: str = Field(description="Trading symbol")
|
||||||
|
asset: Asset = Field(description="Asset holding details")
|
||||||
|
metadata: AssetMetadata = Field(description="Asset metadata")
|
||||||
|
|
||||||
|
# Position-level info
|
||||||
|
leverage: Float | None = Field(default=None, description="Current leverage")
|
||||||
|
side: str | None = Field(default=None, description="LONG or SHORT (for derivatives)")
|
||||||
|
|
||||||
|
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class Balance(BaseModel):
|
||||||
|
"""Account balance for a single currency/asset"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
asset_id: str = Field(description="Asset identifier")
|
||||||
|
total: Float = Field(description="Total balance")
|
||||||
|
available: Float = Field(description="Available for trading")
|
||||||
|
locked: Float = Field(description="Locked in orders/positions")
|
||||||
|
|
||||||
|
# For margin accounts
|
||||||
|
borrowed: Float = Field(default=0.0, description="Borrowed amount (margin)")
|
||||||
|
interest: Float = Field(default=0.0, description="Accrued interest")
|
||||||
|
|
||||||
|
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class AccountState(BaseModel):
|
||||||
|
"""
|
||||||
|
Complete account state including balances, positions, and margin info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
account_id: str = Field(description="Account identifier")
|
||||||
|
exchange: str = Field(description="Exchange identifier")
|
||||||
|
|
||||||
|
balances: list[Balance] = Field(default_factory=list, description="All asset balances")
|
||||||
|
positions: list[Position] = Field(default_factory=list, description="All open positions")
|
||||||
|
|
||||||
|
# Margin account info
|
||||||
|
total_equity: Float | None = Field(default=None, description="Total account equity")
|
||||||
|
total_margin_used: Float | None = Field(default=None, description="Total margin in use")
|
||||||
|
total_available_margin: Float | None = Field(default=None, description="Available margin")
|
||||||
|
margin_level: Float | None = Field(default=None, description="Margin level (equity/margin_used)")
|
||||||
|
|
||||||
|
# Risk metrics
|
||||||
|
total_unrealized_pnl: Float | None = Field(default=None, description="Total unrealized P&L")
|
||||||
|
total_realized_pnl: Float | None = Field(default=None, description="Total realized P&L")
|
||||||
|
|
||||||
|
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific data")
|
||||||
472
backend/src/exchange_kernel/state.py
Normal file
472
backend/src/exchange_kernel/state.py
Normal file
@@ -0,0 +1,472 @@
|
|||||||
|
"""
|
||||||
|
State management for the Exchange Kernel.
|
||||||
|
|
||||||
|
Implements the storage and reconciliation logic for desired vs actual state.
|
||||||
|
This is the "Kubernetes for orders" concept - maintaining intent and continuously
|
||||||
|
reconciling reality to match intent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from .models import OrderIntent, OrderState, ReconciliationStatus
|
||||||
|
from ..schema.order_spec import Uint64
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Intent State Store - Desired State
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IntentStateStore(ABC):
|
||||||
|
"""
|
||||||
|
Storage for order intents (desired state).
|
||||||
|
|
||||||
|
This represents what the strategy kernel wants. Intents are durable and
|
||||||
|
persist across restarts. The reconciliation loop continuously works to
|
||||||
|
make actual state match these intents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_intent(self, intent: OrderIntent) -> None:
|
||||||
|
"""
|
||||||
|
Store a new order intent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent: The order intent to store
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AlreadyExistsError: If intent_id already exists
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||||
|
"""
|
||||||
|
Retrieve an order intent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The order intent
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_intent(self, intent: OrderIntent) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing order intent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent: Updated intent (intent_id must match existing)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_intent(self, intent_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete an order intent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID to delete
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If intent_id doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_intents(
|
||||||
|
self,
|
||||||
|
symbol_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
) -> list[OrderIntent]:
|
||||||
|
"""
|
||||||
|
List all order intents, optionally filtered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: Filter by symbol
|
||||||
|
group_id: Filter by OCO group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching intents
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||||
|
"""
|
||||||
|
Get all intents in an OCO group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: Group ID to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of intents in the group
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Actual State Store - Current Reality
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ActualStateStore(ABC):
|
||||||
|
"""
|
||||||
|
Storage for actual order state (reality on exchange).
|
||||||
|
|
||||||
|
This represents what actually exists on the exchange right now.
|
||||||
|
Updated frequently from exchange feeds and order status queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_order_state(self, state: OrderState) -> None:
|
||||||
|
"""
|
||||||
|
Store a new order state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The order state to store
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AlreadyExistsError: If order state for this intent_id already exists
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||||
|
"""
|
||||||
|
Retrieve order state for an intent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current order state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no state exists for this intent
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||||
|
"""
|
||||||
|
Retrieve order state by exchange order ID.
|
||||||
|
|
||||||
|
Useful for processing exchange callbacks that only provide exchange_order_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_order_id: Exchange's order ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The order state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no state exists for this exchange order ID
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_order_state(self, state: OrderState) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing order state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Updated state (intent_id must match existing)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If state doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_order_state(self, intent_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete an order state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent ID whose state to delete
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If state doesn't exist
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_order_states(
|
||||||
|
self,
|
||||||
|
symbol_id: str | None = None,
|
||||||
|
reconciliation_status: ReconciliationStatus | None = None,
|
||||||
|
) -> list[OrderState]:
|
||||||
|
"""
|
||||||
|
List all order states, optionally filtered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: Filter by symbol
|
||||||
|
reconciliation_status: Filter by reconciliation status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching order states
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||||
|
"""
|
||||||
|
Find orders that haven't been synced recently.
|
||||||
|
|
||||||
|
Used to identify orders that need status updates from exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_seconds: Maximum age since last sync
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of order states that need refresh
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# In-Memory Implementations (for testing/prototyping)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class InMemoryIntentStore(IntentStateStore):
|
||||||
|
"""Simple in-memory implementation of IntentStateStore"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._intents: dict[str, OrderIntent] = {}
|
||||||
|
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||||
|
self._by_group: dict[str, set[str]] = defaultdict(set)
|
||||||
|
|
||||||
|
async def create_intent(self, intent: OrderIntent) -> None:
|
||||||
|
if intent.intent_id in self._intents:
|
||||||
|
raise ValueError(f"Intent {intent.intent_id} already exists")
|
||||||
|
self._intents[intent.intent_id] = intent
|
||||||
|
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||||
|
if intent.group_id:
|
||||||
|
self._by_group[intent.group_id].add(intent.intent_id)
|
||||||
|
|
||||||
|
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||||
|
if intent_id not in self._intents:
|
||||||
|
raise KeyError(f"Intent {intent_id} not found")
|
||||||
|
return self._intents[intent_id]
|
||||||
|
|
||||||
|
async def update_intent(self, intent: OrderIntent) -> None:
|
||||||
|
if intent.intent_id not in self._intents:
|
||||||
|
raise KeyError(f"Intent {intent.intent_id} not found")
|
||||||
|
old_intent = self._intents[intent.intent_id]
|
||||||
|
|
||||||
|
# Update indices if symbol or group changed
|
||||||
|
if old_intent.order.symbol_id != intent.order.symbol_id:
|
||||||
|
self._by_symbol[old_intent.order.symbol_id].discard(intent.intent_id)
|
||||||
|
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||||
|
|
||||||
|
if old_intent.group_id != intent.group_id:
|
||||||
|
if old_intent.group_id:
|
||||||
|
self._by_group[old_intent.group_id].discard(intent.intent_id)
|
||||||
|
if intent.group_id:
|
||||||
|
self._by_group[intent.group_id].add(intent.intent_id)
|
||||||
|
|
||||||
|
self._intents[intent.intent_id] = intent
|
||||||
|
|
||||||
|
async def delete_intent(self, intent_id: str) -> None:
|
||||||
|
if intent_id not in self._intents:
|
||||||
|
raise KeyError(f"Intent {intent_id} not found")
|
||||||
|
intent = self._intents[intent_id]
|
||||||
|
self._by_symbol[intent.order.symbol_id].discard(intent_id)
|
||||||
|
if intent.group_id:
|
||||||
|
self._by_group[intent.group_id].discard(intent_id)
|
||||||
|
del self._intents[intent_id]
|
||||||
|
|
||||||
|
async def list_intents(
|
||||||
|
self,
|
||||||
|
symbol_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
) -> list[OrderIntent]:
|
||||||
|
if symbol_id and group_id:
|
||||||
|
# Intersection of both filters
|
||||||
|
symbol_ids = self._by_symbol.get(symbol_id, set())
|
||||||
|
group_ids = self._by_group.get(group_id, set())
|
||||||
|
intent_ids = symbol_ids & group_ids
|
||||||
|
elif symbol_id:
|
||||||
|
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||||
|
elif group_id:
|
||||||
|
intent_ids = self._by_group.get(group_id, set())
|
||||||
|
else:
|
||||||
|
intent_ids = self._intents.keys()
|
||||||
|
|
||||||
|
return [self._intents[iid] for iid in intent_ids]
|
||||||
|
|
||||||
|
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||||
|
intent_ids = self._by_group.get(group_id, set())
|
||||||
|
return [self._intents[iid] for iid in intent_ids]
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryActualStateStore(ActualStateStore):
|
||||||
|
"""Simple in-memory implementation of ActualStateStore"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._states: dict[str, OrderState] = {}
|
||||||
|
self._by_exchange_id: dict[str, str] = {} # exchange_order_id -> intent_id
|
||||||
|
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||||
|
|
||||||
|
async def create_order_state(self, state: OrderState) -> None:
|
||||||
|
if state.intent_id in self._states:
|
||||||
|
raise ValueError(f"Order state for intent {state.intent_id} already exists")
|
||||||
|
self._states[state.intent_id] = state
|
||||||
|
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||||
|
self._by_symbol[state.status.order.symbol_id].add(state.intent_id)
|
||||||
|
|
||||||
|
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||||
|
if intent_id not in self._states:
|
||||||
|
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||||
|
return self._states[intent_id]
|
||||||
|
|
||||||
|
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||||
|
if exchange_order_id not in self._by_exchange_id:
|
||||||
|
raise KeyError(f"Order state for exchange order {exchange_order_id} not found")
|
||||||
|
intent_id = self._by_exchange_id[exchange_order_id]
|
||||||
|
return self._states[intent_id]
|
||||||
|
|
||||||
|
async def update_order_state(self, state: OrderState) -> None:
|
||||||
|
if state.intent_id not in self._states:
|
||||||
|
raise KeyError(f"Order state for intent {state.intent_id} not found")
|
||||||
|
old_state = self._states[state.intent_id]
|
||||||
|
|
||||||
|
# Update exchange_id index if it changed
|
||||||
|
if old_state.exchange_order_id != state.exchange_order_id:
|
||||||
|
del self._by_exchange_id[old_state.exchange_order_id]
|
||||||
|
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||||
|
|
||||||
|
# Update symbol index if it changed
|
||||||
|
old_symbol = old_state.status.order.symbol_id
|
||||||
|
new_symbol = state.status.order.symbol_id
|
||||||
|
if old_symbol != new_symbol:
|
||||||
|
self._by_symbol[old_symbol].discard(state.intent_id)
|
||||||
|
self._by_symbol[new_symbol].add(state.intent_id)
|
||||||
|
|
||||||
|
self._states[state.intent_id] = state
|
||||||
|
|
||||||
|
async def delete_order_state(self, intent_id: str) -> None:
|
||||||
|
if intent_id not in self._states:
|
||||||
|
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||||
|
state = self._states[intent_id]
|
||||||
|
del self._by_exchange_id[state.exchange_order_id]
|
||||||
|
self._by_symbol[state.status.order.symbol_id].discard(intent_id)
|
||||||
|
del self._states[intent_id]
|
||||||
|
|
||||||
|
async def list_order_states(
|
||||||
|
self,
|
||||||
|
symbol_id: str | None = None,
|
||||||
|
reconciliation_status: ReconciliationStatus | None = None,
|
||||||
|
) -> list[OrderState]:
|
||||||
|
if symbol_id:
|
||||||
|
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||||
|
states = [self._states[iid] for iid in intent_ids]
|
||||||
|
else:
|
||||||
|
states = list(self._states.values())
|
||||||
|
|
||||||
|
if reconciliation_status:
|
||||||
|
states = [s for s in states if s.reconciliation_status == reconciliation_status]
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||||
|
import time
|
||||||
|
current_time = int(time.time())
|
||||||
|
threshold = current_time - max_age_seconds
|
||||||
|
|
||||||
|
return [
|
||||||
|
state
|
||||||
|
for state in self._states.values()
|
||||||
|
if state.last_sync_at < threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reconciliation Engine (framework only, no implementation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ReconciliationEngine:
|
||||||
|
"""
|
||||||
|
Reconciliation engine that continuously works to make actual state match intent.
|
||||||
|
|
||||||
|
This is the heart of the "Kubernetes for orders" concept. It:
|
||||||
|
1. Compares desired state (intents) with actual state (exchange orders)
|
||||||
|
2. Computes necessary actions (place, modify, cancel)
|
||||||
|
3. Executes those actions via the exchange API
|
||||||
|
4. Handles retries, errors, and edge cases
|
||||||
|
|
||||||
|
This is a framework class - concrete implementations will be exchange-specific.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
intent_store: IntentStateStore,
|
||||||
|
actual_store: ActualStateStore,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the reconciliation engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_store: Store for desired state
|
||||||
|
actual_store: Store for actual state
|
||||||
|
"""
|
||||||
|
self.intent_store = intent_store
|
||||||
|
self.actual_store = actual_store
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the reconciliation loop"""
|
||||||
|
self._running = True
|
||||||
|
# Implementation would start async reconciliation loop here
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the reconciliation loop"""
|
||||||
|
self._running = False
|
||||||
|
# Implementation would stop reconciliation loop here
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def reconcile_intent(self, intent_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Reconcile a specific intent.
|
||||||
|
|
||||||
|
Compares the intent with actual state and takes necessary actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent_id: Intent to reconcile
|
||||||
|
"""
|
||||||
|
# Framework only - concrete implementation needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def reconcile_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Reconcile all intents.
|
||||||
|
|
||||||
|
Full reconciliation pass over all orders.
|
||||||
|
"""
|
||||||
|
# Framework only - concrete implementation needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_metrics(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get reconciliation metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Metrics about reconciliation performance, errors, etc.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"running": self._running,
|
||||||
|
"reconciliation_lag_ms": 0, # Framework only
|
||||||
|
"pending_reconciliations": 0, # Framework only
|
||||||
|
"error_count": 0, # Framework only
|
||||||
|
"retry_count": 0, # Framework only
|
||||||
|
}
|
||||||
@@ -94,6 +94,11 @@ class Gateway:
|
|||||||
logger.info(f"Session is busy, interrupting existing task")
|
logger.info(f"Session is busy, interrupting existing task")
|
||||||
await session.interrupt()
|
await session.interrupt()
|
||||||
|
|
||||||
|
# Check if this is a stop interrupt (empty message)
|
||||||
|
if not message.content.strip() and not message.attachments:
|
||||||
|
logger.info("Received stop interrupt (empty message), not starting new agent round")
|
||||||
|
return
|
||||||
|
|
||||||
# Add user message to history
|
# Add user message to history
|
||||||
session.add_message("user", message.content, message.channel_id)
|
session.add_message("user", message.content, message.channel_id)
|
||||||
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
||||||
@@ -134,33 +139,55 @@ class Gateway:
|
|||||||
# Stream chunks back to active channels
|
# Stream chunks back to active channels
|
||||||
full_response = ""
|
full_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
async for chunk in response_stream:
|
accumulated_metadata = {}
|
||||||
chunk_count += 1
|
|
||||||
full_response += chunk
|
|
||||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
|
||||||
|
|
||||||
# Send chunk to all active channels
|
async for chunk in response_stream:
|
||||||
agent_msg = AgentMessage(
|
# Handle dict response with metadata (from agent executor)
|
||||||
session_id=session.session_id,
|
if isinstance(chunk, dict):
|
||||||
target_channels=session.active_channels,
|
content = chunk.get("content", "")
|
||||||
content=chunk,
|
metadata = chunk.get("metadata", {})
|
||||||
stream_chunk=True,
|
# Accumulate metadata (e.g., plot_urls)
|
||||||
done=False
|
for key, value in metadata.items():
|
||||||
)
|
if key == "plot_urls" and value:
|
||||||
await self._send_to_channels(agent_msg)
|
# Append to existing plot_urls
|
||||||
|
if "plot_urls" not in accumulated_metadata:
|
||||||
|
accumulated_metadata["plot_urls"] = []
|
||||||
|
accumulated_metadata["plot_urls"].extend(value)
|
||||||
|
logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}")
|
||||||
|
else:
|
||||||
|
accumulated_metadata[key] = value
|
||||||
|
chunk = content
|
||||||
|
|
||||||
|
# Only send non-empty chunks
|
||||||
|
if chunk:
|
||||||
|
chunk_count += 1
|
||||||
|
full_response += chunk
|
||||||
|
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||||
|
|
||||||
|
# Send chunk to all active channels with accumulated metadata
|
||||||
|
agent_msg = AgentMessage(
|
||||||
|
session_id=session.session_id,
|
||||||
|
target_channels=session.active_channels,
|
||||||
|
content=chunk,
|
||||||
|
stream_chunk=True,
|
||||||
|
done=False,
|
||||||
|
metadata=accumulated_metadata.copy()
|
||||||
|
)
|
||||||
|
await self._send_to_channels(agent_msg)
|
||||||
|
|
||||||
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
||||||
|
|
||||||
# Send final done message
|
# Send final done message with all accumulated metadata
|
||||||
agent_msg = AgentMessage(
|
agent_msg = AgentMessage(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
target_channels=session.active_channels,
|
target_channels=session.active_channels,
|
||||||
content="",
|
content="",
|
||||||
stream_chunk=True,
|
stream_chunk=True,
|
||||||
done=True
|
done=True,
|
||||||
|
metadata=accumulated_metadata
|
||||||
)
|
)
|
||||||
await self._send_to_channels(agent_msg)
|
await self._send_to_channels(agent_msg)
|
||||||
logger.info("Sent final done message to channels")
|
logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}")
|
||||||
|
|
||||||
# Add to history
|
# Add to history
|
||||||
session.add_message("assistant", full_response)
|
session.add_message("assistant", full_response)
|
||||||
|
|||||||
172
backend/src/indicator/__init__.py
Normal file
172
backend/src/indicator/__init__.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""
|
||||||
|
Composable Indicator System.
|
||||||
|
|
||||||
|
Provides a framework for building DAGs of data transformation pipelines
|
||||||
|
that process time-series data incrementally. Indicators can consume
|
||||||
|
DataSources or other Indicators as inputs, composing into arbitrarily
|
||||||
|
complex processing graphs.
|
||||||
|
|
||||||
|
Key Components:
|
||||||
|
---------------
|
||||||
|
|
||||||
|
Indicator (base.py):
|
||||||
|
Abstract base class for all indicator implementations.
|
||||||
|
Declares input/output schemas and implements synchronous compute().
|
||||||
|
|
||||||
|
IndicatorRegistry (registry.py):
|
||||||
|
Central catalog of available indicators with rich metadata
|
||||||
|
for AI agent discovery and tool generation.
|
||||||
|
|
||||||
|
Pipeline (pipeline.py):
|
||||||
|
Execution engine that builds DAGs, resolves dependencies,
|
||||||
|
and orchestrates incremental data flow through indicator chains.
|
||||||
|
|
||||||
|
Schema Types (schema.py):
|
||||||
|
Type definitions for input/output schemas, computation context,
|
||||||
|
and metadata for AI-native documentation.
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
--------------
|
||||||
|
|
||||||
|
from indicator import Indicator, IndicatorRegistry, Pipeline
|
||||||
|
from indicator.schema import (
|
||||||
|
InputSchema, OutputSchema, ComputeContext, ComputeResult,
|
||||||
|
IndicatorMetadata, IndicatorParameter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define an indicator
|
||||||
|
class SimpleMovingAverage(Indicator):
|
||||||
|
@classmethod
|
||||||
|
def get_metadata(cls):
|
||||||
|
return IndicatorMetadata(
|
||||||
|
name="SMA",
|
||||||
|
display_name="Simple Moving Average",
|
||||||
|
description="Arithmetic mean of prices over N periods",
|
||||||
|
category="trend",
|
||||||
|
parameters=[
|
||||||
|
IndicatorParameter(
|
||||||
|
name="period",
|
||||||
|
type="int",
|
||||||
|
description="Number of periods to average",
|
||||||
|
default=20,
|
||||||
|
min_value=1
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tags=["moving-average", "trend-following"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_schema(cls):
|
||||||
|
return InputSchema(
|
||||||
|
required_columns=[
|
||||||
|
ColumnInfo(name="close", type="float", description="Closing price")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_output_schema(cls, **params):
|
||||||
|
return OutputSchema(
|
||||||
|
columns=[
|
||||||
|
ColumnInfo(
|
||||||
|
name="sma",
|
||||||
|
type="float",
|
||||||
|
description=f"Simple moving average over {params.get('period', 20)} periods"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||||
|
period = self.params["period"]
|
||||||
|
closes = context.get_column("close")
|
||||||
|
times = context.get_times()
|
||||||
|
|
||||||
|
sma_values = []
|
||||||
|
for i in range(len(closes)):
|
||||||
|
if i < period - 1:
|
||||||
|
sma_values.append(None)
|
||||||
|
else:
|
||||||
|
window = closes[i - period + 1 : i + 1]
|
||||||
|
sma_values.append(sum(window) / period)
|
||||||
|
|
||||||
|
return ComputeResult(
|
||||||
|
data=[
|
||||||
|
{"time": times[i], "sma": sma_values[i]}
|
||||||
|
for i in range(len(times))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the indicator
|
||||||
|
registry = IndicatorRegistry()
|
||||||
|
registry.register(SimpleMovingAverage)
|
||||||
|
|
||||||
|
# Create a pipeline
|
||||||
|
pipeline = Pipeline(datasource_registry)
|
||||||
|
pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D")
|
||||||
|
|
||||||
|
sma_indicator = registry.create_instance("SMA", "sma_20", period=20)
|
||||||
|
pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"])
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
results = pipeline.execute(datasource_data={"price_data": price_bars})
|
||||||
|
sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma
|
||||||
|
|
||||||
|
Design Philosophy:
|
||||||
|
------------------
|
||||||
|
|
||||||
|
1. **Schema-based composition**: Indicators declare inputs/outputs via schemas,
|
||||||
|
enabling automatic validation and flexible composition.
|
||||||
|
|
||||||
|
2. **Synchronous execution**: All computation is synchronous for simplicity.
|
||||||
|
Async handling happens at the event/strategy layer.
|
||||||
|
|
||||||
|
3. **Incremental updates**: Indicators receive context about what changed,
|
||||||
|
allowing optimized recomputation of only affected values.
|
||||||
|
|
||||||
|
4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs
|
||||||
|
make indicators discoverable and usable by AI agents.
|
||||||
|
|
||||||
|
5. **Generic data flow**: Indicators work with any data source that matches
|
||||||
|
their input schema, not specific DataSource instances.
|
||||||
|
|
||||||
|
6. **Event-driven**: Designed to react to DataSource updates and propagate
|
||||||
|
changes through the DAG efficiently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import DataSourceAdapter, Indicator
|
||||||
|
from .pipeline import Pipeline, PipelineNode
|
||||||
|
from .registry import IndicatorRegistry
|
||||||
|
from .schema import (
|
||||||
|
ComputeContext,
|
||||||
|
ComputeResult,
|
||||||
|
IndicatorMetadata,
|
||||||
|
IndicatorParameter,
|
||||||
|
InputSchema,
|
||||||
|
OutputSchema,
|
||||||
|
)
|
||||||
|
from .talib_adapter import (
|
||||||
|
TALibIndicator,
|
||||||
|
register_all_talib_indicators,
|
||||||
|
is_talib_available,
|
||||||
|
get_talib_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Core classes
|
||||||
|
"Indicator",
|
||||||
|
"IndicatorRegistry",
|
||||||
|
"Pipeline",
|
||||||
|
"PipelineNode",
|
||||||
|
"DataSourceAdapter",
|
||||||
|
# Schema types
|
||||||
|
"InputSchema",
|
||||||
|
"OutputSchema",
|
||||||
|
"ComputeContext",
|
||||||
|
"ComputeResult",
|
||||||
|
"IndicatorMetadata",
|
||||||
|
"IndicatorParameter",
|
||||||
|
# TA-Lib integration
|
||||||
|
"TALibIndicator",
|
||||||
|
"register_all_talib_indicators",
|
||||||
|
"is_talib_available",
|
||||||
|
"get_talib_version",
|
||||||
|
]
|
||||||
230
backend/src/indicator/base.py
Normal file
230
backend/src/indicator/base.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
Abstract Indicator interface.
|
||||||
|
|
||||||
|
Provides the base class for all technical indicators and derived data transformations.
|
||||||
|
Indicators compose into DAGs, processing data incrementally as updates arrive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .schema import (
|
||||||
|
ComputeContext,
|
||||||
|
ComputeResult,
|
||||||
|
IndicatorMetadata,
|
||||||
|
InputSchema,
|
||||||
|
OutputSchema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Indicator(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all indicators.
|
||||||
|
|
||||||
|
Indicators are composable transformation nodes that:
|
||||||
|
- Declare input schema (columns they need)
|
||||||
|
- Declare output schema (columns they produce)
|
||||||
|
- Compute outputs synchronously from inputs
|
||||||
|
- Support incremental updates (process only what changed)
|
||||||
|
- Provide rich metadata for AI agent discovery
|
||||||
|
|
||||||
|
Indicators are stateless at the instance level - all state is managed
|
||||||
|
by the pipeline execution engine. This allows the same indicator class
|
||||||
|
to be reused with different parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, instance_name: str, **params):
|
||||||
|
"""
|
||||||
|
Initialize an indicator instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance_name: Unique name for this instance (used for output column prefixing)
|
||||||
|
**params: Configuration parameters (validated against metadata.parameters)
|
||||||
|
"""
|
||||||
|
self.instance_name = instance_name
|
||||||
|
self.params = params
|
||||||
|
self._validate_params()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(cls) -> IndicatorMetadata:
|
||||||
|
"""
|
||||||
|
Get metadata for this indicator class.
|
||||||
|
|
||||||
|
Called by the registry for AI agent discovery and documentation.
|
||||||
|
Should return comprehensive information about the indicator's purpose,
|
||||||
|
parameters, and use cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndicatorMetadata describing this indicator class
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_input_schema(cls) -> InputSchema:
|
||||||
|
"""
|
||||||
|
Get the input schema required by this indicator.
|
||||||
|
|
||||||
|
Declares what columns must be present in the input data.
|
||||||
|
The pipeline will match this against available data sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputSchema describing required and optional input columns
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_output_schema(cls, **params) -> OutputSchema:
|
||||||
|
"""
|
||||||
|
Get the output schema produced by this indicator.
|
||||||
|
|
||||||
|
Output column names will be automatically prefixed with the instance name
|
||||||
|
by the pipeline engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**params: Configuration parameters (may affect output schema)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputSchema describing the columns this indicator produces
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||||
|
"""
|
||||||
|
Compute indicator values from input data.
|
||||||
|
|
||||||
|
This method is called synchronously by the pipeline engine whenever
|
||||||
|
input data changes. Implementations should:
|
||||||
|
|
||||||
|
1. Extract needed columns from context.data
|
||||||
|
2. Perform calculations
|
||||||
|
3. Return results with proper time alignment
|
||||||
|
|
||||||
|
For incremental updates (context.is_incremental == True):
|
||||||
|
- context.data contains only new/updated rows
|
||||||
|
- Implementations MAY optimize by computing only these rows
|
||||||
|
- OR implementations MAY recompute everything (simpler but slower)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Input data and update metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComputeResult with calculated indicator values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input data doesn't match expected schema
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate that provided parameters match the metadata specification.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required parameters are missing or invalid
|
||||||
|
"""
|
||||||
|
metadata = self.get_metadata()
|
||||||
|
|
||||||
|
# Check for required parameters
|
||||||
|
for param_def in metadata.parameters:
|
||||||
|
if param_def.required and param_def.name not in self.params:
|
||||||
|
raise ValueError(
|
||||||
|
f"Indicator '{metadata.name}' requires parameter '{param_def.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate parameter types and ranges
|
||||||
|
for name, value in self.params.items():
|
||||||
|
# Find parameter definition
|
||||||
|
param_def = next(
|
||||||
|
(p for p in metadata.parameters if p.name == name),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
if param_def is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown parameter '{name}' for indicator '{metadata.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
if param_def.type == "int" and not isinstance(value, int):
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' must be int, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
elif param_def.type == "float" and not isinstance(value, (int, float)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' must be float, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
elif param_def.type == "bool" and not isinstance(value, bool):
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' must be bool, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
elif param_def.type == "string" and not isinstance(value, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' must be string, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Range checking for numeric types
|
||||||
|
if param_def.type in ("int", "float"):
|
||||||
|
if param_def.min_value is not None and value < param_def.min_value:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' must be >= {param_def.min_value}, got {value}"
|
||||||
|
)
|
||||||
|
if param_def.max_value is not None and value > param_def.max_value:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' must be <= {param_def.max_value}, got {value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_output_columns(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the output column names with instance name prefix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of prefixed output column names
|
||||||
|
"""
|
||||||
|
output_schema = self.get_output_schema(**self.params)
|
||||||
|
prefixed = output_schema.with_prefix(self.instance_name)
|
||||||
|
return [col.name for col in prefixed.columns if col.name != output_schema.time_column]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})"
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceAdapter:
|
||||||
|
"""
|
||||||
|
Adapter to make a DataSource look like an Indicator for pipeline composition.
|
||||||
|
|
||||||
|
This allows DataSources to be inputs to indicators in a unified way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, datasource_id: str, symbol: str, resolution: str):
|
||||||
|
"""
|
||||||
|
Create a DataSource adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo')
|
||||||
|
symbol: Symbol to query (e.g., 'BTC/USD')
|
||||||
|
resolution: Time resolution (e.g., '1', '5', '1D')
|
||||||
|
"""
|
||||||
|
self.datasource_id = datasource_id
|
||||||
|
self.symbol = symbol
|
||||||
|
self.resolution = resolution
|
||||||
|
self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_")
|
||||||
|
|
||||||
|
def get_output_columns(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the columns provided by this datasource.
|
||||||
|
|
||||||
|
Note: This requires runtime resolution - the pipeline engine
|
||||||
|
will need to query the actual DataSource to get the schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of column names (placeholder - needs runtime resolution)
|
||||||
|
"""
|
||||||
|
# This will be resolved at runtime by the pipeline engine
|
||||||
|
return []
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')"
|
||||||
439
backend/src/indicator/pipeline.py
Normal file
439
backend/src/indicator/pipeline.py
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
"""
|
||||||
|
Pipeline execution engine for composable indicators.
|
||||||
|
|
||||||
|
Manages DAG construction, dependency resolution, incremental updates,
|
||||||
|
and efficient data flow through indicator chains.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from datasource.base import DataSource
|
||||||
|
from datasource.schema import ColumnInfo
|
||||||
|
|
||||||
|
from .base import DataSourceAdapter, Indicator
|
||||||
|
from .schema import ComputeContext, ComputeResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineNode:
|
||||||
|
"""
|
||||||
|
A node in the pipeline DAG.
|
||||||
|
|
||||||
|
Can be either a DataSource adapter or an Indicator instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_id: str,
|
||||||
|
node: Union[DataSourceAdapter, Indicator],
|
||||||
|
dependencies: List[str]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a pipeline node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Unique identifier for this node
|
||||||
|
node: The DataSourceAdapter or Indicator instance
|
||||||
|
dependencies: List of node_ids this node depends on
|
||||||
|
"""
|
||||||
|
self.node_id = node_id
|
||||||
|
self.node = node
|
||||||
|
self.dependencies = dependencies
|
||||||
|
self.output_columns: List[str] = []
|
||||||
|
self.cached_data: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def is_datasource(self) -> bool:
|
||||||
|
"""Check if this node is a DataSource adapter."""
|
||||||
|
return isinstance(self.node, DataSourceAdapter)
|
||||||
|
|
||||||
|
def is_indicator(self) -> bool:
|
||||||
|
"""Check if this node is an Indicator."""
|
||||||
|
return isinstance(self.node, Indicator)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""
|
||||||
|
Execution engine for indicator DAGs.
|
||||||
|
|
||||||
|
Manages:
|
||||||
|
- DAG construction and validation
|
||||||
|
- Topological sorting for execution order
|
||||||
|
- Data flow and caching
|
||||||
|
- Incremental updates (only recompute what changed)
|
||||||
|
- Schema validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, datasource_registry):
|
||||||
|
"""
|
||||||
|
Initialize a pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasource_registry: DataSourceRegistry for resolving data sources
|
||||||
|
"""
|
||||||
|
self.datasource_registry = datasource_registry
|
||||||
|
self.nodes: Dict[str, PipelineNode] = {}
|
||||||
|
self.execution_order: List[str] = []
|
||||||
|
self._dirty_nodes: Set[str] = set()
|
||||||
|
|
||||||
|
def add_datasource(
|
||||||
|
self,
|
||||||
|
node_id: str,
|
||||||
|
datasource_name: str,
|
||||||
|
symbol: str,
|
||||||
|
resolution: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add a DataSource to the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Unique identifier for this node
|
||||||
|
datasource_name: Name of the datasource in the registry
|
||||||
|
symbol: Symbol to query
|
||||||
|
resolution: Time resolution
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If node_id already exists or datasource not found
|
||||||
|
"""
|
||||||
|
if node_id in self.nodes:
|
||||||
|
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||||
|
|
||||||
|
datasource = self.datasource_registry.get(datasource_name)
|
||||||
|
if not datasource:
|
||||||
|
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
|
||||||
|
|
||||||
|
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
|
||||||
|
node = PipelineNode(node_id, adapter, dependencies=[])
|
||||||
|
|
||||||
|
self.nodes[node_id] = node
|
||||||
|
self._invalidate_execution_order()
|
||||||
|
|
||||||
|
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
|
||||||
|
|
||||||
|
def add_indicator(
|
||||||
|
self,
|
||||||
|
node_id: str,
|
||||||
|
indicator: Indicator,
|
||||||
|
input_node_ids: List[str]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add an Indicator to the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Unique identifier for this node
|
||||||
|
indicator: Indicator instance
|
||||||
|
input_node_ids: List of node IDs providing input data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If node_id already exists, dependencies not found, or schema mismatch
|
||||||
|
"""
|
||||||
|
if node_id in self.nodes:
|
||||||
|
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||||
|
|
||||||
|
# Validate dependencies exist
|
||||||
|
for dep_id in input_node_ids:
|
||||||
|
if dep_id not in self.nodes:
|
||||||
|
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
|
||||||
|
|
||||||
|
# TODO: Validate input schema matches available columns from dependencies
|
||||||
|
# This requires merging output schemas from all input nodes
|
||||||
|
|
||||||
|
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
|
||||||
|
self.nodes[node_id] = node
|
||||||
|
self._invalidate_execution_order()
|
||||||
|
|
||||||
|
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
|
||||||
|
|
||||||
|
def remove_node(self, node_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Remove a node from the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node to remove
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If other nodes depend on this node
|
||||||
|
"""
|
||||||
|
if node_id not in self.nodes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for dependent nodes
|
||||||
|
dependents = [
|
||||||
|
n.node_id for n in self.nodes.values()
|
||||||
|
if node_id in n.dependencies
|
||||||
|
]
|
||||||
|
|
||||||
|
if dependents:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
|
||||||
|
)
|
||||||
|
|
||||||
|
del self.nodes[node_id]
|
||||||
|
self._invalidate_execution_order()
|
||||||
|
|
||||||
|
logger.info(f"Removed node '{node_id}' from pipeline")
|
||||||
|
|
||||||
|
def _invalidate_execution_order(self) -> None:
|
||||||
|
"""Mark execution order as needing recomputation."""
|
||||||
|
self.execution_order = []
|
||||||
|
|
||||||
|
def _compute_execution_order(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Compute topological sort of the DAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of node IDs in execution order
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If DAG contains cycles
|
||||||
|
"""
|
||||||
|
if self.execution_order:
|
||||||
|
return self.execution_order
|
||||||
|
|
||||||
|
# Kahn's algorithm for topological sort
|
||||||
|
in_degree = {node_id: 0 for node_id in self.nodes}
|
||||||
|
for node in self.nodes.values():
|
||||||
|
for dep in node.dependencies:
|
||||||
|
in_degree[node.node_id] += 1
|
||||||
|
|
||||||
|
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
|
||||||
|
result = []
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
node_id = queue.popleft()
|
||||||
|
result.append(node_id)
|
||||||
|
|
||||||
|
# Find all nodes that depend on this one
|
||||||
|
for other_node in self.nodes.values():
|
||||||
|
if node_id in other_node.dependencies:
|
||||||
|
in_degree[other_node.node_id] -= 1
|
||||||
|
if in_degree[other_node.node_id] == 0:
|
||||||
|
queue.append(other_node.node_id)
|
||||||
|
|
||||||
|
if len(result) != len(self.nodes):
|
||||||
|
raise ValueError("Pipeline contains cycles")
|
||||||
|
|
||||||
|
self.execution_order = result
|
||||||
|
logger.debug(f"Computed execution order: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
datasource_data: Dict[str, List[Dict[str, Any]]],
|
||||||
|
incremental: bool = False,
|
||||||
|
updated_from_time: Optional[int] = None
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Execute the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasource_data: Mapping of DataSource node_id to input data
|
||||||
|
incremental: Whether this is an incremental update
|
||||||
|
updated_from_time: Timestamp of earliest updated row (for incremental)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping node_id to output data (all nodes)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required datasource data is missing
|
||||||
|
"""
|
||||||
|
execution_order = self._compute_execution_order()
|
||||||
|
results: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Executing pipeline with {len(execution_order)} nodes "
|
||||||
|
f"(incremental={incremental})"
|
||||||
|
)
|
||||||
|
|
||||||
|
for node_id in execution_order:
|
||||||
|
node = self.nodes[node_id]
|
||||||
|
|
||||||
|
if node.is_datasource():
|
||||||
|
# DataSource node - get data from input
|
||||||
|
if node_id not in datasource_data:
|
||||||
|
raise ValueError(
|
||||||
|
f"DataSource node '{node_id}' has no input data"
|
||||||
|
)
|
||||||
|
results[node_id] = datasource_data[node_id]
|
||||||
|
node.cached_data = results[node_id]
|
||||||
|
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
|
||||||
|
|
||||||
|
elif node.is_indicator():
|
||||||
|
# Indicator node - compute from dependencies
|
||||||
|
indicator = node.node
|
||||||
|
|
||||||
|
# Merge input data from all dependencies
|
||||||
|
input_data = self._merge_dependency_data(node.dependencies, results)
|
||||||
|
|
||||||
|
# Create compute context
|
||||||
|
context = ComputeContext(
|
||||||
|
data=input_data,
|
||||||
|
is_incremental=incremental,
|
||||||
|
updated_from_time=updated_from_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute indicator
|
||||||
|
logger.debug(
|
||||||
|
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
|
||||||
|
)
|
||||||
|
compute_result = indicator.compute(context)
|
||||||
|
|
||||||
|
# Merge result with input data (adding prefixed columns)
|
||||||
|
output_data = compute_result.merge_with_prefix(
|
||||||
|
indicator.instance_name,
|
||||||
|
input_data
|
||||||
|
)
|
||||||
|
|
||||||
|
results[node_id] = output_data
|
||||||
|
node.cached_data = output_data
|
||||||
|
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
|
||||||
|
|
||||||
|
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _merge_dependency_data(
|
||||||
|
self,
|
||||||
|
dependency_ids: List[str],
|
||||||
|
results: Dict[str, List[Dict[str, Any]]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Merge data from multiple dependency nodes.
|
||||||
|
|
||||||
|
Data is merged by time, with later dependencies overwriting earlier ones
|
||||||
|
for conflicting column names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dependency_ids: List of node IDs to merge
|
||||||
|
results: Current execution results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged data rows
|
||||||
|
"""
|
||||||
|
if not dependency_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(dependency_ids) == 1:
|
||||||
|
return results[dependency_ids[0]]
|
||||||
|
|
||||||
|
# Build time-indexed data from first dependency
|
||||||
|
merged: Dict[int, Dict[str, Any]] = {}
|
||||||
|
for row in results[dependency_ids[0]]:
|
||||||
|
merged[row["time"]] = row.copy()
|
||||||
|
|
||||||
|
# Merge in additional dependencies
|
||||||
|
for dep_id in dependency_ids[1:]:
|
||||||
|
for row in results[dep_id]:
|
||||||
|
time_key = row["time"]
|
||||||
|
if time_key in merged:
|
||||||
|
# Merge columns (later dependencies win)
|
||||||
|
merged[time_key].update(row)
|
||||||
|
else:
|
||||||
|
# New timestamp
|
||||||
|
merged[time_key] = row.copy()
|
||||||
|
|
||||||
|
# Sort by time and return
|
||||||
|
sorted_times = sorted(merged.keys())
|
||||||
|
return [merged[t] for t in sorted_times]
|
||||||
|
|
||||||
|
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get cached output data for a specific node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached data or None if not available
|
||||||
|
"""
|
||||||
|
node = self.nodes.get(node_id)
|
||||||
|
return node.cached_data if node else None
|
||||||
|
|
||||||
|
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
|
||||||
|
"""
|
||||||
|
Get the output schema for a specific node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ColumnInfo describing output columns
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If node not found
|
||||||
|
"""
|
||||||
|
node = self.nodes.get(node_id)
|
||||||
|
if not node:
|
||||||
|
raise ValueError(f"Node '{node_id}' not found")
|
||||||
|
|
||||||
|
if node.is_datasource():
|
||||||
|
# Would need to query the actual datasource at runtime
|
||||||
|
# For now, return empty - this requires integration with DataSource
|
||||||
|
return []
|
||||||
|
|
||||||
|
elif node.is_indicator():
|
||||||
|
indicator = node.node
|
||||||
|
output_schema = indicator.get_output_schema(**indicator.params)
|
||||||
|
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
|
||||||
|
return prefixed_schema.columns
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate the entire pipeline for correctness.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
- No cycles (already checked in execution order)
|
||||||
|
- All dependencies exist (already checked in add_indicator)
|
||||||
|
- Input schemas match output schemas (TODO)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._compute_execution_order()
|
||||||
|
return True, None
|
||||||
|
except ValueError as e:
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
def get_node_count(self) -> int:
|
||||||
|
"""Get the number of nodes in the pipeline."""
|
||||||
|
return len(self.nodes)
|
||||||
|
|
||||||
|
def get_indicator_count(self) -> int:
|
||||||
|
"""Get the number of indicator nodes in the pipeline."""
|
||||||
|
return sum(1 for node in self.nodes.values() if node.is_indicator())
|
||||||
|
|
||||||
|
def get_datasource_count(self) -> int:
|
||||||
|
"""Get the number of datasource nodes in the pipeline."""
|
||||||
|
return sum(1 for node in self.nodes.values() if node.is_datasource())
|
||||||
|
|
||||||
|
def describe(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get a detailed description of the pipeline structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with pipeline metadata and structure
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"node_count": self.get_node_count(),
|
||||||
|
"datasource_count": self.get_datasource_count(),
|
||||||
|
"indicator_count": self.get_indicator_count(),
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": node.node_id,
|
||||||
|
"type": "datasource" if node.is_datasource() else "indicator",
|
||||||
|
"node": str(node.node),
|
||||||
|
"dependencies": node.dependencies,
|
||||||
|
"cached_rows": len(node.cached_data)
|
||||||
|
}
|
||||||
|
for node in self.nodes.values()
|
||||||
|
],
|
||||||
|
"execution_order": self.execution_order or self._compute_execution_order(),
|
||||||
|
"is_valid": self.validate_pipeline()[0]
|
||||||
|
}
|
||||||
349
backend/src/indicator/registry.py
Normal file
349
backend/src/indicator/registry.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
"""
|
||||||
|
Indicator registry for managing and discovering indicators.
|
||||||
|
|
||||||
|
Provides AI agents with a queryable catalog of available indicators,
|
||||||
|
their capabilities, and metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from .base import Indicator
|
||||||
|
from .schema import IndicatorMetadata, InputSchema, OutputSchema
|
||||||
|
|
||||||
|
|
||||||
|
class IndicatorRegistry:
|
||||||
|
"""
|
||||||
|
Central registry for indicator classes.
|
||||||
|
|
||||||
|
Enables:
|
||||||
|
- Registration of indicator implementations
|
||||||
|
- Discovery by name, category, or tags
|
||||||
|
- Schema validation
|
||||||
|
- AI agent tool generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._indicators: Dict[str, Type[Indicator]] = {}
|
||||||
|
|
||||||
|
def register(self, indicator_class: Type[Indicator]) -> None:
|
||||||
|
"""
|
||||||
|
Register an indicator class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator_class: Indicator class to register
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an indicator with this name is already registered
|
||||||
|
"""
|
||||||
|
metadata = indicator_class.get_metadata()
|
||||||
|
|
||||||
|
if metadata.name in self._indicators:
|
||||||
|
raise ValueError(
|
||||||
|
f"Indicator '{metadata.name}' is already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._indicators[metadata.name] = indicator_class
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> None:
|
||||||
|
"""
|
||||||
|
Unregister an indicator class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Indicator class name
|
||||||
|
"""
|
||||||
|
self._indicators.pop(name, None)
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Type[Indicator]]:
|
||||||
|
"""
|
||||||
|
Get an indicator class by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Indicator class name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Indicator class or None if not found
|
||||||
|
"""
|
||||||
|
return self._indicators.get(name)
|
||||||
|
|
||||||
|
def list_indicators(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get names of all registered indicators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of indicator class names
|
||||||
|
"""
|
||||||
|
return list(self._indicators.keys())
|
||||||
|
|
||||||
|
def get_metadata(self, name: str) -> Optional[IndicatorMetadata]:
|
||||||
|
"""
|
||||||
|
Get metadata for a specific indicator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Indicator class name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndicatorMetadata or None if not found
|
||||||
|
"""
|
||||||
|
indicator_class = self.get(name)
|
||||||
|
if indicator_class:
|
||||||
|
return indicator_class.get_metadata()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_metadata(self) -> List[IndicatorMetadata]:
|
||||||
|
"""
|
||||||
|
Get metadata for all registered indicators.
|
||||||
|
|
||||||
|
Useful for AI agent tool generation and discovery.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of IndicatorMetadata for all registered indicators
|
||||||
|
"""
|
||||||
|
return [cls.get_metadata() for cls in self._indicators.values()]
|
||||||
|
|
||||||
|
def search_by_category(self, category: str) -> List[IndicatorMetadata]:
|
||||||
|
"""
|
||||||
|
Find indicators by category.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: Category name (e.g., 'momentum', 'trend', 'volatility')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching indicator metadata
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for indicator_class in self._indicators.values():
|
||||||
|
metadata = indicator_class.get_metadata()
|
||||||
|
if metadata.category.lower() == category.lower():
|
||||||
|
results.append(metadata)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_by_tag(self, tag: str) -> List[IndicatorMetadata]:
|
||||||
|
"""
|
||||||
|
Find indicators by tag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: Tag to search for (case-insensitive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching indicator metadata
|
||||||
|
"""
|
||||||
|
tag_lower = tag.lower()
|
||||||
|
results = []
|
||||||
|
for indicator_class in self._indicators.values():
|
||||||
|
metadata = indicator_class.get_metadata()
|
||||||
|
if any(t.lower() == tag_lower for t in metadata.tags):
|
||||||
|
results.append(metadata)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_by_text(self, query: str) -> List[IndicatorMetadata]:
|
||||||
|
"""
|
||||||
|
Full-text search across indicator names, descriptions, and use cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (case-insensitive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching indicator metadata, ranked by relevance
|
||||||
|
"""
|
||||||
|
query_lower = query.lower()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for indicator_class in self._indicators.values():
|
||||||
|
metadata = indicator_class.get_metadata()
|
||||||
|
score = 0
|
||||||
|
|
||||||
|
# Check name (highest weight)
|
||||||
|
if query_lower in metadata.name.lower():
|
||||||
|
score += 10
|
||||||
|
if query_lower in metadata.display_name.lower():
|
||||||
|
score += 8
|
||||||
|
|
||||||
|
# Check description
|
||||||
|
if query_lower in metadata.description.lower():
|
||||||
|
score += 5
|
||||||
|
|
||||||
|
# Check use cases
|
||||||
|
for use_case in metadata.use_cases:
|
||||||
|
if query_lower in use_case.lower():
|
||||||
|
score += 3
|
||||||
|
|
||||||
|
# Check tags
|
||||||
|
for tag in metadata.tags:
|
||||||
|
if query_lower in tag.lower():
|
||||||
|
score += 2
|
||||||
|
|
||||||
|
if score > 0:
|
||||||
|
results.append((score, metadata))
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
results.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return [metadata for _, metadata in results]
|
||||||
|
|
||||||
|
def find_compatible_indicators(
|
||||||
|
self,
|
||||||
|
available_columns: List[str],
|
||||||
|
column_types: Dict[str, str]
|
||||||
|
) -> List[IndicatorMetadata]:
|
||||||
|
"""
|
||||||
|
Find indicators that can be computed from available columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_columns: List of column names available
|
||||||
|
column_types: Mapping of column name to type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of indicators whose input schema is satisfied
|
||||||
|
"""
|
||||||
|
from datasource.schema import ColumnInfo
|
||||||
|
|
||||||
|
# Build ColumnInfo list from available data
|
||||||
|
available_schema = [
|
||||||
|
ColumnInfo(
|
||||||
|
name=name,
|
||||||
|
type=column_types.get(name, "float"),
|
||||||
|
description=f"Column {name}"
|
||||||
|
)
|
||||||
|
for name in available_columns
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for indicator_class in self._indicators.values():
|
||||||
|
input_schema = indicator_class.get_input_schema()
|
||||||
|
if input_schema.matches(available_schema):
|
||||||
|
results.append(indicator_class.get_metadata())
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def validate_indicator_chain(
|
||||||
|
self,
|
||||||
|
indicator_chain: List[tuple[str, Dict]]
|
||||||
|
) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate that a chain of indicators can be connected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator_chain: List of (indicator_name, params) tuples in execution order
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
"""
|
||||||
|
if not indicator_chain:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# For now, just check that all indicators exist
|
||||||
|
# More sophisticated DAG validation happens in the pipeline engine
|
||||||
|
for indicator_name, params in indicator_chain:
|
||||||
|
if indicator_name not in self._indicators:
|
||||||
|
return False, f"Indicator '{indicator_name}' not found in registry"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def get_input_schema(self, name: str) -> Optional[InputSchema]:
|
||||||
|
"""
|
||||||
|
Get input schema for a specific indicator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Indicator class name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputSchema or None if not found
|
||||||
|
"""
|
||||||
|
indicator_class = self.get(name)
|
||||||
|
if indicator_class:
|
||||||
|
return indicator_class.get_input_schema()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]:
|
||||||
|
"""
|
||||||
|
Get output schema for a specific indicator with given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Indicator class name
|
||||||
|
**params: Indicator parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputSchema or None if not found
|
||||||
|
"""
|
||||||
|
indicator_class = self.get(name)
|
||||||
|
if indicator_class:
|
||||||
|
return indicator_class.get_output_schema(**params)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]:
|
||||||
|
"""
|
||||||
|
Create an indicator instance with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Indicator class name
|
||||||
|
instance_name: Unique instance name (for output column prefixing)
|
||||||
|
**params: Indicator configuration parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Indicator instance or None if class not found
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parameters are invalid
|
||||||
|
"""
|
||||||
|
indicator_class = self.get(name)
|
||||||
|
if not indicator_class:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return indicator_class(instance_name=instance_name, **params)
|
||||||
|
|
||||||
|
def generate_ai_tool_spec(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Generate a JSON specification for AI agent tools.
|
||||||
|
|
||||||
|
Creates a structured representation of all indicators that can be
|
||||||
|
used to build agent tools for indicator selection and composition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict suitable for AI agent tool registration
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
for indicator_class in self._indicators.values():
|
||||||
|
metadata = indicator_class.get_metadata()
|
||||||
|
|
||||||
|
# Build parameter spec
|
||||||
|
parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
|
||||||
|
for param in metadata.parameters:
|
||||||
|
param_spec = {
|
||||||
|
"type": param.type,
|
||||||
|
"description": param.description
|
||||||
|
}
|
||||||
|
|
||||||
|
if param.default is not None:
|
||||||
|
param_spec["default"] = param.default
|
||||||
|
if param.min_value is not None:
|
||||||
|
param_spec["minimum"] = param.min_value
|
||||||
|
if param.max_value is not None:
|
||||||
|
param_spec["maximum"] = param.max_value
|
||||||
|
|
||||||
|
parameters["properties"][param.name] = param_spec
|
||||||
|
|
||||||
|
if param.required:
|
||||||
|
parameters["required"].append(param.name)
|
||||||
|
|
||||||
|
tool = {
|
||||||
|
"name": f"indicator_{metadata.name.lower()}",
|
||||||
|
"description": f"{metadata.display_name}: {metadata.description}",
|
||||||
|
"category": metadata.category,
|
||||||
|
"use_cases": metadata.use_cases,
|
||||||
|
"tags": metadata.tags,
|
||||||
|
"parameters": parameters,
|
||||||
|
"input_schema": indicator_class.get_input_schema().model_dump(),
|
||||||
|
"output_schema": indicator_class.get_output_schema().model_dump()
|
||||||
|
}
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"indicator_tools": tools,
|
||||||
|
"total_count": len(tools)
|
||||||
|
}
|
||||||
269
backend/src/indicator/schema.py
Normal file
269
backend/src/indicator/schema.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
Data models for the Indicator system.
|
||||||
|
|
||||||
|
Defines schemas for input/output specifications, computation context,
|
||||||
|
and metadata for AI agent discovery.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from datasource.schema import ColumnInfo
|
||||||
|
|
||||||
|
|
||||||
|
class InputSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Declares the required input columns for an Indicator.
|
||||||
|
|
||||||
|
Indicators match against any data source (DataSource or other Indicator)
|
||||||
|
that provides columns satisfying this schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
required_columns: List[ColumnInfo] = Field(
|
||||||
|
description="Columns that must be present in the input data"
|
||||||
|
)
|
||||||
|
optional_columns: List[ColumnInfo] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Columns that may be used if present but are not required"
|
||||||
|
)
|
||||||
|
time_column: str = Field(
|
||||||
|
default="time",
|
||||||
|
description="Name of the timestamp column (must be present)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def matches(self, available_columns: List[ColumnInfo]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if available columns satisfy this input schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_columns: Columns provided by a data source
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all required columns are present with compatible types
|
||||||
|
"""
|
||||||
|
available_map = {col.name: col for col in available_columns}
|
||||||
|
|
||||||
|
# Check time column exists
|
||||||
|
if self.time_column not in available_map:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check all required columns exist with compatible types
|
||||||
|
for required in self.required_columns:
|
||||||
|
if required.name not in available_map:
|
||||||
|
return False
|
||||||
|
available = available_map[required.name]
|
||||||
|
if available.type != required.type:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of missing required column names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_columns: Columns provided by a data source
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of missing column names
|
||||||
|
"""
|
||||||
|
available_names = {col.name for col in available_columns}
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
if self.time_column not in available_names:
|
||||||
|
missing.append(self.time_column)
|
||||||
|
|
||||||
|
for required in self.required_columns:
|
||||||
|
if required.name not in available_names:
|
||||||
|
missing.append(required.name)
|
||||||
|
|
||||||
|
return missing
|
||||||
|
|
||||||
|
|
||||||
|
class OutputSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Declares the output columns produced by an Indicator.
|
||||||
|
|
||||||
|
Column names will be automatically prefixed with the indicator instance name
|
||||||
|
to avoid collisions in the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
columns: List[ColumnInfo] = Field(
|
||||||
|
description="Output columns produced by this indicator"
|
||||||
|
)
|
||||||
|
time_column: str = Field(
|
||||||
|
default="time",
|
||||||
|
description="Name of the timestamp column (passed through from input)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_prefix(self, prefix: str) -> "OutputSchema":
|
||||||
|
"""
|
||||||
|
Create a new OutputSchema with all column names prefixed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Prefix to add (e.g., indicator instance name)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New OutputSchema with prefixed column names
|
||||||
|
"""
|
||||||
|
prefixed_columns = [
|
||||||
|
ColumnInfo(
|
||||||
|
name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name,
|
||||||
|
type=col.type,
|
||||||
|
description=col.description,
|
||||||
|
unit=col.unit,
|
||||||
|
nullable=col.nullable
|
||||||
|
)
|
||||||
|
for col in self.columns
|
||||||
|
]
|
||||||
|
return OutputSchema(
|
||||||
|
columns=prefixed_columns,
|
||||||
|
time_column=self.time_column
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IndicatorParameter(BaseModel):
|
||||||
|
"""
|
||||||
|
Metadata for a configurable indicator parameter.
|
||||||
|
|
||||||
|
Used for AI agent discovery and dynamic indicator instantiation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
name: str = Field(description="Parameter name")
|
||||||
|
type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type")
|
||||||
|
description: str = Field(description="Human and LLM-readable description")
|
||||||
|
default: Optional[Any] = Field(default=None, description="Default value if not specified")
|
||||||
|
required: bool = Field(default=False, description="Whether this parameter is required")
|
||||||
|
min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)")
|
||||||
|
max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)")
|
||||||
|
|
||||||
|
|
||||||
|
class IndicatorMetadata(BaseModel):
|
||||||
|
"""
|
||||||
|
Rich metadata for an Indicator class.
|
||||||
|
|
||||||
|
Enables AI agents to discover, understand, and instantiate indicators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')")
|
||||||
|
display_name: str = Field(description="Human-readable display name")
|
||||||
|
description: str = Field(description="Detailed description of what this indicator computes and why it's useful")
|
||||||
|
category: str = Field(
|
||||||
|
description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')"
|
||||||
|
)
|
||||||
|
parameters: List[IndicatorParameter] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Configurable parameters for this indicator"
|
||||||
|
)
|
||||||
|
use_cases: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Common use cases and trading scenarios where this indicator is helpful"
|
||||||
|
)
|
||||||
|
references: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="URLs or citations for indicator methodology"
|
||||||
|
)
|
||||||
|
tags: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeContext(BaseModel):
|
||||||
|
"""
|
||||||
|
Context passed to an Indicator's compute() method.
|
||||||
|
|
||||||
|
Contains the input data and metadata about what changed (for incremental updates).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
data: List[Dict[str, Any]] = Field(
|
||||||
|
description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}"
|
||||||
|
)
|
||||||
|
is_incremental: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="True if this is an incremental update (only new/changed rows), False for full recompute"
|
||||||
|
)
|
||||||
|
updated_from_time: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Unix timestamp (ms) of the earliest updated row (for incremental updates)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_column(self, name: str) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Extract a single column as a list of values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Column name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of values in time order
|
||||||
|
"""
|
||||||
|
return [row.get(name) for row in self.data]
|
||||||
|
|
||||||
|
def get_times(self) -> List[int]:
|
||||||
|
"""
|
||||||
|
Get the time column as a list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of timestamps in order
|
||||||
|
"""
|
||||||
|
return [row["time"] for row in self.data]
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeResult(BaseModel):
|
||||||
|
"""
|
||||||
|
Result from an Indicator's compute() method.
|
||||||
|
|
||||||
|
Contains the computed output data with proper column naming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
data: List[Dict[str, Any]] = Field(
|
||||||
|
description="Output data rows (time-ordered). Must include time column."
|
||||||
|
)
|
||||||
|
is_partial: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="True if this result only contains updates (for incremental computation)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Merge this result into existing data with column name prefixing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Prefix to add to all column names except time
|
||||||
|
existing_data: Existing data to merge with (matched by time)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged data with prefixed columns added
|
||||||
|
"""
|
||||||
|
# Build a time index for new data
|
||||||
|
time_index = {row["time"]: row for row in self.data}
|
||||||
|
|
||||||
|
# Merge into existing data
|
||||||
|
result = []
|
||||||
|
for existing_row in existing_data:
|
||||||
|
row_time = existing_row["time"]
|
||||||
|
merged_row = existing_row.copy()
|
||||||
|
|
||||||
|
if row_time in time_index:
|
||||||
|
new_row = time_index[row_time]
|
||||||
|
for key, value in new_row.items():
|
||||||
|
if key != "time":
|
||||||
|
merged_row[f"{prefix}_{key}"] = value
|
||||||
|
|
||||||
|
result.append(merged_row)
|
||||||
|
|
||||||
|
return result
|
||||||
436
backend/src/indicator/talib_adapter.py
Normal file
436
backend/src/indicator/talib_adapter.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
"""
|
||||||
|
TA-Lib indicator adapter.
|
||||||
|
|
||||||
|
Provides automatic registration of all TA-Lib technical indicators
|
||||||
|
as composable Indicator instances.
|
||||||
|
|
||||||
|
Installation Requirements:
|
||||||
|
--------------------------
|
||||||
|
TA-Lib requires both the C library and Python wrapper:
|
||||||
|
|
||||||
|
1. Install TA-Lib C library:
|
||||||
|
- Ubuntu/Debian: sudo apt-get install libta-lib-dev
|
||||||
|
- macOS: brew install ta-lib
|
||||||
|
- From source: https://ta-lib.org/install.html
|
||||||
|
|
||||||
|
2. Install Python wrapper (already in requirements.txt):
|
||||||
|
pip install TA-Lib
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
------
|
||||||
|
from indicator.talib_adapter import register_all_talib_indicators
|
||||||
|
|
||||||
|
# Auto-register all TA-Lib indicators
|
||||||
|
registry = IndicatorRegistry()
|
||||||
|
register_all_talib_indicators(registry)
|
||||||
|
|
||||||
|
# Now you can use any TA-Lib indicator
|
||||||
|
sma = registry.create_instance("SMA", "sma_20", period=20)
|
||||||
|
rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import talib
|
||||||
|
from talib import abstract
|
||||||
|
TALIB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TALIB_AVAILABLE = False
|
||||||
|
talib = None
|
||||||
|
abstract = None
|
||||||
|
|
||||||
|
from datasource.schema import ColumnInfo
|
||||||
|
|
||||||
|
from .base import Indicator
|
||||||
|
from .schema import (
|
||||||
|
ComputeContext,
|
||||||
|
ComputeResult,
|
||||||
|
IndicatorMetadata,
|
||||||
|
IndicatorParameter,
|
||||||
|
InputSchema,
|
||||||
|
OutputSchema,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping of TA-Lib parameter types to our schema types
|
||||||
|
TALIB_TYPE_MAP = {
|
||||||
|
"double": "float",
|
||||||
|
"double[]": "float",
|
||||||
|
"int": "int",
|
||||||
|
"str": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Categorization of TA-Lib functions
|
||||||
|
TALIB_CATEGORIES = {
|
||||||
|
"overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3",
|
||||||
|
"BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"],
|
||||||
|
"momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX",
|
||||||
|
"ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI",
|
||||||
|
"STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC",
|
||||||
|
"BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"],
|
||||||
|
"volume": ["AD", "ADOSC", "OBV"],
|
||||||
|
"volatility": ["ATR", "NATR", "TRANGE"],
|
||||||
|
"price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"],
|
||||||
|
"cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"],
|
||||||
|
"pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE",
|
||||||
|
"CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY",
|
||||||
|
"CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU",
|
||||||
|
"CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI",
|
||||||
|
"CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR",
|
||||||
|
"CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER",
|
||||||
|
"CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE",
|
||||||
|
"CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK",
|
||||||
|
"CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM",
|
||||||
|
"CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW",
|
||||||
|
"CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK",
|
||||||
|
"CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES",
|
||||||
|
"CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN",
|
||||||
|
"CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR",
|
||||||
|
"CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"],
|
||||||
|
"statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT",
|
||||||
|
"LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"],
|
||||||
|
"math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX",
|
||||||
|
"MULT", "SUB", "SUM"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_function_category(func_name: str) -> str:
|
||||||
|
"""Determine the category of a TA-Lib function."""
|
||||||
|
for category, functions in TALIB_CATEGORIES.items():
|
||||||
|
if func_name in functions:
|
||||||
|
return category
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
|
||||||
|
class TALibIndicator(Indicator):
|
||||||
|
"""
|
||||||
|
Generic adapter for TA-Lib technical indicators.
|
||||||
|
|
||||||
|
Wraps any TA-Lib function to work within the composable indicator framework.
|
||||||
|
Handles parameter mapping, input validation, and output formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Class variable to store the TA-Lib function name
|
||||||
|
talib_function_name: str = None
|
||||||
|
|
||||||
|
def __init__(self, instance_name: str, **params):
|
||||||
|
"""
|
||||||
|
Initialize a TA-Lib indicator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance_name: Unique name for this instance
|
||||||
|
**params: TA-Lib function parameters
|
||||||
|
"""
|
||||||
|
if not TALIB_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"TA-Lib is not installed. Please install the TA-Lib C library "
|
||||||
|
"and Python wrapper. See indicator/talib_adapter.py for instructions."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(instance_name, **params)
|
||||||
|
self._talib_func = abstract.Function(self.talib_function_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_metadata(cls) -> IndicatorMetadata:
|
||||||
|
"""Get metadata from TA-Lib function info."""
|
||||||
|
if not TALIB_AVAILABLE:
|
||||||
|
raise ImportError("TA-Lib is not installed")
|
||||||
|
|
||||||
|
func = abstract.Function(cls.talib_function_name)
|
||||||
|
info = func.info
|
||||||
|
|
||||||
|
# Build parameters list from TA-Lib function info
|
||||||
|
parameters = []
|
||||||
|
for param_name, param_info in info.get("parameters", {}).items():
|
||||||
|
# Handle case where param_info is a simple value (int/float) instead of a dict
|
||||||
|
if isinstance(param_info, dict):
|
||||||
|
param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float")
|
||||||
|
default_value = param_info.get("default_value")
|
||||||
|
else:
|
||||||
|
# param_info is a simple value (default), infer type from the value
|
||||||
|
if isinstance(param_info, int):
|
||||||
|
param_type = "int"
|
||||||
|
elif isinstance(param_info, float):
|
||||||
|
param_type = "float"
|
||||||
|
else:
|
||||||
|
param_type = "float" # Default to float
|
||||||
|
default_value = param_info
|
||||||
|
|
||||||
|
parameters.append(
|
||||||
|
IndicatorParameter(
|
||||||
|
name=param_name,
|
||||||
|
type=param_type,
|
||||||
|
description=f"TA-Lib parameter: {param_name}",
|
||||||
|
default=default_value,
|
||||||
|
required=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get function group/category
|
||||||
|
category = _get_function_category(cls.talib_function_name)
|
||||||
|
|
||||||
|
# Build display name (split camelCase or handle CDL prefix)
|
||||||
|
display_name = cls.talib_function_name
|
||||||
|
if display_name.startswith("CDL"):
|
||||||
|
display_name = display_name[3:] # Remove CDL prefix for patterns
|
||||||
|
|
||||||
|
return IndicatorMetadata(
|
||||||
|
name=cls.talib_function_name,
|
||||||
|
display_name=display_name,
|
||||||
|
description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"),
|
||||||
|
category=category,
|
||||||
|
parameters=parameters,
|
||||||
|
use_cases=[f"Technical analysis using {cls.talib_function_name}"],
|
||||||
|
references=["https://ta-lib.org/function.html"],
|
||||||
|
tags=["talib", category, cls.talib_function_name.lower()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_schema(cls) -> InputSchema:
|
||||||
|
"""
|
||||||
|
Get input schema from TA-Lib function requirements.
|
||||||
|
|
||||||
|
Most TA-Lib functions use OHLCV data, but some use subsets.
|
||||||
|
"""
|
||||||
|
if not TALIB_AVAILABLE:
|
||||||
|
raise ImportError("TA-Lib is not installed")
|
||||||
|
|
||||||
|
func = abstract.Function(cls.talib_function_name)
|
||||||
|
info = func.info
|
||||||
|
input_names = info.get("input_names", {})
|
||||||
|
|
||||||
|
required_columns = []
|
||||||
|
|
||||||
|
# Map TA-Lib input names to our schema
|
||||||
|
if "prices" in input_names:
|
||||||
|
price_inputs = input_names["prices"]
|
||||||
|
if "open" in price_inputs:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="open", type="float", description="Opening price")
|
||||||
|
)
|
||||||
|
if "high" in price_inputs:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="high", type="float", description="High price")
|
||||||
|
)
|
||||||
|
if "low" in price_inputs:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="low", type="float", description="Low price")
|
||||||
|
)
|
||||||
|
if "close" in price_inputs:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="close", type="float", description="Closing price")
|
||||||
|
)
|
||||||
|
if "volume" in price_inputs:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="volume", type="float", description="Trading volume")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle functions that take generic price arrays
|
||||||
|
if "price" in input_names:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="close", type="float", description="Price (typically close)")
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no specific inputs found, assume close price
|
||||||
|
if not required_columns:
|
||||||
|
required_columns.append(
|
||||||
|
ColumnInfo(name="close", type="float", description="Closing price")
|
||||||
|
)
|
||||||
|
|
||||||
|
return InputSchema(required_columns=required_columns)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_output_schema(cls, **params) -> OutputSchema:
|
||||||
|
"""Get output schema from TA-Lib function outputs."""
|
||||||
|
if not TALIB_AVAILABLE:
|
||||||
|
raise ImportError("TA-Lib is not installed")
|
||||||
|
|
||||||
|
func = abstract.Function(cls.talib_function_name)
|
||||||
|
info = func.info
|
||||||
|
output_names = info.get("output_names", [])
|
||||||
|
|
||||||
|
columns = []
|
||||||
|
|
||||||
|
# Most TA-Lib functions output one or more float arrays
|
||||||
|
if isinstance(output_names, list):
|
||||||
|
for output_name in output_names:
|
||||||
|
columns.append(
|
||||||
|
ColumnInfo(
|
||||||
|
name=output_name.lower(),
|
||||||
|
type="float",
|
||||||
|
description=f"{cls.talib_function_name} output: {output_name}",
|
||||||
|
nullable=True # TA-Lib often has NaN for initial periods
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Single output, use function name
|
||||||
|
columns.append(
|
||||||
|
ColumnInfo(
|
||||||
|
name=cls.talib_function_name.lower(),
|
||||||
|
type="float",
|
||||||
|
description=f"{cls.talib_function_name} indicator value",
|
||||||
|
nullable=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return OutputSchema(columns=columns)
|
||||||
|
|
||||||
|
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||||
|
"""Compute indicator using TA-Lib."""
|
||||||
|
# Extract input columns
|
||||||
|
input_data = {}
|
||||||
|
|
||||||
|
# Get the function's expected inputs
|
||||||
|
info = self._talib_func.info
|
||||||
|
input_names = info.get("input_names", {})
|
||||||
|
|
||||||
|
# Prepare input arrays
|
||||||
|
if "prices" in input_names:
|
||||||
|
price_inputs = input_names["prices"]
|
||||||
|
for price_type in price_inputs:
|
||||||
|
column_data = context.get_column(price_type)
|
||||||
|
# Convert to numpy array, replacing None with NaN
|
||||||
|
input_data[price_type] = np.array(
|
||||||
|
[float(v) if v is not None else np.nan for v in column_data]
|
||||||
|
)
|
||||||
|
elif "price" in input_names:
|
||||||
|
# Generic price input, use close
|
||||||
|
column_data = context.get_column("close")
|
||||||
|
input_data["price"] = np.array(
|
||||||
|
[float(v) if v is not None else np.nan for v in column_data]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default to close if no inputs specified
|
||||||
|
column_data = context.get_column("close")
|
||||||
|
input_data["close"] = np.array(
|
||||||
|
[float(v) if v is not None else np.nan for v in column_data]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set parameters on the function
|
||||||
|
self._talib_func.parameters = self.params
|
||||||
|
|
||||||
|
# Execute TA-Lib function
|
||||||
|
try:
|
||||||
|
output = self._talib_func(input_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}")
|
||||||
|
raise ValueError(f"TA-Lib computation failed: {e}")
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
times = context.get_times()
|
||||||
|
output_names = info.get("output_names", [])
|
||||||
|
|
||||||
|
# Handle single vs multiple outputs
|
||||||
|
if isinstance(output, np.ndarray):
|
||||||
|
# Single output
|
||||||
|
output_name = output_names[0].lower() if output_names else self.talib_function_name.lower()
|
||||||
|
result_data = [
|
||||||
|
{
|
||||||
|
"time": times[i],
|
||||||
|
output_name: float(output[i]) if not np.isnan(output[i]) else None
|
||||||
|
}
|
||||||
|
for i in range(len(times))
|
||||||
|
]
|
||||||
|
elif isinstance(output, tuple):
|
||||||
|
# Multiple outputs
|
||||||
|
result_data = []
|
||||||
|
for i in range(len(times)):
|
||||||
|
row = {"time": times[i]}
|
||||||
|
for j, output_array in enumerate(output):
|
||||||
|
output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}"
|
||||||
|
row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None
|
||||||
|
result_data.append(row)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected TA-Lib output type: {type(output)}")
|
||||||
|
|
||||||
|
return ComputeResult(
|
||||||
|
data=result_data,
|
||||||
|
is_partial=context.is_incremental
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_talib_indicator_class(func_name: str) -> type:
|
||||||
|
"""
|
||||||
|
Dynamically create an Indicator class for a TA-Lib function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_name: TA-Lib function name (e.g., 'SMA', 'RSI')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Indicator class for this function
|
||||||
|
"""
|
||||||
|
return type(
|
||||||
|
f"TALib_{func_name}",
|
||||||
|
(TALibIndicator,),
|
||||||
|
{"talib_function_name": func_name}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_all_talib_indicators(registry) -> int:
|
||||||
|
"""
|
||||||
|
Auto-register all available TA-Lib indicators with the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: IndicatorRegistry instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of indicators registered
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If TA-Lib is not installed
|
||||||
|
"""
|
||||||
|
if not TALIB_AVAILABLE:
|
||||||
|
logger.warning(
|
||||||
|
"TA-Lib is not installed. Skipping TA-Lib indicator registration. "
|
||||||
|
"Install TA-Lib C library and Python wrapper to enable TA-Lib indicators."
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Get all TA-Lib functions
|
||||||
|
func_groups = talib.get_function_groups()
|
||||||
|
all_functions = []
|
||||||
|
for group, functions in func_groups.items():
|
||||||
|
all_functions.extend(functions)
|
||||||
|
|
||||||
|
# Remove duplicates
|
||||||
|
all_functions = sorted(set(all_functions))
|
||||||
|
|
||||||
|
registered_count = 0
|
||||||
|
for func_name in all_functions:
|
||||||
|
try:
|
||||||
|
# Create indicator class for this function
|
||||||
|
indicator_class = create_talib_indicator_class(func_name)
|
||||||
|
|
||||||
|
# Register with the registry
|
||||||
|
registry.register(indicator_class)
|
||||||
|
registered_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Registered {registered_count} TA-Lib indicators")
|
||||||
|
return registered_count
|
||||||
|
|
||||||
|
|
||||||
|
def get_talib_version() -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the installed TA-Lib version.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Version string or None if not installed
|
||||||
|
"""
|
||||||
|
if TALIB_AVAILABLE:
|
||||||
|
return talib.__version__
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_talib_available() -> bool:
|
||||||
|
"""Check if TA-Lib is available."""
|
||||||
|
return TALIB_AVAILABLE
|
||||||
@@ -20,13 +20,14 @@ from gateway.hub import Gateway
|
|||||||
from gateway.channels.websocket import WebSocketChannel
|
from gateway.channels.websocket import WebSocketChannel
|
||||||
from gateway.protocol import WebSocketAgentUserMessage
|
from gateway.protocol import WebSocketAgentUserMessage
|
||||||
from agent.core import create_agent
|
from agent.core import create_agent
|
||||||
from agent.tools import set_registry, set_datasource_registry
|
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
||||||
from schema.order_spec import SwapOrder
|
from schema.order_spec import SwapOrder
|
||||||
from schema.chart_state import ChartState
|
from schema.chart_state import ChartState
|
||||||
from datasource.registry import DataSourceRegistry
|
from datasource.registry import DataSourceRegistry
|
||||||
from datasource.subscription_manager import SubscriptionManager
|
from datasource.subscription_manager import SubscriptionManager
|
||||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||||
from secrets_manager import SecretsStore, InvalidMasterPassword
|
from secrets_manager import SecretsStore, InvalidMasterPassword
|
||||||
|
from indicator import IndicatorRegistry, register_all_talib_indicators
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -53,6 +54,9 @@ agent_executor = None
|
|||||||
datasource_registry = DataSourceRegistry()
|
datasource_registry = DataSourceRegistry()
|
||||||
subscription_manager = SubscriptionManager()
|
subscription_manager = SubscriptionManager()
|
||||||
|
|
||||||
|
# Indicator infrastructure
|
||||||
|
indicator_registry = IndicatorRegistry()
|
||||||
|
|
||||||
# Global secrets store
|
# Global secrets store
|
||||||
secrets_store = SecretsStore()
|
secrets_store = SecretsStore()
|
||||||
|
|
||||||
@@ -80,6 +84,14 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
||||||
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
||||||
|
|
||||||
|
# Initialize indicator registry with all TA-Lib indicators
|
||||||
|
try:
|
||||||
|
indicator_count = register_all_talib_indicators(indicator_registry)
|
||||||
|
logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
||||||
|
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
||||||
|
|
||||||
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
||||||
anthropic_api_key = None
|
anthropic_api_key = None
|
||||||
|
|
||||||
@@ -101,6 +113,7 @@ async def lifespan(app: FastAPI):
|
|||||||
# Set the registries for agent tools
|
# Set the registries for agent tools
|
||||||
set_registry(registry)
|
set_registry(registry)
|
||||||
set_datasource_registry(datasource_registry)
|
set_datasource_registry(datasource_registry)
|
||||||
|
set_indicator_registry(indicator_registry)
|
||||||
|
|
||||||
# Create and initialize agent
|
# Create and initialize agent
|
||||||
agent_executor = create_agent(
|
agent_executor = create_agent(
|
||||||
|
|||||||
@@ -40,6 +40,58 @@ class Exchange(StrEnum):
|
|||||||
UNISWAP_V3 = "UniswapV3"
|
UNISWAP_V3 = "UniswapV3"
|
||||||
|
|
||||||
|
|
||||||
|
class Side(StrEnum):
|
||||||
|
"""Order side: buy or sell"""
|
||||||
|
BUY = "BUY"
|
||||||
|
SELL = "SELL"
|
||||||
|
|
||||||
|
|
||||||
|
class AmountType(StrEnum):
|
||||||
|
"""Whether the order amount refers to base or quote currency"""
|
||||||
|
BASE = "BASE" # Amount is in base currency (e.g., BTC in BTC/USD)
|
||||||
|
QUOTE = "QUOTE" # Amount is in quote currency (e.g., USD in BTC/USD)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeInForce(StrEnum):
|
||||||
|
"""Order lifetime specification"""
|
||||||
|
GTC = "GTC" # Good Till Cancel
|
||||||
|
IOC = "IOC" # Immediate or Cancel
|
||||||
|
FOK = "FOK" # Fill or Kill
|
||||||
|
DAY = "DAY" # Good for trading day
|
||||||
|
GTD = "GTD" # Good Till Date
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalOrderMode(StrEnum):
|
||||||
|
"""How conditional orders behave on partial fills"""
|
||||||
|
NEW_PER_FILL = "NEW_PER_FILL" # Create new conditional order per each fill
|
||||||
|
UNIFIED_ADJUSTING = "UNIFIED_ADJUSTING" # Single conditional order that adjusts amount
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerType(StrEnum):
|
||||||
|
"""Type of conditional trigger"""
|
||||||
|
STOP_LOSS = "STOP_LOSS"
|
||||||
|
TAKE_PROFIT = "TAKE_PROFIT"
|
||||||
|
STOP_LIMIT = "STOP_LIMIT"
|
||||||
|
TRAILING_STOP = "TRAILING_STOP"
|
||||||
|
|
||||||
|
|
||||||
|
class TickSpacingMode(StrEnum):
|
||||||
|
"""How price tick spacing is determined"""
|
||||||
|
FIXED = "FIXED" # Fixed tick size
|
||||||
|
DYNAMIC = "DYNAMIC" # Tick size varies by price level
|
||||||
|
CONTINUOUS = "CONTINUOUS" # No tick restrictions
|
||||||
|
|
||||||
|
|
||||||
|
class AssetType(StrEnum):
|
||||||
|
"""Type of tradeable asset"""
|
||||||
|
SPOT = "SPOT" # Spot/cash market
|
||||||
|
MARGIN = "MARGIN" # Margin trading
|
||||||
|
PERP = "PERP" # Perpetual futures
|
||||||
|
FUTURE = "FUTURE" # Dated futures
|
||||||
|
OPTION = "OPTION" # Options
|
||||||
|
SYNTHETIC = "SYNTHETIC" # Synthetic/derived instruments
|
||||||
|
|
||||||
|
|
||||||
class OcoMode(StrEnum):
|
class OcoMode(StrEnum):
|
||||||
NO_OCO = "NO_OCO"
|
NO_OCO = "NO_OCO"
|
||||||
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
|
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
|
||||||
@@ -96,6 +148,126 @@ class TrancheStatus(BaseModel):
|
|||||||
endTime: Uint32 = Field(description="Concrete end timestamp")
|
endTime: Uint32 = Field(description="Concrete end timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Standard Order Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ConditionalTrigger(BaseModel):
|
||||||
|
"""Conditional order trigger (stop-loss, take-profit, etc.)"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
trigger_type: TriggerType
|
||||||
|
trigger_price: Float = Field(description="Price at which conditional order activates")
|
||||||
|
trailing_delta: Float | None = Field(default=None, description="For trailing stops: delta from peak/trough")
|
||||||
|
|
||||||
|
|
||||||
|
class AmountConstraints(BaseModel):
|
||||||
|
"""Constraints on order amounts for a symbol"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
min_amount: Float = Field(description="Minimum order amount")
|
||||||
|
max_amount: Float = Field(description="Maximum order amount")
|
||||||
|
step_size: Float = Field(description="Amount increment granularity")
|
||||||
|
|
||||||
|
|
||||||
|
class PriceConstraints(BaseModel):
|
||||||
|
"""Constraints on order pricing for a symbol"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
tick_spacing_mode: TickSpacingMode
|
||||||
|
tick_size: Float | None = Field(default=None, description="Fixed tick size (if FIXED mode)")
|
||||||
|
min_price: Float | None = Field(default=None, description="Minimum allowed price")
|
||||||
|
max_price: Float | None = Field(default=None, description="Maximum allowed price")
|
||||||
|
|
||||||
|
|
||||||
|
class MarketCapabilities(BaseModel):
|
||||||
|
"""Describes what order features a market supports"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
supported_sides: list[Side] = Field(description="Supported order sides (usually both)")
|
||||||
|
supported_amount_types: list[AmountType] = Field(description="Whether BASE, QUOTE, or both amounts are supported")
|
||||||
|
supports_market_orders: bool = Field(description="Whether market orders are supported")
|
||||||
|
supports_limit_orders: bool = Field(description="Whether limit orders are supported")
|
||||||
|
supported_time_in_force: list[TimeInForce] = Field(description="Supported order lifetimes")
|
||||||
|
supports_conditional_orders: bool = Field(description="Whether stop-loss/take-profit are supported")
|
||||||
|
supported_trigger_types: list[TriggerType] = Field(default_factory=list, description="Supported trigger types")
|
||||||
|
supports_post_only: bool = Field(default=False, description="Whether post-only orders are supported")
|
||||||
|
supports_reduce_only: bool = Field(default=False, description="Whether reduce-only orders are supported")
|
||||||
|
supports_iceberg: bool = Field(default=False, description="Whether iceberg orders are supported")
|
||||||
|
market_order_amount_type: AmountType | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Required amount type for market orders (some DEXs require exact-in)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SymbolMetadata(BaseModel):
|
||||||
|
"""Complete metadata describing a tradeable symbol/market"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
symbol_id: str = Field(description="Unique symbol identifier")
|
||||||
|
base_asset: str = Field(description="Base asset (e.g., 'BTC')")
|
||||||
|
quote_asset: str = Field(description="Quote asset (e.g., 'USD')")
|
||||||
|
asset_type: AssetType = Field(description="Type of market")
|
||||||
|
exchange: str = Field(description="Exchange identifier")
|
||||||
|
|
||||||
|
amount_constraints: AmountConstraints
|
||||||
|
price_constraints: PriceConstraints
|
||||||
|
capabilities: MarketCapabilities
|
||||||
|
|
||||||
|
contract_size: Float | None = Field(default=None, description="For futures/options: contract multiplier")
|
||||||
|
settlement_asset: str | None = Field(default=None, description="For derivatives: settlement currency")
|
||||||
|
expiry_timestamp: Uint64 | None = Field(default=None, description="For dated futures/options: expiration")
|
||||||
|
|
||||||
|
|
||||||
|
class StandardOrder(BaseModel):
|
||||||
|
"""Standard order specification for exchange kernels"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
symbol_id: str = Field(description="Symbol to trade")
|
||||||
|
side: Side = Field(description="Buy or sell")
|
||||||
|
amount: Float = Field(description="Order amount")
|
||||||
|
amount_type: AmountType = Field(description="Whether amount is BASE or QUOTE currency")
|
||||||
|
|
||||||
|
limit_price: Float | None = Field(default=None, description="Limit price (None = market order)")
|
||||||
|
time_in_force: TimeInForce = Field(default=TimeInForce.GTC, description="Order lifetime")
|
||||||
|
good_till_date: Uint64 | None = Field(default=None, description="Expiry timestamp for GTD orders")
|
||||||
|
|
||||||
|
conditional_trigger: ConditionalTrigger | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Stop-loss/take-profit trigger"
|
||||||
|
)
|
||||||
|
conditional_mode: ConditionalOrderMode | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="How conditional orders behave on partial fills"
|
||||||
|
)
|
||||||
|
|
||||||
|
reduce_only: bool = Field(default=False, description="Only reduce existing position")
|
||||||
|
post_only: bool = Field(default=False, description="Only make, never take")
|
||||||
|
iceberg_qty: Float | None = Field(default=None, description="Visible amount for iceberg orders")
|
||||||
|
|
||||||
|
client_order_id: str | None = Field(default=None, description="Client-specified order ID")
|
||||||
|
|
||||||
|
|
||||||
|
class StandardOrderStatus(BaseModel):
|
||||||
|
"""Current status of a standard order"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
order: StandardOrder
|
||||||
|
order_id: str = Field(description="Exchange-assigned order ID")
|
||||||
|
status: str = Field(description="Order status: NEW, PARTIALLY_FILLED, FILLED, CANCELED, REJECTED, EXPIRED")
|
||||||
|
filled_amount: Float = Field(description="Amount filled so far")
|
||||||
|
average_fill_price: Float = Field(description="Average execution price")
|
||||||
|
created_at: Uint64 = Field(description="Order creation timestamp")
|
||||||
|
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Order models
|
# Order models
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -117,7 +289,22 @@ class SwapOrder(BaseModel):
|
|||||||
tranches: list[Tranche] = Field(min_length=1)
|
tranches: list[Tranche] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class StandardOrderGroup(BaseModel):
|
||||||
|
"""Group of orders with OCO (One-Cancels-Other) relationship"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
mode: OcoMode
|
||||||
|
orders: list[StandardOrder] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legacy swap order models (kept for backward compatibility)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class OcoGroup(BaseModel):
|
class OcoGroup(BaseModel):
|
||||||
|
"""DEPRECATED: Use StandardOrderGroup instead"""
|
||||||
|
|
||||||
model_config = {"extra": "forbid"}
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
mode: OcoMode
|
mode: OcoMode
|
||||||
|
|||||||
40
backend/src/secrets_manager/__init__.py
Normal file
40
backend/src/secrets_manager/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
Encrypted secrets management with master password protection.
|
||||||
|
|
||||||
|
This module provides secure storage for sensitive configuration like API keys,
|
||||||
|
using Argon2id for password-based key derivation and Fernet (AES-256) for encryption.
|
||||||
|
|
||||||
|
Basic usage:
|
||||||
|
from secrets_manager import SecretsStore
|
||||||
|
|
||||||
|
# First time setup
|
||||||
|
store = SecretsStore()
|
||||||
|
store.initialize("my-master-password")
|
||||||
|
store.set("ANTHROPIC_API_KEY", "sk-ant-...")
|
||||||
|
|
||||||
|
# Later usage
|
||||||
|
store = SecretsStore()
|
||||||
|
store.unlock("my-master-password")
|
||||||
|
api_key = store.get("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
Command-line interface:
|
||||||
|
python -m secrets_manager.cli init
|
||||||
|
python -m secrets_manager.cli set KEY VALUE
|
||||||
|
python -m secrets_manager.cli get KEY
|
||||||
|
python -m secrets_manager.cli list
|
||||||
|
python -m secrets_manager.cli change-password
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .store import (
|
||||||
|
SecretsStore,
|
||||||
|
SecretsStoreError,
|
||||||
|
SecretsStoreLocked,
|
||||||
|
InvalidMasterPassword,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SecretsStore",
|
||||||
|
"SecretsStoreError",
|
||||||
|
"SecretsStoreLocked",
|
||||||
|
"InvalidMasterPassword",
|
||||||
|
]
|
||||||
374
backend/src/secrets_manager/cli.py
Normal file
374
backend/src/secrets_manager/cli.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Command-line interface for managing the encrypted secrets store.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m secrets.cli init # Initialize new secrets store
|
||||||
|
python -m secrets.cli set KEY VALUE # Set a secret
|
||||||
|
python -m secrets.cli get KEY # Get a secret
|
||||||
|
python -m secrets.cli delete KEY # Delete a secret
|
||||||
|
python -m secrets.cli list # List all secret keys
|
||||||
|
python -m secrets.cli change-password # Change master password
|
||||||
|
python -m secrets.cli export FILE # Export encrypted backup
|
||||||
|
python -m secrets.cli import FILE # Import encrypted backup
|
||||||
|
python -m secrets.cli migrate-from-env # Migrate secrets from .env file
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import getpass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .store import SecretsStore, SecretsStoreError, InvalidMasterPassword
|
||||||
|
|
||||||
|
|
||||||
|
def get_password(prompt: str = "Master password: ", confirm: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Securely get password from user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Password prompt
|
||||||
|
confirm: If True, ask for confirmation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Password string
|
||||||
|
"""
|
||||||
|
password = getpass.getpass(prompt)
|
||||||
|
|
||||||
|
if confirm:
|
||||||
|
confirm_password = getpass.getpass("Confirm password: ")
|
||||||
|
if password != confirm_password:
|
||||||
|
print("Error: Passwords do not match", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return password
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_init(args):
|
||||||
|
"""Initialize a new secrets store."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if store.is_initialized:
|
||||||
|
print("Error: Secrets store is already initialized", file=sys.stderr)
|
||||||
|
print(f"Location: {store.secrets_file}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password("Create master password: ", confirm=True)
|
||||||
|
|
||||||
|
if len(password) < 8:
|
||||||
|
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
store.initialize(password)
|
||||||
|
print(f"Secrets store initialized at {store.secrets_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_set(args):
|
||||||
|
"""Set a secret value."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password()
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.unlock(password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid master password", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
store.set(args.key, args.value)
|
||||||
|
print(f"✓ Secret '{args.key}' saved")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_get(args):
|
||||||
|
"""Get a secret value."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password()
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.unlock(password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid master password", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
value = store.get(args.key)
|
||||||
|
if value is None:
|
||||||
|
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Print to stdout (can be captured)
|
||||||
|
print(value)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_delete(args):
|
||||||
|
"""Delete a secret."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password()
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.unlock(password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid master password", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if store.delete(args.key):
|
||||||
|
print(f"✓ Secret '{args.key}' deleted")
|
||||||
|
else:
|
||||||
|
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_list(args):
|
||||||
|
"""List all secret keys."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password()
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.unlock(password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid master password", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
keys = store.list_keys()
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
print("No secrets stored")
|
||||||
|
else:
|
||||||
|
print(f"Stored secrets ({len(keys)}):")
|
||||||
|
for key in sorted(keys):
|
||||||
|
# Show key and value length for verification
|
||||||
|
value = store.get(key)
|
||||||
|
value_str = str(value)
|
||||||
|
value_preview = value_str[:50] + "..." if len(value_str) > 50 else value_str
|
||||||
|
print(f" {key}: {value_preview}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_change_password(args):
|
||||||
|
"""Change the master password."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
current_password = get_password("Current master password: ")
|
||||||
|
new_password = get_password("New master password: ", confirm=True)
|
||||||
|
|
||||||
|
if len(new_password) < 8:
|
||||||
|
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.change_master_password(current_password, new_password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid current password", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_export(args):
|
||||||
|
"""Export encrypted secrets to a backup file."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
output_path = Path(args.file)
|
||||||
|
|
||||||
|
if output_path.exists() and not args.force:
|
||||||
|
print(f"Error: File {output_path} already exists. Use --force to overwrite.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.export_encrypted(output_path)
|
||||||
|
except SecretsStoreError as e:
|
||||||
|
print(f"Error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_import(args):
|
||||||
|
"""Import encrypted secrets from a backup file."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
input_path = Path(args.file)
|
||||||
|
|
||||||
|
if not input_path.exists():
|
||||||
|
print(f"Error: File {input_path} does not exist", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password()
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.import_encrypted(input_path, password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid master password or incompatible backup", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
except SecretsStoreError as e:
|
||||||
|
print(f"Error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_migrate_from_env(args):
|
||||||
|
"""Migrate secrets from .env file to encrypted store."""
|
||||||
|
store = SecretsStore()
|
||||||
|
|
||||||
|
if not store.is_initialized:
|
||||||
|
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Look for .env file
|
||||||
|
backend_root = Path(__file__).parent.parent.parent
|
||||||
|
env_file = backend_root / ".env"
|
||||||
|
|
||||||
|
if not env_file.exists():
|
||||||
|
print(f"Error: .env file not found at {env_file}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
password = get_password()
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.unlock(password)
|
||||||
|
except InvalidMasterPassword:
|
||||||
|
print("Error: Invalid master password", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Parse .env file (simple parser - doesn't handle all edge cases)
|
||||||
|
migrated = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
with open(env_file) as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith('#'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse KEY=VALUE format
|
||||||
|
if '=' not in line:
|
||||||
|
print(f"Warning: Skipping invalid line {line_num}: {line}", file=sys.stderr)
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
key, value = line.split('=', 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Remove quotes if present
|
||||||
|
if value.startswith('"') and value.endswith('"'):
|
||||||
|
value = value[1:-1]
|
||||||
|
elif value.startswith("'") and value.endswith("'"):
|
||||||
|
value = value[1:-1]
|
||||||
|
|
||||||
|
# Check if key already exists
|
||||||
|
existing = store.get(key)
|
||||||
|
if existing is not None:
|
||||||
|
print(f"Warning: Secret '{key}' already exists, skipping", file=sys.stderr)
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
store.set(key, value)
|
||||||
|
print(f"✓ Migrated: {key}")
|
||||||
|
migrated += 1
|
||||||
|
|
||||||
|
print(f"\nMigration complete: {migrated} secrets migrated, {skipped} skipped")
|
||||||
|
|
||||||
|
if not args.keep_env:
|
||||||
|
# Ask for confirmation before deleting .env
|
||||||
|
confirm = input(f"\nDelete {env_file}? [y/N]: ").strip().lower()
|
||||||
|
if confirm == 'y':
|
||||||
|
env_file.unlink()
|
||||||
|
print(f"✓ Deleted {env_file}")
|
||||||
|
else:
|
||||||
|
print(f"Kept {env_file} (consider deleting it manually)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main CLI entry point."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Manage encrypted secrets store",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||||
|
subparsers.required = True
|
||||||
|
|
||||||
|
# init
|
||||||
|
parser_init = subparsers.add_parser('init', help='Initialize new secrets store')
|
||||||
|
parser_init.set_defaults(func=cmd_init)
|
||||||
|
|
||||||
|
# set
|
||||||
|
parser_set = subparsers.add_parser('set', help='Set a secret value')
|
||||||
|
parser_set.add_argument('key', help='Secret key name')
|
||||||
|
parser_set.add_argument('value', help='Secret value')
|
||||||
|
parser_set.set_defaults(func=cmd_set)
|
||||||
|
|
||||||
|
# get
|
||||||
|
parser_get = subparsers.add_parser('get', help='Get a secret value')
|
||||||
|
parser_get.add_argument('key', help='Secret key name')
|
||||||
|
parser_get.set_defaults(func=cmd_get)
|
||||||
|
|
||||||
|
# delete
|
||||||
|
parser_delete = subparsers.add_parser('delete', help='Delete a secret')
|
||||||
|
parser_delete.add_argument('key', help='Secret key name')
|
||||||
|
parser_delete.set_defaults(func=cmd_delete)
|
||||||
|
|
||||||
|
# list
|
||||||
|
parser_list = subparsers.add_parser('list', help='List all secret keys')
|
||||||
|
parser_list.set_defaults(func=cmd_list)
|
||||||
|
|
||||||
|
# change-password
|
||||||
|
parser_change = subparsers.add_parser('change-password', help='Change master password')
|
||||||
|
parser_change.set_defaults(func=cmd_change_password)
|
||||||
|
|
||||||
|
# export
|
||||||
|
parser_export = subparsers.add_parser('export', help='Export encrypted backup')
|
||||||
|
parser_export.add_argument('file', help='Output file path')
|
||||||
|
parser_export.add_argument('--force', action='store_true', help='Overwrite existing file')
|
||||||
|
parser_export.set_defaults(func=cmd_export)
|
||||||
|
|
||||||
|
# import
|
||||||
|
parser_import = subparsers.add_parser('import', help='Import encrypted backup')
|
||||||
|
parser_import.add_argument('file', help='Input file path')
|
||||||
|
parser_import.set_defaults(func=cmd_import)
|
||||||
|
|
||||||
|
# migrate-from-env
|
||||||
|
parser_migrate = subparsers.add_parser('migrate-from-env', help='Migrate from .env file')
|
||||||
|
parser_migrate.add_argument('--keep-env', action='store_true', help='Keep .env file after migration')
|
||||||
|
parser_migrate.set_defaults(func=cmd_migrate_from_env)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
args.func(args)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nAborted", file=sys.stderr)
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
144
backend/src/secrets_manager/crypto.py
Normal file
144
backend/src/secrets_manager/crypto.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
Cryptographic utilities for secrets management.
|
||||||
|
|
||||||
|
Uses Argon2id for password-based key derivation and Fernet for encryption.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import secrets as secrets_module
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from argon2 import PasswordHasher
|
||||||
|
from argon2.low_level import hash_secret_raw, Type
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
# Argon2id parameters (OWASP recommended for password-based KDF)
|
||||||
|
# These provide strong defense against GPU/ASIC attacks
|
||||||
|
ARGON2_TIME_COST = 3 # iterations
|
||||||
|
ARGON2_MEMORY_COST = 65536 # 64 MB
|
||||||
|
ARGON2_PARALLELISM = 4 # threads
|
||||||
|
ARGON2_HASH_LENGTH = 32 # bytes (256 bits for Fernet key)
|
||||||
|
ARGON2_SALT_LENGTH = 16 # bytes (128 bits)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_salt() -> bytes:
|
||||||
|
"""Generate a cryptographically secure random salt."""
|
||||||
|
return secrets_module.token_bytes(ARGON2_SALT_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
|
def derive_key_from_password(password: str, salt: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Derive an encryption key from a password using Argon2id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
password: The master password
|
||||||
|
salt: The salt (must be consistent for the same password to work)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-byte key suitable for Fernet encryption
|
||||||
|
"""
|
||||||
|
password_bytes = password.encode('utf-8')
|
||||||
|
|
||||||
|
# Use Argon2id (hybrid mode - best of Argon2i and Argon2d)
|
||||||
|
raw_hash = hash_secret_raw(
|
||||||
|
secret=password_bytes,
|
||||||
|
salt=salt,
|
||||||
|
time_cost=ARGON2_TIME_COST,
|
||||||
|
memory_cost=ARGON2_MEMORY_COST,
|
||||||
|
parallelism=ARGON2_PARALLELISM,
|
||||||
|
hash_len=ARGON2_HASH_LENGTH,
|
||||||
|
type=Type.ID # Argon2id
|
||||||
|
)
|
||||||
|
|
||||||
|
return raw_hash
|
||||||
|
|
||||||
|
|
||||||
|
def create_fernet(key: bytes) -> Fernet:
|
||||||
|
"""
|
||||||
|
Create a Fernet cipher instance from a raw key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 32-byte raw key from Argon2id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fernet instance for encryption/decryption
|
||||||
|
"""
|
||||||
|
# Fernet requires a URL-safe base64-encoded 32-byte key
|
||||||
|
fernet_key = base64.urlsafe_b64encode(key)
|
||||||
|
return Fernet(fernet_key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_data(data: bytes, key: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Encrypt data using Fernet (AES-256-CBC).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Raw bytes to encrypt
|
||||||
|
key: 32-byte encryption key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encrypted data (includes IV and auth tag)
|
||||||
|
"""
|
||||||
|
fernet = create_fernet(key)
|
||||||
|
return fernet.encrypt(data)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_data(encrypted_data: bytes, key: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Decrypt data using Fernet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_data: Encrypted bytes from encrypt_data
|
||||||
|
key: 32-byte encryption key (must match encryption key)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted raw bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
cryptography.fernet.InvalidToken: If decryption fails (wrong key/corrupted data)
|
||||||
|
"""
|
||||||
|
fernet = create_fernet(key)
|
||||||
|
return fernet.decrypt(encrypted_data)
|
||||||
|
|
||||||
|
|
||||||
|
def create_verification_hash(password: str, salt: bytes) -> str:
|
||||||
|
"""
|
||||||
|
Create a verification hash to check if a password is correct.
|
||||||
|
|
||||||
|
This is NOT for storing the password - it's for verifying the password
|
||||||
|
unlocks the correct key without trying to decrypt the entire secrets file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
password: The master password
|
||||||
|
salt: The salt used for key derivation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64-encoded hash for verification
|
||||||
|
"""
|
||||||
|
# Derive key and hash it again for verification
|
||||||
|
key = derive_key_from_password(password, salt)
|
||||||
|
|
||||||
|
# Simple hash of the key for verification (not security critical since
|
||||||
|
# the key itself is already derived from Argon2id)
|
||||||
|
verification = base64.b64encode(key[:16]).decode('ascii')
|
||||||
|
|
||||||
|
return verification
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, salt: bytes, verification_hash: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify a password against a verification hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
password: Password to verify
|
||||||
|
salt: Salt used for key derivation
|
||||||
|
verification_hash: Expected verification hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if password is correct, False otherwise
|
||||||
|
"""
|
||||||
|
computed_hash = create_verification_hash(password, salt)
|
||||||
|
|
||||||
|
# Constant-time comparison to prevent timing attacks
|
||||||
|
return secrets_module.compare_digest(computed_hash, verification_hash)
|
||||||
406
backend/src/secrets_manager/store.py
Normal file
406
backend/src/secrets_manager/store.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
"""
|
||||||
|
Encrypted secrets store with master password protection.
|
||||||
|
|
||||||
|
The secrets are stored in an encrypted file, with the encryption key derived
|
||||||
|
from a master password using Argon2id. The master password can be changed
|
||||||
|
without re-encrypting all secrets.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
|
from cryptography.fernet import InvalidToken
|
||||||
|
|
||||||
|
from .crypto import (
|
||||||
|
generate_salt,
|
||||||
|
derive_key_from_password,
|
||||||
|
encrypt_data,
|
||||||
|
decrypt_data,
|
||||||
|
create_verification_hash,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SecretsStoreError(Exception):
|
||||||
|
"""Base exception for secrets store errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SecretsStoreLocked(SecretsStoreError):
|
||||||
|
"""Raised when trying to access secrets while store is locked."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidMasterPassword(SecretsStoreError):
|
||||||
|
"""Raised when master password is incorrect."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SecretsStore:
|
||||||
|
"""
|
||||||
|
Encrypted secrets store with master password protection.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Initialize (first time)
|
||||||
|
store = SecretsStore()
|
||||||
|
store.initialize("my-secure-password")
|
||||||
|
|
||||||
|
# Unlock
|
||||||
|
store = SecretsStore()
|
||||||
|
store.unlock("my-secure-password")
|
||||||
|
|
||||||
|
# Access secrets
|
||||||
|
api_key = store.get("ANTHROPIC_API_KEY")
|
||||||
|
store.set("NEW_SECRET", "secret-value")
|
||||||
|
|
||||||
|
# Change master password
|
||||||
|
store.change_master_password("my-secure-password", "new-password")
|
||||||
|
|
||||||
|
# Lock when done
|
||||||
|
store.lock()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: Optional[Path] = None):
|
||||||
|
"""
|
||||||
|
Initialize secrets store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Directory for secrets files (defaults to backend/data)
|
||||||
|
"""
|
||||||
|
if data_dir is None:
|
||||||
|
# Default to backend/data
|
||||||
|
backend_root = Path(__file__).parent.parent.parent
|
||||||
|
data_dir = backend_root / "data"
|
||||||
|
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.master_key_file = self.data_dir / ".master.key"
|
||||||
|
self.secrets_file = self.data_dir / "secrets.enc"
|
||||||
|
|
||||||
|
# Runtime state
|
||||||
|
self._encryption_key: Optional[bytes] = None
|
||||||
|
self._secrets: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if the secrets store has been initialized."""
|
||||||
|
return self.master_key_file.exists()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_unlocked(self) -> bool:
|
||||||
|
"""Check if the secrets store is currently unlocked."""
|
||||||
|
return self._encryption_key is not None
|
||||||
|
|
||||||
|
def initialize(self, master_password: str) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the secrets store with a master password.
|
||||||
|
|
||||||
|
This should only be called once when setting up the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_password: The master password to protect the secrets
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreError: If store is already initialized
|
||||||
|
"""
|
||||||
|
if self.is_initialized:
|
||||||
|
raise SecretsStoreError(
|
||||||
|
"Secrets store is already initialized. "
|
||||||
|
"Use unlock() to access it or change_master_password() to change the password."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate a new random salt
|
||||||
|
salt = generate_salt()
|
||||||
|
|
||||||
|
# Derive encryption key
|
||||||
|
encryption_key = derive_key_from_password(master_password, salt)
|
||||||
|
|
||||||
|
# Create verification hash
|
||||||
|
verification_hash = create_verification_hash(master_password, salt)
|
||||||
|
|
||||||
|
# Store salt and verification hash
|
||||||
|
master_key_data = {
|
||||||
|
"salt": salt.hex(),
|
||||||
|
"verification": verification_hash,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||||
|
|
||||||
|
# Set restrictive permissions (owner read/write only)
|
||||||
|
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||||
|
|
||||||
|
# Initialize empty secrets
|
||||||
|
self._encryption_key = encryption_key
|
||||||
|
self._secrets = {}
|
||||||
|
self._save_secrets()
|
||||||
|
|
||||||
|
print(f"✓ Secrets store initialized at {self.secrets_file}")
|
||||||
|
|
||||||
|
def unlock(self, master_password: str) -> None:
|
||||||
|
"""
|
||||||
|
Unlock the secrets store with the master password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_password: The master password
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreError: If store is not initialized
|
||||||
|
InvalidMasterPassword: If password is incorrect
|
||||||
|
"""
|
||||||
|
if not self.is_initialized:
|
||||||
|
raise SecretsStoreError(
|
||||||
|
"Secrets store is not initialized. Call initialize() first."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load salt and verification hash
|
||||||
|
master_key_data = json.loads(self.master_key_file.read_text())
|
||||||
|
salt = bytes.fromhex(master_key_data["salt"])
|
||||||
|
verification_hash = master_key_data["verification"]
|
||||||
|
|
||||||
|
# Verify password
|
||||||
|
if not verify_password(master_password, salt, verification_hash):
|
||||||
|
raise InvalidMasterPassword("Invalid master password")
|
||||||
|
|
||||||
|
# Derive encryption key
|
||||||
|
encryption_key = derive_key_from_password(master_password, salt)
|
||||||
|
|
||||||
|
# Load and decrypt secrets
|
||||||
|
if self.secrets_file.exists():
|
||||||
|
try:
|
||||||
|
encrypted_data = self.secrets_file.read_bytes()
|
||||||
|
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||||
|
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||||
|
except InvalidToken:
|
||||||
|
raise InvalidMasterPassword("Failed to decrypt secrets (invalid password)")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise SecretsStoreError(f"Corrupted secrets file: {e}")
|
||||||
|
else:
|
||||||
|
# No secrets file yet (fresh initialization)
|
||||||
|
self._secrets = {}
|
||||||
|
|
||||||
|
self._encryption_key = encryption_key
|
||||||
|
print(f"✓ Secrets store unlocked ({len(self._secrets)} secrets)")
|
||||||
|
|
||||||
|
def lock(self) -> None:
|
||||||
|
"""Lock the secrets store (clear decrypted data from memory)."""
|
||||||
|
self._encryption_key = None
|
||||||
|
self._secrets = None
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Get a secret value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Secret key name
|
||||||
|
default: Default value if key doesn't exist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Secret value or default
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreLocked: If store is locked
|
||||||
|
"""
|
||||||
|
if not self.is_unlocked:
|
||||||
|
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||||
|
|
||||||
|
return self._secrets.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Set a secret value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Secret key name
|
||||||
|
value: Secret value (must be JSON-serializable)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreLocked: If store is locked
|
||||||
|
"""
|
||||||
|
if not self.is_unlocked:
|
||||||
|
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||||
|
|
||||||
|
self._secrets[key] = value
|
||||||
|
self._save_secrets()
|
||||||
|
|
||||||
|
def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a secret.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Secret key name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if secret existed and was deleted, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreLocked: If store is locked
|
||||||
|
"""
|
||||||
|
if not self.is_unlocked:
|
||||||
|
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||||
|
|
||||||
|
if key in self._secrets:
|
||||||
|
del self._secrets[key]
|
||||||
|
self._save_secrets()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_keys(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all secret keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of secret keys
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreLocked: If store is locked
|
||||||
|
"""
|
||||||
|
if not self.is_unlocked:
|
||||||
|
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||||
|
|
||||||
|
return list(self._secrets.keys())
|
||||||
|
|
||||||
|
def change_master_password(self, current_password: str, new_password: str) -> None:
|
||||||
|
"""
|
||||||
|
Change the master password.
|
||||||
|
|
||||||
|
This re-encrypts the secrets with a new key derived from the new password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_password: Current master password
|
||||||
|
new_password: New master password
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidMasterPassword: If current password is incorrect
|
||||||
|
"""
|
||||||
|
# ALWAYS verify current password before changing
|
||||||
|
# Load salt and verification hash
|
||||||
|
if not self.is_initialized:
|
||||||
|
raise SecretsStoreError(
|
||||||
|
"Secrets store is not initialized. Call initialize() first."
|
||||||
|
)
|
||||||
|
|
||||||
|
master_key_data = json.loads(self.master_key_file.read_text())
|
||||||
|
salt = bytes.fromhex(master_key_data["salt"])
|
||||||
|
verification_hash = master_key_data["verification"]
|
||||||
|
|
||||||
|
# Verify current password is correct
|
||||||
|
if not verify_password(current_password, salt, verification_hash):
|
||||||
|
raise InvalidMasterPassword("Invalid current password")
|
||||||
|
|
||||||
|
# Unlock if needed to access secrets
|
||||||
|
was_unlocked = self.is_unlocked
|
||||||
|
if not was_unlocked:
|
||||||
|
# Store is locked, so unlock with current password
|
||||||
|
# (we already verified it above, so this will succeed)
|
||||||
|
encryption_key = derive_key_from_password(current_password, salt)
|
||||||
|
|
||||||
|
# Load and decrypt secrets
|
||||||
|
if self.secrets_file.exists():
|
||||||
|
encrypted_data = self.secrets_file.read_bytes()
|
||||||
|
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||||
|
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||||
|
else:
|
||||||
|
self._secrets = {}
|
||||||
|
|
||||||
|
self._encryption_key = encryption_key
|
||||||
|
|
||||||
|
# Generate new salt
|
||||||
|
new_salt = generate_salt()
|
||||||
|
|
||||||
|
# Derive new encryption key
|
||||||
|
new_encryption_key = derive_key_from_password(new_password, new_salt)
|
||||||
|
|
||||||
|
# Create new verification hash
|
||||||
|
new_verification_hash = create_verification_hash(new_password, new_salt)
|
||||||
|
|
||||||
|
# Update master key file
|
||||||
|
master_key_data = {
|
||||||
|
"salt": new_salt.hex(),
|
||||||
|
"verification": new_verification_hash,
|
||||||
|
}
|
||||||
|
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||||
|
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||||
|
|
||||||
|
# Re-encrypt secrets with new key
|
||||||
|
old_key = self._encryption_key
|
||||||
|
self._encryption_key = new_encryption_key
|
||||||
|
self._save_secrets()
|
||||||
|
|
||||||
|
print("✓ Master password changed successfully")
|
||||||
|
|
||||||
|
# Lock if it wasn't unlocked before
|
||||||
|
if not was_unlocked:
|
||||||
|
self.lock()
|
||||||
|
|
||||||
|
def _save_secrets(self) -> None:
|
||||||
|
"""Save secrets to encrypted file."""
|
||||||
|
if not self.is_unlocked:
|
||||||
|
raise SecretsStoreLocked("Cannot save while locked")
|
||||||
|
|
||||||
|
# Serialize secrets to JSON
|
||||||
|
secrets_json = json.dumps(self._secrets, indent=2)
|
||||||
|
secrets_bytes = secrets_json.encode('utf-8')
|
||||||
|
|
||||||
|
# Encrypt
|
||||||
|
encrypted_data = encrypt_data(secrets_bytes, self._encryption_key)
|
||||||
|
|
||||||
|
# Write to file
|
||||||
|
self.secrets_file.write_bytes(encrypted_data)
|
||||||
|
|
||||||
|
# Set restrictive permissions
|
||||||
|
os.chmod(self.secrets_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||||
|
|
||||||
|
def export_encrypted(self, output_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Export encrypted secrets to a file (for backup).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path to export file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SecretsStoreError: If secrets file doesn't exist
|
||||||
|
"""
|
||||||
|
if not self.secrets_file.exists():
|
||||||
|
raise SecretsStoreError("No secrets to export")
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(self.secrets_file, output_path)
|
||||||
|
print(f"✓ Encrypted secrets exported to {output_path}")
|
||||||
|
|
||||||
|
def import_encrypted(self, input_path: Path, master_password: str) -> None:
|
||||||
|
"""
|
||||||
|
Import encrypted secrets from a file.
|
||||||
|
|
||||||
|
This will verify the password can decrypt the import before replacing
|
||||||
|
the current secrets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to import file
|
||||||
|
master_password: Master password for the current store
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidMasterPassword: If password doesn't work with import
|
||||||
|
"""
|
||||||
|
if not self.is_unlocked:
|
||||||
|
self.unlock(master_password)
|
||||||
|
|
||||||
|
# Try to decrypt the imported file with current key
|
||||||
|
try:
|
||||||
|
encrypted_data = Path(input_path).read_bytes()
|
||||||
|
decrypted_data = decrypt_data(encrypted_data, self._encryption_key)
|
||||||
|
imported_secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||||
|
except InvalidToken:
|
||||||
|
raise InvalidMasterPassword(
|
||||||
|
"Cannot decrypt imported secrets with current master password"
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise SecretsStoreError(f"Corrupted import file: {e}")
|
||||||
|
|
||||||
|
# Replace secrets
|
||||||
|
self._secrets = imported_secrets
|
||||||
|
self._save_secrets()
|
||||||
|
|
||||||
|
print(f"✓ Imported {len(self._secrets)} secrets from {input_path}")
|
||||||
@@ -2,8 +2,10 @@
|
|||||||
Test script for CCXT DataSource adapter (Free Version).
|
Test script for CCXT DataSource adapter (Free Version).
|
||||||
|
|
||||||
This demonstrates how to use the free CCXT adapter (not ccxt.pro) with various
|
This demonstrates how to use the free CCXT adapter (not ccxt.pro) with various
|
||||||
exchanges. It uses polling instead of WebSocket for real-time updates and
|
exchanges. It uses polling instead of WebSocket for real-time updates.
|
||||||
verifies that Decimal precision is maintained throughout.
|
|
||||||
|
CCXT is configured to use Decimal mode internally for precision, but OHLCV data
|
||||||
|
is converted to float for optimal DataFrame/analysis performance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -73,10 +75,10 @@ async def test_binance_datasource():
|
|||||||
print(f" Close: {latest.data['close']} (type: {type(latest.data['close']).__name__})")
|
print(f" Close: {latest.data['close']} (type: {type(latest.data['close']).__name__})")
|
||||||
print(f" Volume: {latest.data['volume']} (type: {type(latest.data['volume']).__name__})")
|
print(f" Volume: {latest.data['volume']} (type: {type(latest.data['volume']).__name__})")
|
||||||
|
|
||||||
# Verify Decimal precision
|
# Verify OHLCV uses float (converted from Decimal for DataFrame performance)
|
||||||
assert isinstance(latest.data['close'], Decimal), "Price should be Decimal type!"
|
assert isinstance(latest.data['close'], float), "OHLCV price should be float type!"
|
||||||
assert isinstance(latest.data['volume'], Decimal), "Volume should be Decimal type!"
|
assert isinstance(latest.data['volume'], float), "OHLCV volume should be float type!"
|
||||||
print(f" ✓ Numerical precision verified: using Decimal types")
|
print(f" ✓ OHLCV data type verified: using native float (CCXT uses Decimal internally)")
|
||||||
|
|
||||||
# Test 5: Polling subscription (brief test)
|
# Test 5: Polling subscription (brief test)
|
||||||
print("\n5. Testing polling-based subscription...")
|
print("\n5. Testing polling-based subscription...")
|
||||||
@@ -87,7 +89,7 @@ async def test_binance_datasource():
|
|||||||
tick_count[0] += 1
|
tick_count[0] += 1
|
||||||
if tick_count[0] == 1:
|
if tick_count[0] == 1:
|
||||||
print(f" Received tick: close={data['close']} (type: {type(data['close']).__name__})")
|
print(f" Received tick: close={data['close']} (type: {type(data['close']).__name__})")
|
||||||
assert isinstance(data['close'], Decimal), "Polled data should use Decimal!"
|
assert isinstance(data['close'], float), "Polled OHLCV data should use float!"
|
||||||
|
|
||||||
subscription_id = await binance.subscribe_bars(
|
subscription_id = await binance.subscribe_bars(
|
||||||
symbol="BTC/USDT",
|
symbol="BTC/USDT",
|
||||||
|
|||||||
@@ -1,3 +1,27 @@
|
|||||||
FROM python:3.14-alpine
|
FROM python:3.14-alpine
|
||||||
|
|
||||||
COPY python/src /app/src
|
# Install TA-Lib C library and build dependencies
|
||||||
|
RUN apk add --no-cache --virtual .build-deps \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
make \
|
||||||
|
musl-dev \
|
||||||
|
wget \
|
||||||
|
&& apk add --no-cache \
|
||||||
|
ta-lib \
|
||||||
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements first for better caching
|
||||||
|
COPY backend/requirements.txt /app/requirements.txt
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Clean up build dependencies
|
||||||
|
RUN apk del .build-deps
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY backend/src /app/src
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import { wsManager } from './composables/useWebSocket'
|
|||||||
const isAuthenticated = ref(false)
|
const isAuthenticated = ref(false)
|
||||||
const needsConfirmation = ref(false)
|
const needsConfirmation = ref(false)
|
||||||
const authError = ref<string>()
|
const authError = ref<string>()
|
||||||
|
const isDragging = ref(false)
|
||||||
let stateSyncCleanup: (() => void) | null = null
|
let stateSyncCleanup: (() => void) | null = null
|
||||||
|
|
||||||
// Check if we need password confirmation on first load
|
// Check if we need password confirmation on first load
|
||||||
@@ -58,6 +59,21 @@ const handleAuthenticate = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
// Listen for splitter drag events
|
||||||
|
document.addEventListener('mousedown', (e) => {
|
||||||
|
// Check if the mousedown is on a splitter gutter
|
||||||
|
const target = e.target as HTMLElement
|
||||||
|
if (target.closest('.p-splitter-gutter')) {
|
||||||
|
isDragging.value = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
document.addEventListener('mouseup', () => {
|
||||||
|
isDragging.value = false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
onBeforeUnmount(() => {
|
onBeforeUnmount(() => {
|
||||||
if (stateSyncCleanup) {
|
if (stateSyncCleanup) {
|
||||||
stateSyncCleanup()
|
stateSyncCleanup()
|
||||||
@@ -82,6 +98,8 @@ onBeforeUnmount(() => {
|
|||||||
<ChatPanel />
|
<ChatPanel />
|
||||||
</SplitterPanel>
|
</SplitterPanel>
|
||||||
</Splitter>
|
</Splitter>
|
||||||
|
<!-- Transparent overlay to prevent iframe from capturing mouse events during drag -->
|
||||||
|
<div v-if="isDragging" class="drag-overlay"></div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -90,19 +108,24 @@ onBeforeUnmount(() => {
|
|||||||
width: 100vw !important;
|
width: 100vw !important;
|
||||||
height: 100vh !important;
|
height: 100vh !important;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
background: var(--p-surface-0);
|
background: var(--p-surface-0) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.app-container.dark {
|
||||||
|
background: var(--p-surface-0) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.main-splitter {
|
.main-splitter {
|
||||||
height: 100vh !important;
|
height: 100vh !important;
|
||||||
|
background: var(--p-surface-0) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.main-splitter :deep(.p-splitter-gutter) {
|
.main-splitter :deep(.p-splitter-gutter) {
|
||||||
background: var(--p-surface-100);
|
background: var(--p-surface-0) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.main-splitter :deep(.p-splitter-gutter-handle) {
|
.main-splitter :deep(.p-splitter-gutter-handle) {
|
||||||
background: var(--p-primary-color);
|
background: var(--p-surface-400) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chart-panel,
|
.chart-panel,
|
||||||
@@ -119,4 +142,15 @@ onBeforeUnmount(() => {
|
|||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.drag-overlay {
|
||||||
|
position: fixed;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
z-index: 9999;
|
||||||
|
cursor: col-resize;
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -26,3 +26,16 @@ html, body, #app {
|
|||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
background-color: var(--p-surface-0) !important;
|
background-color: var(--p-surface-0) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.dark {
|
||||||
|
background-color: var(--p-surface-0) !important;
|
||||||
|
color: var(--p-surface-900) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Ensure dark background for main containers */
|
||||||
|
.app-container,
|
||||||
|
.main-splitter,
|
||||||
|
.p-splitter,
|
||||||
|
.p-splitter-panel {
|
||||||
|
background-color: var(--p-surface-0) !important;
|
||||||
|
}
|
||||||
|
|||||||
@@ -191,9 +191,14 @@ onBeforeUnmount(() => {
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
border: none;
|
border: none;
|
||||||
|
border-radius: 0 !important;
|
||||||
background: var(--p-surface-0);
|
background: var(--p-surface-0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.chart-card :deep(.p-card) {
|
||||||
|
border-radius: 0 !important;
|
||||||
|
}
|
||||||
|
|
||||||
.chart-card :deep(.p-card-body) {
|
.chart-card :deep(.p-card-body) {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
display: flex;
|
display: flex;
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
||||||
import { register } from 'vue-advanced-chat'
|
import { register } from 'vue-advanced-chat'
|
||||||
import Badge from 'primevue/badge'
|
import Badge from 'primevue/badge'
|
||||||
|
import Button from 'primevue/button'
|
||||||
import { wsManager } from '../composables/useWebSocket'
|
import { wsManager } from '../composables/useWebSocket'
|
||||||
import type { WebSocketMessage } from '../composables/useWebSocket'
|
import type { WebSocketMessage } from '../composables/useWebSocket'
|
||||||
|
|
||||||
@@ -17,7 +18,7 @@ const messages = ref<any[]>([])
|
|||||||
const messagesLoaded = ref(false)
|
const messagesLoaded = ref(false)
|
||||||
const isConnected = wsManager.isConnected
|
const isConnected = wsManager.isConnected
|
||||||
|
|
||||||
// Reactive rooms that update based on WebSocket connection
|
// Reactive rooms that update based on WebSocket connection and agent processing state
|
||||||
const rooms = computed(() => [{
|
const rooms = computed(() => [{
|
||||||
roomId: SESSION_ID,
|
roomId: SESSION_ID,
|
||||||
roomName: 'AI Agent',
|
roomName: 'AI Agent',
|
||||||
@@ -26,23 +27,29 @@ const rooms = computed(() => [{
|
|||||||
{ _id: CURRENT_USER_ID, username: 'You' },
|
{ _id: CURRENT_USER_ID, username: 'You' },
|
||||||
{ _id: AGENT_ID, username: 'AI Agent', status: { state: isConnected.value ? 'online' : 'offline' } }
|
{ _id: AGENT_ID, username: 'AI Agent', status: { state: isConnected.value ? 'online' : 'offline' } }
|
||||||
],
|
],
|
||||||
unreadCount: 0
|
unreadCount: 0,
|
||||||
|
typingUsers: isAgentProcessing.value ? [AGENT_ID] : []
|
||||||
}])
|
}])
|
||||||
|
|
||||||
// Streaming state
|
// Streaming state
|
||||||
let currentStreamingMessageId: string | null = null
|
let currentStreamingMessageId: string | null = null
|
||||||
let streamingBuffer = ''
|
let streamingBuffer = ''
|
||||||
|
const isAgentProcessing = ref(false)
|
||||||
|
|
||||||
// Generate message ID
|
// Generate message ID
|
||||||
const generateMessageId = () => `msg-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`
|
const generateMessageId = () => `msg-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`
|
||||||
|
|
||||||
// Handle WebSocket messages
|
// Handle WebSocket messages
|
||||||
const handleMessage = (data: WebSocketMessage) => {
|
const handleMessage = (data: WebSocketMessage) => {
|
||||||
|
console.log('[ChatPanel] Received message:', data)
|
||||||
if (data.type === 'agent_chunk') {
|
if (data.type === 'agent_chunk') {
|
||||||
|
console.log('[ChatPanel] Processing agent_chunk, content:', data.content, 'done:', data.done)
|
||||||
const timestamp = new Date().toTimeString().split(' ')[0].slice(0, 5)
|
const timestamp = new Date().toTimeString().split(' ')[0].slice(0, 5)
|
||||||
|
|
||||||
if (!currentStreamingMessageId) {
|
if (!currentStreamingMessageId) {
|
||||||
|
console.log('[ChatPanel] Starting new streaming message')
|
||||||
// Start new streaming message
|
// Start new streaming message
|
||||||
|
isAgentProcessing.value = true
|
||||||
currentStreamingMessageId = generateMessageId()
|
currentStreamingMessageId = generateMessageId()
|
||||||
streamingBuffer = data.content
|
streamingBuffer = data.content
|
||||||
|
|
||||||
@@ -54,7 +61,8 @@ const handleMessage = (data: WebSocketMessage) => {
|
|||||||
date: new Date().toLocaleDateString(),
|
date: new Date().toLocaleDateString(),
|
||||||
saved: false,
|
saved: false,
|
||||||
distributed: false,
|
distributed: false,
|
||||||
seen: false
|
seen: false,
|
||||||
|
files: []
|
||||||
}]
|
}]
|
||||||
} else {
|
} else {
|
||||||
// Update existing streaming message
|
// Update existing streaming message
|
||||||
@@ -62,10 +70,24 @@ const handleMessage = (data: WebSocketMessage) => {
|
|||||||
const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId)
|
const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId)
|
||||||
|
|
||||||
if (msgIndex !== -1) {
|
if (msgIndex !== -1) {
|
||||||
messages.value[msgIndex] = {
|
const updatedMessage: any = {
|
||||||
...messages.value[msgIndex],
|
...messages.value[msgIndex],
|
||||||
content: streamingBuffer
|
content: streamingBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add plot images if present in metadata
|
||||||
|
if (data.metadata && data.metadata.plot_urls && Array.isArray(data.metadata.plot_urls)) {
|
||||||
|
const plotFiles = data.metadata.plot_urls.map((url: string, idx: number) => ({
|
||||||
|
name: `plot_${idx + 1}.png`,
|
||||||
|
size: 0,
|
||||||
|
type: 'png',
|
||||||
|
url: `${BACKEND_URL}${url}`,
|
||||||
|
preview: `${BACKEND_URL}${url}`
|
||||||
|
}))
|
||||||
|
updatedMessage.files = plotFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.value[msgIndex] = updatedMessage
|
||||||
messages.value = [...messages.value]
|
messages.value = [...messages.value]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,21 +96,49 @@ const handleMessage = (data: WebSocketMessage) => {
|
|||||||
// Mark message as complete
|
// Mark message as complete
|
||||||
const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId)
|
const msgIndex = messages.value.findIndex(m => m._id === currentStreamingMessageId)
|
||||||
if (msgIndex !== -1) {
|
if (msgIndex !== -1) {
|
||||||
messages.value[msgIndex] = {
|
const finalMessage: any = {
|
||||||
...messages.value[msgIndex],
|
...messages.value[msgIndex],
|
||||||
saved: true,
|
saved: true,
|
||||||
distributed: true,
|
distributed: true,
|
||||||
seen: true
|
seen: true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure plot images are included in final message
|
||||||
|
if (data.metadata && data.metadata.plot_urls && Array.isArray(data.metadata.plot_urls)) {
|
||||||
|
const plotFiles = data.metadata.plot_urls.map((url: string, idx: number) => ({
|
||||||
|
name: `plot_${idx + 1}.png`,
|
||||||
|
size: 0,
|
||||||
|
type: 'png',
|
||||||
|
url: `${BACKEND_URL}${url}`,
|
||||||
|
preview: `${BACKEND_URL}${url}`
|
||||||
|
}))
|
||||||
|
finalMessage.files = plotFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.value[msgIndex] = finalMessage
|
||||||
messages.value = [...messages.value]
|
messages.value = [...messages.value]
|
||||||
}
|
}
|
||||||
|
|
||||||
currentStreamingMessageId = null
|
currentStreamingMessageId = null
|
||||||
streamingBuffer = ''
|
streamingBuffer = ''
|
||||||
|
isAgentProcessing.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop agent processing
|
||||||
|
const stopAgent = () => {
|
||||||
|
// Send empty message to trigger interrupt without new agent round
|
||||||
|
const wsMessage = {
|
||||||
|
type: 'agent_user_message',
|
||||||
|
session_id: SESSION_ID,
|
||||||
|
content: '',
|
||||||
|
attachments: []
|
||||||
|
}
|
||||||
|
wsManager.send(wsMessage)
|
||||||
|
isAgentProcessing.value = false
|
||||||
|
}
|
||||||
|
|
||||||
// Send message handler
|
// Send message handler
|
||||||
const sendMessage = async (event: any) => {
|
const sendMessage = async (event: any) => {
|
||||||
// Extract data from CustomEvent.detail[0]
|
// Extract data from CustomEvent.detail[0]
|
||||||
@@ -191,39 +241,39 @@ const openFile = ({ file }: any) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Theme configuration for dark mode
|
// Theme configuration for dark mode
|
||||||
const chatTheme = 'light'
|
const chatTheme = 'dark'
|
||||||
|
|
||||||
// Styles to match PrimeVue theme
|
// Styles to match PrimeVue theme
|
||||||
const chatStyles = computed(() => JSON.stringify({
|
const chatStyles = computed(() => JSON.stringify({
|
||||||
general: {
|
general: {
|
||||||
color: 'var(--p-surface-900)',
|
color: '#cdd6e8',
|
||||||
colorSpinner: 'var(--p-primary-color)',
|
colorSpinner: '#00d4aa',
|
||||||
borderStyle: '1px solid var(--p-surface-200)'
|
borderStyle: '1px solid #263452'
|
||||||
},
|
},
|
||||||
container: {
|
container: {
|
||||||
background: 'var(--p-surface-0)'
|
background: '#0a0e1a'
|
||||||
},
|
},
|
||||||
header: {
|
header: {
|
||||||
background: 'var(--p-surface-50)',
|
background: '#0f1629',
|
||||||
colorRoomName: 'var(--p-surface-900)',
|
colorRoomName: '#cdd6e8',
|
||||||
colorRoomInfo: 'var(--p-surface-700)'
|
colorRoomInfo: '#8892a4'
|
||||||
},
|
},
|
||||||
footer: {
|
footer: {
|
||||||
background: 'var(--p-surface-50)',
|
background: '#0f1629',
|
||||||
borderStyleInput: '1px solid var(--p-surface-300)',
|
borderStyleInput: '1px solid #263452',
|
||||||
backgroundInput: 'var(--p-surface-200)',
|
backgroundInput: '#161e35',
|
||||||
colorInput: 'var(--p-surface-900)',
|
colorInput: '#cdd6e8',
|
||||||
colorPlaceholder: 'var(--p-surface-400)',
|
colorPlaceholder: '#8892a4',
|
||||||
colorIcons: 'var(--p-surface-400)'
|
colorIcons: '#8892a4'
|
||||||
},
|
},
|
||||||
content: {
|
content: {
|
||||||
background: 'var(--p-surface-0)'
|
background: '#0a0e1a'
|
||||||
},
|
},
|
||||||
message: {
|
message: {
|
||||||
background: 'var(--p-surface-100)',
|
background: '#161e35',
|
||||||
backgroundMe: 'var(--p-primary-color)',
|
backgroundMe: '#00d4aa',
|
||||||
color: 'var(--p-surface-900)',
|
color: '#cdd6e8',
|
||||||
colorMe: 'var(--p-primary-contrast-color)'
|
colorMe: '#0a0e1a'
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -231,6 +281,14 @@ onMounted(() => {
|
|||||||
wsManager.addHandler(handleMessage)
|
wsManager.addHandler(handleMessage)
|
||||||
// Mark messages as loaded after initialization
|
// Mark messages as loaded after initialization
|
||||||
messagesLoaded.value = true
|
messagesLoaded.value = true
|
||||||
|
|
||||||
|
// Focus on the chat input when component mounts
|
||||||
|
setTimeout(() => {
|
||||||
|
const chatInput = document.querySelector('.vac-textarea') as HTMLTextAreaElement
|
||||||
|
if (chatInput) {
|
||||||
|
chatInput.focus()
|
||||||
|
}
|
||||||
|
}, 300)
|
||||||
})
|
})
|
||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
@@ -251,7 +309,7 @@ onUnmounted(() => {
|
|||||||
-->
|
-->
|
||||||
|
|
||||||
<vue-advanced-chat
|
<vue-advanced-chat
|
||||||
height="calc(100vh - 60px)"
|
height="100vh"
|
||||||
:current-user-id="CURRENT_USER_ID"
|
:current-user-id="CURRENT_USER_ID"
|
||||||
:rooms="JSON.stringify(rooms)"
|
:rooms="JSON.stringify(rooms)"
|
||||||
:messages="JSON.stringify(messages)"
|
:messages="JSON.stringify(messages)"
|
||||||
@@ -267,10 +325,22 @@ onUnmounted(() => {
|
|||||||
:show-emojis="true"
|
:show-emojis="true"
|
||||||
:show-reaction-emojis="false"
|
:show-reaction-emojis="false"
|
||||||
:accepted-files="'image/*,video/*,application/pdf'"
|
:accepted-files="'image/*,video/*,application/pdf'"
|
||||||
|
:message-images="true"
|
||||||
@send-message="sendMessage"
|
@send-message="sendMessage"
|
||||||
@fetch-messages="fetchMessages"
|
@fetch-messages="fetchMessages"
|
||||||
@open-file="openFile"
|
@open-file="openFile"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<!-- Stop button overlay -->
|
||||||
|
<div v-if="isAgentProcessing" class="stop-button-container">
|
||||||
|
<Button
|
||||||
|
icon="pi pi-stop-circle"
|
||||||
|
label="Stop"
|
||||||
|
severity="danger"
|
||||||
|
@click="stopAgent"
|
||||||
|
class="stop-button"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -279,8 +349,9 @@ onUnmounted(() => {
|
|||||||
height: 100% !important;
|
height: 100% !important;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
background: var(--p-surface-0);
|
background: var(--p-surface-0) !important;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
|
position: relative;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat-container :deep(.vac-container) {
|
.chat-container :deep(.vac-container) {
|
||||||
@@ -306,4 +377,25 @@ onUnmounted(() => {
|
|||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
color: var(--p-surface-900);
|
color: var(--p-surface-900);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.stop-button-container {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 80px;
|
||||||
|
right: 20px;
|
||||||
|
z-index: 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stop-button {
|
||||||
|
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
||||||
|
animation: pulse 2s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
opacity: 0.8;
|
||||||
|
}
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import Card from 'primevue/card'
|
import Card from 'primevue/card'
|
||||||
import InputText from 'primevue/inputtext'
|
import InputText from 'primevue/inputtext'
|
||||||
import Password from 'primevue/password'
|
import Password from 'primevue/password'
|
||||||
@@ -66,6 +66,14 @@ const togglePasswordChange = () => {
|
|||||||
newPassword.value = ''
|
newPassword.value = ''
|
||||||
confirmNewPassword.value = ''
|
confirmNewPassword.value = ''
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
// Focus on the password input when component mounts
|
||||||
|
const passwordInput = document.querySelector('#password input') as HTMLInputElement
|
||||||
|
if (passwordInput) {
|
||||||
|
passwordInput.focus()
|
||||||
|
}
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
@@ -81,13 +89,13 @@ const togglePasswordChange = () => {
|
|||||||
<template #content>
|
<template #content>
|
||||||
<div class="login-content">
|
<div class="login-content">
|
||||||
<p v-if="needsConfirmation" class="welcome-message">
|
<p v-if="needsConfirmation" class="welcome-message">
|
||||||
This is your first time connecting. Please create a master password to secure your workspace.
|
This is your first time connecting. Please create a password to secure your workspace.
|
||||||
</p>
|
</p>
|
||||||
<p v-else-if="isChangingPassword" class="welcome-message">
|
<p v-else-if="isChangingPassword" class="welcome-message">
|
||||||
Enter your current password and choose a new one.
|
Enter your current password and choose a new one.
|
||||||
</p>
|
</p>
|
||||||
<p v-else class="welcome-message">
|
<p v-else class="welcome-message">
|
||||||
Enter your master password to connect.
|
Enter your password to connect.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<Message v-if="errorMessage" severity="error" :closable="false">
|
<Message v-if="errorMessage" severity="error" :closable="false">
|
||||||
|
|||||||
Reference in New Issue
Block a user