Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Created May 2, 2021 10:26
Show Gist options
  • Select an option

  • Save jamescalam/4785d0ea5b4dd2d799df5f05dfb1f14f to your computer and use it in GitHub Desktop.

Select an option

Save jamescalam/4785d0ea5b4dd2d799df5f05dfb1f14f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To perform this operation, we first resize our `attention_mask` tensor:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 128])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"attention_mask = tokens['attention_mask']\n",
"attention_mask.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 128, 768])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()\n",
"mask.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]],\n",
"\n",
" [[1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]],\n",
"\n",
" [[1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]],\n",
"\n",
" [[1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mask"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each vector above represents a single token attention mask - each token now has a vector of size 768 representing it's *attention_mask* status. Then we multiply the two tensors to apply the attention mask:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 128, 768])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masked_embeddings = embeddings * mask\n",
"masked_embeddings.shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.0692, 0.6230, 0.0354, ..., 0.8033, 1.6314, 0.3281],\n",
" [ 0.0367, 0.6842, 0.1946, ..., 0.0848, 1.4747, -0.3008],\n",
" [-0.0121, 0.6543, -0.0727, ..., -0.0326, 1.7717, -0.6812],\n",
" ...,\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000]],\n",
"\n",
" [[-0.3212, 0.8251, 1.0554, ..., -0.1855, 0.1517, 0.3937],\n",
" [-0.7146, 1.0297, 1.1217, ..., 0.0331, 0.2382, -0.1563],\n",
" [-0.2352, 1.1353, 0.8594, ..., -0.4310, -0.0272, -0.2968],\n",
" ...,\n",
" [-0.0000, 0.0000, 0.0000, ..., 0.0000, -0.0000, 0.0000],\n",
" [-0.0000, 0.0000, 0.0000, ..., 0.0000, -0.0000, 0.0000],\n",
" [-0.0000, 0.0000, 0.0000, ..., 0.0000, -0.0000, 0.0000]],\n",
"\n",
" [[-0.7576, 0.8399, -0.3792, ..., 0.1271, 1.2514, 0.1365],\n",
" [-0.6591, 0.7613, -0.4662, ..., 0.2259, 1.1289, -0.3611],\n",
" [-0.9007, 0.6791, -0.3778, ..., 0.1142, 0.9080, -0.1830],\n",
" ...,\n",
" [-0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n",
" [-0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000],\n",
" [-0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -0.0000]],\n",
"\n",
" [[-0.2362, 0.8551, -0.8040, ..., 0.6122, 0.3003, -0.1492],\n",
" [-0.0868, 0.9531, -0.6419, ..., 0.7867, 0.2960, -0.7350],\n",
" [-0.3016, 1.0148, -0.3380, ..., 0.8634, 0.0463, -0.3623],\n",
" ...,\n",
" [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, -0.0000]]],\n",
" grad_fn=<MulBackward0>)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masked_embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we sum the remained of the embeddings along axis `1`:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 768])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summed = torch.sum(masked_embeddings, 1)\n",
"summed.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then sum the number of values that must be given attention in each position of the tensor:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 768])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summed_mask = torch.clamp(mask.sum(1), min=1e-9)\n",
"summed_mask.shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[15., 15., 15., ..., 15., 15., 15.],\n",
" [22., 22., 22., ..., 22., 22., 22.],\n",
" [15., 15., 15., ..., 15., 15., 15.],\n",
" [14., 14., 14., ..., 14., 14., 14.]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summed_mask"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we calculate the mean as the sum of the embedding activations `summed` divided by the number of values that should be given attention in each position `summed_mask`:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"mean_pooled = summed / summed_mask"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0745, 0.8637, 0.1795, ..., 0.7734, 1.7247, -0.1803],\n",
" [-0.3715, 0.9729, 1.0840, ..., -0.2552, -0.2759, 0.0358],\n",
" [-0.5030, 0.7950, -0.1240, ..., 0.1441, 0.9704, -0.1791],\n",
" [-0.2131, 1.0175, -0.8833, ..., 0.7371, 0.1947, -0.3011]],\n",
" grad_fn=<DivBackward0>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_pooled"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ML",
"language": "python",
"name": "ml"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment