Skip to content

Commit

Permalink
simplified fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyonofx committed Jul 19, 2024
1 parent 3518827 commit 3008f9d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 17 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
root = true

[*.py]
charset = utf-8
end_of_line = lf
indent_size = 4
indent_style = space
insert_final_newline = false
max_line_length = 120
tab_width = 4
11 changes: 7 additions & 4 deletions src/fairchem/core/models/equiformer_v2/edge_rot_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def init_edge_rot_mat(edge_distance_vec):
norm_y = torch.cross(norm_x, norm_z, dim=1)
norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True)))

y_prod = (norm_x @ norm_x.new_tensor([0,1,0])).abs()
y_aligned = (y_prod > 0.99999)
y_aligned = (norm_x @ norm_x.new_tensor([0,1,0]))
# y_aligned = (y_prod > 0.9999)

# Construct the 3D rotation matrix
norm_x = norm_x.view(-1, 3, 1)
Expand All @@ -54,5 +54,8 @@ def init_edge_rot_mat(edge_distance_vec):

edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2)
edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2)

return edge_rot_mat, y_aligned

# for y-aligned we can apply a random rotation around the z-axis
edge_rot_mat[y_aligned > 0.9999, 1, :] = edge_rot_mat.new_tensor([[0., 1., 0.]])
edge_rot_mat[y_aligned < -0.9999, 1, :] = edge_rot_mat.new_tensor([[0., -1., 0.]])
return edge_rot_mat
4 changes: 2 additions & 2 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,11 @@ def forward(self, data):
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat, y_aligned = self._init_edge_rot_mat(data, edge_index, edge_distance_vec)
edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
self.SO3_rotation[i].set_wigner(edge_rot_mat, y_aligned)
self.SO3_rotation[i].set_wigner(edge_rot_mat)

###############################################################
# Initialize node embeddings
Expand Down
43 changes: 32 additions & 11 deletions src/fairchem/core/models/equiformer_v2/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ def __init__(
self.lmax = lmax
self.mapping = CoefficientMappingModule([self.lmax], [self.lmax])

def set_wigner(self, rot_mat3x3, y_aligned):
def set_wigner(self, rot_mat3x3):
self.device, self.dtype = rot_mat3x3.device, rot_mat3x3.dtype
self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, y_aligned, 0, self.lmax)
self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, 0, self.lmax)
self.wigner_inv = torch.transpose(self.wigner, 1, 2).contiguous()

# Rotate the embedding
Expand All @@ -470,7 +470,7 @@ def rotate_inv(self, embedding, in_lmax: int, in_mmax: int):

# Compute Wigner matrices from rotation matrix
def RotationToWignerDMatrix(
self, edge_rot_mat, y_aligned, start_lmax: int, end_lmax: int
self, edge_rot_mat, start_lmax: int, end_lmax: int
) -> torch.Tensor:
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
alpha, beta = o3.xyz_to_angles(x)
Expand All @@ -481,21 +481,42 @@ def RotationToWignerDMatrix(
gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0])

# only apply random z-rotation for y-aligned vectors.
alpha_ya = torch.zeros_like(alpha)
beta_ya = torch.zeros_like(beta)
gamma_ya = torch.rand_like(gamma) * 2 * math.pi
# alpha_ya = torch.zeros_like(alpha)
# beta_ya = torch.zeros_like(beta)
# beta_ya[y_aligned > 0.9999] = 0.0
# beta_ya[y_aligned < -0.9999] = math.pi
# y_aligned = torch.abs(y_aligned) > 0.9999
# gamma_ya = torch.zeros_like(gamma)
# gamma_ya = torch.rand_like(gamma) * 2 * math.pi - math.pi
# print(gamma[y_aligned])
# print(gamma_ya[y_aligned])
# alpha_ya = torch.clone(alpha).detach()
# beta_ya = torch.clone(beta).detach()
# print(beta_ya[y_aligned])
# mask = torch.abs(beta_ya) < 1.
# print(mask[y_aligned])
# beta_ya[mask] = 0.
# gamma_ya = torch.clone(gamma)
# gamma_ya = torch.clone(gamma)
# print(beta_ya[y_aligned])


size = (end_lmax + 1) ** 2 - (start_lmax) ** 2
wigner = torch.zeros(len(alpha), size, size, device=self.device)
start = 0
for lmax in range(start_lmax, end_lmax + 1):
block = wigner_D(
lmax, alpha, beta, gamma)[~y_aligned].to(wigner.dtype)
block_ya = wigner_D(
lmax, alpha_ya, beta_ya, gamma_ya)[y_aligned].to(wigner.dtype)
lmax, alpha, beta, gamma).to(wigner.dtype)
# lmax, alpha, beta, gamma)[~y_aligned].to(wigner.dtype)
# block_ya = wigner_D(
# lmax, alpha_ya, beta_ya, gamma_ya)[y_aligned].to(wigner.dtype)
# block_ya = wigner_D(
# lmax, alpha, beta, gamma)[y_aligned].to(wigner.dtype)

end = start + block.size()[1]
wigner[~y_aligned, start:end, start:end] = block
wigner[y_aligned, start:end, start:end] = block_ya
# wigner[~y_aligned, start:end, start:end] = block
# wigner[y_aligned, start:end, start:end] = block_ya
wigner[:, start:end, start:end] = block
start = end

return wigner
Expand Down

0 comments on commit 3008f9d

Please sign in to comment.