Files
ai/sandbox/dexorder/impl/charting_api_impl.py

240 lines
7.8 KiB
Python

"""
Implementation of ChartingAPI using mplfinance for professional financial charts.
"""
import logging
from typing import Optional, Tuple, List
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
import mplfinance as mpf
from dexorder.api.charting_api import ChartingAPI
log = logging.getLogger(__name__)
class ChartingAPIImpl(ChartingAPI):
"""
Implementation of ChartingAPI using mplfinance.
This implementation provides professional-looking financial charts with:
- Candlestick plots with various styling options
- Easy addition of indicator panels with proper alignment
- Consistent theming across all chart types
"""
def __init__(self):
"""Initialize the charting API implementation."""
pass
def plot_ohlc(
self,
df: pd.DataFrame,
title: Optional[str] = None,
volume: bool = False,
style: str = "charles",
figsize: Tuple[int, int] = (12, 8),
**kwargs
) -> Tuple[Figure, plt.Axes]:
"""
Create a candlestick chart from OHLC data.
See ChartingAPI.plot_ohlc for full documentation.
"""
# Prepare the dataframe for mplfinance
df_plot = self._prepare_ohlc_dataframe(df)
# Create the plot
fig, axes = mpf.plot(
df_plot,
type='candle',
style=style,
title=title,
volume=volume,
figsize=figsize,
returnfig=True,
**kwargs
)
# Return the main price axes (first axes is price, second is volume if present)
main_ax = axes[0]
return fig, main_ax
def add_indicator_panel(
self,
fig: Figure,
df: pd.DataFrame,
columns: Optional[List[str]] = None,
ylabel: Optional[str] = None,
height_ratio: float = 0.3,
ylim: Optional[Tuple[float, float]] = None,
**kwargs
) -> plt.Axes:
"""
Add a new indicator panel below existing plots with aligned x-axis.
See ChartingAPI.add_indicator_panel for full documentation.
"""
# Determine which columns to plot
if columns is None:
# Plot all numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
columns = numeric_cols
else:
# Validate columns exist
missing = set(columns) - set(df.columns)
if missing:
raise ValueError(f"Columns not found in DataFrame: {missing}")
# Get existing axes
existing_axes = fig.get_axes()
if not existing_axes:
raise ValueError("Figure has no existing axes. Create a plot first with plot_ohlc().")
# Calculate new grid layout
n_existing = len(existing_axes)
# Calculate height ratios: existing axes maintain their relative sizes,
# new axes gets height_ratio relative to the first (main) axes
existing_heights = [ax.get_position().height for ax in existing_axes]
main_height = existing_heights[0]
new_height = main_height * height_ratio
# Adjust existing axes positions to make room for new panel
total_height = sum(existing_heights) + new_height
current_top = 0.98 # Leave small margin at top
current_bottom = 0.05 # Leave margin at bottom
available_height = current_top - current_bottom
# Reposition existing axes
for i, ax in enumerate(existing_axes):
old_pos = ax.get_position()
normalized_height = (existing_heights[i] / total_height) * available_height
new_top = current_top - (sum(existing_heights[:i]) / total_height) * available_height
new_bottom = new_top - normalized_height
ax.set_position([old_pos.x0, new_bottom, old_pos.width, normalized_height])
# Create new axes at the bottom
normalized_new_height = (new_height / total_height) * available_height
new_bottom = current_bottom
new_top = new_bottom + normalized_new_height
first_ax_pos = existing_axes[0].get_position()
new_ax = fig.add_axes([
first_ax_pos.x0,
new_bottom,
first_ax_pos.width,
normalized_new_height
])
# Share x-axis with the first axes for time alignment
new_ax.sharex(existing_axes[0])
# Plot the indicator data
for col in columns:
if col in df.columns:
# Handle potential timestamp index (convert from microseconds)
if df.index.name == 'timestamp' or 'timestamp' in str(df.index.dtype):
# Assume nanoseconds, convert to datetime
plot_index = pd.to_datetime(df.index, unit='ns')
else:
plot_index = df.index
new_ax.plot(plot_index, df[col], label=col, **kwargs)
# Styling
if ylabel:
new_ax.set_ylabel(ylabel)
if ylim:
new_ax.set_ylim(ylim)
if len(columns) > 1:
new_ax.legend(loc='best')
new_ax.grid(True, alpha=0.3)
# Only show x-axis labels on the bottom-most panel
for ax in existing_axes:
ax.set_xlabel('')
plt.setp(ax.get_xticklabels(), visible=False)
return new_ax
def create_figure(
self,
figsize: Tuple[int, int] = (12, 8),
style: str = "charles"
) -> Tuple[Figure, plt.Axes]:
"""
Create a styled figure without OHLC data for custom visualizations.
See ChartingAPI.create_figure for full documentation.
"""
# Get the style parameters from mplfinance
mpf_style = mpf.make_mpf_style(base_mpf_style=style)
# Create figure with the style's colors
fig, ax = plt.subplots(figsize=figsize)
# Apply style colors if available
if 'facecolor' in mpf_style:
fig.patch.set_facecolor(mpf_style['facecolor'])
if 'figcolor' in mpf_style:
ax.set_facecolor(mpf_style['figcolor'])
ax.grid(True, alpha=0.3)
return fig, ax
def _prepare_ohlc_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Prepare a DataFrame for mplfinance plotting.
Ensures the DataFrame has the correct format:
- DatetimeIndex
- Lowercase column names: open, high, low, close, volume
Args:
df: Input DataFrame with OHLC data
Returns:
DataFrame ready for mplfinance
"""
df_copy = df.copy()
# Handle timestamp column (in nanoseconds) -> DatetimeIndex
if 'timestamp' in df_copy.columns:
df_copy.index = pd.to_datetime(df_copy['timestamp'], unit='ns')
df_copy = df_copy.drop(columns=['timestamp'])
elif df_copy.index.name == 'timestamp' or 'int' in str(df_copy.index.dtype):
# Index is timestamp in nanoseconds
df_copy.index = pd.to_datetime(df_copy.index, unit='ns')
# Ensure index is DatetimeIndex
if not isinstance(df_copy.index, pd.DatetimeIndex):
raise ValueError(
"DataFrame must have a DatetimeIndex or a 'timestamp' column in nanoseconds"
)
# Normalize column names to lowercase
df_copy.columns = df_copy.columns.str.lower()
# Validate required columns
required = ['open', 'high', 'low', 'close']
missing = set(required) - set(df_copy.columns)
if missing:
raise ValueError(f"DataFrame missing required OHLC columns: {missing}")
# Keep only OHLC(V) columns for mplfinance
keep_cols = ['open', 'high', 'low', 'close']
if 'volume' in df_copy.columns:
keep_cols.append('volume')
df_copy = df_copy[keep_cols]
return df_copy