diff --git a/src/metatensor/models/soap_bpnn/model.py b/src/metatensor/models/soap_bpnn/model.py index b257e1d4e..50579a63e 100644 --- a/src/metatensor/models/soap_bpnn/model.py +++ b/src/metatensor/models/soap_bpnn/model.py @@ -125,7 +125,7 @@ def __init__( values=torch.tensor(all_species).reshape(-1, 1), ) - def forward(self, systems: List[rascaline.torch.System]) -> Dict[str, TensorMap]: + def forward(self, systems: List[metatensor.torch.atomistic.System]) -> Dict[str, TensorMap]: soap_features = self.soap_calculator(systems) device = soap_features.block(0).values.device diff --git a/src/metatensor/models/utils/compute_loss.py b/src/metatensor/models/utils/compute_loss.py new file mode 100644 index 000000000..7edb47acf --- /dev/null +++ b/src/metatensor/models/utils/compute_loss.py @@ -0,0 +1,181 @@ +import torch +from metatensor.torch.atomistic import System +from metatensor.torch import Labels, TensorBlock, TensorMap + +from typing import Dict, List +from .loss import TensorMapDictLoss +from .output_gradient import compute_gradient + + +def compute_model_loss( + loss: TensorMapDictLoss, + model: torch.nn.Module, + systems: List[System], + targets: Dict[str, TensorMap], +): + """ + Compute the loss of a model on a set of targets. + + This function assumes that the model returns a dictionary of + TensorMaps, with the same keys as the targets. + """ + # Assert that all targets are within the model's capabilities: + if not set(targets.keys()).issubset(model.capabilities.outputs.keys()): + raise ValueError("Not all targets are within the model's capabilities.") + + # Find if there are any energy targets that require gradients: + energy_targets = [] + energy_targets_that_require_position_gradients = [] + energy_targets_that_require_displacement_gradients = [] + for target_name in targets.keys(): + # Check if the target is an energy: + if model.capabilities.outputs[target_name].quantity == "energy": + energy_targets.append(target_name) + # Check if the energy requires gradients: + if targets[target_name].has_gradients("positions"): + energy_targets_that_require_position_gradients.append(target_name) + if targets[target_name].has_gradients("displacements"): + energy_targets_that_require_displacement_gradients.append(target_name) + + if len(energy_targets_that_require_displacement_gradients) > 0: + # TODO: raise an error if the systems do not have a cell + # if not all([system.has_cell for system in systems]): + # raise ValueError("One or more systems does not have a cell.") + displacements = [torch.eye(3, requires_grad=True, dtype=system.dtype, device=system.device) for system in systems] + # Create new "displaced" systems: + systems = [ + System( + positions=system.positions @ displacement, + cell=system.cell @ displacement, + species=system.species, + ) + for system, displacement in zip(systems, displacements) + ] + else: + if len(energy_targets_that_require_position_gradients) > 0: + # Set positions to require gradients: + for system in systems: + system.positions.requires_grad_(True) + + # Based on the keys of the targets, get the outputs of the model: + model_outputs = model(systems, targets.keys()) + + for energy_target in energy_targets: + # If the energy target requires gradients, compute them: + target_requires_pos_gradients = energy_target in energy_targets_that_require_position_gradients + target_requires_disp_gradients = energy_target in energy_targets_that_require_displacement_gradients + if target_requires_pos_gradients and target_requires_disp_gradients: + gradients = compute_gradient( + model_outputs[energy_target].block().values, + [system.positions for system in systems] + displacements, + is_training=True, + ) + old_energy_tensor_map = model_outputs[energy_target] + new_block = old_energy_tensor_map.block().copy() + new_block.add_gradient("positions", _position_gradients_to_block(gradients[:len(systems)])) + new_block.add_gradient("displacements", _displacement_gradients_to_block(gradients[len(systems):])) + new_energy_tensor_map = TensorMap( + keys=old_energy_tensor_map.keys, + blocks=[new_block], + ) + model_outputs[energy_target] = new_energy_tensor_map + elif target_requires_pos_gradients: + gradients = compute_gradient( + model_outputs[energy_target].block().values, + [system.positions for system in systems], + is_training=True, + ) + old_energy_tensor_map = model_outputs[energy_target] + new_block = old_energy_tensor_map.block().copy() + new_block.add_gradient("positions", _position_gradients_to_block(gradients)) + new_energy_tensor_map = TensorMap( + keys=old_energy_tensor_map.keys, + blocks=[new_block], + ) + model_outputs[energy_target] = new_energy_tensor_map + elif target_requires_disp_gradients: + gradients = compute_gradient( + model_outputs[energy_target].block().values, + displacements, + is_training=True, + ) + old_energy_tensor_map = model_outputs[energy_target] + new_block = old_energy_tensor_map.block().copy() + new_block.add_gradient("displacements", _displacement_gradients_to_block(gradients)) + new_energy_tensor_map = TensorMap( + keys=old_energy_tensor_map.keys, + blocks=[new_block], + ) + model_outputs[energy_target] = new_energy_tensor_map + else: + pass + + # Compute the loss: + return loss(model_outputs, targets) + + +def _position_gradients_to_block(gradients_list): + """Convert a list of position gradients to a `TensorBlock` + which can act as a gradient block to an energy block.""" + + # `gradients` consists of a list of tensors where the second dimension is 3 + gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1) + # unsqueeze for the property dimension + + samples = Labels( + names=["sample", "atom"], + values=torch.stack([ + torch.concatenate([torch.tensor([i]*len(structure)) for i, structure in enumerate(gradients_list)]), + torch.concatenate([torch.arange(len(structure)) for structure in gradients_list]), + ], dim=1), + ) + + components = [ + Labels( + names=["coordinate"], + values=torch.tensor([[0], [1], [2]]), + ) + ] + + return TensorBlock( + values=gradients, + samples=samples, + components=components, + properties=Labels.single(), + ) + + +def _displacement_gradients_to_block(gradients_list): + """Convert a list of displacement gradients to a `TensorBlock` + which can act as a gradient block to an energy block.""" + + """Convert a list of position gradients to a `TensorBlock` + which can act as a gradient block to an energy block.""" + + # `gradients` consists of a list of tensors where the second dimension is 3 + gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1) + # unsqueeze for the property dimension + + samples = Labels( + names=["sample"], + values=torch.arange(len(gradients_list)).unsqueeze(-1) + ) + + # TODO: check if this makes physical sense + components = [ + Labels( + names=["cell vector"], + values=torch.tensor([[0], [1], [2]]), + ), + Labels( + names=["coordinate"], + values=torch.tensor([[0], [1], [2]]), + ) + ] + + return TensorBlock( + values=gradients, + samples=samples, + components=components, + properties=Labels.single(), + ) diff --git a/src/metatensor/models/utils/data/dataset.py b/src/metatensor/models/utils/data/dataset.py index 6dd57405b..7640b3541 100644 --- a/src/metatensor/models/utils/data/dataset.py +++ b/src/metatensor/models/utils/data/dataset.py @@ -8,7 +8,7 @@ class Dataset(torch.utils.data.Dataset): def __init__( - self, structures: List[rascaline.torch.System], targets: Dict[str, TensorMap] + self, structures: List[metatensor.torch.atomistic.System], targets: Dict[str, TensorMap] ): """ Creates a dataset from a list of `rascaline.torch.System` objects diff --git a/src/metatensor/models/utils/loss.py b/src/metatensor/models/utils/loss.py index 1d939ea17..7d10ac556 100644 --- a/src/metatensor/models/utils/loss.py +++ b/src/metatensor/models/utils/loss.py @@ -1,12 +1,8 @@ import metatensor.torch from metatensor.torch import TensorMap -from rascaline.torch.system import System - import torch -from typing import Dict, List, Optional - -from .output_gradient import compute_gradient +from typing import Dict, Optional # This file defines losses for metatensor models. @@ -85,83 +81,3 @@ def __call__(self, tensor_map_dict_1: Dict[str, TensorMap], tensor_map_dict_2: D loss += self.losses[key](tensor_map_dict_1[key], tensor_map_dict_2[key]) return loss - - -def compute_model_loss( - loss: TensorMapDictLoss, - model: torch.nn.Module, - systems: List[System], - targets: Dict[str, TensorMap], -): - """ - Compute the loss of a model on a set of targets. - - This function assumes that the model returns a dictionary of - TensorMaps, with the same keys as the targets. - """ - # Assert that all targets are within the model's capabilities: - if not set(targets.keys()).issubset(model.capabilities.outputs.keys()): - raise ValueError("Not all targets are within the model's capabilities.") - - # Find if there are any energy targets that require gradients: - energy_targets = [] - energy_targets_that_require_position_gradients = [] - energy_targets_that_require_displacement_gradients = [] - for target_name in targets.keys(): - # Check if the target is an energy: - if model.capabilities.outputs[target_name].quantity == "energy": - energy_targets.append(target_name) - # Check if the energy requires gradients: - if targets[target_name].has_gradients("positions"): - energy_targets_that_require_position_gradients.append(target_name) - if targets[target_name].has_gradients("displacements"): - energy_targets_that_require_displacement_gradients.append(target_name) - - if len(energy_targets_that_require_displacement_gradients) > 0: - # TODO: raise an error if the systems do not have a cell - # if not all([system.has_cell for system in systems]): - # raise ValueError("One or more systems does not have a cell.") - displacements = [torch.eye(3, requires_grad=True, dtype=system.dtype, device=system.device) for system in systems] - # Create new "displaced" systems: - systems = [ - System( - positions=system.positions @ displacement, - cell=system.cell @ displacement, - species=system.species, - ) - for system, displacement in zip(systems, displacements) - ] - else: - if len(energy_targets_that_require_position_gradients) > 0: - # Set positions to require gradients: - for system in systems: - system.positions.requires_grad_(True) - - # Based on the keys of the targets, get the outputs of the model: - raw_model_outputs = model(systems, targets.keys()) - - for energy_target in energy_targets: - # If the energy target requires gradients, compute them: - target_requires_pos_gradients = energy_target in energy_targets_that_require_position_gradients - target_requires_disp_gradients = energy_target in energy_targets_that_require_displacement_gradients - if target_requires_pos_gradients and target_requires_disp_gradients: - gradients = compute_gradient( - raw_model_outputs[energy_target].block().values, - [system.positions for system in systems] + displacements, - is_training=True, - ) - new_energy_tensor_map - elif target_requires_pos_gradients: - gradients = compute_gradient( - raw_model_outputs[energy_target].block().values, - [system.positions for system in systems], - is_training=True, - ) - elif target_requires_disp_gradients: - gradients = compute_gradient( - raw_model_outputs[energy_target].block().values, - displacements, - is_training=True, - ) - else: - pass diff --git a/tests/utils/test_compute_loss.py b/tests/utils/test_compute_loss.py new file mode 100644 index 000000000..b7793a91f --- /dev/null +++ b/tests/utils/test_compute_loss.py @@ -0,0 +1,78 @@ +from pathlib import Path + +import torch +from metatensor.torch import TensorMap, TensorBlock, Labels +from metatensor.models.utils.loss import TensorMapDictLoss +from metatensor.models.utils.compute_loss import compute_model_loss + +from metatensor.models import soap_bpnn +from metatensor.models.utils.data import read_structures + + +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + +def test_compute_model_loss(): + """Test that the model loss is computed.""" + + loss_fn = TensorMapDictLoss( + weights={ + "energy": {"values": 1.0, "positions": 10.0}, + } + ) + + model = soap_bpnn.Model(all_species=[1, 6, 7, 8]) + model = torch.jit.script(model) # jit the model for good measure + + structures = read_structures(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:5] + + gradient_samples = Labels( + names=["sample", "atom"], + values=torch.stack([ + torch.concatenate([torch.tensor([i]*len(structure)) for i, structure in enumerate(structures)]), + torch.concatenate([torch.arange(len(structure)) for structure in structures]), + ], dim=1), + ) + + gradient_components = [ + Labels( + names=["coordinate"], + values=torch.tensor([[0], [1], [2]]), + ) + ] + + block = TensorBlock( + values=torch.tensor([[0.0]*len(structures)]).T, + samples=Labels.range("structure", len(structures)), + components=[], + properties=Labels.single(), + ) + + block.add_gradient( + "positions", + TensorBlock( + values=torch.tensor([[[1.0], [1.0], [1.0]] for structure in structures for _ in range(len(structure.positions))]), + samples=gradient_samples, + components=gradient_components, + properties=Labels.single(), + ) + ) + + targets = { + "energy": TensorMap( + keys=Labels( + names=["lambda", "sigma"], + values=torch.tensor([[0, 1]]), + ), + blocks=[block] + ), + } + + loss = compute_model_loss( + loss_fn, + model, + structures, + targets, + ) + + diff --git a/tests/utils/test_loss.py b/tests/utils/test_loss.py index 2db407f51..273db3b0c 100644 --- a/tests/utils/test_loss.py +++ b/tests/utils/test_loss.py @@ -1,3 +1,4 @@ +from pathlib import Path import pytest import torch @@ -6,6 +7,9 @@ from metatensor.models.utils.loss import TensorMapLoss, TensorMapDictLoss +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + @pytest.fixture def tensor_map_with_grad_1(): block = TensorBlock( @@ -172,3 +176,28 @@ def test_tmap_dict_loss(tensor_map_with_grad_1, tensor_map_with_grad_2, tensor_m assert torch.allclose(loss(output_dict, target_dict), expected_result) +def test_tmap_dict_loss_subset(tensor_map_with_grad_1, tensor_map_with_grad_3): + """Test that the dict loss is computed correctly when only a subset + of the possible targets is present both in outputs and targets.""" + + loss = TensorMapDictLoss( + weights={ + "output_1": {"values": 1.0, "gradient": 0.5}, + "output_2": {"values": 1.0, "gradient": 0.5}, + } + ) + + output_dict = { + "output_1": tensor_map_with_grad_1, + } + + target_dict = { + "output_1": tensor_map_with_grad_3, + } + + expected_result = ( + 1.0 * (tensor_map_with_grad_1.block().values - tensor_map_with_grad_3.block().values).pow(2).mean() + + 0.5 * (tensor_map_with_grad_1.block().gradient("gradient").values - tensor_map_with_grad_3.block().gradient("gradient").values).pow(2).mean() + ) + + assert torch.allclose(loss(output_dict, target_dict), expected_result)