-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_nodule_segmentation.py
47 lines (43 loc) · 2.64 KB
/
train_nodule_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Launch training for lung segmentation
import argparse
from lucanode.training import detection
NETWORK_VARIATIONS = {
"no_augmentation_no_normalization_binary_crossentropy":
detection.train_nodule_segmentation_no_augmentation_no_normalization_binary_crossentropy,
"no_augmentation_normalization_binary_crossentropy":
detection.train_nodule_segmentation_no_augmentation_normalization_binary_crossentropy,
"no_augmentation_normalization_dice":
detection.train_nodule_segmentation_no_augmentation_normalization_dice,
"augmentation_normalization_dice_3ch_laplacian_mislabeling":
detection.train_nodule_segmentation_augmentation_normalization_dice_3ch_laplacian_mislabeling,
"augmentation_normalization_bce":
detection.train_nodule_segmentation_augmentation_normalization_bce,
"augmentation_normalization_dice_3ch":
detection.train_nodule_segmentation_augmentation_normalization_dice_3ch,
"augmentation_normalization_dice_3ch_laplacian":
detection.train_nodule_segmentation_augmentation_normalization_dice_3ch_laplacian,
"augmentation_normalization_bce_3ch_laplacian":
detection.train_nodule_segmentation_augmentation_normalization_bce_3ch_laplacian,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train nodule segmentation neural network')
parser.add_argument('dataset_hdf5', type=str, help="path where the dataset hdf5 (1mm spacing) is stored")
parser.add_argument('weights_file_output', type=str, help="path where the network weights will be saved")
parser.add_argument('variation', type=str, help="Name of the function", choices=NETWORK_VARIATIONS.keys())
parser.add_argument('--batch-size', dest='batch_size', type=int, default=5, action="store",
help="Training batch size")
parser.add_argument('--num-epochs', dest='num_epochs', type=int, default=5, action="store",
help="Number of training epochs")
parser.add_argument('--last-epoch', dest='last_epoch', type=int, default=0, action='store',
help="Last finished epoch (picks up training from there). Useful if passing --initial-weights")
parser.add_argument('--initial-weights', dest='initial_weights', type=str, default=None, action='store',
help="Initial weights to load into the network (.h5 file path)")
args = parser.parse_args()
NETWORK_VARIATIONS[args.variation](
args.dataset_hdf5,
args.weights_file_output,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
last_epoch=args.last_epoch,
initial_weights=args.initial_weights,
)