249 lines
8.4 KiB
Python
249 lines
8.4 KiB
Python
"""Sub-agent infrastructure for specialized tool routing.
|
|
|
|
This module provides the SubAgent class that wraps specialized agents
|
|
with their own tools and system prompts.
|
|
"""
|
|
|
|
import logging
|
|
from typing import List, Optional, AsyncIterator
|
|
from pathlib import Path
|
|
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langgraph.prebuilt import create_react_agent
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SubAgent:
|
|
"""A specialized sub-agent with its own tools and system prompt.
|
|
|
|
Sub-agents are lightweight, stateless agents that focus on specific domains.
|
|
They use in-memory checkpointing since they don't need persistent state.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
soul_file: str,
|
|
tools: List,
|
|
model_name: str = "claude-sonnet-4-20250514",
|
|
temperature: float = 0.7,
|
|
api_key: Optional[str] = None,
|
|
base_dir: str = "."
|
|
):
|
|
"""Initialize a sub-agent.
|
|
|
|
Args:
|
|
name: Agent name (e.g., "chart", "data", "automation")
|
|
soul_file: Filename in /soul directory (e.g., "chart_agent.md")
|
|
tools: List of LangChain tools for this agent
|
|
model_name: Anthropic model name
|
|
temperature: Model temperature
|
|
api_key: Anthropic API key
|
|
base_dir: Base directory for resolving paths
|
|
"""
|
|
self.name = name
|
|
self.soul_file = soul_file
|
|
self.tools = tools
|
|
self.model_name = model_name
|
|
self.temperature = temperature
|
|
self.api_key = api_key
|
|
self.base_dir = base_dir
|
|
|
|
# Load system prompt from soul file
|
|
soul_path = Path(base_dir) / "soul" / soul_file
|
|
if soul_path.exists():
|
|
with open(soul_path, "r") as f:
|
|
self.system_prompt = f.read()
|
|
logger.info(f"SubAgent '{name}': Loaded system prompt from {soul_path}")
|
|
else:
|
|
logger.warning(f"SubAgent '{name}': Soul file not found at {soul_path}, using default")
|
|
self.system_prompt = f"You are a specialized {name} agent."
|
|
|
|
# Initialize LLM
|
|
self.llm = ChatAnthropic(
|
|
model=model_name,
|
|
temperature=temperature,
|
|
api_key=api_key,
|
|
streaming=True
|
|
)
|
|
|
|
# Create agent with in-memory checkpointer (stateless)
|
|
checkpointer = MemorySaver()
|
|
self.agent = create_react_agent(
|
|
self.llm,
|
|
tools,
|
|
checkpointer=checkpointer
|
|
)
|
|
|
|
logger.info(
|
|
f"SubAgent '{name}' initialized with {len(tools)} tools, "
|
|
f"model={model_name}, temp={temperature}"
|
|
)
|
|
|
|
async def execute(
|
|
self,
|
|
task: str,
|
|
thread_id: Optional[str] = None
|
|
) -> str:
|
|
"""Execute a task with this sub-agent.
|
|
|
|
Args:
|
|
task: The task/prompt for this sub-agent
|
|
thread_id: Optional thread ID for checkpointing (uses ephemeral ID if not provided)
|
|
|
|
Returns:
|
|
The agent's complete response as a string
|
|
"""
|
|
import uuid
|
|
|
|
# Use ephemeral thread ID if not provided
|
|
if thread_id is None:
|
|
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
|
|
|
logger.info(f"SubAgent '{self.name}': Executing task (thread_id={thread_id})")
|
|
logger.debug(f"SubAgent '{self.name}': Task: {task[:200]}...")
|
|
|
|
# Build messages with system prompt
|
|
messages = [
|
|
HumanMessage(content=task)
|
|
]
|
|
|
|
# Prepare config with system prompt injection
|
|
config = RunnableConfig(
|
|
configurable={
|
|
"thread_id": thread_id,
|
|
"state_modifier": self.system_prompt
|
|
},
|
|
metadata={
|
|
"subagent_name": self.name
|
|
}
|
|
)
|
|
|
|
# Execute and collect response
|
|
full_response = ""
|
|
event_count = 0
|
|
|
|
try:
|
|
async for event in self.agent.astream_events(
|
|
{"messages": messages},
|
|
config=config,
|
|
version="v2"
|
|
):
|
|
event_count += 1
|
|
|
|
# Log tool calls
|
|
if event["event"] == "on_tool_start":
|
|
tool_name = event.get("name", "unknown")
|
|
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
|
|
|
elif event["event"] == "on_tool_end":
|
|
tool_name = event.get("name", "unknown")
|
|
logger.debug(f"SubAgent '{self.name}': Tool call completed: {tool_name}")
|
|
|
|
# Extract streaming tokens
|
|
elif event["event"] == "on_chat_model_stream":
|
|
chunk = event["data"]["chunk"]
|
|
if hasattr(chunk, "content") and chunk.content:
|
|
content = chunk.content
|
|
# Handle both string and list content
|
|
if isinstance(content, list):
|
|
text_parts = []
|
|
for block in content:
|
|
if isinstance(block, dict) and "text" in block:
|
|
text_parts.append(block["text"])
|
|
elif hasattr(block, "text"):
|
|
text_parts.append(block.text)
|
|
content = "".join(text_parts)
|
|
|
|
if content:
|
|
full_response += content
|
|
|
|
logger.info(
|
|
f"SubAgent '{self.name}': Completed task "
|
|
f"({event_count} events, {len(full_response)} chars)"
|
|
)
|
|
|
|
except Exception as e:
|
|
error_msg = f"SubAgent '{self.name}' execution error: {str(e)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
return f"Error: {error_msg}"
|
|
|
|
return full_response
|
|
|
|
async def stream(
|
|
self,
|
|
task: str,
|
|
thread_id: Optional[str] = None
|
|
) -> AsyncIterator[str]:
|
|
"""Execute a task with streaming response.
|
|
|
|
Args:
|
|
task: The task/prompt for this sub-agent
|
|
thread_id: Optional thread ID for checkpointing
|
|
|
|
Yields:
|
|
Response chunks as they're generated
|
|
"""
|
|
import uuid
|
|
|
|
# Use ephemeral thread ID if not provided
|
|
if thread_id is None:
|
|
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
|
|
|
logger.info(f"SubAgent '{self.name}': Streaming task (thread_id={thread_id})")
|
|
|
|
# Build messages with system prompt
|
|
messages = [
|
|
HumanMessage(content=task)
|
|
]
|
|
|
|
# Prepare config
|
|
config = RunnableConfig(
|
|
configurable={
|
|
"thread_id": thread_id,
|
|
"state_modifier": self.system_prompt
|
|
},
|
|
metadata={
|
|
"subagent_name": self.name
|
|
}
|
|
)
|
|
|
|
# Stream response
|
|
try:
|
|
async for event in self.agent.astream_events(
|
|
{"messages": messages},
|
|
config=config,
|
|
version="v2"
|
|
):
|
|
# Log tool calls
|
|
if event["event"] == "on_tool_start":
|
|
tool_name = event.get("name", "unknown")
|
|
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
|
|
|
# Extract streaming tokens
|
|
elif event["event"] == "on_chat_model_stream":
|
|
chunk = event["data"]["chunk"]
|
|
if hasattr(chunk, "content") and chunk.content:
|
|
content = chunk.content
|
|
# Handle both string and list content
|
|
if isinstance(content, list):
|
|
text_parts = []
|
|
for block in content:
|
|
if isinstance(block, dict) and "text" in block:
|
|
text_parts.append(block["text"])
|
|
elif hasattr(block, "text"):
|
|
text_parts.append(block.text)
|
|
content = "".join(text_parts)
|
|
|
|
if content:
|
|
yield content
|
|
|
|
except Exception as e:
|
|
error_msg = f"SubAgent '{self.name}' streaming error: {str(e)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
yield f"Error: {error_msg}"
|