Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide better logging #37

Merged
merged 7 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The main entry point for the metatensor-models interface."""

import argparse
import sys
from pathlib import Path
Expand Down
69 changes: 56 additions & 13 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union

import torch
from metatensor.torch.atomistic import ModelCapabilities
Expand All @@ -14,6 +14,7 @@
combine_dataloaders,
get_all_targets,
)
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from .model import DEFAULT_HYPERS, Model
Expand Down Expand Up @@ -91,11 +92,20 @@ def train(

# Extract all the possible outputs and their gradients from the training set:
outputs_dict = _get_outputs_dict(train_datasets)
energy_counter = 0
for output_name in outputs_dict.keys():
if output_name not in model_capabilities.outputs:
raise ValueError(
f"Output {output_name} is not in the model's capabilities."
)
if model_capabilities.outputs[output_name].quantity == "energy":
energy_counter += 1

# This will be useful later for printing forces/virials/stresses:
if energy_counter == 1:
only_one_energy = True
else:
only_one_energy = False

# Create a loss weight dict:
loss_weights_dict = {}
Expand All @@ -107,11 +117,6 @@ def train(
# Create a loss function:
loss_fn = TensorMapDictLoss(loss_weights_dict)

# Create a loss function:
loss_fn = TensorMapDictLoss(
{target_name: {"values": 1.0}},
)

# Create an optimizer:
optimizer = torch.optim.Adam(
model.parameters(), lr=hypers_training["learning_rate"]
Expand All @@ -123,27 +128,65 @@ def train(

# Train the model:
for epoch in range(hypers_training["num_epochs"]):
# aggregated information holders:
aggregated_train_info: Dict[str, Tuple[float, int]] = {}
aggregated_validation_info: Dict[str, Tuple[float, int]] = {}

train_loss = 0.0
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
loss = compute_model_loss(loss_fn, model, structures, targets)
loss, info = compute_model_loss(loss_fn, model, structures, targets)
train_loss += loss.item()
loss.backward()
optimizer.step()
aggregated_train_info = update_aggregated_info(aggregated_train_info, info)
aggregated_train_info = finalize_aggregated_info(aggregated_train_info)

validation_loss = 0.0
for batch in validation_dataloader:
structures, targets = batch
# TODO: specify that the model is not training here to save some autograd
loss = compute_model_loss(loss_fn, model, structures, targets)
loss, info = compute_model_loss(loss_fn, model, structures, targets)
validation_loss += loss.item()
aggregated_validation_info = update_aggregated_info(
aggregated_validation_info, info
)
aggregated_validation_info = finalize_aggregated_info(
aggregated_validation_info
)

# Now we log the information:
if epoch % hypers_training["log_interval"] == 0:
logger.info(
f"Epoch {epoch}, train loss: {train_loss:.4f}, "
f"validation loss: {validation_loss:.4f}"
logging_string = (
f"Epoch {epoch:4}, train loss: {train_loss:10.4f}, "
f"validation loss: {validation_loss:10.4f}"
)
for name, information_holder in zip(
["train", "valid"], [aggregated_train_info, aggregated_validation_info]
):
for key, value in information_holder.items():
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a force, replace the ugly name with "force"
if only_one_energy:
key = "force"
else:
key = f"force[{target_name}]"
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if only_one_energy:
key = "virial/stress"
else:
key = f"virial/stress[{target_name}]"
logging_string += f", {name} {key} RMSE: {value:10.4f}"
logger.info(logging_string)

if epoch % hypers_training["checkpoint_interval"] == 0:
save_model(
Expand All @@ -167,12 +210,12 @@ def train(
return model


def _get_outputs_dict(datasets: List[Dataset]):
def _get_outputs_dict(datasets: List[Union[Dataset, torch.utils.data.Subset]]):
"""
This is a helper function that extracts all the possible outputs and their gradients
from a list of datasets.

:param datasets: A list of datasets.
:param datasets: A list of Datasets or Subsets.

:returns: A dictionary mapping output names to a list of "values" (always)
and possible gradients.
Expand Down
15 changes: 7 additions & 8 deletions src/metatensor/models/utils/compute_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def compute_model_loss(
# Check if the energy requires gradients:
if targets[target_name].block().has_gradient("positions"):
energy_targets_that_require_position_gradients.append(target_name)
if targets[target_name].block().has_gradient("displacements"):
if targets[target_name].block().has_gradient("displacement"):
energy_targets_that_require_displacement_gradients.append(target_name)

if len(energy_targets_that_require_displacement_gradients) > 0:
Expand Down Expand Up @@ -93,7 +93,7 @@ def compute_model_loss(
"positions", _position_gradients_to_block(gradients[: len(systems)])
)
new_block.add_gradient(
"displacements",
"displacement",
_displacement_gradients_to_block(gradients[len(systems) :]),
)
new_energy_tensor_map = TensorMap(
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_model_loss(
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient(
"displacements", _displacement_gradients_to_block(gradients)
"displacement", _displacement_gradients_to_block(gradients)
)
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
Expand All @@ -134,7 +134,7 @@ def compute_model_loss(
else:
pass

# Compute the loss:
# Compute and return the loss and associated info:
return loss(model_outputs, targets)


Expand Down Expand Up @@ -166,7 +166,7 @@ def _position_gradients_to_block(gradients_list):

components = [
Labels(
names=["coordinate"],
names=["direction"],
values=torch.tensor([[0], [1], [2]]),
)
]
Expand All @@ -183,8 +183,7 @@ def _displacement_gradients_to_block(gradients_list):
"""Convert a list of displacement gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""

# `gradients` consists of a list of tensors where the second dimension is 3
gradients = torch.concatenate(gradients_list, dim=0).unsqueeze(-1)
gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1)
# unsqueeze for the property dimension

samples = Labels(
Expand All @@ -193,7 +192,7 @@ def _displacement_gradients_to_block(gradients_list):

components = [
Labels(
names=["cell vector"],
names=["cell_vector"],
values=torch.tensor([[0], [1], [2]]),
),
Labels(
Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def _read_virial_stress_ase(
samples = Labels(["sample"], torch.tensor([[s] for s in range(n_structures)]))

components = [
Labels(["direction_1"], torch.arange(3).reshape(-1, 1)),
Labels(["direction_2"], torch.arange(3).reshape(-1, 1)),
Labels(["cell_vector"], torch.arange(3).reshape(-1, 1)),
Labels(["coordinate"], torch.arange(3).reshape(-1, 1)),
]

block = TensorBlock(
Expand Down
54 changes: 54 additions & 0 deletions src/metatensor/models/utils/info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict, Tuple


def update_aggregated_info(
aggregated_info: Dict[str, Tuple[float, int]],
new_info: Dict[str, Tuple[float, int]],
):
"""
Update the aggregated information dictionary with new information.

For now, the new_info must be a dictionary of tuples, where the first
element is a sum of squared errors and the second element is the
number of samples.

If a key is present in both dictionaries, the values are added.
If a key is present in ``new_info`` but not ``aggregated_info``,
it is simply copied.

:param aggregated_info: The aggregated information dictionary.
:param new_info: The new information dictionary.

:returns: The updated aggregated information dictionary.
"""

for key, value in new_info.items():
if key in aggregated_info:
aggregated_info[key] = (
aggregated_info[key][0] + value[0],
aggregated_info[key][1] + value[1],
)
else:
aggregated_info[key] = value

return aggregated_info


def finalize_aggregated_info(aggregated_info):
"""
Finalize the aggregated information dictionaryby calculating RMSEs.

For now, the aggregated_info must be a dictionary of tuples, where the first
element is a sum of squared errors and the second element is the
number of samples.

:param aggregated_info: The aggregated information dictionary.

:returns: The finalized aggregated information dictionary.
"""

finalized_info = {}
for key, value in aggregated_info.items():
finalized_info[key] = (value[0] / value[1]) ** 0.5

return finalized_info
58 changes: 41 additions & 17 deletions src/metatensor/models/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Tuple

import torch
from metatensor.torch import TensorMap
Expand All @@ -21,7 +21,9 @@ class TensorMapLoss:
:param weight: The weight to apply to the loss on the block values.
:param gradient_weights: The weights to apply to the loss on the gradients.

:returns: The loss as a scalar `torch.Tensor`.
:returns: The loss as a scalar `torch.Tensor`, as well as an information
dictionary with the sum of squared errors and number of samples for values
and each of the gradients.
"""

def __init__(
Expand All @@ -36,7 +38,7 @@ def __init__(

def __call__(
self, tensor_map_1: TensorMap, tensor_map_2: TensorMap
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]:
# Check that the two have the same metadata, except for the samples,
# which can be different due to batching, but must have the same size:
if tensor_map_1.keys != tensor_map_2.keys:
Expand Down Expand Up @@ -87,22 +89,32 @@ def __call__(
"TensorMapLoss does not yet support multiple symmetry keys."
)

# Compute the loss:
# Compute the loss and info:
loss = torch.zeros(
(),
dtype=tensor_map_1.block().values.dtype,
device=tensor_map_1.block().values.device,
)
loss += self.weight * self.loss(
tensor_map_1.block().values, tensor_map_2.block().values
info = {}

values_1 = tensor_map_1.block().values
values_2 = tensor_map_2.block().values
loss += self.weight * self.loss(values_1, values_2)
info["values"] = (
torch.sum((values_1 - values_2) ** 2).item(),
values_1.numel(),
)

for gradient_name, gradient_weight in self.gradient_weights.items():
loss += gradient_weight * self.loss(
tensor_map_1.block().gradient(gradient_name).values,
tensor_map_2.block().gradient(gradient_name).values,
values_1 = tensor_map_1.block().gradient(gradient_name).values
values_2 = tensor_map_2.block().gradient(gradient_name).values
loss += gradient_weight * self.loss(values_1, values_2)
info[gradient_name] = (
torch.sum((values_1 - values_2) ** 2).item(),
values_1.numel(),
)

return loss
return loss, info


class TensorMapDictLoss:
Expand All @@ -121,7 +133,9 @@ class TensorMapDictLoss:
the gradients.
:param reduction: The reduction to apply to the loss. See `torch.nn.MSELoss`.

:returns: The loss as a scalar `torch.Tensor`.
:returns: The loss as a scalar `torch.Tensor`, as well as an information
dictionary with the sum of squared errors and number of samples for values
and each of the gradients.
"""

def __init__(
Expand All @@ -142,16 +156,26 @@ def __call__(
self,
tensor_map_dict_1: Dict[str, TensorMap],
tensor_map_dict_2: Dict[str, TensorMap],
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]:
# Assert that the two have the keys:
assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys())

# Initialize the loss:
first_values = next(iter(tensor_map_dict_1.values())).block(0).values
loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device)
info = {}

# Compute the loss:
for key in tensor_map_dict_1.keys():
loss += self.losses[key](tensor_map_dict_1[key], tensor_map_dict_2[key])

return loss
# Compute the loss and associated info:
for target in tensor_map_dict_1.keys():
target_loss, target_info = self.losses[target](
tensor_map_dict_1[target], tensor_map_dict_2[target]
)
loss += target_loss
info[target] = target_info["values"]
for gradient_name in target_info.keys():
if gradient_name != "values":
info[f"{target}_{gradient_name}_gradients"] = target_info[
gradient_name
]

return loss, info
Loading
Loading