Created
April 30, 2018 23:20
-
-
Save dnlcrl/9e9ee03105a8dc60a5ae7dced0837ee4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 1., 2.],\n", | |
| " [ 3., 4., 5.]])" | |
| ] | |
| }, | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import torch\n", | |
| "a = torch.arange(6).reshape(2, 3)\n", | |
| "a" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "TRANSPOSE" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.73 µs ± 75.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 3.],\n", | |
| " [ 1., 4.],\n", | |
| " [ 2., 5.]])" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ij->ji', [a])\n", | |
| "torch.einsum('ij->ji', [a])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.22 µs ± 29.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 3.],\n", | |
| " [ 1., 4.],\n", | |
| " [ 2., 5.]])" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit a.t()\n", | |
| "a.t()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "SUM" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7.31 µs ± 208 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(15.)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ij->', [a])\n", | |
| "torch.einsum('ij->', [a])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.58 µs ± 44.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(15.)" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit a.sum()\n", | |
| "a.sum()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "COLUMN SUM" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "5.42 µs ± 49.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([ 3., 5., 7.])" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ij->j', [a])\n", | |
| "torch.einsum('ij->j', [a])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.42 µs ± 64.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([ 3., 5., 7.])" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.sum(a, 0)\n", | |
| "torch.sum(a, 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "ROW SUM" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.81 µs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([ 3., 12.])" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ij->i', [a])\n", | |
| "torch.einsum('ij->i', [a])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.92 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([ 3., 12.])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.sum(a, 1)\n", | |
| "torch.sum(a, 1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "MATRIX-MATRIX MULTIPLICATION" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.arange(6).reshape(2, 3)\n", | |
| "b = torch.arange(15).reshape(3, 5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "23.5 µs ± 715 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 25., 28., 31., 34., 37.],\n", | |
| " [ 70., 82., 94., 106., 118.]])" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ik,kj->ij', [a, b])\n", | |
| "torch.einsum('ik,kj->ij', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.77 µs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 25., 28., 31., 34., 37.],\n", | |
| " [ 70., 82., 94., 106., 118.]])" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.mm(a, b)\n", | |
| "torch.mm(a, b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "MATRIX-VECTOR MULTIPLICATION" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.arange(6).reshape(2, 3)\n", | |
| "b = torch.arange(3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "19.3 µs ± 799 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([ 5., 14.])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ik,k->i', [a, b])\n", | |
| "torch.einsum('ik,k->i', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.51 µs ± 20.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([ 5., 14.])" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.mv(a,b)\n", | |
| "torch.mv(a,b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "DOT PRODUCT" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "vector" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.arange(3)\n", | |
| "b = torch.arange(3,6) # -- a vector of length 3 containing [3, 4, 5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "17.1 µs ± 913 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(14.)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('i,i->', [a, b])\n", | |
| "torch.einsum('i,i->', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.92 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(14.)" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.dot(a,b)\n", | |
| "torch.dot(a,b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.arange(6).reshape(2, 3)\n", | |
| "b = torch.arange(6,12).reshape(2, 3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "18.6 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(145.)" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ij,ij->', [a, b])\n", | |
| "torch.einsum('ij,ij->', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "6.13 µs ± 85.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(145.)" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.dot(a.view(-1),b.view(-1))\n", | |
| "torch.dot(a.view(-1),b.view(-1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "HADAMARD PRODUCT" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.arange(6).reshape(2, 3)\n", | |
| "b = torch.arange(6,12).reshape(2, 3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "5.56 µs ± 72.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 7., 16.],\n", | |
| " [ 27., 40., 55.]])" | |
| ] | |
| }, | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ij,ij->ij', [a, b])\n", | |
| "torch.einsum('ij,ij->ij', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.31 µs ± 14.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 7., 16.],\n", | |
| " [ 27., 40., 55.]])" | |
| ] | |
| }, | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit a*b\n", | |
| "a*b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "OUTER PRODUCT" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.arange(3)\n", | |
| "b = torch.arange(3,7) # -- a vector of length 4 containing [3, 4, 5, 6]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "10.8 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 0., 0., 0.],\n", | |
| " [ 3., 4., 5., 6.],\n", | |
| " [ 6., 8., 10., 12.]])" | |
| ] | |
| }, | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('i,j->ij', [a, b])\n", | |
| "torch.einsum('i,j->ij', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.64 µs ± 51 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0., 0., 0., 0.],\n", | |
| " [ 3., 4., 5., 6.],\n", | |
| " [ 6., 8., 10., 12.]])" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.ger(a, b)\n", | |
| "torch.ger(a, b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "BATCH MATRIX MULTIPLICATION" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.randn(3,2,5)\n", | |
| "b = torch.randn(3,5,3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "24 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[-3.6068, 3.6341, 3.4859],\n", | |
| " [ 2.3148, 2.5504, 3.8194]],\n", | |
| "\n", | |
| " [[ 2.3448, 2.5390, -0.1359],\n", | |
| " [ 3.4580, 3.4026, 0.0316]],\n", | |
| "\n", | |
| " [[-2.1875, -3.7540, 4.1446],\n", | |
| " [ 1.5737, -0.2249, -0.2547]]])" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ijk,ikl->ijl', [a, b])\n", | |
| "torch.einsum('ijk,ikl->ijl', [a, b])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.81 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[-3.6068, 3.6341, 3.4859],\n", | |
| " [ 2.3148, 2.5504, 3.8194]],\n", | |
| "\n", | |
| " [[ 2.3448, 2.5390, -0.1359],\n", | |
| " [ 3.4580, 3.4026, 0.0316]],\n", | |
| "\n", | |
| " [[-2.1875, -3.7540, 4.1446],\n", | |
| " [ 1.5737, -0.2249, -0.2547]]])" | |
| ] | |
| }, | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit a.bmm(b)\n", | |
| "a.bmm(b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "TENSOR MULTIPLICATION" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.randn(2,3,5,7)\n", | |
| "b = torch.randn(11,13,3,17,5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "210 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 7, 11, 13, 17])" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape\n", | |
| "torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "83.7 µs ± 5.52 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 7, 11, 13, 17])" | |
| ] | |
| }, | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n", | |
| "b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17).shape\n", | |
| "\n", | |
| "torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n", | |
| "b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17).shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(1, dtype=torch.uint8)" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "(torch.einsum('pqrs,tuqvr->pstuv', [a, b]) == torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n", | |
| "b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17)).all()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "BILINEAR TRANSFORMATION" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a = torch.randn(2,3)\n", | |
| "b = torch.randn(5,3,7)\n", | |
| "c = torch.randn(2,7)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "51.3 µs ± 2.25 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n", | |
| " [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])" | |
| ] | |
| }, | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.einsum('ik,jkl,il->ij', [a, b, c])\n", | |
| "torch.einsum('ik,jkl,il->ij', [a, b, c])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "37 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n", | |
| " [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])" | |
| ] | |
| }, | |
| "execution_count": 37, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
| ".view(-1).gather(0, torch.stack([torch.range(0, 9, 2), torch.range(11, 19, 2)]).view(-1).long()).reshape(2, 5)\n", | |
| "\n", | |
| "torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
| ".view(-1).gather(0, torch.stack([torch.range(0, 9, 2), torch.range(11, 19, 2)]).view(-1).long()).reshape(2, 5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "using .arange" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "34.8 µs ± 929 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n", | |
| " [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])" | |
| ] | |
| }, | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%timeit torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
| ".view(-1).gather(0, torch.stack([torch.arange(0, 10, 2), torch.arange(11, 20, 2)]).view(-1).long()).reshape(2, 5)\n", | |
| "\n", | |
| "torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n", | |
| ".view(-1).gather(0, torch.stack([torch.arange(0, 10, 2), torch.arange(11, 20, 2)]).view(-1).long()).reshape(2, 5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
source: https://rockt.github.io/2018/04/30/einsum#fn.1