backend redesign
This commit is contained in:
4
backend.old/src/agent/__init__.py
Normal file
4
backend.old/src/agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Don't import at module level to avoid circular imports
|
||||
# Users should import directly: from agent.core import create_agent
|
||||
|
||||
__all__ = ["core", "tools"]
|
||||
379
backend.old/src/agent/core.py
Normal file
379
backend.old/src/agent/core.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncIterator, Dict, Any, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
from agent.subagent import SubAgent
|
||||
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage as GatewayUserMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""LangGraph-based agent executor with streaming support.
|
||||
|
||||
Handles agent invocation, tool execution, and response streaming.
|
||||
Supports interruption for real-time user interaction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
memory_manager: Optional[MemoryManager] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize agent executor.
|
||||
|
||||
Args:
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
memory_manager: MemoryManager instance
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Memory and session management
|
||||
self.memory_manager = memory_manager or MemoryManager()
|
||||
self.session_manager = SessionManager(self.memory_manager)
|
||||
self.agent = None # Will be created after initialization
|
||||
|
||||
# Sub-agents (only if using hierarchical tools)
|
||||
self.chart_agent = None
|
||||
self.data_agent = None
|
||||
self.automation_agent = None
|
||||
self.research_agent = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent system."""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# Create agent with tools and LangGraph checkpointer
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
|
||||
# Create specialized sub-agents
|
||||
logger.info("Initializing hierarchical agent architecture with sub-agents")
|
||||
|
||||
self.chart_agent = SubAgent(
|
||||
name="chart",
|
||||
soul_file="chart_agent.md",
|
||||
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.data_agent = SubAgent(
|
||||
name="data",
|
||||
soul_file="data_agent.md",
|
||||
tools=DATASOURCE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.automation_agent = SubAgent(
|
||||
name="automation",
|
||||
soul_file="automation_agent.md",
|
||||
tools=TRIGGER_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.research_agent = SubAgent(
|
||||
name="research",
|
||||
soul_file="research_agent.md",
|
||||
tools=RESEARCH_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
# Set global sub-agent instances for router tools
|
||||
set_chart_agent(self.chart_agent)
|
||||
set_data_agent(self.data_agent)
|
||||
set_automation_agent(self.automation_agent)
|
||||
set_research_agent(self.research_agent)
|
||||
|
||||
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
|
||||
logger.info("Main agent using router tools (4 routers + sync tools)")
|
||||
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
|
||||
|
||||
# Create main agent without a static system prompt
|
||||
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
agent_tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(f"Agent initialized with {len(agent_tools)} tools")
|
||||
|
||||
async def _clear_checkpoint(self, session_id: str) -> None:
|
||||
"""Clear the checkpoint for a session to prevent resuming from invalid state.
|
||||
|
||||
This is called when an error occurs during agent execution to ensure
|
||||
the next interaction starts fresh instead of trying to resume from
|
||||
a broken state (e.g., orphaned tool calls).
|
||||
|
||||
Args:
|
||||
session_id: The session ID whose checkpoint should be cleared
|
||||
"""
|
||||
try:
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
if checkpointer:
|
||||
# Delete all checkpoints for this thread
|
||||
# LangGraph uses thread_id as the key
|
||||
# The checkpointer API doesn't have a direct delete method,
|
||||
# but we can use the underlying connection
|
||||
async with checkpointer.conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||
(session_id,)
|
||||
)
|
||||
await checkpointer.conn.commit()
|
||||
logger.info(f"Cleared checkpoint for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: UserSession,
|
||||
message: GatewayUserMessage
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute the agent and stream responses.
|
||||
|
||||
Args:
|
||||
session: User session
|
||||
message: User message
|
||||
|
||||
Yields:
|
||||
Response chunks as they're generated
|
||||
"""
|
||||
logger.info(f"AgentExecutor.execute called for session {session.session_id}")
|
||||
|
||||
# Get session lock to prevent concurrent execution
|
||||
lock = await self.session_manager.get_session_lock(session.session_id)
|
||||
logger.info(f"Session lock acquired for {session.session_id}")
|
||||
|
||||
async with lock:
|
||||
try:
|
||||
# Build system prompt with current context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, session.active_channels)
|
||||
|
||||
# Build message history WITHOUT prepending system message
|
||||
# The system prompt will be passed via state_modifier in the config
|
||||
messages = []
|
||||
history = session.get_history(limit=10)
|
||||
logger.info(f"Building message history, {len(history)} messages in history")
|
||||
|
||||
for i, msg in enumerate(history):
|
||||
logger.info(f"History message {i}: role={msg.role}, content_len={len(msg.content)}, content='{msg.content[:100]}'")
|
||||
if msg.role == "user":
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
|
||||
for i, msg in enumerate(messages):
|
||||
msg_type = type(msg).__name__
|
||||
content_preview = msg.content[:100] if msg.content else 'EMPTY'
|
||||
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
|
||||
|
||||
# Prepare config with metadata and dynamic system prompt
|
||||
# Pass system_prompt via state_modifier to avoid multiple system messages
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": session.session_id,
|
||||
"state_modifier": system_prompt # Dynamic system prompt injection
|
||||
},
|
||||
metadata={
|
||||
"session_id": session.session_id,
|
||||
"user_id": session.user_id,
|
||||
"active_channels": session.active_channels
|
||||
}
|
||||
)
|
||||
logger.info(f"Agent config prepared: thread_id={session.session_id}")
|
||||
|
||||
# Invoke agent with streaming
|
||||
logger.info("Starting agent.astream_events...")
|
||||
full_response = ""
|
||||
event_count = 0
|
||||
chunk_count = 0
|
||||
|
||||
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
|
||||
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
# Check for cancellation
|
||||
if asyncio.current_task().cancelled():
|
||||
logger.warning("Agent execution cancelled")
|
||||
break
|
||||
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
logger.info(f"Tool call started: {tool_name} with input: {tool_input}")
|
||||
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_output = event.get("data", {}).get("output")
|
||||
|
||||
# LangChain may wrap the output in a ToolMessage with content field
|
||||
# Try to extract the actual content from the ToolMessage
|
||||
actual_output = tool_output
|
||||
if hasattr(tool_output, "content"):
|
||||
actual_output = tool_output.content
|
||||
|
||||
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
|
||||
|
||||
# Extract plot_urls from execute_python tool results
|
||||
if tool_name == "execute_python":
|
||||
# Try to parse as JSON if it's a string
|
||||
import json
|
||||
if isinstance(actual_output, str):
|
||||
try:
|
||||
actual_output = json.loads(actual_output)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
|
||||
|
||||
if isinstance(actual_output, dict):
|
||||
tool_plot_urls = actual_output.get("plot_urls", [])
|
||||
if tool_plot_urls:
|
||||
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
|
||||
plot_urls.extend(tool_plot_urls)
|
||||
# Yield metadata about plots immediately
|
||||
yield {
|
||||
"content": "",
|
||||
"metadata": {"plot_urls": tool_plot_urls}
|
||||
}
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content blocks
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content: # Only yield non-empty content
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
logger.debug(f"Yielding content chunk #{chunk_count}")
|
||||
yield content
|
||||
|
||||
logger.info(f"Agent streaming complete: {event_count} events, {chunk_count} content chunks, {len(full_response)} chars")
|
||||
|
||||
# Save to persistent memory
|
||||
logger.info(f"Saving assistant message to persistent memory")
|
||||
await self.session_manager.save_message(
|
||||
session.session_id,
|
||||
"assistant",
|
||||
full_response
|
||||
)
|
||||
logger.info("Assistant message saved to memory")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"Agent execution cancelled for session {session.session_id}")
|
||||
# Clear checkpoint on cancellation to prevent orphaned tool calls
|
||||
await self._clear_checkpoint(session.session_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Clear checkpoint on error to prevent invalid state (e.g., orphaned tool calls)
|
||||
# This ensures the next interaction starts fresh instead of trying to resume
|
||||
# from a broken state
|
||||
await self._clear_checkpoint(session.session_id)
|
||||
|
||||
yield error_msg
|
||||
|
||||
|
||||
def create_agent(
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
checkpoint_db_path: str = "data/checkpoints.db",
|
||||
chroma_db_path: str = "data/chroma",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
context_docs_dir: str = "doc",
|
||||
base_dir: str = "."
|
||||
) -> AgentExecutor:
|
||||
"""Create and initialize an agent executor.
|
||||
|
||||
Args:
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
checkpoint_db_path: Path to LangGraph checkpoint SQLite DB
|
||||
chroma_db_path: Path to ChromaDB storage directory
|
||||
embedding_model: Sentence-transformers model name
|
||||
context_docs_dir: Directory with context markdown files
|
||||
base_dir: Base directory for resolving paths
|
||||
|
||||
Returns:
|
||||
Initialized AgentExecutor with hierarchical tool routing
|
||||
"""
|
||||
# Initialize memory manager
|
||||
memory_manager = MemoryManager(
|
||||
checkpoint_db_path=checkpoint_db_path,
|
||||
chroma_db_path=chroma_db_path,
|
||||
embedding_model=embedding_model,
|
||||
context_docs_dir=context_docs_dir,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = AgentExecutor(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
memory_manager=memory_manager,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
return executor
|
||||
380
backend.old/src/agent/memory.py
Normal file
380
backend.old/src/agent/memory.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import os
|
||||
import glob
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import aiofiles
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
# Prevent ChromaDB from reporting telemetry to the mothership
|
||||
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
||||
|
||||
class MemoryManager:
|
||||
"""Manages persistent memory using local tools:
|
||||
|
||||
- LangGraph checkpointing (SQLite) for conversation state
|
||||
- ChromaDB for semantic memory search
|
||||
- Local sentence-transformers for embeddings
|
||||
- Memory graph approach for clustering related concepts
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_db_path: str = "data/checkpoints.db",
|
||||
chroma_db_path: str = "data/chroma",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
context_docs_dir: str = "memory",
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize memory manager.
|
||||
|
||||
Args:
|
||||
checkpoint_db_path: Path to SQLite checkpoint database
|
||||
chroma_db_path: Path to ChromaDB directory
|
||||
embedding_model: Sentence-transformers model name
|
||||
context_docs_dir: Directory containing markdown context files
|
||||
base_dir: Base directory for resolving relative paths
|
||||
"""
|
||||
self.checkpoint_db_path = checkpoint_db_path
|
||||
self.chroma_db_path = chroma_db_path
|
||||
self.embedding_model_name = embedding_model
|
||||
self.context_docs_dir = os.path.join(base_dir, context_docs_dir)
|
||||
|
||||
# Will be initialized on startup
|
||||
self.checkpointer: Optional[AsyncSqliteSaver] = None
|
||||
self.checkpointer_context: Optional[Any] = None # Store the context manager
|
||||
self.chroma_client: Optional[chromadb.Client] = None
|
||||
self.memory_collection: Optional[Any] = None
|
||||
self.embedding_model: Optional[SentenceTransformer] = None
|
||||
|
||||
self.context_documents: Dict[str, str] = {}
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the memory system and load context documents."""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
# Ensure data directories exist
|
||||
os.makedirs(os.path.dirname(self.checkpoint_db_path), exist_ok=True)
|
||||
os.makedirs(self.chroma_db_path, exist_ok=True)
|
||||
|
||||
# Initialize LangGraph checkpointer (SQLite)
|
||||
self.checkpointer_context = AsyncSqliteSaver.from_conn_string(
|
||||
self.checkpoint_db_path
|
||||
)
|
||||
self.checkpointer = await self.checkpointer_context.__aenter__()
|
||||
await self.checkpointer.setup()
|
||||
|
||||
# Initialize ChromaDB
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path=self.chroma_db_path,
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True
|
||||
)
|
||||
)
|
||||
|
||||
# Get or create memory collection
|
||||
self.memory_collection = self.chroma_client.get_or_create_collection(
|
||||
name="conversation_memory",
|
||||
metadata={"description": "Semantic memory for conversations"}
|
||||
)
|
||||
|
||||
# Initialize local embedding model
|
||||
print(f"Loading embedding model: {self.embedding_model_name}")
|
||||
self.embedding_model = SentenceTransformer(self.embedding_model_name)
|
||||
|
||||
# Load markdown context documents
|
||||
await self._load_context_documents()
|
||||
|
||||
# Index context documents in ChromaDB
|
||||
await self._index_context_documents()
|
||||
|
||||
self.initialized = True
|
||||
print("Memory system initialized (LangGraph + ChromaDB + local embeddings)")
|
||||
|
||||
async def _load_context_documents(self) -> None:
|
||||
"""Load all markdown files from context directory."""
|
||||
if not os.path.exists(self.context_docs_dir):
|
||||
print(f"Warning: Context directory {self.context_docs_dir} not found")
|
||||
return
|
||||
|
||||
md_files = glob.glob(os.path.join(self.context_docs_dir, "*.md"))
|
||||
|
||||
for md_file in md_files:
|
||||
try:
|
||||
async with aiofiles.open(md_file, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
filename = os.path.basename(md_file)
|
||||
self.context_documents[filename] = content
|
||||
print(f"Loaded context document: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Error loading {md_file}: {e}")
|
||||
|
||||
async def _index_context_documents(self) -> None:
|
||||
"""Index context documents in ChromaDB for semantic search."""
|
||||
if not self.context_documents or not self.memory_collection:
|
||||
return
|
||||
|
||||
for filename, content in self.context_documents.items():
|
||||
# Split into sections (by headers)
|
||||
sections = self._split_document_into_sections(content, filename)
|
||||
|
||||
for i, section in enumerate(sections):
|
||||
doc_id = f"context_{filename}_{i}"
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedding_model.encode(section["content"]).tolist()
|
||||
|
||||
# Add to ChromaDB
|
||||
self.memory_collection.add(
|
||||
ids=[doc_id],
|
||||
embeddings=[embedding],
|
||||
documents=[section["content"]],
|
||||
metadatas=[{
|
||||
"type": "context",
|
||||
"source": filename,
|
||||
"section": section["title"],
|
||||
"indexed_at": datetime.utcnow().isoformat()
|
||||
}]
|
||||
)
|
||||
|
||||
print(f"Indexed {len(self.context_documents)} context documents")
|
||||
|
||||
def _split_document_into_sections(self, content: str, filename: str) -> List[Dict[str, str]]:
|
||||
"""Split markdown document into logical sections.
|
||||
|
||||
Args:
|
||||
content: Markdown content
|
||||
filename: Source filename
|
||||
|
||||
Returns:
|
||||
List of section dicts with title and content
|
||||
"""
|
||||
sections = []
|
||||
current_section = {"title": filename, "content": ""}
|
||||
|
||||
for line in content.split("\n"):
|
||||
if line.startswith("#"):
|
||||
# New section
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
current_section = {
|
||||
"title": line.strip("#").strip(),
|
||||
"content": line + "\n"
|
||||
}
|
||||
else:
|
||||
current_section["content"] += line + "\n"
|
||||
|
||||
# Add last section
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
def get_context_prompt(self) -> str:
|
||||
"""Generate a context prompt from loaded documents.
|
||||
|
||||
system_prompt.md is ALWAYS included first and prioritized.
|
||||
Other documents are included after.
|
||||
|
||||
Returns:
|
||||
Formatted string containing all context documents
|
||||
"""
|
||||
if not self.context_documents:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
|
||||
# ALWAYS include system_prompt.md first if it exists
|
||||
system_prompt_key = "system_prompt.md"
|
||||
if system_prompt_key in self.context_documents:
|
||||
sections.append(self.context_documents[system_prompt_key])
|
||||
sections.append("\n---\n")
|
||||
|
||||
# Add other context documents
|
||||
sections.append("# Additional Context\n")
|
||||
sections.append("The following documents provide additional context about the system:\n")
|
||||
|
||||
for filename, content in sorted(self.context_documents.items()):
|
||||
# Skip system_prompt.md since we already added it
|
||||
if filename == system_prompt_key:
|
||||
continue
|
||||
|
||||
sections.append(f"\n## {filename}\n")
|
||||
sections.append(content)
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Add a message to semantic memory (ChromaDB).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
role: Message role ("user" or "assistant")
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
if not self.memory_collection or not self.embedding_model:
|
||||
return
|
||||
|
||||
try:
|
||||
# Generate unique ID
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
doc_id = f"{session_id}_{role}_{timestamp}"
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedding_model.encode(content).tolist()
|
||||
|
||||
# Prepare metadata
|
||||
meta = {
|
||||
"session_id": session_id,
|
||||
"role": role,
|
||||
"timestamp": timestamp,
|
||||
"type": "conversation",
|
||||
**(metadata or {})
|
||||
}
|
||||
|
||||
# Add to ChromaDB
|
||||
self.memory_collection.add(
|
||||
ids=[doc_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[meta]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error adding to ChromaDB memory: {e}")
|
||||
|
||||
async def search_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
include_context: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search memory using semantic similarity.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier (filters to this session + context docs)
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
include_context: Whether to include context documents in search
|
||||
|
||||
Returns:
|
||||
List of relevant memory items with content and metadata
|
||||
"""
|
||||
if not self.memory_collection or not self.embedding_model:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate query embedding
|
||||
query_embedding = self.embedding_model.encode(query).tolist()
|
||||
|
||||
# Build where filter
|
||||
where_filters = []
|
||||
if include_context:
|
||||
# Search both session messages and context docs
|
||||
where_filters = {
|
||||
"$or": [
|
||||
{"session_id": session_id},
|
||||
{"type": "context"}
|
||||
]
|
||||
}
|
||||
else:
|
||||
where_filters = {"session_id": session_id}
|
||||
|
||||
# Query ChromaDB
|
||||
results = self.memory_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=limit,
|
||||
where=where_filters if where_filters else None
|
||||
)
|
||||
|
||||
# Format results
|
||||
memories = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
memories.append({
|
||||
"content": doc,
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i] if "distances" in results else None
|
||||
})
|
||||
|
||||
return memories
|
||||
except Exception as e:
|
||||
print(f"Error searching ChromaDB memory: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_graph(
|
||||
self,
|
||||
session_id: str,
|
||||
max_depth: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a graph of related memories using clustering.
|
||||
|
||||
This creates a simple memory graph by finding clusters of related concepts.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
max_depth: Maximum depth for related memory traversal
|
||||
|
||||
Returns:
|
||||
Dict representing memory graph structure
|
||||
"""
|
||||
# Simple implementation: get all memories for session and cluster by similarity
|
||||
if not self.memory_collection:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# Get all memories for this session
|
||||
results = self.memory_collection.get(
|
||||
where={"session_id": session_id},
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
)
|
||||
|
||||
if not results or not results["documents"]:
|
||||
return {"nodes": [], "edges": []}
|
||||
|
||||
# Build simple graph structure
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
for i, doc in enumerate(results["documents"]):
|
||||
nodes.append({
|
||||
"id": results["ids"][i],
|
||||
"content": doc,
|
||||
"metadata": results["metadatas"][i]
|
||||
})
|
||||
|
||||
# TODO: Compute edges based on embedding similarity
|
||||
# For now, return just nodes
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error building memory graph: {e}")
|
||||
return {}
|
||||
|
||||
def get_checkpointer(self) -> Optional[AsyncSqliteSaver]:
|
||||
"""Get the LangGraph checkpointer for conversation state.
|
||||
|
||||
Returns:
|
||||
AsyncSqliteSaver instance for LangGraph persistence
|
||||
"""
|
||||
return self.checkpointer
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the memory manager and cleanup resources."""
|
||||
if self.checkpointer_context:
|
||||
await self.checkpointer_context.__aexit__(None, None, None)
|
||||
self.checkpointer = None
|
||||
self.checkpointer_context = None
|
||||
118
backend.old/src/agent/prompts.py
Normal file
118
backend.old/src/agent/prompts.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import List, Dict, Any
|
||||
from gateway.user_session import UserSession
|
||||
|
||||
|
||||
def _get_chart_store_context() -> str:
|
||||
"""Get current ChartStore state for context injection.
|
||||
|
||||
Returns:
|
||||
Formatted string with ChartStore contents, or empty string if unavailable
|
||||
"""
|
||||
try:
|
||||
from agent.tools import _registry
|
||||
|
||||
if not _registry:
|
||||
return ""
|
||||
|
||||
chart_store = _registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
return ""
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
# Only include if there's actual chart data
|
||||
if not chart_data or not chart_data.get("symbol"):
|
||||
return ""
|
||||
|
||||
# Format the chart information
|
||||
symbol = chart_data.get("symbol", "N/A")
|
||||
interval = chart_data.get("interval", "N/A")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
selected_shapes = chart_data.get("selected_shapes", [])
|
||||
|
||||
selected_info = ""
|
||||
if selected_shapes:
|
||||
selected_info = f"\n- **Selected Shapes**: {len(selected_shapes)} shape(s) selected (IDs: {', '.join(selected_shapes)})"
|
||||
|
||||
chart_context = f"""
|
||||
## Current Chart Context
|
||||
|
||||
The user is currently viewing a chart with the following settings:
|
||||
- **Symbol**: {symbol}
|
||||
- **Interval**: {interval}
|
||||
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}{selected_info}
|
||||
|
||||
This information is automatically available because you're connected via websocket.
|
||||
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.
|
||||
"""
|
||||
return chart_context
|
||||
|
||||
except Exception:
|
||||
# Silently fail - chart context is optional enhancement
|
||||
return ""
|
||||
|
||||
|
||||
def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||
"""Build the system prompt for the agent.
|
||||
|
||||
The main system prompt comes from system_prompt.md (loaded in context).
|
||||
This function adds dynamic session information.
|
||||
|
||||
Args:
|
||||
context: Context from loaded markdown documents (includes system_prompt.md)
|
||||
active_channels: List of active channel IDs for this session
|
||||
|
||||
Returns:
|
||||
Formatted system prompt
|
||||
"""
|
||||
channels_str = ", ".join(active_channels) if active_channels else "none"
|
||||
|
||||
# Check if user is connected via websocket - if so, inject chart context
|
||||
# Note: We check for websocket by looking for "websocket" in channel IDs
|
||||
# since WebSocketChannel uses channel_id like "websocket-{uuid}"
|
||||
has_websocket = any("websocket" in channel_id.lower() for channel_id in active_channels)
|
||||
|
||||
chart_context = ""
|
||||
if has_websocket:
|
||||
chart_context = _get_chart_store_context()
|
||||
|
||||
# Context already includes system_prompt.md and other docs
|
||||
# Just add current session information
|
||||
prompt = f"""{context}
|
||||
|
||||
## Current Session Information
|
||||
|
||||
**Active Channels**: {channels_str}
|
||||
|
||||
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
|
||||
If the user sends a new message while you're responding, your current response will be interrupted
|
||||
and you'll be re-invoked with the updated context.
|
||||
{chart_context}"""
|
||||
return prompt
|
||||
|
||||
|
||||
def build_user_prompt_with_history(session: UserSession, current_message: str) -> str:
|
||||
"""Build a user prompt including conversation history.
|
||||
|
||||
Args:
|
||||
session: User session with conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
Formatted prompt with history
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Get recent history (last 10 messages)
|
||||
history = session.get_history(limit=10)
|
||||
|
||||
for msg in history:
|
||||
role_label = "User" if msg.role == "user" else "Assistant"
|
||||
messages.append(f"{role_label}: {msg.content}")
|
||||
|
||||
# Add current message
|
||||
messages.append(f"User: {current_message}")
|
||||
|
||||
return "\n\n".join(messages)
|
||||
218
backend.old/src/agent/routers.py
Normal file
218
backend.old/src/agent/routers.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tool router functions for hierarchical agent architecture.
|
||||
|
||||
This module provides meta-tools that route tasks to specialized sub-agents.
|
||||
The main agent uses these routers instead of accessing all tools directly.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global sub-agent instances (set by create_agent)
|
||||
_chart_agent = None
|
||||
_data_agent = None
|
||||
_automation_agent = None
|
||||
_research_agent = None
|
||||
|
||||
|
||||
def set_chart_agent(agent):
|
||||
"""Set the global chart sub-agent instance."""
|
||||
global _chart_agent
|
||||
_chart_agent = agent
|
||||
|
||||
|
||||
def set_data_agent(agent):
|
||||
"""Set the global data sub-agent instance."""
|
||||
global _data_agent
|
||||
_data_agent = agent
|
||||
|
||||
|
||||
def set_automation_agent(agent):
|
||||
"""Set the global automation sub-agent instance."""
|
||||
global _automation_agent
|
||||
_automation_agent = agent
|
||||
|
||||
|
||||
def set_research_agent(agent):
|
||||
"""Set the global research sub-agent instance."""
|
||||
global _research_agent
|
||||
_research_agent = agent
|
||||
|
||||
|
||||
@tool
|
||||
async def use_chart_analysis(task: str) -> str:
|
||||
"""Analyze charts, compute indicators, execute Python code, and create visualizations.
|
||||
|
||||
This tool delegates to a specialized chart analysis agent that has access to:
|
||||
- Chart data retrieval (get_chart_data)
|
||||
- Python execution environment with pandas, numpy, matplotlib, talib
|
||||
- Technical indicator tools (add/remove indicators, search indicators)
|
||||
- Shape drawing tools (create/update/delete shapes on charts)
|
||||
|
||||
Use this when the user wants to:
|
||||
- Analyze price action or patterns
|
||||
- Calculate technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Execute custom Python analysis on OHLCV data
|
||||
- Generate charts and visualizations
|
||||
- Draw trendlines, support/resistance, or other shapes
|
||||
- Perform statistical analysis on market data
|
||||
|
||||
Args:
|
||||
task: Detailed description of the chart analysis task. Include:
|
||||
- What to analyze (which symbol, timeframe if different from current)
|
||||
- What indicators or calculations to perform
|
||||
- What visualizations to create
|
||||
- Any specific questions to answer
|
||||
|
||||
Returns:
|
||||
The chart agent's analysis results, including computed values,
|
||||
plot URLs if visualizations were created, and interpretation.
|
||||
|
||||
Examples:
|
||||
- "Calculate RSI(14) for the current chart and tell me if it's overbought"
|
||||
- "Draw a trendline connecting the last 3 swing lows"
|
||||
- "Compute Bollinger Bands (20, 2) and create a chart showing price vs bands"
|
||||
- "Analyze the last 100 bars and identify key support/resistance levels"
|
||||
- "Execute Python: calculate correlation between BTC and ETH over the last 30 days"
|
||||
"""
|
||||
if not _chart_agent:
|
||||
return "Error: Chart analysis agent not initialized"
|
||||
|
||||
logger.info(f"Routing to chart agent: {task[:100]}...")
|
||||
result = await _chart_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_data_access(task: str) -> str:
|
||||
"""Search for symbols and retrieve market data from exchanges.
|
||||
|
||||
This tool delegates to a specialized data access agent that has access to:
|
||||
- Symbol search across multiple exchanges
|
||||
- Historical OHLCV data retrieval
|
||||
- Symbol metadata and info
|
||||
- Available data sources and exchanges
|
||||
|
||||
Use this when the user wants to:
|
||||
- Search for a trading symbol or ticker
|
||||
- Get historical price data
|
||||
- Find out what exchanges support a symbol
|
||||
- Retrieve symbol metadata (price scale, supported resolutions, etc.)
|
||||
- Check what data sources are available
|
||||
|
||||
Args:
|
||||
task: Detailed description of the data access task. Include:
|
||||
- What symbol or instrument to search for
|
||||
- What data to retrieve (time range, resolution)
|
||||
- What metadata is needed
|
||||
|
||||
Returns:
|
||||
The data agent's response with requested symbols, data, or metadata.
|
||||
|
||||
Examples:
|
||||
- "Search for Bitcoin symbols on Binance"
|
||||
- "Get the last 100 hours of BTC/USDT 1-hour data from Binance"
|
||||
- "Find all symbols matching 'ETH' on all exchanges"
|
||||
- "Get detailed info about symbol BTC/USDT on Binance"
|
||||
- "List all available data sources"
|
||||
"""
|
||||
if not _data_agent:
|
||||
return "Error: Data access agent not initialized"
|
||||
|
||||
logger.info(f"Routing to data agent: {task[:100]}...")
|
||||
result = await _data_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_automation(task: str) -> str:
|
||||
"""Schedule recurring tasks, create triggers, and manage automation.
|
||||
|
||||
This tool delegates to a specialized automation agent that has access to:
|
||||
- Scheduled agent prompts (cron and interval-based)
|
||||
- One-time agent prompt execution
|
||||
- Trigger management (list, cancel scheduled jobs)
|
||||
- System stats and monitoring
|
||||
|
||||
Use this when the user wants to:
|
||||
- Schedule a recurring task (hourly, daily, weekly, etc.)
|
||||
- Run a one-time background analysis
|
||||
- Set up automated monitoring or alerts
|
||||
- List or cancel existing scheduled tasks
|
||||
- Check trigger system status
|
||||
|
||||
Args:
|
||||
task: Detailed description of the automation task. Include:
|
||||
- What should happen (what analysis or action)
|
||||
- When it should happen (schedule, frequency)
|
||||
- Any priorities or conditions
|
||||
|
||||
Returns:
|
||||
The automation agent's response with job IDs, confirmation,
|
||||
or status information.
|
||||
|
||||
Examples:
|
||||
- "Schedule a task to check BTC price every 5 minutes"
|
||||
- "Run a one-time analysis of ETH volume in the background"
|
||||
- "Set up a daily report at 9 AM with market summary"
|
||||
- "Show me all my scheduled tasks"
|
||||
- "Cancel the hourly BTC monitor job"
|
||||
"""
|
||||
if not _automation_agent:
|
||||
return "Error: Automation agent not initialized"
|
||||
|
||||
logger.info(f"Routing to automation agent: {task[:100]}...")
|
||||
result = await _automation_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_research(task: str) -> str:
|
||||
"""Search the web, academic papers, and external APIs for information.
|
||||
|
||||
This tool delegates to a specialized research agent that has access to:
|
||||
- Web search (DuckDuckGo)
|
||||
- Academic paper search (arXiv)
|
||||
- Wikipedia lookup
|
||||
- HTTP requests to public APIs
|
||||
|
||||
Use this when the user wants to:
|
||||
- Search for current news or events
|
||||
- Find academic papers on trading strategies
|
||||
- Look up financial concepts or terms
|
||||
- Fetch data from external public APIs
|
||||
- Research market trends or sentiment
|
||||
|
||||
Args:
|
||||
task: Detailed description of the research task. Include:
|
||||
- What information to find
|
||||
- What sources to search (web, arxiv, wikipedia, APIs)
|
||||
- What to focus on or filter
|
||||
|
||||
Returns:
|
||||
The research agent's findings with sources, summaries, and links.
|
||||
|
||||
Examples:
|
||||
- "Search arXiv for papers on reinforcement learning for trading"
|
||||
- "Look up 'technical analysis' on Wikipedia"
|
||||
- "Search the web for latest Ethereum news"
|
||||
- "Fetch current BTC price from CoinGecko API"
|
||||
- "Find recent papers on market microstructure"
|
||||
"""
|
||||
if not _research_agent:
|
||||
return "Error: Research agent not initialized"
|
||||
|
||||
logger.info(f"Routing to research agent: {task[:100]}...")
|
||||
result = await _research_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
# Export router tools
|
||||
ROUTER_TOOLS = [
|
||||
use_chart_analysis,
|
||||
use_data_access,
|
||||
use_automation,
|
||||
use_research
|
||||
]
|
||||
93
backend.old/src/agent/session.py
Normal file
93
backend.old/src/agent/session.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from agent.memory import MemoryManager
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages agent sessions and their associated memory.
|
||||
|
||||
Coordinates between the gateway's UserSession (for conversation state)
|
||||
and the agent's MemoryManager (for persistent memory with Zep).
|
||||
"""
|
||||
|
||||
def __init__(self, memory_manager: MemoryManager):
|
||||
"""Initialize session manager.
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance for persistent storage
|
||||
"""
|
||||
self.memory = memory_manager
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
async def get_session_lock(self, session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a session.
|
||||
|
||||
Prevents concurrent execution for the same session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
asyncio.Lock for this session
|
||||
"""
|
||||
if session_id not in self._locks:
|
||||
self._locks[session_id] = asyncio.Lock()
|
||||
return self._locks[session_id]
|
||||
|
||||
async def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict] = None
|
||||
) -> None:
|
||||
"""Save a message to persistent memory.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
role: Message role ("user" or "assistant")
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
await self.memory.add_memory(session_id, role, content, metadata)
|
||||
|
||||
async def get_relevant_context(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5
|
||||
) -> str:
|
||||
"""Get relevant historical context for a query.
|
||||
|
||||
Uses semantic search over past conversations.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
Formatted string of relevant past messages
|
||||
"""
|
||||
results = await self.memory.search_memory(session_id, query, limit)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
context_parts = ["## Relevant Past Context\n"]
|
||||
for i, result in enumerate(results, 1):
|
||||
role = result["role"].capitalize()
|
||||
content = result["content"]
|
||||
score = result.get("score", 0.0)
|
||||
context_parts.append(f"{i}. [{role}, relevance: {score:.2f}] {content}\n")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
async def cleanup_session(self, session_id: str) -> None:
|
||||
"""Cleanup session resources.
|
||||
|
||||
Args:
|
||||
session_id: Session to cleanup
|
||||
"""
|
||||
if session_id in self._locks:
|
||||
del self._locks[session_id]
|
||||
248
backend.old/src/agent/subagent.py
Normal file
248
backend.old/src/agent/subagent.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Sub-agent infrastructure for specialized tool routing.
|
||||
|
||||
This module provides the SubAgent class that wraps specialized agents
|
||||
with their own tools and system prompts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubAgent:
|
||||
"""A specialized sub-agent with its own tools and system prompt.
|
||||
|
||||
Sub-agents are lightweight, stateless agents that focus on specific domains.
|
||||
They use in-memory checkpointing since they don't need persistent state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
soul_file: str,
|
||||
tools: List,
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize a sub-agent.
|
||||
|
||||
Args:
|
||||
name: Agent name (e.g., "chart", "data", "automation")
|
||||
soul_file: Filename in /soul directory (e.g., "chart_agent.md")
|
||||
tools: List of LangChain tools for this agent
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.name = name
|
||||
self.soul_file = soul_file
|
||||
self.tools = tools
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Load system prompt from soul file
|
||||
soul_path = Path(base_dir) / "soul" / soul_file
|
||||
if soul_path.exists():
|
||||
with open(soul_path, "r") as f:
|
||||
self.system_prompt = f.read()
|
||||
logger.info(f"SubAgent '{name}': Loaded system prompt from {soul_path}")
|
||||
else:
|
||||
logger.warning(f"SubAgent '{name}': Soul file not found at {soul_path}, using default")
|
||||
self.system_prompt = f"You are a specialized {name} agent."
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Create agent with in-memory checkpointer (stateless)
|
||||
checkpointer = MemorySaver()
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"SubAgent '{name}' initialized with {len(tools)} tools, "
|
||||
f"model={model_name}, temp={temperature}"
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
thread_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""Execute a task with this sub-agent.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for this sub-agent
|
||||
thread_id: Optional thread ID for checkpointing (uses ephemeral ID if not provided)
|
||||
|
||||
Returns:
|
||||
The agent's complete response as a string
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Use ephemeral thread ID if not provided
|
||||
if thread_id is None:
|
||||
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
||||
|
||||
logger.info(f"SubAgent '{self.name}': Executing task (thread_id={thread_id})")
|
||||
logger.debug(f"SubAgent '{self.name}': Task: {task[:200]}...")
|
||||
|
||||
# Build messages with system prompt
|
||||
messages = [
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
|
||||
# Prepare config with system prompt injection
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": thread_id,
|
||||
"state_modifier": self.system_prompt
|
||||
},
|
||||
metadata={
|
||||
"subagent_name": self.name
|
||||
}
|
||||
)
|
||||
|
||||
# Execute and collect response
|
||||
full_response = ""
|
||||
event_count = 0
|
||||
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
||||
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call completed: {tool_name}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content:
|
||||
full_response += content
|
||||
|
||||
logger.info(
|
||||
f"SubAgent '{self.name}': Completed task "
|
||||
f"({event_count} events, {len(full_response)} chars)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"SubAgent '{self.name}' execution error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
return full_response
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
task: str,
|
||||
thread_id: Optional[str] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute a task with streaming response.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for this sub-agent
|
||||
thread_id: Optional thread ID for checkpointing
|
||||
|
||||
Yields:
|
||||
Response chunks as they're generated
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Use ephemeral thread ID if not provided
|
||||
if thread_id is None:
|
||||
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
||||
|
||||
logger.info(f"SubAgent '{self.name}': Streaming task (thread_id={thread_id})")
|
||||
|
||||
# Build messages with system prompt
|
||||
messages = [
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
|
||||
# Prepare config
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": thread_id,
|
||||
"state_modifier": self.system_prompt
|
||||
},
|
||||
metadata={
|
||||
"subagent_name": self.name
|
||||
}
|
||||
)
|
||||
|
||||
# Stream response
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content:
|
||||
yield content
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"SubAgent '{self.name}' streaming error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield f"Error: {error_msg}"
|
||||
139
backend.old/src/agent/tools/CHART_UTILS_README.md
Normal file
139
backend.old/src/agent/tools/CHART_UTILS_README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# Chart Utilities - Standard OHLC Plotting
|
||||
|
||||
## Overview
|
||||
|
||||
The `chart_utils.py` module provides convenience functions for creating beautiful, professional OHLC candlestick charts with a consistent look and feel. This is designed to be used by the LLM in `analyze_chart_data` scripts, eliminating the need to write custom matplotlib code for every chart.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Beautiful by default**: Uses mplfinance with seaborn-inspired aesthetics
|
||||
- **Consistent styling**: Professional color scheme (teal green up, coral red down)
|
||||
- **Easy to use**: Simple function calls instead of complex matplotlib code
|
||||
- **Customizable**: Supports all mplfinance options via kwargs
|
||||
- **Volume integration**: Optional volume subplot
|
||||
|
||||
## Installation
|
||||
|
||||
The required package `mplfinance` has been added to `requirements.txt`:
|
||||
|
||||
```bash
|
||||
pip install mplfinance
|
||||
```
|
||||
|
||||
## Available Functions
|
||||
|
||||
### 1. `plot_ohlc(df, title=None, volume=True, figsize=(14, 8), **kwargs)`
|
||||
|
||||
Main function for creating standard OHLC candlestick charts.
|
||||
|
||||
**Parameters:**
|
||||
- `df`: pandas DataFrame with DatetimeIndex and OHLCV columns
|
||||
- `title`: Optional chart title
|
||||
- `volume`: Whether to include volume subplot (default: True)
|
||||
- `figsize`: Figure size in inches (default: (14, 8))
|
||||
- `**kwargs`: Additional mplfinance.plot() arguments
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
fig = plot_ohlc(df, title='BTC/USDT 15min', volume=True)
|
||||
```
|
||||
|
||||
### 2. `add_indicators_to_plot(df, indicators, **plot_kwargs)`
|
||||
|
||||
Creates OHLC chart with technical indicators overlaid.
|
||||
|
||||
**Parameters:**
|
||||
- `df`: DataFrame with OHLCV data and indicator columns
|
||||
- `indicators`: Dict mapping indicator column names to display parameters
|
||||
- `**plot_kwargs`: Additional arguments for plot_ohlc()
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5}
|
||||
},
|
||||
title='Price with Moving Averages'
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Preset Functions
|
||||
|
||||
- `plot_price_volume(df, title=None)` - Standard price + volume chart
|
||||
- `plot_price_only(df, title=None)` - Candlesticks without volume
|
||||
|
||||
## Integration with analyze_chart_data
|
||||
|
||||
These functions are automatically available in the `analyze_chart_data` tool's script environment:
|
||||
|
||||
```python
|
||||
# In an analyze_chart_data script:
|
||||
# df is already provided
|
||||
|
||||
# Simple usage
|
||||
fig = plot_ohlc(df, title='Price Action')
|
||||
|
||||
# With indicators
|
||||
df['SMA'] = df['close'].rolling(20).mean()
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={'SMA': {'color': 'blue', 'width': 1.5}},
|
||||
title='Price with SMA'
|
||||
)
|
||||
|
||||
# Return data for the assistant
|
||||
df[['close', 'SMA']].tail(10)
|
||||
```
|
||||
|
||||
## Styling
|
||||
|
||||
The default style includes:
|
||||
- **Up candles**: Teal green (#26a69a)
|
||||
- **Down candles**: Coral red (#ef5350)
|
||||
- **Background**: Light gray with white axes
|
||||
- **Grid**: Subtle dashed lines with 30% alpha
|
||||
- **Professional fonts**: Clean, readable sizes
|
||||
|
||||
## Why This Matters
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
# LLM had to write this every time
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
ax.plot(df.index, df['close'], label='Close')
|
||||
# ... lots more code for styling, colors, etc.
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
# LLM can now just do this
|
||||
fig = plot_ohlc(df, title='BTC/USDT')
|
||||
```
|
||||
|
||||
Benefits:
|
||||
- ✅ Less code to generate → faster response
|
||||
- ✅ Consistent appearance across all charts
|
||||
- ✅ Professional look out of the box
|
||||
- ✅ Easier to maintain and customize
|
||||
- ✅ Better use of mplfinance's candlestick rendering
|
||||
|
||||
## Example Output
|
||||
|
||||
See `chart_utils_example.py` for runnable examples demonstrating:
|
||||
1. Basic OHLC chart with volume
|
||||
2. OHLC chart with multiple indicators
|
||||
3. Price-only chart
|
||||
4. Custom styling options
|
||||
|
||||
## File Locations
|
||||
|
||||
- **Main module**: `backend/src/agent/tools/chart_utils.py`
|
||||
- **Integration**: `backend/src/agent/tools/chart_tools.py` (lines 306-328)
|
||||
- **Examples**: `backend/src/agent/tools/chart_utils_example.py`
|
||||
- **Dependency**: `backend/requirements.txt` (mplfinance added)
|
||||
373
backend.old/src/agent/tools/TRIGGER_TOOLS.md
Normal file
373
backend.old/src/agent/tools/TRIGGER_TOOLS.md
Normal file
@@ -0,0 +1,373 @@
|
||||
# Agent Trigger Tools
|
||||
|
||||
Agent tools for automating tasks via the trigger system.
|
||||
|
||||
## Overview
|
||||
|
||||
These tools allow the agent to:
|
||||
- **Schedule recurring tasks** - Run agent prompts on intervals or cron schedules
|
||||
- **Execute one-time tasks** - Trigger sub-agent runs immediately
|
||||
- **Manage scheduled jobs** - List and cancel scheduled triggers
|
||||
- **React to events** - (Future) Connect data updates to agent actions
|
||||
|
||||
## Available Tools
|
||||
|
||||
### 1. `schedule_agent_prompt`
|
||||
|
||||
Schedule an agent to run with a specific prompt on a recurring schedule.
|
||||
|
||||
**Use Cases:**
|
||||
- Daily market analysis reports
|
||||
- Hourly portfolio rebalancing checks
|
||||
- Weekly performance summaries
|
||||
- Monitoring alerts
|
||||
|
||||
**Arguments:**
|
||||
- `prompt` (str): The prompt to send to the agent when triggered
|
||||
- `schedule_type` (str): "interval" or "cron"
|
||||
- `schedule_config` (dict): Schedule configuration
|
||||
- `name` (str, optional): Descriptive name for this task
|
||||
|
||||
**Schedule Config:**
|
||||
|
||||
*Interval-based:*
|
||||
```json
|
||||
{"minutes": 5}
|
||||
{"hours": 1, "minutes": 30}
|
||||
{"seconds": 30}
|
||||
```
|
||||
|
||||
*Cron-based:*
|
||||
```json
|
||||
{"hour": "9", "minute": "0"} // Daily at 9:00 AM
|
||||
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"} // Weekdays at 9 AM
|
||||
{"minute": "0"} // Every hour on the hour
|
||||
{"hour": "*/6", "minute": "0"} // Every 6 hours
|
||||
```
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"job_id": "interval_123",
|
||||
"message": "Scheduled 'daily_report' with job_id=interval_123",
|
||||
"schedule_type": "cron",
|
||||
"config": {"hour": "9", "minute": "0"}
|
||||
}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# Every 5 minutes: check BTC price
|
||||
schedule_agent_prompt(
|
||||
prompt="Check current BTC price on Binance. If > $50k, alert me.",
|
||||
schedule_type="interval",
|
||||
schedule_config={"minutes": 5},
|
||||
name="btc_price_monitor"
|
||||
)
|
||||
|
||||
# Daily at 9 AM: market summary
|
||||
schedule_agent_prompt(
|
||||
prompt="Generate a comprehensive market summary for BTC, ETH, and SOL. Include price changes, volume, and notable events from the last 24 hours.",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "9", "minute": "0"},
|
||||
name="daily_market_summary"
|
||||
)
|
||||
|
||||
# Every hour on weekdays: portfolio check
|
||||
schedule_agent_prompt(
|
||||
prompt="Review current portfolio positions. Check if any rebalancing is needed based on target allocations.",
|
||||
schedule_type="cron",
|
||||
schedule_config={"minute": "0", "day_of_week": "mon-fri"},
|
||||
name="hourly_portfolio_check"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. `execute_agent_prompt_once`
|
||||
|
||||
Execute an agent prompt once, immediately (enqueued with priority).
|
||||
|
||||
**Use Cases:**
|
||||
- Background analysis tasks
|
||||
- One-time data processing
|
||||
- Responding to specific events
|
||||
- Sub-agent delegation
|
||||
|
||||
**Arguments:**
|
||||
- `prompt` (str): The prompt to send to the agent
|
||||
- `priority` (str): "high", "normal", or "low" (default: "normal")
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"queue_seq": 42,
|
||||
"message": "Enqueued agent prompt with priority=normal",
|
||||
"prompt": "Analyze the last 100 BTC/USDT bars..."
|
||||
}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# Immediate analysis with high priority
|
||||
execute_agent_prompt_once(
|
||||
prompt="Analyze the last 100 BTC/USDT 1m bars and identify key support/resistance levels",
|
||||
priority="high"
|
||||
)
|
||||
|
||||
# Background task with normal priority
|
||||
execute_agent_prompt_once(
|
||||
prompt="Research the latest news about Ethereum upgrades and summarize findings",
|
||||
priority="normal"
|
||||
)
|
||||
|
||||
# Low priority cleanup task
|
||||
execute_agent_prompt_once(
|
||||
prompt="Review and archive old chart drawings from last month",
|
||||
priority="low"
|
||||
)
|
||||
```
|
||||
|
||||
### 3. `list_scheduled_triggers`
|
||||
|
||||
List all currently scheduled triggers.
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "cron_456",
|
||||
"name": "Cron: daily_market_summary",
|
||||
"next_run_time": "2024-03-05 09:00:00",
|
||||
"trigger": "cron[hour='9', minute='0']"
|
||||
},
|
||||
{
|
||||
"id": "interval_123",
|
||||
"name": "Interval: btc_price_monitor",
|
||||
"next_run_time": "2024-03-04 14:35:00",
|
||||
"trigger": "interval[0:05:00]"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
for job in jobs:
|
||||
print(f"{job['name']} - next run: {job['next_run_time']}")
|
||||
```
|
||||
|
||||
### 4. `cancel_scheduled_trigger`
|
||||
|
||||
Cancel a scheduled trigger by its job ID.
|
||||
|
||||
**Arguments:**
|
||||
- `job_id` (str): The job ID from `schedule_agent_prompt` or `list_scheduled_triggers`
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Cancelled job interval_123"
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# List jobs to find the ID
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
# Cancel specific job
|
||||
cancel_scheduled_trigger("interval_123")
|
||||
```
|
||||
|
||||
### 5. `on_data_update_run_agent`
|
||||
|
||||
**(Future)** Set up an agent to run whenever new data arrives for a specific symbol.
|
||||
|
||||
**Arguments:**
|
||||
- `source_name` (str): Data source name (e.g., "binance")
|
||||
- `symbol` (str): Trading pair (e.g., "BTC/USDT")
|
||||
- `resolution` (str): Time resolution (e.g., "1m", "5m")
|
||||
- `prompt_template` (str): Template with variables like {close}, {volume}, {symbol}
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
on_data_update_run_agent(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
prompt_template="New bar on {symbol}: close={close}, volume={volume}. Check if price crossed any key levels."
|
||||
)
|
||||
```
|
||||
|
||||
### 6. `get_trigger_system_stats`
|
||||
|
||||
Get statistics about the trigger system.
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"queue_depth": 3,
|
||||
"queue_running": true,
|
||||
"coordinator_stats": {
|
||||
"current_seq": 1042,
|
||||
"next_commit_seq": 1043,
|
||||
"pending_commits": 1,
|
||||
"total_executions": 1042,
|
||||
"state_counts": {
|
||||
"COMMITTED": 1038,
|
||||
"EXECUTING": 2,
|
||||
"WAITING_COMMIT": 1,
|
||||
"FAILED": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
stats = get_trigger_system_stats()
|
||||
print(f"Queue has {stats['queue_depth']} pending triggers")
|
||||
print(f"System has processed {stats['coordinator_stats']['total_executions']} total triggers")
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
Here's how these tools enable autonomous agent behavior:
|
||||
|
||||
```python
|
||||
# Agent conversation:
|
||||
User: "Monitor BTC price and send me a summary every hour during market hours"
|
||||
|
||||
Agent: I'll set that up for you using the trigger system.
|
||||
|
||||
# Agent uses tool:
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Check the current BTC/USDT price on Binance.
|
||||
Calculate the price change from 1 hour ago.
|
||||
If price moved > 2%, provide a detailed analysis.
|
||||
Otherwise, provide a brief status update.
|
||||
Send results to user as a notification.
|
||||
""",
|
||||
schedule_type="cron",
|
||||
schedule_config={
|
||||
"minute": "0",
|
||||
"hour": "9-17", # 9 AM to 5 PM
|
||||
"day_of_week": "mon-fri"
|
||||
},
|
||||
name="btc_hourly_monitor"
|
||||
)
|
||||
|
||||
Agent: Done! I've scheduled an hourly BTC price monitor that runs during market hours (9 AM - 5 PM on weekdays). You'll receive updates every hour.
|
||||
|
||||
# Later...
|
||||
User: "Can you show me all my scheduled tasks?"
|
||||
|
||||
Agent: Let me check what's scheduled.
|
||||
|
||||
# Agent uses tool:
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
Agent: You have 3 scheduled tasks:
|
||||
1. "btc_hourly_monitor" - runs every hour during market hours
|
||||
2. "daily_market_summary" - runs daily at 9 AM
|
||||
3. "portfolio_rebalance_check" - runs every 4 hours
|
||||
|
||||
Would you like to modify or cancel any of these?
|
||||
```
|
||||
|
||||
## Use Case: Autonomous Trading Bot
|
||||
|
||||
```python
|
||||
# Step 1: Set up data monitoring
|
||||
execute_agent_prompt_once(
|
||||
prompt="""
|
||||
Subscribe to BTC/USDT 1m bars from Binance.
|
||||
When subscribed, set up the following:
|
||||
1. Calculate RSI(14) on each new bar
|
||||
2. If RSI > 70, execute prompt: "RSI overbought on BTC, check if we should sell"
|
||||
3. If RSI < 30, execute prompt: "RSI oversold on BTC, check if we should buy"
|
||||
""",
|
||||
priority="high"
|
||||
)
|
||||
|
||||
# Step 2: Schedule periodic portfolio review
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Review current portfolio:
|
||||
1. Calculate current allocation percentages
|
||||
2. Compare to target allocation (60% BTC, 30% ETH, 10% stable)
|
||||
3. If deviation > 5%, generate rebalancing trades
|
||||
4. Submit trades for execution
|
||||
""",
|
||||
schedule_type="interval",
|
||||
schedule_config={"hours": 4},
|
||||
name="portfolio_rebalance"
|
||||
)
|
||||
|
||||
# Step 3: Schedule daily risk check
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Daily risk assessment:
|
||||
1. Calculate portfolio VaR (Value at Risk)
|
||||
2. Check current leverage across all positions
|
||||
3. Review stop-loss placements
|
||||
4. If risk exceeds threshold, alert and suggest adjustments
|
||||
""",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "8", "minute": "0"},
|
||||
name="daily_risk_check"
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Autonomous operation** - Agent can schedule its own tasks
|
||||
✅ **Event-driven** - React to market data, time, or custom events
|
||||
✅ **Flexible scheduling** - Interval or cron-based
|
||||
✅ **Self-managing** - Agent can list and cancel its own jobs
|
||||
✅ **Priority control** - High-priority tasks jump the queue
|
||||
✅ **Future-proof** - Easy to add Python lambdas, strategy execution, etc.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Python script execution** - Schedule arbitrary Python code
|
||||
- **Strategy triggers** - Connect to strategy execution system
|
||||
- **Event composition** - AND/OR logic for complex event patterns
|
||||
- **Conditional execution** - Only run if conditions met (e.g., volatility > threshold)
|
||||
- **Result chaining** - Use output of one trigger as input to another
|
||||
- **Backtesting mode** - Test trigger logic on historical data
|
||||
|
||||
## Setup in main.py
|
||||
|
||||
```python
|
||||
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
||||
from trigger import TriggerQueue, CommitCoordinator
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Initialize trigger system
|
||||
coordinator = CommitCoordinator()
|
||||
queue = TriggerQueue(coordinator)
|
||||
scheduler = TriggerScheduler(queue)
|
||||
|
||||
await queue.start()
|
||||
scheduler.start()
|
||||
|
||||
# Make available to agent tools
|
||||
set_trigger_queue(queue)
|
||||
set_trigger_scheduler(scheduler)
|
||||
set_coordinator(coordinator)
|
||||
|
||||
# Add TRIGGER_TOOLS to agent's tool list
|
||||
from agent.tools import TRIGGER_TOOLS
|
||||
agent_tools = [..., *TRIGGER_TOOLS]
|
||||
```
|
||||
|
||||
Now the agent has full control over the trigger system! 🚀
|
||||
64
backend.old/src/agent/tools/__init__.py
Normal file
64
backend.old/src/agent/tools/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Agent tools for trading operations.
|
||||
|
||||
This package provides tools for:
|
||||
- Synchronization stores (sync_tools)
|
||||
- Data sources and market data (datasource_tools)
|
||||
- Chart data access and analysis (chart_tools)
|
||||
- Technical indicators (indicator_tools)
|
||||
- Shape/drawing management (shape_tools)
|
||||
- Trigger system and automation (trigger_tools)
|
||||
"""
|
||||
|
||||
# Global registries that will be set by main.py
|
||||
_registry = None
|
||||
_datasource_registry = None
|
||||
_indicator_registry = None
|
||||
|
||||
|
||||
def set_registry(registry):
|
||||
"""Set the global SyncRegistry instance for tools to use."""
|
||||
global _registry
|
||||
_registry = registry
|
||||
|
||||
|
||||
def set_datasource_registry(datasource_registry):
|
||||
"""Set the global DataSourceRegistry instance for tools to use."""
|
||||
global _datasource_registry
|
||||
_datasource_registry = datasource_registry
|
||||
|
||||
|
||||
def set_indicator_registry(indicator_registry):
|
||||
"""Set the global IndicatorRegistry instance for tools to use."""
|
||||
global _indicator_registry
|
||||
_indicator_registry = indicator_registry
|
||||
|
||||
|
||||
# Import all tools from submodules
|
||||
from .sync_tools import SYNC_TOOLS
|
||||
from .datasource_tools import DATASOURCE_TOOLS
|
||||
from .chart_tools import CHART_TOOLS
|
||||
from .indicator_tools import INDICATOR_TOOLS
|
||||
from .research_tools import RESEARCH_TOOLS
|
||||
from .shape_tools import SHAPE_TOOLS
|
||||
from .trigger_tools import (
|
||||
TRIGGER_TOOLS,
|
||||
set_trigger_queue,
|
||||
set_trigger_scheduler,
|
||||
set_coordinator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"set_registry",
|
||||
"set_datasource_registry",
|
||||
"set_indicator_registry",
|
||||
"set_trigger_queue",
|
||||
"set_trigger_scheduler",
|
||||
"set_coordinator",
|
||||
"SYNC_TOOLS",
|
||||
"DATASOURCE_TOOLS",
|
||||
"CHART_TOOLS",
|
||||
"INDICATOR_TOOLS",
|
||||
"RESEARCH_TOOLS",
|
||||
"SHAPE_TOOLS",
|
||||
"TRIGGER_TOOLS",
|
||||
]
|
||||
454
backend.old/src/agent/tools/chart_tools.py
Normal file
454
backend.old/src/agent/tools/chart_tools.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""Chart data access and analysis tools."""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
import io
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
def _get_order_store():
|
||||
"""Get the global OrderStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "OrderStore" in registry.entries:
|
||||
return registry.entries["OrderStore"].model
|
||||
return None
|
||||
|
||||
|
||||
def _get_chart_store():
|
||||
"""Get the global ChartStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "ChartStore" in registry.entries:
|
||||
return registry.entries["ChartStore"].model
|
||||
return None
|
||||
|
||||
|
||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||
"""Internal implementation for getting chart data.
|
||||
|
||||
This is a helper function that can be called by both get_chart_data tool
|
||||
and analyze_chart_data tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||
"""
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
# Read current chart state
|
||||
chart_store = registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
raise ValueError("ChartStore not found in registry")
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
symbol = chart_data.get("symbol", "")
|
||||
interval = chart_data.get("interval", "15")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
if not symbol or symbol is None:
|
||||
raise ValueError(
|
||||
"No chart visible - ChartStore symbol is None. "
|
||||
"The user is likely on a narrow screen (mobile) where charts are hidden. "
|
||||
"Let them know they can view charts on a wider screen, or use get_historical_data() "
|
||||
"if they specify a symbol and timeframe."
|
||||
)
|
||||
|
||||
# Parse the symbol to extract exchange/source and symbol name
|
||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
if ":" not in symbol:
|
||||
raise ValueError(
|
||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||
)
|
||||
|
||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
|
||||
# Get the data source
|
||||
source = datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = datasource_registry.list_sources()
|
||||
raise ValueError(
|
||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||
)
|
||||
|
||||
# Determine time range - REQUIRE it to be set, no defaults
|
||||
if start_time is None or end_time is None:
|
||||
raise ValueError(
|
||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||
f"Wait for the chart to fully load before analyzing data."
|
||||
)
|
||||
|
||||
from_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
logger.info(
|
||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||
f"countback={countback}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||
f"resolution '{interval}'"
|
||||
)
|
||||
|
||||
# Query the data source
|
||||
result = await source.get_bars(
|
||||
symbol=symbol_name,
|
||||
resolution=interval,
|
||||
from_time=from_time,
|
||||
to_time=end_time,
|
||||
countback=countback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received {len(result.bars)} bars from data source. "
|
||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Build chart context to return along with result
|
||||
chart_context = {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
}
|
||||
|
||||
return result, chart_context, source_name
|
||||
|
||||
|
||||
@tool
|
||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||
|
||||
This is a convenience tool that automatically:
|
||||
1. Reads the ChartStore to see what chart the user is viewing
|
||||
2. Parses the symbol to determine the data source (exchange prefix)
|
||||
3. Queries the appropriate data source for that symbol's data
|
||||
4. Returns the data for the visible time range and interval
|
||||
|
||||
This is the preferred way to access chart data when helping the user analyze
|
||||
what they're looking at, since it automatically uses their current chart context.
|
||||
|
||||
**IMPORTANT**: This tool will fail if ChartStore.symbol is None (no chart visible).
|
||||
This happens when the user is on a narrow screen (mobile) where charts are hidden.
|
||||
In that case, let the user know charts are only visible on wider screens, or use
|
||||
get_historical_data() if they specify a symbol and timeframe.
|
||||
|
||||
Args:
|
||||
countback: Optional limit on number of bars to return. If not specified,
|
||||
returns all bars in the visible time range.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- symbol: The trading pair being viewed
|
||||
- resolution: The chart interval
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- source: Which data source was used
|
||||
|
||||
Raises:
|
||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||
if no chart is visible (symbol is None), or if the symbol format is invalid
|
||||
|
||||
Example:
|
||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||
data = get_chart_data()
|
||||
# Returns BTC/USDT data from binance source at 15min resolution
|
||||
# for the currently visible time range
|
||||
"""
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Return enriched result with chart context
|
||||
response = result.model_dump()
|
||||
response["chart_context"] = chart_context
|
||||
response["source"] = source_name
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Execute Python code for technical analysis with automatic chart data loading.
|
||||
|
||||
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
|
||||
|
||||
This is your go-to tool whenever the user asks about indicators, wants to see
|
||||
a chart, or needs any computational analysis of market data.
|
||||
|
||||
Pre-loaded Environment:
|
||||
- `pd` : pandas
|
||||
- `np` : numpy
|
||||
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
|
||||
- `talib` : TA-Lib technical analysis library
|
||||
- `indicator_registry`: 150+ registered indicators
|
||||
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
|
||||
- `registry` : SyncRegistry instance - access to all registered stores
|
||||
- `datasource_registry`: DataSourceRegistry - access to data sources (binance, etc.)
|
||||
- `order_store` : OrderStore instance - current orders list
|
||||
- `chart_store` : ChartStore instance - current chart state
|
||||
|
||||
Auto-loaded when user has a chart visible (ChartStore.symbol is not None):
|
||||
- `df` : pandas DataFrame with DatetimeIndex and columns:
|
||||
open, high, low, close, volume (OHLCV data ready to use)
|
||||
- `chart_context` : dict with symbol, interval, start_time, end_time
|
||||
|
||||
When NO chart is visible (narrow screen/mobile):
|
||||
- `df` : None
|
||||
- `chart_context` : None
|
||||
|
||||
If `df` is None, you can still load alternative data by:
|
||||
- Using chart_store to see what symbol/timeframe is configured
|
||||
- Using datasource_registry.get_source('binance') to access data sources
|
||||
- Calling datasource.get_history(symbol, interval, start, end) to load any data
|
||||
- This allows you to make plots of ANY chart even when not connected to chart view
|
||||
|
||||
The `plot_ohlc()` Helper:
|
||||
Create professional candlestick charts instantly:
|
||||
- `plot_ohlc(df)` - basic OHLC chart with volume
|
||||
- `plot_ohlc(df, title='BTC 15min')` - with custom title
|
||||
- `plot_ohlc(df, volume=False)` - price only, no volume
|
||||
- Returns a matplotlib Figure that's automatically saved to plot_urls
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
countback: Optional limit on number of bars to load (default: all visible bars)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- script_output : printed output + last expression result
|
||||
- result_dataframe : serialized DataFrame if last expression is a DataFrame
|
||||
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
|
||||
- chart_context : {symbol, interval, start_time, end_time} or None
|
||||
- error : traceback if execution failed
|
||||
|
||||
Examples:
|
||||
# RSI indicator with chart
|
||||
execute_python(\"\"\"
|
||||
df['RSI'] = talib.RSI(df['close'], 14)
|
||||
fig = plot_ohlc(df, title='BTC/USDT with RSI')
|
||||
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
|
||||
df[['close', 'RSI']].tail(5)
|
||||
\"\"\")
|
||||
|
||||
# Multiple indicators
|
||||
execute_python(\"\"\"
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
|
||||
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
|
||||
|
||||
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
|
||||
|
||||
current_price = df['close'].iloc[-1]
|
||||
sma20 = df['SMA_20'].iloc[-1]
|
||||
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
|
||||
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
|
||||
\"\"\")
|
||||
|
||||
# Pattern detection
|
||||
execute_python(\"\"\"
|
||||
# Find swing highs
|
||||
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
|
||||
swing_highs = df[df['swing_high']][['high']].tail(5)
|
||||
|
||||
fig = plot_ohlc(df, title='Swing High Detection')
|
||||
print("Recent swing highs:")
|
||||
print(swing_highs)
|
||||
\"\"\")
|
||||
|
||||
# Load alternative data when df is None or for different symbol/timeframe
|
||||
execute_python(\"\"\"
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Get data source
|
||||
binance = datasource_registry.get_source('binance')
|
||||
|
||||
# Load ETH data even if viewing BTC chart
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(days=7)
|
||||
|
||||
result = await binance.get_history(
|
||||
symbol='ETH/USDT',
|
||||
interval='1h',
|
||||
start=int(start_time.timestamp()),
|
||||
end=int(end_time.timestamp())
|
||||
)
|
||||
|
||||
# Convert to DataFrame
|
||||
rows = [{'time': pd.to_datetime(bar.time, unit='s'), **bar.data} for bar in result.bars]
|
||||
eth_df = pd.DataFrame(rows).set_index('time')
|
||||
|
||||
# Calculate RSI and plot
|
||||
eth_df['RSI'] = talib.RSI(eth_df['close'], 14)
|
||||
fig = plot_ohlc(eth_df, title='ETH/USDT 1h - RSI Analysis')
|
||||
print(f"ETH RSI: {eth_df['RSI'].iloc[-1]:.2f}")
|
||||
\"\"\")
|
||||
|
||||
# Access chart store to see current state
|
||||
execute_python(\"\"\"
|
||||
print(f"Current symbol: {chart_store.chart_state.symbol}")
|
||||
print(f"Current interval: {chart_store.chart_state.interval}")
|
||||
print(f"Orders: {len(order_store.orders)}")
|
||||
\"\"\")
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import talib
|
||||
except ImportError:
|
||||
talib = None
|
||||
logger.warning("TA-Lib not available in execute_python environment")
|
||||
|
||||
# --- Attempt to load chart data ---
|
||||
df = None
|
||||
chart_context = None
|
||||
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if registry and datasource_registry:
|
||||
try:
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
bars = result.bars
|
||||
if bars:
|
||||
rows = []
|
||||
for bar in bars:
|
||||
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
|
||||
df = pd.DataFrame(rows).set_index('time')
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
|
||||
except Exception as e:
|
||||
logger.info(f"execute_python: no chart data loaded ({e})")
|
||||
|
||||
# --- Import chart utilities ---
|
||||
from .chart_utils import plot_ohlc
|
||||
|
||||
# --- Get indicator registry ---
|
||||
indicator_registry = _get_indicator_registry()
|
||||
|
||||
# --- Get DataStores ---
|
||||
order_store = _get_order_store()
|
||||
chart_store = _get_chart_store()
|
||||
|
||||
# --- Build globals ---
|
||||
script_globals: Dict[str, Any] = {
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'plt': plt,
|
||||
'talib': talib,
|
||||
'indicator_registry': indicator_registry,
|
||||
'registry': registry,
|
||||
'datasource_registry': datasource_registry,
|
||||
'order_store': order_store,
|
||||
'chart_store': chart_store,
|
||||
'df': df,
|
||||
'chart_context': chart_context,
|
||||
'plot_ohlc': plot_ohlc,
|
||||
}
|
||||
|
||||
# --- Execute ---
|
||||
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stdout_capture = io.StringIO()
|
||||
result_df = None
|
||||
error_msg = None
|
||||
plot_urls = []
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
|
||||
exec(code, script_globals)
|
||||
|
||||
# Capture last expression
|
||||
lines = code.strip().splitlines()
|
||||
if lines:
|
||||
last = lines[-1].strip()
|
||||
if last and not any(last.startswith(kw) for kw in (
|
||||
'if', 'for', 'while', 'def', 'class', 'import',
|
||||
'from', 'with', 'try', 'return', '#'
|
||||
)):
|
||||
try:
|
||||
last_val = eval(last, script_globals)
|
||||
if isinstance(last_val, pd.DataFrame):
|
||||
result_df = last_val
|
||||
elif last_val is not None:
|
||||
stdout_capture.write(str(last_val))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save plots
|
||||
for fig_num in plt.get_fignums():
|
||||
fig = plt.figure(fig_num)
|
||||
filename = f"plot_{uuid.uuid4()}.png"
|
||||
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
|
||||
plot_urls.append(f"/uploads/{filename}")
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
||||
|
||||
# --- Build response ---
|
||||
response: Dict[str, Any] = {
|
||||
'script_output': stdout_capture.getvalue(),
|
||||
'chart_context': chart_context,
|
||||
'plot_urls': plot_urls,
|
||||
}
|
||||
if result_df is not None:
|
||||
response['result_dataframe'] = {
|
||||
'columns': result_df.columns.tolist(),
|
||||
'index': result_df.index.astype(str).tolist(),
|
||||
'data': result_df.values.tolist(),
|
||||
'shape': result_df.shape,
|
||||
}
|
||||
if error_msg:
|
||||
response['error'] = error_msg
|
||||
|
||||
return response
|
||||
|
||||
|
||||
CHART_TOOLS = [
|
||||
get_chart_data,
|
||||
execute_python
|
||||
]
|
||||
224
backend.old/src/agent/tools/chart_utils.py
Normal file
224
backend.old/src/agent/tools/chart_utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Chart plotting utilities for creating standard, beautiful OHLC charts."""
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def plot_ohlc(
|
||||
df: pd.DataFrame,
|
||||
title: Optional[str] = None,
|
||||
volume: bool = True,
|
||||
figsize: Tuple[int, int] = (14, 8),
|
||||
style: str = 'seaborn-v0_8-darkgrid',
|
||||
**kwargs
|
||||
) -> plt.Figure:
|
||||
"""Create a beautiful standard OHLC candlestick chart.
|
||||
|
||||
This is a convenience function that generates a professional-looking candlestick
|
||||
chart with consistent styling across all generated charts. It uses mplfinance
|
||||
with seaborn aesthetics for a polished appearance.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame with DatetimeIndex and columns: open, high, low, close, volume
|
||||
title: Optional chart title. If None, uses symbol from chart context
|
||||
volume: Whether to include volume subplot (default: True)
|
||||
figsize: Figure size as (width, height) in inches (default: (14, 8))
|
||||
style: Base matplotlib style to use (default: 'seaborn-v0_8-darkgrid')
|
||||
**kwargs: Additional arguments to pass to mplfinance.plot()
|
||||
|
||||
Returns:
|
||||
matplotlib.figure.Figure: The created figure object
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Basic usage in analyze_chart_data script
|
||||
fig = plot_ohlc(df, title='BTC/USDT 15min')
|
||||
|
||||
# Customize with additional indicators
|
||||
fig = plot_ohlc(df, volume=True, title='Price Action')
|
||||
|
||||
# Add custom overlays after calling plot_ohlc
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
fig = plot_ohlc(df, title='With SMA')
|
||||
# Note: For mplfinance overlays, use the mav or addplot parameters
|
||||
```
|
||||
|
||||
Note:
|
||||
The DataFrame must have a DatetimeIndex and the standard OHLCV columns.
|
||||
Column names should be lowercase: open, high, low, close, volume
|
||||
"""
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mplfinance is required for plot_ohlc(). "
|
||||
"Install it with: pip install mplfinance"
|
||||
)
|
||||
|
||||
# Validate DataFrame structure
|
||||
required_cols = ['open', 'high', 'low', 'close']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(
|
||||
f"DataFrame missing required columns: {missing_cols}. "
|
||||
f"Required: {required_cols}"
|
||||
)
|
||||
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
raise ValueError(
|
||||
"DataFrame must have a DatetimeIndex. "
|
||||
"Convert with: df.index = pd.to_datetime(df.index)"
|
||||
)
|
||||
|
||||
# Ensure volume column exists for volume plot
|
||||
if volume and 'volume' not in df.columns:
|
||||
logger.warning("volume=True but 'volume' column not found in DataFrame. Disabling volume.")
|
||||
volume = False
|
||||
|
||||
# Create custom style with seaborn aesthetics
|
||||
# Using a professional color scheme: green for up candles, red for down candles
|
||||
mc = mpf.make_marketcolors(
|
||||
up='#26a69a', # Teal green (calmer than bright green)
|
||||
down='#ef5350', # Coral red (softer than pure red)
|
||||
edge='inherit', # Match candle color for edges
|
||||
wick='inherit', # Match candle color for wicks
|
||||
volume='in', # Volume bars colored by price direction
|
||||
alpha=0.9 # Slight transparency for elegance
|
||||
)
|
||||
|
||||
s = mpf.make_mpf_style(
|
||||
base_mpf_style='charles', # Clean base style
|
||||
marketcolors=mc,
|
||||
rc={
|
||||
'font.size': 10,
|
||||
'axes.labelsize': 11,
|
||||
'axes.titlesize': 12,
|
||||
'xtick.labelsize': 9,
|
||||
'ytick.labelsize': 9,
|
||||
'legend.fontsize': 10,
|
||||
'figure.facecolor': '#f0f0f0',
|
||||
'axes.facecolor': '#ffffff',
|
||||
'axes.grid': True,
|
||||
'grid.alpha': 0.3,
|
||||
'grid.linestyle': '--',
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare plot parameters
|
||||
plot_params = {
|
||||
'type': 'candle',
|
||||
'style': s,
|
||||
'volume': volume,
|
||||
'figsize': figsize,
|
||||
'tight_layout': True,
|
||||
'returnfig': True,
|
||||
'warn_too_much_data': 1000, # Warn if > 1000 candles for performance
|
||||
}
|
||||
|
||||
# Add title if provided
|
||||
if title:
|
||||
plot_params['title'] = title
|
||||
|
||||
# Merge any additional kwargs
|
||||
plot_params.update(kwargs)
|
||||
|
||||
# Create the plot
|
||||
logger.info(
|
||||
f"Creating OHLC chart with {len(df)} candles, "
|
||||
f"date range: {df.index.min()} to {df.index.max()}, "
|
||||
f"volume: {volume}"
|
||||
)
|
||||
|
||||
fig, axes = mpf.plot(df, **plot_params)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def add_indicators_to_plot(
|
||||
df: pd.DataFrame,
|
||||
indicators: dict,
|
||||
**plot_kwargs
|
||||
) -> plt.Figure:
|
||||
"""Create an OHLC chart with technical indicators overlaid.
|
||||
|
||||
This extends plot_ohlc() to include common technical indicators using
|
||||
mplfinance's addplot functionality for proper overlay on candlestick charts.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame with OHLCV data and indicator columns
|
||||
indicators: Dictionary mapping indicator names to parameters
|
||||
Example: {
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'EMA_50': {'color': 'orange', 'width': 1.5}
|
||||
}
|
||||
**plot_kwargs: Additional arguments for plot_ohlc()
|
||||
|
||||
Returns:
|
||||
matplotlib.figure.Figure: The created figure object
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Calculate indicators
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
|
||||
# Plot with indicators
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5, 'label': '20 SMA'},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5, 'label': '50 SMA'}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages'
|
||||
)
|
||||
```
|
||||
"""
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mplfinance is required. Install it with: pip install mplfinance"
|
||||
)
|
||||
|
||||
# Build addplot list for indicators
|
||||
addplots = []
|
||||
for indicator_col, params in indicators.items():
|
||||
if indicator_col not in df.columns:
|
||||
logger.warning(f"Indicator column '{indicator_col}' not found in DataFrame. Skipping.")
|
||||
continue
|
||||
|
||||
color = params.get('color', 'blue')
|
||||
width = params.get('width', 1.0)
|
||||
panel = params.get('panel', 0) # 0 = main panel with candles
|
||||
ylabel = params.get('ylabel', '')
|
||||
|
||||
addplots.append(
|
||||
mpf.make_addplot(
|
||||
df[indicator_col],
|
||||
color=color,
|
||||
width=width,
|
||||
panel=panel,
|
||||
ylabel=ylabel
|
||||
)
|
||||
)
|
||||
|
||||
# Pass addplot to plot_ohlc via kwargs
|
||||
if addplots:
|
||||
plot_kwargs['addplot'] = addplots
|
||||
|
||||
return plot_ohlc(df, **plot_kwargs)
|
||||
|
||||
|
||||
# Convenience presets for common chart types
|
||||
def plot_price_volume(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||
"""Create a standard price + volume chart."""
|
||||
return plot_ohlc(df, title=title, volume=True, figsize=(14, 8))
|
||||
|
||||
|
||||
def plot_price_only(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||
"""Create a price-only candlestick chart without volume."""
|
||||
return plot_ohlc(df, title=title, volume=False, figsize=(14, 6))
|
||||
154
backend.old/src/agent/tools/chart_utils_example.py
Normal file
154
backend.old/src/agent/tools/chart_utils_example.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Example usage of chart_utils.py plotting functions.
|
||||
|
||||
This demonstrates how the LLM can use the plot_ohlc() convenience function
|
||||
in analyze_chart_data scripts to create beautiful, standard OHLC charts.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def create_sample_data(days=30):
|
||||
"""Create sample OHLCV data for testing."""
|
||||
dates = pd.date_range(end=datetime.now(), periods=days * 24, freq='1H')
|
||||
|
||||
# Simulate price movement
|
||||
np.random.seed(42)
|
||||
close = 50000 + np.cumsum(np.random.randn(len(dates)) * 100)
|
||||
|
||||
data = {
|
||||
'open': close + np.random.randn(len(dates)) * 50,
|
||||
'high': close + np.abs(np.random.randn(len(dates))) * 100,
|
||||
'low': close - np.abs(np.random.randn(len(dates))) * 100,
|
||||
'close': close,
|
||||
'volume': np.abs(np.random.randn(len(dates))) * 1000000
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
|
||||
# Ensure high is highest and low is lowest
|
||||
df['high'] = df[['open', 'high', 'low', 'close']].max(axis=1)
|
||||
df['low'] = df[['open', 'high', 'low', 'close']].min(axis=1)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from chart_utils import plot_ohlc, add_indicators_to_plot, plot_price_volume
|
||||
|
||||
# Create sample data
|
||||
df = create_sample_data(days=30)
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 1: Basic OHLC chart with volume")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||
df.tail(5)
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||
print("\n✓ Chart created successfully!")
|
||||
print(f" Figure size: {fig.get_size_inches()}")
|
||||
print(f" Number of axes: {len(fig.axes)}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 2: OHLC chart with indicators")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
# Calculate indicators
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
|
||||
# Plot with indicators
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages',
|
||||
volume=True
|
||||
)
|
||||
|
||||
df[['close', 'SMA_20', 'SMA_50', 'EMA_12']].tail(5)
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages',
|
||||
volume=True
|
||||
)
|
||||
|
||||
print("\n✓ Chart with indicators created successfully!")
|
||||
print(f" Last close: ${df['close'].iloc[-1]:,.2f}")
|
||||
print(f" SMA 20: ${df['SMA_20'].iloc[-1]:,.2f}")
|
||||
print(f" SMA 50: ${df['SMA_50'].iloc[-1]:,.2f}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 3: Price-only chart (no volume)")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
from chart_utils import plot_price_only
|
||||
|
||||
fig = plot_price_only(df, title='Clean Price Action')
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
from chart_utils import plot_price_only
|
||||
fig = plot_price_only(df, title='Clean Price Action')
|
||||
|
||||
print("\n✓ Price-only chart created successfully!")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
print("""
|
||||
The chart_utils module provides:
|
||||
|
||||
1. plot_ohlc() - Main function for beautiful candlestick charts
|
||||
- Professional seaborn-inspired styling
|
||||
- Consistent color scheme (teal up, coral down)
|
||||
- Optional volume subplot
|
||||
- Customizable figure size
|
||||
|
||||
2. add_indicators_to_plot() - OHLC charts with technical indicators
|
||||
- Overlay multiple indicators
|
||||
- Customizable colors and line widths
|
||||
- Proper integration with mplfinance
|
||||
|
||||
3. Preset functions for common chart types:
|
||||
- plot_price_volume() - Standard price + volume
|
||||
- plot_price_only() - Candlesticks without volume
|
||||
|
||||
Benefits:
|
||||
✓ Consistent look and feel across all charts
|
||||
✓ Less code for the LLM to generate
|
||||
✓ Professional appearance out of the box
|
||||
✓ Easy to customize when needed
|
||||
✓ Works seamlessly with analyze_chart_data tool
|
||||
|
||||
The LLM can now simply call plot_ohlc(df) instead of writing
|
||||
custom matplotlib code for every chart request!
|
||||
""")
|
||||
158
backend.old/src/agent/tools/datasource_tools.py
Normal file
158
backend.old/src/agent/tools/datasource_tools.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Data source and market data tools."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_data_sources() -> List[str]:
|
||||
"""List all available data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names that can be queried for market data
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_sources()
|
||||
|
||||
|
||||
@tool
|
||||
async def search_symbols(
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for trading symbols across all data sources.
|
||||
|
||||
Automatically searches all available data sources and returns aggregated results.
|
||||
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
||||
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
||||
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
||||
limit: Maximum number of results per source (default: 30)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping source names to lists of matching symbols.
|
||||
Each symbol includes: symbol, full_name, description, exchange, type.
|
||||
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
||||
|
||||
Example response:
|
||||
{
|
||||
"demo": [
|
||||
{
|
||||
"symbol": "BTC/USDT",
|
||||
"full_name": "Bitcoin / Tether USD",
|
||||
"description": "Bitcoin perpetual futures",
|
||||
"exchange": "demo",
|
||||
"type": "crypto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
# Always search all sources
|
||||
results = await registry.search_all(query, type, exchange, limit)
|
||||
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
||||
|
||||
|
||||
@tool
|
||||
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get complete metadata for a trading symbol.
|
||||
|
||||
This retrieves full information about a symbol including:
|
||||
- Description and type
|
||||
- Supported time resolutions
|
||||
- Available data columns (OHLCV, volume, funding rates, etc.)
|
||||
- Trading session information
|
||||
- Price scale and precision
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source (use list_data_sources to see available)
|
||||
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing complete symbol metadata including column schema
|
||||
|
||||
Raises:
|
||||
ValueError: If source_name or symbol is not found
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
symbol_info = await registry.resolve_symbol(source_name, symbol)
|
||||
return symbol_info.model_dump()
|
||||
|
||||
|
||||
@tool
|
||||
async def get_historical_data(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get historical bar/candle data for a symbol.
|
||||
|
||||
Retrieves time-series data between the specified timestamps. The data
|
||||
includes all columns defined for the symbol (OHLCV + any custom columns).
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
||||
from_time: Start time as Unix timestamp in seconds
|
||||
to_time: End time as Unix timestamp in seconds
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- symbol: The requested symbol
|
||||
- resolution: The time resolution
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- nextTime: If present, indicates more data is available for pagination
|
||||
|
||||
Raises:
|
||||
ValueError: If source, symbol, or resolution is invalid
|
||||
|
||||
Example:
|
||||
# Get 1-hour BTC data for the last 24 hours
|
||||
import time
|
||||
to_time = int(time.time())
|
||||
from_time = to_time - 86400 # 24 hours ago
|
||||
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
source = registry.get(source_name)
|
||||
if not source:
|
||||
available = registry.list_sources()
|
||||
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
||||
|
||||
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
DATASOURCE_TOOLS = [
|
||||
list_data_sources,
|
||||
search_symbols,
|
||||
get_symbol_info,
|
||||
get_historical_data,
|
||||
]
|
||||
435
backend.old/src/agent/tools/indicator_tools.py
Normal file
435
backend.old/src/agent/tools/indicator_tools.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""Technical indicator tools.
|
||||
|
||||
These tools allow the agent to:
|
||||
1. Discover available indicators (list, search, get info)
|
||||
2. Add indicators to the chart
|
||||
3. Update/remove indicators
|
||||
4. Query currently applied indicators
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global sync registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_indicator_store():
|
||||
"""Get the global IndicatorStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "IndicatorStore" in registry.entries:
|
||||
return registry.entries["IndicatorStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def list_indicators() -> List[str]:
|
||||
"""List all available technical indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator names that can be used in analysis and strategies
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_indicators()
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific indicator.
|
||||
|
||||
Retrieves metadata including description, parameters, category, use cases,
|
||||
input/output schemas, and references.
|
||||
|
||||
Args:
|
||||
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- name: Indicator name
|
||||
- display_name: Human-readable name
|
||||
- description: What the indicator computes and why it's useful
|
||||
- category: Category (momentum, trend, volatility, volume, etc.)
|
||||
- parameters: List of configurable parameters with types and defaults
|
||||
- use_cases: Common trading scenarios where this indicator helps
|
||||
- tags: Searchable tags
|
||||
- input_schema: Required input columns (e.g., OHLCV requirements)
|
||||
- output_schema: Columns this indicator produces
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator_name is not found
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = registry.get_metadata(indicator_name)
|
||||
if not metadata:
|
||||
total_count = len(registry.list_indicators())
|
||||
raise ValueError(
|
||||
f"Indicator '{indicator_name}' not found. "
|
||||
f"Total available: {total_count} indicators. "
|
||||
f"Use search_indicators() to find indicators by name, category, or tag."
|
||||
)
|
||||
|
||||
input_schema = registry.get_input_schema(indicator_name)
|
||||
output_schema = registry.get_output_schema(indicator_name)
|
||||
|
||||
result = metadata.model_dump()
|
||||
result["input_schema"] = input_schema.model_dump() if input_schema else None
|
||||
result["output_schema"] = output_schema.model_dump() if output_schema else None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def search_indicators(
|
||||
query: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tag: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for indicators by text query, category, or tag.
|
||||
|
||||
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
|
||||
|
||||
Use this to discover relevant indicators for your trading strategy or analysis.
|
||||
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
|
||||
|
||||
Args:
|
||||
query: Optional text search across names, descriptions, and use cases
|
||||
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
|
||||
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
|
||||
|
||||
Returns:
|
||||
List of lightweight indicator summaries. Each contains:
|
||||
- name: Indicator name (use with get_indicator_info() for full details)
|
||||
- display_name: Human-readable name
|
||||
- description: Brief one-line description
|
||||
- category: Category (momentum, trend, volatility, etc.)
|
||||
|
||||
Example:
|
||||
# Find all momentum indicators
|
||||
results = search_indicators(category="momentum")
|
||||
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
|
||||
|
||||
# Then get details on interesting ones
|
||||
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
|
||||
|
||||
# Search for moving average indicators
|
||||
search_indicators(query="moving average")
|
||||
|
||||
# Find all TA-Lib indicators
|
||||
search_indicators(tag="talib")
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
results = []
|
||||
|
||||
if query:
|
||||
results = registry.search_by_text(query)
|
||||
elif category:
|
||||
results = registry.search_by_category(category)
|
||||
elif tag:
|
||||
results = registry.search_by_tag(tag)
|
||||
else:
|
||||
# Return all indicators if no filter
|
||||
results = registry.get_all_metadata()
|
||||
|
||||
# Return lightweight summaries only
|
||||
return [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"description": r.description,
|
||||
"category": r.category
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_categories() -> Dict[str, int]:
|
||||
"""Get all indicator categories and their counts.
|
||||
|
||||
Returns a summary of available indicator categories, useful for
|
||||
exploring what types of indicators are available.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category name to count of indicators in that category.
|
||||
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
categories: Dict[str, int] = {}
|
||||
for metadata in registry.get_all_metadata():
|
||||
category = metadata.category
|
||||
categories[category] = categories.get(category, 0) + 1
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
@tool
|
||||
async def add_indicator_to_chart(
|
||||
indicator_id: str,
|
||||
talib_name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a technical indicator to the chart.
|
||||
|
||||
This will create a new indicator instance and display it on the TradingView chart.
|
||||
The indicator will be synchronized with the frontend in real-time.
|
||||
|
||||
Args:
|
||||
indicator_id: Unique identifier for this indicator instance (e.g., 'rsi_14', 'sma_50')
|
||||
talib_name: Name of the TA-Lib indicator (e.g., 'RSI', 'SMA', 'MACD', 'BBANDS')
|
||||
Use search_indicators() or get_indicator_info() to find available indicators
|
||||
parameters: Optional dictionary of indicator parameters
|
||||
Example for RSI: {'timeperiod': 14}
|
||||
Example for SMA: {'timeperiod': 50}
|
||||
Example for MACD: {'fastperiod': 12, 'slowperiod': 26, 'signalperiod': 9}
|
||||
Example for BBANDS: {'timeperiod': 20, 'nbdevup': 2, 'nbdevdn': 2}
|
||||
symbol: Optional symbol to apply the indicator to (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- indicator: The complete indicator object
|
||||
|
||||
Example:
|
||||
# Add RSI(14)
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='rsi_14',
|
||||
talib_name='RSI',
|
||||
parameters={'timeperiod': 14}
|
||||
)
|
||||
|
||||
# Add 50-period SMA
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='sma_50',
|
||||
talib_name='SMA',
|
||||
parameters={'timeperiod': 50}
|
||||
)
|
||||
|
||||
# Add MACD with default parameters
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='macd_default',
|
||||
talib_name='MACD'
|
||||
)
|
||||
"""
|
||||
from schema.indicator import IndicatorInstance
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
# Verify the indicator exists
|
||||
indicator_registry = _get_indicator_registry()
|
||||
if not indicator_registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = indicator_registry.get_metadata(talib_name)
|
||||
if not metadata:
|
||||
raise ValueError(
|
||||
f"Indicator '{talib_name}' not found. "
|
||||
f"Use search_indicators() to find available indicators."
|
||||
)
|
||||
|
||||
# Check if updating existing indicator
|
||||
existing_indicator = indicator_store.indicators.get(indicator_id)
|
||||
is_update = existing_indicator is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for indicator: {symbol}")
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create indicator instance
|
||||
indicator = IndicatorInstance(
|
||||
id=indicator_id,
|
||||
talib_name=talib_name,
|
||||
instance_name=f"{talib_name}_{indicator_id}",
|
||||
parameters=parameters or {},
|
||||
visible=True,
|
||||
pane='chart', # Most indicators go on the chart pane
|
||||
symbol=symbol,
|
||||
created_at=existing_indicator.get('created_at') if existing_indicator else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
indicator_store.indicators[indicator_id] = indicator.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} indicator '{indicator_id}' "
|
||||
f"(TA-Lib: {talib_name}) with parameters: {parameters}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"indicator": indicator.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def remove_indicator_from_chart(indicator_id: str) -> Dict[str, str]:
|
||||
"""Remove an indicator from the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance to remove
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
await remove_indicator_from_chart('rsi_14')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
if indicator_id not in indicator_store.indicators:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
# Delete the indicator
|
||||
del indicator_store.indicators[indicator_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Removed indicator '{indicator_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Indicator '{indicator_id}' removed"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_chart_indicators(symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List all indicators currently applied to the chart.
|
||||
|
||||
Args:
|
||||
symbol: Optional filter by symbol (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
List of indicator instances, each containing:
|
||||
- id: Indicator instance ID
|
||||
- talib_name: TA-Lib indicator name
|
||||
- instance_name: Display name
|
||||
- parameters: Current parameter values
|
||||
- visible: Whether indicator is visible
|
||||
- pane: Which pane it's displayed in
|
||||
- symbol: Symbol it's applied to
|
||||
|
||||
Example:
|
||||
# List all indicators on current symbol
|
||||
indicators = list_chart_indicators()
|
||||
|
||||
# List indicators on specific symbol
|
||||
btc_indicators = list_chart_indicators(symbol='BINANCE:BTC/USDT')
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
logger.info(f"list_chart_indicators: Raw store indicators: {indicator_store.indicators}")
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None:
|
||||
registry = _get_registry()
|
||||
if registry and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
|
||||
indicators = list(indicator_store.indicators.values())
|
||||
|
||||
logger.info(f"list_chart_indicators: Converted to list: {indicators}")
|
||||
logger.info(f"list_chart_indicators: Filtering by symbol: {symbol}")
|
||||
|
||||
# Filter by symbol if provided
|
||||
if symbol:
|
||||
indicators = [ind for ind in indicators if ind.get('symbol') == symbol]
|
||||
|
||||
logger.info(f"list_chart_indicators: Returning {len(indicators)} indicators")
|
||||
return indicators
|
||||
|
||||
|
||||
@tool
|
||||
def get_chart_indicator(indicator_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific indicator on the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing the indicator data
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
indicator = get_chart_indicator('rsi_14')
|
||||
print(f"Indicator: {indicator['talib_name']}")
|
||||
print(f"Parameters: {indicator['parameters']}")
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
indicator = indicator_store.indicators.get(indicator_id)
|
||||
if not indicator:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
return indicator
|
||||
|
||||
|
||||
INDICATOR_TOOLS = [
|
||||
# Discovery tools
|
||||
list_indicators,
|
||||
get_indicator_info,
|
||||
search_indicators,
|
||||
get_indicator_categories,
|
||||
# Chart indicator management tools
|
||||
add_indicator_to_chart,
|
||||
remove_indicator_from_chart,
|
||||
list_chart_indicators,
|
||||
get_chart_indicator
|
||||
]
|
||||
171
backend.old/src/agent/tools/research_tools.py
Normal file
171
backend.old/src/agent/tools/research_tools.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Research and external data tools for trading analysis."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.tools import tool
|
||||
from langchain_community.tools import (
|
||||
ArxivQueryRun,
|
||||
WikipediaQueryRun,
|
||||
DuckDuckGoSearchRun
|
||||
)
|
||||
from langchain_community.utilities import (
|
||||
ArxivAPIWrapper,
|
||||
WikipediaAPIWrapper,
|
||||
DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def search_arxiv(query: str, max_results: int = 5) -> str:
|
||||
"""Search arXiv for academic papers on quantitative finance, trading strategies, and machine learning.
|
||||
|
||||
Use this to find research papers on topics like:
|
||||
- Market microstructure and order flow
|
||||
- Algorithmic trading strategies
|
||||
- Machine learning for finance
|
||||
- Time series forecasting
|
||||
- Risk management
|
||||
- Portfolio optimization
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "machine learning algorithmic trading", "deep learning stock prediction")
|
||||
max_results: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Summary of papers including titles, authors, abstracts, and links
|
||||
|
||||
Example:
|
||||
search_arxiv("reinforcement learning trading", max_results=3)
|
||||
"""
|
||||
arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=max_results))
|
||||
return arxiv.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information on finance, trading, and economics concepts.
|
||||
|
||||
Use this to get background information on:
|
||||
- Financial instruments and markets
|
||||
- Economic indicators
|
||||
- Trading terminology
|
||||
- Technical analysis concepts
|
||||
- Historical market events
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "Black-Scholes model", "technical analysis", "options trading")
|
||||
|
||||
Returns:
|
||||
Wikipedia article summary with key information
|
||||
|
||||
Example:
|
||||
search_wikipedia("Bollinger Bands")
|
||||
"""
|
||||
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||
return wikipedia.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def search_web(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for current information on markets, news, and trading.
|
||||
|
||||
Use this to find:
|
||||
- Latest market news and analysis
|
||||
- Company announcements and earnings
|
||||
- Economic events and indicators
|
||||
- Cryptocurrency updates
|
||||
- Exchange status and updates
|
||||
- Trading strategy discussions
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "Bitcoin price news", "Fed interest rate decision")
|
||||
max_results: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Search results with titles, snippets, and links
|
||||
|
||||
Example:
|
||||
search_web("Ethereum merge update", max_results=3)
|
||||
"""
|
||||
# Lazy initialization to avoid hanging during import
|
||||
search = DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper())
|
||||
# Note: max_results parameter doesn't work properly with current wrapper
|
||||
return search.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def http_get(url: str, params: Optional[Dict[str, str]] = None) -> str:
|
||||
"""Make HTTP GET request to fetch data from APIs or web pages.
|
||||
|
||||
Use this to retrieve:
|
||||
- Exchange API data (if public endpoints)
|
||||
- Market data from external APIs
|
||||
- Documentation and specifications
|
||||
- News articles and blog posts
|
||||
- JSON/XML data from web services
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
params: Optional query parameters as a dictionary
|
||||
|
||||
Returns:
|
||||
Response text from the URL
|
||||
|
||||
Raises:
|
||||
ValueError: If the request fails
|
||||
|
||||
Example:
|
||||
http_get("https://api.coingecko.com/api/v3/simple/price",
|
||||
params={"ids": "bitcoin", "vs_currencies": "usd"})
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"HTTP GET request failed: {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def http_post(url: str, data: Dict[str, Any]) -> str:
|
||||
"""Make HTTP POST request to send data to APIs.
|
||||
|
||||
Use this to:
|
||||
- Submit data to external APIs
|
||||
- Trigger webhooks
|
||||
- Post analysis results
|
||||
- Interact with exchange APIs (if authenticated)
|
||||
|
||||
Args:
|
||||
url: The URL to post to
|
||||
data: Dictionary of data to send in the request body
|
||||
|
||||
Returns:
|
||||
Response text from the server
|
||||
|
||||
Raises:
|
||||
ValueError: If the request fails
|
||||
|
||||
Example:
|
||||
http_post("https://webhook.site/xxx", {"message": "Trade executed"})
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"HTTP POST request failed: {str(e)}")
|
||||
|
||||
|
||||
# Export tools list
|
||||
RESEARCH_TOOLS = [
|
||||
search_arxiv,
|
||||
search_wikipedia,
|
||||
search_web,
|
||||
http_get,
|
||||
http_post
|
||||
]
|
||||
475
backend.old/src/agent/tools/shape_tools.py
Normal file
475
backend.old/src/agent/tools/shape_tools.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Shape/drawing tools for chart analysis."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map legacy/common shape type names to TradingView's native names
|
||||
SHAPE_TYPE_ALIASES: Dict[str, str] = {
|
||||
'trendline': 'trend_line',
|
||||
'fibonacci': 'fib_retracement',
|
||||
'fibonacci_extension': 'fib_trend_ext',
|
||||
'gann_fan': 'gannbox_fan',
|
||||
}
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_shape_store():
|
||||
"""Get the global ShapeStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "ShapeStore" in registry.entries:
|
||||
return registry.entries["ShapeStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def search_shapes(
|
||||
start_time: Optional[int] = None,
|
||||
end_time: Optional[int] = None,
|
||||
shape_type: Optional[str] = None,
|
||||
symbol: Optional[str] = None,
|
||||
shape_ids: Optional[List[str]] = None,
|
||||
original_ids: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for shapes/drawings using flexible filters.
|
||||
|
||||
This tool can search shapes by:
|
||||
- Time range (finds shapes that overlap the range)
|
||||
- Shape type (e.g., 'trendline', 'horizontal_line')
|
||||
- Symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
- Specific shape IDs (TradingView's assigned IDs)
|
||||
- Original IDs (the IDs you specified when creating shapes)
|
||||
|
||||
Args:
|
||||
start_time: Optional start of time range (Unix timestamp in seconds)
|
||||
end_time: Optional end of time range (Unix timestamp in seconds)
|
||||
shape_type: Optional filter by shape type (e.g., 'trend_line', 'horizontal_line', 'rectangle')
|
||||
symbol: Optional filter by symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
shape_ids: Optional list of specific shape IDs to retrieve (searches both id and original_id fields)
|
||||
original_ids: Optional list of original IDs to search for (the IDs you specified when creating)
|
||||
|
||||
Returns:
|
||||
List of matching shapes, each as a dictionary with:
|
||||
- id: Shape identifier (TradingView's assigned ID)
|
||||
- original_id: The ID you specified when creating the shape (if applicable)
|
||||
- type: Shape type
|
||||
- points: List of control points with time and price
|
||||
- color, line_width, line_style: Visual properties
|
||||
- properties: Additional shape-specific properties
|
||||
- symbol: Symbol the shape is drawn on
|
||||
- created_at, modified_at: Timestamps
|
||||
|
||||
Examples:
|
||||
# Find all shapes in the currently visible chart range
|
||||
shapes = search_shapes(
|
||||
start_time=chart_state.start_time,
|
||||
end_time=chart_state.end_time
|
||||
)
|
||||
|
||||
# Find only trendlines in a specific time range
|
||||
trendlines = search_shapes(
|
||||
start_time=1640000000,
|
||||
end_time=1650000000,
|
||||
shape_type='trend_line'
|
||||
)
|
||||
|
||||
# Find shapes for a specific symbol
|
||||
btc_shapes = search_shapes(
|
||||
start_time=1640000000,
|
||||
end_time=1650000000,
|
||||
symbol='BINANCE:BTC/USDT'
|
||||
)
|
||||
|
||||
# Get specific shapes by TradingView ID or original ID
|
||||
# This searches both the 'id' and 'original_id' fields
|
||||
selected = search_shapes(
|
||||
shape_ids=['trendline-1', 'support-42k', 'fib-retracement-1']
|
||||
)
|
||||
|
||||
# Get shapes by the original IDs you specified when creating them
|
||||
my_shapes = search_shapes(
|
||||
original_ids=['my-support-line', 'my-resistance-line']
|
||||
)
|
||||
|
||||
# Get all trendlines (no time filter)
|
||||
all_trendlines = search_shapes(shape_type='trend_line')
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
shapes_dict = shape_store.shapes
|
||||
matching_shapes = []
|
||||
|
||||
# If specific shape IDs are requested, search by both id and original_id
|
||||
if shape_ids:
|
||||
for requested_id in shape_ids:
|
||||
# First try direct ID lookup
|
||||
shape = shapes_dict.get(requested_id)
|
||||
if shape:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
else:
|
||||
# If not found by ID, search by original_id
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
if shape.get('original_id') == requested_id:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes by ID filter (requested {len(shape_ids)} IDs)"
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
return matching_shapes
|
||||
|
||||
# If specific original IDs are requested, search by original_id only
|
||||
if original_ids:
|
||||
for original_id in original_ids:
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
if shape.get('original_id') == original_id:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes by original_id filter (requested {len(original_ids)} IDs)"
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
return matching_shapes
|
||||
|
||||
# Otherwise, search all shapes with filters
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
# Filter by symbol if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
|
||||
# Filter by type if specified
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
|
||||
# Filter by time range if specified
|
||||
if start_time is not None and end_time is not None:
|
||||
# Check if any control point falls within the time range
|
||||
# or if the shape spans across the time range
|
||||
points = shape.get('points', [])
|
||||
if not points:
|
||||
continue
|
||||
|
||||
# Get min and max times from shape's control points
|
||||
shape_times = [point['time'] for point in points]
|
||||
shape_min_time = min(shape_times)
|
||||
shape_max_time = max(shape_times)
|
||||
|
||||
# Check for overlap: shape overlaps if its range intersects with query range
|
||||
if not (shape_max_time >= start_time and shape_min_time <= end_time):
|
||||
continue
|
||||
|
||||
matching_shapes.append(shape)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes"
|
||||
+ (f" in time range {start_time}-{end_time}" if start_time and end_time else "")
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
|
||||
return matching_shapes
|
||||
|
||||
|
||||
@tool
|
||||
async def create_or_update_shape(
|
||||
shape_id: str,
|
||||
shape_type: str,
|
||||
points: List[Dict[str, Any]],
|
||||
color: Optional[str] = None,
|
||||
line_width: Optional[int] = None,
|
||||
line_style: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new shape or update an existing shape on the chart.
|
||||
|
||||
This tool allows the agent to draw shapes on the user's chart or modify
|
||||
existing shapes. Shapes are synchronized to the frontend in real-time.
|
||||
|
||||
IMPORTANT - Shape ID Mapping:
|
||||
When you create a shape, TradingView will assign its own internal ID that differs
|
||||
from the shape_id you provide. The shape will be updated in the store with:
|
||||
- id: TradingView's assigned ID
|
||||
- original_id: The shape_id you provided
|
||||
|
||||
To find your shape later, use search_shapes() and filter by original_id field.
|
||||
|
||||
Example:
|
||||
# Create a shape
|
||||
await create_or_update_shape(shape_id='my-support', ...)
|
||||
|
||||
# Later, find it by original_id
|
||||
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
|
||||
my_shape = next((s for s in shapes if s.get('original_id') == 'my-support'), None)
|
||||
|
||||
Args:
|
||||
shape_id: Unique identifier for the shape (use existing ID to update, new ID to create)
|
||||
Note: TradingView will assign its own ID; your ID will be stored in original_id
|
||||
shape_type: Type of shape using TradingView's native names.
|
||||
|
||||
Single-point shapes (use 1 point):
|
||||
- 'horizontal_line': Horizontal support/resistance line
|
||||
- 'vertical_line': Vertical time marker
|
||||
- 'text': Text label
|
||||
- 'anchored_text': Anchored text annotation
|
||||
- 'anchored_note': Anchored note
|
||||
- 'note': Note annotation
|
||||
- 'emoji': Emoji marker
|
||||
- 'icon': Icon marker
|
||||
- 'sticker': Sticker marker
|
||||
- 'arrow_up': Upward arrow marker
|
||||
- 'arrow_down': Downward arrow marker
|
||||
- 'flag': Flag marker
|
||||
- 'long_position': Long position marker
|
||||
- 'short_position': Short position marker
|
||||
|
||||
Multi-point shapes (use 2+ points):
|
||||
- 'trend_line': Trendline (2 points)
|
||||
- 'rectangle': Rectangle (2 points: top-left, bottom-right)
|
||||
- 'fib_retracement': Fibonacci retracement (2 points)
|
||||
- 'fib_trend_ext': Fibonacci extension (3 points)
|
||||
- 'parallel_channel': Parallel channel (3 points)
|
||||
- 'arrow': Arrow (2 points)
|
||||
- 'circle': Circle/ellipse (2-3 points)
|
||||
- 'path': Free drawing path (3+ points)
|
||||
- 'pitchfork': Andrew's pitchfork (3 points)
|
||||
- 'gannbox_fan': Gann fan (2 points)
|
||||
- 'head_and_shoulders': Head and shoulders pattern (5 points)
|
||||
|
||||
points: List of control points, each with 'time' (Unix seconds) and 'price' fields
|
||||
color: Optional color (hex like '#FF0000' or name like 'red')
|
||||
line_width: Optional line width in pixels (default: 1)
|
||||
line_style: Optional line style: 'solid', 'dashed', 'dotted' (default: 'solid')
|
||||
properties: Optional dict of additional shape-specific properties
|
||||
symbol: Optional symbol to associate with the shape (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- shape: The complete shape object (initially with your ID, will be updated to TV ID)
|
||||
|
||||
Examples:
|
||||
# Draw a trendline between two points
|
||||
await create_or_update_shape(
|
||||
shape_id='my-trendline-1',
|
||||
shape_type='trend_line',
|
||||
points=[
|
||||
{'time': 1640000000, 'price': 45000.0},
|
||||
{'time': 1650000000, 'price': 50000.0}
|
||||
],
|
||||
color='#00FF00',
|
||||
line_width=2
|
||||
)
|
||||
|
||||
# Draw a horizontal support line
|
||||
await create_or_update_shape(
|
||||
shape_id='support-1',
|
||||
shape_type='horizontal_line',
|
||||
points=[{'time': 1640000000, 'price': 42000.0}],
|
||||
color='blue',
|
||||
line_style='dashed'
|
||||
)
|
||||
|
||||
# Find your shape after creation using original_id
|
||||
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
|
||||
my_shape = next((s for s in shapes if s.get('original_id') == 'support-1'), None)
|
||||
if my_shape:
|
||||
print(f"TradingView assigned ID: {my_shape['id']}")
|
||||
"""
|
||||
from schema.shape import Shape, ControlPoint
|
||||
import time as time_module
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
# Normalize shape type (handle legacy names)
|
||||
normalized_type = SHAPE_TYPE_ALIASES.get(shape_type, shape_type)
|
||||
if normalized_type != shape_type:
|
||||
logger.info(f"Normalized shape type '{shape_type}' -> '{normalized_type}'")
|
||||
|
||||
# Convert points to ControlPoint objects
|
||||
control_points = []
|
||||
for p in points:
|
||||
point_data = {
|
||||
'time': p['time'],
|
||||
'price': p['price']
|
||||
}
|
||||
# Only include channel if it's actually provided
|
||||
if 'channel' in p and p['channel'] is not None:
|
||||
point_data['channel'] = p['channel']
|
||||
control_points.append(ControlPoint(**point_data))
|
||||
|
||||
# Check if updating existing shape
|
||||
existing_shape = shape_store.shapes.get(shape_id)
|
||||
is_update = existing_shape is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for shape: {symbol}")
|
||||
|
||||
now = int(time_module.time())
|
||||
|
||||
# Create shape object
|
||||
shape = Shape(
|
||||
id=shape_id,
|
||||
type=normalized_type,
|
||||
points=control_points,
|
||||
color=color,
|
||||
line_width=line_width,
|
||||
line_style=line_style,
|
||||
properties=properties or {},
|
||||
symbol=symbol,
|
||||
created_at=existing_shape.get('created_at') if existing_shape else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
shape_store.shapes[shape_id] = shape.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} shape '{shape_id}' "
|
||||
f"of type '{shape_type}' with {len(points)} points"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"shape": shape.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_shape(shape_id: str) -> Dict[str, str]:
|
||||
"""Delete a shape from the chart.
|
||||
|
||||
Args:
|
||||
shape_id: ID of the shape to delete
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If shape doesn't exist
|
||||
|
||||
Example:
|
||||
await delete_shape('my-trendline-1')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
if shape_id not in shape_store.shapes:
|
||||
raise ValueError(f"Shape '{shape_id}' not found")
|
||||
|
||||
# Delete the shape
|
||||
del shape_store.shapes[shape_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Deleted shape '{shape_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Shape '{shape_id}' deleted"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_shape(shape_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific shape by ID.
|
||||
|
||||
Args:
|
||||
shape_id: ID of the shape to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the shape data
|
||||
|
||||
Raises:
|
||||
ValueError: If shape doesn't exist
|
||||
|
||||
Example:
|
||||
shape = get_shape('my-trendline-1')
|
||||
print(f"Shape type: {shape['type']}")
|
||||
print(f"Points: {shape['points']}")
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
shape = shape_store.shapes.get(shape_id)
|
||||
if not shape:
|
||||
raise ValueError(f"Shape '{shape_id}' not found")
|
||||
|
||||
return shape
|
||||
|
||||
|
||||
@tool
|
||||
def list_all_shapes() -> List[Dict[str, Any]]:
|
||||
"""List all shapes currently on the chart.
|
||||
|
||||
Returns:
|
||||
List of all shapes as dictionaries
|
||||
|
||||
Example:
|
||||
shapes = list_all_shapes()
|
||||
print(f"Total shapes: {len(shapes)}")
|
||||
for shape in shapes:
|
||||
print(f" - {shape['id']}: {shape['type']}")
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
return list(shape_store.shapes.values())
|
||||
|
||||
|
||||
SHAPE_TOOLS = [
|
||||
search_shapes,
|
||||
create_or_update_shape,
|
||||
delete_shape,
|
||||
get_shape,
|
||||
list_all_shapes
|
||||
]
|
||||
138
backend.old/src/agent/tools/sync_tools.py
Normal file
138
backend.old/src/agent/tools/sync_tools.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Synchronization store tools."""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_sync_stores() -> List[str]:
|
||||
"""List all available synchronization stores.
|
||||
|
||||
Returns:
|
||||
List of store names that can be read/written
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return list(registry.entries.keys())
|
||||
|
||||
|
||||
@tool
|
||||
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
||||
"""Read the current state of a synchronization store.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current state of the store
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
return entry.model.model_dump(mode="json")
|
||||
|
||||
|
||||
@tool
|
||||
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Update the state of a synchronization store.
|
||||
|
||||
This will apply the updates to the store and trigger synchronization
|
||||
with all connected clients.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to update
|
||||
updates: Dictionary of field updates (field_name: new_value)
|
||||
|
||||
Returns:
|
||||
Dictionary with status and updated fields
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist or updates are invalid
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
try:
|
||||
# Get current state
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
|
||||
# Apply updates
|
||||
new_state = {**current_state, **updates}
|
||||
|
||||
# Update the model
|
||||
registry._update_model(entry.model, new_state)
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"store": store_name,
|
||||
"updated_fields": list(updates.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
||||
"""Get the schema/structure of a synchronization store.
|
||||
|
||||
This shows what fields are available and their types.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store
|
||||
|
||||
Returns:
|
||||
Dictionary describing the store's schema
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
# Get model schema
|
||||
schema = entry.model.model_json_schema()
|
||||
|
||||
return {
|
||||
"store_name": store_name,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
SYNC_TOOLS = [
|
||||
list_sync_stores,
|
||||
read_sync_state,
|
||||
write_sync_state,
|
||||
get_store_schema
|
||||
]
|
||||
366
backend.old/src/agent/tools/trigger_tools.py
Normal file
366
backend.old/src/agent/tools/trigger_tools.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Agent tools for trigger system.
|
||||
|
||||
Allows agents to:
|
||||
- Schedule recurring tasks (cron-style)
|
||||
- Execute one-time triggers
|
||||
- Manage scheduled triggers (list, cancel)
|
||||
- Connect events to sub-agent runs or lambdas
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global references set by main.py
|
||||
_trigger_queue = None
|
||||
_trigger_scheduler = None
|
||||
_coordinator = None
|
||||
|
||||
|
||||
def set_trigger_queue(queue):
|
||||
"""Set the global TriggerQueue instance for tools to use."""
|
||||
global _trigger_queue
|
||||
_trigger_queue = queue
|
||||
|
||||
|
||||
def set_trigger_scheduler(scheduler):
|
||||
"""Set the global TriggerScheduler instance for tools to use."""
|
||||
global _trigger_scheduler
|
||||
_trigger_scheduler = scheduler
|
||||
|
||||
|
||||
def set_coordinator(coordinator):
|
||||
"""Set the global CommitCoordinator instance for tools to use."""
|
||||
global _coordinator
|
||||
_coordinator = coordinator
|
||||
|
||||
|
||||
def _get_trigger_queue():
|
||||
"""Get the global trigger queue instance."""
|
||||
if not _trigger_queue:
|
||||
raise ValueError("TriggerQueue not initialized")
|
||||
return _trigger_queue
|
||||
|
||||
|
||||
def _get_trigger_scheduler():
|
||||
"""Get the global trigger scheduler instance."""
|
||||
if not _trigger_scheduler:
|
||||
raise ValueError("TriggerScheduler not initialized")
|
||||
return _trigger_scheduler
|
||||
|
||||
|
||||
def _get_coordinator():
|
||||
"""Get the global coordinator instance."""
|
||||
if not _coordinator:
|
||||
raise ValueError("CommitCoordinator not initialized")
|
||||
return _coordinator
|
||||
|
||||
|
||||
@tool
|
||||
async def schedule_agent_prompt(
|
||||
prompt: str,
|
||||
schedule_type: str,
|
||||
schedule_config: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Schedule an agent to run with a specific prompt on a recurring schedule.
|
||||
|
||||
This allows you to set up automated tasks where the agent runs periodically
|
||||
with a predefined prompt. Useful for:
|
||||
- Daily market analysis reports
|
||||
- Hourly portfolio rebalancing checks
|
||||
- Weekly performance summaries
|
||||
- Monitoring alerts
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent when triggered
|
||||
schedule_type: Type of schedule - "interval" or "cron"
|
||||
schedule_config: Schedule configuration:
|
||||
For "interval": {"minutes": 5} or {"hours": 1, "minutes": 30}
|
||||
For "cron": {"hour": "9", "minute": "0"} for 9:00 AM daily
|
||||
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"}
|
||||
name: Optional descriptive name for this scheduled task
|
||||
|
||||
Returns:
|
||||
Dictionary with job_id and confirmation message
|
||||
|
||||
Examples:
|
||||
# Run every 5 minutes
|
||||
schedule_agent_prompt(
|
||||
prompt="Check BTC price and alert if > $50k",
|
||||
schedule_type="interval",
|
||||
schedule_config={"minutes": 5}
|
||||
)
|
||||
|
||||
# Run daily at 9 AM
|
||||
schedule_agent_prompt(
|
||||
prompt="Generate daily market summary",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "9", "minute": "0"}
|
||||
)
|
||||
|
||||
# Run hourly on weekdays
|
||||
schedule_agent_prompt(
|
||||
prompt="Monitor portfolio for rebalancing opportunities",
|
||||
schedule_type="cron",
|
||||
schedule_config={"minute": "0", "day_of_week": "mon-fri"}
|
||||
)
|
||||
"""
|
||||
from trigger.handlers import LambdaHandler
|
||||
from trigger import Priority
|
||||
|
||||
scheduler = _get_trigger_scheduler()
|
||||
queue = _get_trigger_queue()
|
||||
|
||||
if not name:
|
||||
name = f"agent_prompt_{hash(prompt) % 10000}"
|
||||
|
||||
# Create a lambda that enqueues an agent trigger with the prompt
|
||||
async def agent_prompt_lambda():
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
|
||||
# Create agent trigger (will use current session's context)
|
||||
# In production, you'd want to specify which session/user this belongs to
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id="scheduled", # Special session for scheduled tasks
|
||||
message_content=prompt,
|
||||
coordinator=_get_coordinator(),
|
||||
)
|
||||
|
||||
await queue.enqueue(trigger)
|
||||
return [] # No direct commit intents
|
||||
|
||||
# Wrap in lambda handler
|
||||
lambda_trigger = LambdaHandler(
|
||||
name=f"scheduled_{name}",
|
||||
func=agent_prompt_lambda,
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
|
||||
# Schedule based on type
|
||||
if schedule_type == "interval":
|
||||
job_id = scheduler.schedule_interval(
|
||||
lambda_trigger,
|
||||
seconds=schedule_config.get("seconds"),
|
||||
minutes=schedule_config.get("minutes"),
|
||||
hours=schedule_config.get("hours"),
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
elif schedule_type == "cron":
|
||||
job_id = scheduler.schedule_cron(
|
||||
lambda_trigger,
|
||||
minute=schedule_config.get("minute"),
|
||||
hour=schedule_config.get("hour"),
|
||||
day=schedule_config.get("day"),
|
||||
month=schedule_config.get("month"),
|
||||
day_of_week=schedule_config.get("day_of_week"),
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid schedule_type: {schedule_type}. Use 'interval' or 'cron'")
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"message": f"Scheduled '{name}' with job_id={job_id}",
|
||||
"schedule_type": schedule_type,
|
||||
"config": schedule_config,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_agent_prompt_once(
|
||||
prompt: str,
|
||||
priority: str = "normal",
|
||||
) -> Dict[str, str]:
|
||||
"""Execute an agent prompt once, immediately (enqueued with priority).
|
||||
|
||||
Use this to trigger a sub-agent with a specific task without waiting for
|
||||
a user message. Useful for:
|
||||
- Background analysis tasks
|
||||
- One-time data processing
|
||||
- Responding to specific events
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent
|
||||
priority: Priority level - "high", "normal", or "low"
|
||||
|
||||
Returns:
|
||||
Confirmation that the prompt was enqueued
|
||||
|
||||
Example:
|
||||
execute_agent_prompt_once(
|
||||
prompt="Analyze the last 100 BTC/USDT bars and identify support levels",
|
||||
priority="high"
|
||||
)
|
||||
"""
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
from trigger import Priority
|
||||
|
||||
queue = _get_trigger_queue()
|
||||
|
||||
# Map string priority to enum
|
||||
priority_map = {
|
||||
"high": Priority.USER_AGENT, # Same priority as user messages
|
||||
"normal": Priority.SYSTEM,
|
||||
"low": Priority.LOW,
|
||||
}
|
||||
priority_enum = priority_map.get(priority.lower(), Priority.SYSTEM)
|
||||
|
||||
# Create agent trigger
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id="oneshot",
|
||||
message_content=prompt,
|
||||
coordinator=_get_coordinator(),
|
||||
)
|
||||
|
||||
# Enqueue with priority override
|
||||
queue_seq = await queue.enqueue(trigger, priority_enum)
|
||||
|
||||
return {
|
||||
"queue_seq": queue_seq,
|
||||
"message": f"Enqueued agent prompt with priority={priority}",
|
||||
"prompt": prompt[:100] + "..." if len(prompt) > 100 else prompt,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_scheduled_triggers() -> List[Dict[str, Any]]:
|
||||
"""List all currently scheduled triggers.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with job information (id, name, next_run_time)
|
||||
|
||||
Example:
|
||||
jobs = list_scheduled_triggers()
|
||||
for job in jobs:
|
||||
print(f"{job['id']}: {job['name']} - next run at {job['next_run_time']}")
|
||||
"""
|
||||
scheduler = _get_trigger_scheduler()
|
||||
jobs = scheduler.get_jobs()
|
||||
|
||||
result = []
|
||||
for job in jobs:
|
||||
result.append({
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
|
||||
"trigger": str(job.trigger),
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def cancel_scheduled_trigger(job_id: str) -> Dict[str, str]:
|
||||
"""Cancel a scheduled trigger by its job ID.
|
||||
|
||||
Args:
|
||||
job_id: The job ID returned from schedule_agent_prompt or list_scheduled_triggers
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
|
||||
Example:
|
||||
cancel_scheduled_trigger("interval_123")
|
||||
"""
|
||||
scheduler = _get_trigger_scheduler()
|
||||
success = scheduler.remove_job(job_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Cancelled job {job_id}",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Job {job_id} not found",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def on_data_update_run_agent(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
prompt_template: str,
|
||||
) -> Dict[str, str]:
|
||||
"""Set up an agent to run whenever new data arrives for a specific symbol.
|
||||
|
||||
The prompt_template can include {variables} that will be filled with bar data:
|
||||
- {time}: Bar timestamp
|
||||
- {open}, {high}, {low}, {close}, {volume}: OHLCV values
|
||||
- {symbol}: Trading pair symbol
|
||||
- {source}: Data source name
|
||||
|
||||
Args:
|
||||
source_name: Name of data source (e.g., "binance")
|
||||
symbol: Trading pair (e.g., "BTC/USDT")
|
||||
resolution: Time resolution (e.g., "1m", "5m", "1h")
|
||||
prompt_template: Template string for agent prompt
|
||||
|
||||
Returns:
|
||||
Confirmation with subscription details
|
||||
|
||||
Example:
|
||||
on_data_update_run_agent(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
prompt_template="New bar on {symbol}: close={close}. Check if we should trade."
|
||||
)
|
||||
|
||||
Note:
|
||||
This is a simplified version. Full implementation would wire into
|
||||
DataSource subscription system to trigger on every bar update.
|
||||
"""
|
||||
# TODO: Implement proper DataSource subscription integration
|
||||
# For now, return placeholder
|
||||
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"message": "Data-driven agent triggers coming soon",
|
||||
"config": {
|
||||
"source": source_name,
|
||||
"symbol": symbol,
|
||||
"resolution": resolution,
|
||||
"prompt_template": prompt_template,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_trigger_system_stats() -> Dict[str, Any]:
|
||||
"""Get statistics about the trigger system.
|
||||
|
||||
Returns:
|
||||
Dictionary with queue depth, execution stats, etc.
|
||||
|
||||
Example:
|
||||
stats = get_trigger_system_stats()
|
||||
print(f"Queue depth: {stats['queue_depth']}")
|
||||
print(f"Current seq: {stats['current_seq']}")
|
||||
"""
|
||||
queue = _get_trigger_queue()
|
||||
coordinator = _get_coordinator()
|
||||
|
||||
return {
|
||||
"queue_depth": queue.get_queue_size(),
|
||||
"queue_running": queue.is_running(),
|
||||
"coordinator_stats": coordinator.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
# Export tools list
|
||||
TRIGGER_TOOLS = [
|
||||
schedule_agent_prompt,
|
||||
execute_agent_prompt_once,
|
||||
list_scheduled_triggers,
|
||||
cancel_scheduled_trigger,
|
||||
on_data_update_run_agent,
|
||||
get_trigger_system_stats,
|
||||
]
|
||||
23
backend.old/src/datasource/__init__.py
Normal file
23
backend.old/src/datasource/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .base import DataSource
|
||||
from .schema import (
|
||||
ColumnInfo,
|
||||
SymbolInfo,
|
||||
Bar,
|
||||
HistoryResult,
|
||||
DatafeedConfig,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
)
|
||||
from .registry import DataSourceRegistry
|
||||
|
||||
__all__ = [
|
||||
"DataSource",
|
||||
"ColumnInfo",
|
||||
"SymbolInfo",
|
||||
"Bar",
|
||||
"HistoryResult",
|
||||
"DatafeedConfig",
|
||||
"Resolution",
|
||||
"SearchResult",
|
||||
"DataSourceRegistry",
|
||||
]
|
||||
3
backend.old/src/datasource/adapters/__init__.py
Normal file
3
backend.old/src/datasource/adapters/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ccxt_adapter import CCXTDataSource
|
||||
|
||||
__all__ = ["CCXTDataSource"]
|
||||
568
backend.old/src/datasource/adapters/ccxt_adapter.py
Normal file
568
backend.old/src/datasource/adapters/ccxt_adapter.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
CCXT DataSource adapter for accessing cryptocurrency exchange data.
|
||||
|
||||
This adapter provides access to hundreds of cryptocurrency exchanges through
|
||||
the free CCXT library (not ccxt.pro), supporting both historical data and
|
||||
polling-based subscriptions.
|
||||
|
||||
Numerical Precision:
|
||||
- OHLCV data uses native floats for optimal DataFrame/analysis performance
|
||||
- Account balances and order data should use Decimal (via _to_decimal method)
|
||||
- CCXT returns numeric values as strings or floats depending on configuration
|
||||
- Price data converted to float (_to_float), financial data to Decimal (_to_decimal)
|
||||
|
||||
Real-time Updates:
|
||||
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
||||
- Default polling interval: 60 seconds (configurable)
|
||||
- Simulates real-time subscriptions by periodically fetching latest bars
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import ccxt.async_support as ccxt
|
||||
|
||||
from ..base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..schema import (
|
||||
Bar,
|
||||
ColumnInfo,
|
||||
DatafeedConfig,
|
||||
HistoryResult,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
SymbolInfo,
|
||||
)
|
||||
|
||||
|
||||
class CCXTDataSource(DataSource):
|
||||
"""
|
||||
DataSource adapter for CCXT cryptocurrency exchanges (free version).
|
||||
|
||||
Provides access to:
|
||||
- Multiple cryptocurrency exchanges (Binance, Coinbase, Kraken, etc.)
|
||||
- Historical OHLCV data via REST API
|
||||
- Polling-based real-time updates (configurable interval)
|
||||
- Symbol search and metadata
|
||||
|
||||
Args:
|
||||
exchange_id: CCXT exchange identifier (e.g., 'binance', 'coinbase', 'kraken')
|
||||
config: Optional exchange-specific configuration (API keys, options)
|
||||
sandbox: Whether to use sandbox/testnet mode (default: False)
|
||||
poll_interval: Interval in seconds for polling updates (default: 60)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exchange_id: str = "binance",
|
||||
config: Optional[Dict] = None,
|
||||
sandbox: bool = False,
|
||||
poll_interval: int = 60,
|
||||
):
|
||||
self.exchange_id = exchange_id
|
||||
self._config = config or {}
|
||||
self._sandbox = sandbox
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
# Initialize exchange (using free async_support, not pro)
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchange = exchange_class(self._config)
|
||||
|
||||
# Configure CCXT to use Decimal mode for precise financial calculations
|
||||
# This ensures all numeric values from the exchange use Decimal internally
|
||||
# We then convert OHLCV to float for DataFrame performance, but keep
|
||||
# Decimal precision for account balances, order sizes, etc.
|
||||
from decimal import Decimal as PythonDecimal
|
||||
self.exchange.number = PythonDecimal
|
||||
|
||||
# Log the precision mode being used by this exchange
|
||||
precision_mode = getattr(self.exchange, 'precisionMode', 'UNKNOWN')
|
||||
logger.info(
|
||||
f"CCXT {exchange_id}: Configured with Decimal mode. "
|
||||
f"Exchange precision mode: {precision_mode}"
|
||||
)
|
||||
|
||||
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
||||
self.exchange.set_sandbox_mode(True)
|
||||
|
||||
# Cache for markets
|
||||
self._markets: Optional[Dict] = None
|
||||
self._markets_loaded = False
|
||||
|
||||
# Active subscriptions (polling-based)
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._subscription_callbacks: Dict[str, Callable] = {}
|
||||
self._last_bars: Dict[str, int] = {} # Track last bar timestamp per subscription
|
||||
|
||||
@staticmethod
|
||||
def _to_decimal(value: Union[str, int, float, Decimal, None]) -> Optional[Decimal]:
|
||||
"""
|
||||
Convert a value to Decimal for numerical precision.
|
||||
|
||||
Handles CCXT's mixed output (strings, floats, ints, None).
|
||||
Converts floats by converting to string first to avoid precision loss.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Decimal):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return Decimal(value)
|
||||
if isinstance(value, (int, float)):
|
||||
# Convert to string first to avoid float precision issues
|
||||
return Decimal(str(value))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _to_float(value: Union[str, int, float, Decimal, None]) -> Optional[float]:
|
||||
"""
|
||||
Convert a value to float for OHLCV data.
|
||||
|
||||
OHLCV data is used for charting and DataFrame analysis, where native
|
||||
floats provide better performance and compatibility with pandas/numpy.
|
||||
For financial precision (balances, order sizes), use _to_decimal() instead.
|
||||
|
||||
When CCXT is in Decimal mode (exchange.number = Decimal), it returns
|
||||
Decimal objects. This method converts them to float for performance.
|
||||
|
||||
Handles CCXT's output in both modes:
|
||||
- Decimal mode: receives Decimal objects
|
||||
- Default mode: receives strings, floats, or ints
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
# CCXT in Decimal mode - convert to float for OHLCV
|
||||
return float(value)
|
||||
if isinstance(value, (str, int)):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
async def _ensure_markets_loaded(self):
|
||||
"""Ensure markets are loaded from exchange"""
|
||||
if not self._markets_loaded:
|
||||
self._markets = await self.exchange.load_markets()
|
||||
self._markets_loaded = True
|
||||
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
"""Get datafeed configuration"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
# Determine supported resolutions based on exchange capabilities
|
||||
supported_resolutions = [
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.M30,
|
||||
Resolution.H1,
|
||||
Resolution.H4,
|
||||
Resolution.D1,
|
||||
]
|
||||
|
||||
# Get unique exchange names (most CCXT exchanges are just one)
|
||||
exchanges = [self.exchange_id.upper()]
|
||||
|
||||
return DatafeedConfig(
|
||||
name=f"CCXT {self.exchange_id.title()}",
|
||||
description=f"Live and historical cryptocurrency data from {self.exchange_id} via CCXT library. "
|
||||
f"Supports OHLCV data for {len(self._markets) if self._markets else 'many'} trading pairs.",
|
||||
supported_resolutions=supported_resolutions,
|
||||
supports_search=True,
|
||||
supports_time=True,
|
||||
exchanges=exchanges,
|
||||
symbols_types=["crypto", "spot", "futures", "swap"],
|
||||
)
|
||||
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
"""Search for symbols on the exchange"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
query_upper = query.upper()
|
||||
results = []
|
||||
|
||||
for symbol, market in self._markets.items():
|
||||
# Match query against symbol or base/quote currencies
|
||||
if (query_upper in symbol or
|
||||
query_upper in market.get('base', '') or
|
||||
query_upper in market.get('quote', '')):
|
||||
|
||||
# Filter by type if specified
|
||||
market_type = market.get('type', 'spot')
|
||||
if type and market_type != type:
|
||||
continue
|
||||
|
||||
# Create search result
|
||||
base = market.get('base', '')
|
||||
quote = market.get('quote', '')
|
||||
|
||||
results.append(
|
||||
SearchResult(
|
||||
symbol=f"{base}/{quote}", # Clean user-facing format
|
||||
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
|
||||
full_name=f"{base}/{quote} ({self.exchange_id.upper()})",
|
||||
description=f"{base}/{quote} {market_type} trading pair on {self.exchange_id}",
|
||||
exchange=self.exchange_id.upper(),
|
||||
type=market_type,
|
||||
)
|
||||
)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
"""Get complete metadata for a symbol"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
market = self._markets[symbol]
|
||||
base = market.get('base', '')
|
||||
quote = market.get('quote', '')
|
||||
market_type = market.get('type', 'spot')
|
||||
|
||||
# Determine price scale from market precision
|
||||
# CCXT precision can be in different modes:
|
||||
# - DECIMAL_PLACES (int): number of decimal places (e.g., 2 = 0.01)
|
||||
# - TICK_SIZE (float): actual tick size (e.g., 0.01, 0.00001)
|
||||
# We need to convert to pricescale (10^n where n is decimal places)
|
||||
price_precision = market.get('precision', {}).get('price', 2)
|
||||
|
||||
if isinstance(price_precision, float):
|
||||
# TICK_SIZE mode: precision is the actual tick size (e.g., 0.01, 0.00001)
|
||||
# Convert tick size to decimal places
|
||||
# For 0.01 -> 2 decimal places, 0.00001 -> 5 decimal places
|
||||
tick_str = str(Decimal(str(price_precision)))
|
||||
if '.' in tick_str:
|
||||
decimal_places = len(tick_str.split('.')[1].rstrip('0'))
|
||||
else:
|
||||
decimal_places = 0
|
||||
pricescale = 10 ** decimal_places
|
||||
else:
|
||||
# DECIMAL_PLACES or SIGNIFICANT_DIGITS mode: precision is an integer
|
||||
# Assume DECIMAL_PLACES mode (most common for price)
|
||||
pricescale = 10 ** int(price_precision)
|
||||
|
||||
return SymbolInfo(
|
||||
symbol=f"{base}/{quote}", # Clean user-facing format
|
||||
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
|
||||
name=f"{base}/{quote}",
|
||||
description=f"{base}/{quote} {market_type} pair on {self.exchange_id}. "
|
||||
f"Minimum order: {market.get('limits', {}).get('amount', {}).get('min', 'N/A')} {base}",
|
||||
type=market_type,
|
||||
exchange=self.exchange_id.upper(),
|
||||
timezone="Etc/UTC",
|
||||
session="24x7",
|
||||
supported_resolutions=[
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.M30,
|
||||
Resolution.H1,
|
||||
Resolution.H4,
|
||||
Resolution.D1,
|
||||
],
|
||||
has_intraday=True,
|
||||
has_daily=True,
|
||||
has_weekly_and_monthly=False,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="float",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="float",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="float",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="float",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="float",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
],
|
||||
time_column="time",
|
||||
has_ohlcv=True,
|
||||
pricescale=pricescale,
|
||||
minmov=1,
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
)
|
||||
|
||||
def _resolution_to_timeframe(self, resolution: str) -> str:
|
||||
"""Convert our resolution format to CCXT timeframe format"""
|
||||
# Map our resolutions to CCXT timeframes
|
||||
mapping = {
|
||||
"1": "1m",
|
||||
"5": "5m",
|
||||
"15": "15m",
|
||||
"30": "30m",
|
||||
"60": "1h",
|
||||
"120": "2h",
|
||||
"240": "4h",
|
||||
"360": "6h",
|
||||
"720": "12h",
|
||||
"1D": "1d",
|
||||
"1W": "1w",
|
||||
"1M": "1M",
|
||||
}
|
||||
return mapping.get(resolution, "1m")
|
||||
|
||||
def _timeframe_to_milliseconds(self, timeframe: str) -> int:
|
||||
"""Convert CCXT timeframe to milliseconds"""
|
||||
unit = timeframe[-1]
|
||||
amount = int(timeframe[:-1]) if len(timeframe) > 1 else 1
|
||||
|
||||
units = {
|
||||
's': 1000,
|
||||
'm': 60 * 1000,
|
||||
'h': 60 * 60 * 1000,
|
||||
'd': 24 * 60 * 60 * 1000,
|
||||
'w': 7 * 24 * 60 * 60 * 1000,
|
||||
'M': 30 * 24 * 60 * 60 * 1000, # Approximate
|
||||
}
|
||||
|
||||
return amount * units.get(unit, 60000)
|
||||
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
"""Get historical bars from the exchange"""
|
||||
logger.info(
|
||||
f"CCXTDataSource({self.exchange_id}).get_bars: symbol={symbol}, resolution={resolution}, "
|
||||
f"from_time={from_time}, to_time={to_time}, countback={countback}"
|
||||
)
|
||||
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
timeframe = self._resolution_to_timeframe(resolution)
|
||||
|
||||
# CCXT uses milliseconds for timestamps
|
||||
since = from_time * 1000
|
||||
until = to_time * 1000
|
||||
|
||||
# Fetch OHLCV data
|
||||
limit = countback if countback else 1000
|
||||
|
||||
try:
|
||||
# Fetch in batches if needed
|
||||
all_ohlcv = []
|
||||
current_since = since
|
||||
|
||||
while current_since < until:
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
symbol,
|
||||
timeframe=timeframe,
|
||||
since=current_since,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not ohlcv:
|
||||
break
|
||||
|
||||
all_ohlcv.extend(ohlcv)
|
||||
|
||||
# Update since for next batch
|
||||
last_timestamp = ohlcv[-1][0]
|
||||
if last_timestamp <= current_since:
|
||||
break # No progress, avoid infinite loop
|
||||
current_since = last_timestamp + 1
|
||||
|
||||
# Stop if we have enough bars
|
||||
if countback and len(all_ohlcv) >= countback:
|
||||
all_ohlcv = all_ohlcv[:countback]
|
||||
break
|
||||
|
||||
# Convert to our Bar format with float for OHLCV (used in DataFrames)
|
||||
bars = []
|
||||
for candle in all_ohlcv:
|
||||
timestamp_ms, open_price, high, low, close, volume = candle
|
||||
timestamp = timestamp_ms // 1000 # Convert to seconds
|
||||
|
||||
# Only include bars within requested range
|
||||
if timestamp < from_time or timestamp >= to_time:
|
||||
continue
|
||||
|
||||
bars.append(
|
||||
Bar(
|
||||
time=timestamp,
|
||||
data={
|
||||
"open": self._to_float(open_price),
|
||||
"high": self._to_float(high),
|
||||
"low": self._to_float(low),
|
||||
"close": self._to_float(close),
|
||||
"volume": self._to_float(volume),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Get symbol info for column metadata
|
||||
symbol_info = await self.resolve_symbol(symbol)
|
||||
|
||||
logger.info(
|
||||
f"CCXTDataSource({self.exchange_id}).get_bars: Returning {len(bars)} bars. "
|
||||
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Determine if more data is available
|
||||
next_time = None
|
||||
if bars and countback and len(bars) >= countback:
|
||||
next_time = bars[-1].time + (bars[-1].time - bars[-2].time if len(bars) > 1 else 60)
|
||||
|
||||
return HistoryResult(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
bars=bars,
|
||||
columns=symbol_info.columns,
|
||||
nextTime=next_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch bars for {symbol}: {str(e)}")
|
||||
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to bar updates via polling.
|
||||
|
||||
Note: Uses polling instead of WebSocket since we're using free CCXT.
|
||||
Polls at the configured interval (default: 60 seconds).
|
||||
"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
subscription_id = f"{symbol}:{resolution}:{time.time()}"
|
||||
|
||||
# Store callback
|
||||
self._subscription_callbacks[subscription_id] = on_tick
|
||||
|
||||
# Start polling task
|
||||
timeframe = self._resolution_to_timeframe(resolution)
|
||||
task = asyncio.create_task(
|
||||
self._poll_ohlcv(symbol, timeframe, subscription_id)
|
||||
)
|
||||
self._subscriptions[subscription_id] = task
|
||||
|
||||
return subscription_id
|
||||
|
||||
async def _poll_ohlcv(self, symbol: str, timeframe: str, subscription_id: str):
|
||||
"""
|
||||
Poll for OHLCV updates at regular intervals.
|
||||
|
||||
This simulates real-time updates by fetching the latest bars periodically.
|
||||
Only sends updates when new bars are detected.
|
||||
"""
|
||||
try:
|
||||
while subscription_id in self._subscription_callbacks:
|
||||
try:
|
||||
# Fetch latest bars
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
symbol,
|
||||
timeframe=timeframe,
|
||||
limit=2, # Get last 2 bars to detect new ones
|
||||
)
|
||||
|
||||
if ohlcv and len(ohlcv) > 0:
|
||||
# Get the latest candle
|
||||
latest = ohlcv[-1]
|
||||
timestamp_ms, open_price, high, low, close, volume = latest
|
||||
timestamp = timestamp_ms // 1000
|
||||
|
||||
# Only send update if this is a new bar
|
||||
last_timestamp = self._last_bars.get(subscription_id, 0)
|
||||
if timestamp > last_timestamp:
|
||||
self._last_bars[subscription_id] = timestamp
|
||||
|
||||
# Convert to our format with float for OHLCV (used in DataFrames)
|
||||
tick_data = {
|
||||
"time": timestamp,
|
||||
"open": self._to_float(open_price),
|
||||
"high": self._to_float(high),
|
||||
"low": self._to_float(low),
|
||||
"close": self._to_float(close),
|
||||
"volume": self._to_float(volume),
|
||||
}
|
||||
|
||||
# Call the callback
|
||||
callback = self._subscription_callbacks.get(subscription_id)
|
||||
if callback:
|
||||
callback(tick_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error polling OHLCV for {symbol}: {e}")
|
||||
|
||||
# Wait for next poll interval
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
"""Unsubscribe from polling updates"""
|
||||
# Remove callback and tracking
|
||||
self._subscription_callbacks.pop(subscription_id, None)
|
||||
self._last_bars.pop(subscription_id, None)
|
||||
|
||||
# Cancel polling task
|
||||
task = self._subscriptions.pop(subscription_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close exchange connection and cleanup"""
|
||||
# Cancel all subscriptions
|
||||
for subscription_id in list(self._subscriptions.keys()):
|
||||
await self.unsubscribe_bars(subscription_id)
|
||||
|
||||
# Close exchange
|
||||
if hasattr(self.exchange, 'close'):
|
||||
await self.exchange.close()
|
||||
353
backend.old/src/datasource/adapters/demo.py
Normal file
353
backend.old/src/datasource/adapters/demo.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Demo data source with synthetic data.
|
||||
|
||||
Generates realistic-looking OHLCV data plus additional columns for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from ..base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..schema import (
|
||||
Bar,
|
||||
ColumnInfo,
|
||||
DatafeedConfig,
|
||||
HistoryResult,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
SymbolInfo,
|
||||
)
|
||||
|
||||
|
||||
class DemoDataSource(DataSource):
|
||||
"""
|
||||
Demo data source that generates synthetic time-series data.
|
||||
|
||||
Provides:
|
||||
- Standard OHLCV columns
|
||||
- Additional demo columns (RSI, sentiment, volume_profile)
|
||||
- Real-time updates via polling simulation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._symbols = {
|
||||
"DEMO:BTC/USD": {
|
||||
"name": "Bitcoin",
|
||||
"type": "crypto",
|
||||
"base_price": 50000.0,
|
||||
"volatility": 0.02,
|
||||
},
|
||||
"DEMO:ETH/USD": {
|
||||
"name": "Ethereum",
|
||||
"type": "crypto",
|
||||
"base_price": 3000.0,
|
||||
"volatility": 0.03,
|
||||
},
|
||||
"DEMO:SOL/USD": {
|
||||
"name": "Solana",
|
||||
"type": "crypto",
|
||||
"base_price": 100.0,
|
||||
"volatility": 0.04,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
return DatafeedConfig(
|
||||
name="Demo DataSource",
|
||||
description="Synthetic data generator for testing. Provides OHLCV plus additional indicator columns.",
|
||||
supported_resolutions=[
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.H1,
|
||||
Resolution.D1,
|
||||
],
|
||||
supports_search=True,
|
||||
supports_time=True,
|
||||
exchanges=["DEMO"],
|
||||
symbols_types=["crypto"],
|
||||
)
|
||||
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for symbol, info in self._symbols.items():
|
||||
if query_lower in symbol.lower() or query_lower in info["name"].lower():
|
||||
if type and info["type"] != type:
|
||||
continue
|
||||
results.append(
|
||||
SearchResult(
|
||||
symbol=info['name'], # Clean user-facing format (e.g., "Bitcoin")
|
||||
ticker=symbol, # Keep DEMO:BTC/USD format for routing
|
||||
full_name=f"{info['name']} (DEMO)",
|
||||
description=f"Demo {info['name']} pair",
|
||||
exchange="DEMO",
|
||||
type=info["type"],
|
||||
)
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
info = self._symbols[symbol]
|
||||
base, quote = symbol.split(":")[1].split("/")
|
||||
|
||||
return SymbolInfo(
|
||||
symbol=info["name"], # Clean user-facing format (e.g., "Bitcoin")
|
||||
ticker=symbol, # Keep DEMO:BTC/USD format for routing
|
||||
name=info["name"],
|
||||
description=f"Demo {info['name']} spot price with synthetic indicators",
|
||||
type=info["type"],
|
||||
exchange="DEMO",
|
||||
timezone="Etc/UTC",
|
||||
session="24x7",
|
||||
supported_resolutions=[Resolution.M1, Resolution.M5, Resolution.M15, Resolution.H1, Resolution.D1],
|
||||
has_intraday=True,
|
||||
has_daily=True,
|
||||
has_weekly_and_monthly=False,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="float",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="float",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="float",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="float",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="float",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="rsi",
|
||||
type="float",
|
||||
description="Relative Strength Index (14-period), range 0-100",
|
||||
unit=None,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="sentiment",
|
||||
type="float",
|
||||
description="Synthetic social sentiment score, range -1.0 to 1.0",
|
||||
unit=None,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume_profile",
|
||||
type="float",
|
||||
description="Volume as percentage of 24h average",
|
||||
unit="%",
|
||||
),
|
||||
],
|
||||
time_column="time",
|
||||
has_ohlcv=True,
|
||||
pricescale=100,
|
||||
minmov=1,
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
)
|
||||
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
logger.info(
|
||||
f"DemoDataSource.get_bars: symbol={symbol}, resolution={resolution}, "
|
||||
f"from_time={from_time}, to_time={to_time}, countback={countback}"
|
||||
)
|
||||
|
||||
info = self._symbols[symbol]
|
||||
symbol_meta = await self.resolve_symbol(symbol)
|
||||
|
||||
# Convert resolution to seconds
|
||||
resolution_seconds = self._resolution_to_seconds(resolution)
|
||||
|
||||
# Generate bars
|
||||
bars = []
|
||||
# Align current_time to resolution, but ensure it's >= from_time
|
||||
current_time = from_time - (from_time % resolution_seconds)
|
||||
if current_time < from_time:
|
||||
current_time += resolution_seconds
|
||||
|
||||
price = info["base_price"]
|
||||
|
||||
bar_count = 0
|
||||
max_bars = countback if countback else 5000
|
||||
|
||||
while current_time <= to_time and bar_count < max_bars:
|
||||
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
|
||||
|
||||
# Only include bars within the requested range
|
||||
if from_time <= current_time <= to_time:
|
||||
bars.append(Bar(time=current_time * 1000, data=bar_data)) # Convert to milliseconds
|
||||
bar_count += 1
|
||||
|
||||
price = bar_data["close"] # Next bar starts from previous close
|
||||
current_time += resolution_seconds
|
||||
|
||||
logger.info(
|
||||
f"DemoDataSource.get_bars: Generated {len(bars)} bars. "
|
||||
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Determine if there's more data (for pagination)
|
||||
next_time = current_time if current_time <= to_time else None
|
||||
|
||||
return HistoryResult(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
bars=bars,
|
||||
columns=symbol_meta.columns,
|
||||
nextTime=next_time,
|
||||
)
|
||||
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
subscription_id = f"{symbol}:{resolution}:{time.time()}"
|
||||
|
||||
# Start background task to simulate real-time updates
|
||||
task = asyncio.create_task(
|
||||
self._tick_generator(symbol, resolution, on_tick, subscription_id)
|
||||
)
|
||||
self._subscriptions[subscription_id] = task
|
||||
|
||||
return subscription_id
|
||||
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
task = self._subscriptions.pop(subscription_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _resolution_to_seconds(self, resolution: str) -> int:
|
||||
"""Convert resolution string to seconds"""
|
||||
if resolution.endswith("D"):
|
||||
return int(resolution[:-1]) * 86400
|
||||
elif resolution.endswith("W"):
|
||||
return int(resolution[:-1]) * 604800
|
||||
elif resolution.endswith("M"):
|
||||
return int(resolution[:-1]) * 2592000 # Approximate month
|
||||
else:
|
||||
# Assume minutes
|
||||
return int(resolution) * 60
|
||||
|
||||
def _generate_bar(self, timestamp: int, base_price: float, volatility: float, period_seconds: int) -> dict:
|
||||
"""Generate a single synthetic OHLCV bar"""
|
||||
# Random walk for the period
|
||||
open_price = base_price
|
||||
|
||||
# Generate intra-period price movement
|
||||
num_ticks = max(10, period_seconds // 60) # More ticks for longer periods
|
||||
prices = [open_price]
|
||||
|
||||
for _ in range(num_ticks):
|
||||
change = random.gauss(0, volatility / math.sqrt(num_ticks))
|
||||
prices.append(prices[-1] * (1 + change))
|
||||
|
||||
close_price = prices[-1]
|
||||
high_price = max(prices)
|
||||
low_price = min(prices)
|
||||
|
||||
# Volume with some randomness
|
||||
base_volume = 1000000 * (period_seconds / 60) # Scale with period
|
||||
volume = base_volume * random.uniform(0.5, 2.0)
|
||||
|
||||
# Additional synthetic indicators
|
||||
rsi = 30 + random.random() * 40 # Biased toward middle range
|
||||
sentiment = math.sin(timestamp / 3600) * 0.5 + random.gauss(0, 0.2) # Hourly cycle + noise
|
||||
sentiment = max(-1.0, min(1.0, sentiment))
|
||||
volume_profile = 100 * random.uniform(0.5, 1.5)
|
||||
|
||||
return {
|
||||
"open": round(open_price, 2),
|
||||
"high": round(high_price, 2),
|
||||
"low": round(low_price, 2),
|
||||
"close": round(close_price, 2),
|
||||
"volume": round(volume, 2),
|
||||
"rsi": round(rsi, 2),
|
||||
"sentiment": round(sentiment, 3),
|
||||
"volume_profile": round(volume_profile, 2),
|
||||
}
|
||||
|
||||
async def _tick_generator(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
subscription_id: str,
|
||||
):
|
||||
"""Background task that generates periodic ticks"""
|
||||
info = self._symbols[symbol]
|
||||
resolution_seconds = self._resolution_to_seconds(resolution)
|
||||
|
||||
# Start from current aligned time
|
||||
current_time = int(time.time())
|
||||
current_time = current_time - (current_time % resolution_seconds)
|
||||
price = info["base_price"]
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait until next bar
|
||||
await asyncio.sleep(resolution_seconds)
|
||||
|
||||
current_time += resolution_seconds
|
||||
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
|
||||
price = bar_data["close"]
|
||||
|
||||
# Call the tick handler
|
||||
tick_data = {"time": current_time, **bar_data}
|
||||
on_tick(tick_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Subscription cancelled
|
||||
pass
|
||||
146
backend.old/src/datasource/base.py
Normal file
146
backend.old/src/datasource/base.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Abstract DataSource interface.
|
||||
|
||||
Inspired by TradingView's Datafeed API with extensions for flexible column schemas
|
||||
and AI-native metadata.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from .schema import DatafeedConfig, HistoryResult, SearchResult, SymbolInfo
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""
|
||||
Abstract base class for time-series data sources.
|
||||
|
||||
Provides a standardized interface for:
|
||||
- Symbol search and metadata retrieval
|
||||
- Historical data queries (time-based, paginated)
|
||||
- Real-time data subscriptions
|
||||
|
||||
All data rows must have a timestamp. Additional columns are flexible
|
||||
and described via ColumnInfo metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
"""
|
||||
Get datafeed configuration and capabilities.
|
||||
|
||||
Called once during initialization to understand what this data source
|
||||
supports (resolutions, exchanges, search, etc.).
|
||||
|
||||
Returns:
|
||||
DatafeedConfig describing this datafeed's capabilities
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search for symbols matching a text query.
|
||||
|
||||
Args:
|
||||
query: Free-text search string
|
||||
type: Optional filter by instrument type
|
||||
exchange: Optional filter by exchange
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching symbols with basic metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
"""
|
||||
Get complete metadata for a symbol.
|
||||
|
||||
Called after a symbol is selected to retrieve full information including
|
||||
supported resolutions, column schema, trading session, etc.
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
|
||||
Returns:
|
||||
Complete SymbolInfo including column definitions
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
"""
|
||||
Get historical bars for a symbol and resolution.
|
||||
|
||||
Time range is specified in Unix timestamps (seconds). If more data is
|
||||
available beyond the requested range, the result should include a
|
||||
nextTime value for pagination.
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1", "5", "60", "1D")
|
||||
from_time: Start time (Unix timestamp in seconds)
|
||||
to_time: End time (Unix timestamp in seconds)
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
HistoryResult with bars, column schema, and pagination info
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol or resolution is not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to real-time bar updates.
|
||||
|
||||
The callback will be invoked with new bar data as it becomes available.
|
||||
The data dict will match the column schema from resolve_symbol().
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution
|
||||
on_tick: Callback function receiving bar data dict
|
||||
|
||||
Returns:
|
||||
Subscription ID for later unsubscribe
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol or resolution is not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe from real-time updates.
|
||||
|
||||
Args:
|
||||
subscription_id: ID returned from subscribe_bars()
|
||||
"""
|
||||
pass
|
||||
109
backend.old/src/datasource/registry.py
Normal file
109
backend.old/src/datasource/registry.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
DataSource registry for managing multiple data sources.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base import DataSource
|
||||
from .schema import SearchResult, SymbolInfo
|
||||
|
||||
|
||||
class DataSourceRegistry:
|
||||
"""
|
||||
Central registry for managing multiple DataSource instances.
|
||||
|
||||
Allows routing symbol queries to the appropriate data source and
|
||||
aggregating search results across multiple sources.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._sources: Dict[str, DataSource] = {}
|
||||
|
||||
def register(self, name: str, source: DataSource) -> None:
|
||||
"""
|
||||
Register a data source.
|
||||
|
||||
Args:
|
||||
name: Unique name for this data source
|
||||
source: DataSource implementation
|
||||
"""
|
||||
self._sources[name] = source
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a data source.
|
||||
|
||||
Args:
|
||||
name: Name of the data source to remove
|
||||
"""
|
||||
self._sources.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Optional[DataSource]:
|
||||
"""
|
||||
Get a registered data source by name.
|
||||
|
||||
Args:
|
||||
name: Data source name
|
||||
|
||||
Returns:
|
||||
DataSource instance or None if not found
|
||||
"""
|
||||
return self._sources.get(name)
|
||||
|
||||
def list_sources(self) -> List[str]:
|
||||
"""
|
||||
Get names of all registered data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names
|
||||
"""
|
||||
return list(self._sources.keys())
|
||||
|
||||
async def search_all(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""
|
||||
Search across all registered data sources.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
type: Optional instrument type filter
|
||||
exchange: Optional exchange filter
|
||||
limit: Maximum results per source
|
||||
|
||||
Returns:
|
||||
Dict mapping source name to search results
|
||||
"""
|
||||
results = {}
|
||||
for name, source in self._sources.items():
|
||||
try:
|
||||
source_results = await source.search_symbols(query, type, exchange, limit)
|
||||
if source_results:
|
||||
results[name] = source_results
|
||||
except Exception:
|
||||
# Silently skip sources that error during search
|
||||
continue
|
||||
return results
|
||||
|
||||
async def resolve_symbol(self, source_name: str, symbol: str) -> SymbolInfo:
|
||||
"""
|
||||
Resolve a symbol from a specific data source.
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
|
||||
Returns:
|
||||
SymbolInfo from the specified source
|
||||
|
||||
Raises:
|
||||
ValueError: If source not found or symbol not found
|
||||
"""
|
||||
source = self.get(source_name)
|
||||
if not source:
|
||||
raise ValueError(f"Data source '{source_name}' not found")
|
||||
return await source.resolve_symbol(symbol)
|
||||
194
backend.old/src/datasource/schema.py
Normal file
194
backend.old/src/datasource/schema.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Data models for the DataSource interface.
|
||||
|
||||
Inspired by TradingView's Datafeed API but with flexible column schemas
|
||||
for AI-native trading platform needs.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Resolution(StrEnum):
|
||||
"""Standard time resolutions for bar data"""
|
||||
|
||||
# Seconds
|
||||
S1 = "1S"
|
||||
S5 = "5S"
|
||||
S15 = "15S"
|
||||
S30 = "30S"
|
||||
|
||||
# Minutes
|
||||
M1 = "1"
|
||||
M5 = "5"
|
||||
M15 = "15"
|
||||
M30 = "30"
|
||||
|
||||
# Hours
|
||||
H1 = "60"
|
||||
H2 = "120"
|
||||
H4 = "240"
|
||||
H6 = "360"
|
||||
H12 = "720"
|
||||
|
||||
# Days
|
||||
D1 = "1D"
|
||||
|
||||
# Weeks
|
||||
W1 = "1W"
|
||||
|
||||
# Months
|
||||
MO1 = "1M"
|
||||
|
||||
|
||||
class ColumnInfo(BaseModel):
|
||||
"""
|
||||
Metadata for a single data column.
|
||||
|
||||
Provides rich, LLM-readable descriptions so AI agents can understand
|
||||
and reason about available data fields.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Column name (e.g., 'close', 'volume', 'funding_rate')")
|
||||
type: Literal["float", "int", "bool", "string", "decimal"] = Field(description="Data type")
|
||||
description: str = Field(description="Human and LLM-readable description of what this column represents")
|
||||
unit: Optional[str] = Field(default=None, description="Unit of measurement (e.g., 'USD', 'BTC', '%', 'contracts')")
|
||||
nullable: bool = Field(default=False, description="Whether this column can contain null values")
|
||||
|
||||
|
||||
class SymbolInfo(BaseModel):
|
||||
"""
|
||||
Complete metadata for a tradeable symbol.
|
||||
|
||||
Includes both TradingView-compatible fields and flexible schema definition
|
||||
for arbitrary data columns.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# Core identification
|
||||
symbol: str = Field(description="Unique symbol identifier (primary key for data fetching)")
|
||||
ticker: Optional[str] = Field(default=None, description="TradingView ticker (if different from symbol)")
|
||||
name: str = Field(description="Display name")
|
||||
description: str = Field(description="LLM-readable description of the instrument")
|
||||
type: str = Field(description="Instrument type: 'crypto', 'stock', 'forex', 'futures', 'derived', etc.")
|
||||
exchange: str = Field(description="Exchange or data source identifier")
|
||||
|
||||
# Trading session info
|
||||
timezone: str = Field(default="Etc/UTC", description="IANA timezone identifier")
|
||||
session: str = Field(default="24x7", description="Trading session spec (e.g., '0930-1600' or '24x7')")
|
||||
|
||||
# Resolution support
|
||||
supported_resolutions: List[str] = Field(description="List of supported time resolutions")
|
||||
has_intraday: bool = Field(default=True, description="Whether intraday resolutions are supported")
|
||||
has_daily: bool = Field(default=True, description="Whether daily resolution is supported")
|
||||
has_weekly_and_monthly: bool = Field(default=False, description="Whether weekly/monthly resolutions are supported")
|
||||
|
||||
# Flexible schema definition
|
||||
columns: List[ColumnInfo] = Field(description="Available data columns for this symbol")
|
||||
time_column: str = Field(default="time", description="Name of the timestamp column")
|
||||
|
||||
# Convenience flags
|
||||
has_ohlcv: bool = Field(default=False, description="Whether standard OHLCV columns are present")
|
||||
|
||||
# Price display (for OHLCV data)
|
||||
pricescale: int = Field(default=100, description="Price scale factor (e.g., 100 for 2 decimals)")
|
||||
minmov: int = Field(default=1, description="Minimum price movement in pricescale units")
|
||||
|
||||
# Additional metadata
|
||||
base_currency: Optional[str] = Field(default=None, description="Base currency (for crypto/forex)")
|
||||
quote_currency: Optional[str] = Field(default=None, description="Quote currency (for crypto/forex)")
|
||||
|
||||
|
||||
class Bar(BaseModel):
|
||||
"""
|
||||
A single bar/row of time-series data with flexible columns.
|
||||
|
||||
All bars must have a timestamp. Additional columns are stored in the
|
||||
data dict and described by the associated ColumnInfo metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
time: int = Field(description="Unix timestamp in seconds")
|
||||
data: Dict[str, Any] = Field(description="Column name -> value mapping")
|
||||
|
||||
# Convenience accessors for common OHLCV columns
|
||||
@property
|
||||
def open(self) -> Optional[float]:
|
||||
return self.data.get("open")
|
||||
|
||||
@property
|
||||
def high(self) -> Optional[float]:
|
||||
return self.data.get("high")
|
||||
|
||||
@property
|
||||
def low(self) -> Optional[float]:
|
||||
return self.data.get("low")
|
||||
|
||||
@property
|
||||
def close(self) -> Optional[float]:
|
||||
return self.data.get("close")
|
||||
|
||||
@property
|
||||
def volume(self) -> Optional[float]:
|
||||
return self.data.get("volume")
|
||||
|
||||
|
||||
class HistoryResult(BaseModel):
|
||||
"""
|
||||
Result from a historical data query.
|
||||
|
||||
Includes the bars, schema information, and pagination metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol: str = Field(description="Symbol identifier")
|
||||
resolution: str = Field(description="Time resolution of the bars")
|
||||
bars: List[Bar] = Field(description="The actual data bars")
|
||||
columns: List[ColumnInfo] = Field(description="Schema describing the bar data columns")
|
||||
nextTime: Optional[int] = Field(default=None, description="Unix timestamp for pagination (if more data available)")
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""
|
||||
A single result from symbol search.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol: str = Field(description="Display symbol (e.g., 'BINANCE:ETH/BTC')")
|
||||
ticker: Optional[str] = Field(default=None, description="Backend ticker for data fetching (e.g., 'ETH/BTC')")
|
||||
full_name: str = Field(description="Full display name including exchange")
|
||||
description: str = Field(description="Human-readable description")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
type: str = Field(description="Instrument type")
|
||||
|
||||
|
||||
class DatafeedConfig(BaseModel):
|
||||
"""
|
||||
Configuration and capabilities of a DataSource.
|
||||
|
||||
Similar to TradingView's onReady configuration object.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# Supported features
|
||||
supported_resolutions: List[str] = Field(description="All resolutions this datafeed supports")
|
||||
supports_search: bool = Field(default=True, description="Whether symbol search is available")
|
||||
supports_time: bool = Field(default=True, description="Whether time-based queries are supported")
|
||||
supports_marks: bool = Field(default=False, description="Whether marks/events are supported")
|
||||
|
||||
# Data characteristics
|
||||
exchanges: List[str] = Field(default_factory=list, description="Available exchanges")
|
||||
symbols_types: List[str] = Field(default_factory=list, description="Available instrument types")
|
||||
|
||||
# Metadata
|
||||
name: str = Field(description="Datafeed name")
|
||||
description: str = Field(description="LLM-readable description of this data source")
|
||||
235
backend.old/src/datasource/subscription_manager.py
Normal file
235
backend.old/src/datasource/subscription_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Subscription manager for real-time data feeds.
|
||||
|
||||
Manages subscriptions across multiple data sources and routes updates
|
||||
to WebSocket clients.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, Dict, Optional, Set
|
||||
|
||||
from .base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Subscription:
|
||||
"""Represents a single client subscription"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subscription_id: str,
|
||||
client_id: str,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
callback: Callable[[dict], None],
|
||||
):
|
||||
self.subscription_id = subscription_id
|
||||
self.client_id = client_id
|
||||
self.source_name = source_name
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.callback = callback
|
||||
self.source_subscription_id: Optional[str] = None
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
"""
|
||||
Manages real-time data subscriptions across multiple data sources.
|
||||
|
||||
Handles:
|
||||
- Subscription lifecycle (subscribe/unsubscribe)
|
||||
- Routing updates from data sources to clients
|
||||
- Multiplexing (multiple clients can subscribe to same symbol/resolution)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map subscription_id -> Subscription
|
||||
self._subscriptions: Dict[str, Subscription] = {}
|
||||
|
||||
# Map (source_name, symbol, resolution) -> Set[subscription_id]
|
||||
# For tracking which client subscriptions use which source subscriptions
|
||||
self._source_refs: Dict[tuple, Set[str]] = {}
|
||||
|
||||
# Map source_subscription_id -> (source_name, symbol, resolution)
|
||||
self._source_subs: Dict[str, tuple] = {}
|
||||
|
||||
# Available data sources
|
||||
self._sources: Dict[str, DataSource] = {}
|
||||
|
||||
def register_source(self, name: str, source: DataSource) -> None:
|
||||
"""Register a data source"""
|
||||
self._sources[name] = source
|
||||
|
||||
def unregister_source(self, name: str) -> None:
|
||||
"""Unregister a data source"""
|
||||
self._sources.pop(name, None)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
subscription_id: str,
|
||||
client_id: str,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
callback: Callable[[dict], None],
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe a client to real-time updates.
|
||||
|
||||
Args:
|
||||
subscription_id: Unique ID for this subscription
|
||||
client_id: ID of the subscribing client
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol to subscribe to
|
||||
resolution: Time resolution
|
||||
callback: Function to call with bar updates
|
||||
|
||||
Raises:
|
||||
ValueError: If source not found or subscription fails
|
||||
"""
|
||||
source = self._sources.get(source_name)
|
||||
if not source:
|
||||
raise ValueError(f"Data source '{source_name}' not found")
|
||||
|
||||
# Create subscription record
|
||||
subscription = Subscription(
|
||||
subscription_id=subscription_id,
|
||||
client_id=client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Check if we already have a source subscription for this (source, symbol, resolution)
|
||||
source_key = (source_name, symbol, resolution)
|
||||
if source_key not in self._source_refs:
|
||||
# Need to create a new source subscription
|
||||
try:
|
||||
source_sub_id = await source.subscribe_bars(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
on_tick=lambda bar: self._on_source_update(source_key, bar),
|
||||
)
|
||||
subscription.source_subscription_id = source_sub_id
|
||||
self._source_subs[source_sub_id] = source_key
|
||||
self._source_refs[source_key] = set()
|
||||
logger.info(
|
||||
f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to source: {e}")
|
||||
raise
|
||||
|
||||
# Add this subscription to the reference set
|
||||
self._source_refs[source_key].add(subscription_id)
|
||||
self._subscriptions[subscription_id] = subscription
|
||||
|
||||
logger.info(
|
||||
f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}"
|
||||
)
|
||||
|
||||
async def unsubscribe(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe a client from updates.
|
||||
|
||||
Args:
|
||||
subscription_id: ID of the subscription to remove
|
||||
"""
|
||||
subscription = self._subscriptions.pop(subscription_id, None)
|
||||
if not subscription:
|
||||
logger.warning(f"Subscription {subscription_id} not found")
|
||||
return
|
||||
|
||||
source_key = (subscription.source_name, subscription.symbol, subscription.resolution)
|
||||
|
||||
# Remove from reference set
|
||||
if source_key in self._source_refs:
|
||||
self._source_refs[source_key].discard(subscription_id)
|
||||
|
||||
# If no more clients need this source subscription, unsubscribe from source
|
||||
if not self._source_refs[source_key]:
|
||||
del self._source_refs[source_key]
|
||||
|
||||
if subscription.source_subscription_id:
|
||||
source = self._sources.get(subscription.source_name)
|
||||
if source:
|
||||
try:
|
||||
await source.unsubscribe_bars(subscription.source_subscription_id)
|
||||
logger.info(
|
||||
f"Unsubscribed from source: {subscription.source_subscription_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error unsubscribing from source: {e}")
|
||||
|
||||
self._source_subs.pop(subscription.source_subscription_id, None)
|
||||
|
||||
logger.info(f"Client subscription removed: {subscription_id}")
|
||||
|
||||
async def unsubscribe_client(self, client_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe all subscriptions for a client.
|
||||
|
||||
Useful when a WebSocket connection closes.
|
||||
|
||||
Args:
|
||||
client_id: ID of the client
|
||||
"""
|
||||
# Find all subscriptions for this client
|
||||
client_subs = [
|
||||
sub_id
|
||||
for sub_id, sub in self._subscriptions.items()
|
||||
if sub.client_id == client_id
|
||||
]
|
||||
|
||||
# Unsubscribe each one
|
||||
for sub_id in client_subs:
|
||||
await self.unsubscribe(sub_id)
|
||||
|
||||
logger.info(f"Unsubscribed all subscriptions for client {client_id}")
|
||||
|
||||
def _on_source_update(self, source_key: tuple, bar: dict) -> None:
|
||||
"""
|
||||
Handle an update from a data source.
|
||||
|
||||
Routes the update to all client subscriptions that need it.
|
||||
|
||||
Args:
|
||||
source_key: (source_name, symbol, resolution) tuple
|
||||
bar: Bar data dict from the source
|
||||
"""
|
||||
subscription_ids = self._source_refs.get(source_key, set())
|
||||
|
||||
for sub_id in subscription_ids:
|
||||
subscription = self._subscriptions.get(sub_id)
|
||||
if subscription:
|
||||
try:
|
||||
subscription.callback(bar)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in subscription callback {sub_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
def get_subscription_count(self) -> int:
|
||||
"""Get total number of active client subscriptions"""
|
||||
return len(self._subscriptions)
|
||||
|
||||
def get_source_subscription_count(self) -> int:
|
||||
"""Get total number of active source subscriptions"""
|
||||
return len(self._source_refs)
|
||||
|
||||
def get_client_subscriptions(self, client_id: str) -> list:
|
||||
"""Get all subscriptions for a specific client"""
|
||||
return [
|
||||
{
|
||||
"subscription_id": sub.subscription_id,
|
||||
"source": sub.source_name,
|
||||
"symbol": sub.symbol,
|
||||
"resolution": sub.resolution,
|
||||
}
|
||||
for sub in self._subscriptions.values()
|
||||
if sub.client_id == client_id
|
||||
]
|
||||
347
backend.old/src/datasource/websocket_handler.py
Normal file
347
backend.old/src/datasource/websocket_handler.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
WebSocket handler for TradingView-compatible datafeed API.
|
||||
|
||||
Handles incoming requests for symbol search, metadata, historical data,
|
||||
and real-time subscriptions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from .base import DataSource
|
||||
from .registry import DataSourceRegistry
|
||||
from .subscription_manager import SubscriptionManager
|
||||
from .websocket_protocol import (
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
GetBarsRequest,
|
||||
GetBarsResponse,
|
||||
GetConfigRequest,
|
||||
GetConfigResponse,
|
||||
ResolveSymbolRequest,
|
||||
ResolveSymbolResponse,
|
||||
SearchSymbolsRequest,
|
||||
SearchSymbolsResponse,
|
||||
SubscribeBarsRequest,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsRequest,
|
||||
UnsubscribeBarsResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatafeedWebSocketHandler:
|
||||
"""
|
||||
Handles WebSocket connections for TradingView-compatible datafeed API.
|
||||
|
||||
Each handler manages a single WebSocket connection and routes requests
|
||||
to the appropriate data sources via the registry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
registry: DataSourceRegistry,
|
||||
subscription_manager: SubscriptionManager,
|
||||
default_source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize handler.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
client_id: Unique identifier for this client
|
||||
registry: DataSource registry for accessing data sources
|
||||
subscription_manager: Shared subscription manager
|
||||
default_source: Default data source name if not specified in requests
|
||||
"""
|
||||
self.websocket = websocket
|
||||
self.client_id = client_id
|
||||
self.registry = registry
|
||||
self.subscription_manager = subscription_manager
|
||||
self.default_source = default_source
|
||||
self._connected = True
|
||||
|
||||
async def handle_connection(self) -> None:
|
||||
"""
|
||||
Main connection handler loop.
|
||||
|
||||
Processes incoming messages until the connection closes.
|
||||
"""
|
||||
try:
|
||||
await self.websocket.accept()
|
||||
logger.info(f"WebSocket connected: client_id={self.client_id}")
|
||||
|
||||
while self._connected:
|
||||
# Receive message
|
||||
try:
|
||||
data = await self.websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving/parsing message: {e}")
|
||||
break
|
||||
|
||||
# Route to appropriate handler
|
||||
await self._handle_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
finally:
|
||||
# Clean up subscriptions when connection closes
|
||||
await self.subscription_manager.unsubscribe_client(self.client_id)
|
||||
self._connected = False
|
||||
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
|
||||
|
||||
async def _handle_message(self, message: dict) -> None:
|
||||
"""Route message to appropriate handler based on type"""
|
||||
msg_type = message.get("type")
|
||||
request_id = message.get("request_id", "unknown")
|
||||
|
||||
try:
|
||||
if msg_type == "search_symbols":
|
||||
await self._handle_search_symbols(SearchSymbolsRequest(**message))
|
||||
elif msg_type == "resolve_symbol":
|
||||
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
|
||||
elif msg_type == "get_bars":
|
||||
await self._handle_get_bars(GetBarsRequest(**message))
|
||||
elif msg_type == "subscribe_bars":
|
||||
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
|
||||
elif msg_type == "unsubscribe_bars":
|
||||
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
|
||||
elif msg_type == "get_config":
|
||||
await self._handle_get_config(GetConfigRequest(**message))
|
||||
else:
|
||||
await self._send_error(
|
||||
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
||||
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
|
||||
|
||||
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
|
||||
"""Handle symbol search request"""
|
||||
# Use default source or search all sources
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"SOURCE_NOT_FOUND",
|
||||
f"Default source '{self.default_source}' not found",
|
||||
)
|
||||
return
|
||||
|
||||
results = await source.search_symbols(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
results_data = [r.model_dump(mode="json") for r in results]
|
||||
else:
|
||||
# Search all sources
|
||||
all_results = await self.registry.search_all(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
# Flatten results from all sources
|
||||
results_data = []
|
||||
for source_results in all_results.values():
|
||||
results_data.extend([r.model_dump(mode="json") for r in source_results])
|
||||
|
||||
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
|
||||
"""Handle symbol resolution request"""
|
||||
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"NO_SOURCE_SPECIFIED",
|
||||
"No data source specified and no default source configured",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
|
||||
response = ResolveSymbolResponse(
|
||||
request_id=request.request_id,
|
||||
symbol_info=symbol_info.model_dump(mode="json"),
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
|
||||
|
||||
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
|
||||
"""Handle historical bars request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
source = self.registry.get(source_name)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
history = await source.get_bars(
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
from_time=request.from_time,
|
||||
to_time=request.to_time,
|
||||
countback=request.countback,
|
||||
)
|
||||
response = GetBarsResponse(
|
||||
request_id=request.request_id, history=history.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
|
||||
|
||||
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
|
||||
"""Handle real-time subscription request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create callback that sends updates to this WebSocket
|
||||
async def send_update(bar: dict):
|
||||
update = BarUpdateMessage(
|
||||
subscription_id=request.subscription_id,
|
||||
symbol=request.symbol,
|
||||
resolution=request.resolution,
|
||||
bar=bar,
|
||||
)
|
||||
await self._send_response(update)
|
||||
|
||||
# Register subscription
|
||||
await self.subscription_manager.subscribe(
|
||||
subscription_id=request.subscription_id,
|
||||
client_id=self.client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
callback=lambda bar: self._queue_update(send_update(bar)),
|
||||
)
|
||||
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription failed: {e}", exc_info=True)
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
message=str(e),
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
|
||||
"""Handle unsubscribe request"""
|
||||
try:
|
||||
await self.subscription_manager.unsubscribe(request.subscription_id)
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Unsubscribe failed: {e}")
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_get_config(self, request: GetConfigRequest) -> None:
|
||||
"""Handle datafeed config request"""
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
return
|
||||
|
||||
# Return aggregate config from all sources
|
||||
all_sources = self.registry.list_sources()
|
||||
if not all_sources:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCES", "No data sources available"
|
||||
)
|
||||
return
|
||||
|
||||
# Just use first source's config for now
|
||||
# TODO: Aggregate configs from multiple sources
|
||||
source = self.registry.get(all_sources[0])
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _send_response(self, response) -> None:
|
||||
"""Send a response message to the client"""
|
||||
try:
|
||||
await self.websocket.send_json(response.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending response: {e}")
|
||||
self._connected = False
|
||||
|
||||
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
|
||||
"""Send an error response"""
|
||||
error = ErrorResponse(
|
||||
request_id=request_id, error_code=error_code, error_message=error_message
|
||||
)
|
||||
await self._send_response(error)
|
||||
|
||||
def _queue_update(self, coro):
|
||||
"""Queue an async update to be sent (prevents blocking callback)"""
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(coro)
|
||||
170
backend.old/src/datasource/websocket_protocol.py
Normal file
170
backend.old/src/datasource/websocket_protocol.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
WebSocket protocol messages for TradingView-compatible datafeed API.
|
||||
|
||||
These messages define the wire format for client-server communication
|
||||
over WebSocket for symbol search, historical data, and real-time subscriptions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client -> Server Messages
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SearchSymbolsRequest(BaseModel):
|
||||
"""Request to search for symbols matching a query"""
|
||||
|
||||
type: Literal["search_symbols"] = "search_symbols"
|
||||
request_id: str = Field(description="Client-generated request ID for matching responses")
|
||||
query: str = Field(description="Search query string")
|
||||
symbol_type: Optional[str] = Field(default=None, description="Filter by instrument type")
|
||||
exchange: Optional[str] = Field(default=None, description="Filter by exchange")
|
||||
limit: int = Field(default=30, description="Maximum number of results")
|
||||
|
||||
|
||||
class ResolveSymbolRequest(BaseModel):
|
||||
"""Request full metadata for a specific symbol"""
|
||||
|
||||
type: Literal["resolve_symbol"] = "resolve_symbol"
|
||||
request_id: str
|
||||
symbol: str = Field(description="Symbol identifier to resolve")
|
||||
|
||||
|
||||
class GetBarsRequest(BaseModel):
|
||||
"""Request historical bar data"""
|
||||
|
||||
type: Literal["get_bars"] = "get_bars"
|
||||
request_id: str
|
||||
symbol: str
|
||||
resolution: str = Field(description="Time resolution (e.g., '1', '5', '60', '1D')")
|
||||
from_time: int = Field(description="Start time (Unix timestamp in seconds)")
|
||||
to_time: int = Field(description="End time (Unix timestamp in seconds)")
|
||||
countback: Optional[int] = Field(default=None, description="Maximum number of bars to return")
|
||||
|
||||
|
||||
class SubscribeBarsRequest(BaseModel):
|
||||
"""Subscribe to real-time bar updates"""
|
||||
|
||||
type: Literal["subscribe_bars"] = "subscribe_bars"
|
||||
request_id: str
|
||||
symbol: str
|
||||
resolution: str
|
||||
subscription_id: str = Field(description="Client-generated subscription ID")
|
||||
|
||||
|
||||
class UnsubscribeBarsRequest(BaseModel):
|
||||
"""Unsubscribe from real-time updates"""
|
||||
|
||||
type: Literal["unsubscribe_bars"] = "unsubscribe_bars"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
|
||||
|
||||
class GetConfigRequest(BaseModel):
|
||||
"""Request datafeed configuration"""
|
||||
|
||||
type: Literal["get_config"] = "get_config"
|
||||
request_id: str
|
||||
|
||||
|
||||
# Union of all client request types
|
||||
ClientRequest = Union[
|
||||
SearchSymbolsRequest,
|
||||
ResolveSymbolRequest,
|
||||
GetBarsRequest,
|
||||
SubscribeBarsRequest,
|
||||
UnsubscribeBarsRequest,
|
||||
GetConfigRequest,
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Server -> Client Messages
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SearchSymbolsResponse(BaseModel):
|
||||
"""Response with search results"""
|
||||
|
||||
type: Literal["search_symbols_response"] = "search_symbols_response"
|
||||
request_id: str
|
||||
results: List[Dict[str, Any]] = Field(description="List of SearchResult objects")
|
||||
|
||||
|
||||
class ResolveSymbolResponse(BaseModel):
|
||||
"""Response with symbol metadata"""
|
||||
|
||||
type: Literal["resolve_symbol_response"] = "resolve_symbol_response"
|
||||
request_id: str
|
||||
symbol_info: Dict[str, Any] = Field(description="SymbolInfo object")
|
||||
|
||||
|
||||
class GetBarsResponse(BaseModel):
|
||||
"""Response with historical bars"""
|
||||
|
||||
type: Literal["get_bars_response"] = "get_bars_response"
|
||||
request_id: str
|
||||
history: Dict[str, Any] = Field(description="HistoryResult object with bars and metadata")
|
||||
|
||||
|
||||
class SubscribeBarsResponse(BaseModel):
|
||||
"""Acknowledgment of subscription"""
|
||||
|
||||
type: Literal["subscribe_bars_response"] = "subscribe_bars_response"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class UnsubscribeBarsResponse(BaseModel):
|
||||
"""Acknowledgment of unsubscribe"""
|
||||
|
||||
type: Literal["unsubscribe_bars_response"] = "unsubscribe_bars_response"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class GetConfigResponse(BaseModel):
|
||||
"""Response with datafeed configuration"""
|
||||
|
||||
type: Literal["get_config_response"] = "get_config_response"
|
||||
request_id: str
|
||||
config: Dict[str, Any] = Field(description="DatafeedConfig object")
|
||||
|
||||
|
||||
class BarUpdateMessage(BaseModel):
|
||||
"""Real-time bar update (server-initiated, no request_id)"""
|
||||
|
||||
type: Literal["bar_update"] = "bar_update"
|
||||
subscription_id: str
|
||||
symbol: str
|
||||
resolution: str
|
||||
bar: Dict[str, Any] = Field(description="Bar data including time and all columns")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response for any failed request"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
request_id: str
|
||||
error_code: str = Field(description="Machine-readable error code")
|
||||
error_message: str = Field(description="Human-readable error description")
|
||||
|
||||
|
||||
# Union of all server response types
|
||||
ServerResponse = Union[
|
||||
SearchSymbolsResponse,
|
||||
ResolveSymbolResponse,
|
||||
GetBarsResponse,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsResponse,
|
||||
GetConfigResponse,
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
]
|
||||
179
backend.old/src/exchange_kernel/README.md
Normal file
179
backend.old/src/exchange_kernel/README.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# Exchange Kernel API
|
||||
|
||||
A Kubernetes-style declarative API for managing orders across different exchanges.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The Exchange Kernel maintains two separate views of order state:
|
||||
|
||||
1. **Desired State (Intent)**: What the strategy kernel wants
|
||||
2. **Actual State (Reality)**: What currently exists on the exchange
|
||||
|
||||
A reconciliation loop continuously works to bring actual state into alignment with desired state, handling errors, retries, and edge cases automatically.
|
||||
|
||||
## Core Components
|
||||
|
||||
### Models (`models.py`)
|
||||
|
||||
- **OrderIntent**: Desired order state from strategy kernel
|
||||
- **OrderState**: Actual current order state on exchange
|
||||
- **Position**: Current position (spot, margin, perp, futures, options)
|
||||
- **Asset**: Asset holdings with metadata
|
||||
- **AccountState**: Complete account snapshot (balances, positions, margin)
|
||||
- **AssetMetadata**: Asset type descriptions and trading parameters
|
||||
|
||||
### Events (`events.py`)
|
||||
|
||||
Order lifecycle events:
|
||||
- `OrderSubmitted`, `OrderAccepted`, `OrderRejected`
|
||||
- `OrderPartiallyFilled`, `OrderFilled`, `OrderCanceled`
|
||||
- `OrderModified`, `OrderExpired`
|
||||
|
||||
Position events:
|
||||
- `PositionOpened`, `PositionModified`, `PositionClosed`
|
||||
|
||||
Account events:
|
||||
- `AccountBalanceUpdated`, `MarginCallWarning`
|
||||
|
||||
### Base Interface (`base.py`)
|
||||
|
||||
Abstract `ExchangeKernel` class defining:
|
||||
|
||||
**Command API**:
|
||||
- `place_order()`, `place_order_group()` - Create order intents
|
||||
- `cancel_order()`, `modify_order()` - Update intents
|
||||
- `cancel_all_orders()` - Bulk cancellation
|
||||
|
||||
**Query API**:
|
||||
- `get_order_intent()`, `get_order_state()` - Query single order
|
||||
- `get_all_intents()`, `get_all_orders()` - Query all orders
|
||||
- `get_positions()`, `get_account_state()` - Query positions/balances
|
||||
- `get_symbol_metadata()`, `get_asset_metadata()` - Query market info
|
||||
|
||||
**Event API**:
|
||||
- `subscribe_events()`, `unsubscribe_events()` - Event notifications
|
||||
|
||||
**Lifecycle**:
|
||||
- `start()`, `stop()` - Kernel lifecycle
|
||||
- `health_check()` - Connection status
|
||||
- `force_reconciliation()` - Manual reconciliation trigger
|
||||
|
||||
### State Management (`state.py`)
|
||||
|
||||
- **IntentStateStore**: Storage for desired state (durable, survives restarts)
|
||||
- **ActualStateStore**: Storage for actual exchange state (ephemeral cache)
|
||||
- **ReconciliationEngine**: Framework for intent→reality reconciliation
|
||||
- **InMemory implementations**: For testing/prototyping
|
||||
|
||||
## Standard Order Model
|
||||
|
||||
Defined in `schema/order_spec.py`:
|
||||
|
||||
```python
|
||||
StandardOrder(
|
||||
symbol_id="BTC/USD",
|
||||
side=Side.BUY,
|
||||
amount=1.0,
|
||||
amount_type=AmountType.BASE, # or QUOTE for exact-out
|
||||
limit_price=50000.0, # None for market orders
|
||||
time_in_force=TimeInForce.GTC,
|
||||
conditional_trigger=ConditionalTrigger(...), # Optional stop-loss/take-profit
|
||||
conditional_mode=ConditionalOrderMode.UNIFIED_ADJUSTING,
|
||||
reduce_only=False,
|
||||
post_only=False,
|
||||
iceberg_qty=None,
|
||||
)
|
||||
```
|
||||
|
||||
## Symbol Metadata
|
||||
|
||||
Markets describe their capabilities via `SymbolMetadata`:
|
||||
|
||||
- **AmountConstraints**: Min/max order size, step size
|
||||
- **PriceConstraints**: Tick size, tick spacing mode (fixed/dynamic/continuous)
|
||||
- **MarketCapabilities**:
|
||||
- Supported sides (BUY, SELL)
|
||||
- Supported amount types (BASE, QUOTE, or both)
|
||||
- Market vs limit order support
|
||||
- Time-in-force options (GTC, IOC, FOK, DAY, GTD)
|
||||
- Conditional order support (stop-loss, take-profit, trailing stops)
|
||||
- Advanced features (post-only, reduce-only, iceberg)
|
||||
|
||||
## Asset Types
|
||||
|
||||
Comprehensive asset type system supporting:
|
||||
- **SPOT**: Cash markets
|
||||
- **MARGIN**: Margin trading
|
||||
- **PERP**: Perpetual futures
|
||||
- **FUTURE**: Dated futures
|
||||
- **OPTION**: Options contracts
|
||||
- **SYNTHETIC**: Derived instruments
|
||||
|
||||
Each asset has metadata describing contract specs, settlement, margin requirements, etc.
|
||||
|
||||
## Usage Pattern
|
||||
|
||||
```python
|
||||
# Create exchange kernel for specific exchange
|
||||
kernel = SomeExchangeKernel(exchange_id="binance_main")
|
||||
|
||||
# Subscribe to events
|
||||
kernel.subscribe_events(my_event_handler)
|
||||
|
||||
# Start kernel
|
||||
await kernel.start()
|
||||
|
||||
# Place order (creates intent, kernel handles execution)
|
||||
intent_id = await kernel.place_order(
|
||||
StandardOrder(
|
||||
symbol_id="BTC/USD",
|
||||
side=Side.BUY,
|
||||
amount=1.0,
|
||||
amount_type=AmountType.BASE,
|
||||
limit_price=50000.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Query desired state
|
||||
intent = await kernel.get_order_intent(intent_id)
|
||||
|
||||
# Query actual state
|
||||
state = await kernel.get_order_state(intent_id)
|
||||
|
||||
# Modify order (updates intent, kernel reconciles)
|
||||
await kernel.modify_order(intent_id, new_order)
|
||||
|
||||
# Cancel order
|
||||
await kernel.cancel_order(intent_id)
|
||||
|
||||
# Query positions
|
||||
positions = await kernel.get_positions()
|
||||
|
||||
# Query account state
|
||||
account = await kernel.get_account_state()
|
||||
```
|
||||
|
||||
## Implementation Status
|
||||
|
||||
✅ **Complete**:
|
||||
- Data models and type definitions
|
||||
- Event definitions
|
||||
- Abstract interface
|
||||
- State store framework
|
||||
- In-memory stores for testing
|
||||
|
||||
⏳ **TODO** (Exchange-specific implementations):
|
||||
- Concrete ExchangeKernel implementations per exchange
|
||||
- Reconciliation engine implementation
|
||||
- Exchange API adapters
|
||||
- Persistent state storage (database)
|
||||
- Error handling and retry logic
|
||||
- Monitoring and observability
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Create concrete implementations for specific exchanges (Binance, Uniswap, etc.)
|
||||
2. Implement reconciliation engine with proper error handling
|
||||
3. Add persistent storage backend for intents
|
||||
4. Build integration tests
|
||||
5. Add monitoring/metrics collection
|
||||
75
backend.old/src/exchange_kernel/__init__.py
Normal file
75
backend.old/src/exchange_kernel/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Exchange Kernel API
|
||||
|
||||
The exchange kernel provides a Kubernetes-style declarative API for managing orders
|
||||
across different exchanges. It maintains both desired state (intent) and actual state
|
||||
(current orders on exchange) and reconciles them continuously.
|
||||
|
||||
Key concepts:
|
||||
- OrderIntent: What the strategy kernel wants
|
||||
- OrderState: What actually exists on the exchange
|
||||
- Reconciliation: Bringing actual state into alignment with desired state
|
||||
"""
|
||||
|
||||
from .base import ExchangeKernel
|
||||
from .events import (
|
||||
OrderEvent,
|
||||
OrderSubmitted,
|
||||
OrderAccepted,
|
||||
OrderRejected,
|
||||
OrderPartiallyFilled,
|
||||
OrderFilled,
|
||||
OrderCanceled,
|
||||
OrderModified,
|
||||
OrderExpired,
|
||||
PositionEvent,
|
||||
PositionOpened,
|
||||
PositionModified,
|
||||
PositionClosed,
|
||||
AccountEvent,
|
||||
AccountBalanceUpdated,
|
||||
MarginCallWarning,
|
||||
)
|
||||
from .models import (
|
||||
OrderIntent,
|
||||
OrderState,
|
||||
Position,
|
||||
Asset,
|
||||
AssetMetadata,
|
||||
AccountState,
|
||||
Balance,
|
||||
)
|
||||
from .state import IntentStateStore, ActualStateStore
|
||||
|
||||
__all__ = [
|
||||
# Core interface
|
||||
"ExchangeKernel",
|
||||
# Events
|
||||
"OrderEvent",
|
||||
"OrderSubmitted",
|
||||
"OrderAccepted",
|
||||
"OrderRejected",
|
||||
"OrderPartiallyFilled",
|
||||
"OrderFilled",
|
||||
"OrderCanceled",
|
||||
"OrderModified",
|
||||
"OrderExpired",
|
||||
"PositionEvent",
|
||||
"PositionOpened",
|
||||
"PositionModified",
|
||||
"PositionClosed",
|
||||
"AccountEvent",
|
||||
"AccountBalanceUpdated",
|
||||
"MarginCallWarning",
|
||||
# Models
|
||||
"OrderIntent",
|
||||
"OrderState",
|
||||
"Position",
|
||||
"Asset",
|
||||
"AssetMetadata",
|
||||
"AccountState",
|
||||
"Balance",
|
||||
# State management
|
||||
"IntentStateStore",
|
||||
"ActualStateStore",
|
||||
]
|
||||
361
backend.old/src/exchange_kernel/base.py
Normal file
361
backend.old/src/exchange_kernel/base.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Base interface for Exchange Kernels.
|
||||
|
||||
Defines the abstract API that all exchange kernel implementations must support.
|
||||
Each exchange (or exchange type) will have its own kernel implementation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any
|
||||
|
||||
from .models import (
|
||||
OrderIntent,
|
||||
OrderState,
|
||||
Position,
|
||||
AccountState,
|
||||
AssetMetadata,
|
||||
)
|
||||
from .events import BaseEvent
|
||||
from ..schema.order_spec import (
|
||||
StandardOrder,
|
||||
StandardOrderGroup,
|
||||
SymbolMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ExchangeKernel(ABC):
|
||||
"""
|
||||
Abstract base class for exchange kernels.
|
||||
|
||||
An exchange kernel manages the lifecycle of orders on a specific exchange,
|
||||
maintaining both desired state (intents from strategy kernel) and actual
|
||||
state (current orders on exchange), and continuously reconciling them.
|
||||
|
||||
Think of it as a Kubernetes-style controller for trading orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_id: str):
|
||||
"""
|
||||
Initialize the exchange kernel.
|
||||
|
||||
Args:
|
||||
exchange_id: Unique identifier for this exchange instance
|
||||
"""
|
||||
self.exchange_id = exchange_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Command API - Strategy kernel sends intents
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def place_order(self, order: StandardOrder, metadata: dict[str, Any] | None = None) -> str:
|
||||
"""
|
||||
Place a single order on the exchange.
|
||||
|
||||
This creates an OrderIntent and begins the reconciliation process to
|
||||
get the order onto the exchange.
|
||||
|
||||
Args:
|
||||
order: The order specification
|
||||
metadata: Optional strategy-specific metadata
|
||||
|
||||
Returns:
|
||||
intent_id: Unique identifier for this order intent
|
||||
|
||||
Raises:
|
||||
ValidationError: If order violates market constraints
|
||||
ExchangeError: If exchange rejects the order
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def place_order_group(
|
||||
self,
|
||||
group: StandardOrderGroup,
|
||||
metadata: dict[str, Any] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Place a group of orders with OCO (One-Cancels-Other) relationship.
|
||||
|
||||
Args:
|
||||
group: Group of orders with OCO mode
|
||||
metadata: Optional strategy-specific metadata
|
||||
|
||||
Returns:
|
||||
intent_ids: List of intent IDs for each order in the group
|
||||
|
||||
Raises:
|
||||
ValidationError: If any order violates market constraints
|
||||
ExchangeError: If exchange rejects the group
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_order(self, intent_id: str) -> None:
|
||||
"""
|
||||
Cancel an order by intent ID.
|
||||
|
||||
Updates the intent to indicate cancellation is desired, and the
|
||||
reconciliation loop will handle the actual exchange cancellation.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID of the order to cancel
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
ExchangeError: If exchange rejects cancellation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def modify_order(
|
||||
self,
|
||||
intent_id: str,
|
||||
new_order: StandardOrder,
|
||||
) -> None:
|
||||
"""
|
||||
Modify an existing order.
|
||||
|
||||
Updates the order intent, and the reconciliation loop will update
|
||||
the exchange order (via modify API if available, or cancel+replace).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID of the order to modify
|
||||
new_order: New order specification
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
ValidationError: If new order violates market constraints
|
||||
ExchangeError: If exchange rejects modification
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_all_orders(self, symbol_id: str | None = None) -> int:
|
||||
"""
|
||||
Cancel all orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only cancel orders for this symbol
|
||||
|
||||
Returns:
|
||||
count: Number of orders canceled
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Query API - Read desired and actual state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_intent(self, intent_id: str) -> OrderIntent:
|
||||
"""
|
||||
Get the desired order state (what strategy kernel wants).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The order intent
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
"""
|
||||
Get the actual order state (what's currently on exchange).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The current order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_intents(self, symbol_id: str | None = None) -> list[OrderIntent]:
|
||||
"""
|
||||
Get all order intents, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return intents for this symbol
|
||||
|
||||
Returns:
|
||||
List of order intents
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_orders(self, symbol_id: str | None = None) -> list[OrderState]:
|
||||
"""
|
||||
Get all actual order states, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return orders for this symbol
|
||||
|
||||
Returns:
|
||||
List of order states
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_positions(self, symbol_id: str | None = None) -> list[Position]:
|
||||
"""
|
||||
Get current positions, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return positions for this symbol
|
||||
|
||||
Returns:
|
||||
List of positions
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_account_state(self) -> AccountState:
|
||||
"""
|
||||
Get current account state (balances, margin, etc.).
|
||||
|
||||
Returns:
|
||||
Current account state
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_symbol_metadata(self, symbol_id: str) -> SymbolMetadata:
|
||||
"""
|
||||
Get metadata for a symbol (constraints, capabilities, etc.).
|
||||
|
||||
Args:
|
||||
symbol_id: Symbol to query
|
||||
|
||||
Returns:
|
||||
Symbol metadata
|
||||
|
||||
Raises:
|
||||
NotFoundError: If symbol doesn't exist on this exchange
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_asset_metadata(self, asset_id: str) -> AssetMetadata:
|
||||
"""
|
||||
Get metadata for an asset.
|
||||
|
||||
Args:
|
||||
asset_id: Asset to query
|
||||
|
||||
Returns:
|
||||
Asset metadata
|
||||
|
||||
Raises:
|
||||
NotFoundError: If asset doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_symbols(self) -> list[str]:
|
||||
"""
|
||||
List all available symbols on this exchange.
|
||||
|
||||
Returns:
|
||||
List of symbol IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Event Subscription API
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def subscribe_events(
|
||||
self,
|
||||
callback: Callable[[BaseEvent], None],
|
||||
event_filter: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to events from this exchange kernel.
|
||||
|
||||
Args:
|
||||
callback: Function to call when events occur
|
||||
event_filter: Optional filter criteria (event_type, symbol_id, etc.)
|
||||
|
||||
Returns:
|
||||
subscription_id: Unique ID for this subscription (for unsubscribe)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unsubscribe_events(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe from events.
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription ID returned from subscribe_events
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lifecycle Management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the exchange kernel.
|
||||
|
||||
Initializes connections, starts reconciliation loops, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the exchange kernel.
|
||||
|
||||
Closes connections, stops reconciliation loops, etc.
|
||||
Does NOT cancel open orders - call cancel_all_orders() first if desired.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""
|
||||
Check health status of the exchange kernel.
|
||||
|
||||
Returns:
|
||||
Health status dict with connection state, latency, error counts, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Reconciliation Control (advanced)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def force_reconciliation(self, intent_id: str | None = None) -> None:
|
||||
"""
|
||||
Force immediate reconciliation.
|
||||
|
||||
Args:
|
||||
intent_id: If provided, only reconcile this specific intent.
|
||||
If None, reconcile all intents.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_reconciliation_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get metrics about the reconciliation process.
|
||||
|
||||
Returns:
|
||||
Metrics dict with reconciliation lag, error rates, retry counts, etc.
|
||||
"""
|
||||
pass
|
||||
250
backend.old/src/exchange_kernel/events.py
Normal file
250
backend.old/src/exchange_kernel/events.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Event definitions for the Exchange Kernel.
|
||||
|
||||
All events that can occur during the order lifecycle, position management,
|
||||
and account updates.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..schema.order_spec import Float, Uint64
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base Event Classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EventType(StrEnum):
|
||||
"""Types of events emitted by the exchange kernel"""
|
||||
# Order lifecycle
|
||||
ORDER_SUBMITTED = "ORDER_SUBMITTED"
|
||||
ORDER_ACCEPTED = "ORDER_ACCEPTED"
|
||||
ORDER_REJECTED = "ORDER_REJECTED"
|
||||
ORDER_PARTIALLY_FILLED = "ORDER_PARTIALLY_FILLED"
|
||||
ORDER_FILLED = "ORDER_FILLED"
|
||||
ORDER_CANCELED = "ORDER_CANCELED"
|
||||
ORDER_MODIFIED = "ORDER_MODIFIED"
|
||||
ORDER_EXPIRED = "ORDER_EXPIRED"
|
||||
|
||||
# Position events
|
||||
POSITION_OPENED = "POSITION_OPENED"
|
||||
POSITION_MODIFIED = "POSITION_MODIFIED"
|
||||
POSITION_CLOSED = "POSITION_CLOSED"
|
||||
|
||||
# Account events
|
||||
ACCOUNT_BALANCE_UPDATED = "ACCOUNT_BALANCE_UPDATED"
|
||||
MARGIN_CALL_WARNING = "MARGIN_CALL_WARNING"
|
||||
|
||||
# System events
|
||||
RECONCILIATION_FAILED = "RECONCILIATION_FAILED"
|
||||
CONNECTION_LOST = "CONNECTION_LOST"
|
||||
CONNECTION_RESTORED = "CONNECTION_RESTORED"
|
||||
|
||||
|
||||
class BaseEvent(BaseModel):
|
||||
"""Base class for all exchange kernel events"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
event_type: EventType = Field(description="Type of event")
|
||||
timestamp: Uint64 = Field(description="Event timestamp (Unix milliseconds)")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional event data")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OrderEvent(BaseEvent):
|
||||
"""Base class for order-related events"""
|
||||
|
||||
intent_id: str = Field(description="Order intent ID")
|
||||
order_id: str | None = Field(default=None, description="Exchange order ID (if assigned)")
|
||||
symbol_id: str = Field(description="Symbol being traded")
|
||||
|
||||
|
||||
class OrderSubmitted(OrderEvent):
|
||||
"""Order has been submitted to the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_SUBMITTED)
|
||||
client_order_id: str | None = Field(default=None, description="Client-assigned order ID")
|
||||
|
||||
|
||||
class OrderAccepted(OrderEvent):
|
||||
"""Order has been accepted by the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_ACCEPTED)
|
||||
order_id: str = Field(description="Exchange-assigned order ID")
|
||||
accepted_at: Uint64 = Field(description="Exchange acceptance timestamp")
|
||||
|
||||
|
||||
class OrderRejected(OrderEvent):
|
||||
"""Order was rejected by the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_REJECTED)
|
||||
reason: str = Field(description="Rejection reason")
|
||||
error_code: str | None = Field(default=None, description="Exchange error code")
|
||||
|
||||
|
||||
class OrderPartiallyFilled(OrderEvent):
|
||||
"""Order was partially filled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_PARTIALLY_FILLED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
fill_price: Float = Field(description="Fill price for this execution")
|
||||
fill_quantity: Float = Field(description="Quantity filled in this execution")
|
||||
total_filled: Float = Field(description="Total quantity filled so far")
|
||||
remaining_quantity: Float = Field(description="Remaining quantity to fill")
|
||||
commission: Float = Field(default=0.0, description="Commission/fee for this fill")
|
||||
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||
trade_id: str | None = Field(default=None, description="Exchange trade ID")
|
||||
|
||||
|
||||
class OrderFilled(OrderEvent):
|
||||
"""Order was completely filled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_FILLED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
average_fill_price: Float = Field(description="Average execution price")
|
||||
total_quantity: Float = Field(description="Total quantity filled")
|
||||
total_commission: Float = Field(default=0.0, description="Total commission/fees")
|
||||
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||
completed_at: Uint64 = Field(description="Completion timestamp")
|
||||
|
||||
|
||||
class OrderCanceled(OrderEvent):
|
||||
"""Order was canceled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_CANCELED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
reason: str = Field(description="Cancellation reason")
|
||||
filled_quantity: Float = Field(default=0.0, description="Quantity filled before cancellation")
|
||||
canceled_at: Uint64 = Field(description="Cancellation timestamp")
|
||||
|
||||
|
||||
class OrderModified(OrderEvent):
|
||||
"""Order was modified (price, quantity, etc.)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_MODIFIED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
old_price: Float | None = Field(default=None, description="Previous price")
|
||||
new_price: Float | None = Field(default=None, description="New price")
|
||||
old_quantity: Float | None = Field(default=None, description="Previous quantity")
|
||||
new_quantity: Float | None = Field(default=None, description="New quantity")
|
||||
modified_at: Uint64 = Field(description="Modification timestamp")
|
||||
|
||||
|
||||
class OrderExpired(OrderEvent):
|
||||
"""Order expired (GTD, DAY orders)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_EXPIRED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
filled_quantity: Float = Field(default=0.0, description="Quantity filled before expiration")
|
||||
expired_at: Uint64 = Field(description="Expiration timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PositionEvent(BaseEvent):
|
||||
"""Base class for position-related events"""
|
||||
|
||||
position_id: str = Field(description="Position identifier")
|
||||
symbol_id: str = Field(description="Symbol identifier")
|
||||
asset_id: str = Field(description="Asset identifier")
|
||||
|
||||
|
||||
class PositionOpened(PositionEvent):
|
||||
"""New position was opened"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_OPENED)
|
||||
quantity: Float = Field(description="Position quantity")
|
||||
entry_price: Float = Field(description="Entry price")
|
||||
side: str = Field(description="LONG or SHORT")
|
||||
leverage: Float | None = Field(default=None, description="Leverage")
|
||||
|
||||
|
||||
class PositionModified(PositionEvent):
|
||||
"""Existing position was modified (size change, etc.)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_MODIFIED)
|
||||
old_quantity: Float = Field(description="Previous quantity")
|
||||
new_quantity: Float = Field(description="New quantity")
|
||||
average_entry_price: Float = Field(description="Updated average entry price")
|
||||
unrealized_pnl: Float | None = Field(default=None, description="Current unrealized P&L")
|
||||
|
||||
|
||||
class PositionClosed(PositionEvent):
|
||||
"""Position was closed"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_CLOSED)
|
||||
exit_price: Float = Field(description="Exit price")
|
||||
realized_pnl: Float = Field(description="Realized profit/loss")
|
||||
closed_at: Uint64 = Field(description="Closure timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AccountEvent(BaseEvent):
|
||||
"""Base class for account-related events"""
|
||||
|
||||
account_id: str = Field(description="Account identifier")
|
||||
|
||||
|
||||
class AccountBalanceUpdated(AccountEvent):
|
||||
"""Account balance was updated"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ACCOUNT_BALANCE_UPDATED)
|
||||
asset_id: str = Field(description="Asset that changed")
|
||||
old_balance: Float = Field(description="Previous balance")
|
||||
new_balance: Float = Field(description="New balance")
|
||||
old_available: Float = Field(description="Previous available")
|
||||
new_available: Float = Field(description="New available")
|
||||
change_reason: str = Field(description="Why balance changed (TRADE, DEPOSIT, WITHDRAWAL, etc.)")
|
||||
|
||||
|
||||
class MarginCallWarning(AccountEvent):
|
||||
"""Margin level is approaching liquidation threshold"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.MARGIN_CALL_WARNING)
|
||||
margin_level: Float = Field(description="Current margin level")
|
||||
liquidation_threshold: Float = Field(description="Liquidation threshold")
|
||||
required_action: str = Field(description="Required action to avoid liquidation")
|
||||
estimated_liquidation_price: Float | None = Field(
|
||||
default=None,
|
||||
description="Estimated liquidation price for positions"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReconciliationFailed(BaseEvent):
|
||||
"""Failed to reconcile intent with actual state"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.RECONCILIATION_FAILED)
|
||||
intent_id: str = Field(description="Order intent ID")
|
||||
error_message: str = Field(description="Error details")
|
||||
retry_count: int = Field(description="Number of retry attempts")
|
||||
|
||||
|
||||
class ConnectionLost(BaseEvent):
|
||||
"""Connection to exchange was lost"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.CONNECTION_LOST)
|
||||
reason: str = Field(description="Disconnection reason")
|
||||
|
||||
|
||||
class ConnectionRestored(BaseEvent):
|
||||
"""Connection to exchange was restored"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.CONNECTION_RESTORED)
|
||||
downtime_duration: int = Field(description="Duration of downtime in milliseconds")
|
||||
194
backend.old/src/exchange_kernel/models.py
Normal file
194
backend.old/src/exchange_kernel/models.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Data models for the Exchange Kernel.
|
||||
|
||||
Defines order intents, order state, positions, assets, and account state.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..schema.order_spec import (
|
||||
StandardOrder,
|
||||
StandardOrderStatus,
|
||||
AssetType,
|
||||
Float,
|
||||
Uint64,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order Intent and State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OrderIntent(BaseModel):
|
||||
"""
|
||||
Desired order state from the strategy kernel.
|
||||
|
||||
This represents what the strategy wants, not what currently exists.
|
||||
The exchange kernel will work to reconcile actual state with this intent.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intent_id: str = Field(description="Unique identifier for this intent (client-assigned)")
|
||||
order: StandardOrder = Field(description="The desired order specification")
|
||||
group_id: str | None = Field(default=None, description="Group ID for OCO relationships")
|
||||
created_at: Uint64 = Field(description="When this intent was created")
|
||||
updated_at: Uint64 = Field(description="When this intent was last modified")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific metadata")
|
||||
|
||||
|
||||
class ReconciliationStatus(StrEnum):
|
||||
"""Status of reconciliation between intent and actual state"""
|
||||
PENDING = "PENDING" # Not yet submitted to exchange
|
||||
SUBMITTING = "SUBMITTING" # Currently being submitted
|
||||
ACTIVE = "ACTIVE" # Successfully placed on exchange
|
||||
RECONCILING = "RECONCILING" # Intent changed, updating exchange order
|
||||
FAILED = "FAILED" # Failed to submit or reconcile
|
||||
COMPLETED = "COMPLETED" # Order fully filled
|
||||
CANCELED = "CANCELED" # Order canceled
|
||||
|
||||
|
||||
class OrderState(BaseModel):
|
||||
"""
|
||||
Actual current state of an order on the exchange.
|
||||
|
||||
This represents reality - what the exchange reports about the order.
|
||||
May differ from OrderIntent during reconciliation.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intent_id: str = Field(description="Links back to the OrderIntent")
|
||||
exchange_order_id: str = Field(description="Exchange-assigned order ID")
|
||||
status: StandardOrderStatus = Field(description="Current order status from exchange")
|
||||
reconciliation_status: ReconciliationStatus = Field(description="Reconciliation state")
|
||||
last_sync_at: Uint64 = Field(description="Last time we synced with exchange")
|
||||
error_message: str | None = Field(default=None, description="Error details if FAILED")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position and Asset Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AssetMetadata(BaseModel):
|
||||
"""
|
||||
Metadata describing an asset type.
|
||||
|
||||
Provides context for positions, balances, and trading.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="Unique asset identifier")
|
||||
symbol: str = Field(description="Asset symbol (e.g., 'BTC', 'ETH', 'USD')")
|
||||
asset_type: AssetType = Field(description="Type of asset")
|
||||
name: str = Field(description="Full name")
|
||||
|
||||
# Contract specifications (for derivatives)
|
||||
contract_size: Float | None = Field(default=None, description="Contract multiplier")
|
||||
settlement_asset: str | None = Field(default=None, description="Settlement currency")
|
||||
expiry_timestamp: Uint64 | None = Field(default=None, description="Expiration timestamp")
|
||||
|
||||
# Trading parameters
|
||||
tick_size: Float | None = Field(default=None, description="Minimum price increment")
|
||||
lot_size: Float | None = Field(default=None, description="Minimum quantity increment")
|
||||
|
||||
# Margin requirements (for leveraged products)
|
||||
initial_margin_rate: Float | None = Field(default=None, description="Initial margin requirement")
|
||||
maintenance_margin_rate: Float | None = Field(default=None, description="Maintenance margin requirement")
|
||||
|
||||
# Additional metadata
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific metadata")
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
"""
|
||||
An asset holding (spot, margin, derivative position, etc.)
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="References AssetMetadata")
|
||||
quantity: Float = Field(description="Amount held (positive or negative for short positions)")
|
||||
available: Float = Field(description="Amount available for trading (not locked in orders)")
|
||||
locked: Float = Field(description="Amount locked in open orders")
|
||||
|
||||
# For derivative positions
|
||||
entry_price: Float | None = Field(default=None, description="Average entry price")
|
||||
mark_price: Float | None = Field(default=None, description="Current mark price")
|
||||
liquidation_price: Float | None = Field(default=None, description="Estimated liquidation price")
|
||||
unrealized_pnl: Float | None = Field(default=None, description="Unrealized profit/loss")
|
||||
realized_pnl: Float | None = Field(default=None, description="Realized profit/loss")
|
||||
|
||||
# Margin info
|
||||
margin_used: Float | None = Field(default=None, description="Margin allocated to this position")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""
|
||||
A trading position (spot, margin, perpetual, futures, etc.)
|
||||
|
||||
Tracks both the asset holdings and associated metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
position_id: str = Field(description="Unique position identifier")
|
||||
symbol_id: str = Field(description="Trading symbol")
|
||||
asset: Asset = Field(description="Asset holding details")
|
||||
metadata: AssetMetadata = Field(description="Asset metadata")
|
||||
|
||||
# Position-level info
|
||||
leverage: Float | None = Field(default=None, description="Current leverage")
|
||||
side: str | None = Field(default=None, description="LONG or SHORT (for derivatives)")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class Balance(BaseModel):
|
||||
"""Account balance for a single currency/asset"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="Asset identifier")
|
||||
total: Float = Field(description="Total balance")
|
||||
available: Float = Field(description="Available for trading")
|
||||
locked: Float = Field(description="Locked in orders/positions")
|
||||
|
||||
# For margin accounts
|
||||
borrowed: Float = Field(default=0.0, description="Borrowed amount (margin)")
|
||||
interest: Float = Field(default=0.0, description="Accrued interest")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class AccountState(BaseModel):
|
||||
"""
|
||||
Complete account state including balances, positions, and margin info.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
account_id: str = Field(description="Account identifier")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
|
||||
balances: list[Balance] = Field(default_factory=list, description="All asset balances")
|
||||
positions: list[Position] = Field(default_factory=list, description="All open positions")
|
||||
|
||||
# Margin account info
|
||||
total_equity: Float | None = Field(default=None, description="Total account equity")
|
||||
total_margin_used: Float | None = Field(default=None, description="Total margin in use")
|
||||
total_available_margin: Float | None = Field(default=None, description="Available margin")
|
||||
margin_level: Float | None = Field(default=None, description="Margin level (equity/margin_used)")
|
||||
|
||||
# Risk metrics
|
||||
total_unrealized_pnl: Float | None = Field(default=None, description="Total unrealized P&L")
|
||||
total_realized_pnl: Float | None = Field(default=None, description="Total realized P&L")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific data")
|
||||
472
backend.old/src/exchange_kernel/state.py
Normal file
472
backend.old/src/exchange_kernel/state.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
State management for the Exchange Kernel.
|
||||
|
||||
Implements the storage and reconciliation logic for desired vs actual state.
|
||||
This is the "Kubernetes for orders" concept - maintaining intent and continuously
|
||||
reconciling reality to match intent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
from .models import OrderIntent, OrderState, ReconciliationStatus
|
||||
from ..schema.order_spec import Uint64
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Intent State Store - Desired State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class IntentStateStore(ABC):
|
||||
"""
|
||||
Storage for order intents (desired state).
|
||||
|
||||
This represents what the strategy kernel wants. Intents are durable and
|
||||
persist across restarts. The reconciliation loop continuously works to
|
||||
make actual state match these intents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_intent(self, intent: OrderIntent) -> None:
|
||||
"""
|
||||
Store a new order intent.
|
||||
|
||||
Args:
|
||||
intent: The order intent to store
|
||||
|
||||
Raises:
|
||||
AlreadyExistsError: If intent_id already exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||
"""
|
||||
Retrieve an order intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to retrieve
|
||||
|
||||
Returns:
|
||||
The order intent
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_intent(self, intent: OrderIntent) -> None:
|
||||
"""
|
||||
Update an existing order intent.
|
||||
|
||||
Args:
|
||||
intent: Updated intent (intent_id must match existing)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_intent(self, intent_id: str) -> None:
|
||||
"""
|
||||
Delete an order intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to delete
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_intents(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> list[OrderIntent]:
|
||||
"""
|
||||
List all order intents, optionally filtered.
|
||||
|
||||
Args:
|
||||
symbol_id: Filter by symbol
|
||||
group_id: Filter by OCO group
|
||||
|
||||
Returns:
|
||||
List of matching intents
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||
"""
|
||||
Get all intents in an OCO group.
|
||||
|
||||
Args:
|
||||
group_id: Group ID to query
|
||||
|
||||
Returns:
|
||||
List of intents in the group
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Actual State Store - Current Reality
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ActualStateStore(ABC):
|
||||
"""
|
||||
Storage for actual order state (reality on exchange).
|
||||
|
||||
This represents what actually exists on the exchange right now.
|
||||
Updated frequently from exchange feeds and order status queries.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_order_state(self, state: OrderState) -> None:
|
||||
"""
|
||||
Store a new order state.
|
||||
|
||||
Args:
|
||||
state: The order state to store
|
||||
|
||||
Raises:
|
||||
AlreadyExistsError: If order state for this intent_id already exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
"""
|
||||
Retrieve order state for an intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The current order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no state exists for this intent
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||
"""
|
||||
Retrieve order state by exchange order ID.
|
||||
|
||||
Useful for processing exchange callbacks that only provide exchange_order_id.
|
||||
|
||||
Args:
|
||||
exchange_order_id: Exchange's order ID
|
||||
|
||||
Returns:
|
||||
The order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no state exists for this exchange order ID
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_order_state(self, state: OrderState) -> None:
|
||||
"""
|
||||
Update an existing order state.
|
||||
|
||||
Args:
|
||||
state: Updated state (intent_id must match existing)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If state doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_order_state(self, intent_id: str) -> None:
|
||||
"""
|
||||
Delete an order state.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID whose state to delete
|
||||
|
||||
Raises:
|
||||
NotFoundError: If state doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_order_states(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
reconciliation_status: ReconciliationStatus | None = None,
|
||||
) -> list[OrderState]:
|
||||
"""
|
||||
List all order states, optionally filtered.
|
||||
|
||||
Args:
|
||||
symbol_id: Filter by symbol
|
||||
reconciliation_status: Filter by reconciliation status
|
||||
|
||||
Returns:
|
||||
List of matching order states
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||
"""
|
||||
Find orders that haven't been synced recently.
|
||||
|
||||
Used to identify orders that need status updates from exchange.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age since last sync
|
||||
|
||||
Returns:
|
||||
List of order states that need refresh
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-Memory Implementations (for testing/prototyping)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InMemoryIntentStore(IntentStateStore):
|
||||
"""Simple in-memory implementation of IntentStateStore"""
|
||||
|
||||
def __init__(self):
|
||||
self._intents: dict[str, OrderIntent] = {}
|
||||
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||
self._by_group: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
async def create_intent(self, intent: OrderIntent) -> None:
|
||||
if intent.intent_id in self._intents:
|
||||
raise ValueError(f"Intent {intent.intent_id} already exists")
|
||||
self._intents[intent.intent_id] = intent
|
||||
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].add(intent.intent_id)
|
||||
|
||||
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||
if intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent_id} not found")
|
||||
return self._intents[intent_id]
|
||||
|
||||
async def update_intent(self, intent: OrderIntent) -> None:
|
||||
if intent.intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent.intent_id} not found")
|
||||
old_intent = self._intents[intent.intent_id]
|
||||
|
||||
# Update indices if symbol or group changed
|
||||
if old_intent.order.symbol_id != intent.order.symbol_id:
|
||||
self._by_symbol[old_intent.order.symbol_id].discard(intent.intent_id)
|
||||
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||
|
||||
if old_intent.group_id != intent.group_id:
|
||||
if old_intent.group_id:
|
||||
self._by_group[old_intent.group_id].discard(intent.intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].add(intent.intent_id)
|
||||
|
||||
self._intents[intent.intent_id] = intent
|
||||
|
||||
async def delete_intent(self, intent_id: str) -> None:
|
||||
if intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent_id} not found")
|
||||
intent = self._intents[intent_id]
|
||||
self._by_symbol[intent.order.symbol_id].discard(intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].discard(intent_id)
|
||||
del self._intents[intent_id]
|
||||
|
||||
async def list_intents(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> list[OrderIntent]:
|
||||
if symbol_id and group_id:
|
||||
# Intersection of both filters
|
||||
symbol_ids = self._by_symbol.get(symbol_id, set())
|
||||
group_ids = self._by_group.get(group_id, set())
|
||||
intent_ids = symbol_ids & group_ids
|
||||
elif symbol_id:
|
||||
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||
elif group_id:
|
||||
intent_ids = self._by_group.get(group_id, set())
|
||||
else:
|
||||
intent_ids = self._intents.keys()
|
||||
|
||||
return [self._intents[iid] for iid in intent_ids]
|
||||
|
||||
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||
intent_ids = self._by_group.get(group_id, set())
|
||||
return [self._intents[iid] for iid in intent_ids]
|
||||
|
||||
|
||||
class InMemoryActualStateStore(ActualStateStore):
|
||||
"""Simple in-memory implementation of ActualStateStore"""
|
||||
|
||||
def __init__(self):
|
||||
self._states: dict[str, OrderState] = {}
|
||||
self._by_exchange_id: dict[str, str] = {} # exchange_order_id -> intent_id
|
||||
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
async def create_order_state(self, state: OrderState) -> None:
|
||||
if state.intent_id in self._states:
|
||||
raise ValueError(f"Order state for intent {state.intent_id} already exists")
|
||||
self._states[state.intent_id] = state
|
||||
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||
self._by_symbol[state.status.order.symbol_id].add(state.intent_id)
|
||||
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
if intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||
return self._states[intent_id]
|
||||
|
||||
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||
if exchange_order_id not in self._by_exchange_id:
|
||||
raise KeyError(f"Order state for exchange order {exchange_order_id} not found")
|
||||
intent_id = self._by_exchange_id[exchange_order_id]
|
||||
return self._states[intent_id]
|
||||
|
||||
async def update_order_state(self, state: OrderState) -> None:
|
||||
if state.intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {state.intent_id} not found")
|
||||
old_state = self._states[state.intent_id]
|
||||
|
||||
# Update exchange_id index if it changed
|
||||
if old_state.exchange_order_id != state.exchange_order_id:
|
||||
del self._by_exchange_id[old_state.exchange_order_id]
|
||||
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||
|
||||
# Update symbol index if it changed
|
||||
old_symbol = old_state.status.order.symbol_id
|
||||
new_symbol = state.status.order.symbol_id
|
||||
if old_symbol != new_symbol:
|
||||
self._by_symbol[old_symbol].discard(state.intent_id)
|
||||
self._by_symbol[new_symbol].add(state.intent_id)
|
||||
|
||||
self._states[state.intent_id] = state
|
||||
|
||||
async def delete_order_state(self, intent_id: str) -> None:
|
||||
if intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||
state = self._states[intent_id]
|
||||
del self._by_exchange_id[state.exchange_order_id]
|
||||
self._by_symbol[state.status.order.symbol_id].discard(intent_id)
|
||||
del self._states[intent_id]
|
||||
|
||||
async def list_order_states(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
reconciliation_status: ReconciliationStatus | None = None,
|
||||
) -> list[OrderState]:
|
||||
if symbol_id:
|
||||
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||
states = [self._states[iid] for iid in intent_ids]
|
||||
else:
|
||||
states = list(self._states.values())
|
||||
|
||||
if reconciliation_status:
|
||||
states = [s for s in states if s.reconciliation_status == reconciliation_status]
|
||||
|
||||
return states
|
||||
|
||||
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||
import time
|
||||
current_time = int(time.time())
|
||||
threshold = current_time - max_age_seconds
|
||||
|
||||
return [
|
||||
state
|
||||
for state in self._states.values()
|
||||
if state.last_sync_at < threshold
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconciliation Engine (framework only, no implementation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReconciliationEngine:
|
||||
"""
|
||||
Reconciliation engine that continuously works to make actual state match intent.
|
||||
|
||||
This is the heart of the "Kubernetes for orders" concept. It:
|
||||
1. Compares desired state (intents) with actual state (exchange orders)
|
||||
2. Computes necessary actions (place, modify, cancel)
|
||||
3. Executes those actions via the exchange API
|
||||
4. Handles retries, errors, and edge cases
|
||||
|
||||
This is a framework class - concrete implementations will be exchange-specific.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_store: IntentStateStore,
|
||||
actual_store: ActualStateStore,
|
||||
):
|
||||
"""
|
||||
Initialize the reconciliation engine.
|
||||
|
||||
Args:
|
||||
intent_store: Store for desired state
|
||||
actual_store: Store for actual state
|
||||
"""
|
||||
self.intent_store = intent_store
|
||||
self.actual_store = actual_store
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the reconciliation loop"""
|
||||
self._running = True
|
||||
# Implementation would start async reconciliation loop here
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the reconciliation loop"""
|
||||
self._running = False
|
||||
# Implementation would stop reconciliation loop here
|
||||
pass
|
||||
|
||||
async def reconcile_intent(self, intent_id: str) -> None:
|
||||
"""
|
||||
Reconcile a specific intent.
|
||||
|
||||
Compares the intent with actual state and takes necessary actions.
|
||||
|
||||
Args:
|
||||
intent_id: Intent to reconcile
|
||||
"""
|
||||
# Framework only - concrete implementation needed
|
||||
pass
|
||||
|
||||
async def reconcile_all(self) -> None:
|
||||
"""
|
||||
Reconcile all intents.
|
||||
|
||||
Full reconciliation pass over all orders.
|
||||
"""
|
||||
# Framework only - concrete implementation needed
|
||||
pass
|
||||
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get reconciliation metrics.
|
||||
|
||||
Returns:
|
||||
Metrics about reconciliation performance, errors, etc.
|
||||
"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"reconciliation_lag_ms": 0, # Framework only
|
||||
"pending_reconciliations": 0, # Framework only
|
||||
"error_count": 0, # Framework only
|
||||
"retry_count": 0, # Framework only
|
||||
}
|
||||
4
backend.old/src/gateway/__init__.py
Normal file
4
backend.old/src/gateway/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from gateway.protocol import UserMessage, AgentMessage, ChannelStatus
|
||||
from gateway.hub import Gateway
|
||||
|
||||
__all__ = ["UserMessage", "AgentMessage", "ChannelStatus", "Gateway"]
|
||||
3
backend.old/src/gateway/channels/__init__.py
Normal file
3
backend.old/src/gateway/channels/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from gateway.channels.base import Channel
|
||||
|
||||
__all__ = ["Channel"]
|
||||
73
backend.old/src/gateway/channels/base.py
Normal file
73
backend.old/src/gateway/channels/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, Dict, Any
|
||||
|
||||
from gateway.protocol import UserMessage, AgentMessage, ChannelStatus
|
||||
|
||||
|
||||
class Channel(ABC):
|
||||
"""Abstract base class for communication channels.
|
||||
|
||||
Channels are the transport layer between users and the agent system.
|
||||
They handle protocol-specific encoding/decoding and maintain connection state.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_id: str, channel_type: str):
|
||||
self.channel_id = channel_id
|
||||
self.channel_type = channel_type
|
||||
self._connected = False
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, message: AgentMessage) -> None:
|
||||
"""Send a message from the agent to the user through this channel.
|
||||
|
||||
Args:
|
||||
message: AgentMessage to send (may be streaming chunk or complete message)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> AsyncIterator[UserMessage]:
|
||||
"""Receive messages from the user through this channel.
|
||||
|
||||
Yields:
|
||||
UserMessage objects as they arrive
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""Whether this channel supports streaming responses.
|
||||
|
||||
Returns:
|
||||
True if the channel can handle streaming chunks
|
||||
"""
|
||||
return False
|
||||
|
||||
def supports_attachments(self) -> bool:
|
||||
"""Whether this channel supports file attachments.
|
||||
|
||||
Returns:
|
||||
True if the channel can handle attachments
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_status(self) -> ChannelStatus:
|
||||
"""Get current channel status.
|
||||
|
||||
Returns:
|
||||
ChannelStatus object with connection info and capabilities
|
||||
"""
|
||||
return ChannelStatus(
|
||||
channel_id=self.channel_id,
|
||||
channel_type=self.channel_type,
|
||||
connected=self._connected,
|
||||
capabilities={
|
||||
"streaming": self.supports_streaming(),
|
||||
"attachments": self.supports_attachments(),
|
||||
"markdown": True # Most channels support some form of markdown
|
||||
}
|
||||
)
|
||||
99
backend.old/src/gateway/channels/websocket.py
Normal file
99
backend.old/src/gateway/channels/websocket.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterator, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import WebSocket
|
||||
|
||||
from gateway.channels.base import Channel
|
||||
from gateway.protocol import UserMessage, AgentMessage, WebSocketAgentUserMessage, WebSocketAgentChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketChannel(Channel):
|
||||
"""WebSocket-based communication channel.
|
||||
|
||||
Integrates with the existing WebSocket endpoint to provide
|
||||
bidirectional agent communication with streaming support.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_id: str, websocket: WebSocket, session_id: str):
|
||||
super().__init__(channel_id, "websocket")
|
||||
self.websocket = websocket
|
||||
self.session_id = session_id
|
||||
self._connected = True
|
||||
self._receive_queue: asyncio.Queue[UserMessage] = asyncio.Queue()
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""WebSocket supports streaming responses."""
|
||||
return True
|
||||
|
||||
def supports_attachments(self) -> bool:
|
||||
"""WebSocket can support attachments via URLs."""
|
||||
return True
|
||||
|
||||
async def send(self, message: AgentMessage) -> None:
|
||||
"""Send agent message through WebSocket.
|
||||
|
||||
For streaming messages, sends chunks as they arrive.
|
||||
For complete messages, sends as a single chunk.
|
||||
"""
|
||||
if not self._connected:
|
||||
logger.warning(f"Cannot send message, channel {self.channel_id} not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
chunk = WebSocketAgentChunk(
|
||||
session_id=message.session_id,
|
||||
content=message.content,
|
||||
done=message.done,
|
||||
metadata=message.metadata
|
||||
)
|
||||
chunk_data = chunk.model_dump(mode="json")
|
||||
logger.debug(f"Sending WebSocket message: done={message.done}, content_length={len(message.content)}")
|
||||
await self.websocket.send_json(chunk_data)
|
||||
logger.debug(f"WebSocket message sent successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket send error: {e}", exc_info=True)
|
||||
self._connected = False
|
||||
|
||||
async def receive(self) -> AsyncIterator[UserMessage]:
|
||||
"""Receive messages from WebSocket.
|
||||
|
||||
Yields:
|
||||
UserMessage objects as they arrive from the client
|
||||
"""
|
||||
try:
|
||||
while self._connected:
|
||||
# Read from WebSocket
|
||||
data = await self.websocket.receive_text()
|
||||
message_json = json.loads(data)
|
||||
|
||||
# Only process agent_user_message types
|
||||
if message_json.get("type") == "agent_user_message":
|
||||
msg = WebSocketAgentUserMessage(**message_json)
|
||||
|
||||
user_msg = UserMessage(
|
||||
session_id=msg.session_id,
|
||||
channel_id=self.channel_id,
|
||||
content=msg.content,
|
||||
attachments=msg.attachments,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
yield user_msg
|
||||
except Exception as e:
|
||||
print(f"WebSocket receive error: {e}")
|
||||
self._connected = False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
self._connected = False
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Note: WebSocket close is handled by the main WebSocket endpoint
|
||||
253
backend.old/src/gateway/hub.py
Normal file
253
backend.old/src/gateway/hub.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Callable, Awaitable, Any
|
||||
from datetime import datetime
|
||||
|
||||
from gateway.channels.base import Channel
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage, AgentMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""Central hub for routing messages between users, channels, and the agent.
|
||||
|
||||
The Gateway:
|
||||
- Maintains active channels and user sessions
|
||||
- Routes user messages to the agent
|
||||
- Routes agent responses back to appropriate channels
|
||||
- Handles interruption of in-flight agent tasks
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.channels: Dict[str, Channel] = {}
|
||||
self.sessions: Dict[str, UserSession] = {}
|
||||
self.agent_executor: Optional[Callable[[UserSession, UserMessage], Awaitable[AsyncIterator[str]]]] = None
|
||||
|
||||
def set_agent_executor(
|
||||
self,
|
||||
executor: Callable[[UserSession, UserMessage], Awaitable[Any]]
|
||||
) -> None:
|
||||
"""Set the agent executor function.
|
||||
|
||||
Args:
|
||||
executor: Async function that takes (session, message) and returns async iterator of response chunks
|
||||
"""
|
||||
self.agent_executor = executor
|
||||
|
||||
def register_channel(self, channel: Channel) -> None:
|
||||
"""Register a new communication channel.
|
||||
|
||||
Args:
|
||||
channel: Channel instance to register
|
||||
"""
|
||||
self.channels[channel.channel_id] = channel
|
||||
|
||||
def unregister_channel(self, channel_id: str) -> None:
|
||||
"""Unregister a channel.
|
||||
|
||||
Args:
|
||||
channel_id: ID of channel to remove
|
||||
"""
|
||||
if channel_id in self.channels:
|
||||
# Remove channel from any sessions
|
||||
for session in self.sessions.values():
|
||||
session.remove_channel(channel_id)
|
||||
del self.channels[channel_id]
|
||||
|
||||
def get_or_create_session(self, session_id: str, user_id: str) -> UserSession:
|
||||
"""Get existing session or create a new one.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
UserSession instance
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = UserSession(session_id, user_id)
|
||||
return self.sessions[session_id]
|
||||
|
||||
async def route_user_message(self, message: UserMessage) -> None:
|
||||
"""Route a user message to the agent.
|
||||
|
||||
If the agent is currently processing a message, it will be interrupted
|
||||
and restarted with the new message appended to the conversation.
|
||||
|
||||
Args:
|
||||
message: UserMessage from a channel
|
||||
"""
|
||||
logger.info(f"route_user_message called - session: {message.session_id}, channel: {message.channel_id}")
|
||||
|
||||
# Get or create session
|
||||
session = self.get_or_create_session(message.session_id, message.session_id)
|
||||
logger.info(f"Session obtained/created: {message.session_id}, channels: {session.active_channels}")
|
||||
|
||||
# Ensure channel is attached to session
|
||||
session.add_channel(message.channel_id)
|
||||
logger.info(f"Channel added to session: {message.channel_id}")
|
||||
|
||||
# If agent is busy, interrupt it
|
||||
if session.is_busy():
|
||||
logger.info(f"Session is busy, interrupting existing task")
|
||||
await session.interrupt()
|
||||
|
||||
# Check if this is a stop interrupt (empty message)
|
||||
if not message.content.strip() and not message.attachments:
|
||||
logger.info("Received stop interrupt (empty message), not starting new agent round")
|
||||
return
|
||||
|
||||
# Add user message to history
|
||||
session.add_message("user", message.content, message.channel_id)
|
||||
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
||||
|
||||
# Start agent task
|
||||
if self.agent_executor:
|
||||
logger.info("Starting agent task execution")
|
||||
task = asyncio.create_task(
|
||||
self._execute_agent_and_stream(session, message)
|
||||
)
|
||||
session.set_task(task)
|
||||
logger.info(f"Agent task created and set on session")
|
||||
else:
|
||||
logger.error("No agent_executor configured! Cannot process message.")
|
||||
# Send error message to user
|
||||
error_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="Error: Agent system not initialized. Please check that ANTHROPIC_API_KEY is configured.",
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(error_msg)
|
||||
|
||||
async def _execute_agent_and_stream(self, session: UserSession, message: UserMessage) -> None:
|
||||
"""Execute agent and stream responses back to channels.
|
||||
|
||||
Args:
|
||||
session: User session
|
||||
message: Triggering user message
|
||||
"""
|
||||
logger.info(f"_execute_agent_and_stream starting for session {session.session_id}")
|
||||
try:
|
||||
# Execute agent (returns async generator directly)
|
||||
logger.info("Calling agent_executor...")
|
||||
response_stream = self.agent_executor(session, message)
|
||||
logger.info("Agent executor returned response stream")
|
||||
|
||||
# Stream chunks back to active channels
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
accumulated_metadata = {}
|
||||
|
||||
async for chunk in response_stream:
|
||||
# Handle dict response with metadata (from agent executor)
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
metadata = chunk.get("metadata", {})
|
||||
# Accumulate metadata (e.g., plot_urls)
|
||||
for key, value in metadata.items():
|
||||
if key == "plot_urls" and value:
|
||||
# Append to existing plot_urls
|
||||
if "plot_urls" not in accumulated_metadata:
|
||||
accumulated_metadata["plot_urls"] = []
|
||||
accumulated_metadata["plot_urls"].extend(value)
|
||||
logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}")
|
||||
else:
|
||||
accumulated_metadata[key] = value
|
||||
chunk = content
|
||||
|
||||
# Only send non-empty chunks
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
|
||||
# Send chunk to all active channels with accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False,
|
||||
metadata=accumulated_metadata.copy()
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
|
||||
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
||||
|
||||
# Send final done message with all accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="",
|
||||
stream_chunk=True,
|
||||
done=True,
|
||||
metadata=accumulated_metadata
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}")
|
||||
|
||||
# Add to history
|
||||
session.add_message("assistant", full_response)
|
||||
logger.info("Assistant response added to history")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was interrupted
|
||||
logger.warning(f"Agent task interrupted for session {session.session_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution error: {e}", exc_info=True)
|
||||
# Send error message
|
||||
error_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=f"Error: {str(e)}",
|
||||
done=True
|
||||
)
|
||||
await self._send_to_channels(error_msg)
|
||||
finally:
|
||||
session.set_task(None)
|
||||
logger.info(f"Agent task completed for session {session.session_id}")
|
||||
|
||||
async def _send_to_channels(self, message: AgentMessage) -> None:
|
||||
"""Send message to specified channels.
|
||||
|
||||
Args:
|
||||
message: AgentMessage to send
|
||||
"""
|
||||
logger.debug(f"Sending message to {len(message.target_channels)} channels")
|
||||
for channel_id in message.target_channels:
|
||||
channel = self.channels.get(channel_id)
|
||||
if channel and channel.get_status().connected:
|
||||
try:
|
||||
await channel.send(message)
|
||||
logger.debug(f"Message sent to channel {channel_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to channel {channel_id}: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning(f"Channel {channel_id} not found or not connected")
|
||||
|
||||
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session information.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to query
|
||||
|
||||
Returns:
|
||||
Session info dict or None if not found
|
||||
"""
|
||||
session = self.sessions.get(session_id)
|
||||
return session.to_dict() if session else None
|
||||
|
||||
def get_active_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all active sessions.
|
||||
|
||||
Returns:
|
||||
Dict mapping session_id to session info
|
||||
"""
|
||||
return {
|
||||
sid: session.to_dict()
|
||||
for sid, session in self.sessions.items()
|
||||
}
|
||||
57
backend.old/src/gateway/protocol.py
Normal file
57
backend.old/src/gateway/protocol.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""Message from user to agent through a communication channel."""
|
||||
session_id: str
|
||||
channel_id: str
|
||||
content: str
|
||||
attachments: List[str] = Field(default_factory=list, description="URLs or file paths")
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
"""Message from agent to user(s) through one or more channels."""
|
||||
session_id: str
|
||||
target_channels: List[str] = Field(description="List of channel IDs to send to")
|
||||
content: str
|
||||
stream_chunk: bool = Field(default=False, description="True if this is a streaming chunk")
|
||||
done: bool = Field(default=False, description="True if streaming is complete")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelStatus(BaseModel):
|
||||
"""Status information about a communication channel."""
|
||||
channel_id: str
|
||||
channel_type: str = Field(description="Type: 'websocket', 'slack', 'telegram', etc.")
|
||||
connected: bool
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
capabilities: Dict[str, bool] = Field(
|
||||
default_factory=lambda: {
|
||||
"streaming": False,
|
||||
"attachments": False,
|
||||
"markdown": False
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# WebSocket-specific protocol extensions
|
||||
class WebSocketAgentUserMessage(BaseModel):
|
||||
"""WebSocket message: User → Backend (agent chat)"""
|
||||
type: Literal["agent_user_message"] = "agent_user_message"
|
||||
session_id: str
|
||||
content: str
|
||||
attachments: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WebSocketAgentChunk(BaseModel):
|
||||
"""WebSocket message: Backend → User (streaming agent response)"""
|
||||
type: Literal["agent_chunk"] = "agent_chunk"
|
||||
session_id: str
|
||||
content: str
|
||||
done: bool = False
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
107
backend.old/src/gateway/user_session.py
Normal file
107
backend.old/src/gateway/user_session.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single message in a conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
channel_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class UserSession:
|
||||
"""Manages a user's conversation session with the agent.
|
||||
|
||||
A session tracks:
|
||||
- Active communication channels
|
||||
- Conversation history
|
||||
- In-flight agent tasks (for interruption)
|
||||
- User metadata
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.active_channels: List[str] = []
|
||||
self.conversation_history: List[Message] = []
|
||||
self.current_task: Optional[asyncio.Task] = None
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def add_channel(self, channel_id: str) -> None:
|
||||
"""Attach a channel to this session."""
|
||||
if channel_id not in self.active_channels:
|
||||
self.active_channels.append(channel_id)
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def remove_channel(self, channel_id: str) -> None:
|
||||
"""Detach a channel from this session."""
|
||||
if channel_id in self.active_channels:
|
||||
self.active_channels.remove(channel_id)
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def add_message(self, role: str, content: str, channel_id: Optional[str] = None, **kwargs) -> None:
|
||||
"""Add a message to conversation history."""
|
||||
message = Message(
|
||||
role=role,
|
||||
content=content,
|
||||
channel_id=channel_id,
|
||||
metadata=kwargs
|
||||
)
|
||||
self.conversation_history.append(message)
|
||||
self.last_activity = datetime.utcnow()
|
||||
|
||||
def get_history(self, limit: Optional[int] = None) -> List[Message]:
|
||||
"""Get conversation history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of recent messages to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of Message objects
|
||||
"""
|
||||
if limit:
|
||||
return self.conversation_history[-limit:]
|
||||
return self.conversation_history
|
||||
|
||||
async def interrupt(self) -> bool:
|
||||
"""Interrupt the current agent task if one is running.
|
||||
|
||||
Returns:
|
||||
True if a task was interrupted, False otherwise
|
||||
"""
|
||||
if self.current_task and not self.current_task.done():
|
||||
self.current_task.cancel()
|
||||
try:
|
||||
await self.current_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.current_task = None
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_task(self, task: asyncio.Task) -> None:
|
||||
"""Set the current agent task."""
|
||||
self.current_task = task
|
||||
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if the agent is currently processing a request."""
|
||||
return self.current_task is not None and not self.current_task.done()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize session to dict."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"active_channels": self.active_channels,
|
||||
"message_count": len(self.conversation_history),
|
||||
"is_busy": self.is_busy(),
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_activity": self.last_activity.isoformat(),
|
||||
"metadata": self.metadata
|
||||
}
|
||||
179
backend.old/src/indicator/__init__.py
Normal file
179
backend.old/src/indicator/__init__.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Composable Indicator System.
|
||||
|
||||
Provides a framework for building DAGs of data transformation pipelines
|
||||
that process time-series data incrementally. Indicators can consume
|
||||
DataSources or other Indicators as inputs, composing into arbitrarily
|
||||
complex processing graphs.
|
||||
|
||||
Key Components:
|
||||
---------------
|
||||
|
||||
Indicator (base.py):
|
||||
Abstract base class for all indicator implementations.
|
||||
Declares input/output schemas and implements synchronous compute().
|
||||
|
||||
IndicatorRegistry (registry.py):
|
||||
Central catalog of available indicators with rich metadata
|
||||
for AI agent discovery and tool generation.
|
||||
|
||||
Pipeline (pipeline.py):
|
||||
Execution engine that builds DAGs, resolves dependencies,
|
||||
and orchestrates incremental data flow through indicator chains.
|
||||
|
||||
Schema Types (schema.py):
|
||||
Type definitions for input/output schemas, computation context,
|
||||
and metadata for AI-native documentation.
|
||||
|
||||
Usage Example:
|
||||
--------------
|
||||
|
||||
from indicator import Indicator, IndicatorRegistry, Pipeline
|
||||
from indicator.schema import (
|
||||
InputSchema, OutputSchema, ComputeContext, ComputeResult,
|
||||
IndicatorMetadata, IndicatorParameter
|
||||
)
|
||||
|
||||
# Define an indicator
|
||||
class SimpleMovingAverage(Indicator):
|
||||
@classmethod
|
||||
def get_metadata(cls):
|
||||
return IndicatorMetadata(
|
||||
name="SMA",
|
||||
display_name="Simple Moving Average",
|
||||
description="Arithmetic mean of prices over N periods",
|
||||
category="trend",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="period",
|
||||
type="int",
|
||||
description="Number of periods to average",
|
||||
default=20,
|
||||
min_value=1
|
||||
)
|
||||
],
|
||||
tags=["moving-average", "trend-following"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls):
|
||||
return InputSchema(
|
||||
required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params):
|
||||
return OutputSchema(
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="sma",
|
||||
type="float",
|
||||
description=f"Simple moving average over {params.get('period', 20)} periods"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
period = self.params["period"]
|
||||
closes = context.get_column("close")
|
||||
times = context.get_times()
|
||||
|
||||
sma_values = []
|
||||
for i in range(len(closes)):
|
||||
if i < period - 1:
|
||||
sma_values.append(None)
|
||||
else:
|
||||
window = closes[i - period + 1 : i + 1]
|
||||
sma_values.append(sum(window) / period)
|
||||
|
||||
return ComputeResult(
|
||||
data=[
|
||||
{"time": times[i], "sma": sma_values[i]}
|
||||
for i in range(len(times))
|
||||
]
|
||||
)
|
||||
|
||||
# Register the indicator
|
||||
registry = IndicatorRegistry()
|
||||
registry.register(SimpleMovingAverage)
|
||||
|
||||
# Create a pipeline
|
||||
pipeline = Pipeline(datasource_registry)
|
||||
pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D")
|
||||
|
||||
sma_indicator = registry.create_instance("SMA", "sma_20", period=20)
|
||||
pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"])
|
||||
|
||||
# Execute
|
||||
results = pipeline.execute(datasource_data={"price_data": price_bars})
|
||||
sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma
|
||||
|
||||
Design Philosophy:
|
||||
------------------
|
||||
|
||||
1. **Schema-based composition**: Indicators declare inputs/outputs via schemas,
|
||||
enabling automatic validation and flexible composition.
|
||||
|
||||
2. **Synchronous execution**: All computation is synchronous for simplicity.
|
||||
Async handling happens at the event/strategy layer.
|
||||
|
||||
3. **Incremental updates**: Indicators receive context about what changed,
|
||||
allowing optimized recomputation of only affected values.
|
||||
|
||||
4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs
|
||||
make indicators discoverable and usable by AI agents.
|
||||
|
||||
5. **Generic data flow**: Indicators work with any data source that matches
|
||||
their input schema, not specific DataSource instances.
|
||||
|
||||
6. **Event-driven**: Designed to react to DataSource updates and propagate
|
||||
changes through the DAG efficiently.
|
||||
"""
|
||||
|
||||
from .base import DataSourceAdapter, Indicator
|
||||
from .pipeline import Pipeline, PipelineNode
|
||||
from .registry import IndicatorRegistry
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
from .talib_adapter import (
|
||||
TALibIndicator,
|
||||
register_all_talib_indicators,
|
||||
is_talib_available,
|
||||
get_talib_version,
|
||||
)
|
||||
from .custom_indicators import (
|
||||
register_custom_indicators,
|
||||
CUSTOM_INDICATORS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"Indicator",
|
||||
"IndicatorRegistry",
|
||||
"Pipeline",
|
||||
"PipelineNode",
|
||||
"DataSourceAdapter",
|
||||
# Schema types
|
||||
"InputSchema",
|
||||
"OutputSchema",
|
||||
"ComputeContext",
|
||||
"ComputeResult",
|
||||
"IndicatorMetadata",
|
||||
"IndicatorParameter",
|
||||
# TA-Lib integration
|
||||
"TALibIndicator",
|
||||
"register_all_talib_indicators",
|
||||
"is_talib_available",
|
||||
"get_talib_version",
|
||||
# Custom indicators
|
||||
"register_custom_indicators",
|
||||
"CUSTOM_INDICATORS",
|
||||
]
|
||||
230
backend.old/src/indicator/base.py
Normal file
230
backend.old/src/indicator/base.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Abstract Indicator interface.
|
||||
|
||||
Provides the base class for all technical indicators and derived data transformations.
|
||||
Indicators compose into DAGs, processing data incrementally as updates arrive.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
|
||||
class Indicator(ABC):
|
||||
"""
|
||||
Abstract base class for all indicators.
|
||||
|
||||
Indicators are composable transformation nodes that:
|
||||
- Declare input schema (columns they need)
|
||||
- Declare output schema (columns they produce)
|
||||
- Compute outputs synchronously from inputs
|
||||
- Support incremental updates (process only what changed)
|
||||
- Provide rich metadata for AI agent discovery
|
||||
|
||||
Indicators are stateless at the instance level - all state is managed
|
||||
by the pipeline execution engine. This allows the same indicator class
|
||||
to be reused with different parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_name: str, **params):
|
||||
"""
|
||||
Initialize an indicator instance.
|
||||
|
||||
Args:
|
||||
instance_name: Unique name for this instance (used for output column prefixing)
|
||||
**params: Configuration parameters (validated against metadata.parameters)
|
||||
"""
|
||||
self.instance_name = instance_name
|
||||
self.params = params
|
||||
self._validate_params()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
"""
|
||||
Get metadata for this indicator class.
|
||||
|
||||
Called by the registry for AI agent discovery and documentation.
|
||||
Should return comprehensive information about the indicator's purpose,
|
||||
parameters, and use cases.
|
||||
|
||||
Returns:
|
||||
IndicatorMetadata describing this indicator class
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
"""
|
||||
Get the input schema required by this indicator.
|
||||
|
||||
Declares what columns must be present in the input data.
|
||||
The pipeline will match this against available data sources.
|
||||
|
||||
Returns:
|
||||
InputSchema describing required and optional input columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
"""
|
||||
Get the output schema produced by this indicator.
|
||||
|
||||
Output column names will be automatically prefixed with the instance name
|
||||
by the pipeline engine.
|
||||
|
||||
Args:
|
||||
**params: Configuration parameters (may affect output schema)
|
||||
|
||||
Returns:
|
||||
OutputSchema describing the columns this indicator produces
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
"""
|
||||
Compute indicator values from input data.
|
||||
|
||||
This method is called synchronously by the pipeline engine whenever
|
||||
input data changes. Implementations should:
|
||||
|
||||
1. Extract needed columns from context.data
|
||||
2. Perform calculations
|
||||
3. Return results with proper time alignment
|
||||
|
||||
For incremental updates (context.is_incremental == True):
|
||||
- context.data contains only new/updated rows
|
||||
- Implementations MAY optimize by computing only these rows
|
||||
- OR implementations MAY recompute everything (simpler but slower)
|
||||
|
||||
Args:
|
||||
context: Input data and update metadata
|
||||
|
||||
Returns:
|
||||
ComputeResult with calculated indicator values
|
||||
|
||||
Raises:
|
||||
ValueError: If input data doesn't match expected schema
|
||||
"""
|
||||
pass
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
"""
|
||||
Validate that provided parameters match the metadata specification.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing or invalid
|
||||
"""
|
||||
metadata = self.get_metadata()
|
||||
|
||||
# Check for required parameters
|
||||
for param_def in metadata.parameters:
|
||||
if param_def.required and param_def.name not in self.params:
|
||||
raise ValueError(
|
||||
f"Indicator '{metadata.name}' requires parameter '{param_def.name}'"
|
||||
)
|
||||
|
||||
# Validate parameter types and ranges
|
||||
for name, value in self.params.items():
|
||||
# Find parameter definition
|
||||
param_def = next(
|
||||
(p for p in metadata.parameters if p.name == name),
|
||||
None
|
||||
)
|
||||
|
||||
if param_def is None:
|
||||
raise ValueError(
|
||||
f"Unknown parameter '{name}' for indicator '{metadata.name}'"
|
||||
)
|
||||
|
||||
# Type checking
|
||||
if param_def.type == "int" and not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be int, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "float" and not isinstance(value, (int, float)):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be float, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "bool" and not isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be bool, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "string" and not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be string, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
# Range checking for numeric types
|
||||
if param_def.type in ("int", "float"):
|
||||
if param_def.min_value is not None and value < param_def.min_value:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be >= {param_def.min_value}, got {value}"
|
||||
)
|
||||
if param_def.max_value is not None and value > param_def.max_value:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be <= {param_def.max_value}, got {value}"
|
||||
)
|
||||
|
||||
def get_output_columns(self) -> List[str]:
|
||||
"""
|
||||
Get the output column names with instance name prefix.
|
||||
|
||||
Returns:
|
||||
List of prefixed output column names
|
||||
"""
|
||||
output_schema = self.get_output_schema(**self.params)
|
||||
prefixed = output_schema.with_prefix(self.instance_name)
|
||||
return [col.name for col in prefixed.columns if col.name != output_schema.time_column]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})"
|
||||
|
||||
|
||||
class DataSourceAdapter:
|
||||
"""
|
||||
Adapter to make a DataSource look like an Indicator for pipeline composition.
|
||||
|
||||
This allows DataSources to be inputs to indicators in a unified way.
|
||||
"""
|
||||
|
||||
def __init__(self, datasource_id: str, symbol: str, resolution: str):
|
||||
"""
|
||||
Create a DataSource adapter.
|
||||
|
||||
Args:
|
||||
datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo')
|
||||
symbol: Symbol to query (e.g., 'BTC/USD')
|
||||
resolution: Time resolution (e.g., '1', '5', '1D')
|
||||
"""
|
||||
self.datasource_id = datasource_id
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_")
|
||||
|
||||
def get_output_columns(self) -> List[str]:
|
||||
"""
|
||||
Get the columns provided by this datasource.
|
||||
|
||||
Note: This requires runtime resolution - the pipeline engine
|
||||
will need to query the actual DataSource to get the schema.
|
||||
|
||||
Returns:
|
||||
List of column names (placeholder - needs runtime resolution)
|
||||
"""
|
||||
# This will be resolved at runtime by the pipeline engine
|
||||
return []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')"
|
||||
954
backend.old/src/indicator/custom_indicators.py
Normal file
954
backend.old/src/indicator/custom_indicators.py
Normal file
@@ -0,0 +1,954 @@
|
||||
"""
|
||||
Custom indicator implementations for TradingView indicators not in TA-Lib.
|
||||
|
||||
These indicators follow TA-Lib style conventions and integrate seamlessly
|
||||
with the indicator framework. All implementations are based on well-known,
|
||||
publicly documented formulas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VWAP(Indicator):
|
||||
"""Volume Weighted Average Price - Most widely used institutional indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWAP",
|
||||
display_name="VWAP",
|
||||
description="Volume Weighted Average Price - Average price weighted by volume",
|
||||
category="volume",
|
||||
parameters=[],
|
||||
use_cases=[
|
||||
"Institutional reference price",
|
||||
"Support/resistance levels",
|
||||
"Mean reversion trading"
|
||||
],
|
||||
references=["https://www.investopedia.com/terms/v/vwap.asp"],
|
||||
tags=["vwap", "volume", "institutional"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwap", type="float", description="Volume Weighted Average Price", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
|
||||
# Typical price
|
||||
typical_price = (high + low + close) / 3.0
|
||||
|
||||
# VWAP = cumsum(typical_price * volume) / cumsum(volume)
|
||||
cumulative_tp_vol = np.nancumsum(typical_price * volume)
|
||||
cumulative_vol = np.nancumsum(volume)
|
||||
|
||||
vwap = cumulative_tp_vol / cumulative_vol
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwap": float(vwap[i]) if not np.isnan(vwap[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VWMA(Indicator):
|
||||
"""Volume Weighted Moving Average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWMA",
|
||||
display_name="VWMA",
|
||||
description="Volume Weighted Moving Average - Moving average weighted by volume",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volume-aware trend following", "Dynamic support/resistance"],
|
||||
references=["https://www.investopedia.com/articles/trading/11/trading-with-vwap-mvwap.asp"],
|
||||
tags=["vwma", "volume", "moving average"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwma", type="float", description="Volume Weighted Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
vwma = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
window_close = close[i - length + 1:i + 1]
|
||||
window_volume = volume[i - length + 1:i + 1]
|
||||
vwma[i] = np.sum(window_close * window_volume) / np.sum(window_volume)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwma": float(vwma[i]) if not np.isnan(vwma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class HullMA(Indicator):
|
||||
"""Hull Moving Average - Fast and smooth moving average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="HMA",
|
||||
display_name="Hull Moving Average",
|
||||
description="Hull Moving Average - Reduces lag while maintaining smoothness",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Low-lag trend following", "Quick trend reversal detection"],
|
||||
references=["https://alanhull.com/hull-moving-average"],
|
||||
tags=["hma", "hull", "moving average", "low-lag"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="hma", type="float", description="Hull Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 9)
|
||||
|
||||
def wma(data, period):
|
||||
"""Weighted Moving Average."""
|
||||
weights = np.arange(1, period + 1)
|
||||
result = np.full_like(data, np.nan)
|
||||
for i in range(period - 1, len(data)):
|
||||
window = data[i - period + 1:i + 1]
|
||||
result[i] = np.sum(weights * window) / np.sum(weights)
|
||||
return result
|
||||
|
||||
# HMA = WMA(2 * WMA(n/2) - WMA(n)), sqrt(n))
|
||||
half_length = length // 2
|
||||
sqrt_length = int(np.sqrt(length))
|
||||
|
||||
wma_half = wma(close, half_length)
|
||||
wma_full = wma(close, length)
|
||||
raw_hma = 2 * wma_half - wma_full
|
||||
hma = wma(raw_hma, sqrt_length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "hma": float(hma[i]) if not np.isnan(hma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class SuperTrend(Indicator):
|
||||
"""SuperTrend - Popular trend following indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="SUPERTREND",
|
||||
display_name="SuperTrend",
|
||||
description="SuperTrend - Volatility-based trend indicator",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=3.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Stop loss placement", "Trend reversal signals"],
|
||||
references=["https://www.investopedia.com/articles/trading/08/supertrend-indicator.asp"],
|
||||
tags=["supertrend", "trend", "volatility"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="supertrend", type="float", description="SuperTrend value", nullable=True),
|
||||
ColumnInfo(name="direction", type="int", description="Trend direction (1=up, -1=down)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 10)
|
||||
multiplier = self.params.get("multiplier", 3.0)
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[length - 1] = np.mean(tr[:length])
|
||||
for i in range(length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (length - 1) + tr[i]) / length
|
||||
|
||||
# Calculate basic bands
|
||||
hl2 = (high + low) / 2
|
||||
basic_upper = hl2 + multiplier * atr
|
||||
basic_lower = hl2 - multiplier * atr
|
||||
|
||||
# Calculate final bands
|
||||
final_upper = np.full_like(close, np.nan)
|
||||
final_lower = np.full_like(close, np.nan)
|
||||
supertrend = np.full_like(close, np.nan)
|
||||
direction = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length, len(close)):
|
||||
if i == length:
|
||||
final_upper[i] = basic_upper[i]
|
||||
final_lower[i] = basic_lower[i]
|
||||
else:
|
||||
final_upper[i] = basic_upper[i] if basic_upper[i] < final_upper[i - 1] or close[i - 1] > final_upper[i - 1] else final_upper[i - 1]
|
||||
final_lower[i] = basic_lower[i] if basic_lower[i] > final_lower[i - 1] or close[i - 1] < final_lower[i - 1] else final_lower[i - 1]
|
||||
|
||||
if i == length:
|
||||
supertrend[i] = final_upper[i] if close[i] <= hl2[i] else final_lower[i]
|
||||
direction[i] = -1 if close[i] <= hl2[i] else 1
|
||||
else:
|
||||
if supertrend[i - 1] == final_upper[i - 1] and close[i] <= final_upper[i]:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
elif supertrend[i - 1] == final_upper[i - 1] and close[i] > final_upper[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
elif supertrend[i - 1] == final_lower[i - 1] and close[i] >= final_lower[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
else:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"supertrend": float(supertrend[i]) if not np.isnan(supertrend[i]) else None,
|
||||
"direction": int(direction[i]) if not np.isnan(direction[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class DonchianChannels(Indicator):
|
||||
"""Donchian Channels - Breakout indicator using highest high and lowest low."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="DONCHIAN",
|
||||
display_name="Donchian Channels",
|
||||
description="Donchian Channels - Highest high and lowest low over period",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Breakout trading", "Volatility bands", "Support/resistance"],
|
||||
references=["https://www.investopedia.com/terms/d/donchianchannels.asp"],
|
||||
tags=["donchian", "channels", "breakout"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper channel", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower channel", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
upper = np.full_like(high, np.nan)
|
||||
lower = np.full_like(low, np.nan)
|
||||
|
||||
for i in range(length - 1, len(high)):
|
||||
upper[i] = np.nanmax(high[i - length + 1:i + 1])
|
||||
lower[i] = np.nanmin(low[i - length + 1:i + 1])
|
||||
|
||||
middle = (upper + lower) / 2
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(middle[i]) if not np.isnan(middle[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class KeltnerChannels(Indicator):
|
||||
"""Keltner Channels - ATR-based volatility bands."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="KELTNER",
|
||||
display_name="Keltner Channels",
|
||||
description="Keltner Channels - EMA with ATR-based bands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="EMA period",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=2.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="atr_length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volatility bands", "Overbought/oversold", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/k/keltnerchannel.asp"],
|
||||
tags=["keltner", "channels", "volatility", "atr"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper band", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line (EMA)", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower band", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 20)
|
||||
multiplier = self.params.get("multiplier", 2.0)
|
||||
atr_length = self.params.get("atr_length", 10)
|
||||
|
||||
# Calculate EMA
|
||||
alpha = 2.0 / (length + 1)
|
||||
ema = np.full_like(close, np.nan)
|
||||
ema[0] = close[0]
|
||||
for i in range(1, len(close)):
|
||||
ema[i] = alpha * close[i] + (1 - alpha) * ema[i - 1]
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[atr_length - 1] = np.mean(tr[:atr_length])
|
||||
for i in range(atr_length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (atr_length - 1) + tr[i]) / atr_length
|
||||
|
||||
upper = ema + multiplier * atr
|
||||
lower = ema - multiplier * atr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(ema[i]) if not np.isnan(ema[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChaikinMoneyFlow(Indicator):
|
||||
"""Chaikin Money Flow - Volume-weighted accumulation/distribution."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CMF",
|
||||
display_name="Chaikin Money Flow",
|
||||
description="Chaikin Money Flow - Measures buying and selling pressure",
|
||||
category="volume",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Buying/selling pressure", "Trend confirmation", "Divergence analysis"],
|
||||
references=["https://www.investopedia.com/terms/c/chaikinoscillator.asp"],
|
||||
tags=["cmf", "chaikin", "volume", "money flow"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="cmf", type="float", description="Chaikin Money Flow", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
# Money Flow Multiplier
|
||||
mfm = ((close - low) - (high - close)) / (high - low)
|
||||
mfm = np.where(high == low, 0, mfm)
|
||||
|
||||
# Money Flow Volume
|
||||
mfv = mfm * volume
|
||||
|
||||
# CMF
|
||||
cmf = np.full_like(close, np.nan)
|
||||
for i in range(length - 1, len(close)):
|
||||
cmf[i] = np.nansum(mfv[i - length + 1:i + 1]) / np.nansum(volume[i - length + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "cmf": float(cmf[i]) if not np.isnan(cmf[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VortexIndicator(Indicator):
|
||||
"""Vortex Indicator - Identifies trend direction and strength."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VORTEX",
|
||||
display_name="Vortex Indicator",
|
||||
description="Vortex Indicator - Trend direction and strength",
|
||||
category="momentum",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Trend reversals", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/v/vortex-indicator-vi.asp"],
|
||||
tags=["vortex", "trend", "momentum"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vi_plus", type="float", description="Positive Vortex", nullable=True),
|
||||
ColumnInfo(name="vi_minus", type="float", description="Negative Vortex", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# Vortex Movement
|
||||
vm_plus = np.abs(high - np.roll(low, 1))
|
||||
vm_minus = np.abs(low - np.roll(high, 1))
|
||||
vm_plus[0] = 0
|
||||
vm_minus[0] = 0
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
# Vortex Indicator
|
||||
vi_plus = np.full_like(close, np.nan)
|
||||
vi_minus = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_vm_plus = np.sum(vm_plus[i - length + 1:i + 1])
|
||||
sum_vm_minus = np.sum(vm_minus[i - length + 1:i + 1])
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
|
||||
if sum_tr != 0:
|
||||
vi_plus[i] = sum_vm_plus / sum_tr
|
||||
vi_minus[i] = sum_vm_minus / sum_tr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"vi_plus": float(vi_plus[i]) if not np.isnan(vi_plus[i]) else None,
|
||||
"vi_minus": float(vi_minus[i]) if not np.isnan(vi_minus[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AwesomeOscillator(Indicator):
|
||||
"""Awesome Oscillator - Bill Williams' momentum indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AO",
|
||||
display_name="Awesome Oscillator",
|
||||
description="Awesome Oscillator - Difference between 5 and 34 period SMAs of midpoint",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Momentum shifts", "Trend reversals", "Divergence trading"],
|
||||
references=["https://www.investopedia.com/terms/a/awesomeoscillator.asp"],
|
||||
tags=["awesome", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ao", type="float", description="Awesome Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# SMA 5
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
# SMA 34
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ao": float(ao[i]) if not np.isnan(ao[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AcceleratorOscillator(Indicator):
|
||||
"""Accelerator Oscillator - Rate of change of Awesome Oscillator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AC",
|
||||
display_name="Accelerator Oscillator",
|
||||
description="Accelerator Oscillator - Rate of change of Awesome Oscillator",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Early momentum detection", "Trend acceleration", "Divergence signals"],
|
||||
references=["https://www.investopedia.com/terms/a/accelerator-oscillator.asp"],
|
||||
tags=["accelerator", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ac", type="float", description="Accelerator Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# Calculate AO first
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
# AC = AO - SMA(AO, 5)
|
||||
sma_ao = np.full_like(ao, np.nan)
|
||||
for i in range(4, len(ao)):
|
||||
if not np.isnan(ao[i - 4:i + 1]).any():
|
||||
sma_ao[i] = np.mean(ao[i - 4:i + 1])
|
||||
|
||||
ac = ao - sma_ao
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ac": float(ac[i]) if not np.isnan(ac[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChoppinessIndex(Indicator):
|
||||
"""Choppiness Index - Determines if market is choppy or trending."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CHOP",
|
||||
display_name="Choppiness Index",
|
||||
description="Choppiness Index - Measures market trendiness vs consolidation",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend vs range identification", "Market regime detection"],
|
||||
references=["https://www.tradingview.com/support/solutions/43000501980/"],
|
||||
tags=["chop", "choppiness", "trend", "range"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="chop", type="float", description="Choppiness Index (0-100)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
chop = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
high_low_diff = np.max(high[i - length + 1:i + 1]) - np.min(low[i - length + 1:i + 1])
|
||||
|
||||
if high_low_diff != 0:
|
||||
chop[i] = 100 * np.log10(sum_tr / high_low_diff) / np.log10(length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "chop": float(chop[i]) if not np.isnan(chop[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class MassIndex(Indicator):
|
||||
"""Mass Index - Identifies trend reversals based on range expansion."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="MASS",
|
||||
display_name="Mass Index",
|
||||
description="Mass Index - Identifies reversals when range narrows then expands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="fast_period",
|
||||
type="int",
|
||||
description="Fast EMA period",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="slow_period",
|
||||
type="int",
|
||||
description="Slow EMA period",
|
||||
default=25,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Reversal detection", "Volatility analysis", "Bulge identification"],
|
||||
references=["https://www.investopedia.com/terms/m/mass-index.asp"],
|
||||
tags=["mass", "index", "volatility", "reversal"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="mass", type="float", description="Mass Index", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
fast_period = self.params.get("fast_period", 9)
|
||||
slow_period = self.params.get("slow_period", 25)
|
||||
|
||||
hl_range = high - low
|
||||
|
||||
# Single EMA
|
||||
alpha1 = 2.0 / (fast_period + 1)
|
||||
ema1 = np.full_like(hl_range, np.nan)
|
||||
ema1[0] = hl_range[0]
|
||||
for i in range(1, len(hl_range)):
|
||||
ema1[i] = alpha1 * hl_range[i] + (1 - alpha1) * ema1[i - 1]
|
||||
|
||||
# Double EMA
|
||||
ema2 = np.full_like(ema1, np.nan)
|
||||
ema2[0] = ema1[0]
|
||||
for i in range(1, len(ema1)):
|
||||
if not np.isnan(ema1[i]):
|
||||
ema2[i] = alpha1 * ema1[i] + (1 - alpha1) * ema2[i - 1]
|
||||
|
||||
# EMA Ratio
|
||||
ema_ratio = ema1 / ema2
|
||||
|
||||
# Mass Index
|
||||
mass = np.full_like(hl_range, np.nan)
|
||||
for i in range(slow_period - 1, len(ema_ratio)):
|
||||
mass[i] = np.nansum(ema_ratio[i - slow_period + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "mass": float(mass[i]) if not np.isnan(mass[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
# Registry of all custom indicators
|
||||
CUSTOM_INDICATORS = [
|
||||
VWAP,
|
||||
VWMA,
|
||||
HullMA,
|
||||
SuperTrend,
|
||||
DonchianChannels,
|
||||
KeltnerChannels,
|
||||
ChaikinMoneyFlow,
|
||||
VortexIndicator,
|
||||
AwesomeOscillator,
|
||||
AcceleratorOscillator,
|
||||
ChoppinessIndex,
|
||||
MassIndex,
|
||||
]
|
||||
|
||||
|
||||
def register_custom_indicators(registry) -> int:
|
||||
"""
|
||||
Register all custom indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
"""
|
||||
registered_count = 0
|
||||
|
||||
for indicator_class in CUSTOM_INDICATORS:
|
||||
try:
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
logger.debug(f"Registered custom indicator: {indicator_class.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicator {indicator_class.__name__}: {e}")
|
||||
|
||||
logger.info(f"Registered {registered_count} custom indicators")
|
||||
return registered_count
|
||||
439
backend.old/src/indicator/pipeline.py
Normal file
439
backend.old/src/indicator/pipeline.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Pipeline execution engine for composable indicators.
|
||||
|
||||
Manages DAG construction, dependency resolution, incremental updates,
|
||||
and efficient data flow through indicator chains.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from datasource.base import DataSource
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
from .base import DataSourceAdapter, Indicator
|
||||
from .schema import ComputeContext, ComputeResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineNode:
|
||||
"""
|
||||
A node in the pipeline DAG.
|
||||
|
||||
Can be either a DataSource adapter or an Indicator instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
node: Union[DataSourceAdapter, Indicator],
|
||||
dependencies: List[str]
|
||||
):
|
||||
"""
|
||||
Create a pipeline node.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
node: The DataSourceAdapter or Indicator instance
|
||||
dependencies: List of node_ids this node depends on
|
||||
"""
|
||||
self.node_id = node_id
|
||||
self.node = node
|
||||
self.dependencies = dependencies
|
||||
self.output_columns: List[str] = []
|
||||
self.cached_data: List[Dict[str, Any]] = []
|
||||
|
||||
def is_datasource(self) -> bool:
|
||||
"""Check if this node is a DataSource adapter."""
|
||||
return isinstance(self.node, DataSourceAdapter)
|
||||
|
||||
def is_indicator(self) -> bool:
|
||||
"""Check if this node is an Indicator."""
|
||||
return isinstance(self.node, Indicator)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Execution engine for indicator DAGs.
|
||||
|
||||
Manages:
|
||||
- DAG construction and validation
|
||||
- Topological sorting for execution order
|
||||
- Data flow and caching
|
||||
- Incremental updates (only recompute what changed)
|
||||
- Schema validation
|
||||
"""
|
||||
|
||||
def __init__(self, datasource_registry):
|
||||
"""
|
||||
Initialize a pipeline.
|
||||
|
||||
Args:
|
||||
datasource_registry: DataSourceRegistry for resolving data sources
|
||||
"""
|
||||
self.datasource_registry = datasource_registry
|
||||
self.nodes: Dict[str, PipelineNode] = {}
|
||||
self.execution_order: List[str] = []
|
||||
self._dirty_nodes: Set[str] = set()
|
||||
|
||||
def add_datasource(
|
||||
self,
|
||||
node_id: str,
|
||||
datasource_name: str,
|
||||
symbol: str,
|
||||
resolution: str
|
||||
) -> None:
|
||||
"""
|
||||
Add a DataSource to the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
datasource_name: Name of the datasource in the registry
|
||||
symbol: Symbol to query
|
||||
resolution: Time resolution
|
||||
|
||||
Raises:
|
||||
ValueError: If node_id already exists or datasource not found
|
||||
"""
|
||||
if node_id in self.nodes:
|
||||
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||
|
||||
datasource = self.datasource_registry.get(datasource_name)
|
||||
if not datasource:
|
||||
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
|
||||
|
||||
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
|
||||
node = PipelineNode(node_id, adapter, dependencies=[])
|
||||
|
||||
self.nodes[node_id] = node
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
|
||||
|
||||
def add_indicator(
|
||||
self,
|
||||
node_id: str,
|
||||
indicator: Indicator,
|
||||
input_node_ids: List[str]
|
||||
) -> None:
|
||||
"""
|
||||
Add an Indicator to the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
indicator: Indicator instance
|
||||
input_node_ids: List of node IDs providing input data
|
||||
|
||||
Raises:
|
||||
ValueError: If node_id already exists, dependencies not found, or schema mismatch
|
||||
"""
|
||||
if node_id in self.nodes:
|
||||
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||
|
||||
# Validate dependencies exist
|
||||
for dep_id in input_node_ids:
|
||||
if dep_id not in self.nodes:
|
||||
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
|
||||
|
||||
# TODO: Validate input schema matches available columns from dependencies
|
||||
# This requires merging output schemas from all input nodes
|
||||
|
||||
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
|
||||
self.nodes[node_id] = node
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
|
||||
|
||||
def remove_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Remove a node from the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Node to remove
|
||||
|
||||
Raises:
|
||||
ValueError: If other nodes depend on this node
|
||||
"""
|
||||
if node_id not in self.nodes:
|
||||
return
|
||||
|
||||
# Check for dependent nodes
|
||||
dependents = [
|
||||
n.node_id for n in self.nodes.values()
|
||||
if node_id in n.dependencies
|
||||
]
|
||||
|
||||
if dependents:
|
||||
raise ValueError(
|
||||
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
|
||||
)
|
||||
|
||||
del self.nodes[node_id]
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Removed node '{node_id}' from pipeline")
|
||||
|
||||
def _invalidate_execution_order(self) -> None:
|
||||
"""Mark execution order as needing recomputation."""
|
||||
self.execution_order = []
|
||||
|
||||
def _compute_execution_order(self) -> List[str]:
|
||||
"""
|
||||
Compute topological sort of the DAG.
|
||||
|
||||
Returns:
|
||||
List of node IDs in execution order
|
||||
|
||||
Raises:
|
||||
ValueError: If DAG contains cycles
|
||||
"""
|
||||
if self.execution_order:
|
||||
return self.execution_order
|
||||
|
||||
# Kahn's algorithm for topological sort
|
||||
in_degree = {node_id: 0 for node_id in self.nodes}
|
||||
for node in self.nodes.values():
|
||||
for dep in node.dependencies:
|
||||
in_degree[node.node_id] += 1
|
||||
|
||||
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
node_id = queue.popleft()
|
||||
result.append(node_id)
|
||||
|
||||
# Find all nodes that depend on this one
|
||||
for other_node in self.nodes.values():
|
||||
if node_id in other_node.dependencies:
|
||||
in_degree[other_node.node_id] -= 1
|
||||
if in_degree[other_node.node_id] == 0:
|
||||
queue.append(other_node.node_id)
|
||||
|
||||
if len(result) != len(self.nodes):
|
||||
raise ValueError("Pipeline contains cycles")
|
||||
|
||||
self.execution_order = result
|
||||
logger.debug(f"Computed execution order: {result}")
|
||||
return result
|
||||
|
||||
def execute(
|
||||
self,
|
||||
datasource_data: Dict[str, List[Dict[str, Any]]],
|
||||
incremental: bool = False,
|
||||
updated_from_time: Optional[int] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Execute the pipeline.
|
||||
|
||||
Args:
|
||||
datasource_data: Mapping of DataSource node_id to input data
|
||||
incremental: Whether this is an incremental update
|
||||
updated_from_time: Timestamp of earliest updated row (for incremental)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node_id to output data (all nodes)
|
||||
|
||||
Raises:
|
||||
ValueError: If required datasource data is missing
|
||||
"""
|
||||
execution_order = self._compute_execution_order()
|
||||
results: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
logger.info(
|
||||
f"Executing pipeline with {len(execution_order)} nodes "
|
||||
f"(incremental={incremental})"
|
||||
)
|
||||
|
||||
for node_id in execution_order:
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if node.is_datasource():
|
||||
# DataSource node - get data from input
|
||||
if node_id not in datasource_data:
|
||||
raise ValueError(
|
||||
f"DataSource node '{node_id}' has no input data"
|
||||
)
|
||||
results[node_id] = datasource_data[node_id]
|
||||
node.cached_data = results[node_id]
|
||||
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
|
||||
|
||||
elif node.is_indicator():
|
||||
# Indicator node - compute from dependencies
|
||||
indicator = node.node
|
||||
|
||||
# Merge input data from all dependencies
|
||||
input_data = self._merge_dependency_data(node.dependencies, results)
|
||||
|
||||
# Create compute context
|
||||
context = ComputeContext(
|
||||
data=input_data,
|
||||
is_incremental=incremental,
|
||||
updated_from_time=updated_from_time
|
||||
)
|
||||
|
||||
# Execute indicator
|
||||
logger.debug(
|
||||
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
|
||||
)
|
||||
compute_result = indicator.compute(context)
|
||||
|
||||
# Merge result with input data (adding prefixed columns)
|
||||
output_data = compute_result.merge_with_prefix(
|
||||
indicator.instance_name,
|
||||
input_data
|
||||
)
|
||||
|
||||
results[node_id] = output_data
|
||||
node.cached_data = output_data
|
||||
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
|
||||
|
||||
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
|
||||
return results
|
||||
|
||||
def _merge_dependency_data(
|
||||
self,
|
||||
dependency_ids: List[str],
|
||||
results: Dict[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge data from multiple dependency nodes.
|
||||
|
||||
Data is merged by time, with later dependencies overwriting earlier ones
|
||||
for conflicting column names.
|
||||
|
||||
Args:
|
||||
dependency_ids: List of node IDs to merge
|
||||
results: Current execution results
|
||||
|
||||
Returns:
|
||||
Merged data rows
|
||||
"""
|
||||
if not dependency_ids:
|
||||
return []
|
||||
|
||||
if len(dependency_ids) == 1:
|
||||
return results[dependency_ids[0]]
|
||||
|
||||
# Build time-indexed data from first dependency
|
||||
merged: Dict[int, Dict[str, Any]] = {}
|
||||
for row in results[dependency_ids[0]]:
|
||||
merged[row["time"]] = row.copy()
|
||||
|
||||
# Merge in additional dependencies
|
||||
for dep_id in dependency_ids[1:]:
|
||||
for row in results[dep_id]:
|
||||
time_key = row["time"]
|
||||
if time_key in merged:
|
||||
# Merge columns (later dependencies win)
|
||||
merged[time_key].update(row)
|
||||
else:
|
||||
# New timestamp
|
||||
merged[time_key] = row.copy()
|
||||
|
||||
# Sort by time and return
|
||||
sorted_times = sorted(merged.keys())
|
||||
return [merged[t] for t in sorted_times]
|
||||
|
||||
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached output data for a specific node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
Cached data or None if not available
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
return node.cached_data if node else None
|
||||
|
||||
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
|
||||
"""
|
||||
Get the output schema for a specific node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
List of ColumnInfo describing output columns
|
||||
|
||||
Raises:
|
||||
ValueError: If node not found
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
if not node:
|
||||
raise ValueError(f"Node '{node_id}' not found")
|
||||
|
||||
if node.is_datasource():
|
||||
# Would need to query the actual datasource at runtime
|
||||
# For now, return empty - this requires integration with DataSource
|
||||
return []
|
||||
|
||||
elif node.is_indicator():
|
||||
indicator = node.node
|
||||
output_schema = indicator.get_output_schema(**indicator.params)
|
||||
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
|
||||
return prefixed_schema.columns
|
||||
|
||||
return []
|
||||
|
||||
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate the entire pipeline for correctness.
|
||||
|
||||
Checks:
|
||||
- No cycles (already checked in execution order)
|
||||
- All dependencies exist (already checked in add_indicator)
|
||||
- Input schemas match output schemas (TODO)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
self._compute_execution_order()
|
||||
return True, None
|
||||
except ValueError as e:
|
||||
return False, str(e)
|
||||
|
||||
def get_node_count(self) -> int:
|
||||
"""Get the number of nodes in the pipeline."""
|
||||
return len(self.nodes)
|
||||
|
||||
def get_indicator_count(self) -> int:
|
||||
"""Get the number of indicator nodes in the pipeline."""
|
||||
return sum(1 for node in self.nodes.values() if node.is_indicator())
|
||||
|
||||
def get_datasource_count(self) -> int:
|
||||
"""Get the number of datasource nodes in the pipeline."""
|
||||
return sum(1 for node in self.nodes.values() if node.is_datasource())
|
||||
|
||||
def describe(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a detailed description of the pipeline structure.
|
||||
|
||||
Returns:
|
||||
Dictionary with pipeline metadata and structure
|
||||
"""
|
||||
return {
|
||||
"node_count": self.get_node_count(),
|
||||
"datasource_count": self.get_datasource_count(),
|
||||
"indicator_count": self.get_indicator_count(),
|
||||
"nodes": [
|
||||
{
|
||||
"id": node.node_id,
|
||||
"type": "datasource" if node.is_datasource() else "indicator",
|
||||
"node": str(node.node),
|
||||
"dependencies": node.dependencies,
|
||||
"cached_rows": len(node.cached_data)
|
||||
}
|
||||
for node in self.nodes.values()
|
||||
],
|
||||
"execution_order": self.execution_order or self._compute_execution_order(),
|
||||
"is_valid": self.validate_pipeline()[0]
|
||||
}
|
||||
349
backend.old/src/indicator/registry.py
Normal file
349
backend.old/src/indicator/registry.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Indicator registry for managing and discovering indicators.
|
||||
|
||||
Provides AI agents with a queryable catalog of available indicators,
|
||||
their capabilities, and metadata.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from .base import Indicator
|
||||
from .schema import IndicatorMetadata, InputSchema, OutputSchema
|
||||
|
||||
|
||||
class IndicatorRegistry:
|
||||
"""
|
||||
Central registry for indicator classes.
|
||||
|
||||
Enables:
|
||||
- Registration of indicator implementations
|
||||
- Discovery by name, category, or tags
|
||||
- Schema validation
|
||||
- AI agent tool generation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._indicators: Dict[str, Type[Indicator]] = {}
|
||||
|
||||
def register(self, indicator_class: Type[Indicator]) -> None:
|
||||
"""
|
||||
Register an indicator class.
|
||||
|
||||
Args:
|
||||
indicator_class: Indicator class to register
|
||||
|
||||
Raises:
|
||||
ValueError: If an indicator with this name is already registered
|
||||
"""
|
||||
metadata = indicator_class.get_metadata()
|
||||
|
||||
if metadata.name in self._indicators:
|
||||
raise ValueError(
|
||||
f"Indicator '{metadata.name}' is already registered"
|
||||
)
|
||||
|
||||
self._indicators[metadata.name] = indicator_class
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""
|
||||
Unregister an indicator class.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
"""
|
||||
self._indicators.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Optional[Type[Indicator]]:
|
||||
"""
|
||||
Get an indicator class by name.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
Indicator class or None if not found
|
||||
"""
|
||||
return self._indicators.get(name)
|
||||
|
||||
def list_indicators(self) -> List[str]:
|
||||
"""
|
||||
Get names of all registered indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator class names
|
||||
"""
|
||||
return list(self._indicators.keys())
|
||||
|
||||
def get_metadata(self, name: str) -> Optional[IndicatorMetadata]:
|
||||
"""
|
||||
Get metadata for a specific indicator.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
IndicatorMetadata or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_metadata()
|
||||
return None
|
||||
|
||||
def get_all_metadata(self) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Get metadata for all registered indicators.
|
||||
|
||||
Useful for AI agent tool generation and discovery.
|
||||
|
||||
Returns:
|
||||
List of IndicatorMetadata for all registered indicators
|
||||
"""
|
||||
return [cls.get_metadata() for cls in self._indicators.values()]
|
||||
|
||||
def search_by_category(self, category: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators by category.
|
||||
|
||||
Args:
|
||||
category: Category name (e.g., 'momentum', 'trend', 'volatility')
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata
|
||||
"""
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
if metadata.category.lower() == category.lower():
|
||||
results.append(metadata)
|
||||
return results
|
||||
|
||||
def search_by_tag(self, tag: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators by tag.
|
||||
|
||||
Args:
|
||||
tag: Tag to search for (case-insensitive)
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata
|
||||
"""
|
||||
tag_lower = tag.lower()
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
if any(t.lower() == tag_lower for t in metadata.tags):
|
||||
results.append(metadata)
|
||||
return results
|
||||
|
||||
def search_by_text(self, query: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Full-text search across indicator names, descriptions, and use cases.
|
||||
|
||||
Args:
|
||||
query: Search query (case-insensitive)
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata, ranked by relevance
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
score = 0
|
||||
|
||||
# Check name (highest weight)
|
||||
if query_lower in metadata.name.lower():
|
||||
score += 10
|
||||
if query_lower in metadata.display_name.lower():
|
||||
score += 8
|
||||
|
||||
# Check description
|
||||
if query_lower in metadata.description.lower():
|
||||
score += 5
|
||||
|
||||
# Check use cases
|
||||
for use_case in metadata.use_cases:
|
||||
if query_lower in use_case.lower():
|
||||
score += 3
|
||||
|
||||
# Check tags
|
||||
for tag in metadata.tags:
|
||||
if query_lower in tag.lower():
|
||||
score += 2
|
||||
|
||||
if score > 0:
|
||||
results.append((score, metadata))
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
return [metadata for _, metadata in results]
|
||||
|
||||
def find_compatible_indicators(
|
||||
self,
|
||||
available_columns: List[str],
|
||||
column_types: Dict[str, str]
|
||||
) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators that can be computed from available columns.
|
||||
|
||||
Args:
|
||||
available_columns: List of column names available
|
||||
column_types: Mapping of column name to type
|
||||
|
||||
Returns:
|
||||
List of indicators whose input schema is satisfied
|
||||
"""
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
# Build ColumnInfo list from available data
|
||||
available_schema = [
|
||||
ColumnInfo(
|
||||
name=name,
|
||||
type=column_types.get(name, "float"),
|
||||
description=f"Column {name}"
|
||||
)
|
||||
for name in available_columns
|
||||
]
|
||||
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
input_schema = indicator_class.get_input_schema()
|
||||
if input_schema.matches(available_schema):
|
||||
results.append(indicator_class.get_metadata())
|
||||
|
||||
return results
|
||||
|
||||
def validate_indicator_chain(
|
||||
self,
|
||||
indicator_chain: List[tuple[str, Dict]]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate that a chain of indicators can be connected.
|
||||
|
||||
Args:
|
||||
indicator_chain: List of (indicator_name, params) tuples in execution order
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not indicator_chain:
|
||||
return True, None
|
||||
|
||||
# For now, just check that all indicators exist
|
||||
# More sophisticated DAG validation happens in the pipeline engine
|
||||
for indicator_name, params in indicator_chain:
|
||||
if indicator_name not in self._indicators:
|
||||
return False, f"Indicator '{indicator_name}' not found in registry"
|
||||
|
||||
return True, None
|
||||
|
||||
def get_input_schema(self, name: str) -> Optional[InputSchema]:
|
||||
"""
|
||||
Get input schema for a specific indicator.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
InputSchema or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_input_schema()
|
||||
return None
|
||||
|
||||
def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]:
|
||||
"""
|
||||
Get output schema for a specific indicator with given parameters.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
**params: Indicator parameters
|
||||
|
||||
Returns:
|
||||
OutputSchema or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_output_schema(**params)
|
||||
return None
|
||||
|
||||
def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]:
|
||||
"""
|
||||
Create an indicator instance with validation.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
instance_name: Unique instance name (for output column prefixing)
|
||||
**params: Indicator configuration parameters
|
||||
|
||||
Returns:
|
||||
Indicator instance or None if class not found
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if not indicator_class:
|
||||
return None
|
||||
|
||||
return indicator_class(instance_name=instance_name, **params)
|
||||
|
||||
def generate_ai_tool_spec(self) -> Dict:
|
||||
"""
|
||||
Generate a JSON specification for AI agent tools.
|
||||
|
||||
Creates a structured representation of all indicators that can be
|
||||
used to build agent tools for indicator selection and composition.
|
||||
|
||||
Returns:
|
||||
Dict suitable for AI agent tool registration
|
||||
"""
|
||||
tools = []
|
||||
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
|
||||
# Build parameter spec
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
for param in metadata.parameters:
|
||||
param_spec = {
|
||||
"type": param.type,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
param_spec["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
param_spec["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
param_spec["maximum"] = param.max_value
|
||||
|
||||
parameters["properties"][param.name] = param_spec
|
||||
|
||||
if param.required:
|
||||
parameters["required"].append(param.name)
|
||||
|
||||
tool = {
|
||||
"name": f"indicator_{metadata.name.lower()}",
|
||||
"description": f"{metadata.display_name}: {metadata.description}",
|
||||
"category": metadata.category,
|
||||
"use_cases": metadata.use_cases,
|
||||
"tags": metadata.tags,
|
||||
"parameters": parameters,
|
||||
"input_schema": indicator_class.get_input_schema().model_dump(),
|
||||
"output_schema": indicator_class.get_output_schema().model_dump()
|
||||
}
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return {
|
||||
"indicator_tools": tools,
|
||||
"total_count": len(tools)
|
||||
}
|
||||
269
backend.old/src/indicator/schema.py
Normal file
269
backend.old/src/indicator/schema.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Data models for the Indicator system.
|
||||
|
||||
Defines schemas for input/output specifications, computation context,
|
||||
and metadata for AI agent discovery.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
"""
|
||||
Declares the required input columns for an Indicator.
|
||||
|
||||
Indicators match against any data source (DataSource or other Indicator)
|
||||
that provides columns satisfying this schema.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
required_columns: List[ColumnInfo] = Field(
|
||||
description="Columns that must be present in the input data"
|
||||
)
|
||||
optional_columns: List[ColumnInfo] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that may be used if present but are not required"
|
||||
)
|
||||
time_column: str = Field(
|
||||
default="time",
|
||||
description="Name of the timestamp column (must be present)"
|
||||
)
|
||||
|
||||
def matches(self, available_columns: List[ColumnInfo]) -> bool:
|
||||
"""
|
||||
Check if available columns satisfy this input schema.
|
||||
|
||||
Args:
|
||||
available_columns: Columns provided by a data source
|
||||
|
||||
Returns:
|
||||
True if all required columns are present with compatible types
|
||||
"""
|
||||
available_map = {col.name: col for col in available_columns}
|
||||
|
||||
# Check time column exists
|
||||
if self.time_column not in available_map:
|
||||
return False
|
||||
|
||||
# Check all required columns exist with compatible types
|
||||
for required in self.required_columns:
|
||||
if required.name not in available_map:
|
||||
return False
|
||||
available = available_map[required.name]
|
||||
if available.type != required.type:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]:
|
||||
"""
|
||||
Get list of missing required column names.
|
||||
|
||||
Args:
|
||||
available_columns: Columns provided by a data source
|
||||
|
||||
Returns:
|
||||
List of missing column names
|
||||
"""
|
||||
available_names = {col.name for col in available_columns}
|
||||
missing = []
|
||||
|
||||
if self.time_column not in available_names:
|
||||
missing.append(self.time_column)
|
||||
|
||||
for required in self.required_columns:
|
||||
if required.name not in available_names:
|
||||
missing.append(required.name)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
"""
|
||||
Declares the output columns produced by an Indicator.
|
||||
|
||||
Column names will be automatically prefixed with the indicator instance name
|
||||
to avoid collisions in the pipeline.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
columns: List[ColumnInfo] = Field(
|
||||
description="Output columns produced by this indicator"
|
||||
)
|
||||
time_column: str = Field(
|
||||
default="time",
|
||||
description="Name of the timestamp column (passed through from input)"
|
||||
)
|
||||
|
||||
def with_prefix(self, prefix: str) -> "OutputSchema":
|
||||
"""
|
||||
Create a new OutputSchema with all column names prefixed.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to add (e.g., indicator instance name)
|
||||
|
||||
Returns:
|
||||
New OutputSchema with prefixed column names
|
||||
"""
|
||||
prefixed_columns = [
|
||||
ColumnInfo(
|
||||
name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name,
|
||||
type=col.type,
|
||||
description=col.description,
|
||||
unit=col.unit,
|
||||
nullable=col.nullable
|
||||
)
|
||||
for col in self.columns
|
||||
]
|
||||
return OutputSchema(
|
||||
columns=prefixed_columns,
|
||||
time_column=self.time_column
|
||||
)
|
||||
|
||||
|
||||
class IndicatorParameter(BaseModel):
|
||||
"""
|
||||
Metadata for a configurable indicator parameter.
|
||||
|
||||
Used for AI agent discovery and dynamic indicator instantiation.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Parameter name")
|
||||
type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type")
|
||||
description: str = Field(description="Human and LLM-readable description")
|
||||
default: Optional[Any] = Field(default=None, description="Default value if not specified")
|
||||
required: bool = Field(default=False, description="Whether this parameter is required")
|
||||
min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)")
|
||||
max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)")
|
||||
|
||||
|
||||
class IndicatorMetadata(BaseModel):
|
||||
"""
|
||||
Rich metadata for an Indicator class.
|
||||
|
||||
Enables AI agents to discover, understand, and instantiate indicators.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')")
|
||||
display_name: str = Field(description="Human-readable display name")
|
||||
description: str = Field(description="Detailed description of what this indicator computes and why it's useful")
|
||||
category: str = Field(
|
||||
description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')"
|
||||
)
|
||||
parameters: List[IndicatorParameter] = Field(
|
||||
default_factory=list,
|
||||
description="Configurable parameters for this indicator"
|
||||
)
|
||||
use_cases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Common use cases and trading scenarios where this indicator is helpful"
|
||||
)
|
||||
references: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="URLs or citations for indicator methodology"
|
||||
)
|
||||
tags: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')"
|
||||
)
|
||||
|
||||
|
||||
class ComputeContext(BaseModel):
|
||||
"""
|
||||
Context passed to an Indicator's compute() method.
|
||||
|
||||
Contains the input data and metadata about what changed (for incremental updates).
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}"
|
||||
)
|
||||
is_incremental: bool = Field(
|
||||
default=False,
|
||||
description="True if this is an incremental update (only new/changed rows), False for full recompute"
|
||||
)
|
||||
updated_from_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Unix timestamp (ms) of the earliest updated row (for incremental updates)"
|
||||
)
|
||||
|
||||
def get_column(self, name: str) -> List[Any]:
|
||||
"""
|
||||
Extract a single column as a list of values.
|
||||
|
||||
Args:
|
||||
name: Column name
|
||||
|
||||
Returns:
|
||||
List of values in time order
|
||||
"""
|
||||
return [row.get(name) for row in self.data]
|
||||
|
||||
def get_times(self) -> List[int]:
|
||||
"""
|
||||
Get the time column as a list.
|
||||
|
||||
Returns:
|
||||
List of timestamps in order
|
||||
"""
|
||||
return [row["time"] for row in self.data]
|
||||
|
||||
|
||||
class ComputeResult(BaseModel):
|
||||
"""
|
||||
Result from an Indicator's compute() method.
|
||||
|
||||
Contains the computed output data with proper column naming.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
description="Output data rows (time-ordered). Must include time column."
|
||||
)
|
||||
is_partial: bool = Field(
|
||||
default=False,
|
||||
description="True if this result only contains updates (for incremental computation)"
|
||||
)
|
||||
|
||||
def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge this result into existing data with column name prefixing.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to add to all column names except time
|
||||
existing_data: Existing data to merge with (matched by time)
|
||||
|
||||
Returns:
|
||||
Merged data with prefixed columns added
|
||||
"""
|
||||
# Build a time index for new data
|
||||
time_index = {row["time"]: row for row in self.data}
|
||||
|
||||
# Merge into existing data
|
||||
result = []
|
||||
for existing_row in existing_data:
|
||||
row_time = existing_row["time"]
|
||||
merged_row = existing_row.copy()
|
||||
|
||||
if row_time in time_index:
|
||||
new_row = time_index[row_time]
|
||||
for key, value in new_row.items():
|
||||
if key != "time":
|
||||
merged_row[f"{prefix}_{key}"] = value
|
||||
|
||||
result.append(merged_row)
|
||||
|
||||
return result
|
||||
449
backend.old/src/indicator/talib_adapter.py
Normal file
449
backend.old/src/indicator/talib_adapter.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
TA-Lib indicator adapter.
|
||||
|
||||
Provides automatic registration of all TA-Lib technical indicators
|
||||
as composable Indicator instances.
|
||||
|
||||
Installation Requirements:
|
||||
--------------------------
|
||||
TA-Lib requires both the C library and Python wrapper:
|
||||
|
||||
1. Install TA-Lib C library:
|
||||
- Ubuntu/Debian: sudo apt-get install libta-lib-dev
|
||||
- macOS: brew install ta-lib
|
||||
- From source: https://ta-lib.org/install.html
|
||||
|
||||
2. Install Python wrapper (already in requirements.txt):
|
||||
pip install TA-Lib
|
||||
|
||||
Usage:
|
||||
------
|
||||
from indicator.talib_adapter import register_all_talib_indicators
|
||||
|
||||
# Auto-register all TA-Lib indicators
|
||||
registry = IndicatorRegistry()
|
||||
register_all_talib_indicators(registry)
|
||||
|
||||
# Now you can use any TA-Lib indicator
|
||||
sma = registry.create_instance("SMA", "sma_20", period=20)
|
||||
rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import talib
|
||||
from talib import abstract
|
||||
TALIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
TALIB_AVAILABLE = False
|
||||
talib = None
|
||||
abstract = None
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib parameter types to our schema types
|
||||
TALIB_TYPE_MAP = {
|
||||
"double": "float",
|
||||
"double[]": "float",
|
||||
"int": "int",
|
||||
"str": "string",
|
||||
}
|
||||
|
||||
# Categorization of TA-Lib functions
|
||||
TALIB_CATEGORIES = {
|
||||
"overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3",
|
||||
"BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"],
|
||||
"momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX",
|
||||
"ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI",
|
||||
"STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC",
|
||||
"BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"],
|
||||
"volume": ["AD", "ADOSC", "OBV"],
|
||||
"volatility": ["ATR", "NATR", "TRANGE"],
|
||||
"price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"],
|
||||
"cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"],
|
||||
"pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE",
|
||||
"CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY",
|
||||
"CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU",
|
||||
"CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI",
|
||||
"CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR",
|
||||
"CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER",
|
||||
"CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE",
|
||||
"CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK",
|
||||
"CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM",
|
||||
"CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW",
|
||||
"CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK",
|
||||
"CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES",
|
||||
"CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN",
|
||||
"CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR",
|
||||
"CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"],
|
||||
"statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT",
|
||||
"LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"],
|
||||
"math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX",
|
||||
"MULT", "SUB", "SUM"],
|
||||
}
|
||||
|
||||
|
||||
def _get_function_category(func_name: str) -> str:
|
||||
"""Determine the category of a TA-Lib function."""
|
||||
for category, functions in TALIB_CATEGORIES.items():
|
||||
if func_name in functions:
|
||||
return category
|
||||
return "other"
|
||||
|
||||
|
||||
class TALibIndicator(Indicator):
|
||||
"""
|
||||
Generic adapter for TA-Lib technical indicators.
|
||||
|
||||
Wraps any TA-Lib function to work within the composable indicator framework.
|
||||
Handles parameter mapping, input validation, and output formatting.
|
||||
"""
|
||||
|
||||
# Class variable to store the TA-Lib function name
|
||||
talib_function_name: str = None
|
||||
|
||||
def __init__(self, instance_name: str, **params):
|
||||
"""
|
||||
Initialize a TA-Lib indicator.
|
||||
|
||||
Args:
|
||||
instance_name: Unique name for this instance
|
||||
**params: TA-Lib function parameters
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError(
|
||||
"TA-Lib is not installed. Please install the TA-Lib C library "
|
||||
"and Python wrapper. See indicator/talib_adapter.py for instructions."
|
||||
)
|
||||
|
||||
super().__init__(instance_name, **params)
|
||||
self._talib_func = abstract.Function(self.talib_function_name)
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
"""Get metadata from TA-Lib function info."""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
|
||||
# Build parameters list from TA-Lib function info
|
||||
parameters = []
|
||||
for param_name, param_info in info.get("parameters", {}).items():
|
||||
# Handle case where param_info is a simple value (int/float) instead of a dict
|
||||
if isinstance(param_info, dict):
|
||||
param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float")
|
||||
default_value = param_info.get("default_value")
|
||||
else:
|
||||
# param_info is a simple value (default), infer type from the value
|
||||
if isinstance(param_info, int):
|
||||
param_type = "int"
|
||||
elif isinstance(param_info, float):
|
||||
param_type = "float"
|
||||
else:
|
||||
param_type = "float" # Default to float
|
||||
default_value = param_info
|
||||
|
||||
parameters.append(
|
||||
IndicatorParameter(
|
||||
name=param_name,
|
||||
type=param_type,
|
||||
description=f"TA-Lib parameter: {param_name}",
|
||||
default=default_value,
|
||||
required=False
|
||||
)
|
||||
)
|
||||
|
||||
# Get function group/category
|
||||
category = _get_function_category(cls.talib_function_name)
|
||||
|
||||
# Build display name (split camelCase or handle CDL prefix)
|
||||
display_name = cls.talib_function_name
|
||||
if display_name.startswith("CDL"):
|
||||
display_name = display_name[3:] # Remove CDL prefix for patterns
|
||||
|
||||
return IndicatorMetadata(
|
||||
name=cls.talib_function_name,
|
||||
display_name=display_name,
|
||||
description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"),
|
||||
category=category,
|
||||
parameters=parameters,
|
||||
use_cases=[f"Technical analysis using {cls.talib_function_name}"],
|
||||
references=["https://ta-lib.org/function.html"],
|
||||
tags=["talib", category, cls.talib_function_name.lower()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
"""
|
||||
Get input schema from TA-Lib function requirements.
|
||||
|
||||
Most TA-Lib functions use OHLCV data, but some use subsets.
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
input_names = info.get("input_names", {})
|
||||
|
||||
required_columns = []
|
||||
|
||||
# Map TA-Lib input names to our schema
|
||||
if "prices" in input_names:
|
||||
price_inputs = input_names["prices"]
|
||||
if "open" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="open", type="float", description="Opening price")
|
||||
)
|
||||
if "high" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="high", type="float", description="High price")
|
||||
)
|
||||
if "low" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="low", type="float", description="Low price")
|
||||
)
|
||||
if "close" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
)
|
||||
if "volume" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="volume", type="float", description="Trading volume")
|
||||
)
|
||||
|
||||
# Handle functions that take generic price arrays
|
||||
if "price" in input_names:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Price (typically close)")
|
||||
)
|
||||
|
||||
# If no specific inputs found, assume close price
|
||||
if not required_columns:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
)
|
||||
|
||||
return InputSchema(required_columns=required_columns)
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
"""Get output schema from TA-Lib function outputs."""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
output_names = info.get("output_names", [])
|
||||
|
||||
columns = []
|
||||
|
||||
# Most TA-Lib functions output one or more float arrays
|
||||
if isinstance(output_names, list):
|
||||
for output_name in output_names:
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=output_name.lower(),
|
||||
type="float",
|
||||
description=f"{cls.talib_function_name} output: {output_name}",
|
||||
nullable=True # TA-Lib often has NaN for initial periods
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Single output, use function name
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=cls.talib_function_name.lower(),
|
||||
type="float",
|
||||
description=f"{cls.talib_function_name} indicator value",
|
||||
nullable=True
|
||||
)
|
||||
)
|
||||
|
||||
return OutputSchema(columns=columns)
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
"""Compute indicator using TA-Lib."""
|
||||
# Extract input columns
|
||||
input_data = {}
|
||||
|
||||
# Get the function's expected inputs
|
||||
info = self._talib_func.info
|
||||
input_names = info.get("input_names", {})
|
||||
|
||||
# Prepare input arrays
|
||||
if "prices" in input_names:
|
||||
price_inputs = input_names["prices"]
|
||||
for price_type in price_inputs:
|
||||
column_data = context.get_column(price_type)
|
||||
# Convert to numpy array, replacing None with NaN
|
||||
input_data[price_type] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
elif "price" in input_names:
|
||||
# Generic price input, use close
|
||||
column_data = context.get_column("close")
|
||||
input_data["price"] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
else:
|
||||
# Default to close if no inputs specified
|
||||
column_data = context.get_column("close")
|
||||
input_data["close"] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
|
||||
# Set parameters on the function
|
||||
self._talib_func.parameters = self.params
|
||||
|
||||
# Execute TA-Lib function
|
||||
try:
|
||||
output = self._talib_func(input_data)
|
||||
except Exception as e:
|
||||
logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}")
|
||||
raise ValueError(f"TA-Lib computation failed: {e}")
|
||||
|
||||
# Format output
|
||||
times = context.get_times()
|
||||
output_names = info.get("output_names", [])
|
||||
|
||||
# Handle single vs multiple outputs
|
||||
if isinstance(output, np.ndarray):
|
||||
# Single output
|
||||
output_name = output_names[0].lower() if output_names else self.talib_function_name.lower()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
output_name: float(output[i]) if not np.isnan(output[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
elif isinstance(output, tuple):
|
||||
# Multiple outputs
|
||||
result_data = []
|
||||
for i in range(len(times)):
|
||||
row = {"time": times[i]}
|
||||
for j, output_array in enumerate(output):
|
||||
output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}"
|
||||
row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None
|
||||
result_data.append(row)
|
||||
else:
|
||||
raise ValueError(f"Unexpected TA-Lib output type: {type(output)}")
|
||||
|
||||
return ComputeResult(
|
||||
data=result_data,
|
||||
is_partial=context.is_incremental
|
||||
)
|
||||
|
||||
|
||||
def create_talib_indicator_class(func_name: str) -> type:
|
||||
"""
|
||||
Dynamically create an Indicator class for a TA-Lib function.
|
||||
|
||||
Args:
|
||||
func_name: TA-Lib function name (e.g., 'SMA', 'RSI')
|
||||
|
||||
Returns:
|
||||
Indicator class for this function
|
||||
"""
|
||||
return type(
|
||||
f"TALib_{func_name}",
|
||||
(TALibIndicator,),
|
||||
{"talib_function_name": func_name}
|
||||
)
|
||||
|
||||
|
||||
def register_all_talib_indicators(registry, only_tradingview_supported: bool = True) -> int:
|
||||
"""
|
||||
Auto-register all available TA-Lib indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
only_tradingview_supported: If True, only register indicators that have
|
||||
TradingView equivalents (default: True)
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
|
||||
Raises:
|
||||
ImportError: If TA-Lib is not installed
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
logger.warning(
|
||||
"TA-Lib is not installed. Skipping TA-Lib indicator registration. "
|
||||
"Install TA-Lib C library and Python wrapper to enable TA-Lib indicators."
|
||||
)
|
||||
return 0
|
||||
|
||||
# Get list of supported indicators if filtering is enabled
|
||||
from .tv_mapping import is_indicator_supported
|
||||
|
||||
# Get all TA-Lib functions
|
||||
func_groups = talib.get_function_groups()
|
||||
all_functions = []
|
||||
for group, functions in func_groups.items():
|
||||
all_functions.extend(functions)
|
||||
|
||||
# Remove duplicates
|
||||
all_functions = sorted(set(all_functions))
|
||||
|
||||
registered_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for func_name in all_functions:
|
||||
try:
|
||||
# Skip if filtering enabled and indicator not supported in TradingView
|
||||
if only_tradingview_supported and not is_indicator_supported(func_name):
|
||||
skipped_count += 1
|
||||
logger.debug(f"Skipping TA-Lib function {func_name} - not supported in TradingView")
|
||||
continue
|
||||
|
||||
# Create indicator class for this function
|
||||
indicator_class = create_talib_indicator_class(func_name)
|
||||
|
||||
# Register with the registry
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators (skipped {skipped_count} unsupported)")
|
||||
return registered_count
|
||||
|
||||
|
||||
def get_talib_version() -> Optional[str]:
|
||||
"""
|
||||
Get the installed TA-Lib version.
|
||||
|
||||
Returns:
|
||||
Version string or None if not installed
|
||||
"""
|
||||
if TALIB_AVAILABLE:
|
||||
return talib.__version__
|
||||
return None
|
||||
|
||||
|
||||
def is_talib_available() -> bool:
|
||||
"""Check if TA-Lib is available."""
|
||||
return TALIB_AVAILABLE
|
||||
360
backend.old/src/indicator/tv_mapping.py
Normal file
360
backend.old/src/indicator/tv_mapping.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Mapping layer between TA-Lib indicators and TradingView indicators.
|
||||
|
||||
This module provides bidirectional conversion between our internal TA-Lib-based
|
||||
indicator representation and TradingView's indicator system.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib indicator names to TradingView indicator names
|
||||
# Only includes indicators that are present in BOTH systems (inner join)
|
||||
# Format: {talib_name: tv_name}
|
||||
TALIB_TO_TV_NAMES = {
|
||||
# Overlap Studies (14)
|
||||
"SMA": "Moving Average",
|
||||
"EMA": "Moving Average Exponential",
|
||||
"WMA": "Weighted Moving Average",
|
||||
"DEMA": "DEMA",
|
||||
"TEMA": "TEMA",
|
||||
"TRIMA": "Triangular Moving Average",
|
||||
"KAMA": "KAMA",
|
||||
"MAMA": "MESA Adaptive Moving Average",
|
||||
"T3": "T3",
|
||||
"BBANDS": "Bollinger Bands",
|
||||
"MIDPOINT": "Midpoint",
|
||||
"MIDPRICE": "Midprice",
|
||||
"SAR": "Parabolic SAR",
|
||||
"HT_TRENDLINE": "Hilbert Transform - Instantaneous Trendline",
|
||||
|
||||
# Momentum Indicators (21)
|
||||
"RSI": "Relative Strength Index",
|
||||
"MOM": "Momentum",
|
||||
"ROC": "Rate of Change",
|
||||
"TRIX": "TRIX",
|
||||
"CMO": "Chande Momentum Oscillator",
|
||||
"DX": "Directional Movement Index",
|
||||
"ADX": "Average Directional Movement Index",
|
||||
"ADXR": "Average Directional Movement Index Rating",
|
||||
"APO": "Absolute Price Oscillator",
|
||||
"PPO": "Percentage Price Oscillator",
|
||||
"MACD": "MACD",
|
||||
"MFI": "Money Flow Index",
|
||||
"STOCH": "Stochastic",
|
||||
"STOCHF": "Stochastic Fast",
|
||||
"STOCHRSI": "Stochastic RSI",
|
||||
"WILLR": "Williams %R",
|
||||
"CCI": "Commodity Channel Index",
|
||||
"AROON": "Aroon",
|
||||
"AROONOSC": "Aroon Oscillator",
|
||||
"BOP": "Balance Of Power",
|
||||
"ULTOSC": "Ultimate Oscillator",
|
||||
|
||||
# Volume Indicators (3)
|
||||
"AD": "Chaikin A/D Line",
|
||||
"ADOSC": "Chaikin A/D Oscillator",
|
||||
"OBV": "On Balance Volume",
|
||||
|
||||
# Volatility Indicators (3)
|
||||
"ATR": "Average True Range",
|
||||
"NATR": "Normalized Average True Range",
|
||||
"TRANGE": "True Range",
|
||||
|
||||
# Price Transform (4)
|
||||
"AVGPRICE": "Average Price",
|
||||
"MEDPRICE": "Median Price",
|
||||
"TYPPRICE": "Typical Price",
|
||||
"WCLPRICE": "Weighted Close Price",
|
||||
|
||||
# Cycle Indicators (5)
|
||||
"HT_DCPERIOD": "Hilbert Transform - Dominant Cycle Period",
|
||||
"HT_DCPHASE": "Hilbert Transform - Dominant Cycle Phase",
|
||||
"HT_PHASOR": "Hilbert Transform - Phasor Components",
|
||||
"HT_SINE": "Hilbert Transform - SineWave",
|
||||
"HT_TRENDMODE": "Hilbert Transform - Trend vs Cycle Mode",
|
||||
|
||||
# Statistic Functions (9)
|
||||
"BETA": "Beta",
|
||||
"CORREL": "Pearson's Correlation Coefficient",
|
||||
"LINEARREG": "Linear Regression",
|
||||
"LINEARREG_ANGLE": "Linear Regression Angle",
|
||||
"LINEARREG_INTERCEPT": "Linear Regression Intercept",
|
||||
"LINEARREG_SLOPE": "Linear Regression Slope",
|
||||
"STDDEV": "Standard Deviation",
|
||||
"TSF": "Time Series Forecast",
|
||||
"VAR": "Variance",
|
||||
}
|
||||
|
||||
# Total: 60 indicators supported in both systems
|
||||
|
||||
# Custom indicators (TradingView indicators implemented in our backend)
|
||||
CUSTOM_TO_TV_NAMES = {
|
||||
"VWAP": "VWAP",
|
||||
"VWMA": "VWMA",
|
||||
"HMA": "Hull Moving Average",
|
||||
"SUPERTREND": "SuperTrend",
|
||||
"DONCHIAN": "Donchian Channels",
|
||||
"KELTNER": "Keltner Channels",
|
||||
"CMF": "Chaikin Money Flow",
|
||||
"VORTEX": "Vortex Indicator",
|
||||
"AO": "Awesome Oscillator",
|
||||
"AC": "Accelerator Oscillator",
|
||||
"CHOP": "Choppiness Index",
|
||||
"MASS": "Mass Index",
|
||||
}
|
||||
|
||||
# Combined mapping (TA-Lib + Custom)
|
||||
ALL_BACKEND_TO_TV_NAMES = {**TALIB_TO_TV_NAMES, **CUSTOM_TO_TV_NAMES}
|
||||
|
||||
# Total: 72 indicators (60 TA-Lib + 12 Custom)
|
||||
|
||||
# Reverse mapping
|
||||
TV_TO_TALIB_NAMES = {v: k for k, v in TALIB_TO_TV_NAMES.items()}
|
||||
TV_TO_CUSTOM_NAMES = {v: k for k, v in CUSTOM_TO_TV_NAMES.items()}
|
||||
TV_TO_BACKEND_NAMES = {v: k for k, v in ALL_BACKEND_TO_TV_NAMES.items()}
|
||||
|
||||
|
||||
def get_tv_indicator_name(talib_name: str) -> str:
|
||||
"""
|
||||
Convert TA-Lib indicator name to TradingView indicator name.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name (e.g., 'RSI')
|
||||
|
||||
Returns:
|
||||
TradingView indicator name
|
||||
"""
|
||||
return TALIB_TO_TV_NAMES.get(talib_name, talib_name)
|
||||
|
||||
|
||||
def get_talib_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Convert TradingView indicator name to TA-Lib indicator name.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
TA-Lib indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_TALIB_NAMES.get(tv_name)
|
||||
|
||||
|
||||
def convert_talib_params_to_tv_inputs(
|
||||
talib_name: str,
|
||||
talib_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert TA-Lib parameters to TradingView input format.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
talib_params: TA-Lib parameter dictionary
|
||||
|
||||
Returns:
|
||||
TradingView inputs dictionary
|
||||
"""
|
||||
tv_inputs = {}
|
||||
|
||||
# Common parameter mappings
|
||||
param_mapping = {
|
||||
"timeperiod": "length",
|
||||
"fastperiod": "fastLength",
|
||||
"slowperiod": "slowLength",
|
||||
"signalperiod": "signalLength",
|
||||
"nbdevup": "mult", # Standard deviations for upper band
|
||||
"nbdevdn": "mult", # Standard deviations for lower band
|
||||
"fastlimit": "fastLimit",
|
||||
"slowlimit": "slowLimit",
|
||||
"acceleration": "start",
|
||||
"maximum": "increment",
|
||||
"fastk_period": "kPeriod",
|
||||
"slowk_period": "kPeriod",
|
||||
"slowd_period": "dPeriod",
|
||||
"fastd_period": "dPeriod",
|
||||
"matype": "maType",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
tv_inputs["mult"] = talib_params.get("nbdevup", 2)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
tv_inputs["fastLength"] = talib_params.get("fastperiod", 12)
|
||||
tv_inputs["slowLength"] = talib_params.get("slowperiod", 26)
|
||||
tv_inputs["signalLength"] = talib_params.get("signalperiod", 9)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
tv_inputs["kPeriod"] = talib_params.get("fastk_period", 14)
|
||||
tv_inputs["dPeriod"] = talib_params.get("slowd_period", 3)
|
||||
tv_inputs["smoothK"] = talib_params.get("slowk_period", 3)
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for talib_param, value in talib_params.items():
|
||||
tv_param = param_mapping.get(talib_param, talib_param)
|
||||
tv_inputs[tv_param] = value
|
||||
|
||||
logger.debug(f"Converted TA-Lib params for {talib_name}: {talib_params} -> TV inputs: {tv_inputs}")
|
||||
return tv_inputs
|
||||
|
||||
|
||||
def convert_tv_inputs_to_talib_params(
|
||||
tv_name: str,
|
||||
tv_inputs: Dict[str, Any]
|
||||
) -> Tuple[Optional[str], Dict[str, Any]]:
|
||||
"""
|
||||
Convert TradingView inputs to TA-Lib parameters.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
tv_inputs: TradingView inputs dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (talib_name, talib_params)
|
||||
"""
|
||||
talib_name = get_talib_indicator_name(tv_name)
|
||||
if not talib_name:
|
||||
logger.warning(f"No TA-Lib mapping for TradingView indicator: {tv_name}")
|
||||
return None, {}
|
||||
|
||||
talib_params = {}
|
||||
|
||||
# Reverse parameter mappings
|
||||
reverse_mapping = {
|
||||
"length": "timeperiod",
|
||||
"fastLength": "fastperiod",
|
||||
"slowLength": "slowperiod",
|
||||
"signalLength": "signalperiod",
|
||||
"mult": "nbdevup", # Use same for both up and down
|
||||
"fastLimit": "fastlimit",
|
||||
"slowLimit": "slowlimit",
|
||||
"start": "acceleration",
|
||||
"increment": "maximum",
|
||||
"kPeriod": "fastk_period",
|
||||
"dPeriod": "slowd_period",
|
||||
"smoothK": "slowk_period",
|
||||
"maType": "matype",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
talib_params["nbdevup"] = tv_inputs.get("mult", 2)
|
||||
talib_params["nbdevdn"] = tv_inputs.get("mult", 2)
|
||||
talib_params["matype"] = 0 # SMA
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
talib_params["fastperiod"] = tv_inputs.get("fastLength", 12)
|
||||
talib_params["slowperiod"] = tv_inputs.get("slowLength", 26)
|
||||
talib_params["signalperiod"] = tv_inputs.get("signalLength", 9)
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
talib_params["fastk_period"] = tv_inputs.get("kPeriod", 14)
|
||||
talib_params["slowd_period"] = tv_inputs.get("dPeriod", 3)
|
||||
talib_params["slowk_period"] = tv_inputs.get("smoothK", 3)
|
||||
talib_params["slowk_matype"] = 0 # SMA
|
||||
talib_params["slowd_matype"] = 0 # SMA
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for tv_param, value in tv_inputs.items():
|
||||
if tv_param == "source":
|
||||
continue # Skip source parameter
|
||||
talib_param = reverse_mapping.get(tv_param, tv_param)
|
||||
talib_params[talib_param] = value
|
||||
|
||||
logger.debug(f"Converted TV inputs for {tv_name}: {tv_inputs} -> TA-Lib {talib_name} params: {talib_params}")
|
||||
return talib_name, talib_params
|
||||
|
||||
|
||||
def is_indicator_supported(talib_name: str) -> bool:
|
||||
"""
|
||||
Check if a TA-Lib indicator is supported in TradingView.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
|
||||
Returns:
|
||||
True if supported
|
||||
"""
|
||||
return talib_name in TALIB_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_supported_indicators() -> List[str]:
|
||||
"""
|
||||
Get list of supported TA-Lib indicators.
|
||||
|
||||
Returns:
|
||||
List of TA-Lib indicator names
|
||||
"""
|
||||
return list(TALIB_TO_TV_NAMES.keys())
|
||||
|
||||
|
||||
def get_supported_indicator_count() -> int:
|
||||
"""
|
||||
Get count of supported indicators.
|
||||
|
||||
Returns:
|
||||
Number of indicators supported in both systems (TA-Lib + Custom)
|
||||
"""
|
||||
return len(ALL_BACKEND_TO_TV_NAMES)
|
||||
|
||||
|
||||
def is_custom_indicator(indicator_name: str) -> bool:
|
||||
"""
|
||||
Check if an indicator is a custom implementation (not TA-Lib).
|
||||
|
||||
Args:
|
||||
indicator_name: Indicator name
|
||||
|
||||
Returns:
|
||||
True if custom indicator
|
||||
"""
|
||||
return indicator_name in CUSTOM_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_backend_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get backend indicator name from TradingView name (TA-Lib or custom).
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
Backend indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_BACKEND_NAMES.get(tv_name)
|
||||
712
backend.old/src/main.py
Normal file
712
backend.old/src/main.py
Normal file
@@ -0,0 +1,712 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
import shutil
|
||||
|
||||
from sync.protocol import HelloMessage, PatchMessage, AuthMessage, AuthResponseMessage
|
||||
from sync.registry import SyncRegistry
|
||||
from gateway.hub import Gateway
|
||||
from gateway.channels.websocket import WebSocketChannel
|
||||
from gateway.protocol import WebSocketAgentUserMessage
|
||||
from agent.core import create_agent
|
||||
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
||||
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from schema.shape import ShapeCollection
|
||||
from schema.indicator import IndicatorCollection
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
from secrets_manager import SecretsStore, InvalidMasterPassword
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators
|
||||
from trigger import CommitCoordinator, TriggerQueue
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables from .env file (if present)
|
||||
env_path = Path(__file__).parent.parent / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
|
||||
# Load configuration
|
||||
config_path = Path(__file__).parent.parent / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
registry = SyncRegistry()
|
||||
gateway = Gateway()
|
||||
agent_executor = None
|
||||
|
||||
# DataSource infrastructure
|
||||
datasource_registry = DataSourceRegistry()
|
||||
subscription_manager = SubscriptionManager()
|
||||
|
||||
# Indicator infrastructure
|
||||
indicator_registry = IndicatorRegistry()
|
||||
|
||||
# Trigger system infrastructure
|
||||
trigger_coordinator = None
|
||||
trigger_queue = None
|
||||
trigger_scheduler = None
|
||||
|
||||
# Global secrets store
|
||||
secrets_store = SecretsStore()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize agent system and data sources on startup."""
|
||||
global agent_executor, trigger_coordinator, trigger_queue, trigger_scheduler
|
||||
|
||||
# Initialize CCXT data sources
|
||||
try:
|
||||
from datasource.adapters.ccxt_adapter import CCXTDataSource
|
||||
|
||||
# Binance
|
||||
try:
|
||||
binance_source = CCXTDataSource(exchange_id="binance", poll_interval=60)
|
||||
datasource_registry.register("binance", binance_source)
|
||||
subscription_manager.register_source("binance", binance_source)
|
||||
logger.info("DataSource: Registered Binance source")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Binance source: {e}")
|
||||
|
||||
logger.info(f"DataSource infrastructure initialized with sources: {datasource_registry.list_sources()}")
|
||||
except ImportError as e:
|
||||
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
||||
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
||||
|
||||
# Initialize indicator registry with all TA-Lib indicators
|
||||
try:
|
||||
indicator_count = register_all_talib_indicators(indicator_registry)
|
||||
logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
||||
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
||||
|
||||
# Register custom indicators (TradingView indicators not in TA-Lib)
|
||||
try:
|
||||
custom_count = register_custom_indicators(indicator_registry)
|
||||
logger.info(f"Registered {custom_count} custom indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicators: {e}")
|
||||
|
||||
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
||||
anthropic_api_key = None
|
||||
|
||||
if secrets_store.is_unlocked:
|
||||
anthropic_api_key = secrets_store.get("ANTHROPIC_API_KEY")
|
||||
if anthropic_api_key:
|
||||
logger.info("Loaded API key from encrypted secrets store")
|
||||
|
||||
# Fall back to environment variable
|
||||
if not anthropic_api_key:
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if anthropic_api_key:
|
||||
logger.info("Loaded API key from environment")
|
||||
|
||||
# Initialize trigger system
|
||||
logger.info("Initializing trigger system...")
|
||||
trigger_coordinator = CommitCoordinator()
|
||||
trigger_queue = TriggerQueue(trigger_coordinator)
|
||||
trigger_scheduler = TriggerScheduler(trigger_queue)
|
||||
|
||||
# Start trigger queue and scheduler
|
||||
await trigger_queue.start()
|
||||
trigger_scheduler.start()
|
||||
logger.info("Trigger system initialized and started")
|
||||
|
||||
# Set trigger system for agent tools
|
||||
set_coordinator(trigger_coordinator)
|
||||
set_trigger_queue(trigger_queue)
|
||||
set_trigger_scheduler(trigger_scheduler)
|
||||
|
||||
if not anthropic_api_key:
|
||||
logger.error("ANTHROPIC_API_KEY not found in environment!")
|
||||
logger.info("Agent system will not be available")
|
||||
else:
|
||||
# Set the registries for agent tools
|
||||
set_registry(registry)
|
||||
set_datasource_registry(datasource_registry)
|
||||
set_indicator_registry(indicator_registry)
|
||||
|
||||
# Create and initialize agent
|
||||
agent_executor = create_agent(
|
||||
model_name=config["agent"]["model"],
|
||||
temperature=config["agent"]["temperature"],
|
||||
api_key=anthropic_api_key,
|
||||
checkpoint_db_path=config["memory"]["checkpoint_db"],
|
||||
chroma_db_path=config["memory"]["chroma_db"],
|
||||
embedding_model=config["memory"]["embedding_model"],
|
||||
context_docs_dir=config["agent"]["context_docs_dir"],
|
||||
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ and soul/ live
|
||||
)
|
||||
|
||||
await agent_executor.initialize()
|
||||
|
||||
# Set agent executor in gateway
|
||||
gateway.set_agent_executor(agent_executor.execute)
|
||||
|
||||
logger.info("Agent system initialized")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down systems...")
|
||||
|
||||
# Shutdown trigger system
|
||||
if trigger_scheduler:
|
||||
trigger_scheduler.shutdown(wait=True)
|
||||
logger.info("Trigger scheduler shut down")
|
||||
|
||||
if trigger_queue:
|
||||
await trigger_queue.stop()
|
||||
logger.info("Trigger queue stopped")
|
||||
|
||||
# Shutdown agent system
|
||||
if agent_executor and agent_executor.memory_manager:
|
||||
await agent_executor.memory_manager.close()
|
||||
|
||||
logger.info("All systems shut down")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Create uploads directory
|
||||
UPLOAD_DIR = Path(__file__).parent.parent / "data" / "uploads"
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Mount static files for serving uploads
|
||||
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
||||
|
||||
# OrderStore model for synchronization
|
||||
class OrderStore(BaseModel):
|
||||
orders: list[SwapOrder] = []
|
||||
|
||||
# ChartStore model for synchronization
|
||||
class ChartStore(BaseModel):
|
||||
chart_state: ChartState = ChartState()
|
||||
|
||||
# ShapeStore model for synchronization
|
||||
class ShapeStore(BaseModel):
|
||||
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
|
||||
|
||||
# IndicatorStore model for synchronization
|
||||
class IndicatorStore(BaseModel):
|
||||
indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID
|
||||
|
||||
# Initialize stores
|
||||
order_store = OrderStore()
|
||||
chart_store = ChartStore()
|
||||
shape_store = ShapeStore()
|
||||
indicator_store = IndicatorStore()
|
||||
|
||||
# Register with SyncRegistry
|
||||
registry.register(order_store, store_name="OrderStore")
|
||||
registry.register(chart_store, store_name="ChartStore")
|
||||
registry.register(shape_store, store_name="ShapeStore")
|
||||
registry.register(indicator_store, store_name="IndicatorStore")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
# Helper function to send responses
|
||||
async def send_response(response):
|
||||
try:
|
||||
await websocket.send_json(response.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending response: {e}")
|
||||
|
||||
# Authentication state
|
||||
is_authenticated = False
|
||||
|
||||
# Wait for authentication message (must be first message)
|
||||
try:
|
||||
auth_timeout = 30 # 30 seconds to authenticate
|
||||
auth_data = await asyncio.wait_for(websocket.receive_text(), timeout=auth_timeout)
|
||||
auth_message_json = json.loads(auth_data)
|
||||
|
||||
if auth_message_json.get("type") != "auth":
|
||||
logger.warning("First message was not auth message")
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message="First message must be authentication"
|
||||
))
|
||||
await websocket.close(code=1008, reason="Authentication required")
|
||||
return
|
||||
|
||||
auth_msg = AuthMessage(**auth_message_json)
|
||||
logger.info("Received authentication message")
|
||||
|
||||
# Check if secrets store needs initialization
|
||||
if not secrets_store.is_initialized:
|
||||
logger.info("Secrets store not initialized, performing first-time setup")
|
||||
|
||||
# Require password confirmation for initialization
|
||||
if not auth_msg.confirm_password:
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
needs_confirmation=True,
|
||||
message="First-time setup: password confirmation required"
|
||||
))
|
||||
await websocket.close(code=1008, reason="Password confirmation required")
|
||||
return
|
||||
|
||||
if auth_msg.password != auth_msg.confirm_password:
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
needs_confirmation=True,
|
||||
message="Passwords do not match"
|
||||
))
|
||||
await websocket.close(code=1008, reason="Password confirmation failed")
|
||||
return
|
||||
|
||||
# Initialize secrets store
|
||||
try:
|
||||
secrets_store.initialize(auth_msg.password)
|
||||
|
||||
# Migrate ANTHROPIC_API_KEY from environment if present
|
||||
env_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if env_key:
|
||||
secrets_store.set("ANTHROPIC_API_KEY", env_key)
|
||||
logger.info("Migrated ANTHROPIC_API_KEY from environment to secrets store")
|
||||
|
||||
is_authenticated = True
|
||||
await send_response(AuthResponseMessage(
|
||||
success=True,
|
||||
message="Secrets store initialized successfully"
|
||||
))
|
||||
logger.info("Secrets store initialized and authenticated")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize secrets store: {e}")
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message=f"Initialization failed: {str(e)}"
|
||||
))
|
||||
await websocket.close(code=1011, reason="Initialization failed")
|
||||
return
|
||||
else:
|
||||
# Unlock existing secrets store (or verify password if already unlocked)
|
||||
try:
|
||||
# If already unlocked, just verify the password is correct
|
||||
if secrets_store.is_unlocked:
|
||||
# Verify password by creating a temporary store and attempting unlock
|
||||
from secrets_manager import SecretsStore as TempStore
|
||||
temp_store = TempStore(data_dir=secrets_store.data_dir)
|
||||
temp_store.unlock(auth_msg.password) # This will throw if wrong password
|
||||
logger.info("Password verified (store already unlocked)")
|
||||
else:
|
||||
secrets_store.unlock(auth_msg.password)
|
||||
logger.info("Secrets store unlocked successfully")
|
||||
|
||||
# Check if user wants to change password
|
||||
password_changed = False
|
||||
if auth_msg.change_to_password:
|
||||
# Validate password change request
|
||||
if not auth_msg.confirm_new_password:
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message="New password confirmation required"
|
||||
))
|
||||
await websocket.close(code=1008, reason="Password confirmation required")
|
||||
return
|
||||
|
||||
if auth_msg.change_to_password != auth_msg.confirm_new_password:
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message="New passwords do not match"
|
||||
))
|
||||
await websocket.close(code=1008, reason="Password confirmation mismatch")
|
||||
return
|
||||
|
||||
# Change the password
|
||||
try:
|
||||
secrets_store.change_master_password(auth_msg.password, auth_msg.change_to_password)
|
||||
password_changed = True
|
||||
logger.info("Master password changed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to change password: {e}")
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message=f"Failed to change password: {str(e)}"
|
||||
))
|
||||
await websocket.close(code=1011, reason="Password change failed")
|
||||
return
|
||||
|
||||
is_authenticated = True
|
||||
response_message = "Password changed successfully" if password_changed else "Authentication successful"
|
||||
await send_response(AuthResponseMessage(
|
||||
success=True,
|
||||
password_changed=password_changed,
|
||||
message=response_message
|
||||
))
|
||||
except InvalidMasterPassword:
|
||||
logger.warning("Invalid password attempt")
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message="Invalid password"
|
||||
))
|
||||
await websocket.close(code=1008, reason="Invalid password")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
await send_response(AuthResponseMessage(
|
||||
success=False,
|
||||
message="Authentication failed"
|
||||
))
|
||||
await websocket.close(code=1011, reason="Authentication error")
|
||||
return
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Authentication timeout")
|
||||
await websocket.close(code=1008, reason="Authentication timeout")
|
||||
return
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected during authentication")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Error during authentication: {e}")
|
||||
await websocket.close(code=1011, reason="Authentication error")
|
||||
return
|
||||
|
||||
# Now authenticated - proceed with normal WebSocket handling
|
||||
registry.websocket = websocket
|
||||
|
||||
# Create WebSocket channel for agent communication
|
||||
channel_id = f"ws_{id(websocket)}"
|
||||
client_id = f"client_{id(websocket)}"
|
||||
logger.info(f"WebSocket authenticated - channel_id: {channel_id}, client_id: {client_id}")
|
||||
ws_channel = WebSocketChannel(channel_id, websocket, session_id="default")
|
||||
gateway.register_channel(ws_channel)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
logger.debug(f"Received WebSocket message: {data[:200]}...") # Log first 200 chars
|
||||
message_json = json.loads(data)
|
||||
|
||||
if "type" not in message_json:
|
||||
logger.warning(f"Message missing 'type' field: {message_json}")
|
||||
continue
|
||||
|
||||
msg_type = message_json["type"]
|
||||
logger.info(f"Processing message type: {msg_type}")
|
||||
|
||||
# Handle sync protocol messages
|
||||
if msg_type == "hello":
|
||||
hello_msg = HelloMessage(**message_json)
|
||||
logger.info(f"Hello message received with seqs: {hello_msg.seqs}")
|
||||
await registry.sync_client(hello_msg.seqs)
|
||||
elif msg_type == "patch":
|
||||
patch_msg = PatchMessage(**message_json)
|
||||
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
|
||||
try:
|
||||
await registry.apply_client_patch(
|
||||
store_name=patch_msg.store,
|
||||
client_base_seq=patch_msg.seq,
|
||||
patch=patch_msg.patch
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True)
|
||||
elif msg_type == "agent_user_message":
|
||||
# Handle agent messages directly here
|
||||
print(f"[DEBUG] Raw message_json: {message_json}")
|
||||
logger.info(f"Raw message_json: {message_json}")
|
||||
msg = WebSocketAgentUserMessage(**message_json)
|
||||
print(f"[DEBUG] Parsed message - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
|
||||
logger.info(f"Agent user message received - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
|
||||
from gateway.protocol import UserMessage
|
||||
from datetime import datetime, timezone
|
||||
|
||||
user_msg = UserMessage(
|
||||
session_id=msg.session_id,
|
||||
channel_id=channel_id,
|
||||
content=msg.content,
|
||||
attachments=msg.attachments,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
logger.info(f"Routing user message to gateway - channel: {channel_id}, session: {msg.session_id}")
|
||||
await gateway.route_user_message(user_msg)
|
||||
logger.info("Message routing completed")
|
||||
|
||||
# Handle datafeed protocol messages
|
||||
elif msg_type in ["get_config", "search_symbols", "resolve_symbol", "get_bars", "subscribe_bars", "unsubscribe_bars"]:
|
||||
from datasource.websocket_protocol import (
|
||||
GetConfigRequest, GetConfigResponse,
|
||||
SearchSymbolsRequest, SearchSymbolsResponse,
|
||||
ResolveSymbolRequest, ResolveSymbolResponse,
|
||||
GetBarsRequest, GetBarsResponse,
|
||||
SubscribeBarsRequest, SubscribeBarsResponse,
|
||||
UnsubscribeBarsRequest, UnsubscribeBarsResponse,
|
||||
ErrorResponse
|
||||
)
|
||||
|
||||
request_id = message_json.get("request_id", "unknown")
|
||||
try:
|
||||
if msg_type == "get_config":
|
||||
req = GetConfigRequest(**message_json)
|
||||
logger.info(f"Getting config, request_id={req.request_id}")
|
||||
sources = datasource_registry.list_sources()
|
||||
logger.info(f"Available sources: {sources}")
|
||||
|
||||
if not sources:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="NO_SOURCES", error_message="No data sources available")
|
||||
await send_response(error_response)
|
||||
else:
|
||||
# Get config from first source (we can enhance this later to aggregate)
|
||||
source = datasource_registry.get(sources[0])
|
||||
if source:
|
||||
try:
|
||||
config = await source.get_config()
|
||||
logger.info(f"Got config from {sources[0]}")
|
||||
# Enhance with all available exchanges
|
||||
all_exchanges = set()
|
||||
for source_name in sources:
|
||||
s = datasource_registry.get(source_name)
|
||||
if s:
|
||||
try:
|
||||
cfg = await asyncio.wait_for(s.get_config(), timeout=5.0)
|
||||
all_exchanges.update(cfg.exchanges)
|
||||
logger.info(f"Added exchanges from {source_name}: {cfg.exchanges}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout getting config from {source_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting config from {source_name}: {e}")
|
||||
config_dict = config.model_dump(mode="json")
|
||||
config_dict["exchanges"] = list(all_exchanges)
|
||||
logger.info(f"Sending config with exchanges: {list(all_exchanges)}")
|
||||
response = GetConfigResponse(request_id=req.request_id, config=config_dict)
|
||||
await send_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting config: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
else:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message="Data sources not available")
|
||||
await send_response(error_response)
|
||||
|
||||
elif msg_type == "search_symbols":
|
||||
req = SearchSymbolsRequest(**message_json)
|
||||
logger.info(f"Searching symbols: query='{req.query}', request_id={req.request_id}")
|
||||
|
||||
# Search all data sources
|
||||
all_results = []
|
||||
sources = datasource_registry.list_sources()
|
||||
logger.info(f"Available data sources: {sources}")
|
||||
|
||||
for source_name in sources:
|
||||
source = datasource_registry.get(source_name)
|
||||
if source:
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
source.search_symbols(query=req.query, type=req.symbol_type, exchange=req.exchange, limit=req.limit),
|
||||
timeout=5.0
|
||||
)
|
||||
all_results.extend([r.model_dump(mode="json") for r in results])
|
||||
logger.info(f"Source '{source_name}' returned {len(results)} results")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout searching source '{source_name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error searching source '{source_name}': {e}")
|
||||
|
||||
logger.info(f"Total search results: {len(all_results)}")
|
||||
response = SearchSymbolsResponse(request_id=req.request_id, results=all_results[:req.limit])
|
||||
await send_response(response)
|
||||
|
||||
elif msg_type == "resolve_symbol":
|
||||
req = ResolveSymbolRequest(**message_json)
|
||||
logger.info(f"Resolving symbol: {req.symbol}")
|
||||
|
||||
# Parse ticker format: "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
symbol = req.symbol
|
||||
source_name = None
|
||||
symbol_without_exchange = symbol
|
||||
|
||||
# Check if ticker has exchange prefix
|
||||
if ":" in symbol:
|
||||
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
logger.info(f"Parsed ticker: exchange={source_name}, symbol={symbol_without_exchange}")
|
||||
|
||||
# If we identified a source, try it directly
|
||||
if source_name:
|
||||
try:
|
||||
source = datasource_registry.get(source_name)
|
||||
if source:
|
||||
logger.info(f"Trying to resolve '{symbol_without_exchange}' in source '{source_name}'")
|
||||
symbol_info = await asyncio.wait_for(
|
||||
source.resolve_symbol(symbol_without_exchange),
|
||||
timeout=5.0
|
||||
)
|
||||
logger.info(f"Successfully resolved '{symbol_without_exchange}' in source '{source_name}'")
|
||||
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
|
||||
await send_response(response)
|
||||
else:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
|
||||
await send_response(error_response)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout resolving '{symbol_without_exchange}' in source '{source_name}'")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message=f"Timeout resolving symbol")
|
||||
await send_response(error_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving '{symbol_without_exchange}' in source '{source_name}': {e}")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
else:
|
||||
# No exchange prefix, try all sources
|
||||
found = False
|
||||
for src in datasource_registry.list_sources():
|
||||
try:
|
||||
s = datasource_registry.get(src)
|
||||
if s:
|
||||
logger.info(f"Trying to resolve '{symbol}' in source '{src}'")
|
||||
symbol_info = await asyncio.wait_for(s.resolve_symbol(symbol), timeout=5.0)
|
||||
if symbol_info:
|
||||
logger.info(f"Successfully resolved '{symbol}' in source '{src}'")
|
||||
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
|
||||
await send_response(response)
|
||||
found = True
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout resolving '{symbol}' in source '{src}'")
|
||||
except Exception as e:
|
||||
logger.info(f"Symbol '{symbol}' not found in source '{src}': {e}")
|
||||
continue
|
||||
|
||||
if not found:
|
||||
# Symbol not found in any source
|
||||
logger.warning(f"Symbol '{symbol}' not found in any data source")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=f"Symbol '{symbol}' not found in any data source")
|
||||
await send_response(error_response)
|
||||
|
||||
elif msg_type == "get_bars":
|
||||
req = GetBarsRequest(**message_json)
|
||||
logger.info(f"Getting bars for symbol: {req.symbol}")
|
||||
|
||||
# Parse ticker format: "EXCHANGE:SYMBOL"
|
||||
symbol = req.symbol
|
||||
source_name = None
|
||||
symbol_without_exchange = symbol
|
||||
|
||||
# Check if ticker has exchange prefix
|
||||
if ":" in symbol:
|
||||
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
logger.info(f"Parsed ticker for bars: exchange={source_name}, symbol={symbol_without_exchange}")
|
||||
|
||||
# If we identified a source, use it directly
|
||||
if source_name:
|
||||
try:
|
||||
source = datasource_registry.get(source_name)
|
||||
if source:
|
||||
logger.info(f"Getting bars for '{symbol_without_exchange}' from source '{source_name}'")
|
||||
history = await asyncio.wait_for(
|
||||
source.get_bars(symbol=symbol_without_exchange, resolution=req.resolution, from_time=req.from_time, to_time=req.to_time, countback=req.countback),
|
||||
timeout=10.0
|
||||
)
|
||||
logger.info(f"Successfully got {len(history.bars)} bars for '{symbol_without_exchange}' from source '{source_name}'")
|
||||
response = GetBarsResponse(request_id=req.request_id, history=history.model_dump(mode="json"))
|
||||
await send_response(response)
|
||||
else:
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
|
||||
await send_response(error_response)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout getting bars for '{symbol_without_exchange}' from source '{source_name}'")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message="Timeout fetching bars")
|
||||
await send_response(error_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting bars for '{symbol_without_exchange}' from source '{source_name}': {e}")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
else:
|
||||
# No exchange prefix - this shouldn't happen with proper tickers
|
||||
logger.warning(f"Ticker '{symbol}' has no exchange prefix")
|
||||
error_response = ErrorResponse(request_id=req.request_id, error_code="INVALID_TICKER", error_message="Ticker must include exchange prefix (e.g., BINANCE:BTC/USDT)")
|
||||
await send_response(error_response)
|
||||
|
||||
elif msg_type == "subscribe_bars":
|
||||
req = SubscribeBarsRequest(**message_json)
|
||||
# TODO: Implement subscription management
|
||||
response = SubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
|
||||
await send_response(response)
|
||||
|
||||
elif msg_type == "unsubscribe_bars":
|
||||
req = UnsubscribeBarsRequest(**message_json)
|
||||
response = UnsubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
|
||||
await send_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(request_id=request_id, error_code="INTERNAL_ERROR", error_message=str(e))
|
||||
await send_response(error_response)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected - channel_id: {channel_id}")
|
||||
registry.websocket = None
|
||||
gateway.unregister_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
registry.websocket = None
|
||||
gateway.unregister_channel(channel_id)
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
"""Upload a file and return its URL."""
|
||||
try:
|
||||
# Generate unique filename
|
||||
file_extension = Path(file.filename).suffix if file.filename else ""
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
file_path = UPLOAD_DIR / unique_filename
|
||||
|
||||
# Save file
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Return URL (relative to backend)
|
||||
file_url = f"/uploads/{unique_filename}"
|
||||
logger.info(f"File uploaded: {file.filename} -> {file_url}")
|
||||
|
||||
return {
|
||||
"url": file_url,
|
||||
"filename": file.filename,
|
||||
"size": file_path.stat().st_size
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"File upload error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/healthz")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# Background task to simulate backend updates (optional, for demo)
|
||||
async def simulate_backend_updates():
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
if registry.websocket:
|
||||
# Example: could add/modify orders here
|
||||
await registry.push_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=config["www_port"])
|
||||
28
backend.old/src/schema/chart_state.py
Normal file
28
backend.old/src/schema/chart_state.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChartState(BaseModel):
|
||||
"""Tracks the user's current chart view state.
|
||||
|
||||
This state is synchronized between the frontend and backend to allow
|
||||
the AI agent to understand what the user is currently viewing.
|
||||
|
||||
All fields can be None when no chart is visible (e.g., on mobile/narrow screens).
|
||||
"""
|
||||
|
||||
# Current symbol being viewed (e.g., "BINANCE:BTC/USDT", "BINANCE:ETH/USDT")
|
||||
# None when chart is not visible
|
||||
symbol: Optional[str] = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol, or None if no chart visible")
|
||||
|
||||
# Time range currently visible on chart (Unix timestamps in seconds)
|
||||
# These represent the leftmost and rightmost visible candle times
|
||||
start_time: Optional[int] = Field(default=None, description="Start time of visible range (Unix timestamp in seconds)")
|
||||
end_time: Optional[int] = Field(default=None, description="End time of visible range (Unix timestamp in seconds)")
|
||||
|
||||
# Optional: Chart interval/resolution
|
||||
# None when chart is not visible
|
||||
interval: Optional[str] = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D'), or None if no chart visible")
|
||||
|
||||
# Selected shapes/drawings on the chart
|
||||
selected_shapes: List[str] = Field(default_factory=list, description="Array of selected shape IDs")
|
||||
40
backend.old/src/schema/indicator.py
Normal file
40
backend.old/src/schema/indicator.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IndicatorInstance(BaseModel):
|
||||
"""
|
||||
Represents an instance of an indicator applied to a chart.
|
||||
|
||||
This schema holds both the TA-Lib metadata and TradingView-specific data
|
||||
needed for synchronization.
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for this indicator instance")
|
||||
|
||||
# TA-Lib metadata
|
||||
talib_name: str = Field(..., description="TA-Lib indicator name (e.g., 'RSI', 'SMA', 'MACD')")
|
||||
instance_name: str = Field(..., description="User-friendly instance name")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="TA-Lib indicator parameters")
|
||||
|
||||
# TradingView metadata
|
||||
tv_study_id: Optional[str] = Field(default=None, description="TradingView study ID assigned by the chart widget")
|
||||
tv_indicator_name: Optional[str] = Field(default=None, description="TradingView indicator name if different from TA-Lib")
|
||||
tv_inputs: Optional[Dict[str, Any]] = Field(default=None, description="TradingView-specific input parameters")
|
||||
|
||||
# Visual properties
|
||||
visible: bool = Field(default=True, description="Whether indicator is visible on chart")
|
||||
pane: str = Field(default="chart", description="Pane where indicator is displayed ('chart' or 'separate')")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this indicator is applied to")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class IndicatorCollection(BaseModel):
|
||||
"""Collection of all indicator instances on the chart."""
|
||||
indicators: Dict[str, IndicatorInstance] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary of indicator instances keyed by ID"
|
||||
)
|
||||
327
backend.old/src/schema/order_spec.py
Normal file
327
backend.old/src/schema/order_spec.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, BeforeValidator, PlainSerializer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scalar coercion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _to_int(v: int | str) -> int:
|
||||
return int(v, 0) if isinstance(v, str) else int(v)
|
||||
|
||||
|
||||
def _to_float(v: float | int | str) -> float:
|
||||
return float(v)
|
||||
|
||||
|
||||
_int_to_str = PlainSerializer(str, return_type=str, when_used="json")
|
||||
_float_to_str = PlainSerializer(str, return_type=str, when_used="json")
|
||||
|
||||
# Always stored as Python int; accepts int or string on input; serialises to string in JSON.
|
||||
type Uint8 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint16 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint24 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint32 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint64 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint256 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Float = Annotated[float, BeforeValidator(_to_float), _float_to_str]
|
||||
|
||||
ETH_ADDRESS_PATTERN = r"^0x[0-9a-fA-F]{40}$"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Exchange(StrEnum):
|
||||
UNISWAP_V2 = "UniswapV2"
|
||||
UNISWAP_V3 = "UniswapV3"
|
||||
|
||||
|
||||
class Side(StrEnum):
|
||||
"""Order side: buy or sell"""
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
|
||||
class AmountType(StrEnum):
|
||||
"""Whether the order amount refers to base or quote currency"""
|
||||
BASE = "BASE" # Amount is in base currency (e.g., BTC in BTC/USD)
|
||||
QUOTE = "QUOTE" # Amount is in quote currency (e.g., USD in BTC/USD)
|
||||
|
||||
|
||||
class TimeInForce(StrEnum):
|
||||
"""Order lifetime specification"""
|
||||
GTC = "GTC" # Good Till Cancel
|
||||
IOC = "IOC" # Immediate or Cancel
|
||||
FOK = "FOK" # Fill or Kill
|
||||
DAY = "DAY" # Good for trading day
|
||||
GTD = "GTD" # Good Till Date
|
||||
|
||||
|
||||
class ConditionalOrderMode(StrEnum):
|
||||
"""How conditional orders behave on partial fills"""
|
||||
NEW_PER_FILL = "NEW_PER_FILL" # Create new conditional order per each fill
|
||||
UNIFIED_ADJUSTING = "UNIFIED_ADJUSTING" # Single conditional order that adjusts amount
|
||||
|
||||
|
||||
class TriggerType(StrEnum):
|
||||
"""Type of conditional trigger"""
|
||||
STOP_LOSS = "STOP_LOSS"
|
||||
TAKE_PROFIT = "TAKE_PROFIT"
|
||||
STOP_LIMIT = "STOP_LIMIT"
|
||||
TRAILING_STOP = "TRAILING_STOP"
|
||||
|
||||
|
||||
class TickSpacingMode(StrEnum):
|
||||
"""How price tick spacing is determined"""
|
||||
FIXED = "FIXED" # Fixed tick size
|
||||
DYNAMIC = "DYNAMIC" # Tick size varies by price level
|
||||
CONTINUOUS = "CONTINUOUS" # No tick restrictions
|
||||
|
||||
|
||||
class AssetType(StrEnum):
|
||||
"""Type of tradeable asset"""
|
||||
SPOT = "SPOT" # Spot/cash market
|
||||
MARGIN = "MARGIN" # Margin trading
|
||||
PERP = "PERP" # Perpetual futures
|
||||
FUTURE = "FUTURE" # Dated futures
|
||||
OPTION = "OPTION" # Options
|
||||
SYNTHETIC = "SYNTHETIC" # Synthetic/derived instruments
|
||||
|
||||
|
||||
class OcoMode(StrEnum):
|
||||
NO_OCO = "NO_OCO"
|
||||
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
|
||||
CANCEL_ON_COMPLETION = "CANCEL_ON_COMPLETION"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supporting models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Route(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
exchange: Exchange
|
||||
fee: Uint24 = Field(description="Pool fee tier; also used as maxFee on UniswapV3")
|
||||
|
||||
|
||||
class Line(BaseModel):
|
||||
"""Price line: price = intercept + slope * time. Both zero means line is disabled."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intercept: Float
|
||||
slope: Float
|
||||
|
||||
|
||||
class Tranche(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
fraction: Uint16 = Field(description="Fraction of total order amount; MAX_FRACTION (65535) = 100%")
|
||||
startTimeIsRelative: bool
|
||||
endTimeIsRelative: bool
|
||||
minIsBarrier: bool = Field(description="Not yet supported")
|
||||
maxIsBarrier: bool = Field(description="Not yet supported")
|
||||
marketOrder: bool = Field(
|
||||
description="If true, min/max lines ignored; minLine intercept treated as max slippage"
|
||||
)
|
||||
minIsRatio: bool
|
||||
maxIsRatio: bool
|
||||
rateLimitFraction: Uint16 = Field(description="Max fraction of this tranche's amount per rate-limited execution")
|
||||
rateLimitPeriod: Uint24 = Field(description="Seconds between rate limit resets")
|
||||
startTime: Uint32 = Field(description="Unix timestamp; 0 (DISTANT_PAST) effectively disables")
|
||||
endTime: Uint32 = Field(description="Unix timestamp; 4294967295 (DISTANT_FUTURE) effectively disables")
|
||||
minLine: Line = Field(description="Traditional limit order constraint; can be diagonal")
|
||||
maxLine: Line = Field(description="Upper price boundary (too-good-a-price guard)")
|
||||
|
||||
|
||||
class TrancheStatus(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
filled: Uint256 = Field(description="Amount filled by this tranche")
|
||||
activationTime: Uint32 = Field(description="Earliest time this tranche can execute; 0 = not yet concrete")
|
||||
startTime: Uint32 = Field(description="Concrete start timestamp")
|
||||
endTime: Uint32 = Field(description="Concrete end timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standard Order Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConditionalTrigger(BaseModel):
|
||||
"""Conditional order trigger (stop-loss, take-profit, etc.)"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
trigger_type: TriggerType
|
||||
trigger_price: Float = Field(description="Price at which conditional order activates")
|
||||
trailing_delta: Float | None = Field(default=None, description="For trailing stops: delta from peak/trough")
|
||||
|
||||
|
||||
class AmountConstraints(BaseModel):
|
||||
"""Constraints on order amounts for a symbol"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
min_amount: Float = Field(description="Minimum order amount")
|
||||
max_amount: Float = Field(description="Maximum order amount")
|
||||
step_size: Float = Field(description="Amount increment granularity")
|
||||
|
||||
|
||||
class PriceConstraints(BaseModel):
|
||||
"""Constraints on order pricing for a symbol"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
tick_spacing_mode: TickSpacingMode
|
||||
tick_size: Float | None = Field(default=None, description="Fixed tick size (if FIXED mode)")
|
||||
min_price: Float | None = Field(default=None, description="Minimum allowed price")
|
||||
max_price: Float | None = Field(default=None, description="Maximum allowed price")
|
||||
|
||||
|
||||
class MarketCapabilities(BaseModel):
|
||||
"""Describes what order features a market supports"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
supported_sides: list[Side] = Field(description="Supported order sides (usually both)")
|
||||
supported_amount_types: list[AmountType] = Field(description="Whether BASE, QUOTE, or both amounts are supported")
|
||||
supports_market_orders: bool = Field(description="Whether market orders are supported")
|
||||
supports_limit_orders: bool = Field(description="Whether limit orders are supported")
|
||||
supported_time_in_force: list[TimeInForce] = Field(description="Supported order lifetimes")
|
||||
supports_conditional_orders: bool = Field(description="Whether stop-loss/take-profit are supported")
|
||||
supported_trigger_types: list[TriggerType] = Field(default_factory=list, description="Supported trigger types")
|
||||
supports_post_only: bool = Field(default=False, description="Whether post-only orders are supported")
|
||||
supports_reduce_only: bool = Field(default=False, description="Whether reduce-only orders are supported")
|
||||
supports_iceberg: bool = Field(default=False, description="Whether iceberg orders are supported")
|
||||
market_order_amount_type: AmountType | None = Field(
|
||||
default=None,
|
||||
description="Required amount type for market orders (some DEXs require exact-in)"
|
||||
)
|
||||
|
||||
|
||||
class SymbolMetadata(BaseModel):
|
||||
"""Complete metadata describing a tradeable symbol/market"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol_id: str = Field(description="Unique symbol identifier")
|
||||
base_asset: str = Field(description="Base asset (e.g., 'BTC')")
|
||||
quote_asset: str = Field(description="Quote asset (e.g., 'USD')")
|
||||
asset_type: AssetType = Field(description="Type of market")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
|
||||
amount_constraints: AmountConstraints
|
||||
price_constraints: PriceConstraints
|
||||
capabilities: MarketCapabilities
|
||||
|
||||
contract_size: Float | None = Field(default=None, description="For futures/options: contract multiplier")
|
||||
settlement_asset: str | None = Field(default=None, description="For derivatives: settlement currency")
|
||||
expiry_timestamp: Uint64 | None = Field(default=None, description="For dated futures/options: expiration")
|
||||
|
||||
|
||||
class StandardOrder(BaseModel):
|
||||
"""Standard order specification for exchange kernels"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol_id: str = Field(description="Symbol to trade")
|
||||
side: Side = Field(description="Buy or sell")
|
||||
amount: Float = Field(description="Order amount")
|
||||
amount_type: AmountType = Field(description="Whether amount is BASE or QUOTE currency")
|
||||
|
||||
limit_price: Float | None = Field(default=None, description="Limit price (None = market order)")
|
||||
time_in_force: TimeInForce = Field(default=TimeInForce.GTC, description="Order lifetime")
|
||||
good_till_date: Uint64 | None = Field(default=None, description="Expiry timestamp for GTD orders")
|
||||
|
||||
conditional_trigger: ConditionalTrigger | None = Field(
|
||||
default=None,
|
||||
description="Stop-loss/take-profit trigger"
|
||||
)
|
||||
conditional_mode: ConditionalOrderMode | None = Field(
|
||||
default=None,
|
||||
description="How conditional orders behave on partial fills"
|
||||
)
|
||||
|
||||
reduce_only: bool = Field(default=False, description="Only reduce existing position")
|
||||
post_only: bool = Field(default=False, description="Only make, never take")
|
||||
iceberg_qty: Float | None = Field(default=None, description="Visible amount for iceberg orders")
|
||||
|
||||
client_order_id: str | None = Field(default=None, description="Client-specified order ID")
|
||||
|
||||
|
||||
class StandardOrderStatus(BaseModel):
|
||||
"""Current status of a standard order"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
order: StandardOrder
|
||||
order_id: str = Field(description="Exchange-assigned order ID")
|
||||
status: str = Field(description="Order status: NEW, PARTIALLY_FILLED, FILLED, CANCELED, REJECTED, EXPIRED")
|
||||
filled_amount: Float = Field(description="Amount filled so far")
|
||||
average_fill_price: Float = Field(description="Average execution price")
|
||||
created_at: Uint64 = Field(description="Order creation timestamp")
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SwapOrder(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
tokenIn: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 input token address")
|
||||
tokenOut: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 output token address")
|
||||
route: Route
|
||||
amount: Uint256 = Field(description="Maximum quantity to fill")
|
||||
minFillAmount: Uint256 = Field(description="Minimum tranche amount before tranche is considered complete")
|
||||
amountIsInput: bool = Field(description="true = amount is tokenIn quantity; false = tokenOut")
|
||||
outputDirectlyToOwner: bool = Field(description="true = proceeds go to vault owner; false = vault")
|
||||
inverted: bool = Field(description="false = tokenIn/tokenOut price direction (Uniswap natural)")
|
||||
conditionalOrder: Uint64 = Field(
|
||||
description="NO_CONDITIONAL_ORDER = 2^64-1; high bit set = relative index within placement group"
|
||||
)
|
||||
tranches: list[Tranche] = Field(min_length=1)
|
||||
|
||||
|
||||
class StandardOrderGroup(BaseModel):
|
||||
"""Group of orders with OCO (One-Cancels-Other) relationship"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
orders: list[StandardOrder] = Field(min_length=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy swap order models (kept for backward compatibility)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OcoGroup(BaseModel):
|
||||
"""DEPRECATED: Use StandardOrderGroup instead"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
orders: list[SwapOrder] = Field(min_length=1)
|
||||
|
||||
|
||||
class SwapOrderStatus(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
order: SwapOrder
|
||||
fillFeeHalfBps: Uint8 = Field(description="Fill fee in half-bps (1/20000); max 255 = 1.275%")
|
||||
canceled: bool = Field(description="If true, order is canceled regardless of cancelAllIndex")
|
||||
startTime: Uint32 = Field(description="Earliest block.timestamp at which order may execute")
|
||||
ocoGroup: Uint64 = Field(description="Index into ocoGroups; NO_OCO_INDEX = 2^64-1")
|
||||
originalOrder: Uint64 = Field(description="Index of the original order in the orders array")
|
||||
startPrice: Uint256 = Field(description="Price at order start")
|
||||
filled: Uint256 = Field(description="Total amount filled so far")
|
||||
trancheStatus: list[TrancheStatus]
|
||||
|
||||
|
||||
44
backend.old/src/schema/shape.py
Normal file
44
backend.old/src/schema/shape.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ControlPoint(BaseModel):
|
||||
"""A control point for a drawing shape.
|
||||
|
||||
Control points define the position and properties of a shape.
|
||||
Different shapes have different numbers of control points.
|
||||
"""
|
||||
time: int = Field(..., description="Unix timestamp in seconds")
|
||||
price: float = Field(..., description="Price level")
|
||||
# Optional channel for multi-point shapes (e.g., parallel channels)
|
||||
channel: Optional[str] = Field(default=None, description="Channel identifier for multi-point shapes")
|
||||
|
||||
|
||||
class Shape(BaseModel):
|
||||
"""A TradingView drawing shape/study.
|
||||
|
||||
Represents any drawing the user creates on the chart (trendlines,
|
||||
horizontal lines, rectangles, Fibonacci retracements, etc.)
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for the shape")
|
||||
type: str = Field(..., description="Shape type (e.g., 'trendline', 'horizontal_line', 'rectangle', 'fibonacci')")
|
||||
points: List[ControlPoint] = Field(default_factory=list, description="Control points that define the shape")
|
||||
|
||||
# Visual properties
|
||||
color: Optional[str] = Field(default=None, description="Shape color (hex or color name)")
|
||||
line_width: Optional[int] = Field(default=1, description="Line width in pixels")
|
||||
line_style: Optional[str] = Field(default="solid", description="Line style: 'solid', 'dashed', 'dotted'")
|
||||
|
||||
# Shape-specific properties stored as flexible dict
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Additional shape-specific properties")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this shape is drawn on")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend/agent before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class ShapeCollection(BaseModel):
|
||||
"""Collection of all shapes/drawings on the chart."""
|
||||
shapes: Dict[str, Shape] = Field(default_factory=dict, description="Dictionary of shapes keyed by ID")
|
||||
40
backend.old/src/secrets_manager/__init__.py
Normal file
40
backend.old/src/secrets_manager/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Encrypted secrets management with master password protection.
|
||||
|
||||
This module provides secure storage for sensitive configuration like API keys,
|
||||
using Argon2id for password-based key derivation and Fernet (AES-256) for encryption.
|
||||
|
||||
Basic usage:
|
||||
from secrets_manager import SecretsStore
|
||||
|
||||
# First time setup
|
||||
store = SecretsStore()
|
||||
store.initialize("my-master-password")
|
||||
store.set("ANTHROPIC_API_KEY", "sk-ant-...")
|
||||
|
||||
# Later usage
|
||||
store = SecretsStore()
|
||||
store.unlock("my-master-password")
|
||||
api_key = store.get("ANTHROPIC_API_KEY")
|
||||
|
||||
Command-line interface:
|
||||
python -m secrets_manager.cli init
|
||||
python -m secrets_manager.cli set KEY VALUE
|
||||
python -m secrets_manager.cli get KEY
|
||||
python -m secrets_manager.cli list
|
||||
python -m secrets_manager.cli change-password
|
||||
"""
|
||||
|
||||
from .store import (
|
||||
SecretsStore,
|
||||
SecretsStoreError,
|
||||
SecretsStoreLocked,
|
||||
InvalidMasterPassword,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SecretsStore",
|
||||
"SecretsStoreError",
|
||||
"SecretsStoreLocked",
|
||||
"InvalidMasterPassword",
|
||||
]
|
||||
374
backend.old/src/secrets_manager/cli.py
Normal file
374
backend.old/src/secrets_manager/cli.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Command-line interface for managing the encrypted secrets store.
|
||||
|
||||
Usage:
|
||||
python -m secrets.cli init # Initialize new secrets store
|
||||
python -m secrets.cli set KEY VALUE # Set a secret
|
||||
python -m secrets.cli get KEY # Get a secret
|
||||
python -m secrets.cli delete KEY # Delete a secret
|
||||
python -m secrets.cli list # List all secret keys
|
||||
python -m secrets.cli change-password # Change master password
|
||||
python -m secrets.cli export FILE # Export encrypted backup
|
||||
python -m secrets.cli import FILE # Import encrypted backup
|
||||
python -m secrets.cli migrate-from-env # Migrate secrets from .env file
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
import getpass
|
||||
from pathlib import Path
|
||||
|
||||
from .store import SecretsStore, SecretsStoreError, InvalidMasterPassword
|
||||
|
||||
|
||||
def get_password(prompt: str = "Master password: ", confirm: bool = False) -> str:
|
||||
"""
|
||||
Securely get password from user.
|
||||
|
||||
Args:
|
||||
prompt: Password prompt
|
||||
confirm: If True, ask for confirmation
|
||||
|
||||
Returns:
|
||||
Password string
|
||||
"""
|
||||
password = getpass.getpass(prompt)
|
||||
|
||||
if confirm:
|
||||
confirm_password = getpass.getpass("Confirm password: ")
|
||||
if password != confirm_password:
|
||||
print("Error: Passwords do not match", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return password
|
||||
|
||||
|
||||
def cmd_init(args):
|
||||
"""Initialize a new secrets store."""
|
||||
store = SecretsStore()
|
||||
|
||||
if store.is_initialized:
|
||||
print("Error: Secrets store is already initialized", file=sys.stderr)
|
||||
print(f"Location: {store.secrets_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password("Create master password: ", confirm=True)
|
||||
|
||||
if len(password) < 8:
|
||||
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
store.initialize(password)
|
||||
print(f"Secrets store initialized at {store.secrets_file}")
|
||||
|
||||
|
||||
def cmd_set(args):
|
||||
"""Set a secret value."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
store.set(args.key, args.value)
|
||||
print(f"✓ Secret '{args.key}' saved")
|
||||
|
||||
|
||||
def cmd_get(args):
|
||||
"""Get a secret value."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
value = store.get(args.key)
|
||||
if value is None:
|
||||
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Print to stdout (can be captured)
|
||||
print(value)
|
||||
|
||||
|
||||
def cmd_delete(args):
|
||||
"""Delete a secret."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if store.delete(args.key):
|
||||
print(f"✓ Secret '{args.key}' deleted")
|
||||
else:
|
||||
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_list(args):
|
||||
"""List all secret keys."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
keys = store.list_keys()
|
||||
|
||||
if not keys:
|
||||
print("No secrets stored")
|
||||
else:
|
||||
print(f"Stored secrets ({len(keys)}):")
|
||||
for key in sorted(keys):
|
||||
# Show key and value length for verification
|
||||
value = store.get(key)
|
||||
value_str = str(value)
|
||||
value_preview = value_str[:50] + "..." if len(value_str) > 50 else value_str
|
||||
print(f" {key}: {value_preview}")
|
||||
|
||||
|
||||
def cmd_change_password(args):
|
||||
"""Change the master password."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
current_password = get_password("Current master password: ")
|
||||
new_password = get_password("New master password: ", confirm=True)
|
||||
|
||||
if len(new_password) < 8:
|
||||
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
store.change_master_password(current_password, new_password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid current password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_export(args):
|
||||
"""Export encrypted secrets to a backup file."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
output_path = Path(args.file)
|
||||
|
||||
if output_path.exists() and not args.force:
|
||||
print(f"Error: File {output_path} already exists. Use --force to overwrite.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
store.export_encrypted(output_path)
|
||||
except SecretsStoreError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_import(args):
|
||||
"""Import encrypted secrets from a backup file."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
input_path = Path(args.file)
|
||||
|
||||
if not input_path.exists():
|
||||
print(f"Error: File {input_path} does not exist", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.import_encrypted(input_path, password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password or incompatible backup", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except SecretsStoreError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_migrate_from_env(args):
|
||||
"""Migrate secrets from .env file to encrypted store."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Look for .env file
|
||||
backend_root = Path(__file__).parent.parent.parent
|
||||
env_file = backend_root / ".env"
|
||||
|
||||
if not env_file.exists():
|
||||
print(f"Error: .env file not found at {env_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse .env file (simple parser - doesn't handle all edge cases)
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
|
||||
with open(env_file) as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Parse KEY=VALUE format
|
||||
if '=' not in line:
|
||||
print(f"Warning: Skipping invalid line {line_num}: {line}", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
key, value = line.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Remove quotes if present
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
|
||||
# Check if key already exists
|
||||
existing = store.get(key)
|
||||
if existing is not None:
|
||||
print(f"Warning: Secret '{key}' already exists, skipping", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
store.set(key, value)
|
||||
print(f"✓ Migrated: {key}")
|
||||
migrated += 1
|
||||
|
||||
print(f"\nMigration complete: {migrated} secrets migrated, {skipped} skipped")
|
||||
|
||||
if not args.keep_env:
|
||||
# Ask for confirmation before deleting .env
|
||||
confirm = input(f"\nDelete {env_file}? [y/N]: ").strip().lower()
|
||||
if confirm == 'y':
|
||||
env_file.unlink()
|
||||
print(f"✓ Deleted {env_file}")
|
||||
else:
|
||||
print(f"Kept {env_file} (consider deleting it manually)")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Manage encrypted secrets store",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||
subparsers.required = True
|
||||
|
||||
# init
|
||||
parser_init = subparsers.add_parser('init', help='Initialize new secrets store')
|
||||
parser_init.set_defaults(func=cmd_init)
|
||||
|
||||
# set
|
||||
parser_set = subparsers.add_parser('set', help='Set a secret value')
|
||||
parser_set.add_argument('key', help='Secret key name')
|
||||
parser_set.add_argument('value', help='Secret value')
|
||||
parser_set.set_defaults(func=cmd_set)
|
||||
|
||||
# get
|
||||
parser_get = subparsers.add_parser('get', help='Get a secret value')
|
||||
parser_get.add_argument('key', help='Secret key name')
|
||||
parser_get.set_defaults(func=cmd_get)
|
||||
|
||||
# delete
|
||||
parser_delete = subparsers.add_parser('delete', help='Delete a secret')
|
||||
parser_delete.add_argument('key', help='Secret key name')
|
||||
parser_delete.set_defaults(func=cmd_delete)
|
||||
|
||||
# list
|
||||
parser_list = subparsers.add_parser('list', help='List all secret keys')
|
||||
parser_list.set_defaults(func=cmd_list)
|
||||
|
||||
# change-password
|
||||
parser_change = subparsers.add_parser('change-password', help='Change master password')
|
||||
parser_change.set_defaults(func=cmd_change_password)
|
||||
|
||||
# export
|
||||
parser_export = subparsers.add_parser('export', help='Export encrypted backup')
|
||||
parser_export.add_argument('file', help='Output file path')
|
||||
parser_export.add_argument('--force', action='store_true', help='Overwrite existing file')
|
||||
parser_export.set_defaults(func=cmd_export)
|
||||
|
||||
# import
|
||||
parser_import = subparsers.add_parser('import', help='Import encrypted backup')
|
||||
parser_import.add_argument('file', help='Input file path')
|
||||
parser_import.set_defaults(func=cmd_import)
|
||||
|
||||
# migrate-from-env
|
||||
parser_migrate = subparsers.add_parser('migrate-from-env', help='Migrate from .env file')
|
||||
parser_migrate.add_argument('--keep-env', action='store_true', help='Keep .env file after migration')
|
||||
parser_migrate.set_defaults(func=cmd_migrate_from_env)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
args.func(args)
|
||||
except KeyboardInterrupt:
|
||||
print("\nAborted", file=sys.stderr)
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
144
backend.old/src/secrets_manager/crypto.py
Normal file
144
backend.old/src/secrets_manager/crypto.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Cryptographic utilities for secrets management.
|
||||
|
||||
Uses Argon2id for password-based key derivation and Fernet for encryption.
|
||||
"""
|
||||
import os
|
||||
import secrets as secrets_module
|
||||
from typing import Tuple
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.low_level import hash_secret_raw, Type
|
||||
from cryptography.fernet import Fernet
|
||||
import base64
|
||||
|
||||
|
||||
# Argon2id parameters (OWASP recommended for password-based KDF)
|
||||
# These provide strong defense against GPU/ASIC attacks
|
||||
ARGON2_TIME_COST = 3 # iterations
|
||||
ARGON2_MEMORY_COST = 65536 # 64 MB
|
||||
ARGON2_PARALLELISM = 4 # threads
|
||||
ARGON2_HASH_LENGTH = 32 # bytes (256 bits for Fernet key)
|
||||
ARGON2_SALT_LENGTH = 16 # bytes (128 bits)
|
||||
|
||||
|
||||
def generate_salt() -> bytes:
|
||||
"""Generate a cryptographically secure random salt."""
|
||||
return secrets_module.token_bytes(ARGON2_SALT_LENGTH)
|
||||
|
||||
|
||||
def derive_key_from_password(password: str, salt: bytes) -> bytes:
|
||||
"""
|
||||
Derive an encryption key from a password using Argon2id.
|
||||
|
||||
Args:
|
||||
password: The master password
|
||||
salt: The salt (must be consistent for the same password to work)
|
||||
|
||||
Returns:
|
||||
32-byte key suitable for Fernet encryption
|
||||
"""
|
||||
password_bytes = password.encode('utf-8')
|
||||
|
||||
# Use Argon2id (hybrid mode - best of Argon2i and Argon2d)
|
||||
raw_hash = hash_secret_raw(
|
||||
secret=password_bytes,
|
||||
salt=salt,
|
||||
time_cost=ARGON2_TIME_COST,
|
||||
memory_cost=ARGON2_MEMORY_COST,
|
||||
parallelism=ARGON2_PARALLELISM,
|
||||
hash_len=ARGON2_HASH_LENGTH,
|
||||
type=Type.ID # Argon2id
|
||||
)
|
||||
|
||||
return raw_hash
|
||||
|
||||
|
||||
def create_fernet(key: bytes) -> Fernet:
|
||||
"""
|
||||
Create a Fernet cipher instance from a raw key.
|
||||
|
||||
Args:
|
||||
key: 32-byte raw key from Argon2id
|
||||
|
||||
Returns:
|
||||
Fernet instance for encryption/decryption
|
||||
"""
|
||||
# Fernet requires a URL-safe base64-encoded 32-byte key
|
||||
fernet_key = base64.urlsafe_b64encode(key)
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def encrypt_data(data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt data using Fernet (AES-256-CBC).
|
||||
|
||||
Args:
|
||||
data: Raw bytes to encrypt
|
||||
key: 32-byte encryption key
|
||||
|
||||
Returns:
|
||||
Encrypted data (includes IV and auth tag)
|
||||
"""
|
||||
fernet = create_fernet(key)
|
||||
return fernet.encrypt(data)
|
||||
|
||||
|
||||
def decrypt_data(encrypted_data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt data using Fernet.
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted bytes from encrypt_data
|
||||
key: 32-byte encryption key (must match encryption key)
|
||||
|
||||
Returns:
|
||||
Decrypted raw bytes
|
||||
|
||||
Raises:
|
||||
cryptography.fernet.InvalidToken: If decryption fails (wrong key/corrupted data)
|
||||
"""
|
||||
fernet = create_fernet(key)
|
||||
return fernet.decrypt(encrypted_data)
|
||||
|
||||
|
||||
def create_verification_hash(password: str, salt: bytes) -> str:
|
||||
"""
|
||||
Create a verification hash to check if a password is correct.
|
||||
|
||||
This is NOT for storing the password - it's for verifying the password
|
||||
unlocks the correct key without trying to decrypt the entire secrets file.
|
||||
|
||||
Args:
|
||||
password: The master password
|
||||
salt: The salt used for key derivation
|
||||
|
||||
Returns:
|
||||
Base64-encoded hash for verification
|
||||
"""
|
||||
# Derive key and hash it again for verification
|
||||
key = derive_key_from_password(password, salt)
|
||||
|
||||
# Simple hash of the key for verification (not security critical since
|
||||
# the key itself is already derived from Argon2id)
|
||||
verification = base64.b64encode(key[:16]).decode('ascii')
|
||||
|
||||
return verification
|
||||
|
||||
|
||||
def verify_password(password: str, salt: bytes, verification_hash: str) -> bool:
|
||||
"""
|
||||
Verify a password against a verification hash.
|
||||
|
||||
Args:
|
||||
password: Password to verify
|
||||
salt: Salt used for key derivation
|
||||
verification_hash: Expected verification hash
|
||||
|
||||
Returns:
|
||||
True if password is correct, False otherwise
|
||||
"""
|
||||
computed_hash = create_verification_hash(password, salt)
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
return secrets_module.compare_digest(computed_hash, verification_hash)
|
||||
406
backend.old/src/secrets_manager/store.py
Normal file
406
backend.old/src/secrets_manager/store.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Encrypted secrets store with master password protection.
|
||||
|
||||
The secrets are stored in an encrypted file, with the encryption key derived
|
||||
from a master password using Argon2id. The master password can be changed
|
||||
without re-encrypting all secrets.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from .crypto import (
|
||||
generate_salt,
|
||||
derive_key_from_password,
|
||||
encrypt_data,
|
||||
decrypt_data,
|
||||
create_verification_hash,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
class SecretsStoreError(Exception):
|
||||
"""Base exception for secrets store errors."""
|
||||
pass
|
||||
|
||||
|
||||
class SecretsStoreLocked(SecretsStoreError):
|
||||
"""Raised when trying to access secrets while store is locked."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMasterPassword(SecretsStoreError):
|
||||
"""Raised when master password is incorrect."""
|
||||
pass
|
||||
|
||||
|
||||
class SecretsStore:
|
||||
"""
|
||||
Encrypted secrets store with master password protection.
|
||||
|
||||
Usage:
|
||||
# Initialize (first time)
|
||||
store = SecretsStore()
|
||||
store.initialize("my-secure-password")
|
||||
|
||||
# Unlock
|
||||
store = SecretsStore()
|
||||
store.unlock("my-secure-password")
|
||||
|
||||
# Access secrets
|
||||
api_key = store.get("ANTHROPIC_API_KEY")
|
||||
store.set("NEW_SECRET", "secret-value")
|
||||
|
||||
# Change master password
|
||||
store.change_master_password("my-secure-password", "new-password")
|
||||
|
||||
# Lock when done
|
||||
store.lock()
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize secrets store.
|
||||
|
||||
Args:
|
||||
data_dir: Directory for secrets files (defaults to backend/data)
|
||||
"""
|
||||
if data_dir is None:
|
||||
# Default to backend/data
|
||||
backend_root = Path(__file__).parent.parent.parent
|
||||
data_dir = backend_root / "data"
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.master_key_file = self.data_dir / ".master.key"
|
||||
self.secrets_file = self.data_dir / "secrets.enc"
|
||||
|
||||
# Runtime state
|
||||
self._encryption_key: Optional[bytes] = None
|
||||
self._secrets: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the secrets store has been initialized."""
|
||||
return self.master_key_file.exists()
|
||||
|
||||
@property
|
||||
def is_unlocked(self) -> bool:
|
||||
"""Check if the secrets store is currently unlocked."""
|
||||
return self._encryption_key is not None
|
||||
|
||||
def initialize(self, master_password: str) -> None:
|
||||
"""
|
||||
Initialize the secrets store with a master password.
|
||||
|
||||
This should only be called once when setting up the store.
|
||||
|
||||
Args:
|
||||
master_password: The master password to protect the secrets
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If store is already initialized
|
||||
"""
|
||||
if self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is already initialized. "
|
||||
"Use unlock() to access it or change_master_password() to change the password."
|
||||
)
|
||||
|
||||
# Generate a new random salt
|
||||
salt = generate_salt()
|
||||
|
||||
# Derive encryption key
|
||||
encryption_key = derive_key_from_password(master_password, salt)
|
||||
|
||||
# Create verification hash
|
||||
verification_hash = create_verification_hash(master_password, salt)
|
||||
|
||||
# Store salt and verification hash
|
||||
master_key_data = {
|
||||
"salt": salt.hex(),
|
||||
"verification": verification_hash,
|
||||
}
|
||||
|
||||
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||
|
||||
# Set restrictive permissions (owner read/write only)
|
||||
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
# Initialize empty secrets
|
||||
self._encryption_key = encryption_key
|
||||
self._secrets = {}
|
||||
self._save_secrets()
|
||||
|
||||
print(f"✓ Secrets store initialized at {self.secrets_file}")
|
||||
|
||||
def unlock(self, master_password: str) -> None:
|
||||
"""
|
||||
Unlock the secrets store with the master password.
|
||||
|
||||
Args:
|
||||
master_password: The master password
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If store is not initialized
|
||||
InvalidMasterPassword: If password is incorrect
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
# Load salt and verification hash
|
||||
master_key_data = json.loads(self.master_key_file.read_text())
|
||||
salt = bytes.fromhex(master_key_data["salt"])
|
||||
verification_hash = master_key_data["verification"]
|
||||
|
||||
# Verify password
|
||||
if not verify_password(master_password, salt, verification_hash):
|
||||
raise InvalidMasterPassword("Invalid master password")
|
||||
|
||||
# Derive encryption key
|
||||
encryption_key = derive_key_from_password(master_password, salt)
|
||||
|
||||
# Load and decrypt secrets
|
||||
if self.secrets_file.exists():
|
||||
try:
|
||||
encrypted_data = self.secrets_file.read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
except InvalidToken:
|
||||
raise InvalidMasterPassword("Failed to decrypt secrets (invalid password)")
|
||||
except json.JSONDecodeError as e:
|
||||
raise SecretsStoreError(f"Corrupted secrets file: {e}")
|
||||
else:
|
||||
# No secrets file yet (fresh initialization)
|
||||
self._secrets = {}
|
||||
|
||||
self._encryption_key = encryption_key
|
||||
print(f"✓ Secrets store unlocked ({len(self._secrets)} secrets)")
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock the secrets store (clear decrypted data from memory)."""
|
||||
self._encryption_key = None
|
||||
self._secrets = None
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a secret value.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
default: Default value if key doesn't exist
|
||||
|
||||
Returns:
|
||||
Secret value or default
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
return self._secrets.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Set a secret value.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
value: Secret value (must be JSON-serializable)
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
self._secrets[key] = value
|
||||
self._save_secrets()
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a secret.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
|
||||
Returns:
|
||||
True if secret existed and was deleted, False otherwise
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
if key in self._secrets:
|
||||
del self._secrets[key]
|
||||
self._save_secrets()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_keys(self) -> list[str]:
|
||||
"""
|
||||
List all secret keys.
|
||||
|
||||
Returns:
|
||||
List of secret keys
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
return list(self._secrets.keys())
|
||||
|
||||
def change_master_password(self, current_password: str, new_password: str) -> None:
|
||||
"""
|
||||
Change the master password.
|
||||
|
||||
This re-encrypts the secrets with a new key derived from the new password.
|
||||
|
||||
Args:
|
||||
current_password: Current master password
|
||||
new_password: New master password
|
||||
|
||||
Raises:
|
||||
InvalidMasterPassword: If current password is incorrect
|
||||
"""
|
||||
# ALWAYS verify current password before changing
|
||||
# Load salt and verification hash
|
||||
if not self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
master_key_data = json.loads(self.master_key_file.read_text())
|
||||
salt = bytes.fromhex(master_key_data["salt"])
|
||||
verification_hash = master_key_data["verification"]
|
||||
|
||||
# Verify current password is correct
|
||||
if not verify_password(current_password, salt, verification_hash):
|
||||
raise InvalidMasterPassword("Invalid current password")
|
||||
|
||||
# Unlock if needed to access secrets
|
||||
was_unlocked = self.is_unlocked
|
||||
if not was_unlocked:
|
||||
# Store is locked, so unlock with current password
|
||||
# (we already verified it above, so this will succeed)
|
||||
encryption_key = derive_key_from_password(current_password, salt)
|
||||
|
||||
# Load and decrypt secrets
|
||||
if self.secrets_file.exists():
|
||||
encrypted_data = self.secrets_file.read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
else:
|
||||
self._secrets = {}
|
||||
|
||||
self._encryption_key = encryption_key
|
||||
|
||||
# Generate new salt
|
||||
new_salt = generate_salt()
|
||||
|
||||
# Derive new encryption key
|
||||
new_encryption_key = derive_key_from_password(new_password, new_salt)
|
||||
|
||||
# Create new verification hash
|
||||
new_verification_hash = create_verification_hash(new_password, new_salt)
|
||||
|
||||
# Update master key file
|
||||
master_key_data = {
|
||||
"salt": new_salt.hex(),
|
||||
"verification": new_verification_hash,
|
||||
}
|
||||
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
# Re-encrypt secrets with new key
|
||||
old_key = self._encryption_key
|
||||
self._encryption_key = new_encryption_key
|
||||
self._save_secrets()
|
||||
|
||||
print("✓ Master password changed successfully")
|
||||
|
||||
# Lock if it wasn't unlocked before
|
||||
if not was_unlocked:
|
||||
self.lock()
|
||||
|
||||
def _save_secrets(self) -> None:
|
||||
"""Save secrets to encrypted file."""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Cannot save while locked")
|
||||
|
||||
# Serialize secrets to JSON
|
||||
secrets_json = json.dumps(self._secrets, indent=2)
|
||||
secrets_bytes = secrets_json.encode('utf-8')
|
||||
|
||||
# Encrypt
|
||||
encrypted_data = encrypt_data(secrets_bytes, self._encryption_key)
|
||||
|
||||
# Write to file
|
||||
self.secrets_file.write_bytes(encrypted_data)
|
||||
|
||||
# Set restrictive permissions
|
||||
os.chmod(self.secrets_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
def export_encrypted(self, output_path: Path) -> None:
|
||||
"""
|
||||
Export encrypted secrets to a file (for backup).
|
||||
|
||||
Args:
|
||||
output_path: Path to export file
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If secrets file doesn't exist
|
||||
"""
|
||||
if not self.secrets_file.exists():
|
||||
raise SecretsStoreError("No secrets to export")
|
||||
|
||||
import shutil
|
||||
shutil.copy2(self.secrets_file, output_path)
|
||||
print(f"✓ Encrypted secrets exported to {output_path}")
|
||||
|
||||
def import_encrypted(self, input_path: Path, master_password: str) -> None:
|
||||
"""
|
||||
Import encrypted secrets from a file.
|
||||
|
||||
This will verify the password can decrypt the import before replacing
|
||||
the current secrets.
|
||||
|
||||
Args:
|
||||
input_path: Path to import file
|
||||
master_password: Master password for the current store
|
||||
|
||||
Raises:
|
||||
InvalidMasterPassword: If password doesn't work with import
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
self.unlock(master_password)
|
||||
|
||||
# Try to decrypt the imported file with current key
|
||||
try:
|
||||
encrypted_data = Path(input_path).read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, self._encryption_key)
|
||||
imported_secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
except InvalidToken:
|
||||
raise InvalidMasterPassword(
|
||||
"Cannot decrypt imported secrets with current master password"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise SecretsStoreError(f"Corrupted import file: {e}")
|
||||
|
||||
# Replace secrets
|
||||
self._secrets = imported_secrets
|
||||
self._save_secrets()
|
||||
|
||||
print(f"✓ Imported {len(self._secrets)} secrets from {input_path}")
|
||||
42
backend.old/src/sync/protocol.py
Normal file
42
backend.old/src/sync/protocol.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthMessage(BaseModel):
|
||||
"""Authentication message (must be first message from client)"""
|
||||
type: Literal["auth"] = "auth"
|
||||
password: str
|
||||
confirm_password: Optional[str] = None # Required only for initialization
|
||||
change_to_password: Optional[str] = None # If provided, change password after auth
|
||||
confirm_new_password: Optional[str] = None # Required if change_to_password is set
|
||||
|
||||
class AuthResponseMessage(BaseModel):
|
||||
"""Authentication response from server"""
|
||||
type: Literal["auth_response"] = "auth_response"
|
||||
success: bool
|
||||
needs_confirmation: bool = False # True if this is first-time setup
|
||||
password_changed: bool = False # True if password was changed
|
||||
message: str
|
||||
|
||||
class SnapshotMessage(BaseModel):
|
||||
type: Literal["snapshot"] = "snapshot"
|
||||
store: str
|
||||
seq: int
|
||||
state: Dict[str, Any]
|
||||
|
||||
class PatchMessage(BaseModel):
|
||||
type: Literal["patch"] = "patch"
|
||||
store: str
|
||||
seq: int
|
||||
patch: List[Dict[str, Any]]
|
||||
|
||||
class HelloMessage(BaseModel):
|
||||
type: Literal["hello"] = "hello"
|
||||
seqs: Dict[str, int]
|
||||
|
||||
# Union type for all messages from backend to frontend
|
||||
BackendMessage = Union[SnapshotMessage, PatchMessage, AuthResponseMessage]
|
||||
|
||||
# Union type for all messages from frontend to backend
|
||||
FrontendMessage = Union[AuthMessage, HelloMessage, PatchMessage]
|
||||
246
backend.old/src/sync/registry.py
Normal file
246
backend.old/src/sync/registry.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional, Tuple, Deque
|
||||
|
||||
import jsonpatch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sync.protocol import SnapshotMessage, PatchMessage
|
||||
|
||||
|
||||
class SyncEntry:
|
||||
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
|
||||
self.model = model
|
||||
self.store_name = store_name
|
||||
self.seq = 0
|
||||
self.last_snapshot = model.model_dump(mode="json")
|
||||
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
|
||||
|
||||
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
|
||||
current_state = self.model.model_dump(mode="json")
|
||||
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
|
||||
if not patch.patch:
|
||||
return None
|
||||
return patch.patch
|
||||
|
||||
def commit_patch(self, patch: List[Dict[str, Any]]):
|
||||
self.seq += 1
|
||||
self.history.append((self.seq, patch))
|
||||
self.last_snapshot = self.model.model_dump(mode="json")
|
||||
|
||||
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
|
||||
if since_seq == self.seq:
|
||||
return []
|
||||
|
||||
# Check if all patches from since_seq + 1 to self.seq are in history
|
||||
if not self.history or self.history[0][0] > since_seq + 1:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for seq, patch in self.history:
|
||||
if seq > since_seq:
|
||||
result.append((seq, patch))
|
||||
return result
|
||||
|
||||
class SyncRegistry:
|
||||
def __init__(self):
|
||||
self.entries: Dict[str, SyncEntry] = {}
|
||||
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
|
||||
|
||||
def register(self, model: BaseModel, store_name: str):
|
||||
self.entries[store_name] = SyncEntry(model, store_name)
|
||||
|
||||
async def push_all(self):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not self.websocket:
|
||||
logger.warning("push_all: No websocket connected, cannot push updates")
|
||||
return
|
||||
|
||||
logger.info(f"push_all: Processing {len(self.entries)} store entries")
|
||||
for entry in self.entries.values():
|
||||
patch = entry.compute_patch()
|
||||
if patch:
|
||||
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
|
||||
entry.commit_patch(patch)
|
||||
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
|
||||
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
|
||||
else:
|
||||
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
|
||||
|
||||
async def sync_client(self, client_seqs: Dict[str, int]):
|
||||
if not self.websocket:
|
||||
return
|
||||
|
||||
for store_name, entry in self.entries.items():
|
||||
client_seq = client_seqs.get(store_name, -1)
|
||||
patches = entry.catchup_patches(client_seq)
|
||||
|
||||
if patches is not None:
|
||||
# Replay patches
|
||||
for seq, patch in patches:
|
||||
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
else:
|
||||
# Send full snapshot
|
||||
msg = SnapshotMessage(
|
||||
store=store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
|
||||
|
||||
entry = self.entries.get(store_name)
|
||||
if not entry:
|
||||
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
|
||||
return
|
||||
|
||||
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
|
||||
|
||||
try:
|
||||
if client_base_seq == entry.seq:
|
||||
# No conflict
|
||||
logger.info("apply_client_patch: No conflict - applying patch directly")
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Verify the model was actually updated
|
||||
updated_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
|
||||
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
# Broadcasting would cause an infinite loop
|
||||
logger.info("apply_client_patch: Not broadcasting back to originating client")
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
elif client_base_seq < entry.seq:
|
||||
# Conflict! Frontend wins.
|
||||
# 1. Get backend patches since client_base_seq
|
||||
backend_patches = []
|
||||
for seq, p in entry.history:
|
||||
if seq > client_base_seq:
|
||||
backend_patches.append(p)
|
||||
|
||||
# 2. Apply frontend patch first to the state at client_base_seq
|
||||
# But we only have the current authoritative model.
|
||||
# "Apply the frontend patch first to the model (frontend wins)"
|
||||
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
|
||||
|
||||
# Let's get the state as it was at client_base_seq if possible?
|
||||
# No, history only has patches.
|
||||
|
||||
# Alternative: Apply frontend patch to current model.
|
||||
# Then re-apply backend patches, but discard parts that overlap.
|
||||
|
||||
frontend_paths = {p['path'] for p in patch}
|
||||
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
# Apply frontend patch
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
return
|
||||
|
||||
# Re-apply backend patches that don't overlap
|
||||
for b_patch in backend_patches:
|
||||
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
|
||||
if filtered_b_patch:
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
|
||||
continue
|
||||
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Commit the result as a single new patch
|
||||
# We need to compute what changed from last_snapshot to new_state
|
||||
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
|
||||
if final_patch:
|
||||
entry.commit_patch(final_patch)
|
||||
# Broadcast resolved state as snapshot to converge
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
||||
# Update model fields in-place to preserve references
|
||||
# This is important for dict fields that may be referenced elsewhere
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
if field_name in new_data:
|
||||
new_value = new_data[field_name]
|
||||
current_value = getattr(model, field_name)
|
||||
|
||||
# For dict fields, update in-place instead of replacing
|
||||
if isinstance(current_value, dict) and isinstance(new_value, dict):
|
||||
self._deep_update_dict(current_value, new_value)
|
||||
else:
|
||||
# For other types, just set the new value
|
||||
setattr(model, field_name, new_value)
|
||||
|
||||
def _deep_update_dict(self, target: dict, source: dict):
|
||||
"""Deep update target dict with source dict, preserving nested dict references."""
|
||||
# Remove keys that are in target but not in source
|
||||
keys_to_remove = set(target.keys()) - set(source.keys())
|
||||
for key in keys_to_remove:
|
||||
del target[key]
|
||||
|
||||
# Update or add keys from source
|
||||
for key, source_value in source.items():
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# If both are dicts, recursively update
|
||||
if isinstance(target_value, dict) and isinstance(source_value, dict):
|
||||
self._deep_update_dict(target_value, source_value)
|
||||
else:
|
||||
# Replace the value
|
||||
target[key] = source_value
|
||||
else:
|
||||
# Add new key
|
||||
target[key] = source_value
|
||||
216
backend.old/src/trigger/PRIORITIES.md
Normal file
216
backend.old/src/trigger/PRIORITIES.md
Normal file
@@ -0,0 +1,216 @@
|
||||
# Priority System
|
||||
|
||||
Simple tuple-based priorities for deterministic execution ordering.
|
||||
|
||||
## Basic Concept
|
||||
|
||||
Priorities are just **Python tuples**. Python compares tuples element-by-element, left-to-right:
|
||||
|
||||
```python
|
||||
(0, 1000, 5) < (0, 1001, 3) # True: 0==0, but 1000 < 1001
|
||||
(0, 1000, 5) < (1, 500, 2) # True: 0 < 1
|
||||
(0, 1000) < (0, 1000, 5) # True: shorter wins if equal so far
|
||||
```
|
||||
|
||||
**Lower values = higher priority** (processed first).
|
||||
|
||||
## Priority Categories
|
||||
|
||||
```python
|
||||
class Priority(IntEnum):
|
||||
DATA_SOURCE = 0 # Market data, real-time feeds
|
||||
TIMER = 1 # Scheduled tasks, cron jobs
|
||||
USER_AGENT = 2 # User-agent interactions (chat)
|
||||
USER_DATA_REQUEST = 3 # User data requests (charts)
|
||||
SYSTEM = 4 # Background tasks, cleanup
|
||||
LOW = 5 # Retries after conflicts
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Simple Priority
|
||||
|
||||
```python
|
||||
# Just use the Priority enum
|
||||
trigger = MyTrigger("task", priority=Priority.SYSTEM)
|
||||
await queue.enqueue(trigger)
|
||||
|
||||
# Results in tuple: (4, queue_seq)
|
||||
```
|
||||
|
||||
### Compound Priority (Tuple)
|
||||
|
||||
```python
|
||||
# DataSource: sort by event time (older bars first)
|
||||
trigger = DataUpdateTrigger(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
bar_data={"time": 1678896000, "open": 50000, ...}
|
||||
)
|
||||
await queue.enqueue(trigger)
|
||||
|
||||
# Results in tuple: (0, 1678896000, queue_seq)
|
||||
# ^ ^ ^
|
||||
# | | Queue insertion order (FIFO)
|
||||
# | Event time (candle end time)
|
||||
# DATA_SOURCE priority
|
||||
```
|
||||
|
||||
### Manual Override
|
||||
|
||||
```python
|
||||
# Override at enqueue time
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
priority_override=(Priority.DATA_SOURCE, custom_time, custom_sort)
|
||||
)
|
||||
|
||||
# Queue appends queue_seq: (0, custom_time, custom_sort, queue_seq)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Market Data (Process Chronologically)
|
||||
|
||||
```python
|
||||
# Bar from 10:00 → (0, 10:00_timestamp, queue_seq)
|
||||
# Bar from 10:05 → (0, 10:05_timestamp, queue_seq)
|
||||
#
|
||||
# 10:00 bar processes first (earlier event_time)
|
||||
|
||||
DataUpdateTrigger(
|
||||
...,
|
||||
bar_data={"time": event_timestamp, ...}
|
||||
)
|
||||
```
|
||||
|
||||
### User Messages (FIFO Order)
|
||||
|
||||
```python
|
||||
# Message #1 → (2, msg1_timestamp, queue_seq)
|
||||
# Message #2 → (2, msg2_timestamp, queue_seq)
|
||||
#
|
||||
# Message #1 processes first (earlier timestamp)
|
||||
|
||||
AgentTriggerHandler(
|
||||
session_id="user1",
|
||||
message_content="...",
|
||||
message_timestamp=unix_timestamp # Optional, defaults to now
|
||||
)
|
||||
```
|
||||
|
||||
### Scheduled Tasks (By Schedule Time)
|
||||
|
||||
```python
|
||||
# Job scheduled for 9 AM → (1, 9am_timestamp, queue_seq)
|
||||
# Job scheduled for 2 PM → (1, 2pm_timestamp, queue_seq)
|
||||
#
|
||||
# 9 AM job processes first
|
||||
|
||||
CronTrigger(
|
||||
name="morning_sync",
|
||||
inner_trigger=...,
|
||||
scheduled_time=scheduled_timestamp
|
||||
)
|
||||
```
|
||||
|
||||
## Execution Order Example
|
||||
|
||||
```
|
||||
Queue contains:
|
||||
1. DataSource (BTC @ 10:00) → (0, 10:00, 1)
|
||||
2. DataSource (BTC @ 10:05) → (0, 10:05, 2)
|
||||
3. Timer (scheduled 9 AM) → (1, 09:00, 3)
|
||||
4. User message #1 → (2, 14:30, 4)
|
||||
5. User message #2 → (2, 14:35, 5)
|
||||
|
||||
Dequeue order:
|
||||
1. DataSource (BTC @ 10:00) ← 0 < all others
|
||||
2. DataSource (BTC @ 10:05) ← 0 < all others, 10:05 > 10:00
|
||||
3. Timer (scheduled 9 AM) ← 1 < remaining
|
||||
4. User message #1 ← 2 < remaining, 14:30 < 14:35
|
||||
5. User message #2 ← last
|
||||
```
|
||||
|
||||
## Short Tuple Wins
|
||||
|
||||
If tuples are equal up to the length of the shorter one, **shorter tuple has higher priority**:
|
||||
|
||||
```python
|
||||
(0, 1000) < (0, 1000, 5) # True: shorter wins
|
||||
(0,) < (0, 1000) # True: shorter wins
|
||||
(Priority.DATA_SOURCE,) < (Priority.DATA_SOURCE, 1000) # True
|
||||
```
|
||||
|
||||
This is Python's default tuple comparison behavior. In practice, we always append `queue_seq`, so this rarely matters (all tuples end up same length).
|
||||
|
||||
## Integration with Triggers
|
||||
|
||||
### Trigger Sets Its Own Priority
|
||||
|
||||
```python
|
||||
class MyTrigger(Trigger):
|
||||
def __init__(self, event_time):
|
||||
super().__init__(
|
||||
name="my_trigger",
|
||||
priority=Priority.DATA_SOURCE,
|
||||
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
|
||||
)
|
||||
```
|
||||
|
||||
Queue appends `queue_seq` automatically:
|
||||
```python
|
||||
# Trigger's tuple: (0, event_time)
|
||||
# After enqueue: (0, event_time, queue_seq)
|
||||
```
|
||||
|
||||
### Override at Enqueue
|
||||
|
||||
```python
|
||||
# Ignore trigger's priority, use override
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
priority_override=(Priority.TIMER, scheduled_time)
|
||||
)
|
||||
```
|
||||
|
||||
## Why Tuples?
|
||||
|
||||
✅ **Simple**: No custom classes, just native Python tuples
|
||||
✅ **Flexible**: Add as many sort keys as needed
|
||||
✅ **Efficient**: Python's tuple comparison is highly optimized
|
||||
✅ **Readable**: `(0, 1000, 5)` is obvious what it means
|
||||
✅ **Debuggable**: Can print and inspect easily
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Old: CompoundPriority(primary=0, secondary=1000, tertiary=5)
|
||||
# New: (0, 1000, 5)
|
||||
|
||||
# Same semantics, much simpler!
|
||||
```
|
||||
|
||||
## Advanced: Custom Sorting
|
||||
|
||||
Want to sort by multiple factors? Just add more elements:
|
||||
|
||||
```python
|
||||
# Sort by: priority → symbol → event_time → queue_seq
|
||||
priority_tuple = (
|
||||
Priority.DATA_SOURCE.value,
|
||||
symbol_id, # e.g., hash("BTC/USDT")
|
||||
event_time,
|
||||
# queue_seq appended by queue
|
||||
)
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
- **Priorities are tuples**: `(primary, secondary, ..., queue_seq)`
|
||||
- **Lower = higher priority**: Processed first
|
||||
- **Element-by-element comparison**: Left-to-right
|
||||
- **Shorter tuple wins**: If equal up to shorter length
|
||||
- **Queue appends queue_seq**: Always last element (FIFO within same priority)
|
||||
|
||||
That's it! No complex classes, just tuples. 🎯
|
||||
386
backend.old/src/trigger/README.md
Normal file
386
backend.old/src/trigger/README.md
Normal file
@@ -0,0 +1,386 @@
|
||||
# Trigger System
|
||||
|
||||
Lock-free, sequence-based execution system for deterministic event processing.
|
||||
|
||||
## Overview
|
||||
|
||||
All operations (WebSocket messages, cron tasks, data updates) flow through a **priority queue**, execute in **parallel**, but commit in **strict sequential order** with **optimistic conflict detection**.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Lock-free reads**: Snapshots are deep copies, no blocking
|
||||
- **Sequential commits**: Total ordering via sequence numbers
|
||||
- **Optimistic concurrency**: Conflicts detected, retry with same seq
|
||||
- **Priority preservation**: High-priority work never blocked by low-priority
|
||||
- **Long-running agents**: Execute in parallel, commit sequentially
|
||||
- **Deterministic replay**: Can reproduce exact system state at any seq
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ WebSocket │───┐
|
||||
│ Messages │ │
|
||||
└─────────────┘ │
|
||||
├──→ ┌─────────────────┐
|
||||
┌─────────────┐ │ │ TriggerQueue │
|
||||
│ Cron │───┤ │ (Priority Queue)│
|
||||
│ Scheduled │ │ └────────┬────────┘
|
||||
└─────────────┘ │ │ Assign seq
|
||||
│ ↓
|
||||
┌─────────────┐ │ ┌─────────────────┐
|
||||
│ DataSource │───┘ │ Execute Trigger│
|
||||
│ Updates │ │ (Parallel OK) │
|
||||
└─────────────┘ └────────┬────────┘
|
||||
│ CommitIntents
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ CommitCoordinator│
|
||||
│ (Sequential) │
|
||||
└────────┬────────┘
|
||||
│ Commit in seq order
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ VersionedStores │
|
||||
│ (w/ Backends) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. ExecutionContext (`context.py`)
|
||||
|
||||
Tracks execution seq and store snapshots via `contextvars` (auto-propagates through async calls).
|
||||
|
||||
```python
|
||||
from trigger import get_execution_context
|
||||
|
||||
ctx = get_execution_context()
|
||||
print(f"Running at seq {ctx.seq}")
|
||||
```
|
||||
|
||||
### 2. Trigger Types (`types.py`)
|
||||
|
||||
```python
|
||||
from trigger import Trigger, Priority, CommitIntent
|
||||
|
||||
class MyTrigger(Trigger):
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
# Read snapshot
|
||||
seq, data = some_store.read_snapshot()
|
||||
|
||||
# Modify
|
||||
new_data = modify(data)
|
||||
|
||||
# Prepare commit
|
||||
intent = some_store.prepare_commit(seq, new_data)
|
||||
return [intent]
|
||||
```
|
||||
|
||||
### 3. VersionedStore (`store.py`)
|
||||
|
||||
Stores with pluggable backends and optimistic concurrency:
|
||||
|
||||
```python
|
||||
from trigger import VersionedStore, PydanticStoreBackend
|
||||
|
||||
# Wrap existing Pydantic model
|
||||
backend = PydanticStoreBackend(order_store)
|
||||
versioned_store = VersionedStore("OrderStore", backend)
|
||||
|
||||
# Lock-free snapshot read
|
||||
seq, snapshot = versioned_store.read_snapshot()
|
||||
|
||||
# Prepare commit (does not modify yet)
|
||||
intent = versioned_store.prepare_commit(seq, modified_snapshot)
|
||||
```
|
||||
|
||||
**Pluggable Backends**:
|
||||
- `PydanticStoreBackend`: For existing Pydantic models (OrderStore, ChartStore, etc.)
|
||||
- `FileStoreBackend`: Future - version files (Python scripts, configs)
|
||||
- `DatabaseStoreBackend`: Future - version database rows
|
||||
|
||||
### 4. CommitCoordinator (`coordinator.py`)
|
||||
|
||||
Manages sequential commits with conflict detection:
|
||||
|
||||
- Waits for seq N to commit before N+1
|
||||
- Detects conflicts (expected_seq vs committed_seq)
|
||||
- Re-executes (not re-enqueues) on conflict **with same seq**
|
||||
- Tracks execution state for debugging
|
||||
|
||||
### 5. TriggerQueue (`queue.py`)
|
||||
|
||||
Priority queue with seq assignment:
|
||||
|
||||
```python
|
||||
from trigger import TriggerQueue
|
||||
|
||||
queue = TriggerQueue(coordinator)
|
||||
await queue.start()
|
||||
|
||||
# Enqueue trigger
|
||||
await queue.enqueue(my_trigger, Priority.HIGH)
|
||||
```
|
||||
|
||||
### 6. TriggerScheduler (`scheduler.py`)
|
||||
|
||||
APScheduler integration for cron triggers:
|
||||
|
||||
```python
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
scheduler = TriggerScheduler(queue)
|
||||
scheduler.start()
|
||||
|
||||
# Every 5 minutes
|
||||
scheduler.schedule_interval(
|
||||
IndicatorUpdateTrigger("rsi_14"),
|
||||
minutes=5
|
||||
)
|
||||
|
||||
# Daily at 9 AM
|
||||
scheduler.schedule_cron(
|
||||
SyncExchangeStateTrigger(),
|
||||
hour="9",
|
||||
minute="0"
|
||||
)
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
### Basic Setup in `main.py`
|
||||
|
||||
```python
|
||||
from trigger import (
|
||||
CommitCoordinator,
|
||||
TriggerQueue,
|
||||
VersionedStore,
|
||||
PydanticStoreBackend,
|
||||
)
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Create coordinator
|
||||
coordinator = CommitCoordinator()
|
||||
|
||||
# Wrap existing stores
|
||||
order_store_versioned = VersionedStore(
|
||||
"OrderStore",
|
||||
PydanticStoreBackend(order_store)
|
||||
)
|
||||
coordinator.register_store(order_store_versioned)
|
||||
|
||||
chart_store_versioned = VersionedStore(
|
||||
"ChartStore",
|
||||
PydanticStoreBackend(chart_store)
|
||||
)
|
||||
coordinator.register_store(chart_store_versioned)
|
||||
|
||||
# Create queue and scheduler
|
||||
trigger_queue = TriggerQueue(coordinator)
|
||||
await trigger_queue.start()
|
||||
|
||||
scheduler = TriggerScheduler(trigger_queue)
|
||||
scheduler.start()
|
||||
```
|
||||
|
||||
### WebSocket Message Handler
|
||||
|
||||
```python
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data["type"] == "agent_user_message":
|
||||
# Enqueue agent trigger instead of direct Gateway call
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id=data["session_id"],
|
||||
message_content=data["content"],
|
||||
gateway_handler=gateway.route_user_message,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
await trigger_queue.enqueue(trigger)
|
||||
```
|
||||
|
||||
### DataSource Updates
|
||||
|
||||
```python
|
||||
from trigger.handlers import DataUpdateTrigger
|
||||
|
||||
# In subscription_manager._on_source_update()
|
||||
def _on_source_update(self, source_key: tuple, bar: dict):
|
||||
# Enqueue data update trigger
|
||||
trigger = DataUpdateTrigger(
|
||||
source_name=source_key[0],
|
||||
symbol=source_key[1],
|
||||
resolution=source_key[2],
|
||||
bar_data=bar,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
asyncio.create_task(trigger_queue.enqueue(trigger))
|
||||
```
|
||||
|
||||
### Custom Trigger
|
||||
|
||||
```python
|
||||
from trigger import Trigger, CommitIntent, Priority
|
||||
|
||||
class RecalculatePortfolioTrigger(Trigger):
|
||||
def __init__(self, coordinator):
|
||||
super().__init__("recalc_portfolio", Priority.NORMAL)
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
# Read snapshots from multiple stores
|
||||
order_seq, orders = self.coordinator.get_store("OrderStore").read_snapshot()
|
||||
chart_seq, chart = self.coordinator.get_store("ChartStore").read_snapshot()
|
||||
|
||||
# Calculate portfolio value
|
||||
portfolio_value = calculate_portfolio(orders, chart)
|
||||
|
||||
# Update chart state with portfolio value
|
||||
chart.portfolio_value = portfolio_value
|
||||
|
||||
# Prepare commit
|
||||
intent = self.coordinator.get_store("ChartStore").prepare_commit(
|
||||
chart_seq,
|
||||
chart
|
||||
)
|
||||
|
||||
return [intent]
|
||||
|
||||
# Schedule it
|
||||
scheduler.schedule_interval(
|
||||
RecalculatePortfolioTrigger(coordinator),
|
||||
minutes=1
|
||||
)
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
### Normal Flow (No Conflicts)
|
||||
|
||||
```
|
||||
seq=100: WebSocket message arrives → enqueue → dequeue → assign seq=100 → execute
|
||||
seq=101: Cron trigger fires → enqueue → dequeue → assign seq=101 → execute
|
||||
|
||||
seq=101 finishes first → waits in commit queue
|
||||
seq=100 finishes → commits immediately (next in order)
|
||||
seq=101 commits next
|
||||
```
|
||||
|
||||
### Conflict Flow
|
||||
|
||||
```
|
||||
seq=100: reads OrderStore at seq=99 → executes for 30 seconds
|
||||
seq=101: reads OrderStore at seq=99 → executes for 5 seconds
|
||||
|
||||
seq=101 finishes first → tries to commit based on seq=99
|
||||
seq=100 finishes → commits OrderStore at seq=100
|
||||
|
||||
Coordinator detects conflict:
|
||||
expected_seq=99, committed_seq=100
|
||||
|
||||
seq=101 evicted → RE-EXECUTES with same seq=101 (not re-enqueued)
|
||||
reads OrderStore at seq=100 → executes again
|
||||
finishes → commits successfully at seq=101
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
### For Agent System
|
||||
|
||||
- **Long-running agents work naturally**: Agent starts at seq=100, runs for 60 seconds while market data updates at seq=101-110, commits only if no conflicts
|
||||
- **No deadlocks**: No locks = no deadlock possibility
|
||||
- **Deterministic**: Can replay from any seq for debugging
|
||||
|
||||
### For Strategy Execution
|
||||
|
||||
- **High-frequency data doesn't block strategies**: Data updates enqueued, executed in parallel, commit sequentially
|
||||
- **Priority preservation**: Critical order execution never blocked by indicator calculations
|
||||
- **Conflict detection**: If market moved during strategy calculation, automatically retry with fresh data
|
||||
|
||||
### For Scaling
|
||||
|
||||
- **Single-node first**: Runs on single asyncio event loop, no complex distributed coordination
|
||||
- **Future-proof**: Can swap queue for Redis/PostgreSQL-backed distributed queue later
|
||||
- **Event sourcing ready**: All commits have seq numbers, can build event log
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Current State
|
||||
|
||||
```python
|
||||
# Coordinator stats
|
||||
stats = coordinator.get_stats()
|
||||
print(f"Current seq: {stats['current_seq']}")
|
||||
print(f"Pending commits: {stats['pending_commits']}")
|
||||
print(f"Executions by state: {stats['state_counts']}")
|
||||
|
||||
# Store state
|
||||
store = coordinator.get_store("OrderStore")
|
||||
print(f"Store: {store}") # Shows committed_seq and version
|
||||
|
||||
# Execution record
|
||||
record = coordinator.get_execution_record(100)
|
||||
print(f"Seq 100: {record}") # Shows state, retry_count, error
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Symptoms: High conflict rate**
|
||||
- **Cause**: Multiple triggers modifying same store frequently
|
||||
- **Solution**: Batch updates, use debouncing, or redesign to reduce contention
|
||||
|
||||
**Symptoms: Commits stuck (next_commit_seq not advancing)**
|
||||
- **Cause**: Execution at that seq failed or is taking too long
|
||||
- **Solution**: Check execution_records for that seq, look for errors in logs
|
||||
|
||||
**Symptoms: Queue depth growing**
|
||||
- **Cause**: Executions slower than enqueue rate
|
||||
- **Solution**: Profile trigger execution, optimize slow paths, add rate limiting
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Test: Conflict Detection
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from trigger import VersionedStore, PydanticStoreBackend, CommitCoordinator
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conflict_detection():
|
||||
coordinator = CommitCoordinator()
|
||||
|
||||
store = VersionedStore("TestStore", PydanticStoreBackend(TestModel()))
|
||||
coordinator.register_store(store)
|
||||
|
||||
# Seq 1: read at 0, modify, commit
|
||||
seq1, data1 = store.read_snapshot()
|
||||
data1.value = "seq1"
|
||||
intent1 = store.prepare_commit(seq1, data1)
|
||||
|
||||
# Seq 2: read at 0 (same snapshot), modify
|
||||
seq2, data2 = store.read_snapshot()
|
||||
data2.value = "seq2"
|
||||
intent2 = store.prepare_commit(seq2, data2)
|
||||
|
||||
# Commit seq 1 (should succeed)
|
||||
# ... coordinator logic ...
|
||||
|
||||
# Commit seq 2 (should conflict and retry)
|
||||
# ... verify conflict detected ...
|
||||
```
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Distributed queue**: Redis-backed queue for multi-worker deployment
|
||||
- **Event log persistence**: Store all commits for event sourcing/audit
|
||||
- **Metrics dashboard**: Real-time view of queue depth, conflict rate, latency
|
||||
- **Transaction snapshots**: Full system state at any seq for replay/debugging
|
||||
- **Automatic batching**: Coalesce rapid updates to same store
|
||||
35
backend.old/src/trigger/__init__.py
Normal file
35
backend.old/src/trigger/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Sequential execution trigger system with optimistic concurrency control.
|
||||
|
||||
All operations (websocket, cron, data events) flow through a priority queue,
|
||||
execute in parallel, but commit in strict sequential order with conflict detection.
|
||||
"""
|
||||
|
||||
from .context import ExecutionContext, get_execution_context
|
||||
from .types import Priority, PriorityTuple, Trigger, CommitIntent, ExecutionState
|
||||
from .store import VersionedStore, StoreBackend, PydanticStoreBackend
|
||||
from .coordinator import CommitCoordinator
|
||||
from .queue import TriggerQueue
|
||||
from .handlers import AgentTriggerHandler, LambdaHandler
|
||||
|
||||
__all__ = [
|
||||
# Context
|
||||
"ExecutionContext",
|
||||
"get_execution_context",
|
||||
# Types
|
||||
"Priority",
|
||||
"PriorityTuple",
|
||||
"Trigger",
|
||||
"CommitIntent",
|
||||
"ExecutionState",
|
||||
# Store
|
||||
"VersionedStore",
|
||||
"StoreBackend",
|
||||
"PydanticStoreBackend",
|
||||
# Coordination
|
||||
"CommitCoordinator",
|
||||
"TriggerQueue",
|
||||
# Handlers
|
||||
"AgentTriggerHandler",
|
||||
"LambdaHandler",
|
||||
]
|
||||
61
backend.old/src/trigger/context.py
Normal file
61
backend.old/src/trigger/context.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Execution context tracking using Python's contextvars.
|
||||
|
||||
Each execution gets a unique seq number that propagates through all async calls,
|
||||
allowing us to track which execution made which changes for conflict detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variables - automatically propagate through async call chains
|
||||
_execution_context: ContextVar[Optional["ExecutionContext"]] = ContextVar(
|
||||
"execution_context", default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for a single trigger execution.
|
||||
|
||||
Automatically propagates through async calls via contextvars.
|
||||
Tracks the seq number and which store snapshots were read.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
"""Sequential execution number - determines commit order"""
|
||||
|
||||
trigger_name: str
|
||||
"""Name/type of trigger being executed"""
|
||||
|
||||
snapshot_seqs: dict[str, int] = field(default_factory=dict)
|
||||
"""Store name -> seq number of snapshot that was read"""
|
||||
|
||||
def record_snapshot(self, store_name: str, snapshot_seq: int) -> None:
|
||||
"""Record that we read a snapshot from a store at a specific seq"""
|
||||
self.snapshot_seqs[store_name] = snapshot_seq
|
||||
logger.debug(f"Seq {self.seq}: Read {store_name} at seq {snapshot_seq}")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ExecutionContext(seq={self.seq}, trigger={self.trigger_name})"
|
||||
|
||||
|
||||
def get_execution_context() -> Optional[ExecutionContext]:
|
||||
"""Get the current execution context, or None if not in an execution"""
|
||||
return _execution_context.get()
|
||||
|
||||
|
||||
def set_execution_context(ctx: ExecutionContext) -> None:
|
||||
"""Set the execution context for the current async task"""
|
||||
_execution_context.set(ctx)
|
||||
logger.debug(f"Set execution context: {ctx}")
|
||||
|
||||
|
||||
def clear_execution_context() -> None:
|
||||
"""Clear the execution context"""
|
||||
_execution_context.set(None)
|
||||
302
backend.old/src/trigger/coordinator.py
Normal file
302
backend.old/src/trigger/coordinator.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Commit coordinator - manages sequential commits with conflict detection.
|
||||
|
||||
Ensures that commits happen in strict sequence order, even when executions
|
||||
complete out of order. Detects conflicts and triggers re-execution with the
|
||||
same seq number (not re-enqueue, just re-execute).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .context import ExecutionContext
|
||||
from .store import VersionedStore
|
||||
from .types import CommitIntent, ExecutionRecord, ExecutionState, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommitCoordinator:
|
||||
"""
|
||||
Manages sequential commits with optimistic concurrency control.
|
||||
|
||||
Key responsibilities:
|
||||
- Maintain strict sequential commit order (seq N+1 commits after seq N)
|
||||
- Detect conflicts between execution snapshot and committed state
|
||||
- Trigger re-execution (not re-enqueue) on conflicts with same seq
|
||||
- Track in-flight executions for debugging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._stores: dict[str, VersionedStore] = {}
|
||||
self._current_seq = 0 # Highest committed seq across all operations
|
||||
self._next_commit_seq = 1 # Next seq we're waiting to commit
|
||||
self._pending_commits: dict[int, tuple[ExecutionRecord, list[CommitIntent]]] = {}
|
||||
self._execution_records: dict[int, ExecutionRecord] = {}
|
||||
self._lock = asyncio.Lock() # Only for coordinator internal state, not stores
|
||||
|
||||
def register_store(self, store: VersionedStore) -> None:
|
||||
"""Register a versioned store with the coordinator"""
|
||||
self._stores[store.name] = store
|
||||
logger.info(f"Registered store: {store.name}")
|
||||
|
||||
def get_store(self, name: str) -> Optional[VersionedStore]:
|
||||
"""Get a registered store by name"""
|
||||
return self._stores.get(name)
|
||||
|
||||
async def start_execution(self, seq: int, trigger: Trigger) -> ExecutionRecord:
|
||||
"""
|
||||
Record that an execution is starting.
|
||||
|
||||
Args:
|
||||
seq: Sequence number assigned to this execution
|
||||
trigger: The trigger being executed
|
||||
|
||||
Returns:
|
||||
ExecutionRecord for tracking
|
||||
"""
|
||||
async with self._lock:
|
||||
record = ExecutionRecord(
|
||||
seq=seq,
|
||||
trigger=trigger,
|
||||
state=ExecutionState.EXECUTING,
|
||||
)
|
||||
self._execution_records[seq] = record
|
||||
logger.info(f"Started execution: seq={seq}, trigger={trigger.name}")
|
||||
return record
|
||||
|
||||
async def submit_for_commit(
|
||||
self,
|
||||
seq: int,
|
||||
commit_intents: list[CommitIntent],
|
||||
) -> None:
|
||||
"""
|
||||
Submit commit intents for sequential commit.
|
||||
|
||||
The commit will only happen when:
|
||||
1. All prior seq numbers have committed
|
||||
2. No conflicts detected with committed state
|
||||
|
||||
Args:
|
||||
seq: Sequence number of this execution
|
||||
commit_intents: List of changes to commit (empty if no changes)
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._execution_records.get(seq)
|
||||
if not record:
|
||||
logger.error(f"No execution record found for seq={seq}")
|
||||
return
|
||||
|
||||
record.state = ExecutionState.WAITING_COMMIT
|
||||
record.commit_intents = commit_intents
|
||||
self._pending_commits[seq] = (record, commit_intents)
|
||||
|
||||
logger.info(
|
||||
f"Seq {seq} submitted for commit with {len(commit_intents)} intents"
|
||||
)
|
||||
|
||||
# Try to process commits (this will handle sequential ordering)
|
||||
await self._process_commits()
|
||||
|
||||
async def _process_commits(self) -> None:
|
||||
"""
|
||||
Process pending commits in strict sequential order.
|
||||
|
||||
Only commits seq N if seq N-1 has already committed.
|
||||
Detects conflicts and triggers re-execution with same seq.
|
||||
"""
|
||||
while True:
|
||||
async with self._lock:
|
||||
# Check if next expected seq is ready to commit
|
||||
if self._next_commit_seq not in self._pending_commits:
|
||||
# Waiting for this seq to complete execution
|
||||
break
|
||||
|
||||
seq = self._next_commit_seq
|
||||
record, intents = self._pending_commits[seq]
|
||||
|
||||
logger.info(
|
||||
f"Processing commit for seq={seq} (current_seq={self._current_seq})"
|
||||
)
|
||||
|
||||
# Check for conflicts
|
||||
conflicts = self._check_conflicts(intents)
|
||||
|
||||
if conflicts:
|
||||
# Conflict detected - re-execute with same seq
|
||||
logger.warning(
|
||||
f"Seq {seq} has conflicts in stores: {conflicts}. Re-executing..."
|
||||
)
|
||||
|
||||
# Remove from pending (will be re-added when execution completes)
|
||||
del self._pending_commits[seq]
|
||||
|
||||
# Mark as evicted
|
||||
record.state = ExecutionState.EVICTED
|
||||
record.retry_count += 1
|
||||
|
||||
# Advance to next seq (this seq will be retried in background)
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
# Trigger re-execution (outside lock)
|
||||
asyncio.create_task(self._retry_execution(record))
|
||||
|
||||
continue
|
||||
|
||||
# No conflicts - commit all intents atomically
|
||||
for intent in intents:
|
||||
store = self._stores.get(intent.store_name)
|
||||
if not store:
|
||||
logger.error(
|
||||
f"Seq {seq}: Store '{intent.store_name}' not found"
|
||||
)
|
||||
continue
|
||||
|
||||
store.commit(intent.new_data, seq)
|
||||
|
||||
# Mark as committed
|
||||
record.state = ExecutionState.COMMITTED
|
||||
del self._pending_commits[seq]
|
||||
|
||||
# Advance seq counters
|
||||
self._current_seq = seq
|
||||
self._next_commit_seq = seq + 1
|
||||
|
||||
logger.info(
|
||||
f"Committed seq={seq}, current_seq now {self._current_seq}"
|
||||
)
|
||||
|
||||
def _check_conflicts(self, intents: list[CommitIntent]) -> list[str]:
|
||||
"""
|
||||
Check if any commit intents conflict with current committed state.
|
||||
|
||||
Args:
|
||||
intents: List of commit intents to check
|
||||
|
||||
Returns:
|
||||
List of store names that have conflicts (empty if no conflicts)
|
||||
"""
|
||||
conflicts = []
|
||||
|
||||
for intent in intents:
|
||||
store = self._stores.get(intent.store_name)
|
||||
if not store:
|
||||
logger.error(f"Store '{intent.store_name}' not found during conflict check")
|
||||
continue
|
||||
|
||||
if store.check_conflict(intent.expected_seq):
|
||||
conflicts.append(intent.store_name)
|
||||
|
||||
return conflicts
|
||||
|
||||
async def _retry_execution(self, record: ExecutionRecord) -> None:
|
||||
"""
|
||||
Re-execute a trigger that had conflicts.
|
||||
|
||||
Executes with the SAME seq number (not re-enqueued, just re-executed).
|
||||
This ensures the execution order remains deterministic.
|
||||
|
||||
Args:
|
||||
record: Execution record to retry
|
||||
"""
|
||||
from .context import ExecutionContext, set_execution_context, clear_execution_context
|
||||
|
||||
logger.info(
|
||||
f"Retrying execution: seq={record.seq}, trigger={record.trigger.name}, "
|
||||
f"retry_count={record.retry_count}"
|
||||
)
|
||||
|
||||
# Set execution context for retry
|
||||
ctx = ExecutionContext(
|
||||
seq=record.seq,
|
||||
trigger_name=record.trigger.name,
|
||||
)
|
||||
set_execution_context(ctx)
|
||||
|
||||
try:
|
||||
# Re-execute trigger
|
||||
record.state = ExecutionState.EXECUTING
|
||||
commit_intents = await record.trigger.execute()
|
||||
|
||||
# Submit for commit again (with same seq)
|
||||
await self.submit_for_commit(record.seq, commit_intents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retry execution failed for seq={record.seq}: {e}", exc_info=True
|
||||
)
|
||||
record.state = ExecutionState.FAILED
|
||||
record.error = str(e)
|
||||
|
||||
# Still need to advance past this seq
|
||||
async with self._lock:
|
||||
if record.seq == self._next_commit_seq:
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
# Try to process any pending commits
|
||||
await self._process_commits()
|
||||
|
||||
finally:
|
||||
clear_execution_context()
|
||||
|
||||
async def execution_failed(self, seq: int, error: Exception) -> None:
|
||||
"""
|
||||
Mark an execution as failed.
|
||||
|
||||
Args:
|
||||
seq: Sequence number that failed
|
||||
error: The exception that caused the failure
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._execution_records.get(seq)
|
||||
if record:
|
||||
record.state = ExecutionState.FAILED
|
||||
record.error = str(error)
|
||||
|
||||
# Remove from pending if present
|
||||
self._pending_commits.pop(seq, None)
|
||||
|
||||
# If this is the next seq to commit, advance past it
|
||||
if seq == self._next_commit_seq:
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
logger.info(
|
||||
f"Seq {seq} failed, advancing current_seq to {self._current_seq}"
|
||||
)
|
||||
|
||||
# Try to process any pending commits
|
||||
await self._process_commits()
|
||||
|
||||
def get_current_seq(self) -> int:
|
||||
"""Get the current committed sequence number"""
|
||||
return self._current_seq
|
||||
|
||||
def get_execution_record(self, seq: int) -> Optional[ExecutionRecord]:
|
||||
"""Get execution record for a specific seq"""
|
||||
return self._execution_records.get(seq)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics about the coordinator state"""
|
||||
state_counts = {}
|
||||
for record in self._execution_records.values():
|
||||
state_name = record.state.name
|
||||
state_counts[state_name] = state_counts.get(state_name, 0) + 1
|
||||
|
||||
return {
|
||||
"current_seq": self._current_seq,
|
||||
"next_commit_seq": self._next_commit_seq,
|
||||
"pending_commits": len(self._pending_commits),
|
||||
"total_executions": len(self._execution_records),
|
||||
"state_counts": state_counts,
|
||||
"stores": {name: str(store) for name, store in self._stores.items()},
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CommitCoordinator(current_seq={self._current_seq}, "
|
||||
f"pending={len(self._pending_commits)}, stores={len(self._stores)})"
|
||||
)
|
||||
304
backend.old/src/trigger/handlers.py
Normal file
304
backend.old/src/trigger/handlers.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Trigger handlers - concrete implementations for common trigger types.
|
||||
|
||||
Provides ready-to-use trigger handlers for:
|
||||
- Agent execution (WebSocket user messages)
|
||||
- Lambda/callable execution
|
||||
- Data update triggers
|
||||
- Indicator updates
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
from .coordinator import CommitCoordinator
|
||||
from .types import CommitIntent, Priority, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentTriggerHandler(Trigger):
|
||||
"""
|
||||
Trigger for agent execution from WebSocket user messages.
|
||||
|
||||
Wraps the Gateway's agent execution flow and captures any
|
||||
store modifications as commit intents.
|
||||
|
||||
Priority tuple: (USER_AGENT, message_timestamp, queue_seq)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
message_content: str,
|
||||
message_timestamp: Optional[int] = None,
|
||||
attachments: Optional[list] = None,
|
||||
gateway_handler: Optional[Callable] = None,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent trigger.
|
||||
|
||||
Args:
|
||||
session_id: User session ID
|
||||
message_content: User message content
|
||||
message_timestamp: When user sent message (unix timestamp, defaults to now)
|
||||
attachments: Optional message attachments
|
||||
gateway_handler: Callable to route to Gateway (set during integration)
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
"""
|
||||
if message_timestamp is None:
|
||||
message_timestamp = int(time.time())
|
||||
|
||||
# Priority tuple: sort by USER_AGENT priority, then message timestamp
|
||||
super().__init__(
|
||||
name=f"agent_{session_id}",
|
||||
priority=Priority.USER_AGENT,
|
||||
priority_tuple=(Priority.USER_AGENT.value, message_timestamp)
|
||||
)
|
||||
self.session_id = session_id
|
||||
self.message_content = message_content
|
||||
self.message_timestamp = message_timestamp
|
||||
self.attachments = attachments or []
|
||||
self.gateway_handler = gateway_handler
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Execute agent interaction.
|
||||
|
||||
This will call into the Gateway, which will run the agent.
|
||||
The agent may read from stores and generate responses.
|
||||
Any store modifications are captured as commit intents.
|
||||
|
||||
Returns:
|
||||
List of commit intents (typically empty for now, as agent
|
||||
modifies stores via tools which will be integrated later)
|
||||
"""
|
||||
if not self.gateway_handler:
|
||||
logger.error("No gateway_handler configured for AgentTriggerHandler")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Agent trigger executing: session={self.session_id}, "
|
||||
f"content='{self.message_content[:50]}...'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Call Gateway to handle message
|
||||
# In future, Gateway/agent tools will use coordinator stores
|
||||
await self.gateway_handler(
|
||||
self.session_id,
|
||||
self.message_content,
|
||||
self.attachments,
|
||||
)
|
||||
|
||||
# For now, agent doesn't directly modify stores
|
||||
# Future: agent tools will return commit intents
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class LambdaHandler(Trigger):
|
||||
"""
|
||||
Generic trigger that executes an arbitrary async callable.
|
||||
|
||||
Useful for custom triggers, one-off tasks, or testing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable[[], Awaitable[list[CommitIntent]]],
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
):
|
||||
"""
|
||||
Initialize lambda handler.
|
||||
|
||||
Args:
|
||||
name: Descriptive name for this trigger
|
||||
func: Async callable that returns commit intents
|
||||
priority: Execution priority
|
||||
"""
|
||||
super().__init__(name, priority)
|
||||
self.func = func
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""Execute the callable"""
|
||||
logger.info(f"Lambda trigger executing: {self.name}")
|
||||
return await self.func()
|
||||
|
||||
|
||||
class DataUpdateTrigger(Trigger):
|
||||
"""
|
||||
Trigger for DataSource bar updates.
|
||||
|
||||
Fired when new market data arrives. Can update indicators,
|
||||
trigger strategy logic, or notify the agent of market events.
|
||||
|
||||
Priority tuple: (DATA_SOURCE, event_time, queue_seq)
|
||||
Ensures older bars process before newer ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
bar_data: dict,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
):
|
||||
"""
|
||||
Initialize data update trigger.
|
||||
|
||||
Args:
|
||||
source_name: Name of data source (e.g., "binance")
|
||||
symbol: Trading pair symbol
|
||||
resolution: Time resolution
|
||||
bar_data: Bar data dict (time, open, high, low, close, volume)
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
"""
|
||||
event_time = bar_data.get('time', int(time.time()))
|
||||
|
||||
# Priority tuple: sort by DATA_SOURCE priority, then event time
|
||||
super().__init__(
|
||||
name=f"data_{source_name}_{symbol}_{resolution}",
|
||||
priority=Priority.DATA_SOURCE,
|
||||
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
|
||||
)
|
||||
self.source_name = source_name
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.bar_data = bar_data
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Process bar update.
|
||||
|
||||
Future implementations will:
|
||||
- Update indicator values
|
||||
- Check strategy conditions
|
||||
- Trigger alerts/notifications
|
||||
|
||||
Returns:
|
||||
Commit intents for any store updates
|
||||
"""
|
||||
logger.info(
|
||||
f"Data update trigger: {self.source_name}:{self.symbol}@{self.resolution}, "
|
||||
f"time={self.bar_data.get('time')}"
|
||||
)
|
||||
|
||||
# TODO: Update indicators
|
||||
# TODO: Check strategy conditions
|
||||
# TODO: Notify agent of significant events
|
||||
|
||||
# For now, just log
|
||||
return []
|
||||
|
||||
|
||||
class IndicatorUpdateTrigger(Trigger):
|
||||
"""
|
||||
Trigger for updating indicator values.
|
||||
|
||||
Can be fired by cron (periodic recalculation) or by data updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indicator_id: str,
|
||||
force_full_recalc: bool = False,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
):
|
||||
"""
|
||||
Initialize indicator update trigger.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of indicator to update
|
||||
force_full_recalc: If True, recalculate entire history
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
priority: Execution priority
|
||||
"""
|
||||
super().__init__(f"indicator_{indicator_id}", priority)
|
||||
self.indicator_id = indicator_id
|
||||
self.force_full_recalc = force_full_recalc
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Update indicator value.
|
||||
|
||||
Reads from IndicatorStore, recalculates, prepares commit.
|
||||
|
||||
Returns:
|
||||
Commit intents for updated indicator data
|
||||
"""
|
||||
if not self.coordinator:
|
||||
logger.error("No coordinator configured")
|
||||
return []
|
||||
|
||||
# Get indicator store
|
||||
indicator_store = self.coordinator.get_store("IndicatorStore")
|
||||
if not indicator_store:
|
||||
logger.error("IndicatorStore not registered")
|
||||
return []
|
||||
|
||||
# Read snapshot
|
||||
snapshot_seq, indicator_data = indicator_store.read_snapshot()
|
||||
|
||||
logger.info(
|
||||
f"Indicator update trigger: {self.indicator_id}, "
|
||||
f"snapshot_seq={snapshot_seq}, force_full={self.force_full_recalc}"
|
||||
)
|
||||
|
||||
# TODO: Implement indicator recalculation logic
|
||||
# For now, just return empty (no changes)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
class CronTrigger(Trigger):
|
||||
"""
|
||||
Trigger fired by APScheduler on a schedule.
|
||||
|
||||
Wraps another trigger or callable to execute periodically.
|
||||
|
||||
Priority tuple: (TIMER, scheduled_time, queue_seq)
|
||||
Ensures jobs scheduled for earlier times run first.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
inner_trigger: Trigger,
|
||||
scheduled_time: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize cron trigger.
|
||||
|
||||
Args:
|
||||
name: Descriptive name (e.g., "hourly_sync")
|
||||
inner_trigger: Trigger to execute on schedule
|
||||
scheduled_time: When this was scheduled to run (defaults to now)
|
||||
"""
|
||||
if scheduled_time is None:
|
||||
scheduled_time = int(time.time())
|
||||
|
||||
# Priority tuple: sort by TIMER priority, then scheduled time
|
||||
super().__init__(
|
||||
name=f"cron_{name}",
|
||||
priority=Priority.TIMER,
|
||||
priority_tuple=(Priority.TIMER.value, scheduled_time)
|
||||
)
|
||||
self.inner_trigger = inner_trigger
|
||||
self.scheduled_time = scheduled_time
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""Execute the wrapped trigger"""
|
||||
logger.info(f"Cron trigger firing: {self.name}")
|
||||
return await self.inner_trigger.execute()
|
||||
224
backend.old/src/trigger/queue.py
Normal file
224
backend.old/src/trigger/queue.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Trigger queue - priority queue with sequence number assignment.
|
||||
|
||||
All operations flow through this queue:
|
||||
- WebSocket messages from users
|
||||
- Cron scheduled tasks
|
||||
- DataSource bar updates
|
||||
- Manual triggers
|
||||
|
||||
Queue assigns seq numbers on dequeue, executes triggers, and submits to coordinator.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .context import ExecutionContext, clear_execution_context, set_execution_context
|
||||
from .coordinator import CommitCoordinator
|
||||
from .types import Priority, PriorityTuple, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerQueue:
|
||||
"""
|
||||
Priority queue for trigger execution.
|
||||
|
||||
Key responsibilities:
|
||||
- Maintain priority queue (high priority dequeued first)
|
||||
- Assign sequence numbers on dequeue (determines commit order)
|
||||
- Execute triggers with context set
|
||||
- Submit results to CommitCoordinator
|
||||
- Handle execution errors gracefully
|
||||
"""
|
||||
|
||||
def __init__(self, coordinator: CommitCoordinator):
|
||||
"""
|
||||
Initialize trigger queue.
|
||||
|
||||
Args:
|
||||
coordinator: CommitCoordinator for handling commits
|
||||
"""
|
||||
self._coordinator = coordinator
|
||||
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
||||
self._seq_counter = 0
|
||||
self._seq_lock = asyncio.Lock()
|
||||
self._processor_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the queue processor"""
|
||||
if self._running:
|
||||
logger.warning("TriggerQueue already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._processor_task = asyncio.create_task(self._process_loop())
|
||||
logger.info("TriggerQueue started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the queue processor gracefully"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._processor_task:
|
||||
self._processor_task.cancel()
|
||||
try:
|
||||
await self._processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("TriggerQueue stopped")
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
priority_override: Optional[Priority | PriorityTuple] = None
|
||||
) -> int:
|
||||
"""
|
||||
Add a trigger to the queue.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
priority_override: Override priority (simple Priority or tuple)
|
||||
If None, uses trigger's priority/priority_tuple
|
||||
If Priority enum, creates single-element tuple
|
||||
If tuple, uses as-is
|
||||
|
||||
Returns:
|
||||
Queue sequence number (appended to priority tuple)
|
||||
|
||||
Examples:
|
||||
# Simple priority
|
||||
await queue.enqueue(trigger, Priority.USER_AGENT)
|
||||
# Results in: (Priority.USER_AGENT, queue_seq)
|
||||
|
||||
# Tuple priority with event time
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
(Priority.DATA_SOURCE, bar_data['time'])
|
||||
)
|
||||
# Results in: (Priority.DATA_SOURCE, bar_time, queue_seq)
|
||||
|
||||
# Let trigger decide
|
||||
await queue.enqueue(trigger)
|
||||
"""
|
||||
# Get monotonic seq for queue ordering (appended to tuple)
|
||||
async with self._seq_lock:
|
||||
queue_seq = self._seq_counter
|
||||
self._seq_counter += 1
|
||||
|
||||
# Determine priority tuple
|
||||
if priority_override is not None:
|
||||
if isinstance(priority_override, Priority):
|
||||
# Convert simple priority to tuple
|
||||
priority_tuple = (priority_override.value, queue_seq)
|
||||
else:
|
||||
# Use provided tuple, append queue_seq
|
||||
priority_tuple = priority_override + (queue_seq,)
|
||||
else:
|
||||
# Let trigger determine its own priority tuple
|
||||
priority_tuple = trigger.get_priority_tuple(queue_seq)
|
||||
|
||||
# Priority queue: (priority_tuple, trigger)
|
||||
# Python's PriorityQueue compares tuples element-by-element
|
||||
await self._queue.put((priority_tuple, trigger))
|
||||
|
||||
logger.debug(
|
||||
f"Enqueued: {trigger.name} with priority_tuple={priority_tuple}"
|
||||
)
|
||||
|
||||
return queue_seq
|
||||
|
||||
async def _process_loop(self) -> None:
|
||||
"""
|
||||
Main processing loop.
|
||||
|
||||
Dequeues triggers, assigns execution seq, executes, and submits to coordinator.
|
||||
"""
|
||||
execution_seq = 0 # Separate counter for execution sequence
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for next trigger (with timeout to check _running flag)
|
||||
try:
|
||||
priority_tuple, trigger = await asyncio.wait_for(
|
||||
self._queue.get(), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
# Assign execution sequence number
|
||||
execution_seq += 1
|
||||
|
||||
logger.info(
|
||||
f"Dequeued: seq={execution_seq}, trigger={trigger.name}, "
|
||||
f"priority_tuple={priority_tuple}"
|
||||
)
|
||||
|
||||
# Execute in background (don't block queue)
|
||||
asyncio.create_task(
|
||||
self._execute_trigger(execution_seq, trigger)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process loop: {e}", exc_info=True)
|
||||
|
||||
async def _execute_trigger(self, seq: int, trigger: Trigger) -> None:
|
||||
"""
|
||||
Execute a trigger with proper context and error handling.
|
||||
|
||||
Args:
|
||||
seq: Execution sequence number
|
||||
trigger: Trigger to execute
|
||||
"""
|
||||
# Set up execution context
|
||||
ctx = ExecutionContext(
|
||||
seq=seq,
|
||||
trigger_name=trigger.name,
|
||||
)
|
||||
set_execution_context(ctx)
|
||||
|
||||
# Record execution start with coordinator
|
||||
await self._coordinator.start_execution(seq, trigger)
|
||||
|
||||
try:
|
||||
logger.info(f"Executing: seq={seq}, trigger={trigger.name}")
|
||||
|
||||
# Execute trigger (can be long-running)
|
||||
commit_intents = await trigger.execute()
|
||||
|
||||
logger.info(
|
||||
f"Execution complete: seq={seq}, {len(commit_intents)} commit intents"
|
||||
)
|
||||
|
||||
# Submit for sequential commit
|
||||
await self._coordinator.submit_for_commit(seq, commit_intents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Execution failed: seq={seq}, trigger={trigger.name}, error={e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Notify coordinator of failure
|
||||
await self._coordinator.execution_failed(seq, e)
|
||||
|
||||
finally:
|
||||
clear_execution_context()
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get current queue size (approximate)"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if queue processor is running"""
|
||||
return self._running
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"TriggerQueue(running={self._running}, queue_size={self.get_queue_size()})"
|
||||
)
|
||||
187
backend.old/src/trigger/scheduler.py
Normal file
187
backend.old/src/trigger/scheduler.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
APScheduler integration for cron-style triggers.
|
||||
|
||||
Provides scheduling of periodic triggers (e.g., sync exchange state hourly,
|
||||
recompute indicators every 5 minutes, daily portfolio reports).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger as APSCronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from .queue import TriggerQueue
|
||||
from .types import Priority, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerScheduler:
|
||||
"""
|
||||
Scheduler for periodic trigger execution.
|
||||
|
||||
Wraps APScheduler to enqueue triggers at scheduled times.
|
||||
"""
|
||||
|
||||
def __init__(self, trigger_queue: TriggerQueue):
|
||||
"""
|
||||
Initialize scheduler.
|
||||
|
||||
Args:
|
||||
trigger_queue: TriggerQueue to enqueue triggers into
|
||||
"""
|
||||
self.trigger_queue = trigger_queue
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self._job_counter = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler"""
|
||||
self.scheduler.start()
|
||||
logger.info("TriggerScheduler started")
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""
|
||||
Shut down the scheduler.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for running jobs to complete
|
||||
"""
|
||||
self.scheduler.shutdown(wait=wait)
|
||||
logger.info("TriggerScheduler shut down")
|
||||
|
||||
def schedule_interval(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
seconds: Optional[int] = None,
|
||||
minutes: Optional[int] = None,
|
||||
hours: Optional[int] = None,
|
||||
priority: Optional[Priority] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a trigger to run at regular intervals.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
seconds: Interval in seconds
|
||||
minutes: Interval in minutes
|
||||
hours: Interval in hours
|
||||
priority: Priority override for execution
|
||||
|
||||
Returns:
|
||||
Job ID (can be used to remove job later)
|
||||
|
||||
Example:
|
||||
# Run every 5 minutes
|
||||
scheduler.schedule_interval(
|
||||
IndicatorUpdateTrigger("rsi_14"),
|
||||
minutes=5
|
||||
)
|
||||
"""
|
||||
job_id = f"interval_{self._job_counter}"
|
||||
self._job_counter += 1
|
||||
|
||||
async def job_func():
|
||||
await self.trigger_queue.enqueue(trigger, priority)
|
||||
|
||||
self.scheduler.add_job(
|
||||
job_func,
|
||||
trigger=IntervalTrigger(seconds=seconds, minutes=minutes, hours=hours),
|
||||
id=job_id,
|
||||
name=f"Interval: {trigger.name}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduled interval job: {job_id}, trigger={trigger.name}, "
|
||||
f"interval=(s={seconds}, m={minutes}, h={hours})"
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def schedule_cron(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
minute: Optional[str] = None,
|
||||
hour: Optional[str] = None,
|
||||
day: Optional[str] = None,
|
||||
month: Optional[str] = None,
|
||||
day_of_week: Optional[str] = None,
|
||||
priority: Optional[Priority] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a trigger to run on a cron schedule.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
minute: Minute expression (0-59, *, */5, etc.)
|
||||
hour: Hour expression (0-23, *, etc.)
|
||||
day: Day of month expression (1-31, *, etc.)
|
||||
month: Month expression (1-12, *, etc.)
|
||||
day_of_week: Day of week expression (0-6, mon-sun, *, etc.)
|
||||
priority: Priority override for execution
|
||||
|
||||
Returns:
|
||||
Job ID (can be used to remove job later)
|
||||
|
||||
Example:
|
||||
# Run at 9:00 AM every weekday
|
||||
scheduler.schedule_cron(
|
||||
SyncExchangeStateTrigger(),
|
||||
hour="9",
|
||||
minute="0",
|
||||
day_of_week="mon-fri"
|
||||
)
|
||||
"""
|
||||
job_id = f"cron_{self._job_counter}"
|
||||
self._job_counter += 1
|
||||
|
||||
async def job_func():
|
||||
await self.trigger_queue.enqueue(trigger, priority)
|
||||
|
||||
self.scheduler.add_job(
|
||||
job_func,
|
||||
trigger=APSCronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day,
|
||||
month=month,
|
||||
day_of_week=day_of_week,
|
||||
),
|
||||
id=job_id,
|
||||
name=f"Cron: {trigger.name}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduled cron job: {job_id}, trigger={trigger.name}, "
|
||||
f"schedule=(m={minute}, h={hour}, d={day}, dow={day_of_week})"
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
Remove a scheduled job.
|
||||
|
||||
Args:
|
||||
job_id: Job ID returned from schedule_* methods
|
||||
|
||||
Returns:
|
||||
True if job was removed, False if not found
|
||||
"""
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
logger.info(f"Removed scheduled job: {job_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove job {job_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_jobs(self) -> list:
|
||||
"""Get list of all scheduled jobs"""
|
||||
return self.scheduler.get_jobs()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
job_count = len(self.scheduler.get_jobs())
|
||||
running = self.scheduler.running
|
||||
return f"TriggerScheduler(running={running}, jobs={job_count})"
|
||||
301
backend.old/src/trigger/store.py
Normal file
301
backend.old/src/trigger/store.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Versioned store with pluggable backends.
|
||||
|
||||
Provides optimistic concurrency control via sequence numbers with support
|
||||
for different storage backends (Pydantic models, files, databases, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from .context import get_execution_context
|
||||
from .types import CommitIntent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StoreBackend(ABC, Generic[T]):
|
||||
"""
|
||||
Abstract backend for versioned stores.
|
||||
|
||||
Allows different storage mechanisms (Pydantic models, files, databases)
|
||||
to be used with the same versioned store infrastructure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self) -> T:
|
||||
"""
|
||||
Read the current data.
|
||||
|
||||
Returns:
|
||||
Current data in backend-specific format
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, data: T) -> None:
|
||||
"""
|
||||
Write new data (replaces existing).
|
||||
|
||||
Args:
|
||||
data: New data to write
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def snapshot(self) -> T:
|
||||
"""
|
||||
Create an immutable snapshot of current data.
|
||||
|
||||
Must return a deep copy or immutable version to prevent
|
||||
modifications from affecting the committed state.
|
||||
|
||||
Returns:
|
||||
Immutable snapshot of data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, data: T) -> bool:
|
||||
"""
|
||||
Validate that data is in correct format for this backend.
|
||||
|
||||
Args:
|
||||
data: Data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid with explanation
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PydanticStoreBackend(StoreBackend[T]):
|
||||
"""
|
||||
Backend for Pydantic BaseModel stores.
|
||||
|
||||
Supports the existing OrderStore, ChartStore, etc. pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, model_instance: T):
|
||||
"""
|
||||
Initialize with a Pydantic model instance.
|
||||
|
||||
Args:
|
||||
model_instance: Instance of a Pydantic BaseModel
|
||||
"""
|
||||
self._model = model_instance
|
||||
|
||||
def read(self) -> T:
|
||||
return self._model
|
||||
|
||||
def write(self, data: T) -> None:
|
||||
# Replace the internal model
|
||||
self._model = data
|
||||
|
||||
def snapshot(self) -> T:
|
||||
# Use Pydantic's model_copy for deep copy
|
||||
if hasattr(self._model, "model_copy"):
|
||||
return self._model.model_copy(deep=True)
|
||||
# Fallback for older Pydantic or non-model types
|
||||
return deepcopy(self._model)
|
||||
|
||||
def validate(self, data: T) -> bool:
|
||||
# Pydantic models validate themselves on construction
|
||||
# If we got here with a model instance, it's valid
|
||||
return True
|
||||
|
||||
|
||||
class FileStoreBackend(StoreBackend[str]):
|
||||
"""
|
||||
Backend for file-based storage.
|
||||
|
||||
Future implementation for versioning files (e.g., Python scripts, configs).
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
raise NotImplementedError("FileStoreBackend not yet implemented")
|
||||
|
||||
def read(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def snapshot(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, data: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DatabaseStoreBackend(StoreBackend[dict]):
|
||||
"""
|
||||
Backend for database table storage.
|
||||
|
||||
Future implementation for versioning database interactions.
|
||||
"""
|
||||
|
||||
def __init__(self, table_name: str, connection):
|
||||
self.table_name = table_name
|
||||
self.connection = connection
|
||||
raise NotImplementedError("DatabaseStoreBackend not yet implemented")
|
||||
|
||||
def read(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, data: dict) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def snapshot(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, data: dict) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class VersionedStore(Generic[T]):
|
||||
"""
|
||||
Store with optimistic concurrency control via sequence numbers.
|
||||
|
||||
Wraps any StoreBackend and provides:
|
||||
- Lock-free snapshot reads
|
||||
- Conflict detection on commit
|
||||
- Version tracking for debugging
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, backend: StoreBackend[T]):
|
||||
"""
|
||||
Initialize versioned store.
|
||||
|
||||
Args:
|
||||
name: Unique name for this store (e.g., "OrderStore")
|
||||
backend: Backend implementation for storage
|
||||
"""
|
||||
self.name = name
|
||||
self._backend = backend
|
||||
self._committed_seq = 0 # Highest committed seq
|
||||
self._version = 0 # Increments on each commit (for debugging)
|
||||
|
||||
@property
|
||||
def committed_seq(self) -> int:
|
||||
"""Get the current committed sequence number"""
|
||||
return self._committed_seq
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version (increments on each commit)"""
|
||||
return self._version
|
||||
|
||||
def read_snapshot(self) -> tuple[int, T]:
|
||||
"""
|
||||
Read an immutable snapshot of the store.
|
||||
|
||||
This is lock-free and can be called concurrently. The snapshot
|
||||
captures the current committed seq and a deep copy of the data.
|
||||
|
||||
Automatically records the snapshot seq in the execution context
|
||||
for conflict detection during commit.
|
||||
|
||||
Returns:
|
||||
Tuple of (seq, snapshot_data)
|
||||
"""
|
||||
snapshot_seq = self._committed_seq
|
||||
snapshot_data = self._backend.snapshot()
|
||||
|
||||
# Record in execution context for conflict detection
|
||||
ctx = get_execution_context()
|
||||
if ctx:
|
||||
ctx.record_snapshot(self.name, snapshot_seq)
|
||||
|
||||
logger.debug(
|
||||
f"Store '{self.name}': read_snapshot() -> seq={snapshot_seq}, version={self._version}"
|
||||
)
|
||||
|
||||
return (snapshot_seq, snapshot_data)
|
||||
|
||||
def read_current(self) -> T:
|
||||
"""
|
||||
Read the current data without snapshot tracking.
|
||||
|
||||
Use this for read-only operations that don't need conflict detection.
|
||||
|
||||
Returns:
|
||||
Current data (not a snapshot, modifications visible)
|
||||
"""
|
||||
return self._backend.read()
|
||||
|
||||
def prepare_commit(self, expected_seq: int, new_data: T) -> CommitIntent:
|
||||
"""
|
||||
Create a commit intent for later sequential commit.
|
||||
|
||||
Does NOT modify the store - that happens during the commit phase.
|
||||
|
||||
Args:
|
||||
expected_seq: The seq of the snapshot that was read
|
||||
new_data: The new data to commit
|
||||
|
||||
Returns:
|
||||
CommitIntent to be submitted to CommitCoordinator
|
||||
"""
|
||||
# Validate data before creating intent
|
||||
self._backend.validate(new_data)
|
||||
|
||||
intent = CommitIntent(
|
||||
store_name=self.name,
|
||||
expected_seq=expected_seq,
|
||||
new_data=new_data,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Store '{self.name}': prepare_commit(expected_seq={expected_seq}, current_seq={self._committed_seq})"
|
||||
)
|
||||
|
||||
return intent
|
||||
|
||||
def commit(self, new_data: T, commit_seq: int) -> None:
|
||||
"""
|
||||
Commit new data at a specific seq.
|
||||
|
||||
Called by CommitCoordinator during sequential commit phase.
|
||||
NOT for direct use by triggers.
|
||||
|
||||
Args:
|
||||
new_data: Data to commit
|
||||
commit_seq: Seq number of this commit
|
||||
"""
|
||||
self._backend.write(new_data)
|
||||
self._committed_seq = commit_seq
|
||||
self._version += 1
|
||||
|
||||
logger.info(
|
||||
f"Store '{self.name}': committed seq={commit_seq}, version={self._version}"
|
||||
)
|
||||
|
||||
def check_conflict(self, expected_seq: int) -> bool:
|
||||
"""
|
||||
Check if committing at expected_seq would conflict.
|
||||
|
||||
Args:
|
||||
expected_seq: The seq that was expected during execution
|
||||
|
||||
Returns:
|
||||
True if conflict (committed_seq has advanced beyond expected_seq)
|
||||
"""
|
||||
has_conflict = self._committed_seq != expected_seq
|
||||
if has_conflict:
|
||||
logger.warning(
|
||||
f"Store '{self.name}': conflict detected - "
|
||||
f"expected_seq={expected_seq}, committed_seq={self._committed_seq}"
|
||||
)
|
||||
return has_conflict
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VersionedStore(name='{self.name}', committed_seq={self._committed_seq}, version={self._version})"
|
||||
175
backend.old/src/trigger/types.py
Normal file
175
backend.old/src/trigger/types.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Core types for the trigger system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Priority(IntEnum):
|
||||
"""
|
||||
Primary execution priority for triggers.
|
||||
|
||||
Lower numeric value = higher priority (dequeued first).
|
||||
|
||||
Priority hierarchy (highest to lowest):
|
||||
- DATA_SOURCE: Market data, real-time feeds (most time-sensitive)
|
||||
- TIMER: Scheduled tasks, cron jobs
|
||||
- USER_AGENT: User-agent interactions (WebSocket chat)
|
||||
- USER_DATA_REQUEST: User data requests (chart loads, symbol search)
|
||||
- SYSTEM: Background tasks, cleanup
|
||||
- LOW: Retries after conflicts, non-critical tasks
|
||||
"""
|
||||
|
||||
DATA_SOURCE = 0 # Market data updates, real-time feeds
|
||||
TIMER = 1 # Scheduled tasks, cron jobs
|
||||
USER_AGENT = 2 # User-agent interactions (WebSocket chat)
|
||||
USER_DATA_REQUEST = 3 # User data requests (chart loads, etc.)
|
||||
SYSTEM = 4 # Background tasks, cleanup, etc.
|
||||
LOW = 5 # Retries after conflicts, non-critical tasks
|
||||
|
||||
|
||||
# Type alias for priority tuples
|
||||
# Examples:
|
||||
# (Priority.DATA_SOURCE,) - Simple priority
|
||||
# (Priority.DATA_SOURCE, event_time) - Priority + event time
|
||||
# (Priority.DATA_SOURCE, event_time, queue_seq) - Full ordering
|
||||
#
|
||||
# Python compares tuples element-by-element, left-to-right.
|
||||
# Shorter tuple wins if all shared elements are equal.
|
||||
PriorityTuple = tuple[int, ...]
|
||||
|
||||
|
||||
class ExecutionState(IntEnum):
|
||||
"""State of an execution in the system"""
|
||||
|
||||
QUEUED = 0 # In queue, waiting to be dequeued
|
||||
EXECUTING = 1 # Currently executing
|
||||
WAITING_COMMIT = 2 # Finished executing, waiting for sequential commit
|
||||
COMMITTED = 3 # Successfully committed
|
||||
EVICTED = 4 # Evicted due to conflict, will retry
|
||||
FAILED = 5 # Failed with error
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitIntent:
|
||||
"""
|
||||
Intent to commit changes to a store.
|
||||
|
||||
Created during execution, validated and applied during sequential commit phase.
|
||||
"""
|
||||
|
||||
store_name: str
|
||||
"""Name of the store to commit to"""
|
||||
|
||||
expected_seq: int
|
||||
"""The seq number of the snapshot that was read (for conflict detection)"""
|
||||
|
||||
new_data: Any
|
||||
"""The new data to commit (format depends on store backend)"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
data_preview = str(self.new_data)[:50]
|
||||
return f"CommitIntent(store={self.store_name}, expected_seq={self.expected_seq}, data={data_preview}...)"
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
"""
|
||||
Abstract base class for all triggers.
|
||||
|
||||
A trigger represents a unit of work that:
|
||||
1. Gets assigned a seq number when dequeued
|
||||
2. Executes (potentially long-running, async)
|
||||
3. Returns CommitIntents for any state changes
|
||||
4. Waits for sequential commit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
priority_tuple: Optional[PriorityTuple] = None
|
||||
):
|
||||
"""
|
||||
Initialize trigger.
|
||||
|
||||
Args:
|
||||
name: Descriptive name for logging
|
||||
priority: Simple priority (used if priority_tuple not provided)
|
||||
priority_tuple: Optional tuple for compound sorting
|
||||
Examples:
|
||||
(Priority.DATA_SOURCE, event_time)
|
||||
(Priority.USER_AGENT, message_timestamp)
|
||||
(Priority.TIMER, scheduled_time)
|
||||
"""
|
||||
self.name = name
|
||||
self.priority = priority
|
||||
self._priority_tuple = priority_tuple
|
||||
|
||||
def get_priority_tuple(self, queue_seq: int) -> PriorityTuple:
|
||||
"""
|
||||
Get the priority tuple for queue ordering.
|
||||
|
||||
If a priority tuple was provided at construction, append queue_seq.
|
||||
Otherwise, create tuple from simple priority.
|
||||
|
||||
Args:
|
||||
queue_seq: Queue insertion order (final sort key)
|
||||
|
||||
Returns:
|
||||
Priority tuple for queue ordering
|
||||
|
||||
Examples:
|
||||
(Priority.DATA_SOURCE,) + (queue_seq,) = (0, queue_seq)
|
||||
(Priority.DATA_SOURCE, 1000) + (queue_seq,) = (0, 1000, queue_seq)
|
||||
"""
|
||||
if self._priority_tuple is not None:
|
||||
return self._priority_tuple + (queue_seq,)
|
||||
else:
|
||||
return (self.priority.value, queue_seq)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Execute the trigger logic.
|
||||
|
||||
Can be long-running and async. Should read from stores via
|
||||
VersionedStore.read_snapshot() and return CommitIntents for any changes.
|
||||
|
||||
Returns:
|
||||
List of CommitIntents (empty if no state changes)
|
||||
|
||||
Raises:
|
||||
Exception: On execution failure (will be logged, no commit)
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name='{self.name}', priority={self.priority.name})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRecord:
|
||||
"""
|
||||
Record of an execution for tracking and debugging.
|
||||
|
||||
Maintained by the CommitCoordinator to track in-flight executions.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
trigger: Trigger
|
||||
state: ExecutionState
|
||||
commit_intents: Optional[list[CommitIntent]] = None
|
||||
error: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ExecutionRecord(seq={self.seq}, trigger={self.trigger.name}, "
|
||||
f"state={self.state.name}, retry={self.retry_count})"
|
||||
)
|
||||
Reference in New Issue
Block a user