Created
September 8, 2021 08:25
-
-
Save mlaves/c98cd4e6bcb9dbd4d0c03b34bacb0f65 to your computer and use it in GitHub Desktop.
Kernel SVM.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.8" | |
| }, | |
| "colab": { | |
| "name": "Kernel SVM.ipynb", | |
| "provenance": [], | |
| "include_colab_link": true | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/mlaves/c98cd4e6bcb9dbd4d0c03b34bacb0f65/kernel-svm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "db32bdaf" | |
| }, | |
| "source": [ | |
| "# Kernel Support-Vector Machine in PyTorch\n", | |
| "\n", | |
| "In this notebook, we will implement support-vector machines for classification with linear and Gaussian radial basis function kernels." | |
| ], | |
| "id": "db32bdaf" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "24654e3c" | |
| }, | |
| "source": [ | |
| "from matplotlib import pyplot as plt\n", | |
| "import seaborn as sns\n", | |
| "import torch\n", | |
| "import numpy as np\n", | |
| "from sklearn.datasets import make_moons\n", | |
| "sns.set()\n", | |
| "np.random.seed(0)" | |
| ], | |
| "id": "24654e3c", | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "8deeac3c" | |
| }, | |
| "source": [ | |
| "## Training Data\n", | |
| "\n", | |
| "We use the simple two-moons toy dataset showing two interleaving half circles." | |
| ], | |
| "id": "8deeac3c" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 284 | |
| }, | |
| "id": "f33a3ed6", | |
| "outputId": "7e8f4527-a850-4e5e-c74a-0d4331d221dd" | |
| }, | |
| "source": [ | |
| "x, y = make_moons(20, noise=0.1)\n", | |
| "y[np.where(y==0)] = -1\n", | |
| "\n", | |
| "fig, ax = plt.subplots()\n", | |
| "ax.scatter(x[np.where(y==-1),0], x[np.where(y==-1),1], label='Class 1')\n", | |
| "ax.scatter(x[np.where(y==1),0], x[np.where(y==1),1], label='Class 2')\n", | |
| "ax.set_title('Training data')\n", | |
| "ax.legend();\n", | |
| "x = torch.FloatTensor(x)\n", | |
| "y = torch.FloatTensor(y)" | |
| ], | |
| "id": "f33a3ed6", | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "da37fc6e" | |
| }, | |
| "source": [ | |
| "## The Hinge Loss\n", | |
| "\n", | |
| "The SVM will be trained using the hinge loss\n", | |
| "$$\\min_{\\mathbf{w}} \\left[ \\frac{1}{n} \\sum_{i=1}^{n} \\max\\left(0, 1-y_{i}(\\mathbf{w}^{T}\\mathbf{x}_{i}-b)\\right) \\right] .$$" | |
| ], | |
| "id": "da37fc6e" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "115d16fe" | |
| }, | |
| "source": [ | |
| "def hinge_loss(x, y):\n", | |
| " return torch.max(torch.zeros_like(y), 1-y*x).mean()" | |
| ], | |
| "id": "115d16fe", | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "dfe49048" | |
| }, | |
| "source": [ | |
| "class KernelSVM(torch.nn.Module):\n", | |
| " def __init__(self, train_data_x, kernel='rbf',\n", | |
| " gamma_init=1.0, train_gamma=True):\n", | |
| " super().__init__()\n", | |
| " assert kernel in ['linear', 'rbf']\n", | |
| " self._train_data_x = train_data_x\n", | |
| " \n", | |
| " if kernel == 'linear':\n", | |
| " self._kernel = self.linear\n", | |
| " self._num_c = 2\n", | |
| " elif kernel == 'rbf':\n", | |
| " self._kernel = self.rbf\n", | |
| " self._num_c = x.size(0)\n", | |
| " self._gamma = torch.nn.Parameter(torch.FloatTensor([gamma_init]),\n", | |
| " requires_grad=train_gamma)\n", | |
| " else:\n", | |
| " assert False\n", | |
| " \n", | |
| " self._w = torch.nn.Linear(in_features=self._num_c, out_features=1)\n", | |
| "\n", | |
| " def rbf(self, x, gamma=1):\n", | |
| " y = self._train_data_x.repeat(x.size(0), 1, 1)\n", | |
| " return torch.exp(-self._gamma*((x[:,None]-y)**2).sum(dim=2))\n", | |
| " \n", | |
| " @staticmethod\n", | |
| " def linear(x):\n", | |
| " return x\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " y = self._kernel(x)\n", | |
| " y = self._w(y)\n", | |
| " return y" | |
| ], | |
| "id": "dfe49048", | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "8de892eb" | |
| }, | |
| "source": [ | |
| "model_linear = KernelSVM(x, kernel='linear')\n", | |
| "model_kernel = KernelSVM(x, kernel='rbf')\n", | |
| "opt_linear = torch.optim.SGD(model_linear.parameters(), lr=0.1)\n", | |
| "opt_kernel = torch.optim.SGD(model_kernel.parameters(), lr=0.1)" | |
| ], | |
| "id": "8de892eb", | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "b2ba690b", | |
| "outputId": "89d15f2b-0e00-4f36-b58b-6dad80747ce9" | |
| }, | |
| "source": [ | |
| "for i in range(1000):\n", | |
| " opt_linear.zero_grad()\n", | |
| " opt_kernel.zero_grad()\n", | |
| " \n", | |
| " pred_linear = model_linear(x)\n", | |
| " loss_linear = hinge_loss(pred_linear, y.unsqueeze(1))\n", | |
| " pred_kernel = model_kernel(x)\n", | |
| " loss_kernel = hinge_loss(pred_kernel, y.unsqueeze(1))\n", | |
| " \n", | |
| " loss_linear.backward()\n", | |
| " opt_linear.step()\n", | |
| " loss_kernel.backward()\n", | |
| " opt_kernel.step()\n", | |
| "\n", | |
| "print(\"loss linear model\", loss_linear.item())\n", | |
| "print(\"loss kernel model\", loss_kernel.item())" | |
| ], | |
| "id": "b2ba690b", | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "loss linear model 0.34402042627334595\n", | |
| "loss kernel model 0.0\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "f41f3df9" | |
| }, | |
| "source": [ | |
| "grid_x, grid_y = torch.meshgrid(torch.arange(x.min()*1.1, x.max()*1.1, step=0.1),\n", | |
| " torch.arange(x.min()*1.1, x.max()*1.1, step=0.1))\n", | |
| "x_test = torch.stack((grid_x, grid_y)).reshape(2, -1).transpose(1,0)\n", | |
| "\n", | |
| "y_test_linear = model_linear(x_test).detach()\n", | |
| "y_test_kernel = model_kernel(x_test).detach()\n", | |
| "\n", | |
| "y_test_linear = y_test_linear.transpose(1,0).reshape(grid_x.shape).numpy()\n", | |
| "y_test_kernel = y_test_kernel.transpose(1,0).reshape(grid_x.shape).numpy()" | |
| ], | |
| "id": "f41f3df9", | |
| "execution_count": 7, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 229 | |
| }, | |
| "id": "61712d80", | |
| "outputId": "5ce4eaed-1fcc-4c02-d552-8a4c26202b29" | |
| }, | |
| "source": [ | |
| "fig, ax = plt.subplots(1,2, figsize=(8,3))\n", | |
| "\n", | |
| "cs0 = ax[0].contourf(grid_x.numpy(), grid_y.numpy(), y_test_linear)\n", | |
| "ax[0].contour(cs0, '--', levels=[0], colors='tab:green', linewidths=2)\n", | |
| "ax[0].plot(np.nan, label='decision boundary', color='tab:green')\n", | |
| "ax[0].scatter(x[np.where(y==-1),0], x[np.where(y==-1),1])\n", | |
| "ax[0].scatter(x[np.where(y==1),0], x[np.where(y==1),1])\n", | |
| "ax[0].legend()\n", | |
| "ax[0].set_title('Linear Kernel')\n", | |
| "\n", | |
| "cs1 = ax[1].contourf(grid_x.numpy(), grid_y.numpy(), y_test_kernel)\n", | |
| "cs11 = ax[1].contour(cs1, '--', levels=[0], colors='tab:green', linewidths=2)\n", | |
| "ax[1].plot(np.nan, label='decision boundary', color='tab:green')\n", | |
| "ax[1].scatter(x[np.where(y==-1),0], x[np.where(y==-1),1])\n", | |
| "ax[1].scatter(x[np.where(y==1),0], x[np.where(y==1),1])\n", | |
| "ax[1].set_title('RBF Kernel')\n", | |
| "\n", | |
| "fig.subplots_adjust(wspace=0.2, hspace=0.1,right=0.8)\n", | |
| "cbar_ax = fig.add_axes([0.82, 0.13, 0.02, 0.67])\n", | |
| "cbar = fig.colorbar(cs1, cax=cbar_ax)\n", | |
| "cbar.add_lines(cs11)" | |
| ], | |
| "id": "61712d80", | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 576x216 with 3 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "b39fbbc4" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "id": "b39fbbc4", | |
| "execution_count": 8, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment