-
Notifications
You must be signed in to change notification settings - Fork 0
/
knn.py
70 lines (59 loc) · 2.52 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
class NearestNeighbour(object):
def __init__(self):
pass
def train(self, X, y):
# X is N*D where each row is an example.
# Y is 1-dimension of size N
# the nearest neighbour classifier simply remembers
# all the training data
self.Xtr = X
self.ytr = y
def predict(self, X, K):
labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# X is N * D where each row is an example we wish to predict label for
num_test = X.shape[0]
# Lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# Loop over all test rows
for i in range(num_test):
# using the L1 distance (sum of absolute)
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
# if k = 1 take the argmin index of ytr and return class
if (K == 1):
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
print("Test case " + str(i) + ": \t Predicted label is:" + labels[Ypred[i]])
continue
# sort the distance
min_index = np.argsort(distances, -1,'mergesort')
# K-Nearest component
classes = np.zeros(10)
for j in range(K):
if self.ytr[min_index[j]] == 0:
classes[0] += 1
elif self.ytr[min_index[j]]==1:
classes[1] += 1
elif self.ytr[min_index[j]]==2:
classes[2] += 1
elif self.ytr[min_index[j]]==3:
classes[3] += 1
elif self.ytr[min_index[j]]==4:
classes[4] += 1
elif self.ytr[min_index[j]]==5:
classes[5] += 1
elif self.ytr[min_index[j]]==6:
classes[6] += 1
elif self.ytr[min_index[j]]==7:
classes[7] += 1
elif self.ytr[min_index[j]]==8:
classes[8] += 1
elif self.ytr[min_index[j]]==9:
classes[9] += 1
else:
print('Error - Invalid class')
# predict the label of the nearest example
Ypred[i] = np.argmax(classes)
print("Test case " + str(i) + ": \t Predicted label is:" + labels[Ypred[i]])
# print(Ypred[i])
return Ypred