Files
ai/backend.old/src/agent/subagent.py
2026-03-11 18:47:11 -04:00

249 lines
8.4 KiB
Python

"""Sub-agent infrastructure for specialized tool routing.
This module provides the SubAgent class that wraps specialized agents
with their own tools and system prompts.
"""
import logging
from typing import List, Optional, AsyncIterator
from pathlib import Path
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
logger = logging.getLogger(__name__)
class SubAgent:
"""A specialized sub-agent with its own tools and system prompt.
Sub-agents are lightweight, stateless agents that focus on specific domains.
They use in-memory checkpointing since they don't need persistent state.
"""
def __init__(
self,
name: str,
soul_file: str,
tools: List,
model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7,
api_key: Optional[str] = None,
base_dir: str = "."
):
"""Initialize a sub-agent.
Args:
name: Agent name (e.g., "chart", "data", "automation")
soul_file: Filename in /soul directory (e.g., "chart_agent.md")
tools: List of LangChain tools for this agent
model_name: Anthropic model name
temperature: Model temperature
api_key: Anthropic API key
base_dir: Base directory for resolving paths
"""
self.name = name
self.soul_file = soul_file
self.tools = tools
self.model_name = model_name
self.temperature = temperature
self.api_key = api_key
self.base_dir = base_dir
# Load system prompt from soul file
soul_path = Path(base_dir) / "soul" / soul_file
if soul_path.exists():
with open(soul_path, "r") as f:
self.system_prompt = f.read()
logger.info(f"SubAgent '{name}': Loaded system prompt from {soul_path}")
else:
logger.warning(f"SubAgent '{name}': Soul file not found at {soul_path}, using default")
self.system_prompt = f"You are a specialized {name} agent."
# Initialize LLM
self.llm = ChatAnthropic(
model=model_name,
temperature=temperature,
api_key=api_key,
streaming=True
)
# Create agent with in-memory checkpointer (stateless)
checkpointer = MemorySaver()
self.agent = create_react_agent(
self.llm,
tools,
checkpointer=checkpointer
)
logger.info(
f"SubAgent '{name}' initialized with {len(tools)} tools, "
f"model={model_name}, temp={temperature}"
)
async def execute(
self,
task: str,
thread_id: Optional[str] = None
) -> str:
"""Execute a task with this sub-agent.
Args:
task: The task/prompt for this sub-agent
thread_id: Optional thread ID for checkpointing (uses ephemeral ID if not provided)
Returns:
The agent's complete response as a string
"""
import uuid
# Use ephemeral thread ID if not provided
if thread_id is None:
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
logger.info(f"SubAgent '{self.name}': Executing task (thread_id={thread_id})")
logger.debug(f"SubAgent '{self.name}': Task: {task[:200]}...")
# Build messages with system prompt
messages = [
HumanMessage(content=task)
]
# Prepare config with system prompt injection
config = RunnableConfig(
configurable={
"thread_id": thread_id,
"state_modifier": self.system_prompt
},
metadata={
"subagent_name": self.name
}
)
# Execute and collect response
full_response = ""
event_count = 0
try:
async for event in self.agent.astream_events(
{"messages": messages},
config=config,
version="v2"
):
event_count += 1
# Log tool calls
if event["event"] == "on_tool_start":
tool_name = event.get("name", "unknown")
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
elif event["event"] == "on_tool_end":
tool_name = event.get("name", "unknown")
logger.debug(f"SubAgent '{self.name}': Tool call completed: {tool_name}")
# Extract streaming tokens
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"]
if hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Handle both string and list content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
elif hasattr(block, "text"):
text_parts.append(block.text)
content = "".join(text_parts)
if content:
full_response += content
logger.info(
f"SubAgent '{self.name}': Completed task "
f"({event_count} events, {len(full_response)} chars)"
)
except Exception as e:
error_msg = f"SubAgent '{self.name}' execution error: {str(e)}"
logger.error(error_msg, exc_info=True)
return f"Error: {error_msg}"
return full_response
async def stream(
self,
task: str,
thread_id: Optional[str] = None
) -> AsyncIterator[str]:
"""Execute a task with streaming response.
Args:
task: The task/prompt for this sub-agent
thread_id: Optional thread ID for checkpointing
Yields:
Response chunks as they're generated
"""
import uuid
# Use ephemeral thread ID if not provided
if thread_id is None:
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
logger.info(f"SubAgent '{self.name}': Streaming task (thread_id={thread_id})")
# Build messages with system prompt
messages = [
HumanMessage(content=task)
]
# Prepare config
config = RunnableConfig(
configurable={
"thread_id": thread_id,
"state_modifier": self.system_prompt
},
metadata={
"subagent_name": self.name
}
)
# Stream response
try:
async for event in self.agent.astream_events(
{"messages": messages},
config=config,
version="v2"
):
# Log tool calls
if event["event"] == "on_tool_start":
tool_name = event.get("name", "unknown")
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
# Extract streaming tokens
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"]
if hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Handle both string and list content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
elif hasattr(block, "text"):
text_parts.append(block.text)
content = "".join(text_parts)
if content:
yield content
except Exception as e:
error_msg = f"SubAgent '{self.name}' streaming error: {str(e)}"
logger.error(error_msg, exc_info=True)
yield f"Error: {error_msg}"