-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up logging #45
Clean up logging #45
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import logging | ||
from typing import Dict, Tuple | ||
|
||
import numpy as np | ||
from metatensor.torch.atomistic import ModelCapabilities | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MetricLogger: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You promised also testing for this in your issue. Maybe at least one test that is checking if a log is produced should be added. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will test after we find a way to log the way we want |
||
"""This class provides a simple interface to log training metrics to a file.""" | ||
|
||
def __init__( | ||
self, | ||
model_capabilities: ModelCapabilities, | ||
train_loss_0: float, | ||
validation_loss_0: float, | ||
train_info_0: Dict[str, float], | ||
validation_info_0: Dict[str, float], | ||
): | ||
""" | ||
Initialize the logger with metrics that are supposed to | ||
decrease during training. | ||
|
||
In this way, the logger can align the output to make it easier to read. | ||
|
||
:param model_capabilities: The capabilities of the model. | ||
:param train_loss_0: The initial training loss. | ||
:param validation_loss_0: The initial validation loss. | ||
:param train_info_0: The initial training metrics. | ||
:param validation_info_0: The initial validation metrics. | ||
""" | ||
|
||
# Since the quantities are supposed to decrease, we want to store the | ||
# number of digits at the start of the training, so that we can align | ||
# the output later: | ||
self.digits = {} | ||
self.digits["train_loss"] = _get_digits(train_loss_0) | ||
self.digits["validation_loss"] = _get_digits(validation_loss_0) | ||
for name, information_holder in zip( | ||
["train", "valid"], [train_info_0, validation_info_0] | ||
): | ||
for key, value in information_holder.items(): | ||
self.digits[f"{name}_{key}"] = _get_digits(value) | ||
|
||
# This will be useful later for printing forces/virials/stresses: | ||
energy_counter = 0 | ||
for output in model_capabilities.outputs.values(): | ||
if output.quantity == "energy": | ||
energy_counter += 1 | ||
if energy_counter == 1: | ||
self.only_one_energy = True | ||
else: | ||
self.only_one_energy = False | ||
|
||
# Save the model capabilities for later use: | ||
self.model_capabilities = model_capabilities | ||
|
||
def log( | ||
self, | ||
epoch: int, | ||
train_loss: float, | ||
validation_loss: float, | ||
train_info: Dict[str, float], | ||
validation_info: Dict[str, float], | ||
): | ||
""" | ||
Log the training metrics. | ||
|
||
The training metrics are automatically aligned to make them easier to read, | ||
based on the order of magnitude of each metric at the start of the training. | ||
|
||
:param epoch: The current epoch. | ||
:param train_loss: The current training loss. | ||
:param validation_loss: The current validation loss. | ||
:param train_info: The current training metrics. | ||
:param validation_info: The current validation metrics. | ||
""" | ||
|
||
# The epoch is printed with 4 digits, assuming that the training | ||
# will not last more than 9999 epochs | ||
logging_string = ( | ||
f"Epoch {epoch:4}, train loss: " | ||
f"{train_loss:{self.digits['train_loss'][0]}.{self.digits['train_loss'][1]}f}, " # noqa: E501 | ||
f"validation loss: " | ||
f"{validation_loss:{self.digits['validation_loss'][0]}.{self.digits['validation_loss'][1]}f}" # noqa: E501 | ||
) | ||
for name, information_holder in zip( | ||
["train", "valid"], [train_info, validation_info] | ||
): | ||
for key, value in information_holder.items(): | ||
new_key = key | ||
if key.endswith("_positions_gradients"): | ||
# check if this is a force | ||
target_name = key[: -len("_positions_gradients")] | ||
if ( | ||
self.model_capabilities.outputs[target_name].quantity | ||
== "energy" | ||
): | ||
# if this is a force, replace the ugly name with "force" | ||
if self.only_one_energy: | ||
new_key = "force" | ||
else: | ||
new_key = f"force[{target_name}]" | ||
elif key.endswith("_displacement_gradients"): | ||
# check if this is a virial/stress | ||
target_name = key[: -len("_displacement_gradients")] | ||
if ( | ||
self.model_capabilities.outputs[target_name].quantity | ||
== "energy" | ||
): | ||
# if this is a virial/stress, | ||
# replace the ugly name with "virial/stress" | ||
if self.only_one_energy: | ||
new_key = "virial/stress" | ||
else: | ||
new_key = f"virial/stress[{target_name}]" | ||
logging_string += ( | ||
f", {name} {new_key} RMSE: " | ||
f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501 | ||
) | ||
logger.info(logging_string) | ||
|
||
|
||
def _get_digits(value: float) -> Tuple[int, int]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this chatGPT code? Looks very cumbersome. Python has a Decimal library which does this for you. i.e t = Decimal(1.532).normalize().as_tuple() gives: DecimalTuple(sign=0, digits=(1, 5, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 8, 4, 2, 1, 7, 0, 9, 4, 3), exponent=-26) which I think contains all the information you need in a single line. You can event get the mantissa and the exponent if you like... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm I don't think I can use it to do what I want... Yes, what I do is quite cumbersome but I think it's the best solution for now. The function essentially gets a number and decides what parts of the number should be printed so that there are always at least 5 significant digits There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's wrong with %12.5e or equivalent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have the strong feeling that many of the people we're dealing with won't be comfortable with exponential notation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I would go for format specifiers: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this, but we can also use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. Maybe the |
||
""" | ||
Finds the number of digits to print before and after the decimal point, | ||
based on the order of magnitude of the value. | ||
|
||
5 "significant" digits are guaranteed to be printed. | ||
|
||
:param value: The value for which the number of digits is calculated. | ||
""" | ||
|
||
# Get order of magnitude of the value: | ||
order = int(np.floor(np.log10(value))) | ||
|
||
# Get the number of digits before the decimal point: | ||
if order < 0: | ||
digits_before = 1 | ||
else: | ||
digits_before = order + 1 | ||
|
||
# Get the number of digits after the decimal point: | ||
if order < 0: | ||
digits_after = 4 - order | ||
else: | ||
digits_after = max(1, 4 - order) | ||
|
||
total_characters = digits_before + digits_after + 1 # +1 for the point | ||
|
||
return total_characters, digits_after |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you initialize the logger before entering the training loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it needs the values from the first epoch to format things properly