-
Notifications
You must be signed in to change notification settings - Fork 1
/
util.py
39 lines (29 loc) · 1.26 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import logging
import random
import numpy as np
import torch
def save_model(save_dir, generator, discriminator, generator_opt, discriminator_opt, epoch):
save_dir.mkdir(exist_ok=True, parents=True)
model_path = save_dir / f'model-{epoch}'
logging.info(f'save model to {str(model_path)}')
with model_path.open('wb') as f:
torch.save({
'generator': generator.state_dict(),
'discriminator': discriminator.state_dict(),
'generator_opt': generator_opt.state_dict(),
'discriminator_opt': discriminator_opt.state_dict()
}, f)
def load_model(model_path, generator, discriminator, generator_opt, discriminator_opt, device):
logging.info(f'load model from {str(model_path)}')
with model_path.open('rb') as f:
weights = torch.load(f, map_location=device)
generator.load_state_dict(weights['generator'])
discriminator.load_state_dict(weights['discriminator'])
if generator_opt is not None:
generator_opt.load_state_dict(weights['generator_opt'])
if discriminator_opt is not None:
discriminator_opt.load_state_dict(weights['discriminator_opt'])
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)