-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/PaddleCV-SIG/iann
- Loading branch information
Showing
68 changed files
with
615 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from pathlib import Path | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from .base import ISDataset | ||
|
||
|
||
class MyDataset(ISDataset): | ||
def __init__(self, dataset_path, folder_name, | ||
images_dir_name, masks_dir_name, | ||
**kwargs): | ||
super(MyDataset, self).__init__(**kwargs) | ||
|
||
self.dataset_path = Path(dataset_path) / folder_name | ||
self._images_path = self.dataset_path / images_dir_name | ||
self._insts_path = self.dataset_path / masks_dir_name | ||
|
||
self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] | ||
self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} | ||
|
||
def get_sample(self, index): | ||
image_name = self.dataset_samples[index] | ||
image_path = str(self._images_path / image_name) | ||
mask_path = str(self._masks_paths[image_name.split('.')[0]]) | ||
|
||
image = cv2.imread(image_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2) | ||
instances_mask[instances_mask > 0] = 1 | ||
|
||
instances_ids = [1] | ||
|
||
instances_info = { | ||
x: {'ignore': False} | ||
for x in instances_ids | ||
} | ||
|
||
return { | ||
'image': image, | ||
'instances_mask': instances_mask, | ||
'instances_info': instances_info, | ||
'image_id': index | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import numpy as np | ||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
import util.util as U | ||
|
||
|
||
class NormalizedFocalLossSigmoid(nn.Layer): | ||
def __init__(self, axis=-1, alpha=0.25, gamma=2, | ||
from_logits=False, batch_axis=0, | ||
weight=None, size_average=True, detach_delimeter=True, | ||
eps=1e-12, scale=1.0, | ||
ignore_label=-1): | ||
super(NormalizedFocalLossSigmoid, self).__init__() | ||
self._axis = axis | ||
self._alpha = alpha | ||
self._gamma = gamma | ||
self._ignore_label = ignore_label | ||
self._weight = weight if weight is not None else 1.0 | ||
self._batch_axis = batch_axis | ||
|
||
self._scale = scale | ||
self._from_logits = from_logits | ||
self._eps = eps | ||
self._size_average = size_average | ||
self._detach_delimeter = detach_delimeter | ||
self._k_sum = 0 | ||
|
||
def forward(self, pred, label, sample_weight=None): | ||
one_hot = label > 0 | ||
sample_weight = label != self._ignore_label | ||
sample_weight = sample_weight.astype('float32') | ||
|
||
if not self._from_logits: | ||
pred = F.sigmoid(pred) | ||
|
||
x = sample_weight * 0.5 | ||
y = (1 - self._alpha) * sample_weight | ||
alpha = paddle.where(one_hot, x, y) | ||
pt = paddle.where(one_hot, pred, 1 - pred) | ||
sample_weight = sample_weight.astype('bool') | ||
pt = paddle.where(sample_weight, pt, paddle.ones_like(pt)) | ||
beta = (1 - pt) ** self._gamma | ||
sample_weight = sample_weight.astype('float32') | ||
sw_sum = paddle.sum(sample_weight, axis=(-2, -1), keepdim=True) | ||
beta_sum = paddle.sum(beta, axis=(-2, -1), keepdim=True) | ||
mult = sw_sum / (beta_sum + self._eps) | ||
|
||
if self._detach_delimeter: | ||
mult = mult.detach() | ||
|
||
beta = beta * mult | ||
ignore_area = paddle.sum((label == self._ignore_label).astype('float32'), axis=tuple(range(1, len(label.shape)))).numpy() | ||
sample_mult = paddle.mean(mult, axis=tuple(range(1, len(mult.shape)))).numpy() | ||
if np.any(ignore_area == 0): | ||
self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() | ||
loss = -alpha * beta * paddle.log(paddle.mean(pt + self._eps)) | ||
loss = self._weight * (loss * sample_weight) | ||
|
||
if self._size_average: | ||
bsum = paddle.sum(sample_weight, axis=U.get_dims_with_exclusion(len(sample_weight.shape), self._batch_axis)) | ||
loss = paddle.sum(loss, axis=U.get_dims_with_exclusion(len(loss.shape), self._batch_axis)) / ( | ||
bsum + self._eps) | ||
else: | ||
loss = paddle.sum(loss, axis=U.get_dims_with_exclusion(len(loss.shape), self._batch_axis)) | ||
|
||
return self._scale * loss | ||
|
||
|
||
class FocalLoss(nn.Layer): | ||
def __init__(self, axis=-1, alpha=0.25, gamma=2, | ||
from_logits=False, batch_axis=0, | ||
weight=None, num_class=None, | ||
eps=1e-9, size_average=True, scale=1.0): | ||
super(FocalLoss, self).__init__() | ||
self._axis = axis | ||
self._alpha = alpha | ||
self._gamma = gamma | ||
self._weight = weight if weight is not None else 1.0 | ||
self._batch_axis = batch_axis | ||
|
||
self._scale = scale | ||
self._num_class = num_class | ||
self._from_logits = from_logits | ||
self._eps = eps | ||
self._size_average = size_average | ||
|
||
def forward(self, pred, label, sample_weight=None): | ||
if not self._from_logits: | ||
pred = F.sigmoid(pred) | ||
|
||
one_hot = label > 0 | ||
pt = paddle.where(one_hot, pred, 1 - pred) | ||
t = label != -1 | ||
alpha = paddle.where(one_hot, self._alpha * t, (1 - self._alpha) * t) | ||
beta = (1 - pt) ** self._gamma | ||
|
||
loss = -alpha * beta * paddle.log(paddle.min(pt + self._eps, paddle.ones(1, dtype='float32'))) | ||
sample_weight = label != -1 | ||
|
||
loss = self._weight * (loss * sample_weight) | ||
|
||
if self._size_average: | ||
tsum = paddle.sum(label == 1, axis=U.get_dims_with_exclusion(len(label.shape), self._batch_axis)) | ||
loss = paddle.sum(loss, axis=U.get_dims_with_exclusion(len(loss.shape), self._batch_axis)) / ( | ||
tsum + self._eps) | ||
else: | ||
loss = paddle.sum(loss, axis=U.get_dims_with_exclusion(len(loss.shape), self._batch_axis)) | ||
|
||
return self._scale * loss | ||
|
||
|
||
class SigmoidBinaryCrossEntropyLoss(nn.Layer): | ||
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): | ||
super(SigmoidBinaryCrossEntropyLoss, self).__init__() | ||
self._from_sigmoid = from_sigmoid | ||
self._ignore_label = ignore_label | ||
self._weight = weight if weight is not None else 1.0 | ||
self._batch_axis = batch_axis | ||
|
||
def forward(self, pred, label): | ||
label = label.reshape(pred.shape) | ||
sample_weight = label != self._ignore_label | ||
label = paddle.where(sample_weight, label, paddle.zeros_like(label)) | ||
|
||
if not self._from_sigmoid: | ||
loss = F.relu(pred) - pred * label + F.softplus(-paddle.abs(pred)) | ||
else: | ||
eps = 1e-12 | ||
loss = -(paddle.log(pred + eps) * label + paddle.log(1. - pred + eps) * (1. - label)) | ||
loss = self._weight * (loss * sample_weight) | ||
return paddle.mean(loss, axis=U.get_dims_with_exclusion(len(loss.shape), self._batch_axis)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# 训练iann可用的自定义模型 | ||
|
||
目前已经可以通过简单的配置完成模型训练了,但其中有些设置还不能通过配置文件进行修改。 | ||
|
||
## 一、数据组织 | ||
|
||
在需要训练自己的数据集时,目前需要将数据集构造为如下格式,直接放在datasets文件夹中。文件名可以根据要求来进行设置,只需要在配置文件中设定好即可,图像和标签与平时使用的分割图像的用法相同。 | ||
|
||
``` | ||
datasets | ||
| | ||
├── train_data | ||
| ├── img | ||
| | └── filename_1.jpg | ||
| └── gt | ||
| └── filename_1.png | ||
| | ||
└── eval_data | ||
├── img | ||
| └── filename_1.jpg | ||
└── gt | ||
└── filename_1.png | ||
``` | ||
|
||
## 二、训练 | ||
|
||
直接运行ritm_train.py即可开始训练。 | ||
|
||
```python | ||
%cd train | ||
! python ritm_train.py --config train_config.yaml | ||
``` | ||
|
||
目前一些简单的参数已经可以在yaml配置文件中进行自定义设置,不过现阶段仍然不够灵活,可能出现各种问题。 | ||
|
||
``` | ||
iters: 100000 # 训练轮数 | ||
batch_size: 16 # bs大小 | ||
save_interval: 1000 # 保存间隔 | ||
log_iters: 10 # 打印log的间隔 | ||
worker: 4 # 子进程数 | ||
save_dir: model_output # 保存路径 | ||
use_vdl: False # 是否使用vdl | ||
dataset: | ||
dataset_path: iann/train/datasets # 数据集所在路径 | ||
image_name: img # 图像文件夹的名称 | ||
label_name: gt # 标签文件夹的名称 | ||
train_dataset: # 训练数据 | ||
crop_size: [320, 480] # 裁剪大小 | ||
folder_name: train_data # 训练数据文件夹的名称 | ||
val_dataset: # 验证数据 | ||
folder_name: val_data # 验证数据文件夹的名称 | ||
optimizer: | ||
type: adam # 优化器,目前仅可以选择‘adam’和‘sgd’ | ||
learning_rate: | ||
value_1: 5e-5 # 需要设置两个学习率 | ||
value_2: 5e-6 | ||
decay: | ||
type: poly # 学习率衰减,目前仅支持‘poly’,可以修改下面的参数 | ||
steps: 1000 | ||
power: 0.9 | ||
end_lr: 0.0 | ||
model: | ||
type: deeplab # 模型名称,目前支持‘hrnet’、‘deeplab’以及‘shufflenet’ | ||
backbone: resnet18 # 下面的参数是模型对应的参数,可在源码中查看 | ||
is_ritm: True | ||
weights: None # 加载权重的路径 | ||
``` | ||
|
||
|
||
|
||
### * 说明 | ||
|
||
1. 这里有个坑,数据不能有没有标签的纯背景,这样找不到正样点训练就会卡住,并且还不报错。 | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.