initial commit with charts and assistant chat
This commit is contained in:
347
backend/src/datasource/websocket_handler.py
Normal file
347
backend/src/datasource/websocket_handler.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
WebSocket handler for TradingView-compatible datafeed API.
|
||||
|
||||
Handles incoming requests for symbol search, metadata, historical data,
|
||||
and real-time subscriptions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from .base import DataSource
|
||||
from .registry import DataSourceRegistry
|
||||
from .subscription_manager import SubscriptionManager
|
||||
from .websocket_protocol import (
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
GetBarsRequest,
|
||||
GetBarsResponse,
|
||||
GetConfigRequest,
|
||||
GetConfigResponse,
|
||||
ResolveSymbolRequest,
|
||||
ResolveSymbolResponse,
|
||||
SearchSymbolsRequest,
|
||||
SearchSymbolsResponse,
|
||||
SubscribeBarsRequest,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsRequest,
|
||||
UnsubscribeBarsResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatafeedWebSocketHandler:
|
||||
"""
|
||||
Handles WebSocket connections for TradingView-compatible datafeed API.
|
||||
|
||||
Each handler manages a single WebSocket connection and routes requests
|
||||
to the appropriate data sources via the registry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
registry: DataSourceRegistry,
|
||||
subscription_manager: SubscriptionManager,
|
||||
default_source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize handler.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
client_id: Unique identifier for this client
|
||||
registry: DataSource registry for accessing data sources
|
||||
subscription_manager: Shared subscription manager
|
||||
default_source: Default data source name if not specified in requests
|
||||
"""
|
||||
self.websocket = websocket
|
||||
self.client_id = client_id
|
||||
self.registry = registry
|
||||
self.subscription_manager = subscription_manager
|
||||
self.default_source = default_source
|
||||
self._connected = True
|
||||
|
||||
async def handle_connection(self) -> None:
|
||||
"""
|
||||
Main connection handler loop.
|
||||
|
||||
Processes incoming messages until the connection closes.
|
||||
"""
|
||||
try:
|
||||
await self.websocket.accept()
|
||||
logger.info(f"WebSocket connected: client_id={self.client_id}")
|
||||
|
||||
while self._connected:
|
||||
# Receive message
|
||||
try:
|
||||
data = await self.websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving/parsing message: {e}")
|
||||
break
|
||||
|
||||
# Route to appropriate handler
|
||||
await self._handle_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
finally:
|
||||
# Clean up subscriptions when connection closes
|
||||
await self.subscription_manager.unsubscribe_client(self.client_id)
|
||||
self._connected = False
|
||||
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
|
||||
|
||||
async def _handle_message(self, message: dict) -> None:
|
||||
"""Route message to appropriate handler based on type"""
|
||||
msg_type = message.get("type")
|
||||
request_id = message.get("request_id", "unknown")
|
||||
|
||||
try:
|
||||
if msg_type == "search_symbols":
|
||||
await self._handle_search_symbols(SearchSymbolsRequest(**message))
|
||||
elif msg_type == "resolve_symbol":
|
||||
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
|
||||
elif msg_type == "get_bars":
|
||||
await self._handle_get_bars(GetBarsRequest(**message))
|
||||
elif msg_type == "subscribe_bars":
|
||||
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
|
||||
elif msg_type == "unsubscribe_bars":
|
||||
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
|
||||
elif msg_type == "get_config":
|
||||
await self._handle_get_config(GetConfigRequest(**message))
|
||||
else:
|
||||
await self._send_error(
|
||||
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
||||
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
|
||||
|
||||
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
|
||||
"""Handle symbol search request"""
|
||||
# Use default source or search all sources
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"SOURCE_NOT_FOUND",
|
||||
f"Default source '{self.default_source}' not found",
|
||||
)
|
||||
return
|
||||
|
||||
results = await source.search_symbols(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
results_data = [r.model_dump(mode="json") for r in results]
|
||||
else:
|
||||
# Search all sources
|
||||
all_results = await self.registry.search_all(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
# Flatten results from all sources
|
||||
results_data = []
|
||||
for source_results in all_results.values():
|
||||
results_data.extend([r.model_dump(mode="json") for r in source_results])
|
||||
|
||||
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
|
||||
"""Handle symbol resolution request"""
|
||||
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"NO_SOURCE_SPECIFIED",
|
||||
"No data source specified and no default source configured",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
|
||||
response = ResolveSymbolResponse(
|
||||
request_id=request.request_id,
|
||||
symbol_info=symbol_info.model_dump(mode="json"),
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
|
||||
|
||||
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
|
||||
"""Handle historical bars request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
source = self.registry.get(source_name)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
history = await source.get_bars(
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
from_time=request.from_time,
|
||||
to_time=request.to_time,
|
||||
countback=request.countback,
|
||||
)
|
||||
response = GetBarsResponse(
|
||||
request_id=request.request_id, history=history.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
|
||||
|
||||
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
|
||||
"""Handle real-time subscription request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create callback that sends updates to this WebSocket
|
||||
async def send_update(bar: dict):
|
||||
update = BarUpdateMessage(
|
||||
subscription_id=request.subscription_id,
|
||||
symbol=request.symbol,
|
||||
resolution=request.resolution,
|
||||
bar=bar,
|
||||
)
|
||||
await self._send_response(update)
|
||||
|
||||
# Register subscription
|
||||
await self.subscription_manager.subscribe(
|
||||
subscription_id=request.subscription_id,
|
||||
client_id=self.client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
callback=lambda bar: self._queue_update(send_update(bar)),
|
||||
)
|
||||
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription failed: {e}", exc_info=True)
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
message=str(e),
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
|
||||
"""Handle unsubscribe request"""
|
||||
try:
|
||||
await self.subscription_manager.unsubscribe(request.subscription_id)
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Unsubscribe failed: {e}")
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_get_config(self, request: GetConfigRequest) -> None:
|
||||
"""Handle datafeed config request"""
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
return
|
||||
|
||||
# Return aggregate config from all sources
|
||||
all_sources = self.registry.list_sources()
|
||||
if not all_sources:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCES", "No data sources available"
|
||||
)
|
||||
return
|
||||
|
||||
# Just use first source's config for now
|
||||
# TODO: Aggregate configs from multiple sources
|
||||
source = self.registry.get(all_sources[0])
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _send_response(self, response) -> None:
|
||||
"""Send a response message to the client"""
|
||||
try:
|
||||
await self.websocket.send_json(response.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending response: {e}")
|
||||
self._connected = False
|
||||
|
||||
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
|
||||
"""Send an error response"""
|
||||
error = ErrorResponse(
|
||||
request_id=request_id, error_code=error_code, error_message=error_message
|
||||
)
|
||||
await self._send_response(error)
|
||||
|
||||
def _queue_update(self, coro):
|
||||
"""Queue an async update to be sent (prevents blocking callback)"""
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(coro)
|
||||
Reference in New Issue
Block a user