forked from chrisorozco1097/brain_segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
70 lines (52 loc) · 1.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from datahandler import DataHandler
from models.unet.unet import *
from generator import *
from params import *
from callbacks import getCallbacks
from tqdm import tqdm
import os
import skimage.io as io
from keras.models import *
from keras import backend as K
import argparse
import sys
os.environ["CUDA_DEVICE_ORDER"]="00000000:D8:00.0"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
#Get argument from command line
parser = argparse.ArgumentParser()
parser.add_argument('--exp', required=True,
help='Experiments name, to save weights and logs')
args = parser.parse_args()
exp_name = args.exp
#get parameters
params = getParams(exp_name)
#set common variables
epochs = params['epochs']
batch_size = params['batch_size']
verbose = params['verbose']
val_to_monitor = params['val_to_monitor']
resetSeed()
#Get data and generators
dh = DataHandler()
tr_images, tr_masks, te_images, te_masks = dh.getData()
train_generator = getGenerator(tr_images, tr_masks,
augmentation = False, batch_size=batch_size)
val_generator = getGenerator(te_images, te_masks,
augmentation = False, batch_size=batch_size)
#Get model and add weights
model = getUnet()
#load weights from other problem transfer learning
#model.load_weights('./weights/unet_transfer.h5')
# print(model.summary())
model_json = model.to_json()
with open(params['model_name'], "w") as json_file:
json_file.write(model_json)
Checkpoint, EarlyStop, ReduceLR, Logger, TenBoard = getCallbacks(params)
#Train the model
history = model.fit_generator(train_generator,
epochs=epochs,
steps_per_epoch = len(tr_images) / batch_size,
validation_data = val_generator,
validation_steps = len(te_images) / batch_size,
verbose = verbose,
callbacks = [Checkpoint, Logger, TenBoard])