Last active
June 2, 2025 07:45
-
-
Save kevmo314/ef941347eea7d5e8275b3e960f3e53d6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| >>> from transformers import AutoTokenizer | |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") | |
| >>> beam_search_coverage(stable=tokenizer.encode("Deep"), unstable=" Recurre", k=8) | |
| ' Recurrent' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| for i in range(len(unstable)): | |
| unstable[i:].startswith(token) or token.startswith(unstable[i:]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") | |
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B") | |
| def beam_search_coverage(stable: list[int], unstable: str, k=8): | |
| # generate the coverage table | |
| coverage_table = torch.empty((len(unstable), model.vocab_size), dtype=torch.bool) | |
| for i in range(len(unstable)): | |
| for j in range(model.vocab_size): | |
| # this cell of the coverage table is true if the token is valid | |
| token = tokenizer.decode(j, skip_special_tokens=True) | |
| coverage_table[i, j] = len(token) > 0 and ( | |
| unstable[i:].startswith(token) or token.startswith(unstable[i:]) | |
| ) | |
| # maintain a list of beams that we are searching over and their probabilities | |
| beams = torch.tensor([stable], dtype=torch.int32) | |
| beam_unstable = [""] | |
| beam_probs = torch.tensor([1.0]) | |
| # maintain the best beam we've seen so far | |
| best_beam = None | |
| best_prob = 0.0 | |
| while len(beams) > 0: | |
| logits = model.generate( | |
| input_ids=beams, | |
| attention_mask=torch.ones_like(beams, dtype=torch.int32), | |
| max_new_tokens=1, | |
| output_logits=True, | |
| return_dict_in_generate=True, | |
| ).logits[0] | |
| # convert logits to probabilities and find the next top candidates | |
| token_probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # mask out tokens that are not valid for the current prefix | |
| token_probs[~coverage_table[[len(x) for x in beam_unstable]]] = 0.0 | |
| # find the next top k candidate beams | |
| candidates = (beam_probs.unsqueeze(-1) * token_probs).flatten().topk(k) | |
| source, token = torch.unravel_index( | |
| candidates.indices[candidates.values > best_prob], | |
| token_probs.shape, | |
| ) | |
| beams = torch.hstack((beams[source], token.unsqueeze(-1))) | |
| beam_probs = beam_probs[source] * token_probs[source, token] | |
| beam_unstable = [tokenizer.decode(beam[len(stable) :]) for beam in beams] | |
| # find any beams that exceed the max length | |
| i = 0 | |
| while i < len(beam_unstable): | |
| if len(beam_unstable[i]) >= len(unstable): | |
| # this beam is complete, see if it's better than the best beam | |
| if beam_probs[i] > best_prob: | |
| best_prob = beam_probs[i] | |
| best_beam = beam_unstable[i] | |
| # remove the completed beam from the search | |
| beam_unstable.pop(i) | |
| beams = torch.cat((beams[:i], beams[i + 1 :]), dim=0) | |
| beam_probs = torch.cat((beam_probs[:i], beam_probs[i + 1 :]), dim=0) | |
| else: | |
| i += 1 | |
| return best_beam |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # mask out tokens that are not valid for the current prefix | |
| logits[~coverage_table[[len(x) for x in beam_unstable]]] = float('-inf') | |
| # convert logits to probabilities and find the next top candidates | |
| token_probs = torch.nn.functional.softmax(logits, dim=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment