Created
March 21, 2018 21:29
-
-
Save tonygentilcore/d41c8b631c96ed76c024983b73c8bece to your computer and use it in GitHub Desktop.
Titanic Kaggle w/ fast.ai
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%reload_ext autoreload\n", | |
| "%autoreload 2\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import sys\n", | |
| "sys.path.append(\"/home/tonygentilcore/fastai/courses/dl1/\")\n", | |
| "from fastai.structured import *\n", | |
| "from fastai.column_data import *\n", | |
| "np.set_printoptions(threshold=50, edgeitems=20)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "PATH = '/home/tonygentilcore/.kaggle/competitions/titanic'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "** Explore data **" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "gender_submission.csv models test.csv tmp train.csv\r\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!ls {PATH}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train = pd.read_csv(f'{PATH}/train.csv')\n", | |
| "test = pd.read_csv(f'{PATH}/test.csv')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>PassengerId</th>\n", | |
| " <th>Survived</th>\n", | |
| " <th>Pclass</th>\n", | |
| " <th>Age</th>\n", | |
| " <th>SibSp</th>\n", | |
| " <th>Parch</th>\n", | |
| " <th>Fare</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>count</th>\n", | |
| " <td>891.000000</td>\n", | |
| " <td>891.000000</td>\n", | |
| " <td>891.000000</td>\n", | |
| " <td>714.000000</td>\n", | |
| " <td>891.000000</td>\n", | |
| " <td>891.000000</td>\n", | |
| " <td>891.000000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>mean</th>\n", | |
| " <td>446.000000</td>\n", | |
| " <td>0.383838</td>\n", | |
| " <td>2.308642</td>\n", | |
| " <td>29.699118</td>\n", | |
| " <td>0.523008</td>\n", | |
| " <td>0.381594</td>\n", | |
| " <td>32.204208</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>std</th>\n", | |
| " <td>257.353842</td>\n", | |
| " <td>0.486592</td>\n", | |
| " <td>0.836071</td>\n", | |
| " <td>14.526497</td>\n", | |
| " <td>1.102743</td>\n", | |
| " <td>0.806057</td>\n", | |
| " <td>49.693429</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>min</th>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>0.420000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>25%</th>\n", | |
| " <td>223.500000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>2.000000</td>\n", | |
| " <td>20.125000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>7.910400</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>50%</th>\n", | |
| " <td>446.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>3.000000</td>\n", | |
| " <td>28.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>14.454200</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>75%</th>\n", | |
| " <td>668.500000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>3.000000</td>\n", | |
| " <td>38.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>0.000000</td>\n", | |
| " <td>31.000000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>max</th>\n", | |
| " <td>891.000000</td>\n", | |
| " <td>1.000000</td>\n", | |
| " <td>3.000000</td>\n", | |
| " <td>80.000000</td>\n", | |
| " <td>8.000000</td>\n", | |
| " <td>6.000000</td>\n", | |
| " <td>512.329200</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " PassengerId Survived Pclass Age SibSp \\\n", | |
| "count 891.000000 891.000000 891.000000 714.000000 891.000000 \n", | |
| "mean 446.000000 0.383838 2.308642 29.699118 0.523008 \n", | |
| "std 257.353842 0.486592 0.836071 14.526497 1.102743 \n", | |
| "min 1.000000 0.000000 1.000000 0.420000 0.000000 \n", | |
| "25% 223.500000 0.000000 2.000000 20.125000 0.000000 \n", | |
| "50% 446.000000 0.000000 3.000000 28.000000 0.000000 \n", | |
| "75% 668.500000 1.000000 3.000000 38.000000 1.000000 \n", | |
| "max 891.000000 1.000000 3.000000 80.000000 8.000000 \n", | |
| "\n", | |
| " Parch Fare \n", | |
| "count 891.000000 891.000000 \n", | |
| "mean 0.381594 32.204208 \n", | |
| "std 0.806057 49.693429 \n", | |
| "min 0.000000 0.000000 \n", | |
| "25% 0.000000 7.910400 \n", | |
| "50% 0.000000 14.454200 \n", | |
| "75% 0.000000 31.000000 \n", | |
| "max 6.000000 512.329200 " | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "train.describe()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>PassengerId</th>\n", | |
| " <th>Survived</th>\n", | |
| " <th>Pclass</th>\n", | |
| " <th>Name</th>\n", | |
| " <th>Sex</th>\n", | |
| " <th>Age</th>\n", | |
| " <th>SibSp</th>\n", | |
| " <th>Parch</th>\n", | |
| " <th>Ticket</th>\n", | |
| " <th>Fare</th>\n", | |
| " <th>Cabin</th>\n", | |
| " <th>Embarked</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Braund, Mr. Owen Harris</td>\n", | |
| " <td>male</td>\n", | |
| " <td>22.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>A/5 21171</td>\n", | |
| " <td>7.2500</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>2</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n", | |
| " <td>female</td>\n", | |
| " <td>38.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>PC 17599</td>\n", | |
| " <td>71.2833</td>\n", | |
| " <td>C85</td>\n", | |
| " <td>C</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>3</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Heikkinen, Miss. Laina</td>\n", | |
| " <td>female</td>\n", | |
| " <td>26.0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>STON/O2. 3101282</td>\n", | |
| " <td>7.9250</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>4</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n", | |
| " <td>female</td>\n", | |
| " <td>35.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>113803</td>\n", | |
| " <td>53.1000</td>\n", | |
| " <td>C123</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>5</td>\n", | |
| " <td>0</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Allen, Mr. William Henry</td>\n", | |
| " <td>male</td>\n", | |
| " <td>35.0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>373450</td>\n", | |
| " <td>8.0500</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " PassengerId Survived Pclass \\\n", | |
| "0 1 0 3 \n", | |
| "1 2 1 1 \n", | |
| "2 3 1 3 \n", | |
| "3 4 1 1 \n", | |
| "4 5 0 3 \n", | |
| "\n", | |
| " Name Sex Age SibSp \\\n", | |
| "0 Braund, Mr. Owen Harris male 22.0 1 \n", | |
| "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", | |
| "2 Heikkinen, Miss. Laina female 26.0 0 \n", | |
| "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", | |
| "4 Allen, Mr. William Henry male 35.0 0 \n", | |
| "\n", | |
| " Parch Ticket Fare Cabin Embarked \n", | |
| "0 0 A/5 21171 7.2500 NaN S \n", | |
| "1 0 PC 17599 71.2833 C85 C \n", | |
| "2 0 STON/O2. 3101282 7.9250 NaN S \n", | |
| "3 0 113803 53.1000 C123 S \n", | |
| "4 0 373450 8.0500 NaN S " | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "train.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>PassengerId</th>\n", | |
| " <th>Pclass</th>\n", | |
| " <th>Name</th>\n", | |
| " <th>Sex</th>\n", | |
| " <th>Age</th>\n", | |
| " <th>SibSp</th>\n", | |
| " <th>Parch</th>\n", | |
| " <th>Ticket</th>\n", | |
| " <th>Fare</th>\n", | |
| " <th>Cabin</th>\n", | |
| " <th>Embarked</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>892</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Kelly, Mr. James</td>\n", | |
| " <td>male</td>\n", | |
| " <td>34.5</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>330911</td>\n", | |
| " <td>7.8292</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>Q</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>893</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Wilkes, Mrs. James (Ellen Needs)</td>\n", | |
| " <td>female</td>\n", | |
| " <td>47.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>363272</td>\n", | |
| " <td>7.0000</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>894</td>\n", | |
| " <td>2</td>\n", | |
| " <td>Myles, Mr. Thomas Francis</td>\n", | |
| " <td>male</td>\n", | |
| " <td>62.0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>240276</td>\n", | |
| " <td>9.6875</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>Q</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>895</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Wirz, Mr. Albert</td>\n", | |
| " <td>male</td>\n", | |
| " <td>27.0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>315154</td>\n", | |
| " <td>8.6625</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>896</td>\n", | |
| " <td>3</td>\n", | |
| " <td>Hirvonen, Mrs. Alexander (Helga E Lindqvist)</td>\n", | |
| " <td>female</td>\n", | |
| " <td>22.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3101298</td>\n", | |
| " <td>12.2875</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>S</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " PassengerId Pclass Name Sex \\\n", | |
| "0 892 3 Kelly, Mr. James male \n", | |
| "1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n", | |
| "2 894 2 Myles, Mr. Thomas Francis male \n", | |
| "3 895 3 Wirz, Mr. Albert male \n", | |
| "4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n", | |
| "\n", | |
| " Age SibSp Parch Ticket Fare Cabin Embarked \n", | |
| "0 34.5 0 0 330911 7.8292 NaN Q \n", | |
| "1 47.0 1 0 363272 7.0000 NaN S \n", | |
| "2 62.0 0 0 240276 9.6875 NaN Q \n", | |
| "3 27.0 0 0 315154 8.6625 NaN S \n", | |
| "4 22.0 1 1 3101298 12.2875 NaN S " | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "test.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "** Prepare data **" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "index = 'PassengerId'\n", | |
| "dep = 'Survived'\n", | |
| "cat_vars = ['Pclass', 'Sex', 'SibSp', 'Parch', 'Embarked']\n", | |
| "contin_vars = ['Age']\n", | |
| "drop_vars = ['Name', 'Ticket', 'Cabin', 'Fare']\n", | |
| "\n", | |
| "test.set_index(index)\n", | |
| "train.set_index(index)\n", | |
| "\n", | |
| "for v in cat_vars:\n", | |
| " test[v] = test[v].astype('category').cat.as_ordered()\n", | |
| " train[v] = train[v].astype('category').cat.as_ordered()\n", | |
| "\n", | |
| "for v in contin_vars:\n", | |
| " test[v] = test[v].astype('float32')\n", | |
| " train[v] = train[v].astype('float32')\n", | |
| " \n", | |
| "for v in drop_vars:\n", | |
| " if v in test:\n", | |
| " test.drop(v, axis=1, inplace=True)\n", | |
| " train.drop(v, axis=1, inplace=True)\n", | |
| "\n", | |
| "test[dep] = np.nan\n", | |
| " \n", | |
| "apply_cats(test, train)\n", | |
| "\n", | |
| "df, y, nas, mapper = proc_df(train, dep, do_scale=True, skip_flds=[index])\n", | |
| "df_test, _, nas, mapper = proc_df(test, dep, do_scale=True, skip_flds=[index], mapper=mapper, na_dict=nas)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>Pclass</th>\n", | |
| " <th>Sex</th>\n", | |
| " <th>Age</th>\n", | |
| " <th>SibSp</th>\n", | |
| " <th>Parch</th>\n", | |
| " <th>Embarked</th>\n", | |
| " <th>Age_na</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>3</td>\n", | |
| " <td>2</td>\n", | |
| " <td>-0.565736</td>\n", | |
| " <td>2</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0.663861</td>\n", | |
| " <td>2</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>3</td>\n", | |
| " <td>1</td>\n", | |
| " <td>-0.258337</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0.433312</td>\n", | |
| " <td>2</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>3</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0.433312</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " Pclass Sex Age SibSp Parch Embarked Age_na\n", | |
| "0 3 2 -0.565736 2 1 3 -0.497895\n", | |
| "1 1 1 0.663861 2 1 1 -0.497895\n", | |
| "2 3 1 -0.258337 1 1 3 -0.497895\n", | |
| "3 1 1 0.433312 2 1 3 -0.497895\n", | |
| "4 3 2 0.433312 1 1 3 -0.497895" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "df.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>Pclass</th>\n", | |
| " <th>Sex</th>\n", | |
| " <th>Age</th>\n", | |
| " <th>SibSp</th>\n", | |
| " <th>Parch</th>\n", | |
| " <th>Embarked</th>\n", | |
| " <th>Age_na</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>3</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0.394887</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>2</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>3</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1.355510</td>\n", | |
| " <td>2</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>2</td>\n", | |
| " <td>2</td>\n", | |
| " <td>2.508257</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>2</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>3</td>\n", | |
| " <td>2</td>\n", | |
| " <td>-0.181487</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>3</td>\n", | |
| " <td>1</td>\n", | |
| " <td>-0.565736</td>\n", | |
| " <td>2</td>\n", | |
| " <td>2</td>\n", | |
| " <td>3</td>\n", | |
| " <td>-0.497895</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " Pclass Sex Age SibSp Parch Embarked Age_na\n", | |
| "0 3 2 0.394887 1 1 2 -0.497895\n", | |
| "1 3 1 1.355510 2 1 3 -0.497895\n", | |
| "2 2 2 2.508257 1 1 2 -0.497895\n", | |
| "3 3 2 -0.181487 1 1 3 -0.497895\n", | |
| "4 3 1 -0.565736 2 2 3 -0.497895" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "df_test.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "** Create model/learner **" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "cat_sz = [(c, len(train[c].cat.categories)+1) for c in cat_vars]\n", | |
| "emb_szs = [(c, min(50, (c+1)//2)) for _,c in cat_sz]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n = len(train)\n", | |
| "val_idxs = get_cv_idxs(n)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "md = ColumnarModelData.from_data_frame(PATH, val_idxs, df, y.astype(np.float32),\n", | |
| " cat_flds=cat_vars, bs=128, test_df=df_test)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n", | |
| " 0.04, 1, [1000,500], [0.001,0.01], y_range=[0, 1])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "** Train **" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "8b2f29c343ee4ceb9de09bac7e4074dd", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/html": [ | |
| "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n", | |
| "<p>\n", | |
| " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", | |
| " that the widgets JavaScript is still loading. If this message persists, it\n", | |
| " likely means that the widgets JavaScript library is either not installed or\n", | |
| " not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n", | |
| " Widgets Documentation</a> for setup instructions.\n", | |
| "</p>\n", | |
| "<p>\n", | |
| " If you're reading this message in another frontend (for example, a static\n", | |
| " rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n", | |
| " it may mean that your frontend doesn't currently support widgets.\n", | |
| "</p>\n" | |
| ], | |
| "text/plain": [ | |
| "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "epoch trn_loss val_loss \n", | |
| " 0 0.26849 0.589888 \n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAEOCAYAAACuOOGFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAEz5JREFUeJzt3X2wXHV9x/H3B1KxVoyAASEhhhasE8enumCp2qFVIjrVqFCJbW1U2vSJanVqi7UdFZ2KT7U+1whWaqug0GIEa3hQqm0VcoNSCEpJox1SUKJBCqIw6Ld/7EGW2725m9zfvZubvF8zO3vO7/zOOd+998753N85u2dTVUiSNFP7jLsASdKewUCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDWxYNwFzKWHPexhtWzZsnGXIUnzysaNG79dVYum67dXBcqyZcuYmJgYdxmSNK8k+e9R+nnKS5LUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITYw2UJCckuT7J5iSnDVm+X5Jzu+VXJFk2afnSJHck+eO5qlmSNNzYAiXJvsB7gWcCy4EXJlk+qdspwK1VdSTwDuDNk5a/A/jn2a5VkjS9cY5QjgE2V9WWqrobOAdYOanPSuDsbvo84GlJApDkucAWYNMc1StJ2oFxBspi4MaB+a1d29A+VXUPcBtwUJKfAv4UeP0c1ClJGsE4AyVD2mrEPq8H3lFVd0y7k2RNkokkE9u2bduFMiVJo1gwxn1vBQ4fmF8C3DRFn61JFgALge3Ak4CTkrwFeCjwoyQ/qKr3TN5JVa0F1gL0er3JgSVJamScgbIBOCrJEcD/AKuAX5vUZx2wGvgicBLw2aoq4Kn3dkjyOuCOYWEiSZo7YwuUqronyanAemBf4ENVtSnJ6cBEVa0DzgI+kmQz/ZHJqnHVK0nasfT/4d879Hq9mpiYGHcZkjSvJNlYVb3p+vlJeUlSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNjDVQkpyQ5Pokm5OcNmT5fknO7ZZfkWRZ1358ko1Jrumef3mua5ck3d/YAiXJvsB7gWcCy4EXJlk+qdspwK1VdSTwDuDNXfu3gWdX1WOA1cBH5qZqSdJUxjlCOQbYXFVbqupu4Bxg5aQ+K4Gzu+nzgKclSVV9uapu6to3AQ9Mst+cVC1JGmqcgbIYuHFgfmvXNrRPVd0D3AYcNKnPicCXq+quWapTkjSCBWPcd4a01c70SfJo+qfBVky5k2QNsAZg6dKlO1+lJGkk4xyhbAUOH5hfAtw0VZ8kC4CFwPZufgnwT8BvVtV/TbWTqlpbVb2q6i1atKhh+ZKkQeMMlA3AUUmOSPIAYBWwblKfdfQvugOcBHy2qirJQ4GLgFdX1b/NWcWSpCmNLVC6ayKnAuuBrwIfr6pNSU5P8pyu21nAQUk2A68E7n1r8anAkcBfJPlK9zh4jl+CJGlAqiZftthz9Xq9mpiYGHcZkjSvJNlYVb3p+vlJeUlSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCiSpCYMFElSEwaKJKkJA0WS1MRIgZLk5Ukekr6zklyVZMVsFydJmj9GHaG8tKr+F1gBLAJeApwxa1VJkuadUQMl3fOzgL+tqqsH2iRJGjlQNia5mH6grE+yP/Cjme48yQlJrk+yOclpQ5bvl+TcbvkVSZYNLHt11359kmfMtBZJ0swsGLHfKcDjgS1VdWeSA+mf9tplSfYF3gscD2wFNiRZV1XXTdrvrVV1ZJJVwJuBk5MsB1YBjwYOAy5N8siq+uFMapIk7bpRRyjHAtdX1XeT/Abw58BtM9z3McDmqtpSVXcD5wArJ/VZCZzdTZ8HPC1JuvZzququqvo6sLnbniRpTEYNlPcDdyZ5HPAnwH8DfzfDfS8GbhyY39q1De1TVffQD7GDRlxXkjSHRg2Ue6qq6I8M3llV7wT2n+G+h13UrxH7jLJufwPJmiQTSSa2bdu2kyVKkkY1aqDcnuTVwIuAi7rrHz8xw31vBQ4fmF8C3DRVnyQLgIXA9hHXBaCq1lZVr6p6ixYtmmHJkqSpjBooJwN30f88yjfpn1566wz3vQE4KskRSR5A/yL7ukl91gGru+mTgM92I6V1wKruXWBHAEcBV86wHknSDIz0Lq+q+maSfwCOTvIrwJVVNaNrKFV1T5JTgfXAvsCHqmpTktOBiapaB5wFfCTJZvojk1XdupuSfBy4DrgH+APf4SVJ45X+P/zTdEpeQH9Ecjn96xdPBV5VVefNanWN9Xq9mpiYGHcZkjSvJNlYVb3p+o36OZTXAEdX1S3dxhcBl9J/K68kSSNfQ9nn3jDpfGcn1pUk7QVGHaF8Jsl64GPd/MnAp2enJEnSfDTqRflXJTkReDL9ayhrq+qfZrUySdK8MuoIhao6Hzh/FmuRJM1jOwyUJLcz/BPoAaqqHjIrVUmS5p0dBkpVzfT2KpKkvYTv1JIkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWrCQJEkNWGgSJKaMFAkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWpiLIGS5MAklyS5oXs+YIp+q7s+NyRZ3bU9KMlFSb6WZFOSM+a2eknSMOMaoZwGXFZVRwGXdfP3k+RA4LXAk4BjgNcOBM/bqupRwBOAJyd55tyULUmayrgCZSVwdjd9NvDcIX2eAVxSVdur6lbgEuCEqrqzqj4HUFV3A1cBS+agZknSDowrUA6pqpsBuueDh/RZDNw4ML+1a/uxJA8Fnk1/lCNJGqMFs7XhJJcCDx+y6DWjbmJIWw1sfwHwMeBdVbVlB3WsAdYALF26dMRdS5J21qwFSlU9faplSb6V5NCqujnJocAtQ7ptBY4bmF8CXD4wvxa4oar+epo61nZ96fV6taO+kqRdN65TXuuA1d30auCTQ/qsB1YkOaC7GL+iayPJG4GFwB/NQa2SpBGMK1DOAI5PcgNwfDdPkl6SMwGqajvwBmBD9zi9qrYnWUL/tNly4KokX0nyW+N4EZKk+6Rq7zkL1Ov1amJiYtxlSNK8kmRjVfWm6+cn5SVJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJamIsgZLkwCSXJLmhez5gin6ruz43JFk9ZPm6JNfOfsWSpOmMa4RyGnBZVR0FXNbN30+SA4HXAk8CjgFeOxg8SZ4P3DE35UqSpjOuQFkJnN1Nnw08d0ifZwCXVNX2qroVuAQ4ASDJg4FXAm+cg1olSSMYV6AcUlU3A3TPBw/psxi4cWB+a9cG8Abg7cCds1mkJGl0C2Zrw0kuBR4+ZNFrRt3EkLZK8njgyKp6RZJlI9SxBlgDsHTp0hF3LUnaWbMWKFX19KmWJflWkkOr6uYkhwK3DOm2FThuYH4JcDlwLPDEJN+gX//BSS6vquMYoqrWAmsBer1e7fwrkSSNYlynvNYB975razXwySF91gMrkhzQXYxfAayvqvdX1WFVtQx4CvCfU4WJJGnujCtQzgCOT3IDcHw3T5JekjMBqmo7/WslG7rH6V2bJGk3lKq95yxQr9eriYmJcZchSfNKko1V1Zuun5+UlyQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1YaBIkpowUCRJTRgokqQmDBRJUhMGiiSpCQNFktSEgSJJasJAkSQ1kaoadw1zJsk24LvAbbuw+sOAb7etSDuwkF37Pe3OdtfXNK66Znu/rbffansz2c6urjvT49cjqmrRdJ32qkABSLK2qtbswnoTVdWbjZr0/+3q72l3tru+pnHVNdv7bb39VtubyXZ29+PX3njK61PjLkAj2RN/T7vraxpXXbO939bbb7W9mWxnd/0bAvbCEcqucoQiab5yhLL7WTvuAiRpF83J8csRiiSpCUcokqQmDBRJUhMGiiSpCQNlFyX5qSRnJ/lgkl8fdz2SNKokP53krCTntdyugTIgyYeS3JLk2kntJyS5PsnmJKd1zc8Hzquq3waeM+fFStKAnTl+VdWWqjqldQ0Gyv19GDhhsCHJvsB7gWcCy4EXJlkOLAFu7Lr9cA5rlKRhPszox69ZYaAMqKrPA9snNR8DbO4S/W7gHGAlsJV+qIA/R0ljtpPHr1nhgXB6i7lvJAL9IFkM/CNwYpL3s5vfDkHSXmvo8SvJQUn+BnhCkle32tmCVhvag2VIW1XV94CXzHUxkrQTpjp+fQf43dY7c4Qyva3A4QPzS4CbxlSLJO2MOT1+GSjT2wAcleSIJA8AVgHrxlyTJI1iTo9fBsqAJB8Dvgj8bJKtSU6pqnuAU4H1wFeBj1fVpnHWKUmT7Q7HL28OKUlqwhGKJKkJA0WS1ISBIklqwkCRJDVhoEiSmjBQJElNGCjabSW5Yw728ZyBrySYE0mOS/ILu7DeE5Kc2U2/OMl72le385Ism3zL9CF9FiX5zFzVpPEwULTH627hPVRVrauqM2Zhnzu6T95xwE4HCvBnwLt3qaAxq6ptwM1JnjzuWjR7DBTNC0lelWRDkv9I8vqB9guSbEyyKcmagfY7kpye5Arg2CTfSPL6JFcluSbJo7p+P/5PP8mHk7wryb8n2ZLkpK59nyTv6/ZxYZJP37tsUo2XJ/nLJP8CvDzJs5NckeTLSS5NckiSZfRvyveKJF9J8tTuv/fzu9e3YdhBN8n+wGOr6uohyx6R5LLuZ3NZkqVd+88k+VK3zdOHjfi6bx69KMnVSa5NcnLXfnT3c7g6yZVJ9u9GIl/ofoZXDRtlJdk3yVsHfle/M7D4AsBvN92TVZUPH7vlA7ije14BrKV/59R9gAuBX+yWHdg9/yRwLXBQN1/ACwa29Q3gD7vp3wfO7KZfDLynm/4w8IluH8vpf48EwEnAp7v2hwO3AicNqfdy4H0D8wdw390ofgt4ezf9OuCPB/p9FHhKN70U+OqQbf8ScP7A/GDdnwJWd9MvBS7opi8EXthN/+69P89J2z0R+ODA/ELgAcAW4Oiu7SH070z+IOCBXdtRwEQ3vQy4tpteA/x5N70fMAEc0c0vBq4Z99+Vj9l7ePt6zQcruseXu/kH0z+gfR54WZLnde2Hd+3fof8tmudP2s4/ds8b6X+F8zAXVNWPgOuSHNK1PQX4RNf+zSSf20Gt5w5MLwHOTXIo/YP016dY5+nA8uTHdxp/SJL9q+r2gT6HAtumWP/YgdfzEeAtA+3P7aY/CrxtyLrXAG9L8mbgwqr6QpLHADdX1QaAqvpf6I9mgPckeTz9n+8jh2xvBfDYgRHcQvq/k68DtwCHTfEatAcwUDQfBHhTVX3gfo3JcfQPxsdW1Z1JLgce2C3+QVVN/mrmu7rnHzL13/5dA9OZ9DyK7w1Mvxv4q6pa19X6uinW2Yf+a/j+Drb7fe57bdMZ+QZ9VfWfSZ4IPAt4U5KL6Z+aGraNVwDfAh7X1fyDIX1CfyS4fsiyB9J/HdpDeQ1F88F64KVJHgyQZHGSg+n/93trFyaPAn5+lvb/r/S/nXOfbtRy3IjrLQT+p5tePdB+O7D/wPzF9O8IC0A3Apjsq8CRU+zn3+nflhz61yj+tZv+Ev1TWgwsv58khwF3VtXf0x/B/BzwNeCwJEd3ffbv3mSwkP7I5UfAi4Bhb3ZYD/xekp/o1n1kN7KB/ohmh+8G0/xmoGi3V1UX0z9l88Uk1wDn0T8gfwZYkOQ/gDfQP4DOhvPpf1HRtcAHgCuA20ZY73XAJ5J8Afj2QPungOfde1EeeBnQ6y5iX8eQb9Krqq8BC7uL85O9DHhJ93N4EfDyrv2PgFcmuZL+KbNhNT8GuDLJV4DXAG+s/nePnwy8O8nVwCX0RxfvA1Yn+RL9cPjekO2dCVwHXNW9lfgD3Dca/CXgoiHraA/h7eulESR5cFXdkeQg4ErgyVX1zTmu4RXA7VV15oj9HwR8v6oqySr6F+hXzmqRO67n88DKqrp1XDVodnkNRRrNhUkeSv/i+hvmOkw67wd+dSf6P5H+RfQA36X/DrCxSLKI/vUkw2QP5ghFktSE11AkSU0YKJKkJgwUSVITBookqQkDRZLUhIEiSWri/wAuNq8iO7eWkwAAAABJRU5ErkJggg==\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "m.lr_find()\n", | |
| "m.sched.plot(100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "lr = 1e-3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "04b6976ddcbd4800a815ff98ee348383", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/html": [ | |
| "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n", | |
| "<p>\n", | |
| " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", | |
| " that the widgets JavaScript is still loading. If this message persists, it\n", | |
| " likely means that the widgets JavaScript library is either not installed or\n", | |
| " not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n", | |
| " Widgets Documentation</a> for setup instructions.\n", | |
| "</p>\n", | |
| "<p>\n", | |
| " If you're reading this message in another frontend (for example, a static\n", | |
| " rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n", | |
| " it may mean that your frontend doesn't currently support widgets.\n", | |
| "</p>\n" | |
| ], | |
| "text/plain": [ | |
| "HBox(children=(IntProgress(value=0, description='Epoch', max=28), HTML(value='')))" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "epoch trn_loss val_loss \n", | |
| " 0 0.29958 0.289071 \n", | |
| " 1 0.266482 0.225575 \n", | |
| " 2 0.243882 0.230261 \n", | |
| " 3 0.230642 0.224803 \n", | |
| " 4 0.220913 0.205747 \n", | |
| " 5 0.210903 0.209014 \n", | |
| " 6 0.202561 0.203815 \n", | |
| " 7 0.194041 0.199481 \n", | |
| " 8 0.18833 0.194539 \n", | |
| " 9 0.182456 0.188606 \n", | |
| " 10 0.177253 0.184244 \n", | |
| " 11 0.173428 0.183238 \n", | |
| " 12 0.169191 0.14376 \n", | |
| " 13 0.164396 0.151117 \n", | |
| " 14 0.159777 0.136738 \n", | |
| " 15 0.155043 0.135023 \n", | |
| " 16 0.151094 0.135569 \n", | |
| " 17 0.147484 0.138293 \n", | |
| " 18 0.144015 0.136753 \n", | |
| " 19 0.141021 0.138247 \n", | |
| " 20 0.137857 0.136863 \n", | |
| " 21 0.135251 0.137623 \n", | |
| " 22 0.133364 0.137839 \n", | |
| " 23 0.13165 0.137135 \n", | |
| " 24 0.129502 0.137729 \n", | |
| " 25 0.127662 0.137821 \n", | |
| " 26 0.126263 0.137893 \n", | |
| " 27 0.124427 0.137869 \n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[0.13786882]" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "m.fit(lr, 3, cycle_len=4, cycle_mult=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m.save('val0')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m.load('val0')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "** Prepare submission **" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x,y=m.predict_with_targs()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "178" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "pred_test = m.predict(True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "test[dep] = pred_test" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "csv_fn=f'{PATH}/tmp/sub.csv'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>PassengerId</th>\n", | |
| " <th>Pclass</th>\n", | |
| " <th>Sex</th>\n", | |
| " <th>Age</th>\n", | |
| " <th>SibSp</th>\n", | |
| " <th>Parch</th>\n", | |
| " <th>Embarked</th>\n", | |
| " <th>Survived</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>892</td>\n", | |
| " <td>3</td>\n", | |
| " <td>male</td>\n", | |
| " <td>34.5</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>Q</td>\n", | |
| " <td>0.124961</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>893</td>\n", | |
| " <td>3</td>\n", | |
| " <td>female</td>\n", | |
| " <td>47.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>S</td>\n", | |
| " <td>0.326619</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>894</td>\n", | |
| " <td>2</td>\n", | |
| " <td>male</td>\n", | |
| " <td>62.0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>Q</td>\n", | |
| " <td>0.191005</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>895</td>\n", | |
| " <td>3</td>\n", | |
| " <td>male</td>\n", | |
| " <td>27.0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>S</td>\n", | |
| " <td>0.127235</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>896</td>\n", | |
| " <td>3</td>\n", | |
| " <td>female</td>\n", | |
| " <td>22.0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>S</td>\n", | |
| " <td>0.355572</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " PassengerId Pclass Sex Age SibSp Parch Embarked Survived\n", | |
| "0 892 3 male 34.5 0 0 Q 0.124961\n", | |
| "1 893 3 female 47.0 1 0 S 0.326619\n", | |
| "2 894 2 male 62.0 0 0 Q 0.191005\n", | |
| "3 895 3 male 27.0 0 0 S 0.127235\n", | |
| "4 896 3 female 22.0 1 1 S 0.355572" | |
| ] | |
| }, | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "test.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>PassengerId</th>\n", | |
| " <th>Survived</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>892</td>\n", | |
| " <td>0.124961</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>893</td>\n", | |
| " <td>0.326619</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>894</td>\n", | |
| " <td>0.191005</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>895</td>\n", | |
| " <td>0.127235</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>896</td>\n", | |
| " <td>0.355572</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " PassengerId Survived\n", | |
| "0 892 0.124961\n", | |
| "1 893 0.326619\n", | |
| "2 894 0.191005\n", | |
| "3 895 0.127235\n", | |
| "4 896 0.355572" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sub = test[[index, dep]]\n", | |
| "sub.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<a href='/home/tonygentilcore/.kaggle/competitions/titanic/tmp/sub.csv' target='_blank'>/home/tonygentilcore/.kaggle/competitions/titanic/tmp/sub.csv</a><br>" | |
| ], | |
| "text/plain": [ | |
| "/home/tonygentilcore/.kaggle/competitions/titanic/tmp/sub.csv" | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sub.to_csv(csv_fn, index=False)\n", | |
| "FileLink(csv_fn)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment