-
-
Save nikitakit/6ab61a73b86c50ad88d409bac3c3d09f to your computer and use it in GitHub Desktop.
| """ | |
| Beam decoder for tensorflow | |
| Sample usage: | |
| ``` | |
| from tf_beam_decoder import beam_decoder | |
| decoded_sparse, decoded_logprobs = beam_decoder( | |
| cell=cell, | |
| beam_size=7, | |
| stop_token=2, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=lambda tokens: tf.nn.embedding_lookup(my_embedding, tokens), | |
| ) | |
| ``` | |
| See the `beam_decoder` function for complete documentation. (Only the | |
| `beam_decoder` function is part of the public API here.) | |
| """ | |
| import tensorflow as tf | |
| import numpy as np | |
| from tensorflow.python.util import nest | |
| # %% | |
| def nest_map(func, nested): | |
| if not nest.is_sequence(nested): | |
| return func(nested) | |
| flat = nest.flatten(nested) | |
| return nest.pack_sequence_as(nested, list(map(func, flat))) | |
| # %% | |
| def sparse_boolean_mask(tensor, mask): | |
| """ | |
| Creates a sparse tensor from masked elements of `tensor` | |
| Inputs: | |
| tensor: a 2-D tensor, [batch_size, T] | |
| mask: a 2-D mask, [batch_size, T] | |
| Output: a 2-D sparse tensor | |
| """ | |
| mask_lens = tf.reduce_sum(tf.cast(mask, tf.int32), -1, keep_dims=True) | |
| mask_shape = tf.shape(mask) | |
| left_shifted_mask = tf.tile( | |
| tf.expand_dims(tf.range(mask_shape[1]), 0), | |
| [mask_shape[0], 1] | |
| ) < mask_lens | |
| return tf.SparseTensor( | |
| indices=tf.where(left_shifted_mask), | |
| values=tf.boolean_mask(tensor, mask), | |
| shape=tf.cast(tf.pack([mask_shape[0], tf.reduce_max(mask_lens)]), tf.int64) # For 2D only | |
| ) | |
| # %% | |
| def flat_batch_gather(flat_params, indices, validate_indices=None, | |
| batch_size=None, | |
| options_size=None): | |
| """ | |
| Gather slices from `flat_params` according to `indices`, separately for each | |
| example in a batch. | |
| output[(b * indices_size + i), :, ..., :] = flat_params[(b * options_size + indices[b, i]), :, ..., :] | |
| The arguments `batch_size` and `options_size`, if provided, are used instead | |
| of looking up the shape from the inputs. This may help avoid redundant | |
| computation (TODO: figure out if tensorflow's optimizer can do this automatically) | |
| Args: | |
| flat_params: A `Tensor`, [batch_size * options_size, ...] | |
| indices: A `Tensor`, [batch_size, indices_size] | |
| validate_indices: An optional `bool`. Defaults to `True` | |
| batch_size: (optional) an integer or scalar tensor representing the batch size | |
| options_size: (optional) an integer or scalar Tensor representing the number of options to choose from | |
| """ | |
| if batch_size is None: | |
| batch_size = indices.get_shape()[0].value | |
| if batch_size is None: | |
| batch_size = tf.shape(indices)[0] | |
| if options_size is None: | |
| options_size = flat_params.get_shape()[0].value | |
| if options_size is None: | |
| options_size = tf.shape(flat_params)[0] // batch_size | |
| else: | |
| options_size = options_size // batch_size | |
| indices_offsets = tf.reshape(tf.range(batch_size) * options_size, [-1] + [1] * (len(indices.get_shape())-1)) | |
| indices_into_flat = indices + tf.cast(indices_offsets, indices.dtype) | |
| flat_indices_into_flat = tf.reshape(indices_into_flat, [-1]) | |
| return tf.gather(flat_params, flat_indices_into_flat, validate_indices=validate_indices) | |
| def batch_gather(params, indices, validate_indices=None, | |
| batch_size=None, | |
| options_size=None): | |
| """ | |
| Gather slices from `params` according to `indices`, separately for each | |
| example in a batch. | |
| output[b, i, ..., j, :, ..., :] = params[b, indices[b, i, ..., j], :, ..., :] | |
| The arguments `batch_size` and `options_size`, if provided, are used instead | |
| of looking up the shape from the inputs. This may help avoid redundant | |
| computation (TODO: figure out if tensorflow's optimizer can do this automatically) | |
| Args: | |
| params: A `Tensor`, [batch_size, options_size, ...] | |
| indices: A `Tensor`, [batch_size, ...] | |
| validate_indices: An optional `bool`. Defaults to `True` | |
| batch_size: (optional) an integer or scalar tensor representing the batch size | |
| options_size: (optional) an integer or scalar Tensor representing the number of options to choose from | |
| """ | |
| if batch_size is None: | |
| batch_size = params.get_shape()[0].merge_with(indices.get_shape()[0]).value | |
| if batch_size is None: | |
| batch_size = tf.shape(indices)[0] | |
| if options_size is None: | |
| options_size = params.get_shape()[1].value | |
| if options_size is None: | |
| options_size = tf.shape(params)[1] | |
| batch_size_times_options_size = batch_size * options_size | |
| # TODO(nikita): consider using gather_nd. However as of 1/9/2017 gather_nd | |
| # has no gradients implemented. | |
| flat_params = tf.reshape(params, tf.concat(0,[[batch_size_times_options_size], tf.shape(params)[2:]])) | |
| indices_offsets = tf.reshape(tf.range(batch_size) * options_size, [-1] + [1] * (len(indices.get_shape())-1)) | |
| indices_into_flat = indices + tf.cast(indices_offsets, indices.dtype) | |
| return tf.gather(flat_params, indices_into_flat, validate_indices=validate_indices) | |
| # %% | |
| class BeamFlattenWrapper(tf.nn.rnn_cell.RNNCell): | |
| def __init__(self, cell, beam_size): | |
| self.cell = cell | |
| self.beam_size = beam_size | |
| def merge_batch_beam(self, tensor): | |
| remaining_shape = tf.shape(tensor)[2:] | |
| res = tf.reshape(tensor, tf.concat(0, [[-1], remaining_shape])) | |
| res.set_shape(tf.TensorShape((None,)).concatenate(tensor.get_shape()[2:])) | |
| return res | |
| def unmerge_batch_beam(self, tensor): | |
| remaining_shape = tf.shape(tensor)[1:] | |
| res = tf.reshape(tensor, tf.concat(0, [[-1, self.beam_size], remaining_shape])) | |
| res.set_shape(tf.TensorShape((None,self.beam_size)).concatenate(tensor.get_shape()[1:])) | |
| return res | |
| def prepend_beam_size(self, element): | |
| return tf.TensorShape(self.beam_size).concatenate(element) | |
| def tile_along_beam(self, state): | |
| if nest.is_sequence(state): | |
| return nest_map( | |
| lambda val: self.tile_along_beam(val), | |
| state | |
| ) | |
| if not isinstance(state, tf.Tensor): | |
| raise ValueError("State should be a sequence or tensor") | |
| tensor = state | |
| tensor_shape = tensor.get_shape().with_rank_at_least(1) | |
| new_tensor_shape = tensor_shape[:1].concatenate(self.beam_size).concatenate(tensor_shape[1:]) | |
| dynamic_tensor_shape = tf.unpack(tf.shape(tensor)) | |
| res = tf.expand_dims(tensor, 1) | |
| res = tf.tile(res, [1, self.beam_size] + [1] * (tensor_shape.ndims-1)) | |
| res = tf.reshape(res, [-1, self.beam_size] + list(dynamic_tensor_shape[1:])) | |
| res.set_shape(new_tensor_shape) | |
| return res | |
| def __call__(self, inputs, state, scope=None): | |
| flat_inputs = nest_map(self.merge_batch_beam, inputs) | |
| flat_state = nest_map(self.merge_batch_beam, state) | |
| flat_output, flat_next_state = self.cell(flat_inputs, flat_state, scope=scope) | |
| output = nest_map(self.unmerge_batch_beam, flat_output) | |
| next_state = nest_map(self.unmerge_batch_beam, flat_next_state) | |
| return output, next_state | |
| @property | |
| def state_size(self): | |
| return nest_map(self.prepend_beam_size, self.cell.state_size) | |
| @property | |
| def output_size(self): | |
| return nest_map(self.prepend_beam_size, self.cell.output_size) | |
| # %% | |
| class BeamReplicateWrapper(tf.nn.rnn_cell.RNNCell): | |
| def __init__(self, cell, beam_size): | |
| self.cell = cell | |
| self.beam_size = beam_size | |
| def prepend_beam_size(self, element): | |
| return tf.TensorShape(self.beam_size).concatenate(element) | |
| def tile_along_beam(self, state): | |
| if nest.is_sequence(state): | |
| return nest_map( | |
| lambda val: self.tile_along_beam(val), | |
| state | |
| ) | |
| if not isinstance(state, tf.Tensor): | |
| raise ValueError("State should be a sequence or tensor") | |
| tensor = state | |
| tensor_shape = tensor.get_shape().with_rank_at_least(1) | |
| new_tensor_shape = tensor_shape[:1].concatenate(self.beam_size).concatenate(tensor_shape[1:]) | |
| dynamic_tensor_shape = tf.unpack(tf.shape(tensor)) | |
| res = tf.expand_dims(tensor, 1) | |
| res = tf.tile(res, [1, self.beam_size] + [1] * (tensor_shape.ndims-1)) | |
| res = tf.reshape(res, [-1, self.beam_size] + list(dynamic_tensor_shape[1:])) | |
| res.set_shape(new_tensor_shape) | |
| return res | |
| def __call__(self, inputs, state, scope=None): | |
| varscope = scope or tf.get_variable_scope() | |
| flat_inputs = nest.flatten(inputs) | |
| flat_state = nest.flatten(state) | |
| flat_inputs_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_inputs])) | |
| flat_state_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_state])) | |
| flat_output_unstacked = [] | |
| flat_next_state_unstacked = [] | |
| output_sample = None | |
| next_state_sample = None | |
| for i, (inputs_k, state_k) in enumerate(zip(flat_inputs_unstacked, flat_state_unstacked)): | |
| inputs_k = nest.pack_sequence_as(inputs, inputs_k) | |
| state_k = nest.pack_sequence_as(state, state_k) | |
| # TODO(nikita): is this scope stuff correct? | |
| if i == 0: | |
| output_k, next_state_k = self.cell(inputs_k, state_k, scope=scope) | |
| else: | |
| with tf.variable_scope(varscope, reuse=True): | |
| output_k, next_state_k = self.cell(inputs_k, state_k, scope=varscope if scope is not None else None) | |
| flat_output_unstacked.append(nest.flatten(output_k)) | |
| flat_next_state_unstacked.append(nest.flatten(next_state_k)) | |
| output_sample = output_k | |
| next_state_sample = next_state_k | |
| flat_output = [tf.stack(tensors, axis=1) for tensors in zip(*flat_output_unstacked)] | |
| flat_next_state = [tf.stack(tensors, axis=1) for tensors in zip(*flat_next_state_unstacked)] | |
| output = nest.pack_sequence_as(output_sample, flat_output) | |
| next_state = nest.pack_sequence_as(next_state_sample, flat_next_state) | |
| return output, next_state | |
| @property | |
| def state_size(self): | |
| return nest_map(self.prepend_beam_size, self.cell.state_size) | |
| @property | |
| def output_size(self): | |
| return nest_map(self.prepend_beam_size, self.cell.output_size) | |
| # %% | |
| class BeamSearchHelper(object): | |
| # Our beam scores are stored in a fixed-size tensor, but sometimes the | |
| # tensor size is greater than the number of elements actually on the beam. | |
| # The invalid elements are assigned a highly negative score. | |
| # However, top_k errors if any of the inputs have a score of -inf, so we use | |
| # a large negative constant instead | |
| INVALID_SCORE = -1e18 | |
| def __init__(self, cell, beam_size, stop_token, initial_state, initial_input, | |
| score_upper_bound=None, | |
| max_len=100, | |
| outputs_to_score_fn=None, | |
| tokens_to_inputs_fn=None, | |
| cell_transform='default', | |
| scope=None | |
| ): | |
| self.beam_size = beam_size | |
| self.stop_token = stop_token | |
| self.max_len = max_len | |
| self.scope = scope | |
| if score_upper_bound is None and outputs_to_score_fn is None: | |
| self.score_upper_bound = 0.0 | |
| elif score_upper_bound is None or score_upper_bound > 3e38: | |
| # Note: 3e38 is just a little smaller than the largest float32 | |
| # Second condition allows for Infinity as a synonym for None | |
| self.score_upper_bound = None | |
| else: | |
| self.score_upper_bound = float(score_upper_bound) | |
| if self.max_len is None and self.score_upper_bound is None: | |
| raise ValueError("Beam search needs a stopping criterion. Please provide max_len or score_upper_bound.") | |
| if cell_transform == 'default': | |
| if type(cell) in [tf.nn.rnn_cell.LSTMCell, | |
| tf.nn.rnn_cell.GRUCell, | |
| tf.nn.rnn_cell.BasicLSTMCell, | |
| tf.nn.rnn_cell.BasicRNNCell]: | |
| cell_transform = 'flatten' | |
| else: | |
| cell_transform = 'replicate' | |
| if cell_transform == 'none': | |
| self.cell = cell | |
| self.initial_state = initial_state | |
| self.initial_input = initial_input | |
| elif cell_transform == 'flatten': | |
| self.cell = BeamFlattenWrapper(cell, self.beam_size) | |
| self.initial_state = self.cell.tile_along_beam(initial_state) | |
| self.initial_input = self.cell.tile_along_beam(initial_input) | |
| elif cell_transform == 'replicate': | |
| self.cell = BeamReplicateWrapper(cell, self.beam_size) | |
| self.initial_state = self.cell.tile_along_beam(initial_state) | |
| self.initial_input = self.cell.tile_along_beam(initial_input) | |
| else: | |
| raise ValueError("cell_transform must be one of: 'default', 'flatten', 'replicate', 'none'") | |
| self._cell_transform_used = cell_transform | |
| if outputs_to_score_fn is not None: | |
| self.outputs_to_score_fn = outputs_to_score_fn | |
| if tokens_to_inputs_fn is not None: | |
| self.tokens_to_inputs_fn = tokens_to_inputs_fn | |
| batch_size = tf.Dimension(None) | |
| if not nest.is_sequence(self.initial_state): | |
| batch_size = batch_size.merge_with(self.initial_state.get_shape()[0]) | |
| else: | |
| for tensor in nest.flatten(self.initial_state): | |
| batch_size = batch_size.merge_with(tensor.get_shape()[0]) | |
| if not nest.is_sequence(self.initial_input): | |
| batch_size = batch_size.merge_with(self.initial_input.get_shape()[0]) | |
| else: | |
| for tensor in nest.flatten(self.initial_input): | |
| batch_size = batch_size.merge_with(tensor.get_shape()[0]) | |
| self.inferred_batch_size = batch_size.value | |
| if self.inferred_batch_size is not None: | |
| self.batch_size = self.inferred_batch_size | |
| else: | |
| if not nest.is_sequence(self.initial_state): | |
| self.batch_size = tf.shape(self.initial_state)[0] | |
| else: | |
| self.batch_size = tf.shape(list(nest.flatten(self.initial_state))[0])[0] | |
| self.inferred_batch_size_times_beam_size = None | |
| if self.inferred_batch_size is not None: | |
| self.inferred_batch_size_times_beam_size = self.inferred_batch_size * self.beam_size | |
| self.batch_size_times_beam_size = self.batch_size * self.beam_size | |
| def outputs_to_score_fn(self, cell_output): | |
| return tf.nn.log_softmax(cell_output) | |
| def tokens_to_inputs_fn(self, symbols): | |
| return tf.expand_dims(symbols, -1) | |
| def beam_setup(self, time): | |
| emit_output = None | |
| next_cell_state = self.initial_state | |
| next_input = self.initial_input | |
| # Set up the beam search tracking state | |
| cand_symbols = tf.fill([self.batch_size, 0], tf.constant(self.stop_token, dtype=tf.int32)) | |
| cand_logprobs = tf.ones((self.batch_size,), dtype=tf.float32) * -float('inf') | |
| first_in_beam_mask = tf.equal(tf.range(self.batch_size_times_beam_size) % self.beam_size, 0) | |
| beam_symbols = tf.fill([self.batch_size_times_beam_size, 0], tf.constant(self.stop_token, dtype=tf.int32)) | |
| beam_logprobs = tf.select( | |
| first_in_beam_mask, | |
| tf.fill([self.batch_size_times_beam_size], 0.0), | |
| tf.fill([self.batch_size_times_beam_size], self.INVALID_SCORE) | |
| ) | |
| # Set up correct dimensions for maintaining loop invariants. | |
| # Note that the last dimension (initialized to zero) is not a loop invariant, | |
| # so we need to clear it. TODO(nikita): is there a public API for clearing shape | |
| # inference so that _shape is not necessary? | |
| cand_symbols._shape = tf.TensorShape((self.inferred_batch_size, None)) | |
| cand_logprobs._shape = tf.TensorShape((self.inferred_batch_size,)) | |
| beam_symbols._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size, None)) | |
| beam_logprobs._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size,)) | |
| next_loop_state = ( | |
| cand_symbols, | |
| cand_logprobs, | |
| beam_symbols, | |
| beam_logprobs, | |
| ) | |
| emit_output = tf.zeros(self.cell.output_size) | |
| elements_finished = tf.zeros([self.batch_size], dtype=tf.bool) | |
| return (elements_finished, next_input, next_cell_state, | |
| emit_output, next_loop_state) | |
| def beam_loop(self, time, cell_output, cell_state, loop_state): | |
| ( | |
| past_cand_symbols, # [batch_size, time-1] | |
| past_cand_logprobs,# [batch_size] | |
| past_beam_symbols, # [batch_size*beam_size, time-1], right-aligned | |
| past_beam_logprobs,# [batch_size*beam_size] | |
| ) = loop_state | |
| # We don't actually use this, but emit_output is required to match the | |
| # cell output size specfication. Otherwise we would leave this as None. | |
| emit_output = cell_output | |
| # 1. Get scores for all candidate sequences | |
| logprobs = self.outputs_to_score_fn(cell_output) | |
| try: | |
| num_classes = int(logprobs.get_shape()[-1]) | |
| except: | |
| # Shape inference failed | |
| num_classes = tf.shape(logprobs)[-1] | |
| logprobs_batched = tf.reshape(logprobs + tf.expand_dims(tf.reshape(past_beam_logprobs, [self.batch_size, self.beam_size]), 2), | |
| [self.batch_size, self.beam_size * num_classes]) | |
| # 2. Determine which states to pass to next iteration | |
| # TODO(nikita): consider using slice+fill+concat instead of adding a mask | |
| nondone_mask = tf.reshape( | |
| tf.cast(tf.equal(tf.range(num_classes), self.stop_token), tf.float32) * self.INVALID_SCORE, | |
| [1, 1, num_classes]) | |
| nondone_mask = tf.reshape(tf.tile(nondone_mask, [1, self.beam_size, 1]), | |
| [-1, self.beam_size*num_classes]) | |
| beam_logprobs, indices = tf.nn.top_k(logprobs_batched + nondone_mask, self.beam_size) | |
| beam_logprobs = tf.reshape(beam_logprobs, [-1]) | |
| # For continuing to the next symbols | |
| symbols = indices % num_classes # [batch_size, self.beam_size] | |
| parent_refs = indices // num_classes # [batch_size, self.beam_size] | |
| symbols_history = flat_batch_gather(past_beam_symbols, parent_refs, batch_size=self.batch_size, options_size=self.beam_size) | |
| beam_symbols = tf.concat(1, [symbols_history, tf.reshape(symbols, [-1, 1])]) | |
| # Handle the output and the cell state shuffling | |
| next_cell_state = nest_map( | |
| lambda element: batch_gather(element, parent_refs, batch_size=self.batch_size, options_size=self.beam_size), | |
| cell_state | |
| ) | |
| next_input = self.tokens_to_inputs_fn(tf.reshape(symbols, [-1, self.beam_size])) | |
| # 3. Update the candidate pool to include entries that just ended with a stop token | |
| logprobs_done = tf.reshape(logprobs_batched, [-1, self.beam_size, num_classes])[:,:,self.stop_token] | |
| done_parent_refs = tf.argmax(logprobs_done, 1) | |
| done_symbols = flat_batch_gather(past_beam_symbols, done_parent_refs, batch_size=self.batch_size, options_size=self.beam_size) | |
| logprobs_done_max = tf.reduce_max(logprobs_done, 1) | |
| cand_symbols_unpadded = tf.select(logprobs_done_max > past_cand_logprobs, | |
| done_symbols, | |
| past_cand_symbols) | |
| cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs) | |
| cand_symbols = tf.concat(1, [cand_symbols_unpadded, tf.fill([self.batch_size, 1], self.stop_token)]) | |
| # 4. Check the stopping criteria | |
| if self.max_len is not None: | |
| elements_finished_clip = (time >= self.max_len) | |
| if self.score_upper_bound is not None: | |
| elements_finished_bound = tf.reduce_max(tf.reshape(beam_logprobs, [-1, self.beam_size]), 1) < (cand_logprobs - self.score_upper_bound) | |
| if self.max_len is not None and self.score_upper_bound is not None: | |
| elements_finished = elements_finished_clip | elements_finished_bound | |
| elif self.score_upper_bound is not None: | |
| elements_finished = elements_finished_bound | |
| elif self.max_len is not None: | |
| # this broadcasts elements_finished_clip to the correct shape | |
| elements_finished = tf.zeros([self.batch_size], dtype=tf.bool) | elements_finished_clip | |
| else: | |
| assert False, "Lack of stopping criterion should have been caught in constructor" | |
| # 5. Prepare return values | |
| # While loops require strict shape invariants, so we manually set shapes | |
| # in case the automatic shape inference can't calculate these. Even when | |
| # this is redundant is has the benefit of helping catch shape bugs. | |
| for tensor in list(nest.flatten(next_input)) + list(nest.flatten(next_cell_state)): | |
| tensor.set_shape(tf.TensorShape((self.inferred_batch_size, self.beam_size)).concatenate(tensor.get_shape()[2:])) | |
| for tensor in [cand_symbols, cand_logprobs, elements_finished]: | |
| tensor.set_shape(tf.TensorShape((self.inferred_batch_size,)).concatenate(tensor.get_shape()[1:])) | |
| for tensor in [beam_symbols, beam_logprobs]: | |
| tensor.set_shape(tf.TensorShape((self.inferred_batch_size_times_beam_size,)).concatenate(tensor.get_shape()[1:])) | |
| next_loop_state = ( | |
| cand_symbols, | |
| cand_logprobs, | |
| beam_symbols, | |
| beam_logprobs, | |
| ) | |
| return (elements_finished, next_input, next_cell_state, | |
| emit_output, next_loop_state) | |
| def loop_fn(self, time, cell_output, cell_state, loop_state): | |
| if cell_output is None: | |
| return self.beam_setup(time) | |
| else: | |
| return self.beam_loop(time, cell_output, cell_state, loop_state) | |
| def decode_dense(self): | |
| emit_ta, final_state, final_loop_state = tf.nn.raw_rnn(self.cell, self.loop_fn, scope=self.scope) | |
| cand_symbols, cand_logprobs, beam_symbols, beam_logprobs = final_loop_state | |
| return cand_symbols, cand_logprobs | |
| def decode_sparse(self, include_stop_tokens=True): | |
| dense_symbols, logprobs = self.decode_dense() | |
| mask = tf.not_equal(dense_symbols, self.stop_token) | |
| if include_stop_tokens: | |
| mask = tf.concat(1, [tf.ones_like(mask[:,:1]), mask[:,:-1]]) | |
| return sparse_boolean_mask(dense_symbols, mask), logprobs | |
| # %% | |
| def beam_decoder( | |
| cell, | |
| beam_size, | |
| stop_token, | |
| initial_state, | |
| initial_input, | |
| tokens_to_inputs_fn, | |
| outputs_to_score_fn=None, | |
| score_upper_bound=None, | |
| max_len=None, | |
| cell_transform='default', | |
| output_dense=False, | |
| scope=None | |
| ): | |
| """Beam search decoder | |
| Args: | |
| cell: tf.nn.rnn_cell.RNNCell defining the cell to use | |
| beam_size: the beam size for this search | |
| stop_token: the index of the symbol used to indicate the end of the | |
| output | |
| initial_state: initial cell state for the decoder | |
| initial_input: initial input into the decoder (typically the embedding | |
| of a START token) | |
| tokens_to_inputs_fn: function to go from token numbers to cell inputs. | |
| A typical implementation would look up the tokens in an embedding | |
| matrix. | |
| (signature: [batch_size, beam_size, num_classes] int32 -> [batch_size, beam_size, ...]) | |
| outputs_to_score_fn: function to go from RNN cell outputs to scores for | |
| different tokens. If left unset, log-softmax is used (i.e. the cell | |
| outputs are treated as unnormalized logits). | |
| Inputs to the function are cell outputs, i.e. a possibly nested | |
| structure of items with shape [batch_size, beam_size, ...]. | |
| Must return a single Tensor with shape [batch_size, beam_size, num_classes] | |
| score_upper_bound: (float or None). An upper bound on sequence scores. | |
| Used to determine a stopping criterion for beam search: the search | |
| stops if the highest-scoring complete sequence found so far beats | |
| anything on the beam by at least score_upper_bound. For typical | |
| sequence decoder models, outputs_to_score_fn returns normalized | |
| logits and this upper bound should be set to 0. Defaults to 0 if | |
| outputs_to_score_fn is not provided, otherwise defaults to None. | |
| max_len: (default None) maximum length after which to abort beam search. | |
| This provides an alternative stopping criterion. | |
| cell_transform: 'flatten', 'replicate', 'none', or 'default'. Most RNN | |
| primitives require inputs/outputs/states to have a shape that starts | |
| with [batch_size]. Beam search instead relies on shapes that start | |
| with [batch_size, beam_size]. This parameter controls how the arguments | |
| cell/initial_state/initial_input are transformed to comply with this. | |
| * 'flatten' creates a virtual batch of size batch_size*beam_size, and | |
| uses the cell with such a batch size. This transformation is only | |
| valid for cells that do not rely on the batch ordering in any way. | |
| (This is true of most RNNCells, but notably excludes cells that | |
| use attention.) | |
| The values of initial_state and initial_input are expanded and | |
| tiled along the beam_size dimension. | |
| * 'replicate' creates beam_size virtual replicas of the cell, each | |
| one of which is applied to batch_size elements. This should yield | |
| correct results (even for models with attention), but may not have | |
| ideal performance. | |
| The values of initial_state and initial_input are expanded and | |
| tiled along the beam_size dimension. | |
| * 'none' passes along cell/initial_state/initial_input as-is. | |
| Note that this requires initial_state and initial_input to already | |
| have a shape [batch_size, beam_size, ...] and a custom cell type | |
| that can handle this | |
| * 'default' selects 'flatten' for LSTMCell, GRUCell, BasicLSTMCell, | |
| and BasicRNNCell. For all other cell types, it selects 'replicate' | |
| output_dense: (default False) toggles returning the decoded sequence as | |
| dense tensor. | |
| scope: VariableScope for the created subgraph; defaults to "RNN". | |
| Returns: | |
| A tuple of the form (decoded, log_probabilities) where: | |
| decoded: A SparseTensor (or dense Tensor if output_dense=True), of | |
| underlying shape [batch_size, ?] that contains the decoded sequence | |
| for each batch element | |
| log_probability: a [batch_size] tensor containing sequence | |
| log-probabilities | |
| """ | |
| with tf.variable_scope(scope or "RNN") as varscope: | |
| helper = BeamSearchHelper( | |
| cell=cell, | |
| beam_size=beam_size, | |
| stop_token=stop_token, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=tokens_to_inputs_fn, | |
| outputs_to_score_fn=outputs_to_score_fn, | |
| score_upper_bound=score_upper_bound, | |
| max_len=max_len, | |
| cell_transform=cell_transform, | |
| scope=varscope | |
| ) | |
| if output_dense: | |
| return helper.decode_dense() | |
| else: | |
| return helper.decode_sparse() |
| import tensorflow as tf | |
| from tensorflow.python.platform import test | |
| import numpy as np | |
| from tf_beam_decoder import beam_decoder, BeamSearchHelper | |
| # %% | |
| sess = tf.InteractiveSession() | |
| # %% | |
| class MarkovChainCell(tf.nn.rnn_cell.RNNCell): | |
| """ | |
| This cell type is only used for testing the beam decoder. | |
| It represents a Markov chain characterized by a probability table p(x_t|x_{t-1},x_{t-2}). | |
| """ | |
| def __init__(self, table): | |
| """ | |
| table[a,b,c] = p(x_t=c|x_{t-1}=b,x_{t-2}=a) | |
| """ | |
| assert len(table.shape) == 3 and table.shape[0] == table.shape[1] == table.shape[2] | |
| with np.errstate(divide='ignore'): # ignore warning for log(0) | |
| self.log_table = np.log(np.asarray(table, dtype=np.float32)) | |
| self.log_table_var = None | |
| self._output_size = table.shape[0] | |
| def __call__(self, inputs, state, scope=None): | |
| """ | |
| inputs: [batch_size, 1] int tensor | |
| state: [batch_size, 1] int tensor | |
| """ | |
| # Simulate variable creation, to ensure scoping works correctly | |
| log_table = tf.get_variable('log_table', | |
| shape=(3,3,3), | |
| dtype=tf.float32, | |
| initializer=tf.constant_initializer(self.log_table)) | |
| if self.log_table_var is None: | |
| self.log_table_var = log_table | |
| else: | |
| assert self.log_table_var == log_table | |
| logits = tf.reshape(log_table, [-1, self.output_size]) | |
| indices = state[0] * self.output_size + inputs | |
| return tf.gather(logits, tf.reshape(indices, [-1])), (inputs,) | |
| @property | |
| def state_size(self): | |
| return (1,) | |
| @property | |
| def output_size(self): | |
| return self._output_size | |
| class BeamSearchTest(test.TestCase): | |
| def test1(self): | |
| """ | |
| test correct decode in sequence | |
| """ | |
| with self.test_session() as sess: | |
| table = np.array([[[0.0, 0.6, 0.4], | |
| [0.0, 0.4, 0.6], | |
| [0.0, 0.0, 1.0]]] * 3) | |
| for cell_transform in ['default', 'flatten', 'replicate']: | |
| cell = MarkovChainCell(table) | |
| initial_state = cell.zero_state(1, tf.int32) | |
| initial_input = initial_state[0] | |
| with tf.variable_scope('test1_{}'.format(cell_transform)): | |
| best_sparse, best_logprobs = beam_decoder( | |
| cell=cell, | |
| beam_size=7, | |
| stop_token=2, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
| max_len=5, | |
| cell_transform=cell_transform, | |
| output_dense=False, | |
| ) | |
| tf.variables_initializer([cell.log_table_var]).run() | |
| assert all(best_sparse.eval().values == [2]) | |
| assert np.isclose(np.exp(best_logprobs.eval())[0], 0.4) | |
| def test2(self): | |
| """ | |
| test correct intermediate beam states | |
| """ | |
| with self.test_session() as sess: | |
| table = np.array([[[0.9, 0.1, 0], | |
| [0, 0.9, 0.1], | |
| [0, 0, 1.0]]] * 3) | |
| for cell_transform in ['default', 'flatten', 'replicate']: | |
| cell = MarkovChainCell(table) | |
| initial_state = cell.zero_state(1, tf.int32) | |
| initial_input = initial_state[0] | |
| with tf.variable_scope('test2_{}'.format(cell_transform)): | |
| helper = BeamSearchHelper( | |
| cell=cell, | |
| beam_size=10, | |
| stop_token=2, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
| max_len=3, | |
| cell_transform=cell_transform | |
| ) | |
| _, _, final_loop_state = tf.nn.raw_rnn(helper.cell, helper.loop_fn) | |
| _, _, beam_symbols, beam_logprobs = final_loop_state | |
| tf.variables_initializer([cell.log_table_var]).run() | |
| candidates, candidate_logprobs = sess.run((beam_symbols, beam_logprobs)) | |
| assert all(candidates[0,:] == [0,0,0]) | |
| assert np.isclose(np.exp(candidate_logprobs[0]), 0.9 * 0.9 * 0.9) | |
| # Note that these three candidates all have the same score, and the sort order | |
| # may change in the future | |
| assert all(candidates[1,:] == [0,0,1]) | |
| assert np.isclose(np.exp(candidate_logprobs[1]), 0.9 * 0.9 * 0.1) | |
| assert all(candidates[2,:] == [0,1,1]) | |
| assert np.isclose(np.exp(candidate_logprobs[2]), 0.9 * 0.1 * 0.9) | |
| assert all(candidates[3,:] == [1,1,1]) | |
| assert np.isclose(np.exp(candidate_logprobs[3]), 0.1 * 0.9 * 0.9) | |
| assert all(np.isclose(np.exp(candidate_logprobs[4:]), 0.0)) | |
| def test3(self): | |
| """ | |
| test that variable reuse works as expected | |
| """ | |
| with self.test_session() as sess: | |
| table = np.array([[[0.0, 0.6, 0.4], | |
| [0.0, 0.4, 0.6], | |
| [0.0, 0.0, 1.0]]] * 3) | |
| for cell_transform in ['default', 'flatten', 'replicate']: | |
| cell = MarkovChainCell(table) | |
| initial_state = cell.zero_state(1, tf.int32) | |
| initial_input = initial_state[0] | |
| with tf.variable_scope('test3_{}'.format(cell_transform)) as scope: | |
| best_sparse, best_logprobs = beam_decoder( | |
| cell=cell, | |
| beam_size=7, | |
| stop_token=2, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
| max_len=5, | |
| cell_transform=cell_transform, | |
| output_dense=False, | |
| scope=scope | |
| ) | |
| tf.variables_initializer([cell.log_table_var]).run() | |
| with tf.variable_scope(scope, reuse=True) as varscope: | |
| best_sparse_2, best_logprobs_2 = beam_decoder( | |
| cell=cell, | |
| beam_size=7, | |
| stop_token=2, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
| max_len=5, | |
| cell_transform=cell_transform, | |
| output_dense=False, | |
| scope=varscope | |
| ) | |
| assert all(sess.run(tf.equal(best_sparse.values, best_sparse_2.values))) | |
| assert np.isclose(*sess.run((best_logprobs, best_logprobs_2))) | |
| def test4(self): | |
| """ | |
| test batching, with statically unknown batch size | |
| """ | |
| with self.test_session() as sess: | |
| table = np.array([[[0.9, 0.1, 0], | |
| [0, 0.9, 0.1], | |
| [0, 0, 1.0]]] * 3) | |
| for cell_transform in ['default', 'flatten', 'replicate']: | |
| cell = MarkovChainCell(table) | |
| initial_state = (tf.constant([[2],[0]]),) | |
| initial_input = initial_state[0] | |
| initial_input._shape = tf.TensorShape([None, 1]) | |
| with tf.variable_scope('test4_{}'.format(cell_transform)): | |
| helper = BeamSearchHelper( | |
| cell=cell, | |
| beam_size=10, | |
| stop_token=2, | |
| initial_state=initial_state, | |
| initial_input=initial_input, | |
| tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
| max_len=3, | |
| cell_transform=cell_transform | |
| ) | |
| _, _, final_loop_state = tf.nn.raw_rnn(helper.cell, helper.loop_fn) | |
| _, _, beam_symbols, beam_logprobs = final_loop_state | |
| tf.variables_initializer([cell.log_table_var]).run() | |
| candidates, candidate_logprobs = sess.run((beam_symbols, beam_logprobs)) | |
| assert all(candidates[10,:] == [0,0,0]) | |
| assert np.isclose(np.exp(candidate_logprobs[10]), 0.9 * 0.9 * 0.9) | |
| # Note that these three candidates all have the same score, and the sort order | |
| # may change in the future | |
| assert all(candidates[11,:] == [0,0,1]) | |
| assert np.isclose(np.exp(candidate_logprobs[11]), 0.9 * 0.9 * 0.1) | |
| assert all(candidates[12,:] == [0,1,1]) | |
| assert np.isclose(np.exp(candidate_logprobs[12]), 0.9 * 0.1 * 0.9) | |
| assert all(candidates[13,:] == [1,1,1]) | |
| assert np.isclose(np.exp(candidate_logprobs[13]), 0.1 * 0.9 * 0.9) | |
| assert all(np.isclose(np.exp(candidate_logprobs[14:]), 0.0)) | |
| if __name__ == '__main__': | |
| test.main() |
@nikitakit I tried this with version 1.1.0. After modifying some api changes, I'm stuck with the error:
alueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.BasicLSTMCell object at 0x7f16c0ca3588> with a different variable scope than its first use. First use of cell was with scope 'seq2seq_att/social_bot/Decoder/decoder/decoder/basic_lstm_cell', this attempt is with scope 'seq2seq_att/social_bot/Decoder/decoder/rnn/basic_lstm_cell'
at this line: output_k, next_state_k = self.cell(inputs_k, state_k,scope=scope)
@nikitakit I have some questions about outputs_to_score_fn and the return value.
- what's the expect input shape and output shape of this function?
- the output should be the log_softmax results instead of the logits before the softmax, right?
- if I want to get n best list, what should I do ?
Thanks!
@avostryakov did you solve your problem? I think I have exactly same problem with you. the cand_symbols is just sequence of stop tokens...
@nikitakit hey! does this work with TF 1.0? thanks!