ml-finance-python
python scripts for finance machine learning
git clone https://9o.is/git/ml-finance-python.git
deep_q_network.py
(1117B)
1 from __future__ import print_function
2 import numpy as np
3 from mlfromscratch.utils import to_categorical
4 from mlfromscratch.deep_learning.optimizers import Adam
5 from mlfromscratch.deep_learning.loss_functions import SquareLoss
6 from mlfromscratch.deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
7 from mlfromscratch.deep_learning import NeuralNetwork
8 from mlfromscratch.reinforcement_learning import DeepQNetwork
9
10
11 def main():
12 dqn = DeepQNetwork(env_name='CartPole-v1',
13 epsilon=0.9,
14 gamma=0.8,
15 decay_rate=0.005,
16 min_epsilon=0.1)
17
18 # Model builder
19 def model(n_inputs, n_outputs):
20 clf = NeuralNetwork(optimizer=Adam(), loss=SquareLoss)
21 clf.add(Dense(64, input_shape=(n_inputs,)))
22 clf.add(Activation('relu'))
23 clf.add(Dense(n_outputs))
24 return clf
25
26 dqn.set_model(model)
27
28 print ()
29 dqn.model.summary(name="Deep Q-Network")
30
31 dqn.train(n_epochs=500)
32 dqn.play(n_epochs=100)
33
34 if __name__ == "__main__":
35 main()