diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 62a582b4c..c288f3f25 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -87,6 +87,7 @@ def __init__( basis_width_scalar: float = 1.0, distance_resolution: float = 0.02, show_timing_info: bool = False, + resolution: int | None = None, ) -> None: if mmax_list is None: mmax_list = [2] @@ -176,7 +177,7 @@ def __init__( for lval in range(max(self.lmax_list) + 1): SO3_m_grid = nn.ModuleList() for m in range(max(self.lmax_list) + 1): - SO3_m_grid.append(SO3_Grid(lval, m)) + SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) self.SO3_grid.append(SO3_m_grid) diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index 988797df2..34f505d51 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -452,11 +452,7 @@ class SO3_Grid(torch.nn.Module): mmax (int): Maximum order of the spherical harmonics """ - def __init__( - self, - lmax: int, - mmax: int, - ) -> None: + def __init__(self, lmax: int, mmax: int, resolution: int | None = None) -> None: super().__init__() self.lmax = lmax self.mmax = mmax @@ -465,6 +461,9 @@ def __init__( self.long_resolution = 2 * (self.mmax + 1) + 1 else: self.long_resolution = 2 * (self.mmax) + 1 + if resolution: + self.long_resolution=resolution + self.lat_resolution=resolution self.initialized = False