-
Notifications
You must be signed in to change notification settings - Fork 17
/
utils.py
47 lines (40 loc) · 1.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import os
import os.path
import torch.nn.functional as F
def prepare_save_dir(args, filename):
""" Create saving directory."""
runner_name = os.path.basename(filename).split(".")[0]
model_dir = './experiments/{}/{}/'.format(runner_name, args.name)
args.savedir = model_dir
if not os.path.exists(args.savedir):
os.makedirs(args.savedir)
return args
def entropy(x):
"""
Helper function to compute the entropy over the batch
input: batch w/ shape [b, num_classes]
output: entropy value [is ideally -log(num_classes)]
"""
EPS = 1e-8
x_ = torch.clamp(x, min = EPS)
b = x_ * torch.log(x_)
if len(b.size()) == 2: # Sample-wise entropy
return - b.sum(dim = 1).mean()
elif len(b.size()) == 1: # Distribution-wise entropy
return - b.sum()
else:
raise ValueError('Input tensor is %d-Dimensional' %(len(b.size())))
class MarginLoss(nn.Module):
def __init__(self, m=0.2, weight=None, s=10):
super(MarginLoss, self).__init__()
self.m = m
self.s = s
self.weight = weight
def forward(self, x, target):
index = torch.zeros_like(x, dtype=torch.bool)
index.scatter_(1, target.data.view(-1, 1), 1)
x_m = x - self.m * self.s
output = torch.where(index, x_m, x)
return F.cross_entropy(output, target, weight=self.weight)