diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index 60b46878..436134b6 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -31,11 +31,11 @@ def _base_reader( ) -> List[Any]: if reader is None: try: - filesuffix = Path(filename).suffix - reader = DEFAULT_READER[filesuffix] + file_suffix = Path(filename).suffix + reader = DEFAULT_READER[file_suffix] except KeyError: raise ValueError( - f"File extension {filesuffix!r} is not linked to a default reader " + f"File extension {file_suffix!r} is not linked to a default reader " "library. You can try reading it by setting a specific 'reader' from " f"the known ones: {', '.join(AVAILABLE_READERS)} " ) @@ -171,7 +171,7 @@ def read_targets( This function uses subfunctions like :func:`read_energy` to parse the requested target quantity. Currently only `energy` is a supported target property. But, within the `energy` section gradients such as `forces`, the `stress` or the `virial` can be - added. Other gradients are silentlty irgnored. + added. Other gradients are silently ignored. :param conf: config containing the keys for what should be read. :returns: Dictionary containing a list of TensorMaps for each target section in the @@ -213,14 +213,11 @@ def read_targets( reader=target["forces"]["reader"], ) except Exception: - logger.warning( - f"No Forces found in section {target_key!r}. " - "Continue without forces!" - ) + logger.warning(f"No forces found in section {target_key!r}.") else: logger.info( - f"Forces found in section {target_key!r}. Forces are taken for " - "training!" + f"Forces found in section {target_key!r}, " + "we will use this gradient to train the model" ) for block, position_gradient in zip(blocks, position_gradients): block.add_gradient( @@ -230,7 +227,7 @@ def read_targets( target_info_gradients.append("positions") if target["stress"] and target["virial"]: - raise ValueError("Cannot use stress and virial at the same time!") + raise ValueError("Cannot use stress and virial at the same time") if target["stress"]: try: @@ -240,14 +237,11 @@ def read_targets( reader=target["stress"]["reader"], ) except Exception: - logger.warning( - f"No Stress found in section {target_key!r}. " - "Continue without stress!" - ) + logger.warning(f"No stress found in section {target_key!r}.") else: logger.info( - f"Stress found in section {target_key!r}. Stress is taken for " - f"training!" + f"Stress found in section {target_key!r}, " + "we will use this gradient to train the model" ) for block, strain_gradient in zip(blocks, strain_gradients): block.add_gradient(parameter="strain", gradient=strain_gradient) @@ -262,14 +256,11 @@ def read_targets( reader=target["virial"]["reader"], ) except Exception: - logger.warning( - f"No Virial found in section {target_key!r}. " - "Continue without virial!" - ) + logger.warning(f"No virial found in section {target_key!r}.") else: logger.info( - f"Virial found in section {target_key!r}. Virial is taken for " - f"training!" + f"Virial found in section {target_key!r}, " + "we will use this gradient to train the model" ) for block, strain_gradient in zip(blocks, strain_gradients): block.add_gradient(parameter="strain", gradient=strain_gradient) diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 8a9a57ba..92b5e656 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -256,12 +256,12 @@ def test_read_targets_warnings(stress_dict, virial_dict, monkeypatch, tmp_path, caplog.set_level(logging.WARNING) read_targets(OmegaConf.create(conf)) # , slice_samples_by="system") - assert any(["No Forces found" in rec.message for rec in caplog.records]) + assert any(["No forces found" in rec.message for rec in caplog.records]) if stress_dict: - assert any(["No Stress found" in rec.message for rec in caplog.records]) + assert any(["No stress found" in rec.message for rec in caplog.records]) if virial_dict: - assert any(["No Virial found" in rec.message for rec in caplog.records]) + assert any(["No virial found" in rec.message for rec in caplog.records]) def test_read_targets_error(monkeypatch, tmp_path):