Skip to content

Commit

Permalink
Support points 2d (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Jul 2, 2024
1 parent e9ca146 commit 0e8d4e8
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
37 changes: 33 additions & 4 deletions holonote/annotate/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ class Style(param.Parameterized):
line_opts = _StyleOpts(default={})
span_opts = _StyleOpts(default={})
rectangle_opts = _StyleOpts(default={})
points_opts = _StyleOpts(default={})

# Editor opts
edit_opts = _StyleOpts(default={"line_color": "black"})
edit_line_opts = _StyleOpts(default={})
edit_span_opts = _StyleOpts(default={})
edit_rectangle_opts = _StyleOpts(default={})
edit_points_opts = _StyleOpts(default={})

_groupby = ()
_colormap = None
Expand Down Expand Up @@ -133,6 +135,7 @@ def indicator(self, **select_opts) -> tuple[hv.Options, ...]:
hv.opts.HSpans(**opts, **self.span_opts),
hv.opts.VLines(**opts, **self.line_opts),
hv.opts.HLines(**opts, **self.line_opts),
hv.opts.Points(**opts, **self.points_opts),
)

def editor(self) -> tuple[hv.Options, ...]:
Expand All @@ -148,6 +151,7 @@ def editor(self) -> tuple[hv.Options, ...]:
hv.opts.HSpan(**opts, **self.edit_span_opts),
hv.opts.VLine(**opts, **self.edit_line_opts),
hv.opts.HLine(**opts, **self.edit_line_opts),
hv.opts.Points(**opts, **self.edit_points_opts),
)

def reset(self) -> None:
Expand Down Expand Up @@ -177,8 +181,6 @@ def points_2d(
cls, data, region_labels, fields_labels, invert_axes=False, groupby: str | None = None
):
"Vectorizes point regions to VLines * HLines. Note does not support hover info"
msg = "2D point regions not supported yet"
raise NotImplementedError(msg)
vdims = [*fields_labels, "__selected__"]
element = hv.Points(data, kdims=region_labels, vdims=vdims)
hover = cls._build_hover_tool(data)
Expand Down Expand Up @@ -236,6 +238,17 @@ class AnnotationDisplay(param.Parameterized):

data = param.DataFrame(doc="Combined dataframe of annotation data", constant=True)

nearest_2d_point_threshold = param.Number(
default=None,
bounds=(0, None),
doc="""
Threshold In the distance in data coordinates between the two dimensions;
it does not consider the unit and magnitude differences between the dimensions
for selecting an existing 2D point; anything over this threshold will create
a new point instead. This parameter is experimental and is subject to change.
""",
)

invert_axis = param.Boolean(default=False, doc="Switch the annotation axis")

_count = param.Integer(default=0, precedence=-1)
Expand Down Expand Up @@ -425,13 +438,27 @@ def get_indices_by_position(self, **inputs) -> list[Any]:
iter_mask = (
(df[f"start[{k}]"] <= v) & (v < df[f"end[{k}]"]) for k, v in inputs.items()
)
subset = reduce(np.logical_and, iter_mask)
out = list(df[subset].index)
elif self.region_format == "point-point":
xk, yk = list(inputs.keys())
xdist = (df[f"point[{xk}]"] - inputs[xk]) ** 2
ydist = (df[f"point[{yk}]"] - inputs[yk]) ** 2
distance_squared = xdist + ydist
if (
self.nearest_2d_point_threshold
and (distance_squared > self.nearest_2d_point_threshold**2).all()
):
return []
out = [df.loc[distance_squared.idxmin()].name] # index == name of series
elif "point" in self.region_format:
iter_mask = ((df[f"point[{k}]"] - v).abs().argmin() for k, v in inputs.items())
out = list(df[reduce(np.logical_and, iter_mask)].index)
else:
msg = f"{self.region_format} not implemented"
raise NotImplementedError(msg)

return list(df[reduce(np.logical_and, iter_mask)].index)
return out

def register_tap_selector(self, element: hv.Element) -> hv.Element:
def tap_selector(x, y) -> None: # Tap tool must be enabled on the element
Expand Down Expand Up @@ -484,7 +511,9 @@ def overlay(self, indicators=True, editor=True) -> hv.Overlay:

def static_indicators(self, **events):
fields_labels = self.annotator.all_fields
region_labels = [k for k in self.data.columns if k not in fields_labels]
region_labels = [
k for k in self.data.columns if k not in fields_labels and k != "__selected__"
]

self.data["__selected__"] = self.data.index.isin(self.annotator.selected_indices)

Expand Down
68 changes: 68 additions & 0 deletions holonote/tests/test_display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import holoviews as hv

hv.extension("bokeh")


class TestPoint2D:
def test_get_indices_by_position_exact(self, annotator_point2d):
x, y = 0.5, 0.3
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=x, y=y)
assert len(indices) == 1

def test_get_indices_by_position_nearest_2d_point_threshold(self, annotator_point2d):
x, y = 0.5, 0.3
annotator_point2d.get_display("x", "y").nearest_2d_point_threshold = 1
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=x + 1.5, y=y + 1.5)
assert len(indices) == 0

display.nearest_2d_point_threshold = 5
indices = display.get_indices_by_position(x=x + 0.5, y=y + 0.5)
assert len(indices) == 1

def test_get_indices_by_position_nearest(self, annotator_point2d):
x, y = 0.5, 0.3
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=x + 1.5, y=y + 1.5)
assert len(indices) == 1

display.nearest_2d_point_threshold = 5
indices = display.get_indices_by_position(x=x + 0.5, y=y + 0.5)
assert len(indices) == 1

def test_get_indices_by_position_empty(self, annotator_point2d):
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=0.5, y=0.3)
assert len(indices) == 0

def test_get_indices_by_position_no_position(self, annotator_point2d):
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=None, y=None)
assert len(indices) == 0

def test_get_indices_by_position_multi_choice(self, annotator_point2d):
x, y = 0.5, 0.3
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)

x2, y2 = 0.51, 0.31
description = "A test annotation!"
annotator_point2d.set_regions(x=x2, y=y2)
annotator_point2d.add_annotation(description=description)

display = annotator_point2d.get_display("x", "y")
display.nearest_2d_point_threshold = 1000

indices = display.get_indices_by_position(x=x, y=y)
assert len(indices) == 1

0 comments on commit 0e8d4e8

Please sign in to comment.