From 3b3747eb2b8fb8b133f4284fa538358378493b23 Mon Sep 17 00:00:00 2001 From: lin Date: Sun, 9 May 2021 19:22:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=A4=9A=E6=A0=87=E7=AD=BE?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + iann/app.py | 421 ++++++++++++++++++++++++++++++------------ iann/controller.py | 11 +- iann/util/__init__.py | 3 +- iann/util/label.py | 40 ++++ iann/util/poly.py | 21 +++ iann/util/util.py | 147 ++++++++++----- setup.py | 162 ++++++++++++++++ test.txt | 2 + 9 files changed, 632 insertions(+), 176 deletions(-) create mode 100644 iann/util/label.py create mode 100644 iann/util/poly.py create mode 100644 setup.py create mode 100644 test.txt diff --git a/.gitignore b/.gitignore index 6ec60bb30..a651cc9ca 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ dzq* .vscode .vscode/ +vis_temp.py # C extensions *.so diff --git a/iann/app.py b/iann/app.py index 3e38bfcca..97202acfb 100644 --- a/iann/app.py +++ b/iann/app.py @@ -30,13 +30,13 @@ def __init__(self, parent=None): # app变量 self.controller = None self.outputDir = None # 标签保存路径 - self.labelFiles = None # 保存所有从outputdir发现的标签文件路径 + self.labelPaths = [] # 保存所有从outputdir发现的标签文件路径 self.currIdx = 0 # 标注文件夹时到第几个了 self.filePaths = [] # 标注文件夹时所有文件路径 + # TODO: labelList用一个class实现 self.labelList = [] # 标签列表(数字,名字,颜色) - - self.labelList = [[1, "人", [0, 0, 0]], [2, "车", [128, 128, 128]]] - + # self.labelList = [[1, "人", [0, 0, 0]], [2, "车", [128, 128, 128]]] + self.isDirty = False # 画布部分 self.canvas.clickRequest.connect(self.canvasClick) self.image = None @@ -57,13 +57,16 @@ def __init__(self, parent=None): self.sldClickRadius.valueChanged.connect(self.clickRadiusChanged) self.sldThresh.valueChanged.connect(self.threshChanged) self.refreshLabelList() + # 标签列表点击 + self.labelListTable.cellDoubleClicked.connect(self.labelListDoubleClick) + self.labelListTable.cellClicked.connect(self.labelListClicked) + self.labelListTable.cellChanged.connect(self.labelListItemChanged) # TODO: 打开上次关软件时用的模型 # TODO: 在ui展示后再加载模型 def toBeImplemented(self): self.statusbar.showMessage("功能尚在开发") - pass def initActions(self): def menu(title, actions=None): @@ -115,6 +118,14 @@ def menu(title, actions=None): "", self.tr("打开一个文件夹下所有的图像进行标注"), ) + open_recent = action( + self.tr("&最近标注"), + self.toBeImplemented, + "", + # TODO: 搞个图 + "", + self.tr("打开一个文件夹下所有的图像进行标注"), + ) change_output_dir = action( self.tr("&改变标签保存路径"), self.changeOutputDir, @@ -175,23 +186,104 @@ def menu(title, actions=None): "redo", self.tr("重做一次点击"), ) - + save = action( + self.tr("&保存"), + self.saveLabel, + "", + "redo", + self.tr("保存图像标签"), + ) + save_as = action( + self.tr("&另存为"), + partial(self.saveLabel, True), + "", + "redo", + self.tr("指定标签保存路径"), + ) + auto_save = action( + self.tr("&自动保存"), + self.toggleAutoSave, + "", + None, + self.tr("翻页同时自动保存"), + checkable=True, + ) + recent = action( + self.tr("&近期图片"), + self.toBeImplemented, + "", + "redo", + self.tr("近期打开的图片"), + ) + close = action( + self.tr("&关闭"), + self.toBeImplemented, + "", + "redo", + self.tr("关闭当前图像"), + ) + connected = action( + self.tr("&连通块"), + self.toBeImplemented, + "", + "redo", + self.tr(""), + ) + quit = action( + self.tr("&退出"), + self.close, + "", + "redo", + self.tr("退出软件"), + ) + save_label = action( + self.tr("&保存标签列表"), + self.saveLabelList, + "", + "redo", + self.tr("将标签保存成标签配置文件"), + ) + load_label = action( + self.tr("&加载标签列表"), + self.loadLabelList, + "", + "redo", + self.tr("从标签配置文件中加载标签"), + ) + clear_label = action( + self.tr("&清空标签列表"), + self.clearLabelList, + "", + "redo", + self.tr("清空所有的标签"), + ) + shortcuts = action( + self.tr("&快捷键列表"), + self.showShortcuts, + "", + "redo", + self.tr("查看所有快捷键"), + ) # TODO: 改用manager self.actions = util.struct( - turn_next=turn_next, - turn_prev=turn_prev, - open_image=open_image, - open_folder=open_folder, + auto_save=auto_save, fileMenu=( open_image, open_folder, change_output_dir, + open_recent, None, - turn_prev, + save, + save_as, + auto_save, turn_next, + turn_prev, + close, + None, + quit, ), - helpMenu=(quick_start, about), - labelMenu=(grid_ann,), + labelMenu=(save_label, load_label, clear_label, None, grid_ann), + helpMenu=(quick_start, about, shortcuts), toolBar=(finish_object, clear, undo, redo, turn_prev, turn_next), ) menu("文件", self.actions.fileMenu) @@ -199,24 +291,59 @@ def menu(title, actions=None): menu("帮助", self.actions.helpMenu) util.addActions(self.toolBar, self.actions.toolBar) - def changeOutputDir(self, dir=None): - if dir is not None: - outputDir = QtWidgets.QFileDialog.getExistingDirectory( - self, - self.tr("%s - 选择标签文件夹") % __appname__, - "/home/aistudio/git/paddle/iann", - QtWidgets.QFileDialog.ShowDirsOnly - | QtWidgets.QFileDialog.DontResolveSymlinks, - ) - if len(outputDir) == 0: - return + def showShortcuts(self): + pass - labelFiles = os.listdir(outputDir) - exts = QtGui.QImageReader.supportedImageFormats() - self.labelFiles = [ - osp.join(outputDir, n) for n in labelFiles if n.split(".")[-1] in exts - ] - self.outputDir = outputDir + def clearMask(self): + self.controller.reset_last_object() + + def toggleAutoSave(self, x): + if x and not self.outputDir: + self.changeOutputDir() + if x and not self.outputDir: + self.actions.auto_save.setChecked(False) + + def clearLabelList(self): + self.labelList = [] + if self.controller: + self.controller.label_list = [] + self.controller.curr_label_number = None + self.labelListTable.clear() + self.labelListTable.setRowCount(0) + + def saveLabelList(self): + if len(self.labelList) == 0: + msg = QMessageBox() + msg.setIcon(QMessageBox.Warning) + msg.setWindowTitle("没有需要保存的标签") + msg.setText("请先添加标签之后再进行保存") + msg.setStandardButtons(QMessageBox.Yes) + res = msg.exec_() + filters = self.tr("标签配置文件 (*.txt)") + dlg = QtWidgets.QFileDialog(self, "保存标签配置文件", ".", filters) + dlg.setDefaultSuffix("txt") + dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave) + dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False) + dlg.setOption(QtWidgets.QFileDialog.DontUseNativeDialog, False) + savePath, _ = dlg.getSaveFileName( + self, + self.tr("保存标签配置文件"), + ".", + ) + # print(savePath) + util.saveLabel(self.labelList, savePath) + + def loadLabelList(self): + filters = self.tr("标签配置文件 (*.txt)") + file_path, _ = QtWidgets.QFileDialog.getOpenFileName( + self, + self.tr("%s - 选择标签配置文件路径") % __appname__, + ".", + filters, + ) + self.labelList = util.readLabel(file_path) + print(self.labelList) + self.refreshLabelList() def changeModel(self, idx): # TODO: 设置gpu还是cpu运行 @@ -240,7 +367,6 @@ def addLabel(self): table = self.labelListTable table.insertRow(table.rowCount()) idx = table.rowCount() - 1 - print(idx) numberItem = QTableWidgetItem(str(idx + 1)) numberItem.setFlags(QtCore.Qt.ItemIsEnabled) table.setItem(idx, 0, numberItem) @@ -256,7 +382,7 @@ def refreshLabelList(self): table = self.labelListTable table.clearContents() table.setRowCount(len(self.labelList)) - table.setColumnCount(3) + table.setColumnCount(4) for idx, lab in enumerate(self.labelList): numberItem = QTableWidgetItem(str(lab[0])) numberItem.setFlags(QtCore.Qt.ItemIsEnabled) @@ -267,26 +393,42 @@ def refreshLabelList(self): colorItem.setBackground(QtGui.QColor(c[0], c[1], c[2])) colorItem.setFlags(QtCore.Qt.ItemIsEnabled) table.setItem(idx, 2, colorItem) - - for idx in range(2): + here = osp.dirname(osp.abspath(__file__)) + delItem = QTableWidgetItem() + delItem.setIcon(util.newIcon("clear")) + delItem.setTextAlignment(Qt.AlignCenter) + delItem.setFlags(QtCore.Qt.ItemIsEnabled) + table.setItem(idx, 3, delItem) + + cols = [0, 1, 3] + for idx in cols: table.resizeColumnToContents(idx) - def changeLabelColor(row, col): - print(row, col) - if col != 2: - return - color = QtWidgets.QColorDialog.getColor() - # BUG: 判断颜色没变 - print(color.getRgb()) - table.item(row, col).setBackground(color) - self.labelList[row][2] = color.getRgb()[:3] - if self.controller: - self.controller.label_list = self.labelList + def labelListDoubleClick(self, row, col): + print("cell double clicked", row, col) + if col != 2: + return + table = self.labelListTable + color = QtWidgets.QColorDialog.getColor() + # BUG: 判断颜色没变 + print(color.getRgb()) + table.item(row, col).setBackground(color) + self.labelList[row][2] = color.getRgb()[:3] + if self.controller: + self.controller.label_list = self.labelList - table.cellDoubleClicked.connect(changeLabelColor) + # table.cellDoubleClicked.connect(changeLabelColor) - def cellClicked(row, col): - print("cell clicked", row, col) + def labelListClicked(self, row, col): + print("cell clicked", row, col) + table = self.labelListTable + if col == 3: + table.removeRow(row) + del self.labelList[row] + if col == 0 or col == 1: + for idx in range(len(self.labelList)): + table.item(idx, 0).setBackground(QtGui.QColor(255, 255, 255)) + table.item(row, 0).setBackground(QtGui.QColor(48, 140, 198)) for idx in range(3): table.item(row, idx).setSelected(True) if self.controller: @@ -294,7 +436,14 @@ def cellClicked(row, col): self.controller.change_label_num(int(table.item(row, 0).text())) self.controller.label_list = self.labelList - table.cellClicked.connect(cellClicked) + # table.cellClicked.connect(cellClicked) + + def labelListItemChanged(self, row, col): + print("cell changed", row, col) + if col != 1: + return + name = self.labelListTable.item(row, col).text() + self.labelList[row][1] = name def openImage(self): formats = [ @@ -313,6 +462,23 @@ def openImage(self): self.loadFile(file_path) self.imagePath = file_path + def loadLabel(self, imgPath): + if imgPath == "" or len(self.labelPaths) == 0: + return None + + def getName(path): + return osp.basename(path).split(".")[0] + + imgName = getName(imgPath) + for path in self.labelPaths: + if getName(path) == imgName: + labPath = path + print(labPath) + break + label = cv2.imread(path, cv2.IMREAD_UNCHANGED) + print("label shape", label.shape) + return label + def loadFile(self, path): if len(path) == 0 or not osp.exists(path): return @@ -320,18 +486,12 @@ def loadFile(self, path): image = cv2.imdecode(np.fromfile(path, dtype=np.uint8), 1) image = image[:, :, ::-1] # BGR转RGB self.image = image - imgName = osp.basename(path).split(".")[0] - # TODO: 专门搞一个getlabel的方法 - # if self.outputDir: - # for labelName in self.labelFiles: - # if osp.basename(labelName).split(".")[0] == imgName: - # label = cv2.imdecode(np.fromfile(labelName, dtype=np.uint8), 1) - # print(label.shape) - # break if self.controller: self.controller.set_image(self.image) else: self.changeModel(0) + self.controller.label_list = self.labelList + self.controller.set_label(self.loadLabel(path)) def openFolder(self): self.inputDir = QtWidgets.QFileDialog.getExistingDirectory( @@ -351,7 +511,6 @@ def openFolder(self): self.listFiles.addItems(self.filePaths) self.currIdx = 0 self.turnImg(0) - # self.loadFile(self.filePaths[0]) def listClicked(self): if self.controller.is_incomplete_mask: @@ -368,7 +527,8 @@ def turnImg(self, delta): return if not self.controller: self.changeModel(0) - if self.controller.is_incomplete_mask: + self.completeMask() + if self.actions.auto_save.isChecked(): self.saveLabel() imagePath = self.filePaths[self.currIdx] self.loadFile(imagePath) @@ -384,58 +544,65 @@ def finishObject(self): return self.controller.finish_object() - def saveLabel(self): - if self.controller.image is None: - return - + def completeMask(self): if self.controller.is_incomplete_mask: - # TODO: 如果没选,直接esc,什么也不做 msg = QMessageBox() msg.setIcon(QMessageBox.Warning) - msg.setWindowTitle("保存最后一个目标?") - msg.setText("最后一个目标尚未完成标注,是否进行保存?") - # msg.setInformativeText("") - # msg.setDetailedText("The details are as follows:") + msg.setWindowTitle("完成最后一个目标?") + msg.setText("最后一个目标尚未完成标注,是否完成标注?") msg.setStandardButtons(QMessageBox.Yes | QMessageBox.Cancel) - # msg.buttonClicked.connect() res = msg.exec_() - print(QMessageBox.Yes, res) if res == QMessageBox.Yes: - print("Yes") self.finishObject() + return True + return False + return True + + def saveLabel(self, saveAs=False, savePath=None): + if not self.controller: + return + if self.controller.image is None: + return + self.completeMask() + if not savePath: + if not saveAs and self.outputDir is not None: + savePath = osp.join( + self.outputDir, osp.basename(self.imagePath).split(".")[0] + ".png" + ) else: - return + filters = self.tr("Label files (*.png)") + # BUG: 默认打开路径有问题 + dlg = QtWidgets.QFileDialog( + self, "保存标签文件路径", osp.dirname(self.imagePath), filters + ) + dlg.setDefaultSuffix("png") + dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave) + dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False) + dlg.setOption(QtWidgets.QFileDialog.DontUseNativeDialog, False) + savePath, _ = dlg.getSaveFileName( + self, + self.tr("选择标签文件保存路径"), + osp.basename(self.imagePath).split(".")[0] + ".png", + ) + if ( + savePath is None + or len(savePath) == 0 + or not osp.exists(osp.dirname(savePath)) + ): + return - if not self.outputDir: - filters = self.tr("Label files (*.png)") - # BUG: 默认打开路径有问题 - dlg = QtWidgets.QFileDialog( - self, "保存标签文件路径", osp.dirname(self.imagePath), filters - ) - dlg.setDefaultSuffix("png") - dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave) - dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False) - dlg.setOption(QtWidgets.QFileDialog.DontUseNativeDialog, False) - savePath, _ = dlg.getSaveFileName( - self, - self.tr("选择标签文件保存路径"), - osp.basename(self.imagePath).split(".")[0] + ".png", - ) - if ( - savePath is None - or len(savePath) == 0 - or not osp.exists(osp.dirname(savePath)) - ): - return - else: - savePath = osp.join( - self.outputDir, osp.basename(self.imagePath).split(".")[0] + ".png" - ) - print(self.controller.result_mask.shape) cv2.imwrite(savePath, self.controller.result_mask) + self.setClean() + self.statusbar.showMessage(f"标签成功保存至 {savePath}") + + def setClean(self): + self.isDirty = False + + def setDirty(self): + self.isDirty = True def changeOutputDir(self): - self.outputDir = QtWidgets.QFileDialog.getExistingDirectory( + outputDir = QtWidgets.QFileDialog.getExistingDirectory( self, self.tr("%s - 选择标签保存路径") % __appname__, # osp.dirname(self.imagePath), @@ -443,6 +610,15 @@ def changeOutputDir(self): QtWidgets.QFileDialog.ShowDirsOnly | QtWidgets.QFileDialog.DontResolveSymlinks, ) + if len(outputDir) == 0 or not osp.exists(outputDir): + return False + labelPaths = os.listdir(outputDir) + exts = ["png"] + labelPaths = [n for n in labelPaths if n.split(".")[-1] in exts] + labelPaths = [osp.join(outputDir, n) for n in labelPaths] + self.outputDir = outputDir + self.labelPaths = labelPaths + return True def maskOpacityChanged(self): self.sldOpacity.textLab.setText(str(self.opacity)) @@ -475,7 +651,8 @@ def canvasClick(self, x, y, isLeft): return if x < 0 or y < 0: return - if not self.controller.curr_label_number: + currLabel = self.controller.curr_label_number + if not currLabel or currLabel == 0: msg = QMessageBox() msg.setIcon(QMessageBox.Warning) msg.setWindowTitle("未选择当前标签") @@ -489,18 +666,6 @@ def canvasClick(self, x, y, isLeft): return self.controller.add_click(x, y, isLeft) - @property - def opacity(self): - return self.sldOpacity.value() / 10 - - @property - def click_radius(self): - return self.sldClickRadius.value() - - @property - def seg_thresh(self): - return self.sldThresh.value() / 10 - def _update_image(self, reset_canvas=False): image = self.controller.get_visualization( alpha_blend=self.opacity, @@ -510,22 +675,22 @@ def _update_image(self, reset_canvas=False): bytesPerLine = 3 * width image = QImage(image.data, width, height, bytesPerLine, QImage.Format_RGB888) if reset_canvas: - self.zoom_restart(width, height) + self.resetScene(width, height) self.scene.addPixmap(QPixmap(image)) # TODO: 研究是否有类似swap的更高效方式 self.scene.removeItem(self.scene.items()[1]) - # 确认点击 - def check_click(self): - print(self.sender().text()) - - # 当前打开的模型名称或类别更新 - def update_model_name(self): - self.labModelName.setText(self.sender().text()) - self.check_click() + # # 确认点击 + # def check_click(self): + # print(self.sender().text()) + # + # # 当前打开的模型名称或类别更新 + # def update_model_name(self): + # self.labModelName.setText(self.sender().text()) + # self.check_click() # 界面缩放重置 - def zoom_restart(self, width, height): + def resetScene(self, width, height): # 每次加载图像前设定下当前的显示框,解决图像缩小后不在中心的问题 self.scene.setSceneRect(0, 0, width, height) # 缩放清除 @@ -542,3 +707,15 @@ def zoom_restart(self, width, height): else: self.canvas.zoom_all = scr_cont[0] self.canvas.scale(self.canvas.zoom_all, self.canvas.zoom_all) + + @property + def opacity(self): + return self.sldOpacity.value() / 10 + + @property + def click_radius(self): + return self.sldClickRadius.value() + + @property + def seg_thresh(self): + return self.sldThresh.value() / 10 diff --git a/iann/controller.py b/iann/controller.py index de24286bb..b18a94ddb 100644 --- a/iann/controller.py +++ b/iann/controller.py @@ -77,9 +77,16 @@ def add_click(self, x, y, is_positive): self.probs_history.append((self.probs_history[-1][0], pred)) else: self.probs_history.append((np.zeros_like(pred), pred)) - self.update_image_callback() + def set_label(self, label): + # if label is None: + # return + # self.probs_history.append((np.zeros_like(label), label)) + # print("len", len(self.probs_history)) + # self.update_image_callback() + pass + def undo_click(self): """undo一步点击""" if not self.states: # 如果还没点 @@ -164,7 +171,7 @@ def reset_predictor(self, net=None, predictor_params=None): predictor_params : 网络权重 新的网络权重 """ - # print("resetting", self.image.shape) + print("palette", self.palette) if net is not None: self.net = net if predictor_params is not None: diff --git a/iann/util/__init__.py b/iann/util/__init__.py index d1634bf0a..e2872af7a 100644 --- a/iann/util/__init__.py +++ b/iann/util/__init__.py @@ -1 +1,2 @@ -from .qt import newAction, addActions, struct +from .qt import newAction, addActions, struct, newIcon +from .label import saveLabel, readLabel diff --git a/iann/util/label.py b/iann/util/label.py new file mode 100644 index 000000000..fbf7e5d01 --- /dev/null +++ b/iann/util/label.py @@ -0,0 +1,40 @@ +def toint(seq): + for idx in range(len(seq)): + try: + seq[idx] = int(seq[idx]) + except ValueError: + pass + return seq + + +def saveLabel(labelList, path): + # labelList = [[1, "人", [0, 0, 0]], [2, "车", [128, 128, 128]]] + with open(path, "w") as f: + for l in labelList: + for idx in range(2): + print(l[idx], end=" ", file=f) + for idx in range(3): + print(l[2][idx], end=" ", file=f) + print(file=f) + + +# saveLabel("label.txt") + + +def readLabel(path): + with open(path, "r") as f: + labels = f.readlines() + labelList = [] + for lab in labels: + lab = lab.replace("\n", "").strip(" ").split(" ") + if len(lab) != 2 and len(lab) != 5: + print("标签不合法") + continue + label = toint(lab[:2]) + label.append(toint(lab[2:])) + labelList.append(label) + + return labelList + + +# readLabel("label.txt") diff --git a/iann/util/poly.py b/iann/util/poly.py new file mode 100644 index 000000000..a9332318e --- /dev/null +++ b/iann/util/poly.py @@ -0,0 +1,21 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np + +lab = cv2.imread("../../dzq.png", cv2.IMREAD_UNCHANGED) + +plt.imshow(lab) +plt.show() + +poly, hierarchy = cv2.findContours(lab, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + +result = np.zeros_like(lab) +print(result.shape) +print(poly) +for p in poly[0]: + p = p[0] + print(p[0], p[1]) + result[p[1], p[0]] = 1 + +plt.imshow(result) +plt.show() diff --git a/iann/util/util.py b/iann/util/util.py index 2eeb520ca..b5c210c80 100644 --- a/iann/util/util.py +++ b/iann/util/util.py @@ -13,9 +13,21 @@ from albumentations.augmentations import functional as Func +def toint(seq): + for idx in range(len(seq)): + try: + seq[idx] = int(seq[idx]) + except ValueError: + pass + return seq + + +# TODO: 精简这里的函数,只留推理的 + + def SyncBatchNorm(*args, **kwargs): """In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead""" - if paddle.distributed.ParallelEnv().nranks == 1: + if paddle.distributed.ParallelEnv().nranks == 1: return nn.BatchNorm2D(*args, **kwargs) else: return nn.SyncBatchNorm(*args, **kwargs) @@ -53,8 +65,12 @@ def expand_bbox(bbox, expand_ratio, min_crop_size=None): def clamp_bbox(bbox, rmin, rmax, cmin, cmax): - return (max(rmin, bbox[0]), min(rmax, bbox[1]), - max(cmin, bbox[2]), min(cmax, bbox[3])) + return ( + max(rmin, bbox[0]), + min(rmax, bbox[1]), + max(cmin, bbox[2]), + min(cmax, bbox[3]), + ) def get_bbox_iou(b1, b2): @@ -78,36 +94,40 @@ def get_dims_with_exclusion(dim, exclude=None): return dims - def get_iou(gt_mask, pred_mask, ignore_label=-1): ignore_gt_mask_inv = gt_mask != ignore_label obj_gt_mask = gt_mask == 1 - intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + intersection = np.logical_and( + np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv + ).sum() - union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + union = np.logical_and( + np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv + ).sum() return intersection / union def get_dataset(dataset_name, cfg): - if dataset_name == 'GrabCut': - dataset = GrabCutDataset('./datasets/GrabCut') - elif dataset_name == 'Berkeley': - dataset = BerkeleyDataset('./datasets/Berkeley') - elif dataset_name == 'DAVIS': - dataset = DavisDataset('./datasets/DAVIS') - elif dataset_name == 'COCO_MVal': - dataset = DavisDataset('./datasets/COCO_MVal') - elif dataset_name == 'SBD': - dataset = SBDEvaluationDataset('./datasets/SBD') - elif dataset_name == 'SBD_Train': - dataset = SBDEvaluationDataset('./datasets/SBD', split='train') + if dataset_name == "GrabCut": + dataset = GrabCutDataset("./datasets/GrabCut") + elif dataset_name == "Berkeley": + dataset = BerkeleyDataset("./datasets/Berkeley") + elif dataset_name == "DAVIS": + dataset = DavisDataset("./datasets/DAVIS") + elif dataset_name == "COCO_MVal": + dataset = DavisDataset("./datasets/COCO_MVal") + elif dataset_name == "SBD": + dataset = SBDEvaluationDataset("./datasets/SBD") + elif dataset_name == "SBD_Train": + dataset = SBDEvaluationDataset("./datasets/SBD", split="train") else: dataset = None return dataset + def get_time_metrics(all_ious, elapsed_time): n_images = len(all_ious) n_clicks = sum(map(len, all_ious)) @@ -117,6 +137,7 @@ def get_time_metrics(all_ious, elapsed_time): return mean_spc, mean_spi + def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): def _get_noc(iou_arr, iou_thr): vals = iou_arr >= iou_thr @@ -125,8 +146,9 @@ def _get_noc(iou_arr, iou_thr): noc_list = [] over_max_list = [] for iou_thr in iou_thrs: - scores_arr = np.array([_get_noc(iou_arr, iou_thr) - for iou_arr in all_ious], dtype=np.int) + scores_arr = np.array( + [_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int + ) score = scores_arr.mean() over_max = (scores_arr == max_clicks).sum() @@ -137,38 +159,50 @@ def _get_noc(iou_arr, iou_thr): return noc_list, over_max_list -def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, - n_clicks=20, model_name=None): - table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' - f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' - f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' - f'{"SPC,s":^7}|{"Time":^9}|') +def get_results_table( + noc_list, + over_max_list, + brs_type, + dataset_name, + mean_spc, + elapsed_time, + n_clicks=20, + model_name=None, +): + table_header = ( + f'|{"BRS Type":^13}|{"Dataset":^11}|' + f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' + f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' + f'{"SPC,s":^7}|{"Time":^9}|' + ) row_width = len(table_header) - header = f'Eval results for model: {model_name}\n' if model_name is not None else '' - header += '-' * row_width + '\n' - header += table_header + '\n' + '-' * row_width + header = f"Eval results for model: {model_name}\n" if model_name is not None else "" + header += "-" * row_width + "\n" + header += table_header + "\n" + "-" * row_width eval_time = str(timedelta(seconds=int(elapsed_time))) - table_row = f'|{brs_type:^13}|{dataset_name:^11}|' - table_row += f'{noc_list[0]:^9.2f}|' - table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' - table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' - table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' - table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' - table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' + table_row = f"|{brs_type:^13}|{dataset_name:^11}|" + table_row += f"{noc_list[0]:^9.2f}|" + table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|" return header, table_row + def get_eval_exp_name(args): - if ':' in args.checkpoint: - model_name, checkpoint_prefix = args.checkpoint.split(':') - model_name = model_name.split('/')[-1] + if ":" in args.checkpoint: + model_name, checkpoint_prefix = args.checkpoint.split(":") + model_name = model_name.split("/")[-1] return f"{model_name}_{checkpoint_prefix}" else: return Path(args.checkpoint).stem - + + def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): assert click_indx > 0 pred = pred.numpy()[:, 0, :, :] @@ -177,8 +211,8 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): fn_mask = np.logical_and(gt, pred < pred_thresh) fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) - fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) - fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) num_points = points.shape[1] // 2 points = points.clone() @@ -198,28 +232,39 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): if is_positive: points[bindx, num_points - click_indx, 0] = float(coords[0]) points[bindx, num_points - click_indx, 1] = float(coords[1]) - #points[bindx, num_points - click_indx, 2] = float(click_indx) + # points[bindx, num_points - click_indx, 2] = float(click_indx) else: points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) - #points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) + # points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) return points + class UniformRandomResize(DualTransform): - def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): + def __init__( + self, + scale_range=(0.9, 1.1), + interpolation=cv2.INTER_LINEAR, + always_apply=False, + p=1, + ): super().__init__(always_apply, p) self.scale_range = scale_range self.interpolation = interpolation def get_params_dependent_on_targets(self, params): scale = random.uniform(*self.scale_range) - height = int(round(params['image'].shape[0] * scale)) - width = int(round(params['image'].shape[1] * scale)) - return {'new_height': height, 'new_width': width} + height = int(round(params["image"].shape[0] * scale)) + width = int(round(params["image"].shape[1] * scale)) + return {"new_height": height, "new_width": width} - def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params): - return Func.resize(img, height=new_height, width=new_width, interpolation=interpolation) + def apply( + self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params + ): + return Func.resize( + img, height=new_height, width=new_width, interpolation=interpolation + ) def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params): scale_x = new_width / params["cols"] @@ -231,4 +276,4 @@ def get_transform_init_args_names(self): @property def targets_as_params(self): - return ["image"] \ No newline at end of file + return ["image"] diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..2129e141d --- /dev/null +++ b/setup.py @@ -0,0 +1,162 @@ +from __future__ import print_function + +import distutils.spawn +import os +import re +import shlex +import subprocess +import sys + +from setuptools import find_packages +from setuptools import setup + + +def get_version(): + filename = "labelx/__init__.py" + with open(filename) as f: + match = re.search(r"""^__version__ = ['"]([^'"]*)['"]""", f.read(), re.M) + if not match: + raise RuntimeError("{} doesn't contain __version__".format(filename)) + version = match.groups()[0] + return version + + +def get_install_requires(): + PY3 = sys.version_info[0] == 3 + PY2 = sys.version_info[0] == 2 + assert PY3 or PY2 + + install_requires = [ + "imgviz>=0.11.0", + "matplotlib", # for PyInstaller + "numpy", + "Pillow>=2.8.0", + "PyYAML", + "qtpy", + "termcolor", + "nibabel", + "watchpoints", + "SimpleITK", + "pyqt5", + "scikit-image", + ] + + # Find python binding for qt with priority: + # PyQt5 -> PySide2 -> PyQt4, + # and PyQt5 is automatically installed on Python3. + QT_BINDING = None + + try: + import PyQt5 # NOQA + + QT_BINDING = "pyqt5" + except ImportError: + pass + + if QT_BINDING is None: + try: + import PySide2 # NOQA + + QT_BINDING = "pyside2" + except ImportError: + pass + + if QT_BINDING is None: + try: + import PyQt4 # NOQA + + QT_BINDING = "pyqt4" + except ImportError: + if PY2: + print( + "Please install PyQt5, PySide2 or PyQt4 for Python2.\n" + "Note that PyQt5 can be installed via pip for Python3.", + file=sys.stderr, + ) + sys.exit(1) + assert PY3 + # PyQt5 can be installed via pip for Python3 + install_requires.append("PyQt5") + QT_BINDING = "pyqt5" + del QT_BINDING + + if os.name == "nt": # Windows + install_requires.append("colorama") + + return install_requires + + +def get_long_description(): + with open("README.md", encoding="utf-8") as f: + long_description = f.read() + try: + import github2pypi + + return github2pypi.replace_url(slug="wkentaro/labelx", content=long_description) + except Exception: + return long_description + + +def main(): + version = get_version() + + if sys.argv[1] == "release": + if not distutils.spawn.find_executable("twine"): + print( + "Please install twine:\n\n\tpip install twine\n", + file=sys.stderr, + ) + sys.exit(1) + + commands = [ + "python tests/docs_tests/man_tests/test_labelx_1.py", + "git tag v{:s}".format(version), + "git push origin master --tag", + "python setup.py sdist", + "twine upload dist/labelx-{:s}.tar.gz".format(version), + ] + for cmd in commands: + subprocess.check_call(shlex.split(cmd)) + sys.exit(0) + + setup( + name="labelx", + version=version, + packages=find_packages(exclude=["github2pypi"]), + description="Image Polygonal Annotation with Python", + long_description=get_long_description(), + long_description_content_type="text/markdown", + author="Kentaro Wada", + author_email="www.kentaro.wada@gmail.com", + url="https://github.com/wkentaro/labelx", + install_requires=get_install_requires(), + license="GPLv3", + keywords="Image Annotation, Machine Learning", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + ], + package_data={"labelx": ["icons/*", "config/*.yaml"]}, + entry_points={ + "console_scripts": [ + "labelx=labelx.__main__:main", + "labelx_draw_json=labelx.cli.draw_json:main", + "labelx_draw_label_png=labelx.cli.draw_label_png:main", + "labelx_json_to_dataset=labelx.cli.json_to_dataset:main", + "labelx_on_docker=labelx.cli.on_docker:main", + ], + }, + data_files=[("share/man/man1", ["docs/man/labelx.1"])], + ) + + +if __name__ == "__main__": + main() diff --git a/test.txt b/test.txt new file mode 100644 index 000000000..21cad8dd6 --- /dev/null +++ b/test.txt @@ -0,0 +1,2 @@ +1 人 255 120 0 +2 车 128 128 128