From 79d61a6891397b576af4bfa59213dd2c6b1ed1ef Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Fri, 19 Jul 2024 23:57:18 +0000 Subject: [PATCH] simplified fix for correct and stable gradient-based force model training. --- .../core/models/equiformer_v2/edge_rot_mat.py | 12 +++++++++++- src/fairchem/core/models/equiformer_v2/so3.py | 3 +++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py b/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py index c83cc3143..f2c25359e 100644 --- a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py +++ b/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py @@ -44,6 +44,8 @@ def init_edge_rot_mat(edge_distance_vec): norm_y = torch.cross(norm_x, norm_z, dim=1) norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) + yprod = (norm_x @ norm_x.new_tensor([0,1,0])) + # Construct the 3D rotation matrix norm_x = norm_x.view(-1, 3, 1) norm_y = -norm_y.view(-1, 3, 1) @@ -52,4 +54,12 @@ def init_edge_rot_mat(edge_distance_vec): edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - return edge_rot_mat.detach() + output = torch.zeros_like(edge_rot_mat) + mask = (yprod > -0.9999) & (yprod < 0.9999) + output[mask] = edge_rot_mat[mask] + output[~mask, 0, :] = edge_rot_mat[~mask, 0, :] + output[~mask, 2, :] = edge_rot_mat[~mask, 2, :] + output[yprod > 0.9999, 1, :] = edge_rot_mat.new_tensor([[0., 1., 0.]]) + output[yprod < -0.9999, 1, :] = edge_rot_mat.new_tensor([[0., -1., 0.]]) + + return output \ No newline at end of file diff --git a/src/fairchem/core/models/equiformer_v2/so3.py b/src/fairchem/core/models/equiformer_v2/so3.py index a3d58586e..ea04f4910 100644 --- a/src/fairchem/core/models/equiformer_v2/so3.py +++ b/src/fairchem/core/models/equiformer_v2/so3.py @@ -476,6 +476,9 @@ def RotationToWignerDMatrix( ) -> torch.Tensor: x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) alpha, beta = o3.xyz_to_angles(x) + yprod = (x @ x.new_tensor([0, 1, 0])) + beta[yprod > 0.9999] = 0. + beta[yprod < -0.9999] = math.pi R = ( o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) @ edge_rot_mat