Skip to content

Commit

Permalink
Provide better logging (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jan 26, 2024
1 parent 1e487a9 commit e07e2a8
Show file tree
Hide file tree
Showing 14 changed files with 2,636 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The main entry point for the metatensor-models interface."""

import argparse
import sys
from pathlib import Path
Expand Down
69 changes: 56 additions & 13 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union

import torch
from metatensor.torch.atomistic import ModelCapabilities
Expand All @@ -14,6 +14,7 @@
combine_dataloaders,
get_all_targets,
)
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from .model import DEFAULT_HYPERS, Model
Expand Down Expand Up @@ -91,11 +92,20 @@ def train(

# Extract all the possible outputs and their gradients from the training set:
outputs_dict = _get_outputs_dict(train_datasets)
energy_counter = 0
for output_name in outputs_dict.keys():
if output_name not in model_capabilities.outputs:
raise ValueError(
f"Output {output_name} is not in the model's capabilities."
)
if model_capabilities.outputs[output_name].quantity == "energy":
energy_counter += 1

# This will be useful later for printing forces/virials/stresses:
if energy_counter == 1:
only_one_energy = True
else:
only_one_energy = False

# Create a loss weight dict:
loss_weights_dict = {}
Expand All @@ -107,11 +117,6 @@ def train(
# Create a loss function:
loss_fn = TensorMapDictLoss(loss_weights_dict)

# Create a loss function:
loss_fn = TensorMapDictLoss(
{target_name: {"values": 1.0}},
)

# Create an optimizer:
optimizer = torch.optim.Adam(
model.parameters(), lr=hypers_training["learning_rate"]
Expand All @@ -123,27 +128,65 @@ def train(

# Train the model:
for epoch in range(hypers_training["num_epochs"]):
# aggregated information holders:
aggregated_train_info: Dict[str, Tuple[float, int]] = {}
aggregated_validation_info: Dict[str, Tuple[float, int]] = {}

train_loss = 0.0
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
loss = compute_model_loss(loss_fn, model, structures, targets)
loss, info = compute_model_loss(loss_fn, model, structures, targets)
train_loss += loss.item()
loss.backward()
optimizer.step()
aggregated_train_info = update_aggregated_info(aggregated_train_info, info)
aggregated_train_info = finalize_aggregated_info(aggregated_train_info)

validation_loss = 0.0
for batch in validation_dataloader:
structures, targets = batch
# TODO: specify that the model is not training here to save some autograd
loss = compute_model_loss(loss_fn, model, structures, targets)
loss, info = compute_model_loss(loss_fn, model, structures, targets)
validation_loss += loss.item()
aggregated_validation_info = update_aggregated_info(
aggregated_validation_info, info
)
aggregated_validation_info = finalize_aggregated_info(
aggregated_validation_info
)

# Now we log the information:
if epoch % hypers_training["log_interval"] == 0:
logger.info(
f"Epoch {epoch}, train loss: {train_loss:.4f}, "
f"validation loss: {validation_loss:.4f}"
logging_string = (
f"Epoch {epoch:4}, train loss: {train_loss:10.4f}, "
f"validation loss: {validation_loss:10.4f}"
)
for name, information_holder in zip(
["train", "valid"], [aggregated_train_info, aggregated_validation_info]
):
for key, value in information_holder.items():
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a force, replace the ugly name with "force"
if only_one_energy:
key = "force"
else:
key = f"force[{target_name}]"
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if only_one_energy:
key = "virial/stress"
else:
key = f"virial/stress[{target_name}]"
logging_string += f", {name} {key} RMSE: {value:10.4f}"
logger.info(logging_string)

if epoch % hypers_training["checkpoint_interval"] == 0:
save_model(
Expand All @@ -167,12 +210,12 @@ def train(
return model


def _get_outputs_dict(datasets: List[Dataset]):
def _get_outputs_dict(datasets: List[Union[Dataset, torch.utils.data.Subset]]):
"""
This is a helper function that extracts all the possible outputs and their gradients
from a list of datasets.
:param datasets: A list of datasets.
:param datasets: A list of Datasets or Subsets.
:returns: A dictionary mapping output names to a list of "values" (always)
and possible gradients.
Expand Down
15 changes: 7 additions & 8 deletions src/metatensor/models/utils/compute_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def compute_model_loss(
# Check if the energy requires gradients:
if targets[target_name].block().has_gradient("positions"):
energy_targets_that_require_position_gradients.append(target_name)
if targets[target_name].block().has_gradient("displacements"):
if targets[target_name].block().has_gradient("displacement"):
energy_targets_that_require_displacement_gradients.append(target_name)

if len(energy_targets_that_require_displacement_gradients) > 0:
Expand Down Expand Up @@ -93,7 +93,7 @@ def compute_model_loss(
"positions", _position_gradients_to_block(gradients[: len(systems)])
)
new_block.add_gradient(
"displacements",
"displacement",
_displacement_gradients_to_block(gradients[len(systems) :]),
)
new_energy_tensor_map = TensorMap(
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_model_loss(
old_energy_tensor_map = model_outputs[energy_target]
new_block = old_energy_tensor_map.block().copy()
new_block.add_gradient(
"displacements", _displacement_gradients_to_block(gradients)
"displacement", _displacement_gradients_to_block(gradients)
)
new_energy_tensor_map = TensorMap(
keys=old_energy_tensor_map.keys,
Expand All @@ -134,7 +134,7 @@ def compute_model_loss(
else:
pass

# Compute the loss:
# Compute and return the loss and associated info:
return loss(model_outputs, targets)


Expand Down Expand Up @@ -166,7 +166,7 @@ def _position_gradients_to_block(gradients_list):

components = [
Labels(
names=["coordinate"],
names=["direction"],
values=torch.tensor([[0], [1], [2]]),
)
]
Expand All @@ -183,8 +183,7 @@ def _displacement_gradients_to_block(gradients_list):
"""Convert a list of displacement gradients to a `TensorBlock`
which can act as a gradient block to an energy block."""

# `gradients` consists of a list of tensors where the second dimension is 3
gradients = torch.concatenate(gradients_list, dim=0).unsqueeze(-1)
gradients = torch.stack(gradients_list, dim=0).unsqueeze(-1)
# unsqueeze for the property dimension

samples = Labels(
Expand All @@ -193,7 +192,7 @@ def _displacement_gradients_to_block(gradients_list):

components = [
Labels(
names=["cell vector"],
names=["cell_vector"],
values=torch.tensor([[0], [1], [2]]),
),
Labels(
Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def _read_virial_stress_ase(
samples = Labels(["sample"], torch.tensor([[s] for s in range(n_structures)]))

components = [
Labels(["direction_1"], torch.arange(3).reshape(-1, 1)),
Labels(["direction_2"], torch.arange(3).reshape(-1, 1)),
Labels(["cell_vector"], torch.arange(3).reshape(-1, 1)),
Labels(["coordinate"], torch.arange(3).reshape(-1, 1)),
]

block = TensorBlock(
Expand Down
54 changes: 54 additions & 0 deletions src/metatensor/models/utils/info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict, Tuple


def update_aggregated_info(
aggregated_info: Dict[str, Tuple[float, int]],
new_info: Dict[str, Tuple[float, int]],
):
"""
Update the aggregated information dictionary with new information.
For now, the new_info must be a dictionary of tuples, where the first
element is a sum of squared errors and the second element is the
number of samples.
If a key is present in both dictionaries, the values are added.
If a key is present in ``new_info`` but not ``aggregated_info``,
it is simply copied.
:param aggregated_info: The aggregated information dictionary.
:param new_info: The new information dictionary.
:returns: The updated aggregated information dictionary.
"""

for key, value in new_info.items():
if key in aggregated_info:
aggregated_info[key] = (
aggregated_info[key][0] + value[0],
aggregated_info[key][1] + value[1],
)
else:
aggregated_info[key] = value

return aggregated_info


def finalize_aggregated_info(aggregated_info):
"""
Finalize the aggregated information dictionaryby calculating RMSEs.
For now, the aggregated_info must be a dictionary of tuples, where the first
element is a sum of squared errors and the second element is the
number of samples.
:param aggregated_info: The aggregated information dictionary.
:returns: The finalized aggregated information dictionary.
"""

finalized_info = {}
for key, value in aggregated_info.items():
finalized_info[key] = (value[0] / value[1]) ** 0.5

return finalized_info
58 changes: 41 additions & 17 deletions src/metatensor/models/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Tuple

import torch
from metatensor.torch import TensorMap
Expand All @@ -21,7 +21,9 @@ class TensorMapLoss:
:param weight: The weight to apply to the loss on the block values.
:param gradient_weights: The weights to apply to the loss on the gradients.
:returns: The loss as a scalar `torch.Tensor`.
:returns: The loss as a scalar `torch.Tensor`, as well as an information
dictionary with the sum of squared errors and number of samples for values
and each of the gradients.
"""

def __init__(
Expand All @@ -36,7 +38,7 @@ def __init__(

def __call__(
self, tensor_map_1: TensorMap, tensor_map_2: TensorMap
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]:
# Check that the two have the same metadata, except for the samples,
# which can be different due to batching, but must have the same size:
if tensor_map_1.keys != tensor_map_2.keys:
Expand Down Expand Up @@ -87,22 +89,32 @@ def __call__(
"TensorMapLoss does not yet support multiple symmetry keys."
)

# Compute the loss:
# Compute the loss and info:
loss = torch.zeros(
(),
dtype=tensor_map_1.block().values.dtype,
device=tensor_map_1.block().values.device,
)
loss += self.weight * self.loss(
tensor_map_1.block().values, tensor_map_2.block().values
info = {}

values_1 = tensor_map_1.block().values
values_2 = tensor_map_2.block().values
loss += self.weight * self.loss(values_1, values_2)
info["values"] = (
torch.sum((values_1 - values_2) ** 2).item(),
values_1.numel(),
)

for gradient_name, gradient_weight in self.gradient_weights.items():
loss += gradient_weight * self.loss(
tensor_map_1.block().gradient(gradient_name).values,
tensor_map_2.block().gradient(gradient_name).values,
values_1 = tensor_map_1.block().gradient(gradient_name).values
values_2 = tensor_map_2.block().gradient(gradient_name).values
loss += gradient_weight * self.loss(values_1, values_2)
info[gradient_name] = (
torch.sum((values_1 - values_2) ** 2).item(),
values_1.numel(),
)

return loss
return loss, info


class TensorMapDictLoss:
Expand All @@ -121,7 +133,9 @@ class TensorMapDictLoss:
the gradients.
:param reduction: The reduction to apply to the loss. See `torch.nn.MSELoss`.
:returns: The loss as a scalar `torch.Tensor`.
:returns: The loss as a scalar `torch.Tensor`, as well as an information
dictionary with the sum of squared errors and number of samples for values
and each of the gradients.
"""

def __init__(
Expand All @@ -142,16 +156,26 @@ def __call__(
self,
tensor_map_dict_1: Dict[str, TensorMap],
tensor_map_dict_2: Dict[str, TensorMap],
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]:
# Assert that the two have the keys:
assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys())

# Initialize the loss:
first_values = next(iter(tensor_map_dict_1.values())).block(0).values
loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device)
info = {}

# Compute the loss:
for key in tensor_map_dict_1.keys():
loss += self.losses[key](tensor_map_dict_1[key], tensor_map_dict_2[key])

return loss
# Compute the loss and associated info:
for target in tensor_map_dict_1.keys():
target_loss, target_info = self.losses[target](
tensor_map_dict_1[target], tensor_map_dict_2[target]
)
loss += target_loss
info[target] = target_info["values"]
for gradient_name in target_info.keys():
if gradient_name != "values":
info[f"{target}_{gradient_name}_gradients"] = target_info[
gradient_name
]

return loss, info
Loading

0 comments on commit e07e2a8

Please sign in to comment.