backend redesign
This commit is contained in:
253
backend.old/src/gateway/hub.py
Normal file
253
backend.old/src/gateway/hub.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Callable, Awaitable, Any
|
||||
from datetime import datetime
|
||||
|
||||
from gateway.channels.base import Channel
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage, AgentMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""Central hub for routing messages between users, channels, and the agent.
|
||||
|
||||
The Gateway:
|
||||
- Maintains active channels and user sessions
|
||||
- Routes user messages to the agent
|
||||
- Routes agent responses back to appropriate channels
|
||||
- Handles interruption of in-flight agent tasks
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.channels: Dict[str, Channel] = {}
|
||||
self.sessions: Dict[str, UserSession] = {}
|
||||
self.agent_executor: Optional[Callable[[UserSession, UserMessage], Awaitable[AsyncIterator[str]]]] = None
|
||||
|
||||
def set_agent_executor(
|
||||
self,
|
||||
executor: Callable[[UserSession, UserMessage], Awaitable[Any]]
|
||||
) -> None:
|
||||
"""Set the agent executor function.
|
||||
|
||||
Args:
|
||||
executor: Async function that takes (session, message) and returns async iterator of response chunks
|
||||
"""
|
||||
self.agent_executor = executor
|
||||
|
||||
def register_channel(self, channel: Channel) -> None:
|
||||
"""Register a new communication channel.
|
||||
|
||||
Args:
|
||||
channel: Channel instance to register
|
||||
"""
|
||||
self.channels[channel.channel_id] = channel
|
||||
|
||||
def unregister_channel(self, channel_id: str) -> None:
|
||||
"""Unregister a channel.
|
||||
|
||||
Args:
|
||||
channel_id: ID of channel to remove
|
||||
"""
|
||||
if channel_id in self.channels:
|
||||
# Remove channel from any sessions
|
||||
for session in self.sessions.values():
|
||||
session.remove_channel(channel_id)
|
||||
del self.channels[channel_id]
|
||||
|
||||
def get_or_create_session(self, session_id: str, user_id: str) -> UserSession:
|
||||
"""Get existing session or create a new one.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
UserSession instance
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = UserSession(session_id, user_id)
|
||||
return self.sessions[session_id]
|
||||
|
||||
async def route_user_message(self, message: UserMessage) -> None:
|
||||
"""Route a user message to the agent.
|
||||
|
||||
If the agent is currently processing a message, it will be interrupted
|
||||
and restarted with the new message appended to the conversation.
|
||||
|
||||
Args:
|
||||
message: UserMessage from a channel
|
||||
"""
|
||||
logger.info(f"route_user_message called - session: {message.session_id}, channel: {message.channel_id}")
|
||||
|
||||
# Get or create session
|
||||
session = self.get_or_create_session(message.session_id, message.session_id)
|
||||
logger.info(f"Session obtained/created: {message.session_id}, channels: {session.active_channels}")
|
||||
|
||||
# Ensure channel is attached to session
|
||||
session.add_channel(message.channel_id)
|
||||
logger.info(f"Channel added to session: {message.channel_id}")
|
||||
|
||||
# If agent is busy, interrupt it
|
||||
if session.is_busy():
|
||||
logger.info(f"Session is busy, interrupting existing task")
|
||||
await session.interrupt()
|
||||
|
||||
# Check if this is a stop interrupt (empty message)
|
||||
if not message.content.strip() and not message.attachments:
|
||||
logger.info("Received stop interrupt (empty message), not starting new agent round")
|
||||
return
|
||||
|
||||
# Add user message to history
|
||||
session.add_message("user", message.content, message.channel_id)
|
||||
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
||||
|
||||
# Start agent task
|
||||
if self.agent_executor:
|
||||
logger.info("Starting agent task execution")
|
||||
task = asyncio.create_task(
|
||||
self._execute_agent_and_stream(session, message)
|
||||
)
|
||||
session.set_task(task)
|
||||
logger.info(f"Agent task created and set on session")
|
||||
else:
|
||||
logger.error("No agent_executor configured! Cannot process message.")
|
||||
# Send error message to user
|
||||
error_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="Error: Agent system not initialized. Please check that ANTHROPIC_API_KEY is configured.",
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(error_msg)
|
||||
|
||||
async def _execute_agent_and_stream(self, session: UserSession, message: UserMessage) -> None:
|
||||
"""Execute agent and stream responses back to channels.
|
||||
|
||||
Args:
|
||||
session: User session
|
||||
message: Triggering user message
|
||||
"""
|
||||
logger.info(f"_execute_agent_and_stream starting for session {session.session_id}")
|
||||
try:
|
||||
# Execute agent (returns async generator directly)
|
||||
logger.info("Calling agent_executor...")
|
||||
response_stream = self.agent_executor(session, message)
|
||||
logger.info("Agent executor returned response stream")
|
||||
|
||||
# Stream chunks back to active channels
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
accumulated_metadata = {}
|
||||
|
||||
async for chunk in response_stream:
|
||||
# Handle dict response with metadata (from agent executor)
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
metadata = chunk.get("metadata", {})
|
||||
# Accumulate metadata (e.g., plot_urls)
|
||||
for key, value in metadata.items():
|
||||
if key == "plot_urls" and value:
|
||||
# Append to existing plot_urls
|
||||
if "plot_urls" not in accumulated_metadata:
|
||||
accumulated_metadata["plot_urls"] = []
|
||||
accumulated_metadata["plot_urls"].extend(value)
|
||||
logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}")
|
||||
else:
|
||||
accumulated_metadata[key] = value
|
||||
chunk = content
|
||||
|
||||
# Only send non-empty chunks
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
|
||||
# Send chunk to all active channels with accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False,
|
||||
metadata=accumulated_metadata.copy()
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
|
||||
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
||||
|
||||
# Send final done message with all accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="",
|
||||
stream_chunk=True,
|
||||
done=True,
|
||||
metadata=accumulated_metadata
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}")
|
||||
|
||||
# Add to history
|
||||
session.add_message("assistant", full_response)
|
||||
logger.info("Assistant response added to history")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was interrupted
|
||||
logger.warning(f"Agent task interrupted for session {session.session_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution error: {e}", exc_info=True)
|
||||
# Send error message
|
||||
error_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=f"Error: {str(e)}",
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(error_msg)
|
||||
finally:
|
||||
session.set_task(None)
|
||||
logger.info(f"Agent task completed for session {session.session_id}")
|
||||
|
||||
async def _send_to_channels(self, message: AgentMessage) -> None:
|
||||
"""Send message to specified channels.
|
||||
|
||||
Args:
|
||||
message: AgentMessage to send
|
||||
"""
|
||||
logger.debug(f"Sending message to {len(message.target_channels)} channels")
|
||||
for channel_id in message.target_channels:
|
||||
channel = self.channels.get(channel_id)
|
||||
if channel and channel.get_status().connected:
|
||||
try:
|
||||
await channel.send(message)
|
||||
logger.debug(f"Message sent to channel {channel_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to channel {channel_id}: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning(f"Channel {channel_id} not found or not connected")
|
||||
|
||||
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session information.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to query
|
||||
|
||||
Returns:
|
||||
Session info dict or None if not found
|
||||
"""
|
||||
session = self.sessions.get(session_id)
|
||||
return session.to_dict() if session else None
|
||||
|
||||
def get_active_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all active sessions.
|
||||
|
||||
Returns:
|
||||
Dict mapping session_id to session info
|
||||
"""
|
||||
return {
|
||||
sid: session.to_dict()
|
||||
for sid, session in self.sessions.items()
|
||||
}
|
||||
Reference in New Issue
Block a user