348 lines
13 KiB
Python
348 lines
13 KiB
Python
"""
|
|
WebSocket handler for TradingView-compatible datafeed API.
|
|
|
|
Handles incoming requests for symbol search, metadata, historical data,
|
|
and real-time subscriptions.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, Optional
|
|
|
|
from fastapi import WebSocket
|
|
|
|
from .base import DataSource
|
|
from .registry import DataSourceRegistry
|
|
from .subscription_manager import SubscriptionManager
|
|
from .websocket_protocol import (
|
|
BarUpdateMessage,
|
|
ErrorResponse,
|
|
GetBarsRequest,
|
|
GetBarsResponse,
|
|
GetConfigRequest,
|
|
GetConfigResponse,
|
|
ResolveSymbolRequest,
|
|
ResolveSymbolResponse,
|
|
SearchSymbolsRequest,
|
|
SearchSymbolsResponse,
|
|
SubscribeBarsRequest,
|
|
SubscribeBarsResponse,
|
|
UnsubscribeBarsRequest,
|
|
UnsubscribeBarsResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatafeedWebSocketHandler:
|
|
"""
|
|
Handles WebSocket connections for TradingView-compatible datafeed API.
|
|
|
|
Each handler manages a single WebSocket connection and routes requests
|
|
to the appropriate data sources via the registry.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
websocket: WebSocket,
|
|
client_id: str,
|
|
registry: DataSourceRegistry,
|
|
subscription_manager: SubscriptionManager,
|
|
default_source: Optional[str] = None,
|
|
):
|
|
"""
|
|
Initialize handler.
|
|
|
|
Args:
|
|
websocket: FastAPI WebSocket connection
|
|
client_id: Unique identifier for this client
|
|
registry: DataSource registry for accessing data sources
|
|
subscription_manager: Shared subscription manager
|
|
default_source: Default data source name if not specified in requests
|
|
"""
|
|
self.websocket = websocket
|
|
self.client_id = client_id
|
|
self.registry = registry
|
|
self.subscription_manager = subscription_manager
|
|
self.default_source = default_source
|
|
self._connected = True
|
|
|
|
async def handle_connection(self) -> None:
|
|
"""
|
|
Main connection handler loop.
|
|
|
|
Processes incoming messages until the connection closes.
|
|
"""
|
|
try:
|
|
await self.websocket.accept()
|
|
logger.info(f"WebSocket connected: client_id={self.client_id}")
|
|
|
|
while self._connected:
|
|
# Receive message
|
|
try:
|
|
data = await self.websocket.receive_text()
|
|
message = json.loads(data)
|
|
except Exception as e:
|
|
logger.error(f"Error receiving/parsing message: {e}")
|
|
break
|
|
|
|
# Route to appropriate handler
|
|
await self._handle_message(message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}", exc_info=True)
|
|
finally:
|
|
# Clean up subscriptions when connection closes
|
|
await self.subscription_manager.unsubscribe_client(self.client_id)
|
|
self._connected = False
|
|
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
|
|
|
|
async def _handle_message(self, message: dict) -> None:
|
|
"""Route message to appropriate handler based on type"""
|
|
msg_type = message.get("type")
|
|
request_id = message.get("request_id", "unknown")
|
|
|
|
try:
|
|
if msg_type == "search_symbols":
|
|
await self._handle_search_symbols(SearchSymbolsRequest(**message))
|
|
elif msg_type == "resolve_symbol":
|
|
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
|
|
elif msg_type == "get_bars":
|
|
await self._handle_get_bars(GetBarsRequest(**message))
|
|
elif msg_type == "subscribe_bars":
|
|
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
|
|
elif msg_type == "unsubscribe_bars":
|
|
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
|
|
elif msg_type == "get_config":
|
|
await self._handle_get_config(GetConfigRequest(**message))
|
|
else:
|
|
await self._send_error(
|
|
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
|
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
|
|
|
|
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
|
|
"""Handle symbol search request"""
|
|
# Use default source or search all sources
|
|
if self.default_source:
|
|
source = self.registry.get(self.default_source)
|
|
if not source:
|
|
await self._send_error(
|
|
request.request_id,
|
|
"SOURCE_NOT_FOUND",
|
|
f"Default source '{self.default_source}' not found",
|
|
)
|
|
return
|
|
|
|
results = await source.search_symbols(
|
|
query=request.query,
|
|
type=request.symbol_type,
|
|
exchange=request.exchange,
|
|
limit=request.limit,
|
|
)
|
|
results_data = [r.model_dump(mode="json") for r in results]
|
|
else:
|
|
# Search all sources
|
|
all_results = await self.registry.search_all(
|
|
query=request.query,
|
|
type=request.symbol_type,
|
|
exchange=request.exchange,
|
|
limit=request.limit,
|
|
)
|
|
# Flatten results from all sources
|
|
results_data = []
|
|
for source_results in all_results.values():
|
|
results_data.extend([r.model_dump(mode="json") for r in source_results])
|
|
|
|
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
|
|
await self._send_response(response)
|
|
|
|
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
|
|
"""Handle symbol resolution request"""
|
|
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
|
|
if ":" in request.symbol:
|
|
source_name, symbol = request.symbol.split(":", 1)
|
|
else:
|
|
source_name = self.default_source
|
|
symbol = request.symbol
|
|
|
|
if not source_name:
|
|
await self._send_error(
|
|
request.request_id,
|
|
"NO_SOURCE_SPECIFIED",
|
|
"No data source specified and no default source configured",
|
|
)
|
|
return
|
|
|
|
try:
|
|
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
|
|
response = ResolveSymbolResponse(
|
|
request_id=request.request_id,
|
|
symbol_info=symbol_info.model_dump(mode="json"),
|
|
)
|
|
await self._send_response(response)
|
|
except ValueError as e:
|
|
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
|
|
|
|
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
|
|
"""Handle historical bars request"""
|
|
# Extract source from symbol
|
|
if ":" in request.symbol:
|
|
source_name, symbol = request.symbol.split(":", 1)
|
|
else:
|
|
source_name = self.default_source
|
|
symbol = request.symbol
|
|
|
|
if not source_name:
|
|
await self._send_error(
|
|
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
|
)
|
|
return
|
|
|
|
source = self.registry.get(source_name)
|
|
if not source:
|
|
await self._send_error(
|
|
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
|
|
)
|
|
return
|
|
|
|
try:
|
|
history = await source.get_bars(
|
|
symbol=symbol,
|
|
resolution=request.resolution,
|
|
from_time=request.from_time,
|
|
to_time=request.to_time,
|
|
countback=request.countback,
|
|
)
|
|
response = GetBarsResponse(
|
|
request_id=request.request_id, history=history.model_dump(mode="json")
|
|
)
|
|
await self._send_response(response)
|
|
except ValueError as e:
|
|
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
|
|
|
|
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
|
|
"""Handle real-time subscription request"""
|
|
# Extract source from symbol
|
|
if ":" in request.symbol:
|
|
source_name, symbol = request.symbol.split(":", 1)
|
|
else:
|
|
source_name = self.default_source
|
|
symbol = request.symbol
|
|
|
|
if not source_name:
|
|
await self._send_error(
|
|
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
|
)
|
|
return
|
|
|
|
try:
|
|
# Create callback that sends updates to this WebSocket
|
|
async def send_update(bar: dict):
|
|
update = BarUpdateMessage(
|
|
subscription_id=request.subscription_id,
|
|
symbol=request.symbol,
|
|
resolution=request.resolution,
|
|
bar=bar,
|
|
)
|
|
await self._send_response(update)
|
|
|
|
# Register subscription
|
|
await self.subscription_manager.subscribe(
|
|
subscription_id=request.subscription_id,
|
|
client_id=self.client_id,
|
|
source_name=source_name,
|
|
symbol=symbol,
|
|
resolution=request.resolution,
|
|
callback=lambda bar: self._queue_update(send_update(bar)),
|
|
)
|
|
|
|
response = SubscribeBarsResponse(
|
|
request_id=request.request_id,
|
|
subscription_id=request.subscription_id,
|
|
success=True,
|
|
)
|
|
await self._send_response(response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Subscription failed: {e}", exc_info=True)
|
|
response = SubscribeBarsResponse(
|
|
request_id=request.request_id,
|
|
subscription_id=request.subscription_id,
|
|
success=False,
|
|
message=str(e),
|
|
)
|
|
await self._send_response(response)
|
|
|
|
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
|
|
"""Handle unsubscribe request"""
|
|
try:
|
|
await self.subscription_manager.unsubscribe(request.subscription_id)
|
|
response = UnsubscribeBarsResponse(
|
|
request_id=request.request_id,
|
|
subscription_id=request.subscription_id,
|
|
success=True,
|
|
)
|
|
await self._send_response(response)
|
|
except Exception as e:
|
|
logger.error(f"Unsubscribe failed: {e}")
|
|
response = UnsubscribeBarsResponse(
|
|
request_id=request.request_id,
|
|
subscription_id=request.subscription_id,
|
|
success=False,
|
|
)
|
|
await self._send_response(response)
|
|
|
|
async def _handle_get_config(self, request: GetConfigRequest) -> None:
|
|
"""Handle datafeed config request"""
|
|
if self.default_source:
|
|
source = self.registry.get(self.default_source)
|
|
if source:
|
|
config = await source.get_config()
|
|
response = GetConfigResponse(
|
|
request_id=request.request_id, config=config.model_dump(mode="json")
|
|
)
|
|
await self._send_response(response)
|
|
return
|
|
|
|
# Return aggregate config from all sources
|
|
all_sources = self.registry.list_sources()
|
|
if not all_sources:
|
|
await self._send_error(
|
|
request.request_id, "NO_SOURCES", "No data sources available"
|
|
)
|
|
return
|
|
|
|
# Just use first source's config for now
|
|
# TODO: Aggregate configs from multiple sources
|
|
source = self.registry.get(all_sources[0])
|
|
if source:
|
|
config = await source.get_config()
|
|
response = GetConfigResponse(
|
|
request_id=request.request_id, config=config.model_dump(mode="json")
|
|
)
|
|
await self._send_response(response)
|
|
|
|
async def _send_response(self, response) -> None:
|
|
"""Send a response message to the client"""
|
|
try:
|
|
await self.websocket.send_json(response.model_dump(mode="json"))
|
|
except Exception as e:
|
|
logger.error(f"Error sending response: {e}")
|
|
self._connected = False
|
|
|
|
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
|
|
"""Send an error response"""
|
|
error = ErrorResponse(
|
|
request_id=request_id, error_code=error_code, error_message=error_message
|
|
)
|
|
await self._send_response(error)
|
|
|
|
def _queue_update(self, coro):
|
|
"""Queue an async update to be sent (prevents blocking callback)"""
|
|
import asyncio
|
|
|
|
asyncio.create_task(coro)
|