Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add length_unit + cleanup #35

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ metatensor-models

|tests| |docs|

.. warning::

**metatensor-models is still very early in the concept stage. You should not use it
for anything important.**

This is a repository for models using metatensor, in one shape or another. The only
requirement is for these models to be able to take metatensor objects as inputs and
outputs. The models do not need to live entirely in this repository: in the most extreme
Expand Down Expand Up @@ -32,7 +37,7 @@ Documentation
------------

For details, tutorials, and examples, please have a look at our
[documentation](https://lab-cosmo.github.io/metatensor-models/latest/).
`documentation <https://lab-cosmo.github.io/metatensor-models/latest/>`_.

.. marker-installation

Expand Down
79 changes: 41 additions & 38 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,50 +96,50 @@ def train_model(options: DictConfig) -> None:
generator = torch.Generator()

logger.info("Setting up training set")
conf_training_set = expand_dataset_config(options["training_set"])
structures_train = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
train_options = expand_dataset_config(options["training_set"])
train_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
targets_train = read_targets(conf_training_set["targets"])
train_dataset = Dataset(structures_train, targets_train)
train_targets = read_targets(train_options["targets"])
train_dataset = Dataset(train_structures, train_targets)

logger.info("Setting up test set")
conf_test_set = options["test_set"]
if not isinstance(conf_test_set, float):
conf_test_set = expand_dataset_config(conf_test_set)
structures_test = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
test_options = options["test_set"]
if not isinstance(test_options, float):
test_options = expand_dataset_config(test_options)
test_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
targets_test = read_targets(conf_test_set["targets"])
test_dataset = Dataset(structures_test, targets_test)
fraction_test_set = 0.0
test_targets = read_targets(test_options["targets"])
test_dataset = Dataset(test_structures, test_targets)
test_fraction = 0.0
else:
if conf_test_set < 0 or conf_test_set >= 1:
if test_options < 0 or test_options >= 1:
raise ValueError("Test set split must be between 0 and 1.")
fraction_test_set = conf_test_set
test_fraction = test_options

logger.info("Setting up validation set")
conf_validation_set = options["validation_set"]
if not isinstance(conf_validation_set, float):
conf_validation_set = expand_dataset_config(conf_validation_set)
structures_validation = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
validation_options = options["validation_set"]
if not isinstance(validation_options, float):
validation_options = expand_dataset_config(validation_options)
validation_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
targets_validation = read_targets(conf_validation_set["targets"])
validation_dataset = Dataset(structures_validation, targets_validation)
fraction_validation_set = 0.0
validation_targets = read_targets(validation_options["targets"])
validation_dataset = Dataset(validation_structures, validation_targets)
validation_fraction = 0.0
else:
if conf_validation_set < 0 or conf_validation_set >= 1:
if validation_options < 0 or validation_options >= 1:
raise ValueError("Validation set split must be between 0 and 1.")
fraction_validation_set = conf_validation_set
validation_fraction = validation_options

# Split train dataset if requested
if fraction_test_set or fraction_validation_set:
fraction_train_set = 1 - fraction_test_set - fraction_validation_set
if fraction_train_set < 0:
if test_fraction or validation_fraction:
train_fraction = 1 - test_fraction - validation_fraction
if train_fraction < 0:
raise ValueError("fraction of the train set is smaller then 0!")

# ignore warning of possible empty dataset
Expand All @@ -148,17 +148,17 @@ def train_model(options: DictConfig) -> None:
subsets = torch.utils.data.random_split(
dataset=train_dataset,
lengths=[
fraction_train_set,
fraction_test_set,
fraction_validation_set,
train_fraction,
test_fraction,
validation_fraction,
],
generator=generator,
)

train_dataset = subsets[0]
if fraction_test_set and not fraction_validation_set:
if test_fraction and not validation_fraction:
test_dataset = subsets[1]
elif not fraction_validation_set and fraction_validation_set:
elif not validation_fraction and validation_fraction:
validation_dataset = subsets[1]
else:
test_dataset = subsets[1]
Expand All @@ -184,12 +184,13 @@ def train_model(options: DictConfig) -> None:
outputs = {
key: ModelOutput(
quantity=value["quantity"],
unit=(value["unit"] if value["unit"] is not None else ""), # potential HACK
unit=(value["unit"] if value["unit"] is not None else ""),
)
for key, value in options["training_set"]["targets"].items()
}
length_unit = train_options["structures"]["length_unit"]
model_capabilities = ModelCapabilities(
length_unit="Angstrom",
length_unit=length_unit if length_unit is not None else "",
species=all_species,
outputs=outputs,
)
Expand All @@ -203,3 +204,5 @@ def train_model(options: DictConfig) -> None:
)

save_model(model, options["output_path"])

# TODO: add evaluation of the test set
1 change: 1 addition & 0 deletions src/metatensor/models/utils/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _resolve_single_str(config):
"read_from": "${..read_from}",
"file_format": "${file_format:}",
"key": None,
"length_unit": None,
}
)

Expand Down
5 changes: 3 additions & 2 deletions tests/utils/test_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_expand_dataset_config():
file_name = "foo.xyz"
file_format = ".xyz"

structure_section = {"read_from": file_name, "unit": "angstrom"}
structure_section = {"read_from": file_name, "length_unit": "angstrom"}

target_section = {
"quantity": "energy",
Expand All @@ -31,7 +31,7 @@ def test_expand_dataset_config():

assert conf_expanded["structures"]["read_from"] == file_name
assert conf_expanded["structures"]["file_format"] == file_format
assert conf_expanded["structures"]["unit"] == "angstrom"
assert conf_expanded["structures"]["length_unit"] == "angstrom"

targets_conf = conf_expanded["targets"]
assert len(targets_conf) == 2
Expand Down Expand Up @@ -89,6 +89,7 @@ def test_expand_dataset_config_min():

assert conf_expanded["structures"]["read_from"] == file_name
assert conf_expanded["structures"]["file_format"] == file_format
assert conf_expanded["structures"]["length_unit"] is None

targets_conf = conf_expanded["targets"]
assert targets_conf["energy"]["quantity"] == "energy"
Expand Down