Skip to content

Commit

Permalink
Warn about missing units when exporting
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 2, 2024
1 parent 2431ef3 commit 87f8611
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/metatensor/models/cli/export_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import warnings

from metatensor.torch.atomistic import MetatensorAtomisticModel

Expand Down Expand Up @@ -36,7 +37,7 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:


def export_model(model: str, output: str) -> None:
"""Export a pretrained model to run MD simulations
"""Export a pre-trained model to run MD simulations
:param model: Path to a saved model
:param output: Path to save the exported model
Expand All @@ -45,6 +46,16 @@ def export_model(model: str, output: str) -> None:
# Load the model
loaded_model = load_model(model)

# Warn if the units are not provided for one or more of the model's possible outputs
for model_output_name, model_output in loaded_model.capabilities.outputs.items():
if model_output.unit == "":
warnings.warn(
f"No units were provided for the `{model_output_name}` output. "
"As a result, this model output will be passed to MD engines as is.",
UserWarning,
stacklevel=1,
)

# Export the model
wrapper = MetatensorAtomisticModel(loaded_model.eval(), loaded_model.capabilities)
wrapper.export(output)

0 comments on commit 87f8611

Please sign in to comment.