440 lines
14 KiB
Python
440 lines
14 KiB
Python
"""
|
|
Pipeline execution engine for composable indicators.
|
|
|
|
Manages DAG construction, dependency resolution, incremental updates,
|
|
and efficient data flow through indicator chains.
|
|
"""
|
|
|
|
import logging
|
|
from collections import defaultdict, deque
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
from datasource.base import DataSource
|
|
from datasource.schema import ColumnInfo
|
|
|
|
from .base import DataSourceAdapter, Indicator
|
|
from .schema import ComputeContext, ComputeResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PipelineNode:
|
|
"""
|
|
A node in the pipeline DAG.
|
|
|
|
Can be either a DataSource adapter or an Indicator instance.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
node_id: str,
|
|
node: Union[DataSourceAdapter, Indicator],
|
|
dependencies: List[str]
|
|
):
|
|
"""
|
|
Create a pipeline node.
|
|
|
|
Args:
|
|
node_id: Unique identifier for this node
|
|
node: The DataSourceAdapter or Indicator instance
|
|
dependencies: List of node_ids this node depends on
|
|
"""
|
|
self.node_id = node_id
|
|
self.node = node
|
|
self.dependencies = dependencies
|
|
self.output_columns: List[str] = []
|
|
self.cached_data: List[Dict[str, Any]] = []
|
|
|
|
def is_datasource(self) -> bool:
|
|
"""Check if this node is a DataSource adapter."""
|
|
return isinstance(self.node, DataSourceAdapter)
|
|
|
|
def is_indicator(self) -> bool:
|
|
"""Check if this node is an Indicator."""
|
|
return isinstance(self.node, Indicator)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
|
|
|
|
|
|
class Pipeline:
|
|
"""
|
|
Execution engine for indicator DAGs.
|
|
|
|
Manages:
|
|
- DAG construction and validation
|
|
- Topological sorting for execution order
|
|
- Data flow and caching
|
|
- Incremental updates (only recompute what changed)
|
|
- Schema validation
|
|
"""
|
|
|
|
def __init__(self, datasource_registry):
|
|
"""
|
|
Initialize a pipeline.
|
|
|
|
Args:
|
|
datasource_registry: DataSourceRegistry for resolving data sources
|
|
"""
|
|
self.datasource_registry = datasource_registry
|
|
self.nodes: Dict[str, PipelineNode] = {}
|
|
self.execution_order: List[str] = []
|
|
self._dirty_nodes: Set[str] = set()
|
|
|
|
def add_datasource(
|
|
self,
|
|
node_id: str,
|
|
datasource_name: str,
|
|
symbol: str,
|
|
resolution: str
|
|
) -> None:
|
|
"""
|
|
Add a DataSource to the pipeline.
|
|
|
|
Args:
|
|
node_id: Unique identifier for this node
|
|
datasource_name: Name of the datasource in the registry
|
|
symbol: Symbol to query
|
|
resolution: Time resolution
|
|
|
|
Raises:
|
|
ValueError: If node_id already exists or datasource not found
|
|
"""
|
|
if node_id in self.nodes:
|
|
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
|
|
|
datasource = self.datasource_registry.get(datasource_name)
|
|
if not datasource:
|
|
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
|
|
|
|
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
|
|
node = PipelineNode(node_id, adapter, dependencies=[])
|
|
|
|
self.nodes[node_id] = node
|
|
self._invalidate_execution_order()
|
|
|
|
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
|
|
|
|
def add_indicator(
|
|
self,
|
|
node_id: str,
|
|
indicator: Indicator,
|
|
input_node_ids: List[str]
|
|
) -> None:
|
|
"""
|
|
Add an Indicator to the pipeline.
|
|
|
|
Args:
|
|
node_id: Unique identifier for this node
|
|
indicator: Indicator instance
|
|
input_node_ids: List of node IDs providing input data
|
|
|
|
Raises:
|
|
ValueError: If node_id already exists, dependencies not found, or schema mismatch
|
|
"""
|
|
if node_id in self.nodes:
|
|
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
|
|
|
# Validate dependencies exist
|
|
for dep_id in input_node_ids:
|
|
if dep_id not in self.nodes:
|
|
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
|
|
|
|
# TODO: Validate input schema matches available columns from dependencies
|
|
# This requires merging output schemas from all input nodes
|
|
|
|
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
|
|
self.nodes[node_id] = node
|
|
self._invalidate_execution_order()
|
|
|
|
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
|
|
|
|
def remove_node(self, node_id: str) -> None:
|
|
"""
|
|
Remove a node from the pipeline.
|
|
|
|
Args:
|
|
node_id: Node to remove
|
|
|
|
Raises:
|
|
ValueError: If other nodes depend on this node
|
|
"""
|
|
if node_id not in self.nodes:
|
|
return
|
|
|
|
# Check for dependent nodes
|
|
dependents = [
|
|
n.node_id for n in self.nodes.values()
|
|
if node_id in n.dependencies
|
|
]
|
|
|
|
if dependents:
|
|
raise ValueError(
|
|
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
|
|
)
|
|
|
|
del self.nodes[node_id]
|
|
self._invalidate_execution_order()
|
|
|
|
logger.info(f"Removed node '{node_id}' from pipeline")
|
|
|
|
def _invalidate_execution_order(self) -> None:
|
|
"""Mark execution order as needing recomputation."""
|
|
self.execution_order = []
|
|
|
|
def _compute_execution_order(self) -> List[str]:
|
|
"""
|
|
Compute topological sort of the DAG.
|
|
|
|
Returns:
|
|
List of node IDs in execution order
|
|
|
|
Raises:
|
|
ValueError: If DAG contains cycles
|
|
"""
|
|
if self.execution_order:
|
|
return self.execution_order
|
|
|
|
# Kahn's algorithm for topological sort
|
|
in_degree = {node_id: 0 for node_id in self.nodes}
|
|
for node in self.nodes.values():
|
|
for dep in node.dependencies:
|
|
in_degree[node.node_id] += 1
|
|
|
|
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
|
|
result = []
|
|
|
|
while queue:
|
|
node_id = queue.popleft()
|
|
result.append(node_id)
|
|
|
|
# Find all nodes that depend on this one
|
|
for other_node in self.nodes.values():
|
|
if node_id in other_node.dependencies:
|
|
in_degree[other_node.node_id] -= 1
|
|
if in_degree[other_node.node_id] == 0:
|
|
queue.append(other_node.node_id)
|
|
|
|
if len(result) != len(self.nodes):
|
|
raise ValueError("Pipeline contains cycles")
|
|
|
|
self.execution_order = result
|
|
logger.debug(f"Computed execution order: {result}")
|
|
return result
|
|
|
|
def execute(
|
|
self,
|
|
datasource_data: Dict[str, List[Dict[str, Any]]],
|
|
incremental: bool = False,
|
|
updated_from_time: Optional[int] = None
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
"""
|
|
Execute the pipeline.
|
|
|
|
Args:
|
|
datasource_data: Mapping of DataSource node_id to input data
|
|
incremental: Whether this is an incremental update
|
|
updated_from_time: Timestamp of earliest updated row (for incremental)
|
|
|
|
Returns:
|
|
Dictionary mapping node_id to output data (all nodes)
|
|
|
|
Raises:
|
|
ValueError: If required datasource data is missing
|
|
"""
|
|
execution_order = self._compute_execution_order()
|
|
results: Dict[str, List[Dict[str, Any]]] = {}
|
|
|
|
logger.info(
|
|
f"Executing pipeline with {len(execution_order)} nodes "
|
|
f"(incremental={incremental})"
|
|
)
|
|
|
|
for node_id in execution_order:
|
|
node = self.nodes[node_id]
|
|
|
|
if node.is_datasource():
|
|
# DataSource node - get data from input
|
|
if node_id not in datasource_data:
|
|
raise ValueError(
|
|
f"DataSource node '{node_id}' has no input data"
|
|
)
|
|
results[node_id] = datasource_data[node_id]
|
|
node.cached_data = results[node_id]
|
|
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
|
|
|
|
elif node.is_indicator():
|
|
# Indicator node - compute from dependencies
|
|
indicator = node.node
|
|
|
|
# Merge input data from all dependencies
|
|
input_data = self._merge_dependency_data(node.dependencies, results)
|
|
|
|
# Create compute context
|
|
context = ComputeContext(
|
|
data=input_data,
|
|
is_incremental=incremental,
|
|
updated_from_time=updated_from_time
|
|
)
|
|
|
|
# Execute indicator
|
|
logger.debug(
|
|
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
|
|
)
|
|
compute_result = indicator.compute(context)
|
|
|
|
# Merge result with input data (adding prefixed columns)
|
|
output_data = compute_result.merge_with_prefix(
|
|
indicator.instance_name,
|
|
input_data
|
|
)
|
|
|
|
results[node_id] = output_data
|
|
node.cached_data = output_data
|
|
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
|
|
|
|
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
|
|
return results
|
|
|
|
def _merge_dependency_data(
|
|
self,
|
|
dependency_ids: List[str],
|
|
results: Dict[str, List[Dict[str, Any]]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Merge data from multiple dependency nodes.
|
|
|
|
Data is merged by time, with later dependencies overwriting earlier ones
|
|
for conflicting column names.
|
|
|
|
Args:
|
|
dependency_ids: List of node IDs to merge
|
|
results: Current execution results
|
|
|
|
Returns:
|
|
Merged data rows
|
|
"""
|
|
if not dependency_ids:
|
|
return []
|
|
|
|
if len(dependency_ids) == 1:
|
|
return results[dependency_ids[0]]
|
|
|
|
# Build time-indexed data from first dependency
|
|
merged: Dict[int, Dict[str, Any]] = {}
|
|
for row in results[dependency_ids[0]]:
|
|
merged[row["time"]] = row.copy()
|
|
|
|
# Merge in additional dependencies
|
|
for dep_id in dependency_ids[1:]:
|
|
for row in results[dep_id]:
|
|
time_key = row["time"]
|
|
if time_key in merged:
|
|
# Merge columns (later dependencies win)
|
|
merged[time_key].update(row)
|
|
else:
|
|
# New timestamp
|
|
merged[time_key] = row.copy()
|
|
|
|
# Sort by time and return
|
|
sorted_times = sorted(merged.keys())
|
|
return [merged[t] for t in sorted_times]
|
|
|
|
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
|
|
"""
|
|
Get cached output data for a specific node.
|
|
|
|
Args:
|
|
node_id: Node identifier
|
|
|
|
Returns:
|
|
Cached data or None if not available
|
|
"""
|
|
node = self.nodes.get(node_id)
|
|
return node.cached_data if node else None
|
|
|
|
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
|
|
"""
|
|
Get the output schema for a specific node.
|
|
|
|
Args:
|
|
node_id: Node identifier
|
|
|
|
Returns:
|
|
List of ColumnInfo describing output columns
|
|
|
|
Raises:
|
|
ValueError: If node not found
|
|
"""
|
|
node = self.nodes.get(node_id)
|
|
if not node:
|
|
raise ValueError(f"Node '{node_id}' not found")
|
|
|
|
if node.is_datasource():
|
|
# Would need to query the actual datasource at runtime
|
|
# For now, return empty - this requires integration with DataSource
|
|
return []
|
|
|
|
elif node.is_indicator():
|
|
indicator = node.node
|
|
output_schema = indicator.get_output_schema(**indicator.params)
|
|
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
|
|
return prefixed_schema.columns
|
|
|
|
return []
|
|
|
|
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate the entire pipeline for correctness.
|
|
|
|
Checks:
|
|
- No cycles (already checked in execution order)
|
|
- All dependencies exist (already checked in add_indicator)
|
|
- Input schemas match output schemas (TODO)
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
try:
|
|
self._compute_execution_order()
|
|
return True, None
|
|
except ValueError as e:
|
|
return False, str(e)
|
|
|
|
def get_node_count(self) -> int:
|
|
"""Get the number of nodes in the pipeline."""
|
|
return len(self.nodes)
|
|
|
|
def get_indicator_count(self) -> int:
|
|
"""Get the number of indicator nodes in the pipeline."""
|
|
return sum(1 for node in self.nodes.values() if node.is_indicator())
|
|
|
|
def get_datasource_count(self) -> int:
|
|
"""Get the number of datasource nodes in the pipeline."""
|
|
return sum(1 for node in self.nodes.values() if node.is_datasource())
|
|
|
|
def describe(self) -> Dict[str, Any]:
|
|
"""
|
|
Get a detailed description of the pipeline structure.
|
|
|
|
Returns:
|
|
Dictionary with pipeline metadata and structure
|
|
"""
|
|
return {
|
|
"node_count": self.get_node_count(),
|
|
"datasource_count": self.get_datasource_count(),
|
|
"indicator_count": self.get_indicator_count(),
|
|
"nodes": [
|
|
{
|
|
"id": node.node_id,
|
|
"type": "datasource" if node.is_datasource() else "indicator",
|
|
"node": str(node.node),
|
|
"dependencies": node.dependencies,
|
|
"cached_rows": len(node.cached_data)
|
|
}
|
|
for node in self.nodes.values()
|
|
],
|
|
"execution_order": self.execution_order or self._compute_execution_order(),
|
|
"is_valid": self.validate_pipeline()[0]
|
|
}
|