Skip to content

Commit

Permalink
Add length_units and rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 24, 2024
1 parent 50c5e36 commit b269f6a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 40 deletions.
79 changes: 41 additions & 38 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,50 +96,50 @@ def train_model(options: DictConfig) -> None:
generator = torch.Generator()

logger.info("Setting up training set")
conf_training_set = expand_dataset_config(options["training_set"])
structures_train = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
train_options = expand_dataset_config(options["training_set"])
train_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
targets_train = read_targets(conf_training_set["targets"])
train_dataset = Dataset(structures_train, targets_train)
train_targets = read_targets(train_options["targets"])
train_dataset = Dataset(train_structures, train_targets)

logger.info("Setting up test set")
conf_test_set = options["test_set"]
if not isinstance(conf_test_set, float):
conf_test_set = expand_dataset_config(conf_test_set)
structures_test = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
test_options = options["test_set"]
if not isinstance(test_options, float):
test_options = expand_dataset_config(test_options)
test_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
targets_test = read_targets(conf_test_set["targets"])
test_dataset = Dataset(structures_test, targets_test)
fraction_test_set = 0.0
test_targets = read_targets(test_options["targets"])
test_dataset = Dataset(test_structures, test_targets)
test_fraction = 0.0
else:
if conf_test_set < 0 or conf_test_set >= 1:
if test_options < 0 or test_options >= 1:
raise ValueError("Test set split must be between 0 and 1.")
fraction_test_set = conf_test_set
test_fraction = test_options

logger.info("Setting up validation set")
conf_validation_set = options["validation_set"]
if not isinstance(conf_validation_set, float):
conf_validation_set = expand_dataset_config(conf_validation_set)
structures_validation = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
validation_options = options["validation_set"]
if not isinstance(validation_options, float):
validation_options = expand_dataset_config(validation_options)
validation_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
targets_validation = read_targets(conf_validation_set["targets"])
validation_dataset = Dataset(structures_validation, targets_validation)
fraction_validation_set = 0.0
validation_targets = read_targets(validation_options["targets"])
validation_dataset = Dataset(validation_structures, validation_targets)
validation_fraction = 0.0
else:
if conf_validation_set < 0 or conf_validation_set >= 1:
if validation_options < 0 or validation_options >= 1:
raise ValueError("Validation set split must be between 0 and 1.")
fraction_validation_set = conf_validation_set
validation_fraction = validation_options

# Split train dataset if requested
if fraction_test_set or fraction_validation_set:
fraction_train_set = 1 - fraction_test_set - fraction_validation_set
if fraction_train_set < 0:
if test_fraction or validation_fraction:
train_fraction = 1 - test_fraction - validation_fraction
if train_fraction < 0:
raise ValueError("fraction of the train set is smaller then 0!")

# ignore warning of possible empty dataset
Expand All @@ -148,17 +148,17 @@ def train_model(options: DictConfig) -> None:
subsets = torch.utils.data.random_split(
dataset=train_dataset,
lengths=[
fraction_train_set,
fraction_test_set,
fraction_validation_set,
train_fraction,
test_fraction,
validation_fraction,
],
generator=generator,
)

train_dataset = subsets[0]
if fraction_test_set and not fraction_validation_set:
if test_fraction and not validation_fraction:
test_dataset = subsets[1]
elif not fraction_validation_set and fraction_validation_set:
elif not validation_fraction and validation_fraction:
validation_dataset = subsets[1]
else:
test_dataset = subsets[1]
Expand All @@ -184,12 +184,13 @@ def train_model(options: DictConfig) -> None:
outputs = {
key: ModelOutput(
quantity=value["quantity"],
unit=(value["unit"] if value["unit"] is not None else ""), # potential HACK
unit=(value["unit"] if value["unit"] is not None else ""),
)
for key, value in options["training_set"]["targets"].items()
}
length_unit = train_options["structures"]["length_unit"]
model_capabilities = ModelCapabilities(
length_unit="Angstrom",
length_unit=length_unit if length_unit is not None else "",
species=all_species,
outputs=outputs,
)
Expand All @@ -203,3 +204,5 @@ def train_model(options: DictConfig) -> None:
)

save_model(model, options["output_path"])

# TODO: add evaluation of the test set
1 change: 1 addition & 0 deletions src/metatensor/models/utils/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _resolve_single_str(config):
"read_from": "${..read_from}",
"file_format": "${file_format:}",
"key": None,
"length_unit": None,
}
)

Expand Down
5 changes: 3 additions & 2 deletions tests/utils/test_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_expand_dataset_config():
file_name = "foo.xyz"
file_format = ".xyz"

structure_section = {"read_from": file_name, "unit": "angstrom"}
structure_section = {"read_from": file_name, "length_unit": "angstrom"}

target_section = {
"quantity": "energy",
Expand All @@ -31,7 +31,7 @@ def test_expand_dataset_config():

assert conf_expanded["structures"]["read_from"] == file_name
assert conf_expanded["structures"]["file_format"] == file_format
assert conf_expanded["structures"]["unit"] == "angstrom"
assert conf_expanded["structures"]["length_unit"] == "angstrom"

targets_conf = conf_expanded["targets"]
assert len(targets_conf) == 2
Expand Down Expand Up @@ -89,6 +89,7 @@ def test_expand_dataset_config_min():

assert conf_expanded["structures"]["read_from"] == file_name
assert conf_expanded["structures"]["file_format"] == file_format
assert conf_expanded["structures"]["length_unit"] is None

targets_conf = conf_expanded["targets"]
assert targets_conf["energy"]["quantity"] == "energy"
Expand Down

0 comments on commit b269f6a

Please sign in to comment.