Created
May 2, 2021 10:26
-
-
Save jamescalam/4785d0ea5b4dd2d799df5f05dfb1f14f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "To perform this operation, we first resize our `attention_mask` tensor:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 128])" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "attention_mask = tokens['attention_mask']\n", | |
| "attention_mask.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 128, 768])" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()\n", | |
| "mask.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " ...,\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " ...,\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " ...,\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " [1., 1., 1., ..., 1., 1., 1.],\n", | |
| " ...,\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]]])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mask" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Each vector above represents a single token attention mask - each token now has a vector of size 768 representing it's *attention_mask* status. Then we multiply the two tensors to apply the attention mask:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 128, 768])" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "masked_embeddings = embeddings * mask\n", | |
| "masked_embeddings.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[-0.0692, 0.6230, 0.0354, ..., 0.8033, 1.6314, 0.3281],\n", | |
| " [ 0.0367, 0.6842, 0.1946, ..., 0.0848, 1.4747, -0.3008],\n", | |
| " [-0.0121, 0.6543, -0.0727, ..., -0.0326, 1.7717, -0.6812],\n", | |
| " ...,\n", | |
| " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n", | |
| " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n", | |
| " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000]],\n", | |
| "\n", | |
| " [[-0.3212, 0.8251, 1.0554, ..., -0.1855, 0.1517, 0.3937],\n", | |
| " [-0.7146, 1.0297, 1.1217, ..., 0.0331, 0.2382, -0.1563],\n", | |
| " [-0.2352, 1.1353, 0.8594, ..., -0.4310, -0.0272, -0.2968],\n", | |
| " ...,\n", | |
| " [-0.0000, 0.0000, 0.0000, ..., 0.0000, -0.0000, 0.0000],\n", | |
| " [-0.0000, 0.0000, 0.0000, ..., 0.0000, -0.0000, 0.0000],\n", | |
| " [-0.0000, 0.0000, 0.0000, ..., 0.0000, -0.0000, 0.0000]],\n", | |
| "\n", | |
| " [[-0.7576, 0.8399, -0.3792, ..., 0.1271, 1.2514, 0.1365],\n", | |
| " [-0.6591, 0.7613, -0.4662, ..., 0.2259, 1.1289, -0.3611],\n", | |
| " [-0.9007, 0.6791, -0.3778, ..., 0.1142, 0.9080, -0.1830],\n", | |
| " ...,\n", | |
| " [-0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n", | |
| " [-0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n", | |
| " [-0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000]],\n", | |
| "\n", | |
| " [[-0.2362, 0.8551, -0.8040, ..., 0.6122, 0.3003, -0.1492],\n", | |
| " [-0.0868, 0.9531, -0.6419, ..., 0.7867, 0.2960, -0.7350],\n", | |
| " [-0.3016, 1.0148, -0.3380, ..., 0.8634, 0.0463, -0.3623],\n", | |
| " ...,\n", | |
| " [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n", | |
| " [ 0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n", | |
| " [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, -0.0000]]],\n", | |
| " grad_fn=<MulBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "masked_embeddings" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Then we sum the remained of the embeddings along axis `1`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 768])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "summed = torch.sum(masked_embeddings, 1)\n", | |
| "summed.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Then sum the number of values that must be given attention in each position of the tensor:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 768])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "summed_mask = torch.clamp(mask.sum(1), min=1e-9)\n", | |
| "summed_mask.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[15., 15., 15., ..., 15., 15., 15.],\n", | |
| " [22., 22., 22., ..., 22., 22., 22.],\n", | |
| " [15., 15., 15., ..., 15., 15., 15.],\n", | |
| " [14., 14., 14., ..., 14., 14., 14.]])" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "summed_mask" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Finally, we calculate the mean as the sum of the embedding activations `summed` divided by the number of values that should be given attention in each position `summed_mask`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mean_pooled = summed / summed_mask" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0.0745, 0.8637, 0.1795, ..., 0.7734, 1.7247, -0.1803],\n", | |
| " [-0.3715, 0.9729, 1.0840, ..., -0.2552, -0.2759, 0.0358],\n", | |
| " [-0.5030, 0.7950, -0.1240, ..., 0.1441, 0.9704, -0.1791],\n", | |
| " [-0.2131, 1.0175, -0.8833, ..., 0.7371, 0.1947, -0.3011]],\n", | |
| " grad_fn=<DivBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mean_pooled" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "ML", | |
| "language": "python", | |
| "name": "ml" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment