Created
May 28, 2017 19:42
-
-
Save shrubb/3e4e6e888d3fa490c320530623e887b2 to your computer and use it in GitHub Desktop.
Butyrka: beam search for Nikita's autoencoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_next_token_probabilities(curr_token_sequence, z, embed_fn, decode_fn): | |
| curr_tokens_embeddings = embed_fn([curr_token_sequence]) | |
| return decode_fn(curr_tokens_embeddings, z)[0][-1] | |
| def beam_search_step(current_token_sequence, get_next_token_probabilities, depth=3, width=4): | |
| """ | |
| returns: next token guess | |
| """ | |
| candidates = [(1.0, [])] # (probability, candidate_appendix) | |
| for d in range(depth): | |
| next_candidates = [] # ...of length d+1 | |
| # try guessing the next token | |
| for candidate_proba, candidate in candidates: | |
| next_token_proba = get_next_token_probabilities(current_token_sequence + candidate) | |
| # take `width` most probable | |
| most_probable_tokens = np.argpartition(next_token_proba, -width)[-width:] | |
| most_probable_probs = next_token_proba[most_probable_tokens] | |
| next_candidates += \ | |
| [(p*candidate_proba, candidate + [token]) for p, token \ | |
| in zip(most_probable_probs, most_probable_tokens)] | |
| candidates = next_candidates | |
| candidates.sort(reverse=True) | |
| candidates = candidates[:width] | |
| return candidates[0][1][0] | |
| ############# Usage example ################ | |
| text_1_ind = 100 | |
| text_2_ind = 200 | |
| print("".join(tokenized_texts[text_1_ind])) | |
| print() | |
| print("".join(tokenized_texts[text_2_ind])) | |
| X_test = X[[text_1_ind, text_2_ind]] | |
| h_test = encode_fn(X_test) | |
| alpha=0.8 | |
| h_test =(h_test[0]*alpha+(1-alpha)*h_test[1]).reshape((1,1024)) | |
| print('Without beam search:') | |
| curr_tokens = [tokens_indices[START_TOKEN]] | |
| while curr_tokens[-1] != tokens_indices[END_TOKEN]: | |
| emb_test = embed_fn([curr_tokens]) | |
| probs = decode_fn(emb_test, h_test)[0][-1] | |
| curr_tokens.append(np.argmax(probs)) | |
| print(''.join(indices_tokens[t] for t in curr_tokens)) | |
| # ********************************************************* | |
| print() | |
| print('With beam search:') | |
| curr_tokens = [tokens_indices[START_TOKEN]] | |
| get_next_token_probabilities_AE = \ | |
| lambda c_t_s: get_next_token_probabilities(c_t_s, h_test, embed_fn, decode_fn) | |
| while curr_tokens[-1] != tokens_indices[END_TOKEN]: | |
| curr_tokens.append(beam_search_step(curr_tokens, get_next_token_probabilities_AE, depth=5, width=5)) | |
| if curr_tokens[-1] == tokens_indices['\n']: | |
| print('One more line...') | |
| print() | |
| print(''.join(indices_tokens[t] for t in curr_tokens)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment