ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
decision_tree_regressor.py
(1599B)
1 from __future__ import division, print_function
2 import numpy as np
3 import matplotlib.pyplot as plt
4 import pandas as pd
5
6 from mlfromscratch.utils import train_test_split, standardize, accuracy_score
7 from mlfromscratch.utils import mean_squared_error, calculate_variance, Plot
8 from mlfromscratch.supervised_learning import RegressionTree
9
10 def main():
11
12 print ("-- Regression Tree --")
13
14 # Load temperature data
15 data = pd.read_csv('mlfromscratch/data/TempLinkoping2016.txt', sep="\t")
16
17 time = np.atleast_2d(data["time"].values).T
18 temp = np.atleast_2d(data["temp"].values).T
19
20 X = standardize(time) # Time. Fraction of the year [0, 1]
21 y = temp[:, 0] # Temperature. Reduce to one-dim
22
23 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
24
25 model = RegressionTree()
26 model.fit(X_train, y_train)
27 y_pred = model.predict(X_test)
28
29 y_pred_line = model.predict(X)
30
31 # Color map
32 cmap = plt.get_cmap('viridis')
33
34 mse = mean_squared_error(y_test, y_pred)
35
36 print ("Mean Squared Error:", mse)
37
38 # Plot the results
39 # Plot the results
40 m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)
41 m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)
42 m3 = plt.scatter(366 * X_test, y_pred, color='black', s=10)
43 plt.suptitle("Regression Tree")
44 plt.title("MSE: %.2f" % mse, fontsize=10)
45 plt.xlabel('Day')
46 plt.ylabel('Temperature in Celcius')
47 plt.legend((m1, m2, m3), ("Training data", "Test data", "Prediction"), loc='lower right')
48 plt.show()
49
50
51 if __name__ == "__main__":
52 main()