-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar100_knn.py
65 lines (52 loc) · 2.38 KB
/
cifar100_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import math
import operator
class NearestNeighbour(object):
def __init__(self):
pass
def train(self, X, y, z):
# X is N*D where each row is an example.
# Y is 1-dimension of size N
# the nearest neighbour classifier simply remembers
# all the training data
self.Xtr = X
self.ytr = y
self.ztr = z
def predict(self, X, K):
# X is N * D where each row is an example we wish to predict label for
num_test = X.shape[0]
# Lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
Zpred = np.zeros(num_test, dtype = self.ztr.dtype)
# Loop over all test rows
for i in range(num_test):
# using the L1 distance (sum of absolute)
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
if (K == 1):
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
Zpred[i] = self.ztr[min_index]
print("Test case " + str(i) + ": \t Predicted: " + str(Ypred[i]) + "\t Predicted: " + str(Zpred[i]))
continue
# sort the distance
min_index = np.argsort(distances, -1,'mergesort')
# class dictionaries to add classes and count values
coarse_classes = {}
fine_classes = {}
# take the min distances if class is in dictionary increment 1
# otherwise add to dictionary
for j in range(K):
if self.ytr[min_index[j]] in fine_classes.keys():
fine_classes[self.ytr[min_index[j]]] += 1
else:
fine_classes[self.ytr[min_index[j]]] = 1
if self.ztr[min_index[j]] in coarse_classes.keys():
coarse_classes[self.ztr[min_index[j]]] += 1
else:
coarse_classes[self.ztr[min_index[j]]] = 1
# predict the label of the nearest example
Zpred[i] = max(coarse_classes.items(), key = operator.itemgetter(1))[0]
Ypred[i] = max(fine_classes.items(), key = operator.itemgetter(1))[0]
print("Test case " + str(i) + ": \t Predicted: " + str(Ypred[i]) + "\t Predicted: " + str(Zpred[i]))
# print(Ypred[i])
return Ypred, Zpred