Skip to content

Commit

Permalink
add support for randomized strides
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Oct 9, 2024
1 parent b174356 commit 5c5a778
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 19 deletions.
3 changes: 2 additions & 1 deletion mwatershed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def agglom(
seeds: Optional[np.ndarray] = None,
edges: Optional[list[tuple[bool, int, int]]] = None,
strides: Optional[list[list[int]]] = None,
randomized_strides: bool = False,
):
return agglom_rs(affinities, offsets, seeds, edges, strides)
return agglom_rs(affinities, offsets, seeds, edges, strides, randomized_strides)


__all__ = ["agglom", "cluster"]
103 changes: 85 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3::PyResult;

use ndarray_rand::rand;

use ordered_float::NotNan;
use std::collections::{HashMap, HashSet};
use std::convert::TryInto;
Expand All @@ -26,6 +28,7 @@ pub fn get_edges<const D: usize>(
offsets: Vec<Vec<isize>>,
seeds: &Array<usize, IxDyn>,
strides: Option<Vec<Vec<usize>>>,
randomized_strides: bool,
) -> (Vec<AgglomEdge>, HashSet<usize>) {
// let (_, array_shape) = get_dims::<D>(seeds.dim(), 0);
let offsets: Vec<[isize; D]> = offsets
Expand All @@ -52,40 +55,59 @@ pub fn get_edges<const D: usize>(
.for_each(|(offset_index, (offset, stride))| {
let all_offset_affs = affinities.index_axis(Axis(0), offset_index);
let offset_affs = all_offset_affs.slice_each_axis(|ax| {
let step = if randomized_strides {
1
} else {
stride[ax.axis.index()].try_into().unwrap()
};
Slice::new(
std::cmp::max(0, -offset[ax.axis.index()]),
Some(std::cmp::min(
ax.len as isize,
(ax.len as isize) - offset[ax.axis.index()],
)),
stride[ax.axis.index()].try_into().unwrap(),
step,
)
});
let u_seeds = seeds.slice_each_axis(|ax| {
let step = if randomized_strides {
1
} else {
stride[ax.axis.index()].try_into().unwrap()
};
Slice::new(
std::cmp::max(0, -offset[ax.axis.index()]),
Some(std::cmp::min(
ax.len as isize,
(ax.len as isize) - offset[ax.axis.index()],
)),
stride[ax.axis.index()].try_into().unwrap(),
step,
)
});
let v_seeds = seeds.slice_each_axis(|ax| {
let step = if randomized_strides {
1
} else {
stride[ax.axis.index()].try_into().unwrap()
};
Slice::new(
std::cmp::max(0, offset[ax.axis.index()]),
Some(std::cmp::min(
ax.len as isize,
(ax.len as isize) + offset[ax.axis.index()],
)),
stride[ax.axis.index()].try_into().unwrap(),
step,
)
});
offset_affs.indexed_iter().for_each(|(index, aff)| {
let u = u_seeds[&index];
let v = v_seeds[&index];
affs.push(NotNan::new(aff.abs()).expect("Cannot handle `nan` affinities"));
edges.push(AgglomEdge(aff > &0.0, u, v))
if !randomized_strides
|| rand::random::<f32>() < 1.0 / stride.iter().product::<usize>() as f32
{
let u = u_seeds[&index];
let v = v_seeds[&index];
affs.push(NotNan::new(aff.abs()).expect("Cannot handle `nan` affinities"));
edges.push(AgglomEdge(aff > &0.0, u, v))
}
});
});
let agglom_edges: Vec<AgglomEdge> = affs
Expand Down Expand Up @@ -116,6 +138,7 @@ pub fn agglomerate<const D: usize>(
mut edges: Vec<AgglomEdge>,
mut seeds: Array<usize, IxDyn>,
strides: Option<Vec<Vec<usize>>>,
randomized_strides: bool,
) -> Array<usize, IxDyn> {
// relabel to consecutive ids
let mut lookup = HashMap::new();
Expand Down Expand Up @@ -143,7 +166,7 @@ pub fn agglomerate<const D: usize>(

// main algorithm
let (sorted_edges, mut filtered_background) =
get_edges::<D>(affinities, offsets, &seeds, strides);
get_edges::<D>(affinities, offsets, &seeds, strides, randomized_strides);
edges.extend(sorted_edges);
lookup.values().for_each(|node_id| {
filtered_background.remove(node_id);
Expand Down Expand Up @@ -225,6 +248,7 @@ fn agglom_rs<'py>(
seeds: Option<&PyArrayDyn<usize>>,
edges: Option<Vec<(bool, usize, usize)>>,
strides: Option<Vec<Vec<usize>>>,
randomized_strides: Option<bool>,
) -> PyResult<&'py PyArrayDyn<usize>> {
let affinities = unsafe { affinities.as_array() }.to_owned();
let seeds = match seeds {
Expand All @@ -238,12 +262,54 @@ fn agglom_rs<'py>(
.map(|(pos, u, v)| AgglomEdge(pos, u, v))
.collect();
let result = match dim {
1 => agglomerate::<1>(&affinities, offsets, edges, seeds, strides),
2 => agglomerate::<2>(&affinities, offsets, edges, seeds, strides),
3 => agglomerate::<3>(&affinities, offsets, edges, seeds, strides),
4 => agglomerate::<4>(&affinities, offsets, edges, seeds, strides),
5 => agglomerate::<5>(&affinities, offsets, edges, seeds, strides),
6 => agglomerate::<6>(&affinities, offsets, edges, seeds, strides),
1 => agglomerate::<1>(
&affinities,
offsets,
edges,
seeds,
strides,
randomized_strides.unwrap_or(false),
),
2 => agglomerate::<2>(
&affinities,
offsets,
edges,
seeds,
strides,
randomized_strides.unwrap_or(false),
),
3 => agglomerate::<3>(
&affinities,
offsets,
edges,
seeds,
strides,
randomized_strides.unwrap_or(false),
),
4 => agglomerate::<4>(
&affinities,
offsets,
edges,
seeds,
strides,
randomized_strides.unwrap_or(false),
),
5 => agglomerate::<5>(
&affinities,
offsets,
edges,
seeds,
strides,
randomized_strides.unwrap_or(false),
),
6 => agglomerate::<6>(
&affinities,
offsets,
edges,
seeds,
strides,
randomized_strides.unwrap_or(false),
),
_ => panic!["Only 1-6 dimensional arrays supported"],
};
Ok(result.into_pyarray(_py))
Expand Down Expand Up @@ -308,7 +374,7 @@ mod tests {
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, 1], vec![1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None);
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None, false);
let ids = components
.clone()
.into_iter()
Expand Down Expand Up @@ -352,7 +418,8 @@ mod tests {
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, 1], vec![1, 0]];
let strides = vec![vec![2, 1], vec![1, 2]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, Some(strides));
let components =
agglomerate::<2>(&affinities, offsets, vec![], seeds, Some(strides), false);
let ids = components
.clone()
.into_iter()
Expand Down Expand Up @@ -396,7 +463,7 @@ mod tests {
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, -1], vec![-1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None);
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None, false);
let ids = components
.clone()
.into_iter()
Expand Down Expand Up @@ -440,7 +507,7 @@ mod tests {
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, -1], vec![-1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None);
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None, false);
let ids = components
.clone()
.into_iter()
Expand Down

0 comments on commit 5c5a778

Please sign in to comment.