initial commit with charts and assistant chat

This commit is contained in:
2026-03-02 00:08:19 -04:00
commit d907c5765e
1828 changed files with 50054 additions and 0 deletions

20
backend/config.yaml Normal file
View File

@@ -0,0 +1,20 @@
www_port: 8080
server_port: 8081
# Agent configuration
agent:
model: "claude-sonnet-4-20250514"
temperature: 0.7
context_docs_dir: "doc"
# Local memory configuration (free & sophisticated!)
memory:
# LangGraph checkpointing (SQLite for conversation state)
checkpoint_db: "data/checkpoints.db"
# ChromaDB (local vector DB for semantic search)
chroma_db: "data/chroma"
# Sentence-transformers model (local embeddings)
# Options: all-MiniLM-L6-v2 (fast, small), all-mpnet-base-v2 (better quality)
embedding_model: "all-MiniLM-L6-v2"

32
backend/requirements.txt Normal file
View File

@@ -0,0 +1,32 @@
pydantic2
seaborn
pandas
numpy
scipy
matplotlib
fastapi
uvicorn
websockets
jsonpatch
python-multipart
ccxt>=4.0.0
pyyaml
# LangChain agent dependencies
langchain>=0.3.0
langgraph>=0.2.0
langgraph-checkpoint-sqlite>=1.0.0
langchain-anthropic>=0.3.0
langchain-community>=0.3.0
# Local memory system
chromadb>=0.4.0
sentence-transformers>=2.0.0
sqlalchemy>=2.0.0
aiosqlite>=0.19.0
# Async utilities
aiofiles>=24.0.0
# Environment configuration
python-dotenv>=1.0.0

View File

@@ -0,0 +1,3 @@
from agent.core import create_agent
__all__ = ["create_agent"]

296
backend/src/agent/core.py Normal file
View File

@@ -0,0 +1,296 @@
import asyncio
import logging
from typing import AsyncIterator, Dict, Any, Optional
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS
from agent.memory import MemoryManager
from agent.session import SessionManager
from agent.prompts import build_system_prompt
from gateway.user_session import UserSession
from gateway.protocol import UserMessage as GatewayUserMessage
logger = logging.getLogger(__name__)
class AgentExecutor:
"""LangGraph-based agent executor with streaming support.
Handles agent invocation, tool execution, and response streaming.
Supports interruption for real-time user interaction.
"""
def __init__(
self,
model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7,
api_key: Optional[str] = None,
memory_manager: Optional[MemoryManager] = None
):
"""Initialize agent executor.
Args:
model_name: Anthropic model name
temperature: Model temperature
api_key: Anthropic API key
memory_manager: MemoryManager instance
"""
self.model_name = model_name
self.temperature = temperature
self.api_key = api_key
# Initialize LLM
self.llm = ChatAnthropic(
model=model_name,
temperature=temperature,
api_key=api_key,
streaming=True
)
# Memory and session management
self.memory_manager = memory_manager or MemoryManager()
self.session_manager = SessionManager(self.memory_manager)
self.agent = None # Will be created after initialization
async def initialize(self) -> None:
"""Initialize the agent system."""
await self.memory_manager.initialize()
# Create agent with tools and LangGraph checkpointing
checkpointer = self.memory_manager.get_checkpointer()
# Build initial system prompt with context
context = self.memory_manager.get_context_prompt()
system_prompt = build_system_prompt(context, [])
self.agent = create_react_agent(
self.llm,
SYNC_TOOLS + DATASOURCE_TOOLS,
prompt=system_prompt,
checkpointer=checkpointer
)
async def _clear_checkpoint(self, session_id: str) -> None:
"""Clear the checkpoint for a session to prevent resuming from invalid state.
This is called when an error occurs during agent execution to ensure
the next interaction starts fresh instead of trying to resume from
a broken state (e.g., orphaned tool calls).
Args:
session_id: The session ID whose checkpoint should be cleared
"""
try:
checkpointer = self.memory_manager.get_checkpointer()
if checkpointer:
# Delete all checkpoints for this thread
# LangGraph uses thread_id as the key
# The checkpointer API doesn't have a direct delete method,
# but we can use the underlying connection
async with checkpointer.conn.cursor() as cur:
await cur.execute(
"DELETE FROM checkpoints WHERE thread_id = ?",
(session_id,)
)
await checkpointer.conn.commit()
logger.info(f"Cleared checkpoint for session {session_id}")
except Exception as e:
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage:
"""Build system message with context.
Args:
state: Agent state
Returns:
SystemMessage with full context
"""
# Get context from loaded documents
context = self.memory_manager.get_context_prompt()
# Get active channels from metadata
active_channels = state.get("metadata", {}).get("active_channels", [])
# Build system prompt
system_prompt = build_system_prompt(context, active_channels)
return SystemMessage(content=system_prompt)
async def execute(
self,
session: UserSession,
message: GatewayUserMessage
) -> AsyncIterator[str]:
"""Execute the agent and stream responses.
Args:
session: User session
message: User message
Yields:
Response chunks as they're generated
"""
logger.info(f"AgentExecutor.execute called for session {session.session_id}")
# Get session lock to prevent concurrent execution
lock = await self.session_manager.get_session_lock(session.session_id)
logger.info(f"Session lock acquired for {session.session_id}")
async with lock:
try:
# Build message history
messages = []
history = session.get_history(limit=10)
logger.info(f"Building message history, {len(history)} messages in history")
for i, msg in enumerate(history):
logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'")
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
logger.info(f"Prepared {len(messages)} messages for agent")
for i, msg in enumerate(messages):
logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'")
# Prepare config with metadata
config = RunnableConfig(
configurable={
"thread_id": session.session_id
},
metadata={
"session_id": session.session_id,
"user_id": session.user_id,
"active_channels": session.active_channels
}
)
logger.info(f"Agent config prepared: thread_id={session.session_id}")
# Invoke agent with streaming
logger.info("Starting agent.astream_events...")
full_response = ""
event_count = 0
chunk_count = 0
async for event in self.agent.astream_events(
{"messages": messages},
config=config,
version="v2"
):
event_count += 1
# Check for cancellation
if asyncio.current_task().cancelled():
logger.warning("Agent execution cancelled")
break
# Log tool calls
if event["event"] == "on_tool_start":
tool_name = event.get("name", "unknown")
tool_input = event.get("data", {}).get("input", {})
logger.info(f"Tool call started: {tool_name} with input: {tool_input}")
elif event["event"] == "on_tool_end":
tool_name = event.get("name", "unknown")
tool_output = event.get("data", {}).get("output")
logger.info(f"Tool call completed: {tool_name} with output: {tool_output}")
# Extract streaming tokens
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"]
if hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Handle both string and list content
if isinstance(content, list):
# Extract text from content blocks
text_parts = []
for block in content:
if isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
elif hasattr(block, "text"):
text_parts.append(block.text)
content = "".join(text_parts)
if content: # Only yield non-empty content
full_response += content
chunk_count += 1
logger.debug(f"Yielding content chunk #{chunk_count}")
yield content
logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars")
# Save to persistent memory
logger.info(f"Saving assistant message to persistent memory")
await self.session_manager.save_message(
session.session_id,
"assistant",
full_response
)
logger.info("Assistant message saved to memory")
except asyncio.CancelledError:
logger.warning(f"Agent execution cancelled for session {session.session_id}")
# Clear checkpoint on cancellation to prevent orphaned tool calls
await self._clear_checkpoint(session.session_id)
raise
except Exception as e:
error_msg = f"Agent execution error: {str(e)}"
logger.error(error_msg, exc_info=True)
# Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls)
# This ensures the next interaction starts fresh instead of trying to resume
# from a broken state
await self._clear_checkpoint(session.session_id)
yield error_msg
def create_agent(
model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7,
api_key: Optional[str] = None,
checkpoint_db_path: str = "data/checkpoints.db",
chroma_db_path: str = "data/chroma",
embedding_model: str = "all-MiniLM-L6-v2",
context_docs_dir: str = "doc",
base_dir: str = "."
) -> AgentExecutor:
"""Create and initialize an agent executor.
Args:
model_name: Anthropic model name
temperature: Model temperature
api_key: Anthropic API key
checkpoint_db_path: Path to LangGraph checkpoint SQLite DB
chroma_db_path: Path to ChromaDB storage directory
embedding_model: Sentence-transformers model name
context_docs_dir: Directory with context markdown files
base_dir: Base directory for resolving paths
Returns:
Initialized AgentExecutor
"""
# Initialize memory manager
memory_manager = MemoryManager(
checkpoint_db_path=checkpoint_db_path,
chroma_db_path=chroma_db_path,
embedding_model=embedding_model,
context_docs_dir=context_docs_dir,
base_dir=base_dir
)
# Create executor
executor = AgentExecutor(
model_name=model_name,
temperature=temperature,
api_key=api_key,
memory_manager=memory_manager
)
return executor

380
backend/src/agent/memory.py Normal file
View File

@@ -0,0 +1,380 @@
import os
import glob
from typing import List, Dict, Any, Optional
from datetime import datetime
import aiofiles
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
# Prevent ChromaDB from reporting telemetry to the mothership
os.environ["ANONYMIZED_TELEMETRY"] = "False"
class MemoryManager:
"""Manages persistent memory using local tools:
- LangGraph checkpointing (SQLite) for conversation state
- ChromaDB for semantic memory search
- Local sentence-transformers for embeddings
- Memory graph approach for clustering related concepts
"""
def __init__(
self,
checkpoint_db_path: str = "data/checkpoints.db",
chroma_db_path: str = "data/chroma",
embedding_model: str = "all-MiniLM-L6-v2",
context_docs_dir: str = "memory",
base_dir: str = "."
):
"""Initialize memory manager.
Args:
checkpoint_db_path: Path to SQLite checkpoint database
chroma_db_path: Path to ChromaDB directory
embedding_model: Sentence-transformers model name
context_docs_dir: Directory containing markdown context files
base_dir: Base directory for resolving relative paths
"""
self.checkpoint_db_path = checkpoint_db_path
self.chroma_db_path = chroma_db_path
self.embedding_model_name = embedding_model
self.context_docs_dir = os.path.join(base_dir, context_docs_dir)
# Will be initialized on startup
self.checkpointer: Optional[AsyncSqliteSaver] = None
self.checkpointer_context: Optional[Any] = None # Store the context manager
self.chroma_client: Optional[chromadb.Client] = None
self.memory_collection: Optional[Any] = None
self.embedding_model: Optional[SentenceTransformer] = None
self.context_documents: Dict[str, str] = {}
self.initialized = False
async def initialize(self) -> None:
"""Initialize the memory system and load context documents."""
if self.initialized:
return
# Ensure data directories exist
os.makedirs(os.path.dirname(self.checkpoint_db_path), exist_ok=True)
os.makedirs(self.chroma_db_path, exist_ok=True)
# Initialize LangGraph checkpointer (SQLite)
self.checkpointer_context = AsyncSqliteSaver.from_conn_string(
self.checkpoint_db_path
)
self.checkpointer = await self.checkpointer_context.__aenter__()
await self.checkpointer.setup()
# Initialize ChromaDB
self.chroma_client = chromadb.PersistentClient(
path=self.chroma_db_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# Get or create memory collection
self.memory_collection = self.chroma_client.get_or_create_collection(
name="conversation_memory",
metadata={"description": "Semantic memory for conversations"}
)
# Initialize local embedding model
print(f"Loading embedding model: {self.embedding_model_name}")
self.embedding_model = SentenceTransformer(self.embedding_model_name)
# Load markdown context documents
await self._load_context_documents()
# Index context documents in ChromaDB
await self._index_context_documents()
self.initialized = True
print("Memory system initialized (LangGraph + ChromaDB + local embeddings)")
async def _load_context_documents(self) -> None:
"""Load all markdown files from context directory."""
if not os.path.exists(self.context_docs_dir):
print(f"Warning: Context directory {self.context_docs_dir} not found")
return
md_files = glob.glob(os.path.join(self.context_docs_dir, "*.md"))
for md_file in md_files:
try:
async with aiofiles.open(md_file, "r", encoding="utf-8") as f:
content = await f.read()
filename = os.path.basename(md_file)
self.context_documents[filename] = content
print(f"Loaded context document: {filename}")
except Exception as e:
print(f"Error loading {md_file}: {e}")
async def _index_context_documents(self) -> None:
"""Index context documents in ChromaDB for semantic search."""
if not self.context_documents or not self.memory_collection:
return
for filename, content in self.context_documents.items():
# Split into sections (by headers)
sections = self._split_document_into_sections(content, filename)
for i, section in enumerate(sections):
doc_id = f"context_{filename}_{i}"
# Generate embedding
embedding = self.embedding_model.encode(section["content"]).tolist()
# Add to ChromaDB
self.memory_collection.add(
ids=[doc_id],
embeddings=[embedding],
documents=[section["content"]],
metadatas=[{
"type": "context",
"source": filename,
"section": section["title"],
"indexed_at": datetime.utcnow().isoformat()
}]
)
print(f"Indexed {len(self.context_documents)} context documents")
def _split_document_into_sections(self, content: str, filename: str) -> List[Dict[str, str]]:
"""Split markdown document into logical sections.
Args:
content: Markdown content
filename: Source filename
Returns:
List of section dicts with title and content
"""
sections = []
current_section = {"title": filename, "content": ""}
for line in content.split("\n"):
if line.startswith("#"):
# New section
if current_section["content"].strip():
sections.append(current_section)
current_section = {
"title": line.strip("#").strip(),
"content": line + "\n"
}
else:
current_section["content"] += line + "\n"
# Add last section
if current_section["content"].strip():
sections.append(current_section)
return sections
def get_context_prompt(self) -> str:
"""Generate a context prompt from loaded documents.
system_prompt.md is ALWAYS included first and prioritized.
Other documents are included after.
Returns:
Formatted string containing all context documents
"""
if not self.context_documents:
return ""
sections = []
# ALWAYS include system_prompt.md first if it exists
system_prompt_key = "system_prompt.md"
if system_prompt_key in self.context_documents:
sections.append(self.context_documents[system_prompt_key])
sections.append("\n---\n")
# Add other context documents
sections.append("# Additional Context\n")
sections.append("The following documents provide additional context about the system:\n")
for filename, content in sorted(self.context_documents.items()):
# Skip system_prompt.md since we already added it
if filename == system_prompt_key:
continue
sections.append(f"\n## {filename}\n")
sections.append(content)
return "\n".join(sections)
async def add_memory(
self,
session_id: str,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Add a message to semantic memory (ChromaDB).
Args:
session_id: Session identifier
role: Message role ("user" or "assistant")
content: Message content
metadata: Optional metadata
"""
if not self.memory_collection or not self.embedding_model:
return
try:
# Generate unique ID
timestamp = datetime.utcnow().isoformat()
doc_id = f"{session_id}_{role}_{timestamp}"
# Generate embedding
embedding = self.embedding_model.encode(content).tolist()
# Prepare metadata
meta = {
"session_id": session_id,
"role": role,
"timestamp": timestamp,
"type": "conversation",
**(metadata or {})
}
# Add to ChromaDB
self.memory_collection.add(
ids=[doc_id],
embeddings=[embedding],
documents=[content],
metadatas=[meta]
)
except Exception as e:
print(f"Error adding to ChromaDB memory: {e}")
async def search_memory(
self,
session_id: str,
query: str,
limit: int = 5,
include_context: bool = True
) -> List[Dict[str, Any]]:
"""Search memory using semantic similarity.
Args:
session_id: Session identifier (filters to this session + context docs)
query: Search query
limit: Maximum results
include_context: Whether to include context documents in search
Returns:
List of relevant memory items with content and metadata
"""
if not self.memory_collection or not self.embedding_model:
return []
try:
# Generate query embedding
query_embedding = self.embedding_model.encode(query).tolist()
# Build where filter
where_filters = []
if include_context:
# Search both session messages and context docs
where_filters = {
"$or": [
{"session_id": session_id},
{"type": "context"}
]
}
else:
where_filters = {"session_id": session_id}
# Query ChromaDB
results = self.memory_collection.query(
query_embeddings=[query_embedding],
n_results=limit,
where=where_filters if where_filters else None
)
# Format results
memories = []
if results and results["documents"]:
for i, doc in enumerate(results["documents"][0]):
memories.append({
"content": doc,
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i] if "distances" in results else None
})
return memories
except Exception as e:
print(f"Error searching ChromaDB memory: {e}")
return []
async def get_memory_graph(
self,
session_id: str,
max_depth: int = 2
) -> Dict[str, Any]:
"""Get a graph of related memories using clustering.
This creates a simple memory graph by finding clusters of related concepts.
Args:
session_id: Session identifier
max_depth: Maximum depth for related memory traversal
Returns:
Dict representing memory graph structure
"""
# Simple implementation: get all memories for session and cluster by similarity
if not self.memory_collection:
return {}
try:
# Get all memories for this session
results = self.memory_collection.get(
where={"session_id": session_id},
include=["embeddings", "documents", "metadatas"]
)
if not results or not results["documents"]:
return {"nodes": [], "edges": []}
# Build simple graph structure
nodes = []
edges = []
for i, doc in enumerate(results["documents"]):
nodes.append({
"id": results["ids"][i],
"content": doc,
"metadata": results["metadatas"][i]
})
# TODO: Compute edges based on embedding similarity
# For now, return just nodes
return {"nodes": nodes, "edges": edges}
except Exception as e:
print(f"Error building memory graph: {e}")
return {}
def get_checkpointer(self) -> Optional[AsyncSqliteSaver]:
"""Get the LangGraph checkpointer for conversation state.
Returns:
AsyncSqliteSaver instance for LangGraph persistence
"""
return self.checkpointer
async def close(self) -> None:
"""Close the memory manager and cleanup resources."""
if self.checkpointer_context:
await self.checkpointer_context.__aexit__(None, None, None)
self.checkpointer = None
self.checkpointer_context = None

View File

@@ -0,0 +1,57 @@
from typing import List
from gateway.user_session import UserSession
def build_system_prompt(context: str, active_channels: List[str]) -> str:
"""Build the system prompt for the agent.
The main system prompt comes from system_prompt.md (loaded in context).
This function adds dynamic session information.
Args:
context: Context from loaded markdown documents (includes system_prompt.md)
active_channels: List of active channel IDs for this session
Returns:
Formatted system prompt
"""
channels_str = ", ".join(active_channels) if active_channels else "none"
# Context already includes system_prompt.md and other docs
# Just add current session information
prompt = f"""{context}
## Current Session Information
**Active Channels**: {channels_str}
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
If the user sends a new message while you're responding, your current response will be interrupted
and you'll be re-invoked with the updated context.
"""
return prompt
def build_user_prompt_with_history(session: UserSession, current_message: str) -> str:
"""Build a user prompt including conversation history.
Args:
session: User session with conversation history
current_message: Current user message
Returns:
Formatted prompt with history
"""
messages = []
# Get recent history (last 10 messages)
history = session.get_history(limit=10)
for msg in history:
role_label = "User" if msg.role == "user" else "Assistant"
messages.append(f"{role_label}: {msg.content}")
# Add current message
messages.append(f"User: {current_message}")
return "\n\n".join(messages)

View File

@@ -0,0 +1,93 @@
import asyncio
from typing import Dict, Optional
from agent.memory import MemoryManager
class SessionManager:
"""Manages agent sessions and their associated memory.
Coordinates between the gateway's UserSession (for conversation state)
and the agent's MemoryManager (for persistent memory with Zep).
"""
def __init__(self, memory_manager: MemoryManager):
"""Initialize session manager.
Args:
memory_manager: MemoryManager instance for persistent storage
"""
self.memory = memory_manager
self._locks: Dict[str, asyncio.Lock] = {}
async def get_session_lock(self, session_id: str) -> asyncio.Lock:
"""Get or create a lock for a session.
Prevents concurrent execution for the same session.
Args:
session_id: Session identifier
Returns:
asyncio.Lock for this session
"""
if session_id not in self._locks:
self._locks[session_id] = asyncio.Lock()
return self._locks[session_id]
async def save_message(
self,
session_id: str,
role: str,
content: str,
metadata: Optional[Dict] = None
) -> None:
"""Save a message to persistent memory.
Args:
session_id: Session identifier
role: Message role ("user" or "assistant")
content: Message content
metadata: Optional metadata
"""
await self.memory.add_memory(session_id, role, content, metadata)
async def get_relevant_context(
self,
session_id: str,
query: str,
limit: int = 5
) -> str:
"""Get relevant historical context for a query.
Uses semantic search over past conversations.
Args:
session_id: Session identifier
query: Search query
limit: Maximum results
Returns:
Formatted string of relevant past messages
"""
results = await self.memory.search_memory(session_id, query, limit)
if not results:
return ""
context_parts = ["## Relevant Past Context\n"]
for i, result in enumerate(results, 1):
role = result["role"].capitalize()
content = result["content"]
score = result.get("score", 0.0)
context_parts.append(f"{i}. [{role}, relevance: {score:.2f}] {content}\n")
return "\n".join(context_parts)
async def cleanup_session(self, session_id: str) -> None:
"""Cleanup session resources.
Args:
session_id: Session to cleanup
"""
if session_id in self._locks:
del self._locks[session_id]

662
backend/src/agent/tools.py Normal file
View File

@@ -0,0 +1,662 @@
from typing import Dict, Any, List, Optional
import io
import base64
import sys
import uuid
import logging
from pathlib import Path
from contextlib import redirect_stdout, redirect_stderr
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
# Global registry instance (will be set by main.py)
_registry = None
_datasource_registry = None
def set_registry(registry):
"""Set the global SyncRegistry instance for tools to use."""
global _registry
_registry = registry
def set_datasource_registry(datasource_registry):
"""Set the global DataSourceRegistry instance for tools to use."""
global _datasource_registry
_datasource_registry = datasource_registry
@tool
def list_sync_stores() -> List[str]:
"""List all available synchronization stores.
Returns:
List of store names that can be read/written
"""
if not _registry:
return []
return list(_registry.entries.keys())
@tool
def read_sync_state(store_name: str) -> Dict[str, Any]:
"""Read the current state of a synchronization store.
Args:
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
Returns:
Dictionary containing the current state of the store
Raises:
ValueError: If store_name doesn't exist
"""
if not _registry:
raise ValueError("SyncRegistry not initialized")
entry = _registry.entries.get(store_name)
if not entry:
available = list(_registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
return entry.model.model_dump(mode="json")
@tool
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
"""Update the state of a synchronization store.
This will apply the updates to the store and trigger synchronization
with all connected clients.
Args:
store_name: Name of the store to update
updates: Dictionary of field updates (field_name: new_value)
Returns:
Dictionary with status and updated fields
Raises:
ValueError: If store_name doesn't exist or updates are invalid
"""
if not _registry:
raise ValueError("SyncRegistry not initialized")
entry = _registry.entries.get(store_name)
if not entry:
available = list(_registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
try:
# Get current state
current_state = entry.model.model_dump(mode="json")
# Apply updates
new_state = {**current_state, **updates}
# Update the model
_registry._update_model(entry.model, new_state)
# Trigger sync
await _registry.push_all()
return {
"status": "success",
"store": store_name,
"updated_fields": list(updates.keys())
}
except Exception as e:
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
@tool
def get_store_schema(store_name: str) -> Dict[str, Any]:
"""Get the schema/structure of a synchronization store.
This shows what fields are available and their types.
Args:
store_name: Name of the store
Returns:
Dictionary describing the store's schema
Raises:
ValueError: If store_name doesn't exist
"""
if not _registry:
raise ValueError("SyncRegistry not initialized")
entry = _registry.entries.get(store_name)
if not entry:
available = list(_registry.entries.keys())
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
# Get model schema
schema = entry.model.model_json_schema()
return {
"store_name": store_name,
"schema": schema
}
# DataSource tools
@tool
def list_data_sources() -> List[str]:
"""List all available data sources.
Returns:
List of data source names that can be queried for market data
"""
if not _datasource_registry:
return []
return _datasource_registry.list_sources()
@tool
async def search_symbols(
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> Dict[str, Any]:
"""Search for trading symbols across all data sources.
Automatically searches all available data sources and returns aggregated results.
Use this to find symbols before calling get_symbol_info or get_historical_data.
Args:
query: Search query (e.g., "BTC", "AAPL", "EUR")
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
limit: Maximum number of results per source (default: 30)
Returns:
Dictionary mapping source names to lists of matching symbols.
Each symbol includes: symbol, full_name, description, exchange, type.
Use the source name and symbol from results with get_symbol_info or get_historical_data.
Example response:
{
"demo": [
{
"symbol": "BTC/USDT",
"full_name": "Bitcoin / Tether USD",
"description": "Bitcoin perpetual futures",
"exchange": "demo",
"type": "crypto"
}
]
}
"""
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized")
# Always search all sources
results = await _datasource_registry.search_all(query, type, exchange, limit)
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
@tool
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
"""Get complete metadata for a trading symbol.
This retrieves full information about a symbol including:
- Description and type
- Supported time resolutions
- Available data columns (OHLCV, volume, funding rates, etc.)
- Trading session information
- Price scale and precision
Args:
source_name: Name of the data source (use list_data_sources to see available)
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
Returns:
Dictionary containing complete symbol metadata including column schema
Raises:
ValueError: If source_name or symbol is not found
"""
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized")
symbol_info = await _datasource_registry.resolve_symbol(source_name, symbol)
return symbol_info.model_dump()
@tool
async def get_historical_data(
source_name: str,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> Dict[str, Any]:
"""Get historical bar/candle data for a symbol.
Retrieves time-series data between the specified timestamps. The data
includes all columns defined for the symbol (OHLCV + any custom columns).
Args:
source_name: Name of the data source
symbol: Symbol identifier
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
from_time: Start time as Unix timestamp in seconds
to_time: End time as Unix timestamp in seconds
countback: Optional limit on number of bars to return
Returns:
Dictionary containing:
- symbol: The requested symbol
- resolution: The time resolution
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- nextTime: If present, indicates more data is available for pagination
Raises:
ValueError: If source, symbol, or resolution is invalid
Example:
# Get 1-hour BTC data for the last 24 hours
import time
to_time = int(time.time())
from_time = to_time - 86400 # 24 hours ago
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
"""
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized")
source = _datasource_registry.get(source_name)
if not source:
available = _datasource_registry.list_sources()
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
return result.model_dump()
async def _get_chart_data_impl(countback: Optional[int] = None):
"""Internal implementation for getting chart data.
This is a helper function that can be called by both get_chart_data tool
and analyze_chart_data tool.
Returns:
Tuple of (HistoryResult, chart_context dict, source_name)
"""
if not _registry:
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized - cannot query data")
# Read current chart state
chart_store = _registry.entries.get("ChartStore")
if not chart_store:
raise ValueError("ChartStore not found in registry")
chart_state = chart_store.model.model_dump(mode="json")
chart_data = chart_state.get("chart_state", {})
symbol = chart_data.get("symbol", "")
interval = chart_data.get("interval", "15")
start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time")
if not symbol:
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet")
# Parse the symbol to extract exchange/source and symbol name
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
if ":" not in symbol:
raise ValueError(
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
)
exchange_prefix, symbol_name = symbol.split(":", 1)
source_name = exchange_prefix.lower()
# Get the data source
source = _datasource_registry.get(source_name)
if not source:
available = _datasource_registry.list_sources()
raise ValueError(
f"Data source '{source_name}' not found. Available sources: {available}. "
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
)
# Determine time range - REQUIRE it to be set, no defaults
if start_time is None or end_time is None:
raise ValueError(
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
f"Wait for the chart to fully load before analyzing data."
)
from_time = int(start_time)
end_time = int(end_time)
logger.info(
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
f"countback={countback}"
)
logger.info(
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
f"resolution '{interval}'"
)
# Query the data source
result = await source.get_bars(
symbol=symbol_name,
resolution=interval,
from_time=from_time,
to_time=end_time,
countback=countback
)
logger.info(
f"Received {len(result.bars)} bars from data source. "
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
)
# Build chart context to return along with result
chart_context = {
"symbol": symbol,
"interval": interval,
"start_time": start_time,
"end_time": end_time
}
return result, chart_context, source_name
@tool
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
"""Get the candle/bar data for what the user is currently viewing on their chart.
This is a convenience tool that automatically:
1. Reads the ChartStore to see what chart the user is viewing
2. Parses the symbol to determine the data source (exchange prefix)
3. Queries the appropriate data source for that symbol's data
4. Returns the data for the visible time range and interval
This is the preferred way to access chart data when helping the user analyze
what they're looking at, since it automatically uses their current chart context.
Args:
countback: Optional limit on number of bars to return. If not specified,
returns all bars in the visible time range.
Returns:
Dictionary containing:
- chart_context: Current chart state (symbol, interval, time range)
- symbol: The trading pair being viewed
- resolution: The chart interval
- bars: List of bar data with 'time' and 'data' fields
- columns: Schema describing available data columns
- source: Which data source was used
Raises:
ValueError: If ChartStore or DataSourceRegistry is not initialized,
or if the symbol format is invalid
Example:
# User is viewing BINANCE:BTC/USDT on 15min chart
data = get_chart_data()
# Returns BTC/USDT data from binance source at 15min resolution
# for the currently visible time range
"""
result, chart_context, source_name = await _get_chart_data_impl(countback)
# Return enriched result with chart context
response = result.model_dump()
response["chart_context"] = chart_context
response["source"] = source_name
return response
@tool
async def analyze_chart_data(python_script: str, countback: Optional[int] = None) -> Dict[str, Any]:
"""Analyze the current chart data using a Python script with pandas and matplotlib.
This tool:
1. Gets the current chart data (same as get_chart_data)
2. Converts it to a pandas DataFrame with columns: time, open, high, low, close, volume
3. Executes your Python script with access to the DataFrame as 'df'
4. Saves any matplotlib plots to disk and returns URLs to access them
5. Returns any final DataFrame result and plot URLs
The script has access to:
- `df`: pandas DataFrame with OHLCV data indexed by datetime
- `pandas` (as `pd`): For data manipulation
- `numpy` (as `np`): For numerical operations
- `matplotlib.pyplot` (as `plt`): For plotting (use plt.figure() for each plot)
All matplotlib figures are automatically saved to disk and accessible via URLs.
The last expression in the script (if it's a DataFrame) is returned as the result.
Args:
python_script: Python code to execute. The DataFrame is available as 'df'.
Can use pandas, numpy, matplotlib. Return a DataFrame to include it in results.
countback: Optional limit on number of bars to analyze
Returns:
Dictionary containing:
- chart_context: Current chart state (symbol, interval, time range)
- source: Data source used
- script_output: Any printed output from the script
- result_dataframe: If script returns a DataFrame, it's included here as dict
- plot_urls: List of URLs to saved plot images (one per plt.figure())
- error: Error message if script execution failed
Example scripts:
# Calculate 20-period SMA and plot
```python
df['SMA20'] = df['close'].rolling(20).mean()
plt.figure(figsize=(12, 6))
plt.plot(df.index, df['close'], label='Close')
plt.plot(df.index, df['SMA20'], label='SMA20')
plt.legend()
plt.title('Price with SMA')
df[['close', 'SMA20']].tail(10) # Return last 10 rows
```
# Calculate RSI
```python
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / loss
df['RSI'] = 100 - (100 / (1 + rs))
df[['close', 'RSI']].tail(20)
```
# Multiple plots
```python
# Price chart
plt.figure(figsize=(12, 4))
plt.plot(df['close'])
plt.title('Price')
# Volume chart
plt.figure(figsize=(12, 3))
plt.bar(df.index, df['volume'])
plt.title('Volume')
df.describe() # Return statistics
```
"""
if not _registry:
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
if not _datasource_registry:
raise ValueError("DataSourceRegistry not initialized - cannot query data")
try:
# Import pandas and numpy here to allow lazy loading
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
except ImportError as e:
raise ValueError(
f"Required library not installed: {e}. "
"Please install pandas, numpy, and matplotlib: pip install pandas numpy matplotlib"
)
# Get chart data using the internal helper function
result, chart_context, source_name = await _get_chart_data_impl(countback)
# Build the same response format as get_chart_data
chart_data = result.model_dump()
chart_data["chart_context"] = chart_context
chart_data["source"] = source_name
# Convert bars to DataFrame
bars = chart_data.get('bars', [])
if not bars:
return {
"chart_context": chart_data.get('chart_context', {}),
"source": chart_data.get('source', ''),
"error": "No data available for the current chart"
}
# Build DataFrame
rows = []
for bar in bars:
row = {
'time': pd.to_datetime(bar['time'], unit='s'),
**bar['data'] # Includes open, high, low, close, volume, etc.
}
rows.append(row)
df = pd.DataFrame(rows)
df.set_index('time', inplace=True)
# Convert price columns to float for clean numeric operations
price_columns = ['open', 'high', 'low', 'close', 'volume']
for col in price_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
logger.info(
f"Created DataFrame with {len(df)} rows, columns: {df.columns.tolist()}, "
f"time range: {df.index.min()} to {df.index.max()}, "
f"dtypes: {df.dtypes.to_dict()}"
)
# Prepare execution environment
script_globals = {
'df': df,
'pd': pd,
'np': np,
'plt': plt,
}
# Capture stdout/stderr
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
result_df = None
error_msg = None
plot_urls = []
# Determine uploads directory (relative to this file)
uploads_dir = Path(__file__).parent.parent.parent / "uploads"
uploads_dir.mkdir(exist_ok=True)
try:
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
# Execute the script
exec(python_script, script_globals)
# Check if the last line is an expression that returns a DataFrame
# We'll try to evaluate it separately
script_lines = python_script.strip().split('\n')
if script_lines:
last_line = script_lines[-1].strip()
# Only evaluate if it doesn't look like a statement
if last_line and not any(last_line.startswith(kw) for kw in ['if', 'for', 'while', 'def', 'class', 'import', 'from', 'with', 'try', 'return']):
try:
last_result = eval(last_line, script_globals)
if isinstance(last_result, pd.DataFrame):
result_df = last_result
except:
# If eval fails, that's okay - might not be an expression
pass
# Save all matplotlib figures to disk
for fig_num in plt.get_fignums():
fig = plt.figure(fig_num)
# Generate unique filename
plot_id = str(uuid.uuid4())
filename = f"plot_{plot_id}.png"
filepath = uploads_dir / filename
# Save figure to file
fig.savefig(filepath, format='png', bbox_inches='tight', dpi=100)
# Generate URL that can be accessed via the web server
plot_url = f"/uploads/{filename}"
plot_urls.append(plot_url)
plt.close(fig)
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
import traceback
error_msg += f"\n{traceback.format_exc()}"
# Build response
response = {
"chart_context": chart_data.get('chart_context', {}),
"source": chart_data.get('source', ''),
"script_output": stdout_capture.getvalue(),
}
if error_msg:
response["error"] = error_msg
response["stderr"] = stderr_capture.getvalue()
if result_df is not None:
# Convert DataFrame to dict for JSON serialization
response["result_dataframe"] = {
"columns": result_df.columns.tolist(),
"index": result_df.index.astype(str).tolist() if hasattr(result_df.index, 'astype') else result_df.index.tolist(),
"data": result_df.values.tolist(),
"shape": result_df.shape,
}
if plot_urls:
response["plot_urls"] = plot_urls
return response
# Export all tools
SYNC_TOOLS = [
list_sync_stores,
read_sync_state,
write_sync_state,
get_store_schema
]
DATASOURCE_TOOLS = [
list_data_sources,
search_symbols,
get_symbol_info,
get_historical_data,
get_chart_data,
analyze_chart_data
]

View File

@@ -0,0 +1,23 @@
from .base import DataSource
from .schema import (
ColumnInfo,
SymbolInfo,
Bar,
HistoryResult,
DatafeedConfig,
Resolution,
SearchResult,
)
from .registry import DataSourceRegistry
__all__ = [
"DataSource",
"ColumnInfo",
"SymbolInfo",
"Bar",
"HistoryResult",
"DatafeedConfig",
"Resolution",
"SearchResult",
"DataSourceRegistry",
]

View File

@@ -0,0 +1,3 @@
from .ccxt_adapter import CCXTDataSource
__all__ = ["CCXTDataSource"]

View File

@@ -0,0 +1,526 @@
"""
CCXT DataSource adapter for accessing cryptocurrency exchange data.
This adapter provides access to hundreds of cryptocurrency exchanges through
the free CCXT library (not ccxt.pro), supporting both historical data and
polling-based subscriptions.
Numerical Precision:
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
- CCXT returns numeric values as strings or floats depending on configuration
- All financial values are converted to Decimal to maintain precision
Real-time Updates:
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
- Default polling interval: 60 seconds (configurable)
- Simulates real-time subscriptions by periodically fetching latest bars
"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from decimal import Decimal
from typing import Callable, Dict, List, Optional, Set, Union
import ccxt.async_support as ccxt
from ..base import DataSource
logger = logging.getLogger(__name__)
from ..schema import (
Bar,
ColumnInfo,
DatafeedConfig,
HistoryResult,
Resolution,
SearchResult,
SymbolInfo,
)
class CCXTDataSource(DataSource):
"""
DataSource adapter for CCXT cryptocurrency exchanges (free version).
Provides access to:
- Multiple cryptocurrency exchanges (Binance, Coinbase, Kraken, etc.)
- Historical OHLCV data via REST API
- Polling-based real-time updates (configurable interval)
- Symbol search and metadata
Args:
exchange_id: CCXT exchange identifier (e.g., 'binance', 'coinbase', 'kraken')
config: Optional exchange-specific configuration (API keys, options)
sandbox: Whether to use sandbox/testnet mode (default: False)
poll_interval: Interval in seconds for polling updates (default: 60)
"""
def __init__(
self,
exchange_id: str = "binance",
config: Optional[Dict] = None,
sandbox: bool = False,
poll_interval: int = 60,
):
self.exchange_id = exchange_id
self._config = config or {}
self._sandbox = sandbox
self._poll_interval = poll_interval
# Initialize exchange (using free async_support, not pro)
exchange_class = getattr(ccxt, exchange_id)
self.exchange = exchange_class(self._config)
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
self.exchange.set_sandbox_mode(True)
# Cache for markets
self._markets: Optional[Dict] = None
self._markets_loaded = False
# Active subscriptions (polling-based)
self._subscriptions: Dict[str, asyncio.Task] = {}
self._subscription_callbacks: Dict[str, Callable] = {}
self._last_bars: Dict[str, int] = {} # Track last bar timestamp per subscription
@staticmethod
def _to_decimal(value: Union[str, int, float, Decimal, None]) -> Optional[Decimal]:
"""
Convert a value to Decimal for numerical precision.
Handles CCXT's mixed output (strings, floats, ints, None).
Converts floats by converting to string first to avoid precision loss.
"""
if value is None:
return None
if isinstance(value, Decimal):
return value
if isinstance(value, str):
return Decimal(value)
if isinstance(value, (int, float)):
# Convert to string first to avoid float precision issues
return Decimal(str(value))
return None
async def _ensure_markets_loaded(self):
"""Ensure markets are loaded from exchange"""
if not self._markets_loaded:
self._markets = await self.exchange.load_markets()
self._markets_loaded = True
async def get_config(self) -> DatafeedConfig:
"""Get datafeed configuration"""
await self._ensure_markets_loaded()
# Determine supported resolutions based on exchange capabilities
supported_resolutions = [
Resolution.M1,
Resolution.M5,
Resolution.M15,
Resolution.M30,
Resolution.H1,
Resolution.H4,
Resolution.D1,
]
# Get unique exchange names (most CCXT exchanges are just one)
exchanges = [self.exchange_id.upper()]
return DatafeedConfig(
name=f"CCXT {self.exchange_id.title()}",
description=f"Live and historical cryptocurrency data from {self.exchange_id} via CCXT library. "
f"Supports OHLCV data for {len(self._markets) if self._markets else 'many'} trading pairs.",
supported_resolutions=supported_resolutions,
supports_search=True,
supports_time=True,
exchanges=exchanges,
symbols_types=["crypto", "spot", "futures", "swap"],
)
async def search_symbols(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> List[SearchResult]:
"""Search for symbols on the exchange"""
await self._ensure_markets_loaded()
query_upper = query.upper()
results = []
for symbol, market in self._markets.items():
# Match query against symbol or base/quote currencies
if (query_upper in symbol or
query_upper in market.get('base', '') or
query_upper in market.get('quote', '')):
# Filter by type if specified
market_type = market.get('type', 'spot')
if type and market_type != type:
continue
# Create search result
base = market.get('base', '')
quote = market.get('quote', '')
results.append(
SearchResult(
symbol=f"{base}/{quote}", # Clean user-facing format
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
full_name=f"{base}/{quote} ({self.exchange_id.upper()})",
description=f"{base}/{quote} {market_type} trading pair on {self.exchange_id}",
exchange=self.exchange_id.upper(),
type=market_type,
)
)
if len(results) >= limit:
break
return results
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
"""Get complete metadata for a symbol"""
await self._ensure_markets_loaded()
if symbol not in self._markets:
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
market = self._markets[symbol]
base = market.get('base', '')
quote = market.get('quote', '')
market_type = market.get('type', 'spot')
# Determine price scale from market precision
# CCXT precision can be in different modes:
# - DECIMAL_PLACES (int): number of decimal places (e.g., 2 = 0.01)
# - TICK_SIZE (float): actual tick size (e.g., 0.01, 0.00001)
# We need to convert to pricescale (10^n where n is decimal places)
price_precision = market.get('precision', {}).get('price', 2)
if isinstance(price_precision, float):
# TICK_SIZE mode: precision is the actual tick size (e.g., 0.01, 0.00001)
# Convert tick size to decimal places
# For 0.01 -> 2 decimal places, 0.00001 -> 5 decimal places
tick_str = str(Decimal(str(price_precision)))
if '.' in tick_str:
decimal_places = len(tick_str.split('.')[1].rstrip('0'))
else:
decimal_places = 0
pricescale = 10 ** decimal_places
else:
# DECIMAL_PLACES or SIGNIFICANT_DIGITS mode: precision is an integer
# Assume DECIMAL_PLACES mode (most common for price)
pricescale = 10 ** int(price_precision)
return SymbolInfo(
symbol=f"{base}/{quote}", # Clean user-facing format
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
name=f"{base}/{quote}",
description=f"{base}/{quote} {market_type} pair on {self.exchange_id}. "
f"Minimum order: {market.get('limits', {}).get('amount', {}).get('min', 'N/A')} {base}",
type=market_type,
exchange=self.exchange_id.upper(),
timezone="Etc/UTC",
session="24x7",
supported_resolutions=[
Resolution.M1,
Resolution.M5,
Resolution.M15,
Resolution.M30,
Resolution.H1,
Resolution.H4,
Resolution.D1,
],
has_intraday=True,
has_daily=True,
has_weekly_and_monthly=False,
columns=[
ColumnInfo(
name="open",
type="decimal",
description=f"Opening price in {quote}",
unit=quote,
),
ColumnInfo(
name="high",
type="decimal",
description=f"Highest price in {quote}",
unit=quote,
),
ColumnInfo(
name="low",
type="decimal",
description=f"Lowest price in {quote}",
unit=quote,
),
ColumnInfo(
name="close",
type="decimal",
description=f"Closing price in {quote}",
unit=quote,
),
ColumnInfo(
name="volume",
type="decimal",
description=f"Trading volume in {base}",
unit=base,
),
],
time_column="time",
has_ohlcv=True,
pricescale=pricescale,
minmov=1,
base_currency=base,
quote_currency=quote,
)
def _resolution_to_timeframe(self, resolution: str) -> str:
"""Convert our resolution format to CCXT timeframe format"""
# Map our resolutions to CCXT timeframes
mapping = {
"1": "1m",
"5": "5m",
"15": "15m",
"30": "30m",
"60": "1h",
"120": "2h",
"240": "4h",
"360": "6h",
"720": "12h",
"1D": "1d",
"1W": "1w",
"1M": "1M",
}
return mapping.get(resolution, "1m")
def _timeframe_to_milliseconds(self, timeframe: str) -> int:
"""Convert CCXT timeframe to milliseconds"""
unit = timeframe[-1]
amount = int(timeframe[:-1]) if len(timeframe) > 1 else 1
units = {
's': 1000,
'm': 60 * 1000,
'h': 60 * 60 * 1000,
'd': 24 * 60 * 60 * 1000,
'w': 7 * 24 * 60 * 60 * 1000,
'M': 30 * 24 * 60 * 60 * 1000, # Approximate
}
return amount * units.get(unit, 60000)
async def get_bars(
self,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> HistoryResult:
"""Get historical bars from the exchange"""
logger.info(
f"CCXTDataSource({self.exchange_id}).get_bars: symbol={symbol}, resolution={resolution}, "
f"from_time={from_time}, to_time={to_time}, countback={countback}"
)
await self._ensure_markets_loaded()
if symbol not in self._markets:
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
timeframe = self._resolution_to_timeframe(resolution)
# CCXT uses milliseconds for timestamps
since = from_time * 1000
until = to_time * 1000
# Fetch OHLCV data
limit = countback if countback else 1000
try:
# Fetch in batches if needed
all_ohlcv = []
current_since = since
while current_since < until:
ohlcv = await self.exchange.fetch_ohlcv(
symbol,
timeframe=timeframe,
since=current_since,
limit=limit,
)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
# Update since for next batch
last_timestamp = ohlcv[-1][0]
if last_timestamp <= current_since:
break # No progress, avoid infinite loop
current_since = last_timestamp + 1
# Stop if we have enough bars
if countback and len(all_ohlcv) >= countback:
all_ohlcv = all_ohlcv[:countback]
break
# Convert to our Bar format with Decimal precision
bars = []
for candle in all_ohlcv:
timestamp_ms, open_price, high, low, close, volume = candle
timestamp = timestamp_ms // 1000 # Convert to seconds
# Only include bars within requested range
if timestamp < from_time or timestamp >= to_time:
continue
bars.append(
Bar(
time=timestamp,
data={
"open": self._to_decimal(open_price),
"high": self._to_decimal(high),
"low": self._to_decimal(low),
"close": self._to_decimal(close),
"volume": self._to_decimal(volume),
},
)
)
# Get symbol info for column metadata
symbol_info = await self.resolve_symbol(symbol)
logger.info(
f"CCXTDataSource({self.exchange_id}).get_bars: Returning {len(bars)} bars. "
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
)
# Determine if more data is available
next_time = None
if bars and countback and len(bars) >= countback:
next_time = bars[-1].time + (bars[-1].time - bars[-2].time if len(bars) > 1 else 60)
return HistoryResult(
symbol=symbol,
resolution=resolution,
bars=bars,
columns=symbol_info.columns,
nextTime=next_time,
)
except Exception as e:
raise ValueError(f"Failed to fetch bars for {symbol}: {str(e)}")
async def subscribe_bars(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
) -> str:
"""
Subscribe to bar updates via polling.
Note: Uses polling instead of WebSocket since we're using free CCXT.
Polls at the configured interval (default: 60 seconds).
"""
await self._ensure_markets_loaded()
if symbol not in self._markets:
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
subscription_id = f"{symbol}:{resolution}:{time.time()}"
# Store callback
self._subscription_callbacks[subscription_id] = on_tick
# Start polling task
timeframe = self._resolution_to_timeframe(resolution)
task = asyncio.create_task(
self._poll_ohlcv(symbol, timeframe, subscription_id)
)
self._subscriptions[subscription_id] = task
return subscription_id
async def _poll_ohlcv(self, symbol: str, timeframe: str, subscription_id: str):
"""
Poll for OHLCV updates at regular intervals.
This simulates real-time updates by fetching the latest bars periodically.
Only sends updates when new bars are detected.
"""
try:
while subscription_id in self._subscription_callbacks:
try:
# Fetch latest bars
ohlcv = await self.exchange.fetch_ohlcv(
symbol,
timeframe=timeframe,
limit=2, # Get last 2 bars to detect new ones
)
if ohlcv and len(ohlcv) > 0:
# Get the latest candle
latest = ohlcv[-1]
timestamp_ms, open_price, high, low, close, volume = latest
timestamp = timestamp_ms // 1000
# Only send update if this is a new bar
last_timestamp = self._last_bars.get(subscription_id, 0)
if timestamp > last_timestamp:
self._last_bars[subscription_id] = timestamp
# Convert to our format with Decimal precision
tick_data = {
"time": timestamp,
"open": self._to_decimal(open_price),
"high": self._to_decimal(high),
"low": self._to_decimal(low),
"close": self._to_decimal(close),
"volume": self._to_decimal(volume),
}
# Call the callback
callback = self._subscription_callbacks.get(subscription_id)
if callback:
callback(tick_data)
except Exception as e:
print(f"Error polling OHLCV for {symbol}: {e}")
# Wait for next poll interval
await asyncio.sleep(self._poll_interval)
except asyncio.CancelledError:
pass
async def unsubscribe_bars(self, subscription_id: str) -> None:
"""Unsubscribe from polling updates"""
# Remove callback and tracking
self._subscription_callbacks.pop(subscription_id, None)
self._last_bars.pop(subscription_id, None)
# Cancel polling task
task = self._subscriptions.pop(subscription_id, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def close(self):
"""Close exchange connection and cleanup"""
# Cancel all subscriptions
for subscription_id in list(self._subscriptions.keys()):
await self.unsubscribe_bars(subscription_id)
# Close exchange
if hasattr(self.exchange, 'close'):
await self.exchange.close()

View File

@@ -0,0 +1,353 @@
"""
Demo data source with synthetic data.
Generates realistic-looking OHLCV data plus additional columns for testing.
"""
import asyncio
import logging
import math
import random
import time
from typing import Callable, Dict, List, Optional
from ..base import DataSource
logger = logging.getLogger(__name__)
from ..schema import (
Bar,
ColumnInfo,
DatafeedConfig,
HistoryResult,
Resolution,
SearchResult,
SymbolInfo,
)
class DemoDataSource(DataSource):
"""
Demo data source that generates synthetic time-series data.
Provides:
- Standard OHLCV columns
- Additional demo columns (RSI, sentiment, volume_profile)
- Real-time updates via polling simulation
"""
def __init__(self):
self._subscriptions: Dict[str, asyncio.Task] = {}
self._symbols = {
"DEMO:BTC/USD": {
"name": "Bitcoin",
"type": "crypto",
"base_price": 50000.0,
"volatility": 0.02,
},
"DEMO:ETH/USD": {
"name": "Ethereum",
"type": "crypto",
"base_price": 3000.0,
"volatility": 0.03,
},
"DEMO:SOL/USD": {
"name": "Solana",
"type": "crypto",
"base_price": 100.0,
"volatility": 0.04,
},
}
async def get_config(self) -> DatafeedConfig:
return DatafeedConfig(
name="Demo DataSource",
description="Synthetic data generator for testing. Provides OHLCV plus additional indicator columns.",
supported_resolutions=[
Resolution.M1,
Resolution.M5,
Resolution.M15,
Resolution.H1,
Resolution.D1,
],
supports_search=True,
supports_time=True,
exchanges=["DEMO"],
symbols_types=["crypto"],
)
async def search_symbols(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> List[SearchResult]:
query_lower = query.lower()
results = []
for symbol, info in self._symbols.items():
if query_lower in symbol.lower() or query_lower in info["name"].lower():
if type and info["type"] != type:
continue
results.append(
SearchResult(
symbol=info['name'], # Clean user-facing format (e.g., "Bitcoin")
ticker=symbol, # Keep DEMO:BTC/USD format for routing
full_name=f"{info['name']} (DEMO)",
description=f"Demo {info['name']} pair",
exchange="DEMO",
type=info["type"],
)
)
return results[:limit]
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
if symbol not in self._symbols:
raise ValueError(f"Symbol '{symbol}' not found")
info = self._symbols[symbol]
base, quote = symbol.split(":")[1].split("/")
return SymbolInfo(
symbol=info["name"], # Clean user-facing format (e.g., "Bitcoin")
ticker=symbol, # Keep DEMO:BTC/USD format for routing
name=info["name"],
description=f"Demo {info['name']} spot price with synthetic indicators",
type=info["type"],
exchange="DEMO",
timezone="Etc/UTC",
session="24x7",
supported_resolutions=[Resolution.M1, Resolution.M5, Resolution.M15, Resolution.H1, Resolution.D1],
has_intraday=True,
has_daily=True,
has_weekly_and_monthly=False,
columns=[
ColumnInfo(
name="open",
type="float",
description=f"Opening price in {quote}",
unit=quote,
),
ColumnInfo(
name="high",
type="float",
description=f"Highest price in {quote}",
unit=quote,
),
ColumnInfo(
name="low",
type="float",
description=f"Lowest price in {quote}",
unit=quote,
),
ColumnInfo(
name="close",
type="float",
description=f"Closing price in {quote}",
unit=quote,
),
ColumnInfo(
name="volume",
type="float",
description=f"Trading volume in {base}",
unit=base,
),
ColumnInfo(
name="rsi",
type="float",
description="Relative Strength Index (14-period), range 0-100",
unit=None,
),
ColumnInfo(
name="sentiment",
type="float",
description="Synthetic social sentiment score, range -1.0 to 1.0",
unit=None,
),
ColumnInfo(
name="volume_profile",
type="float",
description="Volume as percentage of 24h average",
unit="%",
),
],
time_column="time",
has_ohlcv=True,
pricescale=100,
minmov=1,
base_currency=base,
quote_currency=quote,
)
async def get_bars(
self,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> HistoryResult:
if symbol not in self._symbols:
raise ValueError(f"Symbol '{symbol}' not found")
logger.info(
f"DemoDataSource.get_bars: symbol={symbol}, resolution={resolution}, "
f"from_time={from_time}, to_time={to_time}, countback={countback}"
)
info = self._symbols[symbol]
symbol_meta = await self.resolve_symbol(symbol)
# Convert resolution to seconds
resolution_seconds = self._resolution_to_seconds(resolution)
# Generate bars
bars = []
# Align current_time to resolution, but ensure it's >= from_time
current_time = from_time - (from_time % resolution_seconds)
if current_time < from_time:
current_time += resolution_seconds
price = info["base_price"]
bar_count = 0
max_bars = countback if countback else 5000
while current_time <= to_time and bar_count < max_bars:
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
# Only include bars within the requested range
if from_time <= current_time <= to_time:
bars.append(Bar(time=current_time * 1000, data=bar_data)) # Convert to milliseconds
bar_count += 1
price = bar_data["close"] # Next bar starts from previous close
current_time += resolution_seconds
logger.info(
f"DemoDataSource.get_bars: Generated {len(bars)} bars. "
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
)
# Determine if there's more data (for pagination)
next_time = current_time if current_time <= to_time else None
return HistoryResult(
symbol=symbol,
resolution=resolution,
bars=bars,
columns=symbol_meta.columns,
nextTime=next_time,
)
async def subscribe_bars(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
) -> str:
if symbol not in self._symbols:
raise ValueError(f"Symbol '{symbol}' not found")
subscription_id = f"{symbol}:{resolution}:{time.time()}"
# Start background task to simulate real-time updates
task = asyncio.create_task(
self._tick_generator(symbol, resolution, on_tick, subscription_id)
)
self._subscriptions[subscription_id] = task
return subscription_id
async def unsubscribe_bars(self, subscription_id: str) -> None:
task = self._subscriptions.pop(subscription_id, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def _resolution_to_seconds(self, resolution: str) -> int:
"""Convert resolution string to seconds"""
if resolution.endswith("D"):
return int(resolution[:-1]) * 86400
elif resolution.endswith("W"):
return int(resolution[:-1]) * 604800
elif resolution.endswith("M"):
return int(resolution[:-1]) * 2592000 # Approximate month
else:
# Assume minutes
return int(resolution) * 60
def _generate_bar(self, timestamp: int, base_price: float, volatility: float, period_seconds: int) -> dict:
"""Generate a single synthetic OHLCV bar"""
# Random walk for the period
open_price = base_price
# Generate intra-period price movement
num_ticks = max(10, period_seconds // 60) # More ticks for longer periods
prices = [open_price]
for _ in range(num_ticks):
change = random.gauss(0, volatility / math.sqrt(num_ticks))
prices.append(prices[-1] * (1 + change))
close_price = prices[-1]
high_price = max(prices)
low_price = min(prices)
# Volume with some randomness
base_volume = 1000000 * (period_seconds / 60) # Scale with period
volume = base_volume * random.uniform(0.5, 2.0)
# Additional synthetic indicators
rsi = 30 + random.random() * 40 # Biased toward middle range
sentiment = math.sin(timestamp / 3600) * 0.5 + random.gauss(0, 0.2) # Hourly cycle + noise
sentiment = max(-1.0, min(1.0, sentiment))
volume_profile = 100 * random.uniform(0.5, 1.5)
return {
"open": round(open_price, 2),
"high": round(high_price, 2),
"low": round(low_price, 2),
"close": round(close_price, 2),
"volume": round(volume, 2),
"rsi": round(rsi, 2),
"sentiment": round(sentiment, 3),
"volume_profile": round(volume_profile, 2),
}
async def _tick_generator(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
subscription_id: str,
):
"""Background task that generates periodic ticks"""
info = self._symbols[symbol]
resolution_seconds = self._resolution_to_seconds(resolution)
# Start from current aligned time
current_time = int(time.time())
current_time = current_time - (current_time % resolution_seconds)
price = info["base_price"]
try:
while True:
# Wait until next bar
await asyncio.sleep(resolution_seconds)
current_time += resolution_seconds
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
price = bar_data["close"]
# Call the tick handler
tick_data = {"time": current_time, **bar_data}
on_tick(tick_data)
except asyncio.CancelledError:
# Subscription cancelled
pass

View File

@@ -0,0 +1,146 @@
"""
Abstract DataSource interface.
Inspired by TradingView's Datafeed API with extensions for flexible column schemas
and AI-native metadata.
"""
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from .schema import DatafeedConfig, HistoryResult, SearchResult, SymbolInfo
class DataSource(ABC):
"""
Abstract base class for time-series data sources.
Provides a standardized interface for:
- Symbol search and metadata retrieval
- Historical data queries (time-based, paginated)
- Real-time data subscriptions
All data rows must have a timestamp. Additional columns are flexible
and described via ColumnInfo metadata.
"""
@abstractmethod
async def get_config(self) -> DatafeedConfig:
"""
Get datafeed configuration and capabilities.
Called once during initialization to understand what this data source
supports (resolutions, exchanges, search, etc.).
Returns:
DatafeedConfig describing this datafeed's capabilities
"""
pass
@abstractmethod
async def search_symbols(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> List[SearchResult]:
"""
Search for symbols matching a text query.
Args:
query: Free-text search string
type: Optional filter by instrument type
exchange: Optional filter by exchange
limit: Maximum number of results
Returns:
List of matching symbols with basic metadata
"""
pass
@abstractmethod
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
"""
Get complete metadata for a symbol.
Called after a symbol is selected to retrieve full information including
supported resolutions, column schema, trading session, etc.
Args:
symbol: Symbol identifier
Returns:
Complete SymbolInfo including column definitions
Raises:
ValueError: If symbol is not found
"""
pass
@abstractmethod
async def get_bars(
self,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> HistoryResult:
"""
Get historical bars for a symbol and resolution.
Time range is specified in Unix timestamps (seconds). If more data is
available beyond the requested range, the result should include a
nextTime value for pagination.
Args:
symbol: Symbol identifier
resolution: Time resolution (e.g., "1", "5", "60", "1D")
from_time: Start time (Unix timestamp in seconds)
to_time: End time (Unix timestamp in seconds)
countback: Optional limit on number of bars to return
Returns:
HistoryResult with bars, column schema, and pagination info
Raises:
ValueError: If symbol or resolution is not supported
"""
pass
@abstractmethod
async def subscribe_bars(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
) -> str:
"""
Subscribe to real-time bar updates.
The callback will be invoked with new bar data as it becomes available.
The data dict will match the column schema from resolve_symbol().
Args:
symbol: Symbol identifier
resolution: Time resolution
on_tick: Callback function receiving bar data dict
Returns:
Subscription ID for later unsubscribe
Raises:
ValueError: If symbol or resolution is not supported
"""
pass
@abstractmethod
async def unsubscribe_bars(self, subscription_id: str) -> None:
"""
Unsubscribe from real-time updates.
Args:
subscription_id: ID returned from subscribe_bars()
"""
pass

View File

@@ -0,0 +1,109 @@
"""
DataSource registry for managing multiple data sources.
"""
from typing import Dict, List, Optional
from .base import DataSource
from .schema import SearchResult, SymbolInfo
class DataSourceRegistry:
"""
Central registry for managing multiple DataSource instances.
Allows routing symbol queries to the appropriate data source and
aggregating search results across multiple sources.
"""
def __init__(self):
self._sources: Dict[str, DataSource] = {}
def register(self, name: str, source: DataSource) -> None:
"""
Register a data source.
Args:
name: Unique name for this data source
source: DataSource implementation
"""
self._sources[name] = source
def unregister(self, name: str) -> None:
"""
Unregister a data source.
Args:
name: Name of the data source to remove
"""
self._sources.pop(name, None)
def get(self, name: str) -> Optional[DataSource]:
"""
Get a registered data source by name.
Args:
name: Data source name
Returns:
DataSource instance or None if not found
"""
return self._sources.get(name)
def list_sources(self) -> List[str]:
"""
Get names of all registered data sources.
Returns:
List of data source names
"""
return list(self._sources.keys())
async def search_all(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> Dict[str, List[SearchResult]]:
"""
Search across all registered data sources.
Args:
query: Search query
type: Optional instrument type filter
exchange: Optional exchange filter
limit: Maximum results per source
Returns:
Dict mapping source name to search results
"""
results = {}
for name, source in self._sources.items():
try:
source_results = await source.search_symbols(query, type, exchange, limit)
if source_results:
results[name] = source_results
except Exception:
# Silently skip sources that error during search
continue
return results
async def resolve_symbol(self, source_name: str, symbol: str) -> SymbolInfo:
"""
Resolve a symbol from a specific data source.
Args:
source_name: Name of the data source
symbol: Symbol identifier
Returns:
SymbolInfo from the specified source
Raises:
ValueError: If source not found or symbol not found
"""
source = self.get(source_name)
if not source:
raise ValueError(f"Data source '{source_name}' not found")
return await source.resolve_symbol(symbol)

View File

@@ -0,0 +1,194 @@
"""
Data models for the DataSource interface.
Inspired by TradingView's Datafeed API but with flexible column schemas
for AI-native trading platform needs.
"""
from enum import StrEnum
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
class Resolution(StrEnum):
"""Standard time resolutions for bar data"""
# Seconds
S1 = "1S"
S5 = "5S"
S15 = "15S"
S30 = "30S"
# Minutes
M1 = "1"
M5 = "5"
M15 = "15"
M30 = "30"
# Hours
H1 = "60"
H2 = "120"
H4 = "240"
H6 = "360"
H12 = "720"
# Days
D1 = "1D"
# Weeks
W1 = "1W"
# Months
MO1 = "1M"
class ColumnInfo(BaseModel):
"""
Metadata for a single data column.
Provides rich, LLM-readable descriptions so AI agents can understand
and reason about available data fields.
"""
model_config = {"extra": "forbid"}
name: str = Field(description="Column name (e.g., 'close', 'volume', 'funding_rate')")
type: Literal["float", "int", "bool", "string", "decimal"] = Field(description="Data type")
description: str = Field(description="Human and LLM-readable description of what this column represents")
unit: Optional[str] = Field(default=None, description="Unit of measurement (e.g., 'USD', 'BTC', '%', 'contracts')")
nullable: bool = Field(default=False, description="Whether this column can contain null values")
class SymbolInfo(BaseModel):
"""
Complete metadata for a tradeable symbol.
Includes both TradingView-compatible fields and flexible schema definition
for arbitrary data columns.
"""
model_config = {"extra": "forbid"}
# Core identification
symbol: str = Field(description="Unique symbol identifier (primary key for data fetching)")
ticker: Optional[str] = Field(default=None, description="TradingView ticker (if different from symbol)")
name: str = Field(description="Display name")
description: str = Field(description="LLM-readable description of the instrument")
type: str = Field(description="Instrument type: 'crypto', 'stock', 'forex', 'futures', 'derived', etc.")
exchange: str = Field(description="Exchange or data source identifier")
# Trading session info
timezone: str = Field(default="Etc/UTC", description="IANA timezone identifier")
session: str = Field(default="24x7", description="Trading session spec (e.g., '0930-1600' or '24x7')")
# Resolution support
supported_resolutions: List[str] = Field(description="List of supported time resolutions")
has_intraday: bool = Field(default=True, description="Whether intraday resolutions are supported")
has_daily: bool = Field(default=True, description="Whether daily resolution is supported")
has_weekly_and_monthly: bool = Field(default=False, description="Whether weekly/monthly resolutions are supported")
# Flexible schema definition
columns: List[ColumnInfo] = Field(description="Available data columns for this symbol")
time_column: str = Field(default="time", description="Name of the timestamp column")
# Convenience flags
has_ohlcv: bool = Field(default=False, description="Whether standard OHLCV columns are present")
# Price display (for OHLCV data)
pricescale: int = Field(default=100, description="Price scale factor (e.g., 100 for 2 decimals)")
minmov: int = Field(default=1, description="Minimum price movement in pricescale units")
# Additional metadata
base_currency: Optional[str] = Field(default=None, description="Base currency (for crypto/forex)")
quote_currency: Optional[str] = Field(default=None, description="Quote currency (for crypto/forex)")
class Bar(BaseModel):
"""
A single bar/row of time-series data with flexible columns.
All bars must have a timestamp. Additional columns are stored in the
data dict and described by the associated ColumnInfo metadata.
"""
model_config = {"extra": "forbid"}
time: int = Field(description="Unix timestamp in seconds")
data: Dict[str, Any] = Field(description="Column name -> value mapping")
# Convenience accessors for common OHLCV columns
@property
def open(self) -> Optional[float]:
return self.data.get("open")
@property
def high(self) -> Optional[float]:
return self.data.get("high")
@property
def low(self) -> Optional[float]:
return self.data.get("low")
@property
def close(self) -> Optional[float]:
return self.data.get("close")
@property
def volume(self) -> Optional[float]:
return self.data.get("volume")
class HistoryResult(BaseModel):
"""
Result from a historical data query.
Includes the bars, schema information, and pagination metadata.
"""
model_config = {"extra": "forbid"}
symbol: str = Field(description="Symbol identifier")
resolution: str = Field(description="Time resolution of the bars")
bars: List[Bar] = Field(description="The actual data bars")
columns: List[ColumnInfo] = Field(description="Schema describing the bar data columns")
nextTime: Optional[int] = Field(default=None, description="Unix timestamp for pagination (if more data available)")
class SearchResult(BaseModel):
"""
A single result from symbol search.
"""
model_config = {"extra": "forbid"}
symbol: str = Field(description="Display symbol (e.g., 'BINANCE:ETH/BTC')")
ticker: Optional[str] = Field(default=None, description="Backend ticker for data fetching (e.g., 'ETH/BTC')")
full_name: str = Field(description="Full display name including exchange")
description: str = Field(description="Human-readable description")
exchange: str = Field(description="Exchange identifier")
type: str = Field(description="Instrument type")
class DatafeedConfig(BaseModel):
"""
Configuration and capabilities of a DataSource.
Similar to TradingView's onReady configuration object.
"""
model_config = {"extra": "forbid"}
# Supported features
supported_resolutions: List[str] = Field(description="All resolutions this datafeed supports")
supports_search: bool = Field(default=True, description="Whether symbol search is available")
supports_time: bool = Field(default=True, description="Whether time-based queries are supported")
supports_marks: bool = Field(default=False, description="Whether marks/events are supported")
# Data characteristics
exchanges: List[str] = Field(default_factory=list, description="Available exchanges")
symbols_types: List[str] = Field(default_factory=list, description="Available instrument types")
# Metadata
name: str = Field(description="Datafeed name")
description: str = Field(description="LLM-readable description of this data source")

View File

@@ -0,0 +1,235 @@
"""
Subscription manager for real-time data feeds.
Manages subscriptions across multiple data sources and routes updates
to WebSocket clients.
"""
import asyncio
import logging
from typing import Callable, Dict, Optional, Set
from .base import DataSource
logger = logging.getLogger(__name__)
class Subscription:
"""Represents a single client subscription"""
def __init__(
self,
subscription_id: str,
client_id: str,
source_name: str,
symbol: str,
resolution: str,
callback: Callable[[dict], None],
):
self.subscription_id = subscription_id
self.client_id = client_id
self.source_name = source_name
self.symbol = symbol
self.resolution = resolution
self.callback = callback
self.source_subscription_id: Optional[str] = None
class SubscriptionManager:
"""
Manages real-time data subscriptions across multiple data sources.
Handles:
- Subscription lifecycle (subscribe/unsubscribe)
- Routing updates from data sources to clients
- Multiplexing (multiple clients can subscribe to same symbol/resolution)
"""
def __init__(self):
# Map subscription_id -> Subscription
self._subscriptions: Dict[str, Subscription] = {}
# Map (source_name, symbol, resolution) -> Set[subscription_id]
# For tracking which client subscriptions use which source subscriptions
self._source_refs: Dict[tuple, Set[str]] = {}
# Map source_subscription_id -> (source_name, symbol, resolution)
self._source_subs: Dict[str, tuple] = {}
# Available data sources
self._sources: Dict[str, DataSource] = {}
def register_source(self, name: str, source: DataSource) -> None:
"""Register a data source"""
self._sources[name] = source
def unregister_source(self, name: str) -> None:
"""Unregister a data source"""
self._sources.pop(name, None)
async def subscribe(
self,
subscription_id: str,
client_id: str,
source_name: str,
symbol: str,
resolution: str,
callback: Callable[[dict], None],
) -> None:
"""
Subscribe a client to real-time updates.
Args:
subscription_id: Unique ID for this subscription
client_id: ID of the subscribing client
source_name: Name of the data source
symbol: Symbol to subscribe to
resolution: Time resolution
callback: Function to call with bar updates
Raises:
ValueError: If source not found or subscription fails
"""
source = self._sources.get(source_name)
if not source:
raise ValueError(f"Data source '{source_name}' not found")
# Create subscription record
subscription = Subscription(
subscription_id=subscription_id,
client_id=client_id,
source_name=source_name,
symbol=symbol,
resolution=resolution,
callback=callback,
)
# Check if we already have a source subscription for this (source, symbol, resolution)
source_key = (source_name, symbol, resolution)
if source_key not in self._source_refs:
# Need to create a new source subscription
try:
source_sub_id = await source.subscribe_bars(
symbol=symbol,
resolution=resolution,
on_tick=lambda bar: self._on_source_update(source_key, bar),
)
subscription.source_subscription_id = source_sub_id
self._source_subs[source_sub_id] = source_key
self._source_refs[source_key] = set()
logger.info(
f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}"
)
except Exception as e:
logger.error(f"Failed to subscribe to source: {e}")
raise
# Add this subscription to the reference set
self._source_refs[source_key].add(subscription_id)
self._subscriptions[subscription_id] = subscription
logger.info(
f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}"
)
async def unsubscribe(self, subscription_id: str) -> None:
"""
Unsubscribe a client from updates.
Args:
subscription_id: ID of the subscription to remove
"""
subscription = self._subscriptions.pop(subscription_id, None)
if not subscription:
logger.warning(f"Subscription {subscription_id} not found")
return
source_key = (subscription.source_name, subscription.symbol, subscription.resolution)
# Remove from reference set
if source_key in self._source_refs:
self._source_refs[source_key].discard(subscription_id)
# If no more clients need this source subscription, unsubscribe from source
if not self._source_refs[source_key]:
del self._source_refs[source_key]
if subscription.source_subscription_id:
source = self._sources.get(subscription.source_name)
if source:
try:
await source.unsubscribe_bars(subscription.source_subscription_id)
logger.info(
f"Unsubscribed from source: {subscription.source_subscription_id}"
)
except Exception as e:
logger.error(f"Error unsubscribing from source: {e}")
self._source_subs.pop(subscription.source_subscription_id, None)
logger.info(f"Client subscription removed: {subscription_id}")
async def unsubscribe_client(self, client_id: str) -> None:
"""
Unsubscribe all subscriptions for a client.
Useful when a WebSocket connection closes.
Args:
client_id: ID of the client
"""
# Find all subscriptions for this client
client_subs = [
sub_id
for sub_id, sub in self._subscriptions.items()
if sub.client_id == client_id
]
# Unsubscribe each one
for sub_id in client_subs:
await self.unsubscribe(sub_id)
logger.info(f"Unsubscribed all subscriptions for client {client_id}")
def _on_source_update(self, source_key: tuple, bar: dict) -> None:
"""
Handle an update from a data source.
Routes the update to all client subscriptions that need it.
Args:
source_key: (source_name, symbol, resolution) tuple
bar: Bar data dict from the source
"""
subscription_ids = self._source_refs.get(source_key, set())
for sub_id in subscription_ids:
subscription = self._subscriptions.get(sub_id)
if subscription:
try:
subscription.callback(bar)
except Exception as e:
logger.error(
f"Error in subscription callback {sub_id}: {e}", exc_info=True
)
def get_subscription_count(self) -> int:
"""Get total number of active client subscriptions"""
return len(self._subscriptions)
def get_source_subscription_count(self) -> int:
"""Get total number of active source subscriptions"""
return len(self._source_refs)
def get_client_subscriptions(self, client_id: str) -> list:
"""Get all subscriptions for a specific client"""
return [
{
"subscription_id": sub.subscription_id,
"source": sub.source_name,
"symbol": sub.symbol,
"resolution": sub.resolution,
}
for sub in self._subscriptions.values()
if sub.client_id == client_id
]

View File

@@ -0,0 +1,347 @@
"""
WebSocket handler for TradingView-compatible datafeed API.
Handles incoming requests for symbol search, metadata, historical data,
and real-time subscriptions.
"""
import json
import logging
from typing import Dict, Optional
from fastapi import WebSocket
from .base import DataSource
from .registry import DataSourceRegistry
from .subscription_manager import SubscriptionManager
from .websocket_protocol import (
BarUpdateMessage,
ErrorResponse,
GetBarsRequest,
GetBarsResponse,
GetConfigRequest,
GetConfigResponse,
ResolveSymbolRequest,
ResolveSymbolResponse,
SearchSymbolsRequest,
SearchSymbolsResponse,
SubscribeBarsRequest,
SubscribeBarsResponse,
UnsubscribeBarsRequest,
UnsubscribeBarsResponse,
)
logger = logging.getLogger(__name__)
class DatafeedWebSocketHandler:
"""
Handles WebSocket connections for TradingView-compatible datafeed API.
Each handler manages a single WebSocket connection and routes requests
to the appropriate data sources via the registry.
"""
def __init__(
self,
websocket: WebSocket,
client_id: str,
registry: DataSourceRegistry,
subscription_manager: SubscriptionManager,
default_source: Optional[str] = None,
):
"""
Initialize handler.
Args:
websocket: FastAPI WebSocket connection
client_id: Unique identifier for this client
registry: DataSource registry for accessing data sources
subscription_manager: Shared subscription manager
default_source: Default data source name if not specified in requests
"""
self.websocket = websocket
self.client_id = client_id
self.registry = registry
self.subscription_manager = subscription_manager
self.default_source = default_source
self._connected = True
async def handle_connection(self) -> None:
"""
Main connection handler loop.
Processes incoming messages until the connection closes.
"""
try:
await self.websocket.accept()
logger.info(f"WebSocket connected: client_id={self.client_id}")
while self._connected:
# Receive message
try:
data = await self.websocket.receive_text()
message = json.loads(data)
except Exception as e:
logger.error(f"Error receiving/parsing message: {e}")
break
# Route to appropriate handler
await self._handle_message(message)
except Exception as e:
logger.error(f"WebSocket error: {e}", exc_info=True)
finally:
# Clean up subscriptions when connection closes
await self.subscription_manager.unsubscribe_client(self.client_id)
self._connected = False
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
async def _handle_message(self, message: dict) -> None:
"""Route message to appropriate handler based on type"""
msg_type = message.get("type")
request_id = message.get("request_id", "unknown")
try:
if msg_type == "search_symbols":
await self._handle_search_symbols(SearchSymbolsRequest(**message))
elif msg_type == "resolve_symbol":
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
elif msg_type == "get_bars":
await self._handle_get_bars(GetBarsRequest(**message))
elif msg_type == "subscribe_bars":
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
elif msg_type == "unsubscribe_bars":
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
elif msg_type == "get_config":
await self._handle_get_config(GetConfigRequest(**message))
else:
await self._send_error(
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
)
except Exception as e:
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
"""Handle symbol search request"""
# Use default source or search all sources
if self.default_source:
source = self.registry.get(self.default_source)
if not source:
await self._send_error(
request.request_id,
"SOURCE_NOT_FOUND",
f"Default source '{self.default_source}' not found",
)
return
results = await source.search_symbols(
query=request.query,
type=request.symbol_type,
exchange=request.exchange,
limit=request.limit,
)
results_data = [r.model_dump(mode="json") for r in results]
else:
# Search all sources
all_results = await self.registry.search_all(
query=request.query,
type=request.symbol_type,
exchange=request.exchange,
limit=request.limit,
)
# Flatten results from all sources
results_data = []
for source_results in all_results.values():
results_data.extend([r.model_dump(mode="json") for r in source_results])
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
await self._send_response(response)
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
"""Handle symbol resolution request"""
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
if ":" in request.symbol:
source_name, symbol = request.symbol.split(":", 1)
else:
source_name = self.default_source
symbol = request.symbol
if not source_name:
await self._send_error(
request.request_id,
"NO_SOURCE_SPECIFIED",
"No data source specified and no default source configured",
)
return
try:
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
response = ResolveSymbolResponse(
request_id=request.request_id,
symbol_info=symbol_info.model_dump(mode="json"),
)
await self._send_response(response)
except ValueError as e:
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
"""Handle historical bars request"""
# Extract source from symbol
if ":" in request.symbol:
source_name, symbol = request.symbol.split(":", 1)
else:
source_name = self.default_source
symbol = request.symbol
if not source_name:
await self._send_error(
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
)
return
source = self.registry.get(source_name)
if not source:
await self._send_error(
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
)
return
try:
history = await source.get_bars(
symbol=symbol,
resolution=request.resolution,
from_time=request.from_time,
to_time=request.to_time,
countback=request.countback,
)
response = GetBarsResponse(
request_id=request.request_id, history=history.model_dump(mode="json")
)
await self._send_response(response)
except ValueError as e:
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
"""Handle real-time subscription request"""
# Extract source from symbol
if ":" in request.symbol:
source_name, symbol = request.symbol.split(":", 1)
else:
source_name = self.default_source
symbol = request.symbol
if not source_name:
await self._send_error(
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
)
return
try:
# Create callback that sends updates to this WebSocket
async def send_update(bar: dict):
update = BarUpdateMessage(
subscription_id=request.subscription_id,
symbol=request.symbol,
resolution=request.resolution,
bar=bar,
)
await self._send_response(update)
# Register subscription
await self.subscription_manager.subscribe(
subscription_id=request.subscription_id,
client_id=self.client_id,
source_name=source_name,
symbol=symbol,
resolution=request.resolution,
callback=lambda bar: self._queue_update(send_update(bar)),
)
response = SubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=True,
)
await self._send_response(response)
except Exception as e:
logger.error(f"Subscription failed: {e}", exc_info=True)
response = SubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=False,
message=str(e),
)
await self._send_response(response)
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
"""Handle unsubscribe request"""
try:
await self.subscription_manager.unsubscribe(request.subscription_id)
response = UnsubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=True,
)
await self._send_response(response)
except Exception as e:
logger.error(f"Unsubscribe failed: {e}")
response = UnsubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=False,
)
await self._send_response(response)
async def _handle_get_config(self, request: GetConfigRequest) -> None:
"""Handle datafeed config request"""
if self.default_source:
source = self.registry.get(self.default_source)
if source:
config = await source.get_config()
response = GetConfigResponse(
request_id=request.request_id, config=config.model_dump(mode="json")
)
await self._send_response(response)
return
# Return aggregate config from all sources
all_sources = self.registry.list_sources()
if not all_sources:
await self._send_error(
request.request_id, "NO_SOURCES", "No data sources available"
)
return
# Just use first source's config for now
# TODO: Aggregate configs from multiple sources
source = self.registry.get(all_sources[0])
if source:
config = await source.get_config()
response = GetConfigResponse(
request_id=request.request_id, config=config.model_dump(mode="json")
)
await self._send_response(response)
async def _send_response(self, response) -> None:
"""Send a response message to the client"""
try:
await self.websocket.send_json(response.model_dump(mode="json"))
except Exception as e:
logger.error(f"Error sending response: {e}")
self._connected = False
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
"""Send an error response"""
error = ErrorResponse(
request_id=request_id, error_code=error_code, error_message=error_message
)
await self._send_response(error)
def _queue_update(self, coro):
"""Queue an async update to be sent (prevents blocking callback)"""
import asyncio
asyncio.create_task(coro)

View File

@@ -0,0 +1,170 @@
"""
WebSocket protocol messages for TradingView-compatible datafeed API.
These messages define the wire format for client-server communication
over WebSocket for symbol search, historical data, and real-time subscriptions.
"""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
# ============================================================================
# Client -> Server Messages
# ============================================================================
class SearchSymbolsRequest(BaseModel):
"""Request to search for symbols matching a query"""
type: Literal["search_symbols"] = "search_symbols"
request_id: str = Field(description="Client-generated request ID for matching responses")
query: str = Field(description="Search query string")
symbol_type: Optional[str] = Field(default=None, description="Filter by instrument type")
exchange: Optional[str] = Field(default=None, description="Filter by exchange")
limit: int = Field(default=30, description="Maximum number of results")
class ResolveSymbolRequest(BaseModel):
"""Request full metadata for a specific symbol"""
type: Literal["resolve_symbol"] = "resolve_symbol"
request_id: str
symbol: str = Field(description="Symbol identifier to resolve")
class GetBarsRequest(BaseModel):
"""Request historical bar data"""
type: Literal["get_bars"] = "get_bars"
request_id: str
symbol: str
resolution: str = Field(description="Time resolution (e.g., '1', '5', '60', '1D')")
from_time: int = Field(description="Start time (Unix timestamp in seconds)")
to_time: int = Field(description="End time (Unix timestamp in seconds)")
countback: Optional[int] = Field(default=None, description="Maximum number of bars to return")
class SubscribeBarsRequest(BaseModel):
"""Subscribe to real-time bar updates"""
type: Literal["subscribe_bars"] = "subscribe_bars"
request_id: str
symbol: str
resolution: str
subscription_id: str = Field(description="Client-generated subscription ID")
class UnsubscribeBarsRequest(BaseModel):
"""Unsubscribe from real-time updates"""
type: Literal["unsubscribe_bars"] = "unsubscribe_bars"
request_id: str
subscription_id: str
class GetConfigRequest(BaseModel):
"""Request datafeed configuration"""
type: Literal["get_config"] = "get_config"
request_id: str
# Union of all client request types
ClientRequest = Union[
SearchSymbolsRequest,
ResolveSymbolRequest,
GetBarsRequest,
SubscribeBarsRequest,
UnsubscribeBarsRequest,
GetConfigRequest,
]
# ============================================================================
# Server -> Client Messages
# ============================================================================
class SearchSymbolsResponse(BaseModel):
"""Response with search results"""
type: Literal["search_symbols_response"] = "search_symbols_response"
request_id: str
results: List[Dict[str, Any]] = Field(description="List of SearchResult objects")
class ResolveSymbolResponse(BaseModel):
"""Response with symbol metadata"""
type: Literal["resolve_symbol_response"] = "resolve_symbol_response"
request_id: str
symbol_info: Dict[str, Any] = Field(description="SymbolInfo object")
class GetBarsResponse(BaseModel):
"""Response with historical bars"""
type: Literal["get_bars_response"] = "get_bars_response"
request_id: str
history: Dict[str, Any] = Field(description="HistoryResult object with bars and metadata")
class SubscribeBarsResponse(BaseModel):
"""Acknowledgment of subscription"""
type: Literal["subscribe_bars_response"] = "subscribe_bars_response"
request_id: str
subscription_id: str
success: bool
message: Optional[str] = None
class UnsubscribeBarsResponse(BaseModel):
"""Acknowledgment of unsubscribe"""
type: Literal["unsubscribe_bars_response"] = "unsubscribe_bars_response"
request_id: str
subscription_id: str
success: bool
class GetConfigResponse(BaseModel):
"""Response with datafeed configuration"""
type: Literal["get_config_response"] = "get_config_response"
request_id: str
config: Dict[str, Any] = Field(description="DatafeedConfig object")
class BarUpdateMessage(BaseModel):
"""Real-time bar update (server-initiated, no request_id)"""
type: Literal["bar_update"] = "bar_update"
subscription_id: str
symbol: str
resolution: str
bar: Dict[str, Any] = Field(description="Bar data including time and all columns")
class ErrorResponse(BaseModel):
"""Error response for any failed request"""
type: Literal["error"] = "error"
request_id: str
error_code: str = Field(description="Machine-readable error code")
error_message: str = Field(description="Human-readable error description")
# Union of all server response types
ServerResponse = Union[
SearchSymbolsResponse,
ResolveSymbolResponse,
GetBarsResponse,
SubscribeBarsResponse,
UnsubscribeBarsResponse,
GetConfigResponse,
BarUpdateMessage,
ErrorResponse,
]

View File

@@ -0,0 +1,4 @@
from gateway.protocol import UserMessage, AgentMessage, ChannelStatus
from gateway.hub import Gateway
__all__ = ["UserMessage", "AgentMessage", "ChannelStatus", "Gateway"]

View File

@@ -0,0 +1,3 @@
from gateway.channels.base import Channel
__all__ = ["Channel"]

View File

@@ -0,0 +1,73 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator, Dict, Any
from gateway.protocol import UserMessage, AgentMessage, ChannelStatus
class Channel(ABC):
"""Abstract base class for communication channels.
Channels are the transport layer between users and the agent system.
They handle protocol-specific encoding/decoding and maintain connection state.
"""
def __init__(self, channel_id: str, channel_type: str):
self.channel_id = channel_id
self.channel_type = channel_type
self._connected = False
@abstractmethod
async def send(self, message: AgentMessage) -> None:
"""Send a message from the agent to the user through this channel.
Args:
message: AgentMessage to send (may be streaming chunk or complete message)
"""
pass
@abstractmethod
async def receive(self) -> AsyncIterator[UserMessage]:
"""Receive messages from the user through this channel.
Yields:
UserMessage objects as they arrive
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the channel and clean up resources."""
pass
def supports_streaming(self) -> bool:
"""Whether this channel supports streaming responses.
Returns:
True if the channel can handle streaming chunks
"""
return False
def supports_attachments(self) -> bool:
"""Whether this channel supports file attachments.
Returns:
True if the channel can handle attachments
"""
return False
def get_status(self) -> ChannelStatus:
"""Get current channel status.
Returns:
ChannelStatus object with connection info and capabilities
"""
return ChannelStatus(
channel_id=self.channel_id,
channel_type=self.channel_type,
connected=self._connected,
capabilities={
"streaming": self.supports_streaming(),
"attachments": self.supports_attachments(),
"markdown": True # Most channels support some form of markdown
}
)

View File

@@ -0,0 +1,99 @@
import asyncio
import json
import logging
from typing import AsyncIterator, Optional
from datetime import datetime
from fastapi import WebSocket
from gateway.channels.base import Channel
from gateway.protocol import UserMessage, AgentMessage, WebSocketAgentUserMessage, WebSocketAgentChunk
logger = logging.getLogger(__name__)
class WebSocketChannel(Channel):
"""WebSocket-based communication channel.
Integrates with the existing WebSocket endpoint to provide
bidirectional agent communication with streaming support.
"""
def __init__(self, channel_id: str, websocket: WebSocket, session_id: str):
super().__init__(channel_id, "websocket")
self.websocket = websocket
self.session_id = session_id
self._connected = True
self._receive_queue: asyncio.Queue[UserMessage] = asyncio.Queue()
self._receive_task: Optional[asyncio.Task] = None
def supports_streaming(self) -> bool:
"""WebSocket supports streaming responses."""
return True
def supports_attachments(self) -> bool:
"""WebSocket can support attachments via URLs."""
return True
async def send(self, message: AgentMessage) -> None:
"""Send agent message through WebSocket.
For streaming messages, sends chunks as they arrive.
For complete messages, sends as a single chunk.
"""
if not self._connected:
logger.warning(f"Cannot send message, channel {self.channel_id} not connected")
return
try:
chunk = WebSocketAgentChunk(
session_id=message.session_id,
content=message.content,
done=message.done,
metadata=message.metadata
)
chunk_data = chunk.model_dump(mode="json")
logger.debug(f"Sending WebSocket message: done={message.done}, content_length={len(message.content)}")
await self.websocket.send_json(chunk_data)
logger.debug(f"WebSocket message sent successfully")
except Exception as e:
logger.error(f"WebSocket send error: {e}", exc_info=True)
self._connected = False
async def receive(self) -> AsyncIterator[UserMessage]:
"""Receive messages from WebSocket.
Yields:
UserMessage objects as they arrive from the client
"""
try:
while self._connected:
# Read from WebSocket
data = await self.websocket.receive_text()
message_json = json.loads(data)
# Only process agent_user_message types
if message_json.get("type") == "agent_user_message":
msg = WebSocketAgentUserMessage(**message_json)
user_msg = UserMessage(
session_id=msg.session_id,
channel_id=self.channel_id,
content=msg.content,
attachments=msg.attachments,
timestamp=datetime.utcnow()
)
yield user_msg
except Exception as e:
print(f"WebSocket receive error: {e}")
self._connected = False
async def close(self) -> None:
"""Close the WebSocket connection."""
self._connected = False
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
# Note: WebSocket close is handled by the main WebSocket endpoint

226
backend/src/gateway/hub.py Normal file
View File

@@ -0,0 +1,226 @@
import asyncio
import logging
from typing import Dict, Optional, Callable, Awaitable, Any
from datetime import datetime
from gateway.channels.base import Channel
from gateway.user_session import UserSession
from gateway.protocol import UserMessage, AgentMessage
logger = logging.getLogger(__name__)
class Gateway:
"""Central hub for routing messages between users, channels, and the agent.
The Gateway:
- Maintains active channels and user sessions
- Routes user messages to the agent
- Routes agent responses back to appropriate channels
- Handles interruption of in-flight agent tasks
"""
def __init__(self):
self.channels: Dict[str, Channel] = {}
self.sessions: Dict[str, UserSession] = {}
self.agent_executor: Optional[Callable[[UserSession, UserMessage], Awaitable[AsyncIterator[str]]]] = None
def set_agent_executor(
self,
executor: Callable[[UserSession, UserMessage], Awaitable[Any]]
) -> None:
"""Set the agent executor function.
Args:
executor: Async function that takes (session, message) and returns async iterator of response chunks
"""
self.agent_executor = executor
def register_channel(self, channel: Channel) -> None:
"""Register a new communication channel.
Args:
channel: Channel instance to register
"""
self.channels[channel.channel_id] = channel
def unregister_channel(self, channel_id: str) -> None:
"""Unregister a channel.
Args:
channel_id: ID of channel to remove
"""
if channel_id in self.channels:
# Remove channel from any sessions
for session in self.sessions.values():
session.remove_channel(channel_id)
del self.channels[channel_id]
def get_or_create_session(self, session_id: str, user_id: str) -> UserSession:
"""Get existing session or create a new one.
Args:
session_id: Unique session identifier
user_id: User identifier
Returns:
UserSession instance
"""
if session_id not in self.sessions:
self.sessions[session_id] = UserSession(session_id, user_id)
return self.sessions[session_id]
async def route_user_message(self, message: UserMessage) -> None:
"""Route a user message to the agent.
If the agent is currently processing a message, it will be interrupted
and restarted with the new message appended to the conversation.
Args:
message: UserMessage from a channel
"""
logger.info(f"route_user_message called - session: {message.session_id}, channel: {message.channel_id}")
# Get or create session
session = self.get_or_create_session(message.session_id, message.session_id)
logger.info(f"Session obtained/created: {message.session_id}, channels: {session.active_channels}")
# Ensure channel is attached to session
session.add_channel(message.channel_id)
logger.info(f"Channel added to session: {message.channel_id}")
# If agent is busy, interrupt it
if session.is_busy():
logger.info(f"Session is busy, interrupting existing task")
await session.interrupt()
# Add user message to history
session.add_message("user", message.content, message.channel_id)
logger.info(f"User message added to history, history length: {len(session.get_history())}")
# Start agent task
if self.agent_executor:
logger.info("Starting agent task execution")
task = asyncio.create_task(
self._execute_agent_and_stream(session, message)
)
session.set_task(task)
logger.info(f"Agent task created and set on session")
else:
logger.error("No agent_executor configured! Cannot process message.")
# Send error message to user
error_msg = AgentMessage(
session_id=session.session_id,
target_channels=session.active_channels,
content="Error: Agent system not initialized. Please check that ANTHROPIC_API_KEY is configured.",
done=True
)
await self._send_to_channels(error_msg)
async def _execute_agent_and_stream(self, session: UserSession, message: UserMessage) -> None:
"""Execute agent and stream responses back to channels.
Args:
session: User session
message: Triggering user message
"""
logger.info(f"_execute_agent_and_stream starting for session {session.session_id}")
try:
# Execute agent (returns async generator directly)
logger.info("Calling agent_executor...")
response_stream = self.agent_executor(session, message)
logger.info("Agent executor returned response stream")
# Stream chunks back to active channels
full_response = ""
chunk_count = 0
async for chunk in response_stream:
chunk_count += 1
full_response += chunk
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
# Send chunk to all active channels
agent_msg = AgentMessage(
session_id=session.session_id,
target_channels=session.active_channels,
content=chunk,
stream_chunk=True,
done=False
)
await self._send_to_channels(agent_msg)
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
# Send final done message
agent_msg = AgentMessage(
session_id=session.session_id,
target_channels=session.active_channels,
content="",
stream_chunk=True,
done=True
)
await self._send_to_channels(agent_msg)
logger.info("Sent final done message to channels")
# Add to history
session.add_message("assistant", full_response)
logger.info("Assistant response added to history")
except asyncio.CancelledError:
# Task was interrupted
logger.warning(f"Agent task interrupted for session {session.session_id}")
raise
except Exception as e:
logger.error(f"Agent execution error: {e}", exc_info=True)
# Send error message
error_msg = AgentMessage(
session_id=session.session_id,
target_channels=session.active_channels,
content=f"Error: {str(e)}",
done=True
)
await self._send_to_channels(error_msg)
finally:
session.set_task(None)
logger.info(f"Agent task completed for session {session.session_id}")
async def _send_to_channels(self, message: AgentMessage) -> None:
"""Send message to specified channels.
Args:
message: AgentMessage to send
"""
logger.debug(f"Sending message to {len(message.target_channels)} channels")
for channel_id in message.target_channels:
channel = self.channels.get(channel_id)
if channel and channel.get_status().connected:
try:
await channel.send(message)
logger.debug(f"Message sent to channel {channel_id}")
except Exception as e:
logger.error(f"Error sending to channel {channel_id}: {e}", exc_info=True)
else:
logger.warning(f"Channel {channel_id} not found or not connected")
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session information.
Args:
session_id: Session ID to query
Returns:
Session info dict or None if not found
"""
session = self.sessions.get(session_id)
return session.to_dict() if session else None
def get_active_sessions(self) -> Dict[str, Dict[str, Any]]:
"""Get all active sessions.
Returns:
Dict mapping session_id to session info
"""
return {
sid: session.to_dict()
for sid, session in self.sessions.items()
}

View File

@@ -0,0 +1,57 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
class UserMessage(BaseModel):
"""Message from user to agent through a communication channel."""
session_id: str
channel_id: str
content: str
attachments: List[str] = Field(default_factory=list, description="URLs or file paths")
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
metadata: Dict[str, Any] = Field(default_factory=dict)
class AgentMessage(BaseModel):
"""Message from agent to user(s) through one or more channels."""
session_id: str
target_channels: List[str] = Field(description="List of channel IDs to send to")
content: str
stream_chunk: bool = Field(default=False, description="True if this is a streaming chunk")
done: bool = Field(default=False, description="True if streaming is complete")
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChannelStatus(BaseModel):
"""Status information about a communication channel."""
channel_id: str
channel_type: str = Field(description="Type: 'websocket', 'slack', 'telegram', etc.")
connected: bool
user_id: Optional[str] = None
session_id: Optional[str] = None
capabilities: Dict[str, bool] = Field(
default_factory=lambda: {
"streaming": False,
"attachments": False,
"markdown": False
}
)
# WebSocket-specific protocol extensions
class WebSocketAgentUserMessage(BaseModel):
"""WebSocket message: User → Backend (agent chat)"""
type: Literal["agent_user_message"] = "agent_user_message"
session_id: str
content: str
attachments: List[str] = Field(default_factory=list)
class WebSocketAgentChunk(BaseModel):
"""WebSocket message: Backend → User (streaming agent response)"""
type: Literal["agent_chunk"] = "agent_chunk"
session_id: str
content: str
done: bool = False
metadata: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1,107 @@
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
class Message(BaseModel):
"""A single message in a conversation."""
role: str # "user" or "assistant"
content: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
channel_id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class UserSession:
"""Manages a user's conversation session with the agent.
A session tracks:
- Active communication channels
- Conversation history
- In-flight agent tasks (for interruption)
- User metadata
"""
def __init__(self, session_id: str, user_id: str):
self.session_id = session_id
self.user_id = user_id
self.active_channels: List[str] = []
self.conversation_history: List[Message] = []
self.current_task: Optional[asyncio.Task] = None
self.metadata: Dict[str, Any] = {}
self.created_at = datetime.utcnow()
self.last_activity = datetime.utcnow()
def add_channel(self, channel_id: str) -> None:
"""Attach a channel to this session."""
if channel_id not in self.active_channels:
self.active_channels.append(channel_id)
self.last_activity = datetime.utcnow()
def remove_channel(self, channel_id: str) -> None:
"""Detach a channel from this session."""
if channel_id in self.active_channels:
self.active_channels.remove(channel_id)
self.last_activity = datetime.utcnow()
def add_message(self, role: str, content: str, channel_id: Optional[str] = None, **kwargs) -> None:
"""Add a message to conversation history."""
message = Message(
role=role,
content=content,
channel_id=channel_id,
metadata=kwargs
)
self.conversation_history.append(message)
self.last_activity = datetime.utcnow()
def get_history(self, limit: Optional[int] = None) -> List[Message]:
"""Get conversation history.
Args:
limit: Maximum number of recent messages to return (None = all)
Returns:
List of Message objects
"""
if limit:
return self.conversation_history[-limit:]
return self.conversation_history
async def interrupt(self) -> bool:
"""Interrupt the current agent task if one is running.
Returns:
True if a task was interrupted, False otherwise
"""
if self.current_task and not self.current_task.done():
self.current_task.cancel()
try:
await self.current_task
except asyncio.CancelledError:
pass
self.current_task = None
return True
return False
def set_task(self, task: asyncio.Task) -> None:
"""Set the current agent task."""
self.current_task = task
def is_busy(self) -> bool:
"""Check if the agent is currently processing a request."""
return self.current_task is not None and not self.current_task.done()
def to_dict(self) -> Dict[str, Any]:
"""Serialize session to dict."""
return {
"session_id": self.session_id,
"user_id": self.user_id,
"active_channels": self.active_channels,
"message_count": len(self.conversation_history),
"is_busy": self.is_busy(),
"created_at": self.created_at.isoformat(),
"last_activity": self.last_activity.isoformat(),
"metadata": self.metadata
}

469
backend/src/main.py Normal file
View File

@@ -0,0 +1,469 @@
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
import yaml
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uuid
import shutil
from sync.protocol import HelloMessage, PatchMessage
from sync.registry import SyncRegistry
from gateway.hub import Gateway
from gateway.channels.websocket import WebSocketChannel
from gateway.protocol import WebSocketAgentUserMessage
from agent.core import create_agent
from agent.tools import set_registry, set_datasource_registry
from schema.order_spec import SwapOrder
from schema.chart_state import ChartState
from datasource.registry import DataSourceRegistry
from datasource.subscription_manager import SubscriptionManager
from datasource.websocket_handler import DatafeedWebSocketHandler
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables from .env file (if present)
env_path = Path(__file__).parent.parent / ".env"
if env_path.exists():
load_dotenv(env_path)
# Load configuration
config_path = Path(__file__).parent.parent / "config.yaml"
with open(config_path) as f:
config = yaml.safe_load(f)
registry = SyncRegistry()
gateway = Gateway()
agent_executor = None
# DataSource infrastructure
datasource_registry = DataSourceRegistry()
subscription_manager = SubscriptionManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize agent system and data sources on startup."""
global agent_executor
# Initialize CCXT data sources
try:
from datasource.adapters.ccxt_adapter import CCXTDataSource
# Binance
try:
binance_source = CCXTDataSource(exchange_id="binance", poll_interval=60)
datasource_registry.register("binance", binance_source)
subscription_manager.register_source("binance", binance_source)
logger.info("DataSource: Registered Binance source")
except Exception as e:
logger.warning(f"Failed to initialize Binance source: {e}")
logger.info(f"DataSource infrastructure initialized with sources: {datasource_registry.list_sources()}")
except ImportError as e:
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
# Get API keys from environment
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
if not anthropic_api_key:
logger.error("ANTHROPIC_API_KEY not found in environment!")
logger.info("Agent system will not be available")
else:
# Set the registries for agent tools
set_registry(registry)
set_datasource_registry(datasource_registry)
# Create and initialize agent
agent_executor = create_agent(
model_name=config["agent"]["model"],
temperature=config["agent"]["temperature"],
api_key=anthropic_api_key,
checkpoint_db_path=config["memory"]["checkpoint_db"],
chroma_db_path=config["memory"]["chroma_db"],
embedding_model=config["memory"]["embedding_model"],
context_docs_dir=config["agent"]["context_docs_dir"],
base_dir=".." # Point to project root from backend/src
)
await agent_executor.initialize()
# Set agent executor in gateway
gateway.set_agent_executor(agent_executor.execute)
logger.info("Agent system initialized")
yield
# Cleanup
if agent_executor and agent_executor.memory_manager:
await agent_executor.memory_manager.close()
logger.info("Agent system shut down")
app = FastAPI(lifespan=lifespan)
# Create uploads directory
UPLOAD_DIR = Path(__file__).parent.parent / "uploads"
UPLOAD_DIR.mkdir(exist_ok=True)
# Mount static files for serving uploads
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
# OrderStore model for synchronization
class OrderStore(BaseModel):
orders: list[SwapOrder] = []
# ChartStore model for synchronization
class ChartStore(BaseModel):
chart_state: ChartState = ChartState()
# Initialize stores
order_store = OrderStore()
chart_store = ChartStore()
# Register with SyncRegistry
registry.register(order_store, store_name="OrderStore")
registry.register(chart_store, store_name="ChartStore")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
registry.websocket = websocket
# Create WebSocket channel for agent communication
channel_id = f"ws_{id(websocket)}"
client_id = f"client_{id(websocket)}"
logger.info(f"WebSocket connected - channel_id: {channel_id}, client_id: {client_id}")
ws_channel = WebSocketChannel(channel_id, websocket, session_id="default")
gateway.register_channel(ws_channel)
# Helper function to send responses
async def send_response(response):
try:
await websocket.send_json(response.model_dump(mode="json"))
except Exception as e:
logger.error(f"Error sending response: {e}")
try:
while True:
data = await websocket.receive_text()
logger.debug(f"Received WebSocket message: {data[:200]}...") # Log first 200 chars
message_json = json.loads(data)
if "type" not in message_json:
logger.warning(f"Message missing 'type' field: {message_json}")
continue
msg_type = message_json["type"]
logger.info(f"Processing message type: {msg_type}")
# Handle sync protocol messages
if msg_type == "hello":
hello_msg = HelloMessage(**message_json)
logger.info(f"Hello message received with seqs: {hello_msg.seqs}")
await registry.sync_client(hello_msg.seqs)
elif msg_type == "patch":
patch_msg = PatchMessage(**message_json)
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
await registry.apply_client_patch(
store_name=patch_msg.store,
client_base_seq=patch_msg.seq,
patch=patch_msg.patch
)
elif msg_type == "agent_user_message":
# Handle agent messages directly here
print(f"[DEBUG] Raw message_json: {message_json}")
logger.info(f"Raw message_json: {message_json}")
msg = WebSocketAgentUserMessage(**message_json)
print(f"[DEBUG] Parsed message - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
logger.info(f"Agent user message received - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
from gateway.protocol import UserMessage
from datetime import datetime, timezone
user_msg = UserMessage(
session_id=msg.session_id,
channel_id=channel_id,
content=msg.content,
attachments=msg.attachments,
timestamp=datetime.now(timezone.utc)
)
logger.info(f"Routing user message to gateway - channel: {channel_id}, session: {msg.session_id}")
await gateway.route_user_message(user_msg)
logger.info("Message routing completed")
# Handle datafeed protocol messages
elif msg_type in ["get_config", "search_symbols", "resolve_symbol", "get_bars", "subscribe_bars", "unsubscribe_bars"]:
from datasource.websocket_protocol import (
GetConfigRequest, GetConfigResponse,
SearchSymbolsRequest, SearchSymbolsResponse,
ResolveSymbolRequest, ResolveSymbolResponse,
GetBarsRequest, GetBarsResponse,
SubscribeBarsRequest, SubscribeBarsResponse,
UnsubscribeBarsRequest, UnsubscribeBarsResponse,
ErrorResponse
)
request_id = message_json.get("request_id", "unknown")
try:
if msg_type == "get_config":
req = GetConfigRequest(**message_json)
logger.info(f"Getting config, request_id={req.request_id}")
sources = datasource_registry.list_sources()
logger.info(f"Available sources: {sources}")
if not sources:
error_response = ErrorResponse(request_id=req.request_id, error_code="NO_SOURCES", error_message="No data sources available")
await send_response(error_response)
else:
# Get config from first source (we can enhance this later to aggregate)
source = datasource_registry.get(sources[0])
if source:
try:
config = await source.get_config()
logger.info(f"Got config from {sources[0]}")
# Enhance with all available exchanges
all_exchanges = set()
for source_name in sources:
s = datasource_registry.get(source_name)
if s:
try:
cfg = await asyncio.wait_for(s.get_config(), timeout=5.0)
all_exchanges.update(cfg.exchanges)
logger.info(f"Added exchanges from {source_name}: {cfg.exchanges}")
except asyncio.TimeoutError:
logger.warning(f"Timeout getting config from {source_name}")
except Exception as e:
logger.warning(f"Error getting config from {source_name}: {e}")
config_dict = config.model_dump(mode="json")
config_dict["exchanges"] = list(all_exchanges)
logger.info(f"Sending config with exchanges: {list(all_exchanges)}")
response = GetConfigResponse(request_id=req.request_id, config=config_dict)
await send_response(response)
except Exception as e:
logger.error(f"Error getting config: {e}", exc_info=True)
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
await send_response(error_response)
else:
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message="Data sources not available")
await send_response(error_response)
elif msg_type == "search_symbols":
req = SearchSymbolsRequest(**message_json)
logger.info(f"Searching symbols: query='{req.query}', request_id={req.request_id}")
# Search all data sources
all_results = []
sources = datasource_registry.list_sources()
logger.info(f"Available data sources: {sources}")
for source_name in sources:
source = datasource_registry.get(source_name)
if source:
try:
results = await asyncio.wait_for(
source.search_symbols(query=req.query, type=req.symbol_type, exchange=req.exchange, limit=req.limit),
timeout=5.0
)
all_results.extend([r.model_dump(mode="json") for r in results])
logger.info(f"Source '{source_name}' returned {len(results)} results")
except asyncio.TimeoutError:
logger.warning(f"Timeout searching source '{source_name}'")
except Exception as e:
logger.warning(f"Error searching source '{source_name}': {e}")
logger.info(f"Total search results: {len(all_results)}")
response = SearchSymbolsResponse(request_id=req.request_id, results=all_results[:req.limit])
await send_response(response)
elif msg_type == "resolve_symbol":
req = ResolveSymbolRequest(**message_json)
logger.info(f"Resolving symbol: {req.symbol}")
# Parse ticker format: "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
symbol = req.symbol
source_name = None
symbol_without_exchange = symbol
# Check if ticker has exchange prefix
if ":" in symbol:
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
source_name = exchange_prefix.lower()
logger.info(f"Parsed ticker: exchange={source_name}, symbol={symbol_without_exchange}")
# If we identified a source, try it directly
if source_name:
try:
source = datasource_registry.get(source_name)
if source:
logger.info(f"Trying to resolve '{symbol_without_exchange}' in source '{source_name}'")
symbol_info = await asyncio.wait_for(
source.resolve_symbol(symbol_without_exchange),
timeout=5.0
)
logger.info(f"Successfully resolved '{symbol_without_exchange}' in source '{source_name}'")
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
await send_response(response)
else:
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
await send_response(error_response)
except asyncio.TimeoutError:
logger.warning(f"Timeout resolving '{symbol_without_exchange}' in source '{source_name}'")
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message=f"Timeout resolving symbol")
await send_response(error_response)
except Exception as e:
logger.warning(f"Error resolving '{symbol_without_exchange}' in source '{source_name}': {e}")
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=str(e))
await send_response(error_response)
else:
# No exchange prefix, try all sources
found = False
for src in datasource_registry.list_sources():
try:
s = datasource_registry.get(src)
if s:
logger.info(f"Trying to resolve '{symbol}' in source '{src}'")
symbol_info = await asyncio.wait_for(s.resolve_symbol(symbol), timeout=5.0)
if symbol_info:
logger.info(f"Successfully resolved '{symbol}' in source '{src}'")
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
await send_response(response)
found = True
break
except asyncio.TimeoutError:
logger.warning(f"Timeout resolving '{symbol}' in source '{src}'")
except Exception as e:
logger.info(f"Symbol '{symbol}' not found in source '{src}': {e}")
continue
if not found:
# Symbol not found in any source
logger.warning(f"Symbol '{symbol}' not found in any data source")
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=f"Symbol '{symbol}' not found in any data source")
await send_response(error_response)
elif msg_type == "get_bars":
req = GetBarsRequest(**message_json)
logger.info(f"Getting bars for symbol: {req.symbol}")
# Parse ticker format: "EXCHANGE:SYMBOL"
symbol = req.symbol
source_name = None
symbol_without_exchange = symbol
# Check if ticker has exchange prefix
if ":" in symbol:
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
source_name = exchange_prefix.lower()
logger.info(f"Parsed ticker for bars: exchange={source_name}, symbol={symbol_without_exchange}")
# If we identified a source, use it directly
if source_name:
try:
source = datasource_registry.get(source_name)
if source:
logger.info(f"Getting bars for '{symbol_without_exchange}' from source '{source_name}'")
history = await asyncio.wait_for(
source.get_bars(symbol=symbol_without_exchange, resolution=req.resolution, from_time=req.from_time, to_time=req.to_time, countback=req.countback),
timeout=10.0
)
logger.info(f"Successfully got {len(history.bars)} bars for '{symbol_without_exchange}' from source '{source_name}'")
response = GetBarsResponse(request_id=req.request_id, history=history.model_dump(mode="json"))
await send_response(response)
else:
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
await send_response(error_response)
except asyncio.TimeoutError:
logger.warning(f"Timeout getting bars for '{symbol_without_exchange}' from source '{source_name}'")
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message="Timeout fetching bars")
await send_response(error_response)
except Exception as e:
logger.warning(f"Error getting bars for '{symbol_without_exchange}' from source '{source_name}': {e}")
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
await send_response(error_response)
else:
# No exchange prefix - this shouldn't happen with proper tickers
logger.warning(f"Ticker '{symbol}' has no exchange prefix")
error_response = ErrorResponse(request_id=req.request_id, error_code="INVALID_TICKER", error_message="Ticker must include exchange prefix (e.g., BINANCE:BTC/USDT)")
await send_response(error_response)
elif msg_type == "subscribe_bars":
req = SubscribeBarsRequest(**message_json)
# TODO: Implement subscription management
response = SubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
await send_response(response)
elif msg_type == "unsubscribe_bars":
req = UnsubscribeBarsRequest(**message_json)
response = UnsubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
await send_response(response)
except Exception as e:
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
error_response = ErrorResponse(request_id=request_id, error_code="INTERNAL_ERROR", error_message=str(e))
await send_response(error_response)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected - channel_id: {channel_id}")
registry.websocket = None
gateway.unregister_channel(channel_id)
except Exception as e:
logger.error(f"WebSocket error: {e}", exc_info=True)
registry.websocket = None
gateway.unregister_channel(channel_id)
@app.post("/api/upload")
async def upload_file(file: UploadFile = File(...)):
"""Upload a file and return its URL."""
try:
# Generate unique filename
file_extension = Path(file.filename).suffix if file.filename else ""
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = UPLOAD_DIR / unique_filename
# Save file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Return URL (relative to backend)
file_url = f"/uploads/{unique_filename}"
logger.info(f"File uploaded: {file.filename} -> {file_url}")
return {
"url": file_url,
"filename": file.filename,
"size": file_path.stat().st_size
}
except Exception as e:
logger.error(f"File upload error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/healthz")
async def health():
return {"status": "ok"}
# Background task to simulate backend updates (optional, for demo)
async def simulate_backend_updates():
while True:
await asyncio.sleep(5)
if registry.websocket:
# Example: could add/modify orders here
await registry.push_all()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=config["www_port"])

View File

@@ -0,0 +1,21 @@
from typing import Optional
from pydantic import BaseModel, Field
class ChartState(BaseModel):
"""Tracks the user's current chart view state.
This state is synchronized between the frontend and backend to allow
the AI agent to understand what the user is currently viewing.
"""
# Current symbol being viewed (e.g., "BINANCE:BTC/USDT", "BINANCE:ETH/USDT")
symbol: str = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol")
# Time range currently visible on chart (Unix timestamps in seconds)
# These represent the leftmost and rightmost visible candle times
start_time: Optional[int] = Field(default=None, description="Start time of visible range (Unix timestamp in seconds)")
end_time: Optional[int] = Field(default=None, description="End time of visible range (Unix timestamp in seconds)")
# Optional: Chart interval/resolution
interval: str = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D')")

View File

@@ -0,0 +1,140 @@
from enum import StrEnum
from typing import Annotated
from pydantic import BaseModel, Field, BeforeValidator, PlainSerializer
# ---------------------------------------------------------------------------
# Scalar coercion helpers
# ---------------------------------------------------------------------------
def _to_int(v: int | str) -> int:
return int(v, 0) if isinstance(v, str) else int(v)
def _to_float(v: float | int | str) -> float:
return float(v)
_int_to_str = PlainSerializer(str, return_type=str, when_used="json")
_float_to_str = PlainSerializer(str, return_type=str, when_used="json")
# Always stored as Python int; accepts int or string on input; serialises to string in JSON.
type Uint8 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
type Uint16 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
type Uint24 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
type Uint32 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
type Uint64 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
type Uint256 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
type Float = Annotated[float, BeforeValidator(_to_float), _float_to_str]
ETH_ADDRESS_PATTERN = r"^0x[0-9a-fA-F]{40}$"
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class Exchange(StrEnum):
UNISWAP_V2 = "UniswapV2"
UNISWAP_V3 = "UniswapV3"
class OcoMode(StrEnum):
NO_OCO = "NO_OCO"
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
CANCEL_ON_COMPLETION = "CANCEL_ON_COMPLETION"
# ---------------------------------------------------------------------------
# Supporting models
# ---------------------------------------------------------------------------
class Route(BaseModel):
model_config = {"extra": "forbid"}
exchange: Exchange
fee: Uint24 = Field(description="Pool fee tier; also used as maxFee on UniswapV3")
class Line(BaseModel):
"""Price line: price = intercept + slope * time. Both zero means line is disabled."""
model_config = {"extra": "forbid"}
intercept: Float
slope: Float
class Tranche(BaseModel):
model_config = {"extra": "forbid"}
fraction: Uint16 = Field(description="Fraction of total order amount; MAX_FRACTION (65535) = 100%")
startTimeIsRelative: bool
endTimeIsRelative: bool
minIsBarrier: bool = Field(description="Not yet supported")
maxIsBarrier: bool = Field(description="Not yet supported")
marketOrder: bool = Field(
description="If true, min/max lines ignored; minLine intercept treated as max slippage"
)
minIsRatio: bool
maxIsRatio: bool
rateLimitFraction: Uint16 = Field(description="Max fraction of this tranche's amount per rate-limited execution")
rateLimitPeriod: Uint24 = Field(description="Seconds between rate limit resets")
startTime: Uint32 = Field(description="Unix timestamp; 0 (DISTANT_PAST) effectively disables")
endTime: Uint32 = Field(description="Unix timestamp; 4294967295 (DISTANT_FUTURE) effectively disables")
minLine: Line = Field(description="Traditional limit order constraint; can be diagonal")
maxLine: Line = Field(description="Upper price boundary (too-good-a-price guard)")
class TrancheStatus(BaseModel):
model_config = {"extra": "forbid"}
filled: Uint256 = Field(description="Amount filled by this tranche")
activationTime: Uint32 = Field(description="Earliest time this tranche can execute; 0 = not yet concrete")
startTime: Uint32 = Field(description="Concrete start timestamp")
endTime: Uint32 = Field(description="Concrete end timestamp")
# ---------------------------------------------------------------------------
# Order models
# ---------------------------------------------------------------------------
class SwapOrder(BaseModel):
model_config = {"extra": "forbid"}
tokenIn: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 input token address")
tokenOut: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 output token address")
route: Route
amount: Uint256 = Field(description="Maximum quantity to fill")
minFillAmount: Uint256 = Field(description="Minimum tranche amount before tranche is considered complete")
amountIsInput: bool = Field(description="true = amount is tokenIn quantity; false = tokenOut")
outputDirectlyToOwner: bool = Field(description="true = proceeds go to vault owner; false = vault")
inverted: bool = Field(description="false = tokenIn/tokenOut price direction (Uniswap natural)")
conditionalOrder: Uint64 = Field(
description="NO_CONDITIONAL_ORDER = 2^64-1; high bit set = relative index within placement group"
)
tranches: list[Tranche] = Field(min_length=1)
class OcoGroup(BaseModel):
model_config = {"extra": "forbid"}
mode: OcoMode
orders: list[SwapOrder] = Field(min_length=1)
class SwapOrderStatus(BaseModel):
model_config = {"extra": "forbid"}
order: SwapOrder
fillFeeHalfBps: Uint8 = Field(description="Fill fee in half-bps (1/20000); max 255 = 1.275%")
canceled: bool = Field(description="If true, order is canceled regardless of cancelAllIndex")
startTime: Uint32 = Field(description="Earliest block.timestamp at which order may execute")
ocoGroup: Uint64 = Field(description="Index into ocoGroups; NO_OCO_INDEX = 2^64-1")
originalOrder: Uint64 = Field(description="Index of the original order in the orders array")
startPrice: Uint256 = Field(description="Price at order start")
filled: Uint256 = Field(description="Total amount filled so far")
trancheStatus: list[TrancheStatus]

View File

@@ -0,0 +1,26 @@
from typing import Any, Dict, List, Literal, Union
from pydantic import BaseModel
class SnapshotMessage(BaseModel):
type: Literal["snapshot"] = "snapshot"
store: str
seq: int
state: Dict[str, Any]
class PatchMessage(BaseModel):
type: Literal["patch"] = "patch"
store: str
seq: int
patch: List[Dict[str, Any]]
class HelloMessage(BaseModel):
type: Literal["hello"] = "hello"
seqs: Dict[str, int]
# Union type for all messages from backend to frontend
BackendMessage = Union[SnapshotMessage, PatchMessage]
# Union type for all messages from frontend to backend
FrontendMessage = Union[HelloMessage, PatchMessage]

View File

@@ -0,0 +1,174 @@
from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Deque
import jsonpatch
from pydantic import BaseModel
from sync.protocol import SnapshotMessage, PatchMessage
class SyncEntry:
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
self.model = model
self.store_name = store_name
self.seq = 0
self.last_snapshot = model.model_dump(mode="json")
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
current_state = self.model.model_dump(mode="json")
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
if not patch.patch:
return None
return patch.patch
def commit_patch(self, patch: List[Dict[str, Any]]):
self.seq += 1
self.history.append((self.seq, patch))
self.last_snapshot = self.model.model_dump(mode="json")
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
if since_seq == self.seq:
return []
# Check if all patches from since_seq + 1 to self.seq are in history
if not self.history or self.history[0][0] > since_seq + 1:
return None
result = []
for seq, patch in self.history:
if seq > since_seq:
result.append((seq, patch))
return result
class SyncRegistry:
def __init__(self):
self.entries: Dict[str, SyncEntry] = {}
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
def register(self, model: BaseModel, store_name: str):
self.entries[store_name] = SyncEntry(model, store_name)
async def push_all(self):
import logging
logger = logging.getLogger(__name__)
if not self.websocket:
logger.warning("push_all: No websocket connected, cannot push updates")
return
logger.info(f"push_all: Processing {len(self.entries)} store entries")
for entry in self.entries.values():
patch = entry.compute_patch()
if patch:
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
entry.commit_patch(patch)
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
await self.websocket.send_json(msg.model_dump(mode="json"))
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
else:
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
async def sync_client(self, client_seqs: Dict[str, int]):
if not self.websocket:
return
for store_name, entry in self.entries.items():
client_seq = client_seqs.get(store_name, -1)
patches = entry.catchup_patches(client_seq)
if patches is not None:
# Replay patches
for seq, patch in patches:
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
await self.websocket.send_json(msg.model_dump(mode="json"))
else:
# Send full snapshot
msg = SnapshotMessage(
store=store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
import logging
logger = logging.getLogger(__name__)
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
entry = self.entries.get(store_name)
if not entry:
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
return
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
if client_base_seq == entry.seq:
# No conflict
logger.info("apply_client_patch: No conflict - applying patch directly")
current_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
new_state = jsonpatch.apply_patch(current_state, patch)
logger.info(f"apply_client_patch: New state after patch: {new_state}")
self._update_model(entry.model, new_state)
entry.commit_patch(patch)
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
# Don't broadcast back to client - they already have this change
# Broadcasting would cause an infinite loop
logger.info("apply_client_patch: Not broadcasting back to originating client")
elif client_base_seq < entry.seq:
# Conflict! Frontend wins.
# 1. Get backend patches since client_base_seq
backend_patches = []
for seq, p in entry.history:
if seq > client_base_seq:
backend_patches.append(p)
# 2. Apply frontend patch first to the state at client_base_seq
# But we only have the current authoritative model.
# "Apply the frontend patch first to the model (frontend wins)"
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
# Let's get the state as it was at client_base_seq if possible?
# No, history only has patches.
# Alternative: Apply frontend patch to current model.
# Then re-apply backend patches, but discard parts that overlap.
frontend_paths = {p['path'] for p in patch}
current_state = entry.model.model_dump(mode="json")
# Apply frontend patch
new_state = jsonpatch.apply_patch(current_state, patch)
# Re-apply backend patches that don't overlap
for b_patch in backend_patches:
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
if filtered_b_patch:
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
self._update_model(entry.model, new_state)
# Commit the result as a single new patch
# We need to compute what changed from last_snapshot to new_state
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
if final_patch:
entry.commit_patch(final_patch)
# Broadcast resolved state as snapshot to converge
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
# Update model using model_validate for potentially nested models
new_model = model.__class__.model_validate(new_data)
for field in model.model_fields:
setattr(model, field, getattr(new_model, field))

View File

View File

@@ -0,0 +1,54 @@
import asyncio
import json
import websockets
import pytest
from pydantic import BaseModel
async def test_websocket_sync():
uri = "ws://localhost:8000/ws"
async with websockets.connect(uri) as websocket:
# 1. Send hello
hello = {
"type": "hello",
"seqs": {}
}
await websocket.send(json.dumps(hello))
# 2. Receive snapshots
# Expecting TraderState and StrategyState
responses = []
for _ in range(2):
resp = await websocket.recv()
responses.append(json.loads(resp))
assert any(r["store"] == "TraderState" for r in responses)
assert any(r["store"] == "StrategyState" for r in responses)
# 3. Send a patch for TraderState
trader_resp = next(r for r in responses if r["store"] == "TraderState")
current_seq = trader_resp["seq"]
patch_msg = {
"type": "patch",
"store": "TraderState",
"seq": current_seq,
"patch": [{"op": "replace", "path": "/status", "value": "busy"}]
}
await websocket.send(json.dumps(patch_msg))
# 4. Receive confirmation patch
confirm_resp = await websocket.recv()
confirm_json = json.loads(confirm_resp)
assert confirm_json["type"] == "patch"
assert confirm_json["store"] == "TraderState"
assert confirm_json["seq"] == current_seq + 1
assert confirm_json["patch"][0]["value"] == "busy"
if __name__ == "__main__":
# This script requires the server to be running:
# PYTHONPATH=backend/src python3 backend/src/main.py
try:
asyncio.run(test_websocket_sync())
print("Test passed!")
except Exception as e:
print(f"Test failed: {e}")