-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
69 lines (68 loc) · 2.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import copy
import numpy as np
def DPTrain(images, teacher, student, criterion_s, lr_gamma, clipping='auto', c=1e-4, e=0.001, sigma=100, batch_size=256, multiT=False, n_teachers=None, teachers=None):
'''
:param images: synthetic data (generated by generator in our case)
:param teacher: teacher model
:param student: student model
:param criterion_s: loss function (KLdiv in our case)
:param lr_gamma: learning rate
:param clipping: clipping or normalizing
:param c: clipping bound
:param e: a positive stability constant (0.001 in our case)
:param sigma: parameter of Gaussian noise
:param batch_size: batch size
:param multiT: multi teachers or not
:param n_teachers: number of teachers if multi teachers
:param teachers: list of teachers if multi teachers
:return: a new output of student which achieves DP
'''
with torch.no_grad():
t_out = teacher(images)
s_out = student(images.detach())
if multiT:
g = []
for j in range(n_teachers):
t_out = teachers[j](images)
loss = criterion_s(s_out,t_out.detach())
s_out.retain_grad()
loss.backward(retain_graph=True)
s_g = s_out.grad
if clipping =='abadi':
with torch.no_grad():
s_g = torch.clamp(s_g, min=-1*c, max=c)
else:
with torch.no_grad():
norm = (torch.max(s_g)**2)**(0.5)
s_g = s_g/(norm + e)
g.append(s_g)
s_g = sum(g)/n_teachers
else:
loss_old = criterion_s(s_out, t_out.detach())
s_out.retain_grad()
loss_old.backward(retain_graph=True)
s_g = s_out.grad
s_g_copy = copy.deepcopy(s_g)
if clipping == 'abadi':
with torch.no_grad():
s_g = torch.clamp(s_g, min=-1*c, max=c)
else:
with torch.no_grad():
norm = (torch.max(s_g)**2)**(0.5)
s_g = s_g / (norm + e)
with torch.no_grad():
s_ = sum(s_g)/batch_size
for k in range(len(s_g)):
s_g[k] = s_
if clipping == 'abadi':
noise = torch.tensor(np.random.normal(0, sigma*c, size=t_out.shape)/batch_size, dtype=t_out.dtype).cuda()
with torch.no_grad():
s_out_new = s_out - lr_gamma*(s_g+noise)
else:
noise = torch.tensor(np.random.normal(0, sigma, size=t_out.shape)/batch_size, dtype=t_out.dtype).cuda()
s_out_new = s_out - lr_gamma*(s_g+noise)
return s_out_new
# s_out_new = DPTrain(...)
# loss = criterion(s_out, s_out_new)
# update the student...