initial commit with charts and assistant chat
This commit is contained in:
23
backend/src/datasource/__init__.py
Normal file
23
backend/src/datasource/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .base import DataSource
|
||||
from .schema import (
|
||||
ColumnInfo,
|
||||
SymbolInfo,
|
||||
Bar,
|
||||
HistoryResult,
|
||||
DatafeedConfig,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
)
|
||||
from .registry import DataSourceRegistry
|
||||
|
||||
__all__ = [
|
||||
"DataSource",
|
||||
"ColumnInfo",
|
||||
"SymbolInfo",
|
||||
"Bar",
|
||||
"HistoryResult",
|
||||
"DatafeedConfig",
|
||||
"Resolution",
|
||||
"SearchResult",
|
||||
"DataSourceRegistry",
|
||||
]
|
||||
3
backend/src/datasource/adapters/__init__.py
Normal file
3
backend/src/datasource/adapters/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ccxt_adapter import CCXTDataSource
|
||||
|
||||
__all__ = ["CCXTDataSource"]
|
||||
526
backend/src/datasource/adapters/ccxt_adapter.py
Normal file
526
backend/src/datasource/adapters/ccxt_adapter.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
CCXT DataSource adapter for accessing cryptocurrency exchange data.
|
||||
|
||||
This adapter provides access to hundreds of cryptocurrency exchanges through
|
||||
the free CCXT library (not ccxt.pro), supporting both historical data and
|
||||
polling-based subscriptions.
|
||||
|
||||
Numerical Precision:
|
||||
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
|
||||
- CCXT returns numeric values as strings or floats depending on configuration
|
||||
- All financial values are converted to Decimal to maintain precision
|
||||
|
||||
Real-time Updates:
|
||||
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
||||
- Default polling interval: 60 seconds (configurable)
|
||||
- Simulates real-time subscriptions by periodically fetching latest bars
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import ccxt.async_support as ccxt
|
||||
|
||||
from ..base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..schema import (
|
||||
Bar,
|
||||
ColumnInfo,
|
||||
DatafeedConfig,
|
||||
HistoryResult,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
SymbolInfo,
|
||||
)
|
||||
|
||||
|
||||
class CCXTDataSource(DataSource):
|
||||
"""
|
||||
DataSource adapter for CCXT cryptocurrency exchanges (free version).
|
||||
|
||||
Provides access to:
|
||||
- Multiple cryptocurrency exchanges (Binance, Coinbase, Kraken, etc.)
|
||||
- Historical OHLCV data via REST API
|
||||
- Polling-based real-time updates (configurable interval)
|
||||
- Symbol search and metadata
|
||||
|
||||
Args:
|
||||
exchange_id: CCXT exchange identifier (e.g., 'binance', 'coinbase', 'kraken')
|
||||
config: Optional exchange-specific configuration (API keys, options)
|
||||
sandbox: Whether to use sandbox/testnet mode (default: False)
|
||||
poll_interval: Interval in seconds for polling updates (default: 60)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exchange_id: str = "binance",
|
||||
config: Optional[Dict] = None,
|
||||
sandbox: bool = False,
|
||||
poll_interval: int = 60,
|
||||
):
|
||||
self.exchange_id = exchange_id
|
||||
self._config = config or {}
|
||||
self._sandbox = sandbox
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
# Initialize exchange (using free async_support, not pro)
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchange = exchange_class(self._config)
|
||||
|
||||
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
||||
self.exchange.set_sandbox_mode(True)
|
||||
|
||||
# Cache for markets
|
||||
self._markets: Optional[Dict] = None
|
||||
self._markets_loaded = False
|
||||
|
||||
# Active subscriptions (polling-based)
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._subscription_callbacks: Dict[str, Callable] = {}
|
||||
self._last_bars: Dict[str, int] = {} # Track last bar timestamp per subscription
|
||||
|
||||
@staticmethod
|
||||
def _to_decimal(value: Union[str, int, float, Decimal, None]) -> Optional[Decimal]:
|
||||
"""
|
||||
Convert a value to Decimal for numerical precision.
|
||||
|
||||
Handles CCXT's mixed output (strings, floats, ints, None).
|
||||
Converts floats by converting to string first to avoid precision loss.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Decimal):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return Decimal(value)
|
||||
if isinstance(value, (int, float)):
|
||||
# Convert to string first to avoid float precision issues
|
||||
return Decimal(str(value))
|
||||
return None
|
||||
|
||||
async def _ensure_markets_loaded(self):
|
||||
"""Ensure markets are loaded from exchange"""
|
||||
if not self._markets_loaded:
|
||||
self._markets = await self.exchange.load_markets()
|
||||
self._markets_loaded = True
|
||||
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
"""Get datafeed configuration"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
# Determine supported resolutions based on exchange capabilities
|
||||
supported_resolutions = [
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.M30,
|
||||
Resolution.H1,
|
||||
Resolution.H4,
|
||||
Resolution.D1,
|
||||
]
|
||||
|
||||
# Get unique exchange names (most CCXT exchanges are just one)
|
||||
exchanges = [self.exchange_id.upper()]
|
||||
|
||||
return DatafeedConfig(
|
||||
name=f"CCXT {self.exchange_id.title()}",
|
||||
description=f"Live and historical cryptocurrency data from {self.exchange_id} via CCXT library. "
|
||||
f"Supports OHLCV data for {len(self._markets) if self._markets else 'many'} trading pairs.",
|
||||
supported_resolutions=supported_resolutions,
|
||||
supports_search=True,
|
||||
supports_time=True,
|
||||
exchanges=exchanges,
|
||||
symbols_types=["crypto", "spot", "futures", "swap"],
|
||||
)
|
||||
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
"""Search for symbols on the exchange"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
query_upper = query.upper()
|
||||
results = []
|
||||
|
||||
for symbol, market in self._markets.items():
|
||||
# Match query against symbol or base/quote currencies
|
||||
if (query_upper in symbol or
|
||||
query_upper in market.get('base', '') or
|
||||
query_upper in market.get('quote', '')):
|
||||
|
||||
# Filter by type if specified
|
||||
market_type = market.get('type', 'spot')
|
||||
if type and market_type != type:
|
||||
continue
|
||||
|
||||
# Create search result
|
||||
base = market.get('base', '')
|
||||
quote = market.get('quote', '')
|
||||
|
||||
results.append(
|
||||
SearchResult(
|
||||
symbol=f"{base}/{quote}", # Clean user-facing format
|
||||
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
|
||||
full_name=f"{base}/{quote} ({self.exchange_id.upper()})",
|
||||
description=f"{base}/{quote} {market_type} trading pair on {self.exchange_id}",
|
||||
exchange=self.exchange_id.upper(),
|
||||
type=market_type,
|
||||
)
|
||||
)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
"""Get complete metadata for a symbol"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
market = self._markets[symbol]
|
||||
base = market.get('base', '')
|
||||
quote = market.get('quote', '')
|
||||
market_type = market.get('type', 'spot')
|
||||
|
||||
# Determine price scale from market precision
|
||||
# CCXT precision can be in different modes:
|
||||
# - DECIMAL_PLACES (int): number of decimal places (e.g., 2 = 0.01)
|
||||
# - TICK_SIZE (float): actual tick size (e.g., 0.01, 0.00001)
|
||||
# We need to convert to pricescale (10^n where n is decimal places)
|
||||
price_precision = market.get('precision', {}).get('price', 2)
|
||||
|
||||
if isinstance(price_precision, float):
|
||||
# TICK_SIZE mode: precision is the actual tick size (e.g., 0.01, 0.00001)
|
||||
# Convert tick size to decimal places
|
||||
# For 0.01 -> 2 decimal places, 0.00001 -> 5 decimal places
|
||||
tick_str = str(Decimal(str(price_precision)))
|
||||
if '.' in tick_str:
|
||||
decimal_places = len(tick_str.split('.')[1].rstrip('0'))
|
||||
else:
|
||||
decimal_places = 0
|
||||
pricescale = 10 ** decimal_places
|
||||
else:
|
||||
# DECIMAL_PLACES or SIGNIFICANT_DIGITS mode: precision is an integer
|
||||
# Assume DECIMAL_PLACES mode (most common for price)
|
||||
pricescale = 10 ** int(price_precision)
|
||||
|
||||
return SymbolInfo(
|
||||
symbol=f"{base}/{quote}", # Clean user-facing format
|
||||
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
|
||||
name=f"{base}/{quote}",
|
||||
description=f"{base}/{quote} {market_type} pair on {self.exchange_id}. "
|
||||
f"Minimum order: {market.get('limits', {}).get('amount', {}).get('min', 'N/A')} {base}",
|
||||
type=market_type,
|
||||
exchange=self.exchange_id.upper(),
|
||||
timezone="Etc/UTC",
|
||||
session="24x7",
|
||||
supported_resolutions=[
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.M30,
|
||||
Resolution.H1,
|
||||
Resolution.H4,
|
||||
Resolution.D1,
|
||||
],
|
||||
has_intraday=True,
|
||||
has_daily=True,
|
||||
has_weekly_and_monthly=False,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="decimal",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="decimal",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="decimal",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="decimal",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="decimal",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
],
|
||||
time_column="time",
|
||||
has_ohlcv=True,
|
||||
pricescale=pricescale,
|
||||
minmov=1,
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
)
|
||||
|
||||
def _resolution_to_timeframe(self, resolution: str) -> str:
|
||||
"""Convert our resolution format to CCXT timeframe format"""
|
||||
# Map our resolutions to CCXT timeframes
|
||||
mapping = {
|
||||
"1": "1m",
|
||||
"5": "5m",
|
||||
"15": "15m",
|
||||
"30": "30m",
|
||||
"60": "1h",
|
||||
"120": "2h",
|
||||
"240": "4h",
|
||||
"360": "6h",
|
||||
"720": "12h",
|
||||
"1D": "1d",
|
||||
"1W": "1w",
|
||||
"1M": "1M",
|
||||
}
|
||||
return mapping.get(resolution, "1m")
|
||||
|
||||
def _timeframe_to_milliseconds(self, timeframe: str) -> int:
|
||||
"""Convert CCXT timeframe to milliseconds"""
|
||||
unit = timeframe[-1]
|
||||
amount = int(timeframe[:-1]) if len(timeframe) > 1 else 1
|
||||
|
||||
units = {
|
||||
's': 1000,
|
||||
'm': 60 * 1000,
|
||||
'h': 60 * 60 * 1000,
|
||||
'd': 24 * 60 * 60 * 1000,
|
||||
'w': 7 * 24 * 60 * 60 * 1000,
|
||||
'M': 30 * 24 * 60 * 60 * 1000, # Approximate
|
||||
}
|
||||
|
||||
return amount * units.get(unit, 60000)
|
||||
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
"""Get historical bars from the exchange"""
|
||||
logger.info(
|
||||
f"CCXTDataSource({self.exchange_id}).get_bars: symbol={symbol}, resolution={resolution}, "
|
||||
f"from_time={from_time}, to_time={to_time}, countback={countback}"
|
||||
)
|
||||
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
timeframe = self._resolution_to_timeframe(resolution)
|
||||
|
||||
# CCXT uses milliseconds for timestamps
|
||||
since = from_time * 1000
|
||||
until = to_time * 1000
|
||||
|
||||
# Fetch OHLCV data
|
||||
limit = countback if countback else 1000
|
||||
|
||||
try:
|
||||
# Fetch in batches if needed
|
||||
all_ohlcv = []
|
||||
current_since = since
|
||||
|
||||
while current_since < until:
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
symbol,
|
||||
timeframe=timeframe,
|
||||
since=current_since,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not ohlcv:
|
||||
break
|
||||
|
||||
all_ohlcv.extend(ohlcv)
|
||||
|
||||
# Update since for next batch
|
||||
last_timestamp = ohlcv[-1][0]
|
||||
if last_timestamp <= current_since:
|
||||
break # No progress, avoid infinite loop
|
||||
current_since = last_timestamp + 1
|
||||
|
||||
# Stop if we have enough bars
|
||||
if countback and len(all_ohlcv) >= countback:
|
||||
all_ohlcv = all_ohlcv[:countback]
|
||||
break
|
||||
|
||||
# Convert to our Bar format with Decimal precision
|
||||
bars = []
|
||||
for candle in all_ohlcv:
|
||||
timestamp_ms, open_price, high, low, close, volume = candle
|
||||
timestamp = timestamp_ms // 1000 # Convert to seconds
|
||||
|
||||
# Only include bars within requested range
|
||||
if timestamp < from_time or timestamp >= to_time:
|
||||
continue
|
||||
|
||||
bars.append(
|
||||
Bar(
|
||||
time=timestamp,
|
||||
data={
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Get symbol info for column metadata
|
||||
symbol_info = await self.resolve_symbol(symbol)
|
||||
|
||||
logger.info(
|
||||
f"CCXTDataSource({self.exchange_id}).get_bars: Returning {len(bars)} bars. "
|
||||
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Determine if more data is available
|
||||
next_time = None
|
||||
if bars and countback and len(bars) >= countback:
|
||||
next_time = bars[-1].time + (bars[-1].time - bars[-2].time if len(bars) > 1 else 60)
|
||||
|
||||
return HistoryResult(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
bars=bars,
|
||||
columns=symbol_info.columns,
|
||||
nextTime=next_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch bars for {symbol}: {str(e)}")
|
||||
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to bar updates via polling.
|
||||
|
||||
Note: Uses polling instead of WebSocket since we're using free CCXT.
|
||||
Polls at the configured interval (default: 60 seconds).
|
||||
"""
|
||||
await self._ensure_markets_loaded()
|
||||
|
||||
if symbol not in self._markets:
|
||||
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
|
||||
|
||||
subscription_id = f"{symbol}:{resolution}:{time.time()}"
|
||||
|
||||
# Store callback
|
||||
self._subscription_callbacks[subscription_id] = on_tick
|
||||
|
||||
# Start polling task
|
||||
timeframe = self._resolution_to_timeframe(resolution)
|
||||
task = asyncio.create_task(
|
||||
self._poll_ohlcv(symbol, timeframe, subscription_id)
|
||||
)
|
||||
self._subscriptions[subscription_id] = task
|
||||
|
||||
return subscription_id
|
||||
|
||||
async def _poll_ohlcv(self, symbol: str, timeframe: str, subscription_id: str):
|
||||
"""
|
||||
Poll for OHLCV updates at regular intervals.
|
||||
|
||||
This simulates real-time updates by fetching the latest bars periodically.
|
||||
Only sends updates when new bars are detected.
|
||||
"""
|
||||
try:
|
||||
while subscription_id in self._subscription_callbacks:
|
||||
try:
|
||||
# Fetch latest bars
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
symbol,
|
||||
timeframe=timeframe,
|
||||
limit=2, # Get last 2 bars to detect new ones
|
||||
)
|
||||
|
||||
if ohlcv and len(ohlcv) > 0:
|
||||
# Get the latest candle
|
||||
latest = ohlcv[-1]
|
||||
timestamp_ms, open_price, high, low, close, volume = latest
|
||||
timestamp = timestamp_ms // 1000
|
||||
|
||||
# Only send update if this is a new bar
|
||||
last_timestamp = self._last_bars.get(subscription_id, 0)
|
||||
if timestamp > last_timestamp:
|
||||
self._last_bars[subscription_id] = timestamp
|
||||
|
||||
# Convert to our format with Decimal precision
|
||||
tick_data = {
|
||||
"time": timestamp,
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
}
|
||||
|
||||
# Call the callback
|
||||
callback = self._subscription_callbacks.get(subscription_id)
|
||||
if callback:
|
||||
callback(tick_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error polling OHLCV for {symbol}: {e}")
|
||||
|
||||
# Wait for next poll interval
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
"""Unsubscribe from polling updates"""
|
||||
# Remove callback and tracking
|
||||
self._subscription_callbacks.pop(subscription_id, None)
|
||||
self._last_bars.pop(subscription_id, None)
|
||||
|
||||
# Cancel polling task
|
||||
task = self._subscriptions.pop(subscription_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close exchange connection and cleanup"""
|
||||
# Cancel all subscriptions
|
||||
for subscription_id in list(self._subscriptions.keys()):
|
||||
await self.unsubscribe_bars(subscription_id)
|
||||
|
||||
# Close exchange
|
||||
if hasattr(self.exchange, 'close'):
|
||||
await self.exchange.close()
|
||||
353
backend/src/datasource/adapters/demo.py
Normal file
353
backend/src/datasource/adapters/demo.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Demo data source with synthetic data.
|
||||
|
||||
Generates realistic-looking OHLCV data plus additional columns for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from ..base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..schema import (
|
||||
Bar,
|
||||
ColumnInfo,
|
||||
DatafeedConfig,
|
||||
HistoryResult,
|
||||
Resolution,
|
||||
SearchResult,
|
||||
SymbolInfo,
|
||||
)
|
||||
|
||||
|
||||
class DemoDataSource(DataSource):
|
||||
"""
|
||||
Demo data source that generates synthetic time-series data.
|
||||
|
||||
Provides:
|
||||
- Standard OHLCV columns
|
||||
- Additional demo columns (RSI, sentiment, volume_profile)
|
||||
- Real-time updates via polling simulation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._symbols = {
|
||||
"DEMO:BTC/USD": {
|
||||
"name": "Bitcoin",
|
||||
"type": "crypto",
|
||||
"base_price": 50000.0,
|
||||
"volatility": 0.02,
|
||||
},
|
||||
"DEMO:ETH/USD": {
|
||||
"name": "Ethereum",
|
||||
"type": "crypto",
|
||||
"base_price": 3000.0,
|
||||
"volatility": 0.03,
|
||||
},
|
||||
"DEMO:SOL/USD": {
|
||||
"name": "Solana",
|
||||
"type": "crypto",
|
||||
"base_price": 100.0,
|
||||
"volatility": 0.04,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
return DatafeedConfig(
|
||||
name="Demo DataSource",
|
||||
description="Synthetic data generator for testing. Provides OHLCV plus additional indicator columns.",
|
||||
supported_resolutions=[
|
||||
Resolution.M1,
|
||||
Resolution.M5,
|
||||
Resolution.M15,
|
||||
Resolution.H1,
|
||||
Resolution.D1,
|
||||
],
|
||||
supports_search=True,
|
||||
supports_time=True,
|
||||
exchanges=["DEMO"],
|
||||
symbols_types=["crypto"],
|
||||
)
|
||||
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for symbol, info in self._symbols.items():
|
||||
if query_lower in symbol.lower() or query_lower in info["name"].lower():
|
||||
if type and info["type"] != type:
|
||||
continue
|
||||
results.append(
|
||||
SearchResult(
|
||||
symbol=info['name'], # Clean user-facing format (e.g., "Bitcoin")
|
||||
ticker=symbol, # Keep DEMO:BTC/USD format for routing
|
||||
full_name=f"{info['name']} (DEMO)",
|
||||
description=f"Demo {info['name']} pair",
|
||||
exchange="DEMO",
|
||||
type=info["type"],
|
||||
)
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
info = self._symbols[symbol]
|
||||
base, quote = symbol.split(":")[1].split("/")
|
||||
|
||||
return SymbolInfo(
|
||||
symbol=info["name"], # Clean user-facing format (e.g., "Bitcoin")
|
||||
ticker=symbol, # Keep DEMO:BTC/USD format for routing
|
||||
name=info["name"],
|
||||
description=f"Demo {info['name']} spot price with synthetic indicators",
|
||||
type=info["type"],
|
||||
exchange="DEMO",
|
||||
timezone="Etc/UTC",
|
||||
session="24x7",
|
||||
supported_resolutions=[Resolution.M1, Resolution.M5, Resolution.M15, Resolution.H1, Resolution.D1],
|
||||
has_intraday=True,
|
||||
has_daily=True,
|
||||
has_weekly_and_monthly=False,
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="float",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="float",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="float",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="float",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="float",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="rsi",
|
||||
type="float",
|
||||
description="Relative Strength Index (14-period), range 0-100",
|
||||
unit=None,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="sentiment",
|
||||
type="float",
|
||||
description="Synthetic social sentiment score, range -1.0 to 1.0",
|
||||
unit=None,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume_profile",
|
||||
type="float",
|
||||
description="Volume as percentage of 24h average",
|
||||
unit="%",
|
||||
),
|
||||
],
|
||||
time_column="time",
|
||||
has_ohlcv=True,
|
||||
pricescale=100,
|
||||
minmov=1,
|
||||
base_currency=base,
|
||||
quote_currency=quote,
|
||||
)
|
||||
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
logger.info(
|
||||
f"DemoDataSource.get_bars: symbol={symbol}, resolution={resolution}, "
|
||||
f"from_time={from_time}, to_time={to_time}, countback={countback}"
|
||||
)
|
||||
|
||||
info = self._symbols[symbol]
|
||||
symbol_meta = await self.resolve_symbol(symbol)
|
||||
|
||||
# Convert resolution to seconds
|
||||
resolution_seconds = self._resolution_to_seconds(resolution)
|
||||
|
||||
# Generate bars
|
||||
bars = []
|
||||
# Align current_time to resolution, but ensure it's >= from_time
|
||||
current_time = from_time - (from_time % resolution_seconds)
|
||||
if current_time < from_time:
|
||||
current_time += resolution_seconds
|
||||
|
||||
price = info["base_price"]
|
||||
|
||||
bar_count = 0
|
||||
max_bars = countback if countback else 5000
|
||||
|
||||
while current_time <= to_time and bar_count < max_bars:
|
||||
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
|
||||
|
||||
# Only include bars within the requested range
|
||||
if from_time <= current_time <= to_time:
|
||||
bars.append(Bar(time=current_time * 1000, data=bar_data)) # Convert to milliseconds
|
||||
bar_count += 1
|
||||
|
||||
price = bar_data["close"] # Next bar starts from previous close
|
||||
current_time += resolution_seconds
|
||||
|
||||
logger.info(
|
||||
f"DemoDataSource.get_bars: Generated {len(bars)} bars. "
|
||||
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Determine if there's more data (for pagination)
|
||||
next_time = current_time if current_time <= to_time else None
|
||||
|
||||
return HistoryResult(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
bars=bars,
|
||||
columns=symbol_meta.columns,
|
||||
nextTime=next_time,
|
||||
)
|
||||
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
if symbol not in self._symbols:
|
||||
raise ValueError(f"Symbol '{symbol}' not found")
|
||||
|
||||
subscription_id = f"{symbol}:{resolution}:{time.time()}"
|
||||
|
||||
# Start background task to simulate real-time updates
|
||||
task = asyncio.create_task(
|
||||
self._tick_generator(symbol, resolution, on_tick, subscription_id)
|
||||
)
|
||||
self._subscriptions[subscription_id] = task
|
||||
|
||||
return subscription_id
|
||||
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
task = self._subscriptions.pop(subscription_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _resolution_to_seconds(self, resolution: str) -> int:
|
||||
"""Convert resolution string to seconds"""
|
||||
if resolution.endswith("D"):
|
||||
return int(resolution[:-1]) * 86400
|
||||
elif resolution.endswith("W"):
|
||||
return int(resolution[:-1]) * 604800
|
||||
elif resolution.endswith("M"):
|
||||
return int(resolution[:-1]) * 2592000 # Approximate month
|
||||
else:
|
||||
# Assume minutes
|
||||
return int(resolution) * 60
|
||||
|
||||
def _generate_bar(self, timestamp: int, base_price: float, volatility: float, period_seconds: int) -> dict:
|
||||
"""Generate a single synthetic OHLCV bar"""
|
||||
# Random walk for the period
|
||||
open_price = base_price
|
||||
|
||||
# Generate intra-period price movement
|
||||
num_ticks = max(10, period_seconds // 60) # More ticks for longer periods
|
||||
prices = [open_price]
|
||||
|
||||
for _ in range(num_ticks):
|
||||
change = random.gauss(0, volatility / math.sqrt(num_ticks))
|
||||
prices.append(prices[-1] * (1 + change))
|
||||
|
||||
close_price = prices[-1]
|
||||
high_price = max(prices)
|
||||
low_price = min(prices)
|
||||
|
||||
# Volume with some randomness
|
||||
base_volume = 1000000 * (period_seconds / 60) # Scale with period
|
||||
volume = base_volume * random.uniform(0.5, 2.0)
|
||||
|
||||
# Additional synthetic indicators
|
||||
rsi = 30 + random.random() * 40 # Biased toward middle range
|
||||
sentiment = math.sin(timestamp / 3600) * 0.5 + random.gauss(0, 0.2) # Hourly cycle + noise
|
||||
sentiment = max(-1.0, min(1.0, sentiment))
|
||||
volume_profile = 100 * random.uniform(0.5, 1.5)
|
||||
|
||||
return {
|
||||
"open": round(open_price, 2),
|
||||
"high": round(high_price, 2),
|
||||
"low": round(low_price, 2),
|
||||
"close": round(close_price, 2),
|
||||
"volume": round(volume, 2),
|
||||
"rsi": round(rsi, 2),
|
||||
"sentiment": round(sentiment, 3),
|
||||
"volume_profile": round(volume_profile, 2),
|
||||
}
|
||||
|
||||
async def _tick_generator(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
subscription_id: str,
|
||||
):
|
||||
"""Background task that generates periodic ticks"""
|
||||
info = self._symbols[symbol]
|
||||
resolution_seconds = self._resolution_to_seconds(resolution)
|
||||
|
||||
# Start from current aligned time
|
||||
current_time = int(time.time())
|
||||
current_time = current_time - (current_time % resolution_seconds)
|
||||
price = info["base_price"]
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait until next bar
|
||||
await asyncio.sleep(resolution_seconds)
|
||||
|
||||
current_time += resolution_seconds
|
||||
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
|
||||
price = bar_data["close"]
|
||||
|
||||
# Call the tick handler
|
||||
tick_data = {"time": current_time, **bar_data}
|
||||
on_tick(tick_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Subscription cancelled
|
||||
pass
|
||||
146
backend/src/datasource/base.py
Normal file
146
backend/src/datasource/base.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Abstract DataSource interface.
|
||||
|
||||
Inspired by TradingView's Datafeed API with extensions for flexible column schemas
|
||||
and AI-native metadata.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from .schema import DatafeedConfig, HistoryResult, SearchResult, SymbolInfo
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""
|
||||
Abstract base class for time-series data sources.
|
||||
|
||||
Provides a standardized interface for:
|
||||
- Symbol search and metadata retrieval
|
||||
- Historical data queries (time-based, paginated)
|
||||
- Real-time data subscriptions
|
||||
|
||||
All data rows must have a timestamp. Additional columns are flexible
|
||||
and described via ColumnInfo metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_config(self) -> DatafeedConfig:
|
||||
"""
|
||||
Get datafeed configuration and capabilities.
|
||||
|
||||
Called once during initialization to understand what this data source
|
||||
supports (resolutions, exchanges, search, etc.).
|
||||
|
||||
Returns:
|
||||
DatafeedConfig describing this datafeed's capabilities
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_symbols(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search for symbols matching a text query.
|
||||
|
||||
Args:
|
||||
query: Free-text search string
|
||||
type: Optional filter by instrument type
|
||||
exchange: Optional filter by exchange
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching symbols with basic metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
|
||||
"""
|
||||
Get complete metadata for a symbol.
|
||||
|
||||
Called after a symbol is selected to retrieve full information including
|
||||
supported resolutions, column schema, trading session, etc.
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
|
||||
Returns:
|
||||
Complete SymbolInfo including column definitions
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> HistoryResult:
|
||||
"""
|
||||
Get historical bars for a symbol and resolution.
|
||||
|
||||
Time range is specified in Unix timestamps (seconds). If more data is
|
||||
available beyond the requested range, the result should include a
|
||||
nextTime value for pagination.
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1", "5", "60", "1D")
|
||||
from_time: Start time (Unix timestamp in seconds)
|
||||
to_time: End time (Unix timestamp in seconds)
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
HistoryResult with bars, column schema, and pagination info
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol or resolution is not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_bars(
|
||||
self,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
on_tick: Callable[[dict], None],
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to real-time bar updates.
|
||||
|
||||
The callback will be invoked with new bar data as it becomes available.
|
||||
The data dict will match the column schema from resolve_symbol().
|
||||
|
||||
Args:
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution
|
||||
on_tick: Callback function receiving bar data dict
|
||||
|
||||
Returns:
|
||||
Subscription ID for later unsubscribe
|
||||
|
||||
Raises:
|
||||
ValueError: If symbol or resolution is not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unsubscribe_bars(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe from real-time updates.
|
||||
|
||||
Args:
|
||||
subscription_id: ID returned from subscribe_bars()
|
||||
"""
|
||||
pass
|
||||
109
backend/src/datasource/registry.py
Normal file
109
backend/src/datasource/registry.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
DataSource registry for managing multiple data sources.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base import DataSource
|
||||
from .schema import SearchResult, SymbolInfo
|
||||
|
||||
|
||||
class DataSourceRegistry:
|
||||
"""
|
||||
Central registry for managing multiple DataSource instances.
|
||||
|
||||
Allows routing symbol queries to the appropriate data source and
|
||||
aggregating search results across multiple sources.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._sources: Dict[str, DataSource] = {}
|
||||
|
||||
def register(self, name: str, source: DataSource) -> None:
|
||||
"""
|
||||
Register a data source.
|
||||
|
||||
Args:
|
||||
name: Unique name for this data source
|
||||
source: DataSource implementation
|
||||
"""
|
||||
self._sources[name] = source
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a data source.
|
||||
|
||||
Args:
|
||||
name: Name of the data source to remove
|
||||
"""
|
||||
self._sources.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Optional[DataSource]:
|
||||
"""
|
||||
Get a registered data source by name.
|
||||
|
||||
Args:
|
||||
name: Data source name
|
||||
|
||||
Returns:
|
||||
DataSource instance or None if not found
|
||||
"""
|
||||
return self._sources.get(name)
|
||||
|
||||
def list_sources(self) -> List[str]:
|
||||
"""
|
||||
Get names of all registered data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names
|
||||
"""
|
||||
return list(self._sources.keys())
|
||||
|
||||
async def search_all(
|
||||
self,
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""
|
||||
Search across all registered data sources.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
type: Optional instrument type filter
|
||||
exchange: Optional exchange filter
|
||||
limit: Maximum results per source
|
||||
|
||||
Returns:
|
||||
Dict mapping source name to search results
|
||||
"""
|
||||
results = {}
|
||||
for name, source in self._sources.items():
|
||||
try:
|
||||
source_results = await source.search_symbols(query, type, exchange, limit)
|
||||
if source_results:
|
||||
results[name] = source_results
|
||||
except Exception:
|
||||
# Silently skip sources that error during search
|
||||
continue
|
||||
return results
|
||||
|
||||
async def resolve_symbol(self, source_name: str, symbol: str) -> SymbolInfo:
|
||||
"""
|
||||
Resolve a symbol from a specific data source.
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
|
||||
Returns:
|
||||
SymbolInfo from the specified source
|
||||
|
||||
Raises:
|
||||
ValueError: If source not found or symbol not found
|
||||
"""
|
||||
source = self.get(source_name)
|
||||
if not source:
|
||||
raise ValueError(f"Data source '{source_name}' not found")
|
||||
return await source.resolve_symbol(symbol)
|
||||
194
backend/src/datasource/schema.py
Normal file
194
backend/src/datasource/schema.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Data models for the DataSource interface.
|
||||
|
||||
Inspired by TradingView's Datafeed API but with flexible column schemas
|
||||
for AI-native trading platform needs.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Resolution(StrEnum):
|
||||
"""Standard time resolutions for bar data"""
|
||||
|
||||
# Seconds
|
||||
S1 = "1S"
|
||||
S5 = "5S"
|
||||
S15 = "15S"
|
||||
S30 = "30S"
|
||||
|
||||
# Minutes
|
||||
M1 = "1"
|
||||
M5 = "5"
|
||||
M15 = "15"
|
||||
M30 = "30"
|
||||
|
||||
# Hours
|
||||
H1 = "60"
|
||||
H2 = "120"
|
||||
H4 = "240"
|
||||
H6 = "360"
|
||||
H12 = "720"
|
||||
|
||||
# Days
|
||||
D1 = "1D"
|
||||
|
||||
# Weeks
|
||||
W1 = "1W"
|
||||
|
||||
# Months
|
||||
MO1 = "1M"
|
||||
|
||||
|
||||
class ColumnInfo(BaseModel):
|
||||
"""
|
||||
Metadata for a single data column.
|
||||
|
||||
Provides rich, LLM-readable descriptions so AI agents can understand
|
||||
and reason about available data fields.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Column name (e.g., 'close', 'volume', 'funding_rate')")
|
||||
type: Literal["float", "int", "bool", "string", "decimal"] = Field(description="Data type")
|
||||
description: str = Field(description="Human and LLM-readable description of what this column represents")
|
||||
unit: Optional[str] = Field(default=None, description="Unit of measurement (e.g., 'USD', 'BTC', '%', 'contracts')")
|
||||
nullable: bool = Field(default=False, description="Whether this column can contain null values")
|
||||
|
||||
|
||||
class SymbolInfo(BaseModel):
|
||||
"""
|
||||
Complete metadata for a tradeable symbol.
|
||||
|
||||
Includes both TradingView-compatible fields and flexible schema definition
|
||||
for arbitrary data columns.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# Core identification
|
||||
symbol: str = Field(description="Unique symbol identifier (primary key for data fetching)")
|
||||
ticker: Optional[str] = Field(default=None, description="TradingView ticker (if different from symbol)")
|
||||
name: str = Field(description="Display name")
|
||||
description: str = Field(description="LLM-readable description of the instrument")
|
||||
type: str = Field(description="Instrument type: 'crypto', 'stock', 'forex', 'futures', 'derived', etc.")
|
||||
exchange: str = Field(description="Exchange or data source identifier")
|
||||
|
||||
# Trading session info
|
||||
timezone: str = Field(default="Etc/UTC", description="IANA timezone identifier")
|
||||
session: str = Field(default="24x7", description="Trading session spec (e.g., '0930-1600' or '24x7')")
|
||||
|
||||
# Resolution support
|
||||
supported_resolutions: List[str] = Field(description="List of supported time resolutions")
|
||||
has_intraday: bool = Field(default=True, description="Whether intraday resolutions are supported")
|
||||
has_daily: bool = Field(default=True, description="Whether daily resolution is supported")
|
||||
has_weekly_and_monthly: bool = Field(default=False, description="Whether weekly/monthly resolutions are supported")
|
||||
|
||||
# Flexible schema definition
|
||||
columns: List[ColumnInfo] = Field(description="Available data columns for this symbol")
|
||||
time_column: str = Field(default="time", description="Name of the timestamp column")
|
||||
|
||||
# Convenience flags
|
||||
has_ohlcv: bool = Field(default=False, description="Whether standard OHLCV columns are present")
|
||||
|
||||
# Price display (for OHLCV data)
|
||||
pricescale: int = Field(default=100, description="Price scale factor (e.g., 100 for 2 decimals)")
|
||||
minmov: int = Field(default=1, description="Minimum price movement in pricescale units")
|
||||
|
||||
# Additional metadata
|
||||
base_currency: Optional[str] = Field(default=None, description="Base currency (for crypto/forex)")
|
||||
quote_currency: Optional[str] = Field(default=None, description="Quote currency (for crypto/forex)")
|
||||
|
||||
|
||||
class Bar(BaseModel):
|
||||
"""
|
||||
A single bar/row of time-series data with flexible columns.
|
||||
|
||||
All bars must have a timestamp. Additional columns are stored in the
|
||||
data dict and described by the associated ColumnInfo metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
time: int = Field(description="Unix timestamp in seconds")
|
||||
data: Dict[str, Any] = Field(description="Column name -> value mapping")
|
||||
|
||||
# Convenience accessors for common OHLCV columns
|
||||
@property
|
||||
def open(self) -> Optional[float]:
|
||||
return self.data.get("open")
|
||||
|
||||
@property
|
||||
def high(self) -> Optional[float]:
|
||||
return self.data.get("high")
|
||||
|
||||
@property
|
||||
def low(self) -> Optional[float]:
|
||||
return self.data.get("low")
|
||||
|
||||
@property
|
||||
def close(self) -> Optional[float]:
|
||||
return self.data.get("close")
|
||||
|
||||
@property
|
||||
def volume(self) -> Optional[float]:
|
||||
return self.data.get("volume")
|
||||
|
||||
|
||||
class HistoryResult(BaseModel):
|
||||
"""
|
||||
Result from a historical data query.
|
||||
|
||||
Includes the bars, schema information, and pagination metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol: str = Field(description="Symbol identifier")
|
||||
resolution: str = Field(description="Time resolution of the bars")
|
||||
bars: List[Bar] = Field(description="The actual data bars")
|
||||
columns: List[ColumnInfo] = Field(description="Schema describing the bar data columns")
|
||||
nextTime: Optional[int] = Field(default=None, description="Unix timestamp for pagination (if more data available)")
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""
|
||||
A single result from symbol search.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol: str = Field(description="Display symbol (e.g., 'BINANCE:ETH/BTC')")
|
||||
ticker: Optional[str] = Field(default=None, description="Backend ticker for data fetching (e.g., 'ETH/BTC')")
|
||||
full_name: str = Field(description="Full display name including exchange")
|
||||
description: str = Field(description="Human-readable description")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
type: str = Field(description="Instrument type")
|
||||
|
||||
|
||||
class DatafeedConfig(BaseModel):
|
||||
"""
|
||||
Configuration and capabilities of a DataSource.
|
||||
|
||||
Similar to TradingView's onReady configuration object.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# Supported features
|
||||
supported_resolutions: List[str] = Field(description="All resolutions this datafeed supports")
|
||||
supports_search: bool = Field(default=True, description="Whether symbol search is available")
|
||||
supports_time: bool = Field(default=True, description="Whether time-based queries are supported")
|
||||
supports_marks: bool = Field(default=False, description="Whether marks/events are supported")
|
||||
|
||||
# Data characteristics
|
||||
exchanges: List[str] = Field(default_factory=list, description="Available exchanges")
|
||||
symbols_types: List[str] = Field(default_factory=list, description="Available instrument types")
|
||||
|
||||
# Metadata
|
||||
name: str = Field(description="Datafeed name")
|
||||
description: str = Field(description="LLM-readable description of this data source")
|
||||
235
backend/src/datasource/subscription_manager.py
Normal file
235
backend/src/datasource/subscription_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Subscription manager for real-time data feeds.
|
||||
|
||||
Manages subscriptions across multiple data sources and routes updates
|
||||
to WebSocket clients.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, Dict, Optional, Set
|
||||
|
||||
from .base import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Subscription:
|
||||
"""Represents a single client subscription"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subscription_id: str,
|
||||
client_id: str,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
callback: Callable[[dict], None],
|
||||
):
|
||||
self.subscription_id = subscription_id
|
||||
self.client_id = client_id
|
||||
self.source_name = source_name
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.callback = callback
|
||||
self.source_subscription_id: Optional[str] = None
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
"""
|
||||
Manages real-time data subscriptions across multiple data sources.
|
||||
|
||||
Handles:
|
||||
- Subscription lifecycle (subscribe/unsubscribe)
|
||||
- Routing updates from data sources to clients
|
||||
- Multiplexing (multiple clients can subscribe to same symbol/resolution)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map subscription_id -> Subscription
|
||||
self._subscriptions: Dict[str, Subscription] = {}
|
||||
|
||||
# Map (source_name, symbol, resolution) -> Set[subscription_id]
|
||||
# For tracking which client subscriptions use which source subscriptions
|
||||
self._source_refs: Dict[tuple, Set[str]] = {}
|
||||
|
||||
# Map source_subscription_id -> (source_name, symbol, resolution)
|
||||
self._source_subs: Dict[str, tuple] = {}
|
||||
|
||||
# Available data sources
|
||||
self._sources: Dict[str, DataSource] = {}
|
||||
|
||||
def register_source(self, name: str, source: DataSource) -> None:
|
||||
"""Register a data source"""
|
||||
self._sources[name] = source
|
||||
|
||||
def unregister_source(self, name: str) -> None:
|
||||
"""Unregister a data source"""
|
||||
self._sources.pop(name, None)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
subscription_id: str,
|
||||
client_id: str,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
callback: Callable[[dict], None],
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe a client to real-time updates.
|
||||
|
||||
Args:
|
||||
subscription_id: Unique ID for this subscription
|
||||
client_id: ID of the subscribing client
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol to subscribe to
|
||||
resolution: Time resolution
|
||||
callback: Function to call with bar updates
|
||||
|
||||
Raises:
|
||||
ValueError: If source not found or subscription fails
|
||||
"""
|
||||
source = self._sources.get(source_name)
|
||||
if not source:
|
||||
raise ValueError(f"Data source '{source_name}' not found")
|
||||
|
||||
# Create subscription record
|
||||
subscription = Subscription(
|
||||
subscription_id=subscription_id,
|
||||
client_id=client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Check if we already have a source subscription for this (source, symbol, resolution)
|
||||
source_key = (source_name, symbol, resolution)
|
||||
if source_key not in self._source_refs:
|
||||
# Need to create a new source subscription
|
||||
try:
|
||||
source_sub_id = await source.subscribe_bars(
|
||||
symbol=symbol,
|
||||
resolution=resolution,
|
||||
on_tick=lambda bar: self._on_source_update(source_key, bar),
|
||||
)
|
||||
subscription.source_subscription_id = source_sub_id
|
||||
self._source_subs[source_sub_id] = source_key
|
||||
self._source_refs[source_key] = set()
|
||||
logger.info(
|
||||
f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to source: {e}")
|
||||
raise
|
||||
|
||||
# Add this subscription to the reference set
|
||||
self._source_refs[source_key].add(subscription_id)
|
||||
self._subscriptions[subscription_id] = subscription
|
||||
|
||||
logger.info(
|
||||
f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}"
|
||||
)
|
||||
|
||||
async def unsubscribe(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe a client from updates.
|
||||
|
||||
Args:
|
||||
subscription_id: ID of the subscription to remove
|
||||
"""
|
||||
subscription = self._subscriptions.pop(subscription_id, None)
|
||||
if not subscription:
|
||||
logger.warning(f"Subscription {subscription_id} not found")
|
||||
return
|
||||
|
||||
source_key = (subscription.source_name, subscription.symbol, subscription.resolution)
|
||||
|
||||
# Remove from reference set
|
||||
if source_key in self._source_refs:
|
||||
self._source_refs[source_key].discard(subscription_id)
|
||||
|
||||
# If no more clients need this source subscription, unsubscribe from source
|
||||
if not self._source_refs[source_key]:
|
||||
del self._source_refs[source_key]
|
||||
|
||||
if subscription.source_subscription_id:
|
||||
source = self._sources.get(subscription.source_name)
|
||||
if source:
|
||||
try:
|
||||
await source.unsubscribe_bars(subscription.source_subscription_id)
|
||||
logger.info(
|
||||
f"Unsubscribed from source: {subscription.source_subscription_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error unsubscribing from source: {e}")
|
||||
|
||||
self._source_subs.pop(subscription.source_subscription_id, None)
|
||||
|
||||
logger.info(f"Client subscription removed: {subscription_id}")
|
||||
|
||||
async def unsubscribe_client(self, client_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe all subscriptions for a client.
|
||||
|
||||
Useful when a WebSocket connection closes.
|
||||
|
||||
Args:
|
||||
client_id: ID of the client
|
||||
"""
|
||||
# Find all subscriptions for this client
|
||||
client_subs = [
|
||||
sub_id
|
||||
for sub_id, sub in self._subscriptions.items()
|
||||
if sub.client_id == client_id
|
||||
]
|
||||
|
||||
# Unsubscribe each one
|
||||
for sub_id in client_subs:
|
||||
await self.unsubscribe(sub_id)
|
||||
|
||||
logger.info(f"Unsubscribed all subscriptions for client {client_id}")
|
||||
|
||||
def _on_source_update(self, source_key: tuple, bar: dict) -> None:
|
||||
"""
|
||||
Handle an update from a data source.
|
||||
|
||||
Routes the update to all client subscriptions that need it.
|
||||
|
||||
Args:
|
||||
source_key: (source_name, symbol, resolution) tuple
|
||||
bar: Bar data dict from the source
|
||||
"""
|
||||
subscription_ids = self._source_refs.get(source_key, set())
|
||||
|
||||
for sub_id in subscription_ids:
|
||||
subscription = self._subscriptions.get(sub_id)
|
||||
if subscription:
|
||||
try:
|
||||
subscription.callback(bar)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in subscription callback {sub_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
def get_subscription_count(self) -> int:
|
||||
"""Get total number of active client subscriptions"""
|
||||
return len(self._subscriptions)
|
||||
|
||||
def get_source_subscription_count(self) -> int:
|
||||
"""Get total number of active source subscriptions"""
|
||||
return len(self._source_refs)
|
||||
|
||||
def get_client_subscriptions(self, client_id: str) -> list:
|
||||
"""Get all subscriptions for a specific client"""
|
||||
return [
|
||||
{
|
||||
"subscription_id": sub.subscription_id,
|
||||
"source": sub.source_name,
|
||||
"symbol": sub.symbol,
|
||||
"resolution": sub.resolution,
|
||||
}
|
||||
for sub in self._subscriptions.values()
|
||||
if sub.client_id == client_id
|
||||
]
|
||||
347
backend/src/datasource/websocket_handler.py
Normal file
347
backend/src/datasource/websocket_handler.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
WebSocket handler for TradingView-compatible datafeed API.
|
||||
|
||||
Handles incoming requests for symbol search, metadata, historical data,
|
||||
and real-time subscriptions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from .base import DataSource
|
||||
from .registry import DataSourceRegistry
|
||||
from .subscription_manager import SubscriptionManager
|
||||
from .websocket_protocol import (
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
GetBarsRequest,
|
||||
GetBarsResponse,
|
||||
GetConfigRequest,
|
||||
GetConfigResponse,
|
||||
ResolveSymbolRequest,
|
||||
ResolveSymbolResponse,
|
||||
SearchSymbolsRequest,
|
||||
SearchSymbolsResponse,
|
||||
SubscribeBarsRequest,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsRequest,
|
||||
UnsubscribeBarsResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatafeedWebSocketHandler:
|
||||
"""
|
||||
Handles WebSocket connections for TradingView-compatible datafeed API.
|
||||
|
||||
Each handler manages a single WebSocket connection and routes requests
|
||||
to the appropriate data sources via the registry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
registry: DataSourceRegistry,
|
||||
subscription_manager: SubscriptionManager,
|
||||
default_source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize handler.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
client_id: Unique identifier for this client
|
||||
registry: DataSource registry for accessing data sources
|
||||
subscription_manager: Shared subscription manager
|
||||
default_source: Default data source name if not specified in requests
|
||||
"""
|
||||
self.websocket = websocket
|
||||
self.client_id = client_id
|
||||
self.registry = registry
|
||||
self.subscription_manager = subscription_manager
|
||||
self.default_source = default_source
|
||||
self._connected = True
|
||||
|
||||
async def handle_connection(self) -> None:
|
||||
"""
|
||||
Main connection handler loop.
|
||||
|
||||
Processes incoming messages until the connection closes.
|
||||
"""
|
||||
try:
|
||||
await self.websocket.accept()
|
||||
logger.info(f"WebSocket connected: client_id={self.client_id}")
|
||||
|
||||
while self._connected:
|
||||
# Receive message
|
||||
try:
|
||||
data = await self.websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving/parsing message: {e}")
|
||||
break
|
||||
|
||||
# Route to appropriate handler
|
||||
await self._handle_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
finally:
|
||||
# Clean up subscriptions when connection closes
|
||||
await self.subscription_manager.unsubscribe_client(self.client_id)
|
||||
self._connected = False
|
||||
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
|
||||
|
||||
async def _handle_message(self, message: dict) -> None:
|
||||
"""Route message to appropriate handler based on type"""
|
||||
msg_type = message.get("type")
|
||||
request_id = message.get("request_id", "unknown")
|
||||
|
||||
try:
|
||||
if msg_type == "search_symbols":
|
||||
await self._handle_search_symbols(SearchSymbolsRequest(**message))
|
||||
elif msg_type == "resolve_symbol":
|
||||
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
|
||||
elif msg_type == "get_bars":
|
||||
await self._handle_get_bars(GetBarsRequest(**message))
|
||||
elif msg_type == "subscribe_bars":
|
||||
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
|
||||
elif msg_type == "unsubscribe_bars":
|
||||
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
|
||||
elif msg_type == "get_config":
|
||||
await self._handle_get_config(GetConfigRequest(**message))
|
||||
else:
|
||||
await self._send_error(
|
||||
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
|
||||
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
|
||||
|
||||
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
|
||||
"""Handle symbol search request"""
|
||||
# Use default source or search all sources
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"SOURCE_NOT_FOUND",
|
||||
f"Default source '{self.default_source}' not found",
|
||||
)
|
||||
return
|
||||
|
||||
results = await source.search_symbols(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
results_data = [r.model_dump(mode="json") for r in results]
|
||||
else:
|
||||
# Search all sources
|
||||
all_results = await self.registry.search_all(
|
||||
query=request.query,
|
||||
type=request.symbol_type,
|
||||
exchange=request.exchange,
|
||||
limit=request.limit,
|
||||
)
|
||||
# Flatten results from all sources
|
||||
results_data = []
|
||||
for source_results in all_results.values():
|
||||
results_data.extend([r.model_dump(mode="json") for r in source_results])
|
||||
|
||||
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
|
||||
"""Handle symbol resolution request"""
|
||||
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id,
|
||||
"NO_SOURCE_SPECIFIED",
|
||||
"No data source specified and no default source configured",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
|
||||
response = ResolveSymbolResponse(
|
||||
request_id=request.request_id,
|
||||
symbol_info=symbol_info.model_dump(mode="json"),
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
|
||||
|
||||
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
|
||||
"""Handle historical bars request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
source = self.registry.get(source_name)
|
||||
if not source:
|
||||
await self._send_error(
|
||||
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
history = await source.get_bars(
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
from_time=request.from_time,
|
||||
to_time=request.to_time,
|
||||
countback=request.countback,
|
||||
)
|
||||
response = GetBarsResponse(
|
||||
request_id=request.request_id, history=history.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
except ValueError as e:
|
||||
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
|
||||
|
||||
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
|
||||
"""Handle real-time subscription request"""
|
||||
# Extract source from symbol
|
||||
if ":" in request.symbol:
|
||||
source_name, symbol = request.symbol.split(":", 1)
|
||||
else:
|
||||
source_name = self.default_source
|
||||
symbol = request.symbol
|
||||
|
||||
if not source_name:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create callback that sends updates to this WebSocket
|
||||
async def send_update(bar: dict):
|
||||
update = BarUpdateMessage(
|
||||
subscription_id=request.subscription_id,
|
||||
symbol=request.symbol,
|
||||
resolution=request.resolution,
|
||||
bar=bar,
|
||||
)
|
||||
await self._send_response(update)
|
||||
|
||||
# Register subscription
|
||||
await self.subscription_manager.subscribe(
|
||||
subscription_id=request.subscription_id,
|
||||
client_id=self.client_id,
|
||||
source_name=source_name,
|
||||
symbol=symbol,
|
||||
resolution=request.resolution,
|
||||
callback=lambda bar: self._queue_update(send_update(bar)),
|
||||
)
|
||||
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription failed: {e}", exc_info=True)
|
||||
response = SubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
message=str(e),
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
|
||||
"""Handle unsubscribe request"""
|
||||
try:
|
||||
await self.subscription_manager.unsubscribe(request.subscription_id)
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=True,
|
||||
)
|
||||
await self._send_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Unsubscribe failed: {e}")
|
||||
response = UnsubscribeBarsResponse(
|
||||
request_id=request.request_id,
|
||||
subscription_id=request.subscription_id,
|
||||
success=False,
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _handle_get_config(self, request: GetConfigRequest) -> None:
|
||||
"""Handle datafeed config request"""
|
||||
if self.default_source:
|
||||
source = self.registry.get(self.default_source)
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
return
|
||||
|
||||
# Return aggregate config from all sources
|
||||
all_sources = self.registry.list_sources()
|
||||
if not all_sources:
|
||||
await self._send_error(
|
||||
request.request_id, "NO_SOURCES", "No data sources available"
|
||||
)
|
||||
return
|
||||
|
||||
# Just use first source's config for now
|
||||
# TODO: Aggregate configs from multiple sources
|
||||
source = self.registry.get(all_sources[0])
|
||||
if source:
|
||||
config = await source.get_config()
|
||||
response = GetConfigResponse(
|
||||
request_id=request.request_id, config=config.model_dump(mode="json")
|
||||
)
|
||||
await self._send_response(response)
|
||||
|
||||
async def _send_response(self, response) -> None:
|
||||
"""Send a response message to the client"""
|
||||
try:
|
||||
await self.websocket.send_json(response.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending response: {e}")
|
||||
self._connected = False
|
||||
|
||||
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
|
||||
"""Send an error response"""
|
||||
error = ErrorResponse(
|
||||
request_id=request_id, error_code=error_code, error_message=error_message
|
||||
)
|
||||
await self._send_response(error)
|
||||
|
||||
def _queue_update(self, coro):
|
||||
"""Queue an async update to be sent (prevents blocking callback)"""
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(coro)
|
||||
170
backend/src/datasource/websocket_protocol.py
Normal file
170
backend/src/datasource/websocket_protocol.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
WebSocket protocol messages for TradingView-compatible datafeed API.
|
||||
|
||||
These messages define the wire format for client-server communication
|
||||
over WebSocket for symbol search, historical data, and real-time subscriptions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client -> Server Messages
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SearchSymbolsRequest(BaseModel):
|
||||
"""Request to search for symbols matching a query"""
|
||||
|
||||
type: Literal["search_symbols"] = "search_symbols"
|
||||
request_id: str = Field(description="Client-generated request ID for matching responses")
|
||||
query: str = Field(description="Search query string")
|
||||
symbol_type: Optional[str] = Field(default=None, description="Filter by instrument type")
|
||||
exchange: Optional[str] = Field(default=None, description="Filter by exchange")
|
||||
limit: int = Field(default=30, description="Maximum number of results")
|
||||
|
||||
|
||||
class ResolveSymbolRequest(BaseModel):
|
||||
"""Request full metadata for a specific symbol"""
|
||||
|
||||
type: Literal["resolve_symbol"] = "resolve_symbol"
|
||||
request_id: str
|
||||
symbol: str = Field(description="Symbol identifier to resolve")
|
||||
|
||||
|
||||
class GetBarsRequest(BaseModel):
|
||||
"""Request historical bar data"""
|
||||
|
||||
type: Literal["get_bars"] = "get_bars"
|
||||
request_id: str
|
||||
symbol: str
|
||||
resolution: str = Field(description="Time resolution (e.g., '1', '5', '60', '1D')")
|
||||
from_time: int = Field(description="Start time (Unix timestamp in seconds)")
|
||||
to_time: int = Field(description="End time (Unix timestamp in seconds)")
|
||||
countback: Optional[int] = Field(default=None, description="Maximum number of bars to return")
|
||||
|
||||
|
||||
class SubscribeBarsRequest(BaseModel):
|
||||
"""Subscribe to real-time bar updates"""
|
||||
|
||||
type: Literal["subscribe_bars"] = "subscribe_bars"
|
||||
request_id: str
|
||||
symbol: str
|
||||
resolution: str
|
||||
subscription_id: str = Field(description="Client-generated subscription ID")
|
||||
|
||||
|
||||
class UnsubscribeBarsRequest(BaseModel):
|
||||
"""Unsubscribe from real-time updates"""
|
||||
|
||||
type: Literal["unsubscribe_bars"] = "unsubscribe_bars"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
|
||||
|
||||
class GetConfigRequest(BaseModel):
|
||||
"""Request datafeed configuration"""
|
||||
|
||||
type: Literal["get_config"] = "get_config"
|
||||
request_id: str
|
||||
|
||||
|
||||
# Union of all client request types
|
||||
ClientRequest = Union[
|
||||
SearchSymbolsRequest,
|
||||
ResolveSymbolRequest,
|
||||
GetBarsRequest,
|
||||
SubscribeBarsRequest,
|
||||
UnsubscribeBarsRequest,
|
||||
GetConfigRequest,
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Server -> Client Messages
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SearchSymbolsResponse(BaseModel):
|
||||
"""Response with search results"""
|
||||
|
||||
type: Literal["search_symbols_response"] = "search_symbols_response"
|
||||
request_id: str
|
||||
results: List[Dict[str, Any]] = Field(description="List of SearchResult objects")
|
||||
|
||||
|
||||
class ResolveSymbolResponse(BaseModel):
|
||||
"""Response with symbol metadata"""
|
||||
|
||||
type: Literal["resolve_symbol_response"] = "resolve_symbol_response"
|
||||
request_id: str
|
||||
symbol_info: Dict[str, Any] = Field(description="SymbolInfo object")
|
||||
|
||||
|
||||
class GetBarsResponse(BaseModel):
|
||||
"""Response with historical bars"""
|
||||
|
||||
type: Literal["get_bars_response"] = "get_bars_response"
|
||||
request_id: str
|
||||
history: Dict[str, Any] = Field(description="HistoryResult object with bars and metadata")
|
||||
|
||||
|
||||
class SubscribeBarsResponse(BaseModel):
|
||||
"""Acknowledgment of subscription"""
|
||||
|
||||
type: Literal["subscribe_bars_response"] = "subscribe_bars_response"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class UnsubscribeBarsResponse(BaseModel):
|
||||
"""Acknowledgment of unsubscribe"""
|
||||
|
||||
type: Literal["unsubscribe_bars_response"] = "unsubscribe_bars_response"
|
||||
request_id: str
|
||||
subscription_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class GetConfigResponse(BaseModel):
|
||||
"""Response with datafeed configuration"""
|
||||
|
||||
type: Literal["get_config_response"] = "get_config_response"
|
||||
request_id: str
|
||||
config: Dict[str, Any] = Field(description="DatafeedConfig object")
|
||||
|
||||
|
||||
class BarUpdateMessage(BaseModel):
|
||||
"""Real-time bar update (server-initiated, no request_id)"""
|
||||
|
||||
type: Literal["bar_update"] = "bar_update"
|
||||
subscription_id: str
|
||||
symbol: str
|
||||
resolution: str
|
||||
bar: Dict[str, Any] = Field(description="Bar data including time and all columns")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response for any failed request"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
request_id: str
|
||||
error_code: str = Field(description="Machine-readable error code")
|
||||
error_message: str = Field(description="Human-readable error description")
|
||||
|
||||
|
||||
# Union of all server response types
|
||||
ServerResponse = Union[
|
||||
SearchSymbolsResponse,
|
||||
ResolveSymbolResponse,
|
||||
GetBarsResponse,
|
||||
SubscribeBarsResponse,
|
||||
UnsubscribeBarsResponse,
|
||||
GetConfigResponse,
|
||||
BarUpdateMessage,
|
||||
ErrorResponse,
|
||||
]
|
||||
Reference in New Issue
Block a user