Skip to content

Commit

Permalink
Fix bad use of torch.squeeze.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Jun 21, 2023
1 parent 701017a commit 340b12c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions matgl/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def step(self, batch: tuple):
batch_size = preds.numel()
return results, batch_size

def loss_fn(self, loss: nn.Module, labels: tuple, preds: tuple):
def loss_fn(self, loss: nn.Module, labels: torch.Tensor, preds: torch.Tensor):
"""Args:
loss: Loss function.
labels: Labels to compute the loss.
Expand All @@ -221,9 +221,10 @@ def loss_fn(self, loss: nn.Module, labels: tuple, preds: tuple):
Returns:
{"Total_Loss": total_loss, "MAE": mae, "RMSE": rmse}
"""
total_loss = loss(labels, torch.squeeze(preds * self.data_std + self.data_mean))
mae = self.mae(labels, torch.squeeze(preds * self.data_std + self.data_mean))
rmse = self.rmse(labels, torch.squeeze(preds * self.data_std + self.data_mean))
scaled_pred = torch.reshape(preds * self.data_std + self.data_mean, labels.size())
total_loss = loss(labels, scaled_pred)
mae = self.mae(labels, scaled_pred)
rmse = self.rmse(labels, scaled_pred)
return {"Total_Loss": total_loss, "MAE": mae, "RMSE": rmse}


Expand Down

0 comments on commit 340b12c

Please sign in to comment.