Created
October 30, 2024 15:48
-
-
Save andreasgrv/8a30363bdb7ab42db486ab7873f24774 to your computer and use it in GitHub Desktop.
Softmax-Permutohedron.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyOJOOYZxOQxu9pSQduDkRIT", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/andreasgrv/8a30363bdb7ab42db486ab7873f24774/softmax-permutohedron.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Low-Rank Softmax and Shadows of the Permutohedron\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "## Problem: When N > D+1, some rankings cannot be predicted.\n" | |
| ], | |
| "metadata": { | |
| "id": "muKYMKF2_yZI" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import numpy as np\n", | |
| "\n", | |
| "from math import factorial\n", | |
| "from scipy.special import softmax\n", | |
| "\n", | |
| "\n", | |
| "np.set_printoptions(precision=2)\n", | |
| "\n", | |
| "\n", | |
| "N_SAMPLES = 50000\n", | |
| "N = 4\n", | |
| "D = 2\n", | |
| "\n", | |
| "# Feature representations\n", | |
| "xx = np.random.normal(0, 1, (N_SAMPLES, D))\n", | |
| "\n", | |
| "## ==== The parameters of the Linear Softmax Layer ====\n", | |
| "W = np.random.normal(0, 1, (N, D))\n", | |
| "\n", | |
| "print('\\nSoftmax W is:\\n', W, '\\n')\n", | |
| "\n", | |
| "logits = xx.dot(W.T)\n", | |
| "probs = softmax(logits, axis=1)\n", | |
| "\n", | |
| "rankings = np.argsort(probs, axis=1)\n", | |
| "# Compute which rank\n", | |
| "pred_rankings = set(tuple(r) for r in rankings)\n", | |
| "print('%d/%d rankings were predicted' % (len(pred_rankings), factorial(N)))\n", | |
| "print('The predicted rankings of classes are:')\n", | |
| "for r in sorted(pred_rankings):\n", | |
| " print('\\t', r)\n" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "4Oq_08oh_8sl", | |
| "outputId": "d3acb39c-fe18-4c48-c4dc-ddb26018f0fb" | |
| }, | |
| "execution_count": 27, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\n", | |
| "Softmax W is:\n", | |
| " [[ 0.99 0.28]\n", | |
| " [-1.43 0.57]\n", | |
| " [-0.57 -1.51]\n", | |
| " [-0.24 0.82]] \n", | |
| "\n", | |
| "12/24 rankings were predicted\n", | |
| "The predicted rankings of classes are:\n", | |
| "\t (0, 2, 3, 1)\n", | |
| "\t (0, 3, 1, 2)\n", | |
| "\t (0, 3, 2, 1)\n", | |
| "\t (1, 2, 3, 0)\n", | |
| "\t (1, 3, 0, 2)\n", | |
| "\t (1, 3, 2, 0)\n", | |
| "\t (2, 0, 1, 3)\n", | |
| "\t (2, 0, 3, 1)\n", | |
| "\t (2, 1, 0, 3)\n", | |
| "\t (2, 1, 3, 0)\n", | |
| "\t (3, 0, 1, 2)\n", | |
| "\t (3, 1, 0, 2)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Which Rankings are Feasible?\n", | |
| "\n", | |
| "" | |
| ], | |
| "metadata": { | |
| "id": "6-VDyJ9cFrsb" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from scipy.spatial import ConvexHull\n", | |
| "from itertools import permutations\n", | |
| "\n", | |
| "\n", | |
| "def permutohedron(n):\n", | |
| " return np.array(list(permutations(np.arange(n))))\n", | |
| "\n", | |
| "\n", | |
| "pp = permutohedron(N)\n", | |
| "pp_proj = (pp).dot(W)\n", | |
| "\n", | |
| "c_hull = ConvexHull(pp_proj)\n", | |
| "surviving_vertices = np.argsort(pp[c_hull.vertices], axis=1)\n", | |
| "\n", | |
| "proj_rankings = set(tuple(p) for p in surviving_vertices)\n", | |
| "print('\\nSurviving vertices of the permutohedron, are:')\n", | |
| "for p in sorted(proj_rankings):\n", | |
| " print('\\t', p)\n", | |
| "\n", | |
| "\n", | |
| "assert proj_rankings == pred_rankings" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "ZDUQIKZhAgUc", | |
| "outputId": "8b5f4b9d-5275-4061-9fc8-72d1cbb9a0a1" | |
| }, | |
| "execution_count": 28, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\n", | |
| "Surviving vertices of the permutohedron, are:\n", | |
| "\t (0, 2, 3, 1)\n", | |
| "\t (0, 3, 1, 2)\n", | |
| "\t (0, 3, 2, 1)\n", | |
| "\t (1, 2, 3, 0)\n", | |
| "\t (1, 3, 0, 2)\n", | |
| "\t (1, 3, 2, 0)\n", | |
| "\t (2, 0, 1, 3)\n", | |
| "\t (2, 0, 3, 1)\n", | |
| "\t (2, 1, 0, 3)\n", | |
| "\t (2, 1, 3, 0)\n", | |
| "\t (3, 0, 1, 2)\n", | |
| "\t (3, 1, 0, 2)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "fig, ax = plt.subplots(figsize=(10, 10))\n", | |
| "ax.scatter(*pp_proj.T, s=50, color='red')\n", | |
| "ax.scatter(*pp_proj[c_hull.vertices].T, s=50, color='green')\n", | |
| "for idx in c_hull.vertices:\n", | |
| " ax.text(*pp_proj[idx], ', '.join(map(str, np.argsort(pp[idx]))))\n", | |
| "plt.show()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 830 | |
| }, | |
| "id": "MuUSus1nCpku", | |
| "outputId": "e6d706b9-41d0-4fc1-d9b7-71d880e004b6" | |
| }, | |
| "execution_count": 29, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 1000x1000 with 1 Axes>" | |
| ], | |
| "image/png": "\n" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "https://viz.unargmaxable.ai/softmax/" | |
| ], | |
| "metadata": { | |
| "id": "NWcLLp4UC7Bw" | |
| } | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment