From 8591d7d2fed3fde1440f976596d1be36c4c2dde9 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 11 Jan 2024 11:58:18 +0100 Subject: [PATCH] Add tests for losses --- src/metatensor/models/utils/loss.py | 2 +- tests/utils/test_loss.py | 174 ++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_loss.py diff --git a/src/metatensor/models/utils/loss.py b/src/metatensor/models/utils/loss.py index 2883ee881..1d939ea17 100644 --- a/src/metatensor/models/utils/loss.py +++ b/src/metatensor/models/utils/loss.py @@ -44,7 +44,7 @@ def __call__(self, tensor_map_1: TensorMap, tensor_map_2: TensorMap) -> torch.Te loss = torch.zeros((), dtype=tensor_map_1.block().values.dtype, device=tensor_map_1.block().values.device) loss += self.weight * self.loss(tensor_map_1.block().values, tensor_map_2.block().values) for gradient_name, gradient_weight in self.gradient_weights.items(): - loss += gradient_weight * self.loss(tensor_map_1.gradient(gradient_name).values, tensor_map_2.gradient(gradient_name).values) + loss += gradient_weight * self.loss(tensor_map_1.block().gradient(gradient_name).values, tensor_map_2.block().gradient(gradient_name).values) return loss diff --git a/tests/utils/test_loss.py b/tests/utils/test_loss.py new file mode 100644 index 000000000..2db407f51 --- /dev/null +++ b/tests/utils/test_loss.py @@ -0,0 +1,174 @@ +import pytest + +import torch +from metatensor.torch import TensorMap, TensorBlock, Labels + +from metatensor.models.utils.loss import TensorMapLoss, TensorMapDictLoss + + +@pytest.fixture +def tensor_map_with_grad_1(): + block = TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ) + ) + tensor_map = TensorMap( + keys=Labels.single(), + blocks=[block] + ) + return tensor_map + +@pytest.fixture +def tensor_map_with_grad_2(): + block = TensorBlock( + values=torch.tensor([[1.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ) + ) + tensor_map = TensorMap( + keys=Labels.single(), + blocks=[block] + ) + return tensor_map + +@pytest.fixture +def tensor_map_with_grad_3(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ) + ) + tensor_map = TensorMap( + keys=Labels.single(), + blocks=[block] + ) + return tensor_map + +@pytest.fixture +def tensor_map_with_grad_4(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [2.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ) + ) + tensor_map = TensorMap( + keys=Labels.single(), + blocks=[block] + ) + return tensor_map + + +def test_tmap_loss_no_gradients(): + """Test that the loss is computed correctly when there are no gradients.""" + loss = TensorMapLoss() + + tensor_map_1 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + ] + ) + tensor_map_2 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[0.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + ] + ) + + assert torch.allclose(loss(tensor_map_1, tensor_map_1), torch.tensor(0.0)) + + # Expected result: 1.0/3.0 (there are three values) + assert torch.allclose(loss(tensor_map_1, tensor_map_2), torch.tensor(1.0/3.0)) + + +def test_tmap_loss_with_gradients(tensor_map_with_grad_1, tensor_map_with_grad_2): + """Test that the loss is computed correctly when there are gradients.""" + loss = TensorMapLoss(gradient_weights={"gradient": 0.5}) + + assert torch.allclose(loss(tensor_map_with_grad_1, tensor_map_with_grad_1), torch.tensor(0.0)) + + # Expected result: 1.0/3.0 + 0.5 * 4.0 / 3.0 (there are three values) + assert torch.allclose(loss(tensor_map_with_grad_1, tensor_map_with_grad_2), torch.tensor(1.0/3.0 + 0.5 * 4.0 / 3.0)) + + +def test_tmap_dict_loss(tensor_map_with_grad_1, tensor_map_with_grad_2, tensor_map_with_grad_3, tensor_map_with_grad_4): + """Test that the dict loss is computed correctly.""" + + loss = TensorMapDictLoss( + weights={ + "output_1": {"values": 1.0, "gradient": 0.5}, + "output_2": {"values": 1.0, "gradient": 0.5}, + } + ) + + output_dict = { + "output_1": tensor_map_with_grad_1, + "output_2": tensor_map_with_grad_2, + } + + target_dict = { + "output_1": tensor_map_with_grad_3, + "output_2": tensor_map_with_grad_4, + } + + expected_result = ( + 1.0 * (tensor_map_with_grad_1.block().values - tensor_map_with_grad_3.block().values).pow(2).mean() + + 0.5 * (tensor_map_with_grad_1.block().gradient("gradient").values - tensor_map_with_grad_3.block().gradient("gradient").values).pow(2).mean() + + 1.0 * (tensor_map_with_grad_2.block().values - tensor_map_with_grad_4.block().values).pow(2).mean() + + 0.5 * (tensor_map_with_grad_2.block().gradient("gradient").values - tensor_map_with_grad_4.block().gradient("gradient").values).pow(2).mean() + ) + + assert torch.allclose(loss(output_dict, target_dict), expected_result) + +