-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up logging #45
Clean up logging #45
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import logging | ||
from typing import Dict, Tuple | ||
|
||
import numpy as np | ||
from metatensor.torch.atomistic import ModelCapabilities | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MetricLogger: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You promised also testing for this in your issue. Maybe at least one test that is checking if a log is produced should be added. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will test after we find a way to log the way we want |
||
"""This class provides a simple interface to log training metrics to a file.""" | ||
|
||
def __init__( | ||
self, | ||
model_capabilities: ModelCapabilities, | ||
train_loss_0: float, | ||
validation_loss_0: float, | ||
train_info_0: Dict[str, float], | ||
validation_info_0: Dict[str, float], | ||
): | ||
""" | ||
Initialize the logger with metrics that are supposed to | ||
decrease during training. | ||
|
||
In this way, the logger can align the output to make it easier to read. | ||
|
||
Args: | ||
model_capabilities: The capabilities of the model. | ||
train_loss_0: The initial training loss. | ||
validation_loss_0: The initial validation loss. | ||
train_info_0: The initial training metrics. | ||
validation_info_0: The initial validation metrics. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not how we write docstrings. Can you adjust this please? |
||
""" | ||
|
||
# Since the quantities are supposed to decrease, we want to store the | ||
# number of digits at the start of the training, so that we can align | ||
# the output later: | ||
self.digits = {} | ||
self.digits["train_loss"] = _get_digits(train_loss_0) | ||
self.digits["validation_loss"] = _get_digits(validation_loss_0) | ||
for name, information_holder in zip( | ||
["train", "valid"], [train_info_0, validation_info_0] | ||
): | ||
for key, value in information_holder.items(): | ||
self.digits[f"{name}_{key}"] = _get_digits(value) | ||
|
||
# This will be useful later for printing forces/virials/stresses: | ||
energy_counter = 0 | ||
for output in model_capabilities.outputs.values(): | ||
if output.quantity == "energy": | ||
energy_counter += 1 | ||
if energy_counter == 1: | ||
self.only_one_energy = True | ||
else: | ||
self.only_one_energy = False | ||
|
||
# Save the model capabilities for later use: | ||
self.model_capabilities = model_capabilities | ||
|
||
def log( | ||
self, | ||
epoch: int, | ||
train_loss: float, | ||
validation_loss: float, | ||
train_info: Dict[str, float], | ||
validation_info: Dict[str, float], | ||
): | ||
""" | ||
Log the training metrics. | ||
|
||
The training metrics are automatically aligned to make them easier to read, | ||
based on the order of magnitude of each metric at the start of the training. | ||
|
||
Args: | ||
epoch: The current epoch. | ||
train_loss: The current training loss. | ||
validation_loss: The current validation loss. | ||
train_info: The current training metrics. | ||
validation_info: The current validation metrics. | ||
""" | ||
|
||
# The epoch is printed with 4 digits, assuming that the training | ||
# will not last more than 9999 epochs | ||
logging_string = ( | ||
f"Epoch {epoch:4}, train loss: " | ||
f"{train_loss:{self.digits['train_loss'][0]}.{self.digits['train_loss'][1]}f}, " # noqa: E501 | ||
f"validation loss: " | ||
f"{validation_loss:{self.digits['validation_loss'][0]}.{self.digits['validation_loss'][1]}f}" # noqa: E501 | ||
) | ||
for name, information_holder in zip( | ||
["train", "valid"], [train_info, validation_info] | ||
): | ||
for key, value in information_holder.items(): | ||
new_key = key | ||
if key.endswith("_positions_gradients"): | ||
# check if this is a force | ||
target_name = key[: -len("_positions_gradients")] | ||
if ( | ||
self.model_capabilities.outputs[target_name].quantity | ||
== "energy" | ||
): | ||
# if this is a force, replace the ugly name with "force" | ||
if self.only_one_energy: | ||
new_key = "force" | ||
else: | ||
new_key = f"force[{target_name}]" | ||
elif key.endswith("_displacement_gradients"): | ||
# check if this is a virial/stress | ||
target_name = key[: -len("_displacement_gradients")] | ||
if ( | ||
self.model_capabilities.outputs[target_name].quantity | ||
== "energy" | ||
): | ||
# if this is a virial/stress, | ||
# replace the ugly name with "virial/stress" | ||
if self.only_one_energy: | ||
new_key = "virial/stress" | ||
else: | ||
new_key = f"virial/stress[{target_name}]" | ||
logging_string += ( | ||
f", {name} {new_key} RMSE: " | ||
f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501 | ||
) | ||
logger.info(logging_string) | ||
|
||
|
||
def _get_digits(value: float) -> Tuple[int, int]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this chatGPT code? Looks very cumbersome. Python has a Decimal library which does this for you. i.e t = Decimal(1.532).normalize().as_tuple() gives: DecimalTuple(sign=0, digits=(1, 5, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 8, 4, 2, 1, 7, 0, 9, 4, 3), exponent=-26) which I think contains all the information you need in a single line. You can event get the mantissa and the exponent if you like... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm I don't think I can use it to do what I want... Yes, what I do is quite cumbersome but I think it's the best solution for now. The function essentially gets a number and decides what parts of the number should be printed so that there are always at least 5 significant digits There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's wrong with %12.5e or equivalent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have the strong feeling that many of the people we're dealing with won't be comfortable with exponential notation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I would go for format specifiers: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this, but we can also use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. Maybe the |
||
""" | ||
Finds the number of digits to print before and after the decimal point, | ||
based on the order of magnitude of the value. | ||
|
||
5 "significant" digits are guaranteed to be printed. | ||
|
||
Args: | ||
value: The value for which the number of digits is calculated. | ||
""" | ||
|
||
# Get order of magnitude of the value: | ||
order = int(np.floor(np.log10(value))) | ||
|
||
# Get the number of digits before the decimal point: | ||
if order < 0: | ||
digits_before = 1 | ||
else: | ||
digits_before = order + 1 | ||
|
||
# Get the number of digits after the decimal point: | ||
if order < 0: | ||
digits_after = 4 - order | ||
else: | ||
digits_after = max(1, 4 - order) | ||
|
||
total_characters = digits_before + digits_after + 1 # +1 for the point | ||
|
||
return total_characters, digits_after |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you initialize the logger before entering the training loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it needs the values from the first epoch to format things properly