Source code for aurora.pipelines.transfer_function_helpers

"""
This module contains helper methods that are used during transfer function processing.

Development Notes:
 Note #1: repeatedly applying edf_weights seems to have no effect at all.
 tested 20240118 and found that test_compare in synthetic passed whether this was commented
 or not.  TODO confirm this is a one-and-done add doc about why this is so.

"""

from aurora.time_series.frequency_band_helpers import get_band_for_tf_estimate
from aurora.time_series.xarray_helpers import handle_nan
from aurora.transfer_function.regression.base import RegressionEstimator
from aurora.transfer_function.regression.iter_control import IterControl
from aurora.transfer_function.regression.RME import RME
from aurora.transfer_function.regression.RME_RR import RME_RR

from aurora.transfer_function.weights.edf_weights import (
    effective_degrees_of_freedom_weights,
)
from mt_metadata.transfer_functions.processing.aurora.decimation_level import (
    DecimationLevel as AuroraDecimationLevel,
)
from loguru import logger
from typing import Literal, Optional, Union

import numpy as np
import xarray as xr

ESTIMATOR_LIBRARY = {"OLS": RegressionEstimator, "RME": RME, "RME_RR": RME_RR}
SUPPORTED_REGRESSION_ESTIMATOR = Literal["OLS", "RME", "RME_RR"]


[docs]def get_estimator_class( estimation_engine: SUPPORTED_REGRESSION_ESTIMATOR, ) -> RegressionEstimator: """ Parameters ---------- estimation_engine: Literal["OLS", "RME", "RME_RR"] One of the keys in the ESTIMATOR_LIBRARY, designates the method that will be used to estimate the transfer function Returns ------- estimator_class: aurora.transfer_function.regression.base.RegressionEstimator The class that will do the TF estimation """ try: estimator_class = ESTIMATOR_LIBRARY[estimation_engine] except KeyError: logger.error(f"processing_scheme {estimation_engine} not supported") logger.error( f"processing_scheme must be one of {list(ESTIMATOR_LIBRARY.keys())}" ) raise Exception return estimator_class
[docs]def set_up_iter_control(config: AuroraDecimationLevel): """ Initializes an IterControl object based on values in the processing config. Development Notes: TODO: Review: maybe better to just make this the __init__ method of the IterControl object, iter_control = IterControl(config) Parameters ---------- config: AuroraDecimationLevel metadata about the decimation level processing. Returns ------- iter_control: aurora.transfer_function.regression.iter_control.IterControl Object with parameters about iteration control in regression """ if config.estimator.engine in ["RME", "RME_RR"]: iter_control = IterControl( max_number_of_iterations=config.regression.max_iterations, max_number_of_redescending_iterations=config.regression.max_redescending_iterations, r0=config.regression.r0, u0=config.regression.u0, tolerance=config.regression.tolerance, ) elif config.estimator.engine in [ "OLS", ]: iter_control = None return iter_control
[docs]def drop_nans(X: xr.Dataset, Y: xr.Dataset, RR: Union[xr.Dataset, None]) -> tuple: """ Drops any observation where any variable in X, Y, or RR is NaN. """ import numpy as np def get_obs_mask(ds): """ Generate a boolean mask indicating which 'observation' entries are finite across all data variables in an xarray Dataset. This function iterates over all data variables in the provided xarray Dataset `ds`, checks for finite values (i.e., not NaN or infinite) along all axes except the 'observation' axis, and combines the results to produce a single boolean mask. The resulting mask is True for each 'observation' index where all data variables are finite, and False otherwise. Parameters ds : xarray.Dataset The input dataset containing data variables with an 'observation' dimension. Returns numpy.ndarray A boolean array with shape matching the 'observation' dimension, where True indicates all data variables are finite """ mask = None for v in ds.data_vars.values(): # Reduce all axes except 'observation' axes = tuple(i for i, d in enumerate(v.dims) if d != "observation") this_mask = np.isfinite(v) if axes: this_mask = this_mask.all(axis=axes) mask = this_mask if mask is None else mask & this_mask return mask mask = get_obs_mask(X) mask = mask & get_obs_mask(Y) if RR is not None: mask = mask & get_obs_mask(RR) X = X.isel(observation=mask) Y = Y.isel(observation=mask) if RR is not None: RR = RR.isel(observation=mask) return X, Y, RR
[docs]def stack_fcs(X, Y, RR): """ Reshape 2D arrays of frequency and time to 1D. Notes: When the data for a frequency band are extracted from the Spectrogram, each channel is a 2D array, one axis is time (the time of the window that was FFT-ed) and the other axis is frequency. However if we make no distinction between the harmonics (bins) within a band in regression, then all the FCs for each channel can be put into a 1D array. This method performs that reshaping (ravelling) operation. It is not important how we unravel the FCs but it is important that the same indexing scheme is used for X, Y and RR. TODO: Consider this take a list and return a list rather than X,Y,RR TODO: Consider decorate this with @dataset_or_dataarray Parameters ---------- X: xarray.core.dataset.Dataset Y: xarray.core.dataset.Dataset RR: xarray.core.dataset.Dataset or None Returns ------- X, Y, RR: Same as input but with stacked time and frequency dimensions """ X = X.stack(observation=("frequency", "time")) Y = Y.stack(observation=("frequency", "time")) if RR is not None: RR = RR.stack(observation=("frequency", "time")) return X, Y, RR
[docs]def apply_weights( X: xr.Dataset, Y: xr.Dataset, RR: xr.Dataset, W: np.ndarray, segment: bool = False, dropna: bool = False, ) -> tuple: """ Applies data weights (W) to each of X, Y, RR. If weight is zero, we set to nan and optionally dropna. Parameters ---------- X: xarray.core.dataset.Dataset Y: xarray.core.dataset.Dataset RR: xarray.core.dataset.Dataset or None W: numpy array The Weights to apply to the data segment: bool If True the weights may need to be reshaped. dropna: bool Whether or not to drop zero-weighted data. If true, we drop the nans. Returns ------- X, Y, RR: tuple Same as input but with weights applied and (optionally) nan dropped. """ W[W == 0] = np.nan if segment: W = np.atleast_2d(W).T X *= W Y *= W if RR is not None: RR *= W if dropna: X, Y, RR = drop_nans(X, Y, RR) return X, Y, RR
[docs]def process_transfer_functions( dec_level_config: AuroraDecimationLevel, local_stft_obj: xr.Dataset, remote_stft_obj: xr.Dataset, transfer_function_obj, ): """ This is the main tf_processing method. It is based on the Matlab legacy code TTFestBand.m. Note #1: Although it is advantageous to execute the regression channel-by-channel vs. all-at-once, we need to keep the all-at-once to get residual covariances (see aurora issue #87) TODO: Consider push the nan-handling into the band extraction as a kwarg. Parameters ---------- dec_level_config: AuroraDecimationLevel Processing parameters for the active decimation level. local_stft_obj: xarray.core.dataset.Dataset remote_stft_obj: xarray.core.dataset.Dataset or None transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ The transfer function container ready to receive values in this method. Returns ------- transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ """ estimator_class: RegressionEstimator = get_estimator_class( dec_level_config.estimator.engine ) iter_control = set_up_iter_control(dec_level_config) for band in transfer_function_obj.frequency_bands.bands(): X, Y, RR = get_band_for_tf_estimate( band, dec_level_config, local_stft_obj, remote_stft_obj ) # Reshape to 2d X, Y, RR = stack_fcs(X, Y, RR) # Should only be needed if weights were applied X, Y, RR = drop_nans(X, Y, RR) W = effective_degrees_of_freedom_weights(X, RR, edf_obj=None) X, Y, RR = apply_weights(X, Y, RR, W, segment=False, dropna=True) if dec_level_config.estimator.estimate_per_channel: for ch in dec_level_config.output_channels: Y_ch = Y[ch].to_dataset() # keep as a dataset, maybe not needed X_, Y_, RR_ = handle_nan(X, Y_ch, RR, drop_dim="observation") # see note #1 # if RR is not None: # W = effective_degrees_of_freedom_weights(X_, RR_, edf_obj=None) # X_, Y_, RR_ = apply_weights(X_, Y_, RR_, W, segment=False) regression_estimator = estimator_class( X=X_, Y=Y_, Z=RR_, iter_control=iter_control ) regression_estimator.estimate() transfer_function_obj.set_tf(regression_estimator, band.center_period) else: X, Y, RR = handle_nan(X, Y, RR, drop_dim="observation") regression_estimator = estimator_class( X=X, Y=Y, Z=RR, iter_control=iter_control ) regression_estimator.estimate() transfer_function_obj.set_tf(regression_estimator, band.center_period) return transfer_function_obj
[docs]def process_transfer_functions_with_weights( dec_level_config: AuroraDecimationLevel, local_stft_obj: xr.Dataset, remote_stft_obj: xr.Dataset, transfer_function_obj, ): """ This is version of process_transfer_functions applies weights to the data. Development Notes: Note #1: This is only for per-channel estimation, so it does not support the dec_level_config.estimator.estimate_per_channel = False Note #2: This was adapted from the process_transfer_functions method but the core loop is inverted to loop over channels first, then bands. Parameters ---------- dec_level_config: AuroraDecimationLevel Processing parameters for the active decimation level. local_stft_obj: xarray.core.dataset.Dataset remote_stft_obj: xarray.core.dataset.Dataset or None transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ The transfer function container ready to receive values in this method. Returns ------- transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ """ if not dec_level_config.estimator.estimate_per_channel: msg = ( "process_transfer_functions_with_weights is only for per-channel estimation" ) logger.error(msg) raise ValueError(msg) estimator_class: RegressionEstimator = get_estimator_class( dec_level_config.estimator.engine ) iter_control = set_up_iter_control(dec_level_config) for ch in dec_level_config.output_channels: # check if there are channel weights for this channel weights = None for chws in dec_level_config.channel_weight_specs: if ch in chws.output_channels: weights = chws.weights for band in transfer_function_obj.frequency_bands.bands(): X, Y, RR = get_band_for_tf_estimate( band, dec_level_config, local_stft_obj, remote_stft_obj ) Y_ch = Y[ch].to_dataset() # keep as a dataset, maybe not needed # extract the weights for this band if weights is not None: # TODO: Investigate best way to extract the weights for band # This may involve finding the nearest frequency bin to the band center period # and then applying the weights for that bin, or some tapered region around it. # For now, we will just use the mean of the weights for the band. # This is a temporary solution and should be replaced with a more robust method. # band_weights = chws.get_weights_for_band(band) band_weights = weights.mean(axis=1) # chws.get_weights_for_band(band) apply_weights( X, Y_ch, RR, band_weights.squeeze(), segment=True, dropna=False ) X, Y_ch, RR = stack_fcs(X, Y_ch, RR) # Reshape to 2d # Should only be needed if weights were applied X, Y_ch, RR = drop_nans(X, Y_ch, RR) W = effective_degrees_of_freedom_weights(X, RR, edf_obj=None) X, Y_ch, RR = apply_weights(X, Y_ch, RR, W, segment=False, dropna=True) X_, Y_, RR_ = handle_nan(X, Y_ch, RR, drop_dim="observation") regression_estimator = estimator_class( X=X_, Y=Y_, Z=RR_, iter_control=iter_control ) regression_estimator.estimate() transfer_function_obj.set_tf(regression_estimator, band.center_period) return transfer_function_obj