ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
run.py
(2173B)
1 # gensim modules
2 from gensim import utils
3 from gensim.models.doc2vec import LabeledSentence
4 from gensim.models import Doc2Vec
5
6 # numpy
7 import numpy
8
9 # shuffle
10 from random import shuffle
11
12 # logging
13 import logging
14 import os.path
15 import sys
16 import cPickle as pickle
17
18 program = os.path.basename(sys.argv[0])
19 logger = logging.getLogger(program)
20 logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s')
21 logging.root.setLevel(level=logging.INFO)
22 logger.info("running %s" % ' '.join(sys.argv))
23
24 class LabeledLineSentence(object):
25
26 def __init__(self, sources):
27 self.sources = sources
28
29 flipped = {}
30
31 # make sure that keys are unique
32 for key, value in sources.items():
33 if value not in flipped:
34 flipped[value] = [key]
35 else:
36 raise Exception('Non-unique prefix encountered')
37
38 def __iter__(self):
39 for source, prefix in self.sources.items():
40 with utils.smart_open(source) as fin:
41 for item_no, line in enumerate(fin):
42 yield LabeledSentence(utils.to_unicode(line).split(), [prefix + '_%s' % item_no])
43
44 def to_array(self):
45 self.sentences = []
46 for source, prefix in self.sources.items():
47 with utils.smart_open(source) as fin:
48 for item_no, line in enumerate(fin):
49 self.sentences.append(LabeledSentence(
50 utils.to_unicode(line).split(), [prefix + '_%s' % item_no]))
51 return self.sentences
52
53 def sentences_perm(self):
54 shuffle(self.sentences)
55 return self.sentences
56
57 sources = {'test-neg.txt':'TEST_NEG', 'test-pos.txt':'TEST_POS', 'train-neg.txt':'TRAIN_NEG', 'train-pos.txt':'TRAIN_POS', 'train-unsup.txt':'TRAIN_UNS'}
58
59 sentences = LabeledLineSentence(sources)
60
61 model = Doc2Vec(min_count=1, window=10, size=100, sample=1e-4, negative=5, workers=7)
62
63 model.build_vocab(sentences.to_array())
64
65 for epoch in range(50):
66 logger.info('Epoch %d' % epoch)
67 model.train(sentences.sentences_perm(),
68 total_examples=model.corpus_count,
69 epochs=model.iter,
70 )
71
72 model.save('./imdb.d2v')