Skip to content

Commit

Permalink
Multiple smaller improvements
Browse files Browse the repository at this point in the history
* Replace Union by |
* Allow multiprocessing in determining emulsions from data
* Add function for refining multiple droplets
  • Loading branch information
david-zwicker committed Oct 16, 2023
1 parent dd9f8b8 commit e78f3ca
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 58 deletions.
11 changes: 0 additions & 11 deletions codecov.yml

This file was deleted.

53 changes: 40 additions & 13 deletions droplets/droplet_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import json
import logging
from typing import Callable, List, Optional, Union
from typing import Callable, List, Literal, Optional

import numpy as np
from numpy.lib import recfunctions as rfn
Expand Down Expand Up @@ -118,7 +118,7 @@ def __len__(self):
"""number of time points"""
return len(self.times)

def __getitem__(self, key: Union[int, slice]):
def __getitem__(self, key: int | slice):
"""return the droplets identified by the given index/slice"""
result = self.droplets.__getitem__(key)
if isinstance(key, slice):
Expand Down Expand Up @@ -357,7 +357,12 @@ def to_file(self, path: str, info: Optional[InfoDict] = None) -> None:

@plot_on_axes()
def plot(
self, attribute: str = "radius", smoothing: float = 0, ax=None, **kwargs
self,
attribute: str = "radius",
smoothing: float = 0,
t_max: Optional[float] = None,
ax=None,
**kwargs,
) -> PlotReference:
"""plot the time evolution of the droplet
Expand Down Expand Up @@ -390,7 +395,14 @@ def plot(
data = self.get_trajectory(smoothing=smoothing, attribute=attribute)
ylabel = attribute.capitalize()

(line,) = ax.plot(self.times, data, **kwargs)
if t_max is not None and len(self.times) >= 2 and self.times[-1] < t_max:
dt = self.times[-1] - self.times[-2]
times = np.r_[self.times, self.times[-1] + dt]
data = np.r_[data, 0]
else:
times = self.times

(line,) = ax.plot(times, data, **kwargs)
ax.set_xlabel("Time")
ax.set_ylabel(ylabel)
return PlotReference(ax, line, {"attribute": attribute})
Expand Down Expand Up @@ -469,7 +481,7 @@ def plot_positions(
class DropletTrackList(list):
"""a list of instances of :class:`DropletTrack`"""

def __getitem__(self, key: Union[int, slice]): # type: ignore
def __getitem__(self, key: int | slice): # type: ignore
"""return the droplets identified by the given index/slice"""
result = super().__getitem__(key)
if isinstance(key, slice):
Expand All @@ -482,7 +494,7 @@ def from_emulsion_time_course(
cls,
time_course: EmulsionTimeCourse,
*,
method: str = "overlap",
method: Literal["distance", "overlap"] = "overlap",
grid: Optional[GridBase] = None,
progress: bool = False,
**kwargs,
Expand Down Expand Up @@ -554,7 +566,7 @@ def match_tracks(
# calculate the distance between droplets
if tracks_alive:
if grid is None:
metric: Union[str, Callable] = "euclidean"
metric: str | Callable = "euclidean"
else:
metric = grid.distance_real
points_prev = [track.last.position for track in tracks_alive]
Expand Down Expand Up @@ -603,8 +615,10 @@ def match_tracks(
def from_storage(
cls,
storage: StorageBase,
*,
method: Literal["distance", "overlap"] = "overlap",
refine: bool = False,
method: str = "overlap",
num_processes: int | Literal["auto"] = 1,
progress: Optional[bool] = None,
) -> DropletTrackList:
r"""obtain droplet tracks from stored scalar field data
Expand All @@ -615,22 +629,27 @@ def from_storage(
Args:
storage (:class:`~pde.storage.base.StorageBase`):
The phase fields for many time instances
refine (bool):
Flag determining whether the droplet properties should be refined
using fitting. This is a potentially slow procedure.
method (str):
The method used for tracking droplet identities. Possible methods are
"overlap" (adding droplets that overlap with those in previous frames)
and "distance" (matching droplets to minimize center-to-center
distances).
refine (bool):
Flag determining whether the droplet properties should be refined
using fitting. This is a potentially slow procedure.
num_processes (int or "auto"):
Number of processes used for the refinement. If set to "auto", the
number of processes is choosen automatically.
progress (bool):
Whether to show the progress of the process. If `None`, the progress is
not shown, except for the first step if `refine` is `True`.
Returns:
:class:`DropletTrackList`: the resulting droplet tracks
"""
etc = EmulsionTimeCourse.from_storage(storage, refine=refine, progress=progress)
etc = EmulsionTimeCourse.from_storage(
storage, refine=refine, num_processes=num_processes, progress=progress
)
if progress is None:
progress = False
return cls.from_emulsion_time_course(etc, method=method, progress=progress)
Expand Down Expand Up @@ -717,11 +736,19 @@ def plot(self, attribute: str = "radius", ax=None, **kwargs) -> PlotReference:
else:
kwargs["color"] = "k" # use black by default

# get maximal time
if self:
t_max = max(track.times[-1] for track in self if len(track.times) > 0)
else:
t_max = None

# adjust alpha such that multiple tracks are visible well
kwargs.setdefault("alpha", min(0.8, 20 / len(self)))
elements = []
for track in self:
elements.append(track.plot(attribute=attribute, ax=ax, **kwargs).element)
elements.append(
track.plot(attribute=attribute, t_max=t_max, ax=ax, **kwargs).element
)
kwargs["label"] = "" # set potential plot label only for first track

return PlotReference(ax, elements, {"attribute": attribute})
Expand Down
56 changes: 40 additions & 16 deletions droplets/emulsions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from __future__ import annotations

import functools
import json
import logging
import math
Expand All @@ -24,11 +25,11 @@
Iterable,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
overload,
)

Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
droplets: Optional[Iterable[SphericalDroplet]] = None,
*,
copy: bool = True,
dtype: Union[np.typing.DTypeLike, np.ndarray, SphericalDroplet] = None,
dtype: np.typing.DTypeLike | np.ndarray | SphericalDroplet = None,
force_consistency: bool = False,
grid: Optional[GridBase] = None,
):
Expand Down Expand Up @@ -121,8 +122,8 @@ def empty(cls, droplet: SphericalDroplet) -> Emulsion:
def from_random(
cls,
num: int,
grid_or_bounds: Union[GridBase, Sequence[Tuple[float, float]]],
radius: Union[float, Tuple[float, float]],
grid_or_bounds: GridBase | Sequence[Tuple[float, float]],
radius: float | Tuple[float, float],
*,
remove_overlapping: bool = True,
droplet_class: Type[SphericalDroplet] = SphericalDroplet,
Expand Down Expand Up @@ -267,7 +268,10 @@ def append(
force_consistency (bool, optional):
Whether to ensure that all droplets are of the same type
"""
if self.dtype is None:
# during some multiprocessing examples, Emulsions might apparently not have
# a proper `dtype` define. This might be because they use some copying or
# __getstate__ methods of the underlying list class
if not hasattr(self, "dtype") or self.dtype is None:
self.dtype = droplet.data.dtype
elif force_consistency and self.dtype != droplet.data.dtype:
raise ValueError(
Expand Down Expand Up @@ -646,7 +650,7 @@ def plot(
color_value: Optional[Callable] = None,
cmap=None,
norm=None,
colorbar: Union[bool, str] = True,
colorbar: bool | str = True,
**kwargs,
) -> PlotReference:
"""plot the current emulsion together with a corresponding field
Expand Down Expand Up @@ -791,7 +795,7 @@ class EmulsionTimeCourse:
def __init__(
self,
emulsions: Optional[Iterable[Emulsion]] = None,
times: Union[np.ndarray, Sequence[float], None] = None,
times: np.ndarray | Sequence[float] | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -858,7 +862,7 @@ def __repr__(self):
def __len__(self):
return len(self.times)

def __getitem__(self, key: Union[int, slice]):
def __getitem__(self, key: int | slice):
"""return the information for the given index"""
result = self.emulsions.__getitem__(key)
if isinstance(key, slice):
Expand All @@ -882,6 +886,8 @@ def __eq__(self, other):
def from_storage(
cls,
storage: StorageBase,
*,
num_processes: int | Literal["auto"] = 1,
refine: bool = False,
progress: Optional[bool] = None,
**kwargs,
Expand All @@ -894,9 +900,13 @@ def from_storage(
refine (bool):
Flag determining whether the droplet properties should be refined
using fitting. This is a potentially slow procedure.
num_processes (int or "auto"):
Number of processes used for the refinement. If set to "auto", the
number of processes is choosen automatically.
progress (bool):
Whether to show the progress of the process. If `None`, the progress is
only shown when `refine` is `True`.
only shown when `refine` is `True`. Progress bars are only shown for
serial calculations (where `num_processes == 1`).
\**kwargs:
All other parameters are forwarded to the
:meth:`~droplets.image_analysis.locate_droplets`.
Expand All @@ -906,14 +916,28 @@ def from_storage(
"""
from .image_analysis import locate_droplets

if progress is None:
progress = refine # show progress only when refining by default
if num_processes == 1:
# obtain the emulsion data for all frames in this process

# obtain the emulsion data for all frames
emulsions = (
locate_droplets(frame, refine=refine, **kwargs)
for frame in display_progress(storage, enabled=progress)
)
if progress is None:
progress = refine # show progress only when refining by default

emulsions: Iterable[Emulsion] = (
locate_droplets(frame, refine=refine, **kwargs)
for frame in display_progress(storage, enabled=progress)
)

else:
# use multiprocessing to obtain emulsion data
from concurrent.futures import ProcessPoolExecutor

_get_emulsion: Callable[[Emulsion], Emulsion] = functools.partial(
locate_droplets, refine=refine, **kwargs
)

max_workers = None if num_processes == "auto" else num_processes
with ProcessPoolExecutor(max_workers=max_workers) as executor:
emulsions = list(executor.map(_get_emulsion, storage))

return cls(emulsions, times=storage.times)

Expand Down
Loading

0 comments on commit e78f3ca

Please sign in to comment.