From bd35d67ee02e4a0f8d2afe040a00e9a146c80a6f Mon Sep 17 00:00:00 2001 From: Hannes Kuchelmeister Date: Fri, 23 Oct 2020 16:28:25 +0200 Subject: [PATCH] add ipynb for assignments in python --- assignment_4.ipynb | 1040 ++++++++++++++++++++++++++++++++++++++++++++ assignment_5.ipynb | 711 ++++++++++++++++++++++++++++++ 2 files changed, 1751 insertions(+) create mode 100644 assignment_4.ipynb create mode 100644 assignment_5.ipynb diff --git a/assignment_4.ipynb b/assignment_4.ipynb new file mode 100644 index 0000000..33e09ba --- /dev/null +++ b/assignment_4.ipynb @@ -0,0 +1,1040 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2020-10-23T07:12:26.661165Z", + "iopub.status.busy": "2020-10-23T07:12:26.660203Z", + "iopub.status.idle": "2020-10-23T07:12:26.665167Z", + "shell.execute_reply": "2020-10-23T07:12:26.664447Z" + }, + "papermill": { + "duration": 0.019587, + "end_time": "2020-10-23T07:12:26.665296", + "exception": false, + "start_time": "2020-10-23T07:12:26.645709", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/kaggle/input/nn-assignment/testset.csv\n", + "/kaggle/input/nn-assignment/samplesubmission.csv\n", + "/kaggle/input/nn-assignment/trainset.csv\n" + ] + } + ], + "source": [ + "# This Python 3 environment comes with many helpful analytics libraries installed\n", + "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", + "# For example, here's several helpful packages to load\n", + "\n", + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "\n", + "# Input data files are available in the read-only \"../input/\" directory\n", + "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", + "\n", + "import os\n", + "for dirname, _, filenames in os.walk('/kaggle/input'):\n", + " for filename in filenames:\n", + " print(os.path.join(dirname, filename))\n", + "\n", + "# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", + "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:12:26.685951Z", + "iopub.status.busy": "2020-10-23T07:12:26.685373Z", + "iopub.status.idle": "2020-10-23T07:12:28.979037Z", + "shell.execute_reply": "2020-10-23T07:12:28.978440Z" + }, + "papermill": { + "duration": 2.30553, + "end_time": "2020-10-23T07:12:28.979146", + "exception": false, + "start_time": "2020-10-23T07:12:26.673616", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelpixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8...pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
1887990000000255255...0000000000
88202000000000...0000000000
23937700000000255...0000000000
14765000000000...0000000000
4345000000000...0000000000
..................................................................
106841000000000...0000000000
152644000000000...25525525525525500000
193487000000000...0000000000
235799000000000...255255255255255255255000
140541000000000...0000000000
\n", + "

28000 rows × 785 columns

\n", + "
" + ], + "text/plain": [ + " label pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 \\\n", + "18879 9 0 0 0 0 0 0 0 255 \n", + "8820 2 0 0 0 0 0 0 0 0 \n", + "23937 7 0 0 0 0 0 0 0 0 \n", + "1476 5 0 0 0 0 0 0 0 0 \n", + "434 5 0 0 0 0 0 0 0 0 \n", + "... ... ... ... ... ... ... ... ... ... \n", + "10684 1 0 0 0 0 0 0 0 0 \n", + "15264 4 0 0 0 0 0 0 0 0 \n", + "19348 7 0 0 0 0 0 0 0 0 \n", + "23579 9 0 0 0 0 0 0 0 0 \n", + "14054 1 0 0 0 0 0 0 0 0 \n", + "\n", + " pixel8 ... pixel774 pixel775 pixel776 pixel777 pixel778 \\\n", + "18879 255 ... 0 0 0 0 0 \n", + "8820 0 ... 0 0 0 0 0 \n", + "23937 255 ... 0 0 0 0 0 \n", + "1476 0 ... 0 0 0 0 0 \n", + "434 0 ... 0 0 0 0 0 \n", + "... ... ... ... ... ... ... ... \n", + "10684 0 ... 0 0 0 0 0 \n", + "15264 0 ... 255 255 255 255 255 \n", + "19348 0 ... 0 0 0 0 0 \n", + "23579 0 ... 255 255 255 255 255 \n", + "14054 0 ... 0 0 0 0 0 \n", + "\n", + " pixel779 pixel780 pixel781 pixel782 pixel783 \n", + "18879 0 0 0 0 0 \n", + "8820 0 0 0 0 0 \n", + "23937 0 0 0 0 0 \n", + "1476 0 0 0 0 0 \n", + "434 0 0 0 0 0 \n", + "... ... ... ... ... ... \n", + "10684 0 0 0 0 0 \n", + "15264 0 0 0 0 0 \n", + "19348 0 0 0 0 0 \n", + "23579 255 255 0 0 0 \n", + "14054 0 0 0 0 0 \n", + "\n", + "[28000 rows x 785 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = pd.read_csv(\"/kaggle/input/nn-assignment/trainset.csv\")\n", + "#randomly shuffle data\n", + "data = data.sample(frac = 1) \n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:12:29.001507Z", + "iopub.status.busy": "2020-10-23T07:12:29.000681Z", + "iopub.status.idle": "2020-10-23T07:12:29.004182Z", + "shell.execute_reply": "2020-10-23T07:12:29.004645Z" + }, + "papermill": { + "duration": 0.01646, + "end_time": "2020-10-23T07:12:29.004783", + "exception": false, + "start_time": "2020-10-23T07:12:28.988323", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "5600" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "validation_data_percentage = 0.2\n", + "validation_data_amount = round(validation_data_percentage * len(data))\n", + "validation_data_amount" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:12:29.087417Z", + "iopub.status.busy": "2020-10-23T07:12:29.084803Z", + "iopub.status.idle": "2020-10-23T07:12:29.202140Z", + "shell.execute_reply": "2020-10-23T07:12:29.201577Z" + }, + "papermill": { + "duration": 0.188986, + "end_time": "2020-10-23T07:12:29.202254", + "exception": false, + "start_time": "2020-10-23T07:12:29.013268", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def reshape_input(input_data):\n", + " return np.reshape(input_data, (len(input_data), 28, 28, 1))\n", + "\n", + "def normalize(input_data):\n", + " return input_data / 255.0\n", + "\n", + "def preprocess(input_data):\n", + " return normalize(reshape_input(input_data))\n", + "\n", + "X = data.drop(columns='label').values\n", + "Y = data.loc[:,'label'].values\n", + "\n", + "X = preprocess(X)\n", + "\n", + "validation_X = X[:validation_data_amount]\n", + "validation_Y = Y[:validation_data_amount]\n", + "train_X = X[validation_data_amount:]\n", + "train_Y = Y[validation_data_amount:]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:12:29.225316Z", + "iopub.status.busy": "2020-10-23T07:12:29.224728Z", + "iopub.status.idle": "2020-10-23T07:12:29.565835Z", + "shell.execute_reply": "2020-10-23T07:12:29.566259Z" + }, + "papermill": { + "duration": 0.355532, + "end_time": "2020-10-23T07:12:29.566416", + "exception": false, + "start_time": "2020-10-23T07:12:29.210884", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAFlCAYAAADyArMXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAASZUlEQVR4nO3df5BV9XnH8c8jIIhCZUNwFtQgKkZoAeMWndCktGggdqzaViNNHZjYwWmkFX90akxbnUnNOC0SO451ukQqZBTjKFQn42RFakIN1gqEn90Ilq6pLoVaTYEhEhee/nEPzWa9y97v3nvO3efu+zXD3Lvffe49z5nDfvxy9n6/mrsLABDPKfVuAADQPwQ4AARFgANAUAQ4AARFgANAUAQ4AAQ1tMiDnWrDfYROL/KQABDeIb3/rrt/vOd4oQE+QqfrMptT5CEBILyX/Jm3yo1zCwUAgiLAASAoAhwAgiLAASAoAhwAgiLAASAoAhwAgiLAASAoAhwAgiLAASCoQpfSDyRtndvq3QIGkLnjp9e7BSAZM3AACIoAB4CgCHAACIoAB4CgCHAACIoAB4CgCHAACIoAB4CgCHAACIoAB4CgBu1SepZOA4iOGTgABEWAA0BQBDgABEWAA0BQBDgABEWAA0BQBDgABNVngJvZOWb2spm1m9kuM7stG7/PzN4xs63Zn6vybxcAcEIlC3m6JN3p7lvMbJSkzWa2LvveN9x9aX7tAQB602eAu/s+Sfuy54fMrF3ShLwbAwCcXNI9cDObKOkSSa9lQ4vNbLuZrTCzMTXuDQBwEhUHuJmdIelZSUvc/aCkRyWdL2mGSjP0B3t53SIz22Rmmz7U0Rq0DACQKgxwMxumUng/4e5rJMnd97v7MXc/Lmm5pJnlXuvure7e4u4twzS8Vn0DwKBXyadQTNJjktrdfVm38eZuZddJ2ln79gAAvankUyizJN0kaYeZbc3G7pE038xmSHJJHZJuyaVDAEBZlXwK5RVJVuZbL9S+HQBApViJCQBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBVbKQB6i5ts5t9W4BA8jc8dPr3UJIzMABICgCHACCIsABICgCHACCIsABICgCHACCIsABICgCHACCIsABICgCHACCGrBL6Vlqje4ePzguqf6Lozpz6qQY6z8YmVT/8JWfT6rv2tuRVI+BiRk4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARl7l7YwUZbk19mcwo7Horz7w9dnlTffv0jOXVSctG6RUn1Uyem7Z2y5sLvJNXnbeqqP06qP+8rG3PqJH+DcZ+kIc17Nrt7S89xZuAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEFShe6G0TB/h/9p2bmHHi2Tu+On1buEX7Fl5aVL9riseTaofqiFJ9Z/ZdmNS/ZlX702qtyFp/RydPS2p/va/W51UP++0g0n1XTqWVP/Ze5ck1Td9c+DsncJeKD/HDBwAgiLAASAoAhwAgiLAASAoAhwAgiLAASCoPgPczM4xs5fNrN3MdpnZbdl4k5mtM7M92eOY/NsFAJxQyQy8S9Kd7n6xpMsl3WpmUyTdLWm9u18oaX32NQCgIH0GuLvvc/ct2fNDktolTZB0jaSVWdlKSdfm1SQA4KOS7oGb2URJl0h6TdJZ7r5PKoW8pHG1bg4A0LuhlRaa2RmSnpW0xN0Pmlmlr1skaZEknTuh4sOhzl6f83BS/VANz6mTkocufiqpfvnGX0+q3/2TtPnH6D/oSKp/qGNOUv28i9cm1aduTeCBP74w0LadKMaesqMVXUYzG6ZSeD/h7muy4f1m1px9v1nSgXKvdfdWd29x95aPfyztLxkAoHeVfArFJD0mqd3dl3X71vOSFmTPF0h6rvbtAQB6U8k9jVmSbpK0w8y2ZmP3SHpA0tNmdrOkH0u6Pp8WAQDl9Bng7v6KpN5ueKfd2AMA1EzgX2UAwOBGgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARV6OYku7ePHKT7GKBal56aWH/295Pqp2xYnFQ/8r/3JtVL5ybW58uueTftBa359IHqMAMHgKAIcAAIigAHgKAIcAAIigAHgKAIcAAIigAHgKAIcAAIigAHgKAIcAAIigAHgKAK3QsFcVz9p3ck1S/7+iNJ9al7m6SaveOGpPoL7t2WVH88qVoavuS0tBe0JR4g0Yqpq5Lq79LlOXWSv7bOtGs7EA1pLj/ODBwAgiLAASAoAhwAgiLAASAoAhwAgiLAASAoAhwAgiLAASAoAhwAgiLAASCoQpfST552RG1t8Ze1DgRzx0/P9f1Hrf6XpPolwxYn1d/556uT6pfePz+pfuxLbyXVdx05klSf6ke3npnr+6N3ef+sFGNP2VFm4AAQFAEOAEER4AAQFAEOAEER4AAQFAEOAEER4AAQFAEOAEER4AAQFAEOAEER4AAQVKF7oezePrJB9iVAT7+06tWk+m+umphUf6bS3r8rqVo6/IXLk+r/a1ba+39l9vNpL8jZKeb1bgE1wAwcAIIiwAEgKAIcAIIiwAEgKAIcAIIiwAEgKAIcAILqM8DNbIWZHTCznd3G7jOzd8xsa/bnqnzbBAD0VMkM/HFJ88qMf8PdZ2R/XqhtWwCAvvQZ4O6+QdJ7BfQCAEhQzT3wxWa2PbvFMqa3IjNbZGabzGzThzpaxeEAAN31dy+URyV9TZJnjw9K+lK5QndvldQqSaOtiQ0Y0C9DL5iUVN/0rbR/NC6fsDSpvnnIaUn1A81xt3q3gBro1wzc3fe7+zF3Py5puaSZtW0LANCXfgW4mTV3+/I6STt7qwUA5KPPWyhmtlrSbEljzextSfdKmm1mM1S6hdIh6ZYcewQAlNFngLv7/DLDj+XQCwAgASsxASAoAhwAgiLAASAoAhwAgiLAASAoAhwAgurvUvp+mTztiNrathV5SAxQv/Urv5lUP+HJ/Un1j0zYkFQvxV4an+rsoWm7WnS9dG5S/bCvnplU769tT6pHCTNwAAiKAAeAoAhwAAiKAAeAoAhwAAiKAAeAoAhwAAiKAAeAoAhwAAiKAAeAoAhwAAjK3NP2RKjGaGvyy2xOYcdDcYZOGJ9UP2Ht/ybVp+9tkmbhW1ck1T8x8Xv5NIJw5o6fnvsxXvJnNrt7S89xZuAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAENTQejeAgWnoBZOS6sd+63+S6lP3NvnB0WFJ9bfvvCGpfvwfpvU/+Y4/SqrfddPDSfWptv4sbU+jZZ1zk+pXTXwxqf794x8k1c/69l0V106669Wk925kzMABICgCHACCIsABICgCHACCIsABICgCHACCIsABICgCHACCIsABICgCHACCYik9yvrE6s6k+r8d/4OcOin58pYvJtWf83s7kup91Kik+lVfyHdpfKov/9WfJNWPffKHSfVT/+FLSfXbPrM8qf6HNz5Uce2n992e9N7ND25Mqo+EGTgABEWAA0BQBDgABEWAA0BQBDgABEWAA0BQBDgABNVngJvZCjM7YGY7u401mdk6M9uTPY7Jt00AQE+VzMAflzSvx9jdkta7+4WS1mdfAwAK1GeAu/sGSe/1GL5G0srs+UpJ19a4LwBAH/p7D/wsd98nSdnjuN4KzWyRmW0ys00f6mg/DwcA6Cn3vVDcvVVSqySNtibP+3ioja81/1PiK4YnVS9864qk+okL/yOp/lhStfTG/VOT6i899XuJR0hzW+espPpx33kzqb7rpz9Nqj/vxm1J9TO+viSpfuHV6yuuHfEeMXJCf2fg+82sWZKyxwO1awkAUIn+BvjzkhZkzxdIeq427QAAKlXJxwhXS3pV0kVm9raZ3SzpAUlXmtkeSVdmXwMACtTnPXB3n9/Lt+bUuBcAQAJWYgJAUAQ4AARFgANAUAQ4AARFgANAUAQ4AASV+1J6DAynXDIlrV6vJtU/fbjX7XDK+snvpi29P3b4/aT6I79zWVL907/9cFJ96tznoKftA/TKtz+VVN+8f2NSfd4m3pPWzyutn6y4dkxH2t/NVG2dadsGFGFIc/lxZuAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEFShe6FMnnZEbW0Db5+BwWDSi5cm1Z9hpybVHzp+WlK9HzqcVH/06plJ9X/5NyuS6qcNy3cuM3PNHUn1FywdWHub5K2r4616t/D/5o6fXu8WythTdpQZOAAERYADQFAEOAAERYADQFAEOAAERYADQFAEOAAERYADQFAEOAAERYADQFAEOAAEVeheKLu3jxyg+ww0vqZbhifVH72iK6n+5tFpe1nc/Ebq3hffT6zP19OHxyXVX/DUBzl1gsGMGTgABEWAA0BQBDgABEWAA0BQBDgABEWAA0BQBDgABEWAA0BQBDgABEWAA0BQhS6lR/187O83JtUf/YtjSfXDg/9VumjdLUn1n1x6KO0AO7em1QMVYAYOAEER4AAQFAEOAEER4AAQFAEOAEER4AAQFAEOAEFV9eFdM+uQdEjSMUld7t5Si6YAAH2rxeqL33D3d2vwPgCABNxCAYCgqg1wl/SimW02s0W1aAgAUJlqb6HMcvdOMxsnaZ2Z/cjdN3QvyIJ9kSSN0MgqD4fB6qsHfjWp/rtPXJ5Uf/7rHyTVH9v5RlI9kIeqZuDu3pk9HpC0VtLMMjWt7t7i7i3DNLyawwEAuul3gJvZ6WY26sRzSZ+TtLNWjQEATq6aWyhnSVprZife50l3/25NugIA9KnfAe7ueyVNr2EvAIAEfIwQAIIiwAEgKAIcAIIiwAEgKAIcAIIiwAEgKAIcAIKqxXayaEC/f/an691CD8eTqpu1Mac+gF/U1rkt92MMaS4/zgwcAIIiwAEgKAIcAIIiwAEgKAIcAIIiwAEgKAIcAIIiwAEgKAIcAIIiwAEgqEKX0k+edkRtbfkvO0X15o7n/5YHVKKYn5U9ZUeZgQNAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUIXuhbJ7+0j22ACAGmEGDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEBQBDgBBEeAAEFRVAW5m88zsDTN708zurlVTAIC+9TvAzWyIpEckfV7SFEnzzWxKrRoDAJxcNTPwmZLedPe97v4zSU9JuqY2bQEA+lJNgE+Q9J/dvn47GwMAFGBoFa+1MmP+kSKzRZIWZV8efcmf2VnFMaMZK+ndejdRoMF0voPpXCXOt94+UW6wmgB/W9I53b4+W1JnzyJ3b5XUKklmtsndW6o4Ziicb+MaTOcqcb4DVTW3UF6XdKGZnWdmp0q6UdLztWkLANCXfs/A3b3LzBZLapM0RNIKd99Vs84AACdVzS0UufsLkl5IeElrNccLiPNtXIPpXCXOd0Ay94/83hEAEABL6QEgqEICfLAtuTezDjPbYWZbzWxTvfupNTNbYWYHzGxnt7EmM1tnZnuyxzH17LGWejnf+8zsnewabzWzq+rZY62Y2Tlm9rKZtZvZLjO7LRtvyOt7kvMNcX1zv4WSLbnfLelKlT56+Lqk+e7+b7keuI7MrENSi7sPpM+R1oyZfVbSYUmr3P2Xs7G/lvSeuz+Q/Ud6jLv/WT37rJVezvc+SYfdfWk9e6s1M2uW1OzuW8xslKTNkq6VtFANeH1Pcr43KMD1LWIGzpL7BuPuGyS912P4Gkkrs+crVfohaAi9nG9Dcvd97r4le35IUrtKK6wb8vqe5HxDKCLAB+OSe5f0opltzlaiDgZnufs+qfRDIWlcnfspwmIz257dYmmIWwrdmdlESZdIek2D4Pr2OF8pwPUtIsArWnLfYGa5+6dU2qnx1uyf4Ggsj0o6X9IMSfskPVjfdmrLzM6Q9KykJe5+sN795K3M+Ya4vkUEeEVL7huJu3dmjwckrVXpNlKj25/dTzxxX/FAnfvJlbvvd/dj7n5c0nI10DU2s2EqhdkT7r4mG27Y61vufKNc3yICfFAtuTez07NfhsjMTpf0OUmDYQOv5yUtyJ4vkPRcHXvJ3Ykwy1ynBrnGZmaSHpPU7u7Lun2rIa9vb+cb5foWspAn+wjOQ/r5kvv7cz9onZjZJJVm3VJppeuTjXa+ZrZa0myVdmzbL+leSf8o6WlJ50r6saTr3b0hfvHXy/nOVumf1y6pQ9ItJ+4RR2ZmvybpnyXtkHQ8G75HpfvCDXd9T3K+8xXg+rISEwCCYiUmAARFgANAUAQ4AARFgANAUAQ4AARFgANAUAQ4AARFgANAUP8HxDUa8nyHYAYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.figure(figsize=(6,6))\n", + "plot_data = np.reshape(data.drop(columns='label').values, (len(data.drop(columns='label').values), 28, 28))\n", + "plt.pcolor(plot_data[6][::-1][:])\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:12:29.594754Z", + "iopub.status.busy": "2020-10-23T07:12:29.593939Z", + "iopub.status.idle": "2020-10-23T07:12:35.838442Z", + "shell.execute_reply": "2020-10-23T07:12:35.837675Z" + }, + "papermill": { + "duration": 6.26273, + "end_time": "2020-10-23T07:12:35.838594", + "exception": false, + "start_time": "2020-10-23T07:12:29.575864", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "conv2d (Conv2D) (None, 26, 26, 32) 320 \n", + "_________________________________________________________________\n", + "max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 \n", + "_________________________________________________________________\n", + "conv2d_1 (Conv2D) (None, 11, 11, 64) 18496 \n", + "_________________________________________________________________\n", + "max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0 \n", + "_________________________________________________________________\n", + "conv2d_2 (Conv2D) (None, 3, 3, 64) 36928 \n", + "_________________________________________________________________\n", + "flatten (Flatten) (None, 576) 0 \n", + "_________________________________________________________________\n", + "dense (Dense) (None, 128) 73856 \n", + "_________________________________________________________________\n", + "dropout (Dropout) (None, 128) 0 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 64) 8256 \n", + "_________________________________________________________________\n", + "dropout_1 (Dropout) (None, 64) 0 \n", + "_________________________________________________________________\n", + "dense_2 (Dense) (None, 10) 650 \n", + "=================================================================\n", + "Total params: 138,506\n", + "Trainable params: 138,506\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "from tensorflow.keras import datasets, layers, models\n", + "import keras\n", + "\n", + "shape = train_X[0].shape\n", + "\n", + "\n", + "model = models.Sequential()\n", + "model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=shape))\n", + "model.add(layers.MaxPooling2D((2, 2)))\n", + "model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", + "model.add(layers.MaxPooling2D((2, 2)))\n", + "model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", + "model.add(layers.Flatten())\n", + "model.add(layers.Dense(128, activation = 'relu'))\n", + "model.add(layers.Dropout(0.25))\n", + "model.add(layers.Dense(64, activation='relu'))\n", + "model.add(layers.Dropout(0.5))\n", + "model.add(layers.Dense(10, activation='softmax'))\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:12:35.873662Z", + "iopub.status.busy": "2020-10-23T07:12:35.872894Z", + "iopub.status.idle": "2020-10-23T07:15:27.735039Z", + "shell.execute_reply": "2020-10-23T07:15:27.734394Z" + }, + "papermill": { + "duration": 171.885215, + "end_time": "2020-10-23T07:15:27.735182", + "exception": false, + "start_time": "2020-10-23T07:12:35.849967", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/15\n", + "700/700 [==============================] - 12s 17ms/step - loss: 1.7676 - accuracy: 0.6979 - val_loss: 1.5428 - val_accuracy: 0.9195\n", + "Epoch 2/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.5487 - accuracy: 0.9154 - val_loss: 1.5261 - val_accuracy: 0.9352\n", + "Epoch 3/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.5251 - accuracy: 0.9375 - val_loss: 1.5113 - val_accuracy: 0.9493\n", + "Epoch 4/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.5136 - accuracy: 0.9487 - val_loss: 1.5112 - val_accuracy: 0.9507\n", + "Epoch 5/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.5088 - accuracy: 0.9531 - val_loss: 1.4994 - val_accuracy: 0.9620\n", + "Epoch 6/15\n", + "700/700 [==============================] - 12s 17ms/step - loss: 1.5021 - accuracy: 0.9598 - val_loss: 1.5121 - val_accuracy: 0.9482\n", + "Epoch 7/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4991 - accuracy: 0.9625 - val_loss: 1.5026 - val_accuracy: 0.9582\n", + "Epoch 8/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4964 - accuracy: 0.9650 - val_loss: 1.5044 - val_accuracy: 0.9566\n", + "Epoch 9/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4988 - accuracy: 0.9623 - val_loss: 1.4946 - val_accuracy: 0.9661\n", + "Epoch 10/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4977 - accuracy: 0.9633 - val_loss: 1.4984 - val_accuracy: 0.9620\n", + "Epoch 11/15\n", + "700/700 [==============================] - 12s 17ms/step - loss: 1.4931 - accuracy: 0.9682 - val_loss: 1.4963 - val_accuracy: 0.9645\n", + "Epoch 12/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4914 - accuracy: 0.9698 - val_loss: 1.4914 - val_accuracy: 0.9702\n", + "Epoch 13/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4924 - accuracy: 0.9690 - val_loss: 1.4924 - val_accuracy: 0.9684\n", + "Epoch 14/15\n", + "700/700 [==============================] - 12s 17ms/step - loss: 1.4876 - accuracy: 0.9740 - val_loss: 1.4899 - val_accuracy: 0.9711\n", + "Epoch 15/15\n", + "700/700 [==============================] - 11s 16ms/step - loss: 1.4883 - accuracy: 0.9729 - val_loss: 1.5007 - val_accuracy: 0.9602\n" + ] + } + ], + "source": [ + "from keras import losses\n", + "# Read more https://keras.io/api/optimizers/\n", + "optimizer = 'adam' # sdg, rmsprop, adam, adadelta, adagrad, adamax, nadam, ftrl\n", + "# Read more https://www.tutorialspoint.com/keras/keras_model_compilation.htm\n", + "loss = losses.SparseCategoricalCrossentropy(from_logits=True)\n", + "# Read more at https://www.tutorialspoint.com/keras/keras_model_compilation.htm\n", + "metrics = ['accuracy'] # accuracy, binary_accuracy, categorical_accuracy, ...\n", + "\n", + "model.compile(optimizer=optimizer,\n", + " loss=loss,\n", + " metrics=metrics)\n", + "\n", + "\n", + "epochs = 15\n", + "shuffle = True # True, False or \"batch\"\n", + "batch_size = 32 # Integer or None -> default 32\n", + "callbacks = [] # List of callbacks during training\n", + "\n", + "history = model.fit(train_X, train_Y, \n", + " validation_data=(validation_X, validation_Y),\n", + " epochs=epochs,\n", + " shuffle=shuffle,\n", + " batch_size=batch_size,\n", + " callbacks=callbacks)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:15:29.295413Z", + "iopub.status.busy": "2020-10-23T07:15:29.294471Z", + "iopub.status.idle": "2020-10-23T07:15:29.443107Z", + "shell.execute_reply": "2020-10-23T07:15:29.442237Z" + }, + "papermill": { + "duration": 0.951328, + "end_time": "2020-10-23T07:15:29.443239", + "exception": false, + "start_time": "2020-10-23T07:15:28.491911", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(history.history['accuracy'], label='accuracy')\n", + "plt.plot(history.history['val_accuracy'], label = 'val_accuracy')\n", + "plt.xlabel('Epoch')\n", + "plt.ylabel('Accuracy')\n", + "plt.ylim([0.8, 1])\n", + "plt.legend(loc='lower right')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:15:30.950554Z", + "iopub.status.busy": "2020-10-23T07:15:30.949956Z", + "iopub.status.idle": "2020-10-23T07:15:33.901103Z", + "shell.execute_reply": "2020-10-23T07:15:33.901654Z" + }, + "papermill": { + "duration": 3.706087, + "end_time": "2020-10-23T07:15:33.901808", + "exception": false, + "start_time": "2020-10-23T07:15:30.195721", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "test_data_unprocessed = pd.read_csv(\"/kaggle/input/nn-assignment/testset.csv\")\n", + "test_data_X = preprocess(test_data_unprocessed.values)\n", + "\n", + "prediction = model.predict(test_data_X)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:15:35.421782Z", + "iopub.status.busy": "2020-10-23T07:15:35.420230Z", + "iopub.status.idle": "2020-10-23T07:15:36.302679Z", + "shell.execute_reply": "2020-10-23T07:15:36.303099Z" + }, + "papermill": { + "duration": 1.650328, + "end_time": "2020-10-23T07:15:36.303252", + "exception": false, + "start_time": "2020-10-23T07:15:34.652924", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import random\n", + "plot_data = np.reshape(test_data_unprocessed.values, (len(test_data_unprocessed.values), 28, 28))\n", + "\n", + "subplots_x = 4\n", + "subplots_y = 4\n", + "\n", + "\n", + "\n", + "plt.figure(figsize=(10,10))\n", + "fig, ax = plt.subplots(subplots_y, subplots_x)\n", + "fig.tight_layout()\n", + "\n", + "randomlist = random.sample(range(0, len(plot_data)), subplots_x * subplots_y)\n", + "r_index = 0\n", + "\n", + "for x in range(subplots_x):\n", + " for y in range(subplots_y):\n", + " index = randomlist[r_index]\n", + " ax[y][x].axes.xaxis.set_visible(False)\n", + " ax[y][x].axes.yaxis.set_visible(False)\n", + " ax[y][x].pcolor(plot_data[index][::-1][:])\n", + " ax[y][x].set_title('Prediction: {}'.format(np.argmax(prediction[index])))\n", + " r_index += 1\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:15:37.845380Z", + "iopub.status.busy": "2020-10-23T07:15:37.844722Z", + "iopub.status.idle": "2020-10-23T07:15:38.098061Z", + "shell.execute_reply": "2020-10-23T07:15:38.097439Z" + }, + "papermill": { + "duration": 1.045857, + "end_time": "2020-10-23T07:15:38.098169", + "exception": false, + "start_time": "2020-10-23T07:15:37.052312", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ImageIdLabel
013
121
233
348
450
.........
13995139960
13996139971
13997139987
13998139996
13999140009
\n", + "

14000 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " ImageId Label\n", + "0 1 3\n", + "1 2 1\n", + "2 3 3\n", + "3 4 8\n", + "4 5 0\n", + "... ... ...\n", + "13995 13996 0\n", + "13996 13997 1\n", + "13997 13998 7\n", + "13998 13999 6\n", + "13999 14000 9\n", + "\n", + "[14000 rows x 2 columns]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# generating the result file\n", + "result = np.argmax(prediction[:], axis = 1 )\n", + "\n", + "df = pd.DataFrame(result)\n", + "df.index += 1\n", + "df.to_csv(\"prediction.csv\", index_label=\"ImageId\", header=[\"Label\"])\n", + "\n", + "# load data again to test correct file format\n", + "pd.read_csv(\"prediction.csv\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "papermill": { + "duration": 196.572998, + "end_time": "2020-10-23T07:15:38.971256", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2020-10-23T07:12:22.398258", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/assignment_5.ipynb b/assignment_5.ipynb new file mode 100644 index 0000000..e1cb6ae --- /dev/null +++ b/assignment_5.ipynb @@ -0,0 +1,711 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2020-10-23T07:01:24.034522Z", + "iopub.status.busy": "2020-10-23T07:01:24.033481Z", + "iopub.status.idle": "2020-10-23T07:01:24.040184Z", + "shell.execute_reply": "2020-10-23T07:01:24.039289Z" + }, + "papermill": { + "duration": 0.026939, + "end_time": "2020-10-23T07:01:24.040385", + "exception": false, + "start_time": "2020-10-23T07:01:24.013446", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/kaggle/input/rl-assignment/samplesubmission.csv\n" + ] + } + ], + "source": [ + "# This Python 3 environment comes with many helpful analytics libraries installed\n", + "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", + "# For example, here's several helpful packages to load\n", + "\n", + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "\n", + "# Input data files are available in the read-only \"../input/\" directory\n", + "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", + "\n", + "import os\n", + "for dirname, _, filenames in os.walk('/kaggle/input'):\n", + " for filename in filenames:\n", + " print(os.path.join(dirname, filename))\n", + "\n", + "# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", + "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", + "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", + "execution": { + "iopub.execute_input": "2020-10-23T07:01:24.069878Z", + "iopub.status.busy": "2020-10-23T07:01:24.068972Z", + "iopub.status.idle": "2020-10-23T07:01:25.495962Z", + "shell.execute_reply": "2020-10-23T07:01:25.494876Z" + }, + "papermill": { + "duration": 1.44331, + "end_time": "2020-10-23T07:01:25.496099", + "exception": false, + "start_time": "2020-10-23T07:01:24.052789", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------+\n", + "|\u001b[34;1mR\u001b[0m: |\u001b[43m \u001b[0m: :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + "\n" + ] + } + ], + "source": [ + "import gym\n", + "\n", + "env = gym.make(\"Taxi-v3\").env\n", + "\n", + "env.reset() # reset environment to a new, random state\n", + "env.render()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:01:25.526480Z", + "iopub.status.busy": "2020-10-23T07:01:25.525487Z", + "iopub.status.idle": "2020-10-23T07:01:25.530094Z", + "shell.execute_reply": "2020-10-23T07:01:25.529110Z" + }, + "papermill": { + "duration": 0.022214, + "end_time": "2020-10-23T07:01:25.530278", + "exception": false, + "start_time": "2020-10-23T07:01:25.508064", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Action Space Discrete(6)\n", + "State Space Discrete(500)\n" + ] + } + ], + "source": [ + "# action space\n", + "print(\"Action Space {}\".format(env.action_space))\n", + "\n", + "print(\"State Space {}\".format(env.observation_space))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:01:25.561002Z", + "iopub.status.busy": "2020-10-23T07:01:25.560068Z", + "iopub.status.idle": "2020-10-23T07:01:25.563851Z", + "shell.execute_reply": "2020-10-23T07:01:25.564479Z" + }, + "papermill": { + "duration": 0.022719, + "end_time": "2020-10-23T07:01:25.564666", + "exception": false, + "start_time": "2020-10-23T07:01:25.541947", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State: 328\n", + "+---------+\n", + "|\u001b[35mR\u001b[0m: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| |\u001b[43m \u001b[0m: | : |\n", + "|\u001b[34;1mY\u001b[0m| : |B: |\n", + "+---------+\n", + "\n" + ] + } + ], + "source": [ + "# state encoding\n", + "\n", + "state = env.encode(3, 1, 2, 0) # (taxi row, taxi column, passenger index, destination index)\n", + "print(\"State:\", state)\n", + "\n", + "env.s = state\n", + "env.render()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:01:25.599957Z", + "iopub.status.busy": "2020-10-23T07:01:25.599051Z", + "iopub.status.idle": "2020-10-23T07:01:25.604442Z", + "shell.execute_reply": "2020-10-23T07:01:25.603711Z" + }, + "papermill": { + "duration": 0.027519, + "end_time": "2020-10-23T07:01:25.604618", + "exception": false, + "start_time": "2020-10-23T07:01:25.577099", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: [(1.0, 428, -1, False)],\n", + " 1: [(1.0, 228, -1, False)],\n", + " 2: [(1.0, 348, -1, False)],\n", + " 3: [(1.0, 328, -1, False)],\n", + " 4: [(1.0, 328, -10, False)],\n", + " 5: [(1.0, 328, -10, False)]}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# rewards table\n", + "env.P[328]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:01:25.639997Z", + "iopub.status.busy": "2020-10-23T07:01:25.638924Z", + "iopub.status.idle": "2020-10-23T07:01:25.643756Z", + "shell.execute_reply": "2020-10-23T07:01:25.642988Z" + }, + "papermill": { + "duration": 0.025446, + "end_time": "2020-10-23T07:01:25.643881", + "exception": false, + "start_time": "2020-10-23T07:01:25.618435", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.]])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Defining q table\n", + "\n", + "import numpy as np\n", + "q_table = np.zeros([env.observation_space.n, env.action_space.n])\n", + "q_table" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:01:25.677831Z", + "iopub.status.busy": "2020-10-23T07:01:25.676776Z", + "iopub.status.idle": "2020-10-23T07:01:25.680343Z", + "shell.execute_reply": "2020-10-23T07:01:25.679723Z" + }, + "papermill": { + "duration": 0.022979, + "end_time": "2020-10-23T07:01:25.680493", + "exception": false, + "start_time": "2020-10-23T07:01:25.657514", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# hyperparameters\n", + "alpha = 0.1\n", + "gamma = 0.6\n", + "epsiolon = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:01:25.765603Z", + "iopub.status.busy": "2020-10-23T07:01:25.726090Z", + "iopub.status.idle": "2020-10-23T07:10:37.138058Z", + "shell.execute_reply": "2020-10-23T07:10:37.138886Z" + }, + "papermill": { + "duration": 551.443876, + "end_time": "2020-10-23T07:10:37.139105", + "exception": false, + "start_time": "2020-10-23T07:01:25.695229", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 999999 Iterations: 14810058\r\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "max_epochs = 1000000 # stops after x epochs or max_iterations\n", + "max_iterations = -1 # 1000000 # -1 to not stop based on iterations\n", + "\n", + "iteration = 0\n", + "for epoch in range(max_epochs):\n", + " done = False\n", + " new_state = env.reset()\n", + " \n", + " while not done and (iteration < max_iterations or max_iterations == -1):\n", + " iteration += 1\n", + " old_state = new_state\n", + " if random.uniform(0,1) < epsiolon:\n", + " action = env.action_space.sample()\n", + " else:\n", + " action = np.argmax(q_table[old_state])\n", + " \n", + " new_state, reward, done, info = env.step(action)\n", + " env.s = new_state # update state of the environment (not sure if )\n", + " \n", + " q_table[old_state, action] = q_table[old_state][action] + alpha * (reward + gamma * np.max(q_table[new_state]) - q_table[old_state][action])\n", + " \n", + " if iteration >= max_iterations and max_iterations != -1:\n", + " break\n", + "print(\"Epoch: {} Iterations: {}\\r\".format(epoch, iteration))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:10:37.176308Z", + "iopub.status.busy": "2020-10-23T07:10:37.175458Z", + "iopub.status.idle": "2020-10-23T07:10:37.179942Z", + "shell.execute_reply": "2020-10-23T07:10:37.180484Z" + }, + "papermill": { + "duration": 0.026804, + "end_time": "2020-10-23T07:10:37.180659", + "exception": false, + "start_time": "2020-10-23T07:10:37.153855", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0. , 0. , 0. , 0. ,\n", + " 0. , 0. ],\n", + " [ -2.41837066, -2.3639511 , -2.41837066, -2.3639511 ,\n", + " -2.27325184, -11.3639511 ],\n", + " [ -1.870144 , -1.45024 , -1.870144 , -1.45024 ,\n", + " -0.7504 , -10.45024 ],\n", + " ...,\n", + " [ -0.75955193, 0.416 , -0.75945955, -1.44283176,\n", + " -9.1295566 , -9.26328213],\n", + " [ -2.26879216, -2.1220864 , -2.24661946, -2.1220864 ,\n", + " -10.7816137 , -10.43290413],\n", + " [ 5.6 , 2.36 , 5.6 , 11. ,\n", + " -3.4 , -3.4 ]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q_table" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:10:37.220686Z", + "iopub.status.busy": "2020-10-23T07:10:37.219859Z", + "iopub.status.idle": "2020-10-23T07:10:37.228862Z", + "shell.execute_reply": "2020-10-23T07:10:37.229575Z" + }, + "papermill": { + "duration": 0.033495, + "end_time": "2020-10-23T07:10:37.229780", + "exception": false, + "start_time": "2020-10-23T07:10:37.196285", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| :\u001b[43m \u001b[0m: : : |\n", + "| | : | : |\n", + "|Y| : |\u001b[34;1mB\u001b[0m: |\n", + "+---------+\n", + " (East)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : :\u001b[43m \u001b[0m: : |\n", + "| | : | : |\n", + "|Y| : |\u001b[34;1mB\u001b[0m: |\n", + "+---------+\n", + " (East)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : :\u001b[43m \u001b[0m: |\n", + "| | : | : |\n", + "|Y| : |\u001b[34;1mB\u001b[0m: |\n", + "+---------+\n", + " (East)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : |\u001b[43m \u001b[0m: |\n", + "|Y| : |\u001b[34;1mB\u001b[0m: |\n", + "+---------+\n", + " (South)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |\u001b[34;1m\u001b[43mB\u001b[0m\u001b[0m: |\n", + "+---------+\n", + " (South)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |\u001b[42mB\u001b[0m: |\n", + "+---------+\n", + " (Pickup)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : |\u001b[42m_\u001b[0m: |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (North)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : :\u001b[42m_\u001b[0m: |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (North)\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | :\u001b[42m_\u001b[0m: |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (North)\n", + "+---------+\n", + "|R: | :\u001b[42m_\u001b[0m:\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (North)\n", + "+---------+\n", + "|R: | : :\u001b[35m\u001b[42mG\u001b[0m\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (East)\n", + "+---------+\n", + "|R: | : :\u001b[35m\u001b[34;1m\u001b[43mG\u001b[0m\u001b[0m\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (Dropoff)\n" + ] + } + ], + "source": [ + "# example game\n", + "\n", + "new_state = env.reset()\n", + "\n", + "done = False\n", + "max_iterations = 100\n", + "iteration = 0\n", + "while not done and (iteration <= max_iterations or max_iterations == -1):\n", + " old_state = new_state\n", + " action = np.argmax(q_table[old_state])\n", + " new_state, reward, done, info = env.step(action)\n", + " env.s = new_state\n", + " env.render()\n", + " iteration += 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-23T07:10:37.269410Z", + "iopub.status.busy": "2020-10-23T07:10:37.268454Z", + "iopub.status.idle": "2020-10-23T07:10:37.659363Z", + "shell.execute_reply": "2020-10-23T07:10:37.658561Z" + }, + "papermill": { + "duration": 0.413907, + "end_time": "2020-10-23T07:10:37.659523", + "exception": false, + "start_time": "2020-10-23T07:10:37.245616", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
IdValue
010.00
120.00
230.00
340.00
450.00
.........
299529962.36
299629975.60
2997299811.00
29982999-3.40
29993000-3.40
\n", + "

3000 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " Id Value\n", + "0 1 0.00\n", + "1 2 0.00\n", + "2 3 0.00\n", + "3 4 0.00\n", + "4 5 0.00\n", + "... ... ...\n", + "2995 2996 2.36\n", + "2996 2997 5.60\n", + "2997 2998 11.00\n", + "2998 2999 -3.40\n", + "2999 3000 -3.40\n", + "\n", + "[3000 rows x 2 columns]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#generate output file\n", + "df = pd.DataFrame(q_table.ravel())\n", + "\n", + "df.index += 1\n", + "df.to_csv(\"q_table.csv\", index_label=\"Id\", header=[\"Value\"])\n", + "df\n", + "\n", + "# load data again to test correct file format\n", + "pd.read_csv(\"q_table.csv\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "papermill": { + "duration": 559.137382, + "end_time": "2020-10-23T07:10:37.784395", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2020-10-23T07:01:18.647013", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}