import asyncio import json import logging import os from contextlib import asynccontextmanager from pathlib import Path import yaml from dotenv import load_dotenv from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import uuid import shutil from sync.protocol import HelloMessage, PatchMessage, AuthMessage, AuthResponseMessage from sync.registry import SyncRegistry from gateway.hub import Gateway from gateway.channels.websocket import WebSocketChannel from gateway.protocol import WebSocketAgentUserMessage from agent.core import create_agent from agent.tools import set_registry, set_datasource_registry, set_indicator_registry from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator from schema.order_spec import SwapOrder from schema.chart_state import ChartState from schema.shape import ShapeCollection from schema.indicator import IndicatorCollection from datasource.registry import DataSourceRegistry from datasource.subscription_manager import SubscriptionManager from datasource.websocket_handler import DatafeedWebSocketHandler from secrets_manager import SecretsStore, InvalidMasterPassword from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators from trigger import CommitCoordinator, TriggerQueue from trigger.scheduler import TriggerScheduler # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Load environment variables from .env file (if present) env_path = Path(__file__).parent.parent / ".env" if env_path.exists(): load_dotenv(env_path) # Load configuration config_path = Path(__file__).parent.parent / "config.yaml" with open(config_path) as f: config = yaml.safe_load(f) registry = SyncRegistry() gateway = Gateway() agent_executor = None # DataSource infrastructure datasource_registry = DataSourceRegistry() subscription_manager = SubscriptionManager() # Indicator infrastructure indicator_registry = IndicatorRegistry() # Trigger system infrastructure trigger_coordinator = None trigger_queue = None trigger_scheduler = None # Global secrets store secrets_store = SecretsStore() @asynccontextmanager async def lifespan(app: FastAPI): """Initialize agent system and data sources on startup.""" global agent_executor, trigger_coordinator, trigger_queue, trigger_scheduler # Initialize CCXT data sources try: from datasource.adapters.ccxt_adapter import CCXTDataSource # Binance try: binance_source = CCXTDataSource(exchange_id="binance", poll_interval=60) datasource_registry.register("binance", binance_source) subscription_manager.register_source("binance", binance_source) logger.info("DataSource: Registered Binance source") except Exception as e: logger.warning(f"Failed to initialize Binance source: {e}") logger.info(f"DataSource infrastructure initialized with sources: {datasource_registry.list_sources()}") except ImportError as e: logger.warning(f"CCXT not available: {e}. Only demo source will be available.") logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0") # Initialize indicator registry with all TA-Lib indicators try: indicator_count = register_all_talib_indicators(indicator_registry) logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators") except Exception as e: logger.warning(f"Failed to register TA-Lib indicators: {e}") logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.") # Register custom indicators (TradingView indicators not in TA-Lib) try: custom_count = register_custom_indicators(indicator_registry) logger.info(f"Registered {custom_count} custom indicators") except Exception as e: logger.warning(f"Failed to register custom indicators: {e}") # Get API keys from secrets store if unlocked, otherwise fall back to environment anthropic_api_key = None if secrets_store.is_unlocked: anthropic_api_key = secrets_store.get("ANTHROPIC_API_KEY") if anthropic_api_key: logger.info("Loaded API key from encrypted secrets store") # Fall back to environment variable if not anthropic_api_key: anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") if anthropic_api_key: logger.info("Loaded API key from environment") # Initialize trigger system logger.info("Initializing trigger system...") trigger_coordinator = CommitCoordinator() trigger_queue = TriggerQueue(trigger_coordinator) trigger_scheduler = TriggerScheduler(trigger_queue) # Start trigger queue and scheduler await trigger_queue.start() trigger_scheduler.start() logger.info("Trigger system initialized and started") # Set trigger system for agent tools set_coordinator(trigger_coordinator) set_trigger_queue(trigger_queue) set_trigger_scheduler(trigger_scheduler) if not anthropic_api_key: logger.error("ANTHROPIC_API_KEY not found in environment!") logger.info("Agent system will not be available") else: # Set the registries for agent tools set_registry(registry) set_datasource_registry(datasource_registry) set_indicator_registry(indicator_registry) # Create and initialize agent agent_executor = create_agent( model_name=config["agent"]["model"], temperature=config["agent"]["temperature"], api_key=anthropic_api_key, checkpoint_db_path=config["memory"]["checkpoint_db"], chroma_db_path=config["memory"]["chroma_db"], embedding_model=config["memory"]["embedding_model"], context_docs_dir=config["agent"]["context_docs_dir"], base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ and soul/ live ) await agent_executor.initialize() # Set agent executor in gateway gateway.set_agent_executor(agent_executor.execute) logger.info("Agent system initialized") yield # Cleanup logger.info("Shutting down systems...") # Shutdown trigger system if trigger_scheduler: trigger_scheduler.shutdown(wait=True) logger.info("Trigger scheduler shut down") if trigger_queue: await trigger_queue.stop() logger.info("Trigger queue stopped") # Shutdown agent system if agent_executor and agent_executor.memory_manager: await agent_executor.memory_manager.close() logger.info("All systems shut down") app = FastAPI(lifespan=lifespan) # Create uploads directory UPLOAD_DIR = Path(__file__).parent.parent / "data" / "uploads" UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # Mount static files for serving uploads app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") # OrderStore model for synchronization class OrderStore(BaseModel): orders: list[SwapOrder] = [] # ChartStore model for synchronization class ChartStore(BaseModel): chart_state: ChartState = ChartState() # ShapeStore model for synchronization class ShapeStore(BaseModel): shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID # IndicatorStore model for synchronization class IndicatorStore(BaseModel): indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID # Initialize stores order_store = OrderStore() chart_store = ChartStore() shape_store = ShapeStore() indicator_store = IndicatorStore() # Register with SyncRegistry registry.register(order_store, store_name="OrderStore") registry.register(chart_store, store_name="ChartStore") registry.register(shape_store, store_name="ShapeStore") registry.register(indicator_store, store_name="IndicatorStore") @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() # Helper function to send responses async def send_response(response): try: await websocket.send_json(response.model_dump(mode="json")) except Exception as e: logger.error(f"Error sending response: {e}") # Authentication state is_authenticated = False # Wait for authentication message (must be first message) try: auth_timeout = 30 # 30 seconds to authenticate auth_data = await asyncio.wait_for(websocket.receive_text(), timeout=auth_timeout) auth_message_json = json.loads(auth_data) if auth_message_json.get("type") != "auth": logger.warning("First message was not auth message") await send_response(AuthResponseMessage( success=False, message="First message must be authentication" )) await websocket.close(code=1008, reason="Authentication required") return auth_msg = AuthMessage(**auth_message_json) logger.info("Received authentication message") # Check if secrets store needs initialization if not secrets_store.is_initialized: logger.info("Secrets store not initialized, performing first-time setup") # Require password confirmation for initialization if not auth_msg.confirm_password: await send_response(AuthResponseMessage( success=False, needs_confirmation=True, message="First-time setup: password confirmation required" )) await websocket.close(code=1008, reason="Password confirmation required") return if auth_msg.password != auth_msg.confirm_password: await send_response(AuthResponseMessage( success=False, needs_confirmation=True, message="Passwords do not match" )) await websocket.close(code=1008, reason="Password confirmation failed") return # Initialize secrets store try: secrets_store.initialize(auth_msg.password) # Migrate ANTHROPIC_API_KEY from environment if present env_key = os.environ.get("ANTHROPIC_API_KEY") if env_key: secrets_store.set("ANTHROPIC_API_KEY", env_key) logger.info("Migrated ANTHROPIC_API_KEY from environment to secrets store") is_authenticated = True await send_response(AuthResponseMessage( success=True, message="Secrets store initialized successfully" )) logger.info("Secrets store initialized and authenticated") except Exception as e: logger.error(f"Failed to initialize secrets store: {e}") await send_response(AuthResponseMessage( success=False, message=f"Initialization failed: {str(e)}" )) await websocket.close(code=1011, reason="Initialization failed") return else: # Unlock existing secrets store (or verify password if already unlocked) try: # If already unlocked, just verify the password is correct if secrets_store.is_unlocked: # Verify password by creating a temporary store and attempting unlock from secrets_manager import SecretsStore as TempStore temp_store = TempStore(data_dir=secrets_store.data_dir) temp_store.unlock(auth_msg.password) # This will throw if wrong password logger.info("Password verified (store already unlocked)") else: secrets_store.unlock(auth_msg.password) logger.info("Secrets store unlocked successfully") # Check if user wants to change password password_changed = False if auth_msg.change_to_password: # Validate password change request if not auth_msg.confirm_new_password: await send_response(AuthResponseMessage( success=False, message="New password confirmation required" )) await websocket.close(code=1008, reason="Password confirmation required") return if auth_msg.change_to_password != auth_msg.confirm_new_password: await send_response(AuthResponseMessage( success=False, message="New passwords do not match" )) await websocket.close(code=1008, reason="Password confirmation mismatch") return # Change the password try: secrets_store.change_master_password(auth_msg.password, auth_msg.change_to_password) password_changed = True logger.info("Master password changed successfully") except Exception as e: logger.error(f"Failed to change password: {e}") await send_response(AuthResponseMessage( success=False, message=f"Failed to change password: {str(e)}" )) await websocket.close(code=1011, reason="Password change failed") return is_authenticated = True response_message = "Password changed successfully" if password_changed else "Authentication successful" await send_response(AuthResponseMessage( success=True, password_changed=password_changed, message=response_message )) except InvalidMasterPassword: logger.warning("Invalid password attempt") await send_response(AuthResponseMessage( success=False, message="Invalid password" )) await websocket.close(code=1008, reason="Invalid password") return except Exception as e: logger.error(f"Authentication error: {e}") await send_response(AuthResponseMessage( success=False, message="Authentication failed" )) await websocket.close(code=1011, reason="Authentication error") return except asyncio.TimeoutError: logger.warning("Authentication timeout") await websocket.close(code=1008, reason="Authentication timeout") return except WebSocketDisconnect: logger.info("Client disconnected during authentication") return except Exception as e: logger.error(f"Error during authentication: {e}") await websocket.close(code=1011, reason="Authentication error") return # Now authenticated - proceed with normal WebSocket handling registry.websocket = websocket # Create WebSocket channel for agent communication channel_id = f"ws_{id(websocket)}" client_id = f"client_{id(websocket)}" logger.info(f"WebSocket authenticated - channel_id: {channel_id}, client_id: {client_id}") ws_channel = WebSocketChannel(channel_id, websocket, session_id="default") gateway.register_channel(ws_channel) try: while True: data = await websocket.receive_text() logger.debug(f"Received WebSocket message: {data[:200]}...") # Log first 200 chars message_json = json.loads(data) if "type" not in message_json: logger.warning(f"Message missing 'type' field: {message_json}") continue msg_type = message_json["type"] logger.info(f"Processing message type: {msg_type}") # Handle sync protocol messages if msg_type == "hello": hello_msg = HelloMessage(**message_json) logger.info(f"Hello message received with seqs: {hello_msg.seqs}") await registry.sync_client(hello_msg.seqs) elif msg_type == "patch": patch_msg = PatchMessage(**message_json) logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}") try: await registry.apply_client_patch( store_name=patch_msg.store, client_base_seq=patch_msg.seq, patch=patch_msg.patch ) except Exception as e: logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True) elif msg_type == "agent_user_message": # Handle agent messages directly here print(f"[DEBUG] Raw message_json: {message_json}") logger.info(f"Raw message_json: {message_json}") msg = WebSocketAgentUserMessage(**message_json) print(f"[DEBUG] Parsed message - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})") logger.info(f"Agent user message received - session: {msg.session_id}, content: '{msg.content}' (len={len(msg.content)})") from gateway.protocol import UserMessage from datetime import datetime, timezone user_msg = UserMessage( session_id=msg.session_id, channel_id=channel_id, content=msg.content, attachments=msg.attachments, timestamp=datetime.now(timezone.utc) ) logger.info(f"Routing user message to gateway - channel: {channel_id}, session: {msg.session_id}") await gateway.route_user_message(user_msg) logger.info("Message routing completed") # Handle datafeed protocol messages elif msg_type in ["get_config", "search_symbols", "resolve_symbol", "get_bars", "subscribe_bars", "unsubscribe_bars"]: from datasource.websocket_protocol import ( GetConfigRequest, GetConfigResponse, SearchSymbolsRequest, SearchSymbolsResponse, ResolveSymbolRequest, ResolveSymbolResponse, GetBarsRequest, GetBarsResponse, SubscribeBarsRequest, SubscribeBarsResponse, UnsubscribeBarsRequest, UnsubscribeBarsResponse, ErrorResponse ) request_id = message_json.get("request_id", "unknown") try: if msg_type == "get_config": req = GetConfigRequest(**message_json) logger.info(f"Getting config, request_id={req.request_id}") sources = datasource_registry.list_sources() logger.info(f"Available sources: {sources}") if not sources: error_response = ErrorResponse(request_id=req.request_id, error_code="NO_SOURCES", error_message="No data sources available") await send_response(error_response) else: # Get config from first source (we can enhance this later to aggregate) source = datasource_registry.get(sources[0]) if source: try: config = await source.get_config() logger.info(f"Got config from {sources[0]}") # Enhance with all available exchanges all_exchanges = set() for source_name in sources: s = datasource_registry.get(source_name) if s: try: cfg = await asyncio.wait_for(s.get_config(), timeout=5.0) all_exchanges.update(cfg.exchanges) logger.info(f"Added exchanges from {source_name}: {cfg.exchanges}") except asyncio.TimeoutError: logger.warning(f"Timeout getting config from {source_name}") except Exception as e: logger.warning(f"Error getting config from {source_name}: {e}") config_dict = config.model_dump(mode="json") config_dict["exchanges"] = list(all_exchanges) logger.info(f"Sending config with exchanges: {list(all_exchanges)}") response = GetConfigResponse(request_id=req.request_id, config=config_dict) await send_response(response) except Exception as e: logger.error(f"Error getting config: {e}", exc_info=True) error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e)) await send_response(error_response) else: error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message="Data sources not available") await send_response(error_response) elif msg_type == "search_symbols": req = SearchSymbolsRequest(**message_json) logger.info(f"Searching symbols: query='{req.query}', request_id={req.request_id}") # Search all data sources all_results = [] sources = datasource_registry.list_sources() logger.info(f"Available data sources: {sources}") for source_name in sources: source = datasource_registry.get(source_name) if source: try: results = await asyncio.wait_for( source.search_symbols(query=req.query, type=req.symbol_type, exchange=req.exchange, limit=req.limit), timeout=5.0 ) all_results.extend([r.model_dump(mode="json") for r in results]) logger.info(f"Source '{source_name}' returned {len(results)} results") except asyncio.TimeoutError: logger.warning(f"Timeout searching source '{source_name}'") except Exception as e: logger.warning(f"Error searching source '{source_name}': {e}") logger.info(f"Total search results: {len(all_results)}") response = SearchSymbolsResponse(request_id=req.request_id, results=all_results[:req.limit]) await send_response(response) elif msg_type == "resolve_symbol": req = ResolveSymbolRequest(**message_json) logger.info(f"Resolving symbol: {req.symbol}") # Parse ticker format: "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD") symbol = req.symbol source_name = None symbol_without_exchange = symbol # Check if ticker has exchange prefix if ":" in symbol: exchange_prefix, symbol_without_exchange = symbol.split(":", 1) source_name = exchange_prefix.lower() logger.info(f"Parsed ticker: exchange={source_name}, symbol={symbol_without_exchange}") # If we identified a source, try it directly if source_name: try: source = datasource_registry.get(source_name) if source: logger.info(f"Trying to resolve '{symbol_without_exchange}' in source '{source_name}'") symbol_info = await asyncio.wait_for( source.resolve_symbol(symbol_without_exchange), timeout=5.0 ) logger.info(f"Successfully resolved '{symbol_without_exchange}' in source '{source_name}'") response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json")) await send_response(response) else: error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found") await send_response(error_response) except asyncio.TimeoutError: logger.warning(f"Timeout resolving '{symbol_without_exchange}' in source '{source_name}'") error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message=f"Timeout resolving symbol") await send_response(error_response) except Exception as e: logger.warning(f"Error resolving '{symbol_without_exchange}' in source '{source_name}': {e}") error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=str(e)) await send_response(error_response) else: # No exchange prefix, try all sources found = False for src in datasource_registry.list_sources(): try: s = datasource_registry.get(src) if s: logger.info(f"Trying to resolve '{symbol}' in source '{src}'") symbol_info = await asyncio.wait_for(s.resolve_symbol(symbol), timeout=5.0) if symbol_info: logger.info(f"Successfully resolved '{symbol}' in source '{src}'") response = ResolveSymbolResponse(request_id=req.request_id, symbol_info=symbol_info.model_dump(mode="json")) await send_response(response) found = True break except asyncio.TimeoutError: logger.warning(f"Timeout resolving '{symbol}' in source '{src}'") except Exception as e: logger.info(f"Symbol '{symbol}' not found in source '{src}': {e}") continue if not found: # Symbol not found in any source logger.warning(f"Symbol '{symbol}' not found in any data source") error_response = ErrorResponse(request_id=req.request_id, error_code="SYMBOL_NOT_FOUND", error_message=f"Symbol '{symbol}' not found in any data source") await send_response(error_response) elif msg_type == "get_bars": req = GetBarsRequest(**message_json) logger.info(f"Getting bars for symbol: {req.symbol}") # Parse ticker format: "EXCHANGE:SYMBOL" symbol = req.symbol source_name = None symbol_without_exchange = symbol # Check if ticker has exchange prefix if ":" in symbol: exchange_prefix, symbol_without_exchange = symbol.split(":", 1) source_name = exchange_prefix.lower() logger.info(f"Parsed ticker for bars: exchange={source_name}, symbol={symbol_without_exchange}") # If we identified a source, use it directly if source_name: try: source = datasource_registry.get(source_name) if source: logger.info(f"Getting bars for '{symbol_without_exchange}' from source '{source_name}'") history = await asyncio.wait_for( source.get_bars(symbol=symbol_without_exchange, resolution=req.resolution, from_time=req.from_time, to_time=req.to_time, countback=req.countback), timeout=10.0 ) logger.info(f"Successfully got {len(history.bars)} bars for '{symbol_without_exchange}' from source '{source_name}'") response = GetBarsResponse(request_id=req.request_id, history=history.model_dump(mode="json")) await send_response(response) else: error_response = ErrorResponse(request_id=req.request_id, error_code="SOURCE_NOT_FOUND", error_message=f"Data source '{source_name}' not found") await send_response(error_response) except asyncio.TimeoutError: logger.warning(f"Timeout getting bars for '{symbol_without_exchange}' from source '{source_name}'") error_response = ErrorResponse(request_id=req.request_id, error_code="TIMEOUT", error_message="Timeout fetching bars") await send_response(error_response) except Exception as e: logger.warning(f"Error getting bars for '{symbol_without_exchange}' from source '{source_name}': {e}") error_response = ErrorResponse(request_id=req.request_id, error_code="ERROR", error_message=str(e)) await send_response(error_response) else: # No exchange prefix - this shouldn't happen with proper tickers logger.warning(f"Ticker '{symbol}' has no exchange prefix") error_response = ErrorResponse(request_id=req.request_id, error_code="INVALID_TICKER", error_message="Ticker must include exchange prefix (e.g., BINANCE:BTC/USDT)") await send_response(error_response) elif msg_type == "subscribe_bars": req = SubscribeBarsRequest(**message_json) # TODO: Implement subscription management response = SubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True) await send_response(response) elif msg_type == "unsubscribe_bars": req = UnsubscribeBarsRequest(**message_json) response = UnsubscribeBarsResponse(request_id=req.request_id, subscription_id=req.subscription_id, success=True) await send_response(response) except Exception as e: logger.error(f"Error handling {msg_type}: {e}", exc_info=True) error_response = ErrorResponse(request_id=request_id, error_code="INTERNAL_ERROR", error_message=str(e)) await send_response(error_response) except WebSocketDisconnect: logger.info(f"WebSocket disconnected - channel_id: {channel_id}") registry.websocket = None gateway.unregister_channel(channel_id) except Exception as e: logger.error(f"WebSocket error: {e}", exc_info=True) registry.websocket = None gateway.unregister_channel(channel_id) @app.post("/api/upload") async def upload_file(file: UploadFile = File(...)): """Upload a file and return its URL.""" try: # Generate unique filename file_extension = Path(file.filename).suffix if file.filename else "" unique_filename = f"{uuid.uuid4()}{file_extension}" file_path = UPLOAD_DIR / unique_filename # Save file with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) # Return URL (relative to backend) file_url = f"/uploads/{unique_filename}" logger.info(f"File uploaded: {file.filename} -> {file_url}") return { "url": file_url, "filename": file.filename, "size": file_path.stat().st_size } except Exception as e: logger.error(f"File upload error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/healthz") async def health(): return {"status": "ok"} # Background task to simulate backend updates (optional, for demo) async def simulate_backend_updates(): while True: await asyncio.sleep(5) if registry.websocket: # Example: could add/modify orders here await registry.push_all() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=config["www_port"])