diff --git a/eulerpi/examples/heat/heat.py b/eulerpi/examples/heat/heat.py index 7fab8ca5..1be0b201 100644 --- a/eulerpi/examples/heat/heat.py +++ b/eulerpi/examples/heat/heat.py @@ -24,16 +24,15 @@ def heat_rhs(t: float, u: jnp.ndarray, args: tuple | list) -> jnp.ndarray: # use the central difference scheme to approximate the derivatives. Gradient preserves the size of the array by using one-sided differences. We throw away the boundary points later. du_dx = jnp.gradient(u, dx, axis=0) - du_dy = jnp.gradient(u, dy, axis=1) - du_dx2 = jnp.gradient(du_dx, dx, axis=0) - du_dy2 = jnp.gradient(du_dy, dy, axis=1) + du_dx2 = (u[2:, 1:-1] - 2 * u[1:-1, 1:-1] + u[:-2, 1:-1]) / dx**2 + du_dy2 = (u[1:-1, 2:] - 2 * u[1:-1, 1:-1] + u[1:-1, :-2]) / dy**2 du_dx_dy = jnp.gradient(du_dx, dy, axis=1) # compute the right hand side of the heat equation rhs = jnp.zeros(u.shape) rhs = rhs.at[1:-1, 1:-1].set( - param[0] * du_dx2[1:-1, 1:-1] - + param[1] * du_dy2[1:-1, 1:-1] + param[0] * du_dx2 + + param[1] * du_dy2 + 2 * param[2] * du_dx_dy[1:-1, 1:-1] ) return rhs