Files
ai/backend.old/src/agent/core.py
2026-03-11 18:47:11 -04:00

380 lines
16 KiB
Python

import asyncio
import logging
from typing import AsyncIterator, Dict, Any, Optional
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
from agent.memory import MemoryManager
from agent.session import SessionManager
from agent.prompts import build_system_prompt
from agent.subagent import SubAgent
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
from gateway.user_session import UserSession
from gateway.protocol import UserMessage as GatewayUserMessage
logger = logging.getLogger(__name__)
class AgentExecutor:
"""LangGraph-based agent executor with streaming support.
Handles agent invocation, tool execution, and response streaming.
Supports interruption for real-time user interaction.
"""
def __init__(
self,
model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7,
api_key: Optional[str] = None,
memory_manager: Optional[MemoryManager] = None,
base_dir: str = "."
):
"""Initialize agent executor.
Args:
model_name: Anthropic model name
temperature: Model temperature
api_key: Anthropic API key
memory_manager: MemoryManager instance
base_dir: Base directory for resolving paths
"""
self.model_name = model_name
self.temperature = temperature
self.api_key = api_key
self.base_dir = base_dir
# Initialize LLM
self.llm = ChatAnthropic(
model=model_name,
temperature=temperature,
api_key=api_key,
streaming=True
)
# Memory and session management
self.memory_manager = memory_manager or MemoryManager()
self.session_manager = SessionManager(self.memory_manager)
self.agent = None # Will be created after initialization
# Sub-agents (only if using hierarchical tools)
self.chart_agent = None
self.data_agent = None
self.automation_agent = None
self.research_agent = None
async def initialize(self) -> None:
"""Initialize the agent system."""
await self.memory_manager.initialize()
# Create agent with tools and LangGraph checkpointer
checkpointer = self.memory_manager.get_checkpointer()
# Create specialized sub-agents
logger.info("Initializing hierarchical agent architecture with sub-agents")
self.chart_agent = SubAgent(
name="chart",
soul_file="chart_agent.md",
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
self.data_agent = SubAgent(
name="data",
soul_file="data_agent.md",
tools=DATASOURCE_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
self.automation_agent = SubAgent(
name="automation",
soul_file="automation_agent.md",
tools=TRIGGER_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
self.research_agent = SubAgent(
name="research",
soul_file="research_agent.md",
tools=RESEARCH_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
# Set global sub-agent instances for router tools
set_chart_agent(self.chart_agent)
set_data_agent(self.data_agent)
set_automation_agent(self.automation_agent)
set_research_agent(self.research_agent)
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
logger.info("Main agent using router tools (4 routers + sync tools)")
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
# Create main agent without a static system prompt
# We'll pass the dynamic system prompt via state_modifier at runtime
self.agent = create_react_agent(
self.llm,
agent_tools,
checkpointer=checkpointer
)
logger.info(f"Agent initialized with {len(agent_tools)} tools")
async def _clear_checkpoint(self, session_id: str) -> None:
"""Clear the checkpoint for a session to prevent resuming from invalid state.
This is called when an error occurs during agent execution to ensure
the next interaction starts fresh instead of trying to resume from
a broken state (e.g., orphaned tool calls).
Args:
session_id: The session ID whose checkpoint should be cleared
"""
try:
checkpointer = self.memory_manager.get_checkpointer()
if checkpointer:
# Delete all checkpoints for this thread
# LangGraph uses thread_id as the key
# The checkpointer API doesn't have a direct delete method,
# but we can use the underlying connection
async with checkpointer.conn.cursor() as cur:
await cur.execute(
"DELETE FROM checkpoints WHERE thread_id = ?",
(session_id,)
)
await checkpointer.conn.commit()
logger.info(f"Cleared checkpoint for session {session_id}")
except Exception as e:
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
async def execute(
self,
session: UserSession,
message: GatewayUserMessage
) -> AsyncIterator[str]:
"""Execute the agent and stream responses.
Args:
session: User session
message: User message
Yields:
Response chunks as they're generated
"""
logger.info(f"AgentExecutor.execute called for session {session.session_id}")
# Get session lock to prevent concurrent execution
lock = await self.session_manager.get_session_lock(session.session_id)
logger.info(f"Session lock acquired for {session.session_id}")
async with lock:
try:
# Build system prompt with current context
context = self.memory_manager.get_context_prompt()
system_prompt = build_system_prompt(context, session.active_channels)
# Build message history WITHOUT prepending system message
# The system prompt will be passed via state_modifier in the config
messages = []
history = session.get_history(limit=10)
logger.info(f"Building message history, {len(history)} messages in history")
for i, msg in enumerate(history):
logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'")
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
for i, msg in enumerate(messages):
msg_type = type(msg).__name__
content_preview = msg.content[:100] if msg.content else 'EMPTY'
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
# Prepare config with metadata and dynamic system prompt
# Pass system_prompt via state_modifier to avoid multiple system messages
config = RunnableConfig(
configurable={
"thread_id": session.session_id,
"state_modifier": system_prompt # Dynamic system prompt injection
},
metadata={
"session_id": session.session_id,
"user_id": session.user_id,
"active_channels": session.active_channels
}
)
logger.info(f"Agent config prepared: thread_id={session.session_id}")
# Invoke agent with streaming
logger.info("Starting agent.astream_events...")
full_response = ""
event_count = 0
chunk_count = 0
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
async for event in self.agent.astream_events(
{"messages": messages},
config=config,
version="v2"
):
event_count += 1
# Check for cancellation
if asyncio.current_task().cancelled():
logger.warning("Agent execution cancelled")
break
# Log tool calls
if event["event"] == "on_tool_start":
tool_name = event.get("name", "unknown")
tool_input = event.get("data", {}).get("input", {})
logger.info(f"Tool call started: {tool_name} with input: {tool_input}")
elif event["event"] == "on_tool_end":
tool_name = event.get("name", "unknown")
tool_output = event.get("data", {}).get("output")
# LangChain may wrap the output in a ToolMessage with content field
# Try to extract the actual content from the ToolMessage
actual_output = tool_output
if hasattr(tool_output, "content"):
actual_output = tool_output.content
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
# Extract plot_urls from execute_python tool results
if tool_name == "execute_python":
# Try to parse as JSON if it's a string
import json
if isinstance(actual_output, str):
try:
actual_output = json.loads(actual_output)
except (json.JSONDecodeError, ValueError):
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
if isinstance(actual_output, dict):
tool_plot_urls = actual_output.get("plot_urls", [])
if tool_plot_urls:
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
plot_urls.extend(tool_plot_urls)
# Yield metadata about plots immediately
yield {
"content": "",
"metadata": {"plot_urls": tool_plot_urls}
}
# Extract streaming tokens
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"]
if hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Handle both string and list content
if isinstance(content, list):
# Extract text from content blocks
text_parts = []
for block in content:
if isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
elif hasattr(block, "text"):
text_parts.append(block.text)
content = "".join(text_parts)
if content: # Only yield non-empty content
full_response += content
chunk_count += 1
logger.debug(f"Yielding content chunk #{chunk_count}")
yield content
logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars")
# Save to persistent memory
logger.info(f"Saving assistant message to persistent memory")
await self.session_manager.save_message(
session.session_id,
"assistant",
full_response
)
logger.info("Assistant message saved to memory")
except asyncio.CancelledError:
logger.warning(f"Agent execution cancelled for session {session.session_id}")
# Clear checkpoint on cancellation to prevent orphaned tool calls
await self._clear_checkpoint(session.session_id)
raise
except Exception as e:
error_msg = f"Agent execution error: {str(e)}"
logger.error(error_msg, exc_info=True)
# Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls)
# This ensures the next interaction starts fresh instead of trying to resume
# from a broken state
await self._clear_checkpoint(session.session_id)
yield error_msg
def create_agent(
model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7,
api_key: Optional[str] = None,
checkpoint_db_path: str = "data/checkpoints.db",
chroma_db_path: str = "data/chroma",
embedding_model: str = "all-MiniLM-L6-v2",
context_docs_dir: str = "doc",
base_dir: str = "."
) -> AgentExecutor:
"""Create and initialize an agent executor.
Args:
model_name: Anthropic model name
temperature: Model temperature
api_key: Anthropic API key
checkpoint_db_path: Path to LangGraph checkpoint SQLite DB
chroma_db_path: Path to ChromaDB storage directory
embedding_model: Sentence-transformers model name
context_docs_dir: Directory with context markdown files
base_dir: Base directory for resolving paths
Returns:
Initialized AgentExecutor with hierarchical tool routing
"""
# Initialize memory manager
memory_manager = MemoryManager(
checkpoint_db_path=checkpoint_db_path,
chroma_db_path=chroma_db_path,
embedding_model=embedding_model,
context_docs_dir=context_docs_dir,
base_dir=base_dir
)
# Create executor
executor = AgentExecutor(
model_name=model_name,
temperature=temperature,
api_key=api_key,
memory_manager=memory_manager,
base_dir=base_dir
)
return executor