diff --git a/src/fairchem/core/models/equiformer_v2/so3.py b/src/fairchem/core/models/equiformer_v2/so3.py index ea04f4910..fcd3c994f 100644 --- a/src/fairchem/core/models/equiformer_v2/so3.py +++ b/src/fairchem/core/models/equiformer_v2/so3.py @@ -453,8 +453,6 @@ def set_wigner(self, rot_mat3x3): self.device, self.dtype = rot_mat3x3.device, rot_mat3x3.dtype self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, 0, self.lmax) self.wigner_inv = torch.transpose(self.wigner, 1, 2).contiguous() - self.wigner = self.wigner.detach() - self.wigner_inv = self.wigner_inv.detach() # Rotate the embedding def rotate(self, embedding, out_lmax: int, out_mmax: int): @@ -494,7 +492,7 @@ def RotationToWignerDMatrix( wigner[:, start:end, start:end] = block start = end - return wigner.detach() + return wigner class SO3_Grid(torch.nn.Module):