-
Notifications
You must be signed in to change notification settings - Fork 1
/
criteria.py
88 lines (72 loc) · 3.05 KB
/
criteria.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
loss_names = ['l1', 'l2']
class MaskedMSELoss(nn.Module):
def __init__(self):
super(MaskedMSELoss, self).__init__()
def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
self.loss = (diff**2).mean()
return self.loss
class MaskedL1Loss(nn.Module):
def __init__(self):
super(MaskedL1Loss, self).__init__()
def forward(self, pred, target, weight=None):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
self.loss = diff.abs().mean()
return self.loss
class PhotometricLoss(nn.Module):
def __init__(self):
super(PhotometricLoss, self).__init__()
def forward(self, target, recon, mask=None):
assert recon.dim(
) == 4, "expected recon dimension to be 4, but instead got {}.".format(
recon.dim())
assert target.dim(
) == 4, "expected target dimension to be 4, but instead got {}.".format(
target.dim())
assert recon.size()==target.size(), "expected recon and target to have the same size, but got {} and {} instead"\
.format(recon.size(), target.size())
diff = (target - recon).abs()
diff = torch.sum(diff, 1) # sum along the color channel
# compare only pixels that are not black
valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1)
> 0).float()
if mask is not None:
valid_mask = valid_mask * torch.squeeze(mask).float()
valid_mask = valid_mask.byte().detach()
if valid_mask.numel() > 0:
diff = diff[valid_mask]
if diff.nelement() > 0:
self.loss = diff.mean()
else:
print(
"warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, try larger batch size)."
)
self.loss = 0
else:
print("warning: 0 valid pixel in PhotometricLoss")
self.loss = 0
return self.loss
class SmoothnessLoss(nn.Module):
def __init__(self):
super(SmoothnessLoss, self).__init__()
def forward(self, depth):
def second_derivative(x):
assert x.dim(
) == 4, "expected 4-dimensional data, but instead got {}".format(
x.dim())
horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, 1:-1, :
-2] - x[:, :, 1:-1, 2:]
vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2, 1:
-1] - x[:, :, 2:, 1:-1]
der_2nd = horizontal.abs() + vertical.abs()
return der_2nd.mean()
self.loss = second_derivative(depth)
return self.loss