""" Pipeline execution engine for composable indicators. Manages DAG construction, dependency resolution, incremental updates, and efficient data flow through indicator chains. """ import logging from collections import defaultdict, deque from typing import Any, Dict, List, Optional, Set, Tuple, Union from datasource.base import DataSource from datasource.schema import ColumnInfo from .base import DataSourceAdapter, Indicator from .schema import ComputeContext, ComputeResult logger = logging.getLogger(__name__) class PipelineNode: """ A node in the pipeline DAG. Can be either a DataSource adapter or an Indicator instance. """ def __init__( self, node_id: str, node: Union[DataSourceAdapter, Indicator], dependencies: List[str] ): """ Create a pipeline node. Args: node_id: Unique identifier for this node node: The DataSourceAdapter or Indicator instance dependencies: List of node_ids this node depends on """ self.node_id = node_id self.node = node self.dependencies = dependencies self.output_columns: List[str] = [] self.cached_data: List[Dict[str, Any]] = [] def is_datasource(self) -> bool: """Check if this node is a DataSource adapter.""" return isinstance(self.node, DataSourceAdapter) def is_indicator(self) -> bool: """Check if this node is an Indicator.""" return isinstance(self.node, Indicator) def __repr__(self) -> str: return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})" class Pipeline: """ Execution engine for indicator DAGs. Manages: - DAG construction and validation - Topological sorting for execution order - Data flow and caching - Incremental updates (only recompute what changed) - Schema validation """ def __init__(self, datasource_registry): """ Initialize a pipeline. Args: datasource_registry: DataSourceRegistry for resolving data sources """ self.datasource_registry = datasource_registry self.nodes: Dict[str, PipelineNode] = {} self.execution_order: List[str] = [] self._dirty_nodes: Set[str] = set() def add_datasource( self, node_id: str, datasource_name: str, symbol: str, resolution: str ) -> None: """ Add a DataSource to the pipeline. Args: node_id: Unique identifier for this node datasource_name: Name of the datasource in the registry symbol: Symbol to query resolution: Time resolution Raises: ValueError: If node_id already exists or datasource not found """ if node_id in self.nodes: raise ValueError(f"Node '{node_id}' already exists in pipeline") datasource = self.datasource_registry.get(datasource_name) if not datasource: raise ValueError(f"DataSource '{datasource_name}' not found in registry") adapter = DataSourceAdapter(datasource_name, symbol, resolution) node = PipelineNode(node_id, adapter, dependencies=[]) self.nodes[node_id] = node self._invalidate_execution_order() logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}") def add_indicator( self, node_id: str, indicator: Indicator, input_node_ids: List[str] ) -> None: """ Add an Indicator to the pipeline. Args: node_id: Unique identifier for this node indicator: Indicator instance input_node_ids: List of node IDs providing input data Raises: ValueError: If node_id already exists, dependencies not found, or schema mismatch """ if node_id in self.nodes: raise ValueError(f"Node '{node_id}' already exists in pipeline") # Validate dependencies exist for dep_id in input_node_ids: if dep_id not in self.nodes: raise ValueError(f"Dependency node '{dep_id}' not found in pipeline") # TODO: Validate input schema matches available columns from dependencies # This requires merging output schemas from all input nodes node = PipelineNode(node_id, indicator, dependencies=input_node_ids) self.nodes[node_id] = node self._invalidate_execution_order() logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}") def remove_node(self, node_id: str) -> None: """ Remove a node from the pipeline. Args: node_id: Node to remove Raises: ValueError: If other nodes depend on this node """ if node_id not in self.nodes: return # Check for dependent nodes dependents = [ n.node_id for n in self.nodes.values() if node_id in n.dependencies ] if dependents: raise ValueError( f"Cannot remove node '{node_id}': nodes {dependents} depend on it" ) del self.nodes[node_id] self._invalidate_execution_order() logger.info(f"Removed node '{node_id}' from pipeline") def _invalidate_execution_order(self) -> None: """Mark execution order as needing recomputation.""" self.execution_order = [] def _compute_execution_order(self) -> List[str]: """ Compute topological sort of the DAG. Returns: List of node IDs in execution order Raises: ValueError: If DAG contains cycles """ if self.execution_order: return self.execution_order # Kahn's algorithm for topological sort in_degree = {node_id: 0 for node_id in self.nodes} for node in self.nodes.values(): for dep in node.dependencies: in_degree[node.node_id] += 1 queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) result = [] while queue: node_id = queue.popleft() result.append(node_id) # Find all nodes that depend on this one for other_node in self.nodes.values(): if node_id in other_node.dependencies: in_degree[other_node.node_id] -= 1 if in_degree[other_node.node_id] == 0: queue.append(other_node.node_id) if len(result) != len(self.nodes): raise ValueError("Pipeline contains cycles") self.execution_order = result logger.debug(f"Computed execution order: {result}") return result def execute( self, datasource_data: Dict[str, List[Dict[str, Any]]], incremental: bool = False, updated_from_time: Optional[int] = None ) -> Dict[str, List[Dict[str, Any]]]: """ Execute the pipeline. Args: datasource_data: Mapping of DataSource node_id to input data incremental: Whether this is an incremental update updated_from_time: Timestamp of earliest updated row (for incremental) Returns: Dictionary mapping node_id to output data (all nodes) Raises: ValueError: If required datasource data is missing """ execution_order = self._compute_execution_order() results: Dict[str, List[Dict[str, Any]]] = {} logger.info( f"Executing pipeline with {len(execution_order)} nodes " f"(incremental={incremental})" ) for node_id in execution_order: node = self.nodes[node_id] if node.is_datasource(): # DataSource node - get data from input if node_id not in datasource_data: raise ValueError( f"DataSource node '{node_id}' has no input data" ) results[node_id] = datasource_data[node_id] node.cached_data = results[node_id] logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows") elif node.is_indicator(): # Indicator node - compute from dependencies indicator = node.node # Merge input data from all dependencies input_data = self._merge_dependency_data(node.dependencies, results) # Create compute context context = ComputeContext( data=input_data, is_incremental=incremental, updated_from_time=updated_from_time ) # Execute indicator logger.debug( f"Computing indicator '{node_id}' with {len(input_data)} input rows" ) compute_result = indicator.compute(context) # Merge result with input data (adding prefixed columns) output_data = compute_result.merge_with_prefix( indicator.instance_name, input_data ) results[node_id] = output_data node.cached_data = output_data logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows") logger.info(f"Pipeline execution complete: {len(results)} nodes processed") return results def _merge_dependency_data( self, dependency_ids: List[str], results: Dict[str, List[Dict[str, Any]]] ) -> List[Dict[str, Any]]: """ Merge data from multiple dependency nodes. Data is merged by time, with later dependencies overwriting earlier ones for conflicting column names. Args: dependency_ids: List of node IDs to merge results: Current execution results Returns: Merged data rows """ if not dependency_ids: return [] if len(dependency_ids) == 1: return results[dependency_ids[0]] # Build time-indexed data from first dependency merged: Dict[int, Dict[str, Any]] = {} for row in results[dependency_ids[0]]: merged[row["time"]] = row.copy() # Merge in additional dependencies for dep_id in dependency_ids[1:]: for row in results[dep_id]: time_key = row["time"] if time_key in merged: # Merge columns (later dependencies win) merged[time_key].update(row) else: # New timestamp merged[time_key] = row.copy() # Sort by time and return sorted_times = sorted(merged.keys()) return [merged[t] for t in sorted_times] def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]: """ Get cached output data for a specific node. Args: node_id: Node identifier Returns: Cached data or None if not available """ node = self.nodes.get(node_id) return node.cached_data if node else None def get_output_schema(self, node_id: str) -> List[ColumnInfo]: """ Get the output schema for a specific node. Args: node_id: Node identifier Returns: List of ColumnInfo describing output columns Raises: ValueError: If node not found """ node = self.nodes.get(node_id) if not node: raise ValueError(f"Node '{node_id}' not found") if node.is_datasource(): # Would need to query the actual datasource at runtime # For now, return empty - this requires integration with DataSource return [] elif node.is_indicator(): indicator = node.node output_schema = indicator.get_output_schema(**indicator.params) prefixed_schema = output_schema.with_prefix(indicator.instance_name) return prefixed_schema.columns return [] def validate_pipeline(self) -> Tuple[bool, Optional[str]]: """ Validate the entire pipeline for correctness. Checks: - No cycles (already checked in execution order) - All dependencies exist (already checked in add_indicator) - Input schemas match output schemas (TODO) Returns: Tuple of (is_valid, error_message) """ try: self._compute_execution_order() return True, None except ValueError as e: return False, str(e) def get_node_count(self) -> int: """Get the number of nodes in the pipeline.""" return len(self.nodes) def get_indicator_count(self) -> int: """Get the number of indicator nodes in the pipeline.""" return sum(1 for node in self.nodes.values() if node.is_indicator()) def get_datasource_count(self) -> int: """Get the number of datasource nodes in the pipeline.""" return sum(1 for node in self.nodes.values() if node.is_datasource()) def describe(self) -> Dict[str, Any]: """ Get a detailed description of the pipeline structure. Returns: Dictionary with pipeline metadata and structure """ return { "node_count": self.get_node_count(), "datasource_count": self.get_datasource_count(), "indicator_count": self.get_indicator_count(), "nodes": [ { "id": node.node_id, "type": "datasource" if node.is_datasource() else "indicator", "node": str(node.node), "dependencies": node.dependencies, "cached_rows": len(node.cached_data) } for node in self.nodes.values() ], "execution_order": self.execution_order or self._compute_execution_order(), "is_valid": self.validate_pipeline()[0] }