Files
UU_ai_course/assignment_5.ipynb

712 lines
20 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"execution": {
"iopub.execute_input": "2020-10-23T07:01:24.034522Z",
"iopub.status.busy": "2020-10-23T07:01:24.033481Z",
"iopub.status.idle": "2020-10-23T07:01:24.040184Z",
"shell.execute_reply": "2020-10-23T07:01:24.039289Z"
},
"papermill": {
"duration": 0.026939,
"end_time": "2020-10-23T07:01:24.040385",
"exception": false,
"start_time": "2020-10-23T07:01:24.013446",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/input/rl-assignment/samplesubmission.csv\n"
]
}
],
"source": [
"# This Python 3 environment comes with many helpful analytics libraries installed\n",
"# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
"# For example, here's several helpful packages to load\n",
"\n",
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"\n",
"# Input data files are available in the read-only \"../input/\" directory\n",
"# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
"\n",
"import os\n",
"for dirname, _, filenames in os.walk('/kaggle/input'):\n",
" for filename in filenames:\n",
" print(os.path.join(dirname, filename))\n",
"\n",
"# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
"# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
"_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a",
"execution": {
"iopub.execute_input": "2020-10-23T07:01:24.069878Z",
"iopub.status.busy": "2020-10-23T07:01:24.068972Z",
"iopub.status.idle": "2020-10-23T07:01:25.495962Z",
"shell.execute_reply": "2020-10-23T07:01:25.494876Z"
},
"papermill": {
"duration": 1.44331,
"end_time": "2020-10-23T07:01:25.496099",
"exception": false,
"start_time": "2020-10-23T07:01:24.052789",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+\n",
"|\u001b[34;1mR\u001b[0m: |\u001b[43m \u001b[0m: :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |B: |\n",
"+---------+\n",
"\n"
]
}
],
"source": [
"import gym\n",
"\n",
"env = gym.make(\"Taxi-v3\").env\n",
"\n",
"env.reset() # reset environment to a new, random state\n",
"env.render()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:01:25.526480Z",
"iopub.status.busy": "2020-10-23T07:01:25.525487Z",
"iopub.status.idle": "2020-10-23T07:01:25.530094Z",
"shell.execute_reply": "2020-10-23T07:01:25.529110Z"
},
"papermill": {
"duration": 0.022214,
"end_time": "2020-10-23T07:01:25.530278",
"exception": false,
"start_time": "2020-10-23T07:01:25.508064",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Action Space Discrete(6)\n",
"State Space Discrete(500)\n"
]
}
],
"source": [
"# action space\n",
"print(\"Action Space {}\".format(env.action_space))\n",
"\n",
"print(\"State Space {}\".format(env.observation_space))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:01:25.561002Z",
"iopub.status.busy": "2020-10-23T07:01:25.560068Z",
"iopub.status.idle": "2020-10-23T07:01:25.563851Z",
"shell.execute_reply": "2020-10-23T07:01:25.564479Z"
},
"papermill": {
"duration": 0.022719,
"end_time": "2020-10-23T07:01:25.564666",
"exception": false,
"start_time": "2020-10-23T07:01:25.541947",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"State: 328\n",
"+---------+\n",
"|\u001b[35mR\u001b[0m: | : :G|\n",
"| : | : : |\n",
"| : : : : |\n",
"| |\u001b[43m \u001b[0m: | : |\n",
"|\u001b[34;1mY\u001b[0m| : |B: |\n",
"+---------+\n",
"\n"
]
}
],
"source": [
"# state encoding\n",
"\n",
"state = env.encode(3, 1, 2, 0) # (taxi row, taxi column, passenger index, destination index)\n",
"print(\"State:\", state)\n",
"\n",
"env.s = state\n",
"env.render()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:01:25.599957Z",
"iopub.status.busy": "2020-10-23T07:01:25.599051Z",
"iopub.status.idle": "2020-10-23T07:01:25.604442Z",
"shell.execute_reply": "2020-10-23T07:01:25.603711Z"
},
"papermill": {
"duration": 0.027519,
"end_time": "2020-10-23T07:01:25.604618",
"exception": false,
"start_time": "2020-10-23T07:01:25.577099",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{0: [(1.0, 428, -1, False)],\n",
" 1: [(1.0, 228, -1, False)],\n",
" 2: [(1.0, 348, -1, False)],\n",
" 3: [(1.0, 328, -1, False)],\n",
" 4: [(1.0, 328, -10, False)],\n",
" 5: [(1.0, 328, -10, False)]}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# rewards table\n",
"env.P[328]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:01:25.639997Z",
"iopub.status.busy": "2020-10-23T07:01:25.638924Z",
"iopub.status.idle": "2020-10-23T07:01:25.643756Z",
"shell.execute_reply": "2020-10-23T07:01:25.642988Z"
},
"papermill": {
"duration": 0.025446,
"end_time": "2020-10-23T07:01:25.643881",
"exception": false,
"start_time": "2020-10-23T07:01:25.618435",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Defining q table\n",
"\n",
"import numpy as np\n",
"q_table = np.zeros([env.observation_space.n, env.action_space.n])\n",
"q_table"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:01:25.677831Z",
"iopub.status.busy": "2020-10-23T07:01:25.676776Z",
"iopub.status.idle": "2020-10-23T07:01:25.680343Z",
"shell.execute_reply": "2020-10-23T07:01:25.679723Z"
},
"papermill": {
"duration": 0.022979,
"end_time": "2020-10-23T07:01:25.680493",
"exception": false,
"start_time": "2020-10-23T07:01:25.657514",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"# hyperparameters\n",
"alpha = 0.1\n",
"gamma = 0.6\n",
"epsiolon = 0.1"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:01:25.765603Z",
"iopub.status.busy": "2020-10-23T07:01:25.726090Z",
"iopub.status.idle": "2020-10-23T07:10:37.138058Z",
"shell.execute_reply": "2020-10-23T07:10:37.138886Z"
},
"papermill": {
"duration": 551.443876,
"end_time": "2020-10-23T07:10:37.139105",
"exception": false,
"start_time": "2020-10-23T07:01:25.695229",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 999999 Iterations: 14810058\r\n"
]
}
],
"source": [
"import random\n",
"\n",
"max_epochs = 1000000 # stops after x epochs or max_iterations\n",
"max_iterations = -1 # 1000000 # -1 to not stop based on iterations\n",
"\n",
"iteration = 0\n",
"for epoch in range(max_epochs):\n",
" done = False\n",
" new_state = env.reset()\n",
" \n",
" while not done and (iteration < max_iterations or max_iterations == -1):\n",
" iteration += 1\n",
" old_state = new_state\n",
" if random.uniform(0,1) < epsiolon:\n",
" action = env.action_space.sample()\n",
" else:\n",
" action = np.argmax(q_table[old_state])\n",
" \n",
" new_state, reward, done, info = env.step(action)\n",
" env.s = new_state # update state of the environment (not sure if )\n",
" \n",
" q_table[old_state, action] = q_table[old_state][action] + alpha * (reward + gamma * np.max(q_table[new_state]) - q_table[old_state][action])\n",
" \n",
" if iteration >= max_iterations and max_iterations != -1:\n",
" break\n",
"print(\"Epoch: {} Iterations: {}\\r\".format(epoch, iteration))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:10:37.176308Z",
"iopub.status.busy": "2020-10-23T07:10:37.175458Z",
"iopub.status.idle": "2020-10-23T07:10:37.179942Z",
"shell.execute_reply": "2020-10-23T07:10:37.180484Z"
},
"papermill": {
"duration": 0.026804,
"end_time": "2020-10-23T07:10:37.180659",
"exception": false,
"start_time": "2020-10-23T07:10:37.153855",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. ],\n",
" [ -2.41837066, -2.3639511 , -2.41837066, -2.3639511 ,\n",
" -2.27325184, -11.3639511 ],\n",
" [ -1.870144 , -1.45024 , -1.870144 , -1.45024 ,\n",
" -0.7504 , -10.45024 ],\n",
" ...,\n",
" [ -0.75955193, 0.416 , -0.75945955, -1.44283176,\n",
" -9.1295566 , -9.26328213],\n",
" [ -2.26879216, -2.1220864 , -2.24661946, -2.1220864 ,\n",
" -10.7816137 , -10.43290413],\n",
" [ 5.6 , 2.36 , 5.6 , 11. ,\n",
" -3.4 , -3.4 ]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"q_table"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:10:37.220686Z",
"iopub.status.busy": "2020-10-23T07:10:37.219859Z",
"iopub.status.idle": "2020-10-23T07:10:37.228862Z",
"shell.execute_reply": "2020-10-23T07:10:37.229575Z"
},
"papermill": {
"duration": 0.033495,
"end_time": "2020-10-23T07:10:37.229780",
"exception": false,
"start_time": "2020-10-23T07:10:37.196285",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| :\u001b[43m \u001b[0m: : : |\n",
"| | : | : |\n",
"|Y| : |\u001b[34;1mB\u001b[0m: |\n",
"+---------+\n",
" (East)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : :\u001b[43m \u001b[0m: : |\n",
"| | : | : |\n",
"|Y| : |\u001b[34;1mB\u001b[0m: |\n",
"+---------+\n",
" (East)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : :\u001b[43m \u001b[0m: |\n",
"| | : | : |\n",
"|Y| : |\u001b[34;1mB\u001b[0m: |\n",
"+---------+\n",
" (East)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : |\u001b[43m \u001b[0m: |\n",
"|Y| : |\u001b[34;1mB\u001b[0m: |\n",
"+---------+\n",
" (South)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |\u001b[34;1m\u001b[43mB\u001b[0m\u001b[0m: |\n",
"+---------+\n",
" (South)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |\u001b[42mB\u001b[0m: |\n",
"+---------+\n",
" (Pickup)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : |\u001b[42m_\u001b[0m: |\n",
"|Y| : |B: |\n",
"+---------+\n",
" (North)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : :\u001b[42m_\u001b[0m: |\n",
"| | : | : |\n",
"|Y| : |B: |\n",
"+---------+\n",
" (North)\n",
"+---------+\n",
"|R: | : :\u001b[35mG\u001b[0m|\n",
"| : | :\u001b[42m_\u001b[0m: |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |B: |\n",
"+---------+\n",
" (North)\n",
"+---------+\n",
"|R: | :\u001b[42m_\u001b[0m:\u001b[35mG\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |B: |\n",
"+---------+\n",
" (North)\n",
"+---------+\n",
"|R: | : :\u001b[35m\u001b[42mG\u001b[0m\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |B: |\n",
"+---------+\n",
" (East)\n",
"+---------+\n",
"|R: | : :\u001b[35m\u001b[34;1m\u001b[43mG\u001b[0m\u001b[0m\u001b[0m|\n",
"| : | : : |\n",
"| : : : : |\n",
"| | : | : |\n",
"|Y| : |B: |\n",
"+---------+\n",
" (Dropoff)\n"
]
}
],
"source": [
"# example game\n",
"\n",
"new_state = env.reset()\n",
"\n",
"done = False\n",
"max_iterations = 100\n",
"iteration = 0\n",
"while not done and (iteration <= max_iterations or max_iterations == -1):\n",
" old_state = new_state\n",
" action = np.argmax(q_table[old_state])\n",
" new_state, reward, done, info = env.step(action)\n",
" env.s = new_state\n",
" env.render()\n",
" iteration += 1\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2020-10-23T07:10:37.269410Z",
"iopub.status.busy": "2020-10-23T07:10:37.268454Z",
"iopub.status.idle": "2020-10-23T07:10:37.659363Z",
"shell.execute_reply": "2020-10-23T07:10:37.658561Z"
},
"papermill": {
"duration": 0.413907,
"end_time": "2020-10-23T07:10:37.659523",
"exception": false,
"start_time": "2020-10-23T07:10:37.245616",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Id</th>\n",
" <th>Value</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2995</th>\n",
" <td>2996</td>\n",
" <td>2.36</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2996</th>\n",
" <td>2997</td>\n",
" <td>5.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2997</th>\n",
" <td>2998</td>\n",
" <td>11.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2998</th>\n",
" <td>2999</td>\n",
" <td>-3.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2999</th>\n",
" <td>3000</td>\n",
" <td>-3.40</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3000 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" Id Value\n",
"0 1 0.00\n",
"1 2 0.00\n",
"2 3 0.00\n",
"3 4 0.00\n",
"4 5 0.00\n",
"... ... ...\n",
"2995 2996 2.36\n",
"2996 2997 5.60\n",
"2997 2998 11.00\n",
"2998 2999 -3.40\n",
"2999 3000 -3.40\n",
"\n",
"[3000 rows x 2 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#generate output file\n",
"df = pd.DataFrame(q_table.ravel())\n",
"\n",
"df.index += 1\n",
"df.to_csv(\"q_table.csv\", index_label=\"Id\", header=[\"Value\"])\n",
"df\n",
"\n",
"# load data again to test correct file format\n",
"pd.read_csv(\"q_table.csv\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"papermill": {
"duration": 559.137382,
"end_time": "2020-10-23T07:10:37.784395",
"environment_variables": {},
"exception": null,
"input_path": "__notebook__.ipynb",
"output_path": "__notebook__.ipynb",
"parameters": {},
"start_time": "2020-10-23T07:01:18.647013",
"version": "2.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}