Created
May 2, 2021 10:31
-
-
Save jamescalam/03614a1e1b2d6538c1576e8bff6f4937 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.metrics.pairwise import cosine_similarity" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's calculate cosine similarity for sentence `0`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.33088905, 0.7219259 , 0.55483633]], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# convert from PyTorch tensor to numpy array\n", | |
| "mean_pooled = mean_pooled.detach().numpy()\n", | |
| "\n", | |
| "# calculate\n", | |
| "cosine_similarity(\n", | |
| " [mean_pooled[0]],\n", | |
| " mean_pooled[1:]\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "These similarities translate to:\n", | |
| "\n", | |
| "| Index | Sentence | Similarity |\n", | |
| "| --- | --- | --- |\n", | |
| "| 1 | \"The fish dreamed of escaping the fishbowl and into the toilet where he saw his friend go.\" | 0.3309 |\n", | |
| "| 2 | \"The person box was packed with jelly many dozens of months later.\" | 0.7219 |\n", | |
| "| 3 | \"He found a leprechaun in his walnut shell.\" | 0.5548 |" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "ML", | |
| "language": "python", | |
| "name": "ml" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment