From 6be2a940121d0a12b4d50cba2c1bf359aaa9e5f9 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:12:13 +0100 Subject: [PATCH] add option to return raw model in mace_mp --- mace/calculators/foundations_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 0fbe3a8f..5c9a896f 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -68,6 +68,7 @@ def mace_mp( damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] dispersion_xc: str = "pbe", dispersion_cutoff: float = 40.0 * units.Bohr, + return_raw_model: bool = False, **kwargs, ) -> MACECalculator: """ @@ -93,6 +94,7 @@ def mace_mp( damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. Returns: @@ -114,6 +116,9 @@ def mace_mp( "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." ) + if return_raw_model: + return torch.load(model_path, map_location=device) + mace_calc = MACECalculator( model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs ) @@ -221,6 +226,7 @@ def mace_off( def mace_anicc( device: str = "cuda", model_path: str = None, + return_raw_model: bool = False, ) -> MACECalculator: """ Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). @@ -236,6 +242,8 @@ def mace_anicc( print( "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" ) + if return_raw_model: + return torch.load(model_path, map_location=device) return MACECalculator( model_paths=model_path, device=device, default_dtype="float64" )