import asyncio import logging from typing import AsyncIterator, Dict, Any, Optional from langchain_anthropic import ChatAnthropic from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langchain_core.runnables import RunnableConfig from langgraph.prebuilt import create_react_agent from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS from agent.memory import MemoryManager from agent.session import SessionManager from agent.prompts import build_system_prompt from agent.subagent import SubAgent from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent from gateway.user_session import UserSession from gateway.protocol import UserMessage as GatewayUserMessage logger = logging.getLogger(__name__) class AgentExecutor: """LangGraph-based agent executor with streaming support. Handles agent invocation, tool execution, and response streaming. Supports interruption for real-time user interaction. """ def __init__( self, model_name: str = "claude-sonnet-4-20250514", temperature: float = 0.7, api_key: Optional[str] = None, memory_manager: Optional[MemoryManager] = None, base_dir: str = "." ): """Initialize agent executor. Args: model_name: Anthropic model name temperature: Model temperature api_key: Anthropic API key memory_manager: MemoryManager instance base_dir: Base directory for resolving paths """ self.model_name = model_name self.temperature = temperature self.api_key = api_key self.base_dir = base_dir # Initialize LLM self.llm = ChatAnthropic( model=model_name, temperature=temperature, api_key=api_key, streaming=True ) # Memory and session management self.memory_manager = memory_manager or MemoryManager() self.session_manager = SessionManager(self.memory_manager) self.agent = None # Will be created after initialization # Sub-agents (only if using hierarchical tools) self.chart_agent = None self.data_agent = None self.automation_agent = None self.research_agent = None async def initialize(self) -> None: """Initialize the agent system.""" await self.memory_manager.initialize() # Create agent with tools and LangGraph checkpointer checkpointer = self.memory_manager.get_checkpointer() # Create specialized sub-agents logger.info("Initializing hierarchical agent architecture with sub-agents") self.chart_agent = SubAgent( name="chart", soul_file="chart_agent.md", tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS, model_name=self.model_name, temperature=self.temperature, api_key=self.api_key, base_dir=self.base_dir ) self.data_agent = SubAgent( name="data", soul_file="data_agent.md", tools=DATASOURCE_TOOLS, model_name=self.model_name, temperature=self.temperature, api_key=self.api_key, base_dir=self.base_dir ) self.automation_agent = SubAgent( name="automation", soul_file="automation_agent.md", tools=TRIGGER_TOOLS, model_name=self.model_name, temperature=self.temperature, api_key=self.api_key, base_dir=self.base_dir ) self.research_agent = SubAgent( name="research", soul_file="research_agent.md", tools=RESEARCH_TOOLS, model_name=self.model_name, temperature=self.temperature, api_key=self.api_key, base_dir=self.base_dir ) # Set global sub-agent instances for router tools set_chart_agent(self.chart_agent) set_data_agent(self.data_agent) set_automation_agent(self.automation_agent) set_research_agent(self.research_agent) # Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS logger.info("Main agent using router tools (4 routers + sync tools)") agent_tools = SYNC_TOOLS + ROUTER_TOOLS # Create main agent without a static system prompt # We'll pass the dynamic system prompt via state_modifier at runtime self.agent = create_react_agent( self.llm, agent_tools, checkpointer=checkpointer ) logger.info(f"Agent initialized with {len(agent_tools)} tools") async def _clear_checkpoint(self, session_id: str) -> None: """Clear the checkpoint for a session to prevent resuming from invalid state. This is called when an error occurs during agent execution to ensure the next interaction starts fresh instead of trying to resume from a broken state (e.g., orphaned tool calls). Args: session_id: The session ID whose checkpoint should be cleared """ try: checkpointer = self.memory_manager.get_checkpointer() if checkpointer: # Delete all checkpoints for this thread # LangGraph uses thread_id as the key # The checkpointer API doesn't have a direct delete method, # but we can use the underlying connection async with checkpointer.conn.cursor() as cur: await cur.execute( "DELETE FROM checkpoints WHERE thread_id = ?", (session_id,) ) await checkpointer.conn.commit() logger.info(f"Cleared checkpoint for session {session_id}") except Exception as e: logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}") async def execute( self, session: UserSession, message: GatewayUserMessage ) -> AsyncIterator[str]: """Execute the agent and stream responses. Args: session: User session message: User message Yields: Response chunks as they're generated """ logger.info(f"AgentExecutor.execute called for session {session.session_id}") # Get session lock to prevent concurrent execution lock = await self.session_manager.get_session_lock(session.session_id) logger.info(f"Session lock acquired for {session.session_id}") async with lock: try: # Build system prompt with current context context = self.memory_manager.get_context_prompt() system_prompt = build_system_prompt(context, session.active_channels) # Build message history WITHOUT prepending system message # The system prompt will be passed via state_modifier in the config messages = [] history = session.get_history(limit=10) logger.info(f"Building message history, {len(history)} messages in history") for i, msg in enumerate(history): logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'") if msg.role == "user": messages.append(HumanMessage(content=msg.content)) elif msg.role == "assistant": messages.append(AIMessage(content=msg.content)) logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)") for i, msg in enumerate(messages): msg_type = type(msg).__name__ content_preview = msg.content[:100] if msg.content else 'EMPTY' logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'") # Prepare config with metadata and dynamic system prompt # Pass system_prompt via state_modifier to avoid multiple system messages config = RunnableConfig( configurable={ "thread_id": session.session_id, "state_modifier": system_prompt # Dynamic system prompt injection }, metadata={ "session_id": session.session_id, "user_id": session.user_id, "active_channels": session.active_channels } ) logger.info(f"Agent config prepared: thread_id={session.session_id}") # Invoke agent with streaming logger.info("Starting agent.astream_events...") full_response = "" event_count = 0 chunk_count = 0 plot_urls = [] # Accumulate plot URLs from execute_python tool calls async for event in self.agent.astream_events( {"messages": messages}, config=config, version="v2" ): event_count += 1 # Check for cancellation if asyncio.current_task().cancelled(): logger.warning("Agent execution cancelled") break # Log tool calls if event["event"] == "on_tool_start": tool_name = event.get("name", "unknown") tool_input = event.get("data", {}).get("input", {}) logger.info(f"Tool call started: {tool_name} with input: {tool_input}") elif event["event"] == "on_tool_end": tool_name = event.get("name", "unknown") tool_output = event.get("data", {}).get("output") # LangChain may wrap the output in a ToolMessage with content field # Try to extract the actual content from the ToolMessage actual_output = tool_output if hasattr(tool_output, "content"): actual_output = tool_output.content logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}") # Extract plot_urls from execute_python tool results if tool_name == "execute_python": # Try to parse as JSON if it's a string import json if isinstance(actual_output, str): try: actual_output = json.loads(actual_output) except (json.JSONDecodeError, ValueError): logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}") if isinstance(actual_output, dict): tool_plot_urls = actual_output.get("plot_urls", []) if tool_plot_urls: logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}") plot_urls.extend(tool_plot_urls) # Yield metadata about plots immediately yield { "content": "", "metadata": {"plot_urls": tool_plot_urls} } # Extract streaming tokens elif event["event"] == "on_chat_model_stream": chunk = event["data"]["chunk"] if hasattr(chunk, "content") and chunk.content: content = chunk.content # Handle both string and list content if isinstance(content, list): # Extract text from content blocks text_parts = [] for block in content: if isinstance(block, dict) and "text" in block: text_parts.append(block["text"]) elif hasattr(block, "text"): text_parts.append(block.text) content = "".join(text_parts) if content: # Only yield non-empty content full_response += content chunk_count += 1 logger.debug(f"Yielding content chunk #{chunk_count}") yield content logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars") # Save to persistent memory logger.info(f"Saving assistant message to persistent memory") await self.session_manager.save_message( session.session_id, "assistant", full_response ) logger.info("Assistant message saved to memory") except asyncio.CancelledError: logger.warning(f"Agent execution cancelled for session {session.session_id}") # Clear checkpoint on cancellation to prevent orphaned tool calls await self._clear_checkpoint(session.session_id) raise except Exception as e: error_msg = f"Agent execution error: {str(e)}" logger.error(error_msg, exc_info=True) # Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls) # This ensures the next interaction starts fresh instead of trying to resume # from a broken state await self._clear_checkpoint(session.session_id) yield error_msg def create_agent( model_name: str = "claude-sonnet-4-20250514", temperature: float = 0.7, api_key: Optional[str] = None, checkpoint_db_path: str = "data/checkpoints.db", chroma_db_path: str = "data/chroma", embedding_model: str = "all-MiniLM-L6-v2", context_docs_dir: str = "doc", base_dir: str = "." ) -> AgentExecutor: """Create and initialize an agent executor. Args: model_name: Anthropic model name temperature: Model temperature api_key: Anthropic API key checkpoint_db_path: Path to LangGraph checkpoint SQLite DB chroma_db_path: Path to ChromaDB storage directory embedding_model: Sentence-transformers model name context_docs_dir: Directory with context markdown files base_dir: Base directory for resolving paths Returns: Initialized AgentExecutor with hierarchical tool routing """ # Initialize memory manager memory_manager = MemoryManager( checkpoint_db_path=checkpoint_db_path, chroma_db_path=chroma_db_path, embedding_model=embedding_model, context_docs_dir=context_docs_dir, base_dir=base_dir ) # Create executor executor = AgentExecutor( model_name=model_name, temperature=temperature, api_key=api_key, memory_manager=memory_manager, base_dir=base_dir ) return executor