From b75f0617a22ec9f93f93a55849685b9ee9ac9488 Mon Sep 17 00:00:00 2001 From: CodeWizard Date: Tue, 1 Aug 2023 00:52:11 +0530 Subject: [PATCH] Created using Colaboratory --- .../W3D3_OptimalControl/W3D3_Tutorial1.ipynb | 5150 +++++++++-------- 1 file changed, 2610 insertions(+), 2540 deletions(-) diff --git a/tutorials/W3D3_OptimalControl/W3D3_Tutorial1.ipynb b/tutorials/W3D3_OptimalControl/W3D3_Tutorial1.ipynb index 095180a0d2..786c08cbfd 100644 --- a/tutorials/W3D3_OptimalControl/W3D3_Tutorial1.ipynb +++ b/tutorials/W3D3_OptimalControl/W3D3_Tutorial1.ipynb @@ -1,2541 +1,2611 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "execution": {}, - "id": "view-in-github" - }, - "source": [ - "\"Open   \"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "# Tutorial 1: Optimal Control for Discrete States\n", - "\n", - "**Week 3, Day 3: Optimal Control**\n", - "\n", - "**By Neuromatch Academy**\n", - "\n", - "**Content creators:** Zhengwei Wu, Itzel Olivos Castillo, Shreya Saxena, Xaq Pitkow\n", - "\n", - "**Content reviewers:** Karolina Stosio, Roozbeh Farhoodi, Saeed Salehi, Ella Batty, Spiros Chavlis, Matt Krause, Michael Waskom, Melisa Maidana Capitan\n", - "\n", - "**Production editors:** Spiros Chavlis" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Tutorial Objectives\n", - "\n", - "*Estimated timing of tutorial: 60 min*\n", - "\n", - "In this tutorial, we will implement a **binary control** task: a Partially Observable Markov Decision Process (POMDP) that describes fishing. The agent (you) seeks reward from two fishing sites without directly observing where the school of fish is (yes, a group of fish is called a school!). This makes the world a Hidden Markov Model (HMM), just like in the *Hidden Dynamics* day. Based on when and where you catch fish, you keep updating your belief about the fish location, i.e., the posterior of the fish given past observations. You should control your position to get the most fish while minimizing the cost of switching sides.\n", - "\n", - "You've already learned about stochastic dynamics, latent states, and measurements. These first exercises largely repeat your previous work. Now we introduce **actions**, based on the new concepts of **control, utility, and policy**. This general structure provides a foundational model for the brain's computations because it includes a perception-action loop where the animal can gather information, draw inferences about its environment, and select actions with the greatest benefit. *How*, mechanistically, the neurons could actually implement these calculations is a separate question we don't address in this lesson.\n", - "\n", - "In this tutorial, you will:\n", - "* Use the Hidden Markov Models you learned about previously to model the world state.\n", - "* Use the observations (fish caught) to build beliefs (posterior distributions) about the fish location.\n", - "* Evaluate the quality of different control policies for choosing actions.\n", - "* Discover the policy that maximizes utility." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Tutorial slides\n", - "# @markdown These are the slides for all videos in this tutorial.\n", - "from IPython.display import IFrame\n", - "link_id = \"8j5rs\"\n", - "print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", - "IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=854, height=480)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "## Setup\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Install and import feedback gadget\n", - "\n", - "!pip3 install vibecheck datatops --quiet\n", - "\n", - "from vibecheck import DatatopsContentReviewContainer\n", - "def content_review(notebook_section: str):\n", - " return DatatopsContentReviewContainer(\n", - " \"\", # No text prompt\n", - " notebook_section,\n", - " {\n", - " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", - " \"name\": \"neuromatch_cn\",\n", - " \"user_key\": \"y1x3mpx5\",\n", - " },\n", - " ).render()\n", - "\n", - "\n", - "feedback_prefix = \"W3D3_T1\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# Imports\n", - "import numpy as np\n", - "from math import isclose\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Figure Settings\n", - "import logging\n", - "logging.getLogger('matplotlib.font_manager').disabled = True\n", - "\n", - "import ipywidgets as widgets\n", - "from IPython.display import HTML\n", - "%config InlineBackend.figure_format = 'retina'\n", - "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Plotting Functions\n", - "\n", - "def plot_fish(fish_state, ax=None, show=True):\n", - " \"\"\"\n", - " Plot the fish dynamics (states across time)\n", - " \"\"\"\n", - " T = len(fish_state)\n", - "\n", - " offset = 3\n", - "\n", - " if not ax:\n", - " fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))\n", - "\n", - " x = np.arange(0, T, 1)\n", - " y = offset * (fish_state*2 - 1)\n", - "\n", - " ax.plot(y, color='cornflowerblue', markersize=10, linewidth=3.0, zorder=0)\n", - " ax.fill_between(x, y, color='cornflowerblue', alpha=.3)\n", - "\n", - " ax.set_xlabel('time')\n", - " ax.set_ylabel('fish location')\n", - "\n", - " ax.set_xlim([0, T])\n", - " ax.set_xticks([])\n", - " ax.xaxis.set_label_coords(1.05, .54)\n", - "\n", - " ax.set_ylim([-(offset+.5), offset+.5])\n", - " ax.set_yticks([-offset, offset])\n", - " ax.set_yticklabels(['left', 'right'])\n", - "\n", - " ax.spines['bottom'].set_position('center')\n", - " if show:\n", - " plt.show()\n", - "\n", - "\n", - "def plot_measurement(measurement, ax=None, show=True):\n", - " \"\"\"\n", - " Plot the measurements\n", - " \"\"\"\n", - " T = len(measurement)\n", - "\n", - " rel_pos = 3\n", - " red_y = []\n", - " blue_y = []\n", - " for idx, value in enumerate(measurement):\n", - " if value == 0:\n", - " blue_y.append([idx, -rel_pos])\n", - " else:\n", - " red_y.append([idx, rel_pos])\n", - "\n", - " red_y = np.asarray(red_y)\n", - " blue_y = np.asarray(blue_y)\n", - "\n", - " if not ax:\n", - " fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))\n", - "\n", - " if len(red_y) > 0:\n", - " ax.plot(red_y[:, 0], red_y[:, 1], '*', markersize=8, color='crimson')\n", - "\n", - " if len(blue_y) > 0:\n", - " ax.plot(blue_y[:, 0], blue_y[:, 1], '*', markersize=8, color='royalblue')\n", - "\n", - " ax.set_xlabel('time', fontsize=18)\n", - " ax.set_ylabel('Caught fish?')\n", - "\n", - " ax.set_xlim([0, T])\n", - " ax.set_xticks([])\n", - " ax.xaxis.set_label_coords(1.05, .54)\n", - "\n", - " ax.set_ylim([-rel_pos - .5, rel_pos + .5])\n", - " ax.set_yticks([-rel_pos, rel_pos])\n", - " ax.set_yticklabels(['no', 'yes!'])\n", - "\n", - " ax.spines['bottom'].set_position('center')\n", - " if show:\n", - " plt.show()\n", - "\n", - "\n", - "def plot_act_loc(loc, act, ax_loc=None, show=True):\n", - " \"\"\"\n", - " Plot the action and location of T time points\n", - " \"\"\"\n", - " T = len(act)\n", - "\n", - " if not ax_loc:\n", - " fig, ax_loc = plt.subplots(1, 1, figsize=(12, 2.5))\n", - "\n", - " loc = loc*2 - 1\n", - " act_down = []\n", - " act_up = []\n", - " for t in range(1, T):\n", - " if loc[t-1] == -1 and loc[t] == 1:\n", - " act_up.append([t - 0.5, 0])\n", - " if loc[t-1] == 1 and loc[t] == -1:\n", - " act_down.append([t - 0.5, 0])\n", - "\n", - " act_down = np.array(act_down)\n", - " act_up = np.array(act_up)\n", - "\n", - " ax_loc.plot(loc, 'g.-', markersize=8, linewidth=5)\n", - "\n", - " if len(act_down) > 0:\n", - " ax_loc.plot(act_down[:, 0], act_down[:, 1], 'rv', markersize=18, zorder=10, label='switch')\n", - "\n", - " if len(act_up) > 0:\n", - " ax_loc.plot(act_up[:, 0], act_up[:, 1], 'r^', markersize=18, zorder=10)\n", - "\n", - " ax_loc.set_xlabel('time')\n", - " ax_loc.set_ylabel('Your state')\n", - "\n", - " ax_loc.set_xlim([0, T])\n", - " ax_loc.set_xticks([])\n", - " ax_loc.xaxis.set_label_coords(1.05, .54)\n", - "\n", - " if len(act_down) > 0:\n", - " ax_loc.legend(loc=\"upper right\")\n", - " elif len(act_down) == 0 and len(act_up) > 0:\n", - " ax_loc.plot(act_up[:, 0], act_up[:, 1], 'r^', markersize=18, zorder=10, label='switch')\n", - " ax_loc.legend(loc=\"upper right\")\n", - "\n", - " ax_loc.set_ylim([-1.1, 1.1])\n", - " ax_loc.set_yticks([-1, 1])\n", - "\n", - " ax_loc.tick_params(axis='both', which='major')\n", - " ax_loc.set_yticklabels(['left', 'right'])\n", - "\n", - " ax_loc.spines['bottom'].set_position('center')\n", - "\n", - " if show:\n", - " plt.show()\n", - "\n", - "\n", - "def plot_belief(belief, ax1=None, choose_policy=None, show=True):\n", - " \"\"\"\n", - " Plot the belief dynamics of T time points\n", - " \"\"\"\n", - "\n", - " T = belief.shape[1]\n", - "\n", - " if not ax1:\n", - " fig, ax1 = plt.subplots(1, 1, figsize=(12, 2.5))\n", - "\n", - " ax1.plot(belief[1, :], color='midnightblue', markersize=10, linewidth=3.0)\n", - "\n", - " ax1.set_xlabel('time')\n", - " ax1.set_ylabel('Belief (right)')\n", - "\n", - " ax1.set_xlim([0, T])\n", - " ax1.set_xticks([])\n", - " ax1.xaxis.set_label_coords(1.05, 0.05)\n", - "\n", - " ax1.set_yticks([0, 1])\n", - " ax1.set_ylim([0, 1.1])\n", - "\n", - " labels = [item.get_text() for item in ax1.get_yticklabels()]\n", - " ax1.set_yticklabels([' 0', ' 1'])\n", - "\n", - " \"\"\"\n", - " if choose_policy == \"threshold\":\n", - " ax2 = ax1.twinx()\n", - " ax2.plot(time_range, threshold * np.ones(time_range.shape), 'r--')\n", - " ax2.plot(time_range, (1 - threshold) * np.ones(time_range.shape), 'c--')\n", - " ax2.set_yticks([threshold, 1 - threshold])\n", - " ax2.set_ylim([0, 1.1])\n", - " ax2.tick_params(axis='both', which='major', labelsize=18)\n", - " labels = [item.get_text() for item in ax2.get_yticklabels()]\n", - " labels[0] = 'threshold to switch \\n from left to right'\n", - " labels[-1] = 'threshold to switch \\n from right to left'\n", - " ax2.set_yticklabels(labels)\n", - " \"\"\"\n", - " if show:\n", - " plt.show()\n", - "\n", - "\n", - "def plot_dynamics(belief, loc, act, meas, fish_state, choose_policy):\n", - " \"\"\"\n", - " Plot the dynamics of T time points\n", - " \"\"\"\n", - " if choose_policy == 'threshold':\n", - " fig, [ax0, ax_bel, ax_loc, ax1] = plt.subplots(4, 1, figsize=(12, 9))\n", - " plot_fish(fish_state, ax=ax0, show=False)\n", - " plot_belief(belief, ax1=ax_bel, show=False)\n", - " plot_measurement(meas, ax=ax1, show=False)\n", - " plot_act_loc(loc, act, ax_loc=ax_loc)\n", - " else:\n", - " fig, [ax0, ax_bel, ax1] = plt.subplots(3, 1, figsize=(12, 7))\n", - " plot_fish(fish_state, ax=ax0, show=False)\n", - " plot_belief(belief, ax1=ax_bel, show=False)\n", - " plot_measurement(meas, ax=ax1, show=False)\n", - "\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - "\n", - "def belief_histogram(belief, bins=100):\n", - " \"\"\"\n", - " Plot the histogram of belief states\n", - " \"\"\"\n", - " fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - " ax.hist(belief, bins)\n", - " ax.set_xlabel('belief', fontsize=18)\n", - " ax.set_ylabel('count', fontsize=18)\n", - " plt.show()\n", - "\n", - "\n", - "def plot_value_threshold(threshold_array, value_array):\n", - " \"\"\"\n", - " Helper function to plot the value function and threshold\n", - " \"\"\"\n", - " yrange = np.max(value_array) - np.min(value_array)\n", - " star_loc = np.argmax(value_array)\n", - "\n", - " fig_, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - " ax.plot(threshold_array, value_array, 'b')\n", - " ax.vlines(threshold_array[star_loc],\n", - " min(value_array) - yrange * .1, max(value_array),\n", - " colors='red', ls='--')\n", - " ax.plot(threshold_array[star_loc],\n", - " value_array[star_loc],\n", - " '*', color='crimson',\n", - " markersize=20)\n", - "\n", - " ax.set_ylim([np.min(value_array) - yrange * .1,\n", - " np.max(value_array) + yrange * .1])\n", - " ax.set_title(f'threshold vs value with switching cost c = {cost_sw:.2f}',\n", - " fontsize=20)\n", - " ax.set_xlabel('threshold', fontsize=16)\n", - " ax.set_ylabel('value', fontsize=16)\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Helper Functions\n", - "\n", - "# To generate a binomial with fixed \"noise\",\n", - "# we generate a sequence of T numbers uniformly at random\n", - "T = 100\n", - "\n", - "rnd_tele = np.random.uniform(0, 1, T)\n", - "rnd_high_rwd = np.random.uniform(0, 1, T)\n", - "rnd_low_rwd = np.random.uniform(0, 1, T)\n", - "\n", - "\n", - "def get_randomness(T):\n", - " global rnd_tele\n", - " global rnd_high_rwd\n", - " global rnd_low_rwd\n", - "\n", - " rnd_tele = np.random.uniform(0, 1, T)\n", - " rnd_high_rwd = np.random.uniform(0, 1, T)\n", - " rnd_low_rwd = np.random.uniform(0, 1, T)\n", - "\n", - "\n", - "def binomial_tele(p):\n", - " return np.array([1 if p > rnd_tele[i] else 0 for i in range(T)])\n", - "\n", - "\n", - "def getRandomness(p, largeT):\n", - " global rnd_tele\n", - " global rnd_high_rwd\n", - " global rnd_low_rwd\n", - "\n", - " rnd_tele = np.random.uniform(0, 1, largeT)\n", - " rnd_high_rwd = np.random.uniform(0, 1, largeT)\n", - " rnd_low_rwd = np.random.uniform(0, 1, largeT)\n", - "\n", - " return [np.array([1 if p > rnd_tele[i] else 0 for i in range(T)]),\n", - " rnd_high_rwd, rnd_low_rwd]\n", - "\n", - "# def binomial_high_rwd(p):\n", - "# return np.array([1 if p > rnd_high_rwd[i] else 0 for i in range(T)])\n", - "\n", - "# def binomial_low_rwd(p):\n", - "# return np.array([1 if p > rnd_low_rwd[i] else 0 for i in range(T)])\n", - "\n", - "\n", - "class ExcerciseError(AssertionError):\n", - " pass\n", - "\n", - "\n", - "class binaryHMM():\n", - "\n", - " def __init__(self, params, fish_initial=0, loc_initial=0):\n", - " self.params = params\n", - " self.fish_initial = fish_initial\n", - " self.loc_initial = loc_initial\n", - "\n", - " def fish_dynamics(self):\n", - " \"\"\"\n", - " fish state dynamics according to telegraph process\n", - "\n", - " Returns:\n", - " fish_state (numpy array of int)\n", - " \"\"\"\n", - " p_stay, _, _, _ = self.params\n", - " fish_state = np.zeros(T, int) # 0: left side and 1: right side\n", - "\n", - " # initialization\n", - " fish_state[0] = self.fish_initial\n", - " tele_operations = binomial_tele(p_stay) # 0: switch and 1: stay\n", - "\n", - " for t in range(1, T):\n", - " # we use logical operation NOT XOR to determine the next state\n", - " fish_state[t] = int(not(fish_state[t-1] ^ tele_operations[t]))\n", - "\n", - " return fish_state\n", - "\n", - " def generate_process_lazy(self):\n", - " \"\"\"\n", - " fish dynamics and rewards if you always stay in the initial location\n", - " without changing sides\n", - "\n", - " Returns:\n", - " fish_state (numpy array of int): locations of the fish\n", - " loc (numpy array of int): left or right site, 0 for left, and 1 for right\n", - " rwd (numpy array of binary): whether a fish was catched or not\n", - " \"\"\"\n", - "\n", - " _, p_low_rwd, p_high_rwd, _ = self.params\n", - "\n", - " fish_state = self.fish_dynamics()\n", - " rwd = np.zeros(T, int) # 0: no food, 1: get food\n", - "\n", - " for t in range(0, T):\n", - " # new measurement\n", - " if fish_state[t] != self.loc_initial:\n", - " rwd[t] = 1 if p_low_rwd > rnd_low_rwd[t] else 0\n", - " else:\n", - " rwd[t] = 1 if p_high_rwd > rnd_high_rwd[t] else 0\n", - "\n", - " # rwd[t] = binomial(1, p_rwd_vector[(fish_state[t] == loc[t]) * 1])\n", - " return fish_state, self.loc_initial*np.ones(T), rwd\n", - "\n", - "\n", - "class binaryHMM_belief(binaryHMM):\n", - "\n", - " def __init__(self, params,\n", - " fish_initial=0, loc_initial=1,\n", - " choose_policy='threshold'):\n", - "\n", - " binaryHMM.__init__(self, params, fish_initial, loc_initial)\n", - " self.choose_policy = choose_policy\n", - "\n", - " def generate_process(self):\n", - " \"\"\"\n", - " fish dynamics and measurements based on the chosen policy\n", - "\n", - " Returns:\n", - " belief (numpy array of float): belief on the states of the two sites\n", - " act (numpy array of string): actions over time\n", - " loc (numpy array of int): left or right site\n", - " measurement (numpy array of binary): whether a reward is obtained\n", - " fish_state (numpy array of int): fish locations\n", - " \"\"\"\n", - "\n", - " p_stay, low_rew_p, high_rew_p, threshold = self.params\n", - " fish_state = self.fish_dynamics() # 0: left side; 1: right side\n", - " loc = np.zeros(T, int) # 0: left side, 1: right side\n", - " measurement = np.zeros(T, int) # 0: no food, 1: get food\n", - " act = np.empty(T, dtype='object') # \"stay\", or \"switch\"\n", - " belief = np.zeros((2, T), float) # the probability that the fish is on the left (1st element)\n", - " # or on the right (2nd element),\n", - " # the beliefs on the two boxes sum up to be 1\n", - "\n", - " rew_prob = np.array([low_rew_p, high_rew_p])\n", - "\n", - " # initialization\n", - " loc[0] = self.loc_initial\n", - " measurement[0] = 0\n", - " belief_0 = np.random.random(1)[0]\n", - " belief[:, 0] = np.array([belief_0, 1 - belief_0])\n", - " act[0] = self.policy(threshold, belief[:, 0], loc[0])\n", - "\n", - " for t in range(1, T):\n", - " if act[t - 1] == \"stay\":\n", - " loc[t] = loc[t - 1]\n", - " else:\n", - " loc[t] = int(not(loc[t - 1] ^ 0))\n", - "\n", - " # new measurement\n", - " # measurement[t] = binomial(1, rew_prob[(fish_state[t] == loc[t]) * 1])\n", - " if fish_state[t] != loc[t]:\n", - " measurement[t] = 1 if low_rew_p > rnd_low_rwd[t] else 0\n", - " else:\n", - " measurement[t] = 1 if high_rew_p > rnd_high_rwd[t] else 0\n", - "\n", - " belief[0, t] = self.belief_update(belief[0, t - 1] , loc[t],\n", - " measurement[t], p_stay,\n", - " high_rew_p, low_rew_p)\n", - " belief[1, t] = 1 - belief[0, t]\n", - "\n", - " act[t] = self.policy(threshold, belief[:, t], loc[t])\n", - "\n", - " return belief, loc, act, measurement, fish_state\n", - "\n", - " def policy(self, threshold, belief, loc):\n", - " \"\"\"\n", - " chooses policy based on whether it is lazy policy\n", - " or a threshold-based policy\n", - "\n", - " Args:\n", - " threshold (float): the threshold of belief on the current site,\n", - " when the belief is lower than the threshold, switch side\n", - " belief (numpy array of float): the belief on the two sites\n", - " loc (int) : the location of the agent\n", - "\n", - " Returns:\n", - " act (string): \"stay\" or \"switch\"\n", - " \"\"\"\n", - " if self.choose_policy == \"threshold\":\n", - " act = policy_threshold(threshold, belief, loc)\n", - " if self.choose_policy == \"lazy\":\n", - " act = policy_lazy(belief, loc)\n", - "\n", - " return act\n", - "\n", - " def belief_update(self, belief_past, loc, measurement, p_stay,\n", - " high_rew_p, low_rew_p):\n", - " \"\"\"\n", - " using PAST belief on the LEFT box, CURRENT location and\n", - " and measurement to update belief\n", - " \"\"\"\n", - " rew_prob_matrix = np.array([[1 - high_rew_p, high_rew_p],\n", - " [1 - low_rew_p, low_rew_p]])\n", - "\n", - " # update belief posterior, p(s[t] | measurement(0-t), act(0-t-1))\n", - " belief_0 = (belief_past * p_stay + (1 - belief_past) * (1 - p_stay)) *\\\n", - " rew_prob_matrix[(loc + 1) // 2, measurement]\n", - " belief_1 = ((1 - belief_past) * p_stay + belief_past * (1 - p_stay)) *\\\n", - " rew_prob_matrix[1-(loc + 1) // 2, measurement]\n", - "\n", - " belief_0 = belief_0 / (belief_0 + belief_1)\n", - "\n", - " return belief_0\n", - "\n", - "\n", - "def policy_lazy(belief, loc):\n", - " \"\"\"\n", - " This function is a lazy policy where stay is also taken\n", - " \"\"\"\n", - " act = \"stay\"\n", - "\n", - " return act\n", - "\n", - "\n", - "def test_policy_threshold():\n", - " well_done = True\n", - " for loc in [-1, 1]:\n", - " threshold = 0.4\n", - " belief = np.array([.2, .3])\n", - " if policy_threshold(threshold, belief, loc) != \"switch\":\n", - " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", - " for loc in [1, -1]:\n", - " threshold = 0.6\n", - " belief = np.array([.7, .8])\n", - " if policy_threshold(threshold, belief, loc) != \"stay\":\n", - " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", - " print(\"Well Done!\")\n", - "\n", - "\n", - "def test_policy_threshold():\n", - " for loc in [-1, 1]:\n", - " threshold = 0.4\n", - " belief = np.ones(2) * (threshold + 0.1)\n", - " belief[(loc + 1) // 2] = threshold - 0.1\n", - "\n", - " if policy_threshold(threshold, belief, loc) != \"switch\":\n", - " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", - " if policy_threshold(threshold, belief, -1 * loc) != \"stay\":\n", - " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", - "\n", - " print(\"Well Done!\")\n", - "\n", - "\n", - "def test_value_function():\n", - " measurement = np.array([0, 0, 0, 1, 0, 0, 0, 0, 1, 1])\n", - " act = np.array([\"switch\", \"stay\", \"switch\", \"stay\", \"stay\",\n", - " \"stay\", \"switch\", \"switch\", \"stay\", \"stay\"])\n", - " cost_sw = .5\n", - " if not isclose(get_value(measurement, act, cost_sw), .1):\n", - " raise ExcerciseError(\"'value_function' function is not correctly implemented!\")\n", - " print(\"Well Done!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Section 1: Analyzing the Problem" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 1: Gone fishing\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', '3oIwUFpolVA'), ('Bilibili', 'BV1FL411p7o5')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Gone_fishing_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "**Problem Setting**\n", - "\n", - "*1. State dynamics:* There are two possible locations for the fish: Left and Right. Secretly, at each time step, the fish may switch sides with a certain probability $p_{\\rm sw} = 1 - p_{\\rm stay}$. This is the binary switching model (*Telegraph process*) that you've seen in the *Linear Systems* day. The fish location, $s^{\\rm fish}$, is latent; you get measurements about it when you try to catch fish, like in the *Hidden Dynamics* day. This gives you a *belief* or posterior probability of the current location given your history of measurements.\n", - "\n", - "*2. Actions:* Unlike past days, you can now **act** on the process! You may stay on your current location (Left or Right), or switch to the other side.\n", - "\n", - "*3. Rewards and Costs:* You get rewarded for each fish you catch (one fish is worth 1 \"point\"). If you're on the same side as the fish, you'll catch more, with probability $q_{\\rm high}$ per discrete time step. Otherwise, you may still catch some fish with probability $q_{\\rm low}$.\n", - "\n", - "You pay a price of $C$ points for switching to the other side. So you better decide wisely!\n", - "\n", - "
\n", - "\n", - "**Maximizing Utility**\n", - "\n", - "To decide \"wisely\" and maximize your total utility (total points), you will follow a **policy** that prescribes what to do in any situation. Here the situation is determined by your location and your **belief** $b_t$ (posterior) about the fish location (remember that the fish location is a latent variable).\n", - "\n", - "In optimal control theory, the belief is the posterior probability over the latent variable given all the past measurements. It can be shown that maximizing the expected utility with respect to this posterior is optimal.\n", - "\n", - "In our problem, the belief can be represented by a single number because the fish are either on the left or the right side. So we write:\n", - "\n", - "\\begin{equation}\n", - "b_t = p(s^{\\rm fish}_t = {\\rm Right}\\ |\\ m_{0:t}, a_{0:t-1})\n", - "\\end{equation}\n", - "\n", - "where $m_{0:t}$ are the measurements and $a_{0:t-1}$ are the actions (stay or switch).\n", - "\n", - "Finally, we will parameterize the policy by a simple threshold on beliefs: when your belief that fish are on your current side falls below a threshold $\\theta$, you switch to the other side.\n", - "\n", - "In this tutorial, you will discover that if you pick the right threshold, this simple policy happens to be optimal!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Interactive Demo 1: Examining fish dynamics\n", - "\n", - "In this demo, we will look at the dynamics of the fish moving from side to side while you stay in one place. Play around with the probability `stay_prob` of fish staying in the same location, and observe the resulting dynamics of the fish." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "**Thinking questions:**\n", - "\n", - "* If the fish have already been on one side for a long time, does that change the chances of them switching sides?\n", - "* For what values of p_stay is the fish location most and least predictable?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @markdown Execute this cell to enable the demo.\n", - "display(HTML(''''''))\n", - "\n", - "@widgets.interact(p_stay=widgets.FloatSlider(.9, description=\"stay_prob\", min=0., max=1., step=0.01))\n", - "\n", - "def update_ex_1(p_stay):\n", - " \"\"\"\n", - " T: Length of timeline\n", - " p_stay: probability that the fish do not swim to the other side at time t\n", - " \"\"\"\n", - " params = [p_stay, _, _, _]\n", - "\n", - " # initial condition: fish [fish_initial] start at the left location (-1)\n", - " binaryHMM_test = binaryHMM(params=params, fish_initial=1)\n", - "\n", - " fish_state = binaryHMM_test.fish_dynamics()\n", - " plot_fish(fish_state)\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {} - }, - "outputs": [], - "source": [ - "# to_remove explanation\n", - "\n", - "\"\"\"\n", - "In Interactive Demo 1, you should see the school of fish switch sides less often when `stay_prob` is high.\n", - "\n", - "* If the fish have already been on one side for a long time, does that change the chances of them switching sides?\n", - "\n", - " No. The telegraph process or binary switching process is Markovian.\n", - " That means that the probabilities of changes depend only on the *current* state.\n", - " States from further in the past do not matter for the chances of switching sides.\n", - " Staying longer in one side is not a statement about the current state, but rather about the past,\n", - " so it's irrelevant for the chances of switching.\n", - "\n", - "\n", - "* For what values of `p_stay` is the fish location most and least predictable?\n", - "\n", - " When `p_stay` is 1 then the fish never move. But when `p_stay` is 0 then the fish *always* move,\n", - " oscillating back and forth deterministically every discrete time step.\n", - "\"\"\";" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Examining_fish_dynamics_Interactive_Demo_and_Discussion\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Section 2: Catching fish" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 2: Catch some fish\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', 'ZjB2_SAY2uE'), ('Bilibili', 'BV1kD4y1m7Lo')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Catch_some_fish_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Interactive Demo 2: Examining the reward function\n", - "\n", - "In this second demo, you control your location by a button, but we fix the fish's location by setting `stay_prob = 1`. Now that the fish are serenely swimming in one location, we can visually inspect the rewards when you're on the same side as the fish or on the other side.\n", - "\n", - "When you're on the same side as the fish, you should have a higher probability of catching them (but watch out, since technically, you are _allowed_ to adjust the sliders to other conditions!).\n", - "\n", - "Play around with the sliders `high_rew_prob` (high reward probability when you're on the fish's side) and `low_rew_prob` (low reward probability when you're on the other side). The button (same location *vs.* different location) determines which probability describes how often you catch fish." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "**Thinking questions:**\n", - "\n", - "* What happens when the fish and the agent (you!) are on the same or different locations?\n", - "* Where do you catch the most fish?\n", - "* Why isn't `low_rew_prob + high_rew_prob = 1`? What do these probabilities mean in the fishing story?\n", - "* You _can_ move the sliders so `low_rew_prob > high_rew_prob`. This doesn't change the math, but it can change whether the math is a reasonable model of the physical problem. Why?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @markdown Execute this cell to enable the demo.\n", - "display(HTML(''''''))\n", - "\n", - "@widgets.interact(locs=widgets.RadioButtons(options=['same location', 'different locations'],\n", - " description='Fish and agent:',\n", - " disabled=False,\n", - " layout={'width': 'max-content'}),\n", - " p_low_rwd=widgets.FloatSlider(.1, description=\"low_rew_prob:\",\n", - " min=0., max=1.),\n", - " p_high_rwd=widgets.FloatSlider(.9, description=\"high_rew_prob:\",\n", - " min=0., max=1.))\n", - "\n", - "def update_ex_2(locs, p_low_rwd, p_high_rwd):\n", - " \"\"\"\n", - " p_stay: probability of fish staying at current side at time t\n", - " p_low_rwd: probability of catching fish when you're NOT on the side where the fish are swimming\n", - " p_high_rwd: probability of catching fish when you're on the side where the fish are swimming\n", - " fish_initial: initial side of fish (-1 left, 1 right)\n", - " agent_initial: initial side of the agent (YOU!) (-1 left, 1 right)\n", - " \"\"\"\n", - " p_stay = 1\n", - " params = [p_stay, p_low_rwd, p_high_rwd, _]\n", - "\n", - " # initial condition for fish [fish_initial] and you [loc_initial]\n", - " if locs == 'same location':\n", - " binaryHMM_test = binaryHMM(params, fish_initial=0, loc_initial=0)\n", - " else:\n", - " binaryHMM_test = binaryHMM(params, fish_initial=1, loc_initial=0)\n", - "\n", - " fish_state, loc, measurement = binaryHMM_test.generate_process_lazy()\n", - " plot_measurement(measurement)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {} - }, - "outputs": [], - "source": [ - "# to_remove explanation\n", - "\n", - "\"\"\"\n", - "* What happens when the fish and the agent (you!) are on the same or different locations?\n", - " You catch fish with different probabilities.\n", - "\n", - "* Where do you catch the most fish?\n", - " When you're on the same side as the fish -- as long as high_rew_prob > low_rew_prob.\n", - "\n", - "* Why isn't low_rew_prob + high_rew_prob = 1? What do these probabilities mean in the fishing story?\n", - " These are not probabilities of mutually exclusive events. They are chances of one event (you catch fish)\n", - " under two different conditions (you and the school of fish are on the same side or different sides).\n", - "\n", - "* You _can_ move the sliders so `low_rew_prob > high_rew_prob`. This doesn't change the math,\n", - " but it can change whether the math is a reasonable model of the physical problem. Why?\n", - " It would be weird if you caught less fish when you're on the same side as the fish.\n", - " But hey, maybe the fish warn each other when they're in a school together! Then they'd be harder to catch...\n", - "\"\"\";" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Examining_the_reward_function_Interactive_Demo_and_Discussion\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Section 3: Belief dynamics and belief distributions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 3: Where are the fish?\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', 'rmETVsRFYGk'), ('Bilibili', 'BV19t4y1Q7VH')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Where_are_the_fish_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Interactive Demo 3: Examining the beliefs\n", - "\n", - "Now it's time to get an intuition on how beliefs are calculated. Here we define your belief about the fish location is just the posterior probability about that location given your measurements, $p(s_t|m_{0:t})$. Note that this is just what you did in the day covering Hidden Dynamics!\n", - "\n", - "In this exercise, you'll always stay on the LEFT side, but the fish will move around. They'll stay on the same side with probability `stay_prob`. You only get to see fish you catch, not where the school of fish is. You have to use those measurements to infer the location of the school.\n", - "\n", - "In this demo, play around with the sliders `high_rew_prob` and `low_rew_prob`, and `stay_prob`.\n", - "\n", - "**Thinking questions:**\n", - "\n", - "* Manipulate the slider for `stay_prob`. How well does the belief explain the dynamics of the fish as you adjust the probability of the fish staying in one location (`stay_prob`)?\n", - "\n", - "* Explore the extreme case where `high_rew_prob = 1` and `low_rew_prob = 0`. How accurate is the belief as these parameters change?\n", - "\n", - "* Under what conditions is it informative to catch a fish? What about to *not* catching a fish?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown Execute this cell to enable the demo.\n", - "display(HTML(''''''))\n", - "\n", - "@widgets.interact(p_stay=widgets.FloatSlider(.96, description=\"stay_prob\",\n", - " min=.8, max=1., step=.01),\n", - " p_low_rwd=widgets.FloatSlider(.1, description=\"low_rew_prob\",\n", - " min=0., max=1., step=.01),\n", - " p_high_rwd=widgets.FloatSlider(.3, description=\"high_rew_prob\",\n", - " min=0., max=1., step=.01))\n", - "\n", - "def update_ex_2(p_stay, p_low_rwd, p_high_rwd):\n", - " \"\"\"\n", - " T: Length of timeline\n", - " p_stay: probability of fish staying at current side at time t\n", - " p_high_rwd: probability of catching fish when you're on the side where the fish are swimming\n", - " p_low_rwd: probability of catching fish when you're NOT on the side where the fish are swimming\n", - " fish_initial: initial side of fish (0 left, 1 right)\n", - " agent_initial: initial side of the agent (YOU!) (0 left, 1 right)\n", - " threshold: threshold of belief below which the action is switching\n", - " \"\"\"\n", - " threshold = 0.2\n", - " params = [p_stay, p_low_rwd, p_high_rwd, threshold]\n", - "\n", - " binaryHMM_test = binaryHMM_belief(params, choose_policy=\"lazy\",\n", - " fish_initial=0, loc_initial=0)\n", - "\n", - " belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()\n", - " plot_dynamics(belief, loc, act, measurement, fish_state,\n", - " binaryHMM_test.choose_policy)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {} - }, - "outputs": [], - "source": [ - "# to_remove explanation\n", - "\n", - "\"\"\"\n", - "* Manipulate the slider for `stay_prob`. How well does the belief explain the dynamics of the fish as\n", - " you adjust the probability of the fish staying in one location (`stay_prob`)?\n", - "\n", - " The parameter (`stay_prob`) determines fish dynamics. If it is low, the fish are moving fast\n", - " and you don't have much time to collect observations that might decrease your uncertainty about\n", - " the actual location of the school. If it is high, you have more time to integrate evidence\n", - " and the belief explains better the dynamics of the fish.\n", - "\n", - "* Explore the extreme case where `high_rew_prob = 1` and `low_rew_prob = 0`.\n", - " Now play around with these sliders. How accurate is the belief as these parameters change?\n", - "\n", - " In the extreme case, the belief explains the dynamics of the fish perfectly because\n", - " our observations are perfect, i.e., catching a fish indicates with certainty the presence of the school.\n", - " If the chances of catching a fish are very different between the two sides, then you get a lot of information\n", - " for each fish you catch. The belief will then rise and fall steeply with each observation.\n", - " If the two probabilities are similar, then the belief will change slowly even if the fish move quickly.\n", - "\n", - "* Under what conditions is it informative to catch a fish? What about to *not* catch a fish?\n", - "\n", - " The bigger the difference in the two probabilities, the more information you get from measurements.\n", - " If both probabilities are low (and different), then you learn a lot from catching a fish.\n", - " But you still learn a little if you don't catch anything, particularly when catching a fish is probable in one case.\n", - "\"\"\";" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Examining_the_beliefs_Interactive_Demo_and_Discussion\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Section 4: Implementing a threshold policy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 4: How should you act?\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', 'cTzaQl2Vxn4'), ('Bilibili', 'BV1ri4y137cj')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_How_should_you_act_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Coding Exercise 4: dynamics following a **threshold-based** policy\n", - "\n", - "Now we'll switch the policy from the 'lazy' policy used above to a threshold policy that you need to write. You'll change your location whenever your belief is low enough that you're on the best side. You'll update the function `policy_threshold(threshold, belief, loc)`. This policy takes three inputs:\n", - "\n", - "1. The `belief` about the fish state. For convenience, we will represent the belief at time *t* using a 2-dimensional vector. The first element is the belief that the fish are on the left, and the second element is the belief the fish are on the right. At every time step, these elements sum to 1.\n", - "\n", - "2. Your location `loc`, represented as \"Left\" = -1 and \"Right\" = 1.\n", - "\n", - "3. A belief `threshold` that determines when to switch. When your belief that you are on the same side as the fish drops below this threshold, you should move to the other location, and otherwise stay.\n", - "\n", - "Your function should return an action for each time *t*, which takes the value of \"stay\" or \"switch\"." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "def policy_threshold(threshold, belief, loc):\n", - " \"\"\"\n", - " chooses whether to switch side based on whether the belief\n", - " on the current site drops below the threshold\n", - "\n", - " Args:\n", - " threshold (float): the threshold of belief on the current site,\n", - " when the belief is lower than the threshold, switch side\n", - " belief (numpy array of float, 2-dimensional): the belief on the\n", - " two sites at a certain time\n", - " loc (int) : the location of the agent at a certain time\n", - " -1 for left side, 1 for right side\n", - "\n", - " Returns:\n", - " act (string): \"stay\" or \"switch\"\n", - " \"\"\"\n", - "\n", - " ############################################################################\n", - " ## 1. Modify the code below to generate actions (stay or switch)\n", - " ## for current belief and location\n", - " ##\n", - " ## Belief is a 2d vector: first element = Prob(fish on Left | measurements)\n", - " ## second element = Prob(fish on Right | measurements)\n", - " ## Returns \"switch\" if Belief that fish are in your current location < threshold\n", - " ## \"stay\" otherwise\n", - " ##\n", - " ## Hint: use loc value to determine which row of belief you need to use\n", - " ## see the docstring for more information about loc\n", - " ##\n", - " ## 2. After completing the function, comment this line:\n", - " raise NotImplementedError(\"Student exercise: Please complete the code\")\n", - " ############################################################################\n", - " # Write the if statement\n", - " if ...:\n", - " # action below threshold\n", - " act = ...\n", - " else:\n", - " # action above threshold\n", - " act = ...\n", - "\n", - " return act\n", - "\n", - "\n", - "# Next line tests your function\n", - "test_policy_threshold()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "You have to see\n", - "\n", - "```Well Done!```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# to_remove solution\n", - "def policy_threshold(threshold, belief, loc):\n", - " \"\"\"\n", - " chooses whether to switch side based on whether the belief\n", - " on the current site drops below the threshold\n", - "\n", - " Args:\n", - " threshold (float): the threshold of belief on the current site,\n", - " when the belief is lower than the threshold, switch side\n", - " belief (numpy array of float, 2-dimensional): the belief on the\n", - " two sites at a certain time\n", - " loc (int) : the location of the agent at a certain time\n", - " -1 for left side, 1 for right side\n", - "\n", - " Returns:\n", - " act (string): \"stay\" or \"switch\"\n", - " \"\"\"\n", - " # Write the if statement\n", - " if belief[(loc + 1) // 2] <= threshold:\n", - " # action below threshold\n", - " act = \"switch\"\n", - " else:\n", - " # action above threshold\n", - " act = \"stay\"\n", - "\n", - " return act\n", - "\n", - "\n", - "# Next line tests your function\n", - "test_policy_threshold()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Dynamics_threshold_based_policy_Exercise\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {}, - "tags": [] - }, - "source": [ - "## Interactive Demo 4: Dynamics with different thresholds\n", - "\n", - "The following demo uses the policy you just built! Play around with the slider and observe the dynamics controlled by your policy.\n", - "\n", - "(The code specifies `stay_prob=0.95`, `high_rew_prob=0.3`, and `low_rew_prob=0.1`. You can change these, but these are reasonable parameters. Note: to see the gradual change with threshold, keep reusing the same random; to see different examples, refresh the seed.\n", - ")\n", - "\n", - "**Thinking questions:**\n", - "* Qualitatively, how well does this policy follow the fish? What does it miss, and why?\n", - "* How can you characterize the fishing strategy if the threshold is very low, or very high?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown Execute this cell to enable the demo.\n", - "display(HTML(''''''))\n", - "\n", - "@widgets.interact(threshold=widgets.FloatSlider(.2, description=\"threshold\", min=0., max=1., step=.01),\n", - " new_seed=widgets.ToggleButtons(options=['Reusing', 'Refreshing'],\n", - " description='Random seed:',\n", - " disabled=False,\n", - " button_style='', # 'success', 'info', 'warning', 'danger' or '',\n", - " icons=['check'] * 2\n", - " ))\n", - "def update_ex_4(threshold, new_seed):\n", - " \"\"\"\n", - " p_stay: probability fish stay\n", - " high_rew_p: p(catch fish) when you're on their side\n", - " low_rew_p : p(catch fish) when you're on other side\n", - " threshold: threshold of belief below which switching is taken\n", - "\n", - " \"\"\"\n", - " if new_seed == \"Refreshing\":\n", - " get_randomness(T)\n", - "\n", - " stay_prob=.95\n", - " high_rew_p=.3\n", - " low_rew_p=.1\n", - "\n", - " params = [stay_prob, high_rew_p, low_rew_p, threshold]\n", - "\n", - " # initial condition for fish [fish_initial] and you [loc_initial]\n", - " binaryHMM_test = binaryHMM_belief(params, fish_initial=0, loc_initial=0, choose_policy=\"threshold\")\n", - "\n", - " belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()\n", - " plot_dynamics(belief, loc, act, measurement,\n", - " fish_state, binaryHMM_test.choose_policy)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {} - }, - "outputs": [], - "source": [ - "# to_remove explanation\n", - "\n", - "\"\"\"\n", - "* Qualitatively, how well does this policy follow the fish? What does it miss, and why?\n", - "\n", - " You generally follow the fish, but there can be a substantial difference in location.\n", - " The belief is not generally very confident when the probabilities of catching fish on the\n", - " two sides are not very different. Depending on your threshold, you might leave just\n", - " from some unlucky times when you're still on the right side. Or you might stay even\n", - " though you have not caught many fish, in the hopes that the fish haven't moved.\n", - "\n", - "* How can you characterize the fishing strategy if the threshold is very low, or very high?\n", - "\n", - " If the threshold is low, you only switch when you have a very low belief that you're on the right side.\n", - " Then you switch very rarely.\n", - " If the threshold is high, then you switch whenever you're not extremely confident,\n", - " so you change sides all the time.\n", - "\"\"\";" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Dynamics_with_different_thresholds_Interactive_Demo_and_Discussion\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Section 5: Implementing a value function" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 5: Evaluate policy\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', 'aJhffROC74w'), ('Bilibili', 'BV1TD4y1D7K3')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Evaluate_policy_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Coding Exercise 5.1: Implementing a value function\n", - "\n", - "Let's find out how good our threshold is. For that, we will calculate a **value function** that quantifies our utility (total points). We will use this value to compare different thresholds; remember, our goal is to maximize the amount of fish we catch while minimizing the effort involved in changing locations.\n", - "\n", - "The value is the total expected utility per unit time.\n", - "\n", - "\\begin{equation}\n", - "V(\\theta) = \\frac{1}{T}\\left( \\sum_t R(s_t) - C(a_t) \\right)\n", - "\\end{equation}\n", - "\n", - "where $R(s_t)$ is the instantaneous reward we get at location $s_t$ and $C(a_t)$ is the cost we paid for the chosen action. Remember, we receive one point for fish caught and pay `cost_sw` points for switching to the other location.\n", - "\n", - "We could take this average mathematically over the probabilities of rewards and actions. However, we can get the same answer by simply averaging the _actual_ rewards and costs over a long time. This is what you are going to do.\n", - "\n", - "\n", - "**Instructions**: Fill in the function `get_value(rewards, actions, cost_sw)`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "def get_value(rewards, actions, cost_sw):\n", - " \"\"\"\n", - " value function\n", - "\n", - " Args:\n", - " rewards (numpy array of length T): whether a reward is obtained (1) or not (0) at each time step\n", - " actions (numpy array of length T): action, \"stay\" or \"switch\", taken at each time step.\n", - " cost_sw (float): the cost of switching to the other location\n", - "\n", - " Returns:\n", - " value (float): expected utility per unit time\n", - " \"\"\"\n", - " actions_int = (actions == \"switch\").astype(int)\n", - "\n", - " ############################################################################\n", - " ## 1. Modify the code below to compute the value function (equation V(theta))\n", - " ##\n", - " ## 2. After completing the function, comment this line:\n", - " raise NotImplementedError(\"Student exercise: Please complete the code\")\n", - " ############################################################################\n", - " # Calculate the value function\n", - " value = ...\n", - "\n", - " return value\n", - "\n", - "\n", - "# Test your function\n", - "test_value_function()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "You will see\n", - "\n", - "```Well Done!```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# to_remove solution\n", - "\n", - "def get_value(rewards, actions, cost_sw):\n", - " \"\"\"\n", - " Args:\n", - " rewards (numpy array of length T): whether a reward is obtained (1) or not (0) at each time step\n", - " actions (numpy array of length T): action, \"stay\" or \"switch\", taken at each time step.\n", - " cost_sw (float): the cost of switching to the other location\n", - "\n", - " Returns:\n", - " value (float): expected utility per unit time\n", - " \"\"\"\n", - " actions_int = (actions == \"switch\").astype(int)\n", - "\n", - " # Calculate the value function\n", - " value = np.sum(rewards - actions_int * cost_sw) / len(rewards)\n", - "\n", - " return value\n", - "\n", - "\n", - "# Test your function\n", - "test_value_function()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Implementing_a_value_function_Exercise\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Coding Exercise 5.2: Run the policy\n", - "\n", - "Now that you have a mechanism to find out how good a threshold is, we will use a brute force approach to **compute the optimal threshold**: we'll just try all thresholds, simulate the value of each, and pick the best one. Complete the function `get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw)`. We provide the code to visualize the output of your function. Observe on this plot which threshold has maximal utility.\n", - "\n", - "**Thinking questions:**\n", - "\n", - "* Try a very high switching cost. What is the best threshold? How does that make sense?\n", - "* Try a zero switching cost. What's different?\n", - "* Generally, how does the best threshold change with the switching cost?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {} - }, - "outputs": [], - "source": [ - "def run_policy(threshold, p_stay, low_rew_p, high_rew_p):\n", - " \"\"\"\n", - " This function executes the policy (fully parameterized by the threshold) and\n", - " returns two arrays:\n", - " The sequence of actions taken from time 0 to T\n", - " The sequence of rewards obtained from time 0 to T\n", - " \"\"\"\n", - " params = [p_stay, low_rew_p, high_rew_p, threshold]\n", - " binaryHMM_test = binaryHMM_belief(params, choose_policy=\"threshold\")\n", - " _, _, actions, rewards, _ = binaryHMM_test.generate_process()\n", - "\n", - " return actions, rewards\n", - "\n", - "\n", - "def get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw):\n", - " \"\"\"\n", - " Args:\n", - " p_stay (float): probability of fish staying in their current location\n", - " low_rew_p (float): probability of catching fish when you and the fist are in different locations.\n", - " high_rew_p (float): probability of catching fish when you and the fist are in the same location.\n", - " cost_sw (float): the cost of switching to the other location\n", - "\n", - " Returns:\n", - " value (float): expected utility per unit time\n", - " \"\"\"\n", - " ############################################################################\n", - " ## 1. Modify the code below to find the best threshold using brute force\n", - " ##\n", - " ## 2. After completing the function, comment this line:\n", - " raise NotImplementedError(\"Student exercise: Please complete the code\")\n", - " ############################################################################\n", - " global T\n", - " T = 10000 # Setting a large time horizon\n", - " get_randomness(T)\n", - "\n", - " # Create an array of 20 equally distanced candidate thresholds (min = 0., max=1.):\n", - " threshold_array = ...\n", - "\n", - " # Using the function get_value() that you coded before and\n", - " # the function run_policy() that we provide, compute the value of your\n", - " # candidate thresholds:\n", - "\n", - " # Create an array to store the value of each of your candidates:\n", - " value_array = ...\n", - "\n", - " for i in ...:\n", - " actions, rewards = ...\n", - " value_array[i] = ...\n", - "\n", - " # Return the array of candidate thresholds and their respective values\n", - "\n", - " return threshold_array, value_array\n", - "\n", - "\n", - "# Feel free to change these parameters\n", - "stay_prob = .9\n", - "low_rew_prob = 0.1\n", - "high_rew_prob = 0.2\n", - "cost_sw = .1\n", - "\n", - "# Visually determine the threshold that obtains the maximum utility\n", - "threshold_array, value_array = get_optimal_threshold(stay_prob, low_rew_prob, high_rew_prob, cost_sw)\n", - "plot_value_threshold(threshold_array, value_array)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# to_remove solution\n", - "\n", - "def run_policy(threshold, p_stay, low_rew_p, high_rew_p):\n", - " \"\"\"\n", - " This function executes the policy (fully parameterized by the threshold) and\n", - " returns two arrays:\n", - " The sequence of actions taken from time 0 to T\n", - " The sequence of rewards obtained from time 0 to T\n", - " \"\"\"\n", - " params = [p_stay, low_rew_p, high_rew_p, threshold]\n", - " binaryHMM_test = binaryHMM_belief(params, choose_policy=\"threshold\")\n", - " _, _, actions, rewards, _ = binaryHMM_test.generate_process()\n", - " return actions, rewards\n", - "\n", - "\n", - "def get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw):\n", - " \"\"\"\n", - " Args:\n", - " p_stay (float): probability of fish staying in their current location\n", - " low_rew_p (float): probability of catching fish when you and the fist are in different locations.\n", - " high_rew_p (float): probability of catching fish when you and the fist are in the same location.\n", - " cost_sw (float): the cost of switching to the other location\n", - "\n", - " Returns:\n", - " value (float): expected utility per unit time\n", - " \"\"\"\n", - " global T\n", - " T = 10000 # Setting a large time horizon\n", - " get_randomness(T)\n", - "\n", - " # Create an array of 20 equally distanced candidate thresholds (min = 0., max=1.):\n", - " threshold_array = np.linspace(0., 1., 20)\n", - "\n", - " # Using the function get_value() that you coded before and\n", - " # the function run_policy() that we provide, compute the value of your\n", - " # candidate thresholds:\n", - "\n", - " # Create an array to store the value of each of your candidates:\n", - " value_array = np.zeros(len(threshold_array))\n", - "\n", - " for i in range(len(threshold_array)):\n", - " actions, rewards = run_policy(threshold_array[i], p_stay, low_rew_p, high_rew_p)\n", - " value_array[i] = get_value(rewards, actions, cost_sw)\n", - "\n", - " # Return the array of candidate thresholds and their respective values\n", - "\n", - " return threshold_array, value_array\n", - "\n", - "\n", - "# Feel free to change these parameters\n", - "stay_prob = .9\n", - "low_rew_prob = 0.1\n", - "high_rew_prob = 0.2\n", - "cost_sw = .1\n", - "\n", - "# Visually determine the threshold that obtains the maximum utility\n", - "threshold_array, value_array = get_optimal_threshold(stay_prob, low_rew_prob, high_rew_prob, cost_sw)\n", - "with plt.xkcd():\n", - " plot_value_threshold(threshold_array, value_array)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {} - }, - "outputs": [], - "source": [ - "# to_remove explanation\n", - "\n", - "\"\"\"\n", - "* Try a very high switching cost. What is the best threshold? How does that make sense?\n", - "\n", - " You should see that there is a best threshold:\n", - " If it is too small, then you never move, missing opportunities to follow the fish.\n", - " If it is too large, then you move too often and pay a large cost for the switching.\n", - " When the switching cost is extremely high, it's never worth moving, so the optimal threshold is at zero.\n", - "\n", - "* Try a zero switching cost. What's different?\n", - "\n", - " When the switching cost is zero, it's not best to always switch, but rather to follow\n", - " the optimal inference about the fish location.\n", - "\n", - "* Generally, how does the best threshold change with the switching cost?\n", - "\n", - " As the switching cost rises, the threshold should fall because\n", - " you have even more incentive to avoid switches.\n", - "\"\"\";" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Run_the_policy_Exercise_and_Discussion\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Summary\n", - "\n", - "In this tutorial, you combined Hidden Markov Models with actions to solve an optimal control problem! This showed us the core formalism of the *Partially Observable Markov Decision Process* (POMDP).\n", - "\n", - "Using observations (fish caught), you built beliefs (posterior distributions) that helped you estimate where the fish were. Next, you computed a value function that helped you evaluate the quality of different policies. Finally, using a brute force approach, you discovered an optimal policy that allowed you to catch as many fish as possible while minimizing the effort of switching your location.\n", - "\n", - "The following tutorial will use continuous states and actions instead of the binary ones we used here. In continuous control, we can still use a POMDP, but we'll focus on control in the *fully* observed case, a Markov Decision Process (MDP), since the policy is still illuminating." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 6: From discrete to continuous control\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', 'ndCMgdjv9Gg'), ('Bilibili', 'BV1JA411v7jy')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_From_discrete_to_continuous_control_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "---\n", - "# Bonus" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "## Bonus Section 1: How does the optimal policy depend on the task?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @title Video 7: Sensitivity of optimal policy\n", - "from ipywidgets import widgets\n", - "from IPython.display import YouTubeVideo\n", - "from IPython.display import IFrame\n", - "from IPython.display import display\n", - "\n", - "\n", - "class PlayVideo(IFrame):\n", - " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", - " self.id = id\n", - " if source == 'Bilibili':\n", - " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", - " elif source == 'Osf':\n", - " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", - " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", - "\n", - "\n", - "def display_videos(video_ids, W=400, H=300, fs=1):\n", - " tab_contents = []\n", - " for i, video_id in enumerate(video_ids):\n", - " out = widgets.Output()\n", - " with out:\n", - " if video_ids[i][0] == 'Youtube':\n", - " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", - " height=H, fs=fs, rel=0)\n", - " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", - " else:\n", - " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", - " height=H, fs=fs, autoplay=False)\n", - " if video_ids[i][0] == 'Bilibili':\n", - " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", - " elif video_ids[i][0] == 'Osf':\n", - " print(f'Video available at https://osf.io/{video.id}')\n", - " display(video)\n", - " tab_contents.append(out)\n", - " return tab_contents\n", - "\n", - "\n", - "video_ids = [('Youtube', 'wd8IVsKoEfA'), ('Bilibili', 'BV1QK4y1e7N9')]\n", - "tab_contents = display_videos(video_ids, W=854, H=480)\n", - "tabs = widgets.Tab()\n", - "tabs.children = tab_contents\n", - "for i in range(len(tab_contents)):\n", - " tabs.set_title(i, video_ids[i][0])\n", - "display(tabs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Sensitivity_of_optimal_policy_Bonus_Video\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": {} - }, - "source": [ - "### Bonus Interactive Demo 1: Explore task parameters\n", - "\n", - "In this demo, you can play with various task parameters. Observe how the optimal threshold changes when you adjust:\n", - "* The switching cost\n", - "* The fish dynamics (`p(stay)`)\n", - "* The probability of catching fish on each side, `p(high_rwd)` and `p(low_rwd)`\n", - "\n", - "Can you explain why the optimal threshold changes with these parameters:\n", - "\n", - "* lower/higher switching cost?\n", - "* faster fish dynamics (_i.e._, low `p_stay`)?\n", - "* rarer fish caught (_i.e._, low `p(high_rwd)` and low `p(low_rwd)`)?\n", - "\n", - "Note that it may require long simulations to see subtle changes in values of different policies, so look for coarse trends first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# @markdown Make sure you execute this cell to enable the widget!\n", - "display(HTML(''''''))\n", - "\n", - "@widgets.interact(p_stay=widgets.FloatSlider(.95, description=\"p(stay)\",\n", - " min=0., max=1.),\n", - " p_high_rwd=widgets.FloatSlider(.4, description=\"p(high_rwd)\",\n", - " min=0., max=1.),\n", - " p_low_rwd=widgets.FloatSlider(.1, description=\"p(low_rwd)\",\n", - " min=0., max=1.),\n", - " cost_sw=widgets.FloatSlider(.2, description=\"switching cost\",\n", - " min=0., max=1., step=.01))\n", - "\n", - "\n", - "def update_ex_bonus(p_stay, p_high_rwd, p_low_rwd, cost_sw):\n", - " \"\"\"\n", - " p_stay: probability fish stay\n", - " high_rew_p: p(catch fish) when you're on their side\n", - " low_rew_p : p(catch fish) when you're on other side\n", - " cost_sw: switching cost\n", - " \"\"\"\n", - "\n", - " threshold_array, value_array = get_optimal_threshold(p_stay,\n", - " p_low_rwd,\n", - " p_high_rwd,\n", - " cost_sw)\n", - " plot_value_threshold(threshold_array, value_array)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": {}, - "tags": [] - }, - "outputs": [], - "source": [ - "# to_remove explanation\n", - "\n", - "\"\"\"\n", - "* lower/higher switching cost?\n", - "\n", - " High switching cost means that you should be more certain that the other side\n", - " is better before committing to change sides. This means that beliefs must fall\n", - " below a threshold before acting. Conversely, a lower switching cost allows you\n", - " more flexibility to switch at less stringent thresholds. In the limit of _zero_\n", - " switching cost, you should always switch whenever you think the other side is\n", - " better, even if it's just 51%, and even if you switch every time step.\n", - "\n", - "* faster fish dynamics (i.e., low p_stay)?\n", - "\n", - " Faster fish dynamics (lower `p_stay`) also promote faster switching because\n", - " you cannot plan as far into the future. In that case you must base your decisions\n", - " on more immediate evidence, but since you still pay the same switching cost that\n", - " cost is a higher fraction of your predictable rewards. Thus, you should be more\n", - " conservative and switch only when you are more confident.\n", - "\n", - "* rarer fish caught (i.e., low p(high_rwd) and low p(low_rwd))?\n", - "\n", - " When `high_rew_p` and/or `low_rew_p` decreases, your predictions become less reliable,\n", - " again encouraging you to require more confidence before committing to a switch.\n", - "\"\"\";" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "execution": {} - }, - "outputs": [], - "source": [ - "# @title Submit your feedback\n", - "content_review(f\"{feedback_prefix}_Explore_task_parameters_Bonus_Interactive_Demo_and_Discussion\")" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "include_colab_link": true, - "name": "W3D3_Tutorial1", - "provenance": [], - "toc_visible": true - }, - "kernel": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.17" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [ - "# Tutorial 1- Optimal Control for Discrete State\n", - "\n", - "Please execute the cell below to initialize the notebook environment.\n", - "\n", - "import numpy as np # import numpy\n", - "import scipy # import scipy\n", - "import random # import basic random number generator functions\n", - "from scipy.linalg import inv\n", - "\n", - "import matplotlib.pyplot as plt # import matplotlib\n", - "\n", - "---\n", - "\n", - "## Tutorial objectives\n", - "\n", - "In this tutorial, we will implement a binary HMM task.\n", - "\n", - "---\n", - "\n", - "## Task Description\n", - "\n", - "There are two boxes. The box can be in a high-rewarding state ($s=1$), which means that a reward will be delivered with high probabilty $q_{high}$; or the box can be in low-rewarding state ($s=0$), then the reward will be delivered with low probabilty $q_{low}$.\n", - "\n", - "The states of the two boxes are latent. At a certain time, only one of the sites can be in high-rewarding state, and the other box will be the opposite. The states of the two boxes switches with a certain probability $p_{sw}$. \n", - "\n", - "![alt text](switching.png \"Title\")\n", - "\n", - "\n", - "The agent may stay at one site for sometime. As the agent accumulates evidence about the state of the box on that site, it may choose to stay or switch to the other side with a switching cost $c$. The agent keeps beliefs on the states of the boxes, which is the posterior probability of the state being high-rewarding given all the past observations. Consider the belief on the state of the left box, we have \n", - "\n", - "$$b(s_t) = p(s_t = 1 | o_{0:t}, l_{0:t}, a_{0:t-1})$$\n", - "\n", - "where $o$ is the observation that whether a reward is obtained, $l$ is the location of the agent, $a$ is the action of staying ($a=0$) or switching($a=1$). \n", - "\n", - "Since the two boxes are completely anti-correlated, i.e. only one of the boxes is high-rewarded at a certain time, the the other one is low-rewarded, the belief on the two boxes should sum up to be 1. As a result, we only need to track the belief on one of the boxes. \n", - "\n", - "The policy of the agent depends on a threshold on beliefs. When the belief on the box on the other side gets higher than the threshold $\\theta$, the agent will switch to the other side. In other words, the agent will choose to switch when it is confident enough that the other side is high rewarding. \n", - "\n", - "The value function can be defined as the reward rate during a single trial.\n", - "\n", - "$$v(\\theta) = \\sum_t r_t - c\\cdot 1_{a_t = 1}$$ \n", - "\n", - "we would like to see the relation between the threshold and the value function. \n", - "\n", - "### Exercise 1: Control for binary HMM\n", - "In this excercise, we generate the dynamics for the binary HMM task as described above. \n", - "\n", - "# This function is the policy based on threshold\n", - "\n", - "def policy(threshold, bel, loc):\n", - " if loc == 0:\n", - " if bel[1] >= threshold:\n", - " act = 1\n", - " else:\n", - " act = 0\n", - " else: # loc = 1\n", - " if bel[0] >= threshold:\n", - " act = 1\n", - " else:\n", - " act = 0\n", - "\n", - " return act\n", - "\n", - "# This function generates the dynamics\n", - "\n", - "def generateProcess(params):\n", - "\n", - " T, p_sw, q_high, q_low, cost_sw, threshold = params\n", - " world_state = np.zeros((2, T), int) # value :1: good box; 0: bad box\n", - " loc = np.zeros(T, int) # 0: left box 1: right box\n", - " obs = np.zeros(T, int) # 0: did not get food 1: get food\n", - " act = np.zeros(T, int) # 0 : stay 1: switch and get food from the other side\n", - " bel = np.zeros((2, T), float) # the probability that the left box has food,\n", - " # then the probability that the second box has food is 1-b\n", - "\n", - "\n", - " p = np.array([1 - p_sw, p_sw]) # transition probability to good state\n", - " q = np.array([q_low, q_high])\n", - " q_mat = np.array([[1 - q_high, q_high], [1 - q_low, q_low]])\n", - "\n", - " for t in range(T):\n", - " if t == 0:\n", - " world_state[0, t] = 1 # good box\n", - " world_state[1, t] = 1 - world_state[0, t]\n", - " loc[t] = 0\n", - " obs[t] = 0\n", - " bel_0 = np.random.random(1)[0]\n", - " bel[:, t] = np.array([bel_0, 1-bel_0])\n", - "\n", - " act[t] = policy(threshold, bel[:, t], loc[t])\n", - "\n", - " else:\n", - " world_state[0, t] = np.random.binomial(1, p[world_state[0, t - 1]])\n", - " world_state[1, t] = 1 - world_state[0, t]\n", - "\n", - " if act[t - 1] == 0:\n", - " loc[t] = loc[t - 1]\n", - " else: # after weitching, open the new box, deplete if any; then wait a usualy time\n", - " loc[t] = 1 - loc[t - 1]\n", - "\n", - " # new observation\n", - " obs[t] = np.random.binomial(1, q[world_state[loc[t], t-1]])\n", - "\n", - " # update belief posterior, p(s[t] | obs(0-t), act(0-t-1))\n", - " bel_0 = (bel[0, t-1] * p_sw + bel[1, t-1] * (1 - p_sw)) * q_mat[loc[t], obs[t]]\n", - " bel_1 = (bel[1, t - 1] * p_sw + bel[0, t - 1] * (1 - p_sw)) * q_mat[1-loc[t], obs[t]]\n", - "\n", - " bel[0, t] = bel_0 / (bel_0 + bel_1)\n", - " bel[1, t] = bel_1 / (bel_0 + bel_1)\n", - "\n", - " act[t] = policy(threshold, bel[:, t], loc[t])\n", - "\n", - " return bel, obs, act, world_state, loc\n", - "\n", - "# value function \n", - "def value_function(obs, act, cost_sw, discount):\n", - " T = len(obs)\n", - " discount_time = np.array([discount ** t for t in range(T)])\n", - "\n", - " #value = (np.sum(obs) - np.sum(act) * cost_sw) / T\n", - " value = (np.sum(np.multiply(obs, discount_time)) - np.sum(np.multiply(act, discount_time)) * cost_sw) / T\n", - "\n", - " return value\n", - "\n", - "def switch_int(obs, act):\n", - " sw_t = np.where(act == 1)[0]\n", - " sw_int = sw_t[1:] - sw_t[:-1]\n", - "\n", - " return sw_int\n", - "\n", - "#Plotting \n", - "def plot_dynamics(bel, obs, act, world_state, loc):\n", - " T = len(obs)\n", - "\n", - " showlen = min(T, 100)\n", - " startT = 0\n", - "\n", - " endT = startT + showlen\n", - " showT = range(startT, endT)\n", - " time_range = np.linspace(0, showlen - 1)\n", - "\n", - " fig_posterior, [ax0, ax1, ax_loc, ax2, ax3] = plt.subplots(5, 1, figsize=(15, 10))\n", - "\n", - " ax0.plot(world_state[0, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", - " ax0.set_ylabel('Left box', rotation=360, fontsize=22)\n", - " ax0.yaxis.set_label_coords(-0.1, 0.25)\n", - " ax0.set_xticks(np.arange(0, showlen, 10))\n", - " ax0.tick_params(axis='both', which='major', labelsize=18)\n", - " ax0.set_xlim([0, showlen])\n", - "\n", - "\n", - " ax3.plot(world_state[1, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", - " ax3.set_ylabel('Right box', rotation=360, fontsize=22)\n", - " ax3.yaxis.set_label_coords(-0.1, 0.25)\n", - " ax3.tick_params(axis='both', which='major', labelsize=18)\n", - " ax3.set_xlim([0, showlen])\n", - " ax3.set_xticks(np.arange(0, showlen, 10))\n", - "\n", - " ax1.plot(bel[0, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", - " ax1.plot(time_range, threshold * np.ones(time_range.shape), 'r--')\n", - " ax1.yaxis.set_label_coords(-0.1, 0.25)\n", - " ax1.set_ylabel('Belief on \\n left box', rotation=360, fontsize=22)\n", - " ax1.tick_params(axis='both', which='major', labelsize=18)\n", - " ax1.set_xlim([0, showlen])\n", - " ax1.set_ylim([0, 1])\n", - " ax1.set_xticks(np.arange(0, showlen, 10))\n", - "\n", - "\n", - " ax_loc.plot(1 - loc[showT], 'g.-', markersize=12, linewidth=5, label = 'location')\n", - " ax_loc.plot((act[showT] - .1) * .8, 'v', markersize=10, label = 'action')\n", - " ax_loc.plot(obs[showT] * .5, '*', markersize=5, label = 'reward')\n", - " ax_loc.legend(loc=\"upper right\")\n", - " ax_loc.set_xlim([0, showlen])\n", - " ax_loc.set_ylim([0, 1])\n", - " #ax_loc.set_yticks([])\n", - " ax_loc.set_xticks([0, showlen])\n", - " ax_loc.tick_params(axis='both', which='major', labelsize=18)\n", - " labels = [item.get_text() for item in ax_loc.get_yticklabels()]\n", - " labels[0] = 'Right'\n", - " labels[-1] = 'Left'\n", - " ax_loc.set_yticklabels(labels)\n", - "\n", - " ax2.plot(bel[1, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", - " ax2.plot(time_range, threshold * np.ones(time_range.shape), 'r--')\n", - " ax2.set_xlabel('time', fontsize=18)\n", - " ax2.yaxis.set_label_coords(-0.1, 0.25)\n", - " ax2.set_ylabel('Belief on \\n right box', rotation=360, fontsize=22)\n", - " ax2.tick_params(axis='both', which='major', labelsize=18)\n", - " ax2.set_xlim([0, showlen])\n", - " ax2.set_ylim([0, 1])\n", - " ax2.set_xticks(np.arange(0, showlen, 10))\n", - "\n", - " plt.show()\n", - "\n", - "def plot_val_thre(threshold_array, value_array):\n", - " fig_, ax = plt.subplots(1, 1, figsize=(10, 10))\n", - " ax.plot(threshold_array, value_array)\n", - " ax.set_ylim([np.min(value_array), np.max(value_array)])\n", - " ax.set_title('threshold vs value')\n", - " ax.set_xlabel('threshold')\n", - " ax.set_ylabel('value')\n", - " plt.show()\n", - "\n", - "T = 5000\n", - "p_sw = .95 # state transiton probability\n", - "q_high = .7\n", - "q_low = 0 #.2\n", - "cost_sw = 1 #int(1/(1-p_sw)) - 5\n", - "threshold = .8 # threshold of belief for switching\n", - "discount = 1\n", - "\n", - "step = 0.1\n", - "threshold_array = np.arange(0, 1 + step, step)\n", - "value_array = np.zeros(threshold_array.shape)\n", - "\n", - "for i in range(len(threshold_array)):\n", - " threshold = threshold_array[i]\n", - " params = [T, p_sw, q_high, q_low, cost_sw, threshold]\n", - " bel, obs, act, world_state, loc = generateProcess(params)\n", - " value_array[i] = value_function(obs, act, cost_sw, discount)\n", - " sw_int = switch_int(obs, act)\n", - " #print(np.mean(sw_int))\n", - "\n", - " if threshold == 0.8:\n", - " plot_dynamics(bel, obs, act, world_state, loc)\n", - "\n", - "plot_val_thre(threshold_array, value_array)\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "kjpX2qMZZX_J" + }, + "source": [ + "# Tutorial 1: Optimal Control for Discrete States\n", + "\n", + "**Week 3, Day 3: Optimal Control**\n", + "\n", + "**By Neuromatch Academy**\n", + "\n", + "**Content creators:** Zhengwei Wu, Itzel Olivos Castillo, Shreya Saxena, Xaq Pitkow\n", + "\n", + "**Content reviewers:** Karolina Stosio, Roozbeh Farhoodi, Saeed Salehi, Ella Batty, Spiros Chavlis, Matt Krause, Michael Waskom, Melisa Maidana Capitan\n", + "\n", + "**Production editors:** Spiros Chavlis\n", + "\n", + "**Post-Production editors**: Gagana B, Spiros Chavlis" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "SX1RJ04PZX_N" + }, + "source": [ + "---\n", + "# Tutorial Objectives\n", + "\n", + "*Estimated timing of tutorial: 60 min*\n", + "\n", + "In this tutorial, we will implement a **binary control** task: a Partially Observable Markov Decision Process (POMDP) that describes fishing. The agent (you) seeks reward from two fishing sites without directly observing where the school of fish is (yes, a group of fish is called a school!). This makes the world a Hidden Markov Model (HMM), just like in the *Hidden Dynamics* day. Based on when and where you catch fish, you keep updating your belief about the fish location, i.e., the posterior of the fish given past observations. You should control your position to get the most fish while minimizing the cost of switching sides.\n", + "\n", + "You've already learned about stochastic dynamics, latent states, and measurements. These first exercises largely repeat your previous work. Now we introduce **actions**, based on the new concepts of **control, utility, and policy**. This general structure provides a foundational model for the brain's computations because it includes a perception-action loop where the animal can gather information, draw inferences about its environment, and select actions with the greatest benefit. *How*, mechanistically, the neurons could actually implement these calculations is a separate question we don't address in this lesson.\n", + "\n", + "In this tutorial, you will:\n", + "* Use the Hidden Markov Models you learned about previously to model the world state.\n", + "* Use the observations (fish caught) to build beliefs (posterior distributions) about the fish location.\n", + "* Evaluate the quality of different control policies for choosing actions.\n", + "* Discover the policy that maximizes utility." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "5XfqRLMRZX_P" + }, + "outputs": [], + "source": [ + "# @title Tutorial slides\n", + "# @markdown These are the slides for all videos in this tutorial.\n", + "from IPython.display import IFrame\n", + "link_id = \"8j5rs\"\n", + "print(f\"If you want to download the slides: https://osf.io/download/{link_id}/\")\n", + "IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render\", width=854, height=480)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "b857uF6aZX_Q" + }, + "source": [ + "---\n", + "## Setup\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "qePPOIv4ZX_R" + }, + "outputs": [], + "source": [ + "# @title Install and import feedback gadget\n", + "\n", + "!pip3 install vibecheck datatops --quiet\n", + "\n", + "from vibecheck import DatatopsContentReviewContainer\n", + "def content_review(notebook_section: str):\n", + " return DatatopsContentReviewContainer(\n", + " \"\", # No text prompt\n", + " notebook_section,\n", + " {\n", + " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", + " \"name\": \"neuromatch_cn\",\n", + " \"user_key\": \"y1x3mpx5\",\n", + " },\n", + " ).render()\n", + "\n", + "\n", + "feedback_prefix = \"W3D3_T1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "tags": [], + "id": "LoIEPZzFZX_R" + }, + "outputs": [], + "source": [ + "# Imports\n", + "import numpy as np\n", + "from math import isclose\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "erZHk4hWZX_T" + }, + "outputs": [], + "source": [ + "# @title Figure Settings\n", + "import logging\n", + "logging.getLogger('matplotlib.font_manager').disabled = True\n", + "\n", + "import ipywidgets as widgets\n", + "from IPython.display import HTML\n", + "%config InlineBackend.figure_format = 'retina'\n", + "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "QwYWNsotZX_U" + }, + "outputs": [], + "source": [ + "# @title Plotting Functions\n", + "\n", + "def plot_fish(fish_state, ax=None, show=True):\n", + " \"\"\"\n", + " Plot the fish dynamics (states across time)\n", + " \"\"\"\n", + " T = len(fish_state)\n", + "\n", + " offset = 3\n", + "\n", + " if not ax:\n", + " fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))\n", + "\n", + " x = np.arange(0, T, 1)\n", + " y = offset * (fish_state*2 - 1)\n", + "\n", + " ax.plot(y, color='cornflowerblue', markersize=10, linewidth=3.0, zorder=0)\n", + " ax.fill_between(x, y, color='cornflowerblue', alpha=.3)\n", + "\n", + " ax.set_xlabel('time')\n", + " ax.set_ylabel('fish location')\n", + "\n", + " ax.set_xlim([0, T])\n", + " ax.set_xticks([])\n", + " ax.xaxis.set_label_coords(1.05, .54)\n", + "\n", + " ax.set_ylim([-(offset+.5), offset+.5])\n", + " ax.set_yticks([-offset, offset])\n", + " ax.set_yticklabels(['left', 'right'])\n", + "\n", + " ax.spines['bottom'].set_position('center')\n", + " if show:\n", + " plt.show()\n", + "\n", + "\n", + "def plot_measurement(measurement, ax=None, show=True):\n", + " \"\"\"\n", + " Plot the measurements\n", + " \"\"\"\n", + " T = len(measurement)\n", + "\n", + " rel_pos = 3\n", + " red_y = []\n", + " blue_y = []\n", + " for idx, value in enumerate(measurement):\n", + " if value == 0:\n", + " blue_y.append([idx, -rel_pos])\n", + " else:\n", + " red_y.append([idx, rel_pos])\n", + "\n", + " red_y = np.asarray(red_y)\n", + " blue_y = np.asarray(blue_y)\n", + "\n", + " if not ax:\n", + " fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))\n", + "\n", + " if len(red_y) > 0:\n", + " ax.plot(red_y[:, 0], red_y[:, 1], '*', markersize=8, color='crimson')\n", + "\n", + " if len(blue_y) > 0:\n", + " ax.plot(blue_y[:, 0], blue_y[:, 1], '*', markersize=8, color='royalblue')\n", + "\n", + " ax.set_xlabel('time', fontsize=18)\n", + " ax.set_ylabel('Caught fish?')\n", + "\n", + " ax.set_xlim([0, T])\n", + " ax.set_xticks([])\n", + " ax.xaxis.set_label_coords(1.05, .54)\n", + "\n", + " ax.set_ylim([-rel_pos - .5, rel_pos + .5])\n", + " ax.set_yticks([-rel_pos, rel_pos])\n", + " ax.set_yticklabels(['no', 'yes!'])\n", + "\n", + " ax.spines['bottom'].set_position('center')\n", + " if show:\n", + " plt.show()\n", + "\n", + "\n", + "def plot_act_loc(loc, act, ax_loc=None, show=True):\n", + " \"\"\"\n", + " Plot the action and location of T time points\n", + " \"\"\"\n", + " T = len(act)\n", + "\n", + " if not ax_loc:\n", + " fig, ax_loc = plt.subplots(1, 1, figsize=(12, 2.5))\n", + "\n", + " loc = loc*2 - 1\n", + " act_down = []\n", + " act_up = []\n", + " for t in range(1, T):\n", + " if loc[t-1] == -1 and loc[t] == 1:\n", + " act_up.append([t - 0.5, 0])\n", + " if loc[t-1] == 1 and loc[t] == -1:\n", + " act_down.append([t - 0.5, 0])\n", + "\n", + " act_down = np.array(act_down)\n", + " act_up = np.array(act_up)\n", + "\n", + " ax_loc.plot(loc, 'g.-', markersize=8, linewidth=5)\n", + "\n", + " if len(act_down) > 0:\n", + " ax_loc.plot(act_down[:, 0], act_down[:, 1], 'rv', markersize=18, zorder=10, label='switch')\n", + "\n", + " if len(act_up) > 0:\n", + " ax_loc.plot(act_up[:, 0], act_up[:, 1], 'r^', markersize=18, zorder=10)\n", + "\n", + " ax_loc.set_xlabel('time')\n", + " ax_loc.set_ylabel('Your state')\n", + "\n", + " ax_loc.set_xlim([0, T])\n", + " ax_loc.set_xticks([])\n", + " ax_loc.xaxis.set_label_coords(1.05, .54)\n", + "\n", + " if len(act_down) > 0:\n", + " ax_loc.legend(loc=\"upper right\")\n", + " elif len(act_down) == 0 and len(act_up) > 0:\n", + " ax_loc.plot(act_up[:, 0], act_up[:, 1], 'r^', markersize=18, zorder=10, label='switch')\n", + " ax_loc.legend(loc=\"upper right\")\n", + "\n", + " ax_loc.set_ylim([-1.1, 1.1])\n", + " ax_loc.set_yticks([-1, 1])\n", + "\n", + " ax_loc.tick_params(axis='both', which='major')\n", + " ax_loc.set_yticklabels(['left', 'right'])\n", + "\n", + " ax_loc.spines['bottom'].set_position('center')\n", + "\n", + " if show:\n", + " plt.show()\n", + "\n", + "\n", + "def plot_belief(belief, ax1=None, choose_policy=None, show=True):\n", + " \"\"\"\n", + " Plot the belief dynamics of T time points\n", + " \"\"\"\n", + "\n", + " T = belief.shape[1]\n", + "\n", + " if not ax1:\n", + " fig, ax1 = plt.subplots(1, 1, figsize=(12, 2.5))\n", + "\n", + " ax1.plot(belief[1, :], color='midnightblue', markersize=10, linewidth=3.0)\n", + "\n", + " ax1.set_xlabel('time')\n", + " ax1.set_ylabel('Belief (right)')\n", + "\n", + " ax1.set_xlim([0, T])\n", + " ax1.set_xticks([])\n", + " ax1.xaxis.set_label_coords(1.05, 0.05)\n", + "\n", + " ax1.set_yticks([0, 1])\n", + " ax1.set_ylim([0, 1.1])\n", + "\n", + " labels = [item.get_text() for item in ax1.get_yticklabels()]\n", + " ax1.set_yticklabels([' 0', ' 1'])\n", + "\n", + " \"\"\"\n", + " if choose_policy == \"threshold\":\n", + " ax2 = ax1.twinx()\n", + " ax2.plot(time_range, threshold * np.ones(time_range.shape), 'r--')\n", + " ax2.plot(time_range, (1 - threshold) * np.ones(time_range.shape), 'c--')\n", + " ax2.set_yticks([threshold, 1 - threshold])\n", + " ax2.set_ylim([0, 1.1])\n", + " ax2.tick_params(axis='both', which='major', labelsize=18)\n", + " labels = [item.get_text() for item in ax2.get_yticklabels()]\n", + " labels[0] = 'threshold to switch \\n from left to right'\n", + " labels[-1] = 'threshold to switch \\n from right to left'\n", + " ax2.set_yticklabels(labels)\n", + " \"\"\"\n", + " if show:\n", + " plt.show()\n", + "\n", + "\n", + "def plot_dynamics(belief, loc, act, meas, fish_state, choose_policy):\n", + " \"\"\"\n", + " Plot the dynamics of T time points\n", + " \"\"\"\n", + " if choose_policy == 'threshold':\n", + " fig, [ax0, ax_bel, ax_loc, ax1] = plt.subplots(4, 1, figsize=(12, 9))\n", + " plot_fish(fish_state, ax=ax0, show=False)\n", + " plot_belief(belief, ax1=ax_bel, show=False)\n", + " plot_measurement(meas, ax=ax1, show=False)\n", + " plot_act_loc(loc, act, ax_loc=ax_loc)\n", + " else:\n", + " fig, [ax0, ax_bel, ax1] = plt.subplots(3, 1, figsize=(12, 7))\n", + " plot_fish(fish_state, ax=ax0, show=False)\n", + " plot_belief(belief, ax1=ax_bel, show=False)\n", + " plot_measurement(meas, ax=ax1, show=False)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "\n", + "def belief_histogram(belief, bins=100):\n", + " \"\"\"\n", + " Plot the histogram of belief states\n", + " \"\"\"\n", + " fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", + " ax.hist(belief, bins)\n", + " ax.set_xlabel('belief', fontsize=18)\n", + " ax.set_ylabel('count', fontsize=18)\n", + " plt.show()\n", + "\n", + "\n", + "def plot_value_threshold(threshold_array, value_array):\n", + " \"\"\"\n", + " Helper function to plot the value function and threshold\n", + " \"\"\"\n", + " yrange = np.max(value_array) - np.min(value_array)\n", + " star_loc = np.argmax(value_array)\n", + "\n", + " fig_, ax = plt.subplots(1, 1, figsize=(8, 6))\n", + " ax.plot(threshold_array, value_array, 'b')\n", + " ax.vlines(threshold_array[star_loc],\n", + " min(value_array) - yrange * .1, max(value_array),\n", + " colors='red', ls='--')\n", + " ax.plot(threshold_array[star_loc],\n", + " value_array[star_loc],\n", + " '*', color='crimson',\n", + " markersize=20)\n", + "\n", + " ax.set_ylim([np.min(value_array) - yrange * .1,\n", + " np.max(value_array) + yrange * .1])\n", + " ax.set_title(f'threshold vs value with switching cost c = {cost_sw:.2f}',\n", + " fontsize=20)\n", + " ax.set_xlabel('threshold', fontsize=16)\n", + " ax.set_ylabel('value', fontsize=16)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "EwPYbXxMZX_X" + }, + "outputs": [], + "source": [ + "# @title Helper Functions\n", + "\n", + "# To generate a binomial with fixed \"noise\",\n", + "# we generate a sequence of T numbers uniformly at random\n", + "T = 100\n", + "\n", + "rnd_tele = np.random.uniform(0, 1, T)\n", + "rnd_high_rwd = np.random.uniform(0, 1, T)\n", + "rnd_low_rwd = np.random.uniform(0, 1, T)\n", + "\n", + "\n", + "def get_randomness(T):\n", + " global rnd_tele\n", + " global rnd_high_rwd\n", + " global rnd_low_rwd\n", + "\n", + " rnd_tele = np.random.uniform(0, 1, T)\n", + " rnd_high_rwd = np.random.uniform(0, 1, T)\n", + " rnd_low_rwd = np.random.uniform(0, 1, T)\n", + "\n", + "\n", + "def binomial_tele(p):\n", + " return np.array([1 if p > rnd_tele[i] else 0 for i in range(T)])\n", + "\n", + "\n", + "def getRandomness(p, largeT):\n", + " global rnd_tele\n", + " global rnd_high_rwd\n", + " global rnd_low_rwd\n", + "\n", + " rnd_tele = np.random.uniform(0, 1, largeT)\n", + " rnd_high_rwd = np.random.uniform(0, 1, largeT)\n", + " rnd_low_rwd = np.random.uniform(0, 1, largeT)\n", + "\n", + " return [np.array([1 if p > rnd_tele[i] else 0 for i in range(T)]),\n", + " rnd_high_rwd, rnd_low_rwd]\n", + "\n", + "# def binomial_high_rwd(p):\n", + "# return np.array([1 if p > rnd_high_rwd[i] else 0 for i in range(T)])\n", + "\n", + "# def binomial_low_rwd(p):\n", + "# return np.array([1 if p > rnd_low_rwd[i] else 0 for i in range(T)])\n", + "\n", + "\n", + "class ExcerciseError(AssertionError):\n", + " pass\n", + "\n", + "\n", + "class binaryHMM():\n", + "\n", + " def __init__(self, params, fish_initial=0, loc_initial=0):\n", + " self.params = params\n", + " self.fish_initial = fish_initial\n", + " self.loc_initial = loc_initial\n", + "\n", + " def fish_dynamics(self):\n", + " \"\"\"\n", + " fish state dynamics according to telegraph process\n", + "\n", + " Returns:\n", + " fish_state (numpy array of int)\n", + " \"\"\"\n", + " p_stay, _, _, _ = self.params\n", + " fish_state = np.zeros(T, int) # 0: left side and 1: right side\n", + "\n", + " # initialization\n", + " fish_state[0] = self.fish_initial\n", + " tele_operations = binomial_tele(p_stay) # 0: switch and 1: stay\n", + "\n", + " for t in range(1, T):\n", + " # we use logical operation NOT XOR to determine the next state\n", + " fish_state[t] = int(not(fish_state[t-1] ^ tele_operations[t]))\n", + "\n", + " return fish_state\n", + "\n", + " def generate_process_lazy(self):\n", + " \"\"\"\n", + " fish dynamics and rewards if you always stay in the initial location\n", + " without changing sides\n", + "\n", + " Returns:\n", + " fish_state (numpy array of int): locations of the fish\n", + " loc (numpy array of int): left or right site, 0 for left, and 1 for right\n", + " rwd (numpy array of binary): whether a fish was catched or not\n", + " \"\"\"\n", + "\n", + " _, p_low_rwd, p_high_rwd, _ = self.params\n", + "\n", + " fish_state = self.fish_dynamics()\n", + " rwd = np.zeros(T, int) # 0: no food, 1: get food\n", + "\n", + " for t in range(0, T):\n", + " # new measurement\n", + " if fish_state[t] != self.loc_initial:\n", + " rwd[t] = 1 if p_low_rwd > rnd_low_rwd[t] else 0\n", + " else:\n", + " rwd[t] = 1 if p_high_rwd > rnd_high_rwd[t] else 0\n", + "\n", + " # rwd[t] = binomial(1, p_rwd_vector[(fish_state[t] == loc[t]) * 1])\n", + " return fish_state, self.loc_initial*np.ones(T), rwd\n", + "\n", + "\n", + "class binaryHMM_belief(binaryHMM):\n", + "\n", + " def __init__(self, params,\n", + " fish_initial=0, loc_initial=1,\n", + " choose_policy='threshold'):\n", + "\n", + " binaryHMM.__init__(self, params, fish_initial, loc_initial)\n", + " self.choose_policy = choose_policy\n", + "\n", + " def generate_process(self):\n", + " \"\"\"\n", + " fish dynamics and measurements based on the chosen policy\n", + "\n", + " Returns:\n", + " belief (numpy array of float): belief on the states of the two sites\n", + " act (numpy array of string): actions over time\n", + " loc (numpy array of int): left or right site\n", + " measurement (numpy array of binary): whether a reward is obtained\n", + " fish_state (numpy array of int): fish locations\n", + " \"\"\"\n", + "\n", + " p_stay, low_rew_p, high_rew_p, threshold = self.params\n", + " fish_state = self.fish_dynamics() # 0: left side; 1: right side\n", + " loc = np.zeros(T, int) # 0: left side, 1: right side\n", + " measurement = np.zeros(T, int) # 0: no food, 1: get food\n", + " act = np.empty(T, dtype='object') # \"stay\", or \"switch\"\n", + " belief = np.zeros((2, T), float) # the probability that the fish is on the left (1st element)\n", + " # or on the right (2nd element),\n", + " # the beliefs on the two boxes sum up to be 1\n", + "\n", + " rew_prob = np.array([low_rew_p, high_rew_p])\n", + "\n", + " # initialization\n", + " loc[0] = self.loc_initial\n", + " measurement[0] = 0\n", + " belief_0 = np.random.random(1)[0]\n", + " belief[:, 0] = np.array([belief_0, 1 - belief_0])\n", + " act[0] = self.policy(threshold, belief[:, 0], loc[0])\n", + "\n", + " for t in range(1, T):\n", + " if act[t - 1] == \"stay\":\n", + " loc[t] = loc[t - 1]\n", + " else:\n", + " loc[t] = int(not(loc[t - 1] ^ 0))\n", + "\n", + " # new measurement\n", + " # measurement[t] = binomial(1, rew_prob[(fish_state[t] == loc[t]) * 1])\n", + " if fish_state[t] != loc[t]:\n", + " measurement[t] = 1 if low_rew_p > rnd_low_rwd[t] else 0\n", + " else:\n", + " measurement[t] = 1 if high_rew_p > rnd_high_rwd[t] else 0\n", + "\n", + " belief[0, t] = self.belief_update(belief[0, t - 1] , loc[t],\n", + " measurement[t], p_stay,\n", + " high_rew_p, low_rew_p)\n", + " belief[1, t] = 1 - belief[0, t]\n", + "\n", + " act[t] = self.policy(threshold, belief[:, t], loc[t])\n", + "\n", + " return belief, loc, act, measurement, fish_state\n", + "\n", + " def policy(self, threshold, belief, loc):\n", + " \"\"\"\n", + " chooses policy based on whether it is lazy policy\n", + " or a threshold-based policy\n", + "\n", + " Args:\n", + " threshold (float): the threshold of belief on the current site,\n", + " when the belief is lower than the threshold, switch side\n", + " belief (numpy array of float): the belief on the two sites\n", + " loc (int) : the location of the agent\n", + "\n", + " Returns:\n", + " act (string): \"stay\" or \"switch\"\n", + " \"\"\"\n", + " if self.choose_policy == \"threshold\":\n", + " act = policy_threshold(threshold, belief, loc)\n", + " if self.choose_policy == \"lazy\":\n", + " act = policy_lazy(belief, loc)\n", + "\n", + " return act\n", + "\n", + " def belief_update(self, belief_past, loc, measurement, p_stay,\n", + " high_rew_p, low_rew_p):\n", + " \"\"\"\n", + " using PAST belief on the LEFT box, CURRENT location and\n", + " and measurement to update belief\n", + " \"\"\"\n", + " rew_prob_matrix = np.array([[1 - high_rew_p, high_rew_p],\n", + " [1 - low_rew_p, low_rew_p]])\n", + "\n", + " # update belief posterior, p(s[t] | measurement(0-t), act(0-t-1))\n", + " belief_0 = (belief_past * p_stay + (1 - belief_past) * (1 - p_stay)) *\\\n", + " rew_prob_matrix[(loc + 1) // 2, measurement]\n", + " belief_1 = ((1 - belief_past) * p_stay + belief_past * (1 - p_stay)) *\\\n", + " rew_prob_matrix[1-(loc + 1) // 2, measurement]\n", + "\n", + " belief_0 = belief_0 / (belief_0 + belief_1)\n", + "\n", + " return belief_0\n", + "\n", + "\n", + "def policy_lazy(belief, loc):\n", + " \"\"\"\n", + " This function is a lazy policy where stay is also taken\n", + " \"\"\"\n", + " act = \"stay\"\n", + "\n", + " return act\n", + "\n", + "\n", + "def test_policy_threshold():\n", + " well_done = True\n", + " for loc in [-1, 1]:\n", + " threshold = 0.4\n", + " belief = np.array([.2, .3])\n", + " if policy_threshold(threshold, belief, loc) != \"switch\":\n", + " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", + " for loc in [1, -1]:\n", + " threshold = 0.6\n", + " belief = np.array([.7, .8])\n", + " if policy_threshold(threshold, belief, loc) != \"stay\":\n", + " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", + " print(\"Well Done!\")\n", + "\n", + "\n", + "def test_policy_threshold():\n", + " for loc in [-1, 1]:\n", + " threshold = 0.4\n", + " belief = np.ones(2) * (threshold + 0.1)\n", + " belief[(loc + 1) // 2] = threshold - 0.1\n", + "\n", + " if policy_threshold(threshold, belief, loc) != \"switch\":\n", + " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", + " if policy_threshold(threshold, belief, -1 * loc) != \"stay\":\n", + " raise ExcerciseError(\"'policy_threshold' function is not correctly implemented!\")\n", + "\n", + " print(\"Well Done!\")\n", + "\n", + "\n", + "def test_value_function():\n", + " measurement = np.array([0, 0, 0, 1, 0, 0, 0, 0, 1, 1])\n", + " act = np.array([\"switch\", \"stay\", \"switch\", \"stay\", \"stay\",\n", + " \"stay\", \"switch\", \"switch\", \"stay\", \"stay\"])\n", + " cost_sw = .5\n", + " if not isclose(get_value(measurement, act, cost_sw), .1):\n", + " raise ExcerciseError(\"'value_function' function is not correctly implemented!\")\n", + " print(\"Well Done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "iny6xbAQZX_a" + }, + "source": [ + "---\n", + "# Section 1: Analyzing the Problem" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "ei-9lgxtZX_b" + }, + "outputs": [], + "source": [ + "# @title Video 1: Gone fishing\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', '3oIwUFpolVA'), ('Bilibili', 'BV1FL411p7o5')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "qNeebiBfZX_b" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Gone_fishing_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "3_UecR9lZX_c" + }, + "source": [ + "**Problem Setting**\n", + "\n", + "*1. State dynamics:* There are two possible locations for the fish: Left and Right. Secretly, at each time step, the fish may switch sides with a certain probability $p_{\\rm sw} = 1 - p_{\\rm stay}$. This is the binary switching model (*Telegraph process*) that you've seen in the *Linear Systems* day. The fish location, $s^{\\rm fish}$, is latent; you get measurements about it when you try to catch fish, like in the *Hidden Dynamics* day. This gives you a *belief* or posterior probability of the current location given your history of measurements.\n", + "\n", + "*2. Actions:* Unlike past days, you can now **act** on the process! You may stay on your current location (Left or Right), or switch to the other side.\n", + "\n", + "*3. Rewards and Costs:* You get rewarded for each fish you catch (one fish is worth 1 \"point\"). If you're on the same side as the fish, you'll catch more, with probability $q_{\\rm high}$ per discrete time step. Otherwise, you may still catch some fish with probability $q_{\\rm low}$.\n", + "\n", + "You pay a price of $C$ points for switching to the other side. So you better decide wisely!\n", + "\n", + "
\n", + "\n", + "**Maximizing Utility**\n", + "\n", + "To decide \"wisely\" and maximize your total utility (total points), you will follow a **policy** that prescribes what to do in any situation. Here the situation is determined by your location and your **belief** $b_t$ (posterior) about the fish location (remember that the fish location is a latent variable).\n", + "\n", + "In optimal control theory, the belief is the posterior probability over the latent variable given all the past measurements. It can be shown that maximizing the expected utility with respect to this posterior is optimal.\n", + "\n", + "In our problem, the belief can be represented by a single number because the fish are either on the left or the right side. So we write:\n", + "\n", + "\\begin{equation}\n", + "b_t = p(s^{\\rm fish}_t = {\\rm Right}\\ |\\ m_{0:t}, a_{0:t-1})\n", + "\\end{equation}\n", + "\n", + "where $m_{0:t}$ are the measurements and $a_{0:t-1}$ are the actions (stay or switch).\n", + "\n", + "Finally, we will parameterize the policy by a simple threshold on beliefs: when your belief that fish are on your current side falls below a threshold $\\theta$, you switch to the other side.\n", + "\n", + "In this tutorial, you will discover that if you pick the right threshold, this simple policy happens to be optimal!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "1Ald4YcTZX_d" + }, + "source": [ + "## Interactive Demo 1: Examining fish dynamics\n", + "\n", + "In this demo, we will look at the dynamics of the fish moving from side to side while you stay in one place. Play around with the probability `stay_prob` of fish staying in the same location, and observe the resulting dynamics of the fish." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "kDLHGeJnZX_d" + }, + "source": [ + "**Thinking questions:**\n", + "\n", + "* If the fish have already been on one side for a long time, does that change the chances of them switching sides?\n", + "* For what values of p_stay is the fish location most and least predictable?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "qtwaoJPhZX_d" + }, + "outputs": [], + "source": [ + "# @markdown Execute this cell to enable the demo.\n", + "display(HTML(''''''))\n", + "\n", + "@widgets.interact(p_stay=widgets.FloatSlider(.9, description=\"stay_prob\", min=0., max=1., step=0.01))\n", + "\n", + "def update_ex_1(p_stay):\n", + " \"\"\"\n", + " T: Length of timeline\n", + " p_stay: probability that the fish do not swim to the other side at time t\n", + " \"\"\"\n", + " params = [p_stay, _, _, _]\n", + "\n", + " # initial condition: fish [fish_initial] start at the left location (-1)\n", + " binaryHMM_test = binaryHMM(params=params, fish_initial=1)\n", + "\n", + " fish_state = binaryHMM_test.fish_dynamics()\n", + " plot_fish(fish_state)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "id": "RvAWfBVGZX_e" + }, + "outputs": [], + "source": [ + "# to_remove explanation\n", + "\n", + "\"\"\"\n", + "In Interactive Demo 1, you should see the school of fish switch sides less often when `stay_prob` is high.\n", + "\n", + "* If the fish have already been on one side for a long time, does that change the chances of them switching sides?\n", + "\n", + " No. The telegraph process or binary switching process is Markovian.\n", + " That means that the probabilities of changes depend only on the *current* state.\n", + " States from further in the past do not matter for the chances of switching sides.\n", + " Staying longer in one side is not a statement about the current state, but rather about the past,\n", + " so it's irrelevant for the chances of switching.\n", + "\n", + "\n", + "* For what values of `p_stay` is the fish location most and least predictable?\n", + "\n", + " When `p_stay` is 1 then the fish never move. But when `p_stay` is 0 then the fish *always* move,\n", + " oscillating back and forth deterministically every discrete time step.\n", + "\"\"\";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "6AbJXjIOZX_f" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Examining_fish_dynamics_Interactive_Demo_and_Discussion\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "Fe2agTiQZX_h" + }, + "source": [ + "---\n", + "# Section 2: Catching fish" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "Qu92q0MeZX_h" + }, + "outputs": [], + "source": [ + "# @title Video 2: Catch some fish\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', 'ZjB2_SAY2uE'), ('Bilibili', 'BV1kD4y1m7Lo')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "BD4tS9w4ZX_i" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Catch_some_fish_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "P2SR4FVGZX_j" + }, + "source": [ + "## Interactive Demo 2: Examining the reward function\n", + "\n", + "In this second demo, you control your location by a button, but we fix the fish's location by setting `stay_prob = 1`. Now that the fish are serenely swimming in one location, we can visually inspect the rewards when you're on the same side as the fish or on the other side.\n", + "\n", + "When you're on the same side as the fish, you should have a higher probability of catching them (but watch out, since technically, you are _allowed_ to adjust the sliders to other conditions!).\n", + "\n", + "Play around with the sliders `high_rew_prob` (high reward probability when you're on the fish's side) and `low_rew_prob` (low reward probability when you're on the other side). The button (same location *vs.* different location) determines which probability describes how often you catch fish." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "jwzOxMiRZX_j" + }, + "source": [ + "**Thinking questions:**\n", + "\n", + "* What happens when the fish and the agent (you!) are on the same or different locations?\n", + "* Where do you catch the most fish?\n", + "* Why isn't `low_rew_prob + high_rew_prob = 1`? What do these probabilities mean in the fishing story?\n", + "* You _can_ move the sliders so `low_rew_prob > high_rew_prob`. This doesn't change the math, but it can change whether the math is a reasonable model of the physical problem. Why?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "W6HJ-SzDZX_j" + }, + "outputs": [], + "source": [ + "# @markdown Execute this cell to enable the demo.\n", + "display(HTML(''''''))\n", + "\n", + "@widgets.interact(locs=widgets.RadioButtons(options=['same location', 'different locations'],\n", + " description='Fish and agent:',\n", + " disabled=False,\n", + " layout={'width': 'max-content'}),\n", + " p_low_rwd=widgets.FloatSlider(.1, description=\"low_rew_prob:\",\n", + " min=0., max=1.),\n", + " p_high_rwd=widgets.FloatSlider(.9, description=\"high_rew_prob:\",\n", + " min=0., max=1.))\n", + "\n", + "def update_ex_2(locs, p_low_rwd, p_high_rwd):\n", + " \"\"\"\n", + " p_stay: probability of fish staying at current side at time t\n", + " p_low_rwd: probability of catching fish when you're NOT on the side where the fish are swimming\n", + " p_high_rwd: probability of catching fish when you're on the side where the fish are swimming\n", + " fish_initial: initial side of fish (-1 left, 1 right)\n", + " agent_initial: initial side of the agent (YOU!) (-1 left, 1 right)\n", + " \"\"\"\n", + " p_stay = 1\n", + " params = [p_stay, p_low_rwd, p_high_rwd, _]\n", + "\n", + " # initial condition for fish [fish_initial] and you [loc_initial]\n", + " if locs == 'same location':\n", + " binaryHMM_test = binaryHMM(params, fish_initial=0, loc_initial=0)\n", + " else:\n", + " binaryHMM_test = binaryHMM(params, fish_initial=1, loc_initial=0)\n", + "\n", + " fish_state, loc, measurement = binaryHMM_test.generate_process_lazy()\n", + " plot_measurement(measurement)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "id": "N2fRHCjlZX_k" + }, + "outputs": [], + "source": [ + "# to_remove explanation\n", + "\n", + "\"\"\"\n", + "* What happens when the fish and the agent (you!) are on the same or different locations?\n", + " You catch fish with different probabilities.\n", + "\n", + "* Where do you catch the most fish?\n", + " When you're on the same side as the fish -- as long as high_rew_prob > low_rew_prob.\n", + "\n", + "* Why isn't low_rew_prob + high_rew_prob = 1? What do these probabilities mean in the fishing story?\n", + " These are not probabilities of mutually exclusive events. They are chances of one event (you catch fish)\n", + " under two different conditions (you and the school of fish are on the same side or different sides).\n", + "\n", + "* You _can_ move the sliders so `low_rew_prob > high_rew_prob`. This doesn't change the math,\n", + " but it can change whether the math is a reasonable model of the physical problem. Why?\n", + " It would be weird if you caught less fish when you're on the same side as the fish.\n", + " But hey, maybe the fish warn each other when they're in a school together! Then they'd be harder to catch...\n", + "\"\"\";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "puwQAEo-ZX_l" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Examining_the_reward_function_Interactive_Demo_and_Discussion\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "UWh8Mh6fZX_l" + }, + "source": [ + "---\n", + "# Section 3: Belief dynamics and belief distributions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "9UvOcb1MZX_m" + }, + "outputs": [], + "source": [ + "# @title Video 3: Where are the fish?\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', 'rmETVsRFYGk'), ('Bilibili', 'BV19t4y1Q7VH')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "Zg0JFHLNZX_n" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Where_are_the_fish_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "YsD4xGQWZX_o" + }, + "source": [ + "## Interactive Demo 3: Examining the beliefs\n", + "\n", + "Now it's time to get an intuition on how beliefs are calculated. Here we define your belief about the fish location is just the posterior probability about that location given your measurements, $p(s_t|m_{0:t})$. Note that this is just what you did in the day covering Hidden Dynamics!\n", + "\n", + "In this exercise, you'll always stay on the LEFT side, but the fish will move around. They'll stay on the same side with probability `stay_prob`. You only get to see fish you catch, not where the school of fish is. You have to use those measurements to infer the location of the school.\n", + "\n", + "In this demo, play around with the sliders `high_rew_prob` and `low_rew_prob`, and `stay_prob`.\n", + "\n", + "**Thinking questions:**\n", + "\n", + "* Manipulate the slider for `stay_prob`. How well does the belief explain the dynamics of the fish as you adjust the probability of the fish staying in one location (`stay_prob`)?\n", + "\n", + "* Explore the extreme case where `high_rew_prob = 1` and `low_rew_prob = 0`. How accurate is the belief as these parameters change?\n", + "\n", + "* Under what conditions is it informative to catch a fish? What about to *not* catching a fish?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "DUUGXDSCZX_o" + }, + "outputs": [], + "source": [ + "# @markdown Execute this cell to enable the demo.\n", + "display(HTML(''''''))\n", + "\n", + "@widgets.interact(p_stay=widgets.FloatSlider(.96, description=\"stay_prob\",\n", + " min=.8, max=1., step=.01),\n", + " p_low_rwd=widgets.FloatSlider(.1, description=\"low_rew_prob\",\n", + " min=0., max=1., step=.01),\n", + " p_high_rwd=widgets.FloatSlider(.3, description=\"high_rew_prob\",\n", + " min=0., max=1., step=.01))\n", + "\n", + "def update_ex_2(p_stay, p_low_rwd, p_high_rwd):\n", + " \"\"\"\n", + " T: Length of timeline\n", + " p_stay: probability of fish staying at current side at time t\n", + " p_high_rwd: probability of catching fish when you're on the side where the fish are swimming\n", + " p_low_rwd: probability of catching fish when you're NOT on the side where the fish are swimming\n", + " fish_initial: initial side of fish (0 left, 1 right)\n", + " agent_initial: initial side of the agent (YOU!) (0 left, 1 right)\n", + " threshold: threshold of belief below which the action is switching\n", + " \"\"\"\n", + " threshold = 0.2\n", + " params = [p_stay, p_low_rwd, p_high_rwd, threshold]\n", + "\n", + " binaryHMM_test = binaryHMM_belief(params, choose_policy=\"lazy\",\n", + " fish_initial=0, loc_initial=0)\n", + "\n", + " belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()\n", + " plot_dynamics(belief, loc, act, measurement, fish_state,\n", + " binaryHMM_test.choose_policy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "id": "G-hxtlQZZX_p" + }, + "outputs": [], + "source": [ + "# to_remove explanation\n", + "\n", + "\"\"\"\n", + "* Manipulate the slider for `stay_prob`. How well does the belief explain the dynamics of the fish as\n", + " you adjust the probability of the fish staying in one location (`stay_prob`)?\n", + "\n", + " The parameter (`stay_prob`) determines fish dynamics. If it is low, the fish are moving fast\n", + " and you don't have much time to collect observations that might decrease your uncertainty about\n", + " the actual location of the school. If it is high, you have more time to integrate evidence\n", + " and the belief explains better the dynamics of the fish.\n", + "\n", + "* Explore the extreme case where `high_rew_prob = 1` and `low_rew_prob = 0`.\n", + " Now play around with these sliders. How accurate is the belief as these parameters change?\n", + "\n", + " In the extreme case, the belief explains the dynamics of the fish perfectly because\n", + " our observations are perfect, i.e., catching a fish indicates with certainty the presence of the school.\n", + " If the chances of catching a fish are very different between the two sides, then you get a lot of information\n", + " for each fish you catch. The belief will then rise and fall steeply with each observation.\n", + " If the two probabilities are similar, then the belief will change slowly even if the fish move quickly.\n", + "\n", + "* Under what conditions is it informative to catch a fish? What about to *not* catch a fish?\n", + "\n", + " The bigger the difference in the two probabilities, the more information you get from measurements.\n", + " If both probabilities are low (and different), then you learn a lot from catching a fish.\n", + " But you still learn a little if you don't catch anything, particularly when catching a fish is probable in one case.\n", + "\"\"\";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "JSDGClShZX_q" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Examining_the_beliefs_Interactive_Demo_and_Discussion\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "u44RChOQZX_r" + }, + "source": [ + "---\n", + "# Section 4: Implementing a threshold policy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "bcfzTNjBZX_r" + }, + "outputs": [], + "source": [ + "# @title Video 4: How should you act?\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', 'cTzaQl2Vxn4'), ('Bilibili', 'BV1ri4y137cj')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "RPrNgALtZX_r" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_How_should_you_act_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "DJUFdKV8ZX_r" + }, + "source": [ + "## Coding Exercise 4: dynamics following a **threshold-based** policy\n", + "\n", + "Now we'll switch the policy from the 'lazy' policy used above to a threshold policy that you need to write. You'll change your location whenever your belief is low enough that you're on the best side. You'll update the function `policy_threshold(threshold, belief, loc)`. This policy takes three inputs:\n", + "\n", + "1. The `belief` about the fish state. For convenience, we will represent the belief at time *t* using a 2-dimensional vector. The first element is the belief that the fish are on the left, and the second element is the belief the fish are on the right. At every time step, these elements sum to 1.\n", + "\n", + "2. Your location `loc`, represented as \"Left\" = -1 and \"Right\" = 1.\n", + "\n", + "3. A belief `threshold` that determines when to switch. When your belief that you are on the same side as the fish drops below this threshold, you should move to the other location, and otherwise stay.\n", + "\n", + "Your function should return an action for each time *t*, which takes the value of \"stay\" or \"switch\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "execution": {}, + "tags": [], + "id": "_qb1nUP8ZX_s" + }, + "outputs": [], + "source": [ + "def policy_threshold(threshold, belief, loc):\n", + " \"\"\"\n", + " chooses whether to switch side based on whether the belief\n", + " on the current site drops below the threshold\n", + "\n", + " Args:\n", + " threshold (float): the threshold of belief on the current site,\n", + " when the belief is lower than the threshold, switch side\n", + " belief (numpy array of float, 2-dimensional): the belief on the\n", + " two sites at a certain time\n", + " loc (int) : the location of the agent at a certain time\n", + " -1 for left side, 1 for right side\n", + "\n", + " Returns:\n", + " act (string): \"stay\" or \"switch\"\n", + " \"\"\"\n", + "\n", + " ############################################################################\n", + " ## 1. Modify the code below to generate actions (stay or switch)\n", + " ## for current belief and location\n", + " ##\n", + " ## Belief is a 2d vector: first element = Prob(fish on Left | measurements)\n", + " ## second element = Prob(fish on Right | measurements)\n", + " ## Returns \"switch\" if Belief that fish are in your current location < threshold\n", + " ## \"stay\" otherwise\n", + " ##\n", + " ## Hint: use loc value to determine which row of belief you need to use\n", + " ## see the docstring for more information about loc\n", + " ##\n", + " ## 2. After completing the function, comment this line:\n", + " raise NotImplementedError(\"Student exercise: Please complete the code\")\n", + " ############################################################################\n", + " # Write the if statement\n", + " if ...:\n", + " # action below threshold\n", + " act = ...\n", + " else:\n", + " # action above threshold\n", + " act = ...\n", + "\n", + " return act\n", + "\n", + "\n", + "# Next line tests your function\n", + "test_policy_threshold()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "NWXnHA61ZX_t" + }, + "source": [ + "You have to see\n", + "\n", + "```Well Done!```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "tags": [], + "id": "DA1iGUWMZX_u" + }, + "outputs": [], + "source": [ + "# to_remove solution\n", + "def policy_threshold(threshold, belief, loc):\n", + " \"\"\"\n", + " chooses whether to switch side based on whether the belief\n", + " on the current site drops below the threshold\n", + "\n", + " Args:\n", + " threshold (float): the threshold of belief on the current site,\n", + " when the belief is lower than the threshold, switch side\n", + " belief (numpy array of float, 2-dimensional): the belief on the\n", + " two sites at a certain time\n", + " loc (int) : the location of the agent at a certain time\n", + " -1 for left side, 1 for right side\n", + "\n", + " Returns:\n", + " act (string): \"stay\" or \"switch\"\n", + " \"\"\"\n", + " # Write the if statement\n", + " if belief[(loc + 1) // 2] <= threshold:\n", + " # action below threshold\n", + " act = \"switch\"\n", + " else:\n", + " # action above threshold\n", + " act = \"stay\"\n", + "\n", + " return act\n", + "\n", + "\n", + "# Next line tests your function\n", + "test_policy_threshold()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "U3kp8UdaZX_u" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Dynamics_threshold_based_policy_Exercise\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "tags": [], + "id": "59jtWGLoZX_v" + }, + "source": [ + "## Interactive Demo 4: Dynamics with different thresholds\n", + "\n", + "The following demo uses the policy you just built! Play around with the slider and observe the dynamics controlled by your policy.\n", + "\n", + "(The code specifies `stay_prob=0.95`, `high_rew_prob=0.3`, and `low_rew_prob=0.1`. You can change these, but these are reasonable parameters. Note: to see the gradual change with threshold, keep reusing the same random; to see different examples, refresh the seed.\n", + ")\n", + "\n", + "**Thinking questions:**\n", + "* Qualitatively, how well does this policy follow the fish? What does it miss, and why?\n", + "* How can you characterize the fishing strategy if the threshold is very low, or very high?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "i4npwrrXZX_v" + }, + "outputs": [], + "source": [ + "# @markdown Execute this cell to enable the demo.\n", + "display(HTML(''''''))\n", + "\n", + "@widgets.interact(threshold=widgets.FloatSlider(.2, description=\"threshold\", min=0., max=1., step=.01),\n", + " new_seed=widgets.ToggleButtons(options=['Reusing', 'Refreshing'],\n", + " description='Random seed:',\n", + " disabled=False,\n", + " button_style='', # 'success', 'info', 'warning', 'danger' or '',\n", + " icons=['check'] * 2\n", + " ))\n", + "def update_ex_4(threshold, new_seed):\n", + " \"\"\"\n", + " p_stay: probability fish stay\n", + " high_rew_p: p(catch fish) when you're on their side\n", + " low_rew_p : p(catch fish) when you're on other side\n", + " threshold: threshold of belief below which switching is taken\n", + "\n", + " \"\"\"\n", + " if new_seed == \"Refreshing\":\n", + " get_randomness(T)\n", + "\n", + " stay_prob=.95\n", + " high_rew_p=.3\n", + " low_rew_p=.1\n", + "\n", + " params = [stay_prob, high_rew_p, low_rew_p, threshold]\n", + "\n", + " # initial condition for fish [fish_initial] and you [loc_initial]\n", + " binaryHMM_test = binaryHMM_belief(params, fish_initial=0, loc_initial=0, choose_policy=\"threshold\")\n", + "\n", + " belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()\n", + " plot_dynamics(belief, loc, act, measurement,\n", + " fish_state, binaryHMM_test.choose_policy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "id": "FZKl18j3ZX_w" + }, + "outputs": [], + "source": [ + "# to_remove explanation\n", + "\n", + "\"\"\"\n", + "* Qualitatively, how well does this policy follow the fish? What does it miss, and why?\n", + "\n", + " You generally follow the fish, but there can be a substantial difference in location.\n", + " The belief is not generally very confident when the probabilities of catching fish on the\n", + " two sides are not very different. Depending on your threshold, you might leave just\n", + " from some unlucky times when you're still on the right side. Or you might stay even\n", + " though you have not caught many fish, in the hopes that the fish haven't moved.\n", + "\n", + "* How can you characterize the fishing strategy if the threshold is very low, or very high?\n", + "\n", + " If the threshold is low, you only switch when you have a very low belief that you're on the right side.\n", + " Then you switch very rarely.\n", + " If the threshold is high, then you switch whenever you're not extremely confident,\n", + " so you change sides all the time.\n", + "\"\"\";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "ly9f4wzAZX_x" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Dynamics_with_different_thresholds_Interactive_Demo_and_Discussion\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "r9sU7kn0ZX_x" + }, + "source": [ + "---\n", + "# Section 5: Implementing a value function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "rXmtS7F8ZX_y" + }, + "outputs": [], + "source": [ + "# @title Video 5: Evaluate policy\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', 'aJhffROC74w'), ('Bilibili', 'BV1TD4y1D7K3')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "eSKZm0N0ZX_y" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Evaluate_policy_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "3jSU5Z0iZX_z" + }, + "source": [ + "## Coding Exercise 5.1: Implementing a value function\n", + "\n", + "Let's find out how good our threshold is. For that, we will calculate a **value function** that quantifies our utility (total points). We will use this value to compare different thresholds; remember, our goal is to maximize the amount of fish we catch while minimizing the effort involved in changing locations.\n", + "\n", + "The value is the total expected utility per unit time.\n", + "\n", + "\\begin{equation}\n", + "V(\\theta) = \\frac{1}{T}\\left( \\sum_t R(s_t) - C(a_t) \\right)\n", + "\\end{equation}\n", + "\n", + "where $R(s_t)$ is the instantaneous reward we get at location $s_t$ and $C(a_t)$ is the cost we paid for the chosen action. Remember, we receive one point for fish caught and pay `cost_sw` points for switching to the other location.\n", + "\n", + "We could take this average mathematically over the probabilities of rewards and actions. However, we can get the same answer by simply averaging the _actual_ rewards and costs over a long time. This is what you are going to do.\n", + "\n", + "\n", + "**Instructions**: Fill in the function `get_value(rewards, actions, cost_sw)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "execution": {}, + "tags": [], + "id": "qqMs5PWmZX_z" + }, + "outputs": [], + "source": [ + "def get_value(rewards, actions, cost_sw):\n", + " \"\"\"\n", + " value function\n", + "\n", + " Args:\n", + " rewards (numpy array of length T): whether a reward is obtained (1) or not (0) at each time step\n", + " actions (numpy array of length T): action, \"stay\" or \"switch\", taken at each time step.\n", + " cost_sw (float): the cost of switching to the other location\n", + "\n", + " Returns:\n", + " value (float): expected utility per unit time\n", + " \"\"\"\n", + " actions_int = (actions == \"switch\").astype(int)\n", + "\n", + " ############################################################################\n", + " ## 1. Modify the code below to compute the value function (equation V(theta))\n", + " ##\n", + " ## 2. After completing the function, comment this line:\n", + " raise NotImplementedError(\"Student exercise: Please complete the code\")\n", + " ############################################################################\n", + " # Calculate the value function\n", + " value = ...\n", + "\n", + " return value\n", + "\n", + "\n", + "# Test your function\n", + "test_value_function()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "mJcc8w3xZX_3" + }, + "source": [ + "You will see\n", + "\n", + "```Well Done!```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "tags": [], + "id": "HfzZWCxqZX_4" + }, + "outputs": [], + "source": [ + "# to_remove solution\n", + "\n", + "def get_value(rewards, actions, cost_sw):\n", + " \"\"\"\n", + " Args:\n", + " rewards (numpy array of length T): whether a reward is obtained (1) or not (0) at each time step\n", + " actions (numpy array of length T): action, \"stay\" or \"switch\", taken at each time step.\n", + " cost_sw (float): the cost of switching to the other location\n", + "\n", + " Returns:\n", + " value (float): expected utility per unit time\n", + " \"\"\"\n", + " actions_int = (actions == \"switch\").astype(int)\n", + "\n", + " # Calculate the value function\n", + " value = np.sum(rewards - actions_int * cost_sw) / len(rewards)\n", + "\n", + " return value\n", + "\n", + "\n", + "# Test your function\n", + "test_value_function()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "yjNP7Qu4ZX_5" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Implementing_a_value_function_Exercise\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "qHSutxRkZX_5" + }, + "source": [ + "## Coding Exercise 5.2: Run the policy\n", + "\n", + "Now that you have a mechanism to find out how good a threshold is, we will use a brute force approach to **compute the optimal threshold**: we'll just try all thresholds, simulate the value of each, and pick the best one. Complete the function `get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw)`. We provide the code to visualize the output of your function. Observe on this plot which threshold has maximal utility.\n", + "\n", + "**Thinking questions:**\n", + "\n", + "* Try a very high switching cost. What is the best threshold? How does that make sense?\n", + "* Try a zero switching cost. What's different?\n", + "* Generally, how does the best threshold change with the switching cost?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "id": "cuBr6rfAZX_6" + }, + "outputs": [], + "source": [ + "def run_policy(threshold, p_stay, low_rew_p, high_rew_p):\n", + " \"\"\"\n", + " This function executes the policy (fully parameterized by the threshold) and\n", + " returns two arrays:\n", + " The sequence of actions taken from time 0 to T\n", + " The sequence of rewards obtained from time 0 to T\n", + " \"\"\"\n", + " params = [p_stay, low_rew_p, high_rew_p, threshold]\n", + " binaryHMM_test = binaryHMM_belief(params, choose_policy=\"threshold\")\n", + " _, _, actions, rewards, _ = binaryHMM_test.generate_process()\n", + "\n", + " return actions, rewards\n", + "\n", + "\n", + "def get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw):\n", + " \"\"\"\n", + " Args:\n", + " p_stay (float): probability of fish staying in their current location\n", + " low_rew_p (float): probability of catching fish when you and the fist are in different locations.\n", + " high_rew_p (float): probability of catching fish when you and the fist are in the same location.\n", + " cost_sw (float): the cost of switching to the other location\n", + "\n", + " Returns:\n", + " value (float): expected utility per unit time\n", + " \"\"\"\n", + " ############################################################################\n", + " ## 1. Modify the code below to find the best threshold using brute force\n", + " ##\n", + " ## 2. After completing the function, comment this line:\n", + " raise NotImplementedError(\"Student exercise: Please complete the code\")\n", + " ############################################################################\n", + " global T\n", + " T = 10000 # Setting a large time horizon\n", + " get_randomness(T)\n", + "\n", + " # Create an array of 20 equally distanced candidate thresholds (min = 0., max=1.):\n", + " threshold_array = ...\n", + "\n", + " # Using the function get_value() that you coded before and\n", + " # the function run_policy() that we provide, compute the value of your\n", + " # candidate thresholds:\n", + "\n", + " # Create an array to store the value of each of your candidates:\n", + " value_array = ...\n", + "\n", + " for i in ...:\n", + " actions, rewards = ...\n", + " value_array[i] = ...\n", + "\n", + " # Return the array of candidate thresholds and their respective values\n", + "\n", + " return threshold_array, value_array\n", + "\n", + "\n", + "# Feel free to change these parameters\n", + "stay_prob = .9\n", + "low_rew_prob = 0.1\n", + "high_rew_prob = 0.2\n", + "cost_sw = .1\n", + "\n", + "# Visually determine the threshold that obtains the maximum utility\n", + "threshold_array, value_array = get_optimal_threshold(stay_prob, low_rew_prob, high_rew_prob, cost_sw)\n", + "plot_value_threshold(threshold_array, value_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "tags": [], + "id": "Tsjo2cfiZX_7" + }, + "outputs": [], + "source": [ + "# to_remove solution\n", + "\n", + "def run_policy(threshold, p_stay, low_rew_p, high_rew_p):\n", + " \"\"\"\n", + " This function executes the policy (fully parameterized by the threshold) and\n", + " returns two arrays:\n", + " The sequence of actions taken from time 0 to T\n", + " The sequence of rewards obtained from time 0 to T\n", + " \"\"\"\n", + " params = [p_stay, low_rew_p, high_rew_p, threshold]\n", + " binaryHMM_test = binaryHMM_belief(params, choose_policy=\"threshold\")\n", + " _, _, actions, rewards, _ = binaryHMM_test.generate_process()\n", + " return actions, rewards\n", + "\n", + "\n", + "def get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw):\n", + " \"\"\"\n", + " Args:\n", + " p_stay (float): probability of fish staying in their current location\n", + " low_rew_p (float): probability of catching fish when you and the fist are in different locations.\n", + " high_rew_p (float): probability of catching fish when you and the fist are in the same location.\n", + " cost_sw (float): the cost of switching to the other location\n", + "\n", + " Returns:\n", + " value (float): expected utility per unit time\n", + " \"\"\"\n", + " global T\n", + " T = 10000 # Setting a large time horizon\n", + " get_randomness(T)\n", + "\n", + " # Create an array of 20 equally distanced candidate thresholds (min = 0., max=1.):\n", + " threshold_array = np.linspace(0., 1., 20)\n", + "\n", + " # Using the function get_value() that you coded before and\n", + " # the function run_policy() that we provide, compute the value of your\n", + " # candidate thresholds:\n", + "\n", + " # Create an array to store the value of each of your candidates:\n", + " value_array = np.zeros(len(threshold_array))\n", + "\n", + " for i in range(len(threshold_array)):\n", + " actions, rewards = run_policy(threshold_array[i], p_stay, low_rew_p, high_rew_p)\n", + " value_array[i] = get_value(rewards, actions, cost_sw)\n", + "\n", + " # Return the array of candidate thresholds and their respective values\n", + "\n", + " return threshold_array, value_array\n", + "\n", + "\n", + "# Feel free to change these parameters\n", + "stay_prob = .9\n", + "low_rew_prob = 0.1\n", + "high_rew_prob = 0.2\n", + "cost_sw = .1\n", + "\n", + "# Visually determine the threshold that obtains the maximum utility\n", + "threshold_array, value_array = get_optimal_threshold(stay_prob, low_rew_prob, high_rew_prob, cost_sw)\n", + "with plt.xkcd():\n", + " plot_value_threshold(threshold_array, value_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "id": "r2ybbRl9ZX_7" + }, + "outputs": [], + "source": [ + "# to_remove explanation\n", + "\n", + "\"\"\"\n", + "* Try a very high switching cost. What is the best threshold? How does that make sense?\n", + "\n", + " You should see that there is a best threshold:\n", + " If it is too small, then you never move, missing opportunities to follow the fish.\n", + " If it is too large, then you move too often and pay a large cost for the switching.\n", + " When the switching cost is extremely high, it's never worth moving, so the optimal threshold is at zero.\n", + "\n", + "* Try a zero switching cost. What's different?\n", + "\n", + " When the switching cost is zero, it's not best to always switch, but rather to follow\n", + " the optimal inference about the fish location.\n", + "\n", + "* Generally, how does the best threshold change with the switching cost?\n", + "\n", + " As the switching cost rises, the threshold should fall because\n", + " you have even more incentive to avoid switches.\n", + "\"\"\";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "x5S_0gYzZX_8" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Run_the_policy_Exercise_and_Discussion\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "OXh8EVeaZX_8" + }, + "source": [ + "---\n", + "# Summary\n", + "\n", + "In this tutorial, you combined Hidden Markov Models with actions to solve an optimal control problem! This showed us the core formalism of the *Partially Observable Markov Decision Process* (POMDP).\n", + "\n", + "Using observations (fish caught), you built beliefs (posterior distributions) that helped you estimate where the fish were. Next, you computed a value function that helped you evaluate the quality of different policies. Finally, using a brute force approach, you discovered an optimal policy that allowed you to catch as many fish as possible while minimizing the effort of switching your location.\n", + "\n", + "The following tutorial will use continuous states and actions instead of the binary ones we used here. In continuous control, we can still use a POMDP, but we'll focus on control in the *fully* observed case, a Markov Decision Process (MDP), since the policy is still illuminating." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "csEXQLRfZX_9" + }, + "outputs": [], + "source": [ + "# @title Video 6: From discrete to continuous control\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', 'ndCMgdjv9Gg'), ('Bilibili', 'BV1JA411v7jy')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "CNAnzg-yZX_9" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_From_discrete_to_continuous_control_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "UbY5Z1tWZX_-" + }, + "source": [ + "---\n", + "# Bonus" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "HCuV34bEZX_-" + }, + "source": [ + "## Bonus Section 1: How does the optimal policy depend on the task?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "gS7ybHcoZX__" + }, + "outputs": [], + "source": [ + "# @title Video 7: Sensitivity of optimal policy\n", + "from ipywidgets import widgets\n", + "from IPython.display import YouTubeVideo\n", + "from IPython.display import IFrame\n", + "from IPython.display import display\n", + "\n", + "\n", + "class PlayVideo(IFrame):\n", + " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", + " self.id = id\n", + " if source == 'Bilibili':\n", + " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", + " elif source == 'Osf':\n", + " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", + " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", + "\n", + "\n", + "def display_videos(video_ids, W=400, H=300, fs=1):\n", + " tab_contents = []\n", + " for i, video_id in enumerate(video_ids):\n", + " out = widgets.Output()\n", + " with out:\n", + " if video_ids[i][0] == 'Youtube':\n", + " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", + " height=H, fs=fs, rel=0)\n", + " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", + " else:\n", + " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", + " height=H, fs=fs, autoplay=False)\n", + " if video_ids[i][0] == 'Bilibili':\n", + " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", + " elif video_ids[i][0] == 'Osf':\n", + " print(f'Video available at https://osf.io/{video.id}')\n", + " display(video)\n", + " tab_contents.append(out)\n", + " return tab_contents\n", + "\n", + "\n", + "video_ids = [('Youtube', 'wd8IVsKoEfA'), ('Bilibili', 'BV1QK4y1e7N9')]\n", + "tab_contents = display_videos(video_ids, W=854, H=480)\n", + "tabs = widgets.Tab()\n", + "tabs.children = tab_contents\n", + "for i in range(len(tab_contents)):\n", + " tabs.set_title(i, video_ids[i][0])\n", + "display(tabs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "W9I0XWlKZX__" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Sensitivity_of_optimal_policy_Bonus_Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": {}, + "id": "UoULh0WpZYAA" + }, + "source": [ + "### Bonus Interactive Demo 1: Explore task parameters\n", + "\n", + "In this demo, you can play with various task parameters. Observe how the optimal threshold changes when you adjust:\n", + "* The switching cost\n", + "* The fish dynamics (`p(stay)`)\n", + "* The probability of catching fish on each side, `p(high_rwd)` and `p(low_rwd)`\n", + "\n", + "Can you explain why the optimal threshold changes with these parameters:\n", + "\n", + "* lower/higher switching cost?\n", + "* faster fish dynamics (_i.e._, low `p_stay`)?\n", + "* rarer fish caught (_i.e._, low `p(high_rwd)` and low `p(low_rwd)`)?\n", + "\n", + "Note that it may require long simulations to see subtle changes in values of different policies, so look for coarse trends first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "tags": [], + "id": "VzCv_RxJZYAB" + }, + "outputs": [], + "source": [ + "# @markdown Make sure you execute this cell to enable the widget!\n", + "display(HTML(''''''))\n", + "\n", + "@widgets.interact(p_stay=widgets.FloatSlider(.95, description=\"p(stay)\",\n", + " min=0., max=1.),\n", + " p_high_rwd=widgets.FloatSlider(.4, description=\"p(high_rwd)\",\n", + " min=0., max=1.),\n", + " p_low_rwd=widgets.FloatSlider(.1, description=\"p(low_rwd)\",\n", + " min=0., max=1.),\n", + " cost_sw=widgets.FloatSlider(.2, description=\"switching cost\",\n", + " min=0., max=1., step=.01))\n", + "\n", + "\n", + "def update_ex_bonus(p_stay, p_high_rwd, p_low_rwd, cost_sw):\n", + " \"\"\"\n", + " p_stay: probability fish stay\n", + " high_rew_p: p(catch fish) when you're on their side\n", + " low_rew_p : p(catch fish) when you're on other side\n", + " cost_sw: switching cost\n", + " \"\"\"\n", + "\n", + " threshold_array, value_array = get_optimal_threshold(p_stay,\n", + " p_low_rwd,\n", + " p_high_rwd,\n", + " cost_sw)\n", + " globals()['cost_sw'] = cost_sw\n", + " plot_value_threshold(threshold_array, value_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": {}, + "tags": [], + "id": "HkFcxBmUZYAB" + }, + "outputs": [], + "source": [ + "# to_remove explanation\n", + "\n", + "\"\"\"\n", + "* lower/higher switching cost?\n", + "\n", + " High switching cost means that you should be more certain that the other side\n", + " is better before committing to change sides. This means that beliefs must fall\n", + " below a threshold before acting. Conversely, a lower switching cost allows you\n", + " more flexibility to switch at less stringent thresholds. In the limit of _zero_\n", + " switching cost, you should always switch whenever you think the other side is\n", + " better, even if it's just 51%, and even if you switch every time step.\n", + "\n", + "* faster fish dynamics (i.e., low p_stay)?\n", + "\n", + " Faster fish dynamics (lower `p_stay`) also promote faster switching because\n", + " you cannot plan as far into the future. In that case you must base your decisions\n", + " on more immediate evidence, but since you still pay the same switching cost that\n", + " cost is a higher fraction of your predictable rewards. Thus, you should be more\n", + " conservative and switch only when you are more confident.\n", + "\n", + "* rarer fish caught (i.e., low p(high_rwd) and low p(low_rwd))?\n", + "\n", + " When `high_rew_p` and/or `low_rew_p` decreases, your predictions become less reliable,\n", + " again encouraging you to require more confidence before committing to a switch.\n", + "\"\"\";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "execution": {}, + "id": "Axd44zMqZYAB" + }, + "outputs": [], + "source": [ + "# @title Submit your feedback\n", + "content_review(f\"{feedback_prefix}_Explore_task_parameters_Bonus_Interactive_Demo_and_Discussion\")" + ] + } + ], + "metadata": { + "colab": { + "name": "W3D3_Tutorial1", + "provenance": [], + "toc_visible": true, + "include_colab_link": true + }, + "kernel": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.17" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [ + "# Tutorial 1- Optimal Control for Discrete State\n", + "\n", + "Please execute the cell below to initialize the notebook environment.\n", + "\n", + "import numpy as np # import numpy\n", + "import scipy # import scipy\n", + "import random # import basic random number generator functions\n", + "from scipy.linalg import inv\n", + "\n", + "import matplotlib.pyplot as plt # import matplotlib\n", + "\n", + "---\n", + "\n", + "## Tutorial objectives\n", + "\n", + "In this tutorial, we will implement a binary HMM task.\n", + "\n", + "---\n", + "\n", + "## Task Description\n", + "\n", + "There are two boxes. The box can be in a high-rewarding state ($s=1$), which means that a reward will be delivered with high probabilty $q_{high}$; or the box can be in low-rewarding state ($s=0$), then the reward will be delivered with low probabilty $q_{low}$.\n", + "\n", + "The states of the two boxes are latent. At a certain time, only one of the sites can be in high-rewarding state, and the other box will be the opposite. The states of the two boxes switches with a certain probability $p_{sw}$. \n", + "\n", + "![alt text](switching.png \"Title\")\n", + "\n", + "\n", + "The agent may stay at one site for sometime. As the agent accumulates evidence about the state of the box on that site, it may choose to stay or switch to the other side with a switching cost $c$. The agent keeps beliefs on the states of the boxes, which is the posterior probability of the state being high-rewarding given all the past observations. Consider the belief on the state of the left box, we have \n", + "\n", + "$$b(s_t) = p(s_t = 1 | o_{0:t}, l_{0:t}, a_{0:t-1})$$\n", + "\n", + "where $o$ is the observation that whether a reward is obtained, $l$ is the location of the agent, $a$ is the action of staying ($a=0$) or switching($a=1$). \n", + "\n", + "Since the two boxes are completely anti-correlated, i.e. only one of the boxes is high-rewarded at a certain time, the the other one is low-rewarded, the belief on the two boxes should sum up to be 1. As a result, we only need to track the belief on one of the boxes. \n", + "\n", + "The policy of the agent depends on a threshold on beliefs. When the belief on the box on the other side gets higher than the threshold $\\theta$, the agent will switch to the other side. In other words, the agent will choose to switch when it is confident enough that the other side is high rewarding. \n", + "\n", + "The value function can be defined as the reward rate during a single trial.\n", + "\n", + "$$v(\\theta) = \\sum_t r_t - c\\cdot 1_{a_t = 1}$$ \n", + "\n", + "we would like to see the relation between the threshold and the value function. \n", + "\n", + "### Exercise 1: Control for binary HMM\n", + "In this excercise, we generate the dynamics for the binary HMM task as described above. \n", + "\n", + "# This function is the policy based on threshold\n", + "\n", + "def policy(threshold, bel, loc):\n", + " if loc == 0:\n", + " if bel[1] >= threshold:\n", + " act = 1\n", + " else:\n", + " act = 0\n", + " else: # loc = 1\n", + " if bel[0] >= threshold:\n", + " act = 1\n", + " else:\n", + " act = 0\n", + "\n", + " return act\n", + "\n", + "# This function generates the dynamics\n", + "\n", + "def generateProcess(params):\n", + "\n", + " T, p_sw, q_high, q_low, cost_sw, threshold = params\n", + " world_state = np.zeros((2, T), int) # value :1: good box; 0: bad box\n", + " loc = np.zeros(T, int) # 0: left box 1: right box\n", + " obs = np.zeros(T, int) # 0: did not get food 1: get food\n", + " act = np.zeros(T, int) # 0 : stay 1: switch and get food from the other side\n", + " bel = np.zeros((2, T), float) # the probability that the left box has food,\n", + " # then the probability that the second box has food is 1-b\n", + "\n", + "\n", + " p = np.array([1 - p_sw, p_sw]) # transition probability to good state\n", + " q = np.array([q_low, q_high])\n", + " q_mat = np.array([[1 - q_high, q_high], [1 - q_low, q_low]])\n", + "\n", + " for t in range(T):\n", + " if t == 0:\n", + " world_state[0, t] = 1 # good box\n", + " world_state[1, t] = 1 - world_state[0, t]\n", + " loc[t] = 0\n", + " obs[t] = 0\n", + " bel_0 = np.random.random(1)[0]\n", + " bel[:, t] = np.array([bel_0, 1-bel_0])\n", + "\n", + " act[t] = policy(threshold, bel[:, t], loc[t])\n", + "\n", + " else:\n", + " world_state[0, t] = np.random.binomial(1, p[world_state[0, t - 1]])\n", + " world_state[1, t] = 1 - world_state[0, t]\n", + "\n", + " if act[t - 1] == 0:\n", + " loc[t] = loc[t - 1]\n", + " else: # after weitching, open the new box, deplete if any; then wait a usualy time\n", + " loc[t] = 1 - loc[t - 1]\n", + "\n", + " # new observation\n", + " obs[t] = np.random.binomial(1, q[world_state[loc[t], t-1]])\n", + "\n", + " # update belief posterior, p(s[t] | obs(0-t), act(0-t-1))\n", + " bel_0 = (bel[0, t-1] * p_sw + bel[1, t-1] * (1 - p_sw)) * q_mat[loc[t], obs[t]]\n", + " bel_1 = (bel[1, t - 1] * p_sw + bel[0, t - 1] * (1 - p_sw)) * q_mat[1-loc[t], obs[t]]\n", + "\n", + " bel[0, t] = bel_0 / (bel_0 + bel_1)\n", + " bel[1, t] = bel_1 / (bel_0 + bel_1)\n", + "\n", + " act[t] = policy(threshold, bel[:, t], loc[t])\n", + "\n", + " return bel, obs, act, world_state, loc\n", + "\n", + "# value function \n", + "def value_function(obs, act, cost_sw, discount):\n", + " T = len(obs)\n", + " discount_time = np.array([discount ** t for t in range(T)])\n", + "\n", + " #value = (np.sum(obs) - np.sum(act) * cost_sw) / T\n", + " value = (np.sum(np.multiply(obs, discount_time)) - np.sum(np.multiply(act, discount_time)) * cost_sw) / T\n", + "\n", + " return value\n", + "\n", + "def switch_int(obs, act):\n", + " sw_t = np.where(act == 1)[0]\n", + " sw_int = sw_t[1:] - sw_t[:-1]\n", + "\n", + " return sw_int\n", + "\n", + "#Plotting \n", + "def plot_dynamics(bel, obs, act, world_state, loc):\n", + " T = len(obs)\n", + "\n", + " showlen = min(T, 100)\n", + " startT = 0\n", + "\n", + " endT = startT + showlen\n", + " showT = range(startT, endT)\n", + " time_range = np.linspace(0, showlen - 1)\n", + "\n", + " fig_posterior, [ax0, ax1, ax_loc, ax2, ax3] = plt.subplots(5, 1, figsize=(15, 10))\n", + "\n", + " ax0.plot(world_state[0, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", + " ax0.set_ylabel('Left box', rotation=360, fontsize=22)\n", + " ax0.yaxis.set_label_coords(-0.1, 0.25)\n", + " ax0.set_xticks(np.arange(0, showlen, 10))\n", + " ax0.tick_params(axis='both', which='major', labelsize=18)\n", + " ax0.set_xlim([0, showlen])\n", + "\n", + "\n", + " ax3.plot(world_state[1, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", + " ax3.set_ylabel('Right box', rotation=360, fontsize=22)\n", + " ax3.yaxis.set_label_coords(-0.1, 0.25)\n", + " ax3.tick_params(axis='both', which='major', labelsize=18)\n", + " ax3.set_xlim([0, showlen])\n", + " ax3.set_xticks(np.arange(0, showlen, 10))\n", + "\n", + " ax1.plot(bel[0, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", + " ax1.plot(time_range, threshold * np.ones(time_range.shape), 'r--')\n", + " ax1.yaxis.set_label_coords(-0.1, 0.25)\n", + " ax1.set_ylabel('Belief on \\n left box', rotation=360, fontsize=22)\n", + " ax1.tick_params(axis='both', which='major', labelsize=18)\n", + " ax1.set_xlim([0, showlen])\n", + " ax1.set_ylim([0, 1])\n", + " ax1.set_xticks(np.arange(0, showlen, 10))\n", + "\n", + "\n", + " ax_loc.plot(1 - loc[showT], 'g.-', markersize=12, linewidth=5, label = 'location')\n", + " ax_loc.plot((act[showT] - .1) * .8, 'v', markersize=10, label = 'action')\n", + " ax_loc.plot(obs[showT] * .5, '*', markersize=5, label = 'reward')\n", + " ax_loc.legend(loc=\"upper right\")\n", + " ax_loc.set_xlim([0, showlen])\n", + " ax_loc.set_ylim([0, 1])\n", + " #ax_loc.set_yticks([])\n", + " ax_loc.set_xticks([0, showlen])\n", + " ax_loc.tick_params(axis='both', which='major', labelsize=18)\n", + " labels = [item.get_text() for item in ax_loc.get_yticklabels()]\n", + " labels[0] = 'Right'\n", + " labels[-1] = 'Left'\n", + " ax_loc.set_yticklabels(labels)\n", + "\n", + " ax2.plot(bel[1, showT], color='dodgerblue', markersize=10, linewidth=3.0)\n", + " ax2.plot(time_range, threshold * np.ones(time_range.shape), 'r--')\n", + " ax2.set_xlabel('time', fontsize=18)\n", + " ax2.yaxis.set_label_coords(-0.1, 0.25)\n", + " ax2.set_ylabel('Belief on \\n right box', rotation=360, fontsize=22)\n", + " ax2.tick_params(axis='both', which='major', labelsize=18)\n", + " ax2.set_xlim([0, showlen])\n", + " ax2.set_ylim([0, 1])\n", + " ax2.set_xticks(np.arange(0, showlen, 10))\n", + "\n", + " plt.show()\n", + "\n", + "def plot_val_thre(threshold_array, value_array):\n", + " fig_, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + " ax.plot(threshold_array, value_array)\n", + " ax.set_ylim([np.min(value_array), np.max(value_array)])\n", + " ax.set_title('threshold vs value')\n", + " ax.set_xlabel('threshold')\n", + " ax.set_ylabel('value')\n", + " plt.show()\n", + "\n", + "T = 5000\n", + "p_sw = .95 # state transiton probability\n", + "q_high = .7\n", + "q_low = 0 #.2\n", + "cost_sw = 1 #int(1/(1-p_sw)) - 5\n", + "threshold = .8 # threshold of belief for switching\n", + "discount = 1\n", + "\n", + "step = 0.1\n", + "threshold_array = np.arange(0, 1 + step, step)\n", + "value_array = np.zeros(threshold_array.shape)\n", + "\n", + "for i in range(len(threshold_array)):\n", + " threshold = threshold_array[i]\n", + " params = [T, p_sw, q_high, q_low, cost_sw, threshold]\n", + " bel, obs, act, world_state, loc = generateProcess(params)\n", + " value_array[i] = value_function(obs, act, cost_sw, discount)\n", + " sw_int = switch_int(obs, act)\n", + " #print(np.mean(sw_int))\n", + "\n", + " if threshold == 0.8:\n", + " plot_dynamics(bel, obs, act, world_state, loc)\n", + "\n", + "plot_val_thre(threshold_array, value_array)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file