Files
ai/backend/src/agent/tools/chart_tools.py

455 lines
16 KiB
Python

"""Chart data access and analysis tools."""
from typing import Dict, Any, Optional, Tuple
import io
import uuid
import logging
from pathlib import Path
from contextlib import redirect_stdout, redirect_stderr
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
def _get_registry():
"""Get the global registry instance."""
from . import _registry
return _registry
def _get_datasource_registry():
"""Get the global datasource registry instance."""
from . import _datasource_registry
return _datasource_registry
def _get_indicator_registry():
"""Get the global indicator registry instance."""
from . import _indicator_registry
return _indicator_registry
def _get_order_store():
"""Get the global OrderStore instance."""
registry = _get_registry()
if registry and "OrderStore" in registry.entries:
return registry.entries["OrderStore"].model
return None
def _get_chart_store():
"""Get the global ChartStore instance."""
registry = _get_registry()
if registry and "ChartStore" in registry.entries:
return registry.entries["ChartStore"].model
return None
async def _get_chart_data_impl(countback: Optional[int] = None):
"""Internal implementation for getting chart data.
This is a helper function that can be called by both get_chart_data tool
and analyze_chart_data tool.
Returns:
Tuple of (HistoryResult, chart_context dict, source_name)
"""
registry = _get_registry()
datasource_registry = _get_datasource_registry()
if not registry:
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
if not datasource_registry:
raise ValueError("DataSourceRegistry not initialized - cannot query data")
# Read current chart state
chart_store = registry.entries.get("ChartStore")
if not chart_store:
raise ValueError("ChartStore not found in registry")
chart_state = chart_store.model.model_dump(mode="json")
chart_data = chart_state.get("chart_state", {})
symbol = chart_data.get("symbol", "")
interval = chart_data.get("interval", "15")
start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time")
if not symbol or symbol is None:
raise ValueError(
"No chart visible - ChartStore symbol is None. "
"The user is likely on a narrow screen (mobile) where charts are hidden. "
"Let them know they can view charts on a wider screen, or use get_historical_data() "
"if they specify a symbol and timeframe."
)
# Parse the symbol to extract exchange/source and symbol name
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
if ":" not in symbol:
raise ValueError(
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
)
exchange_prefix, symbol_name = symbol.split(":", 1)
source_name = exchange_prefix.lower()
# Get the data source
source = datasource_registry.get(source_name)
if not source:
available = datasource_registry.list_sources()
raise ValueError(
f"Data source '{source_name}' not found. Available sources: {available}. "
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
)
# Determine time range - REQUIRE it to be set, no defaults
if start_time is None or end_time is None:
raise ValueError(
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
f"Wait for the chart to fully load before analyzing data."
)
from_time = int(start_time)
end_time = int(end_time)
logger.info(
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
f"countback={countback}"
)
logger.info(
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
f"resolution '{interval}'"
)
# Query the data source
result = await source.get_bars(
symbol=symbol_name,
resolution=interval,
from_time=from_time,
to_time=end_time,
countback=countback
)
logger.info(
f"Received {len(result.bars)} bars from data source. "
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
)
# Build chart context to return along with result
chart_context = {
"symbol": symbol,
"interval": interval,
"start_time": start_time,
"end_time": end_time
}
return result, chart_context, source_name
@tool
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
"""Get the candle/bar data for what the user is currently viewing on their chart.
This is a convenience tool that automatically:
1. Reads the ChartStore to see what chart the user is viewing
2. Parses the symbol to determine the data source (exchange prefix)
3. Queries the appropriate data source for that symbol's data
4. Returns the data for the visible time range and interval
This is the preferred way to access chart data when helping the user analyze
what they're looking at, since it automatically uses their current chart context.
**IMPORTANT**: This tool will fail if ChartStore.symbol is None (no chart visible).
This happens when the user is on a narrow screen (mobile) where charts are hidden.
In that case, let the user know charts are only visible on wider screens, or use
get_historical_data() if they specify a symbol and timeframe.
Args:
countback: Optional limit on number of bars to return. If not specified,
returns all bars in the visible time range.
Returns:
Dictionary containing:
- chart_context: Current chart state (symbol, interval, time range)
- symbol: The trading pair being viewed
- resolution: The chart interval
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- source: Which data source was used
Raises:
ValueError: If ChartStore or DataSourceRegistry is not initialized,
if no chart is visible (symbol is None), or if the symbol format is invalid
Example:
# User is viewing BINANCE:BTC/USDT on 15min chart
data = get_chart_data()
# Returns BTC/USDT data from binance source at 15min resolution
# for the currently visible time range
"""
result, chart_context, source_name = await _get_chart_data_impl(countback)
# Return enriched result with chart context
response = result.model_dump()
response["chart_context"] = chart_context
response["source"] = source_name
return response
@tool
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
"""Execute Python code for technical analysis with automatic chart data loading.
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
This is your go-to tool whenever the user asks about indicators, wants to see
a chart, or needs any computational analysis of market data.
Pre-loaded Environment:
- `pd` : pandas
- `np` : numpy
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
- `talib` : TA-Lib technical analysis library
- `indicator_registry`: 150+ registered indicators
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
- `registry` : SyncRegistry instance - access to all registered stores
- `datasource_registry`: DataSourceRegistry - access to data sources (binance, etc.)
- `order_store` : OrderStore instance - current orders list
- `chart_store` : ChartStore instance - current chart state
Auto-loaded when user has a chart visible (ChartStore.symbol is not None):
- `df` : pandas DataFrame with DatetimeIndex and columns:
open, high, low, close, volume (OHLCV data ready to use)
- `chart_context` : dict with symbol, interval, start_time, end_time
When NO chart is visible (narrow screen/mobile):
- `df` : None
- `chart_context` : None
If `df` is None, you can still load alternative data by:
- Using chart_store to see what symbol/timeframe is configured
- Using datasource_registry.get_source('binance') to access data sources
- Calling datasource.get_history(symbol, interval, start, end) to load any data
- This allows you to make plots of ANY chart even when not connected to chart view
The `plot_ohlc()` Helper:
Create professional candlestick charts instantly:
- `plot_ohlc(df)` - basic OHLC chart with volume
- `plot_ohlc(df, title='BTC 15min')` - with custom title
- `plot_ohlc(df, volume=False)` - price only, no volume
- Returns a matplotlib Figure that's automatically saved to plot_urls
Args:
code: Python code to execute
countback: Optional limit on number of bars to load (default: all visible bars)
Returns:
Dictionary with:
- script_output : printed output + last expression result
- result_dataframe : serialized DataFrame if last expression is a DataFrame
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
- chart_context : {symbol, interval, start_time, end_time} or None
- error : traceback if execution failed
Examples:
# RSI indicator with chart
execute_python(\"\"\"
df['RSI'] = talib.RSI(df['close'], 14)
fig = plot_ohlc(df, title='BTC/USDT with RSI')
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
df[['close', 'RSI']].tail(5)
\"\"\")
# Multiple indicators
execute_python(\"\"\"
df['SMA_20'] = df['close'].rolling(20).mean()
df['SMA_50'] = df['close'].rolling(50).mean()
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
current_price = df['close'].iloc[-1]
sma20 = df['SMA_20'].iloc[-1]
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
\"\"\")
# Pattern detection
execute_python(\"\"\"
# Find swing highs
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
swing_highs = df[df['swing_high']][['high']].tail(5)
fig = plot_ohlc(df, title='Swing High Detection')
print("Recent swing highs:")
print(swing_highs)
\"\"\")
# Load alternative data when df is None or for different symbol/timeframe
execute_python(\"\"\"
from datetime import datetime, timedelta
# Get data source
binance = datasource_registry.get_source('binance')
# Load ETH data even if viewing BTC chart
end_time = datetime.now()
start_time = end_time - timedelta(days=7)
result = await binance.get_history(
symbol='ETH/USDT',
interval='1h',
start=int(start_time.timestamp()),
end=int(end_time.timestamp())
)
# Convert to DataFrame
rows = [{'time': pd.to_datetime(bar.time, unit='s'), **bar.data} for bar in result.bars]
eth_df = pd.DataFrame(rows).set_index('time')
# Calculate RSI and plot
eth_df['RSI'] = talib.RSI(eth_df['close'], 14)
fig = plot_ohlc(eth_df, title='ETH/USDT 1h - RSI Analysis')
print(f"ETH RSI: {eth_df['RSI'].iloc[-1]:.2f}")
\"\"\")
# Access chart store to see current state
execute_python(\"\"\"
print(f"Current symbol: {chart_store.chart_state.symbol}")
print(f"Current interval: {chart_store.chart_state.interval}")
print(f"Orders: {len(order_store.orders)}")
\"\"\")
"""
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
try:
import talib
except ImportError:
talib = None
logger.warning("TA-Lib not available in execute_python environment")
# --- Attempt to load chart data ---
df = None
chart_context = None
registry = _get_registry()
datasource_registry = _get_datasource_registry()
if registry and datasource_registry:
try:
result, chart_context, source_name = await _get_chart_data_impl(countback)
bars = result.bars
if bars:
rows = []
for bar in bars:
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
df = pd.DataFrame(rows).set_index('time')
for col in ['open', 'high', 'low', 'close', 'volume']:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
except Exception as e:
logger.info(f"execute_python: no chart data loaded ({e})")
# --- Import chart utilities ---
from .chart_utils import plot_ohlc
# --- Get indicator registry ---
indicator_registry = _get_indicator_registry()
# --- Get DataStores ---
order_store = _get_order_store()
chart_store = _get_chart_store()
# --- Build globals ---
script_globals: Dict[str, Any] = {
'pd': pd,
'np': np,
'plt': plt,
'talib': talib,
'indicator_registry': indicator_registry,
'registry': registry,
'datasource_registry': datasource_registry,
'order_store': order_store,
'chart_store': chart_store,
'df': df,
'chart_context': chart_context,
'plot_ohlc': plot_ohlc,
}
# --- Execute ---
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
stdout_capture = io.StringIO()
result_df = None
error_msg = None
plot_urls = []
try:
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
exec(code, script_globals)
# Capture last expression
lines = code.strip().splitlines()
if lines:
last = lines[-1].strip()
if last and not any(last.startswith(kw) for kw in (
'if', 'for', 'while', 'def', 'class', 'import',
'from', 'with', 'try', 'return', '#'
)):
try:
last_val = eval(last, script_globals)
if isinstance(last_val, pd.DataFrame):
result_df = last_val
elif last_val is not None:
stdout_capture.write(str(last_val))
except Exception:
pass
# Save plots
for fig_num in plt.get_fignums():
fig = plt.figure(fig_num)
filename = f"plot_{uuid.uuid4()}.png"
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
plot_urls.append(f"/uploads/{filename}")
plt.close(fig)
except Exception as e:
import traceback
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
# --- Build response ---
response: Dict[str, Any] = {
'script_output': stdout_capture.getvalue(),
'chart_context': chart_context,
'plot_urls': plot_urls,
}
if result_df is not None:
response['result_dataframe'] = {
'columns': result_df.columns.tolist(),
'index': result_df.index.astype(str).tolist(),
'data': result_df.values.tolist(),
'shape': result_df.shape,
}
if error_msg:
response['error'] = error_msg
return response
CHART_TOOLS = [
get_chart_data,
execute_python
]