110 lines
2.9 KiB
Python
110 lines
2.9 KiB
Python
"""
|
|
DataSource registry for managing multiple data sources.
|
|
"""
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
from .base import DataSource
|
|
from .schema import SearchResult, SymbolInfo
|
|
|
|
|
|
class DataSourceRegistry:
|
|
"""
|
|
Central registry for managing multiple DataSource instances.
|
|
|
|
Allows routing symbol queries to the appropriate data source and
|
|
aggregating search results across multiple sources.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._sources: Dict[str, DataSource] = {}
|
|
|
|
def register(self, name: str, source: DataSource) -> None:
|
|
"""
|
|
Register a data source.
|
|
|
|
Args:
|
|
name: Unique name for this data source
|
|
source: DataSource implementation
|
|
"""
|
|
self._sources[name] = source
|
|
|
|
def unregister(self, name: str) -> None:
|
|
"""
|
|
Unregister a data source.
|
|
|
|
Args:
|
|
name: Name of the data source to remove
|
|
"""
|
|
self._sources.pop(name, None)
|
|
|
|
def get(self, name: str) -> Optional[DataSource]:
|
|
"""
|
|
Get a registered data source by name.
|
|
|
|
Args:
|
|
name: Data source name
|
|
|
|
Returns:
|
|
DataSource instance or None if not found
|
|
"""
|
|
return self._sources.get(name)
|
|
|
|
def list_sources(self) -> List[str]:
|
|
"""
|
|
Get names of all registered data sources.
|
|
|
|
Returns:
|
|
List of data source names
|
|
"""
|
|
return list(self._sources.keys())
|
|
|
|
async def search_all(
|
|
self,
|
|
query: str,
|
|
type: Optional[str] = None,
|
|
exchange: Optional[str] = None,
|
|
limit: int = 30,
|
|
) -> Dict[str, List[SearchResult]]:
|
|
"""
|
|
Search across all registered data sources.
|
|
|
|
Args:
|
|
query: Search query
|
|
type: Optional instrument type filter
|
|
exchange: Optional exchange filter
|
|
limit: Maximum results per source
|
|
|
|
Returns:
|
|
Dict mapping source name to search results
|
|
"""
|
|
results = {}
|
|
for name, source in self._sources.items():
|
|
try:
|
|
source_results = await source.search_symbols(query, type, exchange, limit)
|
|
if source_results:
|
|
results[name] = source_results
|
|
except Exception:
|
|
# Silently skip sources that error during search
|
|
continue
|
|
return results
|
|
|
|
async def resolve_symbol(self, source_name: str, symbol: str) -> SymbolInfo:
|
|
"""
|
|
Resolve a symbol from a specific data source.
|
|
|
|
Args:
|
|
source_name: Name of the data source
|
|
symbol: Symbol identifier
|
|
|
|
Returns:
|
|
SymbolInfo from the specified source
|
|
|
|
Raises:
|
|
ValueError: If source not found or symbol not found
|
|
"""
|
|
source = self.get(source_name)
|
|
if not source:
|
|
raise ValueError(f"Data source '{source_name}' not found")
|
|
return await source.resolve_symbol(symbol)
|