Skip to content

Commit

Permalink
初步完成多标签不同颜色
Browse files Browse the repository at this point in the history
  • Loading branch information
linhandev committed May 8, 2021
1 parent f2d3d0f commit a7e2652
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 104 deletions.
106 changes: 51 additions & 55 deletions iann/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,12 @@ def __init__(self, parent=None):
self.statusbar.showMessage("模型未加载")

self.initActions()
# TODO: 按照labelme的方式用action
## 菜单栏点击
# for menu_act in self.menuBar.actions():
# if menu_act.text() == "文件":
# for ac_act in menu_act.menu().actions():
# if ac_act.text() == "加载图像":
# ac_act.triggered.connect(self.openImage)
# else:
# ac_act.triggered.connect(self.openFolder)
# elif menu_act.text() == "设置":
# for ac_act in menu_act.menu().actions():
# if ac_act.text() == "设置保存路径":
# ac_act.triggered.connect(self.changeOutputDir)
# else:
# ac_act.triggered.connect(self.check_click)
# elif menu_act.text() == "帮助":
# for ac_act in menu_act.menu().actions():
# if ac_act.text() == "快速上手":
# ac_act.triggered.connect(self.help_dialog.show)
# else:
# ac_act.triggered.connect(self.check_click)
#
# ## 工具栏点击
# for tool_act in self.toolBar.actions():
# if tool_act.text() == "完成当前":
# tool_act.triggered.connect(self.finishObject)
# elif tool_act.text() == "清除全部":
# tool_act.triggered.connect(self.undoAll)
# elif tool_act.text() == "撤销":
# tool_act.triggered.connect(self.undoClick)
# elif tool_act.text() == "重做":
# tool_act.triggered.connect(self.check_click)
# elif tool_act.text() == "上一张":
# tool_act.triggered.connect(partial(self.turnImg, -1))
# elif tool_act.text() == "下一张":
# tool_act.triggered.connect(partial(self.turnImg, 1))

## 按钮点击
self.btnSave.clicked.connect(self.saveLabel) # 保存
self.listFiles.itemDoubleClicked.connect(self.listClicked) # list选择
self.comboModelSelect.currentIndexChanged.connect(self.changeModel) # 模型选择
self.btnAddClass.clicked.connect(self.addLabel)

# 滑动
self.sldOpacity.valueChanged.connect(self.maskOpacityChanged)
Expand All @@ -100,13 +65,13 @@ def toBeImplemented(self):
self.statusbar.showMessage("功能尚在开发")
pass

def menu(self, title, actions=None):
menu = self.menuBar().addMenu(title)
if actions:
util.addActions(menu, actions)
return menu

def initActions(self):
def menu(title, actions=None):
menu = self.menuBar().addMenu(title)
if actions:
util.addActions(menu, actions)
return menu

action = partial(util.newAction, self)
shortcuts = {
"turn_next": "F",
Expand Down Expand Up @@ -210,14 +175,7 @@ def initActions(self):
"redo",
self.tr("重做一次点击"),
)
# finish_object = action(
# self.tr("&N^2宫格标注"),
# self.toBeImplemented,
# None,
# # TODO: 搞个图
# "",
# self.tr("使用N^2宫格进行细粒度标注"),
# )

# TODO: 改用manager
self.actions = util.struct(
turn_next=turn_next,
Expand All @@ -236,9 +194,9 @@ def initActions(self):
labelMenu=(grid_ann,),
toolBar=(finish_object, clear, undo, redo, turn_prev, turn_next),
)
self.menu("文件", self.actions.fileMenu)
self.menu("标注", self.actions.labelMenu)
self.menu("帮助", self.actions.helpMenu)
menu("文件", self.actions.fileMenu)
menu("标注", self.actions.labelMenu)
menu("帮助", self.actions.helpMenu)
util.addActions(self.toolBar, self.actions.toolBar)

def changeOutputDir(self, dir=None):
Expand Down Expand Up @@ -278,9 +236,24 @@ def changeModel(self, idx):

self.statusbar.showMessage(f"{ models[idx].name}模型加载完成", 5000)

def addLabel(self):
table = self.labelListTable
table.insertRow(table.rowCount())
idx = table.rowCount() - 1
print(idx)
numberItem = QTableWidgetItem(str(idx + 1))
numberItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 0, numberItem)
table.setItem(idx, 1, QTableWidgetItem())
c = [255, 255, 255]
colorItem = QTableWidgetItem()
colorItem.setBackground(QtGui.QColor(c[0], c[1], c[2]))
colorItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 2, colorItem)
self.labelList.append([idx + 1, "", [255, 255, 255]])

def refreshLabelList(self):
table = self.labelListTable
# TODO: 添加表头
table.clearContents()
table.setRowCount(len(self.labelList))
table.setColumnCount(3)
Expand All @@ -303,12 +276,26 @@ def changeLabelColor(row, col):
if col != 2:
return
color = QtWidgets.QColorDialog.getColor()
# BUG: 判断颜色没变
print(color.getRgb())
table.item(row, col).setBackground(color)
self.labelList[row][2] = color.getRgb()[:3]
if self.controller:
self.controller.label_list = self.labelList

table.cellDoubleClicked.connect(changeLabelColor)

def cellClicked(row, col):
print("cell clicked", row, col)
for idx in range(3):
table.item(row, idx).setSelected(True)
if self.controller:
print(int(table.item(row, 0).text()))
self.controller.change_label_num(int(table.item(row, 0).text()))
self.controller.label_list = self.labelList

table.cellClicked.connect(cellClicked)

def openImage(self):
formats = [
"*.{}".format(fmt.data().decode())
Expand Down Expand Up @@ -471,7 +458,7 @@ def threshChanged(self):
self._update_image()

def undoClick(self):
if not self.image:
if self.image is None:
return
self.controller.undo_click()

Expand All @@ -488,6 +475,15 @@ def canvasClick(self, x, y, isLeft):
return
if x < 0 or y < 0:
return
if not self.controller.curr_label_number:
msg = QMessageBox()
msg.setIcon(QMessageBox.Warning)
msg.setWindowTitle("未选择当前标签")
msg.setText("请先在标签列表中单击点选标签")
msg.setStandardButtons(QMessageBox.Yes)
res = msg.exec_()
return

s = self.controller.img_size
if x > s[0] or y > s[1]:
return
Expand Down
99 changes: 64 additions & 35 deletions iann/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def __init__(self, net, predictor_params, update_image_callback, prob_thresh=0.5
self.clicker = clicker.Clicker()
self.states = []
self.probs_history = []
self.object_count = 0
self.curr_label_number = 0
self._result_mask = None
self.label_list = None # 存标签编号和颜色的对照

self.image = None
self.image_nd = None
Expand All @@ -41,12 +42,10 @@ def set_image(self, image):
self.image_nd = input_transform(image)[0]

self._result_mask = np.zeros(image.shape[:2], dtype=np.uint8)
self.object_count = 0
self.curr_label_number = 0
self.reset_last_object(update_image=False)
self.update_image_callback(reset_canvas=True)

# def change_alpha(self, alpha_blend):
# self.
def add_click(self, x, y, is_positive):
"""添加一个点
跑推理,保存历史用于undo
Expand Down Expand Up @@ -114,11 +113,27 @@ def finish_object(self):
if object_prob is None:
return

self.object_count += 1 # TODO: 当前是按照第几个目标给结果中的数,改成根据目标编号
# self.curr_label_number += 1 # TODO: 当前是按照第几个目标给结果中的数,改成根据目标编号
object_mask = object_prob > self.prob_thresh
self._result_mask[object_mask] = self.object_count
self._result_mask[object_mask] = self.curr_label_number
self.reset_last_object()

def change_label_num(self, number):
"""修改当前标签的编号
如果当前有标注到一半的目标,改mask。
如果没有,下一个目标是这个数
Parameters
----------
number : int
换成目标的编号
"""
assert isinstance(number, int), "标签编号应为整数"
self.curr_label_number = number
if self.is_incomplete_mask:
pass
# TODO: 改当前mask的编号

def reset_last_object(self, update_image=True):
"""重置控制器状态
Expand Down Expand Up @@ -158,6 +173,49 @@ def reset_predictor(self, net=None, predictor_params=None):
if self.image_nd is not None:
self.predictor.set_input_image(self.image_nd)

def get_visualization(self, alpha_blend, click_radius):
if self.image is None:
return None

# 1. 画当前没标完的mask
results_mask_for_vis = self.result_mask
if self.probs_history:
results_mask_for_vis[
self.current_object_prob > self.prob_thresh
] = self.curr_label_number

vis = draw_with_blend_and_clicks(
self.image,
mask=results_mask_for_vis,
alpha=alpha_blend,
clicks_list=self.clicker.clicks_list,
radius=click_radius,
palette=self.palette,
)

# 2. 在图片和当前mask的基础上画之前标完的mask
if self.probs_history:
total_mask = self.probs_history[-1][0] > self.prob_thresh
results_mask_for_vis[np.logical_not(total_mask)] = 0
vis = draw_with_blend_and_clicks(
vis,
mask=results_mask_for_vis,
alpha=alpha_blend,
palette=self.palette,
)

return vis

@property
def palette(self):
if self.label_list:
colors = [l[2] for l in self.label_list]
colors.insert(0, [0, 0, 0])
else:
colors = [[0, 0, 0]]
print(colors)
return colors

@property
def current_object_prob(self):
"""获取当前推理标签"""
Expand Down Expand Up @@ -185,32 +243,3 @@ def result_mask(self):
def img_size(self):
print(self.image.shape)
return self.image.shape[1::-1]

def get_visualization(self, alpha_blend, click_radius):
if self.image is None:
return None

# 1. 画当前没标完的mask
results_mask_for_vis = self.result_mask
if self.probs_history:
results_mask_for_vis[self.current_object_prob > self.prob_thresh] = (
self.object_count + 1
)

vis = draw_with_blend_and_clicks(
self.image,
mask=results_mask_for_vis,
alpha=alpha_blend,
clicks_list=self.clicker.clicks_list,
radius=click_radius,
)

# 2. 在图片和当前mask的基础上画之前标完的mask
if self.probs_history:
total_mask = self.probs_history[-1][0] > self.prob_thresh
results_mask_for_vis[np.logical_not(total_mask)] = 0
vis = draw_with_blend_and_clicks(
vis, mask=results_mask_for_vis, alpha=alpha_blend
)

return vis
1 change: 1 addition & 0 deletions iann/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def setupUi(self, MainWindow):
# 标签列表
labelListLab = self.create_text(CentralWidget, "labelListLab", "标签列表")
listRegion.addWidget(labelListLab)
# TODO: 改成 list widget
self.labelListTable = QtWidgets.QTableWidget(CentralWidget)
self.labelListTable.horizontalHeader().hide()
self.labelListTable.verticalHeader().hide()
Expand Down
Loading

0 comments on commit a7e2652

Please sign in to comment.