ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
word2vec.ipynb
(33952B)
1 {
2 "cells": [
3 {
4 "cell_type": "markdown",
5 "metadata": {},
6 "source": [
7 "## Imports & Settings"
8 ]
9 },
10 {
11 "cell_type": "code",
12 "execution_count": 1,
13 "metadata": {
14 "ExecuteTime": {
15 "end_time": "2018-12-08T23:57:04.619453Z",
16 "start_time": "2018-12-08T23:57:04.488154Z"
17 }
18 },
19 "outputs": [],
20 "source": [
21 "from pathlib import Path\n",
22 "from time import time\n",
23 "import warnings\n",
24 "from collections import Counter\n",
25 "import logging\n",
26 "from ast import literal_eval as make_tuple\n",
27 "import numpy as np\n",
28 "import pandas as pd\n",
29 "\n",
30 "from gensim.models import Word2Vec, KeyedVectors\n",
31 "from gensim.models.word2vec import LineSentence\n",
32 "import word2vec"
33 ]
34 },
35 {
36 "cell_type": "code",
37 "execution_count": 2,
38 "metadata": {
39 "ExecuteTime": {
40 "end_time": "2018-12-08T23:57:05.049257Z",
41 "start_time": "2018-12-08T23:57:05.040701Z"
42 }
43 },
44 "outputs": [],
45 "source": [
46 "pd.set_option('display.expand_frame_repr', False)\n",
47 "warnings.filterwarnings('ignore')\n",
48 "np.random.seed(42)"
49 ]
50 },
51 {
52 "cell_type": "code",
53 "execution_count": 3,
54 "metadata": {
55 "ExecuteTime": {
56 "end_time": "2018-12-08T23:57:05.244408Z",
57 "start_time": "2018-12-08T23:57:05.240318Z"
58 }
59 },
60 "outputs": [],
61 "source": [
62 "def format_time(t):\n",
63 " m, s = divmod(t, 60)\n",
64 " h, m = divmod(m, 60)\n",
65 " return '{:02.0f}:{:02.0f}:{:02.0f}'.format(h, m, s)"
66 ]
67 },
68 {
69 "cell_type": "markdown",
70 "metadata": {},
71 "source": [
72 "### Logging Setup"
73 ]
74 },
75 {
76 "cell_type": "code",
77 "execution_count": 4,
78 "metadata": {
79 "ExecuteTime": {
80 "end_time": "2018-12-08T23:57:06.423935Z",
81 "start_time": "2018-12-08T23:57:06.421773Z"
82 }
83 },
84 "outputs": [],
85 "source": [
86 "logging.basicConfig(\n",
87 " filename='logs/word2vec.log',\n",
88 " level=logging.DEBUG,\n",
89 " format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',\n",
90 " datefmt='%H:%M:%S')"
91 ]
92 },
93 {
94 "cell_type": "markdown",
95 "metadata": {},
96 "source": [
97 "## word2vec"
98 ]
99 },
100 {
101 "cell_type": "code",
102 "execution_count": 6,
103 "metadata": {
104 "ExecuteTime": {
105 "end_time": "2018-12-08T23:57:34.969991Z",
106 "start_time": "2018-12-08T23:57:34.967461Z"
107 }
108 },
109 "outputs": [],
110 "source": [
111 "analogies_path = Path().cwd().parent / 'data' / 'analogies' / 'analogies-en.txt'"
112 ]
113 },
114 {
115 "cell_type": "markdown",
116 "metadata": {},
117 "source": [
118 "### Set up Sentence Generator"
119 ]
120 },
121 {
122 "cell_type": "code",
123 "execution_count": 8,
124 "metadata": {
125 "ExecuteTime": {
126 "end_time": "2018-12-08T23:57:57.298178Z",
127 "start_time": "2018-12-08T23:57:57.289388Z"
128 }
129 },
130 "outputs": [],
131 "source": [
132 "NGRAMS = 2"
133 ]
134 },
135 {
136 "cell_type": "markdown",
137 "metadata": {},
138 "source": [
139 "To facilitate memory-efficient text ingestion, the LineSentence class creates a generator from individual sentences contained in the provided text file:"
140 ]
141 },
142 {
143 "cell_type": "code",
144 "execution_count": 9,
145 "metadata": {
146 "ExecuteTime": {
147 "end_time": "2018-12-08T23:57:58.496781Z",
148 "start_time": "2018-12-08T23:57:58.494515Z"
149 }
150 },
151 "outputs": [],
152 "source": [
153 "sentence_path = Path('data', 'ngrams', f'ngrams_{NGRAMS}.txt')\n",
154 "sentences = LineSentence(sentence_path)"
155 ]
156 },
157 {
158 "cell_type": "markdown",
159 "metadata": {},
160 "source": [
161 "### Train word2vec Model"
162 ]
163 },
164 {
165 "cell_type": "markdown",
166 "metadata": {},
167 "source": [
168 "The [gensim.models.word2vec](https://radimrehurek.com/gensim/models/word2vec.html) class implements the skipgram and CBOW architectures introduced above. The notebook [word2vec](../03_word2vec.ipynb) contains additional implementation detail."
169 ]
170 },
171 {
172 "cell_type": "code",
173 "execution_count": 10,
174 "metadata": {
175 "ExecuteTime": {
176 "end_time": "2018-12-09T00:09:31.218671Z",
177 "start_time": "2018-12-08T23:58:43.716464Z"
178 }
179 },
180 "outputs": [
181 {
182 "name": "stdout",
183 "output_type": "stream",
184 "text": [
185 "Duration: 00:10:47\n"
186 ]
187 }
188 ],
189 "source": [
190 "start = time()\n",
191 "model = Word2Vec(sentences,\n",
192 " sg=1, # 1 for skip-gram; otherwise CBOW\n",
193 " hs=0, # hierarchical softmax if 1, negative sampling if 0\n",
194 " size=300, # Vector dimensionality\n",
195 " window=3, # Max distance betw. current and predicted word\n",
196 " min_count=50, # Ignore words with lower frequency\n",
197 " negative=10, # noise word count for negative sampling\n",
198 " workers=8, # no threads \n",
199 " iter=1, # no epochs = iterations over corpus\n",
200 " alpha=0.025, # initial learning rate\n",
201 " min_alpha=0.0001 # final learning rate\n",
202 " ) \n",
203 "print('Duration:', format_time(time() - start))"
204 ]
205 },
206 {
207 "cell_type": "markdown",
208 "metadata": {},
209 "source": [
210 "### Persist model & vectors"
211 ]
212 },
213 {
214 "cell_type": "code",
215 "execution_count": 11,
216 "metadata": {
217 "ExecuteTime": {
218 "end_time": "2018-12-09T00:10:01.380925Z",
219 "start_time": "2018-12-09T00:10:01.143768Z"
220 }
221 },
222 "outputs": [],
223 "source": [
224 "model.save('models/baseline/word2vec.model')\n",
225 "model.wv.save('models/baseline/word_vectors.bin')"
226 ]
227 },
228 {
229 "cell_type": "markdown",
230 "metadata": {},
231 "source": [
232 "### Load model and vectors"
233 ]
234 },
235 {
236 "cell_type": "code",
237 "execution_count": 40,
238 "metadata": {
239 "ExecuteTime": {
240 "end_time": "2018-12-10T00:45:27.525905Z",
241 "start_time": "2018-12-10T00:45:27.171700Z"
242 }
243 },
244 "outputs": [],
245 "source": [
246 "model = Word2Vec.load('models/archive/word2vec.model')"
247 ]
248 },
249 {
250 "cell_type": "code",
251 "execution_count": 8,
252 "metadata": {
253 "ExecuteTime": {
254 "end_time": "2018-12-08T22:53:13.020767Z",
255 "start_time": "2018-12-08T22:53:12.843245Z"
256 }
257 },
258 "outputs": [],
259 "source": [
260 "wv = KeyedVectors.load('models/baseline/word_vectors.bin')"
261 ]
262 },
263 {
264 "cell_type": "markdown",
265 "metadata": {},
266 "source": [
267 "### Get vocabulary"
268 ]
269 },
270 {
271 "cell_type": "code",
272 "execution_count": 12,
273 "metadata": {
274 "ExecuteTime": {
275 "end_time": "2018-12-09T00:11:04.596716Z",
276 "start_time": "2018-12-09T00:11:04.539228Z"
277 }
278 },
279 "outputs": [],
280 "source": [
281 "vocab = []\n",
282 "for k, _ in model.wv.vocab.items():\n",
283 " v_ = model.wv.vocab[k]\n",
284 " vocab.append([k, v_.index, v_.count])"
285 ]
286 },
287 {
288 "cell_type": "code",
289 "execution_count": 13,
290 "metadata": {
291 "ExecuteTime": {
292 "end_time": "2018-12-09T00:11:04.905084Z",
293 "start_time": "2018-12-09T00:11:04.868230Z"
294 }
295 },
296 "outputs": [],
297 "source": [
298 "vocab = (pd.DataFrame(vocab, \n",
299 " columns=['token', 'idx', 'count'])\n",
300 " .sort_values('count', ascending=False))"
301 ]
302 },
303 {
304 "cell_type": "code",
305 "execution_count": 14,
306 "metadata": {
307 "ExecuteTime": {
308 "end_time": "2018-12-09T00:11:06.691657Z",
309 "start_time": "2018-12-09T00:11:06.679881Z"
310 }
311 },
312 "outputs": [
313 {
314 "name": "stdout",
315 "output_type": "stream",
316 "text": [
317 "<class 'pandas.core.frame.DataFrame'>\n",
318 "Int64Index: 50491 entries, 104 to 46372\n",
319 "Data columns (total 3 columns):\n",
320 "token 50491 non-null object\n",
321 "idx 50491 non-null int64\n",
322 "count 50491 non-null int64\n",
323 "dtypes: int64(2), object(1)\n",
324 "memory usage: 1.5+ MB\n"
325 ]
326 }
327 ],
328 "source": [
329 "vocab.info()"
330 ]
331 },
332 {
333 "cell_type": "code",
334 "execution_count": 15,
335 "metadata": {
336 "ExecuteTime": {
337 "end_time": "2018-12-09T00:11:07.220241Z",
338 "start_time": "2018-12-09T00:11:07.202935Z"
339 }
340 },
341 "outputs": [
342 {
343 "data": {
344 "text/html": [
345 "<div>\n",
346 "<style scoped>\n",
347 " .dataframe tbody tr th:only-of-type {\n",
348 " vertical-align: middle;\n",
349 " }\n",
350 "\n",
351 " .dataframe tbody tr th {\n",
352 " vertical-align: top;\n",
353 " }\n",
354 "\n",
355 " .dataframe thead th {\n",
356 " text-align: right;\n",
357 " }\n",
358 "</style>\n",
359 "<table border=\"1\" class=\"dataframe\">\n",
360 " <thead>\n",
361 " <tr style=\"text-align: right;\">\n",
362 " <th></th>\n",
363 " <th>token</th>\n",
364 " <th>idx</th>\n",
365 " <th>count</th>\n",
366 " </tr>\n",
367 " </thead>\n",
368 " <tbody>\n",
369 " <tr>\n",
370 " <th>104</th>\n",
371 " <td>million</td>\n",
372 " <td>0</td>\n",
373 " <td>2340243</td>\n",
374 " </tr>\n",
375 " <tr>\n",
376 " <th>0</th>\n",
377 " <td>business</td>\n",
378 " <td>1</td>\n",
379 " <td>1700662</td>\n",
380 " </tr>\n",
381 " <tr>\n",
382 " <th>66</th>\n",
383 " <td>december</td>\n",
384 " <td>2</td>\n",
385 " <td>1513533</td>\n",
386 " </tr>\n",
387 " <tr>\n",
388 " <th>627</th>\n",
389 " <td>company</td>\n",
390 " <td>3</td>\n",
391 " <td>1490752</td>\n",
392 " </tr>\n",
393 " <tr>\n",
394 " <th>477</th>\n",
395 " <td>products</td>\n",
396 " <td>4</td>\n",
397 " <td>1368711</td>\n",
398 " </tr>\n",
399 " <tr>\n",
400 " <th>1071</th>\n",
401 " <td>net</td>\n",
402 " <td>5</td>\n",
403 " <td>1253343</td>\n",
404 " </tr>\n",
405 " <tr>\n",
406 " <th>145</th>\n",
407 " <td>market</td>\n",
408 " <td>6</td>\n",
409 " <td>1149048</td>\n",
410 " </tr>\n",
411 " <tr>\n",
412 " <th>380</th>\n",
413 " <td>including</td>\n",
414 " <td>7</td>\n",
415 " <td>1110482</td>\n",
416 " </tr>\n",
417 " <tr>\n",
418 " <th>381</th>\n",
419 " <td>sales</td>\n",
420 " <td>8</td>\n",
421 " <td>1098312</td>\n",
422 " </tr>\n",
423 " <tr>\n",
424 " <th>60</th>\n",
425 " <td>costs</td>\n",
426 " <td>9</td>\n",
427 " <td>1020383</td>\n",
428 " </tr>\n",
429 " </tbody>\n",
430 "</table>\n",
431 "</div>"
432 ],
433 "text/plain": [
434 " token idx count\n",
435 "104 million 0 2340243\n",
436 "0 business 1 1700662\n",
437 "66 december 2 1513533\n",
438 "627 company 3 1490752\n",
439 "477 products 4 1368711\n",
440 "1071 net 5 1253343\n",
441 "145 market 6 1149048\n",
442 "380 including 7 1110482\n",
443 "381 sales 8 1098312\n",
444 "60 costs 9 1020383"
445 ]
446 },
447 "execution_count": 15,
448 "metadata": {},
449 "output_type": "execute_result"
450 }
451 ],
452 "source": [
453 "vocab.head(10)"
454 ]
455 },
456 {
457 "cell_type": "code",
458 "execution_count": 16,
459 "metadata": {
460 "ExecuteTime": {
461 "end_time": "2018-12-09T00:11:14.683574Z",
462 "start_time": "2018-12-09T00:11:14.648032Z"
463 }
464 },
465 "outputs": [
466 {
467 "data": {
468 "text/plain": [
469 "count 50491\n",
470 "mean 5110\n",
471 "std 37525\n",
472 "min 50\n",
473 "10% 61\n",
474 "20% 78\n",
475 "30.0% 102\n",
476 "40% 137\n",
477 "50% 195\n",
478 "60% 300\n",
479 "70% 522\n",
480 "80% 1164\n",
481 "90% 4578\n",
482 "max 2340243\n",
483 "Name: count, dtype: int64"
484 ]
485 },
486 "execution_count": 16,
487 "metadata": {},
488 "output_type": "execute_result"
489 }
490 ],
491 "source": [
492 "vocab['count'].describe(percentiles=np.arange(.1, 1, .1)).astype(int)"
493 ]
494 },
495 {
496 "cell_type": "markdown",
497 "metadata": {},
498 "source": [
499 "### Evaluate Analogies"
500 ]
501 },
502 {
503 "cell_type": "code",
504 "execution_count": 110,
505 "metadata": {
506 "ExecuteTime": {
507 "end_time": "2018-12-10T04:38:54.485888Z",
508 "start_time": "2018-12-10T04:38:54.482447Z"
509 }
510 },
511 "outputs": [],
512 "source": [
513 "def eval_analogies(w2v, max_vocab=15000):\n",
514 " accuracy = w2v.wv.accuracy(ANALOGIES_PATH,\n",
515 " restrict_vocab=15000,\n",
516 " case_insensitive=True)\n",
517 " return (pd.DataFrame([[c['section'],\n",
518 " len(c['correct']),\n",
519 " len(c['incorrect'])] for c in accuracy],\n",
520 " columns=['category', 'correct', 'incorrect'])\n",
521 " .assign(average=lambda x: \n",
522 " x.correct.div(x.correct.add(x.incorrect)))) "
523 ]
524 },
525 {
526 "cell_type": "code",
527 "execution_count": 52,
528 "metadata": {
529 "ExecuteTime": {
530 "end_time": "2018-12-08T23:21:32.500459Z",
531 "start_time": "2018-12-08T23:21:32.498477Z"
532 }
533 },
534 "outputs": [],
535 "source": [
536 "def total_accuracy(w2v):\n",
537 " df = eval_analogies(w2v)\n",
538 " return df.loc[df.category == 'total', ['correct', 'incorrect', 'average']].squeeze().tolist()"
539 ]
540 },
541 {
542 "cell_type": "code",
543 "execution_count": 42,
544 "metadata": {
545 "ExecuteTime": {
546 "end_time": "2018-12-10T00:45:44.852024Z",
547 "start_time": "2018-12-10T00:45:38.732034Z"
548 }
549 },
550 "outputs": [
551 {
552 "data": {
553 "text/html": [
554 "<div>\n",
555 "<style scoped>\n",
556 " .dataframe tbody tr th:only-of-type {\n",
557 " vertical-align: middle;\n",
558 " }\n",
559 "\n",
560 " .dataframe tbody tr th {\n",
561 " vertical-align: top;\n",
562 " }\n",
563 "\n",
564 " .dataframe thead th {\n",
565 " text-align: right;\n",
566 " }\n",
567 "</style>\n",
568 "<table border=\"1\" class=\"dataframe\">\n",
569 " <thead>\n",
570 " <tr style=\"text-align: right;\">\n",
571 " <th></th>\n",
572 " <th>category</th>\n",
573 " <th>correct</th>\n",
574 " <th>incorrect</th>\n",
575 " <th>average</th>\n",
576 " </tr>\n",
577 " </thead>\n",
578 " <tbody>\n",
579 " <tr>\n",
580 " <th>0</th>\n",
581 " <td>capital-common-countries</td>\n",
582 " <td>2</td>\n",
583 " <td>4</td>\n",
584 " <td>0.333333</td>\n",
585 " </tr>\n",
586 " <tr>\n",
587 " <th>1</th>\n",
588 " <td>capital-world</td>\n",
589 " <td>0</td>\n",
590 " <td>0</td>\n",
591 " <td>0.000000</td>\n",
592 " </tr>\n",
593 " <tr>\n",
594 " <th>2</th>\n",
595 " <td>city-in-state</td>\n",
596 " <td>140</td>\n",
597 " <td>390</td>\n",
598 " <td>0.264151</td>\n",
599 " </tr>\n",
600 " <tr>\n",
601 " <th>3</th>\n",
602 " <td>currency</td>\n",
603 " <td>2</td>\n",
604 " <td>26</td>\n",
605 " <td>0.071429</td>\n",
606 " </tr>\n",
607 " <tr>\n",
608 " <th>4</th>\n",
609 " <td>family</td>\n",
610 " <td>0</td>\n",
611 " <td>0</td>\n",
612 " <td>0.000000</td>\n",
613 " </tr>\n",
614 " <tr>\n",
615 " <th>5</th>\n",
616 " <td>gram1-adjective-to-adverb</td>\n",
617 " <td>48</td>\n",
618 " <td>134</td>\n",
619 " <td>0.263736</td>\n",
620 " </tr>\n",
621 " <tr>\n",
622 " <th>6</th>\n",
623 " <td>gram2-opposite</td>\n",
624 " <td>23</td>\n",
625 " <td>67</td>\n",
626 " <td>0.255556</td>\n",
627 " </tr>\n",
628 " <tr>\n",
629 " <th>7</th>\n",
630 " <td>gram3-comparative</td>\n",
631 " <td>240</td>\n",
632 " <td>222</td>\n",
633 " <td>0.519481</td>\n",
634 " </tr>\n",
635 " <tr>\n",
636 " <th>8</th>\n",
637 " <td>gram4-superlative</td>\n",
638 " <td>19</td>\n",
639 " <td>53</td>\n",
640 " <td>0.263889</td>\n",
641 " </tr>\n",
642 " <tr>\n",
643 " <th>9</th>\n",
644 " <td>gram5-present-participle</td>\n",
645 " <td>90</td>\n",
646 " <td>182</td>\n",
647 " <td>0.330882</td>\n",
648 " </tr>\n",
649 " <tr>\n",
650 " <th>10</th>\n",
651 " <td>gram6-nationality-adjective</td>\n",
652 " <td>250</td>\n",
653 " <td>130</td>\n",
654 " <td>0.657895</td>\n",
655 " </tr>\n",
656 " <tr>\n",
657 " <th>11</th>\n",
658 " <td>gram7-past-tense</td>\n",
659 " <td>94</td>\n",
660 " <td>286</td>\n",
661 " <td>0.247368</td>\n",
662 " </tr>\n",
663 " <tr>\n",
664 " <th>12</th>\n",
665 " <td>gram8-plural</td>\n",
666 " <td>87</td>\n",
667 " <td>69</td>\n",
668 " <td>0.557692</td>\n",
669 " </tr>\n",
670 " <tr>\n",
671 " <th>13</th>\n",
672 " <td>gram9-plural-verbs</td>\n",
673 " <td>72</td>\n",
674 " <td>138</td>\n",
675 " <td>0.342857</td>\n",
676 " </tr>\n",
677 " <tr>\n",
678 " <th>14</th>\n",
679 " <td>total</td>\n",
680 " <td>1067</td>\n",
681 " <td>1701</td>\n",
682 " <td>0.385477</td>\n",
683 " </tr>\n",
684 " </tbody>\n",
685 "</table>\n",
686 "</div>"
687 ],
688 "text/plain": [
689 " category correct incorrect average\n",
690 "0 capital-common-countries 2 4 0.333333\n",
691 "1 capital-world 0 0 0.000000\n",
692 "2 city-in-state 140 390 0.264151\n",
693 "3 currency 2 26 0.071429\n",
694 "4 family 0 0 0.000000\n",
695 "5 gram1-adjective-to-adverb 48 134 0.263736\n",
696 "6 gram2-opposite 23 67 0.255556\n",
697 "7 gram3-comparative 240 222 0.519481\n",
698 "8 gram4-superlative 19 53 0.263889\n",
699 "9 gram5-present-participle 90 182 0.330882\n",
700 "10 gram6-nationality-adjective 250 130 0.657895\n",
701 "11 gram7-past-tense 94 286 0.247368\n",
702 "12 gram8-plural 87 69 0.557692\n",
703 "13 gram9-plural-verbs 72 138 0.342857\n",
704 "14 total 1067 1701 0.385477"
705 ]
706 },
707 "execution_count": 42,
708 "metadata": {},
709 "output_type": "execute_result"
710 }
711 ],
712 "source": [
713 "accuracy = eval_analogies(model)\n",
714 "accuracy"
715 ]
716 },
717 {
718 "cell_type": "markdown",
719 "metadata": {},
720 "source": [
721 "### Validate Vector Arithmetic"
722 ]
723 },
724 {
725 "cell_type": "code",
726 "execution_count": 105,
727 "metadata": {
728 "ExecuteTime": {
729 "end_time": "2018-12-10T01:00:35.772447Z",
730 "start_time": "2018-12-10T01:00:35.756869Z"
731 }
732 },
733 "outputs": [
734 {
735 "data": {
736 "text/html": [
737 "<div>\n",
738 "<style scoped>\n",
739 " .dataframe tbody tr th:only-of-type {\n",
740 " vertical-align: middle;\n",
741 " }\n",
742 "\n",
743 " .dataframe tbody tr th {\n",
744 " vertical-align: top;\n",
745 " }\n",
746 "\n",
747 " .dataframe thead th {\n",
748 " text-align: right;\n",
749 " }\n",
750 "</style>\n",
751 "<table border=\"1\" class=\"dataframe\">\n",
752 " <thead>\n",
753 " <tr style=\"text-align: right;\">\n",
754 " <th></th>\n",
755 " <th>0</th>\n",
756 " <th>1</th>\n",
757 " <th>2</th>\n",
758 " <th>3</th>\n",
759 " </tr>\n",
760 " </thead>\n",
761 " <tbody>\n",
762 " <tr>\n",
763 " <th>0</th>\n",
764 " <td>:</td>\n",
765 " <td>capital-common-countries</td>\n",
766 " <td>NaN</td>\n",
767 " <td>NaN</td>\n",
768 " </tr>\n",
769 " <tr>\n",
770 " <th>1</th>\n",
771 " <td>athens</td>\n",
772 " <td>greece</td>\n",
773 " <td>baghdad</td>\n",
774 " <td>iraq</td>\n",
775 " </tr>\n",
776 " <tr>\n",
777 " <th>2</th>\n",
778 " <td>athens</td>\n",
779 " <td>greece</td>\n",
780 " <td>bangkok</td>\n",
781 " <td>thailand</td>\n",
782 " </tr>\n",
783 " <tr>\n",
784 " <th>3</th>\n",
785 " <td>athens</td>\n",
786 " <td>greece</td>\n",
787 " <td>beijing</td>\n",
788 " <td>china</td>\n",
789 " </tr>\n",
790 " <tr>\n",
791 " <th>4</th>\n",
792 " <td>athens</td>\n",
793 " <td>greece</td>\n",
794 " <td>berlin</td>\n",
795 " <td>germany</td>\n",
796 " </tr>\n",
797 " </tbody>\n",
798 "</table>\n",
799 "</div>"
800 ],
801 "text/plain": [
802 " 0 1 2 3\n",
803 "0 : capital-common-countries NaN NaN\n",
804 "1 athens greece baghdad iraq\n",
805 "2 athens greece bangkok thailand\n",
806 "3 athens greece beijing china\n",
807 "4 athens greece berlin germany"
808 ]
809 },
810 "execution_count": 105,
811 "metadata": {},
812 "output_type": "execute_result"
813 }
814 ],
815 "source": [
816 "pd.read_csv(ANALOGIES_PATH, header=None, sep=' ').head()"
817 ]
818 },
819 {
820 "cell_type": "code",
821 "execution_count": 112,
822 "metadata": {
823 "ExecuteTime": {
824 "end_time": "2018-12-10T08:11:19.340922Z",
825 "start_time": "2018-12-10T08:11:19.334225Z"
826 }
827 },
828 "outputs": [
829 {
830 "name": "stdout",
831 "output_type": "stream",
832 "text": [
833 " term similarity\n",
834 "0 android 0.600454\n",
835 "1 smartphone 0.581685\n",
836 "2 app 0.559129\n",
837 "3 smartphones 0.533848\n",
838 "4 smartphones_tablets 0.526129\n",
839 "5 handsets 0.514813\n",
840 "6 smart_phones 0.512868\n",
841 "7 apple 0.507795\n",
842 "8 apps 0.505517\n",
843 "9 handset 0.491526\n"
844 ]
845 }
846 ],
847 "source": [
848 "sims=model.wv.most_similar(positive=['iphone'], \n",
849 " restrict_vocab=15000)\n",
850 "print(pd.DataFrame(sims, columns=['term', 'similarity']))"
851 ]
852 },
853 {
854 "cell_type": "code",
855 "execution_count": 113,
856 "metadata": {
857 "ExecuteTime": {
858 "end_time": "2018-12-10T08:14:19.395370Z",
859 "start_time": "2018-12-10T08:14:19.381754Z"
860 }
861 },
862 "outputs": [
863 {
864 "name": "stdout",
865 "output_type": "stream",
866 "text": [
867 " term similarity\n",
868 "0 united_kingdom 0.606630\n",
869 "1 germany 0.585644\n",
870 "2 netherlands 0.578868\n",
871 "3 italy 0.547168\n",
872 "4 india 0.545213\n",
873 "5 spain 0.539029\n",
874 "6 singapore 0.535106\n",
875 "7 australia 0.525464\n",
876 "8 belgium 0.523677\n",
877 "9 sweden 0.510462\n"
878 ]
879 }
880 ],
881 "source": [
882 "analogy = model.wv.most_similar(positive=['france', 'london'], \n",
883 " negative=['paris'], \n",
884 " restrict_vocab=15000)\n",
885 "print(pd.DataFrame(analogy, columns=['term', 'similarity']))"
886 ]
887 },
888 {
889 "cell_type": "markdown",
890 "metadata": {},
891 "source": [
892 "### Check similarity for random words"
893 ]
894 },
895 {
896 "cell_type": "code",
897 "execution_count": 41,
898 "metadata": {
899 "ExecuteTime": {
900 "end_time": "2018-12-08T23:10:41.702789Z",
901 "start_time": "2018-12-08T23:10:41.640280Z"
902 }
903 },
904 "outputs": [
905 {
906 "name": "stderr",
907 "output_type": "stream",
908 "text": [
909 "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
910 " if np.issubdtype(vec.dtype, np.int):\n"
911 ]
912 },
913 {
914 "data": {
915 "text/html": [
916 "<div>\n",
917 "<style scoped>\n",
918 " .dataframe tbody tr th:only-of-type {\n",
919 " vertical-align: middle;\n",
920 " }\n",
921 "\n",
922 " .dataframe tbody tr th {\n",
923 " vertical-align: top;\n",
924 " }\n",
925 "\n",
926 " .dataframe thead th {\n",
927 " text-align: right;\n",
928 " }\n",
929 "</style>\n",
930 "<table border=\"1\" class=\"dataframe\">\n",
931 " <thead>\n",
932 " <tr style=\"text-align: right;\">\n",
933 " <th></th>\n",
934 " <th>staff</th>\n",
935 " <th>enables</th>\n",
936 " <th>times</th>\n",
937 " <th>fees</th>\n",
938 " <th>sources</th>\n",
939 " </tr>\n",
940 " </thead>\n",
941 " <tbody>\n",
942 " <tr>\n",
943 " <th>0</th>\n",
944 " <td>personnel</td>\n",
945 " <td>allows</td>\n",
946 " <td>twice</td>\n",
947 " <td>fee</td>\n",
948 " <td>source</td>\n",
949 " </tr>\n",
950 " <tr>\n",
951 " <th>1</th>\n",
952 " <td>team</td>\n",
953 " <td>enabling</td>\n",
954 " <td>standpoint_advantageous</td>\n",
955 " <td>professional_fees</td>\n",
956 " <td>primary_source</td>\n",
957 " </tr>\n",
958 " <tr>\n",
959 " <th>2</th>\n",
960 " <td>teams</td>\n",
961 " <td>helps</td>\n",
962 " <td>vimovo_orange_book</td>\n",
963 " <td>checkcard</td>\n",
964 " <td>sourced</td>\n",
965 " </tr>\n",
966 " <tr>\n",
967 " <th>3</th>\n",
968 " <td>professionals</td>\n",
969 " <td>enable</td>\n",
970 " <td>millisecond</td>\n",
971 " <td>commissions</td>\n",
972 " <td>readily_available</td>\n",
973 " </tr>\n",
974 " <tr>\n",
975 " <th>4</th>\n",
976 " <td>staffed</td>\n",
977 " <td>allowing</td>\n",
978 " <td>saturdays</td>\n",
979 " <td>atm_debit_card</td>\n",
980 " <td>internally_generated</td>\n",
981 " </tr>\n",
982 " <tr>\n",
983 " <th>5</th>\n",
984 " <td>hiring</td>\n",
985 " <td>enabled</td>\n",
986 " <td>assets_liabilities_react_differently</td>\n",
987 " <td>gds_reservation_booking</td>\n",
988 " <td>generated</td>\n",
989 " </tr>\n",
990 " <tr>\n",
991 " <th>6</th>\n",
992 " <td>consultants</td>\n",
993 " <td>allow</td>\n",
994 " <td>twice_weekly</td>\n",
995 " <td>interchange_fees_swipe</td>\n",
996 " <td>biological_contaminants_pollen</td>\n",
997 " </tr>\n",
998 " <tr>\n",
999 " <th>7</th>\n",
1000 " <td>hired</td>\n",
1001 " <td>leverages</td>\n",
1002 " <td>day</td>\n",
1003 " <td>noticing</td>\n",
1004 " <td>repair_reconstruct_damaged</td>\n",
1005 " </tr>\n",
1006 " <tr>\n",
1007 " <th>8</th>\n",
1008 " <td>engineers</td>\n",
1009 " <td>lets</td>\n",
1010 " <td>weekdays</td>\n",
1011 " <td>nonsufficient</td>\n",
1012 " <td>alternative</td>\n",
1013 " </tr>\n",
1014 " <tr>\n",
1015 " <th>9</th>\n",
1016 " <td>salespeople</td>\n",
1017 " <td>easy</td>\n",
1018 " <td>uvb</td>\n",
1019 " <td>bno_usci_cper_usag</td>\n",
1020 " <td>znse</td>\n",
1021 " </tr>\n",
1022 " </tbody>\n",
1023 "</table>\n",
1024 "</div>"
1025 ],
1026 "text/plain": [
1027 " staff enables times fees sources\n",
1028 "0 personnel allows twice fee source\n",
1029 "1 team enabling standpoint_advantageous professional_fees primary_source\n",
1030 "2 teams helps vimovo_orange_book checkcard sourced\n",
1031 "3 professionals enable millisecond commissions readily_available\n",
1032 "4 staffed allowing saturdays atm_debit_card internally_generated\n",
1033 "5 hiring enabled assets_liabilities_react_differently gds_reservation_booking generated\n",
1034 "6 consultants allow twice_weekly interchange_fees_swipe biological_contaminants_pollen\n",
1035 "7 hired leverages day noticing repair_reconstruct_damaged\n",
1036 "8 engineers lets weekdays nonsufficient alternative\n",
1037 "9 salespeople easy uvb bno_usci_cper_usag znse"
1038 ]
1039 },
1040 "execution_count": 41,
1041 "metadata": {},
1042 "output_type": "execute_result"
1043 }
1044 ],
1045 "source": [
1046 "VALID_SET = 5 # Random set of words to get nearest neighbors for\n",
1047 "VALID_WINDOW = 100 # Most frequent words to draw validation set from\n",
1048 "valid_examples = np.random.choice(VALID_WINDOW, size=VALID_SET, replace=False)\n",
1049 "similars = pd.DataFrame()\n",
1050 "\n",
1051 "for id in sorted(valid_examples):\n",
1052 " word = vocab.loc[id, 'token']\n",
1053 " similars[word] = [s[0] for s in model.wv.most_similar(word)]\n",
1054 "similars"
1055 ]
1056 },
1057 {
1058 "cell_type": "markdown",
1059 "metadata": {},
1060 "source": [
1061 "## Continue Training"
1062 ]
1063 },
1064 {
1065 "cell_type": "code",
1066 "execution_count": null,
1067 "metadata": {},
1068 "outputs": [],
1069 "source": [
1070 "accuracies = (eval_analogies(model)\n",
1071 " .set_index('category')\n",
1072 " .average\n",
1073 " .to_frame('baseline'))"
1074 ]
1075 },
1076 {
1077 "cell_type": "code",
1078 "execution_count": 76,
1079 "metadata": {
1080 "ExecuteTime": {
1081 "end_time": "2018-12-08T21:26:29.866811Z",
1082 "start_time": "2018-12-08T20:10:12.950824Z"
1083 }
1084 },
1085 "outputs": [
1086 {
1087 "name": "stderr",
1088 "output_type": "stream",
1089 "text": [
1090 "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/ipykernel_launcher.py:5: DeprecationWarning: Call to deprecated `accuracy` (Method will be removed in 4.0.0, use self.evaluate_word_analogies() instead).\n",
1091 " \"\"\"\n",
1092 "/home/stefan/.pyenv/versions/at-3.6/lib/python3.6/site-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
1093 " if np.issubdtype(vec.dtype, np.int):\n"
1094 ]
1095 },
1096 {
1097 "name": "stdout",
1098 "output_type": "stream",
1099 "text": [
1100 "1 | Duration: 464.0 | Accuracy: 28.93% \n",
1101 "2 | Duration: 457.8 | Accuracy: 28.83% \n",
1102 "3 | Duration: 459.2 | Accuracy: 28.97% \n",
1103 "4 | Duration: 456.9 | Accuracy: 28.60% \n",
1104 "5 | Duration: 457.4 | Accuracy: 29.69% \n",
1105 "6 | Duration: 456.8 | Accuracy: 29.40% \n",
1106 "7 | Duration: 457.7 | Accuracy: 29.91% \n",
1107 "8 | Duration: 456.4 | Accuracy: 29.61% \n",
1108 "9 | Duration: 456.1 | Accuracy: 29.37% \n",
1109 "10 | Duration: 454.6 | Accuracy: 29.17% \n"
1110 ]
1111 }
1112 ],
1113 "source": [
1114 "for i in range(1, 11):\n",
1115 " start = time()\n",
1116 " model.train(sentences, epochs=1, total_examples=model.corpus_count)\n",
1117 " accuracy = eval_analogies(model).set_index('category').average\n",
1118 " accuracies = accuracies.join(accuracy.to_frame(f'{n}'))\n",
1119 " print(f'{i} | Duration: {format_time(time() - start)} | Accuracy: {accuracy.total:.2%}')\n",
1120 " model.save(f'word2vec/models/word2vec_{i}.model')"
1121 ]
1122 },
1123 {
1124 "cell_type": "code",
1125 "execution_count": null,
1126 "metadata": {},
1127 "outputs": [],
1128 "source": [
1129 "model.wv.save('word_vectors_final.bin')"
1130 ]
1131 }
1132 ],
1133 "metadata": {
1134 "kernelspec": {
1135 "display_name": "Python 3",
1136 "language": "python",
1137 "name": "python3"
1138 },
1139 "language_info": {
1140 "codemirror_mode": {
1141 "name": "ipython",
1142 "version": 3
1143 },
1144 "file_extension": ".py",
1145 "mimetype": "text/x-python",
1146 "name": "python",
1147 "nbconvert_exporter": "python",
1148 "pygments_lexer": "ipython3",
1149 "version": "3.6.8"
1150 },
1151 "toc": {
1152 "base_numbering": 1,
1153 "nav_menu": {},
1154 "number_sections": true,
1155 "sideBar": true,
1156 "skip_h1_title": false,
1157 "title_cell": "Table of Contents",
1158 "title_sidebar": "Contents",
1159 "toc_cell": false,
1160 "toc_position": {},
1161 "toc_section_display": true,
1162 "toc_window_display": true
1163 }
1164 },
1165 "nbformat": 4,
1166 "nbformat_minor": 2
1167 }