From 1b98f064dca53efff122e241b3c7072adb2e9aa6 Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Thu, 15 Feb 2024 09:56:25 +0100 Subject: [PATCH] Use Grad in Div. --- src/continuity/pde/grad.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/continuity/pde/grad.py b/src/continuity/pde/grad.py index 6b91793f..d1589ad7 100644 --- a/src/continuity/pde/grad.py +++ b/src/continuity/pde/grad.py @@ -84,15 +84,8 @@ def forward(self, x: Tensor, u: Tensor, y: Optional[Tensor] = None) -> Tensor: assert x.requires_grad, "x must require gradients for divergence operator" - # Compute gradients - gradients = torch.autograd.grad( - u, - x, - grad_outputs=torch.ones_like(u), - create_graph=True, - retain_graph=True, - )[0] - + # Compute divergence + gradients = Grad()(x, u) return torch.sum(gradients, dim=-1, keepdim=True)