ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
decision_tree_classifier.py
(956B)
1 from __future__ import division, print_function
2 import numpy as np
3 from sklearn import datasets
4 import matplotlib.pyplot as plt
5 import sys
6 import os
7
8 # Import helper functions
9 from mlfromscratch.utils import train_test_split, standardize, accuracy_score
10 from mlfromscratch.utils import mean_squared_error, calculate_variance, Plot
11 from mlfromscratch.supervised_learning import ClassificationTree
12
13 def main():
14
15 print ("-- Classification Tree --")
16
17 data = datasets.load_iris()
18 X = data.data
19 y = data.target
20
21 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
22
23 clf = ClassificationTree()
24 clf.fit(X_train, y_train)
25 y_pred = clf.predict(X_test)
26
27 accuracy = accuracy_score(y_test, y_pred)
28
29 print ("Accuracy:", accuracy)
30
31 Plot().plot_in_2d(X_test, y_pred,
32 title="Decision Tree",
33 accuracy=accuracy,
34 legend_labels=data.target_names)
35
36
37 if __name__ == "__main__":
38 main()