Skip to content

Commit

Permalink
Merge pull request #619 from ACEsuit/develop
Browse files Browse the repository at this point in the history
add option to return raw model in mace_mp
  • Loading branch information
ilyes319 authored Oct 2, 2024
2 parents 96aa932 + 6be2a94 commit 118a514
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def mace_mp(
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
dispersion_xc: str = "pbe",
dispersion_cutoff: float = 40.0 * units.Bohr,
return_raw_model: bool = False,
**kwargs,
) -> MACECalculator:
"""
Expand All @@ -93,6 +94,7 @@ def mace_mp(
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
Returns:
Expand All @@ -114,6 +116,9 @@ def mace_mp(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)

if return_raw_model:
return torch.load(model_path, map_location=device)

mace_calc = MACECalculator(
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
)
Expand Down Expand Up @@ -221,6 +226,7 @@ def mace_off(
def mace_anicc(
device: str = "cuda",
model_path: str = None,
return_raw_model: bool = False,
) -> MACECalculator:
"""
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
Expand All @@ -236,6 +242,8 @@ def mace_anicc(
print(
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
)
if return_raw_model:
return torch.load(model_path, map_location=device)
return MACECalculator(
model_paths=model_path, device=device, default_dtype="float64"
)

0 comments on commit 118a514

Please sign in to comment.