""" Subscription manager for real-time data feeds. Manages subscriptions across multiple data sources and routes updates to WebSocket clients. """ import asyncio import logging from typing import Callable, Dict, Optional, Set from .base import DataSource logger = logging.getLogger(__name__) class Subscription: """Represents a single client subscription""" def __init__( self, subscription_id: str, client_id: str, source_name: str, symbol: str, resolution: str, callback: Callable[[dict], None], ): self.subscription_id = subscription_id self.client_id = client_id self.source_name = source_name self.symbol = symbol self.resolution = resolution self.callback = callback self.source_subscription_id: Optional[str] = None class SubscriptionManager: """ Manages real-time data subscriptions across multiple data sources. Handles: - Subscription lifecycle (subscribe/unsubscribe) - Routing updates from data sources to clients - Multiplexing (multiple clients can subscribe to same symbol/resolution) """ def __init__(self): # Map subscription_id -> Subscription self._subscriptions: Dict[str, Subscription] = {} # Map (source_name, symbol, resolution) -> Set[subscription_id] # For tracking which client subscriptions use which source subscriptions self._source_refs: Dict[tuple, Set[str]] = {} # Map source_subscription_id -> (source_name, symbol, resolution) self._source_subs: Dict[str, tuple] = {} # Available data sources self._sources: Dict[str, DataSource] = {} def register_source(self, name: str, source: DataSource) -> None: """Register a data source""" self._sources[name] = source def unregister_source(self, name: str) -> None: """Unregister a data source""" self._sources.pop(name, None) async def subscribe( self, subscription_id: str, client_id: str, source_name: str, symbol: str, resolution: str, callback: Callable[[dict], None], ) -> None: """ Subscribe a client to real-time updates. Args: subscription_id: Unique ID for this subscription client_id: ID of the subscribing client source_name: Name of the data source symbol: Symbol to subscribe to resolution: Time resolution callback: Function to call with bar updates Raises: ValueError: If source not found or subscription fails """ source = self._sources.get(source_name) if not source: raise ValueError(f"Data source '{source_name}' not found") # Create subscription record subscription = Subscription( subscription_id=subscription_id, client_id=client_id, source_name=source_name, symbol=symbol, resolution=resolution, callback=callback, ) # Check if we already have a source subscription for this (source, symbol, resolution) source_key = (source_name, symbol, resolution) if source_key not in self._source_refs: # Need to create a new source subscription try: source_sub_id = await source.subscribe_bars( symbol=symbol, resolution=resolution, on_tick=lambda bar: self._on_source_update(source_key, bar), ) subscription.source_subscription_id = source_sub_id self._source_subs[source_sub_id] = source_key self._source_refs[source_key] = set() logger.info( f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}" ) except Exception as e: logger.error(f"Failed to subscribe to source: {e}") raise # Add this subscription to the reference set self._source_refs[source_key].add(subscription_id) self._subscriptions[subscription_id] = subscription logger.info( f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}" ) async def unsubscribe(self, subscription_id: str) -> None: """ Unsubscribe a client from updates. Args: subscription_id: ID of the subscription to remove """ subscription = self._subscriptions.pop(subscription_id, None) if not subscription: logger.warning(f"Subscription {subscription_id} not found") return source_key = (subscription.source_name, subscription.symbol, subscription.resolution) # Remove from reference set if source_key in self._source_refs: self._source_refs[source_key].discard(subscription_id) # If no more clients need this source subscription, unsubscribe from source if not self._source_refs[source_key]: del self._source_refs[source_key] if subscription.source_subscription_id: source = self._sources.get(subscription.source_name) if source: try: await source.unsubscribe_bars(subscription.source_subscription_id) logger.info( f"Unsubscribed from source: {subscription.source_subscription_id}" ) except Exception as e: logger.error(f"Error unsubscribing from source: {e}") self._source_subs.pop(subscription.source_subscription_id, None) logger.info(f"Client subscription removed: {subscription_id}") async def unsubscribe_client(self, client_id: str) -> None: """ Unsubscribe all subscriptions for a client. Useful when a WebSocket connection closes. Args: client_id: ID of the client """ # Find all subscriptions for this client client_subs = [ sub_id for sub_id, sub in self._subscriptions.items() if sub.client_id == client_id ] # Unsubscribe each one for sub_id in client_subs: await self.unsubscribe(sub_id) logger.info(f"Unsubscribed all subscriptions for client {client_id}") def _on_source_update(self, source_key: tuple, bar: dict) -> None: """ Handle an update from a data source. Routes the update to all client subscriptions that need it. Args: source_key: (source_name, symbol, resolution) tuple bar: Bar data dict from the source """ subscription_ids = self._source_refs.get(source_key, set()) for sub_id in subscription_ids: subscription = self._subscriptions.get(sub_id) if subscription: try: subscription.callback(bar) except Exception as e: logger.error( f"Error in subscription callback {sub_id}: {e}", exc_info=True ) def get_subscription_count(self) -> int: """Get total number of active client subscriptions""" return len(self._subscriptions) def get_source_subscription_count(self) -> int: """Get total number of active source subscriptions""" return len(self._source_refs) def get_client_subscriptions(self, client_id: str) -> list: """Get all subscriptions for a specific client""" return [ { "subscription_id": sub.subscription_id, "source": sub.source_name, "symbol": sub.symbol, "resolution": sub.resolution, } for sub in self._subscriptions.values() if sub.client_id == client_id ]