380 lines
16 KiB
Python
380 lines
16 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import AsyncIterator, Dict, Any, Optional
|
|
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
|
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
|
|
from agent.memory import MemoryManager
|
|
from agent.session import SessionManager
|
|
from agent.prompts import build_system_prompt
|
|
from agent.subagent import SubAgent
|
|
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
|
|
from gateway.user_session import UserSession
|
|
from gateway.protocol import UserMessage as GatewayUserMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentExecutor:
|
|
"""LangGraph-based agent executor with streaming support.
|
|
|
|
Handles agent invocation, tool execution, and response streaming.
|
|
Supports interruption for real-time user interaction.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "claude-sonnet-4-20250514",
|
|
temperature: float = 0.7,
|
|
api_key: Optional[str] = None,
|
|
memory_manager: Optional[MemoryManager] = None,
|
|
base_dir: str = "."
|
|
):
|
|
"""Initialize agent executor.
|
|
|
|
Args:
|
|
model_name: Anthropic model name
|
|
temperature: Model temperature
|
|
api_key: Anthropic API key
|
|
memory_manager: MemoryManager instance
|
|
base_dir: Base directory for resolving paths
|
|
"""
|
|
self.model_name = model_name
|
|
self.temperature = temperature
|
|
self.api_key = api_key
|
|
self.base_dir = base_dir
|
|
|
|
# Initialize LLM
|
|
self.llm = ChatAnthropic(
|
|
model=model_name,
|
|
temperature=temperature,
|
|
api_key=api_key,
|
|
streaming=True
|
|
)
|
|
|
|
# Memory and session management
|
|
self.memory_manager = memory_manager or MemoryManager()
|
|
self.session_manager = SessionManager(self.memory_manager)
|
|
self.agent = None # Will be created after initialization
|
|
|
|
# Sub-agents (only if using hierarchical tools)
|
|
self.chart_agent = None
|
|
self.data_agent = None
|
|
self.automation_agent = None
|
|
self.research_agent = None
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the agent system."""
|
|
await self.memory_manager.initialize()
|
|
|
|
# Create agent with tools and LangGraph checkpointer
|
|
checkpointer = self.memory_manager.get_checkpointer()
|
|
|
|
# Create specialized sub-agents
|
|
logger.info("Initializing hierarchical agent architecture with sub-agents")
|
|
|
|
self.chart_agent = SubAgent(
|
|
name="chart",
|
|
soul_file="chart_agent.md",
|
|
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
|
|
model_name=self.model_name,
|
|
temperature=self.temperature,
|
|
api_key=self.api_key,
|
|
base_dir=self.base_dir
|
|
)
|
|
|
|
self.data_agent = SubAgent(
|
|
name="data",
|
|
soul_file="data_agent.md",
|
|
tools=DATASOURCE_TOOLS,
|
|
model_name=self.model_name,
|
|
temperature=self.temperature,
|
|
api_key=self.api_key,
|
|
base_dir=self.base_dir
|
|
)
|
|
|
|
self.automation_agent = SubAgent(
|
|
name="automation",
|
|
soul_file="automation_agent.md",
|
|
tools=TRIGGER_TOOLS,
|
|
model_name=self.model_name,
|
|
temperature=self.temperature,
|
|
api_key=self.api_key,
|
|
base_dir=self.base_dir
|
|
)
|
|
|
|
self.research_agent = SubAgent(
|
|
name="research",
|
|
soul_file="research_agent.md",
|
|
tools=RESEARCH_TOOLS,
|
|
model_name=self.model_name,
|
|
temperature=self.temperature,
|
|
api_key=self.api_key,
|
|
base_dir=self.base_dir
|
|
)
|
|
|
|
# Set global sub-agent instances for router tools
|
|
set_chart_agent(self.chart_agent)
|
|
set_data_agent(self.data_agent)
|
|
set_automation_agent(self.automation_agent)
|
|
set_research_agent(self.research_agent)
|
|
|
|
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
|
|
logger.info("Main agent using router tools (4 routers + sync tools)")
|
|
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
|
|
|
|
# Create main agent without a static system prompt
|
|
# We'll pass the dynamic system prompt via state_modifier at runtime
|
|
self.agent = create_react_agent(
|
|
self.llm,
|
|
agent_tools,
|
|
checkpointer=checkpointer
|
|
)
|
|
|
|
logger.info(f"Agent initialized with {len(agent_tools)} tools")
|
|
|
|
async def _clear_checkpoint(self, session_id: str) -> None:
|
|
"""Clear the checkpoint for a session to prevent resuming from invalid state.
|
|
|
|
This is called when an error occurs during agent execution to ensure
|
|
the next interaction starts fresh instead of trying to resume from
|
|
a broken state (e.g., orphaned tool calls).
|
|
|
|
Args:
|
|
session_id: The session ID whose checkpoint should be cleared
|
|
"""
|
|
try:
|
|
checkpointer = self.memory_manager.get_checkpointer()
|
|
if checkpointer:
|
|
# Delete all checkpoints for this thread
|
|
# LangGraph uses thread_id as the key
|
|
# The checkpointer API doesn't have a direct delete method,
|
|
# but we can use the underlying connection
|
|
async with checkpointer.conn.cursor() as cur:
|
|
await cur.execute(
|
|
"DELETE FROM checkpoints WHERE thread_id = ?",
|
|
(session_id,)
|
|
)
|
|
await checkpointer.conn.commit()
|
|
logger.info(f"Cleared checkpoint for session {session_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
|
|
|
async def execute(
|
|
self,
|
|
session: UserSession,
|
|
message: GatewayUserMessage
|
|
) -> AsyncIterator[str]:
|
|
"""Execute the agent and stream responses.
|
|
|
|
Args:
|
|
session: User session
|
|
message: User message
|
|
|
|
Yields:
|
|
Response chunks as they're generated
|
|
"""
|
|
logger.info(f"AgentExecutor.execute called for session {session.session_id}")
|
|
|
|
# Get session lock to prevent concurrent execution
|
|
lock = await self.session_manager.get_session_lock(session.session_id)
|
|
logger.info(f"Session lock acquired for {session.session_id}")
|
|
|
|
async with lock:
|
|
try:
|
|
# Build system prompt with current context
|
|
context = self.memory_manager.get_context_prompt()
|
|
system_prompt = build_system_prompt(context, session.active_channels)
|
|
|
|
# Build message history WITHOUT prepending system message
|
|
# The system prompt will be passed via state_modifier in the config
|
|
messages = []
|
|
history = session.get_history(limit=10)
|
|
logger.info(f"Building message history, {len(history)} messages in history")
|
|
|
|
for i, msg in enumerate(history):
|
|
logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'")
|
|
if msg.role == "user":
|
|
messages.append(HumanMessage(content=msg.content))
|
|
elif msg.role == "assistant":
|
|
messages.append(AIMessage(content=msg.content))
|
|
|
|
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
|
|
for i, msg in enumerate(messages):
|
|
msg_type = type(msg).__name__
|
|
content_preview = msg.content[:100] if msg.content else 'EMPTY'
|
|
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
|
|
|
|
# Prepare config with metadata and dynamic system prompt
|
|
# Pass system_prompt via state_modifier to avoid multiple system messages
|
|
config = RunnableConfig(
|
|
configurable={
|
|
"thread_id": session.session_id,
|
|
"state_modifier": system_prompt # Dynamic system prompt injection
|
|
},
|
|
metadata={
|
|
"session_id": session.session_id,
|
|
"user_id": session.user_id,
|
|
"active_channels": session.active_channels
|
|
}
|
|
)
|
|
logger.info(f"Agent config prepared: thread_id={session.session_id}")
|
|
|
|
# Invoke agent with streaming
|
|
logger.info("Starting agent.astream_events...")
|
|
full_response = ""
|
|
event_count = 0
|
|
chunk_count = 0
|
|
|
|
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
|
|
|
|
async for event in self.agent.astream_events(
|
|
{"messages": messages},
|
|
config=config,
|
|
version="v2"
|
|
):
|
|
event_count += 1
|
|
|
|
# Check for cancellation
|
|
if asyncio.current_task().cancelled():
|
|
logger.warning("Agent execution cancelled")
|
|
break
|
|
|
|
# Log tool calls
|
|
if event["event"] == "on_tool_start":
|
|
tool_name = event.get("name", "unknown")
|
|
tool_input = event.get("data", {}).get("input", {})
|
|
logger.info(f"Tool call started: {tool_name} with input: {tool_input}")
|
|
|
|
elif event["event"] == "on_tool_end":
|
|
tool_name = event.get("name", "unknown")
|
|
tool_output = event.get("data", {}).get("output")
|
|
|
|
# LangChain may wrap the output in a ToolMessage with content field
|
|
# Try to extract the actual content from the ToolMessage
|
|
actual_output = tool_output
|
|
if hasattr(tool_output, "content"):
|
|
actual_output = tool_output.content
|
|
|
|
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
|
|
|
|
# Extract plot_urls from execute_python tool results
|
|
if tool_name == "execute_python":
|
|
# Try to parse as JSON if it's a string
|
|
import json
|
|
if isinstance(actual_output, str):
|
|
try:
|
|
actual_output = json.loads(actual_output)
|
|
except (json.JSONDecodeError, ValueError):
|
|
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
|
|
|
|
if isinstance(actual_output, dict):
|
|
tool_plot_urls = actual_output.get("plot_urls", [])
|
|
if tool_plot_urls:
|
|
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
|
|
plot_urls.extend(tool_plot_urls)
|
|
# Yield metadata about plots immediately
|
|
yield {
|
|
"content": "",
|
|
"metadata": {"plot_urls": tool_plot_urls}
|
|
}
|
|
|
|
# Extract streaming tokens
|
|
elif event["event"] == "on_chat_model_stream":
|
|
chunk = event["data"]["chunk"]
|
|
if hasattr(chunk, "content") and chunk.content:
|
|
content = chunk.content
|
|
# Handle both string and list content
|
|
if isinstance(content, list):
|
|
# Extract text from content blocks
|
|
text_parts = []
|
|
for block in content:
|
|
if isinstance(block, dict) and "text" in block:
|
|
text_parts.append(block["text"])
|
|
elif hasattr(block, "text"):
|
|
text_parts.append(block.text)
|
|
content = "".join(text_parts)
|
|
|
|
if content: # Only yield non-empty content
|
|
full_response += content
|
|
chunk_count += 1
|
|
logger.debug(f"Yielding content chunk #{chunk_count}")
|
|
yield content
|
|
|
|
logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars")
|
|
|
|
# Save to persistent memory
|
|
logger.info(f"Saving assistant message to persistent memory")
|
|
await self.session_manager.save_message(
|
|
session.session_id,
|
|
"assistant",
|
|
full_response
|
|
)
|
|
logger.info("Assistant message saved to memory")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.warning(f"Agent execution cancelled for session {session.session_id}")
|
|
# Clear checkpoint on cancellation to prevent orphaned tool calls
|
|
await self._clear_checkpoint(session.session_id)
|
|
raise
|
|
except Exception as e:
|
|
error_msg = f"Agent execution error: {str(e)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
|
|
# Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls)
|
|
# This ensures the next interaction starts fresh instead of trying to resume
|
|
# from a broken state
|
|
await self._clear_checkpoint(session.session_id)
|
|
|
|
yield error_msg
|
|
|
|
|
|
def create_agent(
|
|
model_name: str = "claude-sonnet-4-20250514",
|
|
temperature: float = 0.7,
|
|
api_key: Optional[str] = None,
|
|
checkpoint_db_path: str = "data/checkpoints.db",
|
|
chroma_db_path: str = "data/chroma",
|
|
embedding_model: str = "all-MiniLM-L6-v2",
|
|
context_docs_dir: str = "doc",
|
|
base_dir: str = "."
|
|
) -> AgentExecutor:
|
|
"""Create and initialize an agent executor.
|
|
|
|
Args:
|
|
model_name: Anthropic model name
|
|
temperature: Model temperature
|
|
api_key: Anthropic API key
|
|
checkpoint_db_path: Path to LangGraph checkpoint SQLite DB
|
|
chroma_db_path: Path to ChromaDB storage directory
|
|
embedding_model: Sentence-transformers model name
|
|
context_docs_dir: Directory with context markdown files
|
|
base_dir: Base directory for resolving paths
|
|
|
|
Returns:
|
|
Initialized AgentExecutor with hierarchical tool routing
|
|
"""
|
|
# Initialize memory manager
|
|
memory_manager = MemoryManager(
|
|
checkpoint_db_path=checkpoint_db_path,
|
|
chroma_db_path=chroma_db_path,
|
|
embedding_model=embedding_model,
|
|
context_docs_dir=context_docs_dir,
|
|
base_dir=base_dir
|
|
)
|
|
|
|
# Create executor
|
|
executor = AgentExecutor(
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
api_key=api_key,
|
|
memory_manager=memory_manager,
|
|
base_dir=base_dir
|
|
)
|
|
|
|
return executor
|