Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project refactoring for a new python format string and some simpler syntax. #32

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions core/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
if os.path.isfile(dir):
images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')]
images = list(np.genfromtxt(dir, dtype=np.str, encoding='utf-8'))
else:
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
assert os.path.isdir(dir), f'{dir} is not a valid directory'
for root, _, fnames in sorted(os.walk(dir)):
for fname in sorted(fnames):
if is_image_file(fname):
Expand All @@ -26,23 +28,25 @@ def make_dataset(dir):

return images


def pil_loader(path):
return Image.open(path).convert('RGB')


class BaseDataset(data.Dataset):
def __init__(self, data_root, image_size=[256, 256], loader=pil_loader):
def __init__(self, data_root, image_size=None, loader=pil_loader):
if image_size is None:
image_size = [256, 256]
self.imgs = make_dataset(data_root)
self.tfs = transforms.Compose([
transforms.Resize((image_size[0], image_size[1])),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
self.tfs = transforms.Compose([transforms.Resize((image_size[0], image_size[1])),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201))])

self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.tfs(self.loader(path))
return img
return self.tfs(self.loader(path)) # return image

def __len__(self):
return len(self.imgs)
81 changes: 45 additions & 36 deletions core/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import torch
import torch.nn as nn


import core.util as Util

CustomResult = collections.namedtuple('CustomResult', 'name result')

class BaseModel():

class BaseModel:
def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer):
""" init model with basic input, which are from __init__(**kwargs) function in inherited class """
self.opt = opt
Expand All @@ -24,7 +25,7 @@ def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer):
''' process record '''
self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size']
self.epoch = 0
self.iter = 0
self.iter = 0

self.phase_loader = phase_loader
self.val_loader = val_loader
Expand All @@ -33,24 +34,24 @@ def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer):
''' logger to log file, which only work on GPU 0. writer to tensorboard and result file '''
self.logger = logger
self.writer = writer
self.results_dict = CustomResult([],[]) # {"name":[], "result":[]}
self.results_dict = CustomResult([], []) # {"name":[], "result":[]}

def train(self):
while self.epoch <= self.opt['train']['n_epoch'] and self.iter <= self.opt['train']['n_iter']:
self.epoch += 1
if self.opt['distributed']:
''' sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch '''
self.phase_loader.sampler.set_epoch(self.epoch)
self.phase_loader.sampler.set_epoch(self.epoch)

train_log = self.train_step()

''' save logged informations into log dict '''
''' save logged informations into log dict '''
train_log.update({'epoch': self.epoch, 'iters': self.iter})

''' print logged informations to the screen and tensorboard '''
''' print logged informations to the screen and tensorboard '''
for key, value in train_log.items():
self.logger.info('{:5s}: {}\t'.format(str(key), value))

if self.epoch % self.opt['train']['save_checkpoint_epoch'] == 0:
self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch))
self.save_everything()
Expand Down Expand Up @@ -79,26 +80,26 @@ def val_step(self):

def test_step(self):
pass

def print_network(self, network):
""" print network structure, only work on GPU 0 """
if self.opt['global_rank'] !=0:
if self.opt['global_rank'] != 0:
return
if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
network = network.module

s, n = str(network), sum(map(lambda x: x.numel(), network.parameters()))
net_struc_str = '{}'.format(network.__class__.__name__)
net_struc_str = f'{network.__class__.__name__}'
self.logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n))

self.logger.info(s)

def save_network(self, network, network_label):
""" save network structure, only work on GPU 0 """
if self.opt['global_rank'] !=0:
if self.opt['global_rank'] != 0:
return
save_filename = '{}_{}.pth'.format(self.epoch, network_label)
save_filename = f'{self.epoch}_{network_label}.pth'
save_path = os.path.join(self.opt['path']['checkpoint'], save_filename)
if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
network = network.module
state_dict = network.state_dict()
for key, param in state_dict.items():
Expand All @@ -107,54 +108,62 @@ def save_network(self, network, network_label):

def load_network(self, network, network_label, strict=True):
if self.opt['path']['resume_state'] is None:
return
return
self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label))

model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label)

model_path = f"{self.opt['path']['resume_state']}_{network_label}.pth"
if not os.path.exists(model_path):
self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path))
return

return
self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path))
if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
if isinstance(network, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
network = network.module
network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict)
network.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: Util.set_device(storage)),
strict=strict)

def save_training_state(self):
""" saves training state during training, only work on GPU 0 """
if self.opt['global_rank'] !=0:
if self.opt['global_rank'] != 0:
return
assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.'
assert isinstance(self.optimizers, list) and isinstance(self.schedulers,
list), 'optimizers and schedulers must be a list.'

state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []}

for s in self.schedulers:
state['schedulers'].append(s.state_dict())
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
save_filename = '{}.state'.format(self.epoch)
save_filename = f'{self.epoch}.state'
save_path = os.path.join(self.opt['path']['checkpoint'], save_filename)
torch.save(state, save_path)

def resume_training(self):
""" resume the optimizers and schedulers for training, only work when phase is test or resume training enable """
if self.phase!='train' or self. opt['path']['resume_state'] is None:
if self.phase != 'train' or self.opt['path']['resume_state'] is None:
return
self.logger.info('Beign loading training states'.format())
assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.'

state_path = "{}.state".format(self. opt['path']['resume_state'])

assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), \
'optimizers and schedulers must be a list.'

state_path = f"{self.opt['path']['resume_state']}.state"

if not os.path.exists(state_path):
self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path))
return

self.logger.info('Loading training state for [{:s}] ...'.format(state_path))
resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage))
resume_state = torch.load(state_path, map_location=lambda storage, loc: self.set_device(storage))

resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers))
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers))
assert len(resume_optimizers) == len(
self.optimizers), f'Wrong lengths of optimizers {len(resume_optimizers)} != {len(self.optimizers)}'

assert len(resume_schedulers) == len(
self.schedulers), f'Wrong lengths of schedulers {len(resume_schedulers)} != {len(self.schedulers)}'

for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
Expand All @@ -164,8 +173,8 @@ def resume_training(self):
self.iter = resume_state['iter']

def load_everything(self):
pass
pass

@abstractmethod
def save_everything(self):
raise NotImplementedError('You must specify how to save your networks, optimizers and schedulers.')
83 changes: 41 additions & 42 deletions core/base_network.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,47 @@
import torch.nn as nn
class BaseNetwork(nn.Module):
def __init__(self, init_type='kaiming', gain=0.02):
super(BaseNetwork, self).__init__()
self.init_type = init_type
self.gain = gain

def init_weights(self):
"""
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
"""

def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if self.init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, self.gain)
elif self.init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=self.gain)
elif self.init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif self.init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif self.init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=self.gain)
elif self.init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError('initialization method [%s] is not implemented' % self.init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)

self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(self.init_type, self.gain)
class BaseNetwork(nn.Module):
def __init__(self, init_type='kaiming', gain=0.02):
super(BaseNetwork, self).__init__()
self.init_type = init_type
self.gain = gain

def init_weights(self):
"""
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
"""

def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if self.init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, self.gain)
elif self.init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=self.gain)
elif self.init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif self.init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif self.init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=self.gain)
elif self.init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(f'initialization method [{self.init_type}] is not implemented')

if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)


self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(self.init_type, self.gain)
Loading