###############################################################################
## ##
## LIBRARIES ##
## ##
###############################################################################
# Ignoring a depracation warning to ensure a better console run
import warnings
# Suppress FutureWarning from Seaborn regarding 'use_inf_as_na'
warnings.filterwarnings("ignore", category=FutureWarning, message="use_inf_as_na option is deprecated")
# General Libraries
import numpy as np
import pandas as pd
from pathlib import Path
import calendar
from typing import Union, Dict, Any
from types import SimpleNamespace
import itertools
from itertools import starmap, chain, cycle
from functools import partial
# Plotting Libraries
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors
import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from scipy.signal import csd
# Module imports
from Hydrological_model_validator.Plotting.formatting import (plot_line,
get_min_max_for_identity_line,
get_variable_label_unit,
style_axes_spines,
format_unit)
from Hydrological_model_validator.Processing.data_alignment import (extract_mod_sat_keys,
gather_monthly_data_across_years)
from Hydrological_model_validator.Processing.stats_math_utils import (fit_huber,
fit_lowess,
)
from Hydrological_model_validator.Processing.time_utils import get_season_mask
from Hydrological_model_validator.Processing.utils import extract_options
from Hydrological_model_validator.Plotting.default_plot_options import (default_plot_options_ts,
default_plot_options_scatter,
default_scatter_by_season_options,
default_boxplot_options,
default_violinplot_options,
default_efficiency_plot_options,
spatial_efficiency_defaults,
default_error_timeseries_options,
default_spectral)
###############################################################################
## ##
## FUNCTIONS ##
## ##
###############################################################################
[docs]
def timeseries(data_dict: Dict[str, Union[pd.Series, list]], BIAS: Union[pd.Series, list, None], **kwargs: Any) -> None:
"""
Plot time series of daily mean values from multiple datasets along with BIAS.
This function generates a two-panel time series plot:
1. The upper subplot displays daily mean values of each dataset (typically model and satellite data).
2. The lower subplot shows the BIAS (model - satellite) as a time series.
The figure is saved to a specified output directory as a PNG file and displayed using matplotlib.
Parameters
----------
data_dict : Dict[str, Union[pd.Series, list]]
Dictionary containing daily mean values for different sources (e.g., model and satellite).
Keys are strings identifying the data source.
Values should be `pandas.Series` with datetime indices or lists convertible to Series.
BIAS : Union[pd.Series, list]
Series (or list) representing the BIAS time series (typically model - satellite).
Should be time-aligned with the values in `data_dict`.
Keyword Arguments
-----------------
output_path : str or Path
Required. Path where the figure should be saved.
variable_name : str
Required. Variable code name used to infer full name and unit.
variable : str, optional
Full variable name (e.g., "Chlorophyll"). Used in titles and axis.
unit : str, optional
Unit of measurement (e.g., "mg Chl/m³"). Displayed on axis.
BA : bool, optional
If True, appends " (Basin Average)" to the title.
figsize : tuple of float, optional
Size of figure in inches (default typically (12, 8)).
dpi : int, optional
Resolution of the figure (default 100).
color_palette : iterator, optional
Iterator of colors (e.g., `itertools.cycle(sns.color_palette("tab10"))`).
line_width : float, optional
Width of plotted lines (default 2.0).
title_fontsize : int, optional
Font size of the main title.
bias_title_fontsize : int, optional
Font size of the BIAS subplot title.
label_fontsize : int, optional
Font size of axis labels.
legend_fontsize : int, optional
Font size of the legend.
savefig_kwargs : dict, optional
Additional args for `plt.savefig()`, e.g., `bbox_inches`, `transparent`.
Example
-------
>>> timeseries(
... data_dict={'model': model_series, 'satellite': sat_series},
... BIAS=model_series - sat_series,
... variable_name='Chl',
... output_path='figures/',
... BA=True
... )
Notes
-----
- If `variable` or `unit` is not provided, the function attempts to resolve them using
`get_variable_label_unit(variable_name)`.
- The data series are auto-converted to `pandas.Series` if passed as lists.
- The order and coloring of plotted datasets depend on the order in `data_dict` and `color_palette`.
"""
# ----- RETRIEVE DEFAULT OPTIONS -----
options = SimpleNamespace(**{**default_plot_options_ts, **kwargs})
# ----- REQUIRED PARAMS CHECK -----
if options.output_path is None:
raise ValueError("output_path must be specified either in kwargs or default options.")
if options.variable_name is not None:
variable, unit = get_variable_label_unit(options.variable_name)
options.variable = options.variable or variable
options.unit = options.unit or unit
else:
if options.variable is None or options.unit is None:
raise ValueError("If 'variable_name' is not provided, both 'variable' and 'unit' must be specified.")
# ----- BASIN AVERAGE LABEL -----
title = f'Daily Mean Values for {options.variable} Datasets'
if options.BA:
title += ' (Basin Average)'
mod_key, sat_key = extract_mod_sat_keys(data_dict)
label_lookup = {mod_key: "Model Output", sat_key: "Satellite Observations"}
# ----- CONVERT INPUTS TO SERIES -----
data_dict = {k: pd.Series(v) if not isinstance(v, pd.Series) else v for k, v in data_dict.items()}
if BIAS is not None:
BIAS = pd.Series(BIAS) if not isinstance(BIAS, pd.Series) else BIAS
# ----- FIGURE SETUP -----
if BIAS is not None:
fig = plt.figure(figsize=options.figsize, dpi=options.dpi)
gs = GridSpec(2, 1, height_ratios=[8, 4])
ax1 = fig.add_subplot(gs[0])
else:
fig = plt.figure(figsize=(options.figsize[0], options.figsize[1] * 0.6), dpi=options.dpi)
gs = GridSpec(1, 1)
ax1 = fig.add_subplot(gs[0])
# ----- MAIN TIMESERIES PLOT -----
plotter = partial(
plot_line,
ax=ax1,
label_lookup=label_lookup,
color_palette=options.color_palette,
line_width=options.line_width,
library='plt',
)
list(starmap(plotter, data_dict.items()))
ax1.set_title(title, fontsize=options.title_fontsize, fontweight='bold')
ax1.set_ylabel(f'{options.variable} {options.unit}', fontsize=options.label_fontsize)
ax1.tick_params(width=2)
ax1.legend(loc='upper left', fontsize=options.legend_fontsize)
ax1.grid(True, linestyle='--')
style_axes_spines(ax1)
# ----- OPTIONAL BIAS PLOT -----
if BIAS is not None:
ax2 = fig.add_subplot(gs[1])
ax2.plot(BIAS.index, BIAS.values, color='k')
ax2.set_title(f'BIAS ({options.variable})', fontsize=options.bias_title_fontsize, fontweight='bold')
ax2.set_ylabel(f'BIAS {options.unit}', fontsize=options.label_fontsize)
ax2.tick_params(width=2)
ax2.grid(True, linestyle='--')
style_axes_spines(ax2)
plt.tight_layout()
# ----- SAVE FIGURE -----
output_path = Path(options.output_path)
output_path.mkdir(parents=True, exist_ok=True)
filename = f'{options.variable}_timeseries.png'
save_path = output_path / filename
plt.savefig(save_path, **options.savefig_kwargs)
plt.close()
###############################################################################
###############################################################################
[docs]
def scatter_plot(data_dict: Dict[str, Union[pd.Series, list]], **kwargs: Any) -> None:
"""
Generate a scatter plot comparing daily mean values between model and satellite datasets.
This function creates a single scatter plot showing the relationship between model outputs
and satellite observations. It also includes an identity line (`y = x`) for reference,
allowing visual evaluation of model accuracy.
Parameters
----------
data_dict : Dict[str, Union[pd.Series, list]]
Dictionary containing model and satellite data.
Keys should correspond to model and satellite dataset names.
Values must be 1D arrays or pandas Series.
Keyword Arguments
-----------------
output_path : str or Path
Required. Directory to save figures.
variable_name : str
Required. Variable code (e.g., 'SST').
BA : bool, optional
Whether to append "(Basin Average)" to titles.
figsize : tuple of float, optional
Figure size (width, height).
dpi : int, optional
Figure resolution.
color : str, optional
Scatter point color.
season_colors : dict, optional
Map of season names to colors (for seasonal plots).
alpha : float, optional
Transparency of scatter points.
marker_size : int, optional
Size of scatter markers.
title_fontsize : int, optional
Font size for plot titles.
label_fontsize : int, optional
Font size for axis labels.
tick_labelsize : int, optional
Size of tick labels.
line_width : float, optional
Width of lines (identity, fits, axes).
legend_fontsize : int, optional
Size of legend text.
variable : str, optional
Long name of variable (used in title).
unit : str, optional
Unit of variable (used in labels).
Example
-------
>>> scatter_plot(
... data_dict={'model': model_series, 'satellite': sat_series},
... variable_name='SST',
... output_path='figures/',
... BA=True
... )
"""
# ----- RETRIEVE DEFAULT OPTIONS -----
options = SimpleNamespace(**{**default_plot_options_scatter, **kwargs})
# ----- RETRIEVE NECESSARY OUTPUT PATH AND VARIABLE/UNIT INFO -----
if getattr(options, 'output_path', None) is None:
raise ValueError("output_path must be specified either in kwargs or default options.")
variable_name = getattr(options, 'variable_name', None)
variable = getattr(options, 'variable', None)
unit = getattr(options, 'unit', None)
if variable_name is not None:
# Infer full variable name and unit from short name
variable_label, unit_label = get_variable_label_unit(variable_name)
options.variable = variable or variable_label
options.unit = unit or unit_label
else:
# variable_name not given — require both variable and unit
if variable is None or unit is None:
raise ValueError(
"If 'variable_name' is not provided, both 'variable' and 'unit' must be specified in kwargs or defaults."
)
options.variable = variable
options.unit = unit
# ----- EXTRACT MODEL AND SATELLITE KEYS FROM DATASET FOR PLOTTING -----
mod_key, sat_key = extract_mod_sat_keys(data_dict)
BAmod = pd.Series(data_dict[mod_key])
BAsat = pd.Series(data_dict[sat_key])
# ----- BUILD FULL DATAFRAME -----
df = pd.DataFrame({'Model': BAmod, 'Satellite': BAsat})
# ----- BASIC SEABORN SETTINGS -----
sns.set(style="whitegrid", context='notebook')
sns.set_style("ticks")
# ----- CREATE THE FIGURE -----
fig = plt.figure(figsize=options.figsize, dpi=options.dpi)
ax1 = fig.add_subplot(1, 1, 1)
# ----- CREATE THE MARKERS -----
sns.scatterplot(
x='Model',
y='Satellite',
data=df,
color=options.color,
alpha=options.alpha,
s=options.marker_size,
ax=ax1
)
# ----- SET THE TITLE AND BA TAG IF NECESSARY -----
title = f'Scatter Plot of {options.variable} (Model vs. Satellite)'
if getattr(options, 'BA', False):
title += ' (Basin Average)'
# ----- PLOTTING OPTIONS -----
ax1.set_title(title, fontsize=options.title_fontsize, fontweight='bold')
ax1.set_xlabel(f'{options.variable} (Model) {options.unit}', fontsize=options.label_fontsize)
ax1.set_ylabel(f'{options.variable} (Satellite) {options.unit}', fontsize=options.label_fontsize)
# ----- EXTRACT AND PLOT IDEAL IDENTITY LINE -----
min_val, max_val = get_min_max_for_identity_line(df['Model'], df['Satellite'])
ax1.plot([min_val, max_val], [min_val, max_val], 'k--', linewidth=options.line_width, label='y = x (Ideal)')
# ----- OTHER FORMATTING OPTIONS -----
ax1.tick_params(width=2, labelsize=options.tick_labelsize)
ax1.grid(True, linestyle='--')
list(starmap(lambda _, spine: (spine.set_linewidth(2), spine.set_edgecolor('black')),
enumerate(ax1.spines.values())))
ax1.legend(fontsize=options.legend_fontsize)
plt.tight_layout()
# ----- CHECK IF FOLDER IS AVAILABLE -----
output_path = Path(options.output_path)
output_path.mkdir(parents=True, exist_ok=True)
# ----- SAVE AND PRINT PLOT -----
filename = f'{variable_name or options.variable or "scatterplot"}_scatterplot.png'
save_path = output_path / filename
plt.savefig(save_path)
plt.close()
###############################################################################
###############################################################################
[docs]
def seasonal_scatter_plot(daily_means_dict: Dict[str, Union[np.ndarray, pd.Series]], **kwargs: Any) -> None:
"""
Generate seasonal scatter plots (DJF, MAM, JJA, SON) and a combined plot comparing
model vs satellite daily mean values.
Each seasonal subplot shows:
- Scatter comparison of model and satellite values.
- Identity line (y = x).
- Robust linear regression fit (Huber).
- Nonparametric LOWESS fit.
The final composite plot displays all seasons together, color-coded with a shared legend.
Parameters
----------
daily_means_dict : Dict[str, Union[np.ndarray, pd.Series]]
Dictionary with keys typically "mod" and "sat", each containing 1D arrays or pandas Series
of daily mean values. Assumes data starts from 2000-01-01 and is daily.
Keyword Arguments
-----------------
output_path : str or Path
Required. Directory to save figures.
variable_name : str
Required. Variable code (e.g., 'SST').
BA : bool, optional
Whether to append "(Basin Average)" to titles.
figsize : tuple of int, optional
Figure size (width, height).
dpi : int, optional
Figure resolution.
season_colors : dict, optional
Map of season names to colors. Default covers DJF, MAM, JJA, SON.
alpha : float, optional
Transparency of scatter points.
marker_size : int, optional
Size of scatter markers.
title_fontsize : int, optional
Font size for plot titles.
label_fontsize : int, optional
Font size for axis labels.
line_width : int, optional
Width of lines (identity, fits, axes).
tick_labelsize : int, optional
Size of tick labels.
legend_fontsize : int, optional
Size of legend text.
variable : str, optional
Long name of variable (used in titles).
unit : str, optional
Unit of variable (used in labels).
Returns
-------
None
Saves each seasonal plot and one combined plot to the output path.
Notes
-----
- Assumes data is continuous and aligned from 2000-01-01 onward.
- Handles NaNs automatically during fitting and plotting.
- Useful for visualizing seasonal agreement in long-term model or climate datasets.
Example
-------
>>> seasonal_scatter_plot(
... daily_means_dict={'mod': model_series, 'sat': sat_series},
... variable_name='SST',
... output_path='figures/',
... BA=True
... )
"""
# ----- RETRIEVE DEFAULT OPTIONS -----
options = SimpleNamespace(**{**default_scatter_by_season_options, **kwargs})
# ----- RETRIEVE NECESSARY OUTPUT PATH AND VARIABLE LABEL/UNIT -----
if options.output_path is None:
raise ValueError("output_path must be specified either in kwargs or default options.")
if options.variable_name is not None:
# Infer full variable name and unit from short name
variable, unit = get_variable_label_unit(options.variable_name)
options.variable = options.variable or variable
options.unit = options.unit or unit
else:
# variable_name not given — require both variable and unit
if options.variable is None or options.unit is None:
raise ValueError(
"If 'variable_name' is not provided, both 'variable' and 'unit' must be specified in kwargs or defaults."
)
# ----- ASSIGN DATE RANGE IF NOT AVAILABLE -----
sample_array = next(iter(daily_means_dict.values()))
dates = pd.date_range(start="2000-01-01", periods=len(sample_array), freq='D')
# ----- EXTRACT MODEL AND SATELLITE KEYS -----
mod_key, sat_key = extract_mod_sat_keys(daily_means_dict)
# ----- BUILD MODEL/SATELLITE DATASETS -----
BAmod = np.asarray(daily_means_dict[mod_key])
BAsat = np.asarray(daily_means_dict[sat_key])
# ----- RETRIEVE SEASON LIST -----
seasons = options.season_colors
# ----- CHECK IF THE DIRECTORY IS AVAILABLE, DONE ONCE FOR ALL -----
# ----- PLOTS IN THE FUNCTION -----
output_path = Path(options.output_path)
output_path.mkdir(parents=True, exist_ok=True)
# ----- INITIALIZE EMPTY ARRAYS -----
all_mod_points, all_sat_points, all_colors = [], [], []
# ----- BASIC SEABORN OPTIONS -----
sns.set(style="whitegrid", context='notebook')
sns.set_style("ticks")
# ----- BEGIN PLOTTING THE SEASONAL SCATTERPLOTS -----
for season_name, color in seasons.items():
# ----- MASK TO RETRIEVE THE SEASON -----
mask = get_season_mask(dates, season_name)
# ----- MASK TO USE ONLY VALID POINTS -----
# ----- BOTH MUST NOT BE NANS -----
mod_season = BAmod[mask]
sat_season = BAsat[mask]
valid = ~np.isnan(mod_season) & ~np.isnan(sat_season)
# ----- ASSIGN VALID POINTS TO SEASONAL ARRAYS -----
mod_season = mod_season[valid]
sat_season = sat_season[valid]
# ----- CHECK FOR SEASON NOT BEING EMPTY -----
# ----- OTHERWISE SKIP THE SEASON -----
if mod_season.size == 0:
print(f"Skipping {season_name}: no valid data.")
continue
# USE MODEL/SATELLITE/SEASON INFO TO BUILD THE DATAFRAME -----
df = pd.DataFrame({
'Model': mod_season,
'Satellite': sat_season,
'Season': [season_name] * len(mod_season)
})
# ----- INITIALIZE FIGURE -----
plt.figure(figsize=options.figsize, dpi=options.dpi)
# ----- PLOT -----
ax = sns.scatterplot(
x='Model', y='Satellite', data=df,
color=color, alpha=options.alpha, s=options.marker_size,
label=f'{options.variable} {season_name}'
)
# ----- PERFORM THE HUBER REGRESSION FIT -----
x_vals, y_vals = fit_huber(mod_season, sat_season)
ax.plot(x_vals, y_vals, color='black', linestyle='-', linewidth=options.line_width, label='Linear Fit (Huber)')
# ----- PERFORM THE LOWESS REGRESSION FIT -----
smoothed = fit_lowess(mod_season, sat_season)
ax.plot(smoothed[:, 0], smoothed[:, 1], color='magenta', linestyle='-.', linewidth=options.line_width, label='Smoothed Fit (LOWESS)')
# ----- PERFORM THE IDEAL FIT -----
min_val, max_val = get_min_max_for_identity_line(mod_season, sat_season)
ax.plot([min_val, max_val], [min_val, max_val], 'b--', linewidth=options.line_width, label='Ideal Fit')
# ----- TITLE AND BASIN AVERAGE TAG IF USER SPECIFIED IT -----
# ----- BASIN AVERAGE STRONGLY SUGGESTED TO MARK L4 SATELLITE DATA -----
title = f'{options.variable} Scatter Plot (Model vs Satellite) - {season_name}'
if options.BA:
title += ' (Basin Average)'
# ----- PLOT FORMATTING -----
ax.set_title(title, fontsize=options.title_fontsize, fontweight='bold')
ax.set_xlabel(f'{options.variable} (Model - {season_name}) {options.unit}', fontsize=options.label_fontsize)
ax.set_ylabel(f'{options.variable} (Satellite - {season_name}) {options.unit}', fontsize=options.label_fontsize)
ax.legend(fontsize=options.legend_fontsize)
ax.tick_params(axis='both', labelsize=options.tick_labelsize, width=options.line_width)
style_axes_spines(ax, linewidth=options.line_width, edgecolor='black')
ax.grid(True, linestyle='--')
plt.tight_layout()
# ----- SAVING AND PRINTING THE PLOT -----
filename = f"{options.variable}_{season_name}_scatterplot.png"
plt.savefig(output_path / filename)
plt.close()
# ----- APPEND THE DATA TO FINAL PLOT DATAFRAME -----
all_mod_points.extend(mod_season)
all_sat_points.extend(sat_season)
all_colors.extend([season_name] * len(mod_season))
# ----- CONVERT POINTS INTO NP.ARRAY TO BE USED FOR REGRESSION LINES -----
all_mod_points = np.array(all_mod_points)
all_sat_points = np.array(all_sat_points)
# ----- SWICTH TO SUMMARY PLOT -----
# ----- DONE ONLY IF THE CONTENT OF THE ALL_POINT IS NOT NONE -----
if len(all_mod_points) > 0:
# ----- INITIALIZE FIGURE -----
plt.figure(figsize=options.figsize, dpi=options.dpi)
# ----- BUILD COMBINED DATAFRAME -----
df_combined = pd.DataFrame({
'Model': all_mod_points,
'Satellite': all_sat_points,
'Season': all_colors
})
# ----- PLOT MARKERS -----
ax = sns.scatterplot(
x='Model', y='Satellite', data=df_combined,
hue='Season', palette=seasons,
alpha=options.alpha, s=options.marker_size
)
# ----- PERFORM IDEAL REGRESSION FIT -----
min_val, max_val = get_min_max_for_identity_line(all_mod_points, all_sat_points)
ax.plot([min_val, max_val], [min_val, max_val], 'b--', linewidth=options.line_width, label='Ideal Fit')
# ----- PERFORM HUBER REGRESSION FIT -----
x_vals, y_vals = fit_huber(all_mod_points, all_sat_points)
ax.plot(x_vals, y_vals, color='black', linestyle='-', linewidth=options.line_width, label='Linear Fit (Huber)')
# ----- PERFORM LOWESS FIT -----
smoothed = fit_lowess(all_mod_points, all_sat_points)
ax.plot(smoothed[:, 0], smoothed[:, 1], color='magenta', linestyle='-.', linewidth=options.line_width, label='Smoothed Fit (LOWESS)')
# ----- BUILD HANDLES FOR LEGEND -----
handles = [
Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=season)
for season, color in seasons.items()
] + [
Line2D([0], [0], color='b', linestyle='--', linewidth=options.line_width, label='Ideal Fit'),
Line2D([0], [0], color='black', linestyle='-', linewidth=options.line_width, label='Linear Fit (Huber)'),
Line2D([0], [0], color='magenta', linestyle='-.', linewidth=options.line_width, label='Smoothed Fit (LOWESS)')
]
ax.legend(handles=handles, fontsize=options.legend_fontsize, loc='upper left')
# ----- FORMATTING OPTIONS -----
ax.set_title(f'{options.variable} Scatter Plot (Model vs Satellite) - All Seasons', fontsize=options.title_fontsize, fontweight='bold')
ax.set_xlabel(f'{options.variable} (Model - All Seasons) {options.unit}', fontsize=options.label_fontsize)
ax.set_ylabel(f'{options.variable} (Satellite - All Seasons) {options.unit}', fontsize=options.label_fontsize)
ax.tick_params(axis='both', labelsize=options.tick_labelsize, width=options.line_width)
style_axes_spines(ax, linewidth=options.line_width, edgecolor='black')
ax.grid(True, linestyle='--')
plt.tight_layout()
# ----- SAVE AND PLOT -----
filename = f"{options.variable}_all_seasons_scatterplot.png"
plt.savefig(output_path / filename)
plt.close()
###############################################################################
###############################################################################
[docs]
def whiskerbox(data_dict, **kwargs):
"""
Create a boxplot comparing monthly values of model and satellite data.
This function plots side-by-side boxplots for each month, showing model vs satellite distributions.
It's useful for visualizing variability and central tendency over time.
Parameters
----------
data_dict : dict
Dictionary containing time-series data for model and satellite.
Keys must distinguish between model and satellite (e.g., 'model', 'satellite').
Keyword Arguments
-----------------
output_path : str or Path
Required. Directory to save the resulting PNG plot.
variable_name : str
Required. Variable short name (e.g., 'SST').
variable : str, optional
Full variable name (e.g., 'Sea Surface Temperature').
unit : str, optional
Unit of the variable (e.g., '°C').
figsize : tuple of float, optional
Figure size in inches, e.g., (14, 8).
dpi : int, optional
Plot resolution in dots per inch.
palette : str or list, optional
Seaborn-compatible color palette.
showfliers : bool, optional
Whether to show outlier points in the boxplot.
title_fontsize : int, optional
Font size of the plot title.
title_fontweight : str, optional
Font weight of the plot title (e.g., 'bold').
ylabel_fontsize : int, optional
Font size of the y-axis label.
xlabel : str, optional
Label for the x-axis (default: '').
grid_alpha : float, optional
Transparency for grid lines.
xtick_rotation : int or float, optional
Rotation angle for x-axis tick labels.
tick_width : float, optional
Width of axis ticks.
Returns
-------
None
Saves the boxplot figure as a PNG file to the specified output directory.
Example
-------
>>> whiskerbox(
... data_dict={'model': model_series, 'satellite': sat_series},
... variable_name='Chl',
... output_path='figures/',
... figsize=(14, 8),
... palette='Set2',
... showfliers=False
... )
"""
# ----- FETCH DEFAULT OPTIONS -----
options = SimpleNamespace(**{**default_boxplot_options, **kwargs})
# ----- RETRIEVE NECESSARY OUTPUT PATH AND VARIABLE/UNIT INFO -----
if options.output_path is None:
raise ValueError("output_path must be specified either in kwargs or default options.")
variable_name = getattr(options, 'variable_name', None)
variable = getattr(options, 'variable', None)
unit = getattr(options, 'unit', None)
if variable_name is not None:
var, un = get_variable_label_unit(variable_name)
variable = variable or var
unit = unit or un
else:
if variable is None or unit is None:
raise ValueError(
"If 'variable_name' is not provided, both 'variable' and 'unit' must be specified in kwargs or defaults."
)
# ----- EXTRACT MODEL AND SATELLITE KEYS FROM THE DATASET -----
model_key, sat_key = extract_mod_sat_keys(data_dict)
# ----- DEFINE MONTH NAMES -----
months = [calendar.month_abbr[i] for i in range(1, 13)]
# ----- INITIALIZE ARRAY FOR DATA -----
plot_data = []
# ----- GATHER THE DATA BASED ON MONTHS -----
plot_data = list(chain.from_iterable(
chain(
((val, f"{months[month_idx]} Model") for val in gather_monthly_data_across_years(data_dict, model_key, month_idx)),
((val, f"{months[month_idx]} Satellite") for val in gather_monthly_data_across_years(data_dict, sat_key, month_idx))
)
for month_idx in range(12)
))
# ----- BUILD THE DATAFRAME -----
plot_df = pd.DataFrame(plot_data, columns=['Value', 'Label'])
# ----- INITIALIZE FIGURE -----
plt.figure(figsize=options.figsize, dpi=options.dpi)
# ----- PLOT DATA -----
ax = sns.boxplot(
x='Label', y='Value', data=plot_df,
palette=options.palette,
showfliers=options.showfliers
)
# ----- FORMATTING AND PLOT OPTIONS
ax.set_title(f'Monthly {variable} Comparison: Model vs Satellite',
fontsize=options.title_fontsize,
fontweight=options.title_fontweight)
ylabel = f'{variable} {unit}'
ax.set_ylabel(ylabel, fontsize=options.ylabel_fontsize)
ax.set_xlabel(options.xlabel)
ax.grid(True, linestyle='--', alpha=options.grid_alpha)
plt.xticks(rotation=options.xtick_rotation)
ax.tick_params(width=options.tick_width)
style_axes_spines(ax, linewidth=2, edgecolor='black')
plt.tight_layout()
# ----- CHECK IF FOLDER EXISTS -----
output_path = Path(options.output_path)
output_path.mkdir(parents=True, exist_ok=True)
# ----- PRINT AND SAVE PLOT -----
filename = f'{variable}_boxplot.png'
save_path = output_path / filename
plt.savefig(save_path)
plt.close()
###############################################################################
###############################################################################
[docs]
def violinplot(data_dict, **kwargs):
"""
Plot a violin plot comparing monthly model and satellite values.
Parameters
----------
data_dict : dict
Dictionary containing model and satellite data indexed by datetime.
Keyword Arguments
-----------------
output_path : str or Path
Required. Directory where the figure is saved.
variable_name : str
Short name used to infer full variable name and unit.
variable : str, optional
Full variable name (e.g., "Chlorophyll"). Used in the title.
unit : str, optional
Unit of measurement (e.g., "mg Chl/m³"). Shown on the y-axis.
figsize : tuple of float, optional
Size of the figure in inches.
dpi : int, optional
Resolution of the plot.
palette : list or dict, optional
Colors for the violin plots.
cut : float, optional
Defines how far the violin extends past extreme datapoints.
title_fontsize : int, optional
Font size of the title.
title_fontweight : str or int, optional
Font weight of the title (e.g., 'bold').
ylabel_fontsize : int, optional
Font size of the y-axis label.
xlabel_fontsize : int, optional
Font size of the x-axis label.
xtick_rotation : int, optional
Degree of x-tick label rotation.
tick_width : float, optional
Width of axis ticks.
spine_linewidth : float, optional
Line width for axis spines.
grid_alpha : float, optional
Transparency of grid lines.
Returns
-------
None
Saves the violin plot figure to the specified output directory.
Example
-------
>>> violinplot(
... data_dict={'model': model_series, 'satellite': sat_series},
... variable_name='SST',
... output_path='figures/',
... figsize=(12, 6)
... )
"""
# ----- FETCH DEFAULT OPTIONS -----
options = SimpleNamespace(**{**default_violinplot_options, **kwargs})
# ----- OBTAIN NECESSARY VALUES -----
if options.output_path is None:
raise ValueError("output_path must be specified either in kwargs or default options.")
if options.variable_name is not None:
variable, unit = get_variable_label_unit(options.variable_name)
options.variable = options.variable or variable
options.unit = options.unit or unit
elif options.variable is None or options.unit is None:
raise ValueError("If 'variable_name' is not provided, both 'variable' and 'unit' must be specified.")
# ----- FETCH MOD AND SAT KEYS FROM DICTIONARY -----
model_key, sat_key = extract_mod_sat_keys(data_dict)
# ----- DEFINE MONTH NAMES -----
months = [calendar.month_abbr[i] for i in range(1, 13)]
# ----- GATHER DATA -----
plot_data = list(chain.from_iterable(
chain(
((val, f"{months[month_idx]} Model") for val in gather_monthly_data_across_years(data_dict, model_key, month_idx)),
((val, f"{months[month_idx]} Satellite") for val in gather_monthly_data_across_years(data_dict, sat_key, month_idx))
)
for month_idx in range(12)
))
# ----- SWITCH TO PANDAS DATAFRAME -----
plot_df = pd.DataFrame(plot_data, columns=['Value', 'Label'])
# ----- DEFINE FIGURE ------
plt.figure(figsize=options.figsize, dpi=options.dpi)
# ----- PLOT DATA -----
ax = sns.violinplot(x='Label', y='Value', data=plot_df,
palette=options.palette, cut=options.cut)
# ----- FORMATTING THE PLOT -----
ax.set_title(f'Monthly {options.variable} Comparison: Model vs Satellite',
fontsize=options.title_fontsize, fontweight=options.title_fontweight)
ax.set_ylabel(f'{options.variable} {options.unit}', fontsize=options.ylabel_fontsize)
ax.set_xlabel('', fontsize=options.xlabel_fontsize)
ax.grid(True, linestyle='--', alpha=options.grid_alpha)
plt.xticks(rotation=options.xtick_rotation)
ax.tick_params(width=options.tick_width)
style_axes_spines(ax, linewidth=2, edgecolor='black')
plt.tight_layout()
# ----- CHECK IF PROVIDED FOLDER EXISTS -----
output_path = Path(options.output_path)
output_path.mkdir(parents=True, exist_ok=True)
# ----- PRINT THE PLOT AND SAVE -----
filename = f'{options.variable}_violinplot.png'
plt.savefig(output_path / filename)
plt.close()
###############################################################################
###############################################################################
[docs]
def efficiency_plot(total_value, monthly_values, **kwargs):
"""
Plot efficiency metric (e.g., NSE) for each month with color-coded markers.
Parameters
----------
total_value : float
The efficiency value computed over the full time period (used as reference line).
monthly_values : list of float
Efficiency values for each month (12 total).
Keyword Arguments
-----------------
output_path : str or Path
Required. Directory where the figure is saved.
metric_name : str
Required. Used for the filename (e.g., "NSE").
title : str, optional
Title of the plot (e.g., "Nash-Sutcliffe Efficiency").
y_label : str, optional
LaTeX-formatted y-axis label (e.g., "$E_{rel}$").
figsize : tuple of float, optional
Size of the figure in inches.
dpi : int, optional
Resolution of the plot.
line_color : str, optional
Color of the line connecting monthly points.
line_width : float, optional
Width of the connecting line.
marker_size : float, optional
Size of circular markers.
marker_edge_color : str, optional
Color of marker edge.
marker_edge_width : float, optional
Width of marker edge.
xtick_rotation : int, optional
Degree of x-tick label rotation.
tick_width : float, optional
Width of axis ticks.
spine_width : float, optional
Width of axis spines.
legend_loc : str, optional
Location of the legend.
grid_style : str, optional
Style of grid lines (e.g., "--", ":").
zero_line : dict, optional
Zero reference line options, keys:
- show (bool)
- style (str)
- width (float)
- color (str)
- label (str)
overall_line : dict, optional
Overall mean line options, keys:
- style (str)
- width (float)
- color (str)
- label (str)
Returns
-------
None
Saves the efficiency metric plot to the specified output directory.
Example
-------
>>> plot_efficiency_metric(
... total_value=0.75,
... monthly_values=[0.6, 0.7, 0.8, 0.7, 0.75, 0.76, 0.77, 0.78, 0.74, 0.72, 0.7, 0.69],
... output_path='figures/',
... metric_name='NSE',
... title='Nash-Sutcliffe Efficiency',
... y_label='$E_{rel}$',
... figsize=(10, 6)
... )
"""
# ----- FETCH DEFAULT OPTIONS -----
options = SimpleNamespace(**{**default_efficiency_plot_options, **kwargs})
# --- OPTIONS VALIDATION ---
if options.output_path is None:
raise ValueError("output_path must be specified.")
if options.metric_name is None:
raise KeyError("metric_name must be specified.")
if options.y_label is None:
raise KeyError("y_label must be specified.")
# --- PREPARE DATA ---
months = list(calendar.month_name[1:13])
df = pd.DataFrame({'Month': months, 'Value': monthly_values})
# --- CMAP ---
cmap = plt.cm.RdYlGn
norm = mcolors.Normalize(vmin=0, vmax=1)
marker_colors = [
'gray' if not isinstance(val, (int, float)) else
'red' if val < 0 else
'green' if val > 1 else
cmap(norm(val))
for val in monthly_values
]
# --- SNS/FIGURE SETUP ---
sns.set(style="whitegrid")
sns.set_style("ticks")
plt.figure(figsize=options.figsize, dpi=options.dpi)
ax = sns.lineplot(x='Month', y='Value', data=df,
color=options.line_color,
lw=options.line_width)
# --- XERO LINE PLOT ---
if options.zero_line.get("show", False) and options.title in {
"Nash-Sutcliffe Efficiency",
"Nash-Sutcliffe Efficiency (Logarithmic)",
"Modified NSE ($E_1$, j=1)",
"Relative NSE ($E_{rel}$)"
}:
ax.axhline(0,
linestyle=options.zero_line["style"],
lw=options.zero_line["width"],
color=options.zero_line["color"],
label=options.zero_line["label"])
# --- OVERALL LINE PLOT ---
ax.axhline(total_value,
linestyle=options.overall_line["style"],
lw=options.overall_line["width"],
color=options.overall_line["color"],
label=options.overall_line["label"])
# --- MARKERS ---
for month, value, color in itertools.zip_longest(months, monthly_values, marker_colors):
if value is not None:
ax.plot(month, value, marker='o',
markersize=options.marker_size,
color=color,
markeredgecolor=options.marker_edge_color,
markeredgewidth=options.marker_edge_width)
# --- FORMATTING ---
ax.set_title(options.title, fontsize=options.title_fontsize)
ax.set_xlabel('')
ax.set_ylabel(f'${options.y_label}$', fontsize=options.ylabel_fontsize)
ax.set_xticks(range(len(months)))
ax.set_xticklabels(months, rotation=options.xtick_rotation)
ax.tick_params(width=options.tick_width)
ax.legend(loc=options.legend_loc)
ax.grid(True, linestyle=options.grid_style)
style_axes_spines(ax, linewidth=options.spine_width, edgecolor='black')
plt.tight_layout()
# ----- PRINT AND SAVE -----
output_path = Path(options.output_path)
output_path.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path / f'{options.metric_name}.png')
plt.close()
###############################################################################
###############################################################################
[docs]
def plot_spatial_efficiency(data_array, geo_coords, output_path, title_prefix, **kwargs):
"""
Plot spatial efficiency metric maps (e.g., correlation, NSE) by month or year with Cartopy projection.
Generates a grid of spatial maps showing the spatial distribution of a performance metric
(e.g., correlation, NSE) between model and satellite data for each month or year. Supports
extensive customization via keyword arguments and a centralized defaults dictionary.
Parameters
----------
data_array : xarray.DataArray
3D data with shape (month/year, lat, lon). Must contain either a 'month' or 'year' dimension.
geo_coords : dict
Dictionary containing geographic info:
- latp (2D array): Latitude grid.
- lonp (2D array): Longitude grid.
- MinLambda, MaxLambda (float): Longitude bounds.
- MinPhi, MaxPhi (float): Latitude bounds.
- Epsilon (float, optional): Spatial padding offset for label adjustment.
output_path : str or Path
Directory where the figure will be saved.
title_prefix : str
Title prefix for colorbar and subplot titles (e.g., "Correlation").
Keyword Arguments
-----------------
cmap : str or Colormap, optional
Colormap to use (default: "coolwarm").
vmin : float, optional
Minimum value for colorbar.
vmax : float, optional
Maximum value for colorbar.
suffix : str, optional
Suffix for plot title and filename.
suptitle_fontsize : int, optional
Font size of the super title (reduced by 6 if only one column).
suptitle_fontweight : str, optional
Font weight of the super title.
suptitle_y : float, optional
Vertical position of the super title.
title_fontsize : int, optional
Font size of subplot titles.
title_fontweight : str, optional
Font weight of subplot titles.
cbar_labelsize : int, optional
Font size of colorbar tick labels.
cbar_labelpad : int, optional
Padding between colorbar and label.
cbar_shrink : float, optional
Shrink factor for horizontal colorbar.
cbar_ticks : int, optional
Number of colorbar ticks.
figsize_per_plot : tuple, optional
Size (width, height) per subplot.
max_cols : int, optional
Maximum number of columns in subplot grid.
epsilon : float, optional
Padding fallback if not in geo_coords.
lat_offset_base : float, optional
Extra latitude offset if needed.
gridline_color : str, optional
Color of gridlines.
gridline_style : str, optional
Line style of gridlines (e.g., "--").
gridline_alpha : float, optional
Transparency of gridlines.
gridline_dms : bool, optional
Format labels in degrees-minutes-seconds.
gridline_labels_top : bool, optional
Show gridline labels on top axis.
gridline_labels_right : bool, optional
Show gridline labels on right axis.
projection : str, optional
Cartopy projection class name.
resolution : str, optional
Resolution of coastlines (e.g., "10m").
land_color : str, optional
Color for landmasses.
show : bool, optional
Display the plot interactively.
block : bool, optional
Block execution on plt.show().
dpi : int, optional
Resolution of the output figure.
Raises
------
ValueError
If the `data_array` does not contain a 'month' or 'year' dimension.
Returns
-------
None
Saves the spatial efficiency maps to the specified output directory.
Examples
--------
>>> plot_spatial_efficiency(
... data_array, geo_coords, "figures", "Correlation",
... cmap="coolwarm", vmax=1.0, vmin=-1.0, show=True
... )
"""
# ----- GET DEFAULT OPTIONS -----
options = extract_options(kwargs, spatial_efficiency_defaults)
# ----- SET GEOMETRY -----
latp = geo_coords['latp']
lonp = geo_coords['lonp']
epsilon = geo_coords.get("Epsilon", options["epsilon"])
lat_offset = epsilon + options["lat_offset_base"]
min_lon = geo_coords['MinLambda']
max_lon = geo_coords['MaxLambda']
min_lat = geo_coords['MinPhi']
max_lat = geo_coords['MaxPhi']
# ----- FETCH THE TIME LABELS -----
if 'month' in data_array.dims:
labels = ["January", "February", "March", "April", "May", "June",
"July", "August", "September", "October", "November", "December"]
dim_name = 'month'
n_plots = data_array.sizes['month']
suptitle_prefix = "Monthly"
filename_prefix = "Monthly"
elif 'year' in data_array.dims:
labels = data_array.year.values.astype(str)
dim_name = 'year'
n_plots = data_array.sizes['year']
suptitle_prefix = "Yearly"
filename_prefix = "Yearly"
else:
raise ValueError("data_array must have either 'month' or 'year' dimension.")
# ----- COMPUTE OPTIMAL GRID LAYOUT -----
nrows = int(np.ceil(np.sqrt(n_plots)))
ncols = int(np.ceil(n_plots / nrows))
# ----- SETUP THE FIGURE AND SIZE -----
figsize = ((options["figsize_per_plot"][0] if ncols !=1 else options["figsize_per_plot"][0] + 3) * ncols,
(options["figsize_per_plot"][1] if ncols !=1 else options["figsize_per_plot"][1] + 3) * nrows)
fig = plt.figure(figsize=figsize, dpi=options["dpi"], constrained_layout=True)
gs = GridSpec(nrows, ncols, figure=fig)
axes = []
# ----- ADD SUBPLOTS DYNAMICALLY, CENTER LAST ROW IF NEEDED -----
axes = []
if n_plots > 6 and (n_plots % ncols != 0):
last_row_filled = n_plots % ncols
pad_left = (ncols - last_row_filled) // 2
plot_index = 0
for row in range(nrows):
for col in range(ncols):
# If it's the last row and we're before the pad_left, skip to pad
if row == nrows - 1 and col < pad_left:
continue
# If all plots are placed, break
if plot_index >= n_plots:
break
ax = fig.add_subplot(gs[row, col], projection=getattr(ccrs, options["projection"])())
axes.append(ax)
plot_index += 1
else:
for i in range(n_plots):
row = i // ncols
col = i % ncols
ax = fig.add_subplot(gs[row, col], projection=getattr(ccrs, options["projection"])())
axes.append(ax)
# ---- BUILD ORANGE - GREEEN COLORMAP -----
cmap = options["cmap"]
vmin = options["vmin"]
vmax = options["vmax"]
if cmap == "OrangeGreen":
colors = ['#086e04', 'white', '#ff6700']
base_cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", colors)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
white_start = int(norm(-0.2) * 255)
white_end = int(norm(0.2) * 255)
new_colors = base_cmap(np.linspace(0, 1, 256))
new_colors[white_start:white_end, :] = [1, 1, 1, 1]
cmap = mcolors.ListedColormap(new_colors)
contour_levels = np.linspace(vmin, vmax, 11)
# ----- BEGIN PLOTTING -----
for i, ax in enumerate(axes):
data_slice = data_array.isel({dim_name: i})
ax.set_extent([min_lon, max_lon, min_lat + lat_offset, max_lat], crs=ccrs.PlateCarree())
contour = ax.contourf(
lonp + (0.35 * epsilon), latp + (0.1 * epsilon), data_slice,
levels=contour_levels, cmap=cmap, vmin=vmin, vmax=vmax,
extend="both", transform=ccrs.PlateCarree()
)
ax.coastlines(resolution=options["resolution"])
ax.add_feature(cfeature.LAND, facecolor=options["land_color"])
label = labels[i] if i < len(labels) else f"{dim_name.capitalize()} {i+1}"
ax.set_title(label, fontsize=options["title_fontsize"], fontweight=options["title_fontweight"])
gl = ax.gridlines(draw_labels=True, dms=options["gridline_dms"],
color=options["gridline_color"], linestyle=options["gridline_style"],
alpha=options["gridline_alpha"])
gl.top_labels = options["gridline_labels_top"]
gl.right_labels = options["gridline_labels_right"]
# ----- ADD A COLORBAR -----
unit = options["unit"]
# Only format if unit is not None, else use empty string
if unit is not None:
formatted_unit = format_unit(unit)[1:-1]
else:
formatted_unit = ""
if ncols == 1:
cbar = fig.colorbar(contour, ax=axes, orientation="horizontal",
shrink=options["cbar_shrink"]+0.1,
ticks=np.linspace(vmin, vmax, options["cbar_ticks"]))
if formatted_unit:
cbar.set_label(rf'$\left[{formatted_unit}\right]$', fontsize=16, labelpad=options["cbar_labelpad"])
cbar.ax.tick_params(labelsize=14)
else:
cbar.set_label("", fontsize=16, labelpad=options["cbar_labelpad"])
cbar.ax.tick_params(labelsize=14)
else:
cbar = fig.colorbar(contour, ax=axes, orientation="horizontal",
shrink=(options["cbar_shrink"]),
ticks=np.linspace(vmin, vmax, options["cbar_ticks"]))
if formatted_unit:
cbar.set_label(rf'$\left[{formatted_unit}\right]$', fontsize=22, labelpad=options["cbar_labelpad"])
cbar.ax.tick_params(labelsize=16)
else:
cbar.set_label("", fontsize=22, labelpad=options["cbar_labelpad"])
cbar.ax.tick_params(labelsize=16)
# ----- MAKE THE TITLE -----
detrended = kwargs.get("detrended", False)
det_text = "Detrended" if detrended else "Raw"
# Prepare unit string for title safely:
unit_title = unit if unit is not None else ""
if ncols == 1:
plt.suptitle(
f"{suptitle_prefix} {title_prefix} {unit_title} \n ({det_text}) \n {options['suffix']}",
fontsize=(options["suptitle_fontsize"]-6),
fontweight=options["suptitle_fontweight"],
)
else :
plt.suptitle(
f"{suptitle_prefix} {title_prefix} {unit_title} ({det_text}) \n {options['suffix']}",
fontsize=options["suptitle_fontsize"],
fontweight=options["suptitle_fontweight"],
)
# ----- SET OUTPUT PATH -----
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
safe_title = title_prefix.replace("/", "_").replace("\\", "_")
filename = f"{filename_prefix} {safe_title} ({det_text}) {options['suffix']}.png"
plt.savefig(output_path / filename, bbox_inches='tight', pad_inches=0.2)
plt.close()
###############################################################################
###############################################################################
[docs]
def error_components_timeseries(
stats_df,
output_path,
cloud_cover=None,
variable_name='',
**kwargs
):
"""
Plot time series of error components and optional cloud cover.
Generates a multi-panel plot showing:
- Mean Bias
- Unbiased RMSE
- Standard Deviation of Error
- Correlation
- (Optional) Cloud Cover with smoothed version
Parameters
----------
stats_df : pd.DataFrame
DataFrame with time series of statistical error components. Must include columns:
['mean_bias', 'unbiased_rmse', 'std_error', 'correlation'].
output_path : str or Path
Directory path where the figure will be saved.
cloud_cover : pd.Series, optional
Time series of cloud cover (%) data. Adds extra subplot if provided.
variable_name : str, optional
Variable name (e.g., "SST", "Chlorophyll") used for labeling.
Keyword Arguments
-----------------
fig_width : float, optional
Width of the full figure.
fig_height_per_plot : float, optional
Height allocated per subplot.
sharex : bool, optional
If True, subplots share the x-axis.
title_fontsize : int, optional
Font size of the figure title.
title_fontweight : str, optional
Font weight of the figure title.
label_fontsize : int, optional
Font size of y-axis labels.
grid_color : str, optional
Grid line color.
grid_linestyle : str, optional
Grid line style (e.g., '--').
grid_alpha : float, optional
Transparency of grid lines.
mean_bias_color : str, optional
Line color for Mean Bias subplot.
unbiased_rmse_color : str, optional
Line color for Unbiased RMSE subplot.
std_error_color : str, optional
Line color for Std Error subplot.
correlation_color : str, optional
Line color for Correlation subplot.
cloud_cover_color : str, optional
Line color for raw Cloud Cover.
cloud_cover_smoothed_color : str, optional
Line color for smoothed Cloud Cover.
cloud_cover_rolling_window : int, optional
Window size for rolling smoothing of cloud cover.
spine_linewidth : float, optional
Width of axes spines.
spine_edgecolor : str, optional
Color of axes spines.
filename_template : str, optional
Template for saved filename (e.g., '{}_errors.png').
Returns
-------
None
Saves the multi-panel error component plot to the output path.
Example
-------
>>> error_components_timeseries(
... stats_df=error_df,
... cloud_cover=cloud_series,
... output_path="figures/",
... variable_name="SST"
... )
Notes
-----
- Uses Seaborn and Matplotlib styling.
- Default styles and colors controlled by `default_error_timeseries_options`.
"""
# ----- OPTIONS -----
options = extract_options(kwargs, default_error_timeseries_options)
# ----- STYLE -----
sns.set(style="whitegrid", context='notebook')
sns.set_style("ticks")
# ----- SETUP -----
n_plots = 5 if cloud_cover is not None else 4
fig, axes = plt.subplots(
n_plots, 1,
figsize=(options['fig_width'], options['fig_height_per_plot'] * n_plots),
sharex=options['sharex']
)
# ----- TITLE -----
title = "Comparison between error components timeseries"
if cloud_cover is not None:
title += " and cloud cover"
if variable_name:
title += f" ({variable_name})"
fig.suptitle(title, fontsize=options['title_fontsize'], fontweight=options['title_fontweight'])
fig.subplots_adjust(top=0.85)
# ----- GRID STYLE -----
grid_style = {
'color': options['grid_color'],
'linestyle': options['grid_linestyle'],
'alpha': options['grid_alpha']
}
# ----- MEAN BIAS -----
stats_df['mean_bias'].plot(ax=axes[0], color=options['mean_bias_color'], legend=False)
axes[0].set_ylabel('Mean Bias', fontsize=options['label_fontsize'])
axes[0].grid(**grid_style)
# ----- UNBIASED RMSE -----
stats_df['unbiased_rmse'].plot(ax=axes[1], color=options['unbiased_rmse_color'], legend=False)
axes[1].set_ylabel('Unbiased RMSE', fontsize=options['label_fontsize'])
axes[1].grid(**grid_style)
# ----- STD ERROR -----
stats_df['std_error'].plot(ax=axes[2], color=options['std_error_color'], legend=False)
axes[2].set_ylabel('Std Error', fontsize=options['label_fontsize'])
axes[2].grid(**grid_style)
# ----- CORRELATION -----
stats_df['cross_correlation'].plot(ax=axes[3], color=options['correlation_color'], legend=False)
axes[3].set_ylabel('Correlation', fontsize=options['label_fontsize'])
axes[3].grid(**grid_style)
# ----- CLOUD COVER -----
if cloud_cover is not None:
cloud_cover_30d = cloud_cover.rolling(window=options['cloud_cover_rolling_window'], center=True).mean()
axes[4].plot(cloud_cover.index, cloud_cover, color=options['cloud_cover_color'], label='Cloud Cover')
axes[4].plot(cloud_cover_30d.index, cloud_cover_30d, color=options['cloud_cover_smoothed_color'], label='30-day Smoothed')
axes[4].set_ylabel('Cloud Cover (%)', fontsize=options['label_fontsize'])
axes[4].grid(**grid_style)
axes[4].legend()
else:
axes[3].set_xlabel('')
axes[3].grid(**grid_style)
# ----- STYLE AXES -----
for ax in axes:
ax.label_outer()
style_axes_spines(
ax,
linewidth=options['spine_linewidth'],
edgecolor=options['spine_edgecolor']
)
# ----- LAYOUT -----
plt.tight_layout(rect=[0, 0, 1, 0.96])
# ----- SAVE FIGURE -----
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
filename = options['filename_template'].format(variable_name=variable_name)
plt.savefig(output_path / filename)
plt.close()
###############################################################################
###############################################################################
[docs]
def plot_spectral(
data=None,
plot_type='PSD',
freqs=None,
fft_components=None,
error_comp=None,
cloud_covers=None,
output_path=None,
variable_name=None,
fs=1.0,
nperseg=256,
**kwargs
):
"""
Plot spectral analysis of time series data using either PSD or CSD.
Parameters
----------
data : pd.Series or dict, optional
Optional time series input (not used directly in current implementation).
plot_type : str
Type of spectral plot to generate: 'PSD' (Power Spectral Density) or 'CSD' (Cross Spectral Density).
freqs : array-like, optional
Frequency values used for PSD plotting.
fft_components : dict of arrays, optional
Dictionary mapping labels to FFT-transformed series for PSD plotting.
error_comp : pd.DataFrame or dict, optional
Error component data used in CSD plotting.
cloud_covers : list of (pd.Series, str), optional
List of tuples with cloud cover time series and their labels (used in CSD).
output_path : str or Path
Path to save the resulting spectral plot.
variable_name : str
Short code name of the variable, used in output filename.
fs : float, optional
Sampling frequency (default: 1.0).
nperseg : int, optional
Segment length for computing CSD (default: 256).
Keyword Arguments
-----------------
figsize : tuple, optional
Figure size (e.g., (12, 6)).
xlabel_fontsize : int, optional
Font size of the x-axis label.
ylabel_fontsize : int, optional
Font size of the y-axis label.
title_fontsize : int, optional
Font size of the plot title.
title_fontweight : str, optional
Font weight of the plot title (e.g., 'bold').
tick_labelsize : int, optional
Font size of tick labels.
grid_color : str, optional
Grid line color.
grid_alpha : float, optional
Grid line transparency.
grid_linestyle : str, optional
Grid line style (e.g., '--').
freq_xlim : tuple, optional
Frequency axis limits (e.g., (0.0, 0.5)).
additional_linestyles : list, optional
Linestyles for multiple cloud cover series (e.g., ['--', '-.', ':']).
spine_linewidth : float, optional
Width of plot spines.
spine_edgecolor : str, optional
Color of plot spines.
Raises
------
ValueError
If required inputs are missing or an unknown plot_type is provided.
"""
# ----- OPTIONS -----
options = extract_options(kwargs, default_spectral)
# ----- STYLE -----
sns.set(style="whitegrid", context="notebook")
sns.set_style("ticks")
# ----- FIGURE -----
plt.figure(figsize=options['figsize'])
# ----- PSD -----
if plot_type == 'PSD':
if freqs is None or fft_components is None:
raise ValueError("freqs and fft_components must be provided for PSD plot")
for col, fft_vals in fft_components.items():
if np.all(np.abs(fft_vals) == 0):
continue
continue
with np.errstate(divide='ignore', invalid='ignore'):
plt.plot(freqs, np.abs(fft_vals), label=col)
plt.xlabel('Frequency (1/day)', fontsize=options['xlabel_fontsize'])
plt.ylabel('Aplitude', fontsize=options['ylabel_fontsize'])
plt.title('Power Spectral Density (PSD)', fontsize=options['title_fontsize'], fontweight=options['title_fontweight'])
# ----- CSD -----
elif plot_type == 'CSD':
if error_comp is None:
raise ValueError("error_comp must be provided for CSD plot")
if cloud_covers is None or len(cloud_covers) == 0:
raise ValueError("At least one cloud_cover tuple (data,label) must be provided in cloud_covers")
additional_styles = cycle(options['additional_linestyles'])
columns = error_comp.columns if hasattr(error_comp, 'columns') else list(error_comp.keys())
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
var_colors = {col: colors[i % len(colors)] for i, col in enumerate(columns)}
for i, (cloud_cover, label) in enumerate(cloud_covers):
linestyle = '-' if i == 0 else next(additional_styles)
for col in columns:
try:
f, Pxy = csd(error_comp[col], cloud_cover, fs=fs, nperseg=nperseg)
if np.all(np.abs(Pxy) == 0):
continue
plt.semilogy(f, np.abs(Pxy), linestyle=linestyle, color=var_colors[col], label=f'{col} vs {label}')
except ZeroDivisionError:
continue
plt.xlabel('Frequency (1/day)', fontsize=options['xlabel_fontsize'])
plt.ylabel('Cross Power', fontsize=options['ylabel_fontsize'])
labels = ', '.join([label for _, label in cloud_covers])
plt.title(f'Cross-Spectral Density with {labels}', fontsize=options['title_fontsize'], fontweight=options['title_fontweight'])
# ----- UNKNOWN TYPE -----
else:
raise ValueError(f"Unknown plot_type: {plot_type}")
# ----- FINAL FORMATTING -----
plt.legend()
plt.grid(True, color=options['grid_color'], alpha=options['grid_alpha'], linestyle=options['grid_linestyle'])
plt.xlim(*options['freq_xlim'])
plt.tick_params(axis='both', which='major', labelsize=options['tick_labelsize'])
# ----- SPINES -----
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
# ----- STYLE SPINES -----
style_axes_spines(ax, linewidth=options['spine_linewidth'], edgecolor=options['spine_edgecolor'])
# ----- SAVE FIGURE -----
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
filename = f"Spectral_Plot_{plot_type}_{variable_name}"
plt.savefig(output_path / filename)
plt.close()