-
Notifications
You must be signed in to change notification settings - Fork 16
/
common.py
134 lines (102 loc) · 3.14 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import math
from collections.abc import Sequence, Mapping
import torch
class Args(tuple):
def __new__(cls, *args):
return super().__new__(cls, tuple(args))
def __repr__(self):
return "Args" + super().__repr__()
def one_hot(tensor, C=None, dtype=torch.float):
d = tensor.dim()
C = C or tensor.max() + 1
t = tensor.new_zeros(*tensor.size(), C, dtype=dtype)
return t.scatter_(d, tensor.unsqueeze(d), 1)
CUDA = torch.cuda.is_available()
def detach(t, clone=True):
if torch.is_tensor(t):
if clone:
return t.clone().detach()
else:
return t.detach()
elif isinstance(t, Args):
return t
elif isinstance(t, Sequence):
return t.__class__(detach(x, clone) for x in t)
elif isinstance(t, Mapping):
return t.__class__((k, detach(v, clone)) for k, v in t.items())
else:
return t
def cuda(t):
if torch.is_tensor(t):
return t.cuda() if CUDA else t
elif isinstance(t, Sequence):
return t.__class__(cuda(x) for x in t)
elif isinstance(t, Mapping):
return t.__class__((k, cuda(v)) for k, v in t.items())
else:
return t
def cpu(t):
if torch.is_tensor(t):
return t.cpu()
elif isinstance(t, Sequence):
return t.__class__(cpu(x) for x in t)
elif isinstance(t, Mapping):
return t.__class__((k, cpu(v)) for k, v in t.items())
else:
return t
def _tuple(x, n=-1):
if x is None:
return ()
elif torch.is_tensor(x):
return (x,)
elif not isinstance(x, Sequence):
assert n > 0, "Length must be positive, but got %d" % n
return (x,) * n
else:
if n == -1:
n = len(x)
else:
assert len(x) == n, "The length of x is %d, not equal to the expected length %d" % (len(x), n)
return tuple(x)
def select0(t, indices):
arange = torch.arange(t.size(1), device=t.device)
return t[indices, arange]
def select1(t, indices):
arange = torch.arange(t.size(0), device=t.device)
return t[arange, indices]
def select(t, dim, indices):
if dim == 0:
return select0(t, indices)
elif dim == 1:
return select1(t, indices)
else:
raise ValueError("dim could be only 0 or 1, not %d" % dim)
def sample(t, n):
if len(t) >= n:
indices = torch.randperm(len(t), device=t.device)[:n]
else:
indices = torch.randint(len(t), size=(n,), device=t.device)
return t[indices]
def _concat(xs, dim=1):
if torch.is_tensor(xs):
return xs
elif len(xs) == 1:
return xs[0]
else:
return torch.cat(xs, dim=dim)
def inverse_sigmoid(x, eps=1e-6, inplace=False):
if not torch.is_tensor(x):
if eps != 0:
x = min(max(x, eps), 1-eps)
return math.log(x / (1 - x))
if inplace:
return inverse_sigmoid_(x, eps)
if eps != 0:
x = torch.clamp(x, eps, 1-eps)
return (x / (1 - x)).log()
def inverse_sigmoid_(x, eps=1e-6):
if eps != 0:
x = torch.clamp_(x, eps, 1 - eps)
return x.div_(1 - x).log_()
def expand_last_dim(t, *size):
return t.view(*t.size()[:-1], *size)