From 37995092a6ebadd710211f70f08cacac4fb068ff Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Fri, 4 Jun 2021 11:54:42 -0500 Subject: [PATCH] [#98] multitask nse, kge functions; rm weights --- river_dl/loss_functions.py | 40 ++++++++------------------------------ 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/river_dl/loss_functions.py b/river_dl/loss_functions.py index 431b972..198daf8 100644 --- a/river_dl/loss_functions.py +++ b/river_dl/loss_functions.py @@ -1,4 +1,3 @@ -import numpy as np import tensorflow as tf @@ -69,38 +68,20 @@ def samplewise_nnse_loss(y_true, y_pred): return 1 - nnse_val -def nnse_masked_one_var(data, y_pred, var_idx, tasks): - y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks) - return nnse_loss(y_true, y_pred) +def multitask_nse(lambdas): + return multitask_loss(lambdas, nnse_loss) -def nnse_one_var_samplewise(data, y_pred, var_idx, tasks): - y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks) - return samplewise_nnse_loss(y_true, y_pred) +def multitask_samplewise_nse(lambdas): + return multitask_loss(lambdas, samplewise_nnse_loss) -def y_data_components(data, y_pred, var_idx, tasks): - weights = data[:, :, -tasks:] - y_true = data[:, :, :-tasks] - - # ensure y_pred, weights, and y_true are all tensors the same data type - y_true = tf.convert_to_tensor(y_true) - weights = tf.convert_to_tensor(weights) - y_true = tf.cast(y_true, y_pred.dtype) - weights = tf.cast(weights, y_pred.dtype) - - # make all zero-weighted observations 'nan' so they don't get counted - # at all in the loss calculation - y_true = tf.where(weights == 0, np.nan, y_true) +def multitask_rmse(lambdas): + return multitask_loss(lambdas, rmse) - weights = weights[:, :, var_idx] - y_true = y_true[:, :, var_idx] - y_pred = y_pred[:, :, var_idx] - return y_true, y_pred, weights - -def weighted_masked_rmse(lambdas): - return multitask_loss(lambdas, rmse) +def multitask_kge(lambdas): + return multitask_loss(lambdas, kge_loss) def multitask_loss(lambdas, loss_func): @@ -185,10 +166,5 @@ def kge_norm_loss(y_true, y_pred): return 1 - norm_kge(y_true, y_pred) -def kge_loss_one_var(data, y_pred, var_idx, tasks): - y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks) - return kge_loss(y_true, y_pred) - - def kge_loss(y_true, y_pred): return -1 * kge(y_true, y_pred)