backend redesign

This commit is contained in:
2026-03-11 18:47:11 -04:00
parent 8ff277c8c6
commit e99ef5d2dd
210 changed files with 12147 additions and 155 deletions

View File

@@ -0,0 +1,179 @@
"""
Composable Indicator System.
Provides a framework for building DAGs of data transformation pipelines
that process time-series data incrementally. Indicators can consume
DataSources or other Indicators as inputs, composing into arbitrarily
complex processing graphs.
Key Components:
---------------
Indicator (base.py):
Abstract base class for all indicator implementations.
Declares input/output schemas and implements synchronous compute().
IndicatorRegistry (registry.py):
Central catalog of available indicators with rich metadata
for AI agent discovery and tool generation.
Pipeline (pipeline.py):
Execution engine that builds DAGs, resolves dependencies,
and orchestrates incremental data flow through indicator chains.
Schema Types (schema.py):
Type definitions for input/output schemas, computation context,
and metadata for AI-native documentation.
Usage Example:
--------------
from indicator import Indicator, IndicatorRegistry, Pipeline
from indicator.schema import (
InputSchema, OutputSchema, ComputeContext, ComputeResult,
IndicatorMetadata, IndicatorParameter
)
# Define an indicator
class SimpleMovingAverage(Indicator):
@classmethod
def get_metadata(cls):
return IndicatorMetadata(
name="SMA",
display_name="Simple Moving Average",
description="Arithmetic mean of prices over N periods",
category="trend",
parameters=[
IndicatorParameter(
name="period",
type="int",
description="Number of periods to average",
default=20,
min_value=1
)
],
tags=["moving-average", "trend-following"]
)
@classmethod
def get_input_schema(cls):
return InputSchema(
required_columns=[
ColumnInfo(name="close", type="float", description="Closing price")
]
)
@classmethod
def get_output_schema(cls, **params):
return OutputSchema(
columns=[
ColumnInfo(
name="sma",
type="float",
description=f"Simple moving average over {params.get('period', 20)} periods"
)
]
)
def compute(self, context: ComputeContext) -> ComputeResult:
period = self.params["period"]
closes = context.get_column("close")
times = context.get_times()
sma_values = []
for i in range(len(closes)):
if i < period - 1:
sma_values.append(None)
else:
window = closes[i - period + 1 : i + 1]
sma_values.append(sum(window) / period)
return ComputeResult(
data=[
{"time": times[i], "sma": sma_values[i]}
for i in range(len(times))
]
)
# Register the indicator
registry = IndicatorRegistry()
registry.register(SimpleMovingAverage)
# Create a pipeline
pipeline = Pipeline(datasource_registry)
pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D")
sma_indicator = registry.create_instance("SMA", "sma_20", period=20)
pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"])
# Execute
results = pipeline.execute(datasource_data={"price_data": price_bars})
sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma
Design Philosophy:
------------------
1. **Schema-based composition**: Indicators declare inputs/outputs via schemas,
enabling automatic validation and flexible composition.
2. **Synchronous execution**: All computation is synchronous for simplicity.
Async handling happens at the event/strategy layer.
3. **Incremental updates**: Indicators receive context about what changed,
allowing optimized recomputation of only affected values.
4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs
make indicators discoverable and usable by AI agents.
5. **Generic data flow**: Indicators work with any data source that matches
their input schema, not specific DataSource instances.
6. **Event-driven**: Designed to react to DataSource updates and propagate
changes through the DAG efficiently.
"""
from .base import DataSourceAdapter, Indicator
from .pipeline import Pipeline, PipelineNode
from .registry import IndicatorRegistry
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
IndicatorParameter,
InputSchema,
OutputSchema,
)
from .talib_adapter import (
TALibIndicator,
register_all_talib_indicators,
is_talib_available,
get_talib_version,
)
from .custom_indicators import (
register_custom_indicators,
CUSTOM_INDICATORS,
)
__all__ = [
# Core classes
"Indicator",
"IndicatorRegistry",
"Pipeline",
"PipelineNode",
"DataSourceAdapter",
# Schema types
"InputSchema",
"OutputSchema",
"ComputeContext",
"ComputeResult",
"IndicatorMetadata",
"IndicatorParameter",
# TA-Lib integration
"TALibIndicator",
"register_all_talib_indicators",
"is_talib_available",
"get_talib_version",
# Custom indicators
"register_custom_indicators",
"CUSTOM_INDICATORS",
]

View File

@@ -0,0 +1,230 @@
"""
Abstract Indicator interface.
Provides the base class for all technical indicators and derived data transformations.
Indicators compose into DAGs, processing data incrementally as updates arrive.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
InputSchema,
OutputSchema,
)
class Indicator(ABC):
"""
Abstract base class for all indicators.
Indicators are composable transformation nodes that:
- Declare input schema (columns they need)
- Declare output schema (columns they produce)
- Compute outputs synchronously from inputs
- Support incremental updates (process only what changed)
- Provide rich metadata for AI agent discovery
Indicators are stateless at the instance level - all state is managed
by the pipeline execution engine. This allows the same indicator class
to be reused with different parameters.
"""
def __init__(self, instance_name: str, **params):
"""
Initialize an indicator instance.
Args:
instance_name: Unique name for this instance (used for output column prefixing)
**params: Configuration parameters (validated against metadata.parameters)
"""
self.instance_name = instance_name
self.params = params
self._validate_params()
@classmethod
@abstractmethod
def get_metadata(cls) -> IndicatorMetadata:
"""
Get metadata for this indicator class.
Called by the registry for AI agent discovery and documentation.
Should return comprehensive information about the indicator's purpose,
parameters, and use cases.
Returns:
IndicatorMetadata describing this indicator class
"""
pass
@classmethod
@abstractmethod
def get_input_schema(cls) -> InputSchema:
"""
Get the input schema required by this indicator.
Declares what columns must be present in the input data.
The pipeline will match this against available data sources.
Returns:
InputSchema describing required and optional input columns
"""
pass
@classmethod
@abstractmethod
def get_output_schema(cls, **params) -> OutputSchema:
"""
Get the output schema produced by this indicator.
Output column names will be automatically prefixed with the instance name
by the pipeline engine.
Args:
**params: Configuration parameters (may affect output schema)
Returns:
OutputSchema describing the columns this indicator produces
"""
pass
@abstractmethod
def compute(self, context: ComputeContext) -> ComputeResult:
"""
Compute indicator values from input data.
This method is called synchronously by the pipeline engine whenever
input data changes. Implementations should:
1. Extract needed columns from context.data
2. Perform calculations
3. Return results with proper time alignment
For incremental updates (context.is_incremental == True):
- context.data contains only new/updated rows
- Implementations MAY optimize by computing only these rows
- OR implementations MAY recompute everything (simpler but slower)
Args:
context: Input data and update metadata
Returns:
ComputeResult with calculated indicator values
Raises:
ValueError: If input data doesn't match expected schema
"""
pass
def _validate_params(self) -> None:
"""
Validate that provided parameters match the metadata specification.
Raises:
ValueError: If required parameters are missing or invalid
"""
metadata = self.get_metadata()
# Check for required parameters
for param_def in metadata.parameters:
if param_def.required and param_def.name not in self.params:
raise ValueError(
f"Indicator '{metadata.name}' requires parameter '{param_def.name}'"
)
# Validate parameter types and ranges
for name, value in self.params.items():
# Find parameter definition
param_def = next(
(p for p in metadata.parameters if p.name == name),
None
)
if param_def is None:
raise ValueError(
f"Unknown parameter '{name}' for indicator '{metadata.name}'"
)
# Type checking
if param_def.type == "int" and not isinstance(value, int):
raise ValueError(
f"Parameter '{name}' must be int, got {type(value).__name__}"
)
elif param_def.type == "float" and not isinstance(value, (int, float)):
raise ValueError(
f"Parameter '{name}' must be float, got {type(value).__name__}"
)
elif param_def.type == "bool" and not isinstance(value, bool):
raise ValueError(
f"Parameter '{name}' must be bool, got {type(value).__name__}"
)
elif param_def.type == "string" and not isinstance(value, str):
raise ValueError(
f"Parameter '{name}' must be string, got {type(value).__name__}"
)
# Range checking for numeric types
if param_def.type in ("int", "float"):
if param_def.min_value is not None and value < param_def.min_value:
raise ValueError(
f"Parameter '{name}' must be >= {param_def.min_value}, got {value}"
)
if param_def.max_value is not None and value > param_def.max_value:
raise ValueError(
f"Parameter '{name}' must be <= {param_def.max_value}, got {value}"
)
def get_output_columns(self) -> List[str]:
"""
Get the output column names with instance name prefix.
Returns:
List of prefixed output column names
"""
output_schema = self.get_output_schema(**self.params)
prefixed = output_schema.with_prefix(self.instance_name)
return [col.name for col in prefixed.columns if col.name != output_schema.time_column]
def __repr__(self) -> str:
return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})"
class DataSourceAdapter:
"""
Adapter to make a DataSource look like an Indicator for pipeline composition.
This allows DataSources to be inputs to indicators in a unified way.
"""
def __init__(self, datasource_id: str, symbol: str, resolution: str):
"""
Create a DataSource adapter.
Args:
datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo')
symbol: Symbol to query (e.g., 'BTC/USD')
resolution: Time resolution (e.g., '1', '5', '1D')
"""
self.datasource_id = datasource_id
self.symbol = symbol
self.resolution = resolution
self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_")
def get_output_columns(self) -> List[str]:
"""
Get the columns provided by this datasource.
Note: This requires runtime resolution - the pipeline engine
will need to query the actual DataSource to get the schema.
Returns:
List of column names (placeholder - needs runtime resolution)
"""
# This will be resolved at runtime by the pipeline engine
return []
def __repr__(self) -> str:
return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')"

View File

@@ -0,0 +1,954 @@
"""
Custom indicator implementations for TradingView indicators not in TA-Lib.
These indicators follow TA-Lib style conventions and integrate seamlessly
with the indicator framework. All implementations are based on well-known,
publicly documented formulas.
"""
import logging
from typing import List, Optional
import numpy as np
from datasource.schema import ColumnInfo
from .base import Indicator
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
IndicatorParameter,
InputSchema,
OutputSchema,
)
logger = logging.getLogger(__name__)
class VWAP(Indicator):
"""Volume Weighted Average Price - Most widely used institutional indicator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="VWAP",
display_name="VWAP",
description="Volume Weighted Average Price - Average price weighted by volume",
category="volume",
parameters=[],
use_cases=[
"Institutional reference price",
"Support/resistance levels",
"Mean reversion trading"
],
references=["https://www.investopedia.com/terms/v/vwap.asp"],
tags=["vwap", "volume", "institutional"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
ColumnInfo(name="volume", type="float", description="Volume"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="vwap", type="float", description="Volume Weighted Average Price", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
# Typical price
typical_price = (high + low + close) / 3.0
# VWAP = cumsum(typical_price * volume) / cumsum(volume)
cumulative_tp_vol = np.nancumsum(typical_price * volume)
cumulative_vol = np.nancumsum(volume)
vwap = cumulative_tp_vol / cumulative_vol
times = context.get_times()
result_data = [
{"time": times[i], "vwap": float(vwap[i]) if not np.isnan(vwap[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class VWMA(Indicator):
"""Volume Weighted Moving Average."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="VWMA",
display_name="VWMA",
description="Volume Weighted Moving Average - Moving average weighted by volume",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=20,
min_value=1,
required=False
)
],
use_cases=["Volume-aware trend following", "Dynamic support/resistance"],
references=["https://www.investopedia.com/articles/trading/11/trading-with-vwap-mvwap.asp"],
tags=["vwma", "volume", "moving average"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="close", type="float", description="Close price"),
ColumnInfo(name="volume", type="float", description="Volume"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="vwma", type="float", description="Volume Weighted Moving Average", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
length = self.params.get("length", 20)
vwma = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
window_close = close[i - length + 1:i + 1]
window_volume = volume[i - length + 1:i + 1]
vwma[i] = np.sum(window_close * window_volume) / np.sum(window_volume)
times = context.get_times()
result_data = [
{"time": times[i], "vwma": float(vwma[i]) if not np.isnan(vwma[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class HullMA(Indicator):
"""Hull Moving Average - Fast and smooth moving average."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="HMA",
display_name="Hull Moving Average",
description="Hull Moving Average - Reduces lag while maintaining smoothness",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=9,
min_value=1,
required=False
)
],
use_cases=["Low-lag trend following", "Quick trend reversal detection"],
references=["https://alanhull.com/hull-moving-average"],
tags=["hma", "hull", "moving average", "low-lag"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="hma", type="float", description="Hull Moving Average", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 9)
def wma(data, period):
"""Weighted Moving Average."""
weights = np.arange(1, period + 1)
result = np.full_like(data, np.nan)
for i in range(period - 1, len(data)):
window = data[i - period + 1:i + 1]
result[i] = np.sum(weights * window) / np.sum(weights)
return result
# HMA = WMA(2 * WMA(n/2) - WMA(n)), sqrt(n))
half_length = length // 2
sqrt_length = int(np.sqrt(length))
wma_half = wma(close, half_length)
wma_full = wma(close, length)
raw_hma = 2 * wma_half - wma_full
hma = wma(raw_hma, sqrt_length)
times = context.get_times()
result_data = [
{"time": times[i], "hma": float(hma[i]) if not np.isnan(hma[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class SuperTrend(Indicator):
"""SuperTrend - Popular trend following indicator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="SUPERTREND",
display_name="SuperTrend",
description="SuperTrend - Volatility-based trend indicator",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="ATR period",
default=10,
min_value=1,
required=False
),
IndicatorParameter(
name="multiplier",
type="float",
description="ATR multiplier",
default=3.0,
min_value=0.1,
required=False
)
],
use_cases=["Trend identification", "Stop loss placement", "Trend reversal signals"],
references=["https://www.investopedia.com/articles/trading/08/supertrend-indicator.asp"],
tags=["supertrend", "trend", "volatility"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="supertrend", type="float", description="SuperTrend value", nullable=True),
ColumnInfo(name="direction", type="int", description="Trend direction (1=up, -1=down)", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 10)
multiplier = self.params.get("multiplier", 3.0)
# Calculate ATR
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
atr = np.full_like(close, np.nan)
atr[length - 1] = np.mean(tr[:length])
for i in range(length, len(tr)):
atr[i] = (atr[i - 1] * (length - 1) + tr[i]) / length
# Calculate basic bands
hl2 = (high + low) / 2
basic_upper = hl2 + multiplier * atr
basic_lower = hl2 - multiplier * atr
# Calculate final bands
final_upper = np.full_like(close, np.nan)
final_lower = np.full_like(close, np.nan)
supertrend = np.full_like(close, np.nan)
direction = np.full_like(close, np.nan)
for i in range(length, len(close)):
if i == length:
final_upper[i] = basic_upper[i]
final_lower[i] = basic_lower[i]
else:
final_upper[i] = basic_upper[i] if basic_upper[i] < final_upper[i - 1] or close[i - 1] > final_upper[i - 1] else final_upper[i - 1]
final_lower[i] = basic_lower[i] if basic_lower[i] > final_lower[i - 1] or close[i - 1] < final_lower[i - 1] else final_lower[i - 1]
if i == length:
supertrend[i] = final_upper[i] if close[i] <= hl2[i] else final_lower[i]
direction[i] = -1 if close[i] <= hl2[i] else 1
else:
if supertrend[i - 1] == final_upper[i - 1] and close[i] <= final_upper[i]:
supertrend[i] = final_upper[i]
direction[i] = -1
elif supertrend[i - 1] == final_upper[i - 1] and close[i] > final_upper[i]:
supertrend[i] = final_lower[i]
direction[i] = 1
elif supertrend[i - 1] == final_lower[i - 1] and close[i] >= final_lower[i]:
supertrend[i] = final_lower[i]
direction[i] = 1
else:
supertrend[i] = final_upper[i]
direction[i] = -1
times = context.get_times()
result_data = [
{
"time": times[i],
"supertrend": float(supertrend[i]) if not np.isnan(supertrend[i]) else None,
"direction": int(direction[i]) if not np.isnan(direction[i]) else None
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class DonchianChannels(Indicator):
"""Donchian Channels - Breakout indicator using highest high and lowest low."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="DONCHIAN",
display_name="Donchian Channels",
description="Donchian Channels - Highest high and lowest low over period",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=20,
min_value=1,
required=False
)
],
use_cases=["Breakout trading", "Volatility bands", "Support/resistance"],
references=["https://www.investopedia.com/terms/d/donchianchannels.asp"],
tags=["donchian", "channels", "breakout"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="upper", type="float", description="Upper channel", nullable=True),
ColumnInfo(name="middle", type="float", description="Middle line", nullable=True),
ColumnInfo(name="lower", type="float", description="Lower channel", nullable=True),
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
length = self.params.get("length", 20)
upper = np.full_like(high, np.nan)
lower = np.full_like(low, np.nan)
for i in range(length - 1, len(high)):
upper[i] = np.nanmax(high[i - length + 1:i + 1])
lower[i] = np.nanmin(low[i - length + 1:i + 1])
middle = (upper + lower) / 2
times = context.get_times()
result_data = [
{
"time": times[i],
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
"middle": float(middle[i]) if not np.isnan(middle[i]) else None,
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class KeltnerChannels(Indicator):
"""Keltner Channels - ATR-based volatility bands."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="KELTNER",
display_name="Keltner Channels",
description="Keltner Channels - EMA with ATR-based bands",
category="volatility",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="EMA period",
default=20,
min_value=1,
required=False
),
IndicatorParameter(
name="multiplier",
type="float",
description="ATR multiplier",
default=2.0,
min_value=0.1,
required=False
),
IndicatorParameter(
name="atr_length",
type="int",
description="ATR period",
default=10,
min_value=1,
required=False
)
],
use_cases=["Volatility bands", "Overbought/oversold", "Trend strength"],
references=["https://www.investopedia.com/terms/k/keltnerchannel.asp"],
tags=["keltner", "channels", "volatility", "atr"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="upper", type="float", description="Upper band", nullable=True),
ColumnInfo(name="middle", type="float", description="Middle line (EMA)", nullable=True),
ColumnInfo(name="lower", type="float", description="Lower band", nullable=True),
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 20)
multiplier = self.params.get("multiplier", 2.0)
atr_length = self.params.get("atr_length", 10)
# Calculate EMA
alpha = 2.0 / (length + 1)
ema = np.full_like(close, np.nan)
ema[0] = close[0]
for i in range(1, len(close)):
ema[i] = alpha * close[i] + (1 - alpha) * ema[i - 1]
# Calculate ATR
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
atr = np.full_like(close, np.nan)
atr[atr_length - 1] = np.mean(tr[:atr_length])
for i in range(atr_length, len(tr)):
atr[i] = (atr[i - 1] * (atr_length - 1) + tr[i]) / atr_length
upper = ema + multiplier * atr
lower = ema - multiplier * atr
times = context.get_times()
result_data = [
{
"time": times[i],
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
"middle": float(ema[i]) if not np.isnan(ema[i]) else None,
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class ChaikinMoneyFlow(Indicator):
"""Chaikin Money Flow - Volume-weighted accumulation/distribution."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="CMF",
display_name="Chaikin Money Flow",
description="Chaikin Money Flow - Measures buying and selling pressure",
category="volume",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=20,
min_value=1,
required=False
)
],
use_cases=["Buying/selling pressure", "Trend confirmation", "Divergence analysis"],
references=["https://www.investopedia.com/terms/c/chaikinoscillator.asp"],
tags=["cmf", "chaikin", "volume", "money flow"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
ColumnInfo(name="volume", type="float", description="Volume"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="cmf", type="float", description="Chaikin Money Flow", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
length = self.params.get("length", 20)
# Money Flow Multiplier
mfm = ((close - low) - (high - close)) / (high - low)
mfm = np.where(high == low, 0, mfm)
# Money Flow Volume
mfv = mfm * volume
# CMF
cmf = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
cmf[i] = np.nansum(mfv[i - length + 1:i + 1]) / np.nansum(volume[i - length + 1:i + 1])
times = context.get_times()
result_data = [
{"time": times[i], "cmf": float(cmf[i]) if not np.isnan(cmf[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class VortexIndicator(Indicator):
"""Vortex Indicator - Identifies trend direction and strength."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="VORTEX",
display_name="Vortex Indicator",
description="Vortex Indicator - Trend direction and strength",
category="momentum",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=14,
min_value=1,
required=False
)
],
use_cases=["Trend identification", "Trend reversals", "Trend strength"],
references=["https://www.investopedia.com/terms/v/vortex-indicator-vi.asp"],
tags=["vortex", "trend", "momentum"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="vi_plus", type="float", description="Positive Vortex", nullable=True),
ColumnInfo(name="vi_minus", type="float", description="Negative Vortex", nullable=True),
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 14)
# Vortex Movement
vm_plus = np.abs(high - np.roll(low, 1))
vm_minus = np.abs(low - np.roll(high, 1))
vm_plus[0] = 0
vm_minus[0] = 0
# True Range
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
# Vortex Indicator
vi_plus = np.full_like(close, np.nan)
vi_minus = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
sum_vm_plus = np.sum(vm_plus[i - length + 1:i + 1])
sum_vm_minus = np.sum(vm_minus[i - length + 1:i + 1])
sum_tr = np.sum(tr[i - length + 1:i + 1])
if sum_tr != 0:
vi_plus[i] = sum_vm_plus / sum_tr
vi_minus[i] = sum_vm_minus / sum_tr
times = context.get_times()
result_data = [
{
"time": times[i],
"vi_plus": float(vi_plus[i]) if not np.isnan(vi_plus[i]) else None,
"vi_minus": float(vi_minus[i]) if not np.isnan(vi_minus[i]) else None,
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class AwesomeOscillator(Indicator):
"""Awesome Oscillator - Bill Williams' momentum indicator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="AO",
display_name="Awesome Oscillator",
description="Awesome Oscillator - Difference between 5 and 34 period SMAs of midpoint",
category="momentum",
parameters=[],
use_cases=["Momentum shifts", "Trend reversals", "Divergence trading"],
references=["https://www.investopedia.com/terms/a/awesomeoscillator.asp"],
tags=["awesome", "oscillator", "momentum", "williams"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="ao", type="float", description="Awesome Oscillator", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
midpoint = (high + low) / 2
# SMA 5
sma5 = np.full_like(midpoint, np.nan)
for i in range(4, len(midpoint)):
sma5[i] = np.mean(midpoint[i - 4:i + 1])
# SMA 34
sma34 = np.full_like(midpoint, np.nan)
for i in range(33, len(midpoint)):
sma34[i] = np.mean(midpoint[i - 33:i + 1])
ao = sma5 - sma34
times = context.get_times()
result_data = [
{"time": times[i], "ao": float(ao[i]) if not np.isnan(ao[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class AcceleratorOscillator(Indicator):
"""Accelerator Oscillator - Rate of change of Awesome Oscillator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="AC",
display_name="Accelerator Oscillator",
description="Accelerator Oscillator - Rate of change of Awesome Oscillator",
category="momentum",
parameters=[],
use_cases=["Early momentum detection", "Trend acceleration", "Divergence signals"],
references=["https://www.investopedia.com/terms/a/accelerator-oscillator.asp"],
tags=["accelerator", "oscillator", "momentum", "williams"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="ac", type="float", description="Accelerator Oscillator", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
midpoint = (high + low) / 2
# Calculate AO first
sma5 = np.full_like(midpoint, np.nan)
for i in range(4, len(midpoint)):
sma5[i] = np.mean(midpoint[i - 4:i + 1])
sma34 = np.full_like(midpoint, np.nan)
for i in range(33, len(midpoint)):
sma34[i] = np.mean(midpoint[i - 33:i + 1])
ao = sma5 - sma34
# AC = AO - SMA(AO, 5)
sma_ao = np.full_like(ao, np.nan)
for i in range(4, len(ao)):
if not np.isnan(ao[i - 4:i + 1]).any():
sma_ao[i] = np.mean(ao[i - 4:i + 1])
ac = ao - sma_ao
times = context.get_times()
result_data = [
{"time": times[i], "ac": float(ac[i]) if not np.isnan(ac[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class ChoppinessIndex(Indicator):
"""Choppiness Index - Determines if market is choppy or trending."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="CHOP",
display_name="Choppiness Index",
description="Choppiness Index - Measures market trendiness vs consolidation",
category="volatility",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=14,
min_value=1,
required=False
)
],
use_cases=["Trend vs range identification", "Market regime detection"],
references=["https://www.tradingview.com/support/solutions/43000501980/"],
tags=["chop", "choppiness", "trend", "range"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="chop", type="float", description="Choppiness Index (0-100)", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 14)
# True Range
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
chop = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
sum_tr = np.sum(tr[i - length + 1:i + 1])
high_low_diff = np.max(high[i - length + 1:i + 1]) - np.min(low[i - length + 1:i + 1])
if high_low_diff != 0:
chop[i] = 100 * np.log10(sum_tr / high_low_diff) / np.log10(length)
times = context.get_times()
result_data = [
{"time": times[i], "chop": float(chop[i]) if not np.isnan(chop[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class MassIndex(Indicator):
"""Mass Index - Identifies trend reversals based on range expansion."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="MASS",
display_name="Mass Index",
description="Mass Index - Identifies reversals when range narrows then expands",
category="volatility",
parameters=[
IndicatorParameter(
name="fast_period",
type="int",
description="Fast EMA period",
default=9,
min_value=1,
required=False
),
IndicatorParameter(
name="slow_period",
type="int",
description="Slow EMA period",
default=25,
min_value=1,
required=False
)
],
use_cases=["Reversal detection", "Volatility analysis", "Bulge identification"],
references=["https://www.investopedia.com/terms/m/mass-index.asp"],
tags=["mass", "index", "volatility", "reversal"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="mass", type="float", description="Mass Index", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
fast_period = self.params.get("fast_period", 9)
slow_period = self.params.get("slow_period", 25)
hl_range = high - low
# Single EMA
alpha1 = 2.0 / (fast_period + 1)
ema1 = np.full_like(hl_range, np.nan)
ema1[0] = hl_range[0]
for i in range(1, len(hl_range)):
ema1[i] = alpha1 * hl_range[i] + (1 - alpha1) * ema1[i - 1]
# Double EMA
ema2 = np.full_like(ema1, np.nan)
ema2[0] = ema1[0]
for i in range(1, len(ema1)):
if not np.isnan(ema1[i]):
ema2[i] = alpha1 * ema1[i] + (1 - alpha1) * ema2[i - 1]
# EMA Ratio
ema_ratio = ema1 / ema2
# Mass Index
mass = np.full_like(hl_range, np.nan)
for i in range(slow_period - 1, len(ema_ratio)):
mass[i] = np.nansum(ema_ratio[i - slow_period + 1:i + 1])
times = context.get_times()
result_data = [
{"time": times[i], "mass": float(mass[i]) if not np.isnan(mass[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
# Registry of all custom indicators
CUSTOM_INDICATORS = [
VWAP,
VWMA,
HullMA,
SuperTrend,
DonchianChannels,
KeltnerChannels,
ChaikinMoneyFlow,
VortexIndicator,
AwesomeOscillator,
AcceleratorOscillator,
ChoppinessIndex,
MassIndex,
]
def register_custom_indicators(registry) -> int:
"""
Register all custom indicators with the registry.
Args:
registry: IndicatorRegistry instance
Returns:
Number of indicators registered
"""
registered_count = 0
for indicator_class in CUSTOM_INDICATORS:
try:
registry.register(indicator_class)
registered_count += 1
logger.debug(f"Registered custom indicator: {indicator_class.__name__}")
except Exception as e:
logger.warning(f"Failed to register custom indicator {indicator_class.__name__}: {e}")
logger.info(f"Registered {registered_count} custom indicators")
return registered_count

View File

@@ -0,0 +1,439 @@
"""
Pipeline execution engine for composable indicators.
Manages DAG construction, dependency resolution, incremental updates,
and efficient data flow through indicator chains.
"""
import logging
from collections import defaultdict, deque
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from datasource.base import DataSource
from datasource.schema import ColumnInfo
from .base import DataSourceAdapter, Indicator
from .schema import ComputeContext, ComputeResult
logger = logging.getLogger(__name__)
class PipelineNode:
"""
A node in the pipeline DAG.
Can be either a DataSource adapter or an Indicator instance.
"""
def __init__(
self,
node_id: str,
node: Union[DataSourceAdapter, Indicator],
dependencies: List[str]
):
"""
Create a pipeline node.
Args:
node_id: Unique identifier for this node
node: The DataSourceAdapter or Indicator instance
dependencies: List of node_ids this node depends on
"""
self.node_id = node_id
self.node = node
self.dependencies = dependencies
self.output_columns: List[str] = []
self.cached_data: List[Dict[str, Any]] = []
def is_datasource(self) -> bool:
"""Check if this node is a DataSource adapter."""
return isinstance(self.node, DataSourceAdapter)
def is_indicator(self) -> bool:
"""Check if this node is an Indicator."""
return isinstance(self.node, Indicator)
def __repr__(self) -> str:
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
class Pipeline:
"""
Execution engine for indicator DAGs.
Manages:
- DAG construction and validation
- Topological sorting for execution order
- Data flow and caching
- Incremental updates (only recompute what changed)
- Schema validation
"""
def __init__(self, datasource_registry):
"""
Initialize a pipeline.
Args:
datasource_registry: DataSourceRegistry for resolving data sources
"""
self.datasource_registry = datasource_registry
self.nodes: Dict[str, PipelineNode] = {}
self.execution_order: List[str] = []
self._dirty_nodes: Set[str] = set()
def add_datasource(
self,
node_id: str,
datasource_name: str,
symbol: str,
resolution: str
) -> None:
"""
Add a DataSource to the pipeline.
Args:
node_id: Unique identifier for this node
datasource_name: Name of the datasource in the registry
symbol: Symbol to query
resolution: Time resolution
Raises:
ValueError: If node_id already exists or datasource not found
"""
if node_id in self.nodes:
raise ValueError(f"Node '{node_id}' already exists in pipeline")
datasource = self.datasource_registry.get(datasource_name)
if not datasource:
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
node = PipelineNode(node_id, adapter, dependencies=[])
self.nodes[node_id] = node
self._invalidate_execution_order()
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
def add_indicator(
self,
node_id: str,
indicator: Indicator,
input_node_ids: List[str]
) -> None:
"""
Add an Indicator to the pipeline.
Args:
node_id: Unique identifier for this node
indicator: Indicator instance
input_node_ids: List of node IDs providing input data
Raises:
ValueError: If node_id already exists, dependencies not found, or schema mismatch
"""
if node_id in self.nodes:
raise ValueError(f"Node '{node_id}' already exists in pipeline")
# Validate dependencies exist
for dep_id in input_node_ids:
if dep_id not in self.nodes:
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
# TODO: Validate input schema matches available columns from dependencies
# This requires merging output schemas from all input nodes
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
self.nodes[node_id] = node
self._invalidate_execution_order()
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
def remove_node(self, node_id: str) -> None:
"""
Remove a node from the pipeline.
Args:
node_id: Node to remove
Raises:
ValueError: If other nodes depend on this node
"""
if node_id not in self.nodes:
return
# Check for dependent nodes
dependents = [
n.node_id for n in self.nodes.values()
if node_id in n.dependencies
]
if dependents:
raise ValueError(
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
)
del self.nodes[node_id]
self._invalidate_execution_order()
logger.info(f"Removed node '{node_id}' from pipeline")
def _invalidate_execution_order(self) -> None:
"""Mark execution order as needing recomputation."""
self.execution_order = []
def _compute_execution_order(self) -> List[str]:
"""
Compute topological sort of the DAG.
Returns:
List of node IDs in execution order
Raises:
ValueError: If DAG contains cycles
"""
if self.execution_order:
return self.execution_order
# Kahn's algorithm for topological sort
in_degree = {node_id: 0 for node_id in self.nodes}
for node in self.nodes.values():
for dep in node.dependencies:
in_degree[node.node_id] += 1
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
result = []
while queue:
node_id = queue.popleft()
result.append(node_id)
# Find all nodes that depend on this one
for other_node in self.nodes.values():
if node_id in other_node.dependencies:
in_degree[other_node.node_id] -= 1
if in_degree[other_node.node_id] == 0:
queue.append(other_node.node_id)
if len(result) != len(self.nodes):
raise ValueError("Pipeline contains cycles")
self.execution_order = result
logger.debug(f"Computed execution order: {result}")
return result
def execute(
self,
datasource_data: Dict[str, List[Dict[str, Any]]],
incremental: bool = False,
updated_from_time: Optional[int] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
Execute the pipeline.
Args:
datasource_data: Mapping of DataSource node_id to input data
incremental: Whether this is an incremental update
updated_from_time: Timestamp of earliest updated row (for incremental)
Returns:
Dictionary mapping node_id to output data (all nodes)
Raises:
ValueError: If required datasource data is missing
"""
execution_order = self._compute_execution_order()
results: Dict[str, List[Dict[str, Any]]] = {}
logger.info(
f"Executing pipeline with {len(execution_order)} nodes "
f"(incremental={incremental})"
)
for node_id in execution_order:
node = self.nodes[node_id]
if node.is_datasource():
# DataSource node - get data from input
if node_id not in datasource_data:
raise ValueError(
f"DataSource node '{node_id}' has no input data"
)
results[node_id] = datasource_data[node_id]
node.cached_data = results[node_id]
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
elif node.is_indicator():
# Indicator node - compute from dependencies
indicator = node.node
# Merge input data from all dependencies
input_data = self._merge_dependency_data(node.dependencies, results)
# Create compute context
context = ComputeContext(
data=input_data,
is_incremental=incremental,
updated_from_time=updated_from_time
)
# Execute indicator
logger.debug(
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
)
compute_result = indicator.compute(context)
# Merge result with input data (adding prefixed columns)
output_data = compute_result.merge_with_prefix(
indicator.instance_name,
input_data
)
results[node_id] = output_data
node.cached_data = output_data
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
return results
def _merge_dependency_data(
self,
dependency_ids: List[str],
results: Dict[str, List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""
Merge data from multiple dependency nodes.
Data is merged by time, with later dependencies overwriting earlier ones
for conflicting column names.
Args:
dependency_ids: List of node IDs to merge
results: Current execution results
Returns:
Merged data rows
"""
if not dependency_ids:
return []
if len(dependency_ids) == 1:
return results[dependency_ids[0]]
# Build time-indexed data from first dependency
merged: Dict[int, Dict[str, Any]] = {}
for row in results[dependency_ids[0]]:
merged[row["time"]] = row.copy()
# Merge in additional dependencies
for dep_id in dependency_ids[1:]:
for row in results[dep_id]:
time_key = row["time"]
if time_key in merged:
# Merge columns (later dependencies win)
merged[time_key].update(row)
else:
# New timestamp
merged[time_key] = row.copy()
# Sort by time and return
sorted_times = sorted(merged.keys())
return [merged[t] for t in sorted_times]
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
"""
Get cached output data for a specific node.
Args:
node_id: Node identifier
Returns:
Cached data or None if not available
"""
node = self.nodes.get(node_id)
return node.cached_data if node else None
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
"""
Get the output schema for a specific node.
Args:
node_id: Node identifier
Returns:
List of ColumnInfo describing output columns
Raises:
ValueError: If node not found
"""
node = self.nodes.get(node_id)
if not node:
raise ValueError(f"Node '{node_id}' not found")
if node.is_datasource():
# Would need to query the actual datasource at runtime
# For now, return empty - this requires integration with DataSource
return []
elif node.is_indicator():
indicator = node.node
output_schema = indicator.get_output_schema(**indicator.params)
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
return prefixed_schema.columns
return []
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
"""
Validate the entire pipeline for correctness.
Checks:
- No cycles (already checked in execution order)
- All dependencies exist (already checked in add_indicator)
- Input schemas match output schemas (TODO)
Returns:
Tuple of (is_valid, error_message)
"""
try:
self._compute_execution_order()
return True, None
except ValueError as e:
return False, str(e)
def get_node_count(self) -> int:
"""Get the number of nodes in the pipeline."""
return len(self.nodes)
def get_indicator_count(self) -> int:
"""Get the number of indicator nodes in the pipeline."""
return sum(1 for node in self.nodes.values() if node.is_indicator())
def get_datasource_count(self) -> int:
"""Get the number of datasource nodes in the pipeline."""
return sum(1 for node in self.nodes.values() if node.is_datasource())
def describe(self) -> Dict[str, Any]:
"""
Get a detailed description of the pipeline structure.
Returns:
Dictionary with pipeline metadata and structure
"""
return {
"node_count": self.get_node_count(),
"datasource_count": self.get_datasource_count(),
"indicator_count": self.get_indicator_count(),
"nodes": [
{
"id": node.node_id,
"type": "datasource" if node.is_datasource() else "indicator",
"node": str(node.node),
"dependencies": node.dependencies,
"cached_rows": len(node.cached_data)
}
for node in self.nodes.values()
],
"execution_order": self.execution_order or self._compute_execution_order(),
"is_valid": self.validate_pipeline()[0]
}

View File

@@ -0,0 +1,349 @@
"""
Indicator registry for managing and discovering indicators.
Provides AI agents with a queryable catalog of available indicators,
their capabilities, and metadata.
"""
from typing import Dict, List, Optional, Type
from .base import Indicator
from .schema import IndicatorMetadata, InputSchema, OutputSchema
class IndicatorRegistry:
"""
Central registry for indicator classes.
Enables:
- Registration of indicator implementations
- Discovery by name, category, or tags
- Schema validation
- AI agent tool generation
"""
def __init__(self):
self._indicators: Dict[str, Type[Indicator]] = {}
def register(self, indicator_class: Type[Indicator]) -> None:
"""
Register an indicator class.
Args:
indicator_class: Indicator class to register
Raises:
ValueError: If an indicator with this name is already registered
"""
metadata = indicator_class.get_metadata()
if metadata.name in self._indicators:
raise ValueError(
f"Indicator '{metadata.name}' is already registered"
)
self._indicators[metadata.name] = indicator_class
def unregister(self, name: str) -> None:
"""
Unregister an indicator class.
Args:
name: Indicator class name
"""
self._indicators.pop(name, None)
def get(self, name: str) -> Optional[Type[Indicator]]:
"""
Get an indicator class by name.
Args:
name: Indicator class name
Returns:
Indicator class or None if not found
"""
return self._indicators.get(name)
def list_indicators(self) -> List[str]:
"""
Get names of all registered indicators.
Returns:
List of indicator class names
"""
return list(self._indicators.keys())
def get_metadata(self, name: str) -> Optional[IndicatorMetadata]:
"""
Get metadata for a specific indicator.
Args:
name: Indicator class name
Returns:
IndicatorMetadata or None if not found
"""
indicator_class = self.get(name)
if indicator_class:
return indicator_class.get_metadata()
return None
def get_all_metadata(self) -> List[IndicatorMetadata]:
"""
Get metadata for all registered indicators.
Useful for AI agent tool generation and discovery.
Returns:
List of IndicatorMetadata for all registered indicators
"""
return [cls.get_metadata() for cls in self._indicators.values()]
def search_by_category(self, category: str) -> List[IndicatorMetadata]:
"""
Find indicators by category.
Args:
category: Category name (e.g., 'momentum', 'trend', 'volatility')
Returns:
List of matching indicator metadata
"""
results = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
if metadata.category.lower() == category.lower():
results.append(metadata)
return results
def search_by_tag(self, tag: str) -> List[IndicatorMetadata]:
"""
Find indicators by tag.
Args:
tag: Tag to search for (case-insensitive)
Returns:
List of matching indicator metadata
"""
tag_lower = tag.lower()
results = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
if any(t.lower() == tag_lower for t in metadata.tags):
results.append(metadata)
return results
def search_by_text(self, query: str) -> List[IndicatorMetadata]:
"""
Full-text search across indicator names, descriptions, and use cases.
Args:
query: Search query (case-insensitive)
Returns:
List of matching indicator metadata, ranked by relevance
"""
query_lower = query.lower()
results = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
score = 0
# Check name (highest weight)
if query_lower in metadata.name.lower():
score += 10
if query_lower in metadata.display_name.lower():
score += 8
# Check description
if query_lower in metadata.description.lower():
score += 5
# Check use cases
for use_case in metadata.use_cases:
if query_lower in use_case.lower():
score += 3
# Check tags
for tag in metadata.tags:
if query_lower in tag.lower():
score += 2
if score > 0:
results.append((score, metadata))
# Sort by score descending
results.sort(key=lambda x: x[0], reverse=True)
return [metadata for _, metadata in results]
def find_compatible_indicators(
self,
available_columns: List[str],
column_types: Dict[str, str]
) -> List[IndicatorMetadata]:
"""
Find indicators that can be computed from available columns.
Args:
available_columns: List of column names available
column_types: Mapping of column name to type
Returns:
List of indicators whose input schema is satisfied
"""
from datasource.schema import ColumnInfo
# Build ColumnInfo list from available data
available_schema = [
ColumnInfo(
name=name,
type=column_types.get(name, "float"),
description=f"Column {name}"
)
for name in available_columns
]
results = []
for indicator_class in self._indicators.values():
input_schema = indicator_class.get_input_schema()
if input_schema.matches(available_schema):
results.append(indicator_class.get_metadata())
return results
def validate_indicator_chain(
self,
indicator_chain: List[tuple[str, Dict]]
) -> tuple[bool, Optional[str]]:
"""
Validate that a chain of indicators can be connected.
Args:
indicator_chain: List of (indicator_name, params) tuples in execution order
Returns:
Tuple of (is_valid, error_message)
"""
if not indicator_chain:
return True, None
# For now, just check that all indicators exist
# More sophisticated DAG validation happens in the pipeline engine
for indicator_name, params in indicator_chain:
if indicator_name not in self._indicators:
return False, f"Indicator '{indicator_name}' not found in registry"
return True, None
def get_input_schema(self, name: str) -> Optional[InputSchema]:
"""
Get input schema for a specific indicator.
Args:
name: Indicator class name
Returns:
InputSchema or None if not found
"""
indicator_class = self.get(name)
if indicator_class:
return indicator_class.get_input_schema()
return None
def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]:
"""
Get output schema for a specific indicator with given parameters.
Args:
name: Indicator class name
**params: Indicator parameters
Returns:
OutputSchema or None if not found
"""
indicator_class = self.get(name)
if indicator_class:
return indicator_class.get_output_schema(**params)
return None
def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]:
"""
Create an indicator instance with validation.
Args:
name: Indicator class name
instance_name: Unique instance name (for output column prefixing)
**params: Indicator configuration parameters
Returns:
Indicator instance or None if class not found
Raises:
ValueError: If parameters are invalid
"""
indicator_class = self.get(name)
if not indicator_class:
return None
return indicator_class(instance_name=instance_name, **params)
def generate_ai_tool_spec(self) -> Dict:
"""
Generate a JSON specification for AI agent tools.
Creates a structured representation of all indicators that can be
used to build agent tools for indicator selection and composition.
Returns:
Dict suitable for AI agent tool registration
"""
tools = []
for indicator_class in self._indicators.values():
metadata = indicator_class.get_metadata()
# Build parameter spec
parameters = {
"type": "object",
"properties": {},
"required": []
}
for param in metadata.parameters:
param_spec = {
"type": param.type,
"description": param.description
}
if param.default is not None:
param_spec["default"] = param.default
if param.min_value is not None:
param_spec["minimum"] = param.min_value
if param.max_value is not None:
param_spec["maximum"] = param.max_value
parameters["properties"][param.name] = param_spec
if param.required:
parameters["required"].append(param.name)
tool = {
"name": f"indicator_{metadata.name.lower()}",
"description": f"{metadata.display_name}: {metadata.description}",
"category": metadata.category,
"use_cases": metadata.use_cases,
"tags": metadata.tags,
"parameters": parameters,
"input_schema": indicator_class.get_input_schema().model_dump(),
"output_schema": indicator_class.get_output_schema().model_dump()
}
tools.append(tool)
return {
"indicator_tools": tools,
"total_count": len(tools)
}

View File

@@ -0,0 +1,269 @@
"""
Data models for the Indicator system.
Defines schemas for input/output specifications, computation context,
and metadata for AI agent discovery.
"""
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from datasource.schema import ColumnInfo
class InputSchema(BaseModel):
"""
Declares the required input columns for an Indicator.
Indicators match against any data source (DataSource or other Indicator)
that provides columns satisfying this schema.
"""
model_config = {"extra": "forbid"}
required_columns: List[ColumnInfo] = Field(
description="Columns that must be present in the input data"
)
optional_columns: List[ColumnInfo] = Field(
default_factory=list,
description="Columns that may be used if present but are not required"
)
time_column: str = Field(
default="time",
description="Name of the timestamp column (must be present)"
)
def matches(self, available_columns: List[ColumnInfo]) -> bool:
"""
Check if available columns satisfy this input schema.
Args:
available_columns: Columns provided by a data source
Returns:
True if all required columns are present with compatible types
"""
available_map = {col.name: col for col in available_columns}
# Check time column exists
if self.time_column not in available_map:
return False
# Check all required columns exist with compatible types
for required in self.required_columns:
if required.name not in available_map:
return False
available = available_map[required.name]
if available.type != required.type:
return False
return True
def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]:
"""
Get list of missing required column names.
Args:
available_columns: Columns provided by a data source
Returns:
List of missing column names
"""
available_names = {col.name for col in available_columns}
missing = []
if self.time_column not in available_names:
missing.append(self.time_column)
for required in self.required_columns:
if required.name not in available_names:
missing.append(required.name)
return missing
class OutputSchema(BaseModel):
"""
Declares the output columns produced by an Indicator.
Column names will be automatically prefixed with the indicator instance name
to avoid collisions in the pipeline.
"""
model_config = {"extra": "forbid"}
columns: List[ColumnInfo] = Field(
description="Output columns produced by this indicator"
)
time_column: str = Field(
default="time",
description="Name of the timestamp column (passed through from input)"
)
def with_prefix(self, prefix: str) -> "OutputSchema":
"""
Create a new OutputSchema with all column names prefixed.
Args:
prefix: Prefix to add (e.g., indicator instance name)
Returns:
New OutputSchema with prefixed column names
"""
prefixed_columns = [
ColumnInfo(
name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name,
type=col.type,
description=col.description,
unit=col.unit,
nullable=col.nullable
)
for col in self.columns
]
return OutputSchema(
columns=prefixed_columns,
time_column=self.time_column
)
class IndicatorParameter(BaseModel):
"""
Metadata for a configurable indicator parameter.
Used for AI agent discovery and dynamic indicator instantiation.
"""
model_config = {"extra": "forbid"}
name: str = Field(description="Parameter name")
type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type")
description: str = Field(description="Human and LLM-readable description")
default: Optional[Any] = Field(default=None, description="Default value if not specified")
required: bool = Field(default=False, description="Whether this parameter is required")
min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)")
max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)")
class IndicatorMetadata(BaseModel):
"""
Rich metadata for an Indicator class.
Enables AI agents to discover, understand, and instantiate indicators.
"""
model_config = {"extra": "forbid"}
name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')")
display_name: str = Field(description="Human-readable display name")
description: str = Field(description="Detailed description of what this indicator computes and why it's useful")
category: str = Field(
description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')"
)
parameters: List[IndicatorParameter] = Field(
default_factory=list,
description="Configurable parameters for this indicator"
)
use_cases: List[str] = Field(
default_factory=list,
description="Common use cases and trading scenarios where this indicator is helpful"
)
references: List[str] = Field(
default_factory=list,
description="URLs or citations for indicator methodology"
)
tags: List[str] = Field(
default_factory=list,
description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')"
)
class ComputeContext(BaseModel):
"""
Context passed to an Indicator's compute() method.
Contains the input data and metadata about what changed (for incremental updates).
"""
model_config = {"extra": "forbid"}
data: List[Dict[str, Any]] = Field(
description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}"
)
is_incremental: bool = Field(
default=False,
description="True if this is an incremental update (only new/changed rows), False for full recompute"
)
updated_from_time: Optional[int] = Field(
default=None,
description="Unix timestamp (ms) of the earliest updated row (for incremental updates)"
)
def get_column(self, name: str) -> List[Any]:
"""
Extract a single column as a list of values.
Args:
name: Column name
Returns:
List of values in time order
"""
return [row.get(name) for row in self.data]
def get_times(self) -> List[int]:
"""
Get the time column as a list.
Returns:
List of timestamps in order
"""
return [row["time"] for row in self.data]
class ComputeResult(BaseModel):
"""
Result from an Indicator's compute() method.
Contains the computed output data with proper column naming.
"""
model_config = {"extra": "forbid"}
data: List[Dict[str, Any]] = Field(
description="Output data rows (time-ordered). Must include time column."
)
is_partial: bool = Field(
default=False,
description="True if this result only contains updates (for incremental computation)"
)
def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Merge this result into existing data with column name prefixing.
Args:
prefix: Prefix to add to all column names except time
existing_data: Existing data to merge with (matched by time)
Returns:
Merged data with prefixed columns added
"""
# Build a time index for new data
time_index = {row["time"]: row for row in self.data}
# Merge into existing data
result = []
for existing_row in existing_data:
row_time = existing_row["time"]
merged_row = existing_row.copy()
if row_time in time_index:
new_row = time_index[row_time]
for key, value in new_row.items():
if key != "time":
merged_row[f"{prefix}_{key}"] = value
result.append(merged_row)
return result

View File

@@ -0,0 +1,449 @@
"""
TA-Lib indicator adapter.
Provides automatic registration of all TA-Lib technical indicators
as composable Indicator instances.
Installation Requirements:
--------------------------
TA-Lib requires both the C library and Python wrapper:
1. Install TA-Lib C library:
- Ubuntu/Debian: sudo apt-get install libta-lib-dev
- macOS: brew install ta-lib
- From source: https://ta-lib.org/install.html
2. Install Python wrapper (already in requirements.txt):
pip install TA-Lib
Usage:
------
from indicator.talib_adapter import register_all_talib_indicators
# Auto-register all TA-Lib indicators
registry = IndicatorRegistry()
register_all_talib_indicators(registry)
# Now you can use any TA-Lib indicator
sma = registry.create_instance("SMA", "sma_20", period=20)
rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14)
"""
import logging
from typing import Any, Dict, List, Optional
import numpy as np
try:
import talib
from talib import abstract
TALIB_AVAILABLE = True
except ImportError:
TALIB_AVAILABLE = False
talib = None
abstract = None
from datasource.schema import ColumnInfo
from .base import Indicator
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
IndicatorParameter,
InputSchema,
OutputSchema,
)
logger = logging.getLogger(__name__)
# Mapping of TA-Lib parameter types to our schema types
TALIB_TYPE_MAP = {
"double": "float",
"double[]": "float",
"int": "int",
"str": "string",
}
# Categorization of TA-Lib functions
TALIB_CATEGORIES = {
"overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3",
"BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"],
"momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX",
"ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI",
"STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC",
"BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"],
"volume": ["AD", "ADOSC", "OBV"],
"volatility": ["ATR", "NATR", "TRANGE"],
"price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"],
"cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"],
"pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE",
"CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY",
"CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU",
"CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI",
"CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR",
"CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER",
"CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE",
"CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK",
"CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM",
"CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW",
"CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK",
"CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES",
"CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN",
"CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR",
"CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"],
"statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT",
"LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"],
"math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX",
"MULT", "SUB", "SUM"],
}
def _get_function_category(func_name: str) -> str:
"""Determine the category of a TA-Lib function."""
for category, functions in TALIB_CATEGORIES.items():
if func_name in functions:
return category
return "other"
class TALibIndicator(Indicator):
"""
Generic adapter for TA-Lib technical indicators.
Wraps any TA-Lib function to work within the composable indicator framework.
Handles parameter mapping, input validation, and output formatting.
"""
# Class variable to store the TA-Lib function name
talib_function_name: str = None
def __init__(self, instance_name: str, **params):
"""
Initialize a TA-Lib indicator.
Args:
instance_name: Unique name for this instance
**params: TA-Lib function parameters
"""
if not TALIB_AVAILABLE:
raise ImportError(
"TA-Lib is not installed. Please install the TA-Lib C library "
"and Python wrapper. See indicator/talib_adapter.py for instructions."
)
super().__init__(instance_name, **params)
self._talib_func = abstract.Function(self.talib_function_name)
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
"""Get metadata from TA-Lib function info."""
if not TALIB_AVAILABLE:
raise ImportError("TA-Lib is not installed")
func = abstract.Function(cls.talib_function_name)
info = func.info
# Build parameters list from TA-Lib function info
parameters = []
for param_name, param_info in info.get("parameters", {}).items():
# Handle case where param_info is a simple value (int/float) instead of a dict
if isinstance(param_info, dict):
param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float")
default_value = param_info.get("default_value")
else:
# param_info is a simple value (default), infer type from the value
if isinstance(param_info, int):
param_type = "int"
elif isinstance(param_info, float):
param_type = "float"
else:
param_type = "float" # Default to float
default_value = param_info
parameters.append(
IndicatorParameter(
name=param_name,
type=param_type,
description=f"TA-Lib parameter: {param_name}",
default=default_value,
required=False
)
)
# Get function group/category
category = _get_function_category(cls.talib_function_name)
# Build display name (split camelCase or handle CDL prefix)
display_name = cls.talib_function_name
if display_name.startswith("CDL"):
display_name = display_name[3:] # Remove CDL prefix for patterns
return IndicatorMetadata(
name=cls.talib_function_name,
display_name=display_name,
description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"),
category=category,
parameters=parameters,
use_cases=[f"Technical analysis using {cls.talib_function_name}"],
references=["https://ta-lib.org/function.html"],
tags=["talib", category, cls.talib_function_name.lower()]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
"""
Get input schema from TA-Lib function requirements.
Most TA-Lib functions use OHLCV data, but some use subsets.
"""
if not TALIB_AVAILABLE:
raise ImportError("TA-Lib is not installed")
func = abstract.Function(cls.talib_function_name)
info = func.info
input_names = info.get("input_names", {})
required_columns = []
# Map TA-Lib input names to our schema
if "prices" in input_names:
price_inputs = input_names["prices"]
if "open" in price_inputs:
required_columns.append(
ColumnInfo(name="open", type="float", description="Opening price")
)
if "high" in price_inputs:
required_columns.append(
ColumnInfo(name="high", type="float", description="High price")
)
if "low" in price_inputs:
required_columns.append(
ColumnInfo(name="low", type="float", description="Low price")
)
if "close" in price_inputs:
required_columns.append(
ColumnInfo(name="close", type="float", description="Closing price")
)
if "volume" in price_inputs:
required_columns.append(
ColumnInfo(name="volume", type="float", description="Trading volume")
)
# Handle functions that take generic price arrays
if "price" in input_names:
required_columns.append(
ColumnInfo(name="close", type="float", description="Price (typically close)")
)
# If no specific inputs found, assume close price
if not required_columns:
required_columns.append(
ColumnInfo(name="close", type="float", description="Closing price")
)
return InputSchema(required_columns=required_columns)
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
"""Get output schema from TA-Lib function outputs."""
if not TALIB_AVAILABLE:
raise ImportError("TA-Lib is not installed")
func = abstract.Function(cls.talib_function_name)
info = func.info
output_names = info.get("output_names", [])
columns = []
# Most TA-Lib functions output one or more float arrays
if isinstance(output_names, list):
for output_name in output_names:
columns.append(
ColumnInfo(
name=output_name.lower(),
type="float",
description=f"{cls.talib_function_name} output: {output_name}",
nullable=True # TA-Lib often has NaN for initial periods
)
)
else:
# Single output, use function name
columns.append(
ColumnInfo(
name=cls.talib_function_name.lower(),
type="float",
description=f"{cls.talib_function_name} indicator value",
nullable=True
)
)
return OutputSchema(columns=columns)
def compute(self, context: ComputeContext) -> ComputeResult:
"""Compute indicator using TA-Lib."""
# Extract input columns
input_data = {}
# Get the function's expected inputs
info = self._talib_func.info
input_names = info.get("input_names", {})
# Prepare input arrays
if "prices" in input_names:
price_inputs = input_names["prices"]
for price_type in price_inputs:
column_data = context.get_column(price_type)
# Convert to numpy array, replacing None with NaN
input_data[price_type] = np.array(
[float(v) if v is not None else np.nan for v in column_data]
)
elif "price" in input_names:
# Generic price input, use close
column_data = context.get_column("close")
input_data["price"] = np.array(
[float(v) if v is not None else np.nan for v in column_data]
)
else:
# Default to close if no inputs specified
column_data = context.get_column("close")
input_data["close"] = np.array(
[float(v) if v is not None else np.nan for v in column_data]
)
# Set parameters on the function
self._talib_func.parameters = self.params
# Execute TA-Lib function
try:
output = self._talib_func(input_data)
except Exception as e:
logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}")
raise ValueError(f"TA-Lib computation failed: {e}")
# Format output
times = context.get_times()
output_names = info.get("output_names", [])
# Handle single vs multiple outputs
if isinstance(output, np.ndarray):
# Single output
output_name = output_names[0].lower() if output_names else self.talib_function_name.lower()
result_data = [
{
"time": times[i],
output_name: float(output[i]) if not np.isnan(output[i]) else None
}
for i in range(len(times))
]
elif isinstance(output, tuple):
# Multiple outputs
result_data = []
for i in range(len(times)):
row = {"time": times[i]}
for j, output_array in enumerate(output):
output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}"
row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None
result_data.append(row)
else:
raise ValueError(f"Unexpected TA-Lib output type: {type(output)}")
return ComputeResult(
data=result_data,
is_partial=context.is_incremental
)
def create_talib_indicator_class(func_name: str) -> type:
"""
Dynamically create an Indicator class for a TA-Lib function.
Args:
func_name: TA-Lib function name (e.g., 'SMA', 'RSI')
Returns:
Indicator class for this function
"""
return type(
f"TALib_{func_name}",
(TALibIndicator,),
{"talib_function_name": func_name}
)
def register_all_talib_indicators(registry, only_tradingview_supported: bool = True) -> int:
"""
Auto-register all available TA-Lib indicators with the registry.
Args:
registry: IndicatorRegistry instance
only_tradingview_supported: If True, only register indicators that have
TradingView equivalents (default: True)
Returns:
Number of indicators registered
Raises:
ImportError: If TA-Lib is not installed
"""
if not TALIB_AVAILABLE:
logger.warning(
"TA-Lib is not installed. Skipping TA-Lib indicator registration. "
"Install TA-Lib C library and Python wrapper to enable TA-Lib indicators."
)
return 0
# Get list of supported indicators if filtering is enabled
from .tv_mapping import is_indicator_supported
# Get all TA-Lib functions
func_groups = talib.get_function_groups()
all_functions = []
for group, functions in func_groups.items():
all_functions.extend(functions)
# Remove duplicates
all_functions = sorted(set(all_functions))
registered_count = 0
skipped_count = 0
for func_name in all_functions:
try:
# Skip if filtering enabled and indicator not supported in TradingView
if only_tradingview_supported and not is_indicator_supported(func_name):
skipped_count += 1
logger.debug(f"Skipping TA-Lib function {func_name} - not supported in TradingView")
continue
# Create indicator class for this function
indicator_class = create_talib_indicator_class(func_name)
# Register with the registry
registry.register(indicator_class)
registered_count += 1
except Exception as e:
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
continue
logger.info(f"Registered {registered_count} TA-Lib indicators (skipped {skipped_count} unsupported)")
return registered_count
def get_talib_version() -> Optional[str]:
"""
Get the installed TA-Lib version.
Returns:
Version string or None if not installed
"""
if TALIB_AVAILABLE:
return talib.__version__
return None
def is_talib_available() -> bool:
"""Check if TA-Lib is available."""
return TALIB_AVAILABLE

View File

@@ -0,0 +1,360 @@
"""
Mapping layer between TA-Lib indicators and TradingView indicators.
This module provides bidirectional conversion between our internal TA-Lib-based
indicator representation and TradingView's indicator system.
"""
from typing import Dict, Any, Optional, Tuple, List
import logging
logger = logging.getLogger(__name__)
# Mapping of TA-Lib indicator names to TradingView indicator names
# Only includes indicators that are present in BOTH systems (inner join)
# Format: {talib_name: tv_name}
TALIB_TO_TV_NAMES = {
# Overlap Studies (14)
"SMA": "Moving Average",
"EMA": "Moving Average Exponential",
"WMA": "Weighted Moving Average",
"DEMA": "DEMA",
"TEMA": "TEMA",
"TRIMA": "Triangular Moving Average",
"KAMA": "KAMA",
"MAMA": "MESA Adaptive Moving Average",
"T3": "T3",
"BBANDS": "Bollinger Bands",
"MIDPOINT": "Midpoint",
"MIDPRICE": "Midprice",
"SAR": "Parabolic SAR",
"HT_TRENDLINE": "Hilbert Transform - Instantaneous Trendline",
# Momentum Indicators (21)
"RSI": "Relative Strength Index",
"MOM": "Momentum",
"ROC": "Rate of Change",
"TRIX": "TRIX",
"CMO": "Chande Momentum Oscillator",
"DX": "Directional Movement Index",
"ADX": "Average Directional Movement Index",
"ADXR": "Average Directional Movement Index Rating",
"APO": "Absolute Price Oscillator",
"PPO": "Percentage Price Oscillator",
"MACD": "MACD",
"MFI": "Money Flow Index",
"STOCH": "Stochastic",
"STOCHF": "Stochastic Fast",
"STOCHRSI": "Stochastic RSI",
"WILLR": "Williams %R",
"CCI": "Commodity Channel Index",
"AROON": "Aroon",
"AROONOSC": "Aroon Oscillator",
"BOP": "Balance Of Power",
"ULTOSC": "Ultimate Oscillator",
# Volume Indicators (3)
"AD": "Chaikin A/D Line",
"ADOSC": "Chaikin A/D Oscillator",
"OBV": "On Balance Volume",
# Volatility Indicators (3)
"ATR": "Average True Range",
"NATR": "Normalized Average True Range",
"TRANGE": "True Range",
# Price Transform (4)
"AVGPRICE": "Average Price",
"MEDPRICE": "Median Price",
"TYPPRICE": "Typical Price",
"WCLPRICE": "Weighted Close Price",
# Cycle Indicators (5)
"HT_DCPERIOD": "Hilbert Transform - Dominant Cycle Period",
"HT_DCPHASE": "Hilbert Transform - Dominant Cycle Phase",
"HT_PHASOR": "Hilbert Transform - Phasor Components",
"HT_SINE": "Hilbert Transform - SineWave",
"HT_TRENDMODE": "Hilbert Transform - Trend vs Cycle Mode",
# Statistic Functions (9)
"BETA": "Beta",
"CORREL": "Pearson's Correlation Coefficient",
"LINEARREG": "Linear Regression",
"LINEARREG_ANGLE": "Linear Regression Angle",
"LINEARREG_INTERCEPT": "Linear Regression Intercept",
"LINEARREG_SLOPE": "Linear Regression Slope",
"STDDEV": "Standard Deviation",
"TSF": "Time Series Forecast",
"VAR": "Variance",
}
# Total: 60 indicators supported in both systems
# Custom indicators (TradingView indicators implemented in our backend)
CUSTOM_TO_TV_NAMES = {
"VWAP": "VWAP",
"VWMA": "VWMA",
"HMA": "Hull Moving Average",
"SUPERTREND": "SuperTrend",
"DONCHIAN": "Donchian Channels",
"KELTNER": "Keltner Channels",
"CMF": "Chaikin Money Flow",
"VORTEX": "Vortex Indicator",
"AO": "Awesome Oscillator",
"AC": "Accelerator Oscillator",
"CHOP": "Choppiness Index",
"MASS": "Mass Index",
}
# Combined mapping (TA-Lib + Custom)
ALL_BACKEND_TO_TV_NAMES = {**TALIB_TO_TV_NAMES, **CUSTOM_TO_TV_NAMES}
# Total: 72 indicators (60 TA-Lib + 12 Custom)
# Reverse mapping
TV_TO_TALIB_NAMES = {v: k for k, v in TALIB_TO_TV_NAMES.items()}
TV_TO_CUSTOM_NAMES = {v: k for k, v in CUSTOM_TO_TV_NAMES.items()}
TV_TO_BACKEND_NAMES = {v: k for k, v in ALL_BACKEND_TO_TV_NAMES.items()}
def get_tv_indicator_name(talib_name: str) -> str:
"""
Convert TA-Lib indicator name to TradingView indicator name.
Args:
talib_name: TA-Lib indicator name (e.g., 'RSI')
Returns:
TradingView indicator name
"""
return TALIB_TO_TV_NAMES.get(talib_name, talib_name)
def get_talib_indicator_name(tv_name: str) -> Optional[str]:
"""
Convert TradingView indicator name to TA-Lib indicator name.
Args:
tv_name: TradingView indicator name
Returns:
TA-Lib indicator name or None if not mapped
"""
return TV_TO_TALIB_NAMES.get(tv_name)
def convert_talib_params_to_tv_inputs(
talib_name: str,
talib_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert TA-Lib parameters to TradingView input format.
Args:
talib_name: TA-Lib indicator name
talib_params: TA-Lib parameter dictionary
Returns:
TradingView inputs dictionary
"""
tv_inputs = {}
# Common parameter mappings
param_mapping = {
"timeperiod": "length",
"fastperiod": "fastLength",
"slowperiod": "slowLength",
"signalperiod": "signalLength",
"nbdevup": "mult", # Standard deviations for upper band
"nbdevdn": "mult", # Standard deviations for lower band
"fastlimit": "fastLimit",
"slowlimit": "slowLimit",
"acceleration": "start",
"maximum": "increment",
"fastk_period": "kPeriod",
"slowk_period": "kPeriod",
"slowd_period": "dPeriod",
"fastd_period": "dPeriod",
"matype": "maType",
}
# Special handling for specific indicators
if talib_name == "BBANDS":
# Bollinger Bands
tv_inputs["length"] = talib_params.get("timeperiod", 20)
tv_inputs["mult"] = talib_params.get("nbdevup", 2)
tv_inputs["source"] = "close"
elif talib_name == "MACD":
# MACD
tv_inputs["fastLength"] = talib_params.get("fastperiod", 12)
tv_inputs["slowLength"] = talib_params.get("slowperiod", 26)
tv_inputs["signalLength"] = talib_params.get("signalperiod", 9)
tv_inputs["source"] = "close"
elif talib_name == "RSI":
# RSI
tv_inputs["length"] = talib_params.get("timeperiod", 14)
tv_inputs["source"] = "close"
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
# Moving averages
tv_inputs["length"] = talib_params.get("timeperiod", 14)
tv_inputs["source"] = "close"
elif talib_name == "STOCH":
# Stochastic
tv_inputs["kPeriod"] = talib_params.get("fastk_period", 14)
tv_inputs["dPeriod"] = talib_params.get("slowd_period", 3)
tv_inputs["smoothK"] = talib_params.get("slowk_period", 3)
elif talib_name == "ATR":
# ATR
tv_inputs["length"] = talib_params.get("timeperiod", 14)
elif talib_name == "CCI":
# CCI
tv_inputs["length"] = talib_params.get("timeperiod", 20)
else:
# Generic parameter conversion
for talib_param, value in talib_params.items():
tv_param = param_mapping.get(talib_param, talib_param)
tv_inputs[tv_param] = value
logger.debug(f"Converted TA-Lib params for {talib_name}: {talib_params} -> TV inputs: {tv_inputs}")
return tv_inputs
def convert_tv_inputs_to_talib_params(
tv_name: str,
tv_inputs: Dict[str, Any]
) -> Tuple[Optional[str], Dict[str, Any]]:
"""
Convert TradingView inputs to TA-Lib parameters.
Args:
tv_name: TradingView indicator name
tv_inputs: TradingView inputs dictionary
Returns:
Tuple of (talib_name, talib_params)
"""
talib_name = get_talib_indicator_name(tv_name)
if not talib_name:
logger.warning(f"No TA-Lib mapping for TradingView indicator: {tv_name}")
return None, {}
talib_params = {}
# Reverse parameter mappings
reverse_mapping = {
"length": "timeperiod",
"fastLength": "fastperiod",
"slowLength": "slowperiod",
"signalLength": "signalperiod",
"mult": "nbdevup", # Use same for both up and down
"fastLimit": "fastlimit",
"slowLimit": "slowlimit",
"start": "acceleration",
"increment": "maximum",
"kPeriod": "fastk_period",
"dPeriod": "slowd_period",
"smoothK": "slowk_period",
"maType": "matype",
}
# Special handling for specific indicators
if talib_name == "BBANDS":
# Bollinger Bands
talib_params["timeperiod"] = tv_inputs.get("length", 20)
talib_params["nbdevup"] = tv_inputs.get("mult", 2)
talib_params["nbdevdn"] = tv_inputs.get("mult", 2)
talib_params["matype"] = 0 # SMA
elif talib_name == "MACD":
# MACD
talib_params["fastperiod"] = tv_inputs.get("fastLength", 12)
talib_params["slowperiod"] = tv_inputs.get("slowLength", 26)
talib_params["signalperiod"] = tv_inputs.get("signalLength", 9)
elif talib_name == "RSI":
# RSI
talib_params["timeperiod"] = tv_inputs.get("length", 14)
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
# Moving averages
talib_params["timeperiod"] = tv_inputs.get("length", 14)
elif talib_name == "STOCH":
# Stochastic
talib_params["fastk_period"] = tv_inputs.get("kPeriod", 14)
talib_params["slowd_period"] = tv_inputs.get("dPeriod", 3)
talib_params["slowk_period"] = tv_inputs.get("smoothK", 3)
talib_params["slowk_matype"] = 0 # SMA
talib_params["slowd_matype"] = 0 # SMA
elif talib_name == "ATR":
# ATR
talib_params["timeperiod"] = tv_inputs.get("length", 14)
elif talib_name == "CCI":
# CCI
talib_params["timeperiod"] = tv_inputs.get("length", 20)
else:
# Generic parameter conversion
for tv_param, value in tv_inputs.items():
if tv_param == "source":
continue # Skip source parameter
talib_param = reverse_mapping.get(tv_param, tv_param)
talib_params[talib_param] = value
logger.debug(f"Converted TV inputs for {tv_name}: {tv_inputs} -> TA-Lib {talib_name} params: {talib_params}")
return talib_name, talib_params
def is_indicator_supported(talib_name: str) -> bool:
"""
Check if a TA-Lib indicator is supported in TradingView.
Args:
talib_name: TA-Lib indicator name
Returns:
True if supported
"""
return talib_name in TALIB_TO_TV_NAMES
def get_supported_indicators() -> List[str]:
"""
Get list of supported TA-Lib indicators.
Returns:
List of TA-Lib indicator names
"""
return list(TALIB_TO_TV_NAMES.keys())
def get_supported_indicator_count() -> int:
"""
Get count of supported indicators.
Returns:
Number of indicators supported in both systems (TA-Lib + Custom)
"""
return len(ALL_BACKEND_TO_TV_NAMES)
def is_custom_indicator(indicator_name: str) -> bool:
"""
Check if an indicator is a custom implementation (not TA-Lib).
Args:
indicator_name: Indicator name
Returns:
True if custom indicator
"""
return indicator_name in CUSTOM_TO_TV_NAMES
def get_backend_indicator_name(tv_name: str) -> Optional[str]:
"""
Get backend indicator name from TradingView name (TA-Lib or custom).
Args:
tv_name: TradingView indicator name
Returns:
Backend indicator name or None if not mapped
"""
return TV_TO_BACKEND_NAMES.get(tv_name)