Skip to content

Commit

Permalink
Add tests for losses
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 11, 2024
1 parent e65679d commit 8591d7d
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __call__(self, tensor_map_1: TensorMap, tensor_map_2: TensorMap) -> torch.Te
loss = torch.zeros((), dtype=tensor_map_1.block().values.dtype, device=tensor_map_1.block().values.device)
loss += self.weight * self.loss(tensor_map_1.block().values, tensor_map_2.block().values)
for gradient_name, gradient_weight in self.gradient_weights.items():
loss += gradient_weight * self.loss(tensor_map_1.gradient(gradient_name).values, tensor_map_2.gradient(gradient_name).values)
loss += gradient_weight * self.loss(tensor_map_1.block().gradient(gradient_name).values, tensor_map_2.block().gradient(gradient_name).values)

return loss

Expand Down
174 changes: 174 additions & 0 deletions tests/utils/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import pytest

import torch
from metatensor.torch import TensorMap, TensorBlock, Labels

from metatensor.models.utils.loss import TensorMapLoss, TensorMapDictLoss


@pytest.fixture
def tensor_map_with_grad_1():
block = TensorBlock(
values=torch.tensor([[1.0], [2.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels.single(),
)
block.add_gradient(
"gradient",
TensorBlock(
values=torch.tensor([[1.0], [2.0], [3.0]]),
samples=Labels.range("sample", 3),
components=[],
properties=Labels.single(),
)
)
tensor_map = TensorMap(
keys=Labels.single(),
blocks=[block]
)
return tensor_map

@pytest.fixture
def tensor_map_with_grad_2():
block = TensorBlock(
values=torch.tensor([[1.0], [1.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels.single(),
)
block.add_gradient(
"gradient",
TensorBlock(
values=torch.tensor([[1.0], [0.0], [3.0]]),
samples=Labels.range("sample", 3),
components=[],
properties=Labels.single(),
)
)
tensor_map = TensorMap(
keys=Labels.single(),
blocks=[block]
)
return tensor_map

@pytest.fixture
def tensor_map_with_grad_3():
block = TensorBlock(
values=torch.tensor([[0.0], [1.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels.single(),
)
block.add_gradient(
"gradient",
TensorBlock(
values=torch.tensor([[1.0], [0.0], [3.0]]),
samples=Labels.range("sample", 3),
components=[],
properties=Labels.single(),
)
)
tensor_map = TensorMap(
keys=Labels.single(),
blocks=[block]
)
return tensor_map

@pytest.fixture
def tensor_map_with_grad_4():
block = TensorBlock(
values=torch.tensor([[0.0], [1.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels.single(),
)
block.add_gradient(
"gradient",
TensorBlock(
values=torch.tensor([[1.0], [0.0], [2.0]]),
samples=Labels.range("sample", 3),
components=[],
properties=Labels.single(),
)
)
tensor_map = TensorMap(
keys=Labels.single(),
blocks=[block]
)
return tensor_map


def test_tmap_loss_no_gradients():
"""Test that the loss is computed correctly when there are no gradients."""
loss = TensorMapLoss()

tensor_map_1 = TensorMap(
keys=Labels.single(),
blocks=[
TensorBlock(
values=torch.tensor([[1.0], [2.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels.single(),
)
]
)
tensor_map_2 = TensorMap(
keys=Labels.single(),
blocks=[
TensorBlock(
values=torch.tensor([[0.0], [2.0], [3.0]]),
samples=Labels.range("samples", 3),
components=[],
properties=Labels.single(),
)
]
)

assert torch.allclose(loss(tensor_map_1, tensor_map_1), torch.tensor(0.0))

# Expected result: 1.0/3.0 (there are three values)
assert torch.allclose(loss(tensor_map_1, tensor_map_2), torch.tensor(1.0/3.0))


def test_tmap_loss_with_gradients(tensor_map_with_grad_1, tensor_map_with_grad_2):
"""Test that the loss is computed correctly when there are gradients."""
loss = TensorMapLoss(gradient_weights={"gradient": 0.5})

assert torch.allclose(loss(tensor_map_with_grad_1, tensor_map_with_grad_1), torch.tensor(0.0))

# Expected result: 1.0/3.0 + 0.5 * 4.0 / 3.0 (there are three values)
assert torch.allclose(loss(tensor_map_with_grad_1, tensor_map_with_grad_2), torch.tensor(1.0/3.0 + 0.5 * 4.0 / 3.0))


def test_tmap_dict_loss(tensor_map_with_grad_1, tensor_map_with_grad_2, tensor_map_with_grad_3, tensor_map_with_grad_4):
"""Test that the dict loss is computed correctly."""

loss = TensorMapDictLoss(
weights={
"output_1": {"values": 1.0, "gradient": 0.5},
"output_2": {"values": 1.0, "gradient": 0.5},
}
)

output_dict = {
"output_1": tensor_map_with_grad_1,
"output_2": tensor_map_with_grad_2,
}

target_dict = {
"output_1": tensor_map_with_grad_3,
"output_2": tensor_map_with_grad_4,
}

expected_result = (
1.0 * (tensor_map_with_grad_1.block().values - tensor_map_with_grad_3.block().values).pow(2).mean() +
0.5 * (tensor_map_with_grad_1.block().gradient("gradient").values - tensor_map_with_grad_3.block().gradient("gradient").values).pow(2).mean() +
1.0 * (tensor_map_with_grad_2.block().values - tensor_map_with_grad_4.block().values).pow(2).mean() +
0.5 * (tensor_map_with_grad_2.block().gradient("gradient").values - tensor_map_with_grad_4.block().gradient("gradient").values).pow(2).mean()
)

assert torch.allclose(loss(output_dict, target_dict), expected_result)


0 comments on commit 8591d7d

Please sign in to comment.