Skip to content

Instantly share code, notes, and snippets.

@armancohan
Created March 29, 2022 00:51
Show Gist options
  • Select an option

  • Save armancohan/c985e20635c05c4d02931434783dcbbe to your computer and use it in GitHub Desktop.

Select an option

Save armancohan/c985e20635c05c4d02931434783dcbbe to your computer and use it in GitHub Desktop.
from transformers import (
AutoTokenizer,
LEDConfig,
LEDForConditionalGeneration,
)
from datasets import load_dataset
import random
random.seed(2)
BATCH_SIZE = 4
tokenizer = AutoTokenizer.from_pretrained("allenai/PRIMERA")
config = LEDConfig.from_pretrained("allenai/PRIMERA")
model = LEDForConditionalGeneration.from_pretrained("allenai/PRIMERA")
model.to("cuda")
model.eval()
hotpot = load_dataset("hotpot_qa", "distractor")
# select 10 random numbers in range
rands = random.sample(range(len(hotpot["train"])), 32)
contexts = []
answers = []
for i in rands:
context_str = ""
# make context into a single long string by adding <doc-sep> tokens in between
for title, par in zip(hotpot["train"][i]["context"]["title"], hotpot["train"][i]["context"]["sentences"]):
context_str += title + " " + "".join(par) + tokenizer.additional_special_tokens[0]
# add question to it
full_context = hotpot["train"][i]["question"] + tokenizer.additional_special_tokens[0] + context_str
answer = hotpot["train"][i]["answer"]
contexts.append(full_context)
answers.append(answer)
# break contexts into smaller batches
context_batch = [contexts[i : i + BATCH_SIZE] for i in range(0, len(contexts), BATCH_SIZE)]
answer_batch = [answers[i : i + BATCH_SIZE] for i in range(0, len(answers), BATCH_SIZE)]
generations = []
for context, answer in zip(context_batch, answer_batch):
input_ids = tokenizer.batch_encode_plus(context, return_tensors="pt", padding="longest", max_length=4092)
answer_ids = tokenizer.batch_encode_plus(answer, return_tensors="pt", padding="longest", max_length=4092)
input_ids = input_ids.to("cuda")
# global attention on <doc-sep> tokens
global_attn_mask = input_ids.input_ids == tokenizer.additional_special_tokens_ids[0]
generated_ids = model.generate(input_ids=input_ids.input_ids, global_attention_mask=global_attn_mask, max_length=20)
for ids in generated_ids:
generations.append(tokenizer.decode(ids, skip_special_tokens=True))
print(generations)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment