diff --git a/holonote/annotate/display.py b/holonote/annotate/display.py index 5cb4964..e941e00 100644 --- a/holonote/annotate/display.py +++ b/holonote/annotate/display.py @@ -88,12 +88,14 @@ class Style(param.Parameterized): line_opts = _StyleOpts(default={}) span_opts = _StyleOpts(default={}) rectangle_opts = _StyleOpts(default={}) + points_opts = _StyleOpts(default={}) # Editor opts edit_opts = _StyleOpts(default={"line_color": "black"}) edit_line_opts = _StyleOpts(default={}) edit_span_opts = _StyleOpts(default={}) edit_rectangle_opts = _StyleOpts(default={}) + edit_points_opts = _StyleOpts(default={}) _groupby = () _colormap = None @@ -133,6 +135,7 @@ def indicator(self, **select_opts) -> tuple[hv.Options, ...]: hv.opts.HSpans(**opts, **self.span_opts), hv.opts.VLines(**opts, **self.line_opts), hv.opts.HLines(**opts, **self.line_opts), + hv.opts.Points(**opts, **self.points_opts), ) def editor(self) -> tuple[hv.Options, ...]: @@ -148,6 +151,7 @@ def editor(self) -> tuple[hv.Options, ...]: hv.opts.HSpan(**opts, **self.edit_span_opts), hv.opts.VLine(**opts, **self.edit_line_opts), hv.opts.HLine(**opts, **self.edit_line_opts), + hv.opts.Points(**opts, **self.edit_points_opts), ) def reset(self) -> None: @@ -177,8 +181,6 @@ def points_2d( cls, data, region_labels, fields_labels, invert_axes=False, groupby: str | None = None ): "Vectorizes point regions to VLines * HLines. Note does not support hover info" - msg = "2D point regions not supported yet" - raise NotImplementedError(msg) vdims = [*fields_labels, "__selected__"] element = hv.Points(data, kdims=region_labels, vdims=vdims) hover = cls._build_hover_tool(data) @@ -236,6 +238,17 @@ class AnnotationDisplay(param.Parameterized): data = param.DataFrame(doc="Combined dataframe of annotation data", constant=True) + nearest_2d_point_threshold = param.Number( + default=None, + bounds=(0, None), + doc=""" + Threshold In the distance in data coordinates between the two dimensions; + it does not consider the unit and magnitude differences between the dimensions + for selecting an existing 2D point; anything over this threshold will create + a new point instead. This parameter is experimental and is subject to change. + """, + ) + invert_axis = param.Boolean(default=False, doc="Switch the annotation axis") _count = param.Integer(default=0, precedence=-1) @@ -425,13 +438,27 @@ def get_indices_by_position(self, **inputs) -> list[Any]: iter_mask = ( (df[f"start[{k}]"] <= v) & (v < df[f"end[{k}]"]) for k, v in inputs.items() ) + subset = reduce(np.logical_and, iter_mask) + out = list(df[subset].index) + elif self.region_format == "point-point": + xk, yk = list(inputs.keys()) + xdist = (df[f"point[{xk}]"] - inputs[xk]) ** 2 + ydist = (df[f"point[{yk}]"] - inputs[yk]) ** 2 + distance_squared = xdist + ydist + if ( + self.nearest_2d_point_threshold + and (distance_squared > self.nearest_2d_point_threshold**2).all() + ): + return [] + out = [df.loc[distance_squared.idxmin()].name] # index == name of series elif "point" in self.region_format: iter_mask = ((df[f"point[{k}]"] - v).abs().argmin() for k, v in inputs.items()) + out = list(df[reduce(np.logical_and, iter_mask)].index) else: msg = f"{self.region_format} not implemented" raise NotImplementedError(msg) - return list(df[reduce(np.logical_and, iter_mask)].index) + return out def register_tap_selector(self, element: hv.Element) -> hv.Element: def tap_selector(x, y) -> None: # Tap tool must be enabled on the element @@ -484,7 +511,9 @@ def overlay(self, indicators=True, editor=True) -> hv.Overlay: def static_indicators(self, **events): fields_labels = self.annotator.all_fields - region_labels = [k for k in self.data.columns if k not in fields_labels] + region_labels = [ + k for k in self.data.columns if k not in fields_labels and k != "__selected__" + ] self.data["__selected__"] = self.data.index.isin(self.annotator.selected_indices) diff --git a/holonote/tests/test_display.py b/holonote/tests/test_display.py new file mode 100644 index 0000000..3f28b05 --- /dev/null +++ b/holonote/tests/test_display.py @@ -0,0 +1,68 @@ +import holoviews as hv + +hv.extension("bokeh") + + +class TestPoint2D: + def test_get_indices_by_position_exact(self, annotator_point2d): + x, y = 0.5, 0.3 + description = "A test annotation!" + annotator_point2d.set_regions(x=x, y=y) + annotator_point2d.add_annotation(description=description) + display = annotator_point2d.get_display("x", "y") + indices = display.get_indices_by_position(x=x, y=y) + assert len(indices) == 1 + + def test_get_indices_by_position_nearest_2d_point_threshold(self, annotator_point2d): + x, y = 0.5, 0.3 + annotator_point2d.get_display("x", "y").nearest_2d_point_threshold = 1 + description = "A test annotation!" + annotator_point2d.set_regions(x=x, y=y) + annotator_point2d.add_annotation(description=description) + display = annotator_point2d.get_display("x", "y") + indices = display.get_indices_by_position(x=x + 1.5, y=y + 1.5) + assert len(indices) == 0 + + display.nearest_2d_point_threshold = 5 + indices = display.get_indices_by_position(x=x + 0.5, y=y + 0.5) + assert len(indices) == 1 + + def test_get_indices_by_position_nearest(self, annotator_point2d): + x, y = 0.5, 0.3 + description = "A test annotation!" + annotator_point2d.set_regions(x=x, y=y) + annotator_point2d.add_annotation(description=description) + display = annotator_point2d.get_display("x", "y") + indices = display.get_indices_by_position(x=x + 1.5, y=y + 1.5) + assert len(indices) == 1 + + display.nearest_2d_point_threshold = 5 + indices = display.get_indices_by_position(x=x + 0.5, y=y + 0.5) + assert len(indices) == 1 + + def test_get_indices_by_position_empty(self, annotator_point2d): + display = annotator_point2d.get_display("x", "y") + indices = display.get_indices_by_position(x=0.5, y=0.3) + assert len(indices) == 0 + + def test_get_indices_by_position_no_position(self, annotator_point2d): + display = annotator_point2d.get_display("x", "y") + indices = display.get_indices_by_position(x=None, y=None) + assert len(indices) == 0 + + def test_get_indices_by_position_multi_choice(self, annotator_point2d): + x, y = 0.5, 0.3 + description = "A test annotation!" + annotator_point2d.set_regions(x=x, y=y) + annotator_point2d.add_annotation(description=description) + + x2, y2 = 0.51, 0.31 + description = "A test annotation!" + annotator_point2d.set_regions(x=x2, y=y2) + annotator_point2d.add_annotation(description=description) + + display = annotator_point2d.get_display("x", "y") + display.nearest_2d_point_threshold = 1000 + + indices = display.get_indices_by_position(x=x, y=y) + assert len(indices) == 1