-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
55 lines (48 loc) · 2.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
import math
import numpy as np
import sys
logger = logging.getLogger()
logger.setLevel(logging.INFO)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
def angular_error_scalar(output, labels, debug=False):
output_blue = 1 - output[:, 0] - output[:, 1]
labels_blue = 1 - labels[:, 0] - labels[:, 1]
output = np.concatenate((output, output_blue.reshape(1, len(output)).T), axis=1)
labels = np.concatenate((labels, labels_blue.reshape(1, len(labels)).T), axis=1)
# def dot_product(a, b):
# # return np.sum(np.multiply(a, b))
# return np.diag(np.dot(a, b.T))
#
# numerator = dot_product(output, labels)
# output_norm = np.sqrt(dot_product(output, output))
# labels_norm = np.sqrt(dot_product(labels, labels))
# denominator = np.multiply(output_norm, labels_norm)
# temp = np.dstack((output, labels))
# print(temp.shape, numerator.shape, output_norm.shape, denominator.shape)
# temp = np.dstack((temp, numerator, output_norm, labels_norm, denominator))
# print(temp)
# return np.arccos(numerator / denominator)
angular_errors = []
for i in range(len(output)):
# num = np.dot(output[i], labels[i])
num = output[i][0] * labels[i][0] + \
output[i][1] * labels[i][1] + \
output[i][2] * labels[i][2]
# denum = math.sqrt(np.dot(output[i], output[i])) * \
# math.sqrt(np.dot(labels[i], labels[i]))
left = math.sqrt(output[i][0] * output[i][0] + \
output[i][1] * output[i][1] + \
output[i][2] * output[i][2])
right = math.sqrt(labels[i][0] * labels[i][0] + \
labels[i][1] * labels[i][1] + \
labels[i][2] * labels[i][2])
denum = left * right
angular_error = math.acos(min(num / denum, 1))
# print(output[i], labels[i], num, denum, angular_error)
angular_errors.append(angular_error)
return np.mean(np.array(angular_errors) * 180 / math.pi)