indicators and plots
This commit is contained in:
371
backend/src/agent/tools/chart_tools.py
Normal file
371
backend/src/agent/tools/chart_tools.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Chart data access and analysis tools."""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
import io
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||
"""Internal implementation for getting chart data.
|
||||
|
||||
This is a helper function that can be called by both get_chart_data tool
|
||||
and analyze_chart_data tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||
"""
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
# Read current chart state
|
||||
chart_store = registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
raise ValueError("ChartStore not found in registry")
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
symbol = chart_data.get("symbol", "")
|
||||
interval = chart_data.get("interval", "15")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
if not symbol:
|
||||
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
|
||||
|
||||
# Parse the symbol to extract exchange/source and symbol name
|
||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
if ":" not in symbol:
|
||||
raise ValueError(
|
||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||
)
|
||||
|
||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
|
||||
# Get the data source
|
||||
source = datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = datasource_registry.list_sources()
|
||||
raise ValueError(
|
||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||
)
|
||||
|
||||
# Determine time range - REQUIRE it to be set, no defaults
|
||||
if start_time is None or end_time is None:
|
||||
raise ValueError(
|
||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||
f"Wait for the chart to fully load before analyzing data."
|
||||
)
|
||||
|
||||
from_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
logger.info(
|
||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||
f"countback={countback}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||
f"resolution '{interval}'"
|
||||
)
|
||||
|
||||
# Query the data source
|
||||
result = await source.get_bars(
|
||||
symbol=symbol_name,
|
||||
resolution=interval,
|
||||
from_time=from_time,
|
||||
to_time=end_time,
|
||||
countback=countback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received {len(result.bars)} bars from data source. "
|
||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Build chart context to return along with result
|
||||
chart_context = {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
}
|
||||
|
||||
return result, chart_context, source_name
|
||||
|
||||
|
||||
@tool
|
||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||
|
||||
This is a convenience tool that automatically:
|
||||
1. Reads the ChartStore to see what chart the user is viewing
|
||||
2. Parses the symbol to determine the data source (exchange prefix)
|
||||
3. Queries the appropriate data source for that symbol's data
|
||||
4. Returns the data for the visible time range and interval
|
||||
|
||||
This is the preferred way to access chart data when helping the user analyze
|
||||
what they're looking at, since it automatically uses their current chart context.
|
||||
|
||||
Args:
|
||||
countback: Optional limit on number of bars to return. If not specified,
|
||||
returns all bars in the visible time range.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- symbol: The trading pair being viewed
|
||||
- resolution: The chart interval
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- source: Which data source was used
|
||||
|
||||
Raises:
|
||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||
or if the symbol format is invalid
|
||||
|
||||
Example:
|
||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||
data = get_chart_data()
|
||||
# Returns BTC/USDT data from binance source at 15min resolution
|
||||
# for the currently visible time range
|
||||
"""
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Return enriched result with chart context
|
||||
response = result.model_dump()
|
||||
response["chart_context"] = chart_context
|
||||
response["source"] = source_name
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Execute Python code for technical analysis with automatic chart data loading.
|
||||
|
||||
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
|
||||
|
||||
This is your go-to tool whenever the user asks about indicators, wants to see
|
||||
a chart, or needs any computational analysis of market data.
|
||||
|
||||
Pre-loaded Environment:
|
||||
- `pd` : pandas
|
||||
- `np` : numpy
|
||||
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
|
||||
- `talib` : TA-Lib technical analysis library
|
||||
- `indicator_registry`: 150+ registered indicators
|
||||
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
|
||||
|
||||
Auto-loaded when user has a chart open:
|
||||
- `df` : pandas DataFrame with DatetimeIndex and columns:
|
||||
open, high, low, close, volume (OHLCV data ready to use)
|
||||
- `chart_context` : dict with symbol, interval, start_time, end_time
|
||||
|
||||
The `plot_ohlc()` Helper:
|
||||
Create professional candlestick charts instantly:
|
||||
- `plot_ohlc(df)` - basic OHLC chart with volume
|
||||
- `plot_ohlc(df, title='BTC 15min')` - with custom title
|
||||
- `plot_ohlc(df, volume=False)` - price only, no volume
|
||||
- Returns a matplotlib Figure that's automatically saved to plot_urls
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
countback: Optional limit on number of bars to load (default: all visible bars)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- script_output : printed output + last expression result
|
||||
- result_dataframe : serialized DataFrame if last expression is a DataFrame
|
||||
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
|
||||
- chart_context : {symbol, interval, start_time, end_time} or None
|
||||
- error : traceback if execution failed
|
||||
|
||||
Examples:
|
||||
# RSI indicator with chart
|
||||
execute_python(\"\"\"
|
||||
df['RSI'] = talib.RSI(df['close'], 14)
|
||||
fig = plot_ohlc(df, title='BTC/USDT with RSI')
|
||||
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
|
||||
df[['close', 'RSI']].tail(5)
|
||||
\"\"\")
|
||||
|
||||
# Multiple indicators
|
||||
execute_python(\"\"\"
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
|
||||
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
|
||||
|
||||
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
|
||||
|
||||
current_price = df['close'].iloc[-1]
|
||||
sma20 = df['SMA_20'].iloc[-1]
|
||||
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
|
||||
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
|
||||
\"\"\")
|
||||
|
||||
# Pattern detection
|
||||
execute_python(\"\"\"
|
||||
# Find swing highs
|
||||
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
|
||||
swing_highs = df[df['swing_high']][['high']].tail(5)
|
||||
|
||||
fig = plot_ohlc(df, title='Swing High Detection')
|
||||
print("Recent swing highs:")
|
||||
print(swing_highs)
|
||||
\"\"\")
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import talib
|
||||
except ImportError:
|
||||
talib = None
|
||||
logger.warning("TA-Lib not available in execute_python environment")
|
||||
|
||||
# --- Attempt to load chart data ---
|
||||
df = None
|
||||
chart_context = None
|
||||
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if registry and datasource_registry:
|
||||
try:
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
bars = result.bars
|
||||
if bars:
|
||||
rows = []
|
||||
for bar in bars:
|
||||
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
|
||||
df = pd.DataFrame(rows).set_index('time')
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
|
||||
except Exception as e:
|
||||
logger.info(f"execute_python: no chart data loaded ({e})")
|
||||
|
||||
# --- Import chart utilities ---
|
||||
from .chart_utils import plot_ohlc
|
||||
|
||||
# --- Get indicator registry ---
|
||||
indicator_registry = _get_indicator_registry()
|
||||
|
||||
# --- Build globals ---
|
||||
script_globals: Dict[str, Any] = {
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'plt': plt,
|
||||
'talib': talib,
|
||||
'indicator_registry': indicator_registry,
|
||||
'df': df,
|
||||
'chart_context': chart_context,
|
||||
'plot_ohlc': plot_ohlc,
|
||||
}
|
||||
|
||||
# --- Execute ---
|
||||
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stdout_capture = io.StringIO()
|
||||
result_df = None
|
||||
error_msg = None
|
||||
plot_urls = []
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
|
||||
exec(code, script_globals)
|
||||
|
||||
# Capture last expression
|
||||
lines = code.strip().splitlines()
|
||||
if lines:
|
||||
last = lines[-1].strip()
|
||||
if last and not any(last.startswith(kw) for kw in (
|
||||
'if', 'for', 'while', 'def', 'class', 'import',
|
||||
'from', 'with', 'try', 'return', '#'
|
||||
)):
|
||||
try:
|
||||
last_val = eval(last, script_globals)
|
||||
if isinstance(last_val, pd.DataFrame):
|
||||
result_df = last_val
|
||||
elif last_val is not None:
|
||||
stdout_capture.write(str(last_val))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save plots
|
||||
for fig_num in plt.get_fignums():
|
||||
fig = plt.figure(fig_num)
|
||||
filename = f"plot_{uuid.uuid4()}.png"
|
||||
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
|
||||
plot_urls.append(f"/uploads/{filename}")
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
||||
|
||||
# --- Build response ---
|
||||
response: Dict[str, Any] = {
|
||||
'script_output': stdout_capture.getvalue(),
|
||||
'chart_context': chart_context,
|
||||
'plot_urls': plot_urls,
|
||||
}
|
||||
if result_df is not None:
|
||||
response['result_dataframe'] = {
|
||||
'columns': result_df.columns.tolist(),
|
||||
'index': result_df.index.astype(str).tolist(),
|
||||
'data': result_df.values.tolist(),
|
||||
'shape': result_df.shape,
|
||||
}
|
||||
if error_msg:
|
||||
response['error'] = error_msg
|
||||
|
||||
return response
|
||||
|
||||
|
||||
CHART_TOOLS = [
|
||||
get_chart_data,
|
||||
execute_python
|
||||
]
|
||||
Reference in New Issue
Block a user