initial commit with charts and assistant chat

This commit is contained in:
2026-03-02 00:08:19 -04:00
commit d907c5765e
1828 changed files with 50054 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
from .base import DataSource
from .schema import (
ColumnInfo,
SymbolInfo,
Bar,
HistoryResult,
DatafeedConfig,
Resolution,
SearchResult,
)
from .registry import DataSourceRegistry
__all__ = [
"DataSource",
"ColumnInfo",
"SymbolInfo",
"Bar",
"HistoryResult",
"DatafeedConfig",
"Resolution",
"SearchResult",
"DataSourceRegistry",
]

View File

@@ -0,0 +1,3 @@
from .ccxt_adapter import CCXTDataSource
__all__ = ["CCXTDataSource"]

View File

@@ -0,0 +1,526 @@
"""
CCXT DataSource adapter for accessing cryptocurrency exchange data.
This adapter provides access to hundreds of cryptocurrency exchanges through
the free CCXT library (not ccxt.pro), supporting both historical data and
polling-based subscriptions.
Numerical Precision:
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
- CCXT returns numeric values as strings or floats depending on configuration
- All financial values are converted to Decimal to maintain precision
Real-time Updates:
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
- Default polling interval: 60 seconds (configurable)
- Simulates real-time subscriptions by periodically fetching latest bars
"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from decimal import Decimal
from typing import Callable, Dict, List, Optional, Set, Union
import ccxt.async_support as ccxt
from ..base import DataSource
logger = logging.getLogger(__name__)
from ..schema import (
Bar,
ColumnInfo,
DatafeedConfig,
HistoryResult,
Resolution,
SearchResult,
SymbolInfo,
)
class CCXTDataSource(DataSource):
"""
DataSource adapter for CCXT cryptocurrency exchanges (free version).
Provides access to:
- Multiple cryptocurrency exchanges (Binance, Coinbase, Kraken, etc.)
- Historical OHLCV data via REST API
- Polling-based real-time updates (configurable interval)
- Symbol search and metadata
Args:
exchange_id: CCXT exchange identifier (e.g., 'binance', 'coinbase', 'kraken')
config: Optional exchange-specific configuration (API keys, options)
sandbox: Whether to use sandbox/testnet mode (default: False)
poll_interval: Interval in seconds for polling updates (default: 60)
"""
def __init__(
self,
exchange_id: str = "binance",
config: Optional[Dict] = None,
sandbox: bool = False,
poll_interval: int = 60,
):
self.exchange_id = exchange_id
self._config = config or {}
self._sandbox = sandbox
self._poll_interval = poll_interval
# Initialize exchange (using free async_support, not pro)
exchange_class = getattr(ccxt, exchange_id)
self.exchange = exchange_class(self._config)
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
self.exchange.set_sandbox_mode(True)
# Cache for markets
self._markets: Optional[Dict] = None
self._markets_loaded = False
# Active subscriptions (polling-based)
self._subscriptions: Dict[str, asyncio.Task] = {}
self._subscription_callbacks: Dict[str, Callable] = {}
self._last_bars: Dict[str, int] = {} # Track last bar timestamp per subscription
@staticmethod
def _to_decimal(value: Union[str, int, float, Decimal, None]) -> Optional[Decimal]:
"""
Convert a value to Decimal for numerical precision.
Handles CCXT's mixed output (strings, floats, ints, None).
Converts floats by converting to string first to avoid precision loss.
"""
if value is None:
return None
if isinstance(value, Decimal):
return value
if isinstance(value, str):
return Decimal(value)
if isinstance(value, (int, float)):
# Convert to string first to avoid float precision issues
return Decimal(str(value))
return None
async def _ensure_markets_loaded(self):
"""Ensure markets are loaded from exchange"""
if not self._markets_loaded:
self._markets = await self.exchange.load_markets()
self._markets_loaded = True
async def get_config(self) -> DatafeedConfig:
"""Get datafeed configuration"""
await self._ensure_markets_loaded()
# Determine supported resolutions based on exchange capabilities
supported_resolutions = [
Resolution.M1,
Resolution.M5,
Resolution.M15,
Resolution.M30,
Resolution.H1,
Resolution.H4,
Resolution.D1,
]
# Get unique exchange names (most CCXT exchanges are just one)
exchanges = [self.exchange_id.upper()]
return DatafeedConfig(
name=f"CCXT {self.exchange_id.title()}",
description=f"Live and historical cryptocurrency data from {self.exchange_id} via CCXT library. "
f"Supports OHLCV data for {len(self._markets) if self._markets else 'many'} trading pairs.",
supported_resolutions=supported_resolutions,
supports_search=True,
supports_time=True,
exchanges=exchanges,
symbols_types=["crypto", "spot", "futures", "swap"],
)
async def search_symbols(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> List[SearchResult]:
"""Search for symbols on the exchange"""
await self._ensure_markets_loaded()
query_upper = query.upper()
results = []
for symbol, market in self._markets.items():
# Match query against symbol or base/quote currencies
if (query_upper in symbol or
query_upper in market.get('base', '') or
query_upper in market.get('quote', '')):
# Filter by type if specified
market_type = market.get('type', 'spot')
if type and market_type != type:
continue
# Create search result
base = market.get('base', '')
quote = market.get('quote', '')
results.append(
SearchResult(
symbol=f"{base}/{quote}", # Clean user-facing format
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
full_name=f"{base}/{quote} ({self.exchange_id.upper()})",
description=f"{base}/{quote} {market_type} trading pair on {self.exchange_id}",
exchange=self.exchange_id.upper(),
type=market_type,
)
)
if len(results) >= limit:
break
return results
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
"""Get complete metadata for a symbol"""
await self._ensure_markets_loaded()
if symbol not in self._markets:
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
market = self._markets[symbol]
base = market.get('base', '')
quote = market.get('quote', '')
market_type = market.get('type', 'spot')
# Determine price scale from market precision
# CCXT precision can be in different modes:
# - DECIMAL_PLACES (int): number of decimal places (e.g., 2 = 0.01)
# - TICK_SIZE (float): actual tick size (e.g., 0.01, 0.00001)
# We need to convert to pricescale (10^n where n is decimal places)
price_precision = market.get('precision', {}).get('price', 2)
if isinstance(price_precision, float):
# TICK_SIZE mode: precision is the actual tick size (e.g., 0.01, 0.00001)
# Convert tick size to decimal places
# For 0.01 -> 2 decimal places, 0.00001 -> 5 decimal places
tick_str = str(Decimal(str(price_precision)))
if '.' in tick_str:
decimal_places = len(tick_str.split('.')[1].rstrip('0'))
else:
decimal_places = 0
pricescale = 10 ** decimal_places
else:
# DECIMAL_PLACES or SIGNIFICANT_DIGITS mode: precision is an integer
# Assume DECIMAL_PLACES mode (most common for price)
pricescale = 10 ** int(price_precision)
return SymbolInfo(
symbol=f"{base}/{quote}", # Clean user-facing format
ticker=f"{self.exchange_id.upper()}:{symbol}", # Ticker with exchange prefix for routing
name=f"{base}/{quote}",
description=f"{base}/{quote} {market_type} pair on {self.exchange_id}. "
f"Minimum order: {market.get('limits', {}).get('amount', {}).get('min', 'N/A')} {base}",
type=market_type,
exchange=self.exchange_id.upper(),
timezone="Etc/UTC",
session="24x7",
supported_resolutions=[
Resolution.M1,
Resolution.M5,
Resolution.M15,
Resolution.M30,
Resolution.H1,
Resolution.H4,
Resolution.D1,
],
has_intraday=True,
has_daily=True,
has_weekly_and_monthly=False,
columns=[
ColumnInfo(
name="open",
type="decimal",
description=f"Opening price in {quote}",
unit=quote,
),
ColumnInfo(
name="high",
type="decimal",
description=f"Highest price in {quote}",
unit=quote,
),
ColumnInfo(
name="low",
type="decimal",
description=f"Lowest price in {quote}",
unit=quote,
),
ColumnInfo(
name="close",
type="decimal",
description=f"Closing price in {quote}",
unit=quote,
),
ColumnInfo(
name="volume",
type="decimal",
description=f"Trading volume in {base}",
unit=base,
),
],
time_column="time",
has_ohlcv=True,
pricescale=pricescale,
minmov=1,
base_currency=base,
quote_currency=quote,
)
def _resolution_to_timeframe(self, resolution: str) -> str:
"""Convert our resolution format to CCXT timeframe format"""
# Map our resolutions to CCXT timeframes
mapping = {
"1": "1m",
"5": "5m",
"15": "15m",
"30": "30m",
"60": "1h",
"120": "2h",
"240": "4h",
"360": "6h",
"720": "12h",
"1D": "1d",
"1W": "1w",
"1M": "1M",
}
return mapping.get(resolution, "1m")
def _timeframe_to_milliseconds(self, timeframe: str) -> int:
"""Convert CCXT timeframe to milliseconds"""
unit = timeframe[-1]
amount = int(timeframe[:-1]) if len(timeframe) > 1 else 1
units = {
's': 1000,
'm': 60 * 1000,
'h': 60 * 60 * 1000,
'd': 24 * 60 * 60 * 1000,
'w': 7 * 24 * 60 * 60 * 1000,
'M': 30 * 24 * 60 * 60 * 1000, # Approximate
}
return amount * units.get(unit, 60000)
async def get_bars(
self,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> HistoryResult:
"""Get historical bars from the exchange"""
logger.info(
f"CCXTDataSource({self.exchange_id}).get_bars: symbol={symbol}, resolution={resolution}, "
f"from_time={from_time}, to_time={to_time}, countback={countback}"
)
await self._ensure_markets_loaded()
if symbol not in self._markets:
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
timeframe = self._resolution_to_timeframe(resolution)
# CCXT uses milliseconds for timestamps
since = from_time * 1000
until = to_time * 1000
# Fetch OHLCV data
limit = countback if countback else 1000
try:
# Fetch in batches if needed
all_ohlcv = []
current_since = since
while current_since < until:
ohlcv = await self.exchange.fetch_ohlcv(
symbol,
timeframe=timeframe,
since=current_since,
limit=limit,
)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
# Update since for next batch
last_timestamp = ohlcv[-1][0]
if last_timestamp <= current_since:
break # No progress, avoid infinite loop
current_since = last_timestamp + 1
# Stop if we have enough bars
if countback and len(all_ohlcv) >= countback:
all_ohlcv = all_ohlcv[:countback]
break
# Convert to our Bar format with Decimal precision
bars = []
for candle in all_ohlcv:
timestamp_ms, open_price, high, low, close, volume = candle
timestamp = timestamp_ms // 1000 # Convert to seconds
# Only include bars within requested range
if timestamp < from_time or timestamp >= to_time:
continue
bars.append(
Bar(
time=timestamp,
data={
"open": self._to_decimal(open_price),
"high": self._to_decimal(high),
"low": self._to_decimal(low),
"close": self._to_decimal(close),
"volume": self._to_decimal(volume),
},
)
)
# Get symbol info for column metadata
symbol_info = await self.resolve_symbol(symbol)
logger.info(
f"CCXTDataSource({self.exchange_id}).get_bars: Returning {len(bars)} bars. "
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
)
# Determine if more data is available
next_time = None
if bars and countback and len(bars) >= countback:
next_time = bars[-1].time + (bars[-1].time - bars[-2].time if len(bars) > 1 else 60)
return HistoryResult(
symbol=symbol,
resolution=resolution,
bars=bars,
columns=symbol_info.columns,
nextTime=next_time,
)
except Exception as e:
raise ValueError(f"Failed to fetch bars for {symbol}: {str(e)}")
async def subscribe_bars(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
) -> str:
"""
Subscribe to bar updates via polling.
Note: Uses polling instead of WebSocket since we're using free CCXT.
Polls at the configured interval (default: 60 seconds).
"""
await self._ensure_markets_loaded()
if symbol not in self._markets:
raise ValueError(f"Symbol '{symbol}' not found on {self.exchange_id}")
subscription_id = f"{symbol}:{resolution}:{time.time()}"
# Store callback
self._subscription_callbacks[subscription_id] = on_tick
# Start polling task
timeframe = self._resolution_to_timeframe(resolution)
task = asyncio.create_task(
self._poll_ohlcv(symbol, timeframe, subscription_id)
)
self._subscriptions[subscription_id] = task
return subscription_id
async def _poll_ohlcv(self, symbol: str, timeframe: str, subscription_id: str):
"""
Poll for OHLCV updates at regular intervals.
This simulates real-time updates by fetching the latest bars periodically.
Only sends updates when new bars are detected.
"""
try:
while subscription_id in self._subscription_callbacks:
try:
# Fetch latest bars
ohlcv = await self.exchange.fetch_ohlcv(
symbol,
timeframe=timeframe,
limit=2, # Get last 2 bars to detect new ones
)
if ohlcv and len(ohlcv) > 0:
# Get the latest candle
latest = ohlcv[-1]
timestamp_ms, open_price, high, low, close, volume = latest
timestamp = timestamp_ms // 1000
# Only send update if this is a new bar
last_timestamp = self._last_bars.get(subscription_id, 0)
if timestamp > last_timestamp:
self._last_bars[subscription_id] = timestamp
# Convert to our format with Decimal precision
tick_data = {
"time": timestamp,
"open": self._to_decimal(open_price),
"high": self._to_decimal(high),
"low": self._to_decimal(low),
"close": self._to_decimal(close),
"volume": self._to_decimal(volume),
}
# Call the callback
callback = self._subscription_callbacks.get(subscription_id)
if callback:
callback(tick_data)
except Exception as e:
print(f"Error polling OHLCV for {symbol}: {e}")
# Wait for next poll interval
await asyncio.sleep(self._poll_interval)
except asyncio.CancelledError:
pass
async def unsubscribe_bars(self, subscription_id: str) -> None:
"""Unsubscribe from polling updates"""
# Remove callback and tracking
self._subscription_callbacks.pop(subscription_id, None)
self._last_bars.pop(subscription_id, None)
# Cancel polling task
task = self._subscriptions.pop(subscription_id, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def close(self):
"""Close exchange connection and cleanup"""
# Cancel all subscriptions
for subscription_id in list(self._subscriptions.keys()):
await self.unsubscribe_bars(subscription_id)
# Close exchange
if hasattr(self.exchange, 'close'):
await self.exchange.close()

View File

@@ -0,0 +1,353 @@
"""
Demo data source with synthetic data.
Generates realistic-looking OHLCV data plus additional columns for testing.
"""
import asyncio
import logging
import math
import random
import time
from typing import Callable, Dict, List, Optional
from ..base import DataSource
logger = logging.getLogger(__name__)
from ..schema import (
Bar,
ColumnInfo,
DatafeedConfig,
HistoryResult,
Resolution,
SearchResult,
SymbolInfo,
)
class DemoDataSource(DataSource):
"""
Demo data source that generates synthetic time-series data.
Provides:
- Standard OHLCV columns
- Additional demo columns (RSI, sentiment, volume_profile)
- Real-time updates via polling simulation
"""
def __init__(self):
self._subscriptions: Dict[str, asyncio.Task] = {}
self._symbols = {
"DEMO:BTC/USD": {
"name": "Bitcoin",
"type": "crypto",
"base_price": 50000.0,
"volatility": 0.02,
},
"DEMO:ETH/USD": {
"name": "Ethereum",
"type": "crypto",
"base_price": 3000.0,
"volatility": 0.03,
},
"DEMO:SOL/USD": {
"name": "Solana",
"type": "crypto",
"base_price": 100.0,
"volatility": 0.04,
},
}
async def get_config(self) -> DatafeedConfig:
return DatafeedConfig(
name="Demo DataSource",
description="Synthetic data generator for testing. Provides OHLCV plus additional indicator columns.",
supported_resolutions=[
Resolution.M1,
Resolution.M5,
Resolution.M15,
Resolution.H1,
Resolution.D1,
],
supports_search=True,
supports_time=True,
exchanges=["DEMO"],
symbols_types=["crypto"],
)
async def search_symbols(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> List[SearchResult]:
query_lower = query.lower()
results = []
for symbol, info in self._symbols.items():
if query_lower in symbol.lower() or query_lower in info["name"].lower():
if type and info["type"] != type:
continue
results.append(
SearchResult(
symbol=info['name'], # Clean user-facing format (e.g., "Bitcoin")
ticker=symbol, # Keep DEMO:BTC/USD format for routing
full_name=f"{info['name']} (DEMO)",
description=f"Demo {info['name']} pair",
exchange="DEMO",
type=info["type"],
)
)
return results[:limit]
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
if symbol not in self._symbols:
raise ValueError(f"Symbol '{symbol}' not found")
info = self._symbols[symbol]
base, quote = symbol.split(":")[1].split("/")
return SymbolInfo(
symbol=info["name"], # Clean user-facing format (e.g., "Bitcoin")
ticker=symbol, # Keep DEMO:BTC/USD format for routing
name=info["name"],
description=f"Demo {info['name']} spot price with synthetic indicators",
type=info["type"],
exchange="DEMO",
timezone="Etc/UTC",
session="24x7",
supported_resolutions=[Resolution.M1, Resolution.M5, Resolution.M15, Resolution.H1, Resolution.D1],
has_intraday=True,
has_daily=True,
has_weekly_and_monthly=False,
columns=[
ColumnInfo(
name="open",
type="float",
description=f"Opening price in {quote}",
unit=quote,
),
ColumnInfo(
name="high",
type="float",
description=f"Highest price in {quote}",
unit=quote,
),
ColumnInfo(
name="low",
type="float",
description=f"Lowest price in {quote}",
unit=quote,
),
ColumnInfo(
name="close",
type="float",
description=f"Closing price in {quote}",
unit=quote,
),
ColumnInfo(
name="volume",
type="float",
description=f"Trading volume in {base}",
unit=base,
),
ColumnInfo(
name="rsi",
type="float",
description="Relative Strength Index (14-period), range 0-100",
unit=None,
),
ColumnInfo(
name="sentiment",
type="float",
description="Synthetic social sentiment score, range -1.0 to 1.0",
unit=None,
),
ColumnInfo(
name="volume_profile",
type="float",
description="Volume as percentage of 24h average",
unit="%",
),
],
time_column="time",
has_ohlcv=True,
pricescale=100,
minmov=1,
base_currency=base,
quote_currency=quote,
)
async def get_bars(
self,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> HistoryResult:
if symbol not in self._symbols:
raise ValueError(f"Symbol '{symbol}' not found")
logger.info(
f"DemoDataSource.get_bars: symbol={symbol}, resolution={resolution}, "
f"from_time={from_time}, to_time={to_time}, countback={countback}"
)
info = self._symbols[symbol]
symbol_meta = await self.resolve_symbol(symbol)
# Convert resolution to seconds
resolution_seconds = self._resolution_to_seconds(resolution)
# Generate bars
bars = []
# Align current_time to resolution, but ensure it's >= from_time
current_time = from_time - (from_time % resolution_seconds)
if current_time < from_time:
current_time += resolution_seconds
price = info["base_price"]
bar_count = 0
max_bars = countback if countback else 5000
while current_time <= to_time and bar_count < max_bars:
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
# Only include bars within the requested range
if from_time <= current_time <= to_time:
bars.append(Bar(time=current_time * 1000, data=bar_data)) # Convert to milliseconds
bar_count += 1
price = bar_data["close"] # Next bar starts from previous close
current_time += resolution_seconds
logger.info(
f"DemoDataSource.get_bars: Generated {len(bars)} bars. "
f"First: {bars[0].time if bars else 'N/A'}, Last: {bars[-1].time if bars else 'N/A'}"
)
# Determine if there's more data (for pagination)
next_time = current_time if current_time <= to_time else None
return HistoryResult(
symbol=symbol,
resolution=resolution,
bars=bars,
columns=symbol_meta.columns,
nextTime=next_time,
)
async def subscribe_bars(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
) -> str:
if symbol not in self._symbols:
raise ValueError(f"Symbol '{symbol}' not found")
subscription_id = f"{symbol}:{resolution}:{time.time()}"
# Start background task to simulate real-time updates
task = asyncio.create_task(
self._tick_generator(symbol, resolution, on_tick, subscription_id)
)
self._subscriptions[subscription_id] = task
return subscription_id
async def unsubscribe_bars(self, subscription_id: str) -> None:
task = self._subscriptions.pop(subscription_id, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def _resolution_to_seconds(self, resolution: str) -> int:
"""Convert resolution string to seconds"""
if resolution.endswith("D"):
return int(resolution[:-1]) * 86400
elif resolution.endswith("W"):
return int(resolution[:-1]) * 604800
elif resolution.endswith("M"):
return int(resolution[:-1]) * 2592000 # Approximate month
else:
# Assume minutes
return int(resolution) * 60
def _generate_bar(self, timestamp: int, base_price: float, volatility: float, period_seconds: int) -> dict:
"""Generate a single synthetic OHLCV bar"""
# Random walk for the period
open_price = base_price
# Generate intra-period price movement
num_ticks = max(10, period_seconds // 60) # More ticks for longer periods
prices = [open_price]
for _ in range(num_ticks):
change = random.gauss(0, volatility / math.sqrt(num_ticks))
prices.append(prices[-1] * (1 + change))
close_price = prices[-1]
high_price = max(prices)
low_price = min(prices)
# Volume with some randomness
base_volume = 1000000 * (period_seconds / 60) # Scale with period
volume = base_volume * random.uniform(0.5, 2.0)
# Additional synthetic indicators
rsi = 30 + random.random() * 40 # Biased toward middle range
sentiment = math.sin(timestamp / 3600) * 0.5 + random.gauss(0, 0.2) # Hourly cycle + noise
sentiment = max(-1.0, min(1.0, sentiment))
volume_profile = 100 * random.uniform(0.5, 1.5)
return {
"open": round(open_price, 2),
"high": round(high_price, 2),
"low": round(low_price, 2),
"close": round(close_price, 2),
"volume": round(volume, 2),
"rsi": round(rsi, 2),
"sentiment": round(sentiment, 3),
"volume_profile": round(volume_profile, 2),
}
async def _tick_generator(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
subscription_id: str,
):
"""Background task that generates periodic ticks"""
info = self._symbols[symbol]
resolution_seconds = self._resolution_to_seconds(resolution)
# Start from current aligned time
current_time = int(time.time())
current_time = current_time - (current_time % resolution_seconds)
price = info["base_price"]
try:
while True:
# Wait until next bar
await asyncio.sleep(resolution_seconds)
current_time += resolution_seconds
bar_data = self._generate_bar(current_time, price, info["volatility"], resolution_seconds)
price = bar_data["close"]
# Call the tick handler
tick_data = {"time": current_time, **bar_data}
on_tick(tick_data)
except asyncio.CancelledError:
# Subscription cancelled
pass

View File

@@ -0,0 +1,146 @@
"""
Abstract DataSource interface.
Inspired by TradingView's Datafeed API with extensions for flexible column schemas
and AI-native metadata.
"""
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from .schema import DatafeedConfig, HistoryResult, SearchResult, SymbolInfo
class DataSource(ABC):
"""
Abstract base class for time-series data sources.
Provides a standardized interface for:
- Symbol search and metadata retrieval
- Historical data queries (time-based, paginated)
- Real-time data subscriptions
All data rows must have a timestamp. Additional columns are flexible
and described via ColumnInfo metadata.
"""
@abstractmethod
async def get_config(self) -> DatafeedConfig:
"""
Get datafeed configuration and capabilities.
Called once during initialization to understand what this data source
supports (resolutions, exchanges, search, etc.).
Returns:
DatafeedConfig describing this datafeed's capabilities
"""
pass
@abstractmethod
async def search_symbols(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> List[SearchResult]:
"""
Search for symbols matching a text query.
Args:
query: Free-text search string
type: Optional filter by instrument type
exchange: Optional filter by exchange
limit: Maximum number of results
Returns:
List of matching symbols with basic metadata
"""
pass
@abstractmethod
async def resolve_symbol(self, symbol: str) -> SymbolInfo:
"""
Get complete metadata for a symbol.
Called after a symbol is selected to retrieve full information including
supported resolutions, column schema, trading session, etc.
Args:
symbol: Symbol identifier
Returns:
Complete SymbolInfo including column definitions
Raises:
ValueError: If symbol is not found
"""
pass
@abstractmethod
async def get_bars(
self,
symbol: str,
resolution: str,
from_time: int,
to_time: int,
countback: Optional[int] = None,
) -> HistoryResult:
"""
Get historical bars for a symbol and resolution.
Time range is specified in Unix timestamps (seconds). If more data is
available beyond the requested range, the result should include a
nextTime value for pagination.
Args:
symbol: Symbol identifier
resolution: Time resolution (e.g., "1", "5", "60", "1D")
from_time: Start time (Unix timestamp in seconds)
to_time: End time (Unix timestamp in seconds)
countback: Optional limit on number of bars to return
Returns:
HistoryResult with bars, column schema, and pagination info
Raises:
ValueError: If symbol or resolution is not supported
"""
pass
@abstractmethod
async def subscribe_bars(
self,
symbol: str,
resolution: str,
on_tick: Callable[[dict], None],
) -> str:
"""
Subscribe to real-time bar updates.
The callback will be invoked with new bar data as it becomes available.
The data dict will match the column schema from resolve_symbol().
Args:
symbol: Symbol identifier
resolution: Time resolution
on_tick: Callback function receiving bar data dict
Returns:
Subscription ID for later unsubscribe
Raises:
ValueError: If symbol or resolution is not supported
"""
pass
@abstractmethod
async def unsubscribe_bars(self, subscription_id: str) -> None:
"""
Unsubscribe from real-time updates.
Args:
subscription_id: ID returned from subscribe_bars()
"""
pass

View File

@@ -0,0 +1,109 @@
"""
DataSource registry for managing multiple data sources.
"""
from typing import Dict, List, Optional
from .base import DataSource
from .schema import SearchResult, SymbolInfo
class DataSourceRegistry:
"""
Central registry for managing multiple DataSource instances.
Allows routing symbol queries to the appropriate data source and
aggregating search results across multiple sources.
"""
def __init__(self):
self._sources: Dict[str, DataSource] = {}
def register(self, name: str, source: DataSource) -> None:
"""
Register a data source.
Args:
name: Unique name for this data source
source: DataSource implementation
"""
self._sources[name] = source
def unregister(self, name: str) -> None:
"""
Unregister a data source.
Args:
name: Name of the data source to remove
"""
self._sources.pop(name, None)
def get(self, name: str) -> Optional[DataSource]:
"""
Get a registered data source by name.
Args:
name: Data source name
Returns:
DataSource instance or None if not found
"""
return self._sources.get(name)
def list_sources(self) -> List[str]:
"""
Get names of all registered data sources.
Returns:
List of data source names
"""
return list(self._sources.keys())
async def search_all(
self,
query: str,
type: Optional[str] = None,
exchange: Optional[str] = None,
limit: int = 30,
) -> Dict[str, List[SearchResult]]:
"""
Search across all registered data sources.
Args:
query: Search query
type: Optional instrument type filter
exchange: Optional exchange filter
limit: Maximum results per source
Returns:
Dict mapping source name to search results
"""
results = {}
for name, source in self._sources.items():
try:
source_results = await source.search_symbols(query, type, exchange, limit)
if source_results:
results[name] = source_results
except Exception:
# Silently skip sources that error during search
continue
return results
async def resolve_symbol(self, source_name: str, symbol: str) -> SymbolInfo:
"""
Resolve a symbol from a specific data source.
Args:
source_name: Name of the data source
symbol: Symbol identifier
Returns:
SymbolInfo from the specified source
Raises:
ValueError: If source not found or symbol not found
"""
source = self.get(source_name)
if not source:
raise ValueError(f"Data source '{source_name}' not found")
return await source.resolve_symbol(symbol)

View File

@@ -0,0 +1,194 @@
"""
Data models for the DataSource interface.
Inspired by TradingView's Datafeed API but with flexible column schemas
for AI-native trading platform needs.
"""
from enum import StrEnum
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
class Resolution(StrEnum):
"""Standard time resolutions for bar data"""
# Seconds
S1 = "1S"
S5 = "5S"
S15 = "15S"
S30 = "30S"
# Minutes
M1 = "1"
M5 = "5"
M15 = "15"
M30 = "30"
# Hours
H1 = "60"
H2 = "120"
H4 = "240"
H6 = "360"
H12 = "720"
# Days
D1 = "1D"
# Weeks
W1 = "1W"
# Months
MO1 = "1M"
class ColumnInfo(BaseModel):
"""
Metadata for a single data column.
Provides rich, LLM-readable descriptions so AI agents can understand
and reason about available data fields.
"""
model_config = {"extra": "forbid"}
name: str = Field(description="Column name (e.g., 'close', 'volume', 'funding_rate')")
type: Literal["float", "int", "bool", "string", "decimal"] = Field(description="Data type")
description: str = Field(description="Human and LLM-readable description of what this column represents")
unit: Optional[str] = Field(default=None, description="Unit of measurement (e.g., 'USD', 'BTC', '%', 'contracts')")
nullable: bool = Field(default=False, description="Whether this column can contain null values")
class SymbolInfo(BaseModel):
"""
Complete metadata for a tradeable symbol.
Includes both TradingView-compatible fields and flexible schema definition
for arbitrary data columns.
"""
model_config = {"extra": "forbid"}
# Core identification
symbol: str = Field(description="Unique symbol identifier (primary key for data fetching)")
ticker: Optional[str] = Field(default=None, description="TradingView ticker (if different from symbol)")
name: str = Field(description="Display name")
description: str = Field(description="LLM-readable description of the instrument")
type: str = Field(description="Instrument type: 'crypto', 'stock', 'forex', 'futures', 'derived', etc.")
exchange: str = Field(description="Exchange or data source identifier")
# Trading session info
timezone: str = Field(default="Etc/UTC", description="IANA timezone identifier")
session: str = Field(default="24x7", description="Trading session spec (e.g., '0930-1600' or '24x7')")
# Resolution support
supported_resolutions: List[str] = Field(description="List of supported time resolutions")
has_intraday: bool = Field(default=True, description="Whether intraday resolutions are supported")
has_daily: bool = Field(default=True, description="Whether daily resolution is supported")
has_weekly_and_monthly: bool = Field(default=False, description="Whether weekly/monthly resolutions are supported")
# Flexible schema definition
columns: List[ColumnInfo] = Field(description="Available data columns for this symbol")
time_column: str = Field(default="time", description="Name of the timestamp column")
# Convenience flags
has_ohlcv: bool = Field(default=False, description="Whether standard OHLCV columns are present")
# Price display (for OHLCV data)
pricescale: int = Field(default=100, description="Price scale factor (e.g., 100 for 2 decimals)")
minmov: int = Field(default=1, description="Minimum price movement in pricescale units")
# Additional metadata
base_currency: Optional[str] = Field(default=None, description="Base currency (for crypto/forex)")
quote_currency: Optional[str] = Field(default=None, description="Quote currency (for crypto/forex)")
class Bar(BaseModel):
"""
A single bar/row of time-series data with flexible columns.
All bars must have a timestamp. Additional columns are stored in the
data dict and described by the associated ColumnInfo metadata.
"""
model_config = {"extra": "forbid"}
time: int = Field(description="Unix timestamp in seconds")
data: Dict[str, Any] = Field(description="Column name -> value mapping")
# Convenience accessors for common OHLCV columns
@property
def open(self) -> Optional[float]:
return self.data.get("open")
@property
def high(self) -> Optional[float]:
return self.data.get("high")
@property
def low(self) -> Optional[float]:
return self.data.get("low")
@property
def close(self) -> Optional[float]:
return self.data.get("close")
@property
def volume(self) -> Optional[float]:
return self.data.get("volume")
class HistoryResult(BaseModel):
"""
Result from a historical data query.
Includes the bars, schema information, and pagination metadata.
"""
model_config = {"extra": "forbid"}
symbol: str = Field(description="Symbol identifier")
resolution: str = Field(description="Time resolution of the bars")
bars: List[Bar] = Field(description="The actual data bars")
columns: List[ColumnInfo] = Field(description="Schema describing the bar data columns")
nextTime: Optional[int] = Field(default=None, description="Unix timestamp for pagination (if more data available)")
class SearchResult(BaseModel):
"""
A single result from symbol search.
"""
model_config = {"extra": "forbid"}
symbol: str = Field(description="Display symbol (e.g., 'BINANCE:ETH/BTC')")
ticker: Optional[str] = Field(default=None, description="Backend ticker for data fetching (e.g., 'ETH/BTC')")
full_name: str = Field(description="Full display name including exchange")
description: str = Field(description="Human-readable description")
exchange: str = Field(description="Exchange identifier")
type: str = Field(description="Instrument type")
class DatafeedConfig(BaseModel):
"""
Configuration and capabilities of a DataSource.
Similar to TradingView's onReady configuration object.
"""
model_config = {"extra": "forbid"}
# Supported features
supported_resolutions: List[str] = Field(description="All resolutions this datafeed supports")
supports_search: bool = Field(default=True, description="Whether symbol search is available")
supports_time: bool = Field(default=True, description="Whether time-based queries are supported")
supports_marks: bool = Field(default=False, description="Whether marks/events are supported")
# Data characteristics
exchanges: List[str] = Field(default_factory=list, description="Available exchanges")
symbols_types: List[str] = Field(default_factory=list, description="Available instrument types")
# Metadata
name: str = Field(description="Datafeed name")
description: str = Field(description="LLM-readable description of this data source")

View File

@@ -0,0 +1,235 @@
"""
Subscription manager for real-time data feeds.
Manages subscriptions across multiple data sources and routes updates
to WebSocket clients.
"""
import asyncio
import logging
from typing import Callable, Dict, Optional, Set
from .base import DataSource
logger = logging.getLogger(__name__)
class Subscription:
"""Represents a single client subscription"""
def __init__(
self,
subscription_id: str,
client_id: str,
source_name: str,
symbol: str,
resolution: str,
callback: Callable[[dict], None],
):
self.subscription_id = subscription_id
self.client_id = client_id
self.source_name = source_name
self.symbol = symbol
self.resolution = resolution
self.callback = callback
self.source_subscription_id: Optional[str] = None
class SubscriptionManager:
"""
Manages real-time data subscriptions across multiple data sources.
Handles:
- Subscription lifecycle (subscribe/unsubscribe)
- Routing updates from data sources to clients
- Multiplexing (multiple clients can subscribe to same symbol/resolution)
"""
def __init__(self):
# Map subscription_id -> Subscription
self._subscriptions: Dict[str, Subscription] = {}
# Map (source_name, symbol, resolution) -> Set[subscription_id]
# For tracking which client subscriptions use which source subscriptions
self._source_refs: Dict[tuple, Set[str]] = {}
# Map source_subscription_id -> (source_name, symbol, resolution)
self._source_subs: Dict[str, tuple] = {}
# Available data sources
self._sources: Dict[str, DataSource] = {}
def register_source(self, name: str, source: DataSource) -> None:
"""Register a data source"""
self._sources[name] = source
def unregister_source(self, name: str) -> None:
"""Unregister a data source"""
self._sources.pop(name, None)
async def subscribe(
self,
subscription_id: str,
client_id: str,
source_name: str,
symbol: str,
resolution: str,
callback: Callable[[dict], None],
) -> None:
"""
Subscribe a client to real-time updates.
Args:
subscription_id: Unique ID for this subscription
client_id: ID of the subscribing client
source_name: Name of the data source
symbol: Symbol to subscribe to
resolution: Time resolution
callback: Function to call with bar updates
Raises:
ValueError: If source not found or subscription fails
"""
source = self._sources.get(source_name)
if not source:
raise ValueError(f"Data source '{source_name}' not found")
# Create subscription record
subscription = Subscription(
subscription_id=subscription_id,
client_id=client_id,
source_name=source_name,
symbol=symbol,
resolution=resolution,
callback=callback,
)
# Check if we already have a source subscription for this (source, symbol, resolution)
source_key = (source_name, symbol, resolution)
if source_key not in self._source_refs:
# Need to create a new source subscription
try:
source_sub_id = await source.subscribe_bars(
symbol=symbol,
resolution=resolution,
on_tick=lambda bar: self._on_source_update(source_key, bar),
)
subscription.source_subscription_id = source_sub_id
self._source_subs[source_sub_id] = source_key
self._source_refs[source_key] = set()
logger.info(
f"Created new source subscription: {source_name}/{symbol}/{resolution} -> {source_sub_id}"
)
except Exception as e:
logger.error(f"Failed to subscribe to source: {e}")
raise
# Add this subscription to the reference set
self._source_refs[source_key].add(subscription_id)
self._subscriptions[subscription_id] = subscription
logger.info(
f"Client subscription added: {subscription_id} ({client_id}) -> {source_name}/{symbol}/{resolution}"
)
async def unsubscribe(self, subscription_id: str) -> None:
"""
Unsubscribe a client from updates.
Args:
subscription_id: ID of the subscription to remove
"""
subscription = self._subscriptions.pop(subscription_id, None)
if not subscription:
logger.warning(f"Subscription {subscription_id} not found")
return
source_key = (subscription.source_name, subscription.symbol, subscription.resolution)
# Remove from reference set
if source_key in self._source_refs:
self._source_refs[source_key].discard(subscription_id)
# If no more clients need this source subscription, unsubscribe from source
if not self._source_refs[source_key]:
del self._source_refs[source_key]
if subscription.source_subscription_id:
source = self._sources.get(subscription.source_name)
if source:
try:
await source.unsubscribe_bars(subscription.source_subscription_id)
logger.info(
f"Unsubscribed from source: {subscription.source_subscription_id}"
)
except Exception as e:
logger.error(f"Error unsubscribing from source: {e}")
self._source_subs.pop(subscription.source_subscription_id, None)
logger.info(f"Client subscription removed: {subscription_id}")
async def unsubscribe_client(self, client_id: str) -> None:
"""
Unsubscribe all subscriptions for a client.
Useful when a WebSocket connection closes.
Args:
client_id: ID of the client
"""
# Find all subscriptions for this client
client_subs = [
sub_id
for sub_id, sub in self._subscriptions.items()
if sub.client_id == client_id
]
# Unsubscribe each one
for sub_id in client_subs:
await self.unsubscribe(sub_id)
logger.info(f"Unsubscribed all subscriptions for client {client_id}")
def _on_source_update(self, source_key: tuple, bar: dict) -> None:
"""
Handle an update from a data source.
Routes the update to all client subscriptions that need it.
Args:
source_key: (source_name, symbol, resolution) tuple
bar: Bar data dict from the source
"""
subscription_ids = self._source_refs.get(source_key, set())
for sub_id in subscription_ids:
subscription = self._subscriptions.get(sub_id)
if subscription:
try:
subscription.callback(bar)
except Exception as e:
logger.error(
f"Error in subscription callback {sub_id}: {e}", exc_info=True
)
def get_subscription_count(self) -> int:
"""Get total number of active client subscriptions"""
return len(self._subscriptions)
def get_source_subscription_count(self) -> int:
"""Get total number of active source subscriptions"""
return len(self._source_refs)
def get_client_subscriptions(self, client_id: str) -> list:
"""Get all subscriptions for a specific client"""
return [
{
"subscription_id": sub.subscription_id,
"source": sub.source_name,
"symbol": sub.symbol,
"resolution": sub.resolution,
}
for sub in self._subscriptions.values()
if sub.client_id == client_id
]

View File

@@ -0,0 +1,347 @@
"""
WebSocket handler for TradingView-compatible datafeed API.
Handles incoming requests for symbol search, metadata, historical data,
and real-time subscriptions.
"""
import json
import logging
from typing import Dict, Optional
from fastapi import WebSocket
from .base import DataSource
from .registry import DataSourceRegistry
from .subscription_manager import SubscriptionManager
from .websocket_protocol import (
BarUpdateMessage,
ErrorResponse,
GetBarsRequest,
GetBarsResponse,
GetConfigRequest,
GetConfigResponse,
ResolveSymbolRequest,
ResolveSymbolResponse,
SearchSymbolsRequest,
SearchSymbolsResponse,
SubscribeBarsRequest,
SubscribeBarsResponse,
UnsubscribeBarsRequest,
UnsubscribeBarsResponse,
)
logger = logging.getLogger(__name__)
class DatafeedWebSocketHandler:
"""
Handles WebSocket connections for TradingView-compatible datafeed API.
Each handler manages a single WebSocket connection and routes requests
to the appropriate data sources via the registry.
"""
def __init__(
self,
websocket: WebSocket,
client_id: str,
registry: DataSourceRegistry,
subscription_manager: SubscriptionManager,
default_source: Optional[str] = None,
):
"""
Initialize handler.
Args:
websocket: FastAPI WebSocket connection
client_id: Unique identifier for this client
registry: DataSource registry for accessing data sources
subscription_manager: Shared subscription manager
default_source: Default data source name if not specified in requests
"""
self.websocket = websocket
self.client_id = client_id
self.registry = registry
self.subscription_manager = subscription_manager
self.default_source = default_source
self._connected = True
async def handle_connection(self) -> None:
"""
Main connection handler loop.
Processes incoming messages until the connection closes.
"""
try:
await self.websocket.accept()
logger.info(f"WebSocket connected: client_id={self.client_id}")
while self._connected:
# Receive message
try:
data = await self.websocket.receive_text()
message = json.loads(data)
except Exception as e:
logger.error(f"Error receiving/parsing message: {e}")
break
# Route to appropriate handler
await self._handle_message(message)
except Exception as e:
logger.error(f"WebSocket error: {e}", exc_info=True)
finally:
# Clean up subscriptions when connection closes
await self.subscription_manager.unsubscribe_client(self.client_id)
self._connected = False
logger.info(f"WebSocket disconnected: client_id={self.client_id}")
async def _handle_message(self, message: dict) -> None:
"""Route message to appropriate handler based on type"""
msg_type = message.get("type")
request_id = message.get("request_id", "unknown")
try:
if msg_type == "search_symbols":
await self._handle_search_symbols(SearchSymbolsRequest(**message))
elif msg_type == "resolve_symbol":
await self._handle_resolve_symbol(ResolveSymbolRequest(**message))
elif msg_type == "get_bars":
await self._handle_get_bars(GetBarsRequest(**message))
elif msg_type == "subscribe_bars":
await self._handle_subscribe_bars(SubscribeBarsRequest(**message))
elif msg_type == "unsubscribe_bars":
await self._handle_unsubscribe_bars(UnsubscribeBarsRequest(**message))
elif msg_type == "get_config":
await self._handle_get_config(GetConfigRequest(**message))
else:
await self._send_error(
request_id, "UNKNOWN_REQUEST_TYPE", f"Unknown request type: {msg_type}"
)
except Exception as e:
logger.error(f"Error handling {msg_type}: {e}", exc_info=True)
await self._send_error(request_id, "INTERNAL_ERROR", str(e))
async def _handle_search_symbols(self, request: SearchSymbolsRequest) -> None:
"""Handle symbol search request"""
# Use default source or search all sources
if self.default_source:
source = self.registry.get(self.default_source)
if not source:
await self._send_error(
request.request_id,
"SOURCE_NOT_FOUND",
f"Default source '{self.default_source}' not found",
)
return
results = await source.search_symbols(
query=request.query,
type=request.symbol_type,
exchange=request.exchange,
limit=request.limit,
)
results_data = [r.model_dump(mode="json") for r in results]
else:
# Search all sources
all_results = await self.registry.search_all(
query=request.query,
type=request.symbol_type,
exchange=request.exchange,
limit=request.limit,
)
# Flatten results from all sources
results_data = []
for source_results in all_results.values():
results_data.extend([r.model_dump(mode="json") for r in source_results])
response = SearchSymbolsResponse(request_id=request.request_id, results=results_data)
await self._send_response(response)
async def _handle_resolve_symbol(self, request: ResolveSymbolRequest) -> None:
"""Handle symbol resolution request"""
# Extract source from symbol if present (format: "SOURCE:SYMBOL")
if ":" in request.symbol:
source_name, symbol = request.symbol.split(":", 1)
else:
source_name = self.default_source
symbol = request.symbol
if not source_name:
await self._send_error(
request.request_id,
"NO_SOURCE_SPECIFIED",
"No data source specified and no default source configured",
)
return
try:
symbol_info = await self.registry.resolve_symbol(source_name, symbol)
response = ResolveSymbolResponse(
request_id=request.request_id,
symbol_info=symbol_info.model_dump(mode="json"),
)
await self._send_response(response)
except ValueError as e:
await self._send_error(request.request_id, "SYMBOL_NOT_FOUND", str(e))
async def _handle_get_bars(self, request: GetBarsRequest) -> None:
"""Handle historical bars request"""
# Extract source from symbol
if ":" in request.symbol:
source_name, symbol = request.symbol.split(":", 1)
else:
source_name = self.default_source
symbol = request.symbol
if not source_name:
await self._send_error(
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
)
return
source = self.registry.get(source_name)
if not source:
await self._send_error(
request.request_id, "SOURCE_NOT_FOUND", f"Source '{source_name}' not found"
)
return
try:
history = await source.get_bars(
symbol=symbol,
resolution=request.resolution,
from_time=request.from_time,
to_time=request.to_time,
countback=request.countback,
)
response = GetBarsResponse(
request_id=request.request_id, history=history.model_dump(mode="json")
)
await self._send_response(response)
except ValueError as e:
await self._send_error(request.request_id, "INVALID_REQUEST", str(e))
async def _handle_subscribe_bars(self, request: SubscribeBarsRequest) -> None:
"""Handle real-time subscription request"""
# Extract source from symbol
if ":" in request.symbol:
source_name, symbol = request.symbol.split(":", 1)
else:
source_name = self.default_source
symbol = request.symbol
if not source_name:
await self._send_error(
request.request_id, "NO_SOURCE_SPECIFIED", "No data source specified"
)
return
try:
# Create callback that sends updates to this WebSocket
async def send_update(bar: dict):
update = BarUpdateMessage(
subscription_id=request.subscription_id,
symbol=request.symbol,
resolution=request.resolution,
bar=bar,
)
await self._send_response(update)
# Register subscription
await self.subscription_manager.subscribe(
subscription_id=request.subscription_id,
client_id=self.client_id,
source_name=source_name,
symbol=symbol,
resolution=request.resolution,
callback=lambda bar: self._queue_update(send_update(bar)),
)
response = SubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=True,
)
await self._send_response(response)
except Exception as e:
logger.error(f"Subscription failed: {e}", exc_info=True)
response = SubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=False,
message=str(e),
)
await self._send_response(response)
async def _handle_unsubscribe_bars(self, request: UnsubscribeBarsRequest) -> None:
"""Handle unsubscribe request"""
try:
await self.subscription_manager.unsubscribe(request.subscription_id)
response = UnsubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=True,
)
await self._send_response(response)
except Exception as e:
logger.error(f"Unsubscribe failed: {e}")
response = UnsubscribeBarsResponse(
request_id=request.request_id,
subscription_id=request.subscription_id,
success=False,
)
await self._send_response(response)
async def _handle_get_config(self, request: GetConfigRequest) -> None:
"""Handle datafeed config request"""
if self.default_source:
source = self.registry.get(self.default_source)
if source:
config = await source.get_config()
response = GetConfigResponse(
request_id=request.request_id, config=config.model_dump(mode="json")
)
await self._send_response(response)
return
# Return aggregate config from all sources
all_sources = self.registry.list_sources()
if not all_sources:
await self._send_error(
request.request_id, "NO_SOURCES", "No data sources available"
)
return
# Just use first source's config for now
# TODO: Aggregate configs from multiple sources
source = self.registry.get(all_sources[0])
if source:
config = await source.get_config()
response = GetConfigResponse(
request_id=request.request_id, config=config.model_dump(mode="json")
)
await self._send_response(response)
async def _send_response(self, response) -> None:
"""Send a response message to the client"""
try:
await self.websocket.send_json(response.model_dump(mode="json"))
except Exception as e:
logger.error(f"Error sending response: {e}")
self._connected = False
async def _send_error(self, request_id: str, error_code: str, error_message: str) -> None:
"""Send an error response"""
error = ErrorResponse(
request_id=request_id, error_code=error_code, error_message=error_message
)
await self._send_response(error)
def _queue_update(self, coro):
"""Queue an async update to be sent (prevents blocking callback)"""
import asyncio
asyncio.create_task(coro)

View File

@@ -0,0 +1,170 @@
"""
WebSocket protocol messages for TradingView-compatible datafeed API.
These messages define the wire format for client-server communication
over WebSocket for symbol search, historical data, and real-time subscriptions.
"""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
# ============================================================================
# Client -> Server Messages
# ============================================================================
class SearchSymbolsRequest(BaseModel):
"""Request to search for symbols matching a query"""
type: Literal["search_symbols"] = "search_symbols"
request_id: str = Field(description="Client-generated request ID for matching responses")
query: str = Field(description="Search query string")
symbol_type: Optional[str] = Field(default=None, description="Filter by instrument type")
exchange: Optional[str] = Field(default=None, description="Filter by exchange")
limit: int = Field(default=30, description="Maximum number of results")
class ResolveSymbolRequest(BaseModel):
"""Request full metadata for a specific symbol"""
type: Literal["resolve_symbol"] = "resolve_symbol"
request_id: str
symbol: str = Field(description="Symbol identifier to resolve")
class GetBarsRequest(BaseModel):
"""Request historical bar data"""
type: Literal["get_bars"] = "get_bars"
request_id: str
symbol: str
resolution: str = Field(description="Time resolution (e.g., '1', '5', '60', '1D')")
from_time: int = Field(description="Start time (Unix timestamp in seconds)")
to_time: int = Field(description="End time (Unix timestamp in seconds)")
countback: Optional[int] = Field(default=None, description="Maximum number of bars to return")
class SubscribeBarsRequest(BaseModel):
"""Subscribe to real-time bar updates"""
type: Literal["subscribe_bars"] = "subscribe_bars"
request_id: str
symbol: str
resolution: str
subscription_id: str = Field(description="Client-generated subscription ID")
class UnsubscribeBarsRequest(BaseModel):
"""Unsubscribe from real-time updates"""
type: Literal["unsubscribe_bars"] = "unsubscribe_bars"
request_id: str
subscription_id: str
class GetConfigRequest(BaseModel):
"""Request datafeed configuration"""
type: Literal["get_config"] = "get_config"
request_id: str
# Union of all client request types
ClientRequest = Union[
SearchSymbolsRequest,
ResolveSymbolRequest,
GetBarsRequest,
SubscribeBarsRequest,
UnsubscribeBarsRequest,
GetConfigRequest,
]
# ============================================================================
# Server -> Client Messages
# ============================================================================
class SearchSymbolsResponse(BaseModel):
"""Response with search results"""
type: Literal["search_symbols_response"] = "search_symbols_response"
request_id: str
results: List[Dict[str, Any]] = Field(description="List of SearchResult objects")
class ResolveSymbolResponse(BaseModel):
"""Response with symbol metadata"""
type: Literal["resolve_symbol_response"] = "resolve_symbol_response"
request_id: str
symbol_info: Dict[str, Any] = Field(description="SymbolInfo object")
class GetBarsResponse(BaseModel):
"""Response with historical bars"""
type: Literal["get_bars_response"] = "get_bars_response"
request_id: str
history: Dict[str, Any] = Field(description="HistoryResult object with bars and metadata")
class SubscribeBarsResponse(BaseModel):
"""Acknowledgment of subscription"""
type: Literal["subscribe_bars_response"] = "subscribe_bars_response"
request_id: str
subscription_id: str
success: bool
message: Optional[str] = None
class UnsubscribeBarsResponse(BaseModel):
"""Acknowledgment of unsubscribe"""
type: Literal["unsubscribe_bars_response"] = "unsubscribe_bars_response"
request_id: str
subscription_id: str
success: bool
class GetConfigResponse(BaseModel):
"""Response with datafeed configuration"""
type: Literal["get_config_response"] = "get_config_response"
request_id: str
config: Dict[str, Any] = Field(description="DatafeedConfig object")
class BarUpdateMessage(BaseModel):
"""Real-time bar update (server-initiated, no request_id)"""
type: Literal["bar_update"] = "bar_update"
subscription_id: str
symbol: str
resolution: str
bar: Dict[str, Any] = Field(description="Bar data including time and all columns")
class ErrorResponse(BaseModel):
"""Error response for any failed request"""
type: Literal["error"] = "error"
request_id: str
error_code: str = Field(description="Machine-readable error code")
error_message: str = Field(description="Human-readable error description")
# Union of all server response types
ServerResponse = Union[
SearchSymbolsResponse,
ResolveSymbolResponse,
GetBarsResponse,
SubscribeBarsResponse,
UnsubscribeBarsResponse,
GetConfigResponse,
BarUpdateMessage,
ErrorResponse,
]