ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
k_nearest_neighbors.py
(1265B)
1 from __future__ import print_function, division
2 import numpy as np
3 from mlfromscratch.utils import euclidean_distance
4
5 class KNN():
6 """ K Nearest Neighbors classifier.
7
8 Parameters:
9 -----------
10 k: int
11 The number of closest neighbors that will determine the class of the
12 sample that we wish to predict.
13 """
14 def __init__(self, k=5):
15 self.k = k
16
17 def _vote(self, neighbor_labels):
18 """ Return the most common class among the neighbor samples """
19 counts = np.bincount(neighbor_labels.astype('int'))
20 return counts.argmax()
21
22 def predict(self, X_test, X_train, y_train):
23 y_pred = np.empty(X_test.shape[0])
24 # Determine the class of each sample
25 for i, test_sample in enumerate(X_test):
26 # Sort the training samples by their distance to the test sample and get the K nearest
27 idx = np.argsort([euclidean_distance(test_sample, x) for x in X_train])[:self.k]
28 # Extract the labels of the K nearest neighboring training samples
29 k_nearest_neighbors = np.array([y_train[i] for i in idx])
30 # Label sample as the most common class label
31 y_pred[i] = self._vote(k_nearest_neighbors)
32
33 return y_pred
34