Skip to content

Instantly share code, notes, and snippets.

@kevmo314
Last active June 2, 2025 07:45
Show Gist options
  • Select an option

  • Save kevmo314/ef941347eea7d5e8275b3e960f3e53d6 to your computer and use it in GitHub Desktop.

Select an option

Save kevmo314/ef941347eea7d5e8275b3e960f3e53d6 to your computer and use it in GitHub Desktop.
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
>>> beam_search_coverage(stable=tokenizer.encode("Deep"), unstable=" Recurre", k=8)
' Recurrent'
for i in range(len(unstable)):
unstable[i:].startswith(token) or token.startswith(unstable[i:])
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B")
def beam_search_coverage(stable: list[int], unstable: str, k=8):
# generate the coverage table
coverage_table = torch.empty((len(unstable), model.vocab_size), dtype=torch.bool)
for i in range(len(unstable)):
for j in range(model.vocab_size):
# this cell of the coverage table is true if the token is valid
token = tokenizer.decode(j, skip_special_tokens=True)
coverage_table[i, j] = len(token) > 0 and (
unstable[i:].startswith(token) or token.startswith(unstable[i:])
)
# maintain a list of beams that we are searching over and their probabilities
beams = torch.tensor([stable], dtype=torch.int32)
beam_unstable = [""]
beam_probs = torch.tensor([1.0])
# maintain the best beam we've seen so far
best_beam = None
best_prob = 0.0
while len(beams) > 0:
logits = model.generate(
input_ids=beams,
attention_mask=torch.ones_like(beams, dtype=torch.int32),
max_new_tokens=1,
output_logits=True,
return_dict_in_generate=True,
).logits[0]
# convert logits to probabilities and find the next top candidates
token_probs = torch.nn.functional.softmax(logits, dim=-1)
# mask out tokens that are not valid for the current prefix
token_probs[~coverage_table[[len(x) for x in beam_unstable]]] = 0.0
# find the next top k candidate beams
candidates = (beam_probs.unsqueeze(-1) * token_probs).flatten().topk(k)
source, token = torch.unravel_index(
candidates.indices[candidates.values > best_prob],
token_probs.shape,
)
beams = torch.hstack((beams[source], token.unsqueeze(-1)))
beam_probs = beam_probs[source] * token_probs[source, token]
beam_unstable = [tokenizer.decode(beam[len(stable) :]) for beam in beams]
# find any beams that exceed the max length
i = 0
while i < len(beam_unstable):
if len(beam_unstable[i]) >= len(unstable):
# this beam is complete, see if it's better than the best beam
if beam_probs[i] > best_prob:
best_prob = beam_probs[i]
best_beam = beam_unstable[i]
# remove the completed beam from the search
beam_unstable.pop(i)
beams = torch.cat((beams[:i], beams[i + 1 :]), dim=0)
beam_probs = torch.cat((beam_probs[:i], beam_probs[i + 1 :]), dim=0)
else:
i += 1
return best_beam
# mask out tokens that are not valid for the current prefix
logits[~coverage_table[[len(x) for x in beam_unstable]]] = float('-inf')
# convert logits to probabilities and find the next top candidates
token_probs = torch.nn.functional.softmax(logits, dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment