Skip to content

Instantly share code, notes, and snippets.

@p208p2002
Last active December 1, 2022 07:38
Show Gist options
  • Select an option

  • Save p208p2002/af81c348d7e83bd37a63a1f4cce3141b to your computer and use it in GitHub Desktop.

Select an option

Save p208p2002/af81c348d7e83bd37a63a1f4cce3141b to your computer and use it in GitHub Desktop.
#blog #bart-model #nlp
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
},
"colab": {
"name": "af81c348d7e83bd37a63a1f4cce3141b",
"provenance": []
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "HIFU79dL1CXw"
},
"source": [
"## Import"
]
},
{
"cell_type": "code",
"metadata": {
"id": "MqR3IoOx1MCU"
},
"source": [
"! pip install transformers==4.5.1"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zuDAaR6o1CX2"
},
"source": [
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"import torch"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "p_uBHFJ61CX4"
},
"source": [
"tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')\n",
"model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "mkK9U64E1CX4"
},
"source": [
"## Prepare data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mxDneSrk1CX4"
},
"source": [
"context = 'Harry Potter is a series of'\n",
"label = 'seven fantasy novels' + tokenizer.eos_token"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gdd6UnS01CX5"
},
"source": [
"def convert_to_features(context,label):\n",
" input_encodings = tokenizer(context, pad_to_max_length=True, max_length=50, truncation=True)\n",
" label_encodings = tokenizer(label, pad_to_max_length=True, max_length=50, truncation=True, add_special_tokens=False)\n",
" \n",
" pad_token_id = tokenizer.pad_token_id\n",
" labels = []\n",
" for label_encoding_id in label_encodings['input_ids']:\n",
" if label_encoding_id != pad_token_id:\n",
" labels.append(label_encoding_id)\n",
" else:\n",
" labels.append(-100)\n",
" \n",
" return {\n",
" 'input_ids':torch.LongTensor(input_encodings['input_ids']).unsqueeze(0),\n",
" 'attention_mask':torch.LongTensor(input_encodings['attention_mask']).unsqueeze(0),\n",
" 'labels': torch.LongTensor(labels).unsqueeze(0)\n",
" }"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5KfkgMHc1CX5",
"outputId": "06556e5a-bc7f-49c9-e026-c50a0dc6e1ad"
},
"source": [
"model_input = convert_to_features(context,label)\n",
"model_input"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2079: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
" FutureWarning,\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0]]),\n",
" 'input_ids': tensor([[ 0, 29345, 10997, 16, 10, 651, 9, 2, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
" 'labels': tensor([[17723, 8235, 19405, 2, -100, -100, -100, -100, -100, -100,\n",
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]])}"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DQzWBQiA1CX6"
},
"source": [
"## Fine-tuning"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GUzQzIfw1CX6"
},
"source": [
"from transformers import AdamW\n",
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
"model.to(device)\n",
"model.train()\n",
"optim = AdamW(model.parameters(), lr=5e-5)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fSnU9L0b1CX7"
},
"source": [
"for key in model_input.keys():\n",
" model_input[key] = model_input[key].to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3DjT_G5M1CX7",
"outputId": "db67d271-10d9-4f22-e03e-827a9e67530b"
},
"source": [
"epoch = 0\n",
"while True:\n",
" optim.zero_grad()\n",
" outputs = model(**model_input,return_dict=True)\n",
" loss = outputs['loss']\n",
" loss.backward()\n",
" optim.step()\n",
" #\n",
" epoch += 1\n",
" print('epoch:%d'%epoch,'loss:%3.5f'%loss,end='\\r')\n",
" if loss.item() < 1e-3: break"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
""
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mus4XQGo1CX7"
},
"source": [
"## Overfitting test"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FKIPRIYV1CX7"
},
"source": [
"context = 'Harry Potter is a series of'\n",
"input_ids = tokenizer(context,return_tensors='pt')['input_ids'].to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ci-_llC01CX8",
"outputId": "3dddd25e-8730-484d-d272-d72d9458842c"
},
"source": [
"model.eval()\n",
"# introduction of `model.generate`\n",
"# https://huggingface.co/blog/how-to-generate\n",
"sample_outputs = model.generate(\n",
" input_ids,\n",
" do_sample=False, \n",
" max_length=10, \n",
" top_k=1, \n",
" num_return_sequences=1,\n",
" early_stopping = True\n",
")\n",
"\n",
"print(\"Output:\\n\" + 100 * '-')\n",
"for i, sample_output in enumerate(sample_outputs):\n",
" print(\"{}: {}\".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Output:\n",
"----------------------------------------------------------------------------------------------------\n",
"0: seven fantasy novels\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment