ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

linear_discriminant_analysis.py

(1395B)


      1 from __future__ import print_function, division
      2 import numpy as np
      3 from mlfromscratch.utils import calculate_covariance_matrix, normalize, standardize
      4 
      5 class LDA():
      6     """The Linear Discriminant Analysis classifier, also known as Fisher's linear discriminant.
      7     Can besides from classification also be used to reduce the dimensionaly of the dataset.
      8     """
      9     def __init__(self):
     10         self.w = None
     11 
     12     def transform(self, X, y):
     13         self.fit(X, y)
     14         # Project data onto vector
     15         X_transform = X.dot(self.w)
     16         return X_transform
     17 
     18     def fit(self, X, y):
     19         # Separate data by class
     20         X1 = X[y == 0]
     21         X2 = X[y == 1]
     22 
     23         # Calculate the covariance matrices of the two datasets
     24         cov1 = calculate_covariance_matrix(X1)
     25         cov2 = calculate_covariance_matrix(X2)
     26         cov_tot = cov1 + cov2
     27 
     28         # Calculate the mean of the two datasets
     29         mean1 = X1.mean(0)
     30         mean2 = X2.mean(0)
     31         mean_diff = np.atleast_1d(mean1 - mean2)
     32 
     33         # Determine the vector which when X is projected onto it best separates the
     34         # data by class. w = (mean1 - mean2) / (cov1 + cov2)
     35         self.w = np.linalg.pinv(cov_tot).dot(mean_diff)
     36 
     37     def predict(self, X):
     38         y_pred = []
     39         for sample in X:
     40             h = sample.dot(self.w)
     41             y = 1 * (h < 0)
     42             y_pred.append(y)
     43         return y_pred