Skip to content

Commit

Permalink
Addressing review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
NewtonSander committed Aug 8, 2023
1 parent 987c69e commit c806dd1
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 22 deletions.
3 changes: 1 addition & 2 deletions docs/examples/plot_phased_array_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@
tilt_angle=30,
height=5.0e-3,
)
scenario_3d.sources.clear()
scenario_3d.sources.append(phased_3d)
scenario_3d.sources = [phased_3d]
scenario_3d.make_grid()
scenario_3d.compile_problem()
results = scenario_3d.simulate_steady_state()
Expand Down
6 changes: 3 additions & 3 deletions src/neurotechdevkit/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def make_shaped_grid(
@staticmethod
def make_grid(
extent: Union[Tuple[float, float], Tuple[float, float, float]],
speed_water=float,
center_frequency=float,
ppw=int,
speed_water: float,
center_frequency: float,
ppw: int,
extra: Union[int, Iterable[int]] = 50,
absorbing: Union[int, Iterable[int]] = 40,
) -> "Grid":
Expand Down
2 changes: 1 addition & 1 deletion src/neurotechdevkit/results/_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def create_pulsed_result(
wavefield: npt.NDArray[np.float_],
traces: stride.Traces,
recorded_slice: tuple[int, float] | None = None,
) -> PulsedResult:
) -> Union[PulsedResult2D, PulsedResult3D]:
"""Create results from pulsed simulations.
Creates a PulsedResult2D or PulsedResult3D depending on the number of wavefield
Expand Down
16 changes: 10 additions & 6 deletions src/neurotechdevkit/scenarios/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from enum import IntEnum
from types import SimpleNamespace
from typing import Mapping, Optional
from typing import Mapping, Optional, Union

import nest_asyncio
import numpy as np
Expand Down Expand Up @@ -475,7 +475,7 @@ def _simulate_pulse(
n_jobs: int | None = None,
slice_axis: int | None = None,
slice_position: float | None = None,
) -> results.PulsedResult:
) -> Union[results.PulsedResult2D, results.PulsedResult3D]:
"""Execute a pulsed simulation.
In this simulation, the sources will emit a pulse containing a few cycles of
Expand Down Expand Up @@ -833,7 +833,7 @@ def simulate_pulse(
simulation_time: float | None = None,
recording_time_undersampling: int = 4,
n_jobs: int | None = None,
) -> results.PulsedResult:
) -> results.PulsedResult2D:
"""Execute a pulsed simulation in 2D.
In this simulation, the sources will emit a pulse containing a few cycles of
Expand All @@ -860,14 +860,16 @@ def simulate_pulse(
Returns:
An object containing the result of the 2D pulsed simulation.
"""
return self._simulate_pulse(
result = self._simulate_pulse(
points_per_period=points_per_period,
simulation_time=simulation_time,
recording_time_undersampling=recording_time_undersampling,
n_jobs=n_jobs,
slice_axis=None,
slice_position=None,
)
assert isinstance(result, results.PulsedResult2D)
return result

def render_layout(
self,
Expand Down Expand Up @@ -1015,7 +1017,7 @@ def simulate_pulse(
n_jobs: int | None = None,
slice_axis: int | None = None,
slice_position: float | None = None,
) -> results.PulsedResult:
) -> results.PulsedResult3D:
"""Execute a pulsed simulation in 3D.
In this simulation, the sources will emit a pulse containing a few cycles of
Expand Down Expand Up @@ -1048,14 +1050,16 @@ def simulate_pulse(
Returns:
An object containing the result of the 3D pulsed simulation.
"""
return self._simulate_pulse(
result = self._simulate_pulse(
points_per_period=points_per_period,
simulation_time=simulation_time,
recording_time_undersampling=recording_time_undersampling,
n_jobs=n_jobs,
slice_axis=slice_axis,
slice_position=slice_position,
)
assert isinstance(result, results.PulsedResult3D)
return result

def render_layout(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/neurotechdevkit/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Source(abc.ABC):
appropriate source geometry.
Args:
position (list[float]): a numpy float array indicating the
position (npt.ArrayLike): a numpy float array indicating the
coordinates (in meters) of the point at the center of the source.
direction (list[float]): a numpy float array representing a vector
located at position and pointing towards the focal point. Only the
Expand All @@ -37,7 +37,7 @@ class Source(abc.ABC):
def __init__(
self,
*,
position: list[float],
position: npt.ArrayLike,
direction: list[float],
aperture: float,
focal_length: float,
Expand Down
8 changes: 3 additions & 5 deletions tests/neurotechdevkit/scenarios/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self):
self.return_value = "I'll be back"
self.last_args = None
self.last_kwargs = None
fake_wf = np.arange(12).reshape((3, 4, 1))
fake_wf = np.arange(12).reshape((3, 4, 1, 1))
self.wavefield = SimpleNamespace(data=fake_wf)

async def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -349,8 +349,7 @@ def test_simulate_steady_state_result_wavefield(base_tester, fake_pde, monkeypat
points_per_period=9, n_cycles_steady_state=4, recording_time_undersampling=7
)

# drop the final timestep, then swap axes
expected_wavefield = np.expand_dims(np.arange(8).reshape((2, 4)).T, 1)
expected_wavefield = np.expand_dims(np.arange(8).reshape((2, 4)).T, (1, 2))
np.testing.assert_array_equal(result.wavefield, expected_wavefield)


Expand Down Expand Up @@ -398,8 +397,7 @@ def test_simulate_pulse_result_wavefield(base_tester, fake_pde, monkeypatch):
points_per_period=9, simulation_time=3e-4, recording_time_undersampling=7
)

# drop the final timestep, then swap axes
expected_wavefield = np.expand_dims(np.arange(8).reshape((2, 4)).T, 1)
expected_wavefield = np.expand_dims(np.arange(8).reshape((2, 4)).T, (1, 2))
np.testing.assert_array_equal(result.wavefield, expected_wavefield)


Expand Down
6 changes: 3 additions & 3 deletions tests/neurotechdevkit/scenarios/test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def test_add_material_fields():
"brain": np.array([[1, 0], [0, 1]], dtype=bool),
},
)
assert np.array_equal(
np.testing.assert_array_equal(
problem.medium.vp.data,
np.array([[brain.vp, water.vp], [water.vp, brain.vp]], dtype=np.float32),
)
assert np.array_equal(
np.testing.assert_array_equal(
problem.medium.rho.data,
np.array([[brain.rho, water.rho], [water.rho, brain.rho]], dtype=np.float32),
)
assert np.array_equal(
np.testing.assert_array_equal(
problem.medium.alpha.data,
np.array(
[[brain.alpha, water.alpha], [water.alpha, brain.alpha]], dtype=np.float32
Expand Down

0 comments on commit c806dd1

Please sign in to comment.