Skip to content

Commit

Permalink
fix loading normalizers from checkpoint with OTF (#824)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Sep 4, 2024
1 parent e97fa8d commit a98cb9d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
23 changes: 13 additions & 10 deletions src/fairchem/core/modules/normalization/_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,19 @@ def _load_from_config(
"""
modules = _load_check_duplicates(config, name)
for target in config:
if target == "fit" and not config["fit"].get("fitted", False):
# remove values for output targets that have already been read from files
targets = [
target for target in config["fit"]["targets"] if target not in modules
]
fit_kwargs.update(
{k: v for k, v in config["fit"].items() if k != "targets"}
)
modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs))
config["fit"]["fitted"] = True
if target == "fit":
if not config["fit"].get("fitted", False):
# remove values for output targets that have already been read from files
targets = [
target
for target in config["fit"]["targets"]
if target not in modules
]
fit_kwargs.update(
{k: v for k, v in config["fit"].items() if k != "targets"}
)
modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs))
config["fit"]["fitted"] = True
# if a single file for all outputs is not provided,
# then check if a single file is provided for a specific output
elif target != "file":
Expand Down
19 changes: 11 additions & 8 deletions src/fairchem/core/modules/normalization/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,17 @@ def load_normalizers_from_config(
) -> dict[str, Normalizer]:
"""Create a dictionary with element references from a config."""
# edit the config slightly to extract override args
if "fit" in config:
override_values = {
target: vals
for target, vals in config["fit"]["targets"].items()
if isinstance(vals, dict)
}
config["fit"]["override_values"] = override_values
config["fit"]["targets"] = list(config["fit"]["targets"].keys())
if "fit" in config: # noqa
if "override_values" not in config["fit"] and isinstance(
config["fit"]["targets"], dict
):
override_values = {
target: vals
for target, vals in config["fit"]["targets"].items()
if isinstance(vals, dict)
}
config["fit"]["override_values"] = override_values
config["fit"]["targets"] = list(config["fit"]["targets"].keys())

return _load_from_config(
config,
Expand Down

0 comments on commit a98cb9d

Please sign in to comment.