Skip to content

Commit

Permalink
Update unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
asogaard committed Feb 8, 2022
1 parent 9dbcd16 commit 58682f8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def _compute_elementwise_gradient(outputs: Tensor, inputs: Tensor) -> Tensor:
def test_log_cosh(dtype=torch.float32):
# Prepare test data
x = torch.tensor([-100, -10, -1, 0, 1, 10, 100], dtype=dtype).unsqueeze(1) # Shape [N, 1]
y = 0. * x.clone().squeeze() # Shape [N,]
y = 0. * x.clone() # Shape [N,1]

# Calculate losses using loss function, and manually
log_cosh_loss = LogCoshLoss()
losses = log_cosh_loss(x, y, return_elements=True)
losses_reference = torch.log(torch.cosh(x[:,0] - y))
losses_reference = torch.log(torch.cosh(x - y))

# (1) Loss functions should not return `inf` losses, even for large
# differences between prediction and target. This is not necessarily
Expand Down

0 comments on commit 58682f8

Please sign in to comment.