Instantly share code, notes, and snippets.
Created
May 2, 2021 10:07
-
Star
2
(2)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
-
Save jamescalam/6d973995d7d41ab1324f58ad14b071d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers import AutoTokenizer, AutoModel\n", | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "First we initialize our model and tokenizer:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')\n", | |
| "model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Then we tokenize the sentences just as before:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sentences = [\n", | |
| " \"Three years later, the coffin was still full of Jello.\",\n", | |
| " \"The fish dreamed of escaping the fishbowl and into the toilet where he saw his friend go.\",\n", | |
| " \"The person box was packed with jelly many dozens of months later.\",\n", | |
| " \"He found a leprechaun in his walnut shell.\"\n", | |
| "]\n", | |
| "\n", | |
| "# initialize dictionary to store tokenized sentences\n", | |
| "tokens = {'input_ids': [], 'attention_mask': []}\n", | |
| "\n", | |
| "for sentence in sentences:\n", | |
| " # encode each sentence and append to dictionary\n", | |
| " new_tokens = tokenizer.encode_plus(sentence, max_length=128,\n", | |
| " truncation=True, padding='max_length',\n", | |
| " return_tensors='pt')\n", | |
| " tokens['input_ids'].append(new_tokens['input_ids'][0])\n", | |
| " tokens['attention_mask'].append(new_tokens['attention_mask'][0])\n", | |
| "\n", | |
| "# reformat list of tensors into single tensor\n", | |
| "tokens['input_ids'] = torch.stack(tokens['input_ids'])\n", | |
| "tokens['attention_mask'] = torch.stack(tokens['attention_mask'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We process these tokens through our model:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "odict_keys(['last_hidden_state', 'pooler_output'])" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "outputs = model(**tokens)\n", | |
| "outputs.keys()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The dense vector representations of our `text` are contained within the `outputs` **'last_hidden_state'** tensor, which we access like so:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[-0.0692, 0.6230, 0.0354, ..., 0.8033, 1.6314, 0.3281],\n", | |
| " [ 0.0367, 0.6842, 0.1946, ..., 0.0848, 1.4747, -0.3008],\n", | |
| " [-0.0121, 0.6543, -0.0727, ..., -0.0326, 1.7717, -0.6812],\n", | |
| " ...,\n", | |
| " [ 0.1953, 1.1085, 0.3390, ..., 1.2826, 1.0114, -0.0728],\n", | |
| " [ 0.0902, 1.0288, 0.3297, ..., 1.2940, 0.9865, -0.1113],\n", | |
| " [ 0.1240, 0.9737, 0.3933, ..., 1.1359, 0.8768, -0.1043]],\n", | |
| "\n", | |
| " [[-0.3212, 0.8251, 1.0554, ..., -0.1855, 0.1517, 0.3937],\n", | |
| " [-0.7146, 1.0297, 1.1217, ..., 0.0331, 0.2382, -0.1563],\n", | |
| " [-0.2352, 1.1353, 0.8594, ..., -0.4310, -0.0272, -0.2968],\n", | |
| " ...,\n", | |
| " [-0.5400, 0.3236, 0.7839, ..., 0.0022, -0.2994, 0.2659],\n", | |
| " [-0.5643, 0.3187, 0.9576, ..., 0.0342, -0.3030, 0.1878],\n", | |
| " [-0.5172, 0.3599, 0.9336, ..., 0.0243, -0.2232, 0.1672]],\n", | |
| "\n", | |
| " [[-0.7576, 0.8399, -0.3792, ..., 0.1271, 1.2514, 0.1365],\n", | |
| " [-0.6591, 0.7613, -0.4662, ..., 0.2259, 1.1289, -0.3611],\n", | |
| " [-0.9007, 0.6791, -0.3778, ..., 0.1142, 0.9080, -0.1830],\n", | |
| " ...,\n", | |
| " [-0.2158, 0.5463, 0.3117, ..., 0.1802, 0.7169, -0.0672],\n", | |
| " [-0.3092, 0.4833, 0.3021, ..., 0.2289, 0.6656, -0.0932],\n", | |
| " [-0.2940, 0.4678, 0.3095, ..., 0.2782, 0.5144, -0.1021]],\n", | |
| "\n", | |
| " [[-0.2362, 0.8551, -0.8040, ..., 0.6122, 0.3003, -0.1492],\n", | |
| " [-0.0868, 0.9531, -0.6419, ..., 0.7867, 0.2960, -0.7350],\n", | |
| " [-0.3016, 1.0148, -0.3380, ..., 0.8634, 0.0463, -0.3623],\n", | |
| " ...,\n", | |
| " [-0.1090, 0.6320, -0.8433, ..., 0.7485, 0.1025, 0.0149],\n", | |
| " [ 0.0072, 0.7347, -0.7689, ..., 0.6064, 0.1287, 0.0331],\n", | |
| " [-0.1108, 0.7605, -0.4447, ..., 0.6719, 0.1059, -0.0034]]],\n", | |
| " grad_fn=<NativeLayerNormBackward>)" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "embeddings = outputs.last_hidden_state\n", | |
| "embeddings" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 128, 768])" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "embeddings.shape" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "ML", | |
| "language": "python", | |
| "name": "ml" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment