backend redesign
This commit is contained in:
380
backend.old/src/agent/memory.py
Normal file
380
backend.old/src/agent/memory.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import os
|
||||
import glob
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import aiofiles
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
# Prevent ChromaDB from reporting telemetry to the mothership
|
||||
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
||||
|
||||
class MemoryManager:
|
||||
"""Manages persistent memory using local tools:
|
||||
|
||||
- LangGraph checkpointing (SQLite) for conversation state
|
||||
- ChromaDB for semantic memory search
|
||||
- Local sentence-transformers for embeddings
|
||||
- Memory graph approach for clustering related concepts
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_db_path: str = "data/checkpoints.db",
|
||||
chroma_db_path: str = "data/chroma",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
context_docs_dir: str = "memory",
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize memory manager.
|
||||
|
||||
Args:
|
||||
checkpoint_db_path: Path to SQLite checkpoint database
|
||||
chroma_db_path: Path to ChromaDB directory
|
||||
embedding_model: Sentence-transformers model name
|
||||
context_docs_dir: Directory containing markdown context files
|
||||
base_dir: Base directory for resolving relative paths
|
||||
"""
|
||||
self.checkpoint_db_path = checkpoint_db_path
|
||||
self.chroma_db_path = chroma_db_path
|
||||
self.embedding_model_name = embedding_model
|
||||
self.context_docs_dir = os.path.join(base_dir, context_docs_dir)
|
||||
|
||||
# Will be initialized on startup
|
||||
self.checkpointer: Optional[AsyncSqliteSaver] = None
|
||||
self.checkpointer_context: Optional[Any] = None # Store the context manager
|
||||
self.chroma_client: Optional[chromadb.Client] = None
|
||||
self.memory_collection: Optional[Any] = None
|
||||
self.embedding_model: Optional[SentenceTransformer] = None
|
||||
|
||||
self.context_documents: Dict[str, str] = {}
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the memory system and load context documents."""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
# Ensure data directories exist
|
||||
os.makedirs(os.path.dirname(self.checkpoint_db_path), exist_ok=True)
|
||||
os.makedirs(self.chroma_db_path, exist_ok=True)
|
||||
|
||||
# Initialize LangGraph checkpointer (SQLite)
|
||||
self.checkpointer_context = AsyncSqliteSaver.from_conn_string(
|
||||
self.checkpoint_db_path
|
||||
)
|
||||
self.checkpointer = await self.checkpointer_context.__aenter__()
|
||||
await self.checkpointer.setup()
|
||||
|
||||
# Initialize ChromaDB
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path=self.chroma_db_path,
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True
|
||||
)
|
||||
)
|
||||
|
||||
# Get or create memory collection
|
||||
self.memory_collection = self.chroma_client.get_or_create_collection(
|
||||
name="conversation_memory",
|
||||
metadata={"description": "Semantic memory for conversations"}
|
||||
)
|
||||
|
||||
# Initialize local embedding model
|
||||
print(f"Loading embedding model: {self.embedding_model_name}")
|
||||
self.embedding_model = SentenceTransformer(self.embedding_model_name)
|
||||
|
||||
# Load markdown context documents
|
||||
await self._load_context_documents()
|
||||
|
||||
# Index context documents in ChromaDB
|
||||
await self._index_context_documents()
|
||||
|
||||
self.initialized = True
|
||||
print("Memory system initialized (LangGraph + ChromaDB + local embeddings)")
|
||||
|
||||
async def _load_context_documents(self) -> None:
|
||||
"""Load all markdown files from context directory."""
|
||||
if not os.path.exists(self.context_docs_dir):
|
||||
print(f"Warning: Context directory {self.context_docs_dir} not found")
|
||||
return
|
||||
|
||||
md_files = glob.glob(os.path.join(self.context_docs_dir, "*.md"))
|
||||
|
||||
for md_file in md_files:
|
||||
try:
|
||||
async with aiofiles.open(md_file, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
filename = os.path.basename(md_file)
|
||||
self.context_documents[filename] = content
|
||||
print(f"Loaded context document: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Error loading {md_file}: {e}")
|
||||
|
||||
async def _index_context_documents(self) -> None:
|
||||
"""Index context documents in ChromaDB for semantic search."""
|
||||
if not self.context_documents or not self.memory_collection:
|
||||
return
|
||||
|
||||
for filename, content in self.context_documents.items():
|
||||
# Split into sections (by headers)
|
||||
sections = self._split_document_into_sections(content, filename)
|
||||
|
||||
for i, section in enumerate(sections):
|
||||
doc_id = f"context_{filename}_{i}"
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedding_model.encode(section["content"]).tolist()
|
||||
|
||||
# Add to ChromaDB
|
||||
self.memory_collection.add(
|
||||
ids=[doc_id],
|
||||
embeddings=[embedding],
|
||||
documents=[section["content"]],
|
||||
metadatas=[{
|
||||
"type": "context",
|
||||
"source": filename,
|
||||
"section": section["title"],
|
||||
"indexed_at": datetime.utcnow().isoformat()
|
||||
}]
|
||||
)
|
||||
|
||||
print(f"Indexed {len(self.context_documents)} context documents")
|
||||
|
||||
def _split_document_into_sections(self, content: str, filename: str) -> List[Dict[str, str]]:
|
||||
"""Split markdown document into logical sections.
|
||||
|
||||
Args:
|
||||
content: Markdown content
|
||||
filename: Source filename
|
||||
|
||||
Returns:
|
||||
List of section dicts with title and content
|
||||
"""
|
||||
sections = []
|
||||
current_section = {"title": filename, "content": ""}
|
||||
|
||||
for line in content.split("\n"):
|
||||
if line.startswith("#"):
|
||||
# New section
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
current_section = {
|
||||
"title": line.strip("#").strip(),
|
||||
"content": line + "\n"
|
||||
}
|
||||
else:
|
||||
current_section["content"] += line + "\n"
|
||||
|
||||
# Add last section
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
def get_context_prompt(self) -> str:
|
||||
"""Generate a context prompt from loaded documents.
|
||||
|
||||
system_prompt.md is ALWAYS included first and prioritized.
|
||||
Other documents are included after.
|
||||
|
||||
Returns:
|
||||
Formatted string containing all context documents
|
||||
"""
|
||||
if not self.context_documents:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
|
||||
# ALWAYS include system_prompt.md first if it exists
|
||||
system_prompt_key = "system_prompt.md"
|
||||
if system_prompt_key in self.context_documents:
|
||||
sections.append(self.context_documents[system_prompt_key])
|
||||
sections.append("\n---\n")
|
||||
|
||||
# Add other context documents
|
||||
sections.append("# Additional Context\n")
|
||||
sections.append("The following documents provide additional context about the system:\n")
|
||||
|
||||
for filename, content in sorted(self.context_documents.items()):
|
||||
# Skip system_prompt.md since we already added it
|
||||
if filename == system_prompt_key:
|
||||
continue
|
||||
|
||||
sections.append(f"\n## {filename}\n")
|
||||
sections.append(content)
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Add a message to semantic memory (ChromaDB).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
role: Message role ("user" or "assistant")
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
if not self.memory_collection or not self.embedding_model:
|
||||
return
|
||||
|
||||
try:
|
||||
# Generate unique ID
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
doc_id = f"{session_id}_{role}_{timestamp}"
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedding_model.encode(content).tolist()
|
||||
|
||||
# Prepare metadata
|
||||
meta = {
|
||||
"session_id": session_id,
|
||||
"role": role,
|
||||
"timestamp": timestamp,
|
||||
"type": "conversation",
|
||||
**(metadata or {})
|
||||
}
|
||||
|
||||
# Add to ChromaDB
|
||||
self.memory_collection.add(
|
||||
ids=[doc_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[meta]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error adding to ChromaDB memory: {e}")
|
||||
|
||||
async def search_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
include_context: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search memory using semantic similarity.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier (filters to this session + context docs)
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
include_context: Whether to include context documents in search
|
||||
|
||||
Returns:
|
||||
List of relevant memory items with content and metadata
|
||||
"""
|
||||
if not self.memory_collection or not self.embedding_model:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate query embedding
|
||||
query_embedding = self.embedding_model.encode(query).tolist()
|
||||
|
||||
# Build where filter
|
||||
where_filters = []
|
||||
if include_context:
|
||||
# Search both session messages and context docs
|
||||
where_filters = {
|
||||
"$or": [
|
||||
{"session_id": session_id},
|
||||
{"type": "context"}
|
||||
]
|
||||
}
|
||||
else:
|
||||
where_filters = {"session_id": session_id}
|
||||
|
||||
# Query ChromaDB
|
||||
results = self.memory_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=limit,
|
||||
where=where_filters if where_filters else None
|
||||
)
|
||||
|
||||
# Format results
|
||||
memories = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
memories.append({
|
||||
"content": doc,
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i] if "distances" in results else None
|
||||
})
|
||||
|
||||
return memories
|
||||
except Exception as e:
|
||||
print(f"Error searching ChromaDB memory: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_graph(
|
||||
self,
|
||||
session_id: str,
|
||||
max_depth: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a graph of related memories using clustering.
|
||||
|
||||
This creates a simple memory graph by finding clusters of related concepts.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
max_depth: Maximum depth for related memory traversal
|
||||
|
||||
Returns:
|
||||
Dict representing memory graph structure
|
||||
"""
|
||||
# Simple implementation: get all memories for session and cluster by similarity
|
||||
if not self.memory_collection:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# Get all memories for this session
|
||||
results = self.memory_collection.get(
|
||||
where={"session_id": session_id},
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
)
|
||||
|
||||
if not results or not results["documents"]:
|
||||
return {"nodes": [], "edges": []}
|
||||
|
||||
# Build simple graph structure
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
for i, doc in enumerate(results["documents"]):
|
||||
nodes.append({
|
||||
"id": results["ids"][i],
|
||||
"content": doc,
|
||||
"metadata": results["metadatas"][i]
|
||||
})
|
||||
|
||||
# TODO: Compute edges based on embedding similarity
|
||||
# For now, return just nodes
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error building memory graph: {e}")
|
||||
return {}
|
||||
|
||||
def get_checkpointer(self) -> Optional[AsyncSqliteSaver]:
|
||||
"""Get the LangGraph checkpointer for conversation state.
|
||||
|
||||
Returns:
|
||||
AsyncSqliteSaver instance for LangGraph persistence
|
||||
"""
|
||||
return self.checkpointer
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the memory manager and cleanup resources."""
|
||||
if self.checkpointer_context:
|
||||
await self.checkpointer_context.__aexit__(None, None, None)
|
||||
self.checkpointer = None
|
||||
self.checkpointer_context = None
|
||||
Reference in New Issue
Block a user