Source code for Hydrological_model_validator.Processing.time_utils

###############################################################################
##                                                                           ##
##                               LIBRARIES                                   ##
##                                                                           ##
###############################################################################

# Type hinting
from typing import Dict, List, Union, Any, Tuple

# Data handling
import pandas as pd
import numpy as np
import xarray as xr
import itertools

# Parallel processing and diagnostics
from dask.diagnostics import ProgressBar
from concurrent.futures import ThreadPoolExecutor

# Logging and tracing
import logging
from eliot import start_action, log_message
from eliot.stdlib import EliotHandler

# Date and time handling
from datetime import datetime

# Context management
from contextlib import contextmanager

###############################################################################
##                                                                           ##
##                               FUNCTIONS                                   ##
##                                                                           ##
###############################################################################


[docs] def leapyear(year: int) -> int: """ Check if a given year is a leap year. Parameters ---------- year : int Year as a positive integer. Returns ------- int Returns 1 if the year is a leap year, 0 otherwise. Raises ------ ValueError If `year` is not a positive integer. Examples -------- >>> leapyear(2020) 1 >>> leapyear(1900) 0 >>> leapyear(2000) 1 >>> leapyear(2023) 0 """ # Validate input is a positive integer if not (isinstance(year, int) and year > 0): raise ValueError("❌ Year must be a positive integer. ❌") with Timer("leapyear function"): with start_action(action_type="leapyear function", year=year): # Leap year if divisible by 4 but not 100, or divisible by 400 result = int((year % 4 == 0 and year % 100 != 0) or (year % 400 == 0)) log_message("Computed leap year result", year=year, result=result) logging.info(f"Year {year} is {'a leap year' if result else 'not a leap year'}.") return result
############################################################################### ###############################################################################
[docs] def true_time_series_length( chlfstart: List[int], chlfend: List[int], DinY: int ) -> int: """ Calculate the true time series length in days over multiple files, accounting for leap years. Parameters ---------- chlfstart : list[int] List of start years per file. chlfend : list[int] List of end years per file. DinY : int Number of days in a normal year (expected 365). Returns ------- int Total number of days including leap years. Raises ------ ValueError If input types or values do not meet expectations. Examples -------- >>> true_time_series_length([2000], [2001], 365) 731 >>> true_time_series_length([1999, 2001], [2000, 2002], 365) 1096 """ # ===== VALIDATION ===== if len(chlfstart) != len(chlfend): raise ValueError("❌ chlfstart and chlfend must be the same length. ❌") # Matching start/end lists if not all(isinstance(x, int) for x in itertools.chain(chlfstart, chlfend)): raise ValueError("❌ chlfstart and chlfend must contain integers only. ❌") # All years are ints if not all(end >= start for start, end in zip(chlfstart, chlfend)): raise ValueError("❌ Each chlfend must be >= corresponding chlfstart. ❌") # Valid year ranges if not (isinstance(DinY, int) and DinY == 365): raise ValueError("❌ DinY must be 365. ❌") # Only standard year length allowed with Timer("true_time_series_length computation"): logging.info(f"Inputs received - chlfstart: {chlfstart}, chlfend: {chlfend}, DinY: {DinY}") log_message("Inputs received", chlfstart=chlfstart, chlfend=chlfend, DinY=DinY) # ===== CALCULATION ===== # Generate all years covered by the file ranges inclusive years = itertools.chain.from_iterable(range(start, end + 1) for start, end in zip(chlfstart, chlfend)) # Sum days per year adding 1 day if leap year, else 0 total_days = 0 for year in years: days = DinY + leapyear(year) total_days += days logging.info(f"Year: {year}, days added: {days}") log_message("Year days added", year=year, days=days) logging.info(f"Computed total_days: {total_days}") log_message("Computed total_days", total_days=total_days) return total_days
############################################################################### ###############################################################################
[docs] def split_to_monthly( yearly_data: Dict[int, Union[pd.Series, pd.DataFrame]] ) -> Dict[int, List[Union[pd.Series, pd.DataFrame]]]: """ Split yearly pandas Series or DataFrames with datetime index into monthly segments. Parameters ---------- yearly_data : dict[int, pd.Series or pd.DataFrame] Dictionary keyed by year, with values being pandas Series or DataFrames indexed by datetime. Returns ------- dict[int, list[pd.Series or pd.DataFrame]] Dictionary keyed by year, each containing a list of 12 elements corresponding to monthly slices of the data (January to December). Months with no data will have empty Series or DataFrames of the same type. Raises ------ ValueError If yearly_data is not a dictionary or values are not pandas Series/DataFrames with datetime-like indexes. Examples -------- >>> import pandas as pd >>> import numpy as np >>> dates = pd.date_range('2020-01-01', '2020-12-31') >>> data = pd.Series(np.random.rand(len(dates)), index=dates) >>> yearly = {2020: data} >>> monthly = split_to_monthly(yearly) >>> len(monthly[2020]) 12 >>> monthly[2020][0].index.month.unique() Int64Index([1], dtype='int64') """ # ===== VALIDATION ===== if not isinstance(yearly_data, dict): raise ValueError("❌ Input yearly_data must be a dictionary. ❌") for year, data in yearly_data.items(): if not isinstance(year, int): raise ValueError(f"❌ Year keys must be int, got {type(year)}. ❌") # Ensure year keys are ints if not isinstance(data, (pd.Series, pd.DataFrame)): raise ValueError(f"❌ Values must be pandas Series or DataFrame, got {type(data)}. ❌") # Validate types if not pd.api.types.is_datetime64_any_dtype(data.index): raise ValueError(f"❌ Data index for year {year} must be datetime-like. ❌") # Confirm datetime index with Timer("Splitting yearly data into monthly segments"): monthly_data_dict = {} for year, data in yearly_data.items(): logging.info(f"Processing year {year}, data type: {type(data).__name__}") log_message("Processing year", year=year, data_type=type(data).__name__) # Create an empty template slice (same type, no rows) for months with no data empty_template = data.iloc[0:0] # ===== SPLITTING ===== monthly_slices = [] for month in range(1, 13): month_data = data.loc[data.index.month == month] # Filter data for each month if month_data.empty: monthly_slices.append(empty_template.copy()) # Append empty slice if no data logging.info(f"Empty month slice added for year {year}, month {month}") log_message("Empty month slice added", year=year, month=month) else: monthly_slices.append(month_data) # Append actual data for month logging.info(f"Month slice added for year {year}, month {month}, length {len(month_data)}") log_message("Month slice added", year=year, month=month, length=len(month_data)) monthly_data_dict[year] = monthly_slices # Store monthly slices for the year logging.info(f"Completed splitting all years, total years processed: {len(monthly_data_dict)}") log_message("Completed splitting all years", years_processed=len(monthly_data_dict)) return monthly_data_dict
############################################################################### ###############################################################################
[docs] def split_to_yearly( series: pd.Series, unique_years: List[Union[int, str]] ) -> Dict[Union[int, str], pd.Series]: """ Split a pandas Series with a datetime index into a dictionary keyed by year. Parameters ---------- series : pd.Series Time-indexed pandas Series with a datetime index. unique_years : list of int or str List of years to split the series into. Years can be int or string representations. Returns ------- dict of year (int or str) to pd.Series Dictionary keyed by year containing the Series filtered for that year. Raises ------ ValueError If the series does not have a DatetimeIndex or unique_years contains invalid types. Examples -------- >>> import pandas as pd >>> dates = pd.date_range('2020-01-01', periods=365) >>> s = pd.Series(range(365), index=dates) >>> split_yearly = split_to_yearly(s, [2020]) >>> list(split_yearly.keys()) [2020] >>> split_yearly[2020].index.year.unique() Int64Index([2020], dtype='int64') """ # ===== VALIDATION ===== if not isinstance(series.index, pd.DatetimeIndex): raise ValueError("❌ series must have a DatetimeIndex ❌") # Ensure datetime index if not all(isinstance(year, (int, str)) for year in unique_years): raise ValueError("❌ unique_years must contain only ints or strs ❌") # Validate years list with Timer("split_to_yearly computation"): logging.info(f"Received series with index type: {type(series.index)}") log_message("Input validation passed", unique_years=unique_years) yearly_data = {} # ===== SPLITTING ===== for year in unique_years: try: year_int = int(year) # Convert year to int for comparison except ValueError: raise ValueError(f"❌ Year {year} cannot be converted to int ❌") filtered = series.loc[series.index.year == year_int].copy() yearly_data[year] = filtered # Store filtered series keyed by original year type logging.info(f"Year {year} filtered with {len(filtered)} records") log_message("Yearly data filtered", year=year, count=len(filtered)) logging.info(f"Split completed for years: {list(yearly_data.keys())}") log_message("Completed splitting to yearly", years=list(yearly_data.keys())) return yearly_data
############################################################################### ###############################################################################
[docs] def get_common_years( data_dict: Dict[str, Dict[Union[int, str], Any]], mod_key: str, sat_key: str ) -> List[Union[int, str]]: """ Get sorted years present in both model and satellite datasets. Parameters ---------- data_dict : dict Dictionary where keys are dataset names and values are dictionaries keyed by year. mod_key : str Key for the model dataset in data_dict. sat_key : str Key for the satellite dataset in data_dict. Returns ------- List[int or str] Sorted list of years present in both model and satellite datasets. Raises ------ ValueError If mod_key or sat_key are not present in data_dict or their values are not dictionaries. Examples -------- >>> data = { ... 'model': {2020: 'data1', 2021: 'data2'}, ... 'satellite': {2021: 'dataA', 2022: 'dataB'} ... } >>> get_common_years(data, 'model', 'satellite') [2021] """ # ===== VALIDATE KEYS ===== if mod_key not in data_dict: raise ValueError(f"❌ mod_key '{mod_key}' not found in data_dict ❌") # Ensure model key exists if sat_key not in data_dict: raise ValueError(f"❌ sat_key '{sat_key}' not found in data_dict ❌") # Ensure satellite key exists if not isinstance(data_dict[mod_key], dict): raise ValueError(f"❌ data_dict[{mod_key}] must be a dictionary ❌") # Model data must be dict if not isinstance(data_dict[sat_key], dict): raise ValueError(f"❌ data_dict[{sat_key}] must be a dictionary ❌") # Satellite data must be dict with Timer("get_common_years computation"): logging.info(f"Validated presence and type of keys: '{mod_key}', '{sat_key}'") log_message("Validated keys", mod_key=mod_key, sat_key=sat_key) # ===== COMPUTE COMMON YEARS ===== mod_years = set(data_dict[mod_key].keys()) # Set of model years sat_years = set(data_dict[sat_key].keys()) # Set of satellite years common_years = sorted(mod_years.intersection(sat_years)) logging.info(f"Common years computed: {common_years}") log_message("Computed common years", common_years=common_years) return common_years
############################################################################### ###############################################################################
[docs] def get_season_mask( dates: Union[pd.DatetimeIndex, pd.Series], season_name: str ) -> np.ndarray: """ Generate a boolean mask for the given season on a datetime index or series. Parameters ---------- dates : pd.DatetimeIndex or pd.Series Datetime-like index or pandas Series with a datetime index. season_name : str Season name. Must be one of {'DJF', 'MAM', 'JJA', 'SON'}. Returns ------- np.ndarray Boolean mask array indicating whether each date falls in the specified season. Raises ------ ValueError If `season_name` is not one of the expected values. TypeError If `dates` is not a pandas DatetimeIndex or a Series with DatetimeIndex. Examples -------- >>> import pandas as pd >>> dates = pd.date_range('2023-01-01', periods=12, freq='M') >>> get_season_mask(dates, 'DJF') array([ True, True, False, False, False, False, False, False, False, False, False, True]) """ # ===== VALIDATE SEASON NAME ===== valid_seasons = {'DJF', 'MAM', 'JJA', 'SON'} if season_name not in valid_seasons: raise ValueError(f"❌ Invalid season name: {season_name}. Expected one of {valid_seasons} ❌") # ===== VALIDATE DATES TYPE AND EXTRACT DATETIME INDEX ===== if isinstance(dates, pd.Series): if not isinstance(dates.index, pd.DatetimeIndex): raise TypeError("❌ If input is pd.Series, its index must be a pd.DatetimeIndex ❌") dt_index = dates.index # Use index if Series elif isinstance(dates, pd.DatetimeIndex): dt_index = dates # Directly use DatetimeIndex else: raise TypeError("❌ dates must be a pandas DatetimeIndex or a Series with DatetimeIndex ❌") with Timer(f"get_season_mask computation for season '{season_name}'"): logging.info(f"Validated season_name: {season_name}") log_message("Validated season_name", season_name=season_name) logging.info(f"Using datetime index of length {len(dt_index)}") log_message("Datetime index info", length=len(dt_index)) # ===== MAP SEASON TO MONTHS ===== season_months = { 'DJF': {12, 1, 2}, 'MAM': {3, 4, 5}, 'JJA': {6, 7, 8}, 'SON': {9, 10, 11}, } months = dt_index.month # Extract month numbers from dates mask = np.isin(months, list(season_months[season_name])) # Boolean mask for season months logging.info(f"Computed season mask for '{season_name}'") log_message("Computed season mask", season_name=season_name, mask_shape=mask.shape) return mask
############################################################################### ###############################################################################
[docs] def resample_and_compute( model_sst_chunked: Union[xr.DataArray, xr.Dataset], sat_sst_chunked: Union[xr.DataArray, xr.Dataset] ) -> Tuple[Union[xr.DataArray, xr.Dataset], Union[xr.DataArray, xr.Dataset]]: """ Resample the input chunked SST datasets to monthly means and compute them concurrently. Parameters ---------- model_sst_chunked : xarray.DataArray or Dataset The model SST dataset chunked for dask processing. sat_sst_chunked : xarray.DataArray or Dataset The satellite SST dataset chunked for dask processing. Returns ------- model_sst_monthly : xarray.DataArray or Dataset The computed monthly mean resampled model SST. sat_sst_monthly : xarray.DataArray or Dataset The computed monthly mean resampled satellite SST. """ with Timer("Resample and compute monthly SST datasets concurrently"): # ===== RESAMPLE TO MONTHLY MEAN (LAZY) ===== model_sst_monthly_lazy = model_sst_chunked.resample(time='1MS').mean() # Resample model to start-of-month monthly means (lazy) sat_sst_monthly_lazy = sat_sst_chunked.resample(time='1MS').mean() # Resample satellite similarly (lazy) logging.info("Resampled model and satellite SST to monthly means (lazy)") log_message("Resampled SST to monthly means", model_resample="lazy", sat_resample="lazy") # ===== COMPUTE RESAMPLED DATA IN PARALLEL ===== with ProgressBar(), ThreadPoolExecutor(max_workers=2) as executor: # Schedule parallel compute tasks for model and satellite with thread scheduler future_model = executor.submit(model_sst_monthly_lazy.compute, scheduler='threads') future_sat = executor.submit(sat_sst_monthly_lazy.compute, scheduler='threads') model_sst_monthly = future_model.result() # Get computed model results sat_sst_monthly = future_sat.result() # Get computed satellite results logging.info("Computed monthly SST datasets concurrently") log_message("Computed SST datasets", model_computed=True, satellite_computed=True) return model_sst_monthly, sat_sst_monthly
############################################################################### ###############################################################################
[docs] def is_invalid_time_index( time_index: Union[pd.Index, np.ndarray] ) -> bool: """ Check whether a given time index is invalid based on dtype and value range. This function validates if the input `time_index` is a valid datetime index. It considers the time index invalid if: - The dtype is not a datetime64 type. - All timestamps fall within a very narrow range starting from the Unix epoch (1970-01-01) and the differences between consecutive timestamps are extremely small (less than 1 millisecond). Such a time index might indicate corrupted or placeholder data. Parameters ---------- time_index : array-like An array or pandas Index representing time values, expected to be datetime-like. Returns ------- bool True if the time index is considered invalid, False otherwise. Example ------- >>> import numpy as np >>> import pandas as pd >>> # Valid datetime index >>> idx = pd.date_range("2023-01-01", periods=3) >>> is_invalid_time_index(idx) False >>> # Invalid: non-datetime dtype >>> is_invalid_time_index(np.array([1, 2, 3])) True """ # Check if the dtype of time_index is a subtype of np.datetime64 if not np.issubdtype(time_index.dtype, np.datetime64): return True # Not datetime, so invalid # Convert to pandas DatetimeIndex for easier time operations dt_index = pd.DatetimeIndex(time_index) start = pd.Timestamp("1970-01-01") # Unix epoch start max_allowed_delta = pd.Timedelta('1 day') # Max allowed range for invalid detection # Check if all times are within one day from the Unix epoch start if (dt_index.min() >= start) and (dt_index.max() <= start + max_allowed_delta): # Calculate differences between consecutive timestamps diffs = dt_index.to_series().diff().dropna() # Only flag as invalid if diffs exist and all are less than 1 millisecond if len(diffs) > 0 and all(diffs < pd.Timedelta(microseconds=1000)): return True # Otherwise, the time index is valid return False
###############################################################################
[docs] def prompt_for_datetime_index(length: int) -> pd.DatetimeIndex: """ Prompt the user to manually enter a valid datetime index for a time series. When an invalid or missing time index is detected, this function interacts with the user to obtain a valid start date and frequency. It then generates a pandas DatetimeIndex of the specified length with the given frequency. Parameters ---------- length : int The desired length of the datetime index to generate. Returns ------- pd.DatetimeIndex A pandas DatetimeIndex object starting from the user-provided date, with the specified frequency and length. Example ------- >>> idx = prompt_for_datetime_index(10) Enter the start date for the time series (e.g. 2000-01-01): 2020-01-01 Enter the frequency (e.g. 'D' for daily, 'H' for hourly): D Generated datetime index from 2020-01-01 00:00:00 with frequency 'D'. """ # Notify user about the invalid or missing time index print("⚠️ Detected invalid or missing time index.") border = "#" * 60 while True: # Prompt for start date input as string start_date_str = input("Enter the start date for the time series (e.g. 2000-01-01): ").strip() # Prompt for frequency string, e.g., 'D', 'H', 'M' freq = input("Enter the frequency (e.g. 'D' for daily, 'H' for hourly): ").strip() try: # Convert input string to pandas Timestamp start_date = pd.to_datetime(start_date_str) # Generate a datetime index with the given start, length, and frequency time_index = pd.date_range(start=start_date, periods=length, freq=freq) # Inform user about the generated datetime index print(f"Generated datetime index from {start_date} with frequency '{freq}'.") print("\n" + border + "\n") # Return the generated datetime index return time_index except Exception as e: # Catch errors (e.g. invalid date format or freq) and prompt again print(f"Invalid input: {e}. Please try again.") print("\n" + border + "\n")
############################################################################### ###############################################################################
[docs] def ensure_datetime_index(series: pd.Series, label: str) -> pd.Series: """ Ensure that a pandas Series has a DatetimeIndex. If not, prompt the user to create one. This function checks whether the index of the provided pandas Series is a DatetimeIndex. If the index is not datetime-based, it asks the user to input a start date and frequency, then generates and assigns a new DatetimeIndex to the Series accordingly. Parameters ---------- series : pd.Series The pandas Series whose index is to be checked and possibly converted. label : str A descriptive name for the series, used in prompts and messages. Returns ------- pd.Series The original Series if it already had a DatetimeIndex, or the Series with a newly created DatetimeIndex based on user input. Example ------- >>> s = pd.Series([1, 2, 3]) >>> s = ensure_datetime_index(s, "Sample Series") Please enter the start date for Sample Series (YYYY-MM-DD): 2020-01-01 Please enter the data frequency (e.g. 'D') for Sample Series: D DatetimeIndex created for Sample Series from 2020-01-01 with frequency 'D'. """ border = "#" * 60 # Check if the series index is already a DatetimeIndex if not isinstance(series.index, pd.DatetimeIndex): print(f"{label} has no DatetimeIndex.") # Prompt user for the start date string start_date_str = input(f"Please enter the start date for {label} (YYYY-MM-DD): ").strip() # Prompt user for frequency string (e.g., 'D' for daily) freq_str = input(f"Please enter the data frequency (e.g. 'D') for {label}: ").strip() # Convert input start date string to pandas Timestamp start_date = pd.to_datetime(start_date_str) # Generate a new DatetimeIndex based on user inputs, matching series length series.index = pd.date_range(start=start_date, periods=len(series), freq=freq_str) print(f"DatetimeIndex created for {label} from {start_date_str} with frequency '{freq_str}'.") print("\n" + border + "\n") # Return the series, original or modified return series
############################################################################### ## ## ## TIMER ## ## ## ############################################################################### # ----- SETUP OF THE LOGGER AND TIMER ----- # Clear existing handlers on root logger root_logger = logging.getLogger() if root_logger.hasHandlers(): root_logger.handlers.clear() # --- File handler: capture everything (including third-party libs) --- file_handler = logging.FileHandler("app.log", mode='a', encoding='utf-8') file_handler.setLevel(logging.INFO) # or DEBUG if you want file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) # --- Stream handler: only show YOUR logs in console --- stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.WARNING) # Filter out INFO/DEBUG logs from chat stream_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) # Add handlers root_logger.addHandler(file_handler) root_logger.addHandler(stream_handler) root_logger.addHandler(EliotHandler()) root_logger.setLevel(logging.INFO) # Optional: raise level for noisy libs to WARNING so they skip stream logging.getLogger("xarray").setLevel(logging.INFO) # still captured in file, but not printed logging.getLogger("netCDF4").setLevel(logging.INFO) logging.getLogger("h5netcdf").setLevel(logging.INFO) # --- Timer context manager that logs to both systems ---
[docs] @contextmanager def Timer(description): start = datetime.now() log_message("Timer started", description=description, start=str(start)) logging.info(f"[Timer Start] {description} at {start}") with start_action(action_type="timed block", description=description, start_time=str(start)) as action: try: yield finally: end = datetime.now() elapsed = end - start action.add_success_fields(end_time=str(end), elapsed=str(elapsed)) log_message("Timer ended", description=description, end=str(end), elapsed=str(elapsed)) logging.info(f"[Timer End] {description} at {end} (Elapsed: {elapsed})")