ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

deep_q_network.py

(1117B)


      1 from __future__ import print_function
      2 import numpy as np
      3 from mlfromscratch.utils import to_categorical
      4 from mlfromscratch.deep_learning.optimizers import Adam
      5 from mlfromscratch.deep_learning.loss_functions import SquareLoss
      6 from mlfromscratch.deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
      7 from mlfromscratch.deep_learning import NeuralNetwork
      8 from mlfromscratch.reinforcement_learning import DeepQNetwork
      9 
     10 
     11 def main():
     12     dqn = DeepQNetwork(env_name='CartPole-v1',
     13                         epsilon=0.9, 
     14                         gamma=0.8, 
     15                         decay_rate=0.005, 
     16                         min_epsilon=0.1)
     17 
     18     # Model builder
     19     def model(n_inputs, n_outputs):    
     20         clf = NeuralNetwork(optimizer=Adam(), loss=SquareLoss)
     21         clf.add(Dense(64, input_shape=(n_inputs,)))
     22         clf.add(Activation('relu'))
     23         clf.add(Dense(n_outputs))
     24         return clf
     25 
     26     dqn.set_model(model)
     27 
     28     print ()
     29     dqn.model.summary(name="Deep Q-Network")
     30 
     31     dqn.train(n_epochs=500)
     32     dqn.play(n_epochs=100)
     33 
     34 if __name__ == "__main__":
     35     main()