Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
[#98] multitask nse, kge functions; rm weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jun 4, 2021
1 parent 2c7580b commit 3799509
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions river_dl/loss_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import tensorflow as tf


Expand Down Expand Up @@ -69,38 +68,20 @@ def samplewise_nnse_loss(y_true, y_pred):
return 1 - nnse_val


def nnse_masked_one_var(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return nnse_loss(y_true, y_pred)
def multitask_nse(lambdas):
return multitask_loss(lambdas, nnse_loss)


def nnse_one_var_samplewise(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return samplewise_nnse_loss(y_true, y_pred)
def multitask_samplewise_nse(lambdas):
return multitask_loss(lambdas, samplewise_nnse_loss)


def y_data_components(data, y_pred, var_idx, tasks):
weights = data[:, :, -tasks:]
y_true = data[:, :, :-tasks]

# ensure y_pred, weights, and y_true are all tensors the same data type
y_true = tf.convert_to_tensor(y_true)
weights = tf.convert_to_tensor(weights)
y_true = tf.cast(y_true, y_pred.dtype)
weights = tf.cast(weights, y_pred.dtype)

# make all zero-weighted observations 'nan' so they don't get counted
# at all in the loss calculation
y_true = tf.where(weights == 0, np.nan, y_true)
def multitask_rmse(lambdas):
return multitask_loss(lambdas, rmse)

weights = weights[:, :, var_idx]
y_true = y_true[:, :, var_idx]
y_pred = y_pred[:, :, var_idx]
return y_true, y_pred, weights


def weighted_masked_rmse(lambdas):
return multitask_loss(lambdas, rmse)
def multitask_kge(lambdas):
return multitask_loss(lambdas, kge_loss)


def multitask_loss(lambdas, loss_func):
Expand Down Expand Up @@ -185,10 +166,5 @@ def kge_norm_loss(y_true, y_pred):
return 1 - norm_kge(y_true, y_pred)


def kge_loss_one_var(data, y_pred, var_idx, tasks):
y_true, y_pred, weights = y_data_components(data, y_pred, var_idx, tasks)
return kge_loss(y_true, y_pred)


def kge_loss(y_true, y_pred):
return -1 * kge(y_true, y_pred)

0 comments on commit 3799509

Please sign in to comment.