diff --git a/src/continuity/pde/grad.py b/src/continuity/pde/grad.py index 6b91793f..d1589ad7 100644 --- a/src/continuity/pde/grad.py +++ b/src/continuity/pde/grad.py @@ -84,15 +84,8 @@ def forward(self, x: Tensor, u: Tensor, y: Optional[Tensor] = None) -> Tensor: assert x.requires_grad, "x must require gradients for divergence operator" - # Compute gradients - gradients = torch.autograd.grad( - u, - x, - grad_outputs=torch.ones_like(u), - create_graph=True, - retain_graph=True, - )[0] - + # Compute divergence + gradients = Grad()(x, u) return torch.sum(gradients, dim=-1, keepdim=True)