Files
ai/backend.old/src/indicator/pipeline.py
2026-03-11 18:47:11 -04:00

440 lines
14 KiB
Python

"""
Pipeline execution engine for composable indicators.
Manages DAG construction, dependency resolution, incremental updates,
and efficient data flow through indicator chains.
"""
import logging
from collections import defaultdict, deque
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from datasource.base import DataSource
from datasource.schema import ColumnInfo
from .base import DataSourceAdapter, Indicator
from .schema import ComputeContext, ComputeResult
logger = logging.getLogger(__name__)
class PipelineNode:
"""
A node in the pipeline DAG.
Can be either a DataSource adapter or an Indicator instance.
"""
def __init__(
self,
node_id: str,
node: Union[DataSourceAdapter, Indicator],
dependencies: List[str]
):
"""
Create a pipeline node.
Args:
node_id: Unique identifier for this node
node: The DataSourceAdapter or Indicator instance
dependencies: List of node_ids this node depends on
"""
self.node_id = node_id
self.node = node
self.dependencies = dependencies
self.output_columns: List[str] = []
self.cached_data: List[Dict[str, Any]] = []
def is_datasource(self) -> bool:
"""Check if this node is a DataSource adapter."""
return isinstance(self.node, DataSourceAdapter)
def is_indicator(self) -> bool:
"""Check if this node is an Indicator."""
return isinstance(self.node, Indicator)
def __repr__(self) -> str:
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
class Pipeline:
"""
Execution engine for indicator DAGs.
Manages:
- DAG construction and validation
- Topological sorting for execution order
- Data flow and caching
- Incremental updates (only recompute what changed)
- Schema validation
"""
def __init__(self, datasource_registry):
"""
Initialize a pipeline.
Args:
datasource_registry: DataSourceRegistry for resolving data sources
"""
self.datasource_registry = datasource_registry
self.nodes: Dict[str, PipelineNode] = {}
self.execution_order: List[str] = []
self._dirty_nodes: Set[str] = set()
def add_datasource(
self,
node_id: str,
datasource_name: str,
symbol: str,
resolution: str
) -> None:
"""
Add a DataSource to the pipeline.
Args:
node_id: Unique identifier for this node
datasource_name: Name of the datasource in the registry
symbol: Symbol to query
resolution: Time resolution
Raises:
ValueError: If node_id already exists or datasource not found
"""
if node_id in self.nodes:
raise ValueError(f"Node '{node_id}' already exists in pipeline")
datasource = self.datasource_registry.get(datasource_name)
if not datasource:
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
node = PipelineNode(node_id, adapter, dependencies=[])
self.nodes[node_id] = node
self._invalidate_execution_order()
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
def add_indicator(
self,
node_id: str,
indicator: Indicator,
input_node_ids: List[str]
) -> None:
"""
Add an Indicator to the pipeline.
Args:
node_id: Unique identifier for this node
indicator: Indicator instance
input_node_ids: List of node IDs providing input data
Raises:
ValueError: If node_id already exists, dependencies not found, or schema mismatch
"""
if node_id in self.nodes:
raise ValueError(f"Node '{node_id}' already exists in pipeline")
# Validate dependencies exist
for dep_id in input_node_ids:
if dep_id not in self.nodes:
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
# TODO: Validate input schema matches available columns from dependencies
# This requires merging output schemas from all input nodes
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
self.nodes[node_id] = node
self._invalidate_execution_order()
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
def remove_node(self, node_id: str) -> None:
"""
Remove a node from the pipeline.
Args:
node_id: Node to remove
Raises:
ValueError: If other nodes depend on this node
"""
if node_id not in self.nodes:
return
# Check for dependent nodes
dependents = [
n.node_id for n in self.nodes.values()
if node_id in n.dependencies
]
if dependents:
raise ValueError(
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
)
del self.nodes[node_id]
self._invalidate_execution_order()
logger.info(f"Removed node '{node_id}' from pipeline")
def _invalidate_execution_order(self) -> None:
"""Mark execution order as needing recomputation."""
self.execution_order = []
def _compute_execution_order(self) -> List[str]:
"""
Compute topological sort of the DAG.
Returns:
List of node IDs in execution order
Raises:
ValueError: If DAG contains cycles
"""
if self.execution_order:
return self.execution_order
# Kahn's algorithm for topological sort
in_degree = {node_id: 0 for node_id in self.nodes}
for node in self.nodes.values():
for dep in node.dependencies:
in_degree[node.node_id] += 1
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
result = []
while queue:
node_id = queue.popleft()
result.append(node_id)
# Find all nodes that depend on this one
for other_node in self.nodes.values():
if node_id in other_node.dependencies:
in_degree[other_node.node_id] -= 1
if in_degree[other_node.node_id] == 0:
queue.append(other_node.node_id)
if len(result) != len(self.nodes):
raise ValueError("Pipeline contains cycles")
self.execution_order = result
logger.debug(f"Computed execution order: {result}")
return result
def execute(
self,
datasource_data: Dict[str, List[Dict[str, Any]]],
incremental: bool = False,
updated_from_time: Optional[int] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
Execute the pipeline.
Args:
datasource_data: Mapping of DataSource node_id to input data
incremental: Whether this is an incremental update
updated_from_time: Timestamp of earliest updated row (for incremental)
Returns:
Dictionary mapping node_id to output data (all nodes)
Raises:
ValueError: If required datasource data is missing
"""
execution_order = self._compute_execution_order()
results: Dict[str, List[Dict[str, Any]]] = {}
logger.info(
f"Executing pipeline with {len(execution_order)} nodes "
f"(incremental={incremental})"
)
for node_id in execution_order:
node = self.nodes[node_id]
if node.is_datasource():
# DataSource node - get data from input
if node_id not in datasource_data:
raise ValueError(
f"DataSource node '{node_id}' has no input data"
)
results[node_id] = datasource_data[node_id]
node.cached_data = results[node_id]
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
elif node.is_indicator():
# Indicator node - compute from dependencies
indicator = node.node
# Merge input data from all dependencies
input_data = self._merge_dependency_data(node.dependencies, results)
# Create compute context
context = ComputeContext(
data=input_data,
is_incremental=incremental,
updated_from_time=updated_from_time
)
# Execute indicator
logger.debug(
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
)
compute_result = indicator.compute(context)
# Merge result with input data (adding prefixed columns)
output_data = compute_result.merge_with_prefix(
indicator.instance_name,
input_data
)
results[node_id] = output_data
node.cached_data = output_data
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
return results
def _merge_dependency_data(
self,
dependency_ids: List[str],
results: Dict[str, List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""
Merge data from multiple dependency nodes.
Data is merged by time, with later dependencies overwriting earlier ones
for conflicting column names.
Args:
dependency_ids: List of node IDs to merge
results: Current execution results
Returns:
Merged data rows
"""
if not dependency_ids:
return []
if len(dependency_ids) == 1:
return results[dependency_ids[0]]
# Build time-indexed data from first dependency
merged: Dict[int, Dict[str, Any]] = {}
for row in results[dependency_ids[0]]:
merged[row["time"]] = row.copy()
# Merge in additional dependencies
for dep_id in dependency_ids[1:]:
for row in results[dep_id]:
time_key = row["time"]
if time_key in merged:
# Merge columns (later dependencies win)
merged[time_key].update(row)
else:
# New timestamp
merged[time_key] = row.copy()
# Sort by time and return
sorted_times = sorted(merged.keys())
return [merged[t] for t in sorted_times]
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
"""
Get cached output data for a specific node.
Args:
node_id: Node identifier
Returns:
Cached data or None if not available
"""
node = self.nodes.get(node_id)
return node.cached_data if node else None
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
"""
Get the output schema for a specific node.
Args:
node_id: Node identifier
Returns:
List of ColumnInfo describing output columns
Raises:
ValueError: If node not found
"""
node = self.nodes.get(node_id)
if not node:
raise ValueError(f"Node '{node_id}' not found")
if node.is_datasource():
# Would need to query the actual datasource at runtime
# For now, return empty - this requires integration with DataSource
return []
elif node.is_indicator():
indicator = node.node
output_schema = indicator.get_output_schema(**indicator.params)
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
return prefixed_schema.columns
return []
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
"""
Validate the entire pipeline for correctness.
Checks:
- No cycles (already checked in execution order)
- All dependencies exist (already checked in add_indicator)
- Input schemas match output schemas (TODO)
Returns:
Tuple of (is_valid, error_message)
"""
try:
self._compute_execution_order()
return True, None
except ValueError as e:
return False, str(e)
def get_node_count(self) -> int:
"""Get the number of nodes in the pipeline."""
return len(self.nodes)
def get_indicator_count(self) -> int:
"""Get the number of indicator nodes in the pipeline."""
return sum(1 for node in self.nodes.values() if node.is_indicator())
def get_datasource_count(self) -> int:
"""Get the number of datasource nodes in the pipeline."""
return sum(1 for node in self.nodes.values() if node.is_datasource())
def describe(self) -> Dict[str, Any]:
"""
Get a detailed description of the pipeline structure.
Returns:
Dictionary with pipeline metadata and structure
"""
return {
"node_count": self.get_node_count(),
"datasource_count": self.get_datasource_count(),
"indicator_count": self.get_indicator_count(),
"nodes": [
{
"id": node.node_id,
"type": "datasource" if node.is_datasource() else "indicator",
"node": str(node.node),
"dependencies": node.dependencies,
"cached_rows": len(node.cached_data)
}
for node in self.nodes.values()
],
"execution_order": self.execution_order or self._compute_execution_order(),
"is_valid": self.validate_pipeline()[0]
}