-
Notifications
You must be signed in to change notification settings - Fork 0
/
regularization.py
110 lines (93 loc) · 4.24 KB
/
regularization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from auto_LiRPA import BoundedModule, BoundDataParallel, BoundedTensor, CrossEntropyWrapper
from auto_LiRPA.bound_ops import *
from collections import namedtuple
Node = namedtuple('Node', 'node lower upper')
def compute_L1_reg(args, model, meter):
loss = torch.zeros(()).to(args.device)
for module in model._modules.values():
if isinstance(module, nn.Linear):
loss += torch.abs(module.weight).sum()
elif isinstance(module, nn.Conv2d):
loss += torch.abs(module.weight).sum()
meter.update('L1_loss', loss)
return loss * args.l1_coeff
def compute_reg(args, model, meter, eps, eps_scheduler):
loss = torch.zeros(()).to(args.device)
# Handle the non-feedforward case
l0 = torch.zeros_like(loss)
loss_tightness, loss_std, loss_relu, loss_ratio = (l0.clone() for i in range(4))
if isinstance(model, BoundDataParallel):
modules = list(model._modules.values())[0]._modules
else:
modules = model._modules
node_inp = modules['/input.1']
tightness_0 = ((node_inp.upper - node_inp.lower) / 2).mean()
ratio_init = tightness_0 / ((node_inp.upper + node_inp.lower) / 2).std()
cnt_layers = 0
cnt = 0
for m in model._modules.values():
if isinstance(m, BoundRelu):
lower, upper = m.inputs[0].lower, m.inputs[0].upper
center = (upper + lower) / 2
diff = ((upper - lower) / 2)
tightness = diff.mean()
mean_ = center.mean()
std_ = center.std()
loss_tightness += F.relu(args.tol - tightness_0 / tightness.clamp(min=1e-12)) / args.tol
loss_std += F.relu(args.tol - std_) / args.tol
cnt += 1
# L_{relu}
mask_act, mask_inact = lower>0, upper<0
mean_act = (center * mask_act).mean()
mean_inact = (center * mask_inact).mean()
delta = (center - mean_)**2
var_act = (delta * mask_act).sum()# / center.numel()
var_inact = (delta * mask_inact).sum()# / center.numel()
mean_ratio = mean_act / -mean_inact
var_ratio = var_act / var_inact
mean_ratio = torch.min(mean_ratio, 1 / mean_ratio.clamp(min=1e-12))
var_ratio = torch.min(var_ratio, 1 / var_ratio.clamp(min=1e-12))
loss_relu_ = ((
F.relu(args.tol - mean_ratio) + F.relu(args.tol - var_ratio))
/ args.tol)
if not torch.isnan(loss_relu_) and not torch.isinf(loss_relu_):
loss_relu += loss_relu_
if args.debug:
bn_mean = (lower.mean() + upper.mean()) / 2
bn_var = ((upper**2 + lower**2) / 2).mean() - bn_mean**2
print(m.name, m,
'tightness {:.4f} gain {:.4f} std {:.4f}'.format(
tightness.item(), (tightness/tightness_0).item(), std_.item()),
'input', m.inputs[0], m.inputs[0].name,
'active {:.4f} inactive {:.4f}'.format(
(lower>0).float().sum()/lower.numel(),
(upper<0).float().sum()/lower.numel()),
'bnv2_mean {:.5f} bnv2_var {:.5f}'.format(bn_mean.item(), bn_var.item())
)
# pre-bn
lower, upper = m.inputs[0].inputs[0].lower, m.inputs[0].inputs[0].upper
bn_mean = (lower.mean() + upper.mean()) / 2
bn_var = ((upper**2 + lower**2) / 2).mean() - bn_mean**2
print('pre-bn',
'bnv2_mean {:.5f} bnv2_var {:.5f}'.format(bn_mean.item(), bn_var.item()))
loss_tightness /= cnt
loss_std /= cnt
loss_relu /= cnt
if args.debug:
pdb.set_trace()
for item in ['tightness', 'relu', 'std']:
loss_ = eval('loss_{}'.format(item))
if item in args.reg_obj:
loss += loss_
meter.update('L_{}'.format(item), loss_)
meter.update('loss_reg', loss)
if args.no_reg_dec:
intensity = args.reg_lambda
else:
intensity = args.reg_lambda * (1 - eps_scheduler.get_eps() / eps_scheduler.get_max_eps())
loss *= intensity
return loss