Skip to content

Commit

Permalink
Ensure the state attr from molecular graph is consistent with matgl.f…
Browse files Browse the repository at this point in the history
…loat_th and include linear layer in TensorNet to match the original implementations (#244)

* model version for Potential class is added

* model version for Potential class is modified

* Enable the smooth version of Spherical Bessel function in TensorNet

* max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class

* adding united tests for improving the coverage score

* little clean up in _so3.py and so3.py

* remove unnecessary data storage in dgl graphs

* update pymatgen version to fix the bug

* refractor all include_states into include_state for consistency

* change include_states into include_state in test_graph_conv.py

* Ensure the state attr from the molecular graph is consistent with matgl.float_th and include the linear layer in TensorNet to match the original implementations
  • Loading branch information
kenko911 authored Mar 29, 2024
1 parent ee9e987 commit cb41e60
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def process(self):
if self.graph_labels is not None:
state_attrs = torch.tensor(self.graph_labels).long()
else:
state_attrs = torch.tensor(np.array(state_attrs))
state_attrs = torch.tensor(np.array(state_attrs), dtype=matgl.float_th)

if self.clear_processed:
del self.structures
Expand Down
6 changes: 4 additions & 2 deletions src/matgl/models/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def __init__(
)

self.out_norm = nn.LayerNorm(3 * units, dtype=dtype)
self.linear = nn.Linear(3 * units, units, dtype=dtype)
if is_intensive:
input_feats = 3 * units if field == "node_feat" else units
input_feats = units
if readout_type == "set2set":
self.readout = Set2SetReadOut(
in_feats=input_feats, n_iters=niters_set2set, n_layers=nlayers_set2set, field=field
Expand All @@ -203,7 +204,7 @@ def __init__(
if task_type == "classification":
raise ValueError("Classification task cannot be extensive.")
self.final_layer = WeightedReadOut(
in_feats=3 * units,
in_feats=units,
dims=[units, units],
num_targets=ntargets, # type: ignore
)
Expand Down Expand Up @@ -247,6 +248,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, **kwa

x = torch.cat((tensor_norm(scalars), tensor_norm(skew_metrices), tensor_norm(traceless_tensors)), dim=-1)
x = self.out_norm(x)
x = self.linear(x)

g.ndata["node_feat"] = x
if self.is_intensive:
Expand Down

0 comments on commit cb41e60

Please sign in to comment.