Skip to content

Commit

Permalink
Remove defunct graph_transformations arg.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Jun 21, 2023
1 parent f92699c commit 701017a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
2 changes: 1 addition & 1 deletion matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
activation_type: str = "swish",
**kwargs,
):
r"""
"""
Args:
element_types (tuple): list of elements appearing in the dataset
dim_node_embedding (int): number of embedded atomic features
Expand Down
9 changes: 2 additions & 7 deletions matgl/models/_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
is_classification: bool = False,
include_state: bool = True,
dropout: float | None = None,
graph_transformations: list | None = None,
element_types: tuple[str, ...] = DEFAULT_ELEMENT_TYPES,
bond_expansion: BondExpansion | None = None,
cutoff: float = 4.0,
Expand Down Expand Up @@ -73,8 +72,6 @@ def __init__(
include_state: Whether the state embedding is included
dropout: Randomly zeroes some elements in the input tensor with given probability (0 < x < 1) according to
a Bernoulli distribution
graph_transformations: Perform a graph transformation, e.g., incorporate three-body interactions, prior to
performing the GCL updates.
element_types: Elements included in the training set
bond_expansion: Gaussian expansion for edge attributes
cutoff: cutoff for forming bonds
Expand Down Expand Up @@ -151,7 +148,6 @@ def __init__(
self.dropout = nn.Dropout(dropout) if dropout else None

self.is_classification = is_classification
self.graph_transformations = graph_transformations or [nn.Identity()] * nblocks
self.include_state_embedding = include_state

def forward(
Expand All @@ -172,14 +168,13 @@ def forward(
Returns:
Prediction
"""
graph_transformations = self.graph_transformations
node_feat, edge_feat, state_feat = self.embedding(node_feat, edge_feat, state_feat)
edge_feat = self.edge_encoder(edge_feat)
node_feat = self.node_encoder(node_feat)
state_feat = self.state_encoder(state_feat)

for gt, block in zip(graph_transformations, self.blocks):
output = block(gt(graph), edge_feat, node_feat, state_feat)
for block in self.blocks:
output = block(graph, edge_feat, node_feat, state_feat)
edge_feat, node_feat, state_feat = output

node_vec = self.node_s2s(graph, node_feat)
Expand Down

0 comments on commit 701017a

Please sign in to comment.