Skip to content

Commit

Permalink
add resolution flag to escn
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Aug 12, 2024
1 parent 917056a commit 7a71c46
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
basis_width_scalar: float = 1.0,
distance_resolution: float = 0.02,
show_timing_info: bool = False,
resolution: int | None = None,
) -> None:
if mmax_list is None:
mmax_list = [2]
Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(
for lval in range(max(self.lmax_list) + 1):
SO3_m_grid = nn.ModuleList()
for m in range(max(self.lmax_list) + 1):
SO3_m_grid.append(SO3_Grid(lval, m))
SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution))

self.SO3_grid.append(SO3_m_grid)

Expand Down
9 changes: 4 additions & 5 deletions src/fairchem/core/models/escn/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,7 @@ class SO3_Grid(torch.nn.Module):
mmax (int): Maximum order of the spherical harmonics
"""

def __init__(
self,
lmax: int,
mmax: int,
) -> None:
def __init__(self, lmax: int, mmax: int, resolution: int | None = None) -> None:
super().__init__()
self.lmax = lmax
self.mmax = mmax
Expand All @@ -465,6 +461,9 @@ def __init__(
self.long_resolution = 2 * (self.mmax + 1) + 1
else:
self.long_resolution = 2 * (self.mmax) + 1
if resolution:
self.long_resolution=resolution
self.lat_resolution=resolution

self.initialized = False

Expand Down

0 comments on commit 7a71c46

Please sign in to comment.