initial commit with charts and assistant chat
This commit is contained in:
20
backend/config.yaml
Normal file
20
backend/config.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
www_port: 8080
|
||||
server_port: 8081
|
||||
|
||||
# Agent configuration
|
||||
agent:
|
||||
model: "claude-sonnet-4-20250514"
|
||||
temperature: 0.7
|
||||
context_docs_dir: "doc"
|
||||
|
||||
# Local memory configuration (free & sophisticated!)
|
||||
memory:
|
||||
# LangGraph checkpointing (SQLite for conversation state)
|
||||
checkpoint_db: "data/checkpoints.db"
|
||||
|
||||
# ChromaDB (local vector DB for semantic search)
|
||||
chroma_db: "data/chroma"
|
||||
|
||||
# Sentence-transformers model (local embeddings)
|
||||
# Options: all-MiniLM-L6-v2 (fast, small), all-mpnet-base-v2 (better quality)
|
||||
embedding_model: "all-MiniLM-L6-v2"
|
||||
32
backend/requirements.txt
Normal file
32
backend/requirements.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
pydantic2
|
||||
seaborn
|
||||
pandas
|
||||
numpy
|
||||
scipy
|
||||
matplotlib
|
||||
fastapi
|
||||
uvicorn
|
||||
websockets
|
||||
jsonpatch
|
||||
python-multipart
|
||||
ccxt>=4.0.0
|
||||
pyyaml
|
||||
|
||||
# LangChain agent dependencies
|
||||
langchain>=0.3.0
|
||||
langgraph>=0.2.0
|
||||
langgraph-checkpoint-sqlite>=1.0.0
|
||||
langchain-anthropic>=0.3.0
|
||||
langchain-community>=0.3.0
|
||||
|
||||
# Local memory system
|
||||
chromadb>=0.4.0
|
||||
sentence-transformers>=2.0.0
|
||||
sqlalchemy>=2.0.0
|
||||
aiosqlite>=0.19.0
|
||||
|
||||
# Async utilities
|
||||
aiofiles>=24.0.0
|
||||
|
||||
# Environment configuration
|
||||
python-dotenv>=1.0.0
|
||||
3
backend/src/agent/__init__.py
Normal file
3
backend/src/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.core import create_agent
|
||||
|
||||
__all__ = ["create_agent"]
|
||||
296
backend/src/agent/core.py
Normal file
296
backend/src/agent/core.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncIterator, Dict, Any, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage as GatewayUserMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""LangGraph-based agent executor with streaming support.
|
||||
|
||||
Handles agent invocation, tool execution, and response streaming.
|
||||
Supports interruption for real-time user interaction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
memory_manager: Optional[MemoryManager] = None
|
||||
):
|
||||
"""Initialize agent executor.
|
||||
|
||||
Args:
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Memory and session management
|
||||
self.memory_manager = memory_manager or MemoryManager()
|
||||
self.session_manager = SessionManager(self.memory_manager)
|
||||
self.agent = None # Will be created after initialization
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent system."""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# Create agent with tools and LangGraph checkpointing
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
|
||||
# Build initial system prompt with context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, [])
|
||||
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS,
|
||||
prompt=system_prompt,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
async def _clear_checkpoint(self, session_id: str) -> None:
|
||||
"""Clear the checkpoint for a session to prevent resuming from invalid state.
|
||||
|
||||
This is called when an error occurs during agent execution to ensure
|
||||
the next interaction starts fresh instead of trying to resume from
|
||||
a broken state (e.g., orphaned tool calls).
|
||||
|
||||
Args:
|
||||
session_id: The session ID whose checkpoint should be cleared
|
||||
"""
|
||||
try:
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
if checkpointer:
|
||||
# Delete all checkpoints for this thread
|
||||
# LangGraph uses thread_id as the key
|
||||
# The checkpointer API doesn't have a direct delete method,
|
||||
# but we can use the underlying connection
|
||||
async with checkpointer.conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||
(session_id,)
|
||||
)
|
||||
await checkpointer.conn.commit()
|
||||
logger.info(f"Cleared checkpoint for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
||||
|
||||
def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage:
|
||||
"""Build system message with context.
|
||||
|
||||
Args:
|
||||
state: Agent state
|
||||
|
||||
Returns:
|
||||
SystemMessage with full context
|
||||
"""
|
||||
# Get context from loaded documents
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
|
||||
# Get active channels from metadata
|
||||
active_channels = state.get("metadata", {}).get("active_channels", [])
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = build_system_prompt(context, active_channels)
|
||||
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: UserSession,
|
||||
message: GatewayUserMessage
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute the agent and stream responses.
|
||||
|
||||
Args:
|
||||
session: User session
|
||||
message: User message
|
||||
|
||||
Yields:
|
||||
Response chunks as they're generated
|
||||
"""
|
||||
logger.info(f"AgentExecutor.execute called for session {session.session_id}")
|
||||
|
||||
# Get session lock to prevent concurrent execution
|
||||
lock = await self.session_manager.get_session_lock(session.session_id)
|
||||
logger.info(f"Session lock acquired for {session.session_id}")
|
||||
|
||||
async with lock:
|
||||
try:
|
||||
# Build message history
|
||||
messages = []
|
||||
history = session.get_history(limit=10)
|
||||
logger.info(f"Building message history, {len(history)} messages in history")
|
||||
|
||||
for i, msg in enumerate(history):
|
||||
logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'")
|
||||
if msg.role == "user":
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
logger.info(f"Prepared {len(messages)} messages for agent")
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'")
|
||||
|
||||
# Prepare config with metadata
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": session.session_id
|
||||
},
|
||||
metadata={
|
||||
"session_id": session.session_id,
|
||||
"user_id": session.user_id,
|
||||
"active_channels": session.active_channels
|
||||
}
|
||||
)
|
||||
logger.info(f"Agent config prepared: thread_id={session.session_id}")
|
||||
|
||||
# Invoke agent with streaming
|
||||
logger.info("Starting agent.astream_events...")
|
||||
full_response = ""
|
||||
event_count = 0
|
||||
chunk_count = 0
|
||||
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
# Check for cancellation
|
||||
if asyncio.current_task().cancelled():
|
||||
logger.warning("Agent execution cancelled")
|
||||
break
|
||||
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
logger.info(f"Tool call started: {tool_name} with input: {tool_input}")
|
||||
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_output = event.get("data", {}).get("output")
|
||||
logger.info(f"Tool call completed: {tool_name} with output: {tool_output}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content blocks
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content: # Only yield non-empty content
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
logger.debug(f"Yielding content chunk #{chunk_count}")
|
||||
yield content
|
||||
|
||||
logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars")
|
||||
|
||||
# Save to persistent memory
|
||||
logger.info(f"Saving assistant message to persistent memory")
|
||||
await self.session_manager.save_message(
|
||||
session.session_id,
|
||||
"assistant",
|
||||
full_response
|
||||
)
|
||||
logger.info("Assistant message saved to memory")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"Agent execution cancelled for session {session.session_id}")
|
||||
# Clear checkpoint on cancellation to prevent orphaned tool calls
|
||||
await self._clear_checkpoint(session.session_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls)
|
||||
# This ensures the next interaction starts fresh instead of trying to resume
|
||||
# from a broken state
|
||||
await self._clear_checkpoint(session.session_id)
|
||||
|
||||
yield error_msg
|
||||
|
||||
|
||||
def create_agent(
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
checkpoint_db_path: str = "data/checkpoints.db",
|
||||
chroma_db_path: str = "data/chroma",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
context_docs_dir: str = "doc",
|
||||
base_dir: str = "."
|
||||
) -> AgentExecutor:
|
||||
"""Create and initialize an agent executor.
|
||||
|
||||
Args:
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
checkpoint_db_path: Path to LangGraph checkpoint SQLite DB
|
||||
chroma_db_path: Path to ChromaDB storage directory
|
||||
embedding_model: Sentence-transformers model name
|
||||
context_docs_dir: Directory with context markdown files
|
||||
base_dir: Base directory for resolving paths
|
||||
|
||||
Returns:
|
||||
Initialized AgentExecutor
|
||||
"""
|
||||
# Initialize memory manager
|
||||
memory_manager = MemoryManager(
|
||||
checkpoint_db_path=checkpoint_db_path,
|
||||
chroma_db_path=chroma_db_path,
|
||||
embedding_model=embedding_model,
|
||||
context_docs_dir=context_docs_dir,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = AgentExecutor(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
memory_manager=memory_manager
|
||||
)
|
||||
|
||||
return executor
|
||||
380
backend/src/agent/memory.py
Normal file
380
backend/src/agent/memory.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import os
|
||||
import glob
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import aiofiles
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
# Prevent ChromaDB from reporting telemetry to the mothership
|
||||
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
||||
|
||||
class MemoryManager:
|
||||
"""Manages persistent memory using local tools:
|
||||
|
||||
- LangGraph checkpointing (SQLite) for conversation state
|
||||
- ChromaDB for semantic memory search
|
||||
- Local sentence-transformers for embeddings
|
||||
- Memory graph approach for clustering related concepts
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_db_path: str = "data/checkpoints.db",
|
||||
chroma_db_path: str = "data/chroma",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
context_docs_dir: str = "memory",
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize memory manager.
|
||||
|
||||
Args:
|
||||
checkpoint_db_path: Path to SQLite checkpoint database
|
||||
chroma_db_path: Path to ChromaDB directory
|
||||
embedding_model: Sentence-transformers model name
|
||||
context_docs_dir: Directory containing markdown context files
|
||||
base_dir: Base directory for resolving relative paths
|
||||
"""
|
||||
self.checkpoint_db_path = checkpoint_db_path
|
||||
self.chroma_db_path = chroma_db_path
|
||||
self.embedding_model_name = embedding_model
|
||||
self.context_docs_dir = os.path.join(base_dir, context_docs_dir)
|
||||
|
||||
# Will be initialized on startup
|
||||
self.checkpointer: Optional[AsyncSqliteSaver] = None
|
||||
self.checkpointer_context: Optional[Any] = None # Store the context manager
|
||||
self.chroma_client: Optional[chromadb.Client] = None
|
||||
self.memory_collection: Optional[Any] = None
|
||||
self.embedding_model: Optional[SentenceTransformer] = None
|
||||
|
||||
self.context_documents: Dict[str, str] = {}
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the memory system and load context documents."""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
# Ensure data directories exist
|
||||
os.makedirs(os.path.dirname(self.checkpoint_db_path), exist_ok=True)
|
||||
os.makedirs(self.chroma_db_path, exist_ok=True)
|
||||
|
||||
# Initialize LangGraph checkpointer (SQLite)
|
||||
self.checkpointer_context = AsyncSqliteSaver.from_conn_string(
|
||||
self.checkpoint_db_path
|
||||
)
|
||||
self.checkpointer = await self.checkpointer_context.__aenter__()
|
||||
await self.checkpointer.setup()
|
||||
|
||||
# Initialize ChromaDB
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path=self.chroma_db_path,
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True
|
||||
)
|
||||
)
|
||||
|
||||
# Get or create memory collection
|
||||
self.memory_collection = self.chroma_client.get_or_create_collection(
|
||||
name="conversation_memory",
|
||||
metadata={"description": "Semantic memory for conversations"}
|
||||
)
|
||||
|
||||
# Initialize local embedding model
|
||||
print(f"Loading embedding model: {self.embedding_model_name}")
|
||||
self.embedding_model = SentenceTransformer(self.embedding_model_name)
|
||||
|
||||
# Load markdown context documents
|
||||
await self._load_context_documents()
|
||||
|
||||
# Index context documents in ChromaDB
|
||||
await self._index_context_documents()
|
||||
|
||||
self.initialized = True
|
||||
print("Memory system initialized (LangGraph + ChromaDB + local embeddings)")
|
||||
|
||||
async def _load_context_documents(self) -> None:
|
||||
"""Load all markdown files from context directory."""
|
||||
if not os.path.exists(self.context_docs_dir):
|
||||
print(f"Warning: Context directory {self.context_docs_dir} not found")
|
||||
return
|
||||
|
||||
md_files = glob.glob(os.path.join(self.context_docs_dir, "*.md"))
|
||||
|
||||
for md_file in md_files:
|
||||
try:
|
||||
async with aiofiles.open(md_file, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
filename = os.path.basename(md_file)
|
||||
self.context_documents[filename] = content
|
||||
print(f"Loaded context document: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Error loading {md_file}: {e}")
|
||||
|
||||
async def _index_context_documents(self) -> None:
|
||||
"""Index context documents in ChromaDB for semantic search."""
|
||||
if not self.context_documents or not self.memory_collection:
|
||||
return
|
||||
|
||||
for filename, content in self.context_documents.items():
|
||||
# Split into sections (by headers)
|
||||
sections = self._split_document_into_sections(content, filename)
|
||||
|
||||
for i, section in enumerate(sections):
|
||||
doc_id = f"context_{filename}_{i}"
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedding_model.encode(section["content"]).tolist()
|
||||
|
||||
# Add to ChromaDB
|
||||
self.memory_collection.add(
|
||||
ids=[doc_id],
|
||||
embeddings=[embedding],
|
||||
documents=[section["content"]],
|
||||
metadatas=[{
|
||||
"type": "context",
|
||||
"source": filename,
|
||||
"section": section["title"],
|
||||
"indexed_at": datetime.utcnow().isoformat()
|
||||
}]
|
||||
)
|
||||
|
||||
print(f"Indexed {len(self.context_documents)} context documents")
|
||||
|
||||
def _split_document_into_sections(self, content: str, filename: str) -> List[Dict[str, str]]:
|
||||
"""Split markdown document into logical sections.
|
||||
|
||||
Args:
|
||||
content: Markdown content
|
||||
filename: Source filename
|
||||
|
||||
Returns:
|
||||
List of section dicts with title and content
|
||||
"""
|
||||
sections = []
|
||||
current_section = {"title": filename, "content": ""}
|
||||
|
||||
for line in content.split("\n"):
|
||||
if line.startswith("#"):
|
||||
# New section
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
current_section = {
|
||||
"title": line.strip("#").strip(),
|
||||
"content": line + "\n"
|
||||
}
|
||||
else:
|
||||
current_section["content"] += line + "\n"
|
||||
|
||||
# Add last section
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
def get_context_prompt(self) -> str:
|
||||
"""Generate a context prompt from loaded documents.
|
||||
|
||||
system_prompt.md is ALWAYS included first and prioritized.
|
||||
Other documents are included after.
|
||||
|
||||
Returns:
|
||||
Formatted string containing all context documents
|
||||
"""
|
||||
if not self.context_documents:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
|
||||
# ALWAYS include system_prompt.md first if it exists
|
||||
system_prompt_key = "system_prompt.md"
|
||||
if system_prompt_key in self.context_documents:
|
||||
sections.append(self.context_documents[system_prompt_key])
|
||||
sections.append("\n---\n")
|
||||
|
||||
# Add other context documents
|
||||
sections.append("# Additional Context\n")
|
||||
sections.append("The following documents provide additional context about the system:\n")
|
||||
|
||||
for filename, content in sorted(self.context_documents.items()):
|
||||
# Skip system_prompt.md since we already added it
|
||||
if filename == system_prompt_key:
|
||||
continue
|
||||
|
||||
sections.append(f"\n## {filename}\n")
|
||||
sections.append(content)
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Add a message to semantic memory (ChromaDB).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
role: Message role ("user" or "assistant")
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
if not self.memory_collection or not self.embedding_model:
|
||||
return
|
||||
|
||||
try:
|
||||
# Generate unique ID
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
doc_id = f"{session_id}_{role}_{timestamp}"
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedding_model.encode(content).tolist()
|
||||
|
||||
# Prepare metadata
|
||||
meta = {
|
||||
"session_id": session_id,
|
||||
"role": role,
|
||||
"timestamp": timestamp,
|
||||
"type": "conversation",
|
||||
**(metadata or {})
|
||||
}
|
||||
|
||||
# Add to ChromaDB
|
||||
self.memory_collection.add(
|
||||
ids=[doc_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[meta]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error adding to ChromaDB memory: {e}")
|
||||
|
||||
async def search_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
include_context: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search memory using semantic similarity.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier (filters to this session + context docs)
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
include_context: Whether to include context documents in search
|
||||
|
||||
Returns:
|
||||
List of relevant memory items with content and metadata
|
||||
"""
|
||||
if not self.memory_collection or not self.embedding_model:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate query embedding
|
||||
query_embedding = self.embedding_model.encode(query).tolist()
|
||||
|
||||
# Build where filter
|
||||
where_filters = []
|
||||
if include_context:
|
||||
# Search both session messages and context docs
|
||||
where_filters = {
|
||||
"$or": [
|
||||
{"session_id": session_id},
|
||||
{"type": "context"}
|
||||
]
|
||||
}
|
||||
else:
|
||||
where_filters = {"session_id": session_id}
|
||||
|
||||
# Query ChromaDB
|
||||
results = self.memory_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=limit,
|
||||
where=where_filters if where_filters else None
|
||||
)
|
||||
|
||||
# Format results
|
||||
memories = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
memories.append({
|
||||
"content": doc,
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i] if "distances" in results else None
|
||||
})
|
||||
|
||||
return memories
|
||||
except Exception as e:
|
||||
print(f"Error searching ChromaDB memory: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_graph(
|
||||
self,
|
||||
session_id: str,
|
||||
max_depth: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a graph of related memories using clustering.
|
||||
|
||||
This creates a simple memory graph by finding clusters of related concepts.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
max_depth: Maximum depth for related memory traversal
|
||||
|
||||
Returns:
|
||||
Dict representing memory graph structure
|
||||
"""
|
||||
# Simple implementation: get all memories for session and cluster by similarity
|
||||
if not self.memory_collection:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# Get all memories for this session
|
||||
results = self.memory_collection.get(
|
||||
where={"session_id": session_id},
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
)
|
||||
|
||||
if not results or not results["documents"]:
|
||||
return {"nodes": [], "edges": []}
|
||||
|
||||
# Build simple graph structure
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
for i, doc in enumerate(results["documents"]):
|
||||
nodes.append({
|
||||
"id": results["ids"][i],
|
||||
"content": doc,
|
||||
"metadata": results["metadatas"][i]
|
||||
})
|
||||
|
||||
# TODO: Compute edges based on embedding similarity
|
||||
# For now, return just nodes
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error building memory graph: {e}")
|
||||
return {}
|
||||
|
||||
def get_checkpointer(self) -> Optional[AsyncSqliteSaver]:
|
||||
"""Get the LangGraph checkpointer for conversation state.
|
||||
|
||||
Returns:
|
||||
AsyncSqliteSaver instance for LangGraph persistence
|
||||
"""
|
||||
return self.checkpointer
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the memory manager and cleanup resources."""
|
||||
if self.checkpointer_context:
|
||||
await self.checkpointer_context.__aexit__(None, None, None)
|
||||
self.checkpointer = None
|
||||
self.checkpointer_context = None
|
||||
57
backend/src/agent/prompts.py
Normal file
57
backend/src/agent/prompts.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List
|
||||
from gateway.user_session import UserSession
|
||||
|
||||
|
||||
def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||
"""Build the system prompt for the agent.
|
||||
|
||||
The main system prompt comes from system_prompt.md (loaded in context).
|
||||
This function adds dynamic session information.
|
||||
|
||||
Args:
|
||||
context: Context from loaded markdown documents (includes system_prompt.md)
|
||||
active_channels: List of active channel IDs for this session
|
||||
|
||||
Returns:
|
||||
Formatted system prompt
|
||||
"""
|
||||
channels_str = ", ".join(active_channels) if active_channels else "none"
|
||||
|
||||
# Context already includes system_prompt.md and other docs
|
||||
# Just add current session information
|
||||
prompt = f"""{context}
|
||||
|
||||
## Current Session Information
|
||||
|
||||
**Active Channels**: {channels_str}
|
||||
|
||||
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
|
||||
If the user sends a new message while you're responding, your current response will be interrupted
|
||||
and you'll be re-invoked with the updated context.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
def build_user_prompt_with_history(session: UserSession, current_message: str) -> str:
|
||||
"""Build a user prompt including conversation history.
|
||||
|
||||
Args:
|
||||
session: User session with conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
Formatted prompt with history
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Get recent history (last 10 messages)
|
||||
history = session.get_history(limit=10)
|
||||
|
||||
for msg in history:
|
||||
role_label = "User" if msg.role == "user" else "Assistant"
|
||||
messages.append(f"{role_label}: {msg.content}")
|
||||
|
||||
# Add current message
|
||||
messages.append(f"User: {current_message}")
|
||||
|
||||
return "\n\n".join(messages)
|
||||
93
backend/src/agent/session.py
Normal file
93
backend/src/agent/session.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from agent.memory import MemoryManager
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages agent sessions and their associated memory.
|
||||
|
||||
Coordinates between the gateway's UserSession (for conversation state)
|
||||
and the agent's MemoryManager (for persistent memory with Zep).
|
||||
"""
|
||||
|
||||
def __init__(self, memory_manager: MemoryManager):
|
||||
"""Initialize session manager.
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance for persistent storage
|
||||
"""
|
||||
self.memory = memory_manager
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
async def get_session_lock(self, session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a session.
|
||||
|
||||
Prevents concurrent execution for the same session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
asyncio.Lock for this session
|
||||
"""
|
||||
if session_id not in self._locks:
|
||||
self._locks[session_id] = asyncio.Lock()
|
||||
return self._locks[session_id]
|
||||
|
||||
async def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict] = None
|
||||
) -> None:
|
||||
"""Save a message to persistent memory.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
role: Message role ("user" or "assistant")
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
await self.memory.add_memory(session_id, role, content, metadata)
|
||||
|
||||
async def get_relevant_context(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5
|
||||
) -> str:
|
||||
"""Get relevant historical context for a query.
|
||||
|
||||
Uses semantic search over past conversations.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
Formatted string of relevant past messages
|
||||
"""
|
||||
results = await self.memory.search_memory(session_id, query, limit)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
context_parts = ["## Relevant Past Context\n"]
|
||||
for i, result in enumerate(results, 1):
|
||||
role = result["role"].capitalize()
|
||||
content = result["content"]
|
||||
score = result.get("score", 0.0)
|
||||
context_parts.append(f"{i}. [{role}, relevance: {score:.2f}] {content}\n")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
async def cleanup_session(self, session_id: str) -> None:
|
||||
"""Cleanup session resources.
|
||||
|
||||
Args:
|
||||
session_id: Session to cleanup
|
||||
"""
|
||||
if session_id in self._locks:
|
||||
del self._locks[session_id]
|
||||
662
backend/src/agent/tools.py
Normal file
662
backend/src/agent/tools.py
Normal file
@@ -0,0 +1,662 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
import io
|
||||
import base64
|
||||
import sys
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global registry instance (will be set by main.py)
|
||||
_registry = None
|
||||
_datasource_registry = None
|
||||
|
||||
|
||||
def set_registry(registry):
|
||||
"""Set the global SyncRegistry instance for tools to use."""
|
||||
global _registry
|
||||
_registry = registry
|
||||
|
||||
|
||||
def set_datasource_registry(datasource_registry):
|
||||
"""Set the global DataSourceRegistry instance for tools to use."""
|
||||
global _datasource_registry
|
||||
_datasource_registry = datasource_registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_sync_stores() -> List[str]:
|
||||
"""List all available synchronization stores.
|
||||
|
||||
Returns:
|
||||
List of store names that can be read/written
|
||||
"""
|
||||
if not _registry:
|
||||
return []
|
||||
return list(_registry.entries.keys())
|
||||
|
||||
|
||||
@tool
|
||||
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
||||
"""Read the current state of a synchronization store.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current state of the store
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = _registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(_registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
return entry.model.model_dump(mode="json")
|
||||
|
||||
|
||||
@tool
|
||||
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Update the state of a synchronization store.
|
||||
|
||||
This will apply the updates to the store and trigger synchronization
|
||||
with all connected clients.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to update
|
||||
updates: Dictionary of field updates (field_name: new_value)
|
||||
|
||||
Returns:
|
||||
Dictionary with status and updated fields
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist or updates are invalid
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = _registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(_registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
try:
|
||||
# Get current state
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
|
||||
# Apply updates
|
||||
new_state = {**current_state, **updates}
|
||||
|
||||
# Update the model
|
||||
_registry._update_model(entry.model, new_state)
|
||||
|
||||
# Trigger sync
|
||||
await _registry.push_all()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"store": store_name,
|
||||
"updated_fields": list(updates.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
||||
"""Get the schema/structure of a synchronization store.
|
||||
|
||||
This shows what fields are available and their types.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store
|
||||
|
||||
Returns:
|
||||
Dictionary describing the store's schema
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = _registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(_registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
# Get model schema
|
||||
schema = entry.model.model_json_schema()
|
||||
|
||||
return {
|
||||
"store_name": store_name,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
# DataSource tools
|
||||
|
||||
@tool
|
||||
def list_data_sources() -> List[str]:
|
||||
"""List all available data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names that can be queried for market data
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
return []
|
||||
return _datasource_registry.list_sources()
|
||||
|
||||
|
||||
@tool
|
||||
async def search_symbols(
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for trading symbols across all data sources.
|
||||
|
||||
Automatically searches all available data sources and returns aggregated results.
|
||||
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
||||
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
||||
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
||||
limit: Maximum number of results per source (default: 30)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping source names to lists of matching symbols.
|
||||
Each symbol includes: symbol, full_name, description, exchange, type.
|
||||
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
||||
|
||||
Example response:
|
||||
{
|
||||
"demo": [
|
||||
{
|
||||
"symbol": "BTC/USDT",
|
||||
"full_name": "Bitcoin / Tether USD",
|
||||
"description": "Bitcoin perpetual futures",
|
||||
"exchange": "demo",
|
||||
"type": "crypto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
# Always search all sources
|
||||
results = await _datasource_registry.search_all(query, type, exchange, limit)
|
||||
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
||||
|
||||
|
||||
@tool
|
||||
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get complete metadata for a trading symbol.
|
||||
|
||||
This retrieves full information about a symbol including:
|
||||
- Description and type
|
||||
- Supported time resolutions
|
||||
- Available data columns (OHLCV, volume, funding rates, etc.)
|
||||
- Trading session information
|
||||
- Price scale and precision
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source (use list_data_sources to see available)
|
||||
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing complete symbol metadata including column schema
|
||||
|
||||
Raises:
|
||||
ValueError: If source_name or symbol is not found
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
symbol_info = await _datasource_registry.resolve_symbol(source_name, symbol)
|
||||
return symbol_info.model_dump()
|
||||
|
||||
|
||||
@tool
|
||||
async def get_historical_data(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get historical bar/candle data for a symbol.
|
||||
|
||||
Retrieves time-series data between the specified timestamps. The data
|
||||
includes all columns defined for the symbol (OHLCV + any custom columns).
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
||||
from_time: Start time as Unix timestamp in seconds
|
||||
to_time: End time as Unix timestamp in seconds
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- symbol: The requested symbol
|
||||
- resolution: The time resolution
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- nextTime: If present, indicates more data is available for pagination
|
||||
|
||||
Raises:
|
||||
ValueError: If source, symbol, or resolution is invalid
|
||||
|
||||
Example:
|
||||
# Get 1-hour BTC data for the last 24 hours
|
||||
import time
|
||||
to_time = int(time.time())
|
||||
from_time = to_time - 86400 # 24 hours ago
|
||||
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
||||
"""
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
source = _datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = _datasource_registry.list_sources()
|
||||
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
||||
|
||||
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||
"""Internal implementation for getting chart data.
|
||||
|
||||
This is a helper function that can be called by both get_chart_data tool
|
||||
and analyze_chart_data tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
# Read current chart state
|
||||
chart_store = _registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
raise ValueError("ChartStore not found in registry")
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
symbol = chart_data.get("symbol", "")
|
||||
interval = chart_data.get("interval", "15")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
if not symbol:
|
||||
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
|
||||
|
||||
# Parse the symbol to extract exchange/source and symbol name
|
||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
if ":" not in symbol:
|
||||
raise ValueError(
|
||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||
)
|
||||
|
||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
|
||||
# Get the data source
|
||||
source = _datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = _datasource_registry.list_sources()
|
||||
raise ValueError(
|
||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||
)
|
||||
|
||||
# Determine time range - REQUIRE it to be set, no defaults
|
||||
if start_time is None or end_time is None:
|
||||
raise ValueError(
|
||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||
f"Wait for the chart to fully load before analyzing data."
|
||||
)
|
||||
|
||||
from_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
logger.info(
|
||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||
f"countback={countback}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||
f"resolution '{interval}'"
|
||||
)
|
||||
|
||||
# Query the data source
|
||||
result = await source.get_bars(
|
||||
symbol=symbol_name,
|
||||
resolution=interval,
|
||||
from_time=from_time,
|
||||
to_time=end_time,
|
||||
countback=countback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received {len(result.bars)} bars from data source. "
|
||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Build chart context to return along with result
|
||||
chart_context = {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
}
|
||||
|
||||
return result, chart_context, source_name
|
||||
|
||||
|
||||
@tool
|
||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||
|
||||
This is a convenience tool that automatically:
|
||||
1. Reads the ChartStore to see what chart the user is viewing
|
||||
2. Parses the symbol to determine the data source (exchange prefix)
|
||||
3. Queries the appropriate data source for that symbol's data
|
||||
4. Returns the data for the visible time range and interval
|
||||
|
||||
This is the preferred way to access chart data when helping the user analyze
|
||||
what they're looking at, since it automatically uses their current chart context.
|
||||
|
||||
Args:
|
||||
countback: Optional limit on number of bars to return. If not specified,
|
||||
returns all bars in the visible time range.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- symbol: The trading pair being viewed
|
||||
- resolution: The chart interval
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- source: Which data source was used
|
||||
|
||||
Raises:
|
||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||
or if the symbol format is invalid
|
||||
|
||||
Example:
|
||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||
data = get_chart_data()
|
||||
# Returns BTC/USDT data from binance source at 15min resolution
|
||||
# for the currently visible time range
|
||||
"""
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Return enriched result with chart context
|
||||
response = result.model_dump()
|
||||
response["chart_context"] = chart_context
|
||||
response["source"] = source_name
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@tool
|
||||
async def analyze_chart_data(python_script: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Analyze the current chart data using a Python script with pandas and matplotlib.
|
||||
|
||||
This tool:
|
||||
1. Gets the current chart data (same as get_chart_data)
|
||||
2. Converts it to a pandas DataFrame with columns: time, open, high, low, close, volume
|
||||
3. Executes your Python script with access to the DataFrame as 'df'
|
||||
4. Saves any matplotlib plots to disk and returns URLs to access them
|
||||
5. Returns any final DataFrame result and plot URLs
|
||||
|
||||
The script has access to:
|
||||
- `df`: pandas DataFrame with OHLCV data indexed by datetime
|
||||
- `pandas` (as `pd`): For data manipulation
|
||||
- `numpy` (as `np`): For numerical operations
|
||||
- `matplotlib.pyplot` (as `plt`): For plotting (use plt.figure() for each plot)
|
||||
|
||||
All matplotlib figures are automatically saved to disk and accessible via URLs.
|
||||
The last expression in the script (if it's a DataFrame) is returned as the result.
|
||||
|
||||
Args:
|
||||
python_script: Python code to execute. The DataFrame is available as 'df'.
|
||||
Can use pandas, numpy, matplotlib. Return a DataFrame to include it in results.
|
||||
countback: Optional limit on number of bars to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- source: Data source used
|
||||
- script_output: Any printed output from the script
|
||||
- result_dataframe: If script returns a DataFrame, it's included here as dict
|
||||
- plot_urls: List of URLs to saved plot images (one per plt.figure())
|
||||
- error: Error message if script execution failed
|
||||
|
||||
Example scripts:
|
||||
# Calculate 20-period SMA and plot
|
||||
```python
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(df.index, df['close'], label='Close')
|
||||
plt.plot(df.index, df['SMA20'], label='SMA20')
|
||||
plt.legend()
|
||||
plt.title('Price with SMA')
|
||||
df[['close', 'SMA20']].tail(10) # Return last 10 rows
|
||||
```
|
||||
|
||||
# Calculate RSI
|
||||
```python
|
||||
delta = df['close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
||||
rs = gain / loss
|
||||
df['RSI'] = 100 - (100 / (1 + rs))
|
||||
df[['close', 'RSI']].tail(20)
|
||||
```
|
||||
|
||||
# Multiple plots
|
||||
```python
|
||||
# Price chart
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(df['close'])
|
||||
plt.title('Price')
|
||||
|
||||
# Volume chart
|
||||
plt.figure(figsize=(12, 3))
|
||||
plt.bar(df.index, df['volume'])
|
||||
plt.title('Volume')
|
||||
|
||||
df.describe() # Return statistics
|
||||
```
|
||||
"""
|
||||
if not _registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not _datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
try:
|
||||
# Import pandas and numpy here to allow lazy loading
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"Required library not installed: {e}. "
|
||||
"Please install pandas, numpy, and matplotlib: pip install pandas numpy matplotlib"
|
||||
)
|
||||
|
||||
# Get chart data using the internal helper function
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Build the same response format as get_chart_data
|
||||
chart_data = result.model_dump()
|
||||
chart_data["chart_context"] = chart_context
|
||||
chart_data["source"] = source_name
|
||||
|
||||
# Convert bars to DataFrame
|
||||
bars = chart_data.get('bars', [])
|
||||
if not bars:
|
||||
return {
|
||||
"chart_context": chart_data.get('chart_context', {}),
|
||||
"source": chart_data.get('source', ''),
|
||||
"error": "No data available for the current chart"
|
||||
}
|
||||
|
||||
# Build DataFrame
|
||||
rows = []
|
||||
for bar in bars:
|
||||
row = {
|
||||
'time': pd.to_datetime(bar['time'], unit='s'),
|
||||
**bar['data'] # Includes open, high, low, close, volume, etc.
|
||||
}
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
df.set_index('time', inplace=True)
|
||||
|
||||
# Convert price columns to float for clean numeric operations
|
||||
price_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in price_columns:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
|
||||
logger.info(
|
||||
f"Created DataFrame with {len(df)} rows, columns: {df.columns.tolist()}, "
|
||||
f"time range: {df.index.min()} to {df.index.max()}, "
|
||||
f"dtypes: {df.dtypes.to_dict()}"
|
||||
)
|
||||
|
||||
# Prepare execution environment
|
||||
script_globals = {
|
||||
'df': df,
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'plt': plt,
|
||||
}
|
||||
|
||||
# Capture stdout/stderr
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
result_df = None
|
||||
error_msg = None
|
||||
plot_urls = []
|
||||
|
||||
# Determine uploads directory (relative to this file)
|
||||
uploads_dir = Path(__file__).parent.parent.parent / "uploads"
|
||||
uploads_dir.mkdir(exist_ok=True)
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
|
||||
# Execute the script
|
||||
exec(python_script, script_globals)
|
||||
|
||||
# Check if the last line is an expression that returns a DataFrame
|
||||
# We'll try to evaluate it separately
|
||||
script_lines = python_script.strip().split('\n')
|
||||
if script_lines:
|
||||
last_line = script_lines[-1].strip()
|
||||
# Only evaluate if it doesn't look like a statement
|
||||
if last_line and not any(last_line.startswith(kw) for kw in ['if', 'for', 'while', 'def', 'class', 'import', 'from', 'with', 'try', 'return']):
|
||||
try:
|
||||
last_result = eval(last_line, script_globals)
|
||||
if isinstance(last_result, pd.DataFrame):
|
||||
result_df = last_result
|
||||
except:
|
||||
# If eval fails, that's okay - might not be an expression
|
||||
pass
|
||||
|
||||
# Save all matplotlib figures to disk
|
||||
for fig_num in plt.get_fignums():
|
||||
fig = plt.figure(fig_num)
|
||||
|
||||
# Generate unique filename
|
||||
plot_id = str(uuid.uuid4())
|
||||
filename = f"plot_{plot_id}.png"
|
||||
filepath = uploads_dir / filename
|
||||
|
||||
# Save figure to file
|
||||
fig.savefig(filepath, format='png', bbox_inches='tight', dpi=100)
|
||||
|
||||
# Generate URL that can be accessed via the web server
|
||||
plot_url = f"/uploads/{filename}"
|
||||
plot_urls.append(plot_url)
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
import traceback
|
||||
error_msg += f"\n{traceback.format_exc()}"
|
||||
|
||||
# Build response
|
||||
response = {
|
||||
"chart_context": chart_data.get('chart_context', {}),
|
||||
"source": chart_data.get('source', ''),
|
||||
"script_output": stdout_capture.getvalue(),
|
||||
}
|
||||
|
||||
if error_msg:
|
||||
response["error"] = error_msg
|
||||
response["stderr"] = stderr_capture.getvalue()
|
||||
|
||||
if result_df is not None:
|
||||
# Convert DataFrame to dict for JSON serialization
|
||||
response["result_dataframe"] = {
|
||||
"columns": result_df.columns.tolist(),
|
||||
"index": result_df.index.astype(str).tolist() if hasattr(result_df.index, 'astype') else result_df.index.tolist(),
|
||||
"data": result_df.values.tolist(),
|
||||
"shape": result_df.shape,
|
||||
}
|
||||
|
||||
if plot_urls:
|
||||
response["plot_urls"] = plot_urls
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Export all tools
|
||||
SYNC_TOOLS = [
|
||||
list_sync_stores,
|
||||
read_sync_state,
|
||||
write_sync_state,
|
||||
get_store_schema
|
||||
]
|
||||
|
||||
DATASOURCE_TOOLS = [
|
||||
list_data_sources,
|
||||
search_symbols,
|
||||
get_symbol_info,
|
||||
get_historical_data,
|
||||
get_chart_data,
|
||||
analyze_chart_data
|
||||
]
|
||||
23
backend/src/datasource/__init__.py
Normal file
23
backend/src/datasource/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .base import DataSource
|
||||
from .schema import (
|
||||
ColumnInfo,
|
||||
SymbolInfo,
|
||||
Bar,
|
||||
HistoryResult,
|
||||
DatafeedConfig,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
)
|
||||
from .registry import DataSourceRegistry
|
||||
|
||||
__all__ = [
|
||||
"DataSource",
|
||||
"ColumnInfo",
|
||||
"SymbolInfo",
|
||||
"Bar",
|
||||
"HistoryResult",
|
||||
"DatafeedConfig",
|
||||
"Resolution",
|
||||
"SearchResult",
|
||||
"DataSourceRegistry",
|
||||
]
|
||||
3
backend/src/datasource/adapters/__init__.py
Normal file
3
backend/src/datasource/adapters/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ccxt_adapter import CCXTDataSource
|
||||
|
||||
__all__ = ["CCXTDataSource"]
|
||||
526
backend/src/datasource/adapters/ccxt_adapter.py
Normal file
526
backend/src/datasource/adapters/ccxt_adapter.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
CCXT DataSource adapter for accessing cryptocurrency exchange data.
|
||||
|
||||
This adapter provides access to hundreds of cryptocurrency exchanges through
|
||||
the free CCXT library (not ccxt.pro), supporting both historical data and
|
||||
polling-based subscriptions.
|
||||
|
||||
Numerical Precision:
|
||||
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
|
||||
- CCXT returns numeric values as strings or floats depending on configuration
|
||||
- All financial values are converted to Decimal to maintain precision
|
||||
|
||||
Real-time Updates:
|
||||
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
||||
- Default polling interval: 60 seconds (configurable)
|
||||
- Simulates real-time subscriptions by periodically fetching latest bars
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import ccxt.async_support as ccxt
|
||||
|
||||
from ..base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..schema import (
|
||||
Bar,
|
||||
ColumnInfo,
|
||||
DatafeedConfig,
|
||||
HistoryResult,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
SymbolInfo,
|
||||
)
|
||||
|
||||
|
||||
class CCXTDataSource(DataSource):
|
||||
"""
|
||||
DataSource adapter for CCXT cryptocurrency exchanges (free version).
|
||||
|
||||
Provides access to:
|
||||
- Multiple cryptocurrency exchanges (Binance, Coinbase, Kraken, etc.)
|
||||
- Historical OHLCV data via REST API
|
||||
- Polling-based real-time updates (configurable interval)
|
||||
- Symbol search and metadata
|
||||
|
||||
Args:
|
||||
exchange_id: CCXT exchange identifier (e.g., 'binance', 'coinbase', 'kraken')
|
||||
config: Optional exchange-specific configuration (API keys, options)
|
||||
sandbox: Whether to use sandbox/testnet mode (default: False)
|
||||
poll_interval: Interval in seconds for polling updates (default: 60)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exchange_id: str = "binance",
|
||||
config: Optional[Dict] = None,
|
||||
sandbox: bool = False,
|
||||
poll_interval: int = 60,
|
||||
):
|
||||
self.exchange_id = exchange_id
|
||||
self._config = config or {}
|
||||
self._sandbox = sandbox
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
# Initialize exchange (using free async_support, not pro)
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchange = exchange_class(self._config)
|
||||
|
||||
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
||||
self.exchange.set_sandbox_mode(True)
|
||||
|
||||
# Cache for markets
|
||||
self._markets: Optional[Dict] = None
|
||||
self._markets_loaded = False
|
||||
|
||||
# Active subscriptions (polling-based)
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._subscription_callbacks: Dict[str, Callable] = {}
|
||||
self._last_bars: Dict[str, int] = {} # Track last bar timestamp per subscription
|
||||
|
||||
@staticmethod
|
||||
def _to_decimal(value: Union[str, int, float, Decimal, None]) -> Optional[Decimal]:
|
||||
"""
|
||||
Convert a value to Decimal for numerical precision.
|
||||
|
||||
Handles CCXT's mixed output (strings, floats, ints, None).
|
||||
Converts floats by converting to string first to avoid precision loss.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Decimal):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return Decimal(value)
|
||||
if isinstance(value, (int, float)):
|
||||
# Convert to string first to avoid float precision issues
|
||||
return Decimal(str(value))
|
||||
return None
|
||||
|
||||
async def _ensure_markets_loaded(self):
|
||||
"""Ensure markets are loaded from exchange"""
|
||||
if not self._markets_loaded:
|
||||
self._markets = await self.exchange.load_markets()
|
||||
self._markets_loaded = True
|
||||
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
"""Get datafeed configuration"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
# Determine supported resolutions based on exchange capabilities
|
||||
supported_resolutions = [
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.M30,
|
||||
Resolution.H1,
|
||||
Resolution.H4,
|
||||
Resolution.D1,
|
||||
]
|
||||
|
||||
# Get unique exchange names (most CCXT exchanges are just one)
|
||||
exchanges = [self.exchange_id.upper()]
|
||||
|
||||
return DatafeedConfig(
|
||||
name=f"CCXT {self.exchange_id.title()}",
|
||||
description=f"Live and historical cryptocurrency data from {self.exchange_id} via CCXT library. "
|
||||
f"Supports OHLCV data for {len(self._markets) if self._markets else 'many'} trading pairs.",
|
||||
supported_resolutions=supported_resolutions,
|
||||
supports_search=True,
|
||||
supports_time=True,
|
||||
exchanges=exchanges,
|
||||
symbols_types=["crypto", "spot", "futures", "swap"],
|
||||
)
|
||||
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
"""Search for symbols on the exchange"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
query_upper = query.upper()
|
||||
results = []
|
||||
|
||||
for symbol, market in self._markets.items():
|
||||
# Match query against symbol or base/quote currencies
|
||||
if (query_upper in symbol or
|
||||
query_upper in market.get('base', '') or
|
||||
query_upper in market.get('quote', '')):
|
||||
|
||||
# Filter by type if specified
|
||||
market_type = market.get('type', 'spot')
|
||||
if type and market_type != type:
|
||||
continue
|
||||
|
||||
# Create search result
|
||||
base = market.get('base', '')
|
||||
quote = market.get('quote', '')
|
||||
|
||||
results.append(
|
||||
SearchResult(
|
||||
symbol=f"{base}/{quote}", # Clean user-facing format
|
||||
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
|
||||
full_name=f"{base}/{quote} ({self.exchange_id.upper()})",
|
||||
description=f"{base}/{quote} {market_type} trading pair on {self.exchange_id}",
|
||||
exchange=self.exchange_id.upper(),
|
||||
type=market_type,
|
||||
)
|
||||
)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
"""Get complete metadata for a symbol"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
market = self._markets[symbol]
|
||||
base = market.get('base', '')
|
||||
quote = market.get('quote', '')
|
||||
market_type = market.get('type', 'spot')
|
||||
|
||||
# Determine price scale from market precision
|
||||
# CCXT precision can be in different modes:
|
||||
# - DECIMAL_PLACES (int): number of decimal places (e.g., 2 = 0.01)
|
||||
# - TICK_SIZE (float): actual tick size (e.g., 0.01, 0.00001)
|
||||
# We need to convert to pricescale (10^n where n is decimal places)
|
||||
price_precision = market.get('precision', {}).get('price', 2)
|
||||
|
||||
if isinstance(price_precision, float):
|
||||
# TICK_SIZE mode: precision is the actual tick size (e.g., 0.01, 0.00001)
|
||||
# Convert tick size to decimal places
|
||||
# For 0.01 -> 2 decimal places, 0.00001 -> 5 decimal places
|
||||
tick_str = str(Decimal(str(price_precision)))
|
||||
if '.' in tick_str:
|
||||
decimal_places = len(tick_str.split('.')[1].rstrip('0'))
|
||||
else:
|
||||
decimal_places = 0
|
||||
pricescale = 10 ** decimal_places
|
||||
else:
|
||||
# DECIMAL_PLACES or SIGNIFICANT_DIGITS mode: precision is an integer
|
||||
# Assume DECIMAL_PLACES mode (most common for price)
|
||||
pricescale = 10 ** int(price_precision)
|
||||
|
||||
return SymbolInfo(
|
||||
symbol=f"{base}/{quote}", # Clean user-facing format
|
||||
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
|
||||
name=f"{base}/{quote}",
|
||||
description=f"{base}/{quote} {market_type} pair on {self.exchange_id}. "
|
||||
f"Minimum order: {market.get('limits', {}).get('amount', {}).get('min', 'N/A')} {base}",
|
||||
type=market_type,
|
||||
exchange=self.exchange_id.upper(),
|
||||
timezone="Etc/UTC",
|
||||
session="24x7",
|
||||
supported_resolutions=[
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.M30,
|
||||
Resolution.H1,
|
||||
Resolution.H4,
|
||||
Resolution.D1,
|
||||
],
|
||||
has_intraday=True,
|
||||
has_daily=True,
|
||||
has_weekly_and_monthly=False,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="decimal",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="decimal",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="decimal",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="decimal",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="decimal",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
],
|
||||
time_column="time",
|
||||
has_ohlcv=True,
|
||||
pricescale=pricescale,
|
||||
minmov=1,
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
)
|
||||
|
||||
def _resolution_to_timeframe(self, resolution: str) -> str:
|
||||
"""Convert our resolution format to CCXT timeframe format"""
|
||||
# Map our resolutions to CCXT timeframes
|
||||
mapping = {
|
||||
"1": "1m",
|
||||
"5": "5m",
|
||||
"15": "15m",
|
||||
"30": "30m",
|
||||
"60": "1h",
|
||||
"120": "2h",
|
||||
"240": "4h",
|
||||
"360": "6h",
|
||||
"720": "12h",
|
||||
"1D": "1d",
|
||||
"1W": "1w",
|
||||
"1M": "1M",
|
||||
}
|
||||
return mapping.get(resolution, "1m")
|
||||
|
||||
def _timeframe_to_milliseconds(self, timeframe: str) -> int:
|
||||
"""Convert CCXT timeframe to milliseconds"""
|
||||
unit = timeframe[-1]
|
||||
amount = int(timeframe[:-1]) if len(timeframe) > 1 else 1
|
||||
|
||||
units = {
|
||||
's': 1000,
|
||||
'm': 60 * 1000,
|
||||
'h': 60 * 60 * 1000,
|
||||
'd': 24 * 60 * 60 * 1000,
|
||||
'w': 7 * 24 * 60 * 60 * 1000,
|
||||
'M': 30 * 24 * 60 * 60 * 1000, # Approximate
|
||||
}
|
||||
|
||||
return amount * units.get(unit, 60000)
|
||||
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
"""Get historical bars from the exchange"""
|
||||
logger.info(
|
||||
f"CCXTDataSource({self.exchange_id}).get_bars: symbol={symbol}, resolution={resolution}, "
|
||||
f"from_time={from_time}, to_time={to_time}, countback={countback}"
|
||||
)
|
||||
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
timeframe = self._resolution_to_timeframe(resolution)
|
||||
|
||||
# CCXT uses milliseconds for timestamps
|
||||
since = from_time * 1000
|
||||
until = to_time * 1000
|
||||
|
||||
# Fetch OHLCV data
|
||||
limit = countback if countback else 1000
|
||||
|
||||
try:
|
||||
# Fetch in batches if needed
|
||||
all_ohlcv = []
|
||||
current_since = since
|
||||
|
||||
while current_since < until:
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
symbol,
|
||||
timeframe=timeframe,
|
||||
since=current_since,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not ohlcv:
|
||||
break
|
||||
|
||||
all_ohlcv.extend(ohlcv)
|
||||
|
||||
# Update since for next batch
|
||||
last_timestamp = ohlcv[-1][0]
|
||||
if last_timestamp <= current_since:
|
||||
break # No progress, avoid infinite loop
|
||||
current_since = last_timestamp + 1
|
||||
|
||||
# Stop if we have enough bars
|
||||
if countback and len(all_ohlcv) >= countback:
|
||||
all_ohlcv = all_ohlcv[:countback]
|
||||
break
|
||||
|
||||
# Convert to our Bar format with Decimal precision
|
||||
bars = []
|
||||
for candle in all_ohlcv:
|
||||
timestamp_ms, open_price, high, low, close, volume = candle
|
||||
timestamp = timestamp_ms // 1000 # Convert to seconds
|
||||
|
||||
# Only include bars within requested range
|
||||
if timestamp < from_time or timestamp >= to_time:
|
||||
continue
|
||||
|
||||
bars.append(
|
||||
Bar(
|
||||
time=timestamp,
|
||||
data={
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Get symbol info for column metadata
|
||||
symbol_info = await self.resolve_symbol(symbol)
|
||||
|
||||
logger.info(
|
||||
f"CCXTDataSource({self.exchange_id}).get_bars: Returning {len(bars)} bars. "
|
||||
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Determine if more data is available
|
||||
next_time = None
|
||||
if bars and countback and len(bars) >= countback:
|
||||
next_time = bars[-1].time + (bars[-1].time - bars[-2].time if len(bars) > 1 else 60)
|
||||
|
||||
return HistoryResult(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
bars=bars,
|
||||
columns=symbol_info.columns,
|
||||
nextTime=next_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch bars for {symbol}: {str(e)}")
|
||||
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to bar updates via polling.
|
||||
|
||||
Note: Uses polling instead of WebSocket since we're using free CCXT.
|
||||
Polls at the configured interval (default: 60 seconds).
|
||||
"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
subscription_id = f"{symbol}:{resolution}:{time.time()}"
|
||||
|
||||
# Store callback
|
||||
self._subscription_callbacks[subscription_id] = on_tick
|
||||
|
||||
# Start polling task
|
||||
timeframe = self._resolution_to_timeframe(resolution)
|
||||
task = asyncio.create_task(
|
||||
self._poll_ohlcv(symbol, timeframe, subscription_id)
|
||||
)
|
||||
self._subscriptions[subscription_id] = task
|
||||
|
||||
return subscription_id
|
||||
|
||||
async def _poll_ohlcv(self, symbol: str, timeframe: str, subscription_id: str):
|
||||
"""
|
||||
Poll for OHLCV updates at regular intervals.
|
||||
|
||||
This simulates real-time updates by fetching the latest bars periodically.
|
||||
Only sends updates when new bars are detected.
|
||||
"""
|
||||
try:
|
||||
while subscription_id in self._subscription_callbacks:
|
||||
try:
|
||||
# Fetch latest bars
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
symbol,
|
||||
timeframe=timeframe,
|
||||
limit=2, # Get last 2 bars to detect new ones
|
||||
)
|
||||
|
||||
if ohlcv and len(ohlcv) > 0:
|
||||
# Get the latest candle
|
||||
latest = ohlcv[-1]
|
||||
timestamp_ms, open_price, high, low, close, volume = latest
|
||||
timestamp = timestamp_ms // 1000
|
||||
|
||||
# Only send update if this is a new bar
|
||||
last_timestamp = self._last_bars.get(subscription_id, 0)
|
||||
if timestamp > last_timestamp:
|
||||
self._last_bars[subscription_id] = timestamp
|
||||
|
||||
# Convert to our format with Decimal precision
|
||||
tick_data = {
|
||||
"time": timestamp,
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
}
|
||||
|
||||
# Call the callback
|
||||
callback = self._subscription_callbacks.get(subscription_id)
|
||||
if callback:
|
||||
callback(tick_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error polling OHLCV for {symbol}: {e}")
|
||||
|
||||
# Wait for next poll interval
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
"""Unsubscribe from polling updates"""
|
||||
# Remove callback and tracking
|
||||
self._subscription_callbacks.pop(subscription_id, None)
|
||||
self._last_bars.pop(subscription_id, None)
|
||||
|
||||
# Cancel polling task
|
||||
task = self._subscriptions.pop(subscription_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close exchange connection and cleanup"""
|
||||
# Cancel all subscriptions
|
||||
for subscription_id in list(self._subscriptions.keys()):
|
||||
await self.unsubscribe_bars(subscription_id)
|
||||
|
||||
# Close exchange
|
||||
if hasattr(self.exchange, 'close'):
|
||||
await self.exchange.close()
|
||||
353
backend/src/datasource/adapters/demo.py
Normal file
353
backend/src/datasource/adapters/demo.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Demo data source with synthetic data.
|
||||
|
||||
Generates realistic-looking OHLCV data plus additional columns for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from ..base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..schema import (
|
||||
Bar,
|
||||
ColumnInfo,
|
||||
DatafeedConfig,
|
||||
HistoryResult,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
SymbolInfo,
|
||||
)
|
||||
|
||||
|
||||
class DemoDataSource(DataSource):
|
||||
"""
|
||||
Demo data source that generates synthetic time-series data.
|
||||
|
||||
Provides:
|
||||
- Standard OHLCV columns
|
||||
- Additional demo columns (RSI, sentiment, volume_profile)
|
||||
- Real-time updates via polling simulation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._symbols = {
|
||||
"DEMO:BTC/USD": {
|
||||
"name": "Bitcoin",
|
||||
"type": "crypto",
|
||||
"base_price": 50000.0,
|
||||
"volatility": 0.02,
|
||||
},
|
||||
"DEMO:ETH/USD": {
|
||||
"name": "Ethereum",
|
||||
"type": "crypto",
|
||||
"base_price": 3000.0,
|
||||
"volatility": 0.03,
|
||||
},
|
||||
"DEMO:SOL/USD": {
|
||||
"name": "Solana",
|
||||
"type": "crypto",
|
||||
"base_price": 100.0,
|
||||
"volatility": 0.04,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
return DatafeedConfig(
|
||||
name="Demo DataSource",
|
||||
description="Synthetic data generator for testing. Provides OHLCV plus additional indicator columns.",
|
||||
supported_resolutions=[
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.H1,
|
||||
Resolution.D1,
|
||||
],
|
||||
supports_search=True,
|
||||
supports_time=True,
|
||||
exchanges=["DEMO"],
|
||||
symbols_types=["crypto"],
|
||||
)
|
||||
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for symbol, info in self._symbols.items():
|
||||
if query_lower in symbol.lower() or query_lower in info["name"].lower():
|
||||
if type and info["type"] != type:
|
||||
continue
|
||||
results.append(
|
||||
SearchResult(
|
||||
symbol=info['name'], # Clean user-facing format (e.g., "Bitcoin")
|
||||
ticker=symbol, # Keep DEMO:BTC/USD format for routing
|
||||
full_name=f"{info['name']} (DEMO)",
|
||||
description=f"Demo {info['name']} pair",
|
||||
exchange="DEMO",
|
||||
type=info["type"],
|
||||
)
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
info = self._symbols[symbol]
|
||||
base, quote = symbol.split(":")[1].split("/")
|
||||
|
||||
return SymbolInfo(
|
||||
symbol=info["name"], # Clean user-facing format (e.g., "Bitcoin")
|
||||
ticker=symbol, # Keep DEMO:BTC/USD format for routing
|
||||
name=info["name"],
|
||||
description=f"Demo {info['name']} spot price with synthetic indicators",
|
||||
type=info["type"],
|
||||
exchange="DEMO",
|
||||
timezone="Etc/UTC",
|
||||
session="24x7",
|
||||
supported_resolutions=[Resolution.M1, Resolution.M5, Resolution.M15, Resolution.H1, Resolution.D1],
|
||||
has_intraday=True,
|
||||
has_daily=True,
|
||||
has_weekly_and_monthly=False,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="float",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="float",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="float",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="float",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="float",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="rsi",
|
||||
type="float",
|
||||
description="Relative Strength Index (14-period), range 0-100",
|
||||
unit=None,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="sentiment",
|
||||
type="float",
|
||||
description="Synthetic social sentiment score, range -1.0 to 1.0",
|
||||
unit=None,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume_profile",
|
||||
type="float",
|
||||
description="Volume as percentage of 24h average",
|
||||
unit="%",
|
||||
),
|
||||
],
|
||||
time_column="time",
|
||||
has_ohlcv=True,
|
||||
pricescale=100,
|
||||
minmov=1,
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
)
|
||||
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
logger.info(
|
||||
f"DemoDataSource.get_bars: symbol={symbol}, resolution={resolution}, "
|
||||
f"from_time={from_time}, to_time={to_time}, countback={countback}"
|
||||
)
|
||||
|
||||
info = self._symbols[symbol]
|
||||
symbol_meta = await self.resolve_symbol(symbol)
|
||||
|
||||
# Convert resolution to seconds
|
||||
resolution_seconds = self._resolution_to_seconds(resolution)
|
||||
|
||||
# Generate bars
|
||||
bars = []
|
||||
# Align current_time to resolution, but ensure it's >= from_time
|
||||
current_time = from_time - (from_time % resolution_seconds)
|
||||
if current_time < from_time:
|
||||
current_time += resolution_seconds
|
||||
|
||||
price = info["base_price"]
|
||||
|
||||
bar_count = 0
|
||||
max_bars = countback if countback else 5000
|
||||
|
||||
while current_time <= to_time and bar_count < max_bars:
|
||||
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
|
||||
|
||||
# Only include bars within the requested range
|
||||
if from_time <= current_time <= to_time:
|
||||
bars.append(Bar(time=current_time * 1000, data=bar_data)) # Convert to milliseconds
|
||||
bar_count += 1
|
||||
|
||||
price = bar_data["close"] # Next bar starts from previous close
|
||||
current_time += resolution_seconds
|
||||
|
||||
logger.info(
|
||||
f"DemoDataSource.get_bars: Generated {len(bars)} bars. "
|
||||
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Determine if there's more data (for pagination)
|
||||
next_time = current_time if current_time <= to_time else None
|
||||
|
||||
return HistoryResult(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
bars=bars,
|
||||
columns=symbol_meta.columns,
|
||||
nextTime=next_time,
|
||||
)
|
||||
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
subscription_id = f"{symbol}:{resolution}:{time.time()}"
|
||||
|
||||
# Start background task to simulate real-time updates
|
||||
task = asyncio.create_task(
|
||||
self._tick_generator(symbol, resolution, on_tick, subscription_id)
|
||||
)
|
||||
self._subscriptions[subscription_id] = task
|
||||
|
||||
return subscription_id
|
||||
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
task = self._subscriptions.pop(subscription_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _resolution_to_seconds(self, resolution: str) -> int:
|
||||
"""Convert resolution string to seconds"""
|
||||
if resolution.endswith("D"):
|
||||
return int(resolution[:-1]) * 86400
|
||||
elif resolution.endswith("W"):
|
||||
return int(resolution[:-1]) * 604800
|
||||
elif resolution.endswith("M"):
|
||||
return int(resolution[:-1]) * 2592000 # Approximate month
|
||||
else:
|
||||
# Assume minutes
|
||||
return int(resolution) * 60
|
||||
|
||||
def _generate_bar(self, timestamp: int, base_price: float, volatility: float, period_seconds: int) -> dict:
|
||||
"""Generate a single synthetic OHLCV bar"""
|
||||
# Random walk for the period
|
||||
open_price = base_price
|
||||
|
||||
# Generate intra-period price movement
|
||||
num_ticks = max(10, period_seconds // 60) # More ticks for longer periods
|
||||
prices = [open_price]
|
||||
|
||||
for _ in range(num_ticks):
|
||||
change = random.gauss(0, volatility / math.sqrt(num_ticks))
|
||||
prices.append(prices[-1] * (1 + change))
|
||||
|
||||
close_price = prices[-1]
|
||||
high_price = max(prices)
|
||||
low_price = min(prices)
|
||||
|
||||
# Volume with some randomness
|
||||
base_volume = 1000000 * (period_seconds / 60) # Scale with period
|
||||
volume = base_volume * random.uniform(0.5, 2.0)
|
||||
|
||||
# Additional synthetic indicators
|
||||
rsi = 30 + random.random() * 40 # Biased toward middle range
|
||||
sentiment = math.sin(timestamp / 3600) * 0.5 + random.gauss(0, 0.2) # Hourly cycle + noise
|
||||
sentiment = max(-1.0, min(1.0, sentiment))
|
||||
volume_profile = 100 * random.uniform(0.5, 1.5)
|
||||
|
||||
return {
|
||||
"open": round(open_price, 2),
|
||||
"high": round(high_price, 2),
|
||||
"low": round(low_price, 2),
|
||||
"close": round(close_price, 2),
|
||||
"volume": round(volume, 2),
|
||||
"rsi": round(rsi, 2),
|
||||
"sentiment": round(sentiment, 3),
|
||||
"volume_profile": round(volume_profile, 2),
|
||||
}
|
||||
|
||||
async def _tick_generator(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
subscription_id: str,
|
||||
):
|
||||
"""Background task that generates periodic ticks"""
|
||||
info = self._symbols[symbol]
|
||||
resolution_seconds = self._resolution_to_seconds(resolution)
|
||||
|
||||
# Start from current aligned time
|
||||
current_time = int(time.time())
|
||||
current_time = current_time - (current_time % resolution_seconds)
|
||||
price = info["base_price"]
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait until next bar
|
||||
await asyncio.sleep(resolution_seconds)
|
||||
|
||||
current_time += resolution_seconds
|
||||
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
|
||||
price = bar_data["close"]
|
||||
|
||||
# Call the tick handler
|
||||
tick_data = {"time": current_time, **bar_data}
|
||||
on_tick(tick_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Subscription cancelled
|
||||
pass
|
||||
146
backend/src/datasource/base.py
Normal file
146
backend/src/datasource/base.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Abstract DataSource interface.
|
||||
|
||||
Inspired by TradingView's Datafeed API with extensions for flexible column schemas
|
||||
and AI-native metadata.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from .schema import DatafeedConfig, HistoryResult, SearchResult, SymbolInfo
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""
|
||||
Abstract base class for time-series data sources.
|
||||
|
||||
Provides a standardized interface for:
|
||||
- Symbol search and metadata retrieval
|
||||
- Historical data queries (time-based, paginated)
|
||||
- Real-time data subscriptions
|
||||
|
||||
All data rows must have a timestamp. Additional columns are flexible
|
||||
and described via ColumnInfo metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
"""
|
||||
Get datafeed configuration and capabilities.
|
||||
|
||||
Called once during initialization to understand what this data source
|
||||
supports (resolutions, exchanges, search, etc.).
|
||||
|
||||
Returns:
|
||||
DatafeedConfig describing this datafeed's capabilities
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search for symbols matching a text query.
|
||||
|
||||
Args:
|
||||
query: Free-text search string
|
||||
type: Optional filter by instrument type
|
||||
exchange: Optional filter by exchange
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching symbols with basic metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
"""
|
||||
Get complete metadata for a symbol.
|
||||
|
||||
Called after a symbol is selected to retrieve full information including
|
||||
supported resolutions, column schema, trading session, etc.
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
|
||||
Returns:
|
||||
Complete SymbolInfo including column definitions
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
"""
|
||||
Get historical bars for a symbol and resolution.
|
||||
|
||||
Time range is specified in Unix timestamps (seconds). If more data is
|
||||
available beyond the requested range, the result should include a
|
||||
nextTime value for pagination.
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1", "5", "60", "1D")
|
||||
from_time: Start time (Unix timestamp in seconds)
|
||||
to_time: End time (Unix timestamp in seconds)
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
HistoryResult with bars, column schema, and pagination info
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol or resolution is not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to real-time bar updates.
|
||||
|
||||
The callback will be invoked with new bar data as it becomes available.
|
||||
The data dict will match the column schema from resolve_symbol().
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution
|
||||
on_tick: Callback function receiving bar data dict
|
||||
|
||||
Returns:
|
||||
Subscription ID for later unsubscribe
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol or resolution is not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe from real-time updates.
|
||||
|
||||
Args:
|
||||
subscription_id: ID returned from subscribe_bars()
|
||||
"""
|
||||
pass
|
||||
109
backend/src/datasource/registry.py
Normal file
109
backend/src/datasource/registry.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
DataSource registry for managing multiple data sources.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base import DataSource
|
||||
from .schema import SearchResult, SymbolInfo
|
||||
|
||||
|
||||
class DataSourceRegistry:
|
||||
"""
|
||||
Central registry for managing multiple DataSource instances.
|
||||
|
||||
Allows routing symbol queries to the appropriate data source and
|
||||
aggregating search results across multiple sources.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._sources: Dict[str, DataSource] = {}
|
||||
|
||||
def register(self, name: str, source: DataSource) -> None:
|
||||
"""
|
||||
Register a data source.
|
||||
|
||||
Args:
|
||||
name: Unique name for this data source
|
||||
source: DataSource implementation
|
||||
"""
|
||||
self._sources[name] = source
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a data source.
|
||||
|
||||
Args:
|
||||
name: Name of the data source to remove
|
||||
"""
|
||||
self._sources.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Optional[DataSource]:
|
||||
"""
|
||||
Get a registered data source by name.
|
||||
|
||||
Args:
|
||||
name: Data source name
|
||||
|
||||
Returns:
|
||||
DataSource instance or None if not found
|
||||
"""
|
||||
return self._sources.get(name)
|
||||
|
||||
def list_sources(self) -> List[str]:
|
||||
"""
|
||||
Get names of all registered data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names
|
||||
"""
|
||||
return list(self._sources.keys())
|
||||
|
||||
async def search_all(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""
|
||||
Search across all registered data sources.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
type: Optional instrument type filter
|
||||
exchange: Optional exchange filter
|
||||
limit: Maximum results per source
|
||||
|
||||
Returns:
|
||||
Dict mapping source name to search results
|
||||
"""
|
||||
results = {}
|
||||
for name, source in self._sources.items():
|
||||
try:
|
||||
source_results = await source.search_symbols(query, type, exchange, limit)
|
||||
if source_results:
|
||||
results[name] = source_results
|
||||
except Exception:
|
||||
# Silently skip sources that error during search
|
||||
continue
|
||||
return results
|
||||
|
||||
async def resolve_symbol(self, source_name: str, symbol: str) -> SymbolInfo:
|
||||
"""
|
||||
Resolve a symbol from a specific data source.
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
|
||||
Returns:
|
||||
SymbolInfo from the specified source
|
||||
|
||||
Raises:
|
||||
ValueError: If source not found or symbol not found
|
||||
"""
|
||||
source = self.get(source_name)
|
||||
if not source:
|
||||
raise ValueError(f"Data source '{source_name}' not found")
|
||||
return await source.resolve_symbol(symbol)
|
||||
194
backend/src/datasource/schema.py
Normal file
194
backend/src/datasource/schema.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Data models for the DataSource interface.
|
||||
|
||||
Inspired by TradingView's Datafeed API but with flexible column schemas
|
||||
for AI-native trading platform needs.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Resolution(StrEnum):
|
||||
"""Standard time resolutions for bar data"""
|
||||
|
||||
# Seconds
|
||||
S1 = "1S"
|
||||
S5 = "5S"
|
||||
S15 = "15S"
|
||||
S30 = "30S"
|
||||
|
||||
# Minutes
|
||||
M1 = "1"
|
||||
M5 = "5"
|
||||
M15 = "15"
|
||||
M30 = "30"
|
||||
|
||||
# Hours
|
||||
H1 = "60"
|
||||
H2 = "120"
|
||||
H4 = "240"
|
||||
H6 = "360"
|
||||
H12 = "720"
|
||||
|
||||
# Days
|
||||
D1 = "1D"
|
||||
|
||||
# Weeks
|
||||
W1 = "1W"
|
||||
|
||||
# Months
|
||||
MO1 = "1M"
|
||||
|
||||
|
||||
class ColumnInfo(BaseModel):
|
||||
"""
|
||||
Metadata for a single data column.
|
||||
|
||||
Provides rich, LLM-readable descriptions so AI agents can understand
|
||||
and reason about available data fields.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Column name (e.g., 'close', 'volume', 'funding_rate')")
|
||||
type: Literal["float", "int", "bool", "string", "decimal"] = Field(description="Data type")
|
||||
description: str = Field(description="Human and LLM-readable description of what this column represents")
|
||||
unit: Optional[str] = Field(default=None, description="Unit of measurement (e.g., 'USD', 'BTC', '%', 'contracts')")
|
||||
nullable: bool = Field(default=False, description="Whether this column can contain null values")
|
||||
|
||||
|
||||
class SymbolInfo(BaseModel):
|
||||
"""
|
||||
Complete metadata for a tradeable symbol.
|
||||
|
||||
Includes both TradingView-compatible fields and flexible schema definition
|
||||
for arbitrary data columns.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# Core identification
|
||||
symbol: str = Field(description="Unique symbol identifier (primary key for data fetching)")
|
||||
ticker: Optional[str] = Field(default=None, description="TradingView ticker (if different from symbol)")
|
||||
name: str = Field(description="Display name")
|
||||
description: str = Field(description="LLM-readable description of the instrument")
|
||||
type: str = Field(description="Instrument type: 'crypto', 'stock', 'forex', 'futures', 'derived', etc.")
|
||||
exchange: str = Field(description="Exchange or data source identifier")
|
||||
|
||||
# Trading session info
|
||||
timezone: str = Field(default="Etc/UTC", description="IANA timezone identifier")
|
||||
session: str = Field(default="24x7", description="Trading session spec (e.g., '0930-1600' or '24x7')")
|
||||
|
||||
# Resolution support
|
||||
supported_resolutions: List[str] = Field(description="List of supported time resolutions")
|
||||
has_intraday: bool = Field(default=True, description="Whether intraday resolutions are supported")
|
||||
has_daily: bool = Field(default=True, description="Whether daily resolution is supported")
|
||||
has_weekly_and_monthly: bool = Field(default=False, description="Whether weekly/monthly resolutions are supported")
|
||||
|
||||
# Flexible schema definition
|
||||
columns: List[ColumnInfo] = Field(description="Available data columns for this symbol")
|
||||
time_column: str = Field(default="time", description="Name of the timestamp column")
|
||||
|
||||
# Convenience flags
|
||||
has_ohlcv: bool = Field(default=False, description="Whether standard OHLCV columns are present")
|
||||
|
||||
# Price display (for OHLCV data)
|
||||
pricescale: int = Field(default=100, description="Price scale factor (e.g., 100 for 2 decimals)")
|
||||
minmov: int = Field(default=1, description="Minimum price movement in pricescale units")
|
||||
|
||||
# Additional metadata
|
||||
base_currency: Optional[str] = Field(default=None, description="Base currency (for crypto/forex)")
|
||||
quote_currency: Optional[str] = Field(default=None, description="Quote currency (for crypto/forex)")
|
||||
|
||||
|
||||
class Bar(BaseModel):
|
||||
"""
|
||||
A single bar/row of time-series data with flexible columns.
|
||||
|
||||
All bars must have a timestamp. Additional columns are stored in the
|
||||
data dict and described by the associated ColumnInfo metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
time: int = Field(description="Unix timestamp in seconds")
|
||||
data: Dict[str, Any] = Field(description="Column name -> value mapping")
|
||||
|
||||
# Convenience accessors for common OHLCV columns
|
||||
@property
|
||||
def open(self) -> Optional[float]:
|
||||
return self.data.get("open")
|
||||
|
||||
@property
|
||||
def high(self) -> Optional[float]:
|
||||
return self.data.get("high")
|
||||
|
||||
@property
|
||||
def low(self) -> Optional[float]:
|
||||
return self.data.get("low")
|
||||
|
||||
@property
|
||||
def close(self) -> Optional[float]:
|
||||
return self.data.get("close")
|
||||
|
||||
@property
|
||||
def volume(self) -> Optional[float]:
|
||||
return self.data.get("volume")
|
||||
|
||||
|
||||
class HistoryResult(BaseModel):
|
||||
"""
|
||||
Result from a historical data query.
|
||||
|
||||
Includes the bars, schema information, and pagination metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol: str = Field(description="Symbol identifier")
|
||||
resolution: str = Field(description="Time resolution of the bars")
|
||||
bars: List[Bar] = Field(description="The actual data bars")
|
||||
columns: List[ColumnInfo] = Field(description="Schema describing the bar data columns")
|
||||
nextTime: Optional[int] = Field(default=None, description="Unix timestamp for pagination (if more data available)")
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""
|
||||
A single result from symbol search.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol: str = Field(description="Display symbol (e.g., 'BINANCE:ETH/BTC')")
|
||||
ticker: Optional[str] = Field(default=None, description="Backend ticker for data fetching (e.g., 'ETH/BTC')")
|
||||
full_name: str = Field(description="Full display name including exchange")
|
||||
description: str = Field(description="Human-readable description")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
type: str = Field(description="Instrument type")
|
||||
|
||||
|
||||
class DatafeedConfig(BaseModel):
|
||||
"""
|
||||
Configuration and capabilities of a DataSource.
|
||||
|
||||
Similar to TradingView's onReady configuration object.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# Supported features
|
||||
supported_resolutions: List[str] = Field(description="All resolutions this datafeed supports")
|
||||
supports_search: bool = Field(default=True, description="Whether symbol search is available")
|
||||
supports_time: bool = Field(default=True, description="Whether time-based queries are supported")
|
||||
supports_marks: bool = Field(default=False, description="Whether marks/events are supported")
|
||||
|
||||
# Data characteristics
|
||||
exchanges: List[str] = Field(default_factory=list, description="Available exchanges")
|
||||
symbols_types: List[str] = Field(default_factory=list, description="Available instrument types")
|
||||
|
||||
# Metadata
|
||||
name: str = Field(description="Datafeed name")
|
||||
description: str = Field(description="LLM-readable description of this data source")
|
||||
235
backend/src/datasource/subscription_manager.py
Normal file
235
backend/src/datasource/subscription_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Subscription manager for real-time data feeds.
|
||||
|
||||
Manages subscriptions across multiple data sources and routes updates
|
||||
to WebSocket clients.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, Dict, Optional, Set
|
||||
|
||||
from .base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Subscription:
|
||||
"""Represents a single client subscription"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subscription_id: str,
|
||||
client_id: str,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
callback: Callable[[dict], None],
|
||||
):
|
||||
self.subscription_id = subscription_id
|
||||
self.client_id = client_id
|
||||
self.source_name = source_name
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.callback = callback
|
||||
self.source_subscription_id: Optional[str] = None
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
"""
|
||||
Manages real-time data subscriptions across multiple data sources.
|
||||
|
||||
Handles:
|
||||
- Subscription lifecycle (subscribe/unsubscribe)
|
||||
- Routing updates from data sources to clients
|
||||
- Multiplexing (multiple clients can subscribe to same symbol/resolution)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map subscription_id -> Subscription
|
||||
self._subscriptions: Dict[str, Subscription] = {}
|
||||
|
||||
# Map (source_name, symbol, resolution) -> Set[subscription_id]
|
||||
# For tracking which client subscriptions use which source subscriptions
|
||||
self._source_refs: Dict[tuple, Set[str]] = {}
|
||||
|
||||
# Map source_subscription_id -> (source_name, symbol, resolution)
|
||||
self._source_subs: Dict[str, tuple] = {}
|
||||
|
||||
# Available data sources
|
||||
self._sources: Dict[str, DataSource] = {}
|
||||
|
||||
def register_source(self, name: str, source: DataSource) -> None:
|
||||
"""Register a data source"""
|
||||
self._sources[name] = source
|
||||
|
||||
def unregister_source(self, name: str) -> None:
|
||||
"""Unregister a data source"""
|
||||
self._sources.pop(name, None)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
subscription_id: str,
|
||||
client_id: str,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
callback: Callable[[dict], None],
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe a client to real-time updates.
|
||||
|
||||
Args:
|
||||
subscription_id: Unique ID for this subscription
|
||||
client_id: ID of the subscribing client
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol to subscribe to
|
||||
resolution: Time resolution
|
||||
callback: Function to call with bar updates
|
||||
|
||||
Raises:
|
||||
ValueError: If source not found or subscription fails
|
||||
"""
|
||||
source = self._sources.get(source_name)
|
||||
if not source:
|
||||
raise ValueError(f"Data source '{source_name}' not found")
|
||||
|
||||
# Create subscription record
|
||||
subscription = Subscription(
|
||||
subscription_id=subscription_id,
|
||||
client_id=client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Check if we already have a source subscription for this (source, symbol, resolution)
|
||||
source_key = (source_name, symbol, resolution)
|
||||
if source_key not in self._source_refs:
|
||||
# Need to create a new source subscription
|
||||
try:
|
||||
source_sub_id = await source.subscribe_bars(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
on_tick=lambda bar: self._on_source_update(source_key, bar),
|
||||
)
|
||||
subscription.source_subscription_id = source_sub_id
|
||||
self._source_subs[source_sub_id] = source_key
|
||||
self._source_refs[source_key] = set()
|
||||
logger.info(
|
||||
f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to source: {e}")
|
||||
raise
|
||||
|
||||
# Add this subscription to the reference set
|
||||
self._source_refs[source_key].add(subscription_id)
|
||||
self._subscriptions[subscription_id] = subscription
|
||||
|
||||
logger.info(
|
||||
f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}"
|
||||
)
|
||||
|
||||
async def unsubscribe(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe a client from updates.
|
||||
|
||||
Args:
|
||||
subscription_id: ID of the subscription to remove
|
||||
"""
|
||||
subscription = self._subscriptions.pop(subscription_id, None)
|
||||
if not subscription:
|
||||
logger.warning(f"Subscription {subscription_id} not found")
|
||||
return
|
||||
|
||||
source_key = (subscription.source_name, subscription.symbol, subscription.resolution)
|
||||
|
||||
# Remove from reference set
|
||||
if source_key in self._source_refs:
|
||||
self._source_refs[source_key].discard(subscription_id)
|
||||
|
||||
# If no more clients need this source subscription, unsubscribe from source
|
||||
if not self._source_refs[source_key]:
|
||||
del self._source_refs[source_key]
|
||||
|
||||
if subscription.source_subscription_id:
|
||||
source = self._sources.get(subscription.source_name)
|
||||
if source:
|
||||
try:
|
||||
await source.unsubscribe_bars(subscription.source_subscription_id)
|
||||
logger.info(
|
||||
f"Unsubscribed from source: {subscription.source_subscription_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error unsubscribing from source: {e}")
|
||||
|
||||
self._source_subs.pop(subscription.source_subscription_id, None)
|
||||
|
||||
logger.info(f"Client subscription removed: {subscription_id}")
|
||||
|
||||
async def unsubscribe_client(self, client_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe all subscriptions for a client.
|
||||
|
||||
Useful when a WebSocket connection closes.
|
||||
|
||||
Args:
|
||||
client_id: ID of the client
|
||||
"""
|
||||
# Find all subscriptions for this client
|
||||
client_subs = [
|
||||
sub_id
|
||||
for sub_id, sub in self._subscriptions.items()
|
||||
if sub.client_id == client_id
|
||||
]
|
||||
|
||||
# Unsubscribe each one
|
||||
for sub_id in client_subs:
|
||||
await self.unsubscribe(sub_id)
|
||||
|
||||
logger.info(f"Unsubscribed all subscriptions for client {client_id}")
|
||||
|
||||
def _on_source_update(self, source_key: tuple, bar: dict) -> None:
|
||||
"""
|
||||
Handle an update from a data source.
|
||||
|
||||
Routes the update to all client subscriptions that need it.
|
||||
|
||||
Args:
|
||||
source_key: (source_name, symbol, resolution) tuple
|
||||
bar: Bar data dict from the source
|
||||
"""
|
||||
subscription_ids = self._source_refs.get(source_key, set())
|
||||
|
||||
for sub_id in subscription_ids:
|
||||
subscription = self._subscriptions.get(sub_id)
|
||||
if subscription:
|
||||
try:
|
||||
subscription.callback(bar)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in subscription callback {sub_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
def get_subscription_count(self) -> int:
|
||||
"""Get total number of active client subscriptions"""
|
||||
return len(self._subscriptions)
|
||||
|
||||
def get_source_subscription_count(self) -> int:
|
||||
"""Get total number of active source subscriptions"""
|
||||
return len(self._source_refs)
|
||||
|
||||
def get_client_subscriptions(self, client_id: str) -> list:
|
||||
"""Get all subscriptions for a specific client"""
|
||||
return [
|
||||
{
|
||||
"subscription_id": sub.subscription_id,
|
||||
"source": sub.source_name,
|
||||
"symbol": sub.symbol,
|
||||
"resolution": sub.resolution,
|
||||
}
|
||||
for sub in self._subscriptions.values()
|
||||
if sub.client_id == client_id
|
||||
]
|
||||
347
backend/src/datasource/websocket_handler.py
Normal file
347
backend/src/datasource/websocket_handler.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
WebSocket handler for TradingView-compatible datafeed API.
|
||||
|
||||
Handles incoming requests for symbol search, metadata, historical data,
|
||||
and real-time subscriptions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from .base import DataSource
|
||||
from .registry import DataSourceRegistry
|
||||
from .subscription_manager import SubscriptionManager
|
||||
from .websocket_protocol import (
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
GetBarsRequest,
|
||||
GetBarsResponse,
|
||||
GetConfigRequest,
|
||||
GetConfigResponse,
|
||||
ResolveSymbolRequest,
|
||||
ResolveSymbolResponse,
|
||||
SearchSymbolsRequest,
|
||||
SearchSymbolsResponse,
|
||||
SubscribeBarsRequest,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsRequest,
|
||||
UnsubscribeBarsResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatafeedWebSocketHandler:
|
||||
"""
|
||||
Handles WebSocket connections for TradingView-compatible datafeed API.
|
||||
|
||||
Each handler manages a single WebSocket connection and routes requests
|
||||
to the appropriate data sources via the registry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
registry: DataSourceRegistry,
|
||||
subscription_manager: SubscriptionManager,
|
||||
default_source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize handler.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
client_id: Unique identifier for this client
|
||||
registry: DataSource registry for accessing data sources
|
||||
subscription_manager: Shared subscription manager
|
||||
default_source: Default data source name if not specified in requests
|
||||
"""
|
||||
self.websocket = websocket
|
||||
self.client_id = client_id
|
||||
self.registry = registry
|
||||
self.subscription_manager = subscription_manager
|
||||
self.default_source = default_source
|
||||
self._connected = True
|
||||
|
||||
async def handle_connection(self) -> None:
|
||||
"""
|
||||
Main connection handler loop.
|
||||
|
||||
Processes incoming messages until the connection closes.
|
||||
"""
|
||||
try:
|
||||
await self.websocket.accept()
|
||||
logger.info(f"WebSocket connected: client_id={self.client_id}")
|
||||
|
||||
while self._connected:
|
||||
# Receive message
|
||||
try:
|
||||
data = await self.websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving/parsing message: {e}")
|
||||
break
|
||||
|
||||
# Route to appropriate handler
|
||||
await self._handle_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
finally:
|
||||
# Clean up subscriptions when connection closes
|
||||
await self.subscription_manager.unsubscribe_client(self.client_id)
|
||||
self._connected = False
|
||||
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
|
||||
|
||||
async def _handle_message(self, message: dict) -> None:
|
||||
"""Route message to appropriate handler based on type"""
|
||||
msg_type = message.get("type")
|
||||
request_id = message.get("request_id", "unknown")
|
||||
|
||||
try:
|
||||
if msg_type == "search_symbols":
|
||||
await self._handle_search_symbols(SearchSymbolsRequest(**message))
|
||||
elif msg_type == "resolve_symbol":
|
||||
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
|
||||
elif msg_type == "get_bars":
|
||||
await self._handle_get_bars(GetBarsRequest(**message))
|
||||
elif msg_type == "subscribe_bars":
|
||||
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
|
||||
elif msg_type == "unsubscribe_bars":
|
||||
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
|
||||
elif msg_type == "get_config":
|
||||
await self._handle_get_config(GetConfigRequest(**message))
|
||||
else:
|
||||
await self._send_error(
|
||||
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
||||
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
|
||||
|
||||
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
|
||||
"""Handle symbol search request"""
|
||||
# Use default source or search all sources
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"SOURCE_NOT_FOUND",
|
||||
f"Default source '{self.default_source}' not found",
|
||||
)
|
||||
return
|
||||
|
||||
results = await source.search_symbols(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
results_data = [r.model_dump(mode="json") for r in results]
|
||||
else:
|
||||
# Search all sources
|
||||
all_results = await self.registry.search_all(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
# Flatten results from all sources
|
||||
results_data = []
|
||||
for source_results in all_results.values():
|
||||
results_data.extend([r.model_dump(mode="json") for r in source_results])
|
||||
|
||||
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
|
||||
"""Handle symbol resolution request"""
|
||||
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"NO_SOURCE_SPECIFIED",
|
||||
"No data source specified and no default source configured",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
|
||||
response = ResolveSymbolResponse(
|
||||
request_id=request.request_id,
|
||||
symbol_info=symbol_info.model_dump(mode="json"),
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
|
||||
|
||||
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
|
||||
"""Handle historical bars request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
source = self.registry.get(source_name)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
history = await source.get_bars(
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
from_time=request.from_time,
|
||||
to_time=request.to_time,
|
||||
countback=request.countback,
|
||||
)
|
||||
response = GetBarsResponse(
|
||||
request_id=request.request_id, history=history.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
|
||||
|
||||
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
|
||||
"""Handle real-time subscription request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create callback that sends updates to this WebSocket
|
||||
async def send_update(bar: dict):
|
||||
update = BarUpdateMessage(
|
||||
subscription_id=request.subscription_id,
|
||||
symbol=request.symbol,
|
||||
resolution=request.resolution,
|
||||
bar=bar,
|
||||
)
|
||||
await self._send_response(update)
|
||||
|
||||
# Register subscription
|
||||
await self.subscription_manager.subscribe(
|
||||
subscription_id=request.subscription_id,
|
||||
client_id=self.client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
callback=lambda bar: self._queue_update(send_update(bar)),
|
||||
)
|
||||
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription failed: {e}", exc_info=True)
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
message=str(e),
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
|
||||
"""Handle unsubscribe request"""
|
||||
try:
|
||||
await self.subscription_manager.unsubscribe(request.subscription_id)
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Unsubscribe failed: {e}")
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_get_config(self, request: GetConfigRequest) -> None:
|
||||
"""Handle datafeed config request"""
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
return
|
||||
|
||||
# Return aggregate config from all sources
|
||||
all_sources = self.registry.list_sources()
|
||||
if not all_sources:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCES", "No data sources available"
|
||||
)
|
||||
return
|
||||
|
||||
# Just use first source's config for now
|
||||
# TODO: Aggregate configs from multiple sources
|
||||
source = self.registry.get(all_sources[0])
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _send_response(self, response) -> None:
|
||||
"""Send a response message to the client"""
|
||||
try:
|
||||
await self.websocket.send_json(response.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending response: {e}")
|
||||
self._connected = False
|
||||
|
||||
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
|
||||
"""Send an error response"""
|
||||
error = ErrorResponse(
|
||||
request_id=request_id, error_code=error_code, error_message=error_message
|
||||
)
|
||||
await self._send_response(error)
|
||||
|
||||
def _queue_update(self, coro):
|
||||
"""Queue an async update to be sent (prevents blocking callback)"""
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(coro)
|
||||
170
backend/src/datasource/websocket_protocol.py
Normal file
170
backend/src/datasource/websocket_protocol.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
WebSocket protocol messages for TradingView-compatible datafeed API.
|
||||
|
||||
These messages define the wire format for client-server communication
|
||||
over WebSocket for symbol search, historical data, and real-time subscriptions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client -> Server Messages
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SearchSymbolsRequest(BaseModel):
|
||||
"""Request to search for symbols matching a query"""
|
||||
|
||||
type: Literal["search_symbols"] = "search_symbols"
|
||||
request_id: str = Field(description="Client-generated request ID for matching responses")
|
||||
query: str = Field(description="Search query string")
|
||||
symbol_type: Optional[str] = Field(default=None, description="Filter by instrument type")
|
||||
exchange: Optional[str] = Field(default=None, description="Filter by exchange")
|
||||
limit: int = Field(default=30, description="Maximum number of results")
|
||||
|
||||
|
||||
class ResolveSymbolRequest(BaseModel):
|
||||
"""Request full metadata for a specific symbol"""
|
||||
|
||||
type: Literal["resolve_symbol"] = "resolve_symbol"
|
||||
request_id: str
|
||||
symbol: str = Field(description="Symbol identifier to resolve")
|
||||
|
||||
|
||||
class GetBarsRequest(BaseModel):
|
||||
"""Request historical bar data"""
|
||||
|
||||
type: Literal["get_bars"] = "get_bars"
|
||||
request_id: str
|
||||
symbol: str
|
||||
resolution: str = Field(description="Time resolution (e.g., '1', '5', '60', '1D')")
|
||||
from_time: int = Field(description="Start time (Unix timestamp in seconds)")
|
||||
to_time: int = Field(description="End time (Unix timestamp in seconds)")
|
||||
countback: Optional[int] = Field(default=None, description="Maximum number of bars to return")
|
||||
|
||||
|
||||
class SubscribeBarsRequest(BaseModel):
|
||||
"""Subscribe to real-time bar updates"""
|
||||
|
||||
type: Literal["subscribe_bars"] = "subscribe_bars"
|
||||
request_id: str
|
||||
symbol: str
|
||||
resolution: str
|
||||
subscription_id: str = Field(description="Client-generated subscription ID")
|
||||
|
||||
|
||||
class UnsubscribeBarsRequest(BaseModel):
|
||||
"""Unsubscribe from real-time updates"""
|
||||
|
||||
type: Literal["unsubscribe_bars"] = "unsubscribe_bars"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
|
||||
|
||||
class GetConfigRequest(BaseModel):
|
||||
"""Request datafeed configuration"""
|
||||
|
||||
type: Literal["get_config"] = "get_config"
|
||||
request_id: str
|
||||
|
||||
|
||||
# Union of all client request types
|
||||
ClientRequest = Union[
|
||||
SearchSymbolsRequest,
|
||||
ResolveSymbolRequest,
|
||||
GetBarsRequest,
|
||||
SubscribeBarsRequest,
|
||||
UnsubscribeBarsRequest,
|
||||
GetConfigRequest,
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Server -> Client Messages
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SearchSymbolsResponse(BaseModel):
|
||||
"""Response with search results"""
|
||||
|
||||
type: Literal["search_symbols_response"] = "search_symbols_response"
|
||||
request_id: str
|
||||
results: List[Dict[str, Any]] = Field(description="List of SearchResult objects")
|
||||
|
||||
|
||||
class ResolveSymbolResponse(BaseModel):
|
||||
"""Response with symbol metadata"""
|
||||
|
||||
type: Literal["resolve_symbol_response"] = "resolve_symbol_response"
|
||||
request_id: str
|
||||
symbol_info: Dict[str, Any] = Field(description="SymbolInfo object")
|
||||
|
||||
|
||||
class GetBarsResponse(BaseModel):
|
||||
"""Response with historical bars"""
|
||||
|
||||
type: Literal["get_bars_response"] = "get_bars_response"
|
||||
request_id: str
|
||||
history: Dict[str, Any] = Field(description="HistoryResult object with bars and metadata")
|
||||
|
||||
|
||||
class SubscribeBarsResponse(BaseModel):
|
||||
"""Acknowledgment of subscription"""
|
||||
|
||||
type: Literal["subscribe_bars_response"] = "subscribe_bars_response"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class UnsubscribeBarsResponse(BaseModel):
|
||||
"""Acknowledgment of unsubscribe"""
|
||||
|
||||
type: Literal["unsubscribe_bars_response"] = "unsubscribe_bars_response"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class GetConfigResponse(BaseModel):
|
||||
"""Response with datafeed configuration"""
|
||||
|
||||
type: Literal["get_config_response"] = "get_config_response"
|
||||
request_id: str
|
||||
config: Dict[str, Any] = Field(description="DatafeedConfig object")
|
||||
|
||||
|
||||
class BarUpdateMessage(BaseModel):
|
||||
"""Real-time bar update (server-initiated, no request_id)"""
|
||||
|
||||
type: Literal["bar_update"] = "bar_update"
|
||||
subscription_id: str
|
||||
symbol: str
|
||||
resolution: str
|
||||
bar: Dict[str, Any] = Field(description="Bar data including time and all columns")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response for any failed request"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
request_id: str
|
||||
error_code: str = Field(description="Machine-readable error code")
|
||||
error_message: str = Field(description="Human-readable error description")
|
||||
|
||||
|
||||
# Union of all server response types
|
||||
ServerResponse = Union[
|
||||
SearchSymbolsResponse,
|
||||
ResolveSymbolResponse,
|
||||
GetBarsResponse,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsResponse,
|
||||
GetConfigResponse,
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
]
|
||||
4
backend/src/gateway/__init__.py
Normal file
4
backend/src/gateway/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from gateway.protocol import UserMessage, AgentMessage, ChannelStatus
|
||||
from gateway.hub import Gateway
|
||||
|
||||
__all__ = ["UserMessage", "AgentMessage", "ChannelStatus", "Gateway"]
|
||||
3
backend/src/gateway/channels/__init__.py
Normal file
3
backend/src/gateway/channels/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from gateway.channels.base import Channel
|
||||
|
||||
__all__ = ["Channel"]
|
||||
73
backend/src/gateway/channels/base.py
Normal file
73
backend/src/gateway/channels/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, Dict, Any
|
||||
|
||||
from gateway.protocol import UserMessage, AgentMessage, ChannelStatus
|
||||
|
||||
|
||||
class Channel(ABC):
|
||||
"""Abstract base class for communication channels.
|
||||
|
||||
Channels are the transport layer between users and the agent system.
|
||||
They handle protocol-specific encoding/decoding and maintain connection state.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_id: str, channel_type: str):
|
||||
self.channel_id = channel_id
|
||||
self.channel_type = channel_type
|
||||
self._connected = False
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, message: AgentMessage) -> None:
|
||||
"""Send a message from the agent to the user through this channel.
|
||||
|
||||
Args:
|
||||
message: AgentMessage to send (may be streaming chunk or complete message)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> AsyncIterator[UserMessage]:
|
||||
"""Receive messages from the user through this channel.
|
||||
|
||||
Yields:
|
||||
UserMessage objects as they arrive
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""Whether this channel supports streaming responses.
|
||||
|
||||
Returns:
|
||||
True if the channel can handle streaming chunks
|
||||
"""
|
||||
return False
|
||||
|
||||
def supports_attachments(self) -> bool:
|
||||
"""Whether this channel supports file attachments.
|
||||
|
||||
Returns:
|
||||
True if the channel can handle attachments
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_status(self) -> ChannelStatus:
|
||||
"""Get current channel status.
|
||||
|
||||
Returns:
|
||||
ChannelStatus object with connection info and capabilities
|
||||
"""
|
||||
return ChannelStatus(
|
||||
channel_id=self.channel_id,
|
||||
channel_type=self.channel_type,
|
||||
connected=self._connected,
|
||||
capabilities={
|
||||
"streaming": self.supports_streaming(),
|
||||
"attachments": self.supports_attachments(),
|
||||
"markdown": True # Most channels support some form of markdown
|
||||
}
|
||||
)
|
||||
99
backend/src/gateway/channels/websocket.py
Normal file
99
backend/src/gateway/channels/websocket.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterator, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import WebSocket
|
||||
|
||||
from gateway.channels.base import Channel
|
||||
from gateway.protocol import UserMessage, AgentMessage, WebSocketAgentUserMessage, WebSocketAgentChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketChannel(Channel):
|
||||
"""WebSocket-based communication channel.
|
||||
|
||||
Integrates with the existing WebSocket endpoint to provide
|
||||
bidirectional agent communication with streaming support.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_id: str, websocket: WebSocket, session_id: str):
|
||||
super().__init__(channel_id, "websocket")
|
||||
self.websocket = websocket
|
||||
self.session_id = session_id
|
||||
self._connected = True
|
||||
self._receive_queue: asyncio.Queue[UserMessage] = asyncio.Queue()
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""WebSocket supports streaming responses."""
|
||||
return True
|
||||
|
||||
def supports_attachments(self) -> bool:
|
||||
"""WebSocket can support attachments via URLs."""
|
||||
return True
|
||||
|
||||
async def send(self, message: AgentMessage) -> None:
|
||||
"""Send agent message through WebSocket.
|
||||
|
||||
For streaming messages, sends chunks as they arrive.
|
||||
For complete messages, sends as a single chunk.
|
||||
"""
|
||||
if not self._connected:
|
||||
logger.warning(f"Cannot send message, channel {self.channel_id} not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
chunk = WebSocketAgentChunk(
|
||||
session_id=message.session_id,
|
||||
content=message.content,
|
||||
done=message.done,
|
||||
metadata=message.metadata
|
||||
)
|
||||
chunk_data = chunk.model_dump(mode="json")
|
||||
logger.debug(f"Sending WebSocket message: done={message.done}, content_length={len(message.content)}")
|
||||
await self.websocket.send_json(chunk_data)
|
||||
logger.debug(f"WebSocket message sent successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket send error: {e}", exc_info=True)
|
||||
self._connected = False
|
||||
|
||||
async def receive(self) -> AsyncIterator[UserMessage]:
|
||||
"""Receive messages from WebSocket.
|
||||
|
||||
Yields:
|
||||
UserMessage objects as they arrive from the client
|
||||
"""
|
||||
try:
|
||||
while self._connected:
|
||||
# Read from WebSocket
|
||||
data = await self.websocket.receive_text()
|
||||
message_json = json.loads(data)
|
||||
|
||||
# Only process agent_user_message types
|
||||
if message_json.get("type") == "agent_user_message":
|
||||
msg = WebSocketAgentUserMessage(**message_json)
|
||||
|
||||
user_msg = UserMessage(
|
||||
session_id=msg.session_id,
|
||||
channel_id=self.channel_id,
|
||||
content=msg.content,
|
||||
attachments=msg.attachments,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
yield user_msg
|
||||
except Exception as e:
|
||||
print(f"WebSocket receive error: {e}")
|
||||
self._connected = False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
self._connected = False
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Note: WebSocket close is handled by the main WebSocket endpoint
|
||||
226
backend/src/gateway/hub.py
Normal file
226
backend/src/gateway/hub.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Callable, Awaitable, Any
|
||||
from datetime import datetime
|
||||
|
||||
from gateway.channels.base import Channel
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage, AgentMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""Central hub for routing messages between users, channels, and the agent.
|
||||
|
||||
The Gateway:
|
||||
- Maintains active channels and user sessions
|
||||
- Routes user messages to the agent
|
||||
- Routes agent responses back to appropriate channels
|
||||
- Handles interruption of in-flight agent tasks
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.channels: Dict[str, Channel] = {}
|
||||
self.sessions: Dict[str, UserSession] = {}
|
||||
self.agent_executor: Optional[Callable[[UserSession, UserMessage], Awaitable[AsyncIterator[str]]]] = None
|
||||
|
||||
def set_agent_executor(
|
||||
self,
|
||||
executor: Callable[[UserSession, UserMessage], Awaitable[Any]]
|
||||
) -> None:
|
||||
"""Set the agent executor function.
|
||||
|
||||
Args:
|
||||
executor: Async function that takes (session, message) and returns async iterator of response chunks
|
||||
"""
|
||||
self.agent_executor = executor
|
||||
|
||||
def register_channel(self, channel: Channel) -> None:
|
||||
"""Register a new communication channel.
|
||||
|
||||
Args:
|
||||
channel: Channel instance to register
|
||||
"""
|
||||
self.channels[channel.channel_id] = channel
|
||||
|
||||
def unregister_channel(self, channel_id: str) -> None:
|
||||
"""Unregister a channel.
|
||||
|
||||
Args:
|
||||
channel_id: ID of channel to remove
|
||||
"""
|
||||
if channel_id in self.channels:
|
||||
# Remove channel from any sessions
|
||||
for session in self.sessions.values():
|
||||
session.remove_channel(channel_id)
|
||||
del self.channels[channel_id]
|
||||
|
||||
def get_or_create_session(self, session_id: str, user_id: str) -> UserSession:
|
||||
"""Get existing session or create a new one.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
UserSession instance
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = UserSession(session_id, user_id)
|
||||
return self.sessions[session_id]
|
||||
|
||||
async def route_user_message(self, message: UserMessage) -> None:
|
||||
"""Route a user message to the agent.
|
||||
|
||||
If the agent is currently processing a message, it will be interrupted
|
||||
and restarted with the new message appended to the conversation.
|
||||
|
||||
Args:
|
||||
message: UserMessage from a channel
|
||||
"""
|
||||
logger.info(f"route_user_message called - session: {message.session_id}, channel: {message.channel_id}")
|
||||
|
||||
# Get or create session
|
||||
session = self.get_or_create_session(message.session_id, message.session_id)
|
||||
logger.info(f"Session obtained/created: {message.session_id}, channels: {session.active_channels}")
|
||||
|
||||
# Ensure channel is attached to session
|
||||
session.add_channel(message.channel_id)
|
||||
logger.info(f"Channel added to session: {message.channel_id}")
|
||||
|
||||
# If agent is busy, interrupt it
|
||||
if session.is_busy():
|
||||
logger.info(f"Session is busy, interrupting existing task")
|
||||
await session.interrupt()
|
||||
|
||||
# Add user message to history
|
||||
session.add_message("user", message.content, message.channel_id)
|
||||
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
||||
|
||||
# Start agent task
|
||||
if self.agent_executor:
|
||||
logger.info("Starting agent task execution")
|
||||
task = asyncio.create_task(
|
||||
self._execute_agent_and_stream(session, message)
|
||||
)
|
||||
session.set_task(task)
|
||||
logger.info(f"Agent task created and set on session")
|
||||
else:
|
||||
logger.error("No agent_executor configured! Cannot process message.")
|
||||
# Send error message to user
|
||||
error_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="Error: Agent system not initialized. Please check that ANTHROPIC_API_KEY is configured.",
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(error_msg)
|
||||
|
||||
async def _execute_agent_and_stream(self, session: UserSession, message: UserMessage) -> None:
|
||||
"""Execute agent and stream responses back to channels.
|
||||
|
||||
Args:
|
||||
session: User session
|
||||
message: Triggering user message
|
||||
"""
|
||||
logger.info(f"_execute_agent_and_stream starting for session {session.session_id}")
|
||||
try:
|
||||
# Execute agent (returns async generator directly)
|
||||
logger.info("Calling agent_executor...")
|
||||
response_stream = self.agent_executor(session, message)
|
||||
logger.info("Agent executor returned response stream")
|
||||
|
||||
# Stream chunks back to active channels
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response_stream:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
|
||||
# Send chunk to all active channels
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
|
||||
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
||||
|
||||
# Send final done message
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="",
|
||||
stream_chunk=True,
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
logger.info("Sent final done message to channels")
|
||||
|
||||
# Add to history
|
||||
session.add_message("assistant", full_response)
|
||||
logger.info("Assistant response added to history")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was interrupted
|
||||
logger.warning(f"Agent task interrupted for session {session.session_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution error: {e}", exc_info=True)
|
||||
# Send error message
|
||||
error_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=f"Error: {str(e)}",
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(error_msg)
|
||||
finally:
|
||||
session.set_task(None)
|
||||
logger.info(f"Agent task completed for session {session.session_id}")
|
||||
|
||||
async def _send_to_channels(self, message: AgentMessage) -> None:
|
||||
"""Send message to specified channels.
|
||||
|
||||
Args:
|
||||
message: AgentMessage to send
|
||||
"""
|
||||
logger.debug(f"Sending message to {len(message.target_channels)} channels")
|
||||
for channel_id in message.target_channels:
|
||||
channel = self.channels.get(channel_id)
|
||||
if channel and channel.get_status().connected:
|
||||
try:
|
||||
await channel.send(message)
|
||||
logger.debug(f"Message sent to channel {channel_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to channel {channel_id}: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning(f"Channel {channel_id} not found or not connected")
|
||||
|
||||
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session information.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to query
|
||||
|
||||
Returns:
|
||||
Session info dict or None if not found
|
||||
"""
|
||||
session = self.sessions.get(session_id)
|
||||
return session.to_dict() if session else None
|
||||
|
||||
def get_active_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all active sessions.
|
||||
|
||||
Returns:
|
||||
Dict mapping session_id to session info
|
||||
"""
|
||||
return {
|
||||
sid: session.to_dict()
|
||||
for sid, session in self.sessions.items()
|
||||
}
|
||||
57
backend/src/gateway/protocol.py
Normal file
57
backend/src/gateway/protocol.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""Message from user to agent through a communication channel."""
|
||||
session_id: str
|
||||
channel_id: str
|
||||
content: str
|
||||
attachments: List[str] = Field(default_factory=list, description="URLs or file paths")
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
"""Message from agent to user(s) through one or more channels."""
|
||||
session_id: str
|
||||
target_channels: List[str] = Field(description="List of channel IDs to send to")
|
||||
content: str
|
||||
stream_chunk: bool = Field(default=False, description="True if this is a streaming chunk")
|
||||
done: bool = Field(default=False, description="True if streaming is complete")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelStatus(BaseModel):
|
||||
"""Status information about a communication channel."""
|
||||
channel_id: str
|
||||
channel_type: str = Field(description="Type: 'websocket', 'slack', 'telegram', etc.")
|
||||
connected: bool
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
capabilities: Dict[str, bool] = Field(
|
||||
default_factory=lambda: {
|
||||
"streaming": False,
|
||||
"attachments": False,
|
||||
"markdown": False
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# WebSocket-specific protocol extensions
|
||||
class WebSocketAgentUserMessage(BaseModel):
|
||||
"""WebSocket message: User → Backend (agent chat)"""
|
||||
type: Literal["agent_user_message"] = "agent_user_message"
|
||||
session_id: str
|
||||
content: str
|
||||
attachments: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WebSocketAgentChunk(BaseModel):
|
||||
"""WebSocket message: Backend → User (streaming agent response)"""
|
||||
type: Literal["agent_chunk"] = "agent_chunk"
|
||||
session_id: str
|
||||
content: str
|
||||
done: bool = False
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
107
backend/src/gateway/user_session.py
Normal file
107
backend/src/gateway/user_session.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single message in a conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
channel_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class UserSession:
|
||||
"""Manages a user's conversation session with the agent.
|
||||
|
||||
A session tracks:
|
||||
- Active communication channels
|
||||
- Conversation history
|
||||
- In-flight agent tasks (for interruption)
|
||||
- User metadata
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.active_channels: List[str] = []
|
||||
self.conversation_history: List[Message] = []
|
||||
self.current_task: Optional[asyncio.Task] = None
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def add_channel(self, channel_id: str) -> None:
|
||||
"""Attach a channel to this session."""
|
||||
if channel_id not in self.active_channels:
|
||||
self.active_channels.append(channel_id)
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def remove_channel(self, channel_id: str) -> None:
|
||||
"""Detach a channel from this session."""
|
||||
if channel_id in self.active_channels:
|
||||
self.active_channels.remove(channel_id)
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def add_message(self, role: str, content: str, channel_id: Optional[str] = None, **kwargs) -> None:
|
||||
"""Add a message to conversation history."""
|
||||
message = Message(
|
||||
role=role,
|
||||
content=content,
|
||||
channel_id=channel_id,
|
||||
metadata=kwargs
|
||||
)
|
||||
self.conversation_history.append(message)
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def get_history(self, limit: Optional[int] = None) -> List[Message]:
|
||||
"""Get conversation history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of recent messages to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of Message objects
|
||||
"""
|
||||
if limit:
|
||||
return self.conversation_history[-limit:]
|
||||
return self.conversation_history
|
||||
|
||||
async def interrupt(self) -> bool:
|
||||
"""Interrupt the current agent task if one is running.
|
||||
|
||||
Returns:
|
||||
True if a task was interrupted, False otherwise
|
||||
"""
|
||||
if self.current_task and not self.current_task.done():
|
||||
self.current_task.cancel()
|
||||
try:
|
||||
await self.current_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.current_task = None
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_task(self, task: asyncio.Task) -> None:
|
||||
"""Set the current agent task."""
|
||||
self.current_task = task
|
||||
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if the agent is currently processing a request."""
|
||||
return self.current_task is not None and not self.current_task.done()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize session to dict."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"active_channels": self.active_channels,
|
||||
"message_count": len(self.conversation_history),
|
||||
"is_busy": self.is_busy(),
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_activity": self.last_activity.isoformat(),
|
||||
"metadata": self.metadata
|
||||
}
|
||||
469
backend/src/main.py
Normal file
469
backend/src/main.py
Normal file
@@ -0,0 +1,469 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
import shutil
|
||||
|
||||
from sync.protocol import HelloMessage, PatchMessage
|
||||
from sync.registry import SyncRegistry
|
||||
from gateway.hub import Gateway
|
||||
from gateway.channels.websocket import WebSocketChannel
|
||||
from gateway.protocol import WebSocketAgentUserMessage
|
||||
from agent.core import create_agent
|
||||
from agent.tools import set_registry, set_datasource_registry
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables from .env file (if present)
|
||||
env_path = Path(__file__).parent.parent / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
|
||||
# Load configuration
|
||||
config_path = Path(__file__).parent.parent / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
registry = SyncRegistry()
|
||||
gateway = Gateway()
|
||||
agent_executor = None
|
||||
|
||||
# DataSource infrastructure
|
||||
datasource_registry = DataSourceRegistry()
|
||||
subscription_manager = SubscriptionManager()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize agent system and data sources on startup."""
|
||||
global agent_executor
|
||||
|
||||
# Initialize CCXT data sources
|
||||
try:
|
||||
from datasource.adapters.ccxt_adapter import CCXTDataSource
|
||||
|
||||
# Binance
|
||||
try:
|
||||
binance_source = CCXTDataSource(exchange_id="binance", poll_interval=60)
|
||||
datasource_registry.register("binance", binance_source)
|
||||
subscription_manager.register_source("binance", binance_source)
|
||||
logger.info("DataSource: Registered Binance source")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Binance source: {e}")
|
||||
|
||||
logger.info(f"DataSource infrastructure initialized with sources: {datasource_registry.list_sources()}")
|
||||
except ImportError as e:
|
||||
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
||||
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
||||
|
||||
# Get API keys from environment
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
||||
if not anthropic_api_key:
|
||||
logger.error("ANTHROPIC_API_KEY not found in environment!")
|
||||
logger.info("Agent system will not be available")
|
||||
else:
|
||||
# Set the registries for agent tools
|
||||
set_registry(registry)
|
||||
set_datasource_registry(datasource_registry)
|
||||
|
||||
# Create and initialize agent
|
||||
agent_executor = create_agent(
|
||||
model_name=config["agent"]["model"],
|
||||
temperature=config["agent"]["temperature"],
|
||||
api_key=anthropic_api_key,
|
||||
checkpoint_db_path=config["memory"]["checkpoint_db"],
|
||||
chroma_db_path=config["memory"]["chroma_db"],
|
||||
embedding_model=config["memory"]["embedding_model"],
|
||||
context_docs_dir=config["agent"]["context_docs_dir"],
|
||||
base_dir=".." # Point to project root from backend/src
|
||||
)
|
||||
|
||||
await agent_executor.initialize()
|
||||
|
||||
# Set agent executor in gateway
|
||||
gateway.set_agent_executor(agent_executor.execute)
|
||||
|
||||
logger.info("Agent system initialized")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
if agent_executor and agent_executor.memory_manager:
|
||||
await agent_executor.memory_manager.close()
|
||||
logger.info("Agent system shut down")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Create uploads directory
|
||||
UPLOAD_DIR = Path(__file__).parent.parent / "uploads"
|
||||
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Mount static files for serving uploads
|
||||
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
||||
|
||||
# OrderStore model for synchronization
|
||||
class OrderStore(BaseModel):
|
||||
orders: list[SwapOrder] = []
|
||||
|
||||
# ChartStore model for synchronization
|
||||
class ChartStore(BaseModel):
|
||||
chart_state: ChartState = ChartState()
|
||||
|
||||
# Initialize stores
|
||||
order_store = OrderStore()
|
||||
chart_store = ChartStore()
|
||||
|
||||
# Register with SyncRegistry
|
||||
registry.register(order_store, store_name="OrderStore")
|
||||
registry.register(chart_store, store_name="ChartStore")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
registry.websocket = websocket
|
||||
|
||||
# Create WebSocket channel for agent communication
|
||||
channel_id = f"ws_{id(websocket)}"
|
||||
client_id = f"client_{id(websocket)}"
|
||||
logger.info(f"WebSocket connected - channel_id: {channel_id}, client_id: {client_id}")
|
||||
ws_channel = WebSocketChannel(channel_id, websocket, session_id="default")
|
||||
gateway.register_channel(ws_channel)
|
||||
|
||||
# Helper function to send responses
|
||||
async def send_response(response):
|
||||
try:
|
||||
await websocket.send_json(response.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending response: {e}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
logger.debug(f"Received WebSocket message: {data[:200]}...") # Log first 200 chars
|
||||
message_json = json.loads(data)
|
||||
|
||||
if "type" not in message_json:
|
||||
logger.warning(f"Message missing 'type' field: {message_json}")
|
||||
continue
|
||||
|
||||
msg_type = message_json["type"]
|
||||
logger.info(f"Processing message type: {msg_type}")
|
||||
|
||||
# Handle sync protocol messages
|
||||
if msg_type == "hello":
|
||||
hello_msg = HelloMessage(**message_json)
|
||||
logger.info(f"Hello message received with seqs: {hello_msg.seqs}")
|
||||
await registry.sync_client(hello_msg.seqs)
|
||||
elif msg_type == "patch":
|
||||
patch_msg = PatchMessage(**message_json)
|
||||
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
|
||||
await registry.apply_client_patch(
|
||||
store_name=patch_msg.store,
|
||||
client_base_seq=patch_msg.seq,
|
||||
patch=patch_msg.patch
|
||||
)
|
||||
elif msg_type == "agent_user_message":
|
||||
# Handle agent messages directly here
|
||||
print(f"[DEBUG] Raw message_json: {message_json}")
|
||||
logger.info(f"Raw message_json: {message_json}")
|
||||
msg = WebSocketAgentUserMessage(**message_json)
|
||||
print(f"[DEBUG] Parsed message - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
|
||||
logger.info(f"Agent user message received - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
|
||||
from gateway.protocol import UserMessage
|
||||
from datetime import datetime, timezone
|
||||
|
||||
user_msg = UserMessage(
|
||||
session_id=msg.session_id,
|
||||
channel_id=channel_id,
|
||||
content=msg.content,
|
||||
attachments=msg.attachments,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
logger.info(f"Routing user message to gateway - channel: {channel_id}, session: {msg.session_id}")
|
||||
await gateway.route_user_message(user_msg)
|
||||
logger.info("Message routing completed")
|
||||
|
||||
# Handle datafeed protocol messages
|
||||
elif msg_type in ["get_config", "search_symbols", "resolve_symbol", "get_bars", "subscribe_bars", "unsubscribe_bars"]:
|
||||
from datasource.websocket_protocol import (
|
||||
GetConfigRequest, GetConfigResponse,
|
||||
SearchSymbolsRequest, SearchSymbolsResponse,
|
||||
ResolveSymbolRequest, ResolveSymbolResponse,
|
||||
GetBarsRequest, GetBarsResponse,
|
||||
SubscribeBarsRequest, SubscribeBarsResponse,
|
||||
UnsubscribeBarsRequest, UnsubscribeBarsResponse,
|
||||
ErrorResponse
|
||||
)
|
||||
|
||||
request_id = message_json.get("request_id", "unknown")
|
||||
try:
|
||||
if msg_type == "get_config":
|
||||
req = GetConfigRequest(**message_json)
|
||||
logger.info(f"Getting config, request_id={req.request_id}")
|
||||
sources = datasource_registry.list_sources()
|
||||
logger.info(f"Available sources: {sources}")
|
||||
|
||||
if not sources:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="NO_SOURCES", error_message="No data sources available")
|
||||
await send_response(error_response)
|
||||
else:
|
||||
# Get config from first source (we can enhance this later to aggregate)
|
||||
source = datasource_registry.get(sources[0])
|
||||
if source:
|
||||
try:
|
||||
config = await source.get_config()
|
||||
logger.info(f"Got config from {sources[0]}")
|
||||
# Enhance with all available exchanges
|
||||
all_exchanges = set()
|
||||
for source_name in sources:
|
||||
s = datasource_registry.get(source_name)
|
||||
if s:
|
||||
try:
|
||||
cfg = await asyncio.wait_for(s.get_config(), timeout=5.0)
|
||||
all_exchanges.update(cfg.exchanges)
|
||||
logger.info(f"Added exchanges from {source_name}: {cfg.exchanges}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout getting config from {source_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting config from {source_name}: {e}")
|
||||
config_dict = config.model_dump(mode="json")
|
||||
config_dict["exchanges"] = list(all_exchanges)
|
||||
logger.info(f"Sending config with exchanges: {list(all_exchanges)}")
|
||||
response = GetConfigResponse(request_id=req.request_id, config=config_dict)
|
||||
await send_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting config: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
else:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message="Data sources not available")
|
||||
await send_response(error_response)
|
||||
|
||||
elif msg_type == "search_symbols":
|
||||
req = SearchSymbolsRequest(**message_json)
|
||||
logger.info(f"Searching symbols: query='{req.query}', request_id={req.request_id}")
|
||||
|
||||
# Search all data sources
|
||||
all_results = []
|
||||
sources = datasource_registry.list_sources()
|
||||
logger.info(f"Available data sources: {sources}")
|
||||
|
||||
for source_name in sources:
|
||||
source = datasource_registry.get(source_name)
|
||||
if source:
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
source.search_symbols(query=req.query, type=req.symbol_type, exchange=req.exchange, limit=req.limit),
|
||||
timeout=5.0
|
||||
)
|
||||
all_results.extend([r.model_dump(mode="json") for r in results])
|
||||
logger.info(f"Source '{source_name}' returned {len(results)} results")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout searching source '{source_name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error searching source '{source_name}': {e}")
|
||||
|
||||
logger.info(f"Total search results: {len(all_results)}")
|
||||
response = SearchSymbolsResponse(request_id=req.request_id, results=all_results[:req.limit])
|
||||
await send_response(response)
|
||||
|
||||
elif msg_type == "resolve_symbol":
|
||||
req = ResolveSymbolRequest(**message_json)
|
||||
logger.info(f"Resolving symbol: {req.symbol}")
|
||||
|
||||
# Parse ticker format: "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
symbol = req.symbol
|
||||
source_name = None
|
||||
symbol_without_exchange = symbol
|
||||
|
||||
# Check if ticker has exchange prefix
|
||||
if ":" in symbol:
|
||||
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
logger.info(f"Parsed ticker: exchange={source_name}, symbol={symbol_without_exchange}")
|
||||
|
||||
# If we identified a source, try it directly
|
||||
if source_name:
|
||||
try:
|
||||
source = datasource_registry.get(source_name)
|
||||
if source:
|
||||
logger.info(f"Trying to resolve '{symbol_without_exchange}' in source '{source_name}'")
|
||||
symbol_info = await asyncio.wait_for(
|
||||
source.resolve_symbol(symbol_without_exchange),
|
||||
timeout=5.0
|
||||
)
|
||||
logger.info(f"Successfully resolved '{symbol_without_exchange}' in source '{source_name}'")
|
||||
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
|
||||
await send_response(response)
|
||||
else:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
|
||||
await send_response(error_response)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout resolving '{symbol_without_exchange}' in source '{source_name}'")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message=f"Timeout resolving symbol")
|
||||
await send_response(error_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving '{symbol_without_exchange}' in source '{source_name}': {e}")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
else:
|
||||
# No exchange prefix, try all sources
|
||||
found = False
|
||||
for src in datasource_registry.list_sources():
|
||||
try:
|
||||
s = datasource_registry.get(src)
|
||||
if s:
|
||||
logger.info(f"Trying to resolve '{symbol}' in source '{src}'")
|
||||
symbol_info = await asyncio.wait_for(s.resolve_symbol(symbol), timeout=5.0)
|
||||
if symbol_info:
|
||||
logger.info(f"Successfully resolved '{symbol}' in source '{src}'")
|
||||
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
|
||||
await send_response(response)
|
||||
found = True
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout resolving '{symbol}' in source '{src}'")
|
||||
except Exception as e:
|
||||
logger.info(f"Symbol '{symbol}' not found in source '{src}': {e}")
|
||||
continue
|
||||
|
||||
if not found:
|
||||
# Symbol not found in any source
|
||||
logger.warning(f"Symbol '{symbol}' not found in any data source")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=f"Symbol '{symbol}' not found in any data source")
|
||||
await send_response(error_response)
|
||||
|
||||
elif msg_type == "get_bars":
|
||||
req = GetBarsRequest(**message_json)
|
||||
logger.info(f"Getting bars for symbol: {req.symbol}")
|
||||
|
||||
# Parse ticker format: "EXCHANGE:SYMBOL"
|
||||
symbol = req.symbol
|
||||
source_name = None
|
||||
symbol_without_exchange = symbol
|
||||
|
||||
# Check if ticker has exchange prefix
|
||||
if ":" in symbol:
|
||||
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
logger.info(f"Parsed ticker for bars: exchange={source_name}, symbol={symbol_without_exchange}")
|
||||
|
||||
# If we identified a source, use it directly
|
||||
if source_name:
|
||||
try:
|
||||
source = datasource_registry.get(source_name)
|
||||
if source:
|
||||
logger.info(f"Getting bars for '{symbol_without_exchange}' from source '{source_name}'")
|
||||
history = await asyncio.wait_for(
|
||||
source.get_bars(symbol=symbol_without_exchange, resolution=req.resolution, from_time=req.from_time, to_time=req.to_time, countback=req.countback),
|
||||
timeout=10.0
|
||||
)
|
||||
logger.info(f"Successfully got {len(history.bars)} bars for '{symbol_without_exchange}' from source '{source_name}'")
|
||||
response = GetBarsResponse(request_id=req.request_id, history=history.model_dump(mode="json"))
|
||||
await send_response(response)
|
||||
else:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
|
||||
await send_response(error_response)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout getting bars for '{symbol_without_exchange}' from source '{source_name}'")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message="Timeout fetching bars")
|
||||
await send_response(error_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting bars for '{symbol_without_exchange}' from source '{source_name}': {e}")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
else:
|
||||
# No exchange prefix - this shouldn't happen with proper tickers
|
||||
logger.warning(f"Ticker '{symbol}' has no exchange prefix")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="INVALID_TICKER", error_message="Ticker must include exchange prefix (e.g., BINANCE:BTC/USDT)")
|
||||
await send_response(error_response)
|
||||
|
||||
elif msg_type == "subscribe_bars":
|
||||
req = SubscribeBarsRequest(**message_json)
|
||||
# TODO: Implement subscription management
|
||||
response = SubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
|
||||
await send_response(response)
|
||||
|
||||
elif msg_type == "unsubscribe_bars":
|
||||
req = UnsubscribeBarsRequest(**message_json)
|
||||
response = UnsubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
|
||||
await send_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(request_id=request_id, error_code="INTERNAL_ERROR", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected - channel_id: {channel_id}")
|
||||
registry.websocket = None
|
||||
gateway.unregister_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
registry.websocket = None
|
||||
gateway.unregister_channel(channel_id)
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
"""Upload a file and return its URL."""
|
||||
try:
|
||||
# Generate unique filename
|
||||
file_extension = Path(file.filename).suffix if file.filename else ""
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
file_path = UPLOAD_DIR / unique_filename
|
||||
|
||||
# Save file
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Return URL (relative to backend)
|
||||
file_url = f"/uploads/{unique_filename}"
|
||||
logger.info(f"File uploaded: {file.filename} -> {file_url}")
|
||||
|
||||
return {
|
||||
"url": file_url,
|
||||
"filename": file.filename,
|
||||
"size": file_path.stat().st_size
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"File upload error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/healthz")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# Background task to simulate backend updates (optional, for demo)
|
||||
async def simulate_backend_updates():
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
if registry.websocket:
|
||||
# Example: could add/modify orders here
|
||||
await registry.push_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=config["www_port"])
|
||||
21
backend/src/schema/chart_state.py
Normal file
21
backend/src/schema/chart_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChartState(BaseModel):
|
||||
"""Tracks the user's current chart view state.
|
||||
|
||||
This state is synchronized between the frontend and backend to allow
|
||||
the AI agent to understand what the user is currently viewing.
|
||||
"""
|
||||
|
||||
# Current symbol being viewed (e.g., "BINANCE:BTC/USDT", "BINANCE:ETH/USDT")
|
||||
symbol: str = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol")
|
||||
|
||||
# Time range currently visible on chart (Unix timestamps in seconds)
|
||||
# These represent the leftmost and rightmost visible candle times
|
||||
start_time: Optional[int] = Field(default=None, description="Start time of visible range (Unix timestamp in seconds)")
|
||||
end_time: Optional[int] = Field(default=None, description="End time of visible range (Unix timestamp in seconds)")
|
||||
|
||||
# Optional: Chart interval/resolution
|
||||
interval: str = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D')")
|
||||
140
backend/src/schema/order_spec.py
Normal file
140
backend/src/schema/order_spec.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, BeforeValidator, PlainSerializer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scalar coercion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _to_int(v: int | str) -> int:
|
||||
return int(v, 0) if isinstance(v, str) else int(v)
|
||||
|
||||
|
||||
def _to_float(v: float | int | str) -> float:
|
||||
return float(v)
|
||||
|
||||
|
||||
_int_to_str = PlainSerializer(str, return_type=str, when_used="json")
|
||||
_float_to_str = PlainSerializer(str, return_type=str, when_used="json")
|
||||
|
||||
# Always stored as Python int; accepts int or string on input; serialises to string in JSON.
|
||||
type Uint8 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint16 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint24 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint32 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint64 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint256 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Float = Annotated[float, BeforeValidator(_to_float), _float_to_str]
|
||||
|
||||
ETH_ADDRESS_PATTERN = r"^0x[0-9a-fA-F]{40}$"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Exchange(StrEnum):
|
||||
UNISWAP_V2 = "UniswapV2"
|
||||
UNISWAP_V3 = "UniswapV3"
|
||||
|
||||
|
||||
class OcoMode(StrEnum):
|
||||
NO_OCO = "NO_OCO"
|
||||
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
|
||||
CANCEL_ON_COMPLETION = "CANCEL_ON_COMPLETION"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supporting models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Route(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
exchange: Exchange
|
||||
fee: Uint24 = Field(description="Pool fee tier; also used as maxFee on UniswapV3")
|
||||
|
||||
|
||||
class Line(BaseModel):
|
||||
"""Price line: price = intercept + slope * time. Both zero means line is disabled."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intercept: Float
|
||||
slope: Float
|
||||
|
||||
|
||||
class Tranche(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
fraction: Uint16 = Field(description="Fraction of total order amount; MAX_FRACTION (65535) = 100%")
|
||||
startTimeIsRelative: bool
|
||||
endTimeIsRelative: bool
|
||||
minIsBarrier: bool = Field(description="Not yet supported")
|
||||
maxIsBarrier: bool = Field(description="Not yet supported")
|
||||
marketOrder: bool = Field(
|
||||
description="If true, min/max lines ignored; minLine intercept treated as max slippage"
|
||||
)
|
||||
minIsRatio: bool
|
||||
maxIsRatio: bool
|
||||
rateLimitFraction: Uint16 = Field(description="Max fraction of this tranche's amount per rate-limited execution")
|
||||
rateLimitPeriod: Uint24 = Field(description="Seconds between rate limit resets")
|
||||
startTime: Uint32 = Field(description="Unix timestamp; 0 (DISTANT_PAST) effectively disables")
|
||||
endTime: Uint32 = Field(description="Unix timestamp; 4294967295 (DISTANT_FUTURE) effectively disables")
|
||||
minLine: Line = Field(description="Traditional limit order constraint; can be diagonal")
|
||||
maxLine: Line = Field(description="Upper price boundary (too-good-a-price guard)")
|
||||
|
||||
|
||||
class TrancheStatus(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
filled: Uint256 = Field(description="Amount filled by this tranche")
|
||||
activationTime: Uint32 = Field(description="Earliest time this tranche can execute; 0 = not yet concrete")
|
||||
startTime: Uint32 = Field(description="Concrete start timestamp")
|
||||
endTime: Uint32 = Field(description="Concrete end timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SwapOrder(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
tokenIn: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 input token address")
|
||||
tokenOut: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 output token address")
|
||||
route: Route
|
||||
amount: Uint256 = Field(description="Maximum quantity to fill")
|
||||
minFillAmount: Uint256 = Field(description="Minimum tranche amount before tranche is considered complete")
|
||||
amountIsInput: bool = Field(description="true = amount is tokenIn quantity; false = tokenOut")
|
||||
outputDirectlyToOwner: bool = Field(description="true = proceeds go to vault owner; false = vault")
|
||||
inverted: bool = Field(description="false = tokenIn/tokenOut price direction (Uniswap natural)")
|
||||
conditionalOrder: Uint64 = Field(
|
||||
description="NO_CONDITIONAL_ORDER = 2^64-1; high bit set = relative index within placement group"
|
||||
)
|
||||
tranches: list[Tranche] = Field(min_length=1)
|
||||
|
||||
|
||||
class OcoGroup(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
orders: list[SwapOrder] = Field(min_length=1)
|
||||
|
||||
|
||||
class SwapOrderStatus(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
order: SwapOrder
|
||||
fillFeeHalfBps: Uint8 = Field(description="Fill fee in half-bps (1/20000); max 255 = 1.275%")
|
||||
canceled: bool = Field(description="If true, order is canceled regardless of cancelAllIndex")
|
||||
startTime: Uint32 = Field(description="Earliest block.timestamp at which order may execute")
|
||||
ocoGroup: Uint64 = Field(description="Index into ocoGroups; NO_OCO_INDEX = 2^64-1")
|
||||
originalOrder: Uint64 = Field(description="Index of the original order in the orders array")
|
||||
startPrice: Uint256 = Field(description="Price at order start")
|
||||
filled: Uint256 = Field(description="Total amount filled so far")
|
||||
trancheStatus: list[TrancheStatus]
|
||||
|
||||
|
||||
26
backend/src/sync/protocol.py
Normal file
26
backend/src/sync/protocol.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SnapshotMessage(BaseModel):
|
||||
type: Literal["snapshot"] = "snapshot"
|
||||
store: str
|
||||
seq: int
|
||||
state: Dict[str, Any]
|
||||
|
||||
class PatchMessage(BaseModel):
|
||||
type: Literal["patch"] = "patch"
|
||||
store: str
|
||||
seq: int
|
||||
patch: List[Dict[str, Any]]
|
||||
|
||||
class HelloMessage(BaseModel):
|
||||
type: Literal["hello"] = "hello"
|
||||
seqs: Dict[str, int]
|
||||
|
||||
# Union type for all messages from backend to frontend
|
||||
BackendMessage = Union[SnapshotMessage, PatchMessage]
|
||||
|
||||
# Union type for all messages from frontend to backend
|
||||
FrontendMessage = Union[HelloMessage, PatchMessage]
|
||||
174
backend/src/sync/registry.py
Normal file
174
backend/src/sync/registry.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional, Tuple, Deque
|
||||
|
||||
import jsonpatch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sync.protocol import SnapshotMessage, PatchMessage
|
||||
|
||||
|
||||
class SyncEntry:
|
||||
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
|
||||
self.model = model
|
||||
self.store_name = store_name
|
||||
self.seq = 0
|
||||
self.last_snapshot = model.model_dump(mode="json")
|
||||
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
|
||||
|
||||
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
|
||||
current_state = self.model.model_dump(mode="json")
|
||||
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
|
||||
if not patch.patch:
|
||||
return None
|
||||
return patch.patch
|
||||
|
||||
def commit_patch(self, patch: List[Dict[str, Any]]):
|
||||
self.seq += 1
|
||||
self.history.append((self.seq, patch))
|
||||
self.last_snapshot = self.model.model_dump(mode="json")
|
||||
|
||||
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
|
||||
if since_seq == self.seq:
|
||||
return []
|
||||
|
||||
# Check if all patches from since_seq + 1 to self.seq are in history
|
||||
if not self.history or self.history[0][0] > since_seq + 1:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for seq, patch in self.history:
|
||||
if seq > since_seq:
|
||||
result.append((seq, patch))
|
||||
return result
|
||||
|
||||
class SyncRegistry:
|
||||
def __init__(self):
|
||||
self.entries: Dict[str, SyncEntry] = {}
|
||||
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
|
||||
|
||||
def register(self, model: BaseModel, store_name: str):
|
||||
self.entries[store_name] = SyncEntry(model, store_name)
|
||||
|
||||
async def push_all(self):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not self.websocket:
|
||||
logger.warning("push_all: No websocket connected, cannot push updates")
|
||||
return
|
||||
|
||||
logger.info(f"push_all: Processing {len(self.entries)} store entries")
|
||||
for entry in self.entries.values():
|
||||
patch = entry.compute_patch()
|
||||
if patch:
|
||||
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
|
||||
entry.commit_patch(patch)
|
||||
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
|
||||
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
|
||||
else:
|
||||
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
|
||||
|
||||
async def sync_client(self, client_seqs: Dict[str, int]):
|
||||
if not self.websocket:
|
||||
return
|
||||
|
||||
for store_name, entry in self.entries.items():
|
||||
client_seq = client_seqs.get(store_name, -1)
|
||||
patches = entry.catchup_patches(client_seq)
|
||||
|
||||
if patches is not None:
|
||||
# Replay patches
|
||||
for seq, patch in patches:
|
||||
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
else:
|
||||
# Send full snapshot
|
||||
msg = SnapshotMessage(
|
||||
store=store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
|
||||
|
||||
entry = self.entries.get(store_name)
|
||||
if not entry:
|
||||
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
|
||||
return
|
||||
|
||||
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
|
||||
|
||||
if client_base_seq == entry.seq:
|
||||
# No conflict
|
||||
logger.info("apply_client_patch: No conflict - applying patch directly")
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
# Broadcasting would cause an infinite loop
|
||||
logger.info("apply_client_patch: Not broadcasting back to originating client")
|
||||
|
||||
elif client_base_seq < entry.seq:
|
||||
# Conflict! Frontend wins.
|
||||
# 1. Get backend patches since client_base_seq
|
||||
backend_patches = []
|
||||
for seq, p in entry.history:
|
||||
if seq > client_base_seq:
|
||||
backend_patches.append(p)
|
||||
|
||||
# 2. Apply frontend patch first to the state at client_base_seq
|
||||
# But we only have the current authoritative model.
|
||||
# "Apply the frontend patch first to the model (frontend wins)"
|
||||
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
|
||||
|
||||
# Let's get the state as it was at client_base_seq if possible?
|
||||
# No, history only has patches.
|
||||
|
||||
# Alternative: Apply frontend patch to current model.
|
||||
# Then re-apply backend patches, but discard parts that overlap.
|
||||
|
||||
frontend_paths = {p['path'] for p in patch}
|
||||
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
# Apply frontend patch
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
|
||||
# Re-apply backend patches that don't overlap
|
||||
for b_patch in backend_patches:
|
||||
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
|
||||
if filtered_b_patch:
|
||||
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
|
||||
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Commit the result as a single new patch
|
||||
# We need to compute what changed from last_snapshot to new_state
|
||||
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
|
||||
if final_patch:
|
||||
entry.commit_patch(final_patch)
|
||||
# Broadcast resolved state as snapshot to converge
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
||||
# Update model using model_validate for potentially nested models
|
||||
new_model = model.__class__.model_validate(new_data)
|
||||
for field in model.model_fields:
|
||||
setattr(model, field, getattr(new_model, field))
|
||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
54
backend/tests/test_websocket.py
Normal file
54
backend/tests/test_websocket.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
async def test_websocket_sync():
|
||||
uri = "ws://localhost:8000/ws"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# 1. Send hello
|
||||
hello = {
|
||||
"type": "hello",
|
||||
"seqs": {}
|
||||
}
|
||||
await websocket.send(json.dumps(hello))
|
||||
|
||||
# 2. Receive snapshots
|
||||
# Expecting TraderState and StrategyState
|
||||
responses = []
|
||||
for _ in range(2):
|
||||
resp = await websocket.recv()
|
||||
responses.append(json.loads(resp))
|
||||
|
||||
assert any(r["store"] == "TraderState" for r in responses)
|
||||
assert any(r["store"] == "StrategyState" for r in responses)
|
||||
|
||||
# 3. Send a patch for TraderState
|
||||
trader_resp = next(r for r in responses if r["store"] == "TraderState")
|
||||
current_seq = trader_resp["seq"]
|
||||
|
||||
patch_msg = {
|
||||
"type": "patch",
|
||||
"store": "TraderState",
|
||||
"seq": current_seq,
|
||||
"patch": [{"op": "replace", "path": "/status", "value": "busy"}]
|
||||
}
|
||||
await websocket.send(json.dumps(patch_msg))
|
||||
|
||||
# 4. Receive confirmation patch
|
||||
confirm_resp = await websocket.recv()
|
||||
confirm_json = json.loads(confirm_resp)
|
||||
assert confirm_json["type"] == "patch"
|
||||
assert confirm_json["store"] == "TraderState"
|
||||
assert confirm_json["seq"] == current_seq + 1
|
||||
assert confirm_json["patch"][0]["value"] == "busy"
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This script requires the server to be running:
|
||||
# PYTHONPATH=backend/src python3 backend/src/main.py
|
||||
try:
|
||||
asyncio.run(test_websocket_sync())
|
||||
print("Test passed!")
|
||||
except Exception as e:
|
||||
print(f"Test failed: {e}")
|
||||
Reference in New Issue
Block a user