ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

restricted_boltzmann_machine.py

(1821B)


      1 import logging
      2 
      3 import numpy as np
      4 from sklearn import datasets
      5 from sklearn.datasets import fetch_mldata
      6 import matplotlib.pyplot as plt
      7 
      8 from mlfromscratch.unsupervised_learning import RBM
      9 
     10 logging.basicConfig(level=logging.DEBUG)
     11 
     12 def main():
     13 
     14     mnist = fetch_mldata('MNIST original')
     15 
     16     X = mnist.data / 255.0
     17     y = mnist.target
     18 
     19     # Select the samples of the digit 2
     20     X = X[y == 2]
     21 
     22     # Limit dataset to 500 samples
     23     idx = np.random.choice(range(X.shape[0]), size=500, replace=False)
     24     X = X[idx]
     25 
     26     rbm = RBM(n_hidden=50, n_iterations=200, batch_size=25, learning_rate=0.001)
     27     rbm.fit(X)
     28 
     29     # Training error plot
     30     training, = plt.plot(range(len(rbm.training_errors)), rbm.training_errors, label="Training Error")
     31     plt.legend(handles=[training])
     32     plt.title("Error Plot")
     33     plt.ylabel('Error')
     34     plt.xlabel('Iterations')
     35     plt.show()
     36 
     37     # Get the images that were reconstructed during training
     38     gen_imgs = rbm.training_reconstructions
     39 
     40     # Plot the reconstructed images during the first iteration
     41     fig, axs = plt.subplots(5, 5)
     42     plt.suptitle("Restricted Boltzmann Machine - First Iteration")
     43     cnt = 0
     44     for i in range(5):
     45         for j in range(5):
     46             axs[i,j].imshow(gen_imgs[0][cnt].reshape((28, 28)), cmap='gray')
     47             axs[i,j].axis('off')
     48             cnt += 1
     49     fig.savefig("rbm_first.png")
     50     plt.close()
     51 
     52     # Plot the images during the last iteration
     53     fig, axs = plt.subplots(5, 5)
     54     plt.suptitle("Restricted Boltzmann Machine - Last Iteration")
     55     cnt = 0
     56     for i in range(5):
     57         for j in range(5):
     58             axs[i,j].imshow(gen_imgs[-1][cnt].reshape((28, 28)), cmap='gray')
     59             axs[i,j].axis('off')
     60             cnt += 1
     61     fig.savefig("rbm_last.png")
     62     plt.close()
     63 
     64 
     65 if __name__ == "__main__":
     66     main()