-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
executable file
·124 lines (98 loc) · 4.38 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Implementation of Yolo Loss Function from the original yolo paper
"""
import torch
import torch.nn as nn
from utils import intersection_over_union
class YoloLoss(nn.Module):
"""
Calculate the loss for yolo (v1) model
"""
def __init__(self, S=7, B=2, C=20):
super(YoloLoss, self).__init__()
self.mse = nn.MSELoss(reduction="sum")
"""
S is split size of image (in paper 7),
B is number of boxes (in paper 2),
C is number of classes (in paper and VOC dataset is 20),
"""
self.S = S
self.B = B
self.C = C
# These are from Yolo paper, signifying how much we should
# pay loss for no object (noobj) and the box coordinates (coord)
self.lambda_noobj = 0.5
self.lambda_coord = 5
def forward(self, predictions, target):
# predictions are shaped (BATCH_SIZE, S*S(C+B*5) when inputted
predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)
# Calculate IoU for the two predicted bounding boxes with target bbox
iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
# Take the box with highest IoU out of the two prediction
# Note that bestbox will be indices of 0, 1 for which bbox was best
iou_maxes, bestbox = torch.max(ious, dim=0)
exists_box = target[..., 20].unsqueeze(3) # in paper this is Iobj_i
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
# Set boxes with no object in them to 0. We only take out one of the two
# predictions, which is the one with highest Iou calculated previously.
box_predictions = exists_box * (
(
bestbox * predictions[..., 26:30]
+ (1 - bestbox) * predictions[..., 21:25]
)
)
box_targets = exists_box * target[..., 21:25]
# Take sqrt of width, height of boxes to ensure that
box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
torch.abs(box_predictions[..., 2:4] + 1e-6)
)
box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
box_loss = self.mse(
torch.flatten(box_predictions, end_dim=-2),
torch.flatten(box_targets, end_dim=-2),
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
# pred_box is the confidence score for the bbox with highest IoU
pred_box = (
bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
)
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[..., 20:21]),
)
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
#max_no_obj = torch.max(predictions[..., 20:21], predictions[..., 25:26])
#no_object_loss = self.mse(
# torch.flatten((1 - exists_box) * max_no_obj, start_dim=1),
# torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
#)
no_object_loss = self.mse(
torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
)
no_object_loss += self.mse(
torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
)
# ================== #
# FOR CLASS LOSS #
# ================== #
class_loss = self.mse(
torch.flatten(exists_box * predictions[..., :20], end_dim=-2,),
torch.flatten(exists_box * target[..., :20], end_dim=-2,),
)
loss = (
self.lambda_coord * box_loss # first two rows in paper
+ object_loss # third row in paper
+ self.lambda_noobj * no_object_loss # forth row
+ class_loss # fifth row
)
return loss