Skip to content

Commit

Permalink
Wrap forces and stresses
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 11, 2024
1 parent 8591d7d commit 8df7e41
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
values=torch.tensor(all_species).reshape(-1, 1),
)

def forward(self, systems: List[rascaline.torch.System]) -> Dict[str, TensorMap]:
def forward(self, systems: List[metatensor.torch.atomistic.System]) -> Dict[str, TensorMap]:
soap_features = self.soap_calculator(systems)

device = soap_features.block(0).values.device
Expand Down
181 changes: 181 additions & 0 deletions src/metatensor/models/utils/compute_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import torch
from metatensor.torch.atomistic import System
from metatensor.torch import Labels, TensorBlock, TensorMap

from typing import Dict, List
from .loss import TensorMapDictLoss
from .output_gradient import compute_gradient


def compute_model_loss(
loss: TensorMapDictLoss,
model: torch.nn.Module,
systems: List[System],
targets: Dict[str, TensorMap],
):
"""
Compute the loss of a model on a set of targets.
This function assumes that the model returns a dictionary of
TensorMaps, with the same keys as the targets.
"""
# Assert that all targets are within the model's capabilities:
if not set(targets.keys()).issubset(model.capabilities.outputs.keys()):
raise ValueError("Not all targets are within the model's capabilities.")

# Find if there are any energy targets that require gradients:
energy_targets = []
energy_targets_that_require_position_gradients = []
energy_targets_that_require_displacement_gradients = []
for target_name in targets.keys():
# Check if the target is an energy:
if model.capabilities.outputs[target_name].quantity == "energy":
energy_targets.append(target_name)
# Check if the energy requires gradients:
if targets[target_name].has_gradients("positions"):
energy_targets_that_require_position_gradients.append(target_name)
if targets[target_name].has_gradients("displacements"):
energy_targets_that_require_displacement_gradients.append(target_name)

if len(energy_targets_that_require_displacement_gradients) > 0:
# TODO: raise an error if the systems do not have a cell
# if not all([system.has_cell for system in systems]):
# raise ValueError("One or more systems does not have a cell.")
displacements = [torch.eye(3, requires_grad=True, dtype=system.dtype, device=system.device) for system in systems]
# Create new "displaced" systems:
systems = [
System(
positions=system.positions @ displacement,
cell=system.cell @ displacement,
species=system.species,
)
for system, displacement in zip(systems, displacements)
]
else:
if len(energy_targets_that_require_position_gradients) > 0:
# Set positions to require gradients:
for system in systems:
system.positions.requires_grad_(True)

# Based on the keys of the targets, get the outputs of the model:
model_outputs = model(systems, targets.keys())

for energy_target in energy_targets:
# If the energy target requires gradients, compute them:
target_requires_pos_gradients = energy_target in energy_targets_that_require_position_gradients
target_requires_disp_gradients = energy_target in energy_targets_that_require_displacement_gradients
if target_requires_pos_gradients and target_requires_disp_gradients:
gradients = compute_gradient(
model_outputs[energy_target].block().values,
[system.positions for system in systems] + displacements,
is_training=True,
)
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient("positions", _position_gradients_to_block(gradients[:len(systems)]))
new_block.add_gradient("displacements", _displacement_gradients_to_block(gradients[len(systems):]))
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
blocks=[new_block],
)
model_outputs[energy_target] = new_energy_tensor_map
elif target_requires_pos_gradients:
gradients = compute_gradient(
model_outputs[energy_target].block().values,
[system.positions for system in systems],
is_training=True,
)
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient("positions", _position_gradients_to_block(gradients))
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
blocks=[new_block],
)
model_outputs[energy_target] = new_energy_tensor_map
elif target_requires_disp_gradients:
gradients = compute_gradient(
model_outputs[energy_target].block().values,
displacements,
is_training=True,
)
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient("displacements", _displacement_gradients_to_block(gradients))
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
blocks=[new_block],
)
model_outputs[energy_target] = new_energy_tensor_map
else:
pass

# Compute the loss:
return loss(model_outputs, targets)


def _position_gradients_to_block(gradients_list):
"""Convert a list of position gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""

# `gradients` consists of a list of tensors where the second dimension is 3
gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1)
# unsqueeze for the property dimension

samples = Labels(
names=["sample", "atom"],
values=torch.stack([
torch.concatenate([torch.tensor([i]*len(structure)) for i, structure in enumerate(gradients_list)]),
torch.concatenate([torch.arange(len(structure)) for structure in gradients_list]),
], dim=1),
)

components = [
Labels(
names=["coordinate"],
values=torch.tensor([[0], [1], [2]]),
)
]

return TensorBlock(
values=gradients,
samples=samples,
components=components,
properties=Labels.single(),
)


def _displacement_gradients_to_block(gradients_list):
"""Convert a list of displacement gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""

"""Convert a list of position gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""

# `gradients` consists of a list of tensors where the second dimension is 3
gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1)
# unsqueeze for the property dimension

samples = Labels(
names=["sample"],
values=torch.arange(len(gradients_list)).unsqueeze(-1)
)

# TODO: check if this makes physical sense
components = [
Labels(
names=["cell vector"],
values=torch.tensor([[0], [1], [2]]),
),
Labels(
names=["coordinate"],
values=torch.tensor([[0], [1], [2]]),
)
]

return TensorBlock(
values=gradients,
samples=samples,
components=components,
properties=Labels.single(),
)
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class Dataset(torch.utils.data.Dataset):
def __init__(
self, structures: List[rascaline.torch.System], targets: Dict[str, TensorMap]
self, structures: List[metatensor.torch.atomistic.System], targets: Dict[str, TensorMap]
):
"""
Creates a dataset from a list of `rascaline.torch.System` objects
Expand Down
86 changes: 1 addition & 85 deletions src/metatensor/models/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import metatensor.torch
from metatensor.torch import TensorMap

from rascaline.torch.system import System

import torch
from typing import Dict, List, Optional

from .output_gradient import compute_gradient
from typing import Dict, Optional

# This file defines losses for metatensor models.

Expand Down Expand Up @@ -85,83 +81,3 @@ def __call__(self, tensor_map_dict_1: Dict[str, TensorMap], tensor_map_dict_2: D
loss += self.losses[key](tensor_map_dict_1[key], tensor_map_dict_2[key])

return loss


def compute_model_loss(
loss: TensorMapDictLoss,
model: torch.nn.Module,
systems: List[System],
targets: Dict[str, TensorMap],
):
"""
Compute the loss of a model on a set of targets.
This function assumes that the model returns a dictionary of
TensorMaps, with the same keys as the targets.
"""
# Assert that all targets are within the model's capabilities:
if not set(targets.keys()).issubset(model.capabilities.outputs.keys()):
raise ValueError("Not all targets are within the model's capabilities.")

# Find if there are any energy targets that require gradients:
energy_targets = []
energy_targets_that_require_position_gradients = []
energy_targets_that_require_displacement_gradients = []
for target_name in targets.keys():
# Check if the target is an energy:
if model.capabilities.outputs[target_name].quantity == "energy":
energy_targets.append(target_name)
# Check if the energy requires gradients:
if targets[target_name].has_gradients("positions"):
energy_targets_that_require_position_gradients.append(target_name)
if targets[target_name].has_gradients("displacements"):
energy_targets_that_require_displacement_gradients.append(target_name)

if len(energy_targets_that_require_displacement_gradients) > 0:
# TODO: raise an error if the systems do not have a cell
# if not all([system.has_cell for system in systems]):
# raise ValueError("One or more systems does not have a cell.")
displacements = [torch.eye(3, requires_grad=True, dtype=system.dtype, device=system.device) for system in systems]
# Create new "displaced" systems:
systems = [
System(
positions=system.positions @ displacement,
cell=system.cell @ displacement,
species=system.species,
)
for system, displacement in zip(systems, displacements)
]
else:
if len(energy_targets_that_require_position_gradients) > 0:
# Set positions to require gradients:
for system in systems:
system.positions.requires_grad_(True)

# Based on the keys of the targets, get the outputs of the model:
raw_model_outputs = model(systems, targets.keys())

for energy_target in energy_targets:
# If the energy target requires gradients, compute them:
target_requires_pos_gradients = energy_target in energy_targets_that_require_position_gradients
target_requires_disp_gradients = energy_target in energy_targets_that_require_displacement_gradients
if target_requires_pos_gradients and target_requires_disp_gradients:
gradients = compute_gradient(
raw_model_outputs[energy_target].block().values,
[system.positions for system in systems] + displacements,
is_training=True,
)
new_energy_tensor_map
elif target_requires_pos_gradients:
gradients = compute_gradient(
raw_model_outputs[energy_target].block().values,
[system.positions for system in systems],
is_training=True,
)
elif target_requires_disp_gradients:
gradients = compute_gradient(
raw_model_outputs[energy_target].block().values,
displacements,
is_training=True,
)
else:
pass
78 changes: 78 additions & 0 deletions tests/utils/test_compute_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from pathlib import Path

import torch
from metatensor.torch import TensorMap, TensorBlock, Labels
from metatensor.models.utils.loss import TensorMapDictLoss
from metatensor.models.utils.compute_loss import compute_model_loss

from metatensor.models import soap_bpnn
from metatensor.models.utils.data import read_structures


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


def test_compute_model_loss():
"""Test that the model loss is computed."""

loss_fn = TensorMapDictLoss(
weights={
"energy": {"values": 1.0, "positions": 10.0},
}
)

model = soap_bpnn.Model(all_species=[1, 6, 7, 8])
model = torch.jit.script(model) # jit the model for good measure

structures = read_structures(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:5]

gradient_samples = Labels(
names=["sample", "atom"],
values=torch.stack([
torch.concatenate([torch.tensor([i]*len(structure)) for i, structure in enumerate(structures)]),
torch.concatenate([torch.arange(len(structure)) for structure in structures]),
], dim=1),
)

gradient_components = [
Labels(
names=["coordinate"],
values=torch.tensor([[0], [1], [2]]),
)
]

block = TensorBlock(
values=torch.tensor([[0.0]*len(structures)]).T,
samples=Labels.range("structure", len(structures)),
components=[],
properties=Labels.single(),
)

block.add_gradient(
"positions",
TensorBlock(
values=torch.tensor([[[1.0], [1.0], [1.0]] for structure in structures for _ in range(len(structure.positions))]),
samples=gradient_samples,
components=gradient_components,
properties=Labels.single(),
)
)

targets = {
"energy": TensorMap(
keys=Labels(
names=["lambda", "sigma"],
values=torch.tensor([[0, 1]]),
),
blocks=[block]
),
}

loss = compute_model_loss(
loss_fn,
model,
structures,
targets,
)


Loading

0 comments on commit 8df7e41

Please sign in to comment.