-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.py
92 lines (75 loc) · 2.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
from subprocess import call
import torch
import config as cfg
class Logger(object):
"""Writes both to file and terminal"""
def __init__(self, savepath, mode='a'):
self.terminal = sys.stdout
self.log = open(savepath + 'logfile.log', mode)
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
class Normalizer(object):
"""Normalize a Tensor and restore it later. """
def __init__(self, tensor):
"""tensor is taken as a sample to calculate the mean and std"""
self.mean = torch.mean(tensor).type(cfg.FloatTensor)
self.std = torch.std(tensor).type(cfg.FloatTensor)
def norm(self, tensor):
if self.mean != self.mean or self.std != self.std:
return tensor
return (tensor - self.mean) / self.std
def denorm(self, normed_tensor):
if self.mean != self.mean or self.std != self.std:
return normed_tensor
return normed_tensor * self.std + self.mean
def state_dict(self):
return {'mean': self.mean,
'std': self.std}
def load_state_dict(self, state_dict):
self.mean = state_dict['mean']
self.std = state_dict['std']
class AverageMeter(object):
"""
Computes and stores the average and current value. Accomodates both numbers and tensors.
If the input to be monitored is a tensor, also need the dimensions/shape of the tensor.
Also, for tensors, it keeps a column wise count for average, sum etc.
"""
def __init__(self, is_tensor=False, dimensions=None):
if is_tensor and dimensions is None:
print('Bad definition of AverageMeter!')
sys.exit(1)
self.is_tensor = is_tensor
self.dimensions = dimensions
self.reset()
def reset(self):
self.count = 0
if self.is_tensor:
self.val = torch.zeros(self.dimensions, device=cfg.device)
self.avg = torch.zeros(self.dimensions, device=cfg.device)
self.sum = torch.zeros(self.dimensions, device=cfg.device)
else:
self.val = 0
self.avg = 0
self.sum = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def randomSeed(random_seed):
"""Given a random seed, this will help reproduce results across runs"""
if random_seed is not None:
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_seed)
def clearCache():
torch.cuda.empty_cache()