""" WebSocket handler for TradingView-compatible datafeed API. Handles incoming requests for symbol search, metadata, historical data, and real-time subscriptions. """ import json import logging from typing import Dict, Optional from fastapi import WebSocket from .base import DataSource from .registry import DataSourceRegistry from .subscription_manager import SubscriptionManager from .websocket_protocol import ( BarUpdateMessage, ErrorResponse, GetBarsRequest, GetBarsResponse, GetConfigRequest, GetConfigResponse, ResolveSymbolRequest, ResolveSymbolResponse, SearchSymbolsRequest, SearchSymbolsResponse, SubscribeBarsRequest, SubscribeBarsResponse, UnsubscribeBarsRequest, UnsubscribeBarsResponse, ) logger = logging.getLogger(__name__) class DatafeedWebSocketHandler: """ Handles WebSocket connections for TradingView-compatible datafeed API. Each handler manages a single WebSocket connection and routes requests to the appropriate data sources via the registry. """ def __init__( self, websocket: WebSocket, client_id: str, registry: DataSourceRegistry, subscription_manager: SubscriptionManager, default_source: Optional[str] = None, ): """ Initialize handler. Args: websocket: FastAPI WebSocket connection client_id: Unique identifier for this client registry: DataSource registry for accessing data sources subscription_manager: Shared subscription manager default_source: Default data source name if not specified in requests """ self.websocket = websocket self.client_id = client_id self.registry = registry self.subscription_manager = subscription_manager self.default_source = default_source self._connected = True async def handle_connection(self) -> None: """ Main connection handler loop. Processes incoming messages until the connection closes. """ try: await self.websocket.accept() logger.info(f"WebSocket connected: client_id={self.client_id}") while self._connected: # Receive message try: data = await self.websocket.receive_text() message = json.loads(data) except Exception as e: logger.error(f"Error receiving/parsing message: {e}") break # Route to appropriate handler await self._handle_message(message) except Exception as e: logger.error(f"WebSocket error: {e}", exc_info=True) finally: # Clean up subscriptions when connection closes await self.subscription_manager.unsubscribe_client(self.client_id) self._connected = False logger.info(f"WebSocket disconnected: client_id={self.client_id}") async def _handle_message(self, message: dict) -> None: """Route message to appropriate handler based on type""" msg_type = message.get("type") request_id = message.get("request_id", "unknown") try: if msg_type == "search_symbols": await self._handle_search_symbols(SearchSymbolsRequest(**message)) elif msg_type == "resolve_symbol": await self._handle_resolve_symbol(ResolveSymbolRequest(**message)) elif msg_type == "get_bars": await self._handle_get_bars(GetBarsRequest(**message)) elif msg_type == "subscribe_bars": await self._handle_subscribe_bars(SubscribeBarsRequest(**message)) elif msg_type == "unsubscribe_bars": await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message)) elif msg_type == "get_config": await self._handle_get_config(GetConfigRequest(**message)) else: await self._send_error( request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}" ) except Exception as e: logger.error(f"Error handling {msg_type}: {e}", exc_info=True) await self._send_error(request_id, "INTERNAL_ERROR", str(e)) async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None: """Handle symbol search request""" # Use default source or search all sources if self.default_source: source = self.registry.get(self.default_source) if not source: await self._send_error( request.request_id, "SOURCE_NOT_FOUND", f"Default source '{self.default_source}' not found", ) return results = await source.search_symbols( query=request.query, type=request.symbol_type, exchange=request.exchange, limit=request.limit, ) results_data = [r.model_dump(mode="json") for r in results] else: # Search all sources all_results = await self.registry.search_all( query=request.query, type=request.symbol_type, exchange=request.exchange, limit=request.limit, ) # Flatten results from all sources results_data = [] for source_results in all_results.values(): results_data.extend([r.model_dump(mode="json") for r in source_results]) response = SearchSymbolsResponse(request_id=request.request_id, results=results_data) await self._send_response(response) async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None: """Handle symbol resolution request""" # Extract source from symbol if present (format: "SOURCE:SYMBOL") if ":" in request.symbol: source_name, symbol = request.symbol.split(":", 1) else: source_name = self.default_source symbol = request.symbol if not source_name: await self._send_error( request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified and no default source configured", ) return try: symbol_info = await self.registry.resolve_symbol(source_name, symbol) response = ResolveSymbolResponse( request_id=request.request_id, symbol_info=symbol_info.model_dump(mode="json"), ) await self._send_response(response) except ValueError as e: await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e)) async def _handle_get_bars(self, request: GetBarsRequest) -> None: """Handle historical bars request""" # Extract source from symbol if ":" in request.symbol: source_name, symbol = request.symbol.split(":", 1) else: source_name = self.default_source symbol = request.symbol if not source_name: await self._send_error( request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified" ) return source = self.registry.get(source_name) if not source: await self._send_error( request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found" ) return try: history = await source.get_bars( symbol=symbol, resolution=request.resolution, from_time=request.from_time, to_time=request.to_time, countback=request.countback, ) response = GetBarsResponse( request_id=request.request_id, history=history.model_dump(mode="json") ) await self._send_response(response) except ValueError as e: await self._send_error(request.request_id, "INVALID_REQUEST", str(e)) async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None: """Handle real-time subscription request""" # Extract source from symbol if ":" in request.symbol: source_name, symbol = request.symbol.split(":", 1) else: source_name = self.default_source symbol = request.symbol if not source_name: await self._send_error( request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified" ) return try: # Create callback that sends updates to this WebSocket async def send_update(bar: dict): update = BarUpdateMessage( subscription_id=request.subscription_id, symbol=request.symbol, resolution=request.resolution, bar=bar, ) await self._send_response(update) # Register subscription await self.subscription_manager.subscribe( subscription_id=request.subscription_id, client_id=self.client_id, source_name=source_name, symbol=symbol, resolution=request.resolution, callback=lambda bar: self._queue_update(send_update(bar)), ) response = SubscribeBarsResponse( request_id=request.request_id, subscription_id=request.subscription_id, success=True, ) await self._send_response(response) except Exception as e: logger.error(f"Subscription failed: {e}", exc_info=True) response = SubscribeBarsResponse( request_id=request.request_id, subscription_id=request.subscription_id, success=False, message=str(e), ) await self._send_response(response) async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None: """Handle unsubscribe request""" try: await self.subscription_manager.unsubscribe(request.subscription_id) response = UnsubscribeBarsResponse( request_id=request.request_id, subscription_id=request.subscription_id, success=True, ) await self._send_response(response) except Exception as e: logger.error(f"Unsubscribe failed: {e}") response = UnsubscribeBarsResponse( request_id=request.request_id, subscription_id=request.subscription_id, success=False, ) await self._send_response(response) async def _handle_get_config(self, request: GetConfigRequest) -> None: """Handle datafeed config request""" if self.default_source: source = self.registry.get(self.default_source) if source: config = await source.get_config() response = GetConfigResponse( request_id=request.request_id, config=config.model_dump(mode="json") ) await self._send_response(response) return # Return aggregate config from all sources all_sources = self.registry.list_sources() if not all_sources: await self._send_error( request.request_id, "NO_SOURCES", "No data sources available" ) return # Just use first source's config for now # TODO: Aggregate configs from multiple sources source = self.registry.get(all_sources[0]) if source: config = await source.get_config() response = GetConfigResponse( request_id=request.request_id, config=config.model_dump(mode="json") ) await self._send_response(response) async def _send_response(self, response) -> None: """Send a response message to the client""" try: await self.websocket.send_json(response.model_dump(mode="json")) except Exception as e: logger.error(f"Error sending response: {e}") self._connected = False async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None: """Send an error response""" error = ErrorResponse( request_id=request_id, error_code=error_code, error_message=error_message ) await self._send_response(error) def _queue_update(self, coro): """Queue an async update to be sent (prevents blocking callback)""" import asyncio asyncio.create_task(coro)