Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up logging #45

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 26 additions & 42 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Union

import rascaline
import torch
from metatensor.torch.atomistic import ModelCapabilities

Expand All @@ -15,13 +17,20 @@
get_all_targets,
)
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.logging import MetricLogger
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from .model import DEFAULT_HYPERS, Model


logger = logging.getLogger(__name__)

# disable rascaline logger
rascaline.set_logging_callback(lambda x, y: None)

# Filter out the second derivative warning from rascaline-torch
warnings.filterwarnings("ignore", category=UserWarning, message="second derivative")


def train(
train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
Expand Down Expand Up @@ -95,20 +104,11 @@ def train(

# Extract all the possible outputs and their gradients from the training set:
outputs_dict = _get_outputs_dict(train_datasets)
energy_counter = 0
for output_name in outputs_dict.keys():
if output_name not in model_capabilities.outputs:
raise ValueError(
f"Output {output_name} is not in the model's capabilities."
)
if model_capabilities.outputs[output_name].quantity == "energy":
energy_counter += 1

# This will be useful later for printing forces/virials/stresses:
if energy_counter == 1:
only_one_energy = True
else:
only_one_energy = False

# Create a loss weight dict:
loss_weights_dict = {}
Expand Down Expand Up @@ -145,7 +145,7 @@ def train(
loss.backward()
optimizer.step()
aggregated_train_info = update_aggregated_info(aggregated_train_info, info)
aggregated_train_info = finalize_aggregated_info(aggregated_train_info)
finalized_train_info = finalize_aggregated_info(aggregated_train_info)

validation_loss = 0.0
for batch in validation_dataloader:
Expand All @@ -156,41 +156,25 @@ def train(
aggregated_validation_info = update_aggregated_info(
aggregated_validation_info, info
)
aggregated_validation_info = finalize_aggregated_info(
aggregated_validation_info
)
finalized_validation_info = finalize_aggregated_info(aggregated_validation_info)

# Now we log the information:
if epoch == 0:
metric_logger = MetricLogger(
model_capabilities,
train_loss,
validation_loss,
finalized_train_info,
finalized_validation_info,
)
Comment on lines +162 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you initialize the logger before entering the training loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it needs the values from the first epoch to format things properly

if epoch % hypers_training["log_interval"] == 0:
logging_string = (
f"Epoch {epoch:4}, train loss: {train_loss:10.4f}, "
f"validation loss: {validation_loss:10.4f}"
metric_logger.log(
epoch,
train_loss,
validation_loss,
finalized_train_info,
finalized_validation_info,
)
for name, information_holder in zip(
["train", "valid"], [aggregated_train_info, aggregated_validation_info]
):
for key, value in information_holder.items():
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a force, replace the ugly name with "force"
if only_one_energy:
key = "force"
else:
key = f"force[{target_name}]"
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if only_one_energy:
key = "virial/stress"
else:
key = f"virial/stress[{target_name}]"
logging_string += f", {name} {key} RMSE: {value:10.4f}"
logger.info(logging_string)

if epoch % hypers_training["checkpoint_interval"] == 0:
save_model(
Expand All @@ -206,7 +190,7 @@ def train(
epochs_without_improvement += 1
if epochs_without_improvement >= 50:
logger.info(
f"Early stopping criterion reached after {epoch} "
"Early stopping criterion reached after 50 "
"epochs without improvement."
)
break
Expand Down
4 changes: 3 additions & 1 deletion src/metatensor/models/utils/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def update_aggregated_info(
return aggregated_info


def finalize_aggregated_info(aggregated_info):
def finalize_aggregated_info(
aggregated_info: Dict[str, Tuple[float, int]]
) -> Dict[str, float]:
"""
Finalize the aggregated information dictionaryby calculating RMSEs.

Expand Down
156 changes: 156 additions & 0 deletions src/metatensor/models/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import logging
from typing import Dict, Tuple

import numpy as np
from metatensor.torch.atomistic import ModelCapabilities


logger = logging.getLogger(__name__)


class MetricLogger:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You promised also testing for this in your issue. Maybe at least one test that is checking if a log is produced should be added.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will test after we find a way to log the way we want

"""This class provides a simple interface to log training metrics to a file."""

def __init__(
self,
model_capabilities: ModelCapabilities,
train_loss_0: float,
validation_loss_0: float,
train_info_0: Dict[str, float],
validation_info_0: Dict[str, float],
):
"""
Initialize the logger with metrics that are supposed to
decrease during training.

In this way, the logger can align the output to make it easier to read.

Args:
model_capabilities: The capabilities of the model.
train_loss_0: The initial training loss.
validation_loss_0: The initial validation loss.
train_info_0: The initial training metrics.
validation_info_0: The initial validation metrics.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not how we write docstrings. Can you adjust this please?

"""

# Since the quantities are supposed to decrease, we want to store the
# number of digits at the start of the training, so that we can align
# the output later:
self.digits = {}
self.digits["train_loss"] = _get_digits(train_loss_0)
self.digits["validation_loss"] = _get_digits(validation_loss_0)
for name, information_holder in zip(
["train", "valid"], [train_info_0, validation_info_0]
):
for key, value in information_holder.items():
self.digits[f"{name}_{key}"] = _get_digits(value)

# This will be useful later for printing forces/virials/stresses:
energy_counter = 0
for output in model_capabilities.outputs.values():
if output.quantity == "energy":
energy_counter += 1
if energy_counter == 1:
self.only_one_energy = True
else:
self.only_one_energy = False

# Save the model capabilities for later use:
self.model_capabilities = model_capabilities

def log(
self,
epoch: int,
train_loss: float,
validation_loss: float,
train_info: Dict[str, float],
validation_info: Dict[str, float],
):
"""
Log the training metrics.

The training metrics are automatically aligned to make them easier to read,
based on the order of magnitude of each metric at the start of the training.

Args:
epoch: The current epoch.
train_loss: The current training loss.
validation_loss: The current validation loss.
train_info: The current training metrics.
validation_info: The current validation metrics.
"""

# The epoch is printed with 4 digits, assuming that the training
# will not last more than 9999 epochs
logging_string = (
f"Epoch {epoch:4}, train loss: "
f"{train_loss:{self.digits['train_loss'][0]}.{self.digits['train_loss'][1]}f}, " # noqa: E501
f"validation loss: "
f"{validation_loss:{self.digits['validation_loss'][0]}.{self.digits['validation_loss'][1]}f}" # noqa: E501
)
for name, information_holder in zip(
["train", "valid"], [train_info, validation_info]
):
for key, value in information_holder.items():
new_key = key
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if (
self.model_capabilities.outputs[target_name].quantity
== "energy"
):
# if this is a force, replace the ugly name with "force"
if self.only_one_energy:
new_key = "force"
else:
new_key = f"force[{target_name}]"
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if (
self.model_capabilities.outputs[target_name].quantity
== "energy"
):
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if self.only_one_energy:
new_key = "virial/stress"
else:
new_key = f"virial/stress[{target_name}]"
logging_string += (
f", {name} {new_key} RMSE: "
f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501
)
logger.info(logging_string)


def _get_digits(value: float) -> Tuple[int, int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this chatGPT code?

Looks very cumbersome.

Python has a Decimal library which does this for you. i.e

t = Decimal(1.532).normalize().as_tuple()

gives:

DecimalTuple(sign=0, digits=(1, 5, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 8, 4, 2, 1, 7, 0, 9, 4, 3), exponent=-26)

which I think contains all the information you need in a single line. You can event get the mantissa and the exponent if you like...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I don't think I can use it to do what I want... Yes, what I do is quite cumbersome but I think it's the best solution for now. The function essentially gets a number and decides what parts of the number should be printed so that there are always at least 5 significant digits

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with %12.5e or equivalent?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the strong feeling that many of the people we're dealing with won't be comfortable with exponential notation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would go for format specifiers: f"something {some_float:12.5e}". See documentation here: https://docs.python.org/3.11/library/string.html#formatspec

Copy link
Member

@Luthaf Luthaf Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the strong feeling that many of the people we're dealing with won't be comfortable with exponential notation

I'm not sure about this, but we can also use g specifiers instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Maybe the g is the better option then f

"""
Finds the number of digits to print before and after the decimal point,
based on the order of magnitude of the value.

5 "significant" digits are guaranteed to be printed.

Args:
value: The value for which the number of digits is calculated.
"""

# Get order of magnitude of the value:
order = int(np.floor(np.log10(value)))

# Get the number of digits before the decimal point:
if order < 0:
digits_before = 1
else:
digits_before = order + 1

# Get the number of digits after the decimal point:
if order < 0:
digits_after = 4 - order
else:
digits_after = max(1, 4 - order)

total_characters = digits_before + digits_after + 1 # +1 for the point

return total_characters, digits_after
Loading