Skip to content

Commit

Permalink
Fix second order derivates by manual implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiserls authored and castellinilinguini committed Jul 20, 2023
1 parent b5eeab4 commit 1fbf1c2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions eulerpi/examples/heat/heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ def heat_rhs(t: float, u: jnp.ndarray, args: tuple | list) -> jnp.ndarray:

# use the central difference scheme to approximate the derivatives. Gradient preserves the size of the array by using one-sided differences. We throw away the boundary points later.
du_dx = jnp.gradient(u, dx, axis=0)
du_dy = jnp.gradient(u, dy, axis=1)
du_dx2 = jnp.gradient(du_dx, dx, axis=0)
du_dy2 = jnp.gradient(du_dy, dy, axis=1)
du_dx2 = (u[2:, 1:-1] - 2 * u[1:-1, 1:-1] + u[:-2, 1:-1]) / dx**2
du_dy2 = (u[1:-1, 2:] - 2 * u[1:-1, 1:-1] + u[1:-1, :-2]) / dy**2
du_dx_dy = jnp.gradient(du_dx, dy, axis=1)

# compute the right hand side of the heat equation
rhs = jnp.zeros(u.shape)
rhs = rhs.at[1:-1, 1:-1].set(
param[0] * du_dx2[1:-1, 1:-1]
+ param[1] * du_dy2[1:-1, 1:-1]
param[0] * du_dx2
+ param[1] * du_dy2
+ 2 * param[2] * du_dx_dy[1:-1, 1:-1]
)
return rhs
Expand Down

0 comments on commit 1fbf1c2

Please sign in to comment.