Skip to content

Commit

Permalink
fix l/m; make rescaling optional
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Sep 19, 2024
1 parent 8226618 commit dc9f40d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 21 deletions.
22 changes: 8 additions & 14 deletions src/fairchem/core/models/escn/escn_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
resolution: int | None = None,
compile: bool = False,
export: bool = False,
rescale_grid: bool = False,
) -> None:
super().__init__()

Expand All @@ -103,6 +104,7 @@ def __init__(
self.distance_function = distance_function
self.compile = compile
self.export = export
self.rescale_grid = rescale_grid

# non-linear activation function used throughout the network
self.act = nn.SiLU()
Expand Down Expand Up @@ -152,12 +154,11 @@ def __init__(
# Initialize the transformations between spherical and grid representations
self.SO3_grid = nn.ModuleDict()
self.SO3_grid["lmax_lmax"] = SO3_Grid(
self.lmax, self.lmax, resolution=resolution
self.lmax, self.lmax, resolution=resolution, rescale=self.rescale_grid
)
self.SO3_grid["lmax_mmax"] = SO3_Grid(
self.lmax, self.mmax, resolution=resolution
self.lmax, self.mmax, resolution=resolution, rescale=self.rescale_grid
)
self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax])

# Initialize the blocks for each layer of the GNN
self.layer_blocks = nn.ModuleList()
Expand All @@ -173,7 +174,6 @@ def __init__(
self.max_num_elements,
self.SO3_grid,
self.act,
self.mappingReduced,
)
self.layer_blocks.append(block)

Expand Down Expand Up @@ -435,7 +435,6 @@ def __init__(
max_num_elements: int,
SO3_grid: SO3_Grid,
act,
mappingReduced,
) -> None:
super().__init__()
self.layer_idx = layer_idx
Expand All @@ -444,7 +443,6 @@ def __init__(
self.mmax = mmax
self.sphere_channels = sphere_channels
self.SO3_grid = SO3_grid
self.mappingReduced = mappingReduced

# Message block
self.message_block = MessageBlock(
Expand All @@ -458,7 +456,6 @@ def __init__(
max_num_elements,
self.SO3_grid,
self.act,
self.mappingReduced,
)

# Non-linear point-wise comvolution for the aggregated messages
Expand Down Expand Up @@ -547,7 +544,6 @@ def __init__(
max_num_elements: int,
SO3_grid: SO3_Grid,
act,
mappingReduced,
) -> None:
super().__init__()
self.layer_idx = layer_idx
Expand All @@ -558,8 +554,9 @@ def __init__(
self.lmax = lmax
self.mmax = mmax
self.edge_channels = edge_channels
self.mappingReduced = mappingReduced
self.out_mask = self.mappingReduced.coefficient_idx(self.lmax, self.mmax)
self.out_mask = CoefficientMapping([self.lmax], [self.lmax]).coefficient_idx(
self.lmax, self.mmax
)

# Create edge scalar (invariant to rotations) features
self.edge_block = EdgeBlock(
Expand All @@ -577,7 +574,6 @@ def __init__(
self.lmax,
self.mmax,
self.act,
self.mappingReduced,
)
self.so2_block_target = SO2Block(
self.sphere_channels,
Expand All @@ -586,7 +582,6 @@ def __init__(
self.lmax,
self.mmax,
self.act,
self.mappingReduced,
)

def forward(
Expand Down Expand Up @@ -666,15 +661,14 @@ def __init__(
lmax: int,
mmax: int,
act,
mappingReduced,
) -> None:
super().__init__()
self.sphere_channels = sphere_channels
self.hidden_channels = hidden_channels
self.lmax = lmax
self.mmax = mmax
self.act = act
self.mappingReduced = mappingReduced
self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax])

num_channels_m0 = (self.lmax + 1) * self.sphere_channels

Expand Down
1 change: 0 additions & 1 deletion src/fairchem/core/models/escn/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def _grid_act(self, SO3_grid, act, mappingReduced) -> None:
from_grid_mat = SO3_grid[self.lmax_list[i]][
self.mmax_list[i]
].get_from_grid_mat(self.device)

x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res)
x_grid = act(x_grid)
x_res = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)
Expand Down
5 changes: 3 additions & 2 deletions src/fairchem/core/models/escn/so3_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
mmax: int,
normalization: str = "integral",
resolution: int | None = None,
rescale: bool = False,
):
super().__init__()

Expand All @@ -276,7 +277,7 @@ def __init__(
)
to_grid_mat = torch.einsum("mbi, am -> bai", to_grid.shb, to_grid.sha).detach()
# rescale based on mmax
if lmax != mmax:
if rescale and lmax != mmax:
for lval in range(lmax + 1):
if lval <= mmax:
continue
Expand All @@ -300,7 +301,7 @@ def __init__(
"am, mbi -> bai", from_grid.sha, from_grid.shb
).detach()
# rescale based on mmax
if lmax != mmax:
if rescale and lmax != mmax:
for lval in range(lmax + 1):
if lval <= mmax:
continue
Expand Down
8 changes: 4 additions & 4 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def load_model(type: str, compile=False, export=False):
cutoff=CUTOFF,
max_num_elements=MAX_ELEMENTS,
num_layers=8,
lmax_list=[4],
mmax_list=[2],
lmax_list=[6],
mmax_list=[0],
sphere_channels=128,
hidden_channels=256,
edge_channels=128,
Expand All @@ -87,8 +87,8 @@ def load_model(type: str, compile=False, export=False):
cutoff=CUTOFF,
max_num_elements=MAX_ELEMENTS,
num_layers=8,
lmax=4,
mmax=2,
lmax=6,
mmax=0,
sphere_channels=128,
hidden_channels=256,
edge_channels=128,
Expand Down

0 comments on commit dc9f40d

Please sign in to comment.