ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
elastic_net.py
(1999B)
1 from __future__ import print_function
2 import matplotlib.pyplot as plt
3 import numpy as np
4 import pandas as pd
5 # Import helper functions
6 from mlfromscratch.supervised_learning import ElasticNet
7 from mlfromscratch.utils import k_fold_cross_validation_sets, normalize, mean_squared_error
8 from mlfromscratch.utils import train_test_split, polynomial_features, Plot
9
10
11 def main():
12
13 # Load temperature data
14 data = pd.read_csv('mlfromscratch/data/TempLinkoping2016.txt', sep="\t")
15
16 time = np.atleast_2d(data["time"].values).T
17 temp = data["temp"].values
18
19 X = time # fraction of the year [0, 1]
20 y = temp
21
22 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
23
24 poly_degree = 13
25
26 model = ElasticNet(degree=15,
27 reg_factor=0.01,
28 l1_ratio=0.7,
29 learning_rate=0.001,
30 n_iterations=4000)
31 model.fit(X_train, y_train)
32
33 # Training error plot
34 n = len(model.training_errors)
35 training, = plt.plot(range(n), model.training_errors, label="Training Error")
36 plt.legend(handles=[training])
37 plt.title("Error Plot")
38 plt.ylabel('Mean Squared Error')
39 plt.xlabel('Iterations')
40 plt.show()
41
42 y_pred = model.predict(X_test)
43 mse = mean_squared_error(y_test, y_pred)
44 print ("Mean squared error: %s (given by reg. factor: %s)" % (mse, 0.05))
45
46 y_pred_line = model.predict(X)
47
48 # Color map
49 cmap = plt.get_cmap('viridis')
50
51 # Plot the results
52 m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)
53 m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)
54 plt.plot(366 * X, y_pred_line, color='black', linewidth=2, label="Prediction")
55 plt.suptitle("Elastic Net")
56 plt.title("MSE: %.2f" % mse, fontsize=10)
57 plt.xlabel('Day')
58 plt.ylabel('Temperature in Celcius')
59 plt.legend((m1, m2), ("Training data", "Test data"), loc='lower right')
60 plt.show()
61
62 if __name__ == "__main__":
63 main()