import asyncio import logging from typing import Dict, Optional, Callable, Awaitable, Any from datetime import datetime from gateway.channels.base import Channel from gateway.user_session import UserSession from gateway.protocol import UserMessage, AgentMessage logger = logging.getLogger(__name__) class Gateway: """Central hub for routing messages between users, channels, and the agent. The Gateway: - Maintains active channels and user sessions - Routes user messages to the agent - Routes agent responses back to appropriate channels - Handles interruption of in-flight agent tasks """ def __init__(self): self.channels: Dict[str, Channel] = {} self.sessions: Dict[str, UserSession] = {} self.agent_executor: Optional[Callable[[UserSession, UserMessage], Awaitable[AsyncIterator[str]]]] = None def set_agent_executor( self, executor: Callable[[UserSession, UserMessage], Awaitable[Any]] ) -> None: """Set the agent executor function. Args: executor: Async function that takes (session, message) and returns async iterator of response chunks """ self.agent_executor = executor def register_channel(self, channel: Channel) -> None: """Register a new communication channel. Args: channel: Channel instance to register """ self.channels[channel.channel_id] = channel def unregister_channel(self, channel_id: str) -> None: """Unregister a channel. Args: channel_id: ID of channel to remove """ if channel_id in self.channels: # Remove channel from any sessions for session in self.sessions.values(): session.remove_channel(channel_id) del self.channels[channel_id] def get_or_create_session(self, session_id: str, user_id: str) -> UserSession: """Get existing session or create a new one. Args: session_id: Unique session identifier user_id: User identifier Returns: UserSession instance """ if session_id not in self.sessions: self.sessions[session_id] = UserSession(session_id, user_id) return self.sessions[session_id] async def route_user_message(self, message: UserMessage) -> None: """Route a user message to the agent. If the agent is currently processing a message, it will be interrupted and restarted with the new message appended to the conversation. Args: message: UserMessage from a channel """ logger.info(f"route_user_message called - session: {message.session_id}, channel: {message.channel_id}") # Get or create session session = self.get_or_create_session(message.session_id, message.session_id) logger.info(f"Session obtained/created: {message.session_id}, channels: {session.active_channels}") # Ensure channel is attached to session session.add_channel(message.channel_id) logger.info(f"Channel added to session: {message.channel_id}") # If agent is busy, interrupt it if session.is_busy(): logger.info(f"Session is busy, interrupting existing task") await session.interrupt() # Check if this is a stop interrupt (empty message) if not message.content.strip() and not message.attachments: logger.info("Received stop interrupt (empty message), not starting new agent round") return # Add user message to history session.add_message("user", message.content, message.channel_id) logger.info(f"User message added to history, history length: {len(session.get_history())}") # Start agent task if self.agent_executor: logger.info("Starting agent task execution") task = asyncio.create_task( self._execute_agent_and_stream(session, message) ) session.set_task(task) logger.info(f"Agent task created and set on session") else: logger.error("No agent_executor configured! Cannot process message.") # Send error message to user error_msg = AgentMessage( session_id=session.session_id, target_channels=session.active_channels, content="Error: Agent system not initialized. Please check that ANTHROPIC_API_KEY is configured.", done=True ) await self._send_to_channels(error_msg) async def _execute_agent_and_stream(self, session: UserSession, message: UserMessage) -> None: """Execute agent and stream responses back to channels. Args: session: User session message: Triggering user message """ logger.info(f"_execute_agent_and_stream starting for session {session.session_id}") try: # Execute agent (returns async generator directly) logger.info("Calling agent_executor...") response_stream = self.agent_executor(session, message) logger.info("Agent executor returned response stream") # Stream chunks back to active channels full_response = "" chunk_count = 0 accumulated_metadata = {} async for chunk in response_stream: # Handle dict response with metadata (from agent executor) if isinstance(chunk, dict): content = chunk.get("content", "") metadata = chunk.get("metadata", {}) # Accumulate metadata (e.g., plot_urls) for key, value in metadata.items(): if key == "plot_urls" and value: # Append to existing plot_urls if "plot_urls" not in accumulated_metadata: accumulated_metadata["plot_urls"] = [] accumulated_metadata["plot_urls"].extend(value) logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}") else: accumulated_metadata[key] = value chunk = content # Only send non-empty chunks if chunk: chunk_count += 1 full_response += chunk logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}") # Send chunk to all active channels with accumulated metadata agent_msg = AgentMessage( session_id=session.session_id, target_channels=session.active_channels, content=chunk, stream_chunk=True, done=False, metadata=accumulated_metadata.copy() ) await self._send_to_channels(agent_msg) logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}") # Send final done message with all accumulated metadata agent_msg = AgentMessage( session_id=session.session_id, target_channels=session.active_channels, content="", stream_chunk=True, done=True, metadata=accumulated_metadata ) await self._send_to_channels(agent_msg) logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}") # Add to history session.add_message("assistant", full_response) logger.info("Assistant response added to history") except asyncio.CancelledError: # Task was interrupted logger.warning(f"Agent task interrupted for session {session.session_id}") raise except Exception as e: logger.error(f"Agent execution error: {e}", exc_info=True) # Send error message error_msg = AgentMessage( session_id=session.session_id, target_channels=session.active_channels, content=f"Error: {str(e)}", done=True ) await self._send_to_channels(error_msg) finally: session.set_task(None) logger.info(f"Agent task completed for session {session.session_id}") async def _send_to_channels(self, message: AgentMessage) -> None: """Send message to specified channels. Args: message: AgentMessage to send """ logger.debug(f"Sending message to {len(message.target_channels)} channels") for channel_id in message.target_channels: channel = self.channels.get(channel_id) if channel and channel.get_status().connected: try: await channel.send(message) logger.debug(f"Message sent to channel {channel_id}") except Exception as e: logger.error(f"Error sending to channel {channel_id}: {e}", exc_info=True) else: logger.warning(f"Channel {channel_id} not found or not connected") def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: """Get session information. Args: session_id: Session ID to query Returns: Session info dict or None if not found """ session = self.sessions.get(session_id) return session.to_dict() if session else None def get_active_sessions(self) -> Dict[str, Dict[str, Any]]: """Get all active sessions. Returns: Dict mapping session_id to session info """ return { sid: session.to_dict() for sid, session in self.sessions.items() }