{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2020-10-23T07:01:24.034522Z", "iopub.status.busy": "2020-10-23T07:01:24.033481Z", "iopub.status.idle": "2020-10-23T07:01:24.040184Z", "shell.execute_reply": "2020-10-23T07:01:24.039289Z" }, "papermill": { "duration": 0.026939, "end_time": "2020-10-23T07:01:24.040385", "exception": false, "start_time": "2020-10-23T07:01:24.013446", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/input/rl-assignment/samplesubmission.csv\n" ] } ], "source": [ "# This Python 3 environment comes with many helpful analytics libraries installed\n", "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", "# For example, here's several helpful packages to load\n", "\n", "import numpy as np # linear algebra\n", "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", "\n", "# Input data files are available in the read-only \"../input/\" directory\n", "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", "\n", "import os\n", "for dirname, _, filenames in os.walk('/kaggle/input'):\n", " for filename in filenames:\n", " print(os.path.join(dirname, filename))\n", "\n", "# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", "execution": { "iopub.execute_input": "2020-10-23T07:01:24.069878Z", "iopub.status.busy": "2020-10-23T07:01:24.068972Z", "iopub.status.idle": "2020-10-23T07:01:25.495962Z", "shell.execute_reply": "2020-10-23T07:01:25.494876Z" }, "papermill": { "duration": 1.44331, "end_time": "2020-10-23T07:01:25.496099", "exception": false, "start_time": "2020-10-23T07:01:24.052789", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+\n", "|\u001b[34;1mR\u001b[0m: |\u001b[43m \u001b[0m: :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |B: |\n", "+---------+\n", "\n" ] } ], "source": [ "import gym\n", "\n", "env = gym.make(\"Taxi-v3\").env\n", "\n", "env.reset() # reset environment to a new, random state\n", "env.render()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:01:25.526480Z", "iopub.status.busy": "2020-10-23T07:01:25.525487Z", "iopub.status.idle": "2020-10-23T07:01:25.530094Z", "shell.execute_reply": "2020-10-23T07:01:25.529110Z" }, "papermill": { "duration": 0.022214, "end_time": "2020-10-23T07:01:25.530278", "exception": false, "start_time": "2020-10-23T07:01:25.508064", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action Space Discrete(6)\n", "State Space Discrete(500)\n" ] } ], "source": [ "# action space\n", "print(\"Action Space {}\".format(env.action_space))\n", "\n", "print(\"State Space {}\".format(env.observation_space))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:01:25.561002Z", "iopub.status.busy": "2020-10-23T07:01:25.560068Z", "iopub.status.idle": "2020-10-23T07:01:25.563851Z", "shell.execute_reply": "2020-10-23T07:01:25.564479Z" }, "papermill": { "duration": 0.022719, "end_time": "2020-10-23T07:01:25.564666", "exception": false, "start_time": "2020-10-23T07:01:25.541947", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "State: 328\n", "+---------+\n", "|\u001b[35mR\u001b[0m: | : :G|\n", "| : | : : |\n", "| : : : : |\n", "| |\u001b[43m \u001b[0m: | : |\n", "|\u001b[34;1mY\u001b[0m| : |B: |\n", "+---------+\n", "\n" ] } ], "source": [ "# state encoding\n", "\n", "state = env.encode(3, 1, 2, 0) # (taxi row, taxi column, passenger index, destination index)\n", "print(\"State:\", state)\n", "\n", "env.s = state\n", "env.render()\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:01:25.599957Z", "iopub.status.busy": "2020-10-23T07:01:25.599051Z", "iopub.status.idle": "2020-10-23T07:01:25.604442Z", "shell.execute_reply": "2020-10-23T07:01:25.603711Z" }, "papermill": { "duration": 0.027519, "end_time": "2020-10-23T07:01:25.604618", "exception": false, "start_time": "2020-10-23T07:01:25.577099", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{0: [(1.0, 428, -1, False)],\n", " 1: [(1.0, 228, -1, False)],\n", " 2: [(1.0, 348, -1, False)],\n", " 3: [(1.0, 328, -1, False)],\n", " 4: [(1.0, 328, -10, False)],\n", " 5: [(1.0, 328, -10, False)]}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# rewards table\n", "env.P[328]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:01:25.639997Z", "iopub.status.busy": "2020-10-23T07:01:25.638924Z", "iopub.status.idle": "2020-10-23T07:01:25.643756Z", "shell.execute_reply": "2020-10-23T07:01:25.642988Z" }, "papermill": { "duration": 0.025446, "end_time": "2020-10-23T07:01:25.643881", "exception": false, "start_time": "2020-10-23T07:01:25.618435", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0.]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Defining q table\n", "\n", "import numpy as np\n", "q_table = np.zeros([env.observation_space.n, env.action_space.n])\n", "q_table" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:01:25.677831Z", "iopub.status.busy": "2020-10-23T07:01:25.676776Z", "iopub.status.idle": "2020-10-23T07:01:25.680343Z", "shell.execute_reply": "2020-10-23T07:01:25.679723Z" }, "papermill": { "duration": 0.022979, "end_time": "2020-10-23T07:01:25.680493", "exception": false, "start_time": "2020-10-23T07:01:25.657514", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# hyperparameters\n", "alpha = 0.1\n", "gamma = 0.6\n", "epsiolon = 0.1" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:01:25.765603Z", "iopub.status.busy": "2020-10-23T07:01:25.726090Z", "iopub.status.idle": "2020-10-23T07:10:37.138058Z", "shell.execute_reply": "2020-10-23T07:10:37.138886Z" }, "papermill": { "duration": 551.443876, "end_time": "2020-10-23T07:10:37.139105", "exception": false, "start_time": "2020-10-23T07:01:25.695229", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 999999 Iterations: 14810058\r\n" ] } ], "source": [ "import random\n", "\n", "max_epochs = 1000000 # stops after x epochs or max_iterations\n", "max_iterations = -1 # 1000000 # -1 to not stop based on iterations\n", "\n", "iteration = 0\n", "for epoch in range(max_epochs):\n", " done = False\n", " new_state = env.reset()\n", " \n", " while not done and (iteration < max_iterations or max_iterations == -1):\n", " iteration += 1\n", " old_state = new_state\n", " if random.uniform(0,1) < epsiolon:\n", " action = env.action_space.sample()\n", " else:\n", " action = np.argmax(q_table[old_state])\n", " \n", " new_state, reward, done, info = env.step(action)\n", " env.s = new_state # update state of the environment (not sure if )\n", " \n", " q_table[old_state, action] = q_table[old_state][action] + alpha * (reward + gamma * np.max(q_table[new_state]) - q_table[old_state][action])\n", " \n", " if iteration >= max_iterations and max_iterations != -1:\n", " break\n", "print(\"Epoch: {} Iterations: {}\\r\".format(epoch, iteration))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:10:37.176308Z", "iopub.status.busy": "2020-10-23T07:10:37.175458Z", "iopub.status.idle": "2020-10-23T07:10:37.179942Z", "shell.execute_reply": "2020-10-23T07:10:37.180484Z" }, "papermill": { "duration": 0.026804, "end_time": "2020-10-23T07:10:37.180659", "exception": false, "start_time": "2020-10-23T07:10:37.153855", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([[ 0. , 0. , 0. , 0. ,\n", " 0. , 0. ],\n", " [ -2.41837066, -2.3639511 , -2.41837066, -2.3639511 ,\n", " -2.27325184, -11.3639511 ],\n", " [ -1.870144 , -1.45024 , -1.870144 , -1.45024 ,\n", " -0.7504 , -10.45024 ],\n", " ...,\n", " [ -0.75955193, 0.416 , -0.75945955, -1.44283176,\n", " -9.1295566 , -9.26328213],\n", " [ -2.26879216, -2.1220864 , -2.24661946, -2.1220864 ,\n", " -10.7816137 , -10.43290413],\n", " [ 5.6 , 2.36 , 5.6 , 11. ,\n", " -3.4 , -3.4 ]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "q_table" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:10:37.220686Z", "iopub.status.busy": "2020-10-23T07:10:37.219859Z", "iopub.status.idle": "2020-10-23T07:10:37.228862Z", "shell.execute_reply": "2020-10-23T07:10:37.229575Z" }, "papermill": { "duration": 0.033495, "end_time": "2020-10-23T07:10:37.229780", "exception": false, "start_time": "2020-10-23T07:10:37.196285", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| :\u001b[43m \u001b[0m: : : |\n", "| | : | : |\n", "|Y| : |\u001b[34;1mB\u001b[0m: |\n", "+---------+\n", " (East)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : :\u001b[43m \u001b[0m: : |\n", "| | : | : |\n", "|Y| : |\u001b[34;1mB\u001b[0m: |\n", "+---------+\n", " (East)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : :\u001b[43m \u001b[0m: |\n", "| | : | : |\n", "|Y| : |\u001b[34;1mB\u001b[0m: |\n", "+---------+\n", " (East)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : |\u001b[43m \u001b[0m: |\n", "|Y| : |\u001b[34;1mB\u001b[0m: |\n", "+---------+\n", " (South)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |\u001b[34;1m\u001b[43mB\u001b[0m\u001b[0m: |\n", "+---------+\n", " (South)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |\u001b[42mB\u001b[0m: |\n", "+---------+\n", " (Pickup)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : |\u001b[42m_\u001b[0m: |\n", "|Y| : |B: |\n", "+---------+\n", " (North)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : :\u001b[42m_\u001b[0m: |\n", "| | : | : |\n", "|Y| : |B: |\n", "+---------+\n", " (North)\n", "+---------+\n", "|R: | : :\u001b[35mG\u001b[0m|\n", "| : | :\u001b[42m_\u001b[0m: |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |B: |\n", "+---------+\n", " (North)\n", "+---------+\n", "|R: | :\u001b[42m_\u001b[0m:\u001b[35mG\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |B: |\n", "+---------+\n", " (North)\n", "+---------+\n", "|R: | : :\u001b[35m\u001b[42mG\u001b[0m\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |B: |\n", "+---------+\n", " (East)\n", "+---------+\n", "|R: | : :\u001b[35m\u001b[34;1m\u001b[43mG\u001b[0m\u001b[0m\u001b[0m|\n", "| : | : : |\n", "| : : : : |\n", "| | : | : |\n", "|Y| : |B: |\n", "+---------+\n", " (Dropoff)\n" ] } ], "source": [ "# example game\n", "\n", "new_state = env.reset()\n", "\n", "done = False\n", "max_iterations = 100\n", "iteration = 0\n", "while not done and (iteration <= max_iterations or max_iterations == -1):\n", " old_state = new_state\n", " action = np.argmax(q_table[old_state])\n", " new_state, reward, done, info = env.step(action)\n", " env.s = new_state\n", " env.render()\n", " iteration += 1\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2020-10-23T07:10:37.269410Z", "iopub.status.busy": "2020-10-23T07:10:37.268454Z", "iopub.status.idle": "2020-10-23T07:10:37.659363Z", "shell.execute_reply": "2020-10-23T07:10:37.658561Z" }, "papermill": { "duration": 0.413907, "end_time": "2020-10-23T07:10:37.659523", "exception": false, "start_time": "2020-10-23T07:10:37.245616", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IdValue
010.00
120.00
230.00
340.00
450.00
.........
299529962.36
299629975.60
2997299811.00
29982999-3.40
29993000-3.40
\n", "

3000 rows × 2 columns

\n", "
" ], "text/plain": [ " Id Value\n", "0 1 0.00\n", "1 2 0.00\n", "2 3 0.00\n", "3 4 0.00\n", "4 5 0.00\n", "... ... ...\n", "2995 2996 2.36\n", "2996 2997 5.60\n", "2997 2998 11.00\n", "2998 2999 -3.40\n", "2999 3000 -3.40\n", "\n", "[3000 rows x 2 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#generate output file\n", "df = pd.DataFrame(q_table.ravel())\n", "\n", "df.index += 1\n", "df.to_csv(\"q_table.csv\", index_label=\"Id\", header=[\"Value\"])\n", "df\n", "\n", "# load data again to test correct file format\n", "pd.read_csv(\"q_table.csv\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "papermill": { "duration": 559.137382, "end_time": "2020-10-23T07:10:37.784395", "environment_variables": {}, "exception": null, "input_path": "__notebook__.ipynb", "output_path": "__notebook__.ipynb", "parameters": {}, "start_time": "2020-10-23T07:01:18.647013", "version": "2.1.0" } }, "nbformat": 4, "nbformat_minor": 4 }