ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

gradient_boosting_regressor.py

(1849B)


      1 from __future__ import division, print_function
      2 import numpy as np
      3 import pandas as pd
      4 import matplotlib.pyplot as plt
      5 import progressbar
      6 
      7 from mlfromscratch.utils import train_test_split, standardize, to_categorical
      8 from mlfromscratch.utils import mean_squared_error, accuracy_score, Plot
      9 from mlfromscratch.utils.loss_functions import SquareLoss
     10 from mlfromscratch.utils.misc import bar_widgets
     11 from mlfromscratch.supervised_learning import GradientBoostingRegressor
     12 
     13 
     14 def main():
     15     print ("-- Gradient Boosting Regression --")
     16 
     17     # Load temperature data
     18     data = pd.read_csv('mlfromscratch/data/TempLinkoping2016.txt', sep="\t")
     19 
     20     time = np.atleast_2d(data["time"].values).T
     21     temp = np.atleast_2d(data["temp"].values).T
     22 
     23     X = time.reshape((-1, 1))               # Time. Fraction of the year [0, 1]
     24     X = np.insert(X, 0, values=1, axis=1)   # Insert bias term
     25     y = temp[:, 0]                          # Temperature. Reduce to one-dim
     26 
     27     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
     28 
     29     model = GradientBoostingRegressor()
     30     model.fit(X_train, y_train)
     31     y_pred = model.predict(X_test)
     32 
     33     y_pred_line = model.predict(X)
     34 
     35     # Color map
     36     cmap = plt.get_cmap('viridis')
     37 
     38     mse = mean_squared_error(y_test, y_pred)
     39 
     40     print ("Mean Squared Error:", mse)
     41 
     42     # Plot the results
     43     m1 = plt.scatter(366 * X_train[:, 1], y_train, color=cmap(0.9), s=10)
     44     m2 = plt.scatter(366 * X_test[:, 1], y_test, color=cmap(0.5), s=10)
     45     m3 = plt.scatter(366 * X_test[:, 1], y_pred, color='black', s=10)
     46     plt.suptitle("Regression Tree")
     47     plt.title("MSE: %.2f" % mse, fontsize=10)
     48     plt.xlabel('Day')
     49     plt.ylabel('Temperature in Celcius')
     50     plt.legend((m1, m2, m3), ("Training data", "Test data", "Prediction"), loc='lower right')
     51     plt.show()
     52 
     53 
     54 if __name__ == "__main__":
     55     main()