backend redesign
This commit is contained in:
379
backend.old/src/agent/core.py
Normal file
379
backend.old/src/agent/core.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncIterator, Dict, Any, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
from agent.subagent import SubAgent
|
||||
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage as GatewayUserMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""LangGraph-based agent executor with streaming support.
|
||||
|
||||
Handles agent invocation, tool execution, and response streaming.
|
||||
Supports interruption for real-time user interaction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
memory_manager: Optional[MemoryManager] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize agent executor.
|
||||
|
||||
Args:
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
memory_manager: MemoryManager instance
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Memory and session management
|
||||
self.memory_manager = memory_manager or MemoryManager()
|
||||
self.session_manager = SessionManager(self.memory_manager)
|
||||
self.agent = None # Will be created after initialization
|
||||
|
||||
# Sub-agents (only if using hierarchical tools)
|
||||
self.chart_agent = None
|
||||
self.data_agent = None
|
||||
self.automation_agent = None
|
||||
self.research_agent = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent system."""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# Create agent with tools and LangGraph checkpointer
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
|
||||
# Create specialized sub-agents
|
||||
logger.info("Initializing hierarchical agent architecture with sub-agents")
|
||||
|
||||
self.chart_agent = SubAgent(
|
||||
name="chart",
|
||||
soul_file="chart_agent.md",
|
||||
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.data_agent = SubAgent(
|
||||
name="data",
|
||||
soul_file="data_agent.md",
|
||||
tools=DATASOURCE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.automation_agent = SubAgent(
|
||||
name="automation",
|
||||
soul_file="automation_agent.md",
|
||||
tools=TRIGGER_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.research_agent = SubAgent(
|
||||
name="research",
|
||||
soul_file="research_agent.md",
|
||||
tools=RESEARCH_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
# Set global sub-agent instances for router tools
|
||||
set_chart_agent(self.chart_agent)
|
||||
set_data_agent(self.data_agent)
|
||||
set_automation_agent(self.automation_agent)
|
||||
set_research_agent(self.research_agent)
|
||||
|
||||
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
|
||||
logger.info("Main agent using router tools (4 routers + sync tools)")
|
||||
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
|
||||
|
||||
# Create main agent without a static system prompt
|
||||
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
agent_tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(f"Agent initialized with {len(agent_tools)} tools")
|
||||
|
||||
async def _clear_checkpoint(self, session_id: str) -> None:
|
||||
"""Clear the checkpoint for a session to prevent resuming from invalid state.
|
||||
|
||||
This is called when an error occurs during agent execution to ensure
|
||||
the next interaction starts fresh instead of trying to resume from
|
||||
a broken state (e.g., orphaned tool calls).
|
||||
|
||||
Args:
|
||||
session_id: The session ID whose checkpoint should be cleared
|
||||
"""
|
||||
try:
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
if checkpointer:
|
||||
# Delete all checkpoints for this thread
|
||||
# LangGraph uses thread_id as the key
|
||||
# The checkpointer API doesn't have a direct delete method,
|
||||
# but we can use the underlying connection
|
||||
async with checkpointer.conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||
(session_id,)
|
||||
)
|
||||
await checkpointer.conn.commit()
|
||||
logger.info(f"Cleared checkpoint for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: UserSession,
|
||||
message: GatewayUserMessage
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute the agent and stream responses.
|
||||
|
||||
Args:
|
||||
session: User session
|
||||
message: User message
|
||||
|
||||
Yields:
|
||||
Response chunks as they're generated
|
||||
"""
|
||||
logger.info(f"AgentExecutor.execute called for session {session.session_id}")
|
||||
|
||||
# Get session lock to prevent concurrent execution
|
||||
lock = await self.session_manager.get_session_lock(session.session_id)
|
||||
logger.info(f"Session lock acquired for {session.session_id}")
|
||||
|
||||
async with lock:
|
||||
try:
|
||||
# Build system prompt with current context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, session.active_channels)
|
||||
|
||||
# Build message history WITHOUT prepending system message
|
||||
# The system prompt will be passed via state_modifier in the config
|
||||
messages = []
|
||||
history = session.get_history(limit=10)
|
||||
logger.info(f"Building message history, {len(history)} messages in history")
|
||||
|
||||
for i, msg in enumerate(history):
|
||||
logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'")
|
||||
if msg.role == "user":
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
|
||||
for i, msg in enumerate(messages):
|
||||
msg_type = type(msg).__name__
|
||||
content_preview = msg.content[:100] if msg.content else 'EMPTY'
|
||||
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
|
||||
|
||||
# Prepare config with metadata and dynamic system prompt
|
||||
# Pass system_prompt via state_modifier to avoid multiple system messages
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": session.session_id,
|
||||
"state_modifier": system_prompt # Dynamic system prompt injection
|
||||
},
|
||||
metadata={
|
||||
"session_id": session.session_id,
|
||||
"user_id": session.user_id,
|
||||
"active_channels": session.active_channels
|
||||
}
|
||||
)
|
||||
logger.info(f"Agent config prepared: thread_id={session.session_id}")
|
||||
|
||||
# Invoke agent with streaming
|
||||
logger.info("Starting agent.astream_events...")
|
||||
full_response = ""
|
||||
event_count = 0
|
||||
chunk_count = 0
|
||||
|
||||
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
|
||||
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
# Check for cancellation
|
||||
if asyncio.current_task().cancelled():
|
||||
logger.warning("Agent execution cancelled")
|
||||
break
|
||||
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
logger.info(f"Tool call started: {tool_name} with input: {tool_input}")
|
||||
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_output = event.get("data", {}).get("output")
|
||||
|
||||
# LangChain may wrap the output in a ToolMessage with content field
|
||||
# Try to extract the actual content from the ToolMessage
|
||||
actual_output = tool_output
|
||||
if hasattr(tool_output, "content"):
|
||||
actual_output = tool_output.content
|
||||
|
||||
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
|
||||
|
||||
# Extract plot_urls from execute_python tool results
|
||||
if tool_name == "execute_python":
|
||||
# Try to parse as JSON if it's a string
|
||||
import json
|
||||
if isinstance(actual_output, str):
|
||||
try:
|
||||
actual_output = json.loads(actual_output)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
|
||||
|
||||
if isinstance(actual_output, dict):
|
||||
tool_plot_urls = actual_output.get("plot_urls", [])
|
||||
if tool_plot_urls:
|
||||
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
|
||||
plot_urls.extend(tool_plot_urls)
|
||||
# Yield metadata about plots immediately
|
||||
yield {
|
||||
"content": "",
|
||||
"metadata": {"plot_urls": tool_plot_urls}
|
||||
}
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content blocks
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content: # Only yield non-empty content
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
logger.debug(f"Yielding content chunk #{chunk_count}")
|
||||
yield content
|
||||
|
||||
logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars")
|
||||
|
||||
# Save to persistent memory
|
||||
logger.info(f"Saving assistant message to persistent memory")
|
||||
await self.session_manager.save_message(
|
||||
session.session_id,
|
||||
"assistant",
|
||||
full_response
|
||||
)
|
||||
logger.info("Assistant message saved to memory")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"Agent execution cancelled for session {session.session_id}")
|
||||
# Clear checkpoint on cancellation to prevent orphaned tool calls
|
||||
await self._clear_checkpoint(session.session_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls)
|
||||
# This ensures the next interaction starts fresh instead of trying to resume
|
||||
# from a broken state
|
||||
await self._clear_checkpoint(session.session_id)
|
||||
|
||||
yield error_msg
|
||||
|
||||
|
||||
def create_agent(
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
checkpoint_db_path: str = "data/checkpoints.db",
|
||||
chroma_db_path: str = "data/chroma",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
context_docs_dir: str = "doc",
|
||||
base_dir: str = "."
|
||||
) -> AgentExecutor:
|
||||
"""Create and initialize an agent executor.
|
||||
|
||||
Args:
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
checkpoint_db_path: Path to LangGraph checkpoint SQLite DB
|
||||
chroma_db_path: Path to ChromaDB storage directory
|
||||
embedding_model: Sentence-transformers model name
|
||||
context_docs_dir: Directory with context markdown files
|
||||
base_dir: Base directory for resolving paths
|
||||
|
||||
Returns:
|
||||
Initialized AgentExecutor with hierarchical tool routing
|
||||
"""
|
||||
# Initialize memory manager
|
||||
memory_manager = MemoryManager(
|
||||
checkpoint_db_path=checkpoint_db_path,
|
||||
chroma_db_path=chroma_db_path,
|
||||
embedding_model=embedding_model,
|
||||
context_docs_dir=context_docs_dir,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = AgentExecutor(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
memory_manager=memory_manager,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
return executor
|
||||
Reference in New Issue
Block a user