-
-
Save RichardScottOZ/89e054ca5c4c9fbbbf8d5d6499df856c to your computer and use it in GitHub Desktop.
vectorized `sklearn` with `xarray`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# vectorized `sklearn` with `xarray`\n", | |
| "\n", | |
| "run a `sklearn` classifier on a grid (longitude/X, latitude/Y, lead_time, ...) all at once\n", | |
| "\n", | |
| "might be slow due to `vectorize=True`, but the code is short\n", | |
| "\n", | |
| "inspired by and based on https://renkulab.io/gitlab/lluis.palma/s2s-ai-challenge-bsc/-/blob/submission-ML_models/notebooks/S2S_ML_models.ipynb\n", | |
| "\n", | |
| "answers also https://discourse.pangeo.io/t/vectorized-sklearn/1444" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## import" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "import xarray as xr\n", | |
| "xr.set_options(display_style='text')\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "from sklearn.linear_model import LogisticRegression" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<pre><xarray.Dataset>\n", | |
| "Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 20)\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * year (year) int64 2000 2001 2002 2003 2004 ... 2016 2017 2018 2019\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| "Data variables:\n", | |
| " t2m (lead_time, year, week, X, Y) float64 0.885 0.61 ... 0.1928\n", | |
| " tp (lead_time, year, week, X, Y) float64 0.0597 0.7052 ... 0.3623\n", | |
| " msl (lead_time, year, week, X, Y) float64 0.5728 0.8126 ... 0.2536</pre>" | |
| ], | |
| "text/plain": [ | |
| "<xarray.Dataset>\n", | |
| "Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 20)\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * year (year) int64 2000 2001 2002 2003 2004 ... 2016 2017 2018 2019\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| "Data variables:\n", | |
| " t2m (lead_time, year, week, X, Y) float64 0.885 0.61 ... 0.1928\n", | |
| " tp (lead_time, year, week, X, Y) float64 0.0597 0.7052 ... 0.3623\n", | |
| " msl (lead_time, year, week, X, Y) float64 0.5728 0.8126 ... 0.2536" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# synethetic data: competition on 5x5 grid \n", | |
| "# raw forecasts\n", | |
| "X_train = xr.DataArray(np.random.rand(2,20,53,5,5,3),\n", | |
| " dims=['lead_time','year','week','X','Y','variable'],\n", | |
| " coords={'lead_time':[1,2],'year':range(2000,2020),'week':range(53), 'X':range(5), \"Y\":range(5), \"variable\":['t2m','tp','msl']}\n", | |
| " ).to_dataset(dim='variable')\n", | |
| "X_train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<pre><xarray.Dataset>\n", | |
| "Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 2)\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * year (year) int64 2018 2019\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| "Data variables:\n", | |
| " t2m (lead_time, year, week, X, Y) float64 0.8516 0.4321 ... 0.1928\n", | |
| " tp (lead_time, year, week, X, Y) float64 0.9754 0.6478 ... 0.3623\n", | |
| " msl (lead_time, year, week, X, Y) float64 0.9741 0.05569 ... 0.2536</pre>" | |
| ], | |
| "text/plain": [ | |
| "<xarray.Dataset>\n", | |
| "Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 2)\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * year (year) int64 2018 2019\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| "Data variables:\n", | |
| " t2m (lead_time, year, week, X, Y) float64 0.8516 0.4321 ... 0.1928\n", | |
| " tp (lead_time, year, week, X, Y) float64 0.9754 0.6478 ... 0.3623\n", | |
| " msl (lead_time, year, week, X, Y) float64 0.9741 0.05569 ... 0.2536" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test = X_train.isel(year=[-2,-1])\n", | |
| "X_train = X_train.isel(year=slice(None,-2))\n", | |
| "X_test" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<pre><xarray.Dataset>\n", | |
| "Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 18)\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * year (year) int64 2000 2001 2002 2003 2004 ... 2014 2015 2016 2017\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| "Data variables:\n", | |
| " t2m (lead_time, year, week, X, Y) float64 2.0 1.0 0.0 ... 0.0 1.0 2.0\n", | |
| " tp (lead_time, year, week, X, Y) float64 0.0 2.0 0.0 ... 2.0 0.0 2.0\n", | |
| " msl (lead_time, year, week, X, Y) float64 1.0 2.0 1.0 ... 2.0 1.0 1.0</pre>" | |
| ], | |
| "text/plain": [ | |
| "<xarray.Dataset>\n", | |
| "Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 18)\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * year (year) int64 2000 2001 2002 2003 2004 ... 2014 2015 2016 2017\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| "Data variables:\n", | |
| " t2m (lead_time, year, week, X, Y) float64 2.0 1.0 0.0 ... 0.0 1.0 2.0\n", | |
| " tp (lead_time, year, week, X, Y) float64 0.0 2.0 0.0 ... 2.0 0.0 2.0\n", | |
| " msl (lead_time, year, week, X, Y) float64 1.0 2.0 1.0 ... 2.0 1.0 1.0" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# categorized observations\n", | |
| "y_train = xr.concat([\n", | |
| " 0*xr.ones_like(X_train).where(X_train < 1/3, other=0),\n", | |
| " 1*xr.ones_like(X_train).where((X_train > 1/3) & (X_train < 2/3), other=0),\n", | |
| " 2*xr.ones_like(X_train).where(X_train > 2/3, other=0)\n", | |
| "],'category').sum('category')\n", | |
| "y_train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## config" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sample_dims = ['year','week'] # dimensions used as samples\n", | |
| "features = ['t2m','tp','msl'] # variables used as features\n", | |
| "target_var = 't2m' # var to predict\n", | |
| "\n", | |
| "# sklearn method\n", | |
| "clf = LogisticRegression(penalty='l2',\n", | |
| " solver='liblinear',\n", | |
| " random_state=0,\n", | |
| " multi_class='auto')\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def atomic_function_training_LR(X_train, y_train, clf):\n", | |
| " feature_size=X_train.shape[-1]\n", | |
| " sample_size=np.prod(X_train.shape[:-1])\n", | |
| " # ensure samples are first dimensions\n", | |
| " X_train = X_train.reshape(sample_size, feature_size) # sample sizes, feature sizes\n", | |
| " y_train = y_train.reshape(sample_size)\n", | |
| " try:\n", | |
| " clf = clf.fit(X_train, y_train)\n", | |
| " return clf\n", | |
| " except:\n", | |
| " return None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 93 ms, sys: 2.35 ms, total: 95.4 ms\n", | |
| "Wall time: 103 ms\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "all_classifiers = xr.apply_ufunc(\n", | |
| " atomic_function_training_LR,\n", | |
| " X_train[features].to_array().transpose(...,'variable'), # transpose variable last\n", | |
| " y_train[target_var],\n", | |
| " clf,\n", | |
| " input_core_dims=[sample_dims+['variable'], sample_dims, []], # add variable if needed\n", | |
| " vectorize=True,\n", | |
| " dask='parallelized',\n", | |
| " output_dtypes=[object])\n", | |
| "all_classifiers = all_classifiers.compute()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## predict" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 1.54 s, sys: 29.9 ms, total: 1.57 s\n", | |
| "Wall time: 1.68 s\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<pre><xarray.DataArray (lead_time: 2, X: 5, Y: 5, week: 53, category: 3)>\n", | |
| "array([[[[[2.84430615e-03, 2.48019553e-01, 7.49136141e-01],\n", | |
| " [1.59098678e-02, 4.24464371e-01, 5.59625761e-01],\n", | |
| " [6.43132743e-01, 3.47622040e-01, 9.24521704e-03],\n", | |
| " ...,\n", | |
| " [2.88347752e-02, 2.97344648e-01, 6.73820577e-01],\n", | |
| " [6.99026288e-01, 2.92805754e-01, 8.16795839e-03],\n", | |
| " [6.84561350e-01, 3.09447865e-01, 5.99078411e-03]],\n", | |
| "\n", | |
| " [[3.14927916e-01, 5.73638100e-01, 1.11433984e-01],\n", | |
| " [1.82388659e-03, 2.94081405e-01, 7.04094708e-01],\n", | |
| " [3.03359651e-01, 5.03803615e-01, 1.92836733e-01],\n", | |
| " ...,\n", | |
| " [1.25277388e-03, 2.33315838e-01, 7.65431388e-01],\n", | |
| " [9.91532734e-02, 5.16850748e-01, 3.83995978e-01],\n", | |
| " [7.18120246e-01, 2.78537920e-01, 3.34183389e-03]],\n", | |
| "\n", | |
| " [[1.22389131e-02, 2.89980880e-01, 6.97780207e-01],\n", | |
| " [7.48693273e-01, 2.49804366e-01, 1.50236071e-03],\n", | |
| " [7.40388030e-01, 2.56002389e-01, 3.60958146e-03],\n", | |
| " ...,\n", | |
| "...\n", | |
| " ...,\n", | |
| " [1.59798246e-03, 2.40037927e-01, 7.58364091e-01],\n", | |
| " [5.78522467e-03, 2.42150429e-01, 7.52064346e-01],\n", | |
| " [1.93062439e-01, 4.20348759e-01, 3.86588802e-01]],\n", | |
| "\n", | |
| " [[4.88716204e-03, 2.38253500e-01, 7.56859338e-01],\n", | |
| " [3.58386812e-01, 5.16826179e-01, 1.24787008e-01],\n", | |
| " [6.64768310e-01, 3.30687396e-01, 4.54429314e-03],\n", | |
| " ...,\n", | |
| " [5.13090849e-02, 4.65301898e-01, 4.83389018e-01],\n", | |
| " [6.93597202e-01, 3.05439592e-01, 9.63206169e-04],\n", | |
| " [6.90890885e-01, 3.07015807e-01, 2.09330822e-03]],\n", | |
| "\n", | |
| " [[8.97167636e-03, 3.26590274e-01, 6.64438049e-01],\n", | |
| " [7.34400789e-01, 2.62958955e-01, 2.64025675e-03],\n", | |
| " [1.81597565e-03, 2.58011987e-01, 7.40172038e-01],\n", | |
| " ...,\n", | |
| " [4.70693953e-01, 4.48879066e-01, 8.04269809e-02],\n", | |
| " [3.60043640e-03, 2.39024250e-01, 7.57375314e-01],\n", | |
| " [6.55709623e-01, 3.37543667e-01, 6.74671012e-03]]]]])\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * category (category) float64 0.0 1.0 2.0</pre>" | |
| ], | |
| "text/plain": [ | |
| "<xarray.DataArray (lead_time: 2, X: 5, Y: 5, week: 53, category: 3)>\n", | |
| "array([[[[[2.84430615e-03, 2.48019553e-01, 7.49136141e-01],\n", | |
| " [1.59098678e-02, 4.24464371e-01, 5.59625761e-01],\n", | |
| " [6.43132743e-01, 3.47622040e-01, 9.24521704e-03],\n", | |
| " ...,\n", | |
| " [2.88347752e-02, 2.97344648e-01, 6.73820577e-01],\n", | |
| " [6.99026288e-01, 2.92805754e-01, 8.16795839e-03],\n", | |
| " [6.84561350e-01, 3.09447865e-01, 5.99078411e-03]],\n", | |
| "\n", | |
| " [[3.14927916e-01, 5.73638100e-01, 1.11433984e-01],\n", | |
| " [1.82388659e-03, 2.94081405e-01, 7.04094708e-01],\n", | |
| " [3.03359651e-01, 5.03803615e-01, 1.92836733e-01],\n", | |
| " ...,\n", | |
| " [1.25277388e-03, 2.33315838e-01, 7.65431388e-01],\n", | |
| " [9.91532734e-02, 5.16850748e-01, 3.83995978e-01],\n", | |
| " [7.18120246e-01, 2.78537920e-01, 3.34183389e-03]],\n", | |
| "\n", | |
| " [[1.22389131e-02, 2.89980880e-01, 6.97780207e-01],\n", | |
| " [7.48693273e-01, 2.49804366e-01, 1.50236071e-03],\n", | |
| " [7.40388030e-01, 2.56002389e-01, 3.60958146e-03],\n", | |
| " ...,\n", | |
| "...\n", | |
| " ...,\n", | |
| " [1.59798246e-03, 2.40037927e-01, 7.58364091e-01],\n", | |
| " [5.78522467e-03, 2.42150429e-01, 7.52064346e-01],\n", | |
| " [1.93062439e-01, 4.20348759e-01, 3.86588802e-01]],\n", | |
| "\n", | |
| " [[4.88716204e-03, 2.38253500e-01, 7.56859338e-01],\n", | |
| " [3.58386812e-01, 5.16826179e-01, 1.24787008e-01],\n", | |
| " [6.64768310e-01, 3.30687396e-01, 4.54429314e-03],\n", | |
| " ...,\n", | |
| " [5.13090849e-02, 4.65301898e-01, 4.83389018e-01],\n", | |
| " [6.93597202e-01, 3.05439592e-01, 9.63206169e-04],\n", | |
| " [6.90890885e-01, 3.07015807e-01, 2.09330822e-03]],\n", | |
| "\n", | |
| " [[8.97167636e-03, 3.26590274e-01, 6.64438049e-01],\n", | |
| " [7.34400789e-01, 2.62958955e-01, 2.64025675e-03],\n", | |
| " [1.81597565e-03, 2.58011987e-01, 7.40172038e-01],\n", | |
| " ...,\n", | |
| " [4.70693953e-01, 4.48879066e-01, 8.04269809e-02],\n", | |
| " [3.60043640e-03, 2.39024250e-01, 7.57375314e-01],\n", | |
| " [6.55709623e-01, 3.37543667e-01, 6.74671012e-03]]]]])\n", | |
| "Coordinates:\n", | |
| " * lead_time (lead_time) int64 1 2\n", | |
| " * X (X) int64 0 1 2 3 4\n", | |
| " * Y (Y) int64 0 1 2 3 4\n", | |
| " * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n", | |
| " * category (category) float64 0.0 1.0 2.0" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "def atomic_function_prediction_lr(classifiers, X_test):\n", | |
| " try:\n", | |
| " sample_size = np.prod(X_test.shape[:-1])\n", | |
| " feature_size = X_test.shape[-1]\n", | |
| " if len(X_test.shape)!=2:\n", | |
| " print('reshape')\n", | |
| " X_test = X_test.reshape(sample_size,features_size)\n", | |
| " prediction = classifiers.predict_proba(X_test)[0]\n", | |
| " prediction = xr.DataArray(prediction,dims='category')\n", | |
| " prediction = prediction.assign_coords(category=classifiers.classes_) # doesnt stick\n", | |
| " return prediction\n", | |
| " except Exception as e: # set climatology instead\n", | |
| " print(type(e).__name__,e)\n", | |
| " n_classes = len(classifiers.classes_)\n", | |
| " return xr.DataArray(np.repeat([1/n_classes,n_classes]),dims='category') # adapt repeat\n", | |
| "\n", | |
| "predictions = xr.apply_ufunc(atomic_function_prediction_lr,\n", | |
| " all_classifiers,\n", | |
| " X_test[features].to_array().transpose(...,'variable'),\n", | |
| " input_core_dims=[[], [\"year\",'variable']], # adapt year\n", | |
| " vectorize=True,\n", | |
| " dask='parallelized',\n", | |
| " output_core_dims=[['category']] # new dim for predict_proba\n", | |
| " ).compute()\n", | |
| "\n", | |
| "# manually add new coords\n", | |
| "predictions = predictions.assign_coords(category=all_classifiers.isel({i:0 for i in all_classifiers.dims}).item().classes_)\n", | |
| "predictions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "xr", | |
| "language": "python", | |
| "name": "xr" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.8" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment