ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

Untitled.ipynb

(3512B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "code",
      5    "execution_count": null,
      6    "metadata": {},
      7    "outputs": [],
      8    "source": [
      9     "# Load temperature data\n",
     10     "data = pd.read_csv('mlfromscratch/data/TempLinkoping2016.txt', sep=\"\\t\")\n",
     11     "\n",
     12     "time = np.atleast_2d(data[\"time\"].values).T\n",
     13     "temp = data[\"temp\"].values\n",
     14     "\n",
     15     "X = time # fraction of the year [0, 1]\n",
     16     "y = temp\n",
     17     "\n",
     18     "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)\n",
     19     "\n",
     20     "poly_degree = 15\n",
     21     "\n",
     22     "# Finding regularization constant using cross validation\n",
     23     "lowest_error = float(\"inf\")\n",
     24     "best_reg_factor = None\n",
     25     "print (\"Finding regularization constant using cross validation:\")\n",
     26     "k = 10\n",
     27     "for reg_factor in np.arange(0, 0.1, 0.01):\n",
     28     "    cross_validation_sets = k_fold_cross_validation_sets(\n",
     29     "        X_train, y_train, k=k)\n",
     30     "    mse = 0\n",
     31     "    for _X_train, _X_test, _y_train, _y_test in cross_validation_sets:\n",
     32     "        model = PolynomialRidgeRegression(degree=poly_degree, \n",
     33     "                                        reg_factor=reg_factor,\n",
     34     "                                        learning_rate=0.001,\n",
     35     "                                        n_iterations=10000)\n",
     36     "        model.fit(_X_train, _y_train)\n",
     37     "        y_pred = model.predict(_X_test)\n",
     38     "        _mse = mean_squared_error(_y_test, y_pred)\n",
     39     "        mse += _mse\n",
     40     "    mse /= k\n",
     41     "\n",
     42     "    # Print the mean squared error\n",
     43     "    print (\"\\tMean Squared Error: %s (regularization: %s)\" % (mse, reg_factor))\n",
     44     "\n",
     45     "    # Save reg. constant that gave lowest error\n",
     46     "    if mse < lowest_error:\n",
     47     "        best_reg_factor = reg_factor\n",
     48     "        lowest_error = mse\n",
     49     "\n",
     50     "# Make final prediction\n",
     51     "model = PolynomialRidgeRegression(degree=poly_degree, \n",
     52     "                                reg_factor=best_reg_factor,\n",
     53     "                                learning_rate=0.001,\n",
     54     "                                n_iterations=10000)\n",
     55     "model.fit(X_train, y_train)\n",
     56     "y_pred = model.predict(X_test)\n",
     57     "mse = mean_squared_error(y_test, y_pred)\n",
     58     "print (\"Mean squared error: %s (given by reg. factor: %s)\" % (lowest_error, best_reg_factor))\n",
     59     "\n",
     60     "y_pred_line = model.predict(X)\n",
     61     "\n",
     62     "# Color map\n",
     63     "cmap = plt.get_cmap('viridis')\n",
     64     "\n",
     65     "# Plot the results\n",
     66     "m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)\n",
     67     "m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)\n",
     68     "plt.plot(366 * X, y_pred_line, color='black', linewidth=2, label=\"Prediction\")\n",
     69     "plt.suptitle(\"Polynomial Ridge Regression\")\n",
     70     "plt.title(\"MSE: %.2f\" % mse, fontsize=10)\n",
     71     "plt.xlabel('Day')\n",
     72     "plt.ylabel('Temperature in Celcius')\n",
     73     "plt.legend((m1, m2), (\"Training data\", \"Test data\"), loc='lower right')\n",
     74     "plt.show()"
     75    ]
     76   }
     77  ],
     78  "metadata": {
     79   "kernelspec": {
     80    "display_name": "Python 3",
     81    "language": "python",
     82    "name": "python3"
     83   },
     84   "language_info": {
     85    "codemirror_mode": {
     86     "name": "ipython",
     87     "version": 3
     88    },
     89    "file_extension": ".py",
     90    "mimetype": "text/x-python",
     91    "name": "python",
     92    "nbconvert_exporter": "python",
     93    "pygments_lexer": "ipython3",
     94    "version": "3.8.7"
     95   }
     96  },
     97  "nbformat": 4,
     98  "nbformat_minor": 4
     99 }