diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 9e97b4849..031126d40 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -79,6 +79,7 @@ def __init__( resolution: int | None = None, compile: bool = False, export: bool = False, + rescale_grid: bool = False, ) -> None: super().__init__() @@ -103,6 +104,7 @@ def __init__( self.distance_function = distance_function self.compile = compile self.export = export + self.rescale_grid = rescale_grid # non-linear activation function used throughout the network self.act = nn.SiLU() @@ -152,12 +154,11 @@ def __init__( # Initialize the transformations between spherical and grid representations self.SO3_grid = nn.ModuleDict() self.SO3_grid["lmax_lmax"] = SO3_Grid( - self.lmax, self.lmax, resolution=resolution + self.lmax, self.lmax, resolution=resolution, rescale=self.rescale_grid ) self.SO3_grid["lmax_mmax"] = SO3_Grid( - self.lmax, self.mmax, resolution=resolution + self.lmax, self.mmax, resolution=resolution, rescale=self.rescale_grid ) - self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax]) # Initialize the blocks for each layer of the GNN self.layer_blocks = nn.ModuleList() @@ -173,7 +174,6 @@ def __init__( self.max_num_elements, self.SO3_grid, self.act, - self.mappingReduced, ) self.layer_blocks.append(block) @@ -435,7 +435,6 @@ def __init__( max_num_elements: int, SO3_grid: SO3_Grid, act, - mappingReduced, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -444,7 +443,6 @@ def __init__( self.mmax = mmax self.sphere_channels = sphere_channels self.SO3_grid = SO3_grid - self.mappingReduced = mappingReduced # Message block self.message_block = MessageBlock( @@ -458,7 +456,6 @@ def __init__( max_num_elements, self.SO3_grid, self.act, - self.mappingReduced, ) # Non-linear point-wise comvolution for the aggregated messages @@ -547,7 +544,6 @@ def __init__( max_num_elements: int, SO3_grid: SO3_Grid, act, - mappingReduced, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -558,8 +554,9 @@ def __init__( self.lmax = lmax self.mmax = mmax self.edge_channels = edge_channels - self.mappingReduced = mappingReduced - self.out_mask = self.mappingReduced.coefficient_idx(self.lmax, self.mmax) + self.out_mask = CoefficientMapping([self.lmax], [self.lmax]).coefficient_idx( + self.lmax, self.mmax + ) # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -577,7 +574,6 @@ def __init__( self.lmax, self.mmax, self.act, - self.mappingReduced, ) self.so2_block_target = SO2Block( self.sphere_channels, @@ -586,7 +582,6 @@ def __init__( self.lmax, self.mmax, self.act, - self.mappingReduced, ) def forward( @@ -666,7 +661,6 @@ def __init__( lmax: int, mmax: int, act, - mappingReduced, ) -> None: super().__init__() self.sphere_channels = sphere_channels @@ -674,7 +668,7 @@ def __init__( self.lmax = lmax self.mmax = mmax self.act = act - self.mappingReduced = mappingReduced + self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax]) num_channels_m0 = (self.lmax + 1) * self.sphere_channels diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index 36e9d96cf..fcbc5c8f5 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -280,7 +280,6 @@ def _grid_act(self, SO3_grid, act, mappingReduced) -> None: from_grid_mat = SO3_grid[self.lmax_list[i]][ self.mmax_list[i] ].get_from_grid_mat(self.device) - x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res) x_grid = act(x_grid) x_res = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) diff --git a/src/fairchem/core/models/escn/so3_exportable.py b/src/fairchem/core/models/escn/so3_exportable.py index 537813476..115d3c198 100644 --- a/src/fairchem/core/models/escn/so3_exportable.py +++ b/src/fairchem/core/models/escn/so3_exportable.py @@ -250,6 +250,7 @@ def __init__( mmax: int, normalization: str = "integral", resolution: int | None = None, + rescale: bool = False, ): super().__init__() @@ -276,7 +277,7 @@ def __init__( ) to_grid_mat = torch.einsum("mbi, am -> bai", to_grid.shb, to_grid.sha).detach() # rescale based on mmax - if lmax != mmax: + if rescale and lmax != mmax: for lval in range(lmax + 1): if lval <= mmax: continue @@ -300,7 +301,7 @@ def __init__( "am, mbi -> bai", from_grid.sha, from_grid.shb ).detach() # rescale based on mmax - if lmax != mmax: + if rescale and lmax != mmax: for lval in range(lmax + 1): if lval <= mmax: continue diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index aaa8c9b36..2e9f7b4b9 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -70,8 +70,8 @@ def load_model(type: str, compile=False, export=False): cutoff=CUTOFF, max_num_elements=MAX_ELEMENTS, num_layers=8, - lmax_list=[4], - mmax_list=[2], + lmax_list=[6], + mmax_list=[0], sphere_channels=128, hidden_channels=256, edge_channels=128, @@ -87,8 +87,8 @@ def load_model(type: str, compile=False, export=False): cutoff=CUTOFF, max_num_elements=MAX_ELEMENTS, num_layers=8, - lmax=4, - mmax=2, + lmax=6, + mmax=0, sphere_channels=128, hidden_channels=256, edge_channels=128,