Skip to content

Instantly share code, notes, and snippets.

@alanvww
Last active April 4, 2024 16:10
Show Gist options
  • Select an option

  • Save alanvww/bc5694db65752502e69750dafb0b9d37 to your computer and use it in GitHub Desktop.

Select an option

Save alanvww/bc5694db65752502e69750dafb0b9d37 to your computer and use it in GitHub Desktop.
HW5
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "c4ea461b-a931-4339-88eb-eb431f036bd5",
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "e83d9190-84c5-4d04-8277-0a980f44108d",
"metadata": {},
"outputs": [],
"source": [
"import tracery\n",
"from tracery.modifiers import base_english\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "74ef367e-6153-4560-8cb7-95e9b507a99c",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"distilgpt2\" # Or any other suitable model\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "d220cde0-28e3-4bda-b9b4-886c9dde42a5",
"metadata": {},
"outputs": [],
"source": [
"generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "0804c92c-2d99-45e8-95ab-b51788bffb09",
"metadata": {},
"outputs": [],
"source": [
"def generate_diary_text(prompt):\n",
" generated_text = generator(prompt, max_length=100, num_return_sequences=1)[0][\"generated_text\"]\n",
" return generated_text"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "6627d7f1-26ff-40af-9f8a-4f60f11ec3ec",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
}
],
"source": [
"prompts = [\n",
" \"Today, I saw\",\n",
" \"I remember when\",\n",
" \"It was a\",\n",
" \"The weather was\",\n",
" \"I felt\",\n",
" 'I guess'\n",
"]\n",
"\n",
"rules = {\n",
" \"diary_entry\": \"#date#\\n\\n#greeting#\\n\\n#generated_text#\\n\\n#closing#\\n#name#\",\n",
" \"date\": random.choice([\n",
" \"January\", \"February\", \"March\", \"April\", \"May\", \"June\",\n",
" \"July\", \"August\", \"September\", \"October\", \"November\", \"December\"\n",
" ]) + \" \" + str(random.randint(1, 28)) + \", \" + str(random.randint(1950, 2024)),\n",
" \"greeting\": [\"Long time no see!\", \"Hey there!\", \"Hello, my unseen friend!\", \"To my future self:\"],\n",
" \"closing\": [\"Yours truly,\", \"Best regards,\", \"Warmly,\", \"Until next time,\", \"With love,\", \"Sincerely,\", \"Fondly,\", \"With best wishes,\", \"Take care,\", \"Yours sincerely,\"],\n",
" \"name\": [\"Alex\", \"Jordan\", \"Taylor\", \"Morgan\", \"Casey\", \"Charlie\", \"Hayden\", \"Ellis\"],\n",
" \"generated_text\": generate_diary_text(random.choice(prompts))\n",
"}\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "b395f1f3-0a0d-4cd0-b0d8-dd72bbeb26e6",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"July 24, 1993\n",
"\n",
"Hey there!\n",
"Today, I saw my first day in the company, and felt it would be an unforgettable opportunity to make some great friends and fellow students. But in the end, I was a jerk and I just turned my back and I was on my own. The only things that stand out by my behavior were the words I used when we talked down about a few other people on the show for the first time this year. It was a relief to learn something about me.\n",
"\n",
"\n",
"I'm not proud\n",
"\n",
"Until next time,\n",
"\n"
]
}
],
"source": [
"# Interesting Output\n",
"grammar = tracery.Grammar(rules)\n",
"grammar.add_modifiers(base_english)\n",
"\n",
"diary_entry = grammar.flatten(\"#diary_entry#\")\n",
"print(diary_entry)\n",
"print()"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "7b7d1af4-f961-46d5-880d-39236567f7fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"May 22, 2017\n",
"\n",
"Hello, my unseen friend!\n",
"\n",
"Today, I saw three new products on sale at the time: the Mini-Mini, the Mini-Mini M-Z, and the Mini-Man. I went to the official site for this announcement, but the product was not available until yesterday. (Check it out. It's an \"innate\" price for its \"unbranded\")\n",
"\n",
"\n",
"\n",
"Before a couple of weeks of testing on this product, my review was fairly standard, but we were quickly getting to know each other\n",
"\n",
"Fondly,\n",
"Jordan\n",
"\n"
]
}
],
"source": [
"grammar = tracery.Grammar(rules)\n",
"grammar.add_modifiers(base_english)\n",
"\n",
"diary_entry = grammar.flatten(\"#diary_entry#\")\n",
"print(diary_entry)\n",
"print()"
]
},
{
"cell_type": "markdown",
"id": "c85dc253-af87-4879-bdfb-cabb08d937cf",
"metadata": {},
"source": [
"## Fine-tuning"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "e68f672f-65e8-47f3-a8d3-7f87a9e4f1a5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (2.18.0)\n",
"Requirement already satisfied: filelock in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (3.13.3)\n",
"Requirement already satisfied: numpy>=1.17 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (1.26.4)\n",
"Requirement already satisfied: pyarrow>=12.0.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (15.0.2)\n",
"Requirement already satisfied: pyarrow-hotfix in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (0.6)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (0.3.8)\n",
"Requirement already satisfied: pandas in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (2.2.1)\n",
"Requirement already satisfied: requests>=2.19.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (4.66.2)\n",
"Requirement already satisfied: xxhash in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (0.70.16)\n",
"Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)\n",
"Requirement already satisfied: aiohttp in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (3.9.3)\n",
"Requirement already satisfied: huggingface-hub>=0.19.4 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (0.22.1)\n",
"Requirement already satisfied: packaging in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (24.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from datasets) (6.0.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from aiohttp->datasets) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from aiohttp->datasets) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from huggingface-hub>=0.19.4->datasets) (4.10.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests>=2.19.0->datasets) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests>=2.19.0->datasets) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: six>=1.5 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n"
]
}
],
"source": [
"import sys\n",
"!{sys.executable} -m pip install datasets"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "1cdc1399-635e-4b6d-a4a5-d0d55db2d140",
"metadata": {},
"outputs": [],
"source": [
"import datasets"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "221cee0d-58ca-4da9-9d1a-1ea8277267af",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc16b9cf2d8f49b0b19c071868a53ad3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"training_data = datasets.load_dataset('text', data_files=\"nobody-1.txt\")"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "2efca906-45a2-488c-a927-5be5d6b7273d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c4501a066b5b46a0adec86da41780eaa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4669 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenized_training_data = training_data.map(\n",
" lambda x: tokenizer(x['text']),\n",
" remove_columns=[\"text\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "269dc97b-b18c-4da1-af76-d81e81e7a793",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "16118a7a4cfe475a968e104603215ec7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4669 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"block_size = 64\n",
"# magic from https://github.com/huggingface/notebooks/blob/master/examples/language_modeling.ipynb\n",
"def group_texts(examples):\n",
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
" total_length = (total_length // block_size) * block_size\n",
" result = {\n",
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
" for k, t in concatenated_examples.items()\n",
" }\n",
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
" return result\n",
"lm_training_data = tokenized_training_data.map(\n",
" group_texts,\n",
" batched=True,\n",
" batch_size=200\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "2ef4fcaa-37d5-4be1-b8e6-754dd05e8c17",
"metadata": {},
"outputs": [],
"source": [
"from transformers import Trainer, TrainingArguments"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "5af9718f-b7f8-4a1c-a69f-99441859c69d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: accelerate in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (0.28.0)\n",
"Requirement already satisfied: numpy>=1.17 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (1.26.4)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (24.0)\n",
"Requirement already satisfied: psutil in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (5.9.8)\n",
"Requirement already satisfied: pyyaml in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (6.0.1)\n",
"Requirement already satisfied: torch>=1.10.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (2.2.2)\n",
"Requirement already satisfied: huggingface-hub in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (0.22.1)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from accelerate) (0.4.2)\n",
"Requirement already satisfied: filelock in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from torch>=1.10.0->accelerate) (3.13.3)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from torch>=1.10.0->accelerate) (4.10.0)\n",
"Requirement already satisfied: sympy in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from torch>=1.10.0->accelerate) (1.12)\n",
"Requirement already satisfied: networkx in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from torch>=1.10.0->accelerate) (3.1.3)\n",
"Requirement already satisfied: fsspec in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from torch>=1.10.0->accelerate) (2024.2.0)\n",
"Requirement already satisfied: requests in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from huggingface-hub->accelerate) (4.66.2)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests->huggingface-hub->accelerate) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests->huggingface-hub->accelerate) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from requests->huggingface-hub->accelerate) (2024.2.2)\n",
"Requirement already satisfied: mpmath>=0.19 in /Users/ajr/.pyenv/versions/3.12.2/lib/python3.12/site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n"
]
}
],
"source": [
"import sys\n",
"!{sys.executable} -m pip install accelerate -U"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3b64463d-5663-4623-bb08-232dcd3a5cb6",
"metadata": {},
"outputs": [],
"source": [
"my_tokenizer = AutoTokenizer.from_pretrained('distilgpt2-finetune-nobody')\n",
"my_model = AutoModelForCausalLM.from_pretrained('distilgpt2-finetune-nobody')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f88de1a0-4823-41ce-9fff-22cdf8883acb",
"metadata": {},
"outputs": [],
"source": [
"my_generator = pipeline(\"text-generation\", model=my_model, tokenizer=my_tokenizer)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "92dfd05c-2ae1-4115-820c-e25e0cfef2be",
"metadata": {},
"outputs": [],
"source": [
"import tracery\n",
"from tracery.modifiers import base_english\n",
"import random\n",
"prompts = [\n",
" \"Today, I saw\",\n",
" \"I remember when\",\n",
" \"It was a\",\n",
" \"The weather was\",\n",
" \"I felt\",\n",
" 'I guess'\n",
"]\n",
"\n",
"rules = {\n",
" \"diary_entry\": \"#date#\\n\\n#greeting#\\n\\n#generated_text#\\n\\n#closing#\\n#name#\",\n",
" \"date\": random.choice([\n",
" \"January\", \"February\", \"March\", \"April\", \"May\", \"June\",\n",
" \"July\", \"August\", \"September\", \"October\", \"November\", \"December\"\n",
" ]) + \" \" + str(random.randint(1, 28)) + \", \" + str(random.randint(1950, 2024)),\n",
" \"greeting\": [\"Long time no see!\", \"Hey there!\", \"Hello, my unseen friend!\", \"To my future self:\"],\n",
" \"closing\": [\"Yours truly,\", \"Best regards,\", \"Warmly,\", \"Until next time,\", \"With love,\", \"Sincerely,\", \"Fondly,\", \"With best wishes,\", \"Take care,\", \"Yours sincerely,\"],\n",
" \"name\": [\"Alex\", \"Jordan\", \"Taylor\", \"Morgan\", \"Casey\", \"Charlie\", \"Hayden\", \"Ellis\"],\n",
" \"generated_text\": my_generator(random.choice(prompts))[0]['generated_text']\n",
"}\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8dd1779a-c9fe-4a7c-b6c4-969295aeb593",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The weather was like a dream, and the only thing I could tell it was that the lights were light, and then the fog appeared, and it was so bright, so bright, andso foggy, andI was so afraid of them I'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import random\n",
"my_generator(random.choice(prompts))[0]['generated_text']\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e6b34370-0815-42aa-929e-d32fb40f3316",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"March 10, 1974\n",
"\n",
"To my future self:\n",
"\n",
"I guess it’s time for him to have some fun, but I rather fear he shall not do so. If he wishesto be seen again in the future, I am so grateful that you will make a friend.In the meantime\n",
"\n",
"Yours sincerely,\n",
"Jordan\n",
"\n"
]
}
],
"source": [
"grammar = tracery.Grammar(rules)\n",
"grammar.add_modifiers(base_english)\n",
"\n",
"diary_entry = grammar.flatten(\"#diary_entry#\")\n",
"print(diary_entry)\n",
"print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7d4c730-f2d6-404d-ad92-6dbd30f4c90a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment