ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
04_sentiment_analysis.ipynb
(31989B)
1 {
2 "cells": [
3 {
4 "cell_type": "markdown",
5 "metadata": {},
6 "source": [
7 "# LSTM & Word Embeddings for Sentiment Classification"
8 ]
9 },
10 {
11 "cell_type": "markdown",
12 "metadata": {},
13 "source": [
14 "RNNs are commonly applied to various natural language processing tasks. We've already encountered sentiment analysis using text data in part three of [this book](https://www.amazon.com/Hands-Machine-Learning-Algorithmic-Trading-ebook/dp/B07JLFH7C5/ref=sr_1_2?ie=UTF8&qid=1548455634&sr=8-2&keywords=machine+learning+algorithmic+trading).\n",
15 "\n",
16 "We are now going to illustrate how to apply an RNN model to text data to detect positive or negative sentiment (which can easily be extended to a finer-grained sentiment scale). We are going to use word embeddings to represent the tokens in the documents. We covered word embeddings in Chapter 15, Word Embeddings. They are an excellent technique to convert text into a continuous vector representation such that the relative location of words in the latent space encodes useful semantic aspects based on the words' usage in context.\n",
17 "\n",
18 "We saw in the previous RNN example that Keras has a built-in embedding layer that allows us to train vector representations specific to the task at hand. Alternatively, we can use pretrained vectors."
19 ]
20 },
21 {
22 "cell_type": "markdown",
23 "metadata": {},
24 "source": [
25 "## Imports & Settings"
26 ]
27 },
28 {
29 "cell_type": "code",
30 "execution_count": 1,
31 "metadata": {},
32 "outputs": [
33 {
34 "name": "stderr",
35 "output_type": "stream",
36 "text": [
37 "Using TensorFlow backend.\n"
38 ]
39 }
40 ],
41 "source": [
42 "%matplotlib inline\n",
43 "import numpy as np\n",
44 "import pandas as pd\n",
45 "import matplotlib.pyplot as plt\n",
46 "import seaborn as sns\n",
47 "from datetime import datetime, date\n",
48 "from sklearn.metrics import mean_squared_error, roc_auc_score\n",
49 "from sklearn.preprocessing import minmax_scale\n",
50 "from keras.callbacks import ModelCheckpoint, EarlyStopping\n",
51 "from keras.datasets import imdb\n",
52 "from keras.models import Sequential, Model\n",
53 "from keras.layers import Dense, LSTM, GRU, Input, concatenate, Embedding, Reshape\n",
54 "from keras.preprocessing.sequence import pad_sequences\n",
55 "import keras\n",
56 "import keras.backend as K\n",
57 "import tensorflow as tf"
58 ]
59 },
60 {
61 "cell_type": "code",
62 "execution_count": 2,
63 "metadata": {},
64 "outputs": [],
65 "source": [
66 "sns.set_style('whitegrid')\n",
67 "np.random.seed(42)\n",
68 "K.clear_session()"
69 ]
70 },
71 {
72 "cell_type": "markdown",
73 "metadata": {},
74 "source": [
75 "## Load Reviews"
76 ]
77 },
78 {
79 "cell_type": "markdown",
80 "metadata": {},
81 "source": [
82 "To keep the data manageable, we will illustrate this use case with the IMDB reviews dataset, which contains 50,000 positive and negative movie reviews evenly split into a train and a test set, and with balanced labels in each dataset. The vocabulary consists of 88,586 tokens.\n",
83 "\n",
84 "The dataset is bundled into Keras and can be loaded so that each review is represented as an integer-encoded sequence. We can limit the vocabulary to num_words while filtering out frequent and likely less informative words using skip_top, as well as sentences longer than maxlen. We can also choose oov_char, which represents tokens we chose to exclude from the vocabulary on frequency grounds, as follows:"
85 ]
86 },
87 {
88 "cell_type": "code",
89 "execution_count": 3,
90 "metadata": {},
91 "outputs": [],
92 "source": [
93 "vocab_size = 20000"
94 ]
95 },
96 {
97 "cell_type": "code",
98 "execution_count": 4,
99 "metadata": {},
100 "outputs": [],
101 "source": [
102 "(X_train, y_train), (X_test, y_test) = imdb.load_data(seed=42, \n",
103 " skip_top=0,\n",
104 " maxlen=None, \n",
105 " oov_char=2, \n",
106 " index_from=3\n",
107 " num_words=vocab_size)"
108 ]
109 },
110 {
111 "cell_type": "code",
112 "execution_count": 5,
113 "metadata": {
114 "scrolled": true
115 },
116 "outputs": [
117 {
118 "name": "stderr",
119 "output_type": "stream",
120 "text": [
121 "/usr/local/lib/python3.5/dist-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
122 " return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval\n"
123 ]
124 },
125 {
126 "data": {
127 "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEACAYAAACznAEdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xt4U2W+L/DvStKkLU3T9JaUNo1cClRaEORSUemYGivUyyBFZxxHQRlnGEdx6+BxLrvPhrN1ztnDWGGPZwansx0d54qKF+JYsUVbgVGBYgXKpWhpCm2gbXqhl6RJ1vmDUuwUmrYkXbl8P8/jI0netfJbNOTbtd71vq8giqIIIiIKezKpCyAiosDAQCAiIgAMBCIi6sdAICIiAAwEIiLqx0AgIiIADAQiIurHQCAiIgAMBCIi6sdAICIiAIBC6gJG48CBA1CpVMO2cTgcXttQaONngPgZuMjhcOCaa64ZUdugCgSVSoXMzMxh29TU1HhtQ6GNnwHiZ+CimpqaEbflJSMiIgLAQCAion4MBCIiAsBAICKifgwEIiICwEAgIqJ+DAQiIgLAQCAion5BNTCN6Eq0dzvR6XD5bH9qlQKaaKXP9kckNQYChY1OhwsVx5p9tr/F0xIZCBRSeMmIiIgAMBCIiKgfA4GIiAAwEIiIqB8DgYiIADAQiIioHwOBiIgAMBCIiKgfA4GIiAAwEIiIqB8DgYiIADAQiIioHwOBiIgAMBCIiKgfA4GIiAAwEIiIqB8DgYiIADAQiIioHwOBiIgAMBCIiKjfiAKhoqIC+fn5MJvNePHFF4e87nQ68fjjj8NsNmPFihVoaGgYeG3Lli0wm83Iz89HZWXlwPMmkwm333477rzzTtx1110+OBQiIroSCm8N3G43NmzYgJdeegk6nQ6FhYUwmUyYOnXqQJutW7ciNjYWO3bsgMViwcaNG/H888+jtrYWFosFFosFNpsNq1atQmlpKeRyOQDg5ZdfRnx8vP+OjoiIRszrGUJ1dTWMRiMMBgOUSiUKCgpQVlY2qE15eTmWLVsGAMjPz8eePXsgiiLKyspQUFAApVIJg8EAo9GI6upq/xwJERFdEa9nCDabDXq9fuCxTqcb8qVus9mQkpJyfocKBdRqNex2O2w2G2bPnj1oW5vNNvD4oYcegiAIuOeee3DPPfd4LdbhcKCmpmbYNr29vV7bUGi73GfAqYhBY1Ojz96nJV5AZ9NJn+2PfIffA2PjNRD85S9/+Qt0Oh1aWlqwatUqTJ48GfPnzx92G5VKhczMzGHb1NTUeG1Doe1yn4EGezdS9KLP3ichMQFpWoPP9ke+w++Bi0YTjF4vGel0OjQ1NQ08ttls0Ol0Q9o0Np7/zcvlcqGzsxNarXbYbS/8PyEhAWazmZeSKGC5PSJO2Xvwzy9b8Mb+Bpxs6ZK6JCK/8HqGkJ2djbq6OlitVuh0OlgsFvzqV78a1MZkMmHbtm2YM2cOSktLkZOTA0EQYDKZ8OSTT2LVqlWw2Wyoq6vDrFmz0N3dDY/Hg5iYGHR3d2PXrl344Q9/6LeDJBopURTR3tMHq70H1tZuWO3dOGXvgctz8cyi1+WBMWGChFUS+YfXQFAoFCgqKsLq1avhdruxfPlyZGRkYNOmTcjKykJeXh4KCwuxbt06mM1maDQaFBcXAwAyMjKwZMkSLF26FHK5HEVFRZDL5WhpacEjjzwC4PxdTLfddhsWL17s3yMluozO3j5U1bfBau+GtbUbHb0uAIBCJmBiXBQWToqHIT4ahvholB5qwsmWbokrJvIPQRRF311U9bORXBfktUMarg+h4ljzkOe3VZ3CZ3WtiJ+gRHp8NNK0UUiPj4ZeEwmFbPBV1d0nmrG9uhH/69YZuH12CtK00X47Dho7fg9cNJq/C8k6lYkCRX1rF6bpYrBy0SSvbdPjo/u34VkChR5OXUFhrbfPjTMdDhjiR/ab/vmzBgH17FimEMRAoLDWYO+BCCB9hJd+FDIZUrVRPEOgkMRAoLBmtZ//Yh9NX0B6fDROt/fC6fL4qywiSTAQKKxZW7uRpFYhSikf8Tbp8dFwe0Qcs3X6sTKi8cdAoLAliiLqW7tHfLnoggsdywdPtfujLCLJMBAobLV2OdHtdI+4Q/kCdWQEtNEROHi6w0+VEUmDgUBh60L/gSE+atTbGuKjcegUA4FCCwOBwlZ9aw+Uchl0sZGj3jY9Phpnzzlwuq3HD5URSYOBQGHL2tqNNG0UZIIw6m0v9CPsr7f7uiwiyTAQKCz1uT1obO8Zdf/BBSmaKKgUMuw/2ebjyoikw0CgsHS6rQce8eJv+qMllwmYoVfzDIFCCgOBwtKFkcZp2tF3KF8wM1WDQ6fb0dvn9lVZRJJiIFBYsrZ2QxsdAXVkxJj3kTUxFn1uEYdOczwChQYGAoUlq33s/QcXZKVqAID9CBQyGAgUdtp7+tDe0zfm/oML4icoYYiPYj8ChQwGAoUda3//gcEHi9vMTddif70dQbTOFNFlMRAo7Fhbu6GQCUiJG/2AtH81N10LW4cDp9t7fVAZkbQYCBR26u3dmBgXNWR5zLGYm64FAOw/yctGFPwYCBRW3B4Rp+w9MFzB7aZfNyNFjcgIGfsRKCQwECisNLX3wuURr/gOowsi5DLMSovD/nreaUTBj4FAYaW+f4bTK73D6OvmpmtxmAPUKAQwECisWFu7oY5UQBM19gFp/2puehz63CIXzKGgx0CgsGJt7YZBGw1hDDOcXs5cY3/HMvsRKMgxEChstHU70dLl9OnlIgBIjFEhPT6aI5Yp6DEQKGwcbjy/wpmvOpS/bm56HAeoUdBjIFDYOHSqAzIBSI3zzS2nXzfXqMWZTgdOcQU1CmIjCoSKigrk5+fDbDbjxRdfHPK60+nE448/DrPZjBUrVqChoWHgtS1btsBsNiM/Px+VlZWDtnO73fjmN7+J73//+1d4GETeHTrdAX1sJJQK3/8eNDBAjbefUhDz+i/D7XZjw4YNKCkpgcViwfbt21FbWzuozdatWxEbG4sdO3Zg5cqV2LhxIwCgtrYWFosFFosFJSUlWL9+Pdzui7fmvfLKK5gyZYqPD4loKLdHxOHGDr9cLgKAGXo1oiLkHLFMQc1rIFRXV8NoNMJgMECpVKKgoABlZWWD2pSXl2PZsmUAgPz8fOzZsweiKKKsrAwFBQVQKpUwGAwwGo2orq4GADQ1NeHDDz9EYWGhHw6LaLATZ8+h2+n2WyAo5DLMStOgincaURDzGgg2mw16vX7gsU6ng81mG9ImJSUFAKBQKKBWq2G324fd9tlnn8W6desg88F8MkTeXPiiTvfBDKeXM9eoxaHTHRygRkFLIcWb7ty5E/Hx8cjKysInn3wy4u0cDgdqamqGbdPb2+u1DYW2S30GdlafRYxSDue5FjR2+WYMQku8gM6mkwOPk4UuuDwi3tlVjSzdlc+kSmPH74Gx8RoIOp0OTU1NA49tNht0Ot2QNo2NjdDr9XC5XOjs7IRWq73stuXl5SgvL0dFRQUcDgfOnTuHH//4xwN9D5ejUqmQmZk5bJuamhqvbSi0Xeoz8NV7Z5GVpsHElIk+e5+ExASkaQ0Dj5MMDqzfaUOrEIvMTPaNSYnfAxeNJhi9Xq/Jzs5GXV0drFYrnE4nLBYLTCbToDYmkwnbtm0DAJSWliInJweCIMBkMsFiscDpdMJqtaKurg6zZs3Ck08+iYqKCpSXl+O5555DTk6O1zAgGqvO3j4cO9OJmRNj/fo+iTEqGBOiOWKZgpbXMwSFQoGioiKsXr0abrcby5cvR0ZGBjZt2oSsrCzk5eWhsLAQ69atg9lshkajQXFxMQAgIyMDS5YswdKlSyGXy1FUVAS5XO73gyL6uuqGdogiMHOiBr19Hr++19x0LT6ubYYoij6dHoNoPIyoDyE3Nxe5ubmDnlu7du3An1UqFTZv3nzJbdesWYM1a9Zcdt8LFy7EwoULR1IG0Zhc6FDOTFGjqt6/E9DNTY/DtqpTaLD3+O2OJiJ/4S0+FPKq6tswNTkG6kjfzXB6OXPSOdEdBS8GAoU0URRRZW3DHEPcuLzfhQFqVRyxTEGIgUAhrb61G61dzoHf3P1NIZchO1WDA1YGAgUfBgKFtAu/qc9JH58zBAC4Jj0Oh093wOHiADUKLgwECmlV9XZEK+WYplOP23vOMcTB6fagprFz3N6TyBcYCBTSqqxtmJ0WB7ls/G4Bvab/bOQAO5YpyDAQKGT19rlx+HTHuF4uAoAUTRR0sSr2I1DQYSBQyDp4qh0ujzhuHcpfd40hjoFAQYeBQCHrQofyNeN0y+nXXWPQoq6lG/Yu57i/N9FYMRAoZFVZ7TDERyFJrRr3974QQjxLoGDCQKCQVVXfhjmG8b9cBACz0jSQCec7tYmCBQOBQlJjew8a23vHvUP5ggkqBabp1DxDoKAiyQI5RP52YGBAmv/OEFxuDxrs3Zd9PSM5Bh8eOwtra9eQmU/VKgU00Uq/1UY0FgwECklV1jYoFTJcneK/NRB6+jyoOtF62dcVchk6e13Ytv80Ev+lH2PxtEQGAgUcXjKikFRVb0fWxFgoFdJ9xA396zdbhzmLIAokDAQKOS6PiOqGdknGH3xdcqwKSoWMgUBBg4FAIecruxMOl0eyDuULZIKA1LgoWFt7JK2DaKQYCBRyjpztBeDfDuWRMmij0dTeiz63f5fuJPIFBgKFnCNnHUhWqzBREyl1KTDER8Etimhs41kCBT4GAoWcI2fPjz8IhEXuL3YsMxAo8DEQKKS0djlxutMVEJeLACA2KgKaqAh2LFNQYCBQSDlgPb8GwXitoTwSadooWFsZCBT4GAgUUqrq2yATgOw0jdSlDDBoo2Hv7sM5h0vqUoiGxUCgkFJx7CymxqsQrQycQfiG+P5+BJ4lUIBjIFDIqD3Tic8b2vGNyTFSlzJIalwUZAJHLFPgYyBQyHht3ynIZQK+MWmC1KUMolTIoIuNRAMHqFGAYyBQSHB7RGyrasA3piVBGxU4l4suMGijYbV3wyOKUpdCdFkjCoSKigrk5+fDbDbjxRdfHPK60+nE448/DrPZjBUrVqChoWHgtS1btsBsNiM/Px+VlZUAAIfDgcLCQtxxxx0oKCjA5s2bfXQ4FK521TbD1uHA8mvTpC7lkgzxUXC4PGjudEhdCtFleQ0Et9uNDRs2oKSkBBaLBdu3b0dtbe2gNlu3bkVsbCx27NiBlStXYuPGjQCA2tpaWCwWWCwWlJSUYP369XC73VAqlXj55Zfx9ttv480330RlZSUOHDjgnyOksPD6/gZooiKQl5ksdSmXlMYBahQEvAZCdXU1jEYjDAYDlEolCgoKUFZWNqhNeXk5li1bBgDIz8/Hnj17IIoiysrKUFBQAKVSCYPBAKPRiOrqagiCgAkTzl/ndblccLlcATGqlIJTZ28fSg814fbZKVAp5FKXc0lJahVUnPmUApzXQLDZbNDr9QOPdTodbDbbkDYpKSkAAIVCAbVaDbvdPuy2brcbd955JxYtWoRFixZh9uzZPjkgCj/vftGI3j4Pls8NzMtFwPmZT9O0UWjgracUwCTrfZPL5XjrrbfQ0dGBRx55BMeOHcO0adOG3cbhcKCmpmbYNr29vV7bUGj5Y+VppMVGQHWuETU1TZf9DDgVMWhsavTZ+85IVI5qf1qlB1+e7UX9qdNoiRfQ2XTSZ7XQYPweGBuvgaDT6dDU1DTw2GazQafTDWnT2NgIvV4Pl8uFzs5OaLXaEW0bGxuLhQsXorKy0msgqFQqZGZmDtumpqbGaxsKHfUt3Th45kusy5+Oq6+eCuDyn4EGezdS9L67yycqOhop+pQRt88UO7D31El4VBokJCYgTWvwWS00GL8HLhpNMHq9ZJSdnY26ujpYrVY4nU5YLBaYTKZBbUwmE7Zt2wYAKC0tRU5ODgRBgMlkgsVigdPphNVqRV1dHWbNmoXW1lZ0dHQAOJ/ku3fvxuTJk0dzjEQAzncmCwJw19xUqUvxKk0bBYAD1ChweT1DUCgUKCoqwurVq+F2u7F8+XJkZGRg06ZNyMrKQl5eHgoLC7Fu3TqYzWZoNBoUFxcDADIyMrBkyRIsXboUcrkcRUVFkMvlOHPmDJ5++mm43W6Ioohbb70VN910k98PlkKLxyPijaoGXD8lESmaKKnL8UodGYG46AjeaUQBa0R9CLm5ucjNzR303Nq1awf+rFKpLjuWYM2aNVizZs2g52bMmIE333xztLUSDfJZXSusrT14wjz8pcZAYtBGc04jClgcqUxB6/X9DZiglCN/pt574wBh0EahraePA9QoIDEQKCj1ON1494smLMlOCaiZTb2ZnHR+4r1P61olroRoKAYCBaX3DzfhnMMV0GMPLiVFEwl1pAK7T7RIXQrREAwECkqv7WtAalwUFk6Kl7qUUREEAdN1anz2VSv63B6pyyEahIFAQaepvRe7apuxfG4qZLLgm/Jkul6NLqcbe+vsUpdCNAgDgYLOtqpT8IjAXUF2ueiCqUkxUMgEfHj0jNSlEA3CQKCgIooiXt/fgHlGLa5KDKyFcEZKFSHHbEMcyo8wECiwMBAoqFQ3tKP2zLmAXfdgpBZNScDxM+c4JoECCgOBgsrr+xugUshQMGvkcwgFousmJwAALxtRQGEgUNBwuNx4+/PTuGWmHrGREVKXc0UM8VEwJkTzshEFFAYCBY2dR86grbsPy4NgIjtvBEHATdOTsftEC3r73FKXQwSAgUBB5LV9p5CkVuGGqYlSl+ITN81IhsPlwR4OUqMAwUCgoNByzoEPj57BsjmpUMhD42O7cFI8oiLkvGxEASM0/mVRyHvrwGm4PGLQTVUxnMgIOa6fmoCdR89AFH23cA/RWDEQKCi8vr8BWamxmK5XS12KT900IxkN9h7UnjkndSlEDAQKfEeaOnDodEdInR1c8I3pyQCAnbz9lAIAA4EC3hv7T0EhE3DH7IlSl+JzqXFRmKFXsx+BAgIDgQKay+3BtqpTuGlGMhJiVFKX4xc3zUjG3jo7Onr7pC6FwhwDgQJaZW0zznY6QvJy0QU3TU+GyyPi4+PNUpdCYY6BQAHt9X0N0EZHwDQjWepS/GZuehw0URG8bESSYyBQwGrv6cP7h224Y/ZEKBWh+1FVyGVYPC0JHx49C4+Ht5+SdEL3XxkFPUt1I5wuT9DPbDoSN01PQvM5Bw6ebpe6FApjDAQKWK/vb8DU5Bhkp2qkLsXvcqclQRCAnUfOSl0KhTEGAgWkr5q7sO+kHcvnpkEQgm+ZzNFKiFFhdlocyo/YpC6FwhgDgQLSG/sbIBOAZXOCf2bTkbo1S4/PG9rx5VmOWiZpMBAo4Hg8It7YfwrXT02EXhMpdTnj5q45qZDLBPx9b4PUpVCYUkhdANG/+udXLTjV1oOnbp0udSl+43J70GAfunxmzuR4bN1rxbcWpEEhu/j7mlqlgCZaOZ4lUhgaUSBUVFTgmWeegcfjwYoVK/Dwww8Pet3pdOKpp57CoUOHEBcXh+LiYqSlnb8zZMuWLXjttdcgk8nw85//HDfeeCMaGxvx1FNPoaWlBYIg4O6778YDDzzg+6OjoPT6vlOIUSlwy9V6qUvxm54+D6pOtA55flJCDHbVtuD3lXXITIkdeH7xtEQGAvmd10tGbrcbGzZsQElJCSwWC7Zv347a2tpBbbZu3YrY2Fjs2LEDK1euxMaNGwEAtbW1sFgssFgsKCkpwfr16+F2uyGXy/H000/j3Xffxd/+9jf8+c9/HrJPCk9dDhf+cbARBdkpiFLKpS5n3E3XqxGjUmDvSbvUpVAY8hoI1dXVMBqNMBgMUCqVKCgoQFlZ2aA25eXlWLZsGQAgPz8fe/bsgSiKKCsrQ0FBAZRKJQwGA4xGI6qrq5GcnIyZM2cCAGJiYjB58mTYbLy7goD3Djah2+kOi7EHlyKXCZiTHoejTR3o5NxGNM68BoLNZoNef/HUXafTDfnyttlsSElJAQAoFAqo1WrY7fYRbdvQ0ICamhrMnj37ig6EQsMbVQ1Ij4/G/Ku0UpcimWuNWnhEoKq+TepSKMxI2qnc1dWFxx57DD/96U8RExPjtb3D4UBNTc2wbXp7e722ofElU01Ar9v7WIKz55zYXduCb83Robp27HfauKDC58etQ54X5BFobGoc837/1YxE5Zj3523bFLUCn3x5FlPVfRAEAS3xAjqbTo611LDD74Gx8RoIOp0OTU1NA49tNht0Ot2QNo2NjdDr9XC5XOjs7IRWqx12276+Pjz22GO4/fbbccstt4yoWJVKhczMzGHb1NTUeG1D46vB3o19x7zP5PnhUTtEAKlJcTjUOvY5fRqbziJFnzLk+Tnp0Zd8fqyiose+P2/bXterwhtVp9Cn1MCYMAEJiQlI0xrGWmrY4ffARaMJRq+XjLKzs1FXVwer1Qqn0wmLxQKTyTSojclkwrZt2wAApaWlyMnJgSAIMJlMsFgscDqdsFqtqKurw6xZsyCKIn72s59h8uTJWLVq1SgPj0JRe08f9nzZgqsSohE/gXfTZKdqoJTLsI+dyzSOvJ4hKBQKFBUVYfXq1XC73Vi+fDkyMjKwadMmZGVlIS8vD4WFhVi3bh3MZjM0Gg2Ki4sBABkZGViyZAmWLl0KuVyOoqIiyOVy7N27F2+99RamTZuGO++8EwDwxBNPIDc3179HSwGpx+nGH3Z/BafLg9tDcFW0sVBFyJGdpkH1qXYUzPLdWQ3RcEbUh5Cbmzvky3rt2rUDf1apVNi8efMlt12zZg3WrFkz6Ll58+bh6NGjo62VQlCf24NXPzmJ5k4nHlh0FVI0UVKXFDDmGbXYd9KOg6faYb5a530DoivEqStIMh5RxNZ9DfiquQuF16ZharL3GwvCSXp8NBJjVNhbx8tGND4YCCQJURTx7heNOHiqHUuy9JhtiJO6pIAjCALmGbU42dqNky1dUpdDYYCBQJL4uLYZu0+04PopCbhhaqLU5QSsOelxkAmA5Ysm742JrhADgcbdAWsb/nGwCdmpGizJTgmL9Q7GSh0Zgek6Nf7xRSMcLrfU5VCIYyDQuKo9cw6v72vApMQJWHFtGmQMA69ypiTA3t2Htw6clroUCnEMBBo3p9t68KdPTiJRrcR9C41QyPnxG4mpSTGYkjQBv6/8CqI49gF7RN7wXySNC3uXEy/vrkNkhBwrF00Ky5lMx0oQBNwz34Cjtk5UHvc+4ptorBgI5HftPX14aXcd+jwerFx0FTRREVKXFHRuztQhSa3C7yq/lLoUCmEMBPKr3j43nn69Gm3dTnw35yroYsNnSUxfUipkWLnoKlQeb8aRpg6py6EQxUAgv3F7RDz6lyocPNWBu+cZMClxgtQlBbXvLExHVIQcJZVfSV0KhSgGAvmFKIooeusgdhy24fGbM5CVqpG6pKAXF63EinlpeOvAKZzp6JW6HApBDATyixd21uJPn9TjB7lTwnb1M3948PpJcHlEvLKHayOQ7zEQyOe27rVi4/vHsGxOKp7Kny51OSHlqsQJMGfq8OonJ9HtdEldDoUYBgL51M6jZ/D0G1/gxoxE/N/lsyCTceCZr31v8WS0dffh9X1jX1WO6FIYCOQzn1vb8MNX92OGXo3f3HctlAp+vPxhnlGL2YY4/P7jr+D2cKAa+Q7/xZJPnGzpwoN/+AwJMUq8tGo+YlSSLtcd0gRBwPdunIS6lm68+4Xv1ogmYiDQFWs+58D9//MpPKKIlx9cgGQ1xxr4260z9bg6JRbPWGpwzsG+BPINBgJdkS6HCw/+4TPYOnrx+5XzMSWJi9z4g8vtQYO9e+C/po5erL15Kmwdvfjf7xwaeL692yl1qRTEeF5PY9bn9uCRP+/HwVPt2PLdeZibrpW6pJDV0+dB1YnWIc/PnxSPv+9tQJI6EhPjorB4WiI00UoJKqRQwDMEGhNRFPGzbV/gw6Nn8Z/fzOaavxLJv1qPaJUCbx04BQ9nQqUrxECgMSnecQx/39uAx/IycO/CdKnLCVtRSjmWZulhtffgs7qhZxBEo8FAoFH70ycnsbm8FnfPS8O/3ZwhdTlh7xpDHCYnTkDpoSa0drEPgcaOgUCj8v6hJvz7mwdx0/QkPLMsm8tfBgBBEHDHNRPR5xLxws5aqcuhIMZAoBHbd9KOR/9ShexUDV74zlxEcMWzgJGsjsTiaYkoPWTDR8fOSl0OBSn+i6YRqT1zDg+9/BlSNJH4/cr5iFbyBrVA843pyZiUOAFP/v0AznRyNlQaPQYCDavP7UFJ5ZdY9sIuKGQCXn5wARJjVFKXRZcQIZdhw50zcc7hwr/97QCntaBRYyDQZe2qbcbSTZX4T0sN5hq1eO0Hi2BM4CI3gWxS4gSsv2MmdtW24Dcfsj+BRofn/TTEqbYePGM5jHe/aIIhPgq/u38ebs5MZgdykLh7ngG7alvw3I5jWDApAQsmxUtdEgWJEZ0hVFRUID8/H2azGS+++OKQ151OJx5//HGYzWasWLECDQ0Xp+XdsmULzGYz8vPzUVlZOfD8T37yE1x33XW47bbbfHAY5Au9fW78uvw48n71IcqPnMET5mnY8W+5MF+tYxgEEUEQ8MyyLBjio7H2r1Ww81ZUGiGvgeB2u7FhwwaUlJTAYrFg+/btqK0dfCq6detWxMbGYseOHVi5ciU2btwIAKitrYXFYoHFYkFJSQnWr18Pt9sNALjrrrtQUlLih0OisSirseGW4gpsfP8YbpqejA+eyMVjeRmIjJBLXRqNwoU5j9p7+vDvt12Ns50OPPLn/ahv7eJcR+SV10Corq6G0WiEwWCAUqlEQUEBysrKBrUpLy/HsmXLAAD5+fnYs2cPRFFEWVkZCgoKoFQqYTAYYDQaUV1dDQCYP38+NBqusyu1r5q7sOqlT/HQy3sRIRfw6kML8Zv7rkWaNlrq0mgMevo8qDjWjIpjzTjT4cCtWXrsPtGCZy1HUHGsGZ2cGZWG4bUPwWazQa/XDzzW6XQDX+pfb5OSknJ+hwoF1Go17HY7bDYbZs+ePWhbm83mq9rpCnQ7XXhhZy1+V/EVlAoZfrY0Ew8suoqL2oSY6yYn4KvmLrx/uAnGhGgAiVKXRAEsqDqVHQ4Hampqhm3T29uJmYN+AAANuElEQVTrtU04E0URlSe78LvPWtDc7Ube5Bg8eG084qMdOHH8qF/e06mIQWPT+C3k4urru+T7zUhU+rSOK9nfaLcdaftLtbs+VYGGFhle/Wcdrk+LRGdT6I9R4PfA2HgNBJ1Oh6ampoHHNpsNOp1uSJvGxkbo9Xq4XC50dnZCq9WOaNvRUKlUyMzMHLZNTU2N1zbh6mhTJ/7j7UPY82ULrk6JxW/vn4l5V/n/DpQGezdS9ON3T3xjUyNS9ClDno+Kjr7k82N1Jfsb7bYjbX+5dt+NicdvPzqBX+9uxF8fvi7k17rm98BFowlGr9cHsrOzUVdXB6vVCqfTCYvFApPJNKiNyWTCtm3bAAClpaXIycmBIAgwmUywWCxwOp2wWq2oq6vDrFmzRnk4dKU6evuw4Z3DWLq5EjVNHfjPb2bhnUdvGJcwoMAwMS4Kt82aiE/r7JzviC7L6xmCQqFAUVERVq9eDbfbjeXLlyMjIwObNm1CVlYW8vLyUFhYiHXr1sFsNkOj0aC4uBgAkJGRgSVLlmDp0qWQy+UoKiqCXH7+rpUnnngCn376Kex2OxYvXoxHH30UK1as8O/RhhmPR8Rr+xvwX+8dQUuXE/cuSMePb5kO7QQuoBKO5l+lRbfTheIPjmH+pHjkTE6QuiQKMCPqQ8jNzUVubu6g59auXTvwZ5VKhc2bN19y2zVr1mDNmjVDnn/uuedGUyeNUnVDG4reOoQD1jbMTY/DH1YtQFYq7+oKZ4Ig4Mf503D8zDms/WsV/rF2MeL5ywF9DW8pCTGtXU785I1q3PnCLjTYe/CrFbPx2g8WMQwIABCtVOC/vz0H9q4+/Hjr5xC5yhp9TVDdZUSX53J78OdP6/Gr94/hnMOFh66fhMduzkBsZITUpVGAyUrV4KdLZ+A/3jmM/9lVh4dumCR1SRQgwiYQ2rudITso5/OGNhTvOI7aM+dwrVGLx2/OwKTECejo6UNHT5/U5cHR55a6BOp3YSRzXmYyyo6cwS/erYExPgozUmIBAGqVAppoXkYKV2ETCJ0OFyqONUtdhk919PThvUNNOGBtgyYqAt9ekI6sibGwtvbA2tojdXkD5qTHSV0C9evp86DqxPm1l3OnJaG6oR3rXqvGD3KnQB0ZgcXTEhkIYSxsAiGUuDwe7K5tQfnRM/B4RNw0PQm505I5yphGJVqpwL0L0lHy8Zf4w+46rL5hstQlkcQYCEHmuK0T71Q3ovmcAzP0ahRkpyCBC9bQGBnio3HfQiNe2XMSL++pw+JpnNoinPFXyiBh73Li1X+exEu76yCKIh64zoj7r7uKYUBXLEOnxj3zDbC2duOn276Aw8U+n3DFM4QA1+f2oOLYWXx07CwEAbjlah1umJoIBRe4Jx/KStXgrrlpeH1/A773yj78+t45vEMtDDEQApQoiqhp7IDli0bYu/uQnarBkiw94tjhR35yrVGLjOQJ2Pj+Mdz1/3bj9w/M45KpYYa/Zgags50O/GF3HV79pB5KhQwP3TAJ316QzjAgv7tt9kS88tACNJ9z4M4XduGfX7ZIXRKNIwZCAHH0ufHewUZsLjuO+tZu3DYrBT+6KQNTkmKkLo3ChMvtQXp8NH5731xoIiPwnZJP8Nz7R9HW5ZC6NBoHvGQUAERRxOcN7XjvYCM6el24Nl2LW2bqoOY1XBpnXx+ncP91V+G1fVZsLq/FF6fa8d/3zkWMil8ZoYw/XYk1tvfgnc9Po66lG6lxUfjOQiMM8Vy+kqQXpZTjvhwjKo834/3DTbjj1x9jy33XIkOnlro08hNeMpJIj9ONtz8/jV+X1+JMpwPL5qRizTemMAwooAiCgMXTkrDpW9ego8eFb76wC+8dHL/V72h88QxhnHlEEftO2lF6qAk9TjcWTo7HzZk6RCv5o6DAlZ2qwe/uvxY/23YQP3h1P76bk47VN05GXFQEp7oIIUH1LXSksQMPvVk2bJs+Vx8iFKeHPO/yiHC4PP4qbcRcbg+6nG4YE6Jxx+yJSNFESV0SkVc9fR7UNHbiW/MNeKf6NP74z3rs+bIVz98zm4EQQoIqECZEKnD91OGH1re1tyFOM3QytS6nC7b2wLhTYmpyDGalaSAIob2uLYUehVyGZXPSkKaNxjufn8aDf9iLF++fh2sMnMAwFARVIBi00fjlouEXzr7c4toN9u6Qm+2USCrzr4rHRE0U3qhqwIrf7sZjpgz84BtTEMER9EGNPz0iGpNUbRS23DcXN0xNxK92HMOtz1dgx+EmtHc7pS6NxoiBQERjFqGQw3y1HvctNOJspwPfe2Uf1vxpPz79qpXLcwahoLpkRESB6eqJsZicNAG7TzRjb50dd2/Zg+xUDb61wIA7Zk/kIMsgwTMEIvKJyAg5TDN0+OvDC/HkLdPQ7XThZ9sOYsGzZXhm+2G0d0u/nCsNj4FARD4lQkDCBBUevH4S1uROwdSkGPzu469ww3+V44WdtejlGtsBi4FARH4hCAIM8dH49oJ0/P6BazE7LQ6/LD2K3F/uxEsff4mOXp4xBBr2IRCR302Mi8bS7BTM0Kvx7heNWL+9Br947ygWZyQhZ3I8FkyKx8yJGshlHJsjJQYCEY2byUkxeOSmqYiLjsDOo2exq7YZH9TYAABx0RG4YWoiFk9Lwo0ZidDHRnLw5jhjIBDRuBIEAZOTYtDe48LcdC3ae/oQIQf2nWzDnhMt2F59fvI8TVQEpibHICM5BlmpGphmJGNiHKd68acRBUJFRQWeeeYZeDwerFixAg8//PCg151OJ5566ikcOnQIcXFxKC4uRlpaGgBgy5YteO211yCTyfDzn/8cN95444j2SUThQRMVgTnpcYhWnj9DaOroBSDC2tqDupYu/ONgE/76mRUAkJUai1uu1mP+VfGYpouBOjICIkSc6XCg9FATdh49g8OnO3BjehT+zxQXJ40cJa9/W263Gxs2bMBLL70EnU6HwsJCmEwmTJ06daDN1q1bERsbix07dsBisWDjxo14/vnnUVtbC4vFAovFApvNhlWrVqG0tBQAvO6TiMKPIAhI0URhTnocqurbsGBSAkRRRJo2Cnu+bMHHx5vx3I5jl91+StIEZKdp8PaRZhx5YRf+tDoHSWrVOB5BcPMaCNXV1TAajTAYDACAgoIClJWVDfryLi8vx49+9CMAQH5+PjZs2ABRFFFWVoaCggIolUoYDAYYjUZUV1cDgNd9EhEB50MiOTYS6fETcO/CCbg6JQYnznbhZGv3wC2sMaoIzEqNxen+CSzNGVo88/4J3L1lD36+dAZMmTr2R4yA10Cw2WzQ6/UDj3U63cCX+tfbpKSknN+hQgG1Wg273Q6bzYbZs2cP2tZmO9+B5G2fRESXoopQwOESoY+9eCYBADpN1EAgzJwYi+8sNOLPn9TjoVf2QSmXQTshAurICChkAmT94eDpn15DLhPg9ogQRUAmExAhP//Y7RGhVMigUpy/Q9/lEdHn9vRvI4NMAGSCALkgwC2KEEVx4E4pZ/90+0qFDIIgYLRxJIqAy+OBXCZgcUYSHs3LuKK/t5EIqgtsDocDNTU1Xttdrs2cWF9XRCPS1jWuf/dzYmMBdPm/jivZ32i3HWl7b+3G+vponh/Jc1fyeCR/dnahcIqAwinGSxQdjFwj+u67FIdj5NP+ew0EnU6Hpqamgcc2mw06nW5Im8bGRuj1erhcLnR2dkKr1Q67rbd9Xso111zj/YiIiGhMvI5Uzs7ORl1dHaxWK5xOJywWC0wm06A2JpMJ27ZtAwCUlpYiJycHgiDAZDLBYrHA6XTCarWirq4Os2bNGtE+iYhofHk9Q1AoFCgqKsLq1avhdruxfPlyZGRkYNOmTcjKykJeXh4KCwuxbt06mM1maDQaFBcXAwAyMjKwZMkSLF26FHK5HEVFRZDL5QBwyX0SEZF0BJGTlhMRETi5HRER9WMgEBERAAYCERH1YyAQERGAIBuYNhZWqxW/+c1vcO7cOWzevFnqckgCH3zwAT788EOcO3cOhYWFuOGGG6QuicbRiRMn8PLLL6OtrQ05OTm49957pS4pcIlB6OmnnxZzcnLEgoKCQc9/9NFH4i233CLefPPN4pYtWwa99uijj45nieRnY/kMtLW1iT/5yU/Gs0zyk7H8/N1ut/jkk0+OZ5lBJygD4dNPPxUPHjw46MPgcrnEvLw8sb6+XnQ4HOLtt98uHj9+fOB1BkJoGctn4Be/+IV48OBBKcolHxvtz/+DDz4QH3roIfHtt9+WquSgEJR9CPPnz4dGoxn03NdnZVUqlQMzqFJoGs1nQBRF/PKXv8TixYsxc+ZMiSomXxrtd0BeXh5KSkrwzjvvSFFu0AiZPoTLzcpqt9tRXFyMw4cPY8uWLfj+978vYZXkT5f7DPzxj3/Enj170NnZiZMnT+Lb3/62hFWSv1zu5//JJ59gx44dcDqdyM3NlbDCwBcygXA5Wq0WGzZskLoMktD999+P+++/X+oySCILFy7EwoULpS4jKATlJaNLGcmsrBTa+BkIb/z5X7mQCQTOoEr8DIQ3/vyvXFBObvfEE0/g008/hd1uR0JCAh599FGsWLECH330EZ599tmBGVTXrFkjdankJ/wMhDf+/P0jKAOBiIh8L2QuGRER0ZVhIBAREQAGAhER9WMgEBERAAYCERH1YyAQEREABgIREfVjIBAREQAGAhER9fv/QPXvwfWfLg8AAAAASUVORK5CYII=\n",
128 "text/plain": [
129 "<Figure size 432x288 with 1 Axes>"
130 ]
131 },
132 "metadata": {},
133 "output_type": "display_data"
134 }
135 ],
136 "source": [
137 "ax = sns.distplot([len(review) for review in X_train])\n",
138 "ax.set(xscale='log');"
139 ]
140 },
141 {
142 "cell_type": "markdown",
143 "metadata": {},
144 "source": [
145 "## Prepare Data"
146 ]
147 },
148 {
149 "cell_type": "markdown",
150 "metadata": {},
151 "source": [
152 "In the second step, convert the lists of integers into fixed-size arrays that we can stack and provide as input to our RNN. The pad_sequence function produces arrays of equal length, truncated, and padded to conform to maxlen, as follows:"
153 ]
154 },
155 {
156 "cell_type": "code",
157 "execution_count": 6,
158 "metadata": {},
159 "outputs": [],
160 "source": [
161 "maxlen = 100"
162 ]
163 },
164 {
165 "cell_type": "code",
166 "execution_count": 7,
167 "metadata": {},
168 "outputs": [],
169 "source": [
170 "X_train_padded = pad_sequences(X_train, \n",
171 " truncating='pre', \n",
172 " padding='pre', \n",
173 " maxlen=maxlen)\n",
174 "\n",
175 "X_test_padded = pad_sequences(X_test, \n",
176 " truncating='pre', \n",
177 " padding='pre', \n",
178 " maxlen=maxlen)"
179 ]
180 },
181 {
182 "cell_type": "code",
183 "execution_count": 8,
184 "metadata": {},
185 "outputs": [
186 {
187 "data": {
188 "text/plain": [
189 "((25000, 100), (25000, 100))"
190 ]
191 },
192 "execution_count": 8,
193 "metadata": {},
194 "output_type": "execute_result"
195 }
196 ],
197 "source": [
198 "X_train_padded.shape, X_test_padded.shape"
199 ]
200 },
201 {
202 "cell_type": "markdown",
203 "metadata": {},
204 "source": [
205 "## Define Model Architecture"
206 ]
207 },
208 {
209 "cell_type": "markdown",
210 "metadata": {},
211 "source": [
212 "Now we can define our RNN architecture. The first layer learns the word embeddings. We define the embedding dimension as previously using the input_dim keyword to set the number of tokens that we need to embed, the output_dim keyword, which defines the size of each embedding, and how long each input sequence is going to be."
213 ]
214 },
215 {
216 "cell_type": "markdown",
217 "metadata": {},
218 "source": [
219 "### Custom Loss Metric"
220 ]
221 },
222 {
223 "cell_type": "code",
224 "execution_count": 9,
225 "metadata": {},
226 "outputs": [],
227 "source": [
228 "# source: https://github.com/keras-team/keras/issues/3230\n",
229 "def auc(y_true, y_pred):\n",
230 " ptas = tf.stack([binary_PTA(y_true, y_pred, k) for k in np.linspace(0, 1, 1000)], axis=0)\n",
231 " pfas = tf.stack([binary_PFA(y_true, y_pred, k) for k in np.linspace(0, 1, 1000)], axis=0)\n",
232 " pfas = tf.concat([tf.ones((1,)), pfas], axis=0)\n",
233 " binSizes = -(pfas[1:] - pfas[:-1])\n",
234 " s = ptas * binSizes\n",
235 " return K.sum(s, axis=0)\n",
236 "\n",
237 "\n",
238 "def binary_PFA(y_true, y_pred, threshold=K.variable(value=0.5)):\n",
239 " \"\"\"prob false alert for binary classifier\"\"\"\n",
240 " y_pred = K.cast(y_pred >= threshold, 'float32')\n",
241 " # N = total number of negative labels\n",
242 " N = K.sum(1 - y_true)\n",
243 " # FP = total number of false alerts, alerts from the negative class labels\n",
244 " FP = K.sum(y_pred - y_pred * y_true)\n",
245 " return FP / (N + 1)\n",
246 "\n",
247 "\n",
248 "def binary_PTA(y_true, y_pred, threshold=K.variable(value=0.5)):\n",
249 " \"\"\"prob true alerts for binary classifier\"\"\"\n",
250 " y_pred = K.cast(y_pred >= threshold, 'float32')\n",
251 " # P = total number of positive labels\n",
252 " P = K.sum(y_true)\n",
253 " # TP = total number of correct alerts, alerts from the positive class labels\n",
254 " TP = K.sum(y_pred * y_true)\n",
255 " return TP / (P + 1)"
256 ]
257 },
258 {
259 "cell_type": "code",
260 "execution_count": 10,
261 "metadata": {},
262 "outputs": [],
263 "source": [
264 "embedding_size = 100"
265 ]
266 },
267 {
268 "cell_type": "markdown",
269 "metadata": {},
270 "source": [
271 "Note that we are using GRUs this time, which train faster and perform better on smaller data. We are also using dropout for regularization, as follows:"
272 ]
273 },
274 {
275 "cell_type": "code",
276 "execution_count": 11,
277 "metadata": {},
278 "outputs": [
279 {
280 "name": "stdout",
281 "output_type": "stream",
282 "text": [
283 "_________________________________________________________________\n",
284 "Layer (type) Output Shape Param # \n",
285 "=================================================================\n",
286 "embedding_1 (Embedding) (None, 100, 100) 2000000 \n",
287 "_________________________________________________________________\n",
288 "gru_1 (GRU) (None, 32) 12768 \n",
289 "_________________________________________________________________\n",
290 "dense_1 (Dense) (None, 1) 33 \n",
291 "=================================================================\n",
292 "Total params: 2,012,801\n",
293 "Trainable params: 2,012,801\n",
294 "Non-trainable params: 0\n",
295 "_________________________________________________________________\n"
296 ]
297 }
298 ],
299 "source": [
300 "rnn = Sequential([\n",
301 " Embedding(input_dim=vocab_size, output_dim= embedding_size, input_length=maxlen),\n",
302 " GRU(units=32, dropout=0.2, recurrent_dropout=0.2),\n",
303 " Dense(1, activation='sigmoid')\n",
304 "])\n",
305 "rnn.summary()"
306 ]
307 },
308 {
309 "cell_type": "markdown",
310 "metadata": {},
311 "source": [
312 "The resulting model has over 2 million parameters."
313 ]
314 },
315 {
316 "cell_type": "markdown",
317 "metadata": {},
318 "source": [
319 "We compile the model to use our custom AUC metric, which we introduced previously, and train with early stopping:"
320 ]
321 },
322 {
323 "cell_type": "code",
324 "execution_count": 16,
325 "metadata": {},
326 "outputs": [],
327 "source": [
328 "rnn.compile(loss='binary_crossentropy', \n",
329 " optimizer='RMSProp', \n",
330 " metrics=['accuracy', auc])"
331 ]
332 },
333 {
334 "cell_type": "code",
335 "execution_count": 17,
336 "metadata": {},
337 "outputs": [],
338 "source": [
339 "rnn_path = 'models/imdb.gru.weights.best.hdf5'\n",
340 "checkpointer = ModelCheckpoint(filepath=rnn_path,\n",
341 " monitor='val_auc',\n",
342 " save_best_only=True,\n",
343 " save_weights_only=True,\n",
344 " period=5)"
345 ]
346 },
347 {
348 "cell_type": "code",
349 "execution_count": 18,
350 "metadata": {},
351 "outputs": [],
352 "source": [
353 "early_stopping = EarlyStopping(monitor='val_auc', \n",
354 " mode='max',\n",
355 " patience=5,\n",
356 " restore_best_weights=True)"
357 ]
358 },
359 {
360 "cell_type": "markdown",
361 "metadata": {},
362 "source": [
363 "Training stops after eight epochs and we recover the weights for the best models to find a high test AUC of 0.9346:"
364 ]
365 },
366 {
367 "cell_type": "code",
368 "execution_count": 19,
369 "metadata": {},
370 "outputs": [
371 {
372 "name": "stdout",
373 "output_type": "stream",
374 "text": [
375 "Train on 25000 samples, validate on 25000 samples\n",
376 "Epoch 1/25\n",
377 "25000/25000 [==============================] - 125s 5ms/step - loss: 0.3562 - acc: 0.8514 - auc: 0.8717 - val_loss: 0.4181 - val_acc: 0.8311 - val_auc: 0.8704\n",
378 "Epoch 2/25\n",
379 "25000/25000 [==============================] - 112s 4ms/step - loss: 0.2964 - acc: 0.8833 - auc: 0.8924 - val_loss: 0.3363 - val_acc: 0.8575 - val_auc: 0.8817\n",
380 "Epoch 3/25\n",
381 "25000/25000 [==============================] - 112s 4ms/step - loss: 0.2551 - acc: 0.9006 - auc: 0.9054 - val_loss: 0.3450 - val_acc: 0.8585 - val_auc: 0.8823\n",
382 "Epoch 4/25\n",
383 "25000/25000 [==============================] - 113s 5ms/step - loss: 0.2231 - acc: 0.9145 - auc: 0.9129 - val_loss: 0.3393 - val_acc: 0.8583 - val_auc: 0.8822\n",
384 "Epoch 5/25\n",
385 "25000/25000 [==============================] - 112s 4ms/step - loss: 0.1997 - acc: 0.9258 - auc: 0.9179 - val_loss: 0.3517 - val_acc: 0.8558 - val_auc: 0.8815\n",
386 "Epoch 6/25\n",
387 "25000/25000 [==============================] - 117s 5ms/step - loss: 0.1780 - acc: 0.9335 - auc: 0.9221 - val_loss: 0.3980 - val_acc: 0.8443 - val_auc: 0.8787\n",
388 "Epoch 7/25\n",
389 "25000/25000 [==============================] - 122s 5ms/step - loss: 0.1571 - acc: 0.9436 - auc: 0.9251 - val_loss: 0.3949 - val_acc: 0.8511 - val_auc: 0.8782\n",
390 "Epoch 8/25\n",
391 "25000/25000 [==============================] - 112s 4ms/step - loss: 0.1422 - acc: 0.9484 - auc: 0.9279 - val_loss: 0.3815 - val_acc: 0.8507 - val_auc: 0.8780\n"
392 ]
393 },
394 {
395 "data": {
396 "text/plain": [
397 "<keras.callbacks.History at 0x7f01c4c8a400>"
398 ]
399 },
400 "execution_count": 19,
401 "metadata": {},
402 "output_type": "execute_result"
403 }
404 ],
405 "source": [
406 "rnn.fit(X_train_padded, \n",
407 " y_train, \n",
408 " batch_size=32, \n",
409 " epochs=25, \n",
410 " validation_data=(X_test_padded, y_test), \n",
411 " callbacks=[checkpointer, early_stopping],\n",
412 " verbose=1)"
413 ]
414 },
415 {
416 "cell_type": "markdown",
417 "metadata": {},
418 "source": [
419 "## Evaluate Results"
420 ]
421 },
422 {
423 "cell_type": "code",
424 "execution_count": 20,
425 "metadata": {},
426 "outputs": [],
427 "source": [
428 "rnn.load_weights(rnn_path)"
429 ]
430 },
431 {
432 "cell_type": "code",
433 "execution_count": 22,
434 "metadata": {},
435 "outputs": [
436 {
437 "data": {
438 "text/plain": [
439 "(25000, 1)"
440 ]
441 },
442 "execution_count": 22,
443 "metadata": {},
444 "output_type": "execute_result"
445 }
446 ],
447 "source": [
448 "y_score = rnn.predict(X_test_padded)\n",
449 "y_score.shape"
450 ]
451 },
452 {
453 "cell_type": "code",
454 "execution_count": 23,
455 "metadata": {},
456 "outputs": [
457 {
458 "data": {
459 "text/plain": [
460 "0.9346154079999999"
461 ]
462 },
463 "execution_count": 23,
464 "metadata": {},
465 "output_type": "execute_result"
466 }
467 ],
468 "source": [
469 "roc_auc_score(y_score=y_score.squeeze(), y_true=y_test)"
470 ]
471 },
472 {
473 "cell_type": "code",
474 "execution_count": null,
475 "metadata": {},
476 "outputs": [],
477 "source": []
478 }
479 ],
480 "metadata": {
481 "kernelspec": {
482 "display_name": "Python 3",
483 "language": "python",
484 "name": "python3"
485 },
486 "language_info": {
487 "codemirror_mode": {
488 "name": "ipython",
489 "version": 3
490 },
491 "file_extension": ".py",
492 "mimetype": "text/x-python",
493 "name": "python",
494 "nbconvert_exporter": "python",
495 "pygments_lexer": "ipython3",
496 "version": "3.6.8"
497 },
498 "toc": {
499 "base_numbering": 1,
500 "nav_menu": {},
501 "number_sections": true,
502 "sideBar": true,
503 "skip_h1_title": true,
504 "title_cell": "Table of Contents",
505 "title_sidebar": "Contents",
506 "toc_cell": false,
507 "toc_position": {},
508 "toc_section_display": true,
509 "toc_window_display": false
510 }
511 },
512 "nbformat": 4,
513 "nbformat_minor": 2
514 }