ml-finance-python

python scripts for finance machine learning

git clone https://9o.is/git/ml-finance-python.git

01_gridworld_dynamic_programming.ipynb

(33250B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "markdown",
      5    "metadata": {},
      6    "source": [
      7     "# Dynamic programming: Value and Policy Iteration"
      8    ]
      9   },
     10   {
     11    "cell_type": "markdown",
     12    "metadata": {},
     13    "source": [
     14     "In this section, we will apply value and policy iteration to a toy environment that consists of a 3 x 4 grid that's depicted in the following diagram with the following features:\n",
     15     "\n",
     16     "- **States**: 11 states represented as two-dimensional coordinates. One field is not accessible and the top two states in the rightmost column are terminal, that is, they end the episode.\n",
     17     "- **Actions**: Movements on each step, that is, up, down, left, and right. The environment is randomized so that actions can have unintended outcomes. For each action, there is an 80% probability to move to the expected state, and 10% probability to move in an adjacent direction (for example, right or left instead of up or up or down instead of right).\n",
     18     "- **Rewards**: As depicted in the right-hand side panel, each state results in -.02, except for the +1/-1 rewards in the terminal states:"
     19    ]
     20   },
     21   {
     22    "cell_type": "markdown",
     23    "metadata": {},
     24    "source": [
     25     "<img src=\"img/03_mdp.png\" width=\"500\">"
     26    ]
     27   },
     28   {
     29    "cell_type": "markdown",
     30    "metadata": {},
     31    "source": [
     32     "The right panel of the preceding GridWorld diagram shows the optimal value estimate that's produced by Value Iteration and the corresponding greedy policy. The negative rewards, combined with the uncertainty in the environment, produce an optimal policy that involves moving away from the negative terminal state.\n",
     33     "\n",
     34     "The results are sensitive to both the rewards and the discount factor. The cost of the negative state affects the policy in the surrounding fields, and you should modify the example in the corresponding notebook to identify threshold levels that alter the optimal action selection."
     35    ]
     36   },
     37   {
     38    "cell_type": "markdown",
     39    "metadata": {},
     40    "source": [
     41     "## Imports & Settings"
     42    ]
     43   },
     44   {
     45    "cell_type": "code",
     46    "execution_count": 1,
     47    "metadata": {},
     48    "outputs": [],
     49    "source": [
     50     "%matplotlib inline\n",
     51     "\n",
     52     "from pathlib import Path\n",
     53     "from time import process_time\n",
     54     "import numpy as np\n",
     55     "import pandas as pd\n",
     56     "from mdptoolbox import mdp\n",
     57     "from itertools import product\n",
     58     "import gym"
     59    ]
     60   },
     61   {
     62    "cell_type": "markdown",
     63    "metadata": {},
     64    "source": [
     65     "## Set up Gridworld"
     66    ]
     67   },
     68   {
     69    "cell_type": "markdown",
     70    "metadata": {},
     71    "source": [
     72     "### States, Actions and Rewards"
     73    ]
     74   },
     75   {
     76    "cell_type": "markdown",
     77    "metadata": {},
     78    "source": [
     79     "We will begin by defining the environment parameters:"
     80    ]
     81   },
     82   {
     83    "cell_type": "code",
     84    "execution_count": 2,
     85    "metadata": {},
     86    "outputs": [],
     87    "source": [
     88     "grid_size = (3, 4)\n",
     89     "blocked_cell = (1, 1)\n",
     90     "baseline_reward = -0.02\n",
     91     "absorbing_cells = {(0, 3): 1, (1, 3): -1}"
     92    ]
     93   },
     94   {
     95    "cell_type": "code",
     96    "execution_count": 3,
     97    "metadata": {},
     98    "outputs": [],
     99    "source": [
    100     "actions = ['L', 'U', 'R', 'D']\n",
    101     "num_actions = len(actions)\n",
    102     "probs = [.1, .8, .1, 0]"
    103    ]
    104   },
    105   {
    106    "cell_type": "markdown",
    107    "metadata": {},
    108    "source": [
    109     "We will frequently need to convert between one-dimensional and two-dimensional representations, so we will define two helper functions for this purpose; states are one-dimensional and cells are the corresponding two-dimensional coordinates:"
    110    ]
    111   },
    112   {
    113    "cell_type": "code",
    114    "execution_count": 4,
    115    "metadata": {},
    116    "outputs": [],
    117    "source": [
    118     "to_1d = lambda x: np.ravel_multi_index(x, grid_size)\n",
    119     "to_2d = lambda x: np.unravel_index(x, grid_size)"
    120    ]
    121   },
    122   {
    123    "cell_type": "markdown",
    124    "metadata": {},
    125    "source": [
    126     "Furthermore, we will precompute some data points to make the code more concise:"
    127    ]
    128   },
    129   {
    130    "cell_type": "code",
    131    "execution_count": 5,
    132    "metadata": {},
    133    "outputs": [],
    134    "source": [
    135     "num_states = np.product(grid_size)\n",
    136     "cells = list(np.ndindex(grid_size))\n",
    137     "states = list(range(len(cells)))"
    138    ]
    139   },
    140   {
    141    "cell_type": "code",
    142    "execution_count": 6,
    143    "metadata": {},
    144    "outputs": [],
    145    "source": [
    146     "cell_state = dict(zip(cells, states))\n",
    147     "state_cell= dict(zip(states, cells))"
    148    ]
    149   },
    150   {
    151    "cell_type": "code",
    152    "execution_count": 7,
    153    "metadata": {},
    154    "outputs": [],
    155    "source": [
    156     "absorbing_states = {to_1d(s):r for s, r in absorbing_cells.items()}\n",
    157     "blocked_state = to_1d(blocked_cell)"
    158    ]
    159   },
    160   {
    161    "cell_type": "markdown",
    162    "metadata": {},
    163    "source": [
    164     "We store the rewards for each state:"
    165    ]
    166   },
    167   {
    168    "cell_type": "code",
    169    "execution_count": 8,
    170    "metadata": {},
    171    "outputs": [],
    172    "source": [
    173     "state_rewards = np.full(num_states, baseline_reward)\n",
    174     "state_rewards[blocked_state] = 0\n",
    175     "for state, reward in absorbing_states.items():\n",
    176     "    state_rewards[state] = reward"
    177    ]
    178   },
    179   {
    180    "cell_type": "code",
    181    "execution_count": 9,
    182    "metadata": {},
    183    "outputs": [],
    184    "source": [
    185     "action_outcomes = {}\n",
    186     "for i, action in enumerate(actions):\n",
    187     "    probs_ = dict(zip([actions[j % 4] for j in range(i, num_actions + i)], probs))\n",
    188     "    action_outcomes[actions[(i + 1) % 4]] = probs_"
    189    ]
    190   },
    191   {
    192    "cell_type": "markdown",
    193    "metadata": {},
    194    "source": [
    195     "To account for the probabilistic environment, we also need to compute the probability distribution over the actual move for a given action:"
    196    ]
    197   },
    198   {
    199    "cell_type": "code",
    200    "execution_count": 10,
    201    "metadata": {},
    202    "outputs": [
    203     {
    204      "data": {
    205       "text/plain": [
    206        "{'U': {'L': 0.1, 'U': 0.8, 'R': 0.1, 'D': 0},\n",
    207        " 'R': {'U': 0.1, 'R': 0.8, 'D': 0.1, 'L': 0},\n",
    208        " 'D': {'R': 0.1, 'D': 0.8, 'L': 0.1, 'U': 0},\n",
    209        " 'L': {'D': 0.1, 'L': 0.8, 'U': 0.1, 'R': 0}}"
    210       ]
    211      },
    212      "execution_count": 10,
    213      "metadata": {},
    214      "output_type": "execute_result"
    215     }
    216    ],
    217    "source": [
    218     "action_outcomes"
    219    ]
    220   },
    221   {
    222    "cell_type": "markdown",
    223    "metadata": {},
    224    "source": [
    225     "Now, we are ready to compute the transition matrix, which is the key input to the MDP."
    226    ]
    227   },
    228   {
    229    "cell_type": "markdown",
    230    "metadata": {},
    231    "source": [
    232     "### Transition Matrix"
    233    ]
    234   },
    235   {
    236    "cell_type": "markdown",
    237    "metadata": {},
    238    "source": [
    239     "The transition matrix defines the probability to end up in a certain state, S, for each previous state and action, A, $P(s^\\prime \\mid s, a)$. We will demonstrate `pymdptoolbox`, and use one of the formats that's available to us to specify transitions and rewards. For both transition probabilities, we will create a `NumPy` array with dimensions of $A \\times S \\times S$.\n",
    240     "\n",
    241     "First, we compute the target cell for each starting cell and move:"
    242    ]
    243   },
    244   {
    245    "cell_type": "code",
    246    "execution_count": 11,
    247    "metadata": {},
    248    "outputs": [],
    249    "source": [
    250     "def get_new_cell(state, move):\n",
    251     "    cell = to_2d(state)\n",
    252     "    if actions[move] == 'U':\n",
    253     "        return cell[0] - 1, cell[1]\n",
    254     "    elif actions[move] == 'D':\n",
    255     "        return cell[0] + 1, cell[1]\n",
    256     "    elif actions[move] == 'R':\n",
    257     "        return cell[0], cell[1] + 1\n",
    258     "    elif actions[move] == 'L':\n",
    259     "        return cell[0], cell[1] - 1"
    260    ]
    261   },
    262   {
    263    "cell_type": "code",
    264    "execution_count": 12,
    265    "metadata": {},
    266    "outputs": [
    267     {
    268      "data": {
    269       "text/plain": [
    270        "array([-0.02, -0.02, -0.02,  1.  , -0.02,  0.  , -0.02, -1.  , -0.02,\n",
    271        "       -0.02, -0.02, -0.02])"
    272       ]
    273      },
    274      "execution_count": 12,
    275      "metadata": {},
    276      "output_type": "execute_result"
    277     }
    278    ],
    279    "source": [
    280     "state_rewards"
    281    ]
    282   },
    283   {
    284    "cell_type": "markdown",
    285    "metadata": {},
    286    "source": [
    287     "The following function uses the argument's starting `state`, `action`, and `outcome` to fill in the transition probabilities and rewards:"
    288    ]
    289   },
    290   {
    291    "cell_type": "code",
    292    "execution_count": 13,
    293    "metadata": {},
    294    "outputs": [],
    295    "source": [
    296     "def update_transitions_and_rewards(state, action, outcome):\n",
    297     "    if state in absorbing_states.keys() or state == blocked_state:\n",
    298     "        transitions[action, state, state] = 1\n",
    299     "    else:\n",
    300     "        new_cell = get_new_cell(state, outcome)\n",
    301     "        p = action_outcomes[actions[action]][actions[outcome]]\n",
    302     "        if new_cell not in cells or new_cell == blocked_cell:\n",
    303     "            transitions[action, state, state] += p\n",
    304     "            rewards[action, state, state] = baseline_reward\n",
    305     "        else:\n",
    306     "            new_state= to_1d(new_cell)\n",
    307     "            transitions[action, state, new_state] = p\n",
    308     "            rewards[action, state, new_state] = state_rewards[new_state]"
    309    ]
    310   },
    311   {
    312    "cell_type": "markdown",
    313    "metadata": {},
    314    "source": [
    315     "We generate the transition and reward values by creating placeholder data structures and iterating over the Cartesian product of $A \\times S \\times S$, as follows:"
    316    ]
    317   },
    318   {
    319    "cell_type": "code",
    320    "execution_count": 14,
    321    "metadata": {},
    322    "outputs": [],
    323    "source": [
    324     "rewards = np.zeros(shape=(num_actions, num_states, num_states))\n",
    325     "transitions = np.zeros((num_actions, num_states, num_states))\n",
    326     "actions_ = list(range(num_actions))\n",
    327     "for action, outcome, state in product(actions_, actions_, states):\n",
    328     "    update_transitions_and_rewards(state, action, outcome)"
    329    ]
    330   },
    331   {
    332    "cell_type": "code",
    333    "execution_count": 15,
    334    "metadata": {},
    335    "outputs": [
    336     {
    337      "data": {
    338       "text/plain": [
    339        "((4, 12, 12), (4, 12, 12))"
    340       ]
    341      },
    342      "execution_count": 15,
    343      "metadata": {},
    344      "output_type": "execute_result"
    345     }
    346    ],
    347    "source": [
    348     "rewards.shape, transitions.shape"
    349    ]
    350   },
    351   {
    352    "cell_type": "markdown",
    353    "metadata": {},
    354    "source": [
    355     "## PyMDPToolbox"
    356    ]
    357   },
    358   {
    359    "cell_type": "markdown",
    360    "metadata": {},
    361    "source": [
    362     "We can also solve MDPs using the [pymdptoolbox](https://pymdptoolbox.readthedocs.io/en/latest/api/mdptoolbox.html) Python library, which includes a few more algorithms, including Q-learning."
    363    ]
    364   },
    365   {
    366    "cell_type": "markdown",
    367    "metadata": {},
    368    "source": [
    369     "### Value Iteration"
    370    ]
    371   },
    372   {
    373    "cell_type": "code",
    374    "execution_count": 16,
    375    "metadata": {},
    376    "outputs": [],
    377    "source": [
    378     "gamma = .99\n",
    379     "epsilon = 1e-5"
    380    ]
    381   },
    382   {
    383    "cell_type": "markdown",
    384    "metadata": {},
    385    "source": [
    386     "To run `ValueIteration`, just instantiate the corresponding object with the desired configuration options and the rewards and transition matrices before calling the `.run()` method:"
    387    ]
    388   },
    389   {
    390    "cell_type": "code",
    391    "execution_count": 17,
    392    "metadata": {},
    393    "outputs": [
    394     {
    395      "data": {
    396       "text/plain": [
    397        "'# Iterations: 31 | Time: 0.0006'"
    398       ]
    399      },
    400      "execution_count": 17,
    401      "metadata": {},
    402      "output_type": "execute_result"
    403     }
    404    ],
    405    "source": [
    406     "vi = mdp.ValueIteration(transitions=transitions,\n",
    407     "                        reward=rewards,\n",
    408     "                        discount=gamma,\n",
    409     "                        epsilon=epsilon)\n",
    410     "\n",
    411     "vi.run()\n",
    412     "f'# Iterations: {vi.iter:,d} | Time: {vi.time:.4f}'"
    413    ]
    414   },
    415   {
    416    "cell_type": "code",
    417    "execution_count": 18,
    418    "metadata": {},
    419    "outputs": [
    420     {
    421      "data": {
    422       "text/html": [
    423        "<div>\n",
    424        "<style scoped>\n",
    425        "    .dataframe tbody tr th:only-of-type {\n",
    426        "        vertical-align: middle;\n",
    427        "    }\n",
    428        "\n",
    429        "    .dataframe tbody tr th {\n",
    430        "        vertical-align: top;\n",
    431        "    }\n",
    432        "\n",
    433        "    .dataframe thead th {\n",
    434        "        text-align: right;\n",
    435        "    }\n",
    436        "</style>\n",
    437        "<table border=\"1\" class=\"dataframe\">\n",
    438        "  <thead>\n",
    439        "    <tr style=\"text-align: right;\">\n",
    440        "      <th></th>\n",
    441        "      <th>0</th>\n",
    442        "      <th>1</th>\n",
    443        "      <th>2</th>\n",
    444        "      <th>3</th>\n",
    445        "    </tr>\n",
    446        "  </thead>\n",
    447        "  <tbody>\n",
    448        "    <tr>\n",
    449        "      <th>0</th>\n",
    450        "      <td>R</td>\n",
    451        "      <td>R</td>\n",
    452        "      <td>R</td>\n",
    453        "      <td>L</td>\n",
    454        "    </tr>\n",
    455        "    <tr>\n",
    456        "      <th>1</th>\n",
    457        "      <td>U</td>\n",
    458        "      <td>L</td>\n",
    459        "      <td>U</td>\n",
    460        "      <td>L</td>\n",
    461        "    </tr>\n",
    462        "    <tr>\n",
    463        "      <th>2</th>\n",
    464        "      <td>U</td>\n",
    465        "      <td>L</td>\n",
    466        "      <td>L</td>\n",
    467        "      <td>L</td>\n",
    468        "    </tr>\n",
    469        "  </tbody>\n",
    470        "</table>\n",
    471        "</div>"
    472       ],
    473       "text/plain": [
    474        "   0  1  2  3\n",
    475        "0  R  R  R  L\n",
    476        "1  U  L  U  L\n",
    477        "2  U  L  L  L"
    478       ]
    479      },
    480      "execution_count": 18,
    481      "metadata": {},
    482      "output_type": "execute_result"
    483     }
    484    ],
    485    "source": [
    486     "policy = np.asarray([actions[i] for i in vi.policy])\n",
    487     "pd.DataFrame(policy.reshape(grid_size))"
    488    ]
    489   },
    490   {
    491    "cell_type": "code",
    492    "execution_count": 19,
    493    "metadata": {},
    494    "outputs": [
    495     {
    496      "data": {
    497       "text/html": [
    498        "<div>\n",
    499        "<style scoped>\n",
    500        "    .dataframe tbody tr th:only-of-type {\n",
    501        "        vertical-align: middle;\n",
    502        "    }\n",
    503        "\n",
    504        "    .dataframe tbody tr th {\n",
    505        "        vertical-align: top;\n",
    506        "    }\n",
    507        "\n",
    508        "    .dataframe thead th {\n",
    509        "        text-align: right;\n",
    510        "    }\n",
    511        "</style>\n",
    512        "<table border=\"1\" class=\"dataframe\">\n",
    513        "  <thead>\n",
    514        "    <tr style=\"text-align: right;\">\n",
    515        "      <th></th>\n",
    516        "      <th>0</th>\n",
    517        "      <th>1</th>\n",
    518        "      <th>2</th>\n",
    519        "      <th>3</th>\n",
    520        "    </tr>\n",
    521        "  </thead>\n",
    522        "  <tbody>\n",
    523        "    <tr>\n",
    524        "      <th>0</th>\n",
    525        "      <td>0.884143</td>\n",
    526        "      <td>0.925054</td>\n",
    527        "      <td>0.961986</td>\n",
    528        "      <td>0.000000</td>\n",
    529        "    </tr>\n",
    530        "    <tr>\n",
    531        "      <th>1</th>\n",
    532        "      <td>0.848181</td>\n",
    533        "      <td>0.000000</td>\n",
    534        "      <td>0.714643</td>\n",
    535        "      <td>0.000000</td>\n",
    536        "    </tr>\n",
    537        "    <tr>\n",
    538        "      <th>2</th>\n",
    539        "      <td>0.808345</td>\n",
    540        "      <td>0.773328</td>\n",
    541        "      <td>0.736099</td>\n",
    542        "      <td>0.516083</td>\n",
    543        "    </tr>\n",
    544        "  </tbody>\n",
    545        "</table>\n",
    546        "</div>"
    547       ],
    548       "text/plain": [
    549        "          0         1         2         3\n",
    550        "0  0.884143  0.925054  0.961986  0.000000\n",
    551        "1  0.848181  0.000000  0.714643  0.000000\n",
    552        "2  0.808345  0.773328  0.736099  0.516083"
    553       ]
    554      },
    555      "execution_count": 19,
    556      "metadata": {},
    557      "output_type": "execute_result"
    558     }
    559    ],
    560    "source": [
    561     "value = np.asarray(vi.V).reshape(grid_size)\n",
    562     "pd.DataFrame(value)"
    563    ]
    564   },
    565   {
    566    "cell_type": "markdown",
    567    "metadata": {},
    568    "source": [
    569     "### Policy Iteration"
    570    ]
    571   },
    572   {
    573    "cell_type": "markdown",
    574    "metadata": {},
    575    "source": [
    576     "The `PolicyIteration` function works similarly:"
    577    ]
    578   },
    579   {
    580    "cell_type": "code",
    581    "execution_count": 20,
    582    "metadata": {},
    583    "outputs": [
    584     {
    585      "data": {
    586       "text/plain": [
    587        "'# Iterations: 7 | Time: 0.0087'"
    588       ]
    589      },
    590      "execution_count": 20,
    591      "metadata": {},
    592      "output_type": "execute_result"
    593     }
    594    ],
    595    "source": [
    596     "pi = mdp.PolicyIteration(transitions=transitions,\n",
    597     "                        reward=rewards,\n",
    598     "                        discount=gamma,\n",
    599     "                        max_iter=1000)\n",
    600     "\n",
    601     "pi.run()\n",
    602     "f'# Iterations: {pi.iter:,d} | Time: {pi.time:.4f}'"
    603    ]
    604   },
    605   {
    606    "cell_type": "markdown",
    607    "metadata": {},
    608    "source": [
    609     "It also yields the same policy, but the value function varies by run and does not need to achieve the optimal value before the policy converges."
    610    ]
    611   },
    612   {
    613    "cell_type": "code",
    614    "execution_count": 21,
    615    "metadata": {},
    616    "outputs": [
    617     {
    618      "data": {
    619       "text/html": [
    620        "<div>\n",
    621        "<style scoped>\n",
    622        "    .dataframe tbody tr th:only-of-type {\n",
    623        "        vertical-align: middle;\n",
    624        "    }\n",
    625        "\n",
    626        "    .dataframe tbody tr th {\n",
    627        "        vertical-align: top;\n",
    628        "    }\n",
    629        "\n",
    630        "    .dataframe thead th {\n",
    631        "        text-align: right;\n",
    632        "    }\n",
    633        "</style>\n",
    634        "<table border=\"1\" class=\"dataframe\">\n",
    635        "  <thead>\n",
    636        "    <tr style=\"text-align: right;\">\n",
    637        "      <th></th>\n",
    638        "      <th>0</th>\n",
    639        "      <th>1</th>\n",
    640        "      <th>2</th>\n",
    641        "      <th>3</th>\n",
    642        "    </tr>\n",
    643        "  </thead>\n",
    644        "  <tbody>\n",
    645        "    <tr>\n",
    646        "      <th>0</th>\n",
    647        "      <td>R</td>\n",
    648        "      <td>R</td>\n",
    649        "      <td>R</td>\n",
    650        "      <td>L</td>\n",
    651        "    </tr>\n",
    652        "    <tr>\n",
    653        "      <th>1</th>\n",
    654        "      <td>U</td>\n",
    655        "      <td>L</td>\n",
    656        "      <td>U</td>\n",
    657        "      <td>L</td>\n",
    658        "    </tr>\n",
    659        "    <tr>\n",
    660        "      <th>2</th>\n",
    661        "      <td>U</td>\n",
    662        "      <td>L</td>\n",
    663        "      <td>L</td>\n",
    664        "      <td>L</td>\n",
    665        "    </tr>\n",
    666        "  </tbody>\n",
    667        "</table>\n",
    668        "</div>"
    669       ],
    670       "text/plain": [
    671        "   0  1  2  3\n",
    672        "0  R  R  R  L\n",
    673        "1  U  L  U  L\n",
    674        "2  U  L  L  L"
    675       ]
    676      },
    677      "execution_count": 21,
    678      "metadata": {},
    679      "output_type": "execute_result"
    680     }
    681    ],
    682    "source": [
    683     "policy = np.asarray([actions[i] for i in pi.policy])\n",
    684     "pd.DataFrame(policy.reshape(grid_size))"
    685    ]
    686   },
    687   {
    688    "cell_type": "code",
    689    "execution_count": 22,
    690    "metadata": {},
    691    "outputs": [
    692     {
    693      "data": {
    694       "text/html": [
    695        "<div>\n",
    696        "<style scoped>\n",
    697        "    .dataframe tbody tr th:only-of-type {\n",
    698        "        vertical-align: middle;\n",
    699        "    }\n",
    700        "\n",
    701        "    .dataframe tbody tr th {\n",
    702        "        vertical-align: top;\n",
    703        "    }\n",
    704        "\n",
    705        "    .dataframe thead th {\n",
    706        "        text-align: right;\n",
    707        "    }\n",
    708        "</style>\n",
    709        "<table border=\"1\" class=\"dataframe\">\n",
    710        "  <thead>\n",
    711        "    <tr style=\"text-align: right;\">\n",
    712        "      <th></th>\n",
    713        "      <th>0</th>\n",
    714        "      <th>1</th>\n",
    715        "      <th>2</th>\n",
    716        "      <th>3</th>\n",
    717        "    </tr>\n",
    718        "  </thead>\n",
    719        "  <tbody>\n",
    720        "    <tr>\n",
    721        "      <th>0</th>\n",
    722        "      <td>0.884143</td>\n",
    723        "      <td>0.925054</td>\n",
    724        "      <td>0.961986</td>\n",
    725        "      <td>1.594721e-16</td>\n",
    726        "    </tr>\n",
    727        "    <tr>\n",
    728        "      <th>1</th>\n",
    729        "      <td>0.848181</td>\n",
    730        "      <td>0.000000</td>\n",
    731        "      <td>0.714643</td>\n",
    732        "      <td>-0.000000e+00</td>\n",
    733        "    </tr>\n",
    734        "    <tr>\n",
    735        "      <th>2</th>\n",
    736        "      <td>0.808345</td>\n",
    737        "      <td>0.773328</td>\n",
    738        "      <td>0.736099</td>\n",
    739        "      <td>5.160828e-01</td>\n",
    740        "    </tr>\n",
    741        "  </tbody>\n",
    742        "</table>\n",
    743        "</div>"
    744       ],
    745       "text/plain": [
    746        "          0         1         2             3\n",
    747        "0  0.884143  0.925054  0.961986  1.594721e-16\n",
    748        "1  0.848181  0.000000  0.714643 -0.000000e+00\n",
    749        "2  0.808345  0.773328  0.736099  5.160828e-01"
    750       ]
    751      },
    752      "execution_count": 22,
    753      "metadata": {},
    754      "output_type": "execute_result"
    755     }
    756    ],
    757    "source": [
    758     "value = np.asarray(pi.V).reshape(grid_size)\n",
    759     "pd.DataFrame(value)"
    760    ]
    761   },
    762   {
    763    "cell_type": "markdown",
    764    "metadata": {},
    765    "source": [
    766     "## Value Iteration"
    767    ]
    768   },
    769   {
    770    "cell_type": "code",
    771    "execution_count": 26,
    772    "metadata": {},
    773    "outputs": [],
    774    "source": [
    775     "skip_states = list(absorbing_states.keys())+[blocked_state]\n",
    776     "states_to_update = [s for s in states if s not in skip_states]"
    777    ]
    778   },
    779   {
    780    "cell_type": "markdown",
    781    "metadata": {},
    782    "source": [
    783     "Then, we initialize the value function and set the discount factor gamma and the convergence threshold epsilon:"
    784    ]
    785   },
    786   {
    787    "cell_type": "code",
    788    "execution_count": 27,
    789    "metadata": {},
    790    "outputs": [],
    791    "source": [
    792     "V = np.random.rand(num_states)\n",
    793     "V[skip_states] = 0"
    794    ]
    795   },
    796   {
    797    "cell_type": "code",
    798    "execution_count": 28,
    799    "metadata": {},
    800    "outputs": [],
    801    "source": [
    802     "gamma = .99\n",
    803     "epsilon = 1e-5"
    804    ]
    805   },
    806   {
    807    "cell_type": "markdown",
    808    "metadata": {},
    809    "source": [
    810     "The algorithm updates the value function using the Bellman optimality equation, and terminates when the L1 norm of V changes less than epsilon in absolute terms:"
    811    ]
    812   },
    813   {
    814    "cell_type": "code",
    815    "execution_count": 29,
    816    "metadata": {},
    817    "outputs": [
    818     {
    819      "data": {
    820       "text/plain": [
    821        "'# Iterations 17 | Time 0.0037'"
    822       ]
    823      },
    824      "execution_count": 29,
    825      "metadata": {},
    826      "output_type": "execute_result"
    827     }
    828    ],
    829    "source": [
    830     "iterations = 0\n",
    831     "start = process_time()\n",
    832     "converged = False\n",
    833     "while not converged:\n",
    834     "    V_ = np.copy(V)\n",
    835     "    for state in states_to_update:\n",
    836     "        q_sa = np.sum(transitions[:, state] * (rewards[:, state] + gamma* V), axis=1)\n",
    837     "        V[state] = np.max(q_sa)\n",
    838     "    if np.sum(np.fabs(V - V_)) < epsilon:\n",
    839     "        converged = True\n",
    840     "\n",
    841     "    iterations += 1\n",
    842     "    if iterations % 1000 == 0:\n",
    843     "        print(np.sum(np.fabs(V - V_)))\n",
    844     "\n",
    845     "f'# Iterations {iterations} | Time {process_time() - start:.4f}'"
    846    ]
    847   },
    848   {
    849    "cell_type": "markdown",
    850    "metadata": {},
    851    "source": [
    852     "### Value Function"
    853    ]
    854   },
    855   {
    856    "cell_type": "code",
    857    "execution_count": 30,
    858    "metadata": {},
    859    "outputs": [
    860     {
    861      "name": "stdout",
    862      "output_type": "stream",
    863      "text": [
    864       "          0         1         2         3\n",
    865       "0  0.884143  0.925054  0.961986  0.000000\n",
    866       "1  0.848181  0.000000  0.714643  0.000000\n",
    867       "2  0.808344  0.773327  0.736099  0.516082\n"
    868      ]
    869     }
    870    ],
    871    "source": [
    872     "print(pd.DataFrame(V.reshape(grid_size)))"
    873    ]
    874   },
    875   {
    876    "cell_type": "code",
    877    "execution_count": 31,
    878    "metadata": {},
    879    "outputs": [
    880     {
    881      "data": {
    882       "text/plain": [
    883        "True"
    884       ]
    885      },
    886      "execution_count": 31,
    887      "metadata": {},
    888      "output_type": "execute_result"
    889     }
    890    ],
    891    "source": [
    892     "np.allclose(V.reshape(grid_size), np.asarray(vi.V).reshape(grid_size))"
    893    ]
    894   },
    895   {
    896    "cell_type": "markdown",
    897    "metadata": {},
    898    "source": [
    899     "### Optimal Policy"
    900    ]
    901   },
    902   {
    903    "cell_type": "code",
    904    "execution_count": 35,
    905    "metadata": {
    906     "scrolled": true
    907    },
    908    "outputs": [
    909     {
    910      "data": {
    911       "text/plain": [
    912        "array([2, 2, 2, 0, 1, 0, 0, 0, 1, 0, 0, 0])"
    913       ]
    914      },
    915      "execution_count": 35,
    916      "metadata": {},
    917      "output_type": "execute_result"
    918     }
    919    ],
    920    "source": [
    921     "for state, reward in absorbing_states.items():\n",
    922     "    V[state] = reward\n",
    923     "\n",
    924     "policy = np.argmax(np.sum(transitions * V, 2),0)\n",
    925     "policy"
    926    ]
    927   },
    928   {
    929    "cell_type": "code",
    930    "execution_count": 36,
    931    "metadata": {},
    932    "outputs": [
    933     {
    934      "data": {
    935       "text/html": [
    936        "<div>\n",
    937        "<style scoped>\n",
    938        "    .dataframe tbody tr th:only-of-type {\n",
    939        "        vertical-align: middle;\n",
    940        "    }\n",
    941        "\n",
    942        "    .dataframe tbody tr th {\n",
    943        "        vertical-align: top;\n",
    944        "    }\n",
    945        "\n",
    946        "    .dataframe thead th {\n",
    947        "        text-align: right;\n",
    948        "    }\n",
    949        "</style>\n",
    950        "<table border=\"1\" class=\"dataframe\">\n",
    951        "  <thead>\n",
    952        "    <tr style=\"text-align: right;\">\n",
    953        "      <th></th>\n",
    954        "      <th>0</th>\n",
    955        "      <th>1</th>\n",
    956        "      <th>2</th>\n",
    957        "      <th>3</th>\n",
    958        "    </tr>\n",
    959        "  </thead>\n",
    960        "  <tbody>\n",
    961        "    <tr>\n",
    962        "      <th>0</th>\n",
    963        "      <td>R</td>\n",
    964        "      <td>R</td>\n",
    965        "      <td>R</td>\n",
    966        "      <td>L</td>\n",
    967        "    </tr>\n",
    968        "    <tr>\n",
    969        "      <th>1</th>\n",
    970        "      <td>U</td>\n",
    971        "      <td>L</td>\n",
    972        "      <td>L</td>\n",
    973        "      <td>L</td>\n",
    974        "    </tr>\n",
    975        "    <tr>\n",
    976        "      <th>2</th>\n",
    977        "      <td>U</td>\n",
    978        "      <td>L</td>\n",
    979        "      <td>L</td>\n",
    980        "      <td>L</td>\n",
    981        "    </tr>\n",
    982        "  </tbody>\n",
    983        "</table>\n",
    984        "</div>"
    985       ],
    986       "text/plain": [
    987        "   0  1  2  3\n",
    988        "0  R  R  R  L\n",
    989        "1  U  L  L  L\n",
    990        "2  U  L  L  L"
    991       ]
    992      },
    993      "execution_count": 36,
    994      "metadata": {},
    995      "output_type": "execute_result"
    996     }
    997    ],
    998    "source": [
    999     "pd.DataFrame(policy.reshape(grid_size)).replace(dict(enumerate(actions)))"
   1000    ]
   1001   },
   1002   {
   1003    "cell_type": "markdown",
   1004    "metadata": {},
   1005    "source": [
   1006     "## Policy Iteration"
   1007    ]
   1008   },
   1009   {
   1010    "cell_type": "markdown",
   1011    "metadata": {},
   1012    "source": [
   1013     "Policy iterations involves separate evaluation and improvement steps. We define the improvement part by selecting the action that maximizes the sum of expected reward and next-state value. Note that we temporarily fill in the rewards for the terminal states to avoid ignoring actions that would lead us there:"
   1014    ]
   1015   },
   1016   {
   1017    "cell_type": "code",
   1018    "execution_count": 41,
   1019    "metadata": {},
   1020    "outputs": [],
   1021    "source": [
   1022     "def policy_improvement(value, transitions):\n",
   1023     "    for state, reward in absorbing_states.items():\n",
   1024     "        value[state] = reward\n",
   1025     "    return np.argmax(np.sum(transitions * value, 2),0)"
   1026    ]
   1027   },
   1028   {
   1029    "cell_type": "code",
   1030    "execution_count": null,
   1031    "metadata": {},
   1032    "outputs": [],
   1033    "source": [
   1034     "V = np.random.rand(num_states)\n",
   1035     "V[skip_states] = 0\n",
   1036     "pi = np.random.choice(list(range(num_actions)), size=num_states)"
   1037    ]
   1038   },
   1039   {
   1040    "cell_type": "markdown",
   1041    "metadata": {},
   1042    "source": [
   1043     "The algorithm alternates between policy evaluation for a greedily selected action and policy improvement until the policy stabilizes:"
   1044    ]
   1045   },
   1046   {
   1047    "cell_type": "code",
   1048    "execution_count": 42,
   1049    "metadata": {},
   1050    "outputs": [
   1051     {
   1052      "data": {
   1053       "text/plain": [
   1054        "'# Iterations 3 | Time 0.0059'"
   1055       ]
   1056      },
   1057      "execution_count": 42,
   1058      "metadata": {},
   1059      "output_type": "execute_result"
   1060     }
   1061    ],
   1062    "source": [
   1063     "iterations = 0\n",
   1064     "start = process_time()\n",
   1065     "converged = False\n",
   1066     "while not converged:\n",
   1067     "    pi_ = np.copy(pi)\n",
   1068     "    for state in states_to_update:\n",
   1069     "        action = policy[state]\n",
   1070     "        V[state] = np.dot(transitions[action, state], (rewards[action, state] + gamma* V))\n",
   1071     "        pi = policy_improvement(V.copy(), transitions)\n",
   1072     "    if np.array_equal(pi_, pi):\n",
   1073     "        converged = True\n",
   1074     "    iterations += 1\n",
   1075     "\n",
   1076     "f'# Iterations {iterations} | Time {process_time() - start:.4f}'"
   1077    ]
   1078   },
   1079   {
   1080    "cell_type": "markdown",
   1081    "metadata": {},
   1082    "source": [
   1083     "Policy iteration converges after only three iterations. The policy stabilizes before the algorithm finds the optimal value function, and the optimal policy differs slightly, most notably by suggesting up instead of the safer left for the field next to the negative terminal state. This can be avoided by tightening the convergence criteria, for example, by requiring a stable policy of several rounds or adding a threshold for the value function."
   1084    ]
   1085   },
   1086   {
   1087    "cell_type": "code",
   1088    "execution_count": 43,
   1089    "metadata": {},
   1090    "outputs": [
   1091     {
   1092      "data": {
   1093       "text/html": [
   1094        "<div>\n",
   1095        "<style scoped>\n",
   1096        "    .dataframe tbody tr th:only-of-type {\n",
   1097        "        vertical-align: middle;\n",
   1098        "    }\n",
   1099        "\n",
   1100        "    .dataframe tbody tr th {\n",
   1101        "        vertical-align: top;\n",
   1102        "    }\n",
   1103        "\n",
   1104        "    .dataframe thead th {\n",
   1105        "        text-align: right;\n",
   1106        "    }\n",
   1107        "</style>\n",
   1108        "<table border=\"1\" class=\"dataframe\">\n",
   1109        "  <thead>\n",
   1110        "    <tr style=\"text-align: right;\">\n",
   1111        "      <th></th>\n",
   1112        "      <th>0</th>\n",
   1113        "      <th>1</th>\n",
   1114        "      <th>2</th>\n",
   1115        "      <th>3</th>\n",
   1116        "    </tr>\n",
   1117        "  </thead>\n",
   1118        "  <tbody>\n",
   1119        "    <tr>\n",
   1120        "      <th>0</th>\n",
   1121        "      <td>R</td>\n",
   1122        "      <td>R</td>\n",
   1123        "      <td>R</td>\n",
   1124        "      <td>L</td>\n",
   1125        "    </tr>\n",
   1126        "    <tr>\n",
   1127        "      <th>1</th>\n",
   1128        "      <td>U</td>\n",
   1129        "      <td>L</td>\n",
   1130        "      <td>U</td>\n",
   1131        "      <td>L</td>\n",
   1132        "    </tr>\n",
   1133        "    <tr>\n",
   1134        "      <th>2</th>\n",
   1135        "      <td>U</td>\n",
   1136        "      <td>L</td>\n",
   1137        "      <td>L</td>\n",
   1138        "      <td>L</td>\n",
   1139        "    </tr>\n",
   1140        "  </tbody>\n",
   1141        "</table>\n",
   1142        "</div>"
   1143       ],
   1144       "text/plain": [
   1145        "   0  1  2  3\n",
   1146        "0  R  R  R  L\n",
   1147        "1  U  L  U  L\n",
   1148        "2  U  L  L  L"
   1149       ]
   1150      },
   1151      "execution_count": 43,
   1152      "metadata": {},
   1153      "output_type": "execute_result"
   1154     }
   1155    ],
   1156    "source": [
   1157     "pd.DataFrame(pi.reshape(grid_size)).replace(dict(enumerate(actions)))"
   1158    ]
   1159   },
   1160   {
   1161    "cell_type": "code",
   1162    "execution_count": 44,
   1163    "metadata": {
   1164     "scrolled": true
   1165    },
   1166    "outputs": [
   1167     {
   1168      "data": {
   1169       "text/html": [
   1170        "<div>\n",
   1171        "<style scoped>\n",
   1172        "    .dataframe tbody tr th:only-of-type {\n",
   1173        "        vertical-align: middle;\n",
   1174        "    }\n",
   1175        "\n",
   1176        "    .dataframe tbody tr th {\n",
   1177        "        vertical-align: top;\n",
   1178        "    }\n",
   1179        "\n",
   1180        "    .dataframe thead th {\n",
   1181        "        text-align: right;\n",
   1182        "    }\n",
   1183        "</style>\n",
   1184        "<table border=\"1\" class=\"dataframe\">\n",
   1185        "  <thead>\n",
   1186        "    <tr style=\"text-align: right;\">\n",
   1187        "      <th></th>\n",
   1188        "      <th>0</th>\n",
   1189        "      <th>1</th>\n",
   1190        "      <th>2</th>\n",
   1191        "      <th>3</th>\n",
   1192        "    </tr>\n",
   1193        "  </thead>\n",
   1194        "  <tbody>\n",
   1195        "    <tr>\n",
   1196        "      <th>0</th>\n",
   1197        "      <td>0.765729</td>\n",
   1198        "      <td>0.874981</td>\n",
   1199        "      <td>0.923802</td>\n",
   1200        "      <td>0.000000</td>\n",
   1201        "    </tr>\n",
   1202        "    <tr>\n",
   1203        "      <th>1</th>\n",
   1204        "      <td>0.697293</td>\n",
   1205        "      <td>0.000000</td>\n",
   1206        "      <td>0.405803</td>\n",
   1207        "      <td>0.000000</td>\n",
   1208        "    </tr>\n",
   1209        "    <tr>\n",
   1210        "      <th>2</th>\n",
   1211        "      <td>0.625634</td>\n",
   1212        "      <td>0.563798</td>\n",
   1213        "      <td>0.506690</td>\n",
   1214        "      <td>0.305024</td>\n",
   1215        "    </tr>\n",
   1216        "  </tbody>\n",
   1217        "</table>\n",
   1218        "</div>"
   1219       ],
   1220       "text/plain": [
   1221        "          0         1         2         3\n",
   1222        "0  0.765729  0.874981  0.923802  0.000000\n",
   1223        "1  0.697293  0.000000  0.405803  0.000000\n",
   1224        "2  0.625634  0.563798  0.506690  0.305024"
   1225       ]
   1226      },
   1227      "execution_count": 44,
   1228      "metadata": {},
   1229      "output_type": "execute_result"
   1230     }
   1231    ],
   1232    "source": [
   1233     "pd.DataFrame(V.reshape(grid_size))"
   1234    ]
   1235   }
   1236  ],
   1237  "metadata": {
   1238   "kernelspec": {
   1239    "display_name": "Python 3",
   1240    "language": "python",
   1241    "name": "python3"
   1242   },
   1243   "language_info": {
   1244    "codemirror_mode": {
   1245     "name": "ipython",
   1246     "version": 3
   1247    },
   1248    "file_extension": ".py",
   1249    "mimetype": "text/x-python",
   1250    "name": "python",
   1251    "nbconvert_exporter": "python",
   1252    "pygments_lexer": "ipython3",
   1253    "version": "3.6.8"
   1254   },
   1255   "toc": {
   1256    "base_numbering": 1,
   1257    "nav_menu": {},
   1258    "number_sections": true,
   1259    "sideBar": true,
   1260    "skip_h1_title": true,
   1261    "title_cell": "Table of Contents",
   1262    "title_sidebar": "Contents",
   1263    "toc_cell": false,
   1264    "toc_position": {},
   1265    "toc_section_display": true,
   1266    "toc_window_display": true
   1267   }
   1268  },
   1269  "nbformat": 4,
   1270  "nbformat_minor": 2
   1271 }