ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
bayesian_regression.py
(2472B)
1 import numpy as np
2 import pandas as pd
3 import matplotlib.pyplot as plt
4
5 # Import helper functions
6 from mlfromscratch.utils.data_operation import mean_squared_error
7 from mlfromscratch.utils.data_manipulation import train_test_split, polynomial_features
8 from mlfromscratch.supervised_learning import BayesianRegression
9
10 def main():
11
12 # Load temperature data
13 data = pd.read_csv('mlfromscratch/data/TempLinkoping2016.txt', sep="\t")
14
15 time = np.atleast_2d(data["time"].values).T
16 temp = np.atleast_2d(data["temp"].values).T
17
18 X = time # fraction of the year [0, 1]
19 y = temp
20
21 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
22
23 n_samples, n_features = np.shape(X)
24
25 # Prior parameters
26 # - Weights are assumed distr. according to a Normal distribution
27 # - The variance of the weights are assumed distributed according to
28 # a scaled inverse chi-squared distribution.
29 # High prior uncertainty!
30 # Normal
31 mu0 = np.array([0] * n_features)
32 omega0 = np.diag([.0001] * n_features)
33 # Scaled inverse chi-squared
34 nu0 = 1
35 sigma_sq0 = 100
36
37 # The credible interval
38 cred_int = 10
39
40 clf = BayesianRegression(n_draws=2000,
41 poly_degree=4,
42 mu0=mu0,
43 omega0=omega0,
44 nu0=nu0,
45 sigma_sq0=sigma_sq0,
46 cred_int=cred_int)
47 clf.fit(X_train, y_train)
48 y_pred = clf.predict(X_test)
49
50 mse = mean_squared_error(y_test, y_pred)
51
52 # Get prediction line
53 y_pred_, y_lower_, y_upper_ = clf.predict(X=X, eti=True)
54
55 # Print the mean squared error
56 print ("Mean Squared Error:", mse)
57
58 # Color map
59 cmap = plt.get_cmap('viridis')
60
61 # Plot the results
62 m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)
63 m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)
64 p1 = plt.plot(366 * X, y_pred_, color="black", linewidth=2, label="Prediction")
65 p2 = plt.plot(366 * X, y_lower_, color="gray", linewidth=2, label="{0}% Credible Interval".format(cred_int))
66 p3 = plt.plot(366 * X, y_upper_, color="gray", linewidth=2)
67 plt.axis((0, 366, -20, 25))
68 plt.suptitle("Bayesian Regression")
69 plt.title("MSE: %.2f" % mse, fontsize=10)
70 plt.xlabel('Day')
71 plt.ylabel('Temperature in Celcius')
72 plt.legend(loc='lower right')
73 # plt.legend((m1, m2), ("Training data", "Test data"), loc='lower right')
74 plt.legend(loc='lower right')
75
76 plt.show()
77
78 if __name__ == "__main__":
79 main()