From 4d53fe74064558a9af3a6f2e0e7604e2645766db Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Sun, 6 Aug 2023 15:52:14 -0700 Subject: [PATCH] Fix model loading with map_location. --- matgl/graph/compute.py | 6 +----- matgl/utils/io.py | 8 +++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 68d59245..dad60bf0 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -104,11 +104,7 @@ def compute_theta_and_phi(edges: dgl.udf.EdgeBatch): triple_bond_lengths (torch.tensor): """ angles = compute_theta(edges, cosine=True) - angles.update( - { - "phi": torch.zeros_like(angles["cos_theta"]), - } - ) + angles["phi"] = torch.zeros_like(angles["cos_theta"]) return angles diff --git a/matgl/utils/io.py b/matgl/utils/io.py index b0b96e63..1585dab1 100644 --- a/matgl/utils/io.py +++ b/matgl/utils/io.py @@ -111,11 +111,9 @@ def load(cls, path: str | Path | dict, **kwargs): _check_ver(cls, model_data) - if not torch.cuda.is_available(): - state = torch.load(fpaths["state.pt"], map_location=torch.device("cpu")) - else: - state = torch.load(fpaths["state.pt"]) - d = torch.load(fpaths["model.pt"]) + map_location = torch.device("cpu") if not torch.cuda.is_available() else None + state = torch.load(fpaths["state.pt"], map_location=map_location) + d = torch.load(fpaths["model.pt"], map_location=map_location) # Deserialize any args that are IOMixIn subclasses. for k, v in d.items():