ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

linear_discriminant_analysis.py

(929B)


      1 from __future__ import print_function
      2 from sklearn import datasets
      3 import matplotlib.pyplot as plt
      4 import numpy as np
      5 
      6 from mlfromscratch.supervised_learning import LDA
      7 from mlfromscratch.utils import calculate_covariance_matrix, accuracy_score
      8 from mlfromscratch.utils import normalize, standardize, train_test_split, Plot
      9 from mlfromscratch.unsupervised_learning import PCA
     10 
     11 def main():
     12     # Load the dataset
     13     data = datasets.load_iris()
     14     X = data.data
     15     y = data.target
     16 
     17     # Three -> two classes
     18     X = X[y != 2]
     19     y = y[y != 2]
     20 
     21     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
     22 
     23     # Fit and predict using LDA
     24     lda = LDA()
     25     lda.fit(X_train, y_train)
     26     y_pred = lda.predict(X_test)
     27 
     28     accuracy = accuracy_score(y_test, y_pred)
     29 
     30     print ("Accuracy:", accuracy)
     31 
     32     Plot().plot_in_2d(X_test, y_pred, title="LDA", accuracy=accuracy)
     33 
     34 if __name__ == "__main__":
     35     main()