ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

decision_tree_classifier.py

(956B)


      1 from __future__ import division, print_function
      2 import numpy as np
      3 from sklearn import datasets
      4 import matplotlib.pyplot as plt
      5 import sys
      6 import os
      7 
      8 # Import helper functions
      9 from mlfromscratch.utils import train_test_split, standardize, accuracy_score
     10 from mlfromscratch.utils import mean_squared_error, calculate_variance, Plot
     11 from mlfromscratch.supervised_learning import ClassificationTree
     12 
     13 def main():
     14 
     15     print ("-- Classification Tree --")
     16 
     17     data = datasets.load_iris()
     18     X = data.data
     19     y = data.target
     20 
     21     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
     22 
     23     clf = ClassificationTree()
     24     clf.fit(X_train, y_train)
     25     y_pred = clf.predict(X_test)
     26 
     27     accuracy = accuracy_score(y_test, y_pred)
     28 
     29     print ("Accuracy:", accuracy)
     30 
     31     Plot().plot_in_2d(X_test, y_pred, 
     32         title="Decision Tree", 
     33         accuracy=accuracy, 
     34         legend_labels=data.target_names)
     35 
     36 
     37 if __name__ == "__main__":
     38     main()