Source code for aurora.pipelines.time_series_helpers

"""
    Collection of modules used in processing pipeline that operate on time series
"""
import mt_metadata
import numpy as np
import pandas as pd
import scipy.signal as ssig
import xarray as xr

from loguru import logger

from aurora.time_series.windowed_time_series import WindowedTimeSeries
from aurora.time_series.windowing_scheme import window_scheme_from_decimation


[docs]def validate_sample_rate(run_ts, expected_sample_rate, tol=1e-4): """ Check that the sample rate of a run_ts is the expected value, and warn if not. Parameters ---------- run_ts: mth5.timeseries.run_ts.RunTS Time series object with data and metadata. expected_sample_rate: float The sample rate the time series is expected to have. Normally taken from the processing config """ if run_ts.sample_rate != expected_sample_rate: msg = ( f"sample rate in run time series {run_ts.sample_rate} and " f"processing decimation_obj {expected_sample_rate} do not match" ) logger.warning(msg) delta = run_ts.sample_rate - expected_sample_rate if np.abs(delta) > tol: msg = f"Delta sample rate {delta} > {tol} tolerance" msg += "TOL should be a percentage" raise Exception(msg)
[docs]def apply_prewhitening(decimation_obj, run_xrds_input): """ Applies pre-whitening to time series to avoid spectral leakage when FFT is applied. If "first difference", may want to consider clipping first and last sample from the differentiated time series. Parameters ---------- decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel Information about how the decimation level is to be processed run_xrds_input : xarray.core.dataset.Dataset Time series to be pre-whitened Returns ------- run_xrds : xarray.core.dataset.Dataset pre-whitened time series """ if not decimation_obj.prewhitening_type: return run_xrds_input if decimation_obj.prewhitening_type == "first difference": run_xrds = run_xrds_input.differentiate("time") else: msg = f"{decimation_obj.prewhitening_type} pre-whitening not implemented" logger.exception(msg) raise NotImplementedError(msg) return run_xrds
[docs]def apply_recoloring(decimation_obj, stft_obj): """ Inverts the pre-whitening operation in frequency domain. Parameters ---------- decimation_obj : mt_metadata.transfer_functions.processing.fourier_coefficients.decimation.Decimation Information about how the decimation level is to be processed stft_obj : xarray.core.dataset.Dataset Time series of Fourier coefficients to be recoloured Returns ------- stft_obj : xarray.core.dataset.Dataset Recolored time series of Fourier coefficients """ # No recoloring needed if prewhitening not appiled, or recoloring set to False if not decimation_obj.prewhitening_type: return stft_obj if not decimation_obj.recoloring: return stft_obj if decimation_obj.prewhitening_type == "first difference": freqs = decimation_obj.fft_frequencies prewhitening_correction = 1.0j * 2 * np.pi * freqs # jw stft_obj /= prewhitening_correction # suppress nan and inf to mute later warnings if prewhitening_correction[0] == 0.0: cond = stft_obj.frequency != 0.0 stft_obj = stft_obj.where(cond, complex(0.0)) # elif decimation_obj.prewhitening_type == "ARMA": # from statsmodels.tsa.arima.model import ARIMA # AR = 3 # add this to processing config # MA = 4 # add this to processing config else: msg = f"{decimation_obj.prewhitening_type} recoloring not yet implemented" logger.error(msg) raise NotImplementedError(msg) return stft_obj
[docs]def run_ts_to_stft_scipy(decimation_obj, run_xrds_orig): """ Converts a runts object into a time series of Fourier coefficients. This method uses scipy.signal.spectrogram. Parameters ---------- decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel Information about how the decimation level is to be processed run_xrds_orig : : xarray.core.dataset.Dataset Time series to be processed Returns ------- stft_obj : xarray.core.dataset.Dataset Time series of Fourier coefficients """ run_xrds = apply_prewhitening(decimation_obj, run_xrds_orig) windowing_scheme = window_scheme_from_decimation(decimation_obj) stft_obj = xr.Dataset() for channel_id in run_xrds.data_vars: ff, tt, specgm = ssig.spectrogram( run_xrds[channel_id].data, fs=decimation_obj.sample_rate_decimation, window=windowing_scheme.taper, nperseg=decimation_obj.window.num_samples, noverlap=decimation_obj.window.overlap, detrend="linear", scaling="density", mode="complex", ) # drop Nyquist> ff = ff[:-1] specgm = specgm[:-1, :] specgm *= np.sqrt(2) # make time_axis tt = tt - tt[0] tt *= decimation_obj.sample_rate_decimation time_axis = run_xrds.time.data[tt.astype(int)] xrd = xr.DataArray( specgm.T, dims=["time", "frequency"], coords={"frequency": ff, "time": time_axis}, ) stft_obj.update({channel_id: xrd}) stft_obj = apply_recoloring(decimation_obj, stft_obj) return stft_obj
[docs]def truncate_to_clock_zero(decimation_obj, run_xrds): """ Compute the time interval between the first data sample and the clock zero Identify the first sample in the xarray time series that corresponds to a window start sample. Parameters ---------- decimation_obj: mt_metadata.transfer_functions.processing.aurora.DecimationLevel Information about how the decimation level is to be processed run_xrds : xarray.core.dataset.Dataset normally extracted from mth5.RunTS Returns ------- run_xrds : xarray.core.dataset.Dataset same as the input time series, but possibly slightly shortened """ if decimation_obj.window.clock_zero_type == "ignore": pass else: clock_zero = pd.Timestamp(decimation_obj.window.clock_zero) clock_zero = clock_zero.to_datetime64() delta_t = clock_zero - run_xrds.time[0] assert delta_t.dtype == "<m8[ns]" # expected in nanoseconds delta_t_seconds = int(delta_t) / 1e9 if delta_t_seconds == 0: pass # time series start is already clock zero else: windowing_scheme = window_scheme_from_decimation(decimation_obj) number_of_steps = delta_t_seconds / windowing_scheme.duration_advance n_partial_steps = number_of_steps - np.floor(number_of_steps) n_clip = n_partial_steps * windowing_scheme.num_samples_advance n_clip = int(np.round(n_clip)) t_clip = run_xrds.time[n_clip] cond1 = run_xrds.time >= t_clip msg = ( f"dropping {n_clip} samples to agree with " f"{decimation_obj.window.clock_zero_type} clock zero {clock_zero}" ) logger.info(msg) run_xrds = run_xrds.where(cond1, drop=True) return run_xrds
[docs]def nan_to_mean(xrds: xr.Dataset) -> xr.Dataset: """ Set Nan values in xr.Dataset to the mean value per channel. xrds: xr.Dataset Time series data Returns ------- run_xrds : xr.Dataset The same as the input time series but NaN values are replaced by the mean of the time series (per channel). """ for ch in xrds.keys(): null_values_present = xrds[ch].isnull().any() if null_values_present: msg = "Unexpected Null values in xrds -- this should be examined" logger.warning(msg) value = np.nan_to_num(np.nanmean(xrds[ch].data)) xrds[ch] = xrds[ch].fillna(value) return xrds
[docs]def run_ts_to_stft(decimation_obj, run_xrds_orig): """ Converts a runts object into a time series of Fourier coefficients. Similar to run_ts_to_stft_scipy, but in this implementation operations on individual windows are possible (for example pre-whitening per time window via ARMA filtering). Parameters ---------- decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel Information about how the decimation level is to be processed run_ts : xarray.core.dataset.Dataset normally extracted from mth5.RunTS Returns ------- stft_obj: xarray.core.dataset.Dataset Note that the STFT object may have inf/nan in the DC harmonic, introduced by recoloring. This really doesn't matter since we don't use the DC harmonic for anything. """ # need to remove any nans before windowing, or else if there is a single # nan then the whole channel becomes nan. run_xrds = nan_to_mean(run_xrds_orig) run_xrds = apply_prewhitening(decimation_obj, run_xrds) run_xrds = truncate_to_clock_zero(decimation_obj, run_xrds) windowing_scheme = window_scheme_from_decimation(decimation_obj) windowed_obj = windowing_scheme.apply_sliding_window( run_xrds, dt=1.0 / decimation_obj.sample_rate_decimation ) if not np.prod(windowed_obj.to_array().data.shape): raise ValueError windowed_obj = WindowedTimeSeries.detrend(data=windowed_obj, detrend_type="linear") tapered_obj = WindowedTimeSeries.apply_taper( data=windowed_obj, taper=windowing_scheme.taper ) stft_obj = WindowedTimeSeries.apply_fft( data=tapered_obj, sample_rate=windowing_scheme.sample_rate, spectral_density_correction=windowing_scheme.linear_spectral_density_calibration_factor, detrend_type=decimation_obj.extra_pre_fft_detrend_type, ) stft_obj = apply_recoloring(decimation_obj, stft_obj) return stft_obj
[docs]def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None): """ Calibrates frequency domain data into MT units. Development Notes: The calibration often raises a runtime warning due to DC term in calibration response = 0. TODO: It would be nice to suppress this, maybe by only calibrating the non-dc terms and directly assigning np.nan to the dc component when DC-response is zero. Parameters ---------- stft_obj : xarray.core.dataset.Dataset Time series of Fourier coefficients to be calibrated run_obj : mth5.groups.master_station_run_channel.RunGroup Provides information about filters for calibration units : string usually "MT", contemplating supporting "SI" scale_factors : dict or None keyed by channel, supports a single scalar to apply to that channels data Useful for debugging. Should not be used in production and should throw a warning if it is not None Returns ------- stft_obj : xarray.core.dataset.Dataset Time series of calibrated Fourier coefficients """ for channel_id in stft_obj.keys(): channel = run_obj.get_channel(channel_id) channel_response = channel.channel_response if not channel_response.filters_list: msg = f"Channel {channel_id} with empty filters list detected" logger.warning(msg) if channel_id == "hy": msg = "Channel hy has no filters, try using filters from hx" logger.warning("Channel HY has no filters, try using filters from HX") channel_response = run_obj.get_channel("hx").channel_response indices_to_flip = channel_response.get_indices_of_filters_to_remove( include_decimation=False, include_delay=False ) indices_to_flip = [ i for i in indices_to_flip if channel.metadata.filter.applied[i] ] filters_to_remove = [channel_response.filters_list[i] for i in indices_to_flip] if not filters_to_remove: logger.warning("No filters to remove") calibration_response = channel_response.complex_response( stft_obj.frequency.data, filters_list=filters_to_remove ) if channel_scale_factors: try: channel_scale_factor = channel_scale_factors[channel_id] except KeyError: channel_scale_factor = 1.0 calibration_response /= channel_scale_factor if units == "SI": logger.warning("Warning: SI Units are not robustly supported issue #36") # TODO: This often raises a runtime warning due to DC term in calibration response=0 stft_obj[channel_id].data /= calibration_response return stft_obj
[docs]def prototype_decimate( config: mt_metadata.transfer_functions.processing.aurora.decimation.Decimation, run_xrds: xr.Dataset, ) -> xr.Dataset: """ Basically a wrapper for scipy.signal.decimate. Takes input timeseries (as xarray Dataset) and a Decimation config object and returns a decimated version of the input time series. TODO: Consider moving this function into time_series/decimate.py TODO: Consider Replacing the downsampled_time_axis with rolling mean, or somthing that takes the average value of the time, not the window start TODO: Compare outputs with scipy resample_poly, which also has an FIR AAF and appears faster TODO: Add handling for case that could occur when sliced time axis has a different length than the decimated data -- see mth5 issue #217 https://github.com/kujaku11/mth5/issues/217 Parameters ---------- config : mt_metadata.transfer_functions.processing.aurora.Decimation run_xrds: xr.Dataset Originally from mth5.timeseries.run_ts.RunTS.dataset, but possibly decimated multiple times Returns ------- xr_ds: xr.Dataset Decimated version of the input run_xrds """ # downsample the time axis slicer = slice(None, None, int(config.factor)) # decimation.factor downsampled_time_axis = run_xrds.time.data[slicer] # decimate the time series num_observations = len(downsampled_time_axis) channel_labels = list(run_xrds.data_vars.keys()) # run_ts.channels num_channels = len(channel_labels) new_data = np.full((num_observations, num_channels), np.nan) for i_ch, ch_label in enumerate(channel_labels): new_data[:, i_ch] = ssig.decimate(run_xrds[ch_label], int(config.factor)) xr_da = xr.DataArray( new_data, dims=["time", "channel"], coords={"time": downsampled_time_axis, "channel": channel_labels}, ) attr_dict = run_xrds.attrs attr_dict["sample_rate"] = config.sample_rate xr_da.attrs = attr_dict xr_ds = xr_da.to_dataset("channel") return xr_ds
# Here are some other ways to downsample/decimate that were experimented with. # TODO: Jupyter notebook tutorial showing how these differ (cf standard method). # def prototype_decimate_2(config, run_xrds): # """ # Uses the built-in xarray coarsen method. Not clear what AAF effects are. # Method is fast. Might be non-linear. Seems to give similar performance to # prototype_decimate for synthetic data. # # N.B. config.factor must be integer valued # # Parameters # ---------- # config : mt_metadata.transfer_functions.processing.aurora.Decimation # run_xrds: xr.Dataset # Originally from mth5.timeseries.run_ts.RunTS.dataset, but possibly decimated # multiple times # # Returns # ------- # xr_ds: xr.Dataset # Decimated version of the input run_xrds # """ # new_xr_ds = run_xrds.coarsen(time=int(config.factor), boundary="trim").mean() # attr_dict = run_xrds.attrs # attr_dict["sample_rate"] = config.sample_rate # new_xr_ds.attrs = attr_dict # return new_xr_ds # # # def prototype_decimate_3(config, run_xrds): # """ # Uses scipy's resample method. Not clear what AAF effects are. # Method is fast. # # Parameters # ---------- # config : mt_metadata.transfer_functions.processing.aurora.Decimation # run_xrds: xr.Dataset # Originally from mth5.timeseries.run_ts.RunTS.dataset, but possibly decimated # multiple times # # Returns # ------- # xr_ds: xr.Dataset # Decimated version of the input run_xrds # """ # dt = run_xrds.time.diff(dim="time").median().values # dt_new = config.factor * dt # dt_new = dt_new.__str__().replace("nanoseconds", "ns") # new_xr_ds = run_xrds.resample(time=dt_new).mean(dim="time") # attr_dict = run_xrds.attrs # attr_dict["sample_rate"] = config.sample_rate # new_xr_ds.attrs = attr_dict # return new_xr_ds