diff --git a/run.py b/run.py index 4dd9424..006d757 100755 --- a/run.py +++ b/run.py @@ -262,7 +262,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument # all_user_and_item = model.get_embedding_table() # item_ids = [i for i in range(0, item_size + 1)] # softmax_output_embedding = tf.nn.embedding_lookup(all_user_and_item, item_ids) - + encoder_last_layer = model.get_sequence_output() + encoder_last2_layer = model.all_encoder_layers[-2] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( bert_config, @@ -348,6 +349,15 @@ def metric_fn(masked_lm_example_loss, masked_lm_log_probs, loss=total_loss, eval_metric_ops=eval_metrics, scaffold=scaffold_fn) + + elif mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"input_ids": input_ids,"info":info} + predictions['last_layer_output'] = encoder_last_layer + predictions['last2_layer_output'] = encoder_last2_layer + output_spec = tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + scaffold=scaffold_fn) else: raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) @@ -594,6 +604,30 @@ def main(_): tf.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) + if FLAGS.do_test: + tf.logging.info("***** Running evaluation *****") + tf.logging.info(" Batch size = %d", FLAGS.batch_size) + + test_input_fn = input_fn_builder( + input_files=test_input_files, + max_seq_length=FLAGS.max_seq_length, + max_predictions_per_seq=FLAGS.max_predictions_per_seq, + is_training=False) + output_test_file = os.path.join(FLAGS.checkpointDir, + "test_results.txt") + with tf.gfile.Open(output_test_file, 'w') as writer: + # print('result',next(estimator.predict(test_input_fn,yield_single_examples=True))) + for result in estimator.predict(test_input_fn, yield_single_examples=True): + print('result', result['info'], result['input_ids']) + avg = np.array(len(result['last2_layer_output'][0])) + for i in range(len(result['input_ids'])): + if result['input_ids'][i] == 0: + print('early stop', i) + break + avg = (avg * i + result['last_layer_output'][i]) / (i + 1) + print('avg', avg) + writer.write(str(result['info'][0]) + ' ' + ' '.join(str(x) for x in avg.tolist()) + '\n') + if __name__ == "__main__": flags.mark_flag_as_required("bert_config_file")