ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
linear_discriminant_analysis.py
(1395B)
1 from __future__ import print_function, division
2 import numpy as np
3 from mlfromscratch.utils import calculate_covariance_matrix, normalize, standardize
4
5 class LDA():
6 """The Linear Discriminant Analysis classifier, also known as Fisher's linear discriminant.
7 Can besides from classification also be used to reduce the dimensionaly of the dataset.
8 """
9 def __init__(self):
10 self.w = None
11
12 def transform(self, X, y):
13 self.fit(X, y)
14 # Project data onto vector
15 X_transform = X.dot(self.w)
16 return X_transform
17
18 def fit(self, X, y):
19 # Separate data by class
20 X1 = X[y == 0]
21 X2 = X[y == 1]
22
23 # Calculate the covariance matrices of the two datasets
24 cov1 = calculate_covariance_matrix(X1)
25 cov2 = calculate_covariance_matrix(X2)
26 cov_tot = cov1 + cov2
27
28 # Calculate the mean of the two datasets
29 mean1 = X1.mean(0)
30 mean2 = X2.mean(0)
31 mean_diff = np.atleast_1d(mean1 - mean2)
32
33 # Determine the vector which when X is projected onto it best separates the
34 # data by class. w = (mean1 - mean2) / (cov1 + cov2)
35 self.w = np.linalg.pinv(cov_tot).dot(mean_diff)
36
37 def predict(self, X):
38 y_pred = []
39 for sample in X:
40 h = sample.dot(self.w)
41 y = 1 * (h < 0)
42 y_pred.append(y)
43 return y_pred