Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions copynet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def with_same_shape(old, new):
class CopyNetWrapper(tf.nn.rnn_cell.RNNCell):

def __init__(self, cell, encoder_states, encoder_input_ids, vocab_size,
gen_vocab_size=None, encoder_state_size=None, initial_cell_state=None, name=None):
hparams=None, encoder_state_size=None, initial_cell_state=None, name=None):
"""
Args:
cell:
Expand All @@ -34,12 +34,17 @@ def __init__(self, cell, encoder_states, encoder_input_ids, vocab_size,
initial_cell_state:
"""
super(CopyNetWrapper, self).__init__(name=name)
self._hparams = hparams
self._cell = cell
self._vocab_size = vocab_size
self._gen_vocab_size = gen_vocab_size or vocab_size

self._gen_vocab_size = self._hparams.gen_vocab_size or vocab_size
self._encoder_input_ids = encoder_input_ids
self._encoder_states = encoder_states

if self._hparams.beam_width > 0:
self._encoder_input_ids = tf.contrib.seq2seq.tile_batch(encoder_input_ids, multiplier=self._hparams.beam_width)
self._encoder_states = tf.contrib.seq2seq.tile_batch(encoder_states, multiplier=self._hparams.beam_width)

if encoder_state_size is None:
encoder_state_size = self._encoder_states.shape[-1].value
if encoder_state_size is None:
Expand All @@ -58,20 +63,20 @@ def __call__(self, inputs, state, scope=None):
prob_c = state.prob_c
cell_state = state.cell_state

mask = tf.cast(tf.equal(tf.expand_dims(last_ids, 1), self._encoder_input_ids), tf.float32)
mask = tf.cast(tf.equal(tf.expand_dims(last_ids, 1), self._encoder_input_ids), tf.float32)
mask_sum = tf.reduce_sum(mask, axis=1)
mask = tf.where(tf.less(mask_sum, 1e-7), mask, mask / tf.expand_dims(mask_sum, 1))
rou = mask * prob_c
selective_read = tf.einsum("ijk,ij->ik", self._encoder_states, rou)
inputs = tf.concat([inputs, selective_read], 1)

outputs, cell_state = self._cell(inputs, cell_state, scope)
outputs, cell_state = self._cell(inputs, cell_state, scope) # St,
generate_score = self._projection(outputs)

copy_score = tf.einsum("ijk,km->ijm", self._encoder_states, self._copy_weight)
copy_score = tf.nn.tanh(copy_score)

copy_score = tf.einsum("ijm,im->ij", copy_score, outputs)

encoder_input_mask = tf.one_hot(self._encoder_input_ids, self._vocab_size)
expanded_copy_score = tf.einsum("ijn,ij->ij", encoder_input_mask, copy_score)

Expand Down Expand Up @@ -111,6 +116,7 @@ def zero_state(self, batch_size, dtype):
cell_state = self._initial_cell_state
else:
cell_state = self._cell.zero_state(batch_size, dtype)

last_ids = tf.zeros([batch_size], tf.int32) - 1
prob_c = tf.zeros([batch_size, tf.shape(self._encoder_states)[1]], tf.float32)
return CopyNetWrapperState(cell_state=cell_state, last_ids=last_ids, prob_c=prob_c)
25 changes: 18 additions & 7 deletions nmt/nmt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,24 +380,32 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams):

## Decoder.
with tf.variable_scope("decoder") as decoder_scope:
cell, decoder_initial_state = self._build_decoder_cell(

cell, decoder_initial_state = self._build_decoder_cell( # Here call the AttentionModel._build_decoder_cell
hparams, encoder_outputs, encoder_state,
iterator.source_sequence_length)

# CopyNetMechanism

if hparams.copynet:
# Ensure memory is batch-major
if self.time_major:
encoder_outputs = tf.transpose(encoder_outputs, [1, 0, 2])
encoder_state_size = cell.output_size
if hparams.encoder_type == "bi":
encoder_state_size *= 2
# cell = CopyNetWrapper(cell, encoder_outputs, self.iterator.source,
# self.src_vocab_size, hparams.gen_vocab_size,
# encoder_state_size=encoder_state_size)

cell = CopyNetWrapper(cell, encoder_outputs, self.iterator.source,
self.src_vocab_size, hparams.gen_vocab_size,
encoder_state_size=encoder_state_size)
self.src_vocab_size, hparams=hparams,
encoder_state_size=encoder_state_size)
self.output_layer = None
decoder_initial_state = cell.zero_state(self.batch_size,
tf.float32).clone(cell_state=decoder_initial_state)
if hparams.beam_width > 0:
decoder_initial_state = cell.zero_state(self.batch_size*hparams.beam_width,
tf.float32).clone(cell_state=decoder_initial_state)
else:
decoder_initial_state = cell.zero_state(self.batch_size,
tf.float32).clone(cell_state=decoder_initial_state)

## Train or eval
if self.mode != tf.contrib.learn.ModeKeys.INFER:
Expand Down Expand Up @@ -439,6 +447,7 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams):
else:
logits = outputs.rnn_output


## Inference
else:
beam_width = hparams.beam_width
Expand Down Expand Up @@ -565,6 +574,8 @@ def decode(self, sess):
return sample_words, infer_summary




class Model(BaseModel):
"""Sequence-to-sequence dynamic model.

Expand Down