-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
38 lines (31 loc) · 1.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from util.trainer import Trainer
from util.dataset import AutoFocusDataset
with open('config.json', 'r') as f:
config = json.load(f)
np.random.seed(config['General']['seed'])
list_data = config['Dataset']['paths']['list_datasets']
## train set
autofocus_datasets_train = []
for dataset_name in list_data:
autofocus_datasets_train.append(AutoFocusDataset(config, dataset_name, 'train'))
train_data = ConcatDataset(autofocus_datasets_train)
train_dataloader = DataLoader(train_data, batch_size=config['General']['batch_size'],
shuffle=True,
pin_memory=True,
num_workers=config['General']['num_workers'])
## validation set
autofocus_datasets_val = []
for dataset_name in list_data:
autofocus_datasets_val.append(AutoFocusDataset(config, dataset_name, 'val'))
val_data = ConcatDataset(autofocus_datasets_val)
val_dataloader = DataLoader(val_data, batch_size=config['General']['batch_size'],
shuffle=False,
pin_memory=True,
num_workers=config['General']['num_workers'])
trainer = Trainer(config)
trainer.train(train_dataloader, val_dataloader)