diff --git a/copynet.py b/copynet.py index 8ad3a07..e9373a2 100644 --- a/copynet.py +++ b/copynet.py @@ -22,7 +22,7 @@ def with_same_shape(old, new): class CopyNetWrapper(tf.nn.rnn_cell.RNNCell): def __init__(self, cell, encoder_states, encoder_input_ids, vocab_size, - gen_vocab_size=None, encoder_state_size=None, initial_cell_state=None, name=None): + hparams=None, encoder_state_size=None, initial_cell_state=None, name=None): """ Args: cell: @@ -34,12 +34,17 @@ def __init__(self, cell, encoder_states, encoder_input_ids, vocab_size, initial_cell_state: """ super(CopyNetWrapper, self).__init__(name=name) + self._hparams = hparams self._cell = cell self._vocab_size = vocab_size - self._gen_vocab_size = gen_vocab_size or vocab_size - + self._gen_vocab_size = self._hparams.gen_vocab_size or vocab_size self._encoder_input_ids = encoder_input_ids self._encoder_states = encoder_states + + if self._hparams.beam_width > 0: + self._encoder_input_ids = tf.contrib.seq2seq.tile_batch(encoder_input_ids, multiplier=self._hparams.beam_width) + self._encoder_states = tf.contrib.seq2seq.tile_batch(encoder_states, multiplier=self._hparams.beam_width) + if encoder_state_size is None: encoder_state_size = self._encoder_states.shape[-1].value if encoder_state_size is None: @@ -58,20 +63,20 @@ def __call__(self, inputs, state, scope=None): prob_c = state.prob_c cell_state = state.cell_state - mask = tf.cast(tf.equal(tf.expand_dims(last_ids, 1), self._encoder_input_ids), tf.float32) + mask = tf.cast(tf.equal(tf.expand_dims(last_ids, 1), self._encoder_input_ids), tf.float32) mask_sum = tf.reduce_sum(mask, axis=1) mask = tf.where(tf.less(mask_sum, 1e-7), mask, mask / tf.expand_dims(mask_sum, 1)) rou = mask * prob_c selective_read = tf.einsum("ijk,ij->ik", self._encoder_states, rou) inputs = tf.concat([inputs, selective_read], 1) - outputs, cell_state = self._cell(inputs, cell_state, scope) + outputs, cell_state = self._cell(inputs, cell_state, scope) # St, generate_score = self._projection(outputs) copy_score = tf.einsum("ijk,km->ijm", self._encoder_states, self._copy_weight) copy_score = tf.nn.tanh(copy_score) - copy_score = tf.einsum("ijm,im->ij", copy_score, outputs) + encoder_input_mask = tf.one_hot(self._encoder_input_ids, self._vocab_size) expanded_copy_score = tf.einsum("ijn,ij->ij", encoder_input_mask, copy_score) @@ -111,6 +116,7 @@ def zero_state(self, batch_size, dtype): cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) + last_ids = tf.zeros([batch_size], tf.int32) - 1 prob_c = tf.zeros([batch_size, tf.shape(self._encoder_states)[1]], tf.float32) return CopyNetWrapperState(cell_state=cell_state, last_ids=last_ids, prob_c=prob_c) \ No newline at end of file diff --git a/nmt/nmt/model.py b/nmt/nmt/model.py index 888b13f..d5f16d6 100644 --- a/nmt/nmt/model.py +++ b/nmt/nmt/model.py @@ -380,11 +380,11 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): ## Decoder. with tf.variable_scope("decoder") as decoder_scope: - cell, decoder_initial_state = self._build_decoder_cell( + + cell, decoder_initial_state = self._build_decoder_cell( # Here call the AttentionModel._build_decoder_cell hparams, encoder_outputs, encoder_state, iterator.source_sequence_length) - - # CopyNetMechanism + if hparams.copynet: # Ensure memory is batch-major if self.time_major: @@ -392,12 +392,20 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): encoder_state_size = cell.output_size if hparams.encoder_type == "bi": encoder_state_size *= 2 + # cell = CopyNetWrapper(cell, encoder_outputs, self.iterator.source, + # self.src_vocab_size, hparams.gen_vocab_size, + # encoder_state_size=encoder_state_size) + cell = CopyNetWrapper(cell, encoder_outputs, self.iterator.source, - self.src_vocab_size, hparams.gen_vocab_size, - encoder_state_size=encoder_state_size) + self.src_vocab_size, hparams=hparams, + encoder_state_size=encoder_state_size) self.output_layer = None - decoder_initial_state = cell.zero_state(self.batch_size, - tf.float32).clone(cell_state=decoder_initial_state) + if hparams.beam_width > 0: + decoder_initial_state = cell.zero_state(self.batch_size*hparams.beam_width, + tf.float32).clone(cell_state=decoder_initial_state) + else: + decoder_initial_state = cell.zero_state(self.batch_size, + tf.float32).clone(cell_state=decoder_initial_state) ## Train or eval if self.mode != tf.contrib.learn.ModeKeys.INFER: @@ -439,6 +447,7 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams): else: logits = outputs.rnn_output + ## Inference else: beam_width = hparams.beam_width @@ -565,6 +574,8 @@ def decode(self, sess): return sample_words, infer_summary + + class Model(BaseModel): """Sequence-to-sequence dynamic model.