Skip to content

Commit

Permalink
Merge the predict_structure and featurize_structure into a single met…
Browse files Browse the repository at this point in the history
…hod (#290)

* improve TensorNet model coverage

* Update pyproject.toml

Signed-off-by: Tsz Wai Ko <[email protected]>

* Improve the unit test for SO(3) equivarance in TensorNet class

* improve SO3Net model class coverage and simplify TensorNet implementations

* improve the coverage in MLP_norm class

* Improve the implementation of three-body interactions

* fixed black

* Optimize the speed of _compute_3body class

* type checking is added for scheduler

* update M3GNet Potential training notebook for the demonstration of obtaining and using element offsets

* Downgrade sympy to avoid crash of SO3 operations

* Smooth l1 loss function is added and united tests are improved

* merge the method predict_structure and featurize_structure into a function including both

---------

Signed-off-by: Tsz Wai Ko <[email protected]>
  • Loading branch information
kenko911 authored Jul 21, 2024
1 parent 1cb40b5 commit e97203e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
13 changes: 13 additions & 0 deletions examples/Property Predictions using MEGNet or M3GNet Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@
"print(f\"The predicted formation energy for CsCl is {float(eform):.3f} eV/atom.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0245162",
"metadata": {},
"outputs": [],
"source": [
"# Extract the structure features of a structure\n",
"feat_dict = model.model.predict_structure(struct, return_features=True)\n",
"# Print out structure-wise features, it should be the dimension of node_features * 2 from set2set layer\n",
"print(feat_dict[\"readout\"].shape)"
]
},
{
"cell_type": "markdown",
"id": "90e95671",
Expand Down
44 changes: 17 additions & 27 deletions src/matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,25 +294,28 @@ def forward(
return fea_dict
return torch.squeeze(output)

def featurize_structure(
def predict_structure(
self,
structure,
state_feats: torch.Tensor | None = None,
graph_converter: GraphConverter | None = None,
output_layers: list | None = None,
return_features: bool = False,
):
"""Convenience method to featurize a structure with M3GNet model.
"""Convenience method to featurize or predict properties of a structure with M3GNet model.
Args:
structure: An input crystal/molecule.
state_feats (torch.tensor): Graph attributes.
graph_converter: Object that implements a get_graph_from_structure.
output_layers: List of names for the layer of GNN as output. Choose from "bond_expansion", "embedding",
"three_body_basis", "gc_1", "gc_2", "gc_3", "readout", and "final". By default, all M3GNet layer
outputs are returned.
outputs are returned. Ignored if `return_features` is False.
return_features (bool): If True, return specified layer outputs. If False, only return final output.
Returns:
output (dict): M3GNet intermediate and final layer outputs for a structure.
output (dict or torch.tensor): M3GNet intermediate and final layer outputs for a structure, or final
predicted property if `return_features` is False.
"""
allowed_output_layers = [
"bond_expansion",
Expand All @@ -321,7 +324,10 @@ def featurize_structure(
"readout",
"final",
] + [f"gc_{i + 1}" for i in range(self.n_blocks)]
if output_layers is None:

if not return_features:
output_layers = ["final"]
elif output_layers is None:
output_layers = allowed_output_layers
elif not isinstance(output_layers, list) or set(output_layers).difference(allowed_output_layers):
raise ValueError(f"Invalid output_layers, it must be a sublist of {allowed_output_layers}.")
Expand All @@ -330,33 +336,17 @@ def featurize_structure(
from matgl.ext.pymatgen import Structure2Graph

graph_converter = Structure2Graph(element_types=self.element_types, cutoff=self.cutoff) # type: ignore

g, lat, state_feats_default = graph_converter.get_graph(structure)
g.edata["pbc_offshift"] = torch.matmul(g.edata["pbc_offset"], lat[0])
g.ndata["pos"] = g.ndata["frac_coords"] @ lat[0]

if state_feats is None:
state_feats = torch.tensor(state_feats_default)
if output_layers == ["final"]:
return self(g=g, state_attr=state_feats).detach()
return {
k: v
for k, v in self(g=g, state_attr=state_feats, return_all_layer_output=True).items()
if k in output_layers
}

def predict_structure(
self,
structure,
state_feats: torch.Tensor | None = None,
graph_converter: GraphConverter | None = None,
):
"""Convenient method to directly predict property from structure.
model_output = self(g=g, state_attr=state_feats, return_all_layer_output=True)

Args:
structure: An input crystal/molecule.
state_feats (torch.tensor): Graph attributes
graph_converter: Object that implements a get_graph_from_structure.
if not return_features:
return model_output["final"].detach()

Returns:
output(torch.tensor): output property for a structure
"""
return self.featurize_structure(structure, state_feats, graph_converter, ["final"])
return {k: v for k, v in model_output.items() if k in output_layers}
8 changes: 5 additions & 3 deletions tests/models/test_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_featurize_structure(self, graph_MoS):
model_extensive = M3GNet(is_intensive=False)
for model in [model_extensive, model_intensive]:
with pytest.raises(ValueError, match="Invalid output_layers"):
model.featurize_structure(structure, output_layers=["whatever"])
features = model.featurize_structure(structure)
model.predict_structure(structure, output_layers=["whatever"], return_features=True)
features = model.predict_structure(structure, return_features=True)
assert torch.numel(features["bond_expansion"]) == 252
assert torch.numel(features["three_body_basis"]) == 3276
for output_layer in ["embedding", "gc_1", "gc_2", "gc_3"]:
Expand All @@ -95,4 +95,6 @@ def test_featurize_structure(self, graph_MoS):
else:
assert torch.numel(features["readout"]) == 2
assert torch.numel(features["final"]) == 1
assert list(model.featurize_structure(structure, output_layers=["gc_1"]).keys()) == ["gc_1"]
assert list(model.predict_structure(structure, output_layers=["gc_1"], return_features=True).keys()) == [
"gc_1"
]

0 comments on commit e97203e

Please sign in to comment.