import asyncio import json import logging from typing import AsyncIterator, Optional from datetime import datetime from fastapi import WebSocket from gateway.channels.base import Channel from gateway.protocol import UserMessage, AgentMessage, WebSocketAgentUserMessage, WebSocketAgentChunk logger = logging.getLogger(__name__) class WebSocketChannel(Channel): """WebSocket-based communication channel. Integrates with the existing WebSocket endpoint to provide bidirectional agent communication with streaming support. """ def __init__(self, channel_id: str, websocket: WebSocket, session_id: str): super().__init__(channel_id, "websocket") self.websocket = websocket self.session_id = session_id self._connected = True self._receive_queue: asyncio.Queue[UserMessage] = asyncio.Queue() self._receive_task: Optional[asyncio.Task] = None def supports_streaming(self) -> bool: """WebSocket supports streaming responses.""" return True def supports_attachments(self) -> bool: """WebSocket can support attachments via URLs.""" return True async def send(self, message: AgentMessage) -> None: """Send agent message through WebSocket. For streaming messages, sends chunks as they arrive. For complete messages, sends as a single chunk. """ if not self._connected: logger.warning(f"Cannot send message, channel {self.channel_id} not connected") return try: chunk = WebSocketAgentChunk( session_id=message.session_id, content=message.content, done=message.done, metadata=message.metadata ) chunk_data = chunk.model_dump(mode="json") logger.debug(f"Sending WebSocket message: done={message.done}, content_length={len(message.content)}") await self.websocket.send_json(chunk_data) logger.debug(f"WebSocket message sent successfully") except Exception as e: logger.error(f"WebSocket send error: {e}", exc_info=True) self._connected = False async def receive(self) -> AsyncIterator[UserMessage]: """Receive messages from WebSocket. Yields: UserMessage objects as they arrive from the client """ try: while self._connected: # Read from WebSocket data = await self.websocket.receive_text() message_json = json.loads(data) # Only process agent_user_message types if message_json.get("type") == "agent_user_message": msg = WebSocketAgentUserMessage(**message_json) user_msg = UserMessage( session_id=msg.session_id, channel_id=self.channel_id, content=msg.content, attachments=msg.attachments, timestamp=datetime.utcnow() ) yield user_msg except Exception as e: print(f"WebSocket receive error: {e}") self._connected = False async def close(self) -> None: """Close the WebSocket connection.""" self._connected = False if self._receive_task: self._receive_task.cancel() try: await self._receive_task except asyncio.CancelledError: pass # Note: WebSocket close is handled by the main WebSocket endpoint