From 6a7966bbf230349916e47d69761befc2901b9f6d Mon Sep 17 00:00:00 2001 From: Benjamin C <63878559+castellinilinguini@users.noreply.github.com> Date: Sun, 16 Jul 2023 15:10:57 +0200 Subject: [PATCH] Revert "Add linear interpolation to heat model evaluation" --- CHANGELOG.md | 4 --- eulerpi/examples/heat/heat.py | 46 ++++++++++++----------------------- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1dffd84f..ff0614f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,10 +20,6 @@ All notable changes to this project will be documented in this file. ## [Unreleased] -### Changed - -- Added linear interpolation in heat model evaluation - ## [0.4.0] ### Added diff --git a/eulerpi/examples/heat/heat.py b/eulerpi/examples/heat/heat.py index dd662f13..b2b496e3 100644 --- a/eulerpi/examples/heat/heat.py +++ b/eulerpi/examples/heat/heat.py @@ -56,7 +56,6 @@ class Heat(JaxModel): t_end = 0.1 plate_length = jnp.array([1.0, 1.0]) - num_grid_points = 20 param_dim = 3 data_dim = 5 # The values of the heat equation solution at five points are observed. See evaluation_points. @@ -105,31 +104,17 @@ def forward(cls, param: np.ndarray) -> np.ndarray: solution = cls.perform_simulation(kappa=param) - # linearly interpolate the solution at the evaluation points, use own interpolation function as jax doesn't support scipy.interpolate.interp2d and doesn't provide a 2d interpolation function - x = jnp.linspace(0, cls.plate_length[0], cls.num_grid_points) - y = jnp.linspace(0, cls.plate_length[1], cls.num_grid_points) - X, Y = jnp.meshgrid(x, y) - - # compute the indices between which the evaluation points lie - x_indices = jnp.searchsorted(x, cls.evaluation_points[:, 0]) - 1 - y_indices = jnp.searchsorted(y, cls.evaluation_points[:, 1]) - 1 - dx = cls.plate_length[0] / cls.num_grid_points - dy = cls.plate_length[1] / cls.num_grid_points - - # interpolate the solution at the evaluation points - solution_at_evaluation_points = (1 / (dx * dy)) * ( - solution[x_indices, y_indices] - * (x[x_indices + 1] - cls.evaluation_points[:, 0]) - * (y[y_indices + 1] - cls.evaluation_points[:, 1]) - + solution[x_indices + 1, y_indices] - * (cls.evaluation_points[:, 0] - x[x_indices]) - * (y[y_indices + 1] - cls.evaluation_points[:, 1]) - + solution[x_indices, y_indices + 1] - * (x[x_indices + 1] - cls.evaluation_points[:, 0]) - * (cls.evaluation_points[:, 1] - y[y_indices]) - + solution[x_indices + 1, y_indices + 1] - * (cls.evaluation_points[:, 0] - x[x_indices]) - * (cls.evaluation_points[:, 1] - y[y_indices]) + # the indices of the wanted evaluation points + evaluation_indices = jnp.multiply( + cls.evaluation_points, + jnp.array(solution.shape), + ).astype(int) + + solution_at_evaluation_points = jnp.array( + solution[ + evaluation_indices[:, 0], + evaluation_indices[:, 1], + ] ) return solution_at_evaluation_points @@ -160,10 +145,11 @@ def perform_simulation(cls, kappa: np.ndarray) -> np.ndarray: # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" # The grid - x = jnp.linspace(0, cls.plate_length[0], cls.num_grid_points) - y = jnp.linspace(0, cls.plate_length[1], cls.num_grid_points) - dx = cls.plate_length[0] / cls.num_grid_points - dy = cls.plate_length[1] / cls.num_grid_points + num_grid_points = 20 + x = jnp.linspace(0, cls.plate_length[0], num_grid_points) + y = jnp.linspace(0, cls.plate_length[1], num_grid_points) + dx = cls.plate_length[0] / num_grid_points + dy = cls.plate_length[1] / num_grid_points def stable_time_step(dx, dy, kappa): trace_kappa = kappa[0] + kappa[1]