It works perfectly fine with the Greedy decoder. Here is the code
Tensorflow: 1.8.0
encoder_emb_inp = tf.nn.embedding_lookup(embeddings, x)
encoder_cell = rnn.GRUCell(rnn_size,name='encoder')
encoder_outputs, encoder_state= tf.nn.dynamic_rnn(encoder_cell,encoder_emb_inp,sequence_length=len_docs,dtype=tf.float32)
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(len_docs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
tiled_t = tf.contrib.seq2seq.tile_batch(t,multiplier=beam_width)
start_tokens = tf.constant(word2int['SOS'], shape=[batch_size])
decoder_cell = rnn.GRUCell(rnn_size,name='decoder')
attention_mechanism = tf.contrib.seq2seq.LuongAttention(rnn_size,tiled_encoder_outputs,memory_sequence_length=tiled_sequence_length)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism,attention_layer_size=rnn_size)
initial_state = decoder_cell.zero_state(batch_size*beam_width, dtype=tf.float32).clone(cell_state=tiled_encoder_final_state)
decoder_cell = CopyNetWrapper(decoder_cell, tiled_encoder_outputs, tiled_t,len(set(delta).union(words)),vocab_size)
initial_state = decoder_cell.zero_state(batch_size*beam_width, dtype=tf.float32).clone(cell_state=initial_state)
tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell,embedding=embeddings,start_tokens=start_tokens,end_token=word2int['EOS'],initial_state=initial_state,beam_width=beam_width,output_layer=None,length_penalty_weight=0.0)
outputs,_,_ = tf.contrib.seq2seq.dynamic_decode(decoder)
It works perfectly fine with the Greedy decoder. Here is the code
Tensorflow: 1.8.0
ERROR is :
File "/home/usr/.local/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 531, in _split_batch_beams
reshaped_t.set_shape(expected_reshaped_shape)
File "/home/usr/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 538, in set_shape
raise ValueError(str(e))
ValueError: Dimension 2 in both shapes must be equal, but are 38253 and 4. Shapes are [1,1,38253] and [1,1,4].