-
Notifications
You must be signed in to change notification settings - Fork 1
/
heterogeneity_loss.py
40 lines (36 loc) · 1.26 KB
/
heterogeneity_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from torch import nn, tensor
import torch
from torch.autograd import Variable
class hetero_loss(nn.Module):
def __init__(self, margin=0.1, dist_type = 'l2'):
super(hetero_loss, self).__init__()
self.margin = margin
self.dist_type = dist_type
if dist_type == 'l2':
self.dist = nn.MSELoss(reduction='sum') # 求所有对应位置差的平方的和,返回的是一个标量。
if dist_type == 'cos':
self.dist = nn.CosineSimilarity(dim=0)
if dist_type == 'l1':
self.dist = nn.L1Loss()
def forward(self, feat1, feat2, label1, label2):
feat_size = feat1.size()[1]
feat_num = feat1.size()[0]
label_num = len(label1.unique())
feat1 = feat1.chunk(label_num, 0)
feat2 = feat2.chunk(label_num, 0)
#loss = Variable(.cuda())
for i in range(label_num):
center1 = torch.mean(feat1[i], dim=0)
center2 = torch.mean(feat2[i], dim=0)
if self.dist_type == 'l2' or self.dist_type == 'l1':
if i == 0:
dist = max(0, self.dist(center1, center2) - self.margin)
else:
dist += max(0, self.dist(center1, center2) - self.margin)
elif self.dist_type == 'cos':
if i == 0:
dist = max(0, 1-self.dist(center1, center2) - self.margin)
else:
dist += max(0, 1-self.dist(center1, center2) - self.margin)
return dist