ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

04_news_text_classification.ipynb

(11467B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "markdown",
      5    "metadata": {},
      6    "source": [
      7     "# Text classification and sentiment analysis"
      8    ]
      9   },
     10   {
     11    "cell_type": "markdown",
     12    "metadata": {},
     13    "source": [
     14     "Once text data has been converted into numerical features using the natural language processing techniques discussed in the previous sections, text classification works just like any other classification task.\n",
     15     "\n",
     16     "In this notebook, we will apply these preprocessing technique to news articles, product reviews, and Twitter data and teach various classifiers to predict discrete news categories, review scores, and sentiment polarity."
     17    ]
     18   },
     19   {
     20    "cell_type": "markdown",
     21    "metadata": {},
     22    "source": [
     23     "## Imports"
     24    ]
     25   },
     26   {
     27    "cell_type": "code",
     28    "execution_count": 3,
     29    "metadata": {
     30     "ExecuteTime": {
     31      "end_time": "2018-11-26T06:37:13.006374Z",
     32      "start_time": "2018-11-26T06:37:12.515786Z"
     33     }
     34    },
     35    "outputs": [],
     36    "source": [
     37     "%matplotlib inline\n",
     38     "import warnings\n",
     39     "from collections import Counter, OrderedDict\n",
     40     "from pathlib import Path\n",
     41     "\n",
     42     "import numpy as np\n",
     43     "import pandas as pd\n",
     44     "from pandas.io.json import json_normalize\n",
     45     "import pyarrow as pa   \n",
     46     "import pyarrow.parquet as pq\n",
     47     "from fastparquet import ParquetFile \n",
     48     "from scipy import sparse\n",
     49     "from scipy.spatial.distance import pdist, squareform\n",
     50     "\n",
     51     "# Visualization\n",
     52     "import matplotlib.pyplot as plt\n",
     53     "from matplotlib.ticker import FuncFormatter, ScalarFormatter\n",
     54     "import seaborn as sns\n",
     55     "\n",
     56     "# spacy, textblob and nltk for language processing\n",
     57     "from textblob import TextBlob, Word\n",
     58     "\n",
     59     "# sklearn for feature extraction & modeling\n",
     60     "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n",
     61     "from sklearn.model_selection import train_test_split\n",
     62     "from sklearn.naive_bayes import MultinomialNB\n",
     63     "from sklearn.linear_model import LogisticRegression\n",
     64     "from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score, confusion_matrix\n",
     65     "from sklearn.externals import joblib\n",
     66     "\n",
     67     "import lightgbm as lgb\n",
     68     "\n",
     69     "import json\n",
     70     "from time import clock, time"
     71    ]
     72   },
     73   {
     74    "cell_type": "code",
     75    "execution_count": 4,
     76    "metadata": {
     77     "ExecuteTime": {
     78      "end_time": "2018-11-26T06:37:13.010613Z",
     79      "start_time": "2018-11-26T06:37:13.007802Z"
     80     }
     81    },
     82    "outputs": [],
     83    "source": [
     84     "plt.style.use('fivethirtyeight')\n",
     85     "warnings.filterwarnings('ignore')"
     86    ]
     87   },
     88   {
     89    "cell_type": "markdown",
     90    "metadata": {},
     91    "source": [
     92     "## News article classification"
     93    ]
     94   },
     95   {
     96    "cell_type": "markdown",
     97    "metadata": {},
     98    "source": [
     99     "We start with an illustration of the Naive Bayes model for news article classification using the BBC articles that we read as before to obtain a DataFrame with 2,225 articles from 5 categories."
    100    ]
    101   },
    102   {
    103    "cell_type": "markdown",
    104    "metadata": {},
    105    "source": [
    106     "### Read BBC articles"
    107    ]
    108   },
    109   {
    110    "cell_type": "code",
    111    "execution_count": 5,
    112    "metadata": {
    113     "ExecuteTime": {
    114      "end_time": "2018-11-22T18:43:12.744002Z",
    115      "start_time": "2018-11-22T18:43:12.665470Z"
    116     }
    117    },
    118    "outputs": [],
    119    "source": [
    120     "path = Path('data', 'bbc')\n",
    121     "files = path.glob('**/*.txt')\n",
    122     "doc_list = []\n",
    123     "for i, file in enumerate(files):\n",
    124     "    topic = file.parts[-2]\n",
    125     "    article = file.read_text(encoding='latin1').split('\\n')\n",
    126     "    heading = article[0].strip()\n",
    127     "    body = ' '.join([l.strip() for l in article[1:]])\n",
    128     "    doc_list.append([topic, heading, body])"
    129    ]
    130   },
    131   {
    132    "cell_type": "code",
    133    "execution_count": 6,
    134    "metadata": {
    135     "ExecuteTime": {
    136      "end_time": "2018-11-22T18:43:13.064527Z",
    137      "start_time": "2018-11-22T18:43:13.048580Z"
    138     }
    139    },
    140    "outputs": [
    141     {
    142      "name": "stdout",
    143      "output_type": "stream",
    144      "text": [
    145       "<class 'pandas.core.frame.DataFrame'>\n",
    146       "RangeIndex: 2225 entries, 0 to 2224\n",
    147       "Data columns (total 3 columns):\n",
    148       "topic      2225 non-null object\n",
    149       "heading    2225 non-null object\n",
    150       "body       2225 non-null object\n",
    151       "dtypes: object(3)\n",
    152       "memory usage: 52.2+ KB\n"
    153      ]
    154     }
    155    ],
    156    "source": [
    157     "docs = pd.DataFrame(doc_list, columns=['topic', 'heading', 'body'])\n",
    158     "docs.info()"
    159    ]
    160   },
    161   {
    162    "cell_type": "markdown",
    163    "metadata": {},
    164    "source": [
    165     "### Create stratified train-test split"
    166    ]
    167   },
    168   {
    169    "cell_type": "markdown",
    170    "metadata": {},
    171    "source": [
    172     "We split the data into the default 75:25 train-test sets, ensuring that the test set classes closely mirror the train set:"
    173    ]
    174   },
    175   {
    176    "cell_type": "code",
    177    "execution_count": 7,
    178    "metadata": {
    179     "ExecuteTime": {
    180      "end_time": "2018-11-22T18:43:29.481585Z",
    181      "start_time": "2018-11-22T18:43:29.476360Z"
    182     }
    183    },
    184    "outputs": [],
    185    "source": [
    186     "y = pd.factorize(docs.topic)[0]\n",
    187     "X = docs.body\n",
    188     "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1, stratify=y)"
    189    ]
    190   },
    191   {
    192    "cell_type": "markdown",
    193    "metadata": {},
    194    "source": [
    195     "### Vectorize text data"
    196    ]
    197   },
    198   {
    199    "cell_type": "markdown",
    200    "metadata": {},
    201    "source": [
    202     "We proceed to learn the vocabulary from the training set and transforming both dataset using the CountVectorizer with default settings to obtain almost 26,000 features:"
    203    ]
    204   },
    205   {
    206    "cell_type": "code",
    207    "execution_count": 8,
    208    "metadata": {
    209     "ExecuteTime": {
    210      "end_time": "2018-11-22T18:43:39.761327Z",
    211      "start_time": "2018-11-22T18:43:39.332132Z"
    212     }
    213    },
    214    "outputs": [],
    215    "source": [
    216     "vectorizer = CountVectorizer()\n",
    217     "X_train_dtm = vectorizer.fit_transform(X_train)\n",
    218     "X_test_dtm = vectorizer.transform(X_test)"
    219    ]
    220   },
    221   {
    222    "cell_type": "code",
    223    "execution_count": 9,
    224    "metadata": {
    225     "ExecuteTime": {
    226      "end_time": "2018-11-22T18:43:42.048891Z",
    227      "start_time": "2018-11-22T18:43:42.042059Z"
    228     }
    229    },
    230    "outputs": [
    231     {
    232      "data": {
    233       "text/plain": [
    234        "((1668, 25919), (557, 25919))"
    235       ]
    236      },
    237      "execution_count": 9,
    238      "metadata": {},
    239      "output_type": "execute_result"
    240     }
    241    ],
    242    "source": [
    243     "X_train_dtm.shape, X_test_dtm.shape"
    244    ]
    245   },
    246   {
    247    "cell_type": "markdown",
    248    "metadata": {},
    249    "source": [
    250     "### Train Multi-class Naive Bayes model"
    251    ]
    252   },
    253   {
    254    "cell_type": "code",
    255    "execution_count": 10,
    256    "metadata": {
    257     "ExecuteTime": {
    258      "end_time": "2018-11-22T18:43:50.244608Z",
    259      "start_time": "2018-11-22T18:43:50.223308Z"
    260     }
    261    },
    262    "outputs": [],
    263    "source": [
    264     "nb = MultinomialNB()\n",
    265     "nb.fit(X_train_dtm, y_train)\n",
    266     "y_pred_class = nb.predict(X_test_dtm)"
    267    ]
    268   },
    269   {
    270    "cell_type": "markdown",
    271    "metadata": {},
    272    "source": [
    273     "### Evaluate Results"
    274    ]
    275   },
    276   {
    277    "cell_type": "markdown",
    278    "metadata": {},
    279    "source": [
    280     "We evaluate the multiclass predictions using accuracy to find the default classifier achieved almost 98%:"
    281    ]
    282   },
    283   {
    284    "cell_type": "markdown",
    285    "metadata": {},
    286    "source": [
    287     "#### Accuracy"
    288    ]
    289   },
    290   {
    291    "cell_type": "code",
    292    "execution_count": 12,
    293    "metadata": {
    294     "ExecuteTime": {
    295      "end_time": "2018-11-22T18:43:57.224720Z",
    296      "start_time": "2018-11-22T18:43:57.208935Z"
    297     }
    298    },
    299    "outputs": [
    300     {
    301      "data": {
    302       "text/plain": [
    303        "0.9766606822262118"
    304       ]
    305      },
    306      "execution_count": 12,
    307      "metadata": {},
    308      "output_type": "execute_result"
    309     }
    310    ],
    311    "source": [
    312     "accuracy_score(y_test, y_pred_class)"
    313    ]
    314   },
    315   {
    316    "cell_type": "markdown",
    317    "metadata": {},
    318    "source": [
    319     "#### Confusion matrix"
    320    ]
    321   },
    322   {
    323    "cell_type": "code",
    324    "execution_count": 13,
    325    "metadata": {
    326     "ExecuteTime": {
    327      "end_time": "2018-11-22T18:44:00.292728Z",
    328      "start_time": "2018-11-22T18:44:00.268874Z"
    329     }
    330    },
    331    "outputs": [
    332     {
    333      "data": {
    334       "text/html": [
    335        "<div>\n",
    336        "<style scoped>\n",
    337        "    .dataframe tbody tr th:only-of-type {\n",
    338        "        vertical-align: middle;\n",
    339        "    }\n",
    340        "\n",
    341        "    .dataframe tbody tr th {\n",
    342        "        vertical-align: top;\n",
    343        "    }\n",
    344        "\n",
    345        "    .dataframe thead th {\n",
    346        "        text-align: right;\n",
    347        "    }\n",
    348        "</style>\n",
    349        "<table border=\"1\" class=\"dataframe\">\n",
    350        "  <thead>\n",
    351        "    <tr style=\"text-align: right;\">\n",
    352        "      <th></th>\n",
    353        "      <th>0</th>\n",
    354        "      <th>1</th>\n",
    355        "      <th>2</th>\n",
    356        "      <th>3</th>\n",
    357        "      <th>4</th>\n",
    358        "    </tr>\n",
    359        "  </thead>\n",
    360        "  <tbody>\n",
    361        "    <tr>\n",
    362        "      <th>0</th>\n",
    363        "      <td>98</td>\n",
    364        "      <td>0</td>\n",
    365        "      <td>0</td>\n",
    366        "      <td>2</td>\n",
    367        "      <td>0</td>\n",
    368        "    </tr>\n",
    369        "    <tr>\n",
    370        "      <th>1</th>\n",
    371        "      <td>0</td>\n",
    372        "      <td>128</td>\n",
    373        "      <td>0</td>\n",
    374        "      <td>0</td>\n",
    375        "      <td>0</td>\n",
    376        "    </tr>\n",
    377        "    <tr>\n",
    378        "      <th>2</th>\n",
    379        "      <td>0</td>\n",
    380        "      <td>0</td>\n",
    381        "      <td>102</td>\n",
    382        "      <td>2</td>\n",
    383        "      <td>0</td>\n",
    384        "    </tr>\n",
    385        "    <tr>\n",
    386        "      <th>3</th>\n",
    387        "      <td>2</td>\n",
    388        "      <td>0</td>\n",
    389        "      <td>5</td>\n",
    390        "      <td>121</td>\n",
    391        "      <td>0</td>\n",
    392        "    </tr>\n",
    393        "    <tr>\n",
    394        "      <th>4</th>\n",
    395        "      <td>0</td>\n",
    396        "      <td>0</td>\n",
    397        "      <td>1</td>\n",
    398        "      <td>1</td>\n",
    399        "      <td>95</td>\n",
    400        "    </tr>\n",
    401        "  </tbody>\n",
    402        "</table>\n",
    403        "</div>"
    404       ],
    405       "text/plain": [
    406        "    0    1    2    3   4\n",
    407        "0  98    0    0    2   0\n",
    408        "1   0  128    0    0   0\n",
    409        "2   0    0  102    2   0\n",
    410        "3   2    0    5  121   0\n",
    411        "4   0    0    1    1  95"
    412       ]
    413      },
    414      "execution_count": 13,
    415      "metadata": {},
    416      "output_type": "execute_result"
    417     }
    418    ],
    419    "source": [
    420     "pd.DataFrame(confusion_matrix(y_true=y_test, y_pred=y_pred_class))"
    421    ]
    422   }
    423  ],
    424  "metadata": {
    425   "celltoolbar": "Slideshow",
    426   "kernelspec": {
    427    "display_name": "Python 3",
    428    "language": "python",
    429    "name": "python3"
    430   },
    431   "language_info": {
    432    "codemirror_mode": {
    433     "name": "ipython",
    434     "version": 3
    435    },
    436    "file_extension": ".py",
    437    "mimetype": "text/x-python",
    438    "name": "python",
    439    "nbconvert_exporter": "python",
    440    "pygments_lexer": "ipython3",
    441    "version": "3.6.8"
    442   },
    443   "toc": {
    444    "base_numbering": 1,
    445    "nav_menu": {},
    446    "number_sections": true,
    447    "sideBar": true,
    448    "skip_h1_title": true,
    449    "title_cell": "Table of Contents",
    450    "title_sidebar": "Contents",
    451    "toc_cell": false,
    452    "toc_position": {
    453     "height": "calc(100% - 180px)",
    454     "left": "10px",
    455     "top": "150px",
    456     "width": "316px"
    457    },
    458    "toc_section_display": true,
    459    "toc_window_display": true
    460   }
    461  },
    462  "nbformat": 4,
    463  "nbformat_minor": 2
    464 }