From a7e2652220c2cbbd56f09c5a8f40bf51eb61e99e Mon Sep 17 00:00:00 2001 From: lin Date: Sat, 8 May 2021 15:52:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=8C=E6=88=90=E5=A4=9A?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E4=B8=8D=E5=90=8C=E9=A2=9C=E8=89=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- iann/app.py | 106 ++++++++++++++++++++++----------------------- iann/controller.py | 99 +++++++++++++++++++++++++++--------------- iann/ui.py | 1 + iann/util/vis.py | 45 +++++++++++++------ 4 files changed, 147 insertions(+), 104 deletions(-) diff --git a/iann/app.py b/iann/app.py index 8d079a44a..3e38bfcca 100644 --- a/iann/app.py +++ b/iann/app.py @@ -45,47 +45,12 @@ def __init__(self, parent=None): self.statusbar.showMessage("模型未加载") self.initActions() - # TODO: 按照labelme的方式用action - ## 菜单栏点击 - # for menu_act in self.menuBar.actions(): - # if menu_act.text() == "文件": - # for ac_act in menu_act.menu().actions(): - # if ac_act.text() == "加载图像": - # ac_act.triggered.connect(self.openImage) - # else: - # ac_act.triggered.connect(self.openFolder) - # elif menu_act.text() == "设置": - # for ac_act in menu_act.menu().actions(): - # if ac_act.text() == "设置保存路径": - # ac_act.triggered.connect(self.changeOutputDir) - # else: - # ac_act.triggered.connect(self.check_click) - # elif menu_act.text() == "帮助": - # for ac_act in menu_act.menu().actions(): - # if ac_act.text() == "快速上手": - # ac_act.triggered.connect(self.help_dialog.show) - # else: - # ac_act.triggered.connect(self.check_click) - # - # ## 工具栏点击 - # for tool_act in self.toolBar.actions(): - # if tool_act.text() == "完成当前": - # tool_act.triggered.connect(self.finishObject) - # elif tool_act.text() == "清除全部": - # tool_act.triggered.connect(self.undoAll) - # elif tool_act.text() == "撤销": - # tool_act.triggered.connect(self.undoClick) - # elif tool_act.text() == "重做": - # tool_act.triggered.connect(self.check_click) - # elif tool_act.text() == "上一张": - # tool_act.triggered.connect(partial(self.turnImg, -1)) - # elif tool_act.text() == "下一张": - # tool_act.triggered.connect(partial(self.turnImg, 1)) ## 按钮点击 self.btnSave.clicked.connect(self.saveLabel) # 保存 self.listFiles.itemDoubleClicked.connect(self.listClicked) # list选择 self.comboModelSelect.currentIndexChanged.connect(self.changeModel) # 模型选择 + self.btnAddClass.clicked.connect(self.addLabel) # 滑动 self.sldOpacity.valueChanged.connect(self.maskOpacityChanged) @@ -100,13 +65,13 @@ def toBeImplemented(self): self.statusbar.showMessage("功能尚在开发") pass - def menu(self, title, actions=None): - menu = self.menuBar().addMenu(title) - if actions: - util.addActions(menu, actions) - return menu - def initActions(self): + def menu(title, actions=None): + menu = self.menuBar().addMenu(title) + if actions: + util.addActions(menu, actions) + return menu + action = partial(util.newAction, self) shortcuts = { "turn_next": "F", @@ -210,14 +175,7 @@ def initActions(self): "redo", self.tr("重做一次点击"), ) - # finish_object = action( - # self.tr("&N^2宫格标注"), - # self.toBeImplemented, - # None, - # # TODO: 搞个图 - # "", - # self.tr("使用N^2宫格进行细粒度标注"), - # ) + # TODO: 改用manager self.actions = util.struct( turn_next=turn_next, @@ -236,9 +194,9 @@ def initActions(self): labelMenu=(grid_ann,), toolBar=(finish_object, clear, undo, redo, turn_prev, turn_next), ) - self.menu("文件", self.actions.fileMenu) - self.menu("标注", self.actions.labelMenu) - self.menu("帮助", self.actions.helpMenu) + menu("文件", self.actions.fileMenu) + menu("标注", self.actions.labelMenu) + menu("帮助", self.actions.helpMenu) util.addActions(self.toolBar, self.actions.toolBar) def changeOutputDir(self, dir=None): @@ -278,9 +236,24 @@ def changeModel(self, idx): self.statusbar.showMessage(f"{ models[idx].name}模型加载完成", 5000) + def addLabel(self): + table = self.labelListTable + table.insertRow(table.rowCount()) + idx = table.rowCount() - 1 + print(idx) + numberItem = QTableWidgetItem(str(idx + 1)) + numberItem.setFlags(QtCore.Qt.ItemIsEnabled) + table.setItem(idx, 0, numberItem) + table.setItem(idx, 1, QTableWidgetItem()) + c = [255, 255, 255] + colorItem = QTableWidgetItem() + colorItem.setBackground(QtGui.QColor(c[0], c[1], c[2])) + colorItem.setFlags(QtCore.Qt.ItemIsEnabled) + table.setItem(idx, 2, colorItem) + self.labelList.append([idx + 1, "", [255, 255, 255]]) + def refreshLabelList(self): table = self.labelListTable - # TODO: 添加表头 table.clearContents() table.setRowCount(len(self.labelList)) table.setColumnCount(3) @@ -303,12 +276,26 @@ def changeLabelColor(row, col): if col != 2: return color = QtWidgets.QColorDialog.getColor() + # BUG: 判断颜色没变 print(color.getRgb()) table.item(row, col).setBackground(color) self.labelList[row][2] = color.getRgb()[:3] + if self.controller: + self.controller.label_list = self.labelList table.cellDoubleClicked.connect(changeLabelColor) + def cellClicked(row, col): + print("cell clicked", row, col) + for idx in range(3): + table.item(row, idx).setSelected(True) + if self.controller: + print(int(table.item(row, 0).text())) + self.controller.change_label_num(int(table.item(row, 0).text())) + self.controller.label_list = self.labelList + + table.cellClicked.connect(cellClicked) + def openImage(self): formats = [ "*.{}".format(fmt.data().decode()) @@ -471,7 +458,7 @@ def threshChanged(self): self._update_image() def undoClick(self): - if not self.image: + if self.image is None: return self.controller.undo_click() @@ -488,6 +475,15 @@ def canvasClick(self, x, y, isLeft): return if x < 0 or y < 0: return + if not self.controller.curr_label_number: + msg = QMessageBox() + msg.setIcon(QMessageBox.Warning) + msg.setWindowTitle("未选择当前标签") + msg.setText("请先在标签列表中单击点选标签") + msg.setStandardButtons(QMessageBox.Yes) + res = msg.exec_() + return + s = self.controller.img_size if x > s[0] or y > s[1]: return diff --git a/iann/controller.py b/iann/controller.py index 8a6219a58..de24286bb 100644 --- a/iann/controller.py +++ b/iann/controller.py @@ -14,8 +14,9 @@ def __init__(self, net, predictor_params, update_image_callback, prob_thresh=0.5 self.clicker = clicker.Clicker() self.states = [] self.probs_history = [] - self.object_count = 0 + self.curr_label_number = 0 self._result_mask = None + self.label_list = None # 存标签编号和颜色的对照 self.image = None self.image_nd = None @@ -41,12 +42,10 @@ def set_image(self, image): self.image_nd = input_transform(image)[0] self._result_mask = np.zeros(image.shape[:2], dtype=np.uint8) - self.object_count = 0 + self.curr_label_number = 0 self.reset_last_object(update_image=False) self.update_image_callback(reset_canvas=True) - # def change_alpha(self, alpha_blend): - # self. def add_click(self, x, y, is_positive): """添加一个点 跑推理,保存历史用于undo @@ -114,11 +113,27 @@ def finish_object(self): if object_prob is None: return - self.object_count += 1 # TODO: 当前是按照第几个目标给结果中的数,改成根据目标编号 + # self.curr_label_number += 1 # TODO: 当前是按照第几个目标给结果中的数,改成根据目标编号 object_mask = object_prob > self.prob_thresh - self._result_mask[object_mask] = self.object_count + self._result_mask[object_mask] = self.curr_label_number self.reset_last_object() + def change_label_num(self, number): + """修改当前标签的编号 + 如果当前有标注到一半的目标,改mask。 + 如果没有,下一个目标是这个数 + + Parameters + ---------- + number : int + 换成目标的编号 + """ + assert isinstance(number, int), "标签编号应为整数" + self.curr_label_number = number + if self.is_incomplete_mask: + pass + # TODO: 改当前mask的编号 + def reset_last_object(self, update_image=True): """重置控制器状态 @@ -158,6 +173,49 @@ def reset_predictor(self, net=None, predictor_params=None): if self.image_nd is not None: self.predictor.set_input_image(self.image_nd) + def get_visualization(self, alpha_blend, click_radius): + if self.image is None: + return None + + # 1. 画当前没标完的mask + results_mask_for_vis = self.result_mask + if self.probs_history: + results_mask_for_vis[ + self.current_object_prob > self.prob_thresh + ] = self.curr_label_number + + vis = draw_with_blend_and_clicks( + self.image, + mask=results_mask_for_vis, + alpha=alpha_blend, + clicks_list=self.clicker.clicks_list, + radius=click_radius, + palette=self.palette, + ) + + # 2. 在图片和当前mask的基础上画之前标完的mask + if self.probs_history: + total_mask = self.probs_history[-1][0] > self.prob_thresh + results_mask_for_vis[np.logical_not(total_mask)] = 0 + vis = draw_with_blend_and_clicks( + vis, + mask=results_mask_for_vis, + alpha=alpha_blend, + palette=self.palette, + ) + + return vis + + @property + def palette(self): + if self.label_list: + colors = [l[2] for l in self.label_list] + colors.insert(0, [0, 0, 0]) + else: + colors = [[0, 0, 0]] + print(colors) + return colors + @property def current_object_prob(self): """获取当前推理标签""" @@ -185,32 +243,3 @@ def result_mask(self): def img_size(self): print(self.image.shape) return self.image.shape[1::-1] - - def get_visualization(self, alpha_blend, click_radius): - if self.image is None: - return None - - # 1. 画当前没标完的mask - results_mask_for_vis = self.result_mask - if self.probs_history: - results_mask_for_vis[self.current_object_prob > self.prob_thresh] = ( - self.object_count + 1 - ) - - vis = draw_with_blend_and_clicks( - self.image, - mask=results_mask_for_vis, - alpha=alpha_blend, - clicks_list=self.clicker.clicks_list, - radius=click_radius, - ) - - # 2. 在图片和当前mask的基础上画之前标完的mask - if self.probs_history: - total_mask = self.probs_history[-1][0] > self.prob_thresh - results_mask_for_vis[np.logical_not(total_mask)] = 0 - vis = draw_with_blend_and_clicks( - vis, mask=results_mask_for_vis, alpha=alpha_blend - ) - - return vis diff --git a/iann/ui.py b/iann/ui.py index 4a37df50d..c0d851218 100644 --- a/iann/ui.py +++ b/iann/ui.py @@ -185,6 +185,7 @@ def setupUi(self, MainWindow): # 标签列表 labelListLab = self.create_text(CentralWidget, "labelListLab", "标签列表") listRegion.addWidget(labelListLab) + # TODO: 改成 list widget self.labelListTable = QtWidgets.QTableWidget(CentralWidget) self.labelListTable.horizontalHeader().hide() self.labelListTable.verticalHeader().hide() diff --git a/iann/util/vis.py b/iann/util/vis.py index 9d744b7cf..7fa1a7bfc 100644 --- a/iann/util/vis.py +++ b/iann/util/vis.py @@ -4,8 +4,9 @@ import numpy as np -def visualize_instances(imask, bg_color=255, - boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): +def visualize_instances( + imask, bg_color=255, boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8 +): num_objects = imask.max() + 1 palette = get_palette(num_objects) if bg_color is not None: @@ -24,6 +25,7 @@ def visualize_instances(imask, bg_color=255, @lru_cache(maxsize=16) def get_palette(num_cls): + return np.array([[0, 0, 0], [128, 0, 0], [0, 128, 0], [0, 0, 128]]) palette = np.zeros(3 * num_cls, dtype=np.int32) for j in range(0, num_cls): @@ -31,9 +33,9 @@ def get_palette(num_cls): i = 0 while lab > 0: - palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) - palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) - palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) + palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i) + palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i) + palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i) i = i + 1 lab >>= 3 @@ -87,7 +89,9 @@ def blend_mask(image, mask, alpha=0.6): def get_boundaries(instances_masks, boundaries_width=1): - boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) + boundaries = np.zeros( + (instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool + ) for obj_id in np.unique(instances_masks.flatten()): if obj_id == 0: @@ -95,25 +99,39 @@ def get_boundaries(instances_masks, boundaries_width=1): obj_mask = instances_masks == obj_id kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) - inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) + inner_mask = cv2.erode( + obj_mask.astype(np.uint8), kernel, iterations=boundaries_width + ).astype(np.bool) obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) boundaries = np.logical_or(boundaries, obj_boundary) return boundaries -def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), - neg_color=(255, 0, 0), radius=4): +def draw_with_blend_and_clicks( + img, + mask=None, + alpha=0.6, + clicks_list=None, + pos_color=(0, 255, 0), + neg_color=(255, 0, 0), + radius=4, + palette=None, +): result = img.copy() if mask is not None: - palette = get_palette(np.max(mask) + 1) + if not palette: + palette = get_palette(np.max(mask) + 1) + palette = np.array(palette) rgb_mask = palette[mask.astype(np.uint8)] mask_region = (mask > 0).astype(np.uint8) - result = result * (1 - mask_region[:, :, np.newaxis]) + \ - (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ - alpha * rgb_mask + result = ( + result * (1 - mask_region[:, :, np.newaxis]) + + (1 - alpha) * mask_region[:, :, np.newaxis] * result + + alpha * rgb_mask + ) result = result.astype(np.uint8) # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) @@ -126,4 +144,3 @@ def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_ result = draw_points(result, neg_points, neg_color, radius=radius) return result -