Last active
October 1, 2020 03:23
-
-
Save brightsparc/87cb98382e0d81cada6c565a8ffb43ac to your computer and use it in GitHub Desktop.
Athena SQL Model for NYC Taxi Data Set (us-east-1)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Athena SQL Model\n", | |
| "\n", | |
| "This example will create an athena table for [Jan 2017 taxi dataset](https://aws.amazon.com/blogs/big-data/build-a-data-lake-foundation-with-aws-glue-and-amazon-s3/). You can improve performance if you convert into a parquet format.\n", | |
| "\n", | |
| "Configure your notebook role with permissions to [query data from athena](https://aws.amazon.com/blogs/machine-learning/run-sql-queries-from-your-sagemaker-notebooks-using-amazon-athena/) and access the s3 staging bucket." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Install libraries\n", | |
| "\n", | |
| "Install the [Athena library](https://pypi.org/project/PyAthena/) for python and [tqdm](https://tqdm.github.io/)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[33mWARNING: You are using pip version 19.3.1; however, version 20.0.2 is available.\n", | |
| "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n", | |
| "\u001b[33mWARNING: You are using pip version 19.3.1; however, version 20.0.2 is available.\n", | |
| "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import sys\n", | |
| "!{sys.executable} -m pip install -q PyAthena\n", | |
| "!{sys.executable} -m pip install -q tqdm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Import Data\n", | |
| "\n", | |
| "Create an anthena database and external table for the imported nyc bit dataset." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "s3 staging dir: s3://sagemaker-us-east-1-691313291965/athena\n", | |
| "athena table: nyc_taxi.taxi_csv\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import boto3\n", | |
| "import sagemaker\n", | |
| "\n", | |
| "# Initialize the boto session in us-east-1 region\n", | |
| "boto_session = boto3.session.Session(region_name='us-east-1')\n", | |
| "region = boto_session.region_name\n", | |
| "bucket = sagemaker.session.Session(boto_session).default_bucket()\n", | |
| "\n", | |
| "# Get the athena staging dir andtable\n", | |
| "s3_staging_dir = 's3://{}/athena'.format(bucket)\n", | |
| "db_name = 'nyc_taxi'\n", | |
| "table_name = '{}.taxi_csv'.format(db_name)\n", | |
| "\n", | |
| "print('s3 staging dir: {}'.format(s3_staging_dir))\n", | |
| "print('athena table: {}'.format(table_name))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Make the bucket if it doesn't exist" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "make_bucket: sagemaker-us-east-1-691313291965\r\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!aws s3 mb s3://$bucket --region $region" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Query the nyc taxi dataset using [PandasCursor](https://pypi.org/project/PyAthena/#pandascursor) for improved performance" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from pyathena import connect\n", | |
| "from pyathena.pandas_cursor import PandasCursor\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "cursor = connect(s3_staging_dir=s3_staging_dir,\n", | |
| " region_name=region,\n", | |
| " cursor_class=PandasCursor).cursor()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Status: SUCCEEDED, Run time: 0.27s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "sql_ddl_create_table = 'CREATE DATABASE IF NOT EXISTS {};'.format(db_name)\n", | |
| "\n", | |
| "cursor.execute(sql_ddl_create_table)\n", | |
| "print('Status: {}, Run time: {:.2f}s'.format(cursor.state, \n", | |
| " cursor.execution_time_in_millis/1000.0))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Status: SUCCEEDED, Run time: 0.55s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "sql_create_table = '''\n", | |
| "CREATE EXTERNAL TABLE IF NOT EXISTS `{}` (\n", | |
| " `vendorid` bigint, \n", | |
| " `lpep_pickup_datetime` string, \n", | |
| " `lpep_dropoff_datetime` string, \n", | |
| " `store_and_fwd_flag` string, \n", | |
| " `ratecodeid` bigint, \n", | |
| " `pulocationid` bigint, \n", | |
| " `dolocationid` bigint, \n", | |
| " `passenger_count` bigint, \n", | |
| " `trip_distance` double, \n", | |
| " `fare_amount` double, \n", | |
| " `extra` double, \n", | |
| " `mta_tax` double, \n", | |
| " `tip_amount` double, \n", | |
| " `tolls_amount` double, \n", | |
| " `ehail_fee` string, \n", | |
| " `improvement_surcharge` double, \n", | |
| " `total_amount` double, \n", | |
| " `payment_type` bigint, \n", | |
| " `trip_type` bigint)\n", | |
| "ROW FORMAT DELIMITED \n", | |
| " FIELDS TERMINATED BY ',' \n", | |
| "STORED AS INPUTFORMAT \n", | |
| " 'org.apache.hadoop.mapred.TextInputFormat' \n", | |
| "OUTPUTFORMAT \n", | |
| " 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'\n", | |
| "LOCATION\n", | |
| " 's3://aws-bigdata-blog/artifacts/glue-data-lake/data/'\n", | |
| "TBLPROPERTIES (\n", | |
| " 'columnsOrdered'='true', \n", | |
| " 'compressionType'='none', \n", | |
| " 'skip.header.line.count'='1')\n", | |
| "'''.format(table_name)\n", | |
| "\n", | |
| "cursor.execute(sql_create_table)\n", | |
| "print('Status: {}, Run time: {:.2f}s'.format(cursor.state, \n", | |
| " cursor.execution_time_in_millis/1000.0))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Querying... \n", | |
| "SELECT \n", | |
| " total_amount, fare_amount, lpep_pickup_datetime, lpep_dropoff_datetime, trip_distance \n", | |
| "FROM nyc_taxi.taxi_csv;\n", | |
| "\n", | |
| "Status: SUCCEEDED, Run time: 6.09s, Data scanned: 91.34MB, Records: 1,070,262\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_amount</th>\n", | |
| " <th>fare_amount</th>\n", | |
| " <th>lpep_pickup_datetime</th>\n", | |
| " <th>lpep_dropoff_datetime</th>\n", | |
| " <th>trip_distance</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>20.30</td>\n", | |
| " <td>16.5</td>\n", | |
| " <td>2017-01-22 21:49:27</td>\n", | |
| " <td>2017-01-22 22:07:02</td>\n", | |
| " <td>4.74</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>26.16</td>\n", | |
| " <td>20.5</td>\n", | |
| " <td>2017-01-22 21:52:32</td>\n", | |
| " <td>2017-01-22 22:15:40</td>\n", | |
| " <td>5.56</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>10.56</td>\n", | |
| " <td>7.5</td>\n", | |
| " <td>2017-01-22 21:07:23</td>\n", | |
| " <td>2017-01-22 21:14:19</td>\n", | |
| " <td>1.61</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>12.96</td>\n", | |
| " <td>9.5</td>\n", | |
| " <td>2017-01-22 21:37:01</td>\n", | |
| " <td>2017-01-22 21:46:48</td>\n", | |
| " <td>2.28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>11.16</td>\n", | |
| " <td>8.0</td>\n", | |
| " <td>2017-01-22 21:55:06</td>\n", | |
| " <td>2017-01-22 22:03:13</td>\n", | |
| " <td>1.71</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_amount fare_amount lpep_pickup_datetime lpep_dropoff_datetime \\\n", | |
| "0 20.30 16.5 2017-01-22 21:49:27 2017-01-22 22:07:02 \n", | |
| "1 26.16 20.5 2017-01-22 21:52:32 2017-01-22 22:15:40 \n", | |
| "2 10.56 7.5 2017-01-22 21:07:23 2017-01-22 21:14:19 \n", | |
| "3 12.96 9.5 2017-01-22 21:37:01 2017-01-22 21:46:48 \n", | |
| "4 11.16 8.0 2017-01-22 21:55:06 2017-01-22 22:03:13 \n", | |
| "\n", | |
| " trip_distance \n", | |
| "0 4.74 \n", | |
| "1 5.56 \n", | |
| "2 1.61 \n", | |
| "3 2.28 \n", | |
| "4 1.71 " | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "data_sql = '''\n", | |
| "SELECT \n", | |
| " total_amount, fare_amount, lpep_pickup_datetime, lpep_dropoff_datetime, trip_distance \n", | |
| "FROM {};\n", | |
| "'''.format(table_name)\n", | |
| "print('Querying...', data_sql)\n", | |
| "\n", | |
| "data_df = cursor.execute(data_sql).as_pandas()\n", | |
| "print('Status: {}, Run time: {:.2f}s, Data scanned: {:.2f}MB, Records: {:,}'.format(cursor.state, \n", | |
| " cursor.execution_time_in_millis/1000.0, cursor.data_scanned_in_bytes/1024.0/1024.0, data_df.shape[0]))\n", | |
| "\n", | |
| "data_df.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Performance some simple feature engineering" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Add some date features\n", | |
| "data_df['lpep_pickup_datetime'] = data_df['lpep_pickup_datetime'].astype('datetime64[ns]')\n", | |
| "data_df['lpep_dropoff_datetime'] = data_df['lpep_dropoff_datetime'].astype('datetime64[ns]')\n", | |
| "data_df['duration_minutes'] = (data_df['lpep_dropoff_datetime'] - data_df['lpep_pickup_datetime']).dt.seconds/60\n", | |
| "data_df['hour_of_day'] = data_df['lpep_pickup_datetime'].dt.hour\n", | |
| "data_df['day_of_week'] = data_df['lpep_pickup_datetime'].dt.dayofweek\n", | |
| "data_df['week_of_year'] = data_df['lpep_pickup_datetime'].dt.weekofyear\n", | |
| "data_df['month_of_year'] = data_df['lpep_pickup_datetime'].dt.month" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(1046381, 10)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_amount</th>\n", | |
| " <th>fare_amount</th>\n", | |
| " <th>lpep_pickup_datetime</th>\n", | |
| " <th>lpep_dropoff_datetime</th>\n", | |
| " <th>trip_distance</th>\n", | |
| " <th>duration_minutes</th>\n", | |
| " <th>hour_of_day</th>\n", | |
| " <th>day_of_week</th>\n", | |
| " <th>week_of_year</th>\n", | |
| " <th>month_of_year</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>20.30</td>\n", | |
| " <td>16.5</td>\n", | |
| " <td>2017-01-22 21:49:27</td>\n", | |
| " <td>2017-01-22 22:07:02</td>\n", | |
| " <td>4.74</td>\n", | |
| " <td>17.583333</td>\n", | |
| " <td>21.0</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>26.16</td>\n", | |
| " <td>20.5</td>\n", | |
| " <td>2017-01-22 21:52:32</td>\n", | |
| " <td>2017-01-22 22:15:40</td>\n", | |
| " <td>5.56</td>\n", | |
| " <td>23.133333</td>\n", | |
| " <td>21.0</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>10.56</td>\n", | |
| " <td>7.5</td>\n", | |
| " <td>2017-01-22 21:07:23</td>\n", | |
| " <td>2017-01-22 21:14:19</td>\n", | |
| " <td>1.61</td>\n", | |
| " <td>6.933333</td>\n", | |
| " <td>21.0</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>12.96</td>\n", | |
| " <td>9.5</td>\n", | |
| " <td>2017-01-22 21:37:01</td>\n", | |
| " <td>2017-01-22 21:46:48</td>\n", | |
| " <td>2.28</td>\n", | |
| " <td>9.783333</td>\n", | |
| " <td>21.0</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>11.16</td>\n", | |
| " <td>8.0</td>\n", | |
| " <td>2017-01-22 21:55:06</td>\n", | |
| " <td>2017-01-22 22:03:13</td>\n", | |
| " <td>1.71</td>\n", | |
| " <td>8.116667</td>\n", | |
| " <td>21.0</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1.0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_amount fare_amount lpep_pickup_datetime lpep_dropoff_datetime \\\n", | |
| "0 20.30 16.5 2017-01-22 21:49:27 2017-01-22 22:07:02 \n", | |
| "1 26.16 20.5 2017-01-22 21:52:32 2017-01-22 22:15:40 \n", | |
| "2 10.56 7.5 2017-01-22 21:07:23 2017-01-22 21:14:19 \n", | |
| "3 12.96 9.5 2017-01-22 21:37:01 2017-01-22 21:46:48 \n", | |
| "4 11.16 8.0 2017-01-22 21:55:06 2017-01-22 22:03:13 \n", | |
| "\n", | |
| " trip_distance duration_minutes hour_of_day day_of_week week_of_year \\\n", | |
| "0 4.74 17.583333 21.0 6.0 3.0 \n", | |
| "1 5.56 23.133333 21.0 6.0 3.0 \n", | |
| "2 1.61 6.933333 21.0 6.0 3.0 \n", | |
| "3 2.28 9.783333 21.0 6.0 3.0 \n", | |
| "4 1.71 8.116667 21.0 6.0 3.0 \n", | |
| "\n", | |
| " month_of_year \n", | |
| "0 1.0 \n", | |
| "1 1.0 \n", | |
| "2 1.0 \n", | |
| "3 1.0 \n", | |
| "4 1.0 " | |
| ] | |
| }, | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Exclude any outliers\n", | |
| "data_df = data_df[(data_df.total_amount > 0) & (data_df.total_amount < 200) & \n", | |
| " (data_df.duration_minutes > 0) & (data_df.duration_minutes < 120) & \n", | |
| " (data_df.trip_distance > 0) & (data_df.trip_distance < 1000)].dropna()\n", | |
| "print(data_df.shape)\n", | |
| "data_df.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Visualise Data\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Setup plotting defaults\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import matplotlib as mpl\n", | |
| "import seaborn as sns\n", | |
| "%matplotlib inline\n", | |
| "\n", | |
| "mpl.style.use('seaborn')\n", | |
| "\n", | |
| "mpl.rcParams['figure.figsize'] = [12.0, 6.0]\n", | |
| "mpl.rcParams['figure.dpi'] = 80\n", | |
| "mpl.rcParams['savefig.dpi'] = 100\n", | |
| "\n", | |
| "mpl.rcParams['font.size'] = 12\n", | |
| "mpl.rcParams['legend.fontsize'] = 'medium'\n", | |
| "mpl.rcParams['figure.titlesize'] = 'medium'\n", | |
| "\n", | |
| "sample_df = data_df.sample(1000)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Inspect the distribution of data by date, and then cost/time metrics" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7fd22bc659e8>,\n", | |
| " <matplotlib.axes._subplots.AxesSubplot object at 0x7fd22bc0db38>],\n", | |
| " [<matplotlib.axes._subplots.AxesSubplot object at 0x7fd22bc3a4e0>,\n", | |
| " <matplotlib.axes._subplots.AxesSubplot object at 0x7fd22bbdee48>]],\n", | |
| " dtype=object)" | |
| ] | |
| }, | |
| "execution_count": 46, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 960x480 with 4 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "date_cols = ['hour_of_day', 'day_of_week', 'week_of_year', 'month_of_year']\n", | |
| "sample_df[date_cols].hist(bins=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7fd22bb01ac8>,\n", | |
| " <matplotlib.axes._subplots.AxesSubplot object at 0x7fd22bb16160>],\n", | |
| " [<matplotlib.axes._subplots.AxesSubplot object at 0x7fd2868722b0>,\n", | |
| " <matplotlib.axes._subplots.AxesSubplot object at 0x7fd22b70fc18>]],\n", | |
| " dtype=object)" | |
| ] | |
| }, | |
| "execution_count": 48, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 960x480 with 4 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "cost_cols = ['total_amount', 'fare_amount', 'duration_minutes', 'trip_distance']\n", | |
| "sample_df[cost_cols].hist(bins=100)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "See if there is a correlation between duration in minutes and trip distance" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.axes._subplots.AxesSubplot at 0x7fd22b515c50>" | |
| ] | |
| }, | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 960x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "sample_df.plot.scatter(x='duration_minutes', y='fare_amount')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.axes._subplots.AxesSubplot at 0x7fd22b2ae9e8>" | |
| ] | |
| }, | |
| "execution_count": 50, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 960x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "sample_df.plot.scatter(x='trip_distance', y='fare_amount')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.axes._subplots.AxesSubplot at 0x7fd22b1d0048>" | |
| ] | |
| }, | |
| "execution_count": 51, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 960x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "sample_df.plot.scatter(x='hour_of_day', y='fare_amount')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Train Model\n", | |
| "\n", | |
| "Build an XGBoost model to predict the total amount based on some fields" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "bucket: sagemaker-us-east-1-691313291965, prefix: nyc-taxi\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import boto3 \n", | |
| "import sagemaker\n", | |
| "\n", | |
| "sagemaker_session = sagemaker.session.Session(boto_session)\n", | |
| "role = sagemaker.get_execution_role()\n", | |
| "prefix = 'nyc-taxi'\n", | |
| "\n", | |
| "print('bucket: {}, prefix: {}'.format(bucket, prefix))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "split train: 837104, val: 104638, test: 104639 \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Trip test split\n", | |
| "from sklearn.model_selection import train_test_split\n", | |
| "\n", | |
| "train_cols = ['total_amount', 'duration_minutes', 'trip_distance', 'hour_of_day']\n", | |
| "train_df, val_df = train_test_split(data_df[train_cols], test_size=0.20, random_state=42)\n", | |
| "val_df, test_df = train_test_split(val_df, test_size=0.50, random_state=42)\n", | |
| "\n", | |
| "print('split train: {}, val: {}, test: {} '.format(train_df.shape[0], val_df.shape[0], test_df.shape[0]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Reset index and save files with target as first column\n", | |
| "train_df = train_df.reset_index(drop=True)\n", | |
| "val_df = val_df.reset_index(drop=True)\n", | |
| "test_df = test_df.reset_index(drop=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Upload Data\n", | |
| "\n", | |
| "Save train and validation as CSV with `total_amount` as first col but no headers" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Drop the tpep_pickup_datetime and save\n", | |
| "train_df.to_csv('train.csv', index=False, header=False)\n", | |
| "val_df.to_csv('validation.csv', index=False, header=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 214 ms, sys: 30.8 ms, total: 245 ms\n", | |
| "Wall time: 997 ms\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "\n", | |
| "# Uplaod the files to s3 \n", | |
| "s3_train_uri = sagemaker_session.upload_data('train.csv', bucket, prefix + '/data/training')\n", | |
| "s3_val_uri = sagemaker_session.upload_data('validation.csv', bucket, prefix + '/data/validation')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Validate that we have uploaded these files succesfully" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2020-03-23 05:25:34 23700354 train.csv\n", | |
| "2020-03-23 05:25:35 2966559 validation.csv\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!aws s3 ls $s3_train_uri \n", | |
| "!aws s3 ls $s3_val_uri" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Get estimator" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "container: 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from sagemaker.amazon.amazon_estimator import get_image_uri\n", | |
| "container = get_image_uri(region, 'xgboost', '0.90-1')\n", | |
| "print('container: {}'.format(container))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "output: s3://sagemaker-us-east-1-691313291965/nyc-taxi/output\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "output_path = 's3://{}/{}/output'.format(bucket, prefix)\n", | |
| "print('output: {}'.format(output_path))\n", | |
| "\n", | |
| "xgb = sagemaker.estimator.Estimator(container,\n", | |
| " role,\n", | |
| " train_instance_count=1,\n", | |
| " train_instance_type='ml.m4.xlarge',\n", | |
| " output_path=output_path,\n", | |
| " sagemaker_session=sagemaker_session)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2020-03-23 05:25:36 Starting - Starting the training job...\n", | |
| "2020-03-23 05:25:38 Starting - Launching requested ML instances.........\n", | |
| "2020-03-23 05:27:09 Starting - Preparing the instances for training......\n", | |
| "2020-03-23 05:28:36 Downloading - Downloading input data\n", | |
| "2020-03-23 05:28:36 Training - Downloading the training image...\n", | |
| "2020-03-23 05:28:58 Training - Training image download completed. Training in progress..\u001b[34mINFO:sagemaker-containers:Imported framework sagemaker_xgboost_container.training\u001b[0m\n", | |
| "\u001b[34mINFO:sagemaker-containers:Failed to parse hyperparameter objective value reg:linear to Json.\u001b[0m\n", | |
| "\u001b[34mReturning the value itself\u001b[0m\n", | |
| "\u001b[34mINFO:sagemaker-containers:No GPUs detected (normal if no gpus installed)\u001b[0m\n", | |
| "\u001b[34mINFO:sagemaker_xgboost_container.training:Running XGBoost Sagemaker in algorithm mode\u001b[0m\n", | |
| "\u001b[34mINFO:root:Determined delimiter of CSV input is ','\u001b[0m\n", | |
| "\u001b[34mINFO:root:Determined delimiter of CSV input is ','\u001b[0m\n", | |
| "\u001b[34mINFO:root:Determined delimiter of CSV input is ','\u001b[0m\n", | |
| "\u001b[34m[05:29:00] 837104x3 matrix with 2511312 entries loaded from /opt/ml/input/data/train?format=csv&label_column=0&delimiter=,\u001b[0m\n", | |
| "\u001b[34mINFO:root:Determined delimiter of CSV input is ','\u001b[0m\n", | |
| "\u001b[34m[05:29:00] 104638x3 matrix with 313914 entries loaded from /opt/ml/input/data/validation?format=csv&label_column=0&delimiter=,\u001b[0m\n", | |
| "\u001b[34mINFO:root:Single node training.\u001b[0m\n", | |
| "\u001b[34mINFO:root:Train matrix has 837104 rows\u001b[0m\n", | |
| "\u001b[34mINFO:root:Validation matrix has 104638 rows\u001b[0m\n", | |
| "\u001b[34m[05:29:00] WARNING: /workspace/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.\u001b[0m\n", | |
| "\u001b[34m[0]#011train-rmse:13.0692#011validation-rmse:13.1612\u001b[0m\n", | |
| "\u001b[34m[1]#011train-rmse:10.6111#011validation-rmse:10.7012\u001b[0m\n", | |
| "\u001b[34m[2]#011train-rmse:8.67642#011validation-rmse:8.76889\u001b[0m\n", | |
| "\u001b[34m[3]#011train-rmse:7.16695#011validation-rmse:7.26401\u001b[0m\n", | |
| "\u001b[34m[4]#011train-rmse:6.00262#011validation-rmse:6.10651\u001b[0m\n", | |
| "\u001b[34m[5]#011train-rmse:5.11893#011validation-rmse:5.23096\u001b[0m\n", | |
| "\u001b[34m[6]#011train-rmse:4.46191#011validation-rmse:4.58335\u001b[0m\n", | |
| "\u001b[34m[7]#011train-rmse:3.9839#011validation-rmse:4.11368\u001b[0m\n", | |
| "\u001b[34m[8]#011train-rmse:3.6433#011validation-rmse:3.7812\u001b[0m\n", | |
| "\u001b[34m[9]#011train-rmse:3.40639#011validation-rmse:3.5513\u001b[0m\n", | |
| "\u001b[34m[10]#011train-rmse:3.24428#011validation-rmse:3.39454\u001b[0m\n", | |
| "\u001b[34m[11]#011train-rmse:3.13648#011validation-rmse:3.29043\u001b[0m\n", | |
| "\u001b[34m[12]#011train-rmse:3.06445#011validation-rmse:3.22099\u001b[0m\n", | |
| "\u001b[34m[13]#011train-rmse:3.01457#011validation-rmse:3.17258\u001b[0m\n", | |
| "\u001b[34m[14]#011train-rmse:2.98203#011validation-rmse:3.14265\u001b[0m\n", | |
| "\u001b[34m[15]#011train-rmse:2.96079#011validation-rmse:3.12184\u001b[0m\n", | |
| "\u001b[34m[16]#011train-rmse:2.94609#011validation-rmse:3.10747\u001b[0m\n", | |
| "\u001b[34m[17]#011train-rmse:2.93603#011validation-rmse:3.09802\u001b[0m\n", | |
| "\u001b[34m[18]#011train-rmse:2.92937#011validation-rmse:3.09116\u001b[0m\n", | |
| "\u001b[34m[19]#011train-rmse:2.92451#011validation-rmse:3.08665\u001b[0m\n", | |
| "\u001b[34m[20]#011train-rmse:2.9212#011validation-rmse:3.08414\u001b[0m\n", | |
| "\u001b[34m[21]#011train-rmse:2.91907#011validation-rmse:3.08233\u001b[0m\n", | |
| "\u001b[34m[22]#011train-rmse:2.91687#011validation-rmse:3.08143\u001b[0m\n", | |
| "\u001b[34m[23]#011train-rmse:2.91541#011validation-rmse:3.08056\u001b[0m\n", | |
| "\u001b[34m[24]#011train-rmse:2.91476#011validation-rmse:3.0801\u001b[0m\n", | |
| "\u001b[34m[25]#011train-rmse:2.91364#011validation-rmse:3.07958\u001b[0m\n", | |
| "\u001b[34m[26]#011train-rmse:2.91242#011validation-rmse:3.07885\u001b[0m\n", | |
| "\u001b[34m[27]#011train-rmse:2.91151#011validation-rmse:3.07778\u001b[0m\n", | |
| "\u001b[34m[28]#011train-rmse:2.91076#011validation-rmse:3.07718\u001b[0m\n", | |
| "\u001b[34m[29]#011train-rmse:2.91033#011validation-rmse:3.0771\u001b[0m\n", | |
| "\u001b[34m[30]#011train-rmse:2.90999#011validation-rmse:3.07715\u001b[0m\n", | |
| "\u001b[34m[31]#011train-rmse:2.90978#011validation-rmse:3.07716\u001b[0m\n", | |
| "\u001b[34m[32]#011train-rmse:2.9095#011validation-rmse:3.07682\u001b[0m\n", | |
| "\u001b[34m[33]#011train-rmse:2.90914#011validation-rmse:3.07649\u001b[0m\n", | |
| "\u001b[34m[34]#011train-rmse:2.90873#011validation-rmse:3.07659\u001b[0m\n", | |
| "\u001b[34m[35]#011train-rmse:2.90838#011validation-rmse:3.0766\u001b[0m\n", | |
| "\u001b[34m[36]#011train-rmse:2.90784#011validation-rmse:3.07625\u001b[0m\n", | |
| "\u001b[34m[37]#011train-rmse:2.90713#011validation-rmse:3.07665\u001b[0m\n", | |
| "\u001b[34m[38]#011train-rmse:2.90699#011validation-rmse:3.07682\u001b[0m\n", | |
| "\u001b[34m[39]#011train-rmse:2.90651#011validation-rmse:3.07634\u001b[0m\n", | |
| "\u001b[34m[40]#011train-rmse:2.9063#011validation-rmse:3.07625\u001b[0m\n", | |
| "\u001b[34m[41]#011train-rmse:2.90587#011validation-rmse:3.07616\u001b[0m\n", | |
| "\u001b[34m[42]#011train-rmse:2.90548#011validation-rmse:3.07558\u001b[0m\n", | |
| "\u001b[34m[43]#011train-rmse:2.90489#011validation-rmse:3.07631\u001b[0m\n", | |
| "\u001b[34m[44]#011train-rmse:2.90473#011validation-rmse:3.07629\u001b[0m\n", | |
| "\u001b[34m[45]#011train-rmse:2.90467#011validation-rmse:3.07646\u001b[0m\n", | |
| "\u001b[34m[46]#011train-rmse:2.90442#011validation-rmse:3.0759\u001b[0m\n", | |
| "\u001b[34m[47]#011train-rmse:2.90406#011validation-rmse:3.07541\u001b[0m\n", | |
| "\u001b[34m[48]#011train-rmse:2.90367#011validation-rmse:3.07583\u001b[0m\n", | |
| "\u001b[34m[49]#011train-rmse:2.90351#011validation-rmse:3.07563\u001b[0m\n", | |
| "\u001b[34m[50]#011train-rmse:2.90282#011validation-rmse:3.07573\u001b[0m\n", | |
| "\n", | |
| "2020-03-23 05:29:36 Uploading - Uploading generated training model\n", | |
| "2020-03-23 05:29:36 Completed - Training job completed\n", | |
| "\u001b[34m[51]#011train-rmse:2.9025#011validation-rmse:3.07599\u001b[0m\n", | |
| "\u001b[34m[52]#011train-rmse:2.90232#011validation-rmse:3.07636\u001b[0m\n", | |
| "\u001b[34m[53]#011train-rmse:2.90185#011validation-rmse:3.07594\u001b[0m\n", | |
| "\u001b[34m[54]#011train-rmse:2.9015#011validation-rmse:3.0765\u001b[0m\n", | |
| "\u001b[34m[55]#011train-rmse:2.90112#011validation-rmse:3.07619\u001b[0m\n", | |
| "\u001b[34m[56]#011train-rmse:2.90091#011validation-rmse:3.07628\u001b[0m\n", | |
| "\u001b[34m[57]#011train-rmse:2.90062#011validation-rmse:3.07634\u001b[0m\n", | |
| "Training seconds: 85\n", | |
| "Billable seconds: 85\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "xgb.set_hyperparameters(max_depth=9,\n", | |
| " eta=0.2, \n", | |
| " gamma=4,\n", | |
| " min_child_weight=300,\n", | |
| " subsample=0.8,\n", | |
| " silent=0,\n", | |
| " objective='reg:linear',\n", | |
| " early_stopping_rounds=10,\n", | |
| " num_round=10000)\n", | |
| "\n", | |
| "s3_input_train = sagemaker.s3_input(s3_data=s3_train_uri, content_type='csv')\n", | |
| "s3_input_val = sagemaker.s3_input(s3_data=s3_val_uri, content_type='csv')\n", | |
| "\n", | |
| "xgb.fit({'train': s3_input_train, 'validation': s3_input_val})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Deploy model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "-------------!" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "xgb_predictor = xgb.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Evalulate Model\n", | |
| "\n", | |
| "Get predicitons for the validation set" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sagemaker.predictor import csv_serializer\n", | |
| "\n", | |
| "xgb_predictor.content_type = 'text/csv'\n", | |
| "xgb_predictor.serializer = csv_serializer\n", | |
| "xgb_predictor.deserializer = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 210/210 [00:04<00:00, 45.49it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 1.47 s, sys: 18.4 ms, total: 1.49 s\n", | |
| "Wall time: 4.67 s\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "def predict(data, rows=500):\n", | |
| " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", | |
| " predictions = ''\n", | |
| " for array in tqdm(split_array):\n", | |
| " predictions = ','.join([predictions, xgb_predictor.predict(array).decode('utf-8')])\n", | |
| " return np.fromstring(predictions[1:], sep=',')\n", | |
| "\n", | |
| "# Get predictions and store in df\n", | |
| "predictions = predict(val_df[train_cols[1:]].values)\n", | |
| "predictions = pd.DataFrame({'total_amount_predictions': predictions })" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_amount</th>\n", | |
| " <th>duration_minutes</th>\n", | |
| " <th>trip_distance</th>\n", | |
| " <th>hour_of_day</th>\n", | |
| " <th>total_amount_predictions</th>\n", | |
| " <th>error</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>80926</th>\n", | |
| " <td>160.00</td>\n", | |
| " <td>0.050000</td>\n", | |
| " <td>0.08</td>\n", | |
| " <td>19.0</td>\n", | |
| " <td>20.669760</td>\n", | |
| " <td>139.330240</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>36441</th>\n", | |
| " <td>150.00</td>\n", | |
| " <td>25.450000</td>\n", | |
| " <td>1.56</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>17.868708</td>\n", | |
| " <td>132.131292</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>31496</th>\n", | |
| " <td>120.00</td>\n", | |
| " <td>0.283333</td>\n", | |
| " <td>0.08</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>7.605297</td>\n", | |
| " <td>112.394703</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2357</th>\n", | |
| " <td>118.30</td>\n", | |
| " <td>17.833333</td>\n", | |
| " <td>4.51</td>\n", | |
| " <td>23.0</td>\n", | |
| " <td>19.775103</td>\n", | |
| " <td>98.524897</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>51305</th>\n", | |
| " <td>197.61</td>\n", | |
| " <td>118.533333</td>\n", | |
| " <td>53.10</td>\n", | |
| " <td>14.0</td>\n", | |
| " <td>104.829735</td>\n", | |
| " <td>92.780265</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>77732</th>\n", | |
| " <td>141.30</td>\n", | |
| " <td>46.416667</td>\n", | |
| " <td>12.14</td>\n", | |
| " <td>16.0</td>\n", | |
| " <td>48.566212</td>\n", | |
| " <td>92.733788</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>7371</th>\n", | |
| " <td>100.50</td>\n", | |
| " <td>0.500000</td>\n", | |
| " <td>0.12</td>\n", | |
| " <td>23.0</td>\n", | |
| " <td>8.080048</td>\n", | |
| " <td>92.419952</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>132</th>\n", | |
| " <td>178.50</td>\n", | |
| " <td>41.366667</td>\n", | |
| " <td>33.62</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>95.918167</td>\n", | |
| " <td>82.581833</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>21145</th>\n", | |
| " <td>91.00</td>\n", | |
| " <td>10.316667</td>\n", | |
| " <td>1.38</td>\n", | |
| " <td>18.0</td>\n", | |
| " <td>10.686583</td>\n", | |
| " <td>80.313417</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>56765</th>\n", | |
| " <td>100.00</td>\n", | |
| " <td>0.083333</td>\n", | |
| " <td>0.06</td>\n", | |
| " <td>20.0</td>\n", | |
| " <td>21.311424</td>\n", | |
| " <td>78.688576</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_amount duration_minutes trip_distance hour_of_day \\\n", | |
| "80926 160.00 0.050000 0.08 19.0 \n", | |
| "36441 150.00 25.450000 1.56 3.0 \n", | |
| "31496 120.00 0.283333 0.08 4.0 \n", | |
| "2357 118.30 17.833333 4.51 23.0 \n", | |
| "51305 197.61 118.533333 53.10 14.0 \n", | |
| "77732 141.30 46.416667 12.14 16.0 \n", | |
| "7371 100.50 0.500000 0.12 23.0 \n", | |
| "132 178.50 41.366667 33.62 6.0 \n", | |
| "21145 91.00 10.316667 1.38 18.0 \n", | |
| "56765 100.00 0.083333 0.06 20.0 \n", | |
| "\n", | |
| " total_amount_predictions error \n", | |
| "80926 20.669760 139.330240 \n", | |
| "36441 17.868708 132.131292 \n", | |
| "31496 7.605297 112.394703 \n", | |
| "2357 19.775103 98.524897 \n", | |
| "51305 104.829735 92.780265 \n", | |
| "77732 48.566212 92.733788 \n", | |
| "7371 8.080048 92.419952 \n", | |
| "132 95.918167 82.581833 \n", | |
| "21145 10.686583 80.313417 \n", | |
| "56765 21.311424 78.688576 " | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Get the abs error between predictions\n", | |
| "pred_df = val_df.join(predictions)\n", | |
| "pred_df['error'] = abs(pred_df['total_amount']-pred_df['total_amount_predictions'])\n", | |
| "pred_df.sort_values('error', ascending=False).head(10)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Print the `RMSE` validation metric" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "3.0754101574135873" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from math import sqrt\n", | |
| "from sklearn.metrics import mean_squared_error\n", | |
| "\n", | |
| "def rmse(pred_df):\n", | |
| " return sqrt(mean_squared_error(pred_df['total_amount'], pred_df['total_amount_predictions']))\n", | |
| "\n", | |
| "rmse(pred_df)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Visualse the outlines vs predicted values" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.axes._subplots.AxesSubplot at 0x7fd22b19dda0>" | |
| ] | |
| }, | |
| "execution_count": 52, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 960x480 with 2 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "pred_df.tail(1000).plot.scatter(x='total_amount_predictions', y='total_amount', c='error', title='actual amount (y) vs predicted (x)')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Create Athena UDF \n", | |
| "\n", | |
| "Create a [User Defined Function](https://aws.amazon.com/blogs/big-data/prepare-data-for-model-training-and-invoke-machine-learning-models-with-amazon-athena/) for the deployed endpoint so you can query directly in Athena." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "endpoint: sagemaker-xgboost-2020-03-23-05-25-35-962\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "endpoint_name = xgb_predictor.endpoint\n", | |
| "print('endpoint: {}'.format(endpoint_name))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "`NOTE`: Athena ML is [in preview](https://aws.amazon.com/athena/faqs/#Preview_features). To enable this Preview feature you need to create an Athena workgroup named `AmazonAthenaPreviewFunctionality` and run any queries attempting to federate to this connector, use a UDF, or SageMaker inference from that workgroup." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\r\n", | |
| "An error occurred (InvalidRequestException) when calling the CreateWorkGroup operation: WorkGroup AmazonAthenaPreviewFunctionality is already created\r\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "workgroup_name = 'AmazonAthenaPreviewFunctionality'\n", | |
| "\n", | |
| "!aws athena create-work-group --name $workgroup_name --region $region" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Using presto [datetime](https://prestodb.io/docs/0.172/functions/datetime.html) functions with inline query, rank by absolute error." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Querying... \n", | |
| "USING FUNCTION predict_total(\n", | |
| " duration_minutes DOUBLE, \n", | |
| " trip_distance DOUBLE, \n", | |
| " hour_of_day DOUBLE) returns DOUBLE type SAGEMAKER_INVOKE_ENDPOINT\n", | |
| "WITH (sagemaker_endpoint='sagemaker-xgboost-2020-03-23-05-25-35-962')\n", | |
| "\n", | |
| "SELECT \n", | |
| " *, ABS(predicted_total_amount-total_amount) as error\n", | |
| "FROM ( \n", | |
| " SELECT\n", | |
| " *,\n", | |
| " predict_total(duration_minutes, trip_distance, hour_of_day) as predicted_total_amount\n", | |
| " FROM \n", | |
| " (\n", | |
| " SELECT \n", | |
| " total_amount,\n", | |
| " CAST(date_diff('minute', \n", | |
| " CAST(lpep_pickup_datetime as timestamp), \n", | |
| " CAST(lpep_dropoff_datetime as timestamp)) as DOUBLE) as duration_minutes,\n", | |
| " CAST(trip_distance as DOUBLE) as trip_distance,\n", | |
| " CAST(hour(CAST(lpep_pickup_datetime as timestamp)) as double) as hour_of_day\n", | |
| " FROM nyc_taxi.taxi_csv\n", | |
| " WHERE DAY(CAST(lpep_pickup_datetime as timestamp)) = 1 -- Filter by day\n", | |
| " )\n", | |
| ")\n", | |
| "ORDER BY error DESC\n", | |
| "LIMIT 10;\n", | |
| "\n", | |
| "Status: SUCCEEDED, Run time: 5.93s, Data scanned: 91.34MB, Records: 10\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_amount</th>\n", | |
| " <th>duration_minutes</th>\n", | |
| " <th>trip_distance</th>\n", | |
| " <th>hour_of_day</th>\n", | |
| " <th>predicted_total_amount</th>\n", | |
| " <th>error</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>240.00</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.00</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>18.295891</td>\n", | |
| " <td>221.704109</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>276.64</td>\n", | |
| " <td>49.0</td>\n", | |
| " <td>48.20</td>\n", | |
| " <td>8.0</td>\n", | |
| " <td>96.595856</td>\n", | |
| " <td>180.044144</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>140.80</td>\n", | |
| " <td>1.0</td>\n", | |
| " <td>0.20</td>\n", | |
| " <td>10.0</td>\n", | |
| " <td>4.583008</td>\n", | |
| " <td>136.216992</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>203.16</td>\n", | |
| " <td>43.0</td>\n", | |
| " <td>38.59</td>\n", | |
| " <td>11.0</td>\n", | |
| " <td>96.354904</td>\n", | |
| " <td>106.805096</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>7.00</td>\n", | |
| " <td>193.0</td>\n", | |
| " <td>52.96</td>\n", | |
| " <td>22.0</td>\n", | |
| " <td>103.960861</td>\n", | |
| " <td>96.960861</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>5</th>\n", | |
| " <td>141.30</td>\n", | |
| " <td>46.0</td>\n", | |
| " <td>12.14</td>\n", | |
| " <td>16.0</td>\n", | |
| " <td>48.747108</td>\n", | |
| " <td>92.552892</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>6</th>\n", | |
| " <td>-68.31</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.01</td>\n", | |
| " <td>17.0</td>\n", | |
| " <td>22.999298</td>\n", | |
| " <td>91.309298</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>7</th>\n", | |
| " <td>10.00</td>\n", | |
| " <td>52.0</td>\n", | |
| " <td>26.49</td>\n", | |
| " <td>7.0</td>\n", | |
| " <td>96.595856</td>\n", | |
| " <td>86.595856</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>8</th>\n", | |
| " <td>104.00</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.00</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>18.295891</td>\n", | |
| " <td>85.704109</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>9</th>\n", | |
| " <td>178.50</td>\n", | |
| " <td>41.0</td>\n", | |
| " <td>33.62</td>\n", | |
| " <td>6.0</td>\n", | |
| " <td>95.805733</td>\n", | |
| " <td>82.694267</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_amount duration_minutes trip_distance hour_of_day \\\n", | |
| "0 240.00 0.0 0.00 4.0 \n", | |
| "1 276.64 49.0 48.20 8.0 \n", | |
| "2 140.80 1.0 0.20 10.0 \n", | |
| "3 203.16 43.0 38.59 11.0 \n", | |
| "4 7.00 193.0 52.96 22.0 \n", | |
| "5 141.30 46.0 12.14 16.0 \n", | |
| "6 -68.31 0.0 0.01 17.0 \n", | |
| "7 10.00 52.0 26.49 7.0 \n", | |
| "8 104.00 0.0 0.00 6.0 \n", | |
| "9 178.50 41.0 33.62 6.0 \n", | |
| "\n", | |
| " predicted_total_amount error \n", | |
| "0 18.295891 221.704109 \n", | |
| "1 96.595856 180.044144 \n", | |
| "2 4.583008 136.216992 \n", | |
| "3 96.354904 106.805096 \n", | |
| "4 103.960861 96.960861 \n", | |
| "5 48.747108 92.552892 \n", | |
| "6 22.999298 91.309298 \n", | |
| "7 96.595856 86.595856 \n", | |
| "8 18.295891 85.704109 \n", | |
| "9 95.805733 82.694267 " | |
| ] | |
| }, | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "query_sql = '''\n", | |
| "USING FUNCTION predict_total(\n", | |
| " duration_minutes DOUBLE, \n", | |
| " trip_distance DOUBLE, \n", | |
| " hour_of_day DOUBLE) returns DOUBLE type SAGEMAKER_INVOKE_ENDPOINT\n", | |
| "WITH (sagemaker_endpoint='{}')\n", | |
| "\n", | |
| "SELECT \n", | |
| " *, ABS(predicted_total_amount-total_amount) as error\n", | |
| "FROM ( \n", | |
| " SELECT\n", | |
| " *,\n", | |
| " predict_total(duration_minutes, trip_distance, hour_of_day) as predicted_total_amount\n", | |
| " FROM \n", | |
| " (\n", | |
| " SELECT \n", | |
| " total_amount,\n", | |
| " CAST(date_diff('minute', \n", | |
| " CAST(lpep_pickup_datetime as timestamp), \n", | |
| " CAST(lpep_dropoff_datetime as timestamp)) as DOUBLE) as duration_minutes,\n", | |
| " CAST(trip_distance as DOUBLE) as trip_distance,\n", | |
| " CAST(hour(CAST(lpep_pickup_datetime as timestamp)) as double) as hour_of_day\n", | |
| " FROM {}\n", | |
| " WHERE DAY(CAST(lpep_pickup_datetime as timestamp)) = {} -- Filter by day\n", | |
| " )\n", | |
| ")\n", | |
| "ORDER BY error DESC\n", | |
| "LIMIT {};\n", | |
| "'''.format(endpoint_name, table_name, 1, 10)\n", | |
| "print('Querying...', query_sql)\n", | |
| "\n", | |
| "query_df = cursor.execute(query_sql, work_group=workgroup_name).as_pandas()\n", | |
| "print('Status: {}, Run time: {:.2f}s, Data scanned: {:.2f}MB, Records: {:,}'.format(cursor.state, \n", | |
| " cursor.execution_time_in_millis/1000.0, cursor.data_scanned_in_bytes/1024.0/1024.0, query_df.shape[0]))\n", | |
| "\n", | |
| "query_df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "conda_tensorflow_p36", | |
| "language": "python", | |
| "name": "conda_tensorflow_p36" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment