Skip to content

Commit

Permalink
Fix model loading with map_location.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Aug 6, 2023
1 parent 4961e1c commit 4d53fe7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
6 changes: 1 addition & 5 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def compute_theta_and_phi(edges: dgl.udf.EdgeBatch):
triple_bond_lengths (torch.tensor):
"""
angles = compute_theta(edges, cosine=True)
angles.update(
{
"phi": torch.zeros_like(angles["cos_theta"]),
}
)
angles["phi"] = torch.zeros_like(angles["cos_theta"])
return angles


Expand Down
8 changes: 3 additions & 5 deletions matgl/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,9 @@ def load(cls, path: str | Path | dict, **kwargs):

_check_ver(cls, model_data)

if not torch.cuda.is_available():
state = torch.load(fpaths["state.pt"], map_location=torch.device("cpu"))
else:
state = torch.load(fpaths["state.pt"])
d = torch.load(fpaths["model.pt"])
map_location = torch.device("cpu") if not torch.cuda.is_available() else None
state = torch.load(fpaths["state.pt"], map_location=map_location)
d = torch.load(fpaths["model.pt"], map_location=map_location)

# Deserialize any args that are IOMixIn subclasses.
for k, v in d.items():
Expand Down

0 comments on commit 4d53fe7

Please sign in to comment.