Skip to content

Commit

Permalink
Merge pull request #163 from asogaard/target-dimensions
Browse files Browse the repository at this point in the history
Ensure targets always have two dimensions
  • Loading branch information
asogaard authored Feb 8, 2022
2 parents 141d3dd + 58682f8 commit 63390a8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
9 changes: 4 additions & 5 deletions src/graphnet/components/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def _log_cosh(cls, x: Tensor) -> Tensor: # pylint: disable=invalid-name

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Implementation of loss calculation."""
assert prediction.dim() == target.dim() + 1
diff = prediction[:,0] - target
diff = prediction - target
elements = self._log_cosh(diff)
return elements

Expand Down Expand Up @@ -226,11 +225,11 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""
# Check(s)
assert prediction.dim() == 2 and prediction.size()[1] == 2
assert target.dim() == 1
assert target.dim() == 2
assert prediction.size()[0] == target.size()[0]

# Formatting target
angle_true = target
angle_true = target[:,0]
t = torch.stack([
torch.cos(angle_true),
torch.sin(angle_true),
Expand All @@ -248,6 +247,6 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:

class XYZWithMaxScaling(LossFunction):
def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
diff = (prediction[:,0] - target[:,0]/764.431509)**2 + (prediction[:,1] - target[:,1]/785.041607)**2 + (prediction[:,2] - target[:,2]/1083.249944)**2 #+(prediction[:,3] - target[:,3]/14721.646883)
diff = (prediction[:,0] - target[:,0]/764.431509)**2 + (prediction[:,1] - target[:,1]/785.041607)**2 + (prediction[:,2] - target[:,2]/1083.249944)**2 #+(prediction[:,3] - target[:,3]/14721.646883)
elements = torch.sqrt(diff)
return elements
4 changes: 2 additions & 2 deletions tests/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def _compute_elementwise_gradient(outputs: Tensor, inputs: Tensor) -> Tensor:
def test_log_cosh(dtype=torch.float32):
# Prepare test data
x = torch.tensor([-100, -10, -1, 0, 1, 10, 100], dtype=dtype).unsqueeze(1) # Shape [N, 1]
y = 0. * x.clone().squeeze() # Shape [N,]
y = 0. * x.clone() # Shape [N,1]

# Calculate losses using loss function, and manually
log_cosh_loss = LogCoshLoss()
losses = log_cosh_loss(x, y, return_elements=True)
losses_reference = torch.log(torch.cosh(x[:,0] - y))
losses_reference = torch.log(torch.cosh(x - y))

# (1) Loss functions should not return `inf` losses, even for large
# differences between prediction and target. This is not necessarily
Expand Down

0 comments on commit 63390a8

Please sign in to comment.