Files
ai/backend.old/src/datasource/subscription_manager.py
2026-03-11 18:47:11 -04:00

236 lines
8.0 KiB
Python

"""
Subscription manager for real-time data feeds.
Manages subscriptions across multiple data sources and routes updates
to WebSocket clients.
"""
import asyncio
import logging
from typing import Callable, Dict, Optional, Set
from .base import DataSource
logger = logging.getLogger(__name__)
class Subscription:
"""Represents a single client subscription"""
def __init__(
self,
subscription_id: str,
client_id: str,
source_name: str,
symbol: str,
resolution: str,
callback: Callable[[dict], None],
):
self.subscription_id = subscription_id
self.client_id = client_id
self.source_name = source_name
self.symbol = symbol
self.resolution = resolution
self.callback = callback
self.source_subscription_id: Optional[str] = None
class SubscriptionManager:
"""
Manages real-time data subscriptions across multiple data sources.
Handles:
- Subscription lifecycle (subscribe/unsubscribe)
- Routing updates from data sources to clients
- Multiplexing (multiple clients can subscribe to same symbol/resolution)
"""
def __init__(self):
# Map subscription_id -> Subscription
self._subscriptions: Dict[str, Subscription] = {}
# Map (source_name, symbol, resolution) -> Set[subscription_id]
# For tracking which client subscriptions use which source subscriptions
self._source_refs: Dict[tuple, Set[str]] = {}
# Map source_subscription_id -> (source_name, symbol, resolution)
self._source_subs: Dict[str, tuple] = {}
# Available data sources
self._sources: Dict[str, DataSource] = {}
def register_source(self, name: str, source: DataSource) -> None:
"""Register a data source"""
self._sources[name] = source
def unregister_source(self, name: str) -> None:
"""Unregister a data source"""
self._sources.pop(name, None)
async def subscribe(
self,
subscription_id: str,
client_id: str,
source_name: str,
symbol: str,
resolution: str,
callback: Callable[[dict], None],
) -> None:
"""
Subscribe a client to real-time updates.
Args:
subscription_id: Unique ID for this subscription
client_id: ID of the subscribing client
source_name: Name of the data source
symbol: Symbol to subscribe to
resolution: Time resolution
callback: Function to call with bar updates
Raises:
ValueError: If source not found or subscription fails
"""
source = self._sources.get(source_name)
if not source:
raise ValueError(f"Data source '{source_name}' not found")
# Create subscription record
subscription = Subscription(
subscription_id=subscription_id,
client_id=client_id,
source_name=source_name,
symbol=symbol,
resolution=resolution,
callback=callback,
)
# Check if we already have a source subscription for this (source, symbol, resolution)
source_key = (source_name, symbol, resolution)
if source_key not in self._source_refs:
# Need to create a new source subscription
try:
source_sub_id = await source.subscribe_bars(
symbol=symbol,
resolution=resolution,
on_tick=lambda bar: self._on_source_update(source_key, bar),
)
subscription.source_subscription_id = source_sub_id
self._source_subs[source_sub_id] = source_key
self._source_refs[source_key] = set()
logger.info(
f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}"
)
except Exception as e:
logger.error(f"Failed to subscribe to source: {e}")
raise
# Add this subscription to the reference set
self._source_refs[source_key].add(subscription_id)
self._subscriptions[subscription_id] = subscription
logger.info(
f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}"
)
async def unsubscribe(self, subscription_id: str) -> None:
"""
Unsubscribe a client from updates.
Args:
subscription_id: ID of the subscription to remove
"""
subscription = self._subscriptions.pop(subscription_id, None)
if not subscription:
logger.warning(f"Subscription {subscription_id} not found")
return
source_key = (subscription.source_name, subscription.symbol, subscription.resolution)
# Remove from reference set
if source_key in self._source_refs:
self._source_refs[source_key].discard(subscription_id)
# If no more clients need this source subscription, unsubscribe from source
if not self._source_refs[source_key]:
del self._source_refs[source_key]
if subscription.source_subscription_id:
source = self._sources.get(subscription.source_name)
if source:
try:
await source.unsubscribe_bars(subscription.source_subscription_id)
logger.info(
f"Unsubscribed from source: {subscription.source_subscription_id}"
)
except Exception as e:
logger.error(f"Error unsubscribing from source: {e}")
self._source_subs.pop(subscription.source_subscription_id, None)
logger.info(f"Client subscription removed: {subscription_id}")
async def unsubscribe_client(self, client_id: str) -> None:
"""
Unsubscribe all subscriptions for a client.
Useful when a WebSocket connection closes.
Args:
client_id: ID of the client
"""
# Find all subscriptions for this client
client_subs = [
sub_id
for sub_id, sub in self._subscriptions.items()
if sub.client_id == client_id
]
# Unsubscribe each one
for sub_id in client_subs:
await self.unsubscribe(sub_id)
logger.info(f"Unsubscribed all subscriptions for client {client_id}")
def _on_source_update(self, source_key: tuple, bar: dict) -> None:
"""
Handle an update from a data source.
Routes the update to all client subscriptions that need it.
Args:
source_key: (source_name, symbol, resolution) tuple
bar: Bar data dict from the source
"""
subscription_ids = self._source_refs.get(source_key, set())
for sub_id in subscription_ids:
subscription = self._subscriptions.get(sub_id)
if subscription:
try:
subscription.callback(bar)
except Exception as e:
logger.error(
f"Error in subscription callback {sub_id}: {e}", exc_info=True
)
def get_subscription_count(self) -> int:
"""Get total number of active client subscriptions"""
return len(self._subscriptions)
def get_source_subscription_count(self) -> int:
"""Get total number of active source subscriptions"""
return len(self._source_refs)
def get_client_subscriptions(self, client_id: str) -> list:
"""Get all subscriptions for a specific client"""
return [
{
"subscription_id": sub.subscription_id,
"source": sub.source_name,
"symbol": sub.symbol,
"resolution": sub.resolution,
}
for sub in self._subscriptions.values()
if sub.client_id == client_id
]