Created
March 9, 2021 18:23
-
-
Save ottonemo/f8771c3a7f0f6abf6afb8ae157b673ba to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| [[package]] | |
| name = "atomicwrites" | |
| version = "1.4.0" | |
| description = "Atomic file writes." | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
| [[package]] | |
| name = "attrs" | |
| version = "20.3.0" | |
| description = "Classes Without Boilerplate" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
| [package.extras] | |
| dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] | |
| docs = ["furo", "sphinx", "zope.interface"] | |
| tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] | |
| tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] | |
| [[package]] | |
| name = "colorama" | |
| version = "0.4.4" | |
| description = "Cross-platform colored terminal text." | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" | |
| [[package]] | |
| name = "more-itertools" | |
| version = "8.7.0" | |
| description = "More routines for operating on iterables, beyond itertools" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=3.5" | |
| [[package]] | |
| name = "packaging" | |
| version = "20.9" | |
| description = "Core utilities for Python packages" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
| [package.dependencies] | |
| pyparsing = ">=2.0.2" | |
| [[package]] | |
| name = "pluggy" | |
| version = "0.13.1" | |
| description = "plugin and hook calling mechanisms for python" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
| [package.extras] | |
| dev = ["pre-commit", "tox"] | |
| [[package]] | |
| name = "py" | |
| version = "1.10.0" | |
| description = "library with cross-python path, ini-parsing, io, code, log facilities" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
| [[package]] | |
| name = "pyparsing" | |
| version = "2.4.7" | |
| description = "Python parsing module" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" | |
| [[package]] | |
| name = "pytest" | |
| version = "5.4.3" | |
| description = "pytest: simple powerful testing with Python" | |
| category = "dev" | |
| optional = false | |
| python-versions = ">=3.5" | |
| [package.dependencies] | |
| atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} | |
| attrs = ">=17.4.0" | |
| colorama = {version = "*", markers = "sys_platform == \"win32\""} | |
| more-itertools = ">=4.0.0" | |
| packaging = "*" | |
| pluggy = ">=0.12,<1.0" | |
| py = ">=1.5.0" | |
| wcwidth = "*" | |
| [package.extras] | |
| checkqa-mypy = ["mypy (==v0.761)"] | |
| testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] | |
| [[package]] | |
| name = "wcwidth" | |
| version = "0.2.5" | |
| description = "Measures the displayed width of unicode strings in a terminal" | |
| category = "dev" | |
| optional = false | |
| python-versions = "*" | |
| [metadata] | |
| lock-version = "1.1" | |
| python-versions = "^3.8" | |
| content-hash = "c27944f25b55067b06883f1cea204be7d97841a4b8228fab69b91895347494ad" | |
| [metadata.files] | |
| atomicwrites = [ | |
| {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, | |
| {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, | |
| ] | |
| attrs = [ | |
| {file = "attrs-20.3.0-py2.py3-none-any.whl", hash = "sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6"}, | |
| {file = "attrs-20.3.0.tar.gz", hash = "sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700"}, | |
| ] | |
| colorama = [ | |
| {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, | |
| {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, | |
| ] | |
| more-itertools = [ | |
| {file = "more-itertools-8.7.0.tar.gz", hash = "sha256:c5d6da9ca3ff65220c3bfd2a8db06d698f05d4d2b9be57e1deb2be5a45019713"}, | |
| {file = "more_itertools-8.7.0-py3-none-any.whl", hash = "sha256:5652a9ac72209ed7df8d9c15daf4e1aa0e3d2ccd3c87f8265a0673cd9cbc9ced"}, | |
| ] | |
| packaging = [ | |
| {file = "packaging-20.9-py2.py3-none-any.whl", hash = "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"}, | |
| {file = "packaging-20.9.tar.gz", hash = "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5"}, | |
| ] | |
| pluggy = [ | |
| {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, | |
| {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, | |
| ] | |
| py = [ | |
| {file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"}, | |
| {file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"}, | |
| ] | |
| pyparsing = [ | |
| {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, | |
| {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, | |
| ] | |
| pytest = [ | |
| {file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"}, | |
| {file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, | |
| ] | |
| wcwidth = [ | |
| {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, | |
| {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, | |
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Example from https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# tabnet example code" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.preprocessing import LabelEncoder\n", | |
| "from sklearn.metrics import roc_auc_score\n", | |
| "\n", | |
| "from pytorch_tabnet.tab_model import TabNetClassifier\n", | |
| "\n", | |
| "import torch\n", | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "np.random.seed(0)\n", | |
| "\n", | |
| "import os\n", | |
| "import wget\n", | |
| "from pathlib import Path\n", | |
| "\n", | |
| "from matplotlib import pyplot as plt\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", | |
| "dataset_name = 'census-income'\n", | |
| "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "File already exists.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "out.parent.mkdir(parents=True, exist_ok=True)\n", | |
| "if out.exists():\n", | |
| " print(\"File already exists.\")\n", | |
| "else:\n", | |
| " print(\"Downloading file...\")\n", | |
| " wget.download(url, out.as_posix())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train = pd.read_csv(out)\n", | |
| "target = ' <=50K'\n", | |
| "if \"Set\" not in train.columns:\n", | |
| " train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", | |
| "\n", | |
| "train_indices = train[train.Set==\"train\"].index\n", | |
| "valid_indices = train[train.Set==\"valid\"].index\n", | |
| "test_indices = train[train.Set==\"test\"].index" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "39 73\n", | |
| " State-gov 9\n", | |
| " Bachelors 16\n", | |
| " 13 16\n", | |
| " Never-married 7\n", | |
| " Adm-clerical 15\n", | |
| " Not-in-family 6\n", | |
| " White 5\n", | |
| " Male 2\n", | |
| " 2174 119\n", | |
| " 0 92\n", | |
| " 40 94\n", | |
| " United-States 42\n", | |
| " <=50K 2\n", | |
| "Set 3\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "nunique = train.nunique()\n", | |
| "types = train.dtypes\n", | |
| "\n", | |
| "categorical_columns = []\n", | |
| "categorical_dims = {}\n", | |
| "for col in train.columns:\n", | |
| " if types[col] == 'object' or nunique[col] < 200:\n", | |
| " print(col, train[col].nunique())\n", | |
| " l_enc = LabelEncoder()\n", | |
| " train[col] = train[col].fillna(\"VV_likely\")\n", | |
| " train[col] = l_enc.fit_transform(train[col].values)\n", | |
| " categorical_columns.append(col)\n", | |
| " categorical_dims[col] = len(l_enc.classes_)\n", | |
| " else:\n", | |
| " train.fillna(train.loc[train_indices, col].mean(), inplace=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# check that pipeline accepts strings\n", | |
| "train.loc[train[target]==0, target] = \"wealthy\"\n", | |
| "train.loc[train[target]==1, target] = \"not_wealthy\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "unused_feat = ['Set']\n", | |
| "\n", | |
| "features = [ col for col in train.columns if col not in unused_feat+[target]] \n", | |
| "\n", | |
| "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", | |
| "\n", | |
| "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Device used : cuda\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "clf = TabNetClassifier(cat_idxs=cat_idxs,\n", | |
| " cat_dims=cat_dims,\n", | |
| " cat_emb_dim=1,\n", | |
| " optimizer_fn=torch.optim.Adam,\n", | |
| " optimizer_params=dict(lr=2e-2),\n", | |
| " scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n", | |
| " \"gamma\":0.9},\n", | |
| " scheduler_fn=torch.optim.lr_scheduler.StepLR,\n", | |
| " mask_type='entmax' # \"sparsemax\"\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X_train = train[features].values[train_indices]\n", | |
| "y_train = train[target].values[train_indices]\n", | |
| "\n", | |
| "X_valid = train[features].values[valid_indices]\n", | |
| "y_valid = train[target].values[valid_indices]\n", | |
| "\n", | |
| "X_test = train[features].values[test_indices]\n", | |
| "y_test = train[target].values[test_indices]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "epoch 0 | loss: 0.668 | train_auc: 0.75705 | valid_auc: 0.7551 | 0:00:02s\n", | |
| "epoch 1 | loss: 0.52031 | train_auc: 0.81912 | valid_auc: 0.82696 | 0:00:04s\n", | |
| "epoch 2 | loss: 0.47527 | train_auc: 0.84816 | valid_auc: 0.85195 | 0:00:06s\n", | |
| "epoch 3 | loss: 0.45715 | train_auc: 0.86756 | valid_auc: 0.86571 | 0:00:08s\n", | |
| "epoch 4 | loss: 0.43029 | train_auc: 0.88064 | valid_auc: 0.87487 | 0:00:10s\n", | |
| "epoch 5 | loss: 0.41997 | train_auc: 0.89128 | valid_auc: 0.8849 | 0:00:12s\n", | |
| "epoch 6 | loss: 0.40586 | train_auc: 0.898 | valid_auc: 0.88995 | 0:00:14s\n", | |
| "epoch 7 | loss: 0.40141 | train_auc: 0.90266 | valid_auc: 0.89769 | 0:00:16s\n", | |
| "epoch 8 | loss: 0.39187 | train_auc: 0.90459 | valid_auc: 0.8956 | 0:00:18s\n", | |
| "epoch 9 | loss: 0.37791 | train_auc: 0.91019 | valid_auc: 0.90593 | 0:00:20s\n", | |
| "epoch 10 | loss: 0.37631 | train_auc: 0.91394 | valid_auc: 0.90945 | 0:00:22s\n", | |
| "epoch 11 | loss: 0.36412 | train_auc: 0.91093 | valid_auc: 0.90707 | 0:00:24s\n", | |
| "epoch 12 | loss: 0.3587 | train_auc: 0.91243 | valid_auc: 0.90965 | 0:00:26s\n", | |
| "epoch 13 | loss: 0.35557 | train_auc: 0.915 | valid_auc: 0.90905 | 0:00:29s\n", | |
| "epoch 14 | loss: 0.34672 | train_auc: 0.9182 | valid_auc: 0.91487 | 0:00:31s\n", | |
| "epoch 15 | loss: 0.35145 | train_auc: 0.92211 | valid_auc: 0.91805 | 0:00:33s\n", | |
| "epoch 16 | loss: 0.34199 | train_auc: 0.92471 | valid_auc: 0.92013 | 0:00:35s\n", | |
| "epoch 17 | loss: 0.3372 | train_auc: 0.9272 | valid_auc: 0.92226 | 0:00:37s\n", | |
| "epoch 18 | loss: 0.34344 | train_auc: 0.92886 | valid_auc: 0.92452 | 0:00:39s\n", | |
| "epoch 19 | loss: 0.34549 | train_auc: 0.92919 | valid_auc: 0.92233 | 0:00:41s\n", | |
| "epoch 20 | loss: 0.33269 | train_auc: 0.93105 | valid_auc: 0.92654 | 0:00:43s\n", | |
| "epoch 21 | loss: 0.32923 | train_auc: 0.93199 | valid_auc: 0.92505 | 0:00:45s\n", | |
| "epoch 22 | loss: 0.33069 | train_auc: 0.93208 | valid_auc: 0.92693 | 0:00:47s\n", | |
| "epoch 23 | loss: 0.3301 | train_auc: 0.93287 | valid_auc: 0.92766 | 0:00:49s\n", | |
| "epoch 24 | loss: 0.33326 | train_auc: 0.93347 | valid_auc: 0.92745 | 0:00:51s\n", | |
| "epoch 25 | loss: 0.32665 | train_auc: 0.93452 | valid_auc: 0.92802 | 0:00:53s\n", | |
| "epoch 26 | loss: 0.32089 | train_auc: 0.93444 | valid_auc: 0.92747 | 0:00:55s\n", | |
| "epoch 27 | loss: 0.32657 | train_auc: 0.93284 | valid_auc: 0.92749 | 0:00:57s\n", | |
| "epoch 28 | loss: 0.32863 | train_auc: 0.93331 | valid_auc: 0.92529 | 0:00:59s\n", | |
| "epoch 29 | loss: 0.32456 | train_auc: 0.93459 | valid_auc: 0.92775 | 0:01:01s\n", | |
| "epoch 30 | loss: 0.3245 | train_auc: 0.93506 | valid_auc: 0.92776 | 0:01:03s\n", | |
| "epoch 31 | loss: 0.31973 | train_auc: 0.93558 | valid_auc: 0.92732 | 0:01:05s\n", | |
| "epoch 32 | loss: 0.32807 | train_auc: 0.9334 | valid_auc: 0.92574 | 0:01:07s\n", | |
| "epoch 33 | loss: 0.32806 | train_auc: 0.93508 | valid_auc: 0.92774 | 0:01:09s\n", | |
| "epoch 34 | loss: 0.31981 | train_auc: 0.93656 | valid_auc: 0.93014 | 0:01:11s\n", | |
| "epoch 35 | loss: 0.31738 | train_auc: 0.93678 | valid_auc: 0.92766 | 0:01:13s\n", | |
| "epoch 36 | loss: 0.3209 | train_auc: 0.93637 | valid_auc: 0.92766 | 0:01:15s\n", | |
| "epoch 37 | loss: 0.31531 | train_auc: 0.93336 | valid_auc: 0.92297 | 0:01:17s\n", | |
| "epoch 38 | loss: 0.3231 | train_auc: 0.93368 | valid_auc: 0.92438 | 0:01:19s\n", | |
| "epoch 39 | loss: 0.31914 | train_auc: 0.93741 | valid_auc: 0.92685 | 0:01:22s\n", | |
| "epoch 40 | loss: 0.31784 | train_auc: 0.93709 | valid_auc: 0.92647 | 0:01:24s\n", | |
| "epoch 41 | loss: 0.32154 | train_auc: 0.93775 | valid_auc: 0.92521 | 0:01:26s\n", | |
| "epoch 42 | loss: 0.31726 | train_auc: 0.93814 | valid_auc: 0.92743 | 0:01:28s\n", | |
| "epoch 43 | loss: 0.31768 | train_auc: 0.93822 | valid_auc: 0.9265 | 0:01:30s\n", | |
| "epoch 44 | loss: 0.31297 | train_auc: 0.93664 | valid_auc: 0.92333 | 0:01:32s\n", | |
| "epoch 45 | loss: 0.31219 | train_auc: 0.93833 | valid_auc: 0.92682 | 0:01:34s\n", | |
| "epoch 46 | loss: 0.31816 | train_auc: 0.93877 | valid_auc: 0.92526 | 0:01:36s\n", | |
| "epoch 47 | loss: 0.3168 | train_auc: 0.93903 | valid_auc: 0.92521 | 0:01:38s\n", | |
| "epoch 48 | loss: 0.31014 | train_auc: 0.93864 | valid_auc: 0.92364 | 0:01:40s\n", | |
| "epoch 49 | loss: 0.31637 | train_auc: 0.93793 | valid_auc: 0.92628 | 0:01:42s\n", | |
| "epoch 50 | loss: 0.31441 | train_auc: 0.9398 | valid_auc: 0.92782 | 0:01:44s\n", | |
| "epoch 51 | loss: 0.30673 | train_auc: 0.94062 | valid_auc: 0.92624 | 0:01:46s\n", | |
| "epoch 52 | loss: 0.30835 | train_auc: 0.94006 | valid_auc: 0.92509 | 0:01:48s\n", | |
| "epoch 53 | loss: 0.30838 | train_auc: 0.94081 | valid_auc: 0.92882 | 0:01:50s\n", | |
| "epoch 54 | loss: 0.31133 | train_auc: 0.94049 | valid_auc: 0.92622 | 0:01:52s\n", | |
| "\n", | |
| "Early stopping occurred at epoch 54 with best_epoch = 34 and best_valid_auc = 0.93014\n", | |
| "Best weights from best epoch are automatically used!\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "clf.fit(\n", | |
| " X_train=X_train, y_train=y_train,\n", | |
| " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", | |
| " eval_name=['train', 'valid'],\n", | |
| " eval_metric=['auc'],\n", | |
| " max_epochs=max_epochs , patience=20,\n", | |
| " batch_size=1024, virtual_batch_size=128,\n", | |
| " num_workers=0,\n", | |
| " weights=1,\n", | |
| " drop_last=False\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "raw", | |
| "metadata": {}, | |
| "source": [ | |
| "y_pred = clf.predict(X_train)\n", | |
| "y_pred_enc = label_encoder.transform(y_pred)\n", | |
| "roc_auc_score(y_train_enc, y_pred_enc)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# skorch tabnet port" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 61, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import sklearn\n", | |
| "\n", | |
| "import skorch\n", | |
| "from skorch.helper import predefined_split\n", | |
| "\n", | |
| "import pytorch_tabnet\n", | |
| "from pytorch_tabnet.multiclass_utils import infer_output_dim\n", | |
| "from pytorch_tabnet.tab_network import TabNet\n", | |
| "from pytorch_tabnet.utils import (\n", | |
| " PredictDataset,\n", | |
| " create_explain_matrix,\n", | |
| " validate_eval_set,\n", | |
| " create_dataloaders,\n", | |
| " define_device,\n", | |
| " ComplexEncoder,\n", | |
| ")\n", | |
| "\n", | |
| "from torch.nn import CrossEntropyLoss\n", | |
| "\n", | |
| "from scipy.sparse import csc_matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 62, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class SkorchTabModel(skorch.NeuralNet):\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " criterion,\n", | |
| " module=TabNet, \n", | |
| " module__input_dim=100, \n", | |
| " module__output_dim=5,\n", | |
| " **kwargs,\n", | |
| " ):\n", | |
| " super().__init__(\n", | |
| " module,\n", | |
| " criterion,\n", | |
| " module__input_dim=module__input_dim,\n", | |
| " module__output_dim=module__output_dim,\n", | |
| " **kwargs,\n", | |
| " )\n", | |
| " \n", | |
| " def initialize_module(self):\n", | |
| " \"\"\"Setup the network and explain matrix.\"\"\"\n", | |
| " kwargs = self.get_params_for('module')\n", | |
| "\n", | |
| " self.module_ = TabNet(**kwargs).to(self.device)\n", | |
| "\n", | |
| " self.reducing_matrix_ = create_explain_matrix(\n", | |
| " self.module_.input_dim,\n", | |
| " self.module_.cat_emb_dim,\n", | |
| " self.module_.cat_idxs,\n", | |
| " self.module_.post_embed_dim,\n", | |
| " )\n", | |
| " \n", | |
| " def compute_feature_importances(self, X):\n", | |
| " \"\"\"Compute global feature importance.\"\"\" \n", | |
| " feature_importances_ = np.zeros((self.module_.post_embed_dim))\n", | |
| " \n", | |
| " for (M_explain, masks) in self.forward_masks_iter(X):\n", | |
| " feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()\n", | |
| "\n", | |
| " feature_importances_ = csc_matrix.dot(\n", | |
| " feature_importances_, self.reducing_matrix_,\n", | |
| " )\n", | |
| " return feature_importances_ / np.sum(feature_importances_)\n", | |
| " \n", | |
| " def on_train_end(self, net, X, **kwargs):\n", | |
| " self.feature_importances_ = self.compute_feature_importances(X)\n", | |
| " super().on_train_end(net, X=X, **kwargs)\n", | |
| " \n", | |
| " def forward_masks_iter(self, X, training=False, device='cpu'):\n", | |
| " dataset = self.get_dataset(X)\n", | |
| " iterator = self.get_iterator(dataset, training=training)\n", | |
| " for data in iterator:\n", | |
| " Xi = skorch.dataset.unpack_data(data)[0]\n", | |
| " Xi = skorch.utils.to_device(Xi, self.device)\n", | |
| " with torch.set_grad_enabled(False):\n", | |
| " yp = self.module_.forward_masks(Xi)\n", | |
| " yield skorch.utils.to_device(yp, device=device)\n", | |
| " \n", | |
| " def explain(self, X):\n", | |
| " \"\"\"\n", | |
| " Return local explanation\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " X : tensor: `torch.Tensor`\n", | |
| " Input data\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " M_explain : matrix\n", | |
| " Importance per sample, per columns.\n", | |
| " masks : matrix\n", | |
| " Sparse matrix showing attention masks used by network.\n", | |
| " \"\"\"\n", | |
| " res_explain = []\n", | |
| " \n", | |
| " for i, (M_explain, masks) in enumerate(self.forward_masks_iter(X)):\n", | |
| " for key, value in masks.items():\n", | |
| " masks[key] = csc_matrix.dot(\n", | |
| " value.cpu().detach().numpy(), self.reducing_matrix_\n", | |
| " )\n", | |
| "\n", | |
| " res_explain.append(\n", | |
| " csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix_)\n", | |
| " )\n", | |
| "\n", | |
| " if i == 0:\n", | |
| " res_masks = masks\n", | |
| " else:\n", | |
| " for key, value in masks.items():\n", | |
| " res_masks[key] = np.vstack([res_masks[key], value])\n", | |
| " \n", | |
| " res_explain = np.vstack(res_explain)\n", | |
| " return res_explain, res_masks\n", | |
| " \n", | |
| " def predict(self, X):\n", | |
| " y_proba = self.predict_proba(X)\n", | |
| " return y_proba.argmax(-1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 63, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "LabelEncoder()" | |
| ] | |
| }, | |
| "execution_count": 63, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "label_encoder = LabelEncoder()\n", | |
| "label_encoder.fit(y_train)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 64, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "y_train_enc = label_encoder.transform(y_train)\n", | |
| "y_valid_enc = label_encoder.transform(y_valid)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 65, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CrossEntropySparsityLoss(torch.nn.CrossEntropyLoss):\n", | |
| " def __init__(self, lambda_sparse=1e-3):\n", | |
| " super().__init__()\n", | |
| " self.lambda_sparse = lambda_sparse\n", | |
| " \n", | |
| " def forward(self, y_pred, y_true):\n", | |
| " output, M_loss = y_pred\n", | |
| "\n", | |
| " loss = super().forward(output, y_true)\n", | |
| " \n", | |
| " # Add the overall sparsity loss\n", | |
| " loss -= self.lambda_sparse * M_loss\n", | |
| " \n", | |
| " return loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 66, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "skorch_clf = SkorchTabModel(\n", | |
| " criterion=CrossEntropySparsityLoss,\n", | |
| " \n", | |
| " module__input_dim=X_train.shape[-1],\n", | |
| " module__output_dim=infer_output_dim(y_train)[0],\n", | |
| " module__cat_idxs=cat_idxs,\n", | |
| " module__cat_dims=cat_dims,\n", | |
| " module__cat_emb_dim=1,\n", | |
| " module__mask_type='entmax', # \"sparsemax\"\n", | |
| " module__virtual_batch_size=128,\n", | |
| " \n", | |
| " optimizer=torch.optim.Adam,\n", | |
| " optimizer__lr=2e-2,\n", | |
| " \n", | |
| " batch_size=1024,\n", | |
| " iterator_train__num_workers=0,\n", | |
| " iterator_train__drop_last=False,\n", | |
| " iterator_valid__num_workers=0,\n", | |
| " iterator_valid__drop_last=False,\n", | |
| " \n", | |
| " callbacks=[\n", | |
| " skorch.callbacks.LRScheduler(\n", | |
| " policy=torch.optim.lr_scheduler.StepLR,\n", | |
| " step_size=50,\n", | |
| " gamma=0.9,\n", | |
| " ),\n", | |
| " skorch.callbacks.EarlyStopping(patience=20),\n", | |
| " skorch.callbacks.GradientNormClipping(gradient_clip_value=1.),\n", | |
| " skorch.callbacks.EpochScoring('roc_auc'),\n", | |
| " ],\n", | |
| " train_split=predefined_split(skorch.dataset.Dataset(X_valid, y_valid_enc)),\n", | |
| " max_epochs=max_epochs,\n", | |
| " \n", | |
| " device='cuda',\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 67, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "skorch_clf.initialize()\n", | |
| "skorch_clf.load_params('ble.pt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 68, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Automatic pdb calling has been turned ON\n", | |
| "Re-initializing optimizer because the following parameters were re-set: lr.\n", | |
| " epoch roc_auc train_loss valid_loss lr dur\n", | |
| "------- --------- ------------ ------------ ------ ------\n", | |
| " 1 \u001b[36m0.7798\u001b[0m \u001b[32m0.5105\u001b[0m \u001b[35m0.6061\u001b[0m 0.0200 1.6014\n", | |
| " 2 0.8398 \u001b[32m0.4094\u001b[0m \u001b[35m0.5071\u001b[0m 0.0200 1.6054\n", | |
| " 3 0.8557 \u001b[32m0.3883\u001b[0m \u001b[35m0.4586\u001b[0m 0.0200 1.5067\n", | |
| " 4 0.8740 \u001b[32m0.3762\u001b[0m \u001b[35m0.4217\u001b[0m 0.0200 1.5827\n", | |
| " 5 0.8743 \u001b[32m0.3664\u001b[0m \u001b[35m0.3860\u001b[0m 0.0200 1.5149\n", | |
| " 6 0.8841 \u001b[32m0.3616\u001b[0m \u001b[35m0.3518\u001b[0m 0.0200 1.5379\n", | |
| " 7 0.8913 \u001b[32m0.3562\u001b[0m \u001b[35m0.3398\u001b[0m 0.0200 1.5068\n", | |
| " 8 0.8944 \u001b[32m0.3425\u001b[0m 0.3400 0.0200 1.5603\n", | |
| " 9 0.8916 \u001b[32m0.3336\u001b[0m 0.4340 0.0200 1.5260\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<class '__main__.SkorchTabModel'>[initialized](\n", | |
| " module_=TabNet(\n", | |
| " (embedder): EmbeddingGenerator(\n", | |
| " (embeddings): ModuleList(\n", | |
| " (0): Embedding(73, 1)\n", | |
| " (1): Embedding(9, 1)\n", | |
| " (2): Embedding(16, 1)\n", | |
| " (3): Embedding(16, 1)\n", | |
| " (4): Embedding(7, 1)\n", | |
| " (5): Embedding(15, 1)\n", | |
| " (6): Embedding(6, 1)\n", | |
| " (7): Embedding(5, 1)\n", | |
| " (8): Embedding(2, 1)\n", | |
| " (9): Embedding(119, 1)\n", | |
| " (10): Embedding(92, 1)\n", | |
| " (11): Embedding(94, 1)\n", | |
| " (12): Embedding(42, 1)\n", | |
| " )\n", | |
| " )\n", | |
| " (tabnet): TabNetNoEmbeddings(\n", | |
| " (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n", | |
| " (encoder): TabNetEncoder(\n", | |
| " (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n", | |
| " (initial_splitter): FeatTransformer(\n", | |
| " (shared): GLU_Block(\n", | |
| " (shared_layers): ModuleList(\n", | |
| " (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " )\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (specifics): GLU_Block(\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (feat_transformers): ModuleList(\n", | |
| " (0): FeatTransformer(\n", | |
| " (shared): GLU_Block(\n", | |
| " (shared_layers): ModuleList(\n", | |
| " (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " )\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (specifics): GLU_Block(\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (1): FeatTransformer(\n", | |
| " (shared): GLU_Block(\n", | |
| " (shared_layers): ModuleList(\n", | |
| " (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " )\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (specifics): GLU_Block(\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (2): FeatTransformer(\n", | |
| " (shared): GLU_Block(\n", | |
| " (shared_layers): ModuleList(\n", | |
| " (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " )\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (specifics): GLU_Block(\n", | |
| " (glu_layers): ModuleList(\n", | |
| " (0): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): GLU_Layer(\n", | |
| " (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (att_transformers): ModuleList(\n", | |
| " (0): AttentiveTransformer(\n", | |
| " (fc): Linear(in_features=8, out_features=14, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " (selector): Entmax15()\n", | |
| " )\n", | |
| " (1): AttentiveTransformer(\n", | |
| " (fc): Linear(in_features=8, out_features=14, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " (selector): Entmax15()\n", | |
| " )\n", | |
| " (2): AttentiveTransformer(\n", | |
| " (fc): Linear(in_features=8, out_features=14, bias=False)\n", | |
| " (bn): GBN(\n", | |
| " (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
| " )\n", | |
| " (selector): Entmax15()\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (final_mapping): Linear(in_features=8, out_features=2, bias=False)\n", | |
| " )\n", | |
| " ),\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 68, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%pdb on\n", | |
| "skorch_clf.fit(\n", | |
| " X_train, \n", | |
| " y_train_enc,\n", | |
| " #weights=1,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 69, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.8094909975251289" | |
| ] | |
| }, | |
| "execution_count": 69, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sklearn.metrics.roc_auc_score(y_train_enc, skorch_clf.predict(X_train))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 70, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.809482523487759" | |
| ] | |
| }, | |
| "execution_count": 70, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sklearn.metrics.roc_auc_score(y_valid_enc, skorch_clf.predict(X_valid))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 71, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array([[0.00000000e+00, 3.16206515e-01, 0.00000000e+00, ...,\n", | |
| " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
| " [0.00000000e+00, 4.00576532e-01, 0.00000000e+00, ...,\n", | |
| " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
| " [0.00000000e+00, 4.18307304e-01, 0.00000000e+00, ...,\n", | |
| " 0.00000000e+00, 1.17498322e-03, 0.00000000e+00],\n", | |
| " ...,\n", | |
| " [0.00000000e+00, 3.26685965e-01, 0.00000000e+00, ...,\n", | |
| " 1.23882025e-01, 5.39707905e-03, 0.00000000e+00],\n", | |
| " [1.30514791e-02, 2.11704329e-01, 0.00000000e+00, ...,\n", | |
| " 1.28650689e+00, 3.36281955e-03, 1.37906140e-02],\n", | |
| " [1.27431333e-01, 1.56031281e-01, 1.24305105e-02, ...,\n", | |
| " 2.07836628e-01, 1.33876745e-02, 9.61675271e-02]]),\n", | |
| " {0: array([[0.00000000e+00, 9.32369530e-02, 0.00000000e+00, ...,\n", | |
| " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
| " [0.00000000e+00, 4.57307428e-01, 0.00000000e+00, ...,\n", | |
| " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
| " [0.00000000e+00, 6.30757153e-01, 0.00000000e+00, ...,\n", | |
| " 0.00000000e+00, 1.77173363e-03, 0.00000000e+00],\n", | |
| " ...,\n", | |
| " [0.00000000e+00, 1.68548390e-01, 0.00000000e+00, ...,\n", | |
| " 6.39149472e-02, 2.78453645e-03, 0.00000000e+00],\n", | |
| " [2.71768179e-02, 1.00498842e-02, 0.00000000e+00, ...,\n", | |
| " 1.84161370e-04, 7.00232806e-03, 2.87159029e-02],\n", | |
| " [1.20953389e-01, 1.07856750e-01, 1.24875447e-02, ...,\n", | |
| " 1.23726524e-01, 1.34490998e-02, 9.66087654e-02]]),\n", | |
| " 1: array([[0. , 0. , 0. , ..., 0. , 0. ,\n", | |
| " 0. ],\n", | |
| " [0. , 0. , 0. , ..., 0. , 0. ,\n", | |
| " 0. ],\n", | |
| " [0. , 0. , 0. , ..., 0. , 0. ,\n", | |
| " 0. ],\n", | |
| " ...,\n", | |
| " [0. , 0. , 0. , ..., 0. , 0. ,\n", | |
| " 0. ],\n", | |
| " [0. , 0.02491164, 0. , ..., 0.15490678, 0. ,\n", | |
| " 0. ],\n", | |
| " [0.01362148, 0.09429348, 0. , ..., 0.1640597 , 0. ,\n", | |
| " 0. ]]),\n", | |
| " 2: array([[0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ...,\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]])})" | |
| ] | |
| }, | |
| "execution_count": 71, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "skorch_clf.explain(X_valid)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "raw", | |
| "metadata": {}, | |
| "source": [ | |
| "skorch_clf.save_params(f_params='ble.pt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment