Skip to content

Commit

Permalink
Revert "Add linear interpolation to heat model evaluation"
Browse files Browse the repository at this point in the history
  • Loading branch information
castellinilinguini authored Jul 16, 2023
1 parent 68c4466 commit 6a7966b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 34 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ All notable changes to this project will be documented in this file.
## [Unreleased]


### Changed

- Added linear interpolation in heat model evaluation

## [0.4.0]

### Added
Expand Down
46 changes: 16 additions & 30 deletions eulerpi/examples/heat/heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class Heat(JaxModel):

t_end = 0.1
plate_length = jnp.array([1.0, 1.0])
num_grid_points = 20

param_dim = 3
data_dim = 5 # The values of the heat equation solution at five points are observed. See evaluation_points.
Expand Down Expand Up @@ -105,31 +104,17 @@ def forward(cls, param: np.ndarray) -> np.ndarray:

solution = cls.perform_simulation(kappa=param)

# linearly interpolate the solution at the evaluation points, use own interpolation function as jax doesn't support scipy.interpolate.interp2d and doesn't provide a 2d interpolation function
x = jnp.linspace(0, cls.plate_length[0], cls.num_grid_points)
y = jnp.linspace(0, cls.plate_length[1], cls.num_grid_points)
X, Y = jnp.meshgrid(x, y)

# compute the indices between which the evaluation points lie
x_indices = jnp.searchsorted(x, cls.evaluation_points[:, 0]) - 1
y_indices = jnp.searchsorted(y, cls.evaluation_points[:, 1]) - 1
dx = cls.plate_length[0] / cls.num_grid_points
dy = cls.plate_length[1] / cls.num_grid_points

# interpolate the solution at the evaluation points
solution_at_evaluation_points = (1 / (dx * dy)) * (
solution[x_indices, y_indices]
* (x[x_indices + 1] - cls.evaluation_points[:, 0])
* (y[y_indices + 1] - cls.evaluation_points[:, 1])
+ solution[x_indices + 1, y_indices]
* (cls.evaluation_points[:, 0] - x[x_indices])
* (y[y_indices + 1] - cls.evaluation_points[:, 1])
+ solution[x_indices, y_indices + 1]
* (x[x_indices + 1] - cls.evaluation_points[:, 0])
* (cls.evaluation_points[:, 1] - y[y_indices])
+ solution[x_indices + 1, y_indices + 1]
* (cls.evaluation_points[:, 0] - x[x_indices])
* (cls.evaluation_points[:, 1] - y[y_indices])
# the indices of the wanted evaluation points
evaluation_indices = jnp.multiply(
cls.evaluation_points,
jnp.array(solution.shape),
).astype(int)

solution_at_evaluation_points = jnp.array(
solution[
evaluation_indices[:, 0],
evaluation_indices[:, 1],
]
)
return solution_at_evaluation_points

Expand Down Expand Up @@ -160,10 +145,11 @@ def perform_simulation(cls, kappa: np.ndarray) -> np.ndarray:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

# The grid
x = jnp.linspace(0, cls.plate_length[0], cls.num_grid_points)
y = jnp.linspace(0, cls.plate_length[1], cls.num_grid_points)
dx = cls.plate_length[0] / cls.num_grid_points
dy = cls.plate_length[1] / cls.num_grid_points
num_grid_points = 20
x = jnp.linspace(0, cls.plate_length[0], num_grid_points)
y = jnp.linspace(0, cls.plate_length[1], num_grid_points)
dx = cls.plate_length[0] / num_grid_points
dy = cls.plate_length[1] / num_grid_points

def stable_time_step(dx, dy, kappa):
trace_kappa = kappa[0] + kappa[1]
Expand Down

0 comments on commit 6a7966b

Please sign in to comment.