240 lines
7.8 KiB
Python
240 lines
7.8 KiB
Python
"""
|
|
Implementation of ChartingAPI using mplfinance for professional financial charts.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Tuple, List
|
|
import pandas as pd
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib.figure import Figure
|
|
import mplfinance as mpf
|
|
|
|
from dexorder.api.charting_api import ChartingAPI
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ChartingAPIImpl(ChartingAPI):
|
|
"""
|
|
Implementation of ChartingAPI using mplfinance.
|
|
|
|
This implementation provides professional-looking financial charts with:
|
|
- Candlestick plots with various styling options
|
|
- Easy addition of indicator panels with proper alignment
|
|
- Consistent theming across all chart types
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the charting API implementation."""
|
|
pass
|
|
|
|
def plot_ohlc(
|
|
self,
|
|
df: pd.DataFrame,
|
|
title: Optional[str] = None,
|
|
volume: bool = False,
|
|
style: str = "charles",
|
|
figsize: Tuple[int, int] = (12, 8),
|
|
**kwargs
|
|
) -> Tuple[Figure, plt.Axes]:
|
|
"""
|
|
Create a candlestick chart from OHLC data.
|
|
|
|
See ChartingAPI.plot_ohlc for full documentation.
|
|
"""
|
|
# Prepare the dataframe for mplfinance
|
|
df_plot = self._prepare_ohlc_dataframe(df)
|
|
|
|
# Create the plot
|
|
fig, axes = mpf.plot(
|
|
df_plot,
|
|
type='candle',
|
|
style=style,
|
|
title=title,
|
|
volume=volume,
|
|
figsize=figsize,
|
|
returnfig=True,
|
|
**kwargs
|
|
)
|
|
|
|
# Return the main price axes (first axes is price, second is volume if present)
|
|
main_ax = axes[0]
|
|
|
|
return fig, main_ax
|
|
|
|
def add_indicator_panel(
|
|
self,
|
|
fig: Figure,
|
|
df: pd.DataFrame,
|
|
columns: Optional[List[str]] = None,
|
|
ylabel: Optional[str] = None,
|
|
height_ratio: float = 0.3,
|
|
ylim: Optional[Tuple[float, float]] = None,
|
|
**kwargs
|
|
) -> plt.Axes:
|
|
"""
|
|
Add a new indicator panel below existing plots with aligned x-axis.
|
|
|
|
See ChartingAPI.add_indicator_panel for full documentation.
|
|
"""
|
|
# Determine which columns to plot
|
|
if columns is None:
|
|
# Plot all numeric columns
|
|
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
columns = numeric_cols
|
|
else:
|
|
# Validate columns exist
|
|
missing = set(columns) - set(df.columns)
|
|
if missing:
|
|
raise ValueError(f"Columns not found in DataFrame: {missing}")
|
|
|
|
# Get existing axes
|
|
existing_axes = fig.get_axes()
|
|
if not existing_axes:
|
|
raise ValueError("Figure has no existing axes. Create a plot first with plot_ohlc().")
|
|
|
|
# Calculate new grid layout
|
|
n_existing = len(existing_axes)
|
|
|
|
# Calculate height ratios: existing axes maintain their relative sizes,
|
|
# new axes gets height_ratio relative to the first (main) axes
|
|
existing_heights = [ax.get_position().height for ax in existing_axes]
|
|
main_height = existing_heights[0]
|
|
new_height = main_height * height_ratio
|
|
|
|
# Adjust existing axes positions to make room for new panel
|
|
total_height = sum(existing_heights) + new_height
|
|
current_top = 0.98 # Leave small margin at top
|
|
current_bottom = 0.05 # Leave margin at bottom
|
|
available_height = current_top - current_bottom
|
|
|
|
# Reposition existing axes
|
|
for i, ax in enumerate(existing_axes):
|
|
old_pos = ax.get_position()
|
|
normalized_height = (existing_heights[i] / total_height) * available_height
|
|
new_top = current_top - (sum(existing_heights[:i]) / total_height) * available_height
|
|
new_bottom = new_top - normalized_height
|
|
ax.set_position([old_pos.x0, new_bottom, old_pos.width, normalized_height])
|
|
|
|
# Create new axes at the bottom
|
|
normalized_new_height = (new_height / total_height) * available_height
|
|
new_bottom = current_bottom
|
|
new_top = new_bottom + normalized_new_height
|
|
|
|
first_ax_pos = existing_axes[0].get_position()
|
|
new_ax = fig.add_axes([
|
|
first_ax_pos.x0,
|
|
new_bottom,
|
|
first_ax_pos.width,
|
|
normalized_new_height
|
|
])
|
|
|
|
# Share x-axis with the first axes for time alignment
|
|
new_ax.sharex(existing_axes[0])
|
|
|
|
# Plot the indicator data
|
|
for col in columns:
|
|
if col in df.columns:
|
|
# Handle potential timestamp index (convert from microseconds)
|
|
if df.index.name == 'timestamp' or 'timestamp' in str(df.index.dtype):
|
|
# Assume nanoseconds, convert to datetime
|
|
plot_index = pd.to_datetime(df.index, unit='ns')
|
|
else:
|
|
plot_index = df.index
|
|
|
|
new_ax.plot(plot_index, df[col], label=col, **kwargs)
|
|
|
|
# Styling
|
|
if ylabel:
|
|
new_ax.set_ylabel(ylabel)
|
|
|
|
if ylim:
|
|
new_ax.set_ylim(ylim)
|
|
|
|
if len(columns) > 1:
|
|
new_ax.legend(loc='best')
|
|
|
|
new_ax.grid(True, alpha=0.3)
|
|
|
|
# Only show x-axis labels on the bottom-most panel
|
|
for ax in existing_axes:
|
|
ax.set_xlabel('')
|
|
plt.setp(ax.get_xticklabels(), visible=False)
|
|
|
|
return new_ax
|
|
|
|
def create_figure(
|
|
self,
|
|
figsize: Tuple[int, int] = (12, 8),
|
|
style: str = "charles"
|
|
) -> Tuple[Figure, plt.Axes]:
|
|
"""
|
|
Create a styled figure without OHLC data for custom visualizations.
|
|
|
|
See ChartingAPI.create_figure for full documentation.
|
|
"""
|
|
# Get the style parameters from mplfinance
|
|
mpf_style = mpf.make_mpf_style(base_mpf_style=style)
|
|
|
|
# Create figure with the style's colors
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
|
|
# Apply style colors if available
|
|
if 'facecolor' in mpf_style:
|
|
fig.patch.set_facecolor(mpf_style['facecolor'])
|
|
if 'figcolor' in mpf_style:
|
|
ax.set_facecolor(mpf_style['figcolor'])
|
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
return fig, ax
|
|
|
|
def _prepare_ohlc_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""
|
|
Prepare a DataFrame for mplfinance plotting.
|
|
|
|
Ensures the DataFrame has the correct format:
|
|
- DatetimeIndex
|
|
- Lowercase column names: open, high, low, close, volume
|
|
|
|
Args:
|
|
df: Input DataFrame with OHLC data
|
|
|
|
Returns:
|
|
DataFrame ready for mplfinance
|
|
"""
|
|
df_copy = df.copy()
|
|
|
|
# Handle timestamp column (in nanoseconds) -> DatetimeIndex
|
|
if 'timestamp' in df_copy.columns:
|
|
df_copy.index = pd.to_datetime(df_copy['timestamp'], unit='ns')
|
|
df_copy = df_copy.drop(columns=['timestamp'])
|
|
elif df_copy.index.name == 'timestamp' or 'int' in str(df_copy.index.dtype):
|
|
# Index is timestamp in nanoseconds
|
|
df_copy.index = pd.to_datetime(df_copy.index, unit='ns')
|
|
|
|
# Ensure index is DatetimeIndex
|
|
if not isinstance(df_copy.index, pd.DatetimeIndex):
|
|
raise ValueError(
|
|
"DataFrame must have a DatetimeIndex or a 'timestamp' column in nanoseconds"
|
|
)
|
|
|
|
# Normalize column names to lowercase
|
|
df_copy.columns = df_copy.columns.str.lower()
|
|
|
|
# Validate required columns
|
|
required = ['open', 'high', 'low', 'close']
|
|
missing = set(required) - set(df_copy.columns)
|
|
if missing:
|
|
raise ValueError(f"DataFrame missing required OHLC columns: {missing}")
|
|
|
|
# Keep only OHLC(V) columns for mplfinance
|
|
keep_cols = ['open', 'high', 'low', 'close']
|
|
if 'volume' in df_copy.columns:
|
|
keep_cols.append('volume')
|
|
|
|
df_copy = df_copy[keep_cols]
|
|
|
|
return df_copy
|