Skip to content

Commit

Permalink
simplified fix for correct and stable gradient-based force model trai…
Browse files Browse the repository at this point in the history
…ning.
  • Loading branch information
kyonofx committed Jul 19, 2024
1 parent f9ecf73 commit 79d61a6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/fairchem/core/models/equiformer_v2/edge_rot_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def init_edge_rot_mat(edge_distance_vec):
norm_y = torch.cross(norm_x, norm_z, dim=1)
norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True)))

yprod = (norm_x @ norm_x.new_tensor([0,1,0]))

# Construct the 3D rotation matrix
norm_x = norm_x.view(-1, 3, 1)
norm_y = -norm_y.view(-1, 3, 1)
Expand All @@ -52,4 +54,12 @@ def init_edge_rot_mat(edge_distance_vec):
edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2)
edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2)

return edge_rot_mat.detach()
output = torch.zeros_like(edge_rot_mat)
mask = (yprod > -0.9999) & (yprod < 0.9999)
output[mask] = edge_rot_mat[mask]
output[~mask, 0, :] = edge_rot_mat[~mask, 0, :]
output[~mask, 2, :] = edge_rot_mat[~mask, 2, :]
output[yprod > 0.9999, 1, :] = edge_rot_mat.new_tensor([[0., 1., 0.]])
output[yprod < -0.9999, 1, :] = edge_rot_mat.new_tensor([[0., -1., 0.]])

return output
3 changes: 3 additions & 0 deletions src/fairchem/core/models/equiformer_v2/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ def RotationToWignerDMatrix(
) -> torch.Tensor:
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
alpha, beta = o3.xyz_to_angles(x)
yprod = (x @ x.new_tensor([0, 1, 0]))
beta[yprod > 0.9999] = 0.
beta[yprod < -0.9999] = math.pi
R = (
o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2)
@ edge_rot_mat
Expand Down

0 comments on commit 79d61a6

Please sign in to comment.