Skip to content

Commit

Permalink
Correctly merge field_df and region_df (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoxbro committed Sep 20, 2023
1 parent 1615a2a commit eb621ab
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 13 deletions.
30 changes: 17 additions & 13 deletions holonote/annotate/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,28 @@ def ranges_1d(cls, region_df, field_df, invert_axes=False, extra_params=None):

@classmethod
def _range_indicators(cls, region_df, field_df, dimensionality, invert_axes=False, extra_params=None):
rect_data = []
# TODO: Clean this up VSpans/HSpans/VLines/HLines
index_col_name = 'id' if field_df.index.name is None else field_df.index.name

mdata_vals = ([None] * len(region_df['_id'])
if len(field_df.columns)==0 else field_df.to_dict('records'))
for id_val, value, mdata in zip(region_df['_id'], region_df["value"], mdata_vals):
if dimensionality=='1d':
coords = (value[0], extra_params['rect_min'], value[1], extra_params['rect_max'])
else:
coords = (value[0], value[2], value[1], value[3]) # LBRT format
if region_df.empty:
return hv.Rectangles([], vdims=[*field_df.columns, index_col_name])

if None in coords: continue
data = region_df.merge(field_df, left_on="_id", right_index=True)
values = pd.DataFrame.from_records(data["value"])
id_vals = data["_id"].rename({"_id": index_col_name})
mdata_vals = data[field_df.columns]

mdata_tuple = () if len(field_df.columns)==0 else tuple(mdata.values())
rect_data.append(coords + mdata_tuple + (id_val,))
# TODO: Add check for None, (None, None), or (None, None, None, None) in values?

index_col_name = ['id'] if field_df.index.name is None else [field_df.index.name]
return hv.Rectangles(rect_data, vdims=list(field_df.columns)+index_col_name) # kdims?
if dimensionality=='1d':
coords = values[[0, 0, 1, 1]].copy()
coords.iloc[:, 1] = extra_params["rect_min"]
coords.iloc[:, 3] = extra_params["rect_max"]
else:
coords = values[[0, 2, 1, 3]] # LBRT format

rect_data = list(pd.concat([coords, mdata_vals, id_vals], axis=1).itertuples(index=False))
return hv.Rectangles(rect_data, vdims=[*field_df.columns, index_col_name]) # kdims?


class AnnotatorInterface(param.Parameterized):
Expand Down
1 change: 1 addition & 0 deletions holonote/app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .panel import PanelWidgets # noqa: F401
150 changes: 150 additions & 0 deletions holonote/app/panel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import annotations

import datetime as dt
from typing import TYPE_CHECKING, Any

import panel as pn
import param

if TYPE_CHECKING:
from holonote.annotate import Annotator


class PanelWidgets:
mapping = {
str: pn.widgets.TextInput,
bool: pn.widgets.Checkbox,
dt.datetime: pn.widgets.DatePicker,
dt.date: pn.widgets.DatePicker,
int: pn.widgets.IntSlider,
float: pn.widgets.FloatSlider,
}

def __init__(self, annotator: Annotator, field_values: dict[str, Any] | None=None):
self.annotator = annotator
self._widget_mode_group = pn.widgets.RadioButtonGroup(
name="Mode", options=["+", "-", "✏"], width=90
)
self._widget_apply_button = pn.widgets.Button(name="✓", width=20)
self._widget_revert_button = pn.widgets.Button(name="↺", width=20)
self._widget_commit_button = pn.widgets.Button(name="▲", width=20)

if field_values is None:
self._fields_values = {k: "" for k in self.annotator.fields}
else:
self._fields_values = {
k: field_values.get(k, "") for k in self.annotator.fields
}
self._fields_widgets = self._create_fields_widgets(self._fields_values)

self._set_standard_callbacks()

@property
def tool_widgets(self):
return pn.Row(
self._widget_apply_button,
pn.Spacer(width=10),
self._widget_mode_group,
pn.Spacer(width=10),
self._widget_revert_button,
self._widget_commit_button,
)

def _create_fields_widgets(self, fields_values):
fields_widgets = {}
for widget_name, default in fields_values.items():
if isinstance(default, param.Parameter):
parameterized = type(
"widgets", (param.Parameterized,), {widget_name: default}
)
pane = pn.Param(parameterized)
fields_widgets[widget_name] = pane.layout[1]
elif isinstance(default, list):
fields_widgets[widget_name] = pn.widgets.Select(
value=default[0], options=default, name=widget_name
)
else:
widget_type = self.mapping[type(default)]
if issubclass(widget_type, pn.widgets.TextInput):
fields_widgets[widget_name] = widget_type(
value=default, placeholder=widget_name, name=widget_name
)
else:
fields_widgets[widget_name] = widget_type(
value=default, name=widget_name
)
return fields_widgets

@property
def fields_widgets(self):
accordion = False # Experimental
widgets = pn.Column(*self._fields_widgets.values())
if accordion:
return pn.Accordion(("fields", widgets))
else:
return widgets

def _reset_fields_widgets(self):
for widget_name, default in self._fields_values.items():
if isinstance(default, param.Parameter):
default = default.default
try:
self._fields_widgets[widget_name].value = default
except Exception:
pass # TODO: Fix when lists (for categories, not the same as the default!)

def _callback_apply(self, event):
selected_ind = (
self.annotator.selected_indices[0]
if len(self.annotator.selected_indices) == 1
else None
)
self.annotator.select_by_index()

if self._widget_mode_group.value in ["+", "✏"]:
fields_values = {k: v.value for k, v in self._fields_widgets.items()}
if self._widget_mode_group.value == "+":
self.annotator.add_annotation(**fields_values)
self._reset_fields_widgets()
elif (self._widget_mode_group.value == "✏") and (selected_ind is not None):
self.annotator.update_annotation_fields(
selected_ind, **fields_values
) # TODO: Handle only changed
elif self._widget_mode_group.value == "-":
if selected_ind is not None:
self.annotator.delete_annotation(selected_ind)

def _callback_commit(self, event):
self.annotator.commit()

def _watcher_selected_indices(self, event):
if len(event.new) != 1:
return
selected_index = event.new[0]
# if self._widget_mode_group.value == '✏':
for name, widget in self._fields_widgets.items():
value = self.annotator.annotation_table._field_df.loc[selected_index][name]
widget.value = value

def _watcher_mode_group(self, event):
if event.new in ["-", "✏"]:
self.annotator.selection_enabled = True
self.annotator.select_by_index()
self.annotator.editable_enabled = False
elif event.new == "+":
self.annotator.editable_enabled = True
self.annotator.select_by_index()
self.annotator.selection_enabled = False

for widget in self._fields_widgets.values():
widget.disabled = event.new == "-"

def _set_standard_callbacks(self):
self._widget_apply_button.on_click(self._callback_apply)
# self._widget_revert_button.on_click(self._callback_revert)
self._widget_commit_button.on_click(self._callback_commit)
self.annotator.param.watch(self._watcher_selected_indices, "selected_indices")
self._widget_mode_group.param.watch(self._watcher_mode_group, "value")

def __panel__(self):
return pn.Column(self.fields_widgets, self.tool_widgets)
46 changes: 46 additions & 0 deletions holonote/tests/test_indicators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import pandas as pd

from holonote.annotate.annotator import Indicator


def test_range2d_id_matches() -> None:
value = np.arange(8).reshape(2, 4)
region_df = pd.DataFrame({"value": list(value), "_id": ["A", "B"]})
field_df = pd.DataFrame(["B", "A"], index=["B", "A"], columns=["description"])

# id and description should match
output = Indicator.ranges_2d(region_df, field_df).data
expected = pd.DataFrame(
{
"x0": {0: 0, 1: 4},
"y0": {0: 2, 1: 6},
"x1": {0: 1, 1: 5},
"y1": {0: 3, 1: 7},
"description": {0: "A", 1: "B"},
"id": {0: "A", 1: "B"},
}
)
pd.testing.assert_frame_equal(output, expected)


def test_range1d_id_matches() -> None:
value = np.arange(4).reshape(2, 2)
region_df = pd.DataFrame({"value": list(value), "_id": ["A", "B"]})
field_df = pd.DataFrame(["B", "A"], index=["B", "A"], columns=["description"])

# id and description should match
output = Indicator.ranges_1d(
region_df, field_df, extra_params={"rect_min": -2, "rect_max": -2}
).data
expected = pd.DataFrame(
{
"x0": {0: 0, 1: 2},
"y0": {0: -2, 1: -2},
"x1": {0: 1, 1: 3},
"y1": {0: -2, 1: -2},
"description": {0: "A", 1: "B"},
"id": {0: "A", 1: "B"},
}
)
pd.testing.assert_frame_equal(output, expected)

0 comments on commit eb621ab

Please sign in to comment.