ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
linear_discriminant_analysis.py
(929B)
1 from __future__ import print_function
2 from sklearn import datasets
3 import matplotlib.pyplot as plt
4 import numpy as np
5
6 from mlfromscratch.supervised_learning import LDA
7 from mlfromscratch.utils import calculate_covariance_matrix, accuracy_score
8 from mlfromscratch.utils import normalize, standardize, train_test_split, Plot
9 from mlfromscratch.unsupervised_learning import PCA
10
11 def main():
12 # Load the dataset
13 data = datasets.load_iris()
14 X = data.data
15 y = data.target
16
17 # Three -> two classes
18 X = X[y != 2]
19 y = y[y != 2]
20
21 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
22
23 # Fit and predict using LDA
24 lda = LDA()
25 lda.fit(X_train, y_train)
26 y_pred = lda.predict(X_test)
27
28 accuracy = accuracy_score(y_test, y_pred)
29
30 print ("Accuracy:", accuracy)
31
32 Plot().plot_in_2d(X_test, y_pred, title="LDA", accuracy=accuracy)
33
34 if __name__ == "__main__":
35 main()