From cb41e6073b56b81fd69c16a903db203d2ccaa394 Mon Sep 17 00:00:00 2001 From: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:54:04 -0700 Subject: [PATCH] Ensure the state attr from molecular graph is consistent with matgl.float_th and include linear layer in TensorNet to match the original implementations (#244) * model version for Potential class is added * model version for Potential class is modified * Enable the smooth version of Spherical Bessel function in TensorNet * max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class * adding united tests for improving the coverage score * little clean up in _so3.py and so3.py * remove unnecessary data storage in dgl graphs * update pymatgen version to fix the bug * refractor all include_states into include_state for consistency * change include_states into include_state in test_graph_conv.py * Ensure the state attr from the molecular graph is consistent with matgl.float_th and include the linear layer in TensorNet to match the original implementations --- src/matgl/graph/data.py | 2 +- src/matgl/models/_tensornet.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/matgl/graph/data.py b/src/matgl/graph/data.py index 7ffb2dfe..93836e68 100644 --- a/src/matgl/graph/data.py +++ b/src/matgl/graph/data.py @@ -189,7 +189,7 @@ def process(self): if self.graph_labels is not None: state_attrs = torch.tensor(self.graph_labels).long() else: - state_attrs = torch.tensor(np.array(state_attrs)) + state_attrs = torch.tensor(np.array(state_attrs), dtype=matgl.float_th) if self.clear_processed: del self.structures diff --git a/src/matgl/models/_tensornet.py b/src/matgl/models/_tensornet.py index 9050dbbb..66d8131a 100644 --- a/src/matgl/models/_tensornet.py +++ b/src/matgl/models/_tensornet.py @@ -180,8 +180,9 @@ def __init__( ) self.out_norm = nn.LayerNorm(3 * units, dtype=dtype) + self.linear = nn.Linear(3 * units, units, dtype=dtype) if is_intensive: - input_feats = 3 * units if field == "node_feat" else units + input_feats = units if readout_type == "set2set": self.readout = Set2SetReadOut( in_feats=input_feats, n_iters=niters_set2set, n_layers=nlayers_set2set, field=field @@ -203,7 +204,7 @@ def __init__( if task_type == "classification": raise ValueError("Classification task cannot be extensive.") self.final_layer = WeightedReadOut( - in_feats=3 * units, + in_feats=units, dims=[units, units], num_targets=ntargets, # type: ignore ) @@ -247,6 +248,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, **kwa x = torch.cat((tensor_norm(scalars), tensor_norm(skew_metrices), tensor_norm(traceless_tensors)), dim=-1) x = self.out_norm(x) + x = self.linear(x) g.ndata["node_feat"] = x if self.is_intensive: