ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

word2vec.ipynb

(33952B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "markdown",
      5    "metadata": {},
      6    "source": [
      7     "## Imports & Settings"
      8    ]
      9   },
     10   {
     11    "cell_type": "code",
     12    "execution_count": 1,
     13    "metadata": {
     14     "ExecuteTime": {
     15      "end_time": "2018-12-08T23:57:04.619453Z",
     16      "start_time": "2018-12-08T23:57:04.488154Z"
     17     }
     18    },
     19    "outputs": [],
     20    "source": [
     21     "from pathlib import Path\n",
     22     "from time import time\n",
     23     "import warnings\n",
     24     "from collections import Counter\n",
     25     "import logging\n",
     26     "from ast import literal_eval as make_tuple\n",
     27     "import numpy as np\n",
     28     "import pandas as pd\n",
     29     "\n",
     30     "from gensim.models import Word2Vec, KeyedVectors\n",
     31     "from gensim.models.word2vec import LineSentence\n",
     32     "import word2vec"
     33    ]
     34   },
     35   {
     36    "cell_type": "code",
     37    "execution_count": 2,
     38    "metadata": {
     39     "ExecuteTime": {
     40      "end_time": "2018-12-08T23:57:05.049257Z",
     41      "start_time": "2018-12-08T23:57:05.040701Z"
     42     }
     43    },
     44    "outputs": [],
     45    "source": [
     46     "pd.set_option('display.expand_frame_repr', False)\n",
     47     "warnings.filterwarnings('ignore')\n",
     48     "np.random.seed(42)"
     49    ]
     50   },
     51   {
     52    "cell_type": "code",
     53    "execution_count": 3,
     54    "metadata": {
     55     "ExecuteTime": {
     56      "end_time": "2018-12-08T23:57:05.244408Z",
     57      "start_time": "2018-12-08T23:57:05.240318Z"
     58     }
     59    },
     60    "outputs": [],
     61    "source": [
     62     "def format_time(t):\n",
     63     "    m, s = divmod(t, 60)\n",
     64     "    h, m = divmod(m, 60)\n",
     65     "    return '{:02.0f}:{:02.0f}:{:02.0f}'.format(h, m, s)"
     66    ]
     67   },
     68   {
     69    "cell_type": "markdown",
     70    "metadata": {},
     71    "source": [
     72     "### Logging Setup"
     73    ]
     74   },
     75   {
     76    "cell_type": "code",
     77    "execution_count": 4,
     78    "metadata": {
     79     "ExecuteTime": {
     80      "end_time": "2018-12-08T23:57:06.423935Z",
     81      "start_time": "2018-12-08T23:57:06.421773Z"
     82     }
     83    },
     84    "outputs": [],
     85    "source": [
     86     "logging.basicConfig(\n",
     87     "        filename='logs/word2vec.log',\n",
     88     "        level=logging.DEBUG,\n",
     89     "        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',\n",
     90     "        datefmt='%H:%M:%S')"
     91    ]
     92   },
     93   {
     94    "cell_type": "markdown",
     95    "metadata": {},
     96    "source": [
     97     "## word2vec"
     98    ]
     99   },
    100   {
    101    "cell_type": "code",
    102    "execution_count": 6,
    103    "metadata": {
    104     "ExecuteTime": {
    105      "end_time": "2018-12-08T23:57:34.969991Z",
    106      "start_time": "2018-12-08T23:57:34.967461Z"
    107     }
    108    },
    109    "outputs": [],
    110    "source": [
    111     "analogies_path = Path().cwd().parent / 'data' / 'analogies' / 'analogies-en.txt'"
    112    ]
    113   },
    114   {
    115    "cell_type": "markdown",
    116    "metadata": {},
    117    "source": [
    118     "### Set up Sentence Generator"
    119    ]
    120   },
    121   {
    122    "cell_type": "code",
    123    "execution_count": 8,
    124    "metadata": {
    125     "ExecuteTime": {
    126      "end_time": "2018-12-08T23:57:57.298178Z",
    127      "start_time": "2018-12-08T23:57:57.289388Z"
    128     }
    129    },
    130    "outputs": [],
    131    "source": [
    132     "NGRAMS = 2"
    133    ]
    134   },
    135   {
    136    "cell_type": "markdown",
    137    "metadata": {},
    138    "source": [
    139     "To facilitate memory-efficient text ingestion, the LineSentence class creates a generator from individual sentences contained in the provided text file:"
    140    ]
    141   },
    142   {
    143    "cell_type": "code",
    144    "execution_count": 9,
    145    "metadata": {
    146     "ExecuteTime": {
    147      "end_time": "2018-12-08T23:57:58.496781Z",
    148      "start_time": "2018-12-08T23:57:58.494515Z"
    149     }
    150    },
    151    "outputs": [],
    152    "source": [
    153     "sentence_path = Path('data', 'ngrams', f'ngrams_{NGRAMS}.txt')\n",
    154     "sentences = LineSentence(sentence_path)"
    155    ]
    156   },
    157   {
    158    "cell_type": "markdown",
    159    "metadata": {},
    160    "source": [
    161     "### Train word2vec Model"
    162    ]
    163   },
    164   {
    165    "cell_type": "markdown",
    166    "metadata": {},
    167    "source": [
    168     "The [gensim.models.word2vec](https://radimrehurek.com/gensim/models/word2vec.html) class implements the skipgram and CBOW architectures introduced above. The notebook [word2vec](../03_word2vec.ipynb) contains additional implementation detail."
    169    ]
    170   },
    171   {
    172    "cell_type": "code",
    173    "execution_count": 10,
    174    "metadata": {
    175     "ExecuteTime": {
    176      "end_time": "2018-12-09T00:09:31.218671Z",
    177      "start_time": "2018-12-08T23:58:43.716464Z"
    178     }
    179    },
    180    "outputs": [
    181     {
    182      "name": "stdout",
    183      "output_type": "stream",
    184      "text": [
    185       "Duration: 00:10:47\n"
    186      ]
    187     }
    188    ],
    189    "source": [
    190     "start = time()\n",
    191     "model = Word2Vec(sentences,\n",
    192     "                 sg=1,          # 1 for skip-gram; otherwise CBOW\n",
    193     "                 hs=0,          # hierarchical softmax if 1, negative sampling if 0\n",
    194     "                 size=300,      # Vector dimensionality\n",
    195     "                 window=3,      # Max distance betw. current and predicted word\n",
    196     "                 min_count=50,  # Ignore words with lower frequency\n",
    197     "                 negative=10,    # noise word count for negative sampling\n",
    198     "                 workers=8,     # no threads \n",
    199     "                 iter=1,        # no epochs = iterations over corpus\n",
    200     "                 alpha=0.025,   # initial learning rate\n",
    201     "                 min_alpha=0.0001 # final learning rate\n",
    202     "                ) \n",
    203     "print('Duration:', format_time(time() - start))"
    204    ]
    205   },
    206   {
    207    "cell_type": "markdown",
    208    "metadata": {},
    209    "source": [
    210     "### Persist model & vectors"
    211    ]
    212   },
    213   {
    214    "cell_type": "code",
    215    "execution_count": 11,
    216    "metadata": {
    217     "ExecuteTime": {
    218      "end_time": "2018-12-09T00:10:01.380925Z",
    219      "start_time": "2018-12-09T00:10:01.143768Z"
    220     }
    221    },
    222    "outputs": [],
    223    "source": [
    224     "model.save('models/baseline/word2vec.model')\n",
    225     "model.wv.save('models/baseline/word_vectors.bin')"
    226    ]
    227   },
    228   {
    229    "cell_type": "markdown",
    230    "metadata": {},
    231    "source": [
    232     "### Load model and vectors"
    233    ]
    234   },
    235   {
    236    "cell_type": "code",
    237    "execution_count": 40,
    238    "metadata": {
    239     "ExecuteTime": {
    240      "end_time": "2018-12-10T00:45:27.525905Z",
    241      "start_time": "2018-12-10T00:45:27.171700Z"
    242     }
    243    },
    244    "outputs": [],
    245    "source": [
    246     "model = Word2Vec.load('models/archive/word2vec.model')"
    247    ]
    248   },
    249   {
    250    "cell_type": "code",
    251    "execution_count": 8,
    252    "metadata": {
    253     "ExecuteTime": {
    254      "end_time": "2018-12-08T22:53:13.020767Z",
    255      "start_time": "2018-12-08T22:53:12.843245Z"
    256     }
    257    },
    258    "outputs": [],
    259    "source": [
    260     "wv = KeyedVectors.load('models/baseline/word_vectors.bin')"
    261    ]
    262   },
    263   {
    264    "cell_type": "markdown",
    265    "metadata": {},
    266    "source": [
    267     "### Get vocabulary"
    268    ]
    269   },
    270   {
    271    "cell_type": "code",
    272    "execution_count": 12,
    273    "metadata": {
    274     "ExecuteTime": {
    275      "end_time": "2018-12-09T00:11:04.596716Z",
    276      "start_time": "2018-12-09T00:11:04.539228Z"
    277     }
    278    },
    279    "outputs": [],
    280    "source": [
    281     "vocab = []\n",
    282     "for k, _ in model.wv.vocab.items():\n",
    283     "    v_ = model.wv.vocab[k]\n",
    284     "    vocab.append([k, v_.index, v_.count])"
    285    ]
    286   },
    287   {
    288    "cell_type": "code",
    289    "execution_count": 13,
    290    "metadata": {
    291     "ExecuteTime": {
    292      "end_time": "2018-12-09T00:11:04.905084Z",
    293      "start_time": "2018-12-09T00:11:04.868230Z"
    294     }
    295    },
    296    "outputs": [],
    297    "source": [
    298     "vocab = (pd.DataFrame(vocab, \n",
    299     "                     columns=['token', 'idx', 'count'])\n",
    300     "         .sort_values('count', ascending=False))"
    301    ]
    302   },
    303   {
    304    "cell_type": "code",
    305    "execution_count": 14,
    306    "metadata": {
    307     "ExecuteTime": {
    308      "end_time": "2018-12-09T00:11:06.691657Z",
    309      "start_time": "2018-12-09T00:11:06.679881Z"
    310     }
    311    },
    312    "outputs": [
    313     {
    314      "name": "stdout",
    315      "output_type": "stream",
    316      "text": [
    317       "<class 'pandas.core.frame.DataFrame'>\n",
    318       "Int64Index: 50491 entries, 104 to 46372\n",
    319       "Data columns (total 3 columns):\n",
    320       "token    50491 non-null object\n",
    321       "idx      50491 non-null int64\n",
    322       "count    50491 non-null int64\n",
    323       "dtypes: int64(2), object(1)\n",
    324       "memory usage: 1.5+ MB\n"
    325      ]
    326     }
    327    ],
    328    "source": [
    329     "vocab.info()"
    330    ]
    331   },
    332   {
    333    "cell_type": "code",
    334    "execution_count": 15,
    335    "metadata": {
    336     "ExecuteTime": {
    337      "end_time": "2018-12-09T00:11:07.220241Z",
    338      "start_time": "2018-12-09T00:11:07.202935Z"
    339     }
    340    },
    341    "outputs": [
    342     {
    343      "data": {
    344       "text/html": [
    345        "<div>\n",
    346        "<style scoped>\n",
    347        "    .dataframe tbody tr th:only-of-type {\n",
    348        "        vertical-align: middle;\n",
    349        "    }\n",
    350        "\n",
    351        "    .dataframe tbody tr th {\n",
    352        "        vertical-align: top;\n",
    353        "    }\n",
    354        "\n",
    355        "    .dataframe thead th {\n",
    356        "        text-align: right;\n",
    357        "    }\n",
    358        "</style>\n",
    359        "<table border=\"1\" class=\"dataframe\">\n",
    360        "  <thead>\n",
    361        "    <tr style=\"text-align: right;\">\n",
    362        "      <th></th>\n",
    363        "      <th>token</th>\n",
    364        "      <th>idx</th>\n",
    365        "      <th>count</th>\n",
    366        "    </tr>\n",
    367        "  </thead>\n",
    368        "  <tbody>\n",
    369        "    <tr>\n",
    370        "      <th>104</th>\n",
    371        "      <td>million</td>\n",
    372        "      <td>0</td>\n",
    373        "      <td>2340243</td>\n",
    374        "    </tr>\n",
    375        "    <tr>\n",
    376        "      <th>0</th>\n",
    377        "      <td>business</td>\n",
    378        "      <td>1</td>\n",
    379        "      <td>1700662</td>\n",
    380        "    </tr>\n",
    381        "    <tr>\n",
    382        "      <th>66</th>\n",
    383        "      <td>december</td>\n",
    384        "      <td>2</td>\n",
    385        "      <td>1513533</td>\n",
    386        "    </tr>\n",
    387        "    <tr>\n",
    388        "      <th>627</th>\n",
    389        "      <td>company</td>\n",
    390        "      <td>3</td>\n",
    391        "      <td>1490752</td>\n",
    392        "    </tr>\n",
    393        "    <tr>\n",
    394        "      <th>477</th>\n",
    395        "      <td>products</td>\n",
    396        "      <td>4</td>\n",
    397        "      <td>1368711</td>\n",
    398        "    </tr>\n",
    399        "    <tr>\n",
    400        "      <th>1071</th>\n",
    401        "      <td>net</td>\n",
    402        "      <td>5</td>\n",
    403        "      <td>1253343</td>\n",
    404        "    </tr>\n",
    405        "    <tr>\n",
    406        "      <th>145</th>\n",
    407        "      <td>market</td>\n",
    408        "      <td>6</td>\n",
    409        "      <td>1149048</td>\n",
    410        "    </tr>\n",
    411        "    <tr>\n",
    412        "      <th>380</th>\n",
    413        "      <td>including</td>\n",
    414        "      <td>7</td>\n",
    415        "      <td>1110482</td>\n",
    416        "    </tr>\n",
    417        "    <tr>\n",
    418        "      <th>381</th>\n",
    419        "      <td>sales</td>\n",
    420        "      <td>8</td>\n",
    421        "      <td>1098312</td>\n",
    422        "    </tr>\n",
    423        "    <tr>\n",
    424        "      <th>60</th>\n",
    425        "      <td>costs</td>\n",
    426        "      <td>9</td>\n",
    427        "      <td>1020383</td>\n",
    428        "    </tr>\n",
    429        "  </tbody>\n",
    430        "</table>\n",
    431        "</div>"
    432       ],
    433       "text/plain": [
    434        "          token  idx    count\n",
    435        "104     million    0  2340243\n",
    436        "0      business    1  1700662\n",
    437        "66     december    2  1513533\n",
    438        "627     company    3  1490752\n",
    439        "477    products    4  1368711\n",
    440        "1071        net    5  1253343\n",
    441        "145      market    6  1149048\n",
    442        "380   including    7  1110482\n",
    443        "381       sales    8  1098312\n",
    444        "60        costs    9  1020383"
    445       ]
    446      },
    447      "execution_count": 15,
    448      "metadata": {},
    449      "output_type": "execute_result"
    450     }
    451    ],
    452    "source": [
    453     "vocab.head(10)"
    454    ]
    455   },
    456   {
    457    "cell_type": "code",
    458    "execution_count": 16,
    459    "metadata": {
    460     "ExecuteTime": {
    461      "end_time": "2018-12-09T00:11:14.683574Z",
    462      "start_time": "2018-12-09T00:11:14.648032Z"
    463     }
    464    },
    465    "outputs": [
    466     {
    467      "data": {
    468       "text/plain": [
    469        "count      50491\n",
    470        "mean        5110\n",
    471        "std        37525\n",
    472        "min           50\n",
    473        "10%           61\n",
    474        "20%           78\n",
    475        "30.0%        102\n",
    476        "40%          137\n",
    477        "50%          195\n",
    478        "60%          300\n",
    479        "70%          522\n",
    480        "80%         1164\n",
    481        "90%         4578\n",
    482        "max      2340243\n",
    483        "Name: count, dtype: int64"
    484       ]
    485      },
    486      "execution_count": 16,
    487      "metadata": {},
    488      "output_type": "execute_result"
    489     }
    490    ],
    491    "source": [
    492     "vocab['count'].describe(percentiles=np.arange(.1, 1, .1)).astype(int)"
    493    ]
    494   },
    495   {
    496    "cell_type": "markdown",
    497    "metadata": {},
    498    "source": [
    499     "### Evaluate Analogies"
    500    ]
    501   },
    502   {
    503    "cell_type": "code",
    504    "execution_count": 110,
    505    "metadata": {
    506     "ExecuteTime": {
    507      "end_time": "2018-12-10T04:38:54.485888Z",
    508      "start_time": "2018-12-10T04:38:54.482447Z"
    509     }
    510    },
    511    "outputs": [],
    512    "source": [
    513     "def eval_analogies(w2v, max_vocab=15000):\n",
    514     "    accuracy = w2v.wv.accuracy(ANALOGIES_PATH,\n",
    515     "                               restrict_vocab=15000,\n",
    516     "                               case_insensitive=True)\n",
    517     "    return (pd.DataFrame([[c['section'],\n",
    518     "                        len(c['correct']),\n",
    519     "                        len(c['incorrect'])] for c in accuracy],\n",
    520     "                      columns=['category', 'correct', 'incorrect'])\n",
    521     "          .assign(average=lambda x: \n",
    522     "                  x.correct.div(x.correct.add(x.incorrect))))  "
    523    ]
    524   },
    525   {
    526    "cell_type": "code",
    527    "execution_count": 52,
    528    "metadata": {
    529     "ExecuteTime": {
    530      "end_time": "2018-12-08T23:21:32.500459Z",
    531      "start_time": "2018-12-08T23:21:32.498477Z"
    532     }
    533    },
    534    "outputs": [],
    535    "source": [
    536     "def total_accuracy(w2v):\n",
    537     "    df = eval_analogies(w2v)\n",
    538     "    return df.loc[df.category == 'total', ['correct', 'incorrect', 'average']].squeeze().tolist()"
    539    ]
    540   },
    541   {
    542    "cell_type": "code",
    543    "execution_count": 42,
    544    "metadata": {
    545     "ExecuteTime": {
    546      "end_time": "2018-12-10T00:45:44.852024Z",
    547      "start_time": "2018-12-10T00:45:38.732034Z"
    548     }
    549    },
    550    "outputs": [
    551     {
    552      "data": {
    553       "text/html": [
    554        "<div>\n",
    555        "<style scoped>\n",
    556        "    .dataframe tbody tr th:only-of-type {\n",
    557        "        vertical-align: middle;\n",
    558        "    }\n",
    559        "\n",
    560        "    .dataframe tbody tr th {\n",
    561        "        vertical-align: top;\n",
    562        "    }\n",
    563        "\n",
    564        "    .dataframe thead th {\n",
    565        "        text-align: right;\n",
    566        "    }\n",
    567        "</style>\n",
    568        "<table border=\"1\" class=\"dataframe\">\n",
    569        "  <thead>\n",
    570        "    <tr style=\"text-align: right;\">\n",
    571        "      <th></th>\n",
    572        "      <th>category</th>\n",
    573        "      <th>correct</th>\n",
    574        "      <th>incorrect</th>\n",
    575        "      <th>average</th>\n",
    576        "    </tr>\n",
    577        "  </thead>\n",
    578        "  <tbody>\n",
    579        "    <tr>\n",
    580        "      <th>0</th>\n",
    581        "      <td>capital-common-countries</td>\n",
    582        "      <td>2</td>\n",
    583        "      <td>4</td>\n",
    584        "      <td>0.333333</td>\n",
    585        "    </tr>\n",
    586        "    <tr>\n",
    587        "      <th>1</th>\n",
    588        "      <td>capital-world</td>\n",
    589        "      <td>0</td>\n",
    590        "      <td>0</td>\n",
    591        "      <td>0.000000</td>\n",
    592        "    </tr>\n",
    593        "    <tr>\n",
    594        "      <th>2</th>\n",
    595        "      <td>city-in-state</td>\n",
    596        "      <td>140</td>\n",
    597        "      <td>390</td>\n",
    598        "      <td>0.264151</td>\n",
    599        "    </tr>\n",
    600        "    <tr>\n",
    601        "      <th>3</th>\n",
    602        "      <td>currency</td>\n",
    603        "      <td>2</td>\n",
    604        "      <td>26</td>\n",
    605        "      <td>0.071429</td>\n",
    606        "    </tr>\n",
    607        "    <tr>\n",
    608        "      <th>4</th>\n",
    609        "      <td>family</td>\n",
    610        "      <td>0</td>\n",
    611        "      <td>0</td>\n",
    612        "      <td>0.000000</td>\n",
    613        "    </tr>\n",
    614        "    <tr>\n",
    615        "      <th>5</th>\n",
    616        "      <td>gram1-adjective-to-adverb</td>\n",
    617        "      <td>48</td>\n",
    618        "      <td>134</td>\n",
    619        "      <td>0.263736</td>\n",
    620        "    </tr>\n",
    621        "    <tr>\n",
    622        "      <th>6</th>\n",
    623        "      <td>gram2-opposite</td>\n",
    624        "      <td>23</td>\n",
    625        "      <td>67</td>\n",
    626        "      <td>0.255556</td>\n",
    627        "    </tr>\n",
    628        "    <tr>\n",
    629        "      <th>7</th>\n",
    630        "      <td>gram3-comparative</td>\n",
    631        "      <td>240</td>\n",
    632        "      <td>222</td>\n",
    633        "      <td>0.519481</td>\n",
    634        "    </tr>\n",
    635        "    <tr>\n",
    636        "      <th>8</th>\n",
    637        "      <td>gram4-superlative</td>\n",
    638        "      <td>19</td>\n",
    639        "      <td>53</td>\n",
    640        "      <td>0.263889</td>\n",
    641        "    </tr>\n",
    642        "    <tr>\n",
    643        "      <th>9</th>\n",
    644        "      <td>gram5-present-participle</td>\n",
    645        "      <td>90</td>\n",
    646        "      <td>182</td>\n",
    647        "      <td>0.330882</td>\n",
    648        "    </tr>\n",
    649        "    <tr>\n",
    650        "      <th>10</th>\n",
    651        "      <td>gram6-nationality-adjective</td>\n",
    652        "      <td>250</td>\n",
    653        "      <td>130</td>\n",
    654        "      <td>0.657895</td>\n",
    655        "    </tr>\n",
    656        "    <tr>\n",
    657        "      <th>11</th>\n",
    658        "      <td>gram7-past-tense</td>\n",
    659        "      <td>94</td>\n",
    660        "      <td>286</td>\n",
    661        "      <td>0.247368</td>\n",
    662        "    </tr>\n",
    663        "    <tr>\n",
    664        "      <th>12</th>\n",
    665        "      <td>gram8-plural</td>\n",
    666        "      <td>87</td>\n",
    667        "      <td>69</td>\n",
    668        "      <td>0.557692</td>\n",
    669        "    </tr>\n",
    670        "    <tr>\n",
    671        "      <th>13</th>\n",
    672        "      <td>gram9-plural-verbs</td>\n",
    673        "      <td>72</td>\n",
    674        "      <td>138</td>\n",
    675        "      <td>0.342857</td>\n",
    676        "    </tr>\n",
    677        "    <tr>\n",
    678        "      <th>14</th>\n",
    679        "      <td>total</td>\n",
    680        "      <td>1067</td>\n",
    681        "      <td>1701</td>\n",
    682        "      <td>0.385477</td>\n",
    683        "    </tr>\n",
    684        "  </tbody>\n",
    685        "</table>\n",
    686        "</div>"
    687       ],
    688       "text/plain": [
    689        "                       category  correct  incorrect   average\n",
    690        "0      capital-common-countries        2          4  0.333333\n",
    691        "1                 capital-world        0          0  0.000000\n",
    692        "2                 city-in-state      140        390  0.264151\n",
    693        "3                      currency        2         26  0.071429\n",
    694        "4                        family        0          0  0.000000\n",
    695        "5     gram1-adjective-to-adverb       48        134  0.263736\n",
    696        "6                gram2-opposite       23         67  0.255556\n",
    697        "7             gram3-comparative      240        222  0.519481\n",
    698        "8             gram4-superlative       19         53  0.263889\n",
    699        "9      gram5-present-participle       90        182  0.330882\n",
    700        "10  gram6-nationality-adjective      250        130  0.657895\n",
    701        "11             gram7-past-tense       94        286  0.247368\n",
    702        "12                 gram8-plural       87         69  0.557692\n",
    703        "13           gram9-plural-verbs       72        138  0.342857\n",
    704        "14                        total     1067       1701  0.385477"
    705       ]
    706      },
    707      "execution_count": 42,
    708      "metadata": {},
    709      "output_type": "execute_result"
    710     }
    711    ],
    712    "source": [
    713     "accuracy = eval_analogies(model)\n",
    714     "accuracy"
    715    ]
    716   },
    717   {
    718    "cell_type": "markdown",
    719    "metadata": {},
    720    "source": [
    721     "### Validate Vector Arithmetic"
    722    ]
    723   },
    724   {
    725    "cell_type": "code",
    726    "execution_count": 105,
    727    "metadata": {
    728     "ExecuteTime": {
    729      "end_time": "2018-12-10T01:00:35.772447Z",
    730      "start_time": "2018-12-10T01:00:35.756869Z"
    731     }
    732    },
    733    "outputs": [
    734     {
    735      "data": {
    736       "text/html": [
    737        "<div>\n",
    738        "<style scoped>\n",
    739        "    .dataframe tbody tr th:only-of-type {\n",
    740        "        vertical-align: middle;\n",
    741        "    }\n",
    742        "\n",
    743        "    .dataframe tbody tr th {\n",
    744        "        vertical-align: top;\n",
    745        "    }\n",
    746        "\n",
    747        "    .dataframe thead th {\n",
    748        "        text-align: right;\n",
    749        "    }\n",
    750        "</style>\n",
    751        "<table border=\"1\" class=\"dataframe\">\n",
    752        "  <thead>\n",
    753        "    <tr style=\"text-align: right;\">\n",
    754        "      <th></th>\n",
    755        "      <th>0</th>\n",
    756        "      <th>1</th>\n",
    757        "      <th>2</th>\n",
    758        "      <th>3</th>\n",
    759        "    </tr>\n",
    760        "  </thead>\n",
    761        "  <tbody>\n",
    762        "    <tr>\n",
    763        "      <th>0</th>\n",
    764        "      <td>:</td>\n",
    765        "      <td>capital-common-countries</td>\n",
    766        "      <td>NaN</td>\n",
    767        "      <td>NaN</td>\n",
    768        "    </tr>\n",
    769        "    <tr>\n",
    770        "      <th>1</th>\n",
    771        "      <td>athens</td>\n",
    772        "      <td>greece</td>\n",
    773        "      <td>baghdad</td>\n",
    774        "      <td>iraq</td>\n",
    775        "    </tr>\n",
    776        "    <tr>\n",
    777        "      <th>2</th>\n",
    778        "      <td>athens</td>\n",
    779        "      <td>greece</td>\n",
    780        "      <td>bangkok</td>\n",
    781        "      <td>thailand</td>\n",
    782        "    </tr>\n",
    783        "    <tr>\n",
    784        "      <th>3</th>\n",
    785        "      <td>athens</td>\n",
    786        "      <td>greece</td>\n",
    787        "      <td>beijing</td>\n",
    788        "      <td>china</td>\n",
    789        "    </tr>\n",
    790        "    <tr>\n",
    791        "      <th>4</th>\n",
    792        "      <td>athens</td>\n",
    793        "      <td>greece</td>\n",
    794        "      <td>berlin</td>\n",
    795        "      <td>germany</td>\n",
    796        "    </tr>\n",
    797        "  </tbody>\n",
    798        "</table>\n",
    799        "</div>"
    800       ],
    801       "text/plain": [
    802        "        0                         1        2         3\n",
    803        "0       :  capital-common-countries      NaN       NaN\n",
    804        "1  athens                    greece  baghdad      iraq\n",
    805        "2  athens                    greece  bangkok  thailand\n",
    806        "3  athens                    greece  beijing     china\n",
    807        "4  athens                    greece   berlin   germany"
    808       ]
    809      },
    810      "execution_count": 105,
    811      "metadata": {},
    812      "output_type": "execute_result"
    813     }
    814    ],
    815    "source": [
    816     "pd.read_csv(ANALOGIES_PATH, header=None, sep=' ').head()"
    817    ]
    818   },
    819   {
    820    "cell_type": "code",
    821    "execution_count": 112,
    822    "metadata": {
    823     "ExecuteTime": {
    824      "end_time": "2018-12-10T08:11:19.340922Z",
    825      "start_time": "2018-12-10T08:11:19.334225Z"
    826     }
    827    },
    828    "outputs": [
    829     {
    830      "name": "stdout",
    831      "output_type": "stream",
    832      "text": [
    833       "                  term  similarity\n",
    834       "0              android    0.600454\n",
    835       "1           smartphone    0.581685\n",
    836       "2                  app    0.559129\n",
    837       "3          smartphones    0.533848\n",
    838       "4  smartphones_tablets    0.526129\n",
    839       "5             handsets    0.514813\n",
    840       "6         smart_phones    0.512868\n",
    841       "7                apple    0.507795\n",
    842       "8                 apps    0.505517\n",
    843       "9              handset    0.491526\n"
    844      ]
    845     }
    846    ],
    847    "source": [
    848     "sims=model.wv.most_similar(positive=['iphone'], \n",
    849     "                           restrict_vocab=15000)\n",
    850     "print(pd.DataFrame(sims, columns=['term', 'similarity']))"
    851    ]
    852   },
    853   {
    854    "cell_type": "code",
    855    "execution_count": 113,
    856    "metadata": {
    857     "ExecuteTime": {
    858      "end_time": "2018-12-10T08:14:19.395370Z",
    859      "start_time": "2018-12-10T08:14:19.381754Z"
    860     }
    861    },
    862    "outputs": [
    863     {
    864      "name": "stdout",
    865      "output_type": "stream",
    866      "text": [
    867       "             term  similarity\n",
    868       "0  united_kingdom    0.606630\n",
    869       "1         germany    0.585644\n",
    870       "2     netherlands    0.578868\n",
    871       "3           italy    0.547168\n",
    872       "4           india    0.545213\n",
    873       "5           spain    0.539029\n",
    874       "6       singapore    0.535106\n",
    875       "7       australia    0.525464\n",
    876       "8         belgium    0.523677\n",
    877       "9          sweden    0.510462\n"
    878      ]
    879     }
    880    ],
    881    "source": [
    882     "analogy = model.wv.most_similar(positive=['france', 'london'], \n",
    883     "                                negative=['paris'], \n",
    884     "                                restrict_vocab=15000)\n",
    885     "print(pd.DataFrame(analogy, columns=['term', 'similarity']))"
    886    ]
    887   },
    888   {
    889    "cell_type": "markdown",
    890    "metadata": {},
    891    "source": [
    892     "### Check similarity for random words"
    893    ]
    894   },
    895   {
    896    "cell_type": "code",
    897    "execution_count": 41,
    898    "metadata": {
    899     "ExecuteTime": {
    900      "end_time": "2018-12-08T23:10:41.702789Z",
    901      "start_time": "2018-12-08T23:10:41.640280Z"
    902     }
    903    },
    904    "outputs": [
    905     {
    906      "name": "stderr",
    907      "output_type": "stream",
    908      "text": [
    909       "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
    910       "  if np.issubdtype(vec.dtype, np.int):\n"
    911      ]
    912     },
    913     {
    914      "data": {
    915       "text/html": [
    916        "<div>\n",
    917        "<style scoped>\n",
    918        "    .dataframe tbody tr th:only-of-type {\n",
    919        "        vertical-align: middle;\n",
    920        "    }\n",
    921        "\n",
    922        "    .dataframe tbody tr th {\n",
    923        "        vertical-align: top;\n",
    924        "    }\n",
    925        "\n",
    926        "    .dataframe thead th {\n",
    927        "        text-align: right;\n",
    928        "    }\n",
    929        "</style>\n",
    930        "<table border=\"1\" class=\"dataframe\">\n",
    931        "  <thead>\n",
    932        "    <tr style=\"text-align: right;\">\n",
    933        "      <th></th>\n",
    934        "      <th>staff</th>\n",
    935        "      <th>enables</th>\n",
    936        "      <th>times</th>\n",
    937        "      <th>fees</th>\n",
    938        "      <th>sources</th>\n",
    939        "    </tr>\n",
    940        "  </thead>\n",
    941        "  <tbody>\n",
    942        "    <tr>\n",
    943        "      <th>0</th>\n",
    944        "      <td>personnel</td>\n",
    945        "      <td>allows</td>\n",
    946        "      <td>twice</td>\n",
    947        "      <td>fee</td>\n",
    948        "      <td>source</td>\n",
    949        "    </tr>\n",
    950        "    <tr>\n",
    951        "      <th>1</th>\n",
    952        "      <td>team</td>\n",
    953        "      <td>enabling</td>\n",
    954        "      <td>standpoint_advantageous</td>\n",
    955        "      <td>professional_fees</td>\n",
    956        "      <td>primary_source</td>\n",
    957        "    </tr>\n",
    958        "    <tr>\n",
    959        "      <th>2</th>\n",
    960        "      <td>teams</td>\n",
    961        "      <td>helps</td>\n",
    962        "      <td>vimovo_orange_book</td>\n",
    963        "      <td>checkcard</td>\n",
    964        "      <td>sourced</td>\n",
    965        "    </tr>\n",
    966        "    <tr>\n",
    967        "      <th>3</th>\n",
    968        "      <td>professionals</td>\n",
    969        "      <td>enable</td>\n",
    970        "      <td>millisecond</td>\n",
    971        "      <td>commissions</td>\n",
    972        "      <td>readily_available</td>\n",
    973        "    </tr>\n",
    974        "    <tr>\n",
    975        "      <th>4</th>\n",
    976        "      <td>staffed</td>\n",
    977        "      <td>allowing</td>\n",
    978        "      <td>saturdays</td>\n",
    979        "      <td>atm_debit_card</td>\n",
    980        "      <td>internally_generated</td>\n",
    981        "    </tr>\n",
    982        "    <tr>\n",
    983        "      <th>5</th>\n",
    984        "      <td>hiring</td>\n",
    985        "      <td>enabled</td>\n",
    986        "      <td>assets_liabilities_react_differently</td>\n",
    987        "      <td>gds_reservation_booking</td>\n",
    988        "      <td>generated</td>\n",
    989        "    </tr>\n",
    990        "    <tr>\n",
    991        "      <th>6</th>\n",
    992        "      <td>consultants</td>\n",
    993        "      <td>allow</td>\n",
    994        "      <td>twice_weekly</td>\n",
    995        "      <td>interchange_fees_swipe</td>\n",
    996        "      <td>biological_contaminants_pollen</td>\n",
    997        "    </tr>\n",
    998        "    <tr>\n",
    999        "      <th>7</th>\n",
   1000        "      <td>hired</td>\n",
   1001        "      <td>leverages</td>\n",
   1002        "      <td>day</td>\n",
   1003        "      <td>noticing</td>\n",
   1004        "      <td>repair_reconstruct_damaged</td>\n",
   1005        "    </tr>\n",
   1006        "    <tr>\n",
   1007        "      <th>8</th>\n",
   1008        "      <td>engineers</td>\n",
   1009        "      <td>lets</td>\n",
   1010        "      <td>weekdays</td>\n",
   1011        "      <td>nonsufficient</td>\n",
   1012        "      <td>alternative</td>\n",
   1013        "    </tr>\n",
   1014        "    <tr>\n",
   1015        "      <th>9</th>\n",
   1016        "      <td>salespeople</td>\n",
   1017        "      <td>easy</td>\n",
   1018        "      <td>uvb</td>\n",
   1019        "      <td>bno_usci_cper_usag</td>\n",
   1020        "      <td>znse</td>\n",
   1021        "    </tr>\n",
   1022        "  </tbody>\n",
   1023        "</table>\n",
   1024        "</div>"
   1025       ],
   1026       "text/plain": [
   1027        "           staff    enables                                 times                     fees                         sources\n",
   1028        "0      personnel     allows                                 twice                      fee                          source\n",
   1029        "1           team   enabling               standpoint_advantageous        professional_fees                  primary_source\n",
   1030        "2          teams      helps                    vimovo_orange_book                checkcard                         sourced\n",
   1031        "3  professionals     enable                           millisecond              commissions               readily_available\n",
   1032        "4        staffed   allowing                             saturdays           atm_debit_card            internally_generated\n",
   1033        "5         hiring    enabled  assets_liabilities_react_differently  gds_reservation_booking                       generated\n",
   1034        "6    consultants      allow                          twice_weekly   interchange_fees_swipe  biological_contaminants_pollen\n",
   1035        "7          hired  leverages                                   day                 noticing      repair_reconstruct_damaged\n",
   1036        "8      engineers       lets                              weekdays            nonsufficient                     alternative\n",
   1037        "9    salespeople       easy                                   uvb       bno_usci_cper_usag                            znse"
   1038       ]
   1039      },
   1040      "execution_count": 41,
   1041      "metadata": {},
   1042      "output_type": "execute_result"
   1043     }
   1044    ],
   1045    "source": [
   1046     "VALID_SET = 5  # Random set of words to get nearest neighbors for\n",
   1047     "VALID_WINDOW = 100  # Most frequent words to draw validation set from\n",
   1048     "valid_examples = np.random.choice(VALID_WINDOW, size=VALID_SET, replace=False)\n",
   1049     "similars = pd.DataFrame()\n",
   1050     "\n",
   1051     "for id in sorted(valid_examples):\n",
   1052     "    word = vocab.loc[id, 'token']\n",
   1053     "    similars[word] = [s[0] for s in model.wv.most_similar(word)]\n",
   1054     "similars"
   1055    ]
   1056   },
   1057   {
   1058    "cell_type": "markdown",
   1059    "metadata": {},
   1060    "source": [
   1061     "## Continue Training"
   1062    ]
   1063   },
   1064   {
   1065    "cell_type": "code",
   1066    "execution_count": null,
   1067    "metadata": {},
   1068    "outputs": [],
   1069    "source": [
   1070     "accuracies = (eval_analogies(model)\n",
   1071     "              .set_index('category')\n",
   1072     "              .average\n",
   1073     "              .to_frame('baseline'))"
   1074    ]
   1075   },
   1076   {
   1077    "cell_type": "code",
   1078    "execution_count": 76,
   1079    "metadata": {
   1080     "ExecuteTime": {
   1081      "end_time": "2018-12-08T21:26:29.866811Z",
   1082      "start_time": "2018-12-08T20:10:12.950824Z"
   1083     }
   1084    },
   1085    "outputs": [
   1086     {
   1087      "name": "stderr",
   1088      "output_type": "stream",
   1089      "text": [
   1090       "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/ipykernel_launcher.py:5: DeprecationWarning: Call to deprecated `accuracy` (Method will be removed in 4.0.0, use self.evaluate_word_analogies() instead).\n",
   1091       "  \"\"\"\n",
   1092       "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
   1093       "  if np.issubdtype(vec.dtype, np.int):\n"
   1094      ]
   1095     },
   1096     {
   1097      "name": "stdout",
   1098      "output_type": "stream",
   1099      "text": [
   1100       "1 | Duration: 464.0 | Accuracy: 28.93% \n",
   1101       "2 | Duration: 457.8 | Accuracy: 28.83% \n",
   1102       "3 | Duration: 459.2 | Accuracy: 28.97% \n",
   1103       "4 | Duration: 456.9 | Accuracy: 28.60% \n",
   1104       "5 | Duration: 457.4 | Accuracy: 29.69% \n",
   1105       "6 | Duration: 456.8 | Accuracy: 29.40% \n",
   1106       "7 | Duration: 457.7 | Accuracy: 29.91% \n",
   1107       "8 | Duration: 456.4 | Accuracy: 29.61% \n",
   1108       "9 | Duration: 456.1 | Accuracy: 29.37% \n",
   1109       "10 | Duration: 454.6 | Accuracy: 29.17% \n"
   1110      ]
   1111     }
   1112    ],
   1113    "source": [
   1114     "for i in range(1, 11):\n",
   1115     "    start = time()\n",
   1116     "    model.train(sentences, epochs=1, total_examples=model.corpus_count)\n",
   1117     "    accuracy = eval_analogies(model).set_index('category').average\n",
   1118     "    accuracies = accuracies.join(accuracy.to_frame(f'{n}'))\n",
   1119     "    print(f'{i} | Duration: {format_time(time() - start)} | Accuracy: {accuracy.total:.2%}')\n",
   1120     "    model.save(f'word2vec/models/word2vec_{i}.model')"
   1121    ]
   1122   },
   1123   {
   1124    "cell_type": "code",
   1125    "execution_count": null,
   1126    "metadata": {},
   1127    "outputs": [],
   1128    "source": [
   1129     "model.wv.save('word_vectors_final.bin')"
   1130    ]
   1131   }
   1132  ],
   1133  "metadata": {
   1134   "kernelspec": {
   1135    "display_name": "Python 3",
   1136    "language": "python",
   1137    "name": "python3"
   1138   },
   1139   "language_info": {
   1140    "codemirror_mode": {
   1141     "name": "ipython",
   1142     "version": 3
   1143    },
   1144    "file_extension": ".py",
   1145    "mimetype": "text/x-python",
   1146    "name": "python",
   1147    "nbconvert_exporter": "python",
   1148    "pygments_lexer": "ipython3",
   1149    "version": "3.6.8"
   1150   },
   1151   "toc": {
   1152    "base_numbering": 1,
   1153    "nav_menu": {},
   1154    "number_sections": true,
   1155    "sideBar": true,
   1156    "skip_h1_title": false,
   1157    "title_cell": "Table of Contents",
   1158    "title_sidebar": "Contents",
   1159    "toc_cell": false,
   1160    "toc_position": {},
   1161    "toc_section_display": true,
   1162    "toc_window_display": true
   1163   }
   1164  },
   1165  "nbformat": 4,
   1166  "nbformat_minor": 2
   1167 }