Created
March 29, 2022 00:51
-
-
Save armancohan/c985e20635c05c4d02931434783dcbbe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import ( | |
| AutoTokenizer, | |
| LEDConfig, | |
| LEDForConditionalGeneration, | |
| ) | |
| from datasets import load_dataset | |
| import random | |
| random.seed(2) | |
| BATCH_SIZE = 4 | |
| tokenizer = AutoTokenizer.from_pretrained("allenai/PRIMERA") | |
| config = LEDConfig.from_pretrained("allenai/PRIMERA") | |
| model = LEDForConditionalGeneration.from_pretrained("allenai/PRIMERA") | |
| model.to("cuda") | |
| model.eval() | |
| hotpot = load_dataset("hotpot_qa", "distractor") | |
| # select 10 random numbers in range | |
| rands = random.sample(range(len(hotpot["train"])), 32) | |
| contexts = [] | |
| answers = [] | |
| for i in rands: | |
| context_str = "" | |
| # make context into a single long string by adding <doc-sep> tokens in between | |
| for title, par in zip(hotpot["train"][i]["context"]["title"], hotpot["train"][i]["context"]["sentences"]): | |
| context_str += title + " " + "".join(par) + tokenizer.additional_special_tokens[0] | |
| # add question to it | |
| full_context = hotpot["train"][i]["question"] + tokenizer.additional_special_tokens[0] + context_str | |
| answer = hotpot["train"][i]["answer"] | |
| contexts.append(full_context) | |
| answers.append(answer) | |
| # break contexts into smaller batches | |
| context_batch = [contexts[i : i + BATCH_SIZE] for i in range(0, len(contexts), BATCH_SIZE)] | |
| answer_batch = [answers[i : i + BATCH_SIZE] for i in range(0, len(answers), BATCH_SIZE)] | |
| generations = [] | |
| for context, answer in zip(context_batch, answer_batch): | |
| input_ids = tokenizer.batch_encode_plus(context, return_tensors="pt", padding="longest", max_length=4092) | |
| answer_ids = tokenizer.batch_encode_plus(answer, return_tensors="pt", padding="longest", max_length=4092) | |
| input_ids = input_ids.to("cuda") | |
| # global attention on <doc-sep> tokens | |
| global_attn_mask = input_ids.input_ids == tokenizer.additional_special_tokens_ids[0] | |
| generated_ids = model.generate(input_ids=input_ids.input_ids, global_attention_mask=global_attn_mask, max_length=20) | |
| for ids in generated_ids: | |
| generations.append(tokenizer.decode(ids, skip_special_tokens=True)) | |
| print(generations) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment