Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Created May 2, 2021 10:07
Show Gist options
  • Select an option

  • Save jamescalam/6d973995d7d41ab1324f58ad14b071d7 to your computer and use it in GitHub Desktop.

Select an option

Save jamescalam/6d973995d7d41ab1324f58ad14b071d7 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"import torch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we initialize our model and tokenizer:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')\n",
"model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we tokenize the sentences just as before:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"sentences = [\n",
" \"Three years later, the coffin was still full of Jello.\",\n",
" \"The fish dreamed of escaping the fishbowl and into the toilet where he saw his friend go.\",\n",
" \"The person box was packed with jelly many dozens of months later.\",\n",
" \"He found a leprechaun in his walnut shell.\"\n",
"]\n",
"\n",
"# initialize dictionary to store tokenized sentences\n",
"tokens = {'input_ids': [], 'attention_mask': []}\n",
"\n",
"for sentence in sentences:\n",
" # encode each sentence and append to dictionary\n",
" new_tokens = tokenizer.encode_plus(sentence, max_length=128,\n",
" truncation=True, padding='max_length',\n",
" return_tensors='pt')\n",
" tokens['input_ids'].append(new_tokens['input_ids'][0])\n",
" tokens['attention_mask'].append(new_tokens['attention_mask'][0])\n",
"\n",
"# reformat list of tensors into single tensor\n",
"tokens['input_ids'] = torch.stack(tokens['input_ids'])\n",
"tokens['attention_mask'] = torch.stack(tokens['attention_mask'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We process these tokens through our model:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"odict_keys(['last_hidden_state', 'pooler_output'])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs = model(**tokens)\n",
"outputs.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dense vector representations of our `text` are contained within the `outputs` **'last_hidden_state'** tensor, which we access like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.0692, 0.6230, 0.0354, ..., 0.8033, 1.6314, 0.3281],\n",
" [ 0.0367, 0.6842, 0.1946, ..., 0.0848, 1.4747, -0.3008],\n",
" [-0.0121, 0.6543, -0.0727, ..., -0.0326, 1.7717, -0.6812],\n",
" ...,\n",
" [ 0.1953, 1.1085, 0.3390, ..., 1.2826, 1.0114, -0.0728],\n",
" [ 0.0902, 1.0288, 0.3297, ..., 1.2940, 0.9865, -0.1113],\n",
" [ 0.1240, 0.9737, 0.3933, ..., 1.1359, 0.8768, -0.1043]],\n",
"\n",
" [[-0.3212, 0.8251, 1.0554, ..., -0.1855, 0.1517, 0.3937],\n",
" [-0.7146, 1.0297, 1.1217, ..., 0.0331, 0.2382, -0.1563],\n",
" [-0.2352, 1.1353, 0.8594, ..., -0.4310, -0.0272, -0.2968],\n",
" ...,\n",
" [-0.5400, 0.3236, 0.7839, ..., 0.0022, -0.2994, 0.2659],\n",
" [-0.5643, 0.3187, 0.9576, ..., 0.0342, -0.3030, 0.1878],\n",
" [-0.5172, 0.3599, 0.9336, ..., 0.0243, -0.2232, 0.1672]],\n",
"\n",
" [[-0.7576, 0.8399, -0.3792, ..., 0.1271, 1.2514, 0.1365],\n",
" [-0.6591, 0.7613, -0.4662, ..., 0.2259, 1.1289, -0.3611],\n",
" [-0.9007, 0.6791, -0.3778, ..., 0.1142, 0.9080, -0.1830],\n",
" ...,\n",
" [-0.2158, 0.5463, 0.3117, ..., 0.1802, 0.7169, -0.0672],\n",
" [-0.3092, 0.4833, 0.3021, ..., 0.2289, 0.6656, -0.0932],\n",
" [-0.2940, 0.4678, 0.3095, ..., 0.2782, 0.5144, -0.1021]],\n",
"\n",
" [[-0.2362, 0.8551, -0.8040, ..., 0.6122, 0.3003, -0.1492],\n",
" [-0.0868, 0.9531, -0.6419, ..., 0.7867, 0.2960, -0.7350],\n",
" [-0.3016, 1.0148, -0.3380, ..., 0.8634, 0.0463, -0.3623],\n",
" ...,\n",
" [-0.1090, 0.6320, -0.8433, ..., 0.7485, 0.1025, 0.0149],\n",
" [ 0.0072, 0.7347, -0.7689, ..., 0.6064, 0.1287, 0.0331],\n",
" [-0.1108, 0.7605, -0.4447, ..., 0.6719, 0.1059, -0.0034]]],\n",
" grad_fn=<NativeLayerNormBackward>)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings = outputs.last_hidden_state\n",
"embeddings"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 128, 768])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ML",
"language": "python",
"name": "ml"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment