381 lines
13 KiB
Python
381 lines
13 KiB
Python
import os
|
|
import glob
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
import aiofiles
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
from sentence_transformers import SentenceTransformer
|
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
|
|
# Prevent ChromaDB from reporting telemetry to the mothership
|
|
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
|
|
|
class MemoryManager:
|
|
"""Manages persistent memory using local tools:
|
|
|
|
- LangGraph checkpointing (SQLite) for conversation state
|
|
- ChromaDB for semantic memory search
|
|
- Local sentence-transformers for embeddings
|
|
- Memory graph approach for clustering related concepts
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint_db_path: str = "data/checkpoints.db",
|
|
chroma_db_path: str = "data/chroma",
|
|
embedding_model: str = "all-MiniLM-L6-v2",
|
|
context_docs_dir: str = "memory",
|
|
base_dir: str = "."
|
|
):
|
|
"""Initialize memory manager.
|
|
|
|
Args:
|
|
checkpoint_db_path: Path to SQLite checkpoint database
|
|
chroma_db_path: Path to ChromaDB directory
|
|
embedding_model: Sentence-transformers model name
|
|
context_docs_dir: Directory containing markdown context files
|
|
base_dir: Base directory for resolving relative paths
|
|
"""
|
|
self.checkpoint_db_path = checkpoint_db_path
|
|
self.chroma_db_path = chroma_db_path
|
|
self.embedding_model_name = embedding_model
|
|
self.context_docs_dir = os.path.join(base_dir, context_docs_dir)
|
|
|
|
# Will be initialized on startup
|
|
self.checkpointer: Optional[AsyncSqliteSaver] = None
|
|
self.checkpointer_context: Optional[Any] = None # Store the context manager
|
|
self.chroma_client: Optional[chromadb.Client] = None
|
|
self.memory_collection: Optional[Any] = None
|
|
self.embedding_model: Optional[SentenceTransformer] = None
|
|
|
|
self.context_documents: Dict[str, str] = {}
|
|
self.initialized = False
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the memory system and load context documents."""
|
|
if self.initialized:
|
|
return
|
|
|
|
# Ensure data directories exist
|
|
os.makedirs(os.path.dirname(self.checkpoint_db_path), exist_ok=True)
|
|
os.makedirs(self.chroma_db_path, exist_ok=True)
|
|
|
|
# Initialize LangGraph checkpointer (SQLite)
|
|
self.checkpointer_context = AsyncSqliteSaver.from_conn_string(
|
|
self.checkpoint_db_path
|
|
)
|
|
self.checkpointer = await self.checkpointer_context.__aenter__()
|
|
await self.checkpointer.setup()
|
|
|
|
# Initialize ChromaDB
|
|
self.chroma_client = chromadb.PersistentClient(
|
|
path=self.chroma_db_path,
|
|
settings=Settings(
|
|
anonymized_telemetry=False,
|
|
allow_reset=True
|
|
)
|
|
)
|
|
|
|
# Get or create memory collection
|
|
self.memory_collection = self.chroma_client.get_or_create_collection(
|
|
name="conversation_memory",
|
|
metadata={"description": "Semantic memory for conversations"}
|
|
)
|
|
|
|
# Initialize local embedding model
|
|
print(f"Loading embedding model: {self.embedding_model_name}")
|
|
self.embedding_model = SentenceTransformer(self.embedding_model_name)
|
|
|
|
# Load markdown context documents
|
|
await self._load_context_documents()
|
|
|
|
# Index context documents in ChromaDB
|
|
await self._index_context_documents()
|
|
|
|
self.initialized = True
|
|
print("Memory system initialized (LangGraph + ChromaDB + local embeddings)")
|
|
|
|
async def _load_context_documents(self) -> None:
|
|
"""Load all markdown files from context directory."""
|
|
if not os.path.exists(self.context_docs_dir):
|
|
print(f"Warning: Context directory {self.context_docs_dir} not found")
|
|
return
|
|
|
|
md_files = glob.glob(os.path.join(self.context_docs_dir, "*.md"))
|
|
|
|
for md_file in md_files:
|
|
try:
|
|
async with aiofiles.open(md_file, "r", encoding="utf-8") as f:
|
|
content = await f.read()
|
|
filename = os.path.basename(md_file)
|
|
self.context_documents[filename] = content
|
|
print(f"Loaded context document: {filename}")
|
|
except Exception as e:
|
|
print(f"Error loading {md_file}: {e}")
|
|
|
|
async def _index_context_documents(self) -> None:
|
|
"""Index context documents in ChromaDB for semantic search."""
|
|
if not self.context_documents or not self.memory_collection:
|
|
return
|
|
|
|
for filename, content in self.context_documents.items():
|
|
# Split into sections (by headers)
|
|
sections = self._split_document_into_sections(content, filename)
|
|
|
|
for i, section in enumerate(sections):
|
|
doc_id = f"context_{filename}_{i}"
|
|
|
|
# Generate embedding
|
|
embedding = self.embedding_model.encode(section["content"]).tolist()
|
|
|
|
# Add to ChromaDB
|
|
self.memory_collection.add(
|
|
ids=[doc_id],
|
|
embeddings=[embedding],
|
|
documents=[section["content"]],
|
|
metadatas=[{
|
|
"type": "context",
|
|
"source": filename,
|
|
"section": section["title"],
|
|
"indexed_at": datetime.utcnow().isoformat()
|
|
}]
|
|
)
|
|
|
|
print(f"Indexed {len(self.context_documents)} context documents")
|
|
|
|
def _split_document_into_sections(self, content: str, filename: str) -> List[Dict[str, str]]:
|
|
"""Split markdown document into logical sections.
|
|
|
|
Args:
|
|
content: Markdown content
|
|
filename: Source filename
|
|
|
|
Returns:
|
|
List of section dicts with title and content
|
|
"""
|
|
sections = []
|
|
current_section = {"title": filename, "content": ""}
|
|
|
|
for line in content.split("\n"):
|
|
if line.startswith("#"):
|
|
# New section
|
|
if current_section["content"].strip():
|
|
sections.append(current_section)
|
|
current_section = {
|
|
"title": line.strip("#").strip(),
|
|
"content": line + "\n"
|
|
}
|
|
else:
|
|
current_section["content"] += line + "\n"
|
|
|
|
# Add last section
|
|
if current_section["content"].strip():
|
|
sections.append(current_section)
|
|
|
|
return sections
|
|
|
|
def get_context_prompt(self) -> str:
|
|
"""Generate a context prompt from loaded documents.
|
|
|
|
system_prompt.md is ALWAYS included first and prioritized.
|
|
Other documents are included after.
|
|
|
|
Returns:
|
|
Formatted string containing all context documents
|
|
"""
|
|
if not self.context_documents:
|
|
return ""
|
|
|
|
sections = []
|
|
|
|
# ALWAYS include system_prompt.md first if it exists
|
|
system_prompt_key = "system_prompt.md"
|
|
if system_prompt_key in self.context_documents:
|
|
sections.append(self.context_documents[system_prompt_key])
|
|
sections.append("\n---\n")
|
|
|
|
# Add other context documents
|
|
sections.append("# Additional Context\n")
|
|
sections.append("The following documents provide additional context about the system:\n")
|
|
|
|
for filename, content in sorted(self.context_documents.items()):
|
|
# Skip system_prompt.md since we already added it
|
|
if filename == system_prompt_key:
|
|
continue
|
|
|
|
sections.append(f"\n## {filename}\n")
|
|
sections.append(content)
|
|
|
|
return "\n".join(sections)
|
|
|
|
async def add_memory(
|
|
self,
|
|
session_id: str,
|
|
role: str,
|
|
content: str,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
"""Add a message to semantic memory (ChromaDB).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
role: Message role ("user" or "assistant")
|
|
content: Message content
|
|
metadata: Optional metadata
|
|
"""
|
|
if not self.memory_collection or not self.embedding_model:
|
|
return
|
|
|
|
try:
|
|
# Generate unique ID
|
|
timestamp = datetime.utcnow().isoformat()
|
|
doc_id = f"{session_id}_{role}_{timestamp}"
|
|
|
|
# Generate embedding
|
|
embedding = self.embedding_model.encode(content).tolist()
|
|
|
|
# Prepare metadata
|
|
meta = {
|
|
"session_id": session_id,
|
|
"role": role,
|
|
"timestamp": timestamp,
|
|
"type": "conversation",
|
|
**(metadata or {})
|
|
}
|
|
|
|
# Add to ChromaDB
|
|
self.memory_collection.add(
|
|
ids=[doc_id],
|
|
embeddings=[embedding],
|
|
documents=[content],
|
|
metadatas=[meta]
|
|
)
|
|
except Exception as e:
|
|
print(f"Error adding to ChromaDB memory: {e}")
|
|
|
|
async def search_memory(
|
|
self,
|
|
session_id: str,
|
|
query: str,
|
|
limit: int = 5,
|
|
include_context: bool = True
|
|
) -> List[Dict[str, Any]]:
|
|
"""Search memory using semantic similarity.
|
|
|
|
Args:
|
|
session_id: Session identifier (filters to this session + context docs)
|
|
query: Search query
|
|
limit: Maximum results
|
|
include_context: Whether to include context documents in search
|
|
|
|
Returns:
|
|
List of relevant memory items with content and metadata
|
|
"""
|
|
if not self.memory_collection or not self.embedding_model:
|
|
return []
|
|
|
|
try:
|
|
# Generate query embedding
|
|
query_embedding = self.embedding_model.encode(query).tolist()
|
|
|
|
# Build where filter
|
|
where_filters = []
|
|
if include_context:
|
|
# Search both session messages and context docs
|
|
where_filters = {
|
|
"$or": [
|
|
{"session_id": session_id},
|
|
{"type": "context"}
|
|
]
|
|
}
|
|
else:
|
|
where_filters = {"session_id": session_id}
|
|
|
|
# Query ChromaDB
|
|
results = self.memory_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=limit,
|
|
where=where_filters if where_filters else None
|
|
)
|
|
|
|
# Format results
|
|
memories = []
|
|
if results and results["documents"]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
memories.append({
|
|
"content": doc,
|
|
"metadata": results["metadatas"][0][i],
|
|
"distance": results["distances"][0][i] if "distances" in results else None
|
|
})
|
|
|
|
return memories
|
|
except Exception as e:
|
|
print(f"Error searching ChromaDB memory: {e}")
|
|
return []
|
|
|
|
async def get_memory_graph(
|
|
self,
|
|
session_id: str,
|
|
max_depth: int = 2
|
|
) -> Dict[str, Any]:
|
|
"""Get a graph of related memories using clustering.
|
|
|
|
This creates a simple memory graph by finding clusters of related concepts.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
max_depth: Maximum depth for related memory traversal
|
|
|
|
Returns:
|
|
Dict representing memory graph structure
|
|
"""
|
|
# Simple implementation: get all memories for session and cluster by similarity
|
|
if not self.memory_collection:
|
|
return {}
|
|
|
|
try:
|
|
# Get all memories for this session
|
|
results = self.memory_collection.get(
|
|
where={"session_id": session_id},
|
|
include=["embeddings", "documents", "metadatas"]
|
|
)
|
|
|
|
if not results or not results["documents"]:
|
|
return {"nodes": [], "edges": []}
|
|
|
|
# Build simple graph structure
|
|
nodes = []
|
|
edges = []
|
|
|
|
for i, doc in enumerate(results["documents"]):
|
|
nodes.append({
|
|
"id": results["ids"][i],
|
|
"content": doc,
|
|
"metadata": results["metadatas"][i]
|
|
})
|
|
|
|
# TODO: Compute edges based on embedding similarity
|
|
# For now, return just nodes
|
|
return {"nodes": nodes, "edges": edges}
|
|
|
|
except Exception as e:
|
|
print(f"Error building memory graph: {e}")
|
|
return {}
|
|
|
|
def get_checkpointer(self) -> Optional[AsyncSqliteSaver]:
|
|
"""Get the LangGraph checkpointer for conversation state.
|
|
|
|
Returns:
|
|
AsyncSqliteSaver instance for LangGraph persistence
|
|
"""
|
|
return self.checkpointer
|
|
|
|
async def close(self) -> None:
|
|
"""Close the memory manager and cleanup resources."""
|
|
if self.checkpointer_context:
|
|
await self.checkpointer_context.__aexit__(None, None, None)
|
|
self.checkpointer = None
|
|
self.checkpointer_context = None
|