Last active
December 1, 2022 07:38
-
-
Save p208p2002/af81c348d7e83bd37a63a1f4cce3141b to your computer and use it in GitHub Desktop.
#blog #bart-model #nlp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.9" | |
| }, | |
| "colab": { | |
| "name": "af81c348d7e83bd37a63a1f4cce3141b", | |
| "provenance": [] | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "HIFU79dL1CXw" | |
| }, | |
| "source": [ | |
| "## Import" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "MqR3IoOx1MCU" | |
| }, | |
| "source": [ | |
| "! pip install transformers==4.5.1" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "zuDAaR6o1CX2" | |
| }, | |
| "source": [ | |
| "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", | |
| "import torch" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "p_uBHFJ61CX4" | |
| }, | |
| "source": [ | |
| "tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')\n", | |
| "model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "mkK9U64E1CX4" | |
| }, | |
| "source": [ | |
| "## Prepare data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "mxDneSrk1CX4" | |
| }, | |
| "source": [ | |
| "context = 'Harry Potter is a series of'\n", | |
| "label = 'seven fantasy novels' + tokenizer.eos_token" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "gdd6UnS01CX5" | |
| }, | |
| "source": [ | |
| "def convert_to_features(context,label):\n", | |
| " input_encodings = tokenizer(context, pad_to_max_length=True, max_length=50, truncation=True)\n", | |
| " label_encodings = tokenizer(label, pad_to_max_length=True, max_length=50, truncation=True, add_special_tokens=False)\n", | |
| " \n", | |
| " pad_token_id = tokenizer.pad_token_id\n", | |
| " labels = []\n", | |
| " for label_encoding_id in label_encodings['input_ids']:\n", | |
| " if label_encoding_id != pad_token_id:\n", | |
| " labels.append(label_encoding_id)\n", | |
| " else:\n", | |
| " labels.append(-100)\n", | |
| " \n", | |
| " return {\n", | |
| " 'input_ids':torch.LongTensor(input_encodings['input_ids']).unsqueeze(0),\n", | |
| " 'attention_mask':torch.LongTensor(input_encodings['attention_mask']).unsqueeze(0),\n", | |
| " 'labels': torch.LongTensor(labels).unsqueeze(0)\n", | |
| " }" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "5KfkgMHc1CX5", | |
| "outputId": "06556e5a-bc7f-49c9-e026-c50a0dc6e1ad" | |
| }, | |
| "source": [ | |
| "model_input = convert_to_features(context,label)\n", | |
| "model_input" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2079: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n", | |
| " FutureWarning,\n" | |
| ], | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |
| " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |
| " 0, 0]]),\n", | |
| " 'input_ids': tensor([[ 0, 29345, 10997, 16, 10, 651, 9, 2, 1, 1,\n", | |
| " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
| " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
| " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
| " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n", | |
| " 'labels': tensor([[17723, 8235, 19405, 2, -100, -100, -100, -100, -100, -100,\n", | |
| " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", | |
| " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", | |
| " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", | |
| " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]])}" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 18 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "DQzWBQiA1CX6" | |
| }, | |
| "source": [ | |
| "## Fine-tuning" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "GUzQzIfw1CX6" | |
| }, | |
| "source": [ | |
| "from transformers import AdamW\n", | |
| "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", | |
| "model.to(device)\n", | |
| "model.train()\n", | |
| "optim = AdamW(model.parameters(), lr=5e-5)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "fSnU9L0b1CX7" | |
| }, | |
| "source": [ | |
| "for key in model_input.keys():\n", | |
| " model_input[key] = model_input[key].to(device)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "3DjT_G5M1CX7", | |
| "outputId": "db67d271-10d9-4f22-e03e-827a9e67530b" | |
| }, | |
| "source": [ | |
| "epoch = 0\n", | |
| "while True:\n", | |
| " optim.zero_grad()\n", | |
| " outputs = model(**model_input,return_dict=True)\n", | |
| " loss = outputs['loss']\n", | |
| " loss.backward()\n", | |
| " optim.step()\n", | |
| " #\n", | |
| " epoch += 1\n", | |
| " print('epoch:%d'%epoch,'loss:%3.5f'%loss,end='\\r')\n", | |
| " if loss.item() < 1e-3: break" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Mus4XQGo1CX7" | |
| }, | |
| "source": [ | |
| "## Overfitting test" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "FKIPRIYV1CX7" | |
| }, | |
| "source": [ | |
| "context = 'Harry Potter is a series of'\n", | |
| "input_ids = tokenizer(context,return_tensors='pt')['input_ids'].to(device)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "ci-_llC01CX8", | |
| "outputId": "3dddd25e-8730-484d-d272-d72d9458842c" | |
| }, | |
| "source": [ | |
| "model.eval()\n", | |
| "# introduction of `model.generate`\n", | |
| "# https://huggingface.co/blog/how-to-generate\n", | |
| "sample_outputs = model.generate(\n", | |
| " input_ids,\n", | |
| " do_sample=False, \n", | |
| " max_length=10, \n", | |
| " top_k=1, \n", | |
| " num_return_sequences=1,\n", | |
| " early_stopping = True\n", | |
| ")\n", | |
| "\n", | |
| "print(\"Output:\\n\" + 100 * '-')\n", | |
| "for i, sample_output in enumerate(sample_outputs):\n", | |
| " print(\"{}: {}\".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Output:\n", | |
| "----------------------------------------------------------------------------------------------------\n", | |
| "0: seven fantasy novels\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment