diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..65fbde851 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,10 @@ +root = true + +[*.py] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = false +max_line_length = 120 +tab_width = 4 \ No newline at end of file diff --git a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py b/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py index 6787db96a..7d2c29603 100644 --- a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py +++ b/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py @@ -44,8 +44,8 @@ def init_edge_rot_mat(edge_distance_vec): norm_y = torch.cross(norm_x, norm_z, dim=1) norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) - y_prod = (norm_x @ norm_x.new_tensor([0,1,0])).abs() - y_aligned = (y_prod > 0.99999) + y_aligned = (norm_x @ norm_x.new_tensor([0,1,0])) + # y_aligned = (y_prod > 0.9999) # Construct the 3D rotation matrix norm_x = norm_x.view(-1, 3, 1) @@ -54,5 +54,8 @@ def init_edge_rot_mat(edge_distance_vec): edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - - return edge_rot_mat, y_aligned + + # for y-aligned we can apply a random rotation around the z-axis + edge_rot_mat[y_aligned > 0.9999, 1, :] = edge_rot_mat.new_tensor([[0., 1., 0.]]) + edge_rot_mat[y_aligned < -0.9999, 1, :] = edge_rot_mat.new_tensor([[0., -1., 0.]]) + return edge_rot_mat diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py index c548b74a3..8edf81319 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py @@ -485,11 +485,11 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat, y_aligned = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): - self.SO3_rotation[i].set_wigner(edge_rot_mat, y_aligned) + self.SO3_rotation[i].set_wigner(edge_rot_mat) ############################################################### # Initialize node embeddings diff --git a/src/fairchem/core/models/equiformer_v2/so3.py b/src/fairchem/core/models/equiformer_v2/so3.py index 56318495f..ebae94d7f 100644 --- a/src/fairchem/core/models/equiformer_v2/so3.py +++ b/src/fairchem/core/models/equiformer_v2/so3.py @@ -449,9 +449,9 @@ def __init__( self.lmax = lmax self.mapping = CoefficientMappingModule([self.lmax], [self.lmax]) - def set_wigner(self, rot_mat3x3, y_aligned): + def set_wigner(self, rot_mat3x3): self.device, self.dtype = rot_mat3x3.device, rot_mat3x3.dtype - self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, y_aligned, 0, self.lmax) + self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, 0, self.lmax) self.wigner_inv = torch.transpose(self.wigner, 1, 2).contiguous() # Rotate the embedding @@ -470,7 +470,7 @@ def rotate_inv(self, embedding, in_lmax: int, in_mmax: int): # Compute Wigner matrices from rotation matrix def RotationToWignerDMatrix( - self, edge_rot_mat, y_aligned, start_lmax: int, end_lmax: int + self, edge_rot_mat, start_lmax: int, end_lmax: int ) -> torch.Tensor: x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) alpha, beta = o3.xyz_to_angles(x) @@ -481,21 +481,42 @@ def RotationToWignerDMatrix( gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) # only apply random z-rotation for y-aligned vectors. - alpha_ya = torch.zeros_like(alpha) - beta_ya = torch.zeros_like(beta) - gamma_ya = torch.rand_like(gamma) * 2 * math.pi + # alpha_ya = torch.zeros_like(alpha) + # beta_ya = torch.zeros_like(beta) + # beta_ya[y_aligned > 0.9999] = 0.0 + # beta_ya[y_aligned < -0.9999] = math.pi + # y_aligned = torch.abs(y_aligned) > 0.9999 + # gamma_ya = torch.zeros_like(gamma) + # gamma_ya = torch.rand_like(gamma) * 2 * math.pi - math.pi + # print(gamma[y_aligned]) + # print(gamma_ya[y_aligned]) + # alpha_ya = torch.clone(alpha).detach() + # beta_ya = torch.clone(beta).detach() + # print(beta_ya[y_aligned]) + # mask = torch.abs(beta_ya) < 1. + # print(mask[y_aligned]) + # beta_ya[mask] = 0. + # gamma_ya = torch.clone(gamma) + # gamma_ya = torch.clone(gamma) + # print(beta_ya[y_aligned]) + size = (end_lmax + 1) ** 2 - (start_lmax) ** 2 wigner = torch.zeros(len(alpha), size, size, device=self.device) start = 0 for lmax in range(start_lmax, end_lmax + 1): block = wigner_D( - lmax, alpha, beta, gamma)[~y_aligned].to(wigner.dtype) - block_ya = wigner_D( - lmax, alpha_ya, beta_ya, gamma_ya)[y_aligned].to(wigner.dtype) + lmax, alpha, beta, gamma).to(wigner.dtype) + # lmax, alpha, beta, gamma)[~y_aligned].to(wigner.dtype) + # block_ya = wigner_D( + # lmax, alpha_ya, beta_ya, gamma_ya)[y_aligned].to(wigner.dtype) + # block_ya = wigner_D( + # lmax, alpha, beta, gamma)[y_aligned].to(wigner.dtype) + end = start + block.size()[1] - wigner[~y_aligned, start:end, start:end] = block - wigner[y_aligned, start:end, start:end] = block_ya + # wigner[~y_aligned, start:end, start:end] = block + # wigner[y_aligned, start:end, start:end] = block_ya + wigner[:, start:end, start:end] = block start = end return wigner