Source code for Hydrological_model_validator.Processing.Taylor_computations

###############################################################################
##                                                                           ##
##                               LIBRARIES                                   ##
##                                                                           ##
###############################################################################

# Data handling libraries
import skill_metrics as sm
import numpy as np
import pandas as pd
from typing import Dict, List, Union, Tuple, Optional

# Logging and tracing
import logging
from eliot import start_action, log_message

# Module utilities
from Hydrological_model_validator.Processing.time_utils import Timer
from Hydrological_model_validator.Processing.data_alignment import (get_common_series_by_year, 
                                                                    get_valid_mask, 
                                                                    extract_mod_sat_keys)

###############################################################################
##                                                                           ##
##                               FUNCTIONS                                   ##
##                                                                           ##
###############################################################################


[docs] def compute_taylor_stat_tuple(mod_values: np.ndarray, sat_values: np.ndarray, label: str) -> Tuple[str, float, float, float]: """ Compute Taylor statistics (standard deviation, centered RMSD, and correlation coefficient) for a given pair of model and satellite data arrays. Parameters ---------- mod_values : np.ndarray Array of model data values. sat_values : np.ndarray Array of satellite data values. label : str Identifier associated with these data (e.g., year or month string). Returns ------- Tuple[str, float, float, float] A tuple containing: - label (str): The input label - model standard deviation (float) - centered RMSD (float) - correlation coefficient (float) Raises ------ ValueError If input arrays are empty. If no finite data pairs exist between model and satellite arrays. Example ------- >>> mod = np.array([1.1, 2.0, 3.2]) >>> sat = np.array([1.0, 2.1, 3.0]) >>> compute_taylor_stat_tuple(mod, sat, '2001') ('2001', ..., ..., ...) """ # =====INPUT VALIDATION===== # Ensure inputs are numpy arrays if not isinstance(mod_values, np.ndarray) or not isinstance(sat_values, np.ndarray): raise ValueError("❌ Both 'mod_values' and 'sat_values' must be NumPy arrays. ❌") # Ensure inputs are not empty if mod_values.size == 0 or sat_values.size == 0: raise ValueError("❌ Input arrays must not be empty. ❌") with Timer(f"compute_taylor_stat_tuple: {label}"): with start_action(action_type="compute_taylor_stat_tuple", label=label): logging.info(f"Computing Taylor stats for label: {label}") log_message(f"Start compute_taylor_stat_tuple for label {label}") # =====VALID DATA MASK===== # Find valid finite data points for comparison valid_mask = np.isfinite(mod_values) & np.isfinite(sat_values) if not np.any(valid_mask): raise ValueError("❌ No valid finite data pairs found in input arrays. ❌") # =====FILTER DATA===== # Extract only valid data points mod_valid = mod_values[valid_mask] sat_valid = sat_values[valid_mask] # =====COMPUTE STATISTICS===== # Calculate Taylor stats for filtered data stats = sm.taylor_statistics(mod_valid, sat_valid, 'data') result = (label, stats['sdev'][1], stats['crmsd'][1], stats['ccoef'][1]) logging.info(f"Computed Taylor stats for {label}: {result[1:]}") log_message(f"Completed compute_taylor_stat_tuple for label {label}") return result
############################################################################### ###############################################################################
[docs] def compute_std_reference(sat_data_by_year: Dict[Union[int, str], List[Union[np.ndarray, list]]], years: List[Union[int, str]], month_index: int) -> float: """ Compute the reference standard deviation for satellite data of a specific month across multiple years. Parameters ---------- sat_data_by_year : dict Dictionary keyed by year (int or str), with each value being a list of monthly data arrays. years : list List of years (int or str) to include in the computation. month_index : int Index of the month (0 = January). Can be any valid index that exists in the dataset. Returns ------- float Standard deviation of concatenated satellite values for the specified month across all given years. Raises ------ ValueError If 'month_index' is not a non-negative integer. If no valid data is found for the specified month across the selected years. If any matched monthly array is empty. Example ------- >>> sat_data_by_year = { ... 2000: [np.random.rand(10) for _ in range(6)], # up to June ... 2001: [np.random.rand(10) for _ in range(6)] ... } >>> std = compute_std_reference(sat_data_by_year, [2000, 2001], 2) # March >>> isinstance(std, float) True """ # =====INPUT VALIDATION===== # Ensure month_index is a non-negative integer if not isinstance(month_index, int) or month_index < 0: raise ValueError(f"❌ 'month_index' must be a non-negative integer. Got {month_index}. ❌") with Timer(f"compute_std_reference: month {month_index}"): with start_action(action_type="compute_std_reference", month_index=month_index): logging.info(f"Computing std reference for month_index: {month_index}") log_message(f"Start compute_std_reference for month_index {month_index}") # =====COLLECT MONTHLY DATA===== # Initialize list to hold monthly satellite data across years monthly_data = [] for year in years: # Skip year if not present in satellite data dictionary if year not in sat_data_by_year: continue monthly_series = sat_data_by_year[year] # Skip if month_index out of range for that year if month_index >= len(monthly_series): continue # Flatten monthly data array to 1D for concatenation arr = np.asarray(monthly_series[month_index]).flatten() # Ensure monthly data is not empty if arr.size == 0: raise ValueError(f"❌ Empty data array for year {year}, month index {month_index}. ❌") # Append valid data array to collection monthly_data.append(arr) # =====VALIDATION===== # Check if any valid monthly data was collected if not monthly_data: raise ValueError(f"❌ No valid satellite data found for month index {month_index} across given years. ❌") # =====CONCATENATE & COMPUTE STD===== # Concatenate all monthly arrays into one all_monthly_sat = np.concatenate(monthly_data) std_value = np.nanstd(all_monthly_sat) logging.info(f"Computed std reference: {std_value} for month_index: {month_index}") log_message(f"Completed compute_std_reference with std {std_value} for month_index {month_index}") return std_value
############################################################################### ###############################################################################
[docs] def compute_norm_taylor_stats(mod_vals: np.ndarray, sat_vals: np.ndarray, std_ref: float) -> Optional[Dict[str, float]]: """ Compute normalized Taylor statistics for a given pair of model and satellite data arrays. Parameters ---------- mod_vals : np.ndarray Array of model data values. sat_vals : np.ndarray Array of satellite data values. std_ref : float Reference standard deviation to normalize the statistics. Returns ------- dict or None Dictionary containing: - 'sdev' : Normalized model standard deviation - 'crmsd': Normalized centered root-mean-square difference - 'ccoef': Correlation coefficient Returns None if there are no valid overlapping values. Raises ------ ValueError If `std_ref` is not a positive number. Example ------- >>> mod = np.array([1.0, 2.0, 3.0]) >>> sat = np.array([1.1, 2.1, 3.1]) >>> std_ref = 0.5 >>> stats = compute_norm_taylor_stats(mod, sat, std_ref) >>> stats.keys() dict_keys(['sdev', 'crmsd', 'ccoef']) """ # =====INPUT VALIDATION===== # std_ref must be a positive number for normalization if not isinstance(std_ref, (int, float)) or std_ref <= 0: raise ValueError(f"❌ 'std_ref' must be a positive number. Got {std_ref}. ❌") with Timer("compute_norm_taylor_stats"): with start_action(action_type="compute_norm_taylor_stats", std_ref=std_ref): logging.info(f"Computing normalized Taylor stats with std_ref={std_ref}") log_message(f"Start compute_norm_taylor_stats with std_ref={std_ref}") # =====VALID DATA MASK===== # Determine valid (finite) overlapping data points in both arrays valid = get_valid_mask(mod_vals, sat_vals) if not np.any(valid): logging.info("No valid overlapping data points found.") log_message("No valid overlapping data points found, returning None") return None # =====COMPUTE TAYLOR STATISTICS===== # Compute Taylor stats on valid data subset stats = sm.taylor_statistics(mod_vals[valid], sat_vals[valid], 'data') # =====NORMALIZE & RETURN===== result = { "sdev": stats['sdev'][1] / std_ref, "crmsd": stats['crmsd'][1] / std_ref, "ccoef": stats['ccoef'][1], } logging.info(f"Computed normalized stats: {result}") log_message(f"Completed compute_norm_taylor_stats with result {result}") return result
############################################################################### ###############################################################################
[docs] def build_all_points( data_dict: Dict[Union[str, int], Dict[int, List[Union[np.ndarray, list]]]] ) -> Tuple[pd.DataFrame, List[Union[str, int]]]: """ Build a DataFrame of normalized Taylor statistics points for all months and years, including reference points per month. Parameters ---------- data_dict : dict Dictionary containing model and satellite data structured as: { 'model_key': { year: [monthly data arrays/lists] }, 'satellite_key': { year: [monthly data arrays/lists] } } Years can be strings or integers, months are indexed 0-based. Returns ------- tuple of (pandas.DataFrame, list) - DataFrame with columns ['sdev', 'crmsd', 'ccoef', 'month', 'year'] containing normalized Taylor statistics for each year and month, plus monthly reference points. - List of years found in the satellite data. Raises ------ KeyError If expected model or satellite keys are missing in `data_dict`. Notes ----- - Months with invalid or zero reference standard deviation are skipped. - The reference point per month has sdev=1, crmsd=0, ccoef=1, labeled year='Ref'. Example ------- >>> data_dict = { ... 'model': { ... 2000: [np.array([...]), np.array([...]), ...], ... 2001: [np.array([...]), np.array([...]), ...], ... }, ... 'satellite': { ... 2000: [np.array([...]), np.array([...]), ...], ... 2001: [np.array([...]), np.array([...]), ...], ... } ... } >>> df, years = build_all_points(data_dict) >>> df.head() sdev crmsd ccoef month year 0 1.00 0.00 1.00 0 Ref 1 0.85 0.12 0.95 0 2000 2 0.88 0.10 0.96 0 2001 ... """ with Timer("build_all_points"): with start_action(action_type="build_all_points"): # =====EXTRACT MODEL AND SATELLITE KEYS===== mod_key, sat_key = extract_mod_sat_keys(data_dict) # =====VALIDATE PRESENCE OF REQUIRED KEYS===== if mod_key not in data_dict or sat_key not in data_dict: raise KeyError(f"❌ Expected keys '{mod_key}' and '{sat_key}' not found in data_dict. ❌") model_data_by_year = data_dict[mod_key] sat_data_by_year = data_dict[sat_key] # =====SORT YEARS FROM SATELLITE DATA===== years = sorted(sat_data_by_year.keys()) # =====DETERMINE MAXIMUM MONTHS AVAILABLE===== max_months = max(len(sat_data_by_year[year]) for year in years if year in sat_data_by_year) std_refs = {} # =====COMPUTE REFERENCE STANDARD DEVIATION PER MONTH===== for month_idx in range(max_months): try: std_refs[month_idx] = compute_std_reference(sat_data_by_year, years, month_idx) except ValueError: # Skip months without valid reference standard deviation continue all_points = [] # =====BUILD DATA POINTS INCLUDING REFERENCE AND NORMALIZED STATS===== for month_idx, std_ref in std_refs.items(): if std_ref <= 0 or np.isnan(std_ref): continue # Add reference point for perfect agreement in this month all_points.append({ "sdev": 1.0, "crmsd": 0.0, "ccoef": 1.0, "month": month_idx, "year": "Ref" }) for year in years: try: mod_vals = np.asarray(model_data_by_year[year][month_idx]) sat_vals = np.asarray(sat_data_by_year[year][month_idx]) except (IndexError, KeyError): continue norm_stats = compute_norm_taylor_stats(mod_vals, sat_vals, std_ref) if norm_stats is None: continue all_points.append({ **norm_stats, "month": month_idx, "year": year }) df = pd.DataFrame(all_points) log_message(f"Built {len(df)} Taylor stat points across {len(years)} years and {len(std_refs)} months.") logging.info(f"build_all_points completed: {len(df)} points, years={years}, months={list(std_refs.keys())}") return df, years
############################################################################### ###############################################################################
[docs] def compute_yearly_taylor_stats( data_dict: Dict[Union[str, int], Dict[int, List[Union[np.ndarray, list]]]] ) -> Tuple[List[Tuple[str, float, float, float]], float]: """ Compute Taylor statistics for each year using model and satellite data from the data dictionary. Also computes the global standard deviation of all satellite data. Parameters ---------- data_dict : dict Dictionary containing model and satellite data organized by year and month. Expected structure: { 'model_key': { year: [monthly data arrays/lists] }, 'satellite_key': { year: [monthly data arrays/lists] } } Returns ------- tuple - yearly_stats : list of tuples List of (year, sdev, crmsd, ccoef) tuples representing Taylor statistics for each year. - std_ref : float Global satellite standard deviation across all years and months, used as a normalization reference. Raises ------ KeyError If expected model or satellite keys are missing in data_dict. ValueError If global satellite standard deviation is zero or NaN (indicating invalid data). Example ------- >>> yearly_stats, std_ref = compute_yearly_taylor_stats(data_dict) >>> for year, sdev, crmsd, ccoef in yearly_stats: ... print(f"{year}: sdev={sdev:.2f}, crmsd={crmsd:.2f}, ccoef={ccoef:.2f}") ... 2000: sdev=0.85, crmsd=0.12, ccoef=0.95 2001: sdev=0.88, crmsd=0.10, ccoef=0.96 ... """ # =====EXTRACT MODEL AND SATELLITE KEYS===== mod_key, sat_key = extract_mod_sat_keys(data_dict) # =====CHECK THAT KEYS EXIST IN INPUT DICTIONARY===== if mod_key not in data_dict or sat_key not in data_dict: raise KeyError(f"❌ Expected keys '{mod_key}' and '{sat_key}' not found in data_dict. ❌") sat_data_by_year = data_dict[sat_key] # =====FLATTEN ALL SATELLITE DATA ACROSS ALL YEARS AND MONTHS===== all_sat_data = np.concatenate([ np.asarray(month_array).flatten() for year_data in sat_data_by_year.values() for month_array in year_data if month_array is not None and np.asarray(month_array).size > 0 ]) # =====COMPUTE GLOBAL STANDARD DEVIATION AS REFERENCE===== std_ref = np.nanstd(all_sat_data) # =====VALIDATE STANDARD DEVIATION===== if np.isnan(std_ref) or std_ref == 0: raise ValueError("Global satellite standard deviation is zero or NaN, indicating invalid or missing data.") # =====START TIMER AND ELIOT TRACING FOR THE MAIN COMPUTATION===== with Timer("compute_yearly_taylor_stats"): with start_action(action_type="compute_yearly_taylor_stats"): # =====ALIGN DATA BY YEAR TO OBTAIN COMMON SERIES FOR MODEL AND SATELLITE===== aligned_data = get_common_series_by_year(data_dict) # =====COMPUTE TAYLOR STATISTICS FOR EACH YEAR===== yearly_stats = [ compute_taylor_stat_tuple(mod_values, sat_values, str(year)) for year, mod_values, sat_values in aligned_data ] log_message(f"Computed yearly Taylor stats for {len(yearly_stats)} years with std_ref={std_ref:.3f}.") logging.info(f"compute_yearly_taylor_stats completed: years={len(yearly_stats)}, std_ref={std_ref}") # =====RETURN THE LIST OF YEARLY STATS AND GLOBAL STD REF===== return yearly_stats, std_ref