100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import AsyncIterator, Optional
|
|
from datetime import datetime
|
|
from fastapi import WebSocket
|
|
|
|
from gateway.channels.base import Channel
|
|
from gateway.protocol import UserMessage, AgentMessage, WebSocketAgentUserMessage, WebSocketAgentChunk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WebSocketChannel(Channel):
|
|
"""WebSocket-based communication channel.
|
|
|
|
Integrates with the existing WebSocket endpoint to provide
|
|
bidirectional agent communication with streaming support.
|
|
"""
|
|
|
|
def __init__(self, channel_id: str, websocket: WebSocket, session_id: str):
|
|
super().__init__(channel_id, "websocket")
|
|
self.websocket = websocket
|
|
self.session_id = session_id
|
|
self._connected = True
|
|
self._receive_queue: asyncio.Queue[UserMessage] = asyncio.Queue()
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
|
|
def supports_streaming(self) -> bool:
|
|
"""WebSocket supports streaming responses."""
|
|
return True
|
|
|
|
def supports_attachments(self) -> bool:
|
|
"""WebSocket can support attachments via URLs."""
|
|
return True
|
|
|
|
async def send(self, message: AgentMessage) -> None:
|
|
"""Send agent message through WebSocket.
|
|
|
|
For streaming messages, sends chunks as they arrive.
|
|
For complete messages, sends as a single chunk.
|
|
"""
|
|
if not self._connected:
|
|
logger.warning(f"Cannot send message, channel {self.channel_id} not connected")
|
|
return
|
|
|
|
try:
|
|
chunk = WebSocketAgentChunk(
|
|
session_id=message.session_id,
|
|
content=message.content,
|
|
done=message.done,
|
|
metadata=message.metadata
|
|
)
|
|
chunk_data = chunk.model_dump(mode="json")
|
|
logger.debug(f"Sending WebSocket message: done={message.done}, content_length={len(message.content)}")
|
|
await self.websocket.send_json(chunk_data)
|
|
logger.debug(f"WebSocket message sent successfully")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket send error: {e}", exc_info=True)
|
|
self._connected = False
|
|
|
|
async def receive(self) -> AsyncIterator[UserMessage]:
|
|
"""Receive messages from WebSocket.
|
|
|
|
Yields:
|
|
UserMessage objects as they arrive from the client
|
|
"""
|
|
try:
|
|
while self._connected:
|
|
# Read from WebSocket
|
|
data = await self.websocket.receive_text()
|
|
message_json = json.loads(data)
|
|
|
|
# Only process agent_user_message types
|
|
if message_json.get("type") == "agent_user_message":
|
|
msg = WebSocketAgentUserMessage(**message_json)
|
|
|
|
user_msg = UserMessage(
|
|
session_id=msg.session_id,
|
|
channel_id=self.channel_id,
|
|
content=msg.content,
|
|
attachments=msg.attachments,
|
|
timestamp=datetime.utcnow()
|
|
)
|
|
yield user_msg
|
|
except Exception as e:
|
|
print(f"WebSocket receive error: {e}")
|
|
self._connected = False
|
|
|
|
async def close(self) -> None:
|
|
"""Close the WebSocket connection."""
|
|
self._connected = False
|
|
if self._receive_task:
|
|
self._receive_task.cancel()
|
|
try:
|
|
await self._receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Note: WebSocket close is handled by the main WebSocket endpoint
|