Skip to content

Instantly share code, notes, and snippets.

@shrubb
Created May 28, 2017 19:42
Show Gist options
  • Select an option

  • Save shrubb/3e4e6e888d3fa490c320530623e887b2 to your computer and use it in GitHub Desktop.

Select an option

Save shrubb/3e4e6e888d3fa490c320530623e887b2 to your computer and use it in GitHub Desktop.
Butyrka: beam search for Nikita's autoencoder
def get_next_token_probabilities(curr_token_sequence, z, embed_fn, decode_fn):
curr_tokens_embeddings = embed_fn([curr_token_sequence])
return decode_fn(curr_tokens_embeddings, z)[0][-1]
def beam_search_step(current_token_sequence, get_next_token_probabilities, depth=3, width=4):
"""
returns: next token guess
"""
candidates = [(1.0, [])] # (probability, candidate_appendix)
for d in range(depth):
next_candidates = [] # ...of length d+1
# try guessing the next token
for candidate_proba, candidate in candidates:
next_token_proba = get_next_token_probabilities(current_token_sequence + candidate)
# take `width` most probable
most_probable_tokens = np.argpartition(next_token_proba, -width)[-width:]
most_probable_probs = next_token_proba[most_probable_tokens]
next_candidates += \
[(p*candidate_proba, candidate + [token]) for p, token \
in zip(most_probable_probs, most_probable_tokens)]
candidates = next_candidates
candidates.sort(reverse=True)
candidates = candidates[:width]
return candidates[0][1][0]
############# Usage example ################
text_1_ind = 100
text_2_ind = 200
print("".join(tokenized_texts[text_1_ind]))
print()
print("".join(tokenized_texts[text_2_ind]))
X_test = X[[text_1_ind, text_2_ind]]
h_test = encode_fn(X_test)
alpha=0.8
h_test =(h_test[0]*alpha+(1-alpha)*h_test[1]).reshape((1,1024))
print('Without beam search:')
curr_tokens = [tokens_indices[START_TOKEN]]
while curr_tokens[-1] != tokens_indices[END_TOKEN]:
emb_test = embed_fn([curr_tokens])
probs = decode_fn(emb_test, h_test)[0][-1]
curr_tokens.append(np.argmax(probs))
print(''.join(indices_tokens[t] for t in curr_tokens))
# *********************************************************
print()
print('With beam search:')
curr_tokens = [tokens_indices[START_TOKEN]]
get_next_token_probabilities_AE = \
lambda c_t_s: get_next_token_probabilities(c_t_s, h_test, embed_fn, decode_fn)
while curr_tokens[-1] != tokens_indices[END_TOKEN]:
curr_tokens.append(beam_search_step(curr_tokens, get_next_token_probabilities_AE, depth=5, width=5))
if curr_tokens[-1] == tokens_indices['\n']:
print('One more line...')
print()
print(''.join(indices_tokens[t] for t in curr_tokens))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment