indicator integration
This commit is contained in:
@@ -1,7 +1,18 @@
|
||||
"""Technical indicator tools."""
|
||||
"""Technical indicator tools.
|
||||
|
||||
These tools allow the agent to:
|
||||
1. Discover available indicators (list, search, get info)
|
||||
2. Add indicators to the chart
|
||||
3. Update/remove indicators
|
||||
4. Query currently applied indicators
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
@@ -10,6 +21,20 @@ def _get_indicator_registry():
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global sync registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_indicator_store():
|
||||
"""Get the global IndicatorStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "IndicatorStore" in registry.entries:
|
||||
return registry.entries["IndicatorStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def list_indicators() -> List[str]:
|
||||
"""List all available technical indicators.
|
||||
@@ -161,9 +186,250 @@ def get_indicator_categories() -> Dict[str, int]:
|
||||
return categories
|
||||
|
||||
|
||||
@tool
|
||||
async def add_indicator_to_chart(
|
||||
indicator_id: str,
|
||||
talib_name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a technical indicator to the chart.
|
||||
|
||||
This will create a new indicator instance and display it on the TradingView chart.
|
||||
The indicator will be synchronized with the frontend in real-time.
|
||||
|
||||
Args:
|
||||
indicator_id: Unique identifier for this indicator instance (e.g., 'rsi_14', 'sma_50')
|
||||
talib_name: Name of the TA-Lib indicator (e.g., 'RSI', 'SMA', 'MACD', 'BBANDS')
|
||||
Use search_indicators() or get_indicator_info() to find available indicators
|
||||
parameters: Optional dictionary of indicator parameters
|
||||
Example for RSI: {'timeperiod': 14}
|
||||
Example for SMA: {'timeperiod': 50}
|
||||
Example for MACD: {'fastperiod': 12, 'slowperiod': 26, 'signalperiod': 9}
|
||||
Example for BBANDS: {'timeperiod': 20, 'nbdevup': 2, 'nbdevdn': 2}
|
||||
symbol: Optional symbol to apply the indicator to (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- indicator: The complete indicator object
|
||||
|
||||
Example:
|
||||
# Add RSI(14)
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='rsi_14',
|
||||
talib_name='RSI',
|
||||
parameters={'timeperiod': 14}
|
||||
)
|
||||
|
||||
# Add 50-period SMA
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='sma_50',
|
||||
talib_name='SMA',
|
||||
parameters={'timeperiod': 50}
|
||||
)
|
||||
|
||||
# Add MACD with default parameters
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='macd_default',
|
||||
talib_name='MACD'
|
||||
)
|
||||
"""
|
||||
from schema.indicator import IndicatorInstance
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
# Verify the indicator exists
|
||||
indicator_registry = _get_indicator_registry()
|
||||
if not indicator_registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = indicator_registry.get_metadata(talib_name)
|
||||
if not metadata:
|
||||
raise ValueError(
|
||||
f"Indicator '{talib_name}' not found. "
|
||||
f"Use search_indicators() to find available indicators."
|
||||
)
|
||||
|
||||
# Check if updating existing indicator
|
||||
existing_indicator = indicator_store.indicators.get(indicator_id)
|
||||
is_update = existing_indicator is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for indicator: {symbol}")
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create indicator instance
|
||||
indicator = IndicatorInstance(
|
||||
id=indicator_id,
|
||||
talib_name=talib_name,
|
||||
instance_name=f"{talib_name}_{indicator_id}",
|
||||
parameters=parameters or {},
|
||||
visible=True,
|
||||
pane='chart', # Most indicators go on the chart pane
|
||||
symbol=symbol,
|
||||
created_at=existing_indicator.get('created_at') if existing_indicator else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
indicator_store.indicators[indicator_id] = indicator.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} indicator '{indicator_id}' "
|
||||
f"(TA-Lib: {talib_name}) with parameters: {parameters}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"indicator": indicator.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def remove_indicator_from_chart(indicator_id: str) -> Dict[str, str]:
|
||||
"""Remove an indicator from the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance to remove
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
await remove_indicator_from_chart('rsi_14')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
if indicator_id not in indicator_store.indicators:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
# Delete the indicator
|
||||
del indicator_store.indicators[indicator_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Removed indicator '{indicator_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Indicator '{indicator_id}' removed"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_chart_indicators(symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List all indicators currently applied to the chart.
|
||||
|
||||
Args:
|
||||
symbol: Optional filter by symbol (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
List of indicator instances, each containing:
|
||||
- id: Indicator instance ID
|
||||
- talib_name: TA-Lib indicator name
|
||||
- instance_name: Display name
|
||||
- parameters: Current parameter values
|
||||
- visible: Whether indicator is visible
|
||||
- pane: Which pane it's displayed in
|
||||
- symbol: Symbol it's applied to
|
||||
|
||||
Example:
|
||||
# List all indicators on current symbol
|
||||
indicators = list_chart_indicators()
|
||||
|
||||
# List indicators on specific symbol
|
||||
btc_indicators = list_chart_indicators(symbol='BINANCE:BTC/USDT')
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
logger.info(f"list_chart_indicators: Raw store indicators: {indicator_store.indicators}")
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None:
|
||||
registry = _get_registry()
|
||||
if registry and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
|
||||
indicators = list(indicator_store.indicators.values())
|
||||
|
||||
logger.info(f"list_chart_indicators: Converted to list: {indicators}")
|
||||
logger.info(f"list_chart_indicators: Filtering by symbol: {symbol}")
|
||||
|
||||
# Filter by symbol if provided
|
||||
if symbol:
|
||||
indicators = [ind for ind in indicators if ind.get('symbol') == symbol]
|
||||
|
||||
logger.info(f"list_chart_indicators: Returning {len(indicators)} indicators")
|
||||
return indicators
|
||||
|
||||
|
||||
@tool
|
||||
def get_chart_indicator(indicator_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific indicator on the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing the indicator data
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
indicator = get_chart_indicator('rsi_14')
|
||||
print(f"Indicator: {indicator['talib_name']}")
|
||||
print(f"Parameters: {indicator['parameters']}")
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
indicator = indicator_store.indicators.get(indicator_id)
|
||||
if not indicator:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
return indicator
|
||||
|
||||
|
||||
INDICATOR_TOOLS = [
|
||||
# Discovery tools
|
||||
list_indicators,
|
||||
get_indicator_info,
|
||||
search_indicators,
|
||||
get_indicator_categories
|
||||
get_indicator_categories,
|
||||
# Chart indicator management tools
|
||||
add_indicator_to_chart,
|
||||
remove_indicator_from_chart,
|
||||
list_chart_indicators,
|
||||
get_chart_indicator
|
||||
]
|
||||
|
||||
@@ -149,6 +149,10 @@ from .talib_adapter import (
|
||||
is_talib_available,
|
||||
get_talib_version,
|
||||
)
|
||||
from .custom_indicators import (
|
||||
register_custom_indicators,
|
||||
CUSTOM_INDICATORS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
@@ -169,4 +173,7 @@ __all__ = [
|
||||
"register_all_talib_indicators",
|
||||
"is_talib_available",
|
||||
"get_talib_version",
|
||||
# Custom indicators
|
||||
"register_custom_indicators",
|
||||
"CUSTOM_INDICATORS",
|
||||
]
|
||||
|
||||
954
backend/src/indicator/custom_indicators.py
Normal file
954
backend/src/indicator/custom_indicators.py
Normal file
@@ -0,0 +1,954 @@
|
||||
"""
|
||||
Custom indicator implementations for TradingView indicators not in TA-Lib.
|
||||
|
||||
These indicators follow TA-Lib style conventions and integrate seamlessly
|
||||
with the indicator framework. All implementations are based on well-known,
|
||||
publicly documented formulas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VWAP(Indicator):
|
||||
"""Volume Weighted Average Price - Most widely used institutional indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWAP",
|
||||
display_name="VWAP",
|
||||
description="Volume Weighted Average Price - Average price weighted by volume",
|
||||
category="volume",
|
||||
parameters=[],
|
||||
use_cases=[
|
||||
"Institutional reference price",
|
||||
"Support/resistance levels",
|
||||
"Mean reversion trading"
|
||||
],
|
||||
references=["https://www.investopedia.com/terms/v/vwap.asp"],
|
||||
tags=["vwap", "volume", "institutional"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwap", type="float", description="Volume Weighted Average Price", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
|
||||
# Typical price
|
||||
typical_price = (high + low + close) / 3.0
|
||||
|
||||
# VWAP = cumsum(typical_price * volume) / cumsum(volume)
|
||||
cumulative_tp_vol = np.nancumsum(typical_price * volume)
|
||||
cumulative_vol = np.nancumsum(volume)
|
||||
|
||||
vwap = cumulative_tp_vol / cumulative_vol
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwap": float(vwap[i]) if not np.isnan(vwap[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VWMA(Indicator):
|
||||
"""Volume Weighted Moving Average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWMA",
|
||||
display_name="VWMA",
|
||||
description="Volume Weighted Moving Average - Moving average weighted by volume",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volume-aware trend following", "Dynamic support/resistance"],
|
||||
references=["https://www.investopedia.com/articles/trading/11/trading-with-vwap-mvwap.asp"],
|
||||
tags=["vwma", "volume", "moving average"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwma", type="float", description="Volume Weighted Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
vwma = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
window_close = close[i - length + 1:i + 1]
|
||||
window_volume = volume[i - length + 1:i + 1]
|
||||
vwma[i] = np.sum(window_close * window_volume) / np.sum(window_volume)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwma": float(vwma[i]) if not np.isnan(vwma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class HullMA(Indicator):
|
||||
"""Hull Moving Average - Fast and smooth moving average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="HMA",
|
||||
display_name="Hull Moving Average",
|
||||
description="Hull Moving Average - Reduces lag while maintaining smoothness",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Low-lag trend following", "Quick trend reversal detection"],
|
||||
references=["https://alanhull.com/hull-moving-average"],
|
||||
tags=["hma", "hull", "moving average", "low-lag"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="hma", type="float", description="Hull Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 9)
|
||||
|
||||
def wma(data, period):
|
||||
"""Weighted Moving Average."""
|
||||
weights = np.arange(1, period + 1)
|
||||
result = np.full_like(data, np.nan)
|
||||
for i in range(period - 1, len(data)):
|
||||
window = data[i - period + 1:i + 1]
|
||||
result[i] = np.sum(weights * window) / np.sum(weights)
|
||||
return result
|
||||
|
||||
# HMA = WMA(2 * WMA(n/2) - WMA(n)), sqrt(n))
|
||||
half_length = length // 2
|
||||
sqrt_length = int(np.sqrt(length))
|
||||
|
||||
wma_half = wma(close, half_length)
|
||||
wma_full = wma(close, length)
|
||||
raw_hma = 2 * wma_half - wma_full
|
||||
hma = wma(raw_hma, sqrt_length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "hma": float(hma[i]) if not np.isnan(hma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class SuperTrend(Indicator):
|
||||
"""SuperTrend - Popular trend following indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="SUPERTREND",
|
||||
display_name="SuperTrend",
|
||||
description="SuperTrend - Volatility-based trend indicator",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=3.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Stop loss placement", "Trend reversal signals"],
|
||||
references=["https://www.investopedia.com/articles/trading/08/supertrend-indicator.asp"],
|
||||
tags=["supertrend", "trend", "volatility"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="supertrend", type="float", description="SuperTrend value", nullable=True),
|
||||
ColumnInfo(name="direction", type="int", description="Trend direction (1=up, -1=down)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 10)
|
||||
multiplier = self.params.get("multiplier", 3.0)
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[length - 1] = np.mean(tr[:length])
|
||||
for i in range(length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (length - 1) + tr[i]) / length
|
||||
|
||||
# Calculate basic bands
|
||||
hl2 = (high + low) / 2
|
||||
basic_upper = hl2 + multiplier * atr
|
||||
basic_lower = hl2 - multiplier * atr
|
||||
|
||||
# Calculate final bands
|
||||
final_upper = np.full_like(close, np.nan)
|
||||
final_lower = np.full_like(close, np.nan)
|
||||
supertrend = np.full_like(close, np.nan)
|
||||
direction = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length, len(close)):
|
||||
if i == length:
|
||||
final_upper[i] = basic_upper[i]
|
||||
final_lower[i] = basic_lower[i]
|
||||
else:
|
||||
final_upper[i] = basic_upper[i] if basic_upper[i] < final_upper[i - 1] or close[i - 1] > final_upper[i - 1] else final_upper[i - 1]
|
||||
final_lower[i] = basic_lower[i] if basic_lower[i] > final_lower[i - 1] or close[i - 1] < final_lower[i - 1] else final_lower[i - 1]
|
||||
|
||||
if i == length:
|
||||
supertrend[i] = final_upper[i] if close[i] <= hl2[i] else final_lower[i]
|
||||
direction[i] = -1 if close[i] <= hl2[i] else 1
|
||||
else:
|
||||
if supertrend[i - 1] == final_upper[i - 1] and close[i] <= final_upper[i]:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
elif supertrend[i - 1] == final_upper[i - 1] and close[i] > final_upper[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
elif supertrend[i - 1] == final_lower[i - 1] and close[i] >= final_lower[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
else:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"supertrend": float(supertrend[i]) if not np.isnan(supertrend[i]) else None,
|
||||
"direction": int(direction[i]) if not np.isnan(direction[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class DonchianChannels(Indicator):
|
||||
"""Donchian Channels - Breakout indicator using highest high and lowest low."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="DONCHIAN",
|
||||
display_name="Donchian Channels",
|
||||
description="Donchian Channels - Highest high and lowest low over period",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Breakout trading", "Volatility bands", "Support/resistance"],
|
||||
references=["https://www.investopedia.com/terms/d/donchianchannels.asp"],
|
||||
tags=["donchian", "channels", "breakout"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper channel", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower channel", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
upper = np.full_like(high, np.nan)
|
||||
lower = np.full_like(low, np.nan)
|
||||
|
||||
for i in range(length - 1, len(high)):
|
||||
upper[i] = np.nanmax(high[i - length + 1:i + 1])
|
||||
lower[i] = np.nanmin(low[i - length + 1:i + 1])
|
||||
|
||||
middle = (upper + lower) / 2
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(middle[i]) if not np.isnan(middle[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class KeltnerChannels(Indicator):
|
||||
"""Keltner Channels - ATR-based volatility bands."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="KELTNER",
|
||||
display_name="Keltner Channels",
|
||||
description="Keltner Channels - EMA with ATR-based bands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="EMA period",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=2.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="atr_length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volatility bands", "Overbought/oversold", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/k/keltnerchannel.asp"],
|
||||
tags=["keltner", "channels", "volatility", "atr"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper band", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line (EMA)", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower band", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 20)
|
||||
multiplier = self.params.get("multiplier", 2.0)
|
||||
atr_length = self.params.get("atr_length", 10)
|
||||
|
||||
# Calculate EMA
|
||||
alpha = 2.0 / (length + 1)
|
||||
ema = np.full_like(close, np.nan)
|
||||
ema[0] = close[0]
|
||||
for i in range(1, len(close)):
|
||||
ema[i] = alpha * close[i] + (1 - alpha) * ema[i - 1]
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[atr_length - 1] = np.mean(tr[:atr_length])
|
||||
for i in range(atr_length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (atr_length - 1) + tr[i]) / atr_length
|
||||
|
||||
upper = ema + multiplier * atr
|
||||
lower = ema - multiplier * atr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(ema[i]) if not np.isnan(ema[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChaikinMoneyFlow(Indicator):
|
||||
"""Chaikin Money Flow - Volume-weighted accumulation/distribution."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CMF",
|
||||
display_name="Chaikin Money Flow",
|
||||
description="Chaikin Money Flow - Measures buying and selling pressure",
|
||||
category="volume",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Buying/selling pressure", "Trend confirmation", "Divergence analysis"],
|
||||
references=["https://www.investopedia.com/terms/c/chaikinoscillator.asp"],
|
||||
tags=["cmf", "chaikin", "volume", "money flow"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="cmf", type="float", description="Chaikin Money Flow", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
# Money Flow Multiplier
|
||||
mfm = ((close - low) - (high - close)) / (high - low)
|
||||
mfm = np.where(high == low, 0, mfm)
|
||||
|
||||
# Money Flow Volume
|
||||
mfv = mfm * volume
|
||||
|
||||
# CMF
|
||||
cmf = np.full_like(close, np.nan)
|
||||
for i in range(length - 1, len(close)):
|
||||
cmf[i] = np.nansum(mfv[i - length + 1:i + 1]) / np.nansum(volume[i - length + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "cmf": float(cmf[i]) if not np.isnan(cmf[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VortexIndicator(Indicator):
|
||||
"""Vortex Indicator - Identifies trend direction and strength."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VORTEX",
|
||||
display_name="Vortex Indicator",
|
||||
description="Vortex Indicator - Trend direction and strength",
|
||||
category="momentum",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Trend reversals", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/v/vortex-indicator-vi.asp"],
|
||||
tags=["vortex", "trend", "momentum"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vi_plus", type="float", description="Positive Vortex", nullable=True),
|
||||
ColumnInfo(name="vi_minus", type="float", description="Negative Vortex", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# Vortex Movement
|
||||
vm_plus = np.abs(high - np.roll(low, 1))
|
||||
vm_minus = np.abs(low - np.roll(high, 1))
|
||||
vm_plus[0] = 0
|
||||
vm_minus[0] = 0
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
# Vortex Indicator
|
||||
vi_plus = np.full_like(close, np.nan)
|
||||
vi_minus = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_vm_plus = np.sum(vm_plus[i - length + 1:i + 1])
|
||||
sum_vm_minus = np.sum(vm_minus[i - length + 1:i + 1])
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
|
||||
if sum_tr != 0:
|
||||
vi_plus[i] = sum_vm_plus / sum_tr
|
||||
vi_minus[i] = sum_vm_minus / sum_tr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"vi_plus": float(vi_plus[i]) if not np.isnan(vi_plus[i]) else None,
|
||||
"vi_minus": float(vi_minus[i]) if not np.isnan(vi_minus[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AwesomeOscillator(Indicator):
|
||||
"""Awesome Oscillator - Bill Williams' momentum indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AO",
|
||||
display_name="Awesome Oscillator",
|
||||
description="Awesome Oscillator - Difference between 5 and 34 period SMAs of midpoint",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Momentum shifts", "Trend reversals", "Divergence trading"],
|
||||
references=["https://www.investopedia.com/terms/a/awesomeoscillator.asp"],
|
||||
tags=["awesome", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ao", type="float", description="Awesome Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# SMA 5
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
# SMA 34
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ao": float(ao[i]) if not np.isnan(ao[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AcceleratorOscillator(Indicator):
|
||||
"""Accelerator Oscillator - Rate of change of Awesome Oscillator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AC",
|
||||
display_name="Accelerator Oscillator",
|
||||
description="Accelerator Oscillator - Rate of change of Awesome Oscillator",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Early momentum detection", "Trend acceleration", "Divergence signals"],
|
||||
references=["https://www.investopedia.com/terms/a/accelerator-oscillator.asp"],
|
||||
tags=["accelerator", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ac", type="float", description="Accelerator Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# Calculate AO first
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
# AC = AO - SMA(AO, 5)
|
||||
sma_ao = np.full_like(ao, np.nan)
|
||||
for i in range(4, len(ao)):
|
||||
if not np.isnan(ao[i - 4:i + 1]).any():
|
||||
sma_ao[i] = np.mean(ao[i - 4:i + 1])
|
||||
|
||||
ac = ao - sma_ao
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ac": float(ac[i]) if not np.isnan(ac[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChoppinessIndex(Indicator):
|
||||
"""Choppiness Index - Determines if market is choppy or trending."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CHOP",
|
||||
display_name="Choppiness Index",
|
||||
description="Choppiness Index - Measures market trendiness vs consolidation",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend vs range identification", "Market regime detection"],
|
||||
references=["https://www.tradingview.com/support/solutions/43000501980/"],
|
||||
tags=["chop", "choppiness", "trend", "range"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="chop", type="float", description="Choppiness Index (0-100)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
chop = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
high_low_diff = np.max(high[i - length + 1:i + 1]) - np.min(low[i - length + 1:i + 1])
|
||||
|
||||
if high_low_diff != 0:
|
||||
chop[i] = 100 * np.log10(sum_tr / high_low_diff) / np.log10(length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "chop": float(chop[i]) if not np.isnan(chop[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class MassIndex(Indicator):
|
||||
"""Mass Index - Identifies trend reversals based on range expansion."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="MASS",
|
||||
display_name="Mass Index",
|
||||
description="Mass Index - Identifies reversals when range narrows then expands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="fast_period",
|
||||
type="int",
|
||||
description="Fast EMA period",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="slow_period",
|
||||
type="int",
|
||||
description="Slow EMA period",
|
||||
default=25,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Reversal detection", "Volatility analysis", "Bulge identification"],
|
||||
references=["https://www.investopedia.com/terms/m/mass-index.asp"],
|
||||
tags=["mass", "index", "volatility", "reversal"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="mass", type="float", description="Mass Index", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
fast_period = self.params.get("fast_period", 9)
|
||||
slow_period = self.params.get("slow_period", 25)
|
||||
|
||||
hl_range = high - low
|
||||
|
||||
# Single EMA
|
||||
alpha1 = 2.0 / (fast_period + 1)
|
||||
ema1 = np.full_like(hl_range, np.nan)
|
||||
ema1[0] = hl_range[0]
|
||||
for i in range(1, len(hl_range)):
|
||||
ema1[i] = alpha1 * hl_range[i] + (1 - alpha1) * ema1[i - 1]
|
||||
|
||||
# Double EMA
|
||||
ema2 = np.full_like(ema1, np.nan)
|
||||
ema2[0] = ema1[0]
|
||||
for i in range(1, len(ema1)):
|
||||
if not np.isnan(ema1[i]):
|
||||
ema2[i] = alpha1 * ema1[i] + (1 - alpha1) * ema2[i - 1]
|
||||
|
||||
# EMA Ratio
|
||||
ema_ratio = ema1 / ema2
|
||||
|
||||
# Mass Index
|
||||
mass = np.full_like(hl_range, np.nan)
|
||||
for i in range(slow_period - 1, len(ema_ratio)):
|
||||
mass[i] = np.nansum(ema_ratio[i - slow_period + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "mass": float(mass[i]) if not np.isnan(mass[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
# Registry of all custom indicators
|
||||
CUSTOM_INDICATORS = [
|
||||
VWAP,
|
||||
VWMA,
|
||||
HullMA,
|
||||
SuperTrend,
|
||||
DonchianChannels,
|
||||
KeltnerChannels,
|
||||
ChaikinMoneyFlow,
|
||||
VortexIndicator,
|
||||
AwesomeOscillator,
|
||||
AcceleratorOscillator,
|
||||
ChoppinessIndex,
|
||||
MassIndex,
|
||||
]
|
||||
|
||||
|
||||
def register_custom_indicators(registry) -> int:
|
||||
"""
|
||||
Register all custom indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
"""
|
||||
registered_count = 0
|
||||
|
||||
for indicator_class in CUSTOM_INDICATORS:
|
||||
try:
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
logger.debug(f"Registered custom indicator: {indicator_class.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicator {indicator_class.__name__}: {e}")
|
||||
|
||||
logger.info(f"Registered {registered_count} custom indicators")
|
||||
return registered_count
|
||||
@@ -372,12 +372,14 @@ def create_talib_indicator_class(func_name: str) -> type:
|
||||
)
|
||||
|
||||
|
||||
def register_all_talib_indicators(registry) -> int:
|
||||
def register_all_talib_indicators(registry, only_tradingview_supported: bool = True) -> int:
|
||||
"""
|
||||
Auto-register all available TA-Lib indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
only_tradingview_supported: If True, only register indicators that have
|
||||
TradingView equivalents (default: True)
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
@@ -392,6 +394,9 @@ def register_all_talib_indicators(registry) -> int:
|
||||
)
|
||||
return 0
|
||||
|
||||
# Get list of supported indicators if filtering is enabled
|
||||
from .tv_mapping import is_indicator_supported
|
||||
|
||||
# Get all TA-Lib functions
|
||||
func_groups = talib.get_function_groups()
|
||||
all_functions = []
|
||||
@@ -402,8 +407,16 @@ def register_all_talib_indicators(registry) -> int:
|
||||
all_functions = sorted(set(all_functions))
|
||||
|
||||
registered_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for func_name in all_functions:
|
||||
try:
|
||||
# Skip if filtering enabled and indicator not supported in TradingView
|
||||
if only_tradingview_supported and not is_indicator_supported(func_name):
|
||||
skipped_count += 1
|
||||
logger.debug(f"Skipping TA-Lib function {func_name} - not supported in TradingView")
|
||||
continue
|
||||
|
||||
# Create indicator class for this function
|
||||
indicator_class = create_talib_indicator_class(func_name)
|
||||
|
||||
@@ -415,7 +428,7 @@ def register_all_talib_indicators(registry) -> int:
|
||||
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators")
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators (skipped {skipped_count} unsupported)")
|
||||
return registered_count
|
||||
|
||||
|
||||
|
||||
360
backend/src/indicator/tv_mapping.py
Normal file
360
backend/src/indicator/tv_mapping.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Mapping layer between TA-Lib indicators and TradingView indicators.
|
||||
|
||||
This module provides bidirectional conversion between our internal TA-Lib-based
|
||||
indicator representation and TradingView's indicator system.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib indicator names to TradingView indicator names
|
||||
# Only includes indicators that are present in BOTH systems (inner join)
|
||||
# Format: {talib_name: tv_name}
|
||||
TALIB_TO_TV_NAMES = {
|
||||
# Overlap Studies (14)
|
||||
"SMA": "Moving Average",
|
||||
"EMA": "Moving Average Exponential",
|
||||
"WMA": "Weighted Moving Average",
|
||||
"DEMA": "DEMA",
|
||||
"TEMA": "TEMA",
|
||||
"TRIMA": "Triangular Moving Average",
|
||||
"KAMA": "KAMA",
|
||||
"MAMA": "MESA Adaptive Moving Average",
|
||||
"T3": "T3",
|
||||
"BBANDS": "Bollinger Bands",
|
||||
"MIDPOINT": "Midpoint",
|
||||
"MIDPRICE": "Midprice",
|
||||
"SAR": "Parabolic SAR",
|
||||
"HT_TRENDLINE": "Hilbert Transform - Instantaneous Trendline",
|
||||
|
||||
# Momentum Indicators (21)
|
||||
"RSI": "Relative Strength Index",
|
||||
"MOM": "Momentum",
|
||||
"ROC": "Rate of Change",
|
||||
"TRIX": "TRIX",
|
||||
"CMO": "Chande Momentum Oscillator",
|
||||
"DX": "Directional Movement Index",
|
||||
"ADX": "Average Directional Movement Index",
|
||||
"ADXR": "Average Directional Movement Index Rating",
|
||||
"APO": "Absolute Price Oscillator",
|
||||
"PPO": "Percentage Price Oscillator",
|
||||
"MACD": "MACD",
|
||||
"MFI": "Money Flow Index",
|
||||
"STOCH": "Stochastic",
|
||||
"STOCHF": "Stochastic Fast",
|
||||
"STOCHRSI": "Stochastic RSI",
|
||||
"WILLR": "Williams %R",
|
||||
"CCI": "Commodity Channel Index",
|
||||
"AROON": "Aroon",
|
||||
"AROONOSC": "Aroon Oscillator",
|
||||
"BOP": "Balance Of Power",
|
||||
"ULTOSC": "Ultimate Oscillator",
|
||||
|
||||
# Volume Indicators (3)
|
||||
"AD": "Chaikin A/D Line",
|
||||
"ADOSC": "Chaikin A/D Oscillator",
|
||||
"OBV": "On Balance Volume",
|
||||
|
||||
# Volatility Indicators (3)
|
||||
"ATR": "Average True Range",
|
||||
"NATR": "Normalized Average True Range",
|
||||
"TRANGE": "True Range",
|
||||
|
||||
# Price Transform (4)
|
||||
"AVGPRICE": "Average Price",
|
||||
"MEDPRICE": "Median Price",
|
||||
"TYPPRICE": "Typical Price",
|
||||
"WCLPRICE": "Weighted Close Price",
|
||||
|
||||
# Cycle Indicators (5)
|
||||
"HT_DCPERIOD": "Hilbert Transform - Dominant Cycle Period",
|
||||
"HT_DCPHASE": "Hilbert Transform - Dominant Cycle Phase",
|
||||
"HT_PHASOR": "Hilbert Transform - Phasor Components",
|
||||
"HT_SINE": "Hilbert Transform - SineWave",
|
||||
"HT_TRENDMODE": "Hilbert Transform - Trend vs Cycle Mode",
|
||||
|
||||
# Statistic Functions (9)
|
||||
"BETA": "Beta",
|
||||
"CORREL": "Pearson's Correlation Coefficient",
|
||||
"LINEARREG": "Linear Regression",
|
||||
"LINEARREG_ANGLE": "Linear Regression Angle",
|
||||
"LINEARREG_INTERCEPT": "Linear Regression Intercept",
|
||||
"LINEARREG_SLOPE": "Linear Regression Slope",
|
||||
"STDDEV": "Standard Deviation",
|
||||
"TSF": "Time Series Forecast",
|
||||
"VAR": "Variance",
|
||||
}
|
||||
|
||||
# Total: 60 indicators supported in both systems
|
||||
|
||||
# Custom indicators (TradingView indicators implemented in our backend)
|
||||
CUSTOM_TO_TV_NAMES = {
|
||||
"VWAP": "VWAP",
|
||||
"VWMA": "VWMA",
|
||||
"HMA": "Hull Moving Average",
|
||||
"SUPERTREND": "SuperTrend",
|
||||
"DONCHIAN": "Donchian Channels",
|
||||
"KELTNER": "Keltner Channels",
|
||||
"CMF": "Chaikin Money Flow",
|
||||
"VORTEX": "Vortex Indicator",
|
||||
"AO": "Awesome Oscillator",
|
||||
"AC": "Accelerator Oscillator",
|
||||
"CHOP": "Choppiness Index",
|
||||
"MASS": "Mass Index",
|
||||
}
|
||||
|
||||
# Combined mapping (TA-Lib + Custom)
|
||||
ALL_BACKEND_TO_TV_NAMES = {**TALIB_TO_TV_NAMES, **CUSTOM_TO_TV_NAMES}
|
||||
|
||||
# Total: 72 indicators (60 TA-Lib + 12 Custom)
|
||||
|
||||
# Reverse mapping
|
||||
TV_TO_TALIB_NAMES = {v: k for k, v in TALIB_TO_TV_NAMES.items()}
|
||||
TV_TO_CUSTOM_NAMES = {v: k for k, v in CUSTOM_TO_TV_NAMES.items()}
|
||||
TV_TO_BACKEND_NAMES = {v: k for k, v in ALL_BACKEND_TO_TV_NAMES.items()}
|
||||
|
||||
|
||||
def get_tv_indicator_name(talib_name: str) -> str:
|
||||
"""
|
||||
Convert TA-Lib indicator name to TradingView indicator name.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name (e.g., 'RSI')
|
||||
|
||||
Returns:
|
||||
TradingView indicator name
|
||||
"""
|
||||
return TALIB_TO_TV_NAMES.get(talib_name, talib_name)
|
||||
|
||||
|
||||
def get_talib_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Convert TradingView indicator name to TA-Lib indicator name.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
TA-Lib indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_TALIB_NAMES.get(tv_name)
|
||||
|
||||
|
||||
def convert_talib_params_to_tv_inputs(
|
||||
talib_name: str,
|
||||
talib_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert TA-Lib parameters to TradingView input format.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
talib_params: TA-Lib parameter dictionary
|
||||
|
||||
Returns:
|
||||
TradingView inputs dictionary
|
||||
"""
|
||||
tv_inputs = {}
|
||||
|
||||
# Common parameter mappings
|
||||
param_mapping = {
|
||||
"timeperiod": "length",
|
||||
"fastperiod": "fastLength",
|
||||
"slowperiod": "slowLength",
|
||||
"signalperiod": "signalLength",
|
||||
"nbdevup": "mult", # Standard deviations for upper band
|
||||
"nbdevdn": "mult", # Standard deviations for lower band
|
||||
"fastlimit": "fastLimit",
|
||||
"slowlimit": "slowLimit",
|
||||
"acceleration": "start",
|
||||
"maximum": "increment",
|
||||
"fastk_period": "kPeriod",
|
||||
"slowk_period": "kPeriod",
|
||||
"slowd_period": "dPeriod",
|
||||
"fastd_period": "dPeriod",
|
||||
"matype": "maType",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
tv_inputs["mult"] = talib_params.get("nbdevup", 2)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
tv_inputs["fastLength"] = talib_params.get("fastperiod", 12)
|
||||
tv_inputs["slowLength"] = talib_params.get("slowperiod", 26)
|
||||
tv_inputs["signalLength"] = talib_params.get("signalperiod", 9)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
tv_inputs["kPeriod"] = talib_params.get("fastk_period", 14)
|
||||
tv_inputs["dPeriod"] = talib_params.get("slowd_period", 3)
|
||||
tv_inputs["smoothK"] = talib_params.get("slowk_period", 3)
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for talib_param, value in talib_params.items():
|
||||
tv_param = param_mapping.get(talib_param, talib_param)
|
||||
tv_inputs[tv_param] = value
|
||||
|
||||
logger.debug(f"Converted TA-Lib params for {talib_name}: {talib_params} -> TV inputs: {tv_inputs}")
|
||||
return tv_inputs
|
||||
|
||||
|
||||
def convert_tv_inputs_to_talib_params(
|
||||
tv_name: str,
|
||||
tv_inputs: Dict[str, Any]
|
||||
) -> Tuple[Optional[str], Dict[str, Any]]:
|
||||
"""
|
||||
Convert TradingView inputs to TA-Lib parameters.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
tv_inputs: TradingView inputs dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (talib_name, talib_params)
|
||||
"""
|
||||
talib_name = get_talib_indicator_name(tv_name)
|
||||
if not talib_name:
|
||||
logger.warning(f"No TA-Lib mapping for TradingView indicator: {tv_name}")
|
||||
return None, {}
|
||||
|
||||
talib_params = {}
|
||||
|
||||
# Reverse parameter mappings
|
||||
reverse_mapping = {
|
||||
"length": "timeperiod",
|
||||
"fastLength": "fastperiod",
|
||||
"slowLength": "slowperiod",
|
||||
"signalLength": "signalperiod",
|
||||
"mult": "nbdevup", # Use same for both up and down
|
||||
"fastLimit": "fastlimit",
|
||||
"slowLimit": "slowlimit",
|
||||
"start": "acceleration",
|
||||
"increment": "maximum",
|
||||
"kPeriod": "fastk_period",
|
||||
"dPeriod": "slowd_period",
|
||||
"smoothK": "slowk_period",
|
||||
"maType": "matype",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
talib_params["nbdevup"] = tv_inputs.get("mult", 2)
|
||||
talib_params["nbdevdn"] = tv_inputs.get("mult", 2)
|
||||
talib_params["matype"] = 0 # SMA
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
talib_params["fastperiod"] = tv_inputs.get("fastLength", 12)
|
||||
talib_params["slowperiod"] = tv_inputs.get("slowLength", 26)
|
||||
talib_params["signalperiod"] = tv_inputs.get("signalLength", 9)
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
talib_params["fastk_period"] = tv_inputs.get("kPeriod", 14)
|
||||
talib_params["slowd_period"] = tv_inputs.get("dPeriod", 3)
|
||||
talib_params["slowk_period"] = tv_inputs.get("smoothK", 3)
|
||||
talib_params["slowk_matype"] = 0 # SMA
|
||||
talib_params["slowd_matype"] = 0 # SMA
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for tv_param, value in tv_inputs.items():
|
||||
if tv_param == "source":
|
||||
continue # Skip source parameter
|
||||
talib_param = reverse_mapping.get(tv_param, tv_param)
|
||||
talib_params[talib_param] = value
|
||||
|
||||
logger.debug(f"Converted TV inputs for {tv_name}: {tv_inputs} -> TA-Lib {talib_name} params: {talib_params}")
|
||||
return talib_name, talib_params
|
||||
|
||||
|
||||
def is_indicator_supported(talib_name: str) -> bool:
|
||||
"""
|
||||
Check if a TA-Lib indicator is supported in TradingView.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
|
||||
Returns:
|
||||
True if supported
|
||||
"""
|
||||
return talib_name in TALIB_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_supported_indicators() -> List[str]:
|
||||
"""
|
||||
Get list of supported TA-Lib indicators.
|
||||
|
||||
Returns:
|
||||
List of TA-Lib indicator names
|
||||
"""
|
||||
return list(TALIB_TO_TV_NAMES.keys())
|
||||
|
||||
|
||||
def get_supported_indicator_count() -> int:
|
||||
"""
|
||||
Get count of supported indicators.
|
||||
|
||||
Returns:
|
||||
Number of indicators supported in both systems (TA-Lib + Custom)
|
||||
"""
|
||||
return len(ALL_BACKEND_TO_TV_NAMES)
|
||||
|
||||
|
||||
def is_custom_indicator(indicator_name: str) -> bool:
|
||||
"""
|
||||
Check if an indicator is a custom implementation (not TA-Lib).
|
||||
|
||||
Args:
|
||||
indicator_name: Indicator name
|
||||
|
||||
Returns:
|
||||
True if custom indicator
|
||||
"""
|
||||
return indicator_name in CUSTOM_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_backend_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get backend indicator name from TradingView name (TA-Lib or custom).
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
Backend indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_BACKEND_NAMES.get(tv_name)
|
||||
@@ -24,11 +24,12 @@ from agent.tools import set_registry, set_datasource_registry, set_indicator_reg
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from schema.shape import ShapeCollection
|
||||
from schema.indicator import IndicatorCollection
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
from secrets_manager import SecretsStore, InvalidMasterPassword
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -93,6 +94,13 @@ async def lifespan(app: FastAPI):
|
||||
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
||||
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
||||
|
||||
# Register custom indicators (TradingView indicators not in TA-Lib)
|
||||
try:
|
||||
custom_count = register_custom_indicators(indicator_registry)
|
||||
logger.info(f"Registered {custom_count} custom indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicators: {e}")
|
||||
|
||||
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
||||
anthropic_api_key = None
|
||||
|
||||
@@ -164,15 +172,21 @@ class ChartStore(BaseModel):
|
||||
class ShapeStore(BaseModel):
|
||||
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
|
||||
|
||||
# IndicatorStore model for synchronization
|
||||
class IndicatorStore(BaseModel):
|
||||
indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID
|
||||
|
||||
# Initialize stores
|
||||
order_store = OrderStore()
|
||||
chart_store = ChartStore()
|
||||
shape_store = ShapeStore()
|
||||
indicator_store = IndicatorStore()
|
||||
|
||||
# Register with SyncRegistry
|
||||
registry.register(order_store, store_name="OrderStore")
|
||||
registry.register(chart_store, store_name="ChartStore")
|
||||
registry.register(shape_store, store_name="ShapeStore")
|
||||
registry.register(indicator_store, store_name="IndicatorStore")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
40
backend/src/schema/indicator.py
Normal file
40
backend/src/schema/indicator.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IndicatorInstance(BaseModel):
|
||||
"""
|
||||
Represents an instance of an indicator applied to a chart.
|
||||
|
||||
This schema holds both the TA-Lib metadata and TradingView-specific data
|
||||
needed for synchronization.
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for this indicator instance")
|
||||
|
||||
# TA-Lib metadata
|
||||
talib_name: str = Field(..., description="TA-Lib indicator name (e.g., 'RSI', 'SMA', 'MACD')")
|
||||
instance_name: str = Field(..., description="User-friendly instance name")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="TA-Lib indicator parameters")
|
||||
|
||||
# TradingView metadata
|
||||
tv_study_id: Optional[str] = Field(default=None, description="TradingView study ID assigned by the chart widget")
|
||||
tv_indicator_name: Optional[str] = Field(default=None, description="TradingView indicator name if different from TA-Lib")
|
||||
tv_inputs: Optional[Dict[str, Any]] = Field(default=None, description="TradingView-specific input parameters")
|
||||
|
||||
# Visual properties
|
||||
visible: bool = Field(default=True, description="Whether indicator is visible on chart")
|
||||
pane: str = Field(default="chart", description="Pane where indicator is displayed ('chart' or 'separate')")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this indicator is applied to")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class IndicatorCollection(BaseModel):
|
||||
"""Collection of all indicator instances on the chart."""
|
||||
indicators: Dict[str, IndicatorInstance] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary of indicator instances keyed by ID"
|
||||
)
|
||||
@@ -116,6 +116,10 @@ class SyncRegistry:
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Verify the model was actually updated
|
||||
updated_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
|
||||
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
@@ -206,7 +210,37 @@ class SyncRegistry:
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
||||
# Update model using model_validate for potentially nested models
|
||||
new_model = model.__class__.model_validate(new_data)
|
||||
for field in model.model_fields:
|
||||
setattr(model, field, getattr(new_model, field))
|
||||
# Update model fields in-place to preserve references
|
||||
# This is important for dict fields that may be referenced elsewhere
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
if field_name in new_data:
|
||||
new_value = new_data[field_name]
|
||||
current_value = getattr(model, field_name)
|
||||
|
||||
# For dict fields, update in-place instead of replacing
|
||||
if isinstance(current_value, dict) and isinstance(new_value, dict):
|
||||
self._deep_update_dict(current_value, new_value)
|
||||
else:
|
||||
# For other types, just set the new value
|
||||
setattr(model, field_name, new_value)
|
||||
|
||||
def _deep_update_dict(self, target: dict, source: dict):
|
||||
"""Deep update target dict with source dict, preserving nested dict references."""
|
||||
# Remove keys that are in target but not in source
|
||||
keys_to_remove = set(target.keys()) - set(source.keys())
|
||||
for key in keys_to_remove:
|
||||
del target[key]
|
||||
|
||||
# Update or add keys from source
|
||||
for key, source_value in source.items():
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# If both are dicts, recursively update
|
||||
if isinstance(target_value, dict) and isinstance(source_value, dict):
|
||||
self._deep_update_dict(target_value, source_value)
|
||||
else:
|
||||
# Replace the value
|
||||
target[key] = source_value
|
||||
else:
|
||||
# Add new key
|
||||
target[key] = source_value
|
||||
|
||||
Reference in New Issue
Block a user