713 lines
34 KiB
Python
713 lines
34 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
import uuid
|
|
import shutil
|
|
|
|
from sync.protocol import HelloMessage, PatchMessage, AuthMessage, AuthResponseMessage
|
|
from sync.registry import SyncRegistry
|
|
from gateway.hub import Gateway
|
|
from gateway.channels.websocket import WebSocketChannel
|
|
from gateway.protocol import WebSocketAgentUserMessage
|
|
from agent.core import create_agent
|
|
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
|
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
|
from schema.order_spec import SwapOrder
|
|
from schema.chart_state import ChartState
|
|
from schema.shape import ShapeCollection
|
|
from schema.indicator import IndicatorCollection
|
|
from datasource.registry import DataSourceRegistry
|
|
from datasource.subscription_manager import SubscriptionManager
|
|
from datasource.websocket_handler import DatafeedWebSocketHandler
|
|
from secrets_manager import SecretsStore, InvalidMasterPassword
|
|
from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators
|
|
from trigger import CommitCoordinator, TriggerQueue
|
|
from trigger.scheduler import TriggerScheduler
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Load environment variables from .env file (if present)
|
|
env_path = Path(__file__).parent.parent / ".env"
|
|
if env_path.exists():
|
|
load_dotenv(env_path)
|
|
|
|
# Load configuration
|
|
config_path = Path(__file__).parent.parent / "config.yaml"
|
|
with open(config_path) as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
registry = SyncRegistry()
|
|
gateway = Gateway()
|
|
agent_executor = None
|
|
|
|
# DataSource infrastructure
|
|
datasource_registry = DataSourceRegistry()
|
|
subscription_manager = SubscriptionManager()
|
|
|
|
# Indicator infrastructure
|
|
indicator_registry = IndicatorRegistry()
|
|
|
|
# Trigger system infrastructure
|
|
trigger_coordinator = None
|
|
trigger_queue = None
|
|
trigger_scheduler = None
|
|
|
|
# Global secrets store
|
|
secrets_store = SecretsStore()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Initialize agent system and data sources on startup."""
|
|
global agent_executor, trigger_coordinator, trigger_queue, trigger_scheduler
|
|
|
|
# Initialize CCXT data sources
|
|
try:
|
|
from datasource.adapters.ccxt_adapter import CCXTDataSource
|
|
|
|
# Binance
|
|
try:
|
|
binance_source = CCXTDataSource(exchange_id="binance", poll_interval=60)
|
|
datasource_registry.register("binance", binance_source)
|
|
subscription_manager.register_source("binance", binance_source)
|
|
logger.info("DataSource: Registered Binance source")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Binance source: {e}")
|
|
|
|
logger.info(f"DataSource infrastructure initialized with sources: {datasource_registry.list_sources()}")
|
|
except ImportError as e:
|
|
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
|
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
|
|
|
# Initialize indicator registry with all TA-Lib indicators
|
|
try:
|
|
indicator_count = register_all_talib_indicators(indicator_registry)
|
|
logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
|
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
|
|
|
# Register custom indicators (TradingView indicators not in TA-Lib)
|
|
try:
|
|
custom_count = register_custom_indicators(indicator_registry)
|
|
logger.info(f"Registered {custom_count} custom indicators")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to register custom indicators: {e}")
|
|
|
|
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
|
anthropic_api_key = None
|
|
|
|
if secrets_store.is_unlocked:
|
|
anthropic_api_key = secrets_store.get("ANTHROPIC_API_KEY")
|
|
if anthropic_api_key:
|
|
logger.info("Loaded API key from encrypted secrets store")
|
|
|
|
# Fall back to environment variable
|
|
if not anthropic_api_key:
|
|
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if anthropic_api_key:
|
|
logger.info("Loaded API key from environment")
|
|
|
|
# Initialize trigger system
|
|
logger.info("Initializing trigger system...")
|
|
trigger_coordinator = CommitCoordinator()
|
|
trigger_queue = TriggerQueue(trigger_coordinator)
|
|
trigger_scheduler = TriggerScheduler(trigger_queue)
|
|
|
|
# Start trigger queue and scheduler
|
|
await trigger_queue.start()
|
|
trigger_scheduler.start()
|
|
logger.info("Trigger system initialized and started")
|
|
|
|
# Set trigger system for agent tools
|
|
set_coordinator(trigger_coordinator)
|
|
set_trigger_queue(trigger_queue)
|
|
set_trigger_scheduler(trigger_scheduler)
|
|
|
|
if not anthropic_api_key:
|
|
logger.error("ANTHROPIC_API_KEY not found in environment!")
|
|
logger.info("Agent system will not be available")
|
|
else:
|
|
# Set the registries for agent tools
|
|
set_registry(registry)
|
|
set_datasource_registry(datasource_registry)
|
|
set_indicator_registry(indicator_registry)
|
|
|
|
# Create and initialize agent
|
|
agent_executor = create_agent(
|
|
model_name=config["agent"]["model"],
|
|
temperature=config["agent"]["temperature"],
|
|
api_key=anthropic_api_key,
|
|
checkpoint_db_path=config["memory"]["checkpoint_db"],
|
|
chroma_db_path=config["memory"]["chroma_db"],
|
|
embedding_model=config["memory"]["embedding_model"],
|
|
context_docs_dir=config["agent"]["context_docs_dir"],
|
|
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ and soul/ live
|
|
)
|
|
|
|
await agent_executor.initialize()
|
|
|
|
# Set agent executor in gateway
|
|
gateway.set_agent_executor(agent_executor.execute)
|
|
|
|
logger.info("Agent system initialized")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
logger.info("Shutting down systems...")
|
|
|
|
# Shutdown trigger system
|
|
if trigger_scheduler:
|
|
trigger_scheduler.shutdown(wait=True)
|
|
logger.info("Trigger scheduler shut down")
|
|
|
|
if trigger_queue:
|
|
await trigger_queue.stop()
|
|
logger.info("Trigger queue stopped")
|
|
|
|
# Shutdown agent system
|
|
if agent_executor and agent_executor.memory_manager:
|
|
await agent_executor.memory_manager.close()
|
|
|
|
logger.info("All systems shut down")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# Create uploads directory
|
|
UPLOAD_DIR = Path(__file__).parent.parent / "data" / "uploads"
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Mount static files for serving uploads
|
|
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
|
|
|
# OrderStore model for synchronization
|
|
class OrderStore(BaseModel):
|
|
orders: list[SwapOrder] = []
|
|
|
|
# ChartStore model for synchronization
|
|
class ChartStore(BaseModel):
|
|
chart_state: ChartState = ChartState()
|
|
|
|
# ShapeStore model for synchronization
|
|
class ShapeStore(BaseModel):
|
|
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
|
|
|
|
# IndicatorStore model for synchronization
|
|
class IndicatorStore(BaseModel):
|
|
indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID
|
|
|
|
# Initialize stores
|
|
order_store = OrderStore()
|
|
chart_store = ChartStore()
|
|
shape_store = ShapeStore()
|
|
indicator_store = IndicatorStore()
|
|
|
|
# Register with SyncRegistry
|
|
registry.register(order_store, store_name="OrderStore")
|
|
registry.register(chart_store, store_name="ChartStore")
|
|
registry.register(shape_store, store_name="ShapeStore")
|
|
registry.register(indicator_store, store_name="IndicatorStore")
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
|
|
# Helper function to send responses
|
|
async def send_response(response):
|
|
try:
|
|
await websocket.send_json(response.model_dump(mode="json"))
|
|
except Exception as e:
|
|
logger.error(f"Error sending response: {e}")
|
|
|
|
# Authentication state
|
|
is_authenticated = False
|
|
|
|
# Wait for authentication message (must be first message)
|
|
try:
|
|
auth_timeout = 30 # 30 seconds to authenticate
|
|
auth_data = await asyncio.wait_for(websocket.receive_text(), timeout=auth_timeout)
|
|
auth_message_json = json.loads(auth_data)
|
|
|
|
if auth_message_json.get("type") != "auth":
|
|
logger.warning("First message was not auth message")
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message="First message must be authentication"
|
|
))
|
|
await websocket.close(code=1008, reason="Authentication required")
|
|
return
|
|
|
|
auth_msg = AuthMessage(**auth_message_json)
|
|
logger.info("Received authentication message")
|
|
|
|
# Check if secrets store needs initialization
|
|
if not secrets_store.is_initialized:
|
|
logger.info("Secrets store not initialized, performing first-time setup")
|
|
|
|
# Require password confirmation for initialization
|
|
if not auth_msg.confirm_password:
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
needs_confirmation=True,
|
|
message="First-time setup: password confirmation required"
|
|
))
|
|
await websocket.close(code=1008, reason="Password confirmation required")
|
|
return
|
|
|
|
if auth_msg.password != auth_msg.confirm_password:
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
needs_confirmation=True,
|
|
message="Passwords do not match"
|
|
))
|
|
await websocket.close(code=1008, reason="Password confirmation failed")
|
|
return
|
|
|
|
# Initialize secrets store
|
|
try:
|
|
secrets_store.initialize(auth_msg.password)
|
|
|
|
# Migrate ANTHROPIC_API_KEY from environment if present
|
|
env_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if env_key:
|
|
secrets_store.set("ANTHROPIC_API_KEY", env_key)
|
|
logger.info("Migrated ANTHROPIC_API_KEY from environment to secrets store")
|
|
|
|
is_authenticated = True
|
|
await send_response(AuthResponseMessage(
|
|
success=True,
|
|
message="Secrets store initialized successfully"
|
|
))
|
|
logger.info("Secrets store initialized and authenticated")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize secrets store: {e}")
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message=f"Initialization failed: {str(e)}"
|
|
))
|
|
await websocket.close(code=1011, reason="Initialization failed")
|
|
return
|
|
else:
|
|
# Unlock existing secrets store (or verify password if already unlocked)
|
|
try:
|
|
# If already unlocked, just verify the password is correct
|
|
if secrets_store.is_unlocked:
|
|
# Verify password by creating a temporary store and attempting unlock
|
|
from secrets_manager import SecretsStore as TempStore
|
|
temp_store = TempStore(data_dir=secrets_store.data_dir)
|
|
temp_store.unlock(auth_msg.password) # This will throw if wrong password
|
|
logger.info("Password verified (store already unlocked)")
|
|
else:
|
|
secrets_store.unlock(auth_msg.password)
|
|
logger.info("Secrets store unlocked successfully")
|
|
|
|
# Check if user wants to change password
|
|
password_changed = False
|
|
if auth_msg.change_to_password:
|
|
# Validate password change request
|
|
if not auth_msg.confirm_new_password:
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message="New password confirmation required"
|
|
))
|
|
await websocket.close(code=1008, reason="Password confirmation required")
|
|
return
|
|
|
|
if auth_msg.change_to_password != auth_msg.confirm_new_password:
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message="New passwords do not match"
|
|
))
|
|
await websocket.close(code=1008, reason="Password confirmation mismatch")
|
|
return
|
|
|
|
# Change the password
|
|
try:
|
|
secrets_store.change_master_password(auth_msg.password, auth_msg.change_to_password)
|
|
password_changed = True
|
|
logger.info("Master password changed successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to change password: {e}")
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message=f"Failed to change password: {str(e)}"
|
|
))
|
|
await websocket.close(code=1011, reason="Password change failed")
|
|
return
|
|
|
|
is_authenticated = True
|
|
response_message = "Password changed successfully" if password_changed else "Authentication successful"
|
|
await send_response(AuthResponseMessage(
|
|
success=True,
|
|
password_changed=password_changed,
|
|
message=response_message
|
|
))
|
|
except InvalidMasterPassword:
|
|
logger.warning("Invalid password attempt")
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message="Invalid password"
|
|
))
|
|
await websocket.close(code=1008, reason="Invalid password")
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Authentication error: {e}")
|
|
await send_response(AuthResponseMessage(
|
|
success=False,
|
|
message="Authentication failed"
|
|
))
|
|
await websocket.close(code=1011, reason="Authentication error")
|
|
return
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning("Authentication timeout")
|
|
await websocket.close(code=1008, reason="Authentication timeout")
|
|
return
|
|
except WebSocketDisconnect:
|
|
logger.info("Client disconnected during authentication")
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Error during authentication: {e}")
|
|
await websocket.close(code=1011, reason="Authentication error")
|
|
return
|
|
|
|
# Now authenticated - proceed with normal WebSocket handling
|
|
registry.websocket = websocket
|
|
|
|
# Create WebSocket channel for agent communication
|
|
channel_id = f"ws_{id(websocket)}"
|
|
client_id = f"client_{id(websocket)}"
|
|
logger.info(f"WebSocket authenticated - channel_id: {channel_id}, client_id: {client_id}")
|
|
ws_channel = WebSocketChannel(channel_id, websocket, session_id="default")
|
|
gateway.register_channel(ws_channel)
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
logger.debug(f"Received WebSocket message: {data[:200]}...") # Log first 200 chars
|
|
message_json = json.loads(data)
|
|
|
|
if "type" not in message_json:
|
|
logger.warning(f"Message missing 'type' field: {message_json}")
|
|
continue
|
|
|
|
msg_type = message_json["type"]
|
|
logger.info(f"Processing message type: {msg_type}")
|
|
|
|
# Handle sync protocol messages
|
|
if msg_type == "hello":
|
|
hello_msg = HelloMessage(**message_json)
|
|
logger.info(f"Hello message received with seqs: {hello_msg.seqs}")
|
|
await registry.sync_client(hello_msg.seqs)
|
|
elif msg_type == "patch":
|
|
patch_msg = PatchMessage(**message_json)
|
|
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
|
|
try:
|
|
await registry.apply_client_patch(
|
|
store_name=patch_msg.store,
|
|
client_base_seq=patch_msg.seq,
|
|
patch=patch_msg.patch
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True)
|
|
elif msg_type == "agent_user_message":
|
|
# Handle agent messages directly here
|
|
print(f"[DEBUG] Raw message_json: {message_json}")
|
|
logger.info(f"Raw message_json: {message_json}")
|
|
msg = WebSocketAgentUserMessage(**message_json)
|
|
print(f"[DEBUG] Parsed message - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
|
|
logger.info(f"Agent user message received - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})")
|
|
from gateway.protocol import UserMessage
|
|
from datetime import datetime, timezone
|
|
|
|
user_msg = UserMessage(
|
|
session_id=msg.session_id,
|
|
channel_id=channel_id,
|
|
content=msg.content,
|
|
attachments=msg.attachments,
|
|
timestamp=datetime.now(timezone.utc)
|
|
)
|
|
logger.info(f"Routing user message to gateway - channel: {channel_id}, session: {msg.session_id}")
|
|
await gateway.route_user_message(user_msg)
|
|
logger.info("Message routing completed")
|
|
|
|
# Handle datafeed protocol messages
|
|
elif msg_type in ["get_config", "search_symbols", "resolve_symbol", "get_bars", "subscribe_bars", "unsubscribe_bars"]:
|
|
from datasource.websocket_protocol import (
|
|
GetConfigRequest, GetConfigResponse,
|
|
SearchSymbolsRequest, SearchSymbolsResponse,
|
|
ResolveSymbolRequest, ResolveSymbolResponse,
|
|
GetBarsRequest, GetBarsResponse,
|
|
SubscribeBarsRequest, SubscribeBarsResponse,
|
|
UnsubscribeBarsRequest, UnsubscribeBarsResponse,
|
|
ErrorResponse
|
|
)
|
|
|
|
request_id = message_json.get("request_id", "unknown")
|
|
try:
|
|
if msg_type == "get_config":
|
|
req = GetConfigRequest(**message_json)
|
|
logger.info(f"Getting config, request_id={req.request_id}")
|
|
sources = datasource_registry.list_sources()
|
|
logger.info(f"Available sources: {sources}")
|
|
|
|
if not sources:
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="NO_SOURCES", error_message="No data sources available")
|
|
await send_response(error_response)
|
|
else:
|
|
# Get config from first source (we can enhance this later to aggregate)
|
|
source = datasource_registry.get(sources[0])
|
|
if source:
|
|
try:
|
|
config = await source.get_config()
|
|
logger.info(f"Got config from {sources[0]}")
|
|
# Enhance with all available exchanges
|
|
all_exchanges = set()
|
|
for source_name in sources:
|
|
s = datasource_registry.get(source_name)
|
|
if s:
|
|
try:
|
|
cfg = await asyncio.wait_for(s.get_config(), timeout=5.0)
|
|
all_exchanges.update(cfg.exchanges)
|
|
logger.info(f"Added exchanges from {source_name}: {cfg.exchanges}")
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Timeout getting config from {source_name}")
|
|
except Exception as e:
|
|
logger.warning(f"Error getting config from {source_name}: {e}")
|
|
config_dict = config.model_dump(mode="json")
|
|
config_dict["exchanges"] = list(all_exchanges)
|
|
logger.info(f"Sending config with exchanges: {list(all_exchanges)}")
|
|
response = GetConfigResponse(request_id=req.request_id, config=config_dict)
|
|
await send_response(response)
|
|
except Exception as e:
|
|
logger.error(f"Error getting config: {e}", exc_info=True)
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
|
|
await send_response(error_response)
|
|
else:
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message="Data sources not available")
|
|
await send_response(error_response)
|
|
|
|
elif msg_type == "search_symbols":
|
|
req = SearchSymbolsRequest(**message_json)
|
|
logger.info(f"Searching symbols: query='{req.query}', request_id={req.request_id}")
|
|
|
|
# Search all data sources
|
|
all_results = []
|
|
sources = datasource_registry.list_sources()
|
|
logger.info(f"Available data sources: {sources}")
|
|
|
|
for source_name in sources:
|
|
source = datasource_registry.get(source_name)
|
|
if source:
|
|
try:
|
|
results = await asyncio.wait_for(
|
|
source.search_symbols(query=req.query, type=req.symbol_type, exchange=req.exchange, limit=req.limit),
|
|
timeout=5.0
|
|
)
|
|
all_results.extend([r.model_dump(mode="json") for r in results])
|
|
logger.info(f"Source '{source_name}' returned {len(results)} results")
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Timeout searching source '{source_name}'")
|
|
except Exception as e:
|
|
logger.warning(f"Error searching source '{source_name}': {e}")
|
|
|
|
logger.info(f"Total search results: {len(all_results)}")
|
|
response = SearchSymbolsResponse(request_id=req.request_id, results=all_results[:req.limit])
|
|
await send_response(response)
|
|
|
|
elif msg_type == "resolve_symbol":
|
|
req = ResolveSymbolRequest(**message_json)
|
|
logger.info(f"Resolving symbol: {req.symbol}")
|
|
|
|
# Parse ticker format: "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
|
symbol = req.symbol
|
|
source_name = None
|
|
symbol_without_exchange = symbol
|
|
|
|
# Check if ticker has exchange prefix
|
|
if ":" in symbol:
|
|
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
|
|
source_name = exchange_prefix.lower()
|
|
logger.info(f"Parsed ticker: exchange={source_name}, symbol={symbol_without_exchange}")
|
|
|
|
# If we identified a source, try it directly
|
|
if source_name:
|
|
try:
|
|
source = datasource_registry.get(source_name)
|
|
if source:
|
|
logger.info(f"Trying to resolve '{symbol_without_exchange}' in source '{source_name}'")
|
|
symbol_info = await asyncio.wait_for(
|
|
source.resolve_symbol(symbol_without_exchange),
|
|
timeout=5.0
|
|
)
|
|
logger.info(f"Successfully resolved '{symbol_without_exchange}' in source '{source_name}'")
|
|
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
|
|
await send_response(response)
|
|
else:
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
|
|
await send_response(error_response)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Timeout resolving '{symbol_without_exchange}' in source '{source_name}'")
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message=f"Timeout resolving symbol")
|
|
await send_response(error_response)
|
|
except Exception as e:
|
|
logger.warning(f"Error resolving '{symbol_without_exchange}' in source '{source_name}': {e}")
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=str(e))
|
|
await send_response(error_response)
|
|
else:
|
|
# No exchange prefix, try all sources
|
|
found = False
|
|
for src in datasource_registry.list_sources():
|
|
try:
|
|
s = datasource_registry.get(src)
|
|
if s:
|
|
logger.info(f"Trying to resolve '{symbol}' in source '{src}'")
|
|
symbol_info = await asyncio.wait_for(s.resolve_symbol(symbol), timeout=5.0)
|
|
if symbol_info:
|
|
logger.info(f"Successfully resolved '{symbol}' in source '{src}'")
|
|
response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json"))
|
|
await send_response(response)
|
|
found = True
|
|
break
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Timeout resolving '{symbol}' in source '{src}'")
|
|
except Exception as e:
|
|
logger.info(f"Symbol '{symbol}' not found in source '{src}': {e}")
|
|
continue
|
|
|
|
if not found:
|
|
# Symbol not found in any source
|
|
logger.warning(f"Symbol '{symbol}' not found in any data source")
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=f"Symbol '{symbol}' not found in any data source")
|
|
await send_response(error_response)
|
|
|
|
elif msg_type == "get_bars":
|
|
req = GetBarsRequest(**message_json)
|
|
logger.info(f"Getting bars for symbol: {req.symbol}")
|
|
|
|
# Parse ticker format: "EXCHANGE:SYMBOL"
|
|
symbol = req.symbol
|
|
source_name = None
|
|
symbol_without_exchange = symbol
|
|
|
|
# Check if ticker has exchange prefix
|
|
if ":" in symbol:
|
|
exchange_prefix, symbol_without_exchange = symbol.split(":", 1)
|
|
source_name = exchange_prefix.lower()
|
|
logger.info(f"Parsed ticker for bars: exchange={source_name}, symbol={symbol_without_exchange}")
|
|
|
|
# If we identified a source, use it directly
|
|
if source_name:
|
|
try:
|
|
source = datasource_registry.get(source_name)
|
|
if source:
|
|
logger.info(f"Getting bars for '{symbol_without_exchange}' from source '{source_name}'")
|
|
history = await asyncio.wait_for(
|
|
source.get_bars(symbol=symbol_without_exchange, resolution=req.resolution, from_time=req.from_time, to_time=req.to_time, countback=req.countback),
|
|
timeout=10.0
|
|
)
|
|
logger.info(f"Successfully got {len(history.bars)} bars for '{symbol_without_exchange}' from source '{source_name}'")
|
|
response = GetBarsResponse(request_id=req.request_id, history=history.model_dump(mode="json"))
|
|
await send_response(response)
|
|
else:
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found")
|
|
await send_response(error_response)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Timeout getting bars for '{symbol_without_exchange}' from source '{source_name}'")
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message="Timeout fetching bars")
|
|
await send_response(error_response)
|
|
except Exception as e:
|
|
logger.warning(f"Error getting bars for '{symbol_without_exchange}' from source '{source_name}': {e}")
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e))
|
|
await send_response(error_response)
|
|
else:
|
|
# No exchange prefix - this shouldn't happen with proper tickers
|
|
logger.warning(f"Ticker '{symbol}' has no exchange prefix")
|
|
error_response = ErrorResponse(request_id=req.request_id, error_code="INVALID_TICKER", error_message="Ticker must include exchange prefix (e.g., BINANCE:BTC/USDT)")
|
|
await send_response(error_response)
|
|
|
|
elif msg_type == "subscribe_bars":
|
|
req = SubscribeBarsRequest(**message_json)
|
|
# TODO: Implement subscription management
|
|
response = SubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
|
|
await send_response(response)
|
|
|
|
elif msg_type == "unsubscribe_bars":
|
|
req = UnsubscribeBarsRequest(**message_json)
|
|
response = UnsubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True)
|
|
await send_response(response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
|
error_response = ErrorResponse(request_id=request_id, error_code="INTERNAL_ERROR", error_message=str(e))
|
|
await send_response(error_response)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket disconnected - channel_id: {channel_id}")
|
|
registry.websocket = None
|
|
gateway.unregister_channel(channel_id)
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}", exc_info=True)
|
|
registry.websocket = None
|
|
gateway.unregister_channel(channel_id)
|
|
|
|
@app.post("/api/upload")
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
"""Upload a file and return its URL."""
|
|
try:
|
|
# Generate unique filename
|
|
file_extension = Path(file.filename).suffix if file.filename else ""
|
|
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
|
file_path = UPLOAD_DIR / unique_filename
|
|
|
|
# Save file
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
# Return URL (relative to backend)
|
|
file_url = f"/uploads/{unique_filename}"
|
|
logger.info(f"File uploaded: {file.filename} -> {file_url}")
|
|
|
|
return {
|
|
"url": file_url,
|
|
"filename": file.filename,
|
|
"size": file_path.stat().st_size
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"File upload error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/healthz")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
# Background task to simulate backend updates (optional, for demo)
|
|
async def simulate_backend_updates():
|
|
while True:
|
|
await asyncio.sleep(5)
|
|
if registry.websocket:
|
|
# Example: could add/modify orders here
|
|
await registry.push_all()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=config["www_port"])
|