Skip to content

Commit

Permalink
Towards issue #119
Browse files Browse the repository at this point in the history
- tidied up workflow in process_transfer_functions
- created placeholders for segment weights
- factored out some methods applied to X, Y, RR for readability
  • Loading branch information
kkappler committed Jan 19, 2024
1 parent aad4d79 commit 93e5785
Showing 1 changed file with 67 additions and 35 deletions.
102 changes: 67 additions & 35 deletions aurora/pipelines/transfer_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from aurora.time_series.frequency_band_helpers import get_band_for_tf_estimate
from aurora.time_series.xarray_helpers import handle_nan
from aurora.transfer_function.regression.base import RegressionEstimator
from aurora.transfer_function.regression.iter_control import IterControl
from aurora.transfer_function.regression.TRME import TRME
from aurora.transfer_function.regression.TRME_RR import TRME_RR

from aurora.transfer_function.regression.base import RegressionEstimator
# from aurora.transfer_function.weights.coherence_weights import compute_multiple_coherence_weights
from aurora.transfer_function.weights.edf_weights import (
effective_degrees_of_freedom_weights,
)
Expand Down Expand Up @@ -89,14 +90,50 @@ def select_channel(xrda, channel_label):
return ch


def dropna(X, Y, RR):
"""
Just a helper intended to enhance readability
TODO: document the implications of dropna on index of xarray for other weights
"""
X = X.dropna(dim="observation")
Y = Y.dropna(dim="observation")
if RR is not None:
RR = RR.dropna(dim="observation")
return X, Y, RR


def stack_fcs(X, Y, RR):
"""Reshape 2D arrays of frequency and time to 1D"""
X = X.stack(observation=("frequency", "time"))
Y = Y.stack(observation=("frequency", "time"))
if RR is not None:
RR = RR.stack(observation=("frequency", "time"))
return X, Y, RR


def apply_weights(X, Y, RR, W, segment=False, dropna=False):
W[W == 0] = np.nan
if segment:
W = np.atleast_2d(W).T
X *= W
Y *= W
if RR is not None:
RR *= W

if dropna:
X, Y, RR = dropna(X, Y, RR)

return X, Y, RR


def process_transfer_functions(
dec_level_config,
local_stft_obj,
remote_stft_obj,
transfer_function_obj,
segment_weights=None,
# segment_weights=["jj84_coherence_weights",],
segment_weights=[],
channel_weights=None,
# use_multiple_coherence_weights=False,
):
"""
This method based on TTFestBand.m
Expand Down Expand Up @@ -131,51 +168,46 @@ def process_transfer_functions(
estimator_class = get_estimator_class(dec_level_config.estimator.engine)
iter_control = set_up_iter_control(dec_level_config)
for band in transfer_function_obj.frequency_bands.bands():
# if use_multiple_coherence_weights:
# from aurora.transfer_function.weights.coherence_weights import compute_multiple_coherence_weights
# Wmc = compute_multiple_coherence_weights(band, local_stft_obj, remote_stft_obj)

X, Y, RR = get_band_for_tf_estimate(
band, dec_level_config, local_stft_obj, remote_stft_obj
)
# if there are segment weights apply them here

# Apply segment weights first
# This could be replaced by a method that combines (product) all segment weights in a dict
# weights = {}
if "jj84_coherence_weights" in segment_weights:
from aurora.transfer_function.weights.coherence_weights import (
coherence_weights_jj84,
)

Wjj84 = coherence_weights_jj84(band, local_stft_obj, remote_stft_obj)
apply_weights(X, Y, RR, Wjj84, segment=True, dropna=False)

# if multiple_coherence_weights in segment_weights:
# from aurora.transfer_function.weights.coherence_weights import compute_multiple_coherence_weights
# Wmc = compute_multiple_coherence_weights(band, local_stft_obj, remote_stft_obj)
# apply_segment_weights(X, Y, RR, Wmc)

# if there are channel weights apply them here
# Reshape to 2d - maybe push this into extract band method
X = X.stack(observation=("frequency", "time"))
Y = Y.stack(observation=("frequency", "time"))
if RR is not None:
RR = RR.stack(observation=("frequency", "time"))

# Reshape to 2d
X, Y, RR = stack_fcs(X, Y, RR)

# Should only be needed if weights were applied
X, Y, RR = dropna(X, Y, RR)

W = effective_degrees_of_freedom_weights(X, RR, edf_obj=None)
# if use_multiple_coherence_weights:
# W *= Wmc
W[W == 0] = np.nan # use this to drop values in the handle_nan
# apply weights
X *= W
Y *= W
if RR is not None:
RR *= W
X = X.dropna(dim="observation")
Y = Y.dropna(dim="observation")
if RR is not None:
RR = RR.dropna(dim="observation")

# COHERENCE SORTING
# coh_type = "local"
# if dec_level_config.decimation.level == 0:
# from aurora.transfer_function.weights.coherence_weights import coherence_weights_jj84
# X, Y, RR = coherence_weights_jj84(X,Y,RR, coh_type=coh_type)
X, Y, RR = apply_weights(X, Y, RR, W, segment=False, dropna=True)

if dec_level_config.estimator.estimate_per_channel:
for ch in dec_level_config.output_channels:
Y_ch = Y[ch].to_dataset() # keep as a dataset, maybe not needed

X_, Y_, RR_ = handle_nan(X, Y_ch, RR, drop_dim="observation")

# W = effective_degrees_of_freedom_weights(X_, RR_, edf_obj=None)
# X_ *= W
# Y_ *= W
# if RR is not None:
# RR_ *= W
W = effective_degrees_of_freedom_weights(X_, RR_, edf_obj=None)
X_, Y_, RR_ = apply_weights(X_, Y_, RR_, W, segment=False)

regression_estimator = estimator_class(
X=X_, Y=Y_, Z=RR_, iter_control=iter_control
Expand Down

0 comments on commit 93e5785

Please sign in to comment.