Skip to content

Commit

Permalink
Less intense logger message when gradients are used for training
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 23, 2024
1 parent 2ab36b0 commit 0b4bb21
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 26 deletions.
37 changes: 14 additions & 23 deletions src/metatrain/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def _base_reader(
) -> List[Any]:
if reader is None:
try:
filesuffix = Path(filename).suffix
reader = DEFAULT_READER[filesuffix]
file_suffix = Path(filename).suffix
reader = DEFAULT_READER[file_suffix]
except KeyError:
raise ValueError(
f"File extension {filesuffix!r} is not linked to a default reader "
f"File extension {file_suffix!r} is not linked to a default reader "
"library. You can try reading it by setting a specific 'reader' from "
f"the known ones: {', '.join(AVAILABLE_READERS)} "
)
Expand Down Expand Up @@ -171,7 +171,7 @@ def read_targets(
This function uses subfunctions like :func:`read_energy` to parse the requested
target quantity. Currently only `energy` is a supported target property. But, within
the `energy` section gradients such as `forces`, the `stress` or the `virial` can be
added. Other gradients are silentlty irgnored.
added. Other gradients are silently ignored.
:param conf: config containing the keys for what should be read.
:returns: Dictionary containing a list of TensorMaps for each target section in the
Expand Down Expand Up @@ -213,14 +213,11 @@ def read_targets(
reader=target["forces"]["reader"],
)
except Exception:
logger.warning(
f"No Forces found in section {target_key!r}. "
"Continue without forces!"
)
logger.warning(f"No forces found in section {target_key!r}.")
else:
logger.info(
f"Forces found in section {target_key!r}. Forces are taken for "
"training!"
f"Forces found in section {target_key!r}, "
"we will use this gradient to train the model"
)
for block, position_gradient in zip(blocks, position_gradients):
block.add_gradient(
Expand All @@ -230,7 +227,7 @@ def read_targets(
target_info_gradients.append("positions")

if target["stress"] and target["virial"]:
raise ValueError("Cannot use stress and virial at the same time!")
raise ValueError("Cannot use stress and virial at the same time")

if target["stress"]:
try:
Expand All @@ -240,14 +237,11 @@ def read_targets(
reader=target["stress"]["reader"],
)
except Exception:
logger.warning(
f"No Stress found in section {target_key!r}. "
"Continue without stress!"
)
logger.warning(f"No stress found in section {target_key!r}.")
else:
logger.info(
f"Stress found in section {target_key!r}. Stress is taken for "
f"training!"
f"Stress found in section {target_key!r}, "
"we will use this gradient to train the model"
)
for block, strain_gradient in zip(blocks, strain_gradients):
block.add_gradient(parameter="strain", gradient=strain_gradient)
Expand All @@ -262,14 +256,11 @@ def read_targets(
reader=target["virial"]["reader"],
)
except Exception:
logger.warning(
f"No Virial found in section {target_key!r}. "
"Continue without virial!"
)
logger.warning(f"No virial found in section {target_key!r}.")
else:
logger.info(
f"Virial found in section {target_key!r}. Virial is taken for "
f"training!"
f"Virial found in section {target_key!r}, "
"we will use this gradient to train the model"
)
for block, strain_gradient in zip(blocks, strain_gradients):
block.add_gradient(parameter="strain", gradient=strain_gradient)
Expand Down
6 changes: 3 additions & 3 deletions tests/utils/data/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ def test_read_targets_warnings(stress_dict, virial_dict, monkeypatch, tmp_path,
caplog.set_level(logging.WARNING)
read_targets(OmegaConf.create(conf)) # , slice_samples_by="system")

assert any(["No Forces found" in rec.message for rec in caplog.records])
assert any(["No forces found" in rec.message for rec in caplog.records])

if stress_dict:
assert any(["No Stress found" in rec.message for rec in caplog.records])
assert any(["No stress found" in rec.message for rec in caplog.records])
if virial_dict:
assert any(["No Virial found" in rec.message for rec in caplog.records])
assert any(["No virial found" in rec.message for rec in caplog.records])


def test_read_targets_error(monkeypatch, tmp_path):
Expand Down

0 comments on commit 0b4bb21

Please sign in to comment.