Source code for aurora.pipelines.process_mth5

"""
This module contains the main methods used in processing mth5 objects to transfer functions.

The main function is called process_mth5.
This function was recently changed to process_mth5_legacy, os that process_mth5
can be repurposed for other TF estimation schemes.  The "legacy" version
corresponds to aurora default processing.


Notes on process_mth5_legacy:
Note 1: process_mth5 assumes application of cascading decimation, and that the
decimated data will be accessed from the previous decimation level.  This should be
revisited. It may make more sense to have a get_decimation_level() interface that
provides an option of applying decimation or loading pre-decimated data.
This will be addressed via creation of the FC layer inside mth5.

Note 2: We can encounter cases where some runs can be decimated and others can not.
We need a way to handle this. For example, a short run may not yield any data from a
later decimation level. An attempt to handle this has been made in TF Kernel by
adding a is_valid_dataset column, associated with each run-decimation level pair.


Note 3: This point in the loop marks the interface between _generation_ of the FCs and
 their _usage_. In future the code above this comment would be pushed into
 create_fourier_coefficients() and the code below this would access those FCs and
 execute compute_transfer_function().
 This would also be an appropriate place to place a feature extraction layer, and
 compute weights for the FCs.

"""

import mth5.groups

# =============================================================================
# Imports
# =============================================================================

from aurora.pipelines.time_series_helpers import calibrate_stft_obj
from aurora.pipelines.time_series_helpers import run_ts_to_stft
from aurora.pipelines.transfer_function_helpers import (
    process_transfer_functions,
)
from aurora.pipelines.transfer_function_kernel import TransferFunctionKernel
from aurora.pipelines.transfer_function_kernel import station_obj_from_row
from aurora.sandbox.triage_metadata import triage_run_id
from aurora.transfer_function.transfer_function_collection import (
    TransferFunctionCollection,
)
from aurora.transfer_function.TTFZ import TTFZ
from loguru import logger
from mth5.helpers import close_open_files
from typing import Optional, Union

import aurora.config.metadata.processing
import pandas as pd
import xarray as xr


SUPPORTED_PROCESSINGS = [
    "legacy",
]

# =============================================================================


[docs]def make_stft_objects( processing_config, i_dec_level, run_obj, run_xrds, units="MT" ): """ Operates on a "per-run" basis. Applies STFT to all time series in the input run. This method could be modified in a multiple station code so that it doesn't care if the station is "local" or "remote" but rather uses scale factors keyed by station_id (WIP - issue #329) Parameters ---------- processing_config: mt_metadata.transfer_functions.processing.aurora.Processing Metadata about the processing to be applied i_dec_level: int The decimation level to process run_obj: mth5.groups.master_station_run_channel.RunGroup The run to transform to stft run_xrds: xarray.core.dataset.Dataset The data time series from the run to transform units: str expects "MT". May change so that this is the only accepted set of units station_id: str To be deprecated, this information is contained in the run_obj as run_obj.station_group.metadata.id Returns ------- stft_obj: xarray.core.dataset.Dataset Time series of calibrated Fourier coefficients per each channel in the run """ stft_config = processing_config.get_decimation_level(i_dec_level) stft_obj = run_ts_to_stft(stft_config, run_xrds) run_id = run_obj.metadata.id if run_obj.station_metadata.id == processing_config.stations.local.id: scale_factors = processing_config.stations.local.run_dict[ run_id ].channel_scale_factors elif run_obj.station_metadata.id == processing_config.stations.remote[0].id: scale_factors = ( processing_config.stations.remote[0] .run_dict[run_id] .channel_scale_factors ) stft_obj = calibrate_stft_obj( stft_obj, run_obj, units=units, channel_scale_factors=scale_factors, ) return stft_obj
[docs]def process_tf_decimation_level( config: aurora.config.metadata.processing.Processing, i_dec_level: int, local_stft_obj: xr.core.dataset.Dataset, remote_stft_obj: Union[xr.core.dataset.Dataset, None], units="MT", ): """ Processing pipeline for a single decimation_level TODO: Add a check that the processing config sample rates agree with the data TODO: Add units to local_stft_obj, remote_stft_obj sampling rates otherwise raise Exception This method can be single station or remote based on the process cfg Parameters ---------- config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel Config for a single decimation level i_dec_level: int decimation level_id ?could we pack this into the decimation level as an attr? local_stft_obj: xarray.core.dataset.Dataset The time series of Fourier coefficients from the local station remote_stft_obj: xarray.core.dataset.Dataset or None The time series of Fourier coefficients from the remote station units: str one of ["MT","SI"] Returns ------- transfer_function_obj : aurora.transfer_function.TTFZ.TTFZ The transfer function values packed into an object """ frequency_bands = config.decimations[i_dec_level].frequency_bands_obj() transfer_function_obj = TTFZ( i_dec_level, frequency_bands, processing_config=config ) dec_level_config = config.decimations[i_dec_level] # segment_weights = coherence_weights(dec_level_config, local_stft_obj, remote_stft_obj) transfer_function_obj = process_transfer_functions( dec_level_config, local_stft_obj, remote_stft_obj, transfer_function_obj ) return transfer_function_obj
# def enrich_row(row): # pass
[docs]def triage_issue_289(local_stfts: list, remote_stfts: list): """ Takes STFT objects in and returns them after shape-checking and making sure they are same. WIP: Timing Error Workaround See Aurora Issue #289. Seems associated with getting one fewer sample than expected from the edge of a run. returns: Tuple (local_stfts, remote_stfts) The original input arguments, shape-matched """ n_chunks = len(local_stfts) for i_chunk in range(n_chunks): ok = local_stfts[i_chunk].time.shape == remote_stfts[i_chunk].time.shape if not ok: logger.warning( "Mismatch in FC array lengths detected -- Issue #289" ) glb = max( local_stfts[i_chunk].time.min(), remote_stfts[i_chunk].time.min(), ) lub = min( local_stfts[i_chunk].time.max(), remote_stfts[i_chunk].time.max(), ) cond1 = local_stfts[i_chunk].time >= glb cond2 = local_stfts[i_chunk].time <= lub local_stfts[i_chunk] = local_stfts[i_chunk].where( cond1 & cond2, drop=True ) cond1 = remote_stfts[i_chunk].time >= glb cond2 = remote_stfts[i_chunk].time <= lub remote_stfts[i_chunk] = remote_stfts[i_chunk].where( cond1 & cond2, drop=True ) assert ( local_stfts[i_chunk].time.shape == remote_stfts[i_chunk].time.shape ) return local_stfts, remote_stfts
[docs]def merge_stfts(stfts: dict, tfk: TransferFunctionKernel): """ Applies concatenation along the time axis to multiple arrays of STFTs from different runs. At the TF estimation level we treat all the FCs in one array. This builds the array for both the local and the remote STFTs. Parameters ---------- stfts: dict The dict is keyed by "local" and "remote". Each value is a list of STFTs (one list for local and one for remote) tfk: TransferFunctionKernel Just here to let us know if there is a remote reference to merge or not. Returns ------- local_merged_stft_obj, remote_merged_stft_obj: Tuple Both are xr.Datasets """ # Timing Error Workaround See Aurora Issue #289 local_stfts = stfts["local"] remote_stfts = stfts["remote"] if tfk.config.stations.remote: local_stfts, remote_stfts = triage_issue_289(local_stfts, remote_stfts) remote_merged_stft_obj = xr.concat(remote_stfts, "time") else: remote_merged_stft_obj = None local_merged_stft_obj = xr.concat(local_stfts, "time") return local_merged_stft_obj, remote_merged_stft_obj
[docs]def append_chunk_to_stfts(stfts: dict, chunk: xr.Dataset, remote: bool) -> dict: """ Aggregate one STFT into a larger dictionary that tracks all the STFTs Parameters ---------- stfts: dict has keys "local" and "remote". chunk: xr.Dataset The data to append to the dictionary remote: bool If True, append the chunk to stfts["remote"], else append to stfts["local"] Returns ------- stfts: dict Same as input but now has new chunk appended to it. """ if remote: stfts["remote"].append(chunk) else: stfts["local"].append(chunk) return stfts
# def load_spectrogram_from_station_object(station_obj, fc_group_id, fc_decimation_id): # """ # Placeholder. This could also be a method in mth5 # Returns # ------- # # """ # return station_obj.fourier_coefficients_group.get_fc_group(fc_group_id).get_decimation_level(fc_decimation_id)
[docs]def load_stft_obj_from_mth5( i_dec_level: int, row: pd.Series, run_obj: mth5.groups.RunGroup, channels: Optional[list] = None, ) -> xr.Dataset: """ Load stft_obj from mth5 (instead of compute) Note #1: See note #1 in time_series.frequency_band_helpers.extract_band Parameters ---------- i_dec_level: int The decimation level where the data are stored within the Fourier Coefficient group row: pandas.core.series.Series A row of the TFK.dataset_df run_obj: mth5.groups.run.RunGroup The original time-domain run associated with the data to load Returns ------- stft_chunk: xr.Dataset An STFT from mth5. """ station_obj = station_obj_from_row(row) fc_group = station_obj.fourier_coefficients_group.get_fc_group( run_obj.metadata.id ) fc_decimation_level = fc_group.get_decimation_level(f"{i_dec_level}") stft_obj = fc_decimation_level.to_xarray(channels=channels) cond1 = stft_obj.time >= row.start.tz_localize(None) cond2 = stft_obj.time <= row.end.tz_localize(None) try: stft_chunk = stft_obj.where(cond1 & cond2, drop=True) except TypeError: # see Note #1 tmp = stft_obj.to_array() tmp = tmp.where(cond1 & cond2, drop=True) stft_chunk = tmp.to_dataset("variable") return stft_chunk
[docs]def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None: """ Optionally saves the stft object into the MTH5. Note that the dec_level_config must have its save_fcs attr set to True to actually save the data. WIP Note #1: Logic for building FC layers: If the processing config decimation_level.save_fcs_type = "h5" and fc_levels_already_exist is False, then open in append mode, else open in read mode. We should support a flag: force_rebuild_fcs, normally False. This flag is only needed when save_fcs_type=="h5". If True, then we open in append mode, regarless of fc_levels_already_exist The task of setting mode="a", mode="r" can be handled by tfk (maybe in tfk.validate()) Parameters ---------- dec_level_config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel The information about decimation level associated with row, run, stft_obj row: pd.Series A row of the TFK.dataset_df run_obj: mth5.groups.run.RunGroup The run object associated with the STFTs. stft_obj: xr.Dataset The data to pack intp the mth5. """ if not dec_level_config.save_fcs: msg = "Skip saving FCs. dec_level_config.save_fc = " msg = f"{msg} {dec_level_config.save_fcs}" logger.info(f"{msg}") return i_dec_level = dec_level_config.decimation.level if dec_level_config.save_fcs_type == "csv": msg = "Unless debugging or testing, saving FCs to csv unexpected" logger.warning(msg) csv_name = f"{row.station}_dec_level_{i_dec_level}.csv" stft_df = stft_obj.to_dataframe() stft_df.to_csv(csv_name) elif dec_level_config.save_fcs_type == "h5": logger.info(("Saving FC level")) station_obj = station_obj_from_row(row) if not row.mth5_obj.h5_is_write(): msg = "See Note #1: Logic for building FC layers" raise NotImplementedError(msg) # Get FC group (create if needed) if ( run_obj.metadata.id in station_obj.fourier_coefficients_group.groups_list ): fc_group = station_obj.fourier_coefficients_group.get_fc_group( run_obj.metadata.id ) else: fc_group = station_obj.fourier_coefficients_group.add_fc_group( run_obj.metadata.id ) decimation_level_metadata = dec_level_config.to_fc_decimation() # Get FC Decimation Level (create if needed) dec_level_name = f"{i_dec_level}" if dec_level_name in fc_group.groups_list: fc_decimation_level = fc_group.get_decimation_level(dec_level_name) fc_decimation_level.metadata = decimation_level_metadata else: fc_decimation_level = fc_group.add_decimation_level( dec_level_name, decimation_level_metadata=decimation_level_metadata, ) fc_decimation_level.from_xarray( stft_obj, decimation_level_metadata.sample_rate ) fc_decimation_level.update_metadata() fc_group.update_metadata() else: msg = f"dec_level_config.save_fcs_type {dec_level_config.save_fcs_type} not recognized" msg = f"{msg} \n Skipping save fcs" logger.warning(msg) return
[docs]def get_spectrogams(tfk, i_dec_level, units="MT"): """ Given a decimation level id, loads a dictianary of all spectragrams from information in tfk. TODO: Make this a method of TFK Parameters ---------- tfk: TransferFunctionKernel i_dec_level: integer The decimation level of the spectrograms. units: str "MT" or "SI", likely to be deprecated Returns ------- stfts: dict The short time fourier transforms for the decimation level as a dictionary. Keys are "local" and "remote". Values are lists, one (element) xr.Dataset per run """ stfts = {} stfts["local"] = [] stfts["remote"] = [] # Check first if TS processing or accessing FC Levels for i, row in tfk.dataset_df.iterrows(): # This iterator could be updated to iterate over row-pairs if remote is True, # corresponding to simultaneous data if not tfk.is_valid_dataset(row, i_dec_level): continue run_obj = row.mth5_obj.from_reference(row.run_hdf5_reference) if row.fc: stft_obj = load_stft_obj_from_mth5(i_dec_level, row, run_obj) stfts = append_chunk_to_stfts(stfts, stft_obj, row.remote) continue run_xrds = row["run_dataarray"].to_dataset("channel") # Musgraves workaround for old MT data triage_run_id(row.run, run_obj) stft_obj = make_stft_objects( tfk.config, i_dec_level, run_obj, run_xrds, units, ) # Pack FCs into h5 dec_level_config = tfk.config.decimations[i_dec_level] save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) stfts = append_chunk_to_stfts(stfts, stft_obj, row.remote) return stfts
[docs]def process_mth5_legacy( config, tfk_dataset=None, units="MT", show_plot=False, z_file_path=None, return_collection=False, ): """ This is the main method used to transform a processing_config, and a kernel_dataset into a transfer function estimate. Parameters ---------- config: mt_metadata.transfer_functions.processing.aurora.Processing or path to json All processing parameters tfk_dataset: aurora.tf_kernel.dataset.Dataset or None Specifies what datasets to process according to config units: string "MT" or "SI". To be deprecated once data have units embedded show_plot: boolean Only used for dev z_file_path: string or pathlib.Path Target path for a z_file output if desired return_collection : boolean return_collection=False will return an mt_metadata TF object return_collection=True will return aurora.transfer_function.transfer_function_collection.TransferFunctionCollection Returns ------- tf_collection: TransferFunctionCollection or mt_metadata TF The transfer function object tf_cls: mt_metadata.transfer_functions.TF TF object """ # Initialize config and mth5s tfk = TransferFunctionKernel(dataset=tfk_dataset, config=config) tfk.make_processing_summary() tfk.show_processing_summary() tfk.validate() tfk.initialize_mth5s() msg = ( f"Processing config indicates {len(tfk.config.decimations)} " f"decimation levels" ) logger.info(msg) tf_dict = {} for i_dec_level, dec_level_config in enumerate(tfk.valid_decimations()): # if not tfk.all_fcs_already_exist(): tfk.update_dataset_df(i_dec_level) tfk.apply_clock_zero(dec_level_config) stfts = get_spectrogams(tfk, i_dec_level, units=units) local_merged_stft_obj, remote_merged_stft_obj = merge_stfts(stfts, tfk) # FC TF Interface here (see Note #3) # Feature Extraction, Selection of weights ttfz_obj = process_tf_decimation_level( tfk.config, i_dec_level, local_merged_stft_obj, remote_merged_stft_obj, ) ttfz_obj.apparent_resistivity( tfk.config.channel_nomenclature, units=units ) tf_dict[i_dec_level] = ttfz_obj if show_plot: from aurora.sandbox.plot_helpers import plot_tf_obj plot_tf_obj(ttfz_obj, out_filename="") tf_collection = TransferFunctionCollection( tf_dict=tf_dict, processing_config=tfk.config ) tf_cls = tfk.export_tf_collection(tf_collection) if z_file_path: tf_cls.write(z_file_path) tfk.dataset.close_mth5s() if return_collection: # this is now really only to be used for debugging and may be deprecated soon return tf_collection else: return tf_cls
[docs]def process_mth5( config, tfk_dataset=None, units="MT", show_plot=False, z_file_path=None, return_collection=False, processing_type="legacy", ): """ This is a pass-through method that routes the config and tfk_dataset to MT data processing. It currently only supports legacy aurora processing. Parameters ---------- config: mt_metadata.transfer_functions.processing.aurora.Processing or path to json All processing parameters tfk_dataset: aurora.tf_kernel.dataset.Dataset or None Specifies what datasets to process according to config units: string "MT" or "SI". To be deprecated once data have units embedded show_plot: boolean Only used for dev z_file_path: string or pathlib.Path Target path for a z_file output if desired return_collection : boolean return_collection=False will return an mt_metadata TF object return_collection=True will return aurora.transfer_function.transfer_function_collection.TransferFunctionCollection processing_type: string Controlled vocabulary, must be one of ["legacy",] This is not really supported now, but the idea is that in future, the config and tfk_dataset can be passed to another processing method if desired. Returns ------- tf_obj: TransferFunctionCollection or mt_metadata.transfer_functions.TF The transfer function object """ if processing_type not in SUPPORTED_PROCESSINGS: raise NotImplementedError( f"Processing type {processing_type} not supported" ) if processing_type == "legacy": try: return process_mth5_legacy( config, tfk_dataset=tfk_dataset, units=units, show_plot=show_plot, z_file_path=z_file_path, return_collection=return_collection, ) except Exception as e: close_open_files() msg = "Failed to run legacy processing\n" msg += "closing all open mth5 files and exiting" msg += f"The encountered exception was {e}" logger.error(msg) logger.exception(msg) return