Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove consistency checks (except for tests) #293

Merged
merged 6 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
selected_atoms=None,
)

outputs = exported_model([ethanol_system], evaluation_options, check_consistency=True)
outputs = exported_model([ethanol_system], evaluation_options, check_consistency=False)
lpr = outputs["mtt::aux::energy_uncertainty"].block().values.detach().cpu().numpy()

# %%
Expand Down
20 changes: 18 additions & 2 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
default="output.xyz",
help="filename of the predictions (default: %(default)s)",
)
parser.add_argument(
"--check-consistency",
dest="check_consistency",
action="store_true",
help="whether to run consistency checks (default: %(default)s)",
)


def _prepare_eval_model_args(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -150,6 +156,7 @@ def _eval_targets(
dataset: Union[Dataset, torch.utils.data.Subset],
options: TargetInfoDict,
return_predictions: bool,
check_consistency: bool = False,
) -> Optional[Dict[str, TensorMap]]:
"""Evaluates an exported model on a dataset and prints the RMSEs for each target.
Optionally, it also returns the predictions of the model.
Expand Down Expand Up @@ -195,7 +202,13 @@ def _eval_targets(
key: value.to(dtype=dtype, device=device)
for key, value in batch_targets.items()
}
batch_predictions = evaluate_model(model, systems, options, is_training=False)
batch_predictions = evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)
batch_predictions = average_by_num_atoms(
batch_predictions, systems, per_structure_keys=[]
)
Expand Down Expand Up @@ -228,6 +241,7 @@ def eval_model(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
options: DictConfig,
output: Union[Path, str] = "output.xyz",
check_consistency: bool = False,
) -> None:
"""Evaluate an exported model on a given data set.

Expand All @@ -237,7 +251,8 @@ def eval_model(

:param model: Saved model to be evaluated.
:param options: DictConfig to define a test dataset taken for the evaluation.
:param output: Path to save the predicted values
:param output: Path to save the predicted values.
:param check_consistency: Whether to run consistency checks during model evaluation.
"""
logger.info("Setting up evaluation set.")

Expand Down Expand Up @@ -290,6 +305,7 @@ def eval_model(
dataset=eval_dataset,
options=eval_info_dict,
return_predictions=True,
check_consistency=check_consistency,
)
except Exception as e:
raise ArchitectureError(e)
Expand Down
14 changes: 9 additions & 5 deletions src/metatrain/utils/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def evaluate_model(
systems: List[System],
targets: TargetInfoDict,
is_training: bool,
check_consistency: bool = False,
) -> Dict[str, TensorMap]:
"""
Evaluate the model (in training or exported) on a set of requested targets.
Expand Down Expand Up @@ -75,13 +76,14 @@ def evaluate_model(
system,
positions_grad=len(energy_targets_that_require_position_gradients) > 0,
strain_grad=len(energy_targets_that_require_strain_gradients) > 0,
check_consistency=check_consistency,
)
new_systems.append(new_system)
strains.append(strain)
systems = new_systems

# Based on the keys of the targets, get the outputs of the model:
model_outputs = _get_model_outputs(model, systems, targets)
model_outputs = _get_model_outputs(model, systems, targets, check_consistency)

for energy_target in energy_targets:
# If the energy target requires gradients, compute them:
Expand Down Expand Up @@ -233,6 +235,7 @@ def _get_model_outputs(
],
systems: List[System],
targets: TargetInfoDict,
check_consistency: bool,
) -> Dict[str, TensorMap]:
if is_exported(model):
# put together an EvaluationOptions object
Expand All @@ -245,8 +248,7 @@ def _get_model_outputs(
for key, value in targets.items()
},
)
# we check consistency here because this could be called from eval
return model(systems, options, check_consistency=True)
return model(systems, options, check_consistency=check_consistency)
else:
return model(
systems,
Expand All @@ -259,7 +261,9 @@ def _get_model_outputs(
)


def _prepare_system(system: System, positions_grad: bool, strain_grad: bool):
def _prepare_system(
system: System, positions_grad: bool, strain_grad: bool, check_consistency: bool
):
"""
Prepares a system for gradient calculation.
"""
Expand Down Expand Up @@ -294,7 +298,7 @@ def _prepare_system(system: System, positions_grad: bool, strain_grad: bool):
for nl_options in system.known_neighbor_lists():
nl = system.get_neighbor_list(nl_options)
nl = metatensor.torch.detach_block(nl)
register_autograd_neighbors(new_system, nl, check_consistency=True)
register_autograd_neighbors(new_system, nl, check_consistency)
new_system.add_neighbor_list(nl_options, nl)

return new_system, strain
10 changes: 3 additions & 7 deletions src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
outputs=outputs,
selected_atoms=selected_atoms,
)
return self.model(systems, options, check_consistency=True)
return self.model(systems, options, check_consistency=False)

Check warning on line 102 in src/metatrain/utils/llpr.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/llpr.py#L102

Added line #L102 was not covered by tests

per_atom_all_targets = [output.per_atom for output in outputs.values()]
# impose either all per atom or all not per atom
Expand Down Expand Up @@ -130,9 +130,7 @@
outputs=outputs_for_model,
selected_atoms=selected_atoms,
)
return_dict = self.model(
systems, options, check_consistency=True
) # TODO: True or False here?
return_dict = self.model(systems, options, check_consistency=False)

ll_features = return_dict["mtt::aux::last_layer_features"]

Expand Down Expand Up @@ -248,9 +246,7 @@
length_unit="",
outputs=outputs,
)
output = self.model(
systems, options, check_consistency=True
) # TODO: True or False here?
output = self.model(systems, options, check_consistency=False)
ll_feat_tmap = output["mtt::aux::last_layer_features"]
ll_feats = ll_feat_tmap.block().values / n_atoms.unsqueeze(1)
self.covariance += ll_feats.T @ ll_feats
Expand Down
5 changes: 5 additions & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_eval_cli(monkeypatch, tmp_path):
str(EVAL_OPTIONS_PATH),
"-e",
str(RESOURCES_PATH / "extensions"),
"--check-consistency",
]

output = subprocess.check_output(command, stderr=subprocess.STDOUT)
Expand All @@ -60,6 +61,7 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
model=model,
options=options,
output="foo.xyz",
check_consistency=True,
)

# Test target predictions
Expand Down Expand Up @@ -94,6 +96,7 @@ def test_eval_export(monkeypatch, tmp_path, options):
model=exported_model,
options=options,
output="foo.xyz",
check_consistency=True,
)


Expand All @@ -108,6 +111,7 @@ def test_eval_multi_dataset(monkeypatch, tmp_path, caplog, model, options):
model=model,
options=OmegaConf.create([options, options]),
output="foo.xyz",
check_consistency=True,
)

# Test target predictions
Expand All @@ -131,6 +135,7 @@ def test_eval_no_targets(monkeypatch, tmp_path, model, options):
eval_model(
model=model,
options=options,
check_consistency=True,
)

assert Path("output.xyz").is_file()
4 changes: 3 additions & 1 deletion tests/utils/test_evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def test_evaluate_model(training, exported):
]

systems = [system.to(torch.float32) for system in systems]
outputs = evaluate_model(model, systems, targets, is_training=training)
outputs = evaluate_model(
model, systems, targets, is_training=training, check_consistency=True
)

assert isinstance(outputs, dict)
assert "energy" in outputs
Expand Down