ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

linear_regression.py

(1610B)


      1 import numpy as np
      2 import pandas as pd
      3 import matplotlib.pyplot as plt
      4 from sklearn.datasets import make_regression
      5 
      6 from mlfromscratch.utils import train_test_split, polynomial_features
      7 from mlfromscratch.utils import mean_squared_error, Plot
      8 from mlfromscratch.supervised_learning import LinearRegression
      9 
     10 def main():
     11 
     12     X, y = make_regression(n_samples=100, n_features=1, noise=20)
     13 
     14     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
     15 
     16     n_samples, n_features = np.shape(X)
     17 
     18     model = LinearRegression(n_iterations=100)
     19 
     20     model.fit(X_train, y_train)
     21     
     22     # Training error plot
     23     n = len(model.training_errors)
     24     training, = plt.plot(range(n), model.training_errors, label="Training Error")
     25     plt.legend(handles=[training])
     26     plt.title("Error Plot")
     27     plt.ylabel('Mean Squared Error')
     28     plt.xlabel('Iterations')
     29     plt.show()
     30 
     31     y_pred = model.predict(X_test)
     32     mse = mean_squared_error(y_test, y_pred)
     33     print ("Mean squared error: %s" % (mse))
     34 
     35     y_pred_line = model.predict(X)
     36 
     37     # Color map
     38     cmap = plt.get_cmap('viridis')
     39 
     40     # Plot the results
     41     m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)
     42     m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)
     43     plt.plot(366 * X, y_pred_line, color='black', linewidth=2, label="Prediction")
     44     plt.suptitle("Linear Regression")
     45     plt.title("MSE: %.2f" % mse, fontsize=10)
     46     plt.xlabel('Day')
     47     plt.ylabel('Temperature in Celcius')
     48     plt.legend((m1, m2), ("Training data", "Test data"), loc='lower right')
     49     plt.show()
     50 
     51 if __name__ == "__main__":
     52     main()