From fa50fc5c7aa0db8b42006e8035678b72e866b671 Mon Sep 17 00:00:00 2001 From: laiyongkui Date: Fri, 26 Jun 2020 22:01:22 +0800 Subject: [PATCH 1/6] Add joint learning, MDP v0.2 --- main.py | 83 +++--- models/few_shot_learner.py | 276 ++++++++++++++++++++ models/few_shot_seq_labeler.py | 139 ++-------- models/few_shot_text_classifier.py | 145 ++--------- models/modules/context_embedder_base.py | 2 - models/modules/similarity_scorer_base.py | 2 +- readme.md | 4 +- scripts/run_bert_sc+sl.sh | 284 +++++++++++++++++++++ scripts/run_bert_sc.sh | 2 +- scripts/run_electra_sc+sl.sh | 284 +++++++++++++++++++++ scripts/run_electra_sc.sh | 8 +- utils/data_loader.py | 11 - utils/device_helper.py | 3 + utils/iter_helper.py | 29 ++- utils/model_helper.py | 128 +++++----- utils/opt.py | 12 +- utils/preprocessor.py | 308 ++++++++++++----------- utils/tester.py | 146 ++++++----- utils/trainer.py | 140 ++++++----- 19 files changed, 1367 insertions(+), 639 deletions(-) create mode 100644 models/few_shot_learner.py create mode 100644 scripts/run_bert_sc+sl.sh create mode 100644 scripts/run_electra_sc+sl.sh diff --git a/main.py b/main.py index bee4d50..9fb4df5 100644 --- a/main.py +++ b/main.py @@ -29,49 +29,54 @@ def get_training_data_and_feature(opt, data_loader, preprocessor): """ prepare feature and data """ if opt.load_feature: try: - train_features, train_label2id, train_id2label = load_feature(opt.train_path.replace('.json', '.saved.pk')) - dev_features, dev_label2id, dev_id2label = load_feature(opt.dev_path.replace('.json', '.saved.pk')) + train_features, train_label2id_map, train_id2label_map = \ + load_feature(opt.train_path.replace('.json', '.saved.pk')) + dev_features, dev_label2id_map, dev_id2label_map = load_feature(opt.dev_path.replace('.json', '.saved.pk')) except FileNotFoundError: opt.load_feature, opt.save_feature = False, True # Not a saved feature file yet, make it - train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label =\ + train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map =\ get_training_data_and_feature(opt, data_loader, preprocessor) opt.load_feature, opt.save_feature = True, False # restore option else: train_examples, train_max_len, train_max_support_size = data_loader.load_data(path=opt.train_path) dev_examples, dev_max_len, dev_max_support_size = data_loader.load_data(path=opt.dev_path) - train_label2id, train_id2label = make_dict(opt, train_examples) - dev_label2id, dev_id2label = make_dict(opt, dev_examples) + train_label2id_map, train_id2label_map = make_dict(opt, train_examples) + dev_label2id_map, dev_id2label_map = make_dict(opt, dev_examples) logger.info(' Finish train dev prepare dict ') train_features = preprocessor.construct_feature( - train_examples, train_max_support_size, train_label2id, train_id2label) - dev_features = preprocessor.construct_feature(dev_examples, dev_max_support_size, dev_label2id, dev_id2label) + train_examples, train_max_support_size, train_label2id_map, train_id2label_map) + dev_features = preprocessor.construct_feature(dev_examples, dev_max_support_size, + dev_label2id_map, dev_id2label_map) logger.info(' Finish prepare train dev features ') if opt.save_feature: - save_feature(opt.train_path.replace('.json', '.saved.pk'), train_features, train_label2id, train_id2label) - save_feature(opt.dev_path.replace('.json', '.saved.pk'), dev_features, dev_label2id, dev_id2label) - return train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label + save_feature(opt.train_path.replace('.json', '.saved.pk'), + train_features, train_label2id_map, train_id2label_map) + save_feature(opt.dev_path.replace('.json', '.saved.pk'), dev_features, dev_label2id_map, dev_id2label_map) + return train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map def get_testing_data_feature(opt, data_loader, preprocessor): """ prepare feature and data """ if opt.load_feature: try: - test_features, test_label2id, test_id2label = load_feature(opt.test_path.replace('.json', '.saved.pk')) + test_features, test_label2id_map, test_id2label_map = \ + load_feature(opt.test_path.replace('.json', '.saved.pk')) except FileNotFoundError: opt.load_feature, opt.save_feature = False, True # Not a saved feature file yet, make it - test_features, test_label2id, test_id2label = get_testing_data_feature(opt, data_loader, preprocessor) + test_features, test_label2id_map, test_id2label_map = get_testing_data_feature(opt, data_loader, preprocessor) opt.load_feature, opt.save_feature = True, False # restore option else: test_examples, test_max_len, test_max_support_size = data_loader.load_data(path=opt.test_path) - test_label2id, test_id2label = make_dict(opt, test_examples) + test_label2id_map, test_id2label_map = make_dict(opt, test_examples) logger.info(' Finish prepare test dict') test_features = preprocessor.construct_feature( - test_examples, test_max_support_size, test_label2id, test_id2label) + test_examples, test_max_support_size, test_label2id_map, test_id2label_map) logger.info(' Finish prepare test feature') if opt.save_feature: - save_feature(opt.test_path.replace('.json', '.saved.pk'), test_features, test_label2id, test_id2label) - return test_features, test_label2id, test_id2label + save_feature(opt.test_path.replace('.json', '.saved.pk'), + test_features, test_label2id_map, test_id2label_map) + return test_features, test_label2id_map, test_id2label_map def main(): @@ -92,30 +97,32 @@ def main(): data_loader = FewShotRawDataLoader(opt) preprocessor = make_preprocessor(opt) if opt.do_train: - train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label = \ + train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map = \ get_training_data_and_feature(opt, data_loader, preprocessor) - if opt.mask_transition and opt.task == 'sl': - opt.train_label_mask = make_label_mask(opt, opt.train_path, train_label2id) - opt.dev_label_mask = make_label_mask(opt, opt.dev_path, dev_label2id) + if opt.mask_transition and 'sl' in opt.task: + opt.train_label_mask = make_label_mask(opt, opt.train_path, train_label2id_map['sl']) + opt.dev_label_mask = make_label_mask(opt, opt.dev_path, dev_label2id_map['sl']) else: - train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label = [None] * 6 - if opt.mask_transition and opt.task == 'sl': + train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map = \ + [None] * 6 + if opt.mask_transition and 'sl' in opt.task: opt.train_label_mask = None opt.dev_label_mask = None + if opt.do_predict: - test_features, test_label2id, test_id2label = get_testing_data_feature(opt, data_loader, preprocessor) - if opt.mask_transition and opt.task == 'sl': - opt.test_label_mask = make_label_mask(opt, opt.test_path, test_label2id) + test_features, test_label2id_map, test_id2label_map = get_testing_data_feature(opt, data_loader, preprocessor) + if opt.mask_transition and 'sl' in opt.task: + opt.test_label_mask = make_label_mask(opt, opt.test_path, test_label2id_map['sl']) else: - test_features, test_label2id, test_id2label = [None] * 3 - if opt.mask_transition and opt.task == 'sl': + test_features, test_label2id_map, test_id2label_map = [None] * 3 + if opt.mask_transition and 'sl' in opt.task: opt.test_label_mask = None ''' over fitting test ''' if opt.do_overfit_test: - test_features, test_label2id, test_id2label = train_features, train_label2id, train_id2label - dev_features, dev_label2id, dev_id2label = train_features, train_label2id, train_id2label + test_features, test_label2id_map, test_id2label_map = train_features, train_label2id_map, train_id2label_map + dev_features, dev_label2id_map, dev_id2label_map = train_features, train_label2id_map, train_id2label_map ''' select training & testing mode ''' trainer_class = SchemaFewShotTrainer if opt.use_schema else FewShotTrainer @@ -130,9 +137,10 @@ def main(): opt = training_model.opt opt.warmup_epoch = -1 else: - training_model = make_model(opt, config={'num_tags': len(train_label2id)}) + training_model = make_model(opt, config={ + 'num_tags': len(train_label2id_map['sl']) if 'sl' in train_label2id_map else 0}) training_model = prepare_model(opt, training_model, device, n_gpu) - if opt.mask_transition and opt.task == 'sl': + if opt.mask_transition and 'sl' in opt.task: training_model.label_mask = opt.train_label_mask.to(device) # prepare a set of name subseuqence/mark to use different learning rate for part of params upper_structures = [ @@ -153,12 +161,13 @@ def main(): print('========== Warmup training finished! ==========') trained_model, best_dev_score, test_score = trainer.do_train( training_model, train_features, opt.num_train_epochs, - dev_features, dev_id2label, test_features, test_id2label, best_dev_score_now=0) + dev_features, dev_id2label_map, test_features, test_id2label_map, best_dev_score_now=0) # decide the best model if not opt.eval_when_train: # select best among check points best_model, best_score, test_score_then = trainer.select_model_from_check_point( - train_id2label, dev_features, dev_id2label, test_features, test_id2label, rm_cpt=opt.delete_checkpoint) + train_id2label_map, dev_features, dev_id2label_map, test_features, test_id2label_map, + rm_cpt=opt.delete_checkpoint) else: # best model is selected during training best_model = trained_model logger.info('dev:{}, test:{}'.format(best_dev_score, test_score)) @@ -173,17 +182,17 @@ def main(): if not opt.saved_model_path or not os.path.exists(opt.saved_model_path): raise ValueError("No model trained and no trained model file given (or not exist)") if os.path.isdir(opt.saved_model_path): # eval a list of checkpoints - max_score = eval_check_points(opt, tester, test_features, test_id2label, device) + max_score = eval_check_points(opt, tester, test_features, test_id2label_map, device) print('best check points scores:{}'.format(max_score)) exit(0) else: best_model = load_model(opt.saved_model_path) ''' test the best model ''' - testing_model = tester.clone_model(best_model, test_id2label) # copy reusable params - if opt.mask_transition and opt.task == 'sl': + testing_model = tester.clone_model(best_model, test_id2label_map) # copy reusable params + if opt.mask_transition and 'sl' in opt.task: testing_model.label_mask = opt.test_label_mask.to(device) - test_score = tester.do_test(testing_model, test_features, test_id2label, log_mark='test_pred') + test_score = tester.do_test(testing_model, test_features, test_id2label_map, log_mark='test_pred') logger.info('test:{}'.format(test_score)) print('test:{}'.format(test_score)) diff --git a/models/few_shot_learner.py b/models/few_shot_learner.py new file mode 100644 index 0000000..9ec0e87 --- /dev/null +++ b/models/few_shot_learner.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python +import torch +from typing import Tuple, Dict, List +from models.modules.context_embedder_base import ContextEmbedderBase +from models.modules.emission_scorer_base import EmissionScorerBase +from models.modules.transition_scorer import TransitionScorerBase +from models.few_shot_seq_labeler import FewShotSeqLabeler, SchemaFewShotSeqLabeler +from models.few_shot_text_classifier import FewShotTextClassifier, SchemaFewShotTextClassifier + + +class FewShotLearner(torch.nn.Module): + + def __init__(self, + opt, + context_embedder: ContextEmbedderBase, + # emission_scorer_map: Dict[str, EmissionScorerBase], + # decoder_map: Dict[str, torch.nn.Module], + model_map: Dict[str, torch.nn.Module], + # transition_scorer: TransitionScorerBase = None, + config: dict = None, # store necessary setting or none-torch params + emb_log: str = None): + super(FewShotLearner, self).__init__() + self.opt = opt + self.context_embedder = context_embedder + # self.emission_scorer_map = emission_scorer_map + # self.transition_scorer = transition_scorer + # self.decoder_map = decoder_map + self.no_embedder_grad = opt.no_embedder_grad + self.label_mask = None + self.config = config + self.emb_log = emb_log + + # self.task_lst = decoder_map.keys() + self.model_map = model_map + # for task in self.task_lst: + # if task == 'sl': + # self.model_map[task] = FewShotSeqLabeler(opt=opt, + # context_embedder=context_embedder, + # emission_scorer=emission_scorer_map[task], + # decoder=decoder_map[task], + # transition_scorer=transition_scorer, + # label_mask=self.label_mask, + # config=config, + # emb_log=emb_log) + # elif task == 'sc': + # self.model_map[task] = FewShotTextClassifier(opt=opt, + # context_embedder=context_embedder, + # emission_scorer=emission_scorer_map[task], + # decoder=decoder_map[task], + # config=config, + # emb_log=emb_log) + + def forward( + self, + test_token_ids: torch.Tensor, + test_segment_ids: torch.Tensor, + test_nwp_index: torch.Tensor, + test_input_mask: torch.Tensor, + test_output_mask_map: Dict[str, torch.Tensor], + support_token_ids: torch.Tensor, + support_segment_ids: torch.Tensor, + support_nwp_index: torch.Tensor, + support_input_mask: torch.Tensor, + support_output_mask_map: Dict[str, torch.Tensor], + test_target_map: Dict[str, torch.Tensor], + support_target_map: Dict[str, torch.Tensor], + support_num: torch.Tensor, + ): + """ + :param test_token_ids: (batch_size, test_len) + :param test_segment_ids: (batch_size, test_len) + :param test_nwp_index: (batch_size, test_len) + :param test_input_mask: (batch_size, test_len) + :param test_output_mask_map: A dict of (batch_size, test_len) + :param support_token_ids: (batch_size, support_size, support_len) + :param support_segment_ids: (batch_size, support_size, support_len) + :param support_nwp_index: (batch_size, support_size, support_len) + :param support_input_mask: (batch_size, support_size, support_len) + :param support_output_mask_map: A dict of (batch_size, support_size, support_len) + :param test_target_map: A dict of index targets (batch_size, test_len) + :param support_target_map: A dict of one-hot targets (batch_size, support_size, support_len, num_tags) + :param support_num: (batch_size, 1) + :return: + """ + # reps for tokens: (batch_size, support_size, nwp_sent_len, emb_len) + seq_test_reps, seq_support_reps, tc_test_reps, tc_support_reps = self.get_context_reps( + test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, support_segment_ids, + support_nwp_index, support_input_mask + ) + + reps_map = {'sl': {'test': seq_test_reps, 'support': seq_support_reps}, + 'sc': {'test': tc_test_reps, 'support': tc_support_reps}} + + if self.training: + loss = 0. + for task in self.opt.task: + loss += self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], support_num, + self.training) + return loss + else: + prediction_map = {} + for task in self.opt.task: + prediction = self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], support_num, + self.training) + prediction_map[task] = prediction + return prediction_map + + def get_context_reps( + self, + test_token_ids: torch.Tensor, + test_segment_ids: torch.Tensor, + test_nwp_index: torch.Tensor, + test_input_mask: torch.Tensor, + support_token_ids: torch.Tensor, + support_segment_ids: torch.Tensor, + support_nwp_index: torch.Tensor, + support_input_mask: torch.Tensor, + ): + if self.no_embedder_grad: + self.context_embedder.eval() # to avoid the dropout effect of reps model + self.context_embedder.requires_grad = False + else: + self.context_embedder.train() # to avoid the dropout effect of reps model + self.context_embedder.requires_grad = True + seq_test_reps, seq_support_reps, tc_test_reps, tc_support_reps = self.context_embedder( + test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, + support_segment_ids, + support_nwp_index, support_input_mask + ) + if self.no_embedder_grad: + seq_test_reps = seq_test_reps.detach() # detach the reps part from graph + seq_support_reps = seq_support_reps.detach() # detach the reps part from graph + tc_test_reps = tc_test_reps.detach() # detach the reps part from graph + tc_support_reps = tc_support_reps.detach() # detach the reps part from graph + return seq_test_reps, seq_support_reps, tc_test_reps, tc_support_reps + + +class SchemaFewShotLearner(FewShotLearner): + def __init__( + self, + opt, + context_embedder: ContextEmbedderBase, + # emission_scorer_map: Dict[str, EmissionScorerBase], + # decoder_map: Dict[str, torch.nn.Module], + model_map: Dict[str, torch.nn.Module], + # transition_scorer: TransitionScorerBase = None, + config: dict = None, # store necessary setting or none-torch params + emb_log: str = None + ): + # super(SchemaFewShotLearner, self).__init__( + # opt, context_embedder, emission_scorer_map, decoder_map, transition_scorer, config, emb_log) + super(SchemaFewShotLearner, self).__init__(opt, context_embedder, model_map, config, emb_log) + + # self.task_lst = decoder_map.keys() + self.model_map = model_map + # for task in self.task_lst: + # if task == 'sl': + # self.model_map[task] = SchemaFewShotSeqLabeler(opt=opt, + # context_embedder=context_embedder, + # emission_scorer=emission_scorer_map[task], + # decoder=decoder_map[task], + # transition_scorer=transition_scorer, + # label_mask=self.label_mask, + # config=config, + # emb_log=emb_log) + # elif task == 'sc': + # self.model_map[task] = SchemaFewShotTextClassifier(opt=opt, + # context_embedder=context_embedder, + # emission_scorer=emission_scorer_map[task], + # decoder=decoder_map[task], + # config=config, + # emb_log=emb_log) + + def forward( + self, + test_token_ids: torch.Tensor, + test_segment_ids: torch.Tensor, + test_nwp_index: torch.Tensor, + test_input_mask: torch.Tensor, + test_output_mask_map: Dict[str, torch.Tensor], + support_token_ids: torch.Tensor, + support_segment_ids: torch.Tensor, + support_nwp_index: torch.Tensor, + support_input_mask: torch.Tensor, + support_output_mask_map: Dict[str, torch.Tensor], + test_target_map: Dict[str, torch.Tensor], + support_target_map: Dict[str, torch.Tensor], + support_num: torch.Tensor, + label_token_ids_map: torch.Tensor = None, + label_segment_ids_map: torch.Tensor = None, + label_nwp_index_map: torch.Tensor = None, + label_input_mask_map: torch.Tensor = None, + label_output_mask_map: torch.Tensor = None, + ): + """ + few-shot sequence labeler using schema information + :param test_token_ids: (batch_size, test_len) + :param test_segment_ids: (batch_size, test_len) + :param test_nwp_index: (batch_size, test_len) + :param test_input_mask: (batch_size, test_len) + :param test_output_mask_map: A dict of (batch_size, test_len) + :param support_token_ids: (batch_size, support_size, support_len) + :param support_segment_ids: (batch_size, support_size, support_len) + :param support_nwp_index: (batch_size, support_size, support_len) + :param support_input_mask: (batch_size, support_size, support_len) + :param support_output_mask_map: A dict of (batch_size, support_size, support_len) + :param test_target_map: A dict of index targets (batch_size, test_len) + :param support_target_map: A dict of one-hot targets (batch_size, support_size, support_len, num_tags) + :param support_num: (batch_size, 1) + :param label_token_ids_map: A dict of tensor which + if label_reps=cat: + (batch_size, label_num * label_des_len) + elif: + (batch_size, label_num, label_des_len) + :param label_segment_ids_map: A dict of tensor which is same to label token ids + :param label_nwp_index_map: A dict of tensor which is same to label token ids + :param label_input_mask_map: A dict of tensor which is same to label token ids + :param label_output_mask_map: A dict of tensor which is same to label token ids + :return: + """ + # reps for tokens: (batch_size, support_size, nwp_sent_len, emb_len) + seq_test_reps, seq_support_reps, tc_test_reps, tc_support_reps = self.get_context_reps( + test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, support_segment_ids, + support_nwp_index, support_input_mask + ) + + reps_map = {'sl': {'test': seq_test_reps, 'support': seq_support_reps}, + 'sc': {'test': tc_test_reps, 'support': tc_support_reps}} + + # get label reps, shape (batch_size, max_label_num, emb_dim) + label_reps_map = {} + for task in self.opt.task: + label_reps_map[task] = self.get_label_reps( + label_token_ids_map[task], label_segment_ids_map[task], + label_nwp_index_map[task], label_input_mask_map[task] + ) + + if self.training: + loss = 0. + for task in self.opt.task: + loss += self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + support_num, label_reps_map[task], self.training) + return loss + else: + prediction_map = {} + for task in self.opt.task: + prediction = self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + support_num, label_reps_map[task], self.training) + prediction_map[task] = prediction + return prediction_map + + def get_label_reps( + self, + label_token_ids: torch.Tensor, + label_segment_ids: torch.Tensor, + label_nwp_index: torch.Tensor, + label_input_mask: torch.Tensor, + ) -> torch.Tensor: + """ + :param label_token_ids: + :param label_segment_ids: + :param label_nwp_index: + :param label_input_mask: + :return: shape (batch_size, label_num, label_des_len) + """ + return self.context_embedder( + label_token_ids, label_segment_ids, label_nwp_index, label_input_mask, reps_type='label') + diff --git a/models/few_shot_seq_labeler.py b/models/few_shot_seq_labeler.py index a170206..9a6b84b 100644 --- a/models/few_shot_seq_labeler.py +++ b/models/few_shot_seq_labeler.py @@ -15,6 +15,7 @@ def __init__(self, emission_scorer: EmissionScorerBase, decoder: torch.nn.Module, transition_scorer: TransitionScorerBase = None, + label_mask: torch.Tensor = None, config: dict = None, # store necessary setting or none-torch params emb_log: str = None): super(FewShotSeqLabeler, self).__init__() @@ -24,48 +25,32 @@ def __init__(self, self.transition_scorer = transition_scorer self.decoder = decoder self.no_embedder_grad = opt.no_embedder_grad - self.label_mask = None + self.label_mask = label_mask self.config = config self.emb_log = emb_log def forward( self, - test_token_ids: torch.Tensor, - test_segment_ids: torch.Tensor, - test_nwp_index: torch.Tensor, - test_input_mask: torch.Tensor, + test_reps: torch.Tensor, test_output_mask: torch.Tensor, - support_token_ids: torch.Tensor, - support_segment_ids: torch.Tensor, - support_nwp_index: torch.Tensor, - support_input_mask: torch.Tensor, + support_reps: torch.Tensor, support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, support_num: torch.Tensor, + is_training: bool = True, ): """ - :param test_token_ids: (batch_size, test_len) - :param test_segment_ids: (batch_size, test_len) - :param test_nwp_index: (batch_size, test_len) - :param test_input_mask: (batch_size, test_len) + :param test_reps: (batch_size, test_len, emb_dim) :param test_output_mask: (batch_size, test_len) - :param support_token_ids: (batch_size, support_size, support_len) - :param support_segment_ids: (batch_size, support_size, support_len) - :param support_nwp_index: (batch_size, support_size, support_len) - :param support_input_mask: (batch_size, support_size, support_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, test_len) :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) :param support_num: (batch_size, 1) + :param is_training: the training mode :return: """ - # reps for tokens: (batch_size, support_size, nwp_sent_len, emb_len) - test_reps, support_reps = self.get_context_reps( - test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, support_segment_ids, - support_nwp_index, support_input_mask - ) - # calculate emission: shape(batch_size, test_len, no_pad_num_tag) emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target) @@ -82,7 +67,7 @@ def forward( transitions = self.mask_transition(transitions, self.label_mask) self.decoder: ConditionalRandomField - if self.training: + if is_training: # the CRF staff llh = self.decoder.forward( inputs=logits, @@ -104,7 +89,7 @@ def forward( prediction = self.add_back_pad_label(prediction) else: self.decoder: SequenceLabeler - if self.training: + if is_training: loss = self.decoder.forward(logits=logits, tags=test_target, mask=test_output_mask) @@ -112,37 +97,11 @@ def forward( prediction = self.decoder.decode(logits=logits, masks=test_output_mask) # we block pad label(id=0) before by - 1, here, we add 1 back prediction = self.add_back_pad_label(prediction) - if self.training: + if is_training: return loss else: return prediction - def get_context_reps( - self, - test_token_ids: torch.Tensor, - test_segment_ids: torch.Tensor, - test_nwp_index: torch.Tensor, - test_input_mask: torch.Tensor, - support_token_ids: torch.Tensor, - support_segment_ids: torch.Tensor, - support_nwp_index: torch.Tensor, - support_input_mask: torch.Tensor, - ): - if self.no_embedder_grad: - self.context_embedder.eval() # to avoid the dropout effect of reps model - self.context_embedder.requires_grad = False - else: - self.context_embedder.train() # to avoid the dropout effect of reps model - self.context_embedder.requires_grad = True - test_reps, support_reps, _, _ = self.context_embedder( - test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, support_segment_ids, - support_nwp_index, support_input_mask - ) - if self.no_embedder_grad: - test_reps = test_reps.detach() # detach the reps part from graph - support_reps = support_reps.detach() # detach the reps part from graph - return test_reps, support_reps - def add_back_pad_label(self, predictions: List[List[int]]): for pred in predictions: for ind, l_id in enumerate(pred): @@ -163,73 +122,43 @@ def __init__( emission_scorer: EmissionScorerBase, decoder: torch.nn.Module, transition_scorer: TransitionScorerBase = None, + label_mask: torch.Tensor = None, config: dict = None, # store necessary setting or none-torch params emb_log: str = None ): super(SchemaFewShotSeqLabeler, self).__init__( - opt, context_embedder, emission_scorer, decoder, transition_scorer, config, emb_log) + opt, context_embedder, emission_scorer, decoder, transition_scorer, label_mask, config, emb_log) def forward( self, - test_token_ids: torch.Tensor, - test_segment_ids: torch.Tensor, - test_nwp_index: torch.Tensor, - test_input_mask: torch.Tensor, + test_reps: torch.Tensor, test_output_mask: torch.Tensor, - support_token_ids: torch.Tensor, - support_segment_ids: torch.Tensor, - support_nwp_index: torch.Tensor, - support_input_mask: torch.Tensor, + support_reps: torch.Tensor, support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, support_num: torch.Tensor, - label_token_ids: torch.Tensor = None, - label_segment_ids: torch.Tensor = None, - label_nwp_index: torch.Tensor = None, - label_input_mask: torch.Tensor = None, - label_output_mask: torch.Tensor = None, + label_reps: torch.Tensor = None, + is_training: bool = True, ): """ few-shot sequence labeler using schema information - :param test_token_ids: (batch_size, test_len) - :param test_segment_ids: (batch_size, test_len) - :param test_nwp_index: (batch_size, test_len) - :param test_input_mask: (batch_size, test_len) + :param test_reps: (batch_size, test_len, emb_dim) :param test_output_mask: (batch_size, test_len) - :param support_token_ids: (batch_size, support_size, support_len) - :param support_segment_ids: (batch_size, support_size, support_len) - :param support_nwp_index: (batch_size, support_size, support_len) - :param support_input_mask: (batch_size, support_size, support_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, test_len) :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) :param support_num: (batch_size, 1) - :param label_token_ids: - if label_reps=cat: - (batch_size, label_num * label_des_len) - elif: - (batch_size, label_num, label_des_len) - :param label_segment_ids: same to label token ids - :param label_nwp_index: same to label token ids - :param label_input_mask: same to label token ids - :param label_output_mask: same to label token ids + :param label_reps: (batch_size, label_num, emb_dim) + :param is_training: the training mode :return: """ - test_reps, support_reps = self.get_context_reps( - test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, - support_token_ids, support_segment_ids, support_nwp_index, support_input_mask - ) - # get label reps, shape (batch_size, max_label_num, emb_dim) - label_reps = self.get_label_reps( - label_token_ids, label_segment_ids, label_nwp_index, label_input_mask, - ) # calculate emission: shape(batch_size, test_len, no_pad_num_tag) - # todo: Design new emission here emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target, label_reps) - if not self.training and self.emb_log: + if not is_training and self.emb_log: self.emb_log.write('\n'.join(['test_target\t' + '\t'.join(map(str, one_target)) for one_target in test_target.tolist()]) + '\n') @@ -239,7 +168,6 @@ def forward( test_target = torch.nn.functional.relu(test_target - 1) loss, prediction = torch.FloatTensor([0]).to(test_target.device), None - # todo: Design new transition here if self.transition_scorer: transitions, start_transitions, end_transitions = self.transition_scorer(test_reps, support_target, label_reps[0]) @@ -247,7 +175,7 @@ def forward( transitions = self.mask_transition(transitions, self.label_mask) self.decoder: ConditionalRandomField - if self.training: + if is_training: # the CRF staff llh = self.decoder.forward( inputs=logits, @@ -269,7 +197,7 @@ def forward( prediction = self.add_back_pad_label(prediction) else: self.decoder: SequenceLabeler - if self.training: + if is_training: loss = self.decoder.forward(logits=logits, tags=test_target, mask=test_output_mask) @@ -277,28 +205,11 @@ def forward( prediction = self.decoder.decode(logits=logits, masks=test_output_mask) # we block pad label(id=0) before by - 1, here, we add 1 back prediction = self.add_back_pad_label(prediction) - if self.training: + if is_training: return loss else: return prediction - def get_label_reps( - self, - label_token_ids: torch.Tensor, - label_segment_ids: torch.Tensor, - label_nwp_index: torch.Tensor, - label_input_mask: torch.Tensor, - ) -> torch.Tensor: - """ - :param label_token_ids: - :param label_segment_ids: - :param label_nwp_index: - :param label_input_mask: - :return: shape (batch_size, label_num, label_des_len) - """ - return self.context_embedder( - label_token_ids, label_segment_ids, label_nwp_index, label_input_mask, reps_type='label') - def main(): pass diff --git a/models/few_shot_text_classifier.py b/models/few_shot_text_classifier.py index b0b76ea..6fd7689 100644 --- a/models/few_shot_text_classifier.py +++ b/models/few_shot_text_classifier.py @@ -25,50 +25,26 @@ def __init__(self, def forward( self, - test_token_ids: torch.Tensor, - test_segment_ids: torch.Tensor, - test_nwp_index: torch.Tensor, - test_input_mask: torch.Tensor, + test_reps: torch.Tensor, test_output_mask: torch.Tensor, - support_token_ids: torch.Tensor, - support_segment_ids: torch.Tensor, - support_nwp_index: torch.Tensor, - support_input_mask: torch.Tensor, + support_reps: torch.Tensor, support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, support_num: torch.Tensor, - support_sentence_feature: torch.Tensor = None, - test_sentence_feature: torch.Tensor = None, - support_sentence_target: torch.Tensor = None, - test_sentence_target: torch.Tensor = None + is_training: bool = True, ): """ - :param test_token_ids: (batch_size, test_len) - :param test_segment_ids: (batch_size, test_len) - :param test_nwp_index: (batch_size, test_len) - :param test_input_mask: (batch_size, test_len) + :param test_reps: (batch_size, test_len, emb_dim) :param test_output_mask: (batch_size, test_len) - :param support_token_ids: (batch_size, support_size, support_len) - :param support_segment_ids: (batch_size, support_size, support_len) - :param support_nwp_index: (batch_size, support_size, support_len) - :param support_input_mask: (batch_size, support_size, support_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, multi-label_num) :param support_target: one-hot targets (batch_size, support_size, multi-label_num, num_tags) :param support_num: (batch_size, 1) - :param support_sentence_feature: same to label token ids - :param test_sentence_feature: same to label token ids - :param support_sentence_target: same to label token ids - :param test_sentence_target: same to label token ids + :param is_training: the training mode :return: """ - # reps for whole sentences: (batch_size, support_size, 1, emb_len) - test_reps, support_reps = self.get_context_reps( - test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, support_segment_ids, - support_nwp_index, support_input_mask - ) - # calculate emission: shape(batch_size, 1, no_pad_num_tag) test_output_mask = torch.ones(test_output_mask.shape[0], 1).to(test_output_mask.device) # for sc, each test has only 1 output emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target) @@ -78,44 +54,18 @@ def forward( test_target = torch.nn.functional.relu(test_target - 1) loss, prediction = torch.FloatTensor(0).to(test_target.device), None - if self.training: + if is_training: loss = self.decoder.forward(logits=logits, mask=test_output_mask, tags=test_target) else: prediction = self.decoder.decode(logits=logits) # we block pad label(id=0) before by - 1, here, we add 1 back prediction = self.add_back_pad_label(prediction) - if self.training: + if is_training: return loss else: return prediction - def get_context_reps( - self, - test_token_ids: torch.Tensor, - test_segment_ids: torch.Tensor, - test_nwp_index: torch.Tensor, - test_input_mask: torch.Tensor, - support_token_ids: torch.Tensor, - support_segment_ids: torch.Tensor, - support_nwp_index: torch.Tensor, - support_input_mask: torch.Tensor, - ): - if self.no_embedder_grad: - self.context_embedder.eval() # to avoid the dropout effect of reps model - self.context_embedder.requires_grad = False - else: - self.context_embedder.train() # to avoid the dropout effect of reps model - self.context_embedder.requires_grad = True - _, _, test_reps, support_reps = self.context_embedder( - test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, support_token_ids, support_segment_ids, - support_nwp_index, support_input_mask - ) - if self.no_embedder_grad: - test_reps = test_reps.detach() # detach the reps part from graph - support_reps = support_reps.detach() # detach the reps part from graph - return test_reps, support_reps - def add_back_pad_label(self, predictions: List[List[int]]): for pred in predictions: for ind, l_id in enumerate(pred): @@ -137,73 +87,33 @@ def __init__( def forward( self, - test_token_ids: torch.Tensor, - test_segment_ids: torch.Tensor, - test_nwp_index: torch.Tensor, - test_input_mask: torch.Tensor, + test_reps: torch.Tensor, test_output_mask: torch.Tensor, - support_token_ids: torch.Tensor, - support_segment_ids: torch.Tensor, - support_nwp_index: torch.Tensor, - support_input_mask: torch.Tensor, + support_reps: torch.Tensor, support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, support_num: torch.Tensor, - label_token_ids: torch.Tensor = None, - label_segment_ids: torch.Tensor = None, - label_nwp_index: torch.Tensor = None, - label_input_mask: torch.Tensor = None, - label_output_mask: torch.Tensor = None, - support_sentence_feature: torch.Tensor = None, - test_sentence_feature: torch.Tensor = None, - support_sentence_target: torch.Tensor = None, - test_sentence_target: torch.Tensor = None + label_reps: torch.Tensor = None, + is_training: bool = True, ): """ few-shot sequence labeler using schema information - :param test_token_ids: (batch_size, test_len) - :param test_segment_ids: (batch_size, test_len) - :param test_nwp_index: (batch_size, test_len) - :param test_input_mask: (batch_size, test_len) + :param test_reps: (batch_size, test_len, emb_dim) :param test_output_mask: (batch_size, test_len) - :param support_token_ids: (batch_size, support_size, support_len) - :param support_segment_ids: (batch_size, support_size, support_len) - :param support_nwp_index: (batch_size, support_size, support_len) - :param support_input_mask: (batch_size, support_size, support_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, test_len) :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) :param support_num: (batch_size, 1) - :param label_token_ids: - if label_reps=cat: - (batch_size, label_num * label_des_len) - elif: - (batch_size, label_num, label_des_len) - :param label_segment_ids: same to label token ids - :param label_nwp_index: same to label token ids - :param label_input_mask: same to label token ids - :param label_output_mask: same to label token ids - :param support_sentence_feature: same to label token ids - :param test_sentence_feature: same to label token ids - :param support_sentence_target: same to label token ids - :param test_sentence_target: same to label token ids + :param label_reps: (batch_size, label_num, emb_dim) + :param is_training: the training mode :return: """ - test_reps, support_reps = self.get_context_reps( - test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, - support_token_ids, support_segment_ids, support_nwp_index, support_input_mask - ) - - # get label reps, shape (batch_size, max_label_num, emb_dim) - label_reps = self.get_label_reps( - label_token_ids, label_segment_ids, label_nwp_index, label_input_mask, - ) - # calculate emission: shape(batch_size, test_len, no_pad_num_tag) emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target, label_reps) - if not self.training and self.emb_log: + if not is_training and self.emb_log: self.emb_log.write('\n'.join(['test_target\t' + '\t'.join(map(str, one_target)) for one_target in test_target.tolist()]) + '\n') @@ -214,34 +124,17 @@ def forward( loss, prediction = torch.FloatTensor([0]).to(test_target.device), None - if self.training: + if is_training: loss = self.decoder.forward(logits=logits, mask=test_output_mask, tags=test_target) else: prediction = self.decoder.decode(logits=logits) # we block pad label(id=0) before by - 1, here, we add 1 back prediction = self.add_back_pad_label(prediction) - if self.training: + if is_training: return loss else: return prediction - def get_label_reps( - self, - label_token_ids: torch.Tensor, - label_segment_ids: torch.Tensor, - label_nwp_index: torch.Tensor, - label_input_mask: torch.Tensor, - ) -> torch.Tensor: - """ - :param label_token_ids: - :param label_segment_ids: - :param label_nwp_index: - :param label_input_mask: - :return: shape (batch_size, label_num, label_des_len) - """ - return self.context_embedder( - label_token_ids, label_segment_ids, label_nwp_index, label_input_mask, reps_type='label') - def main(): pass diff --git a/models/modules/context_embedder_base.py b/models/modules/context_embedder_base.py index c00a7f0..c45057f 100644 --- a/models/modules/context_embedder_base.py +++ b/models/modules/context_embedder_base.py @@ -264,7 +264,6 @@ def forward( def get_label_reps(self, test_token_ids, test_segment_ids, test_nwp_index, test_input_mask): batch_size = test_token_ids.shape[0] if self.opt.label_reps == 'cat': - # todo: use label mask to represent a label with only in domain info reps = self.single_reps(test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, ) elif self.opt.label_reps in ['sep', 'sep_sum']: input_ids, segment_ids, input_mask = self.flatten_input(test_token_ids, test_segment_ids, @@ -280,7 +279,6 @@ def get_label_reps(self, test_token_ids, test_segment_ids, test_nwp_index, test_ elif self.opt.label_reps == 'sep_sum': # average all label reps as reps reps = sequence_output emb_mask = self.expand_mask(test_input_mask, 2, reps_size) - # todo: use mask to get sum of single embedding raise NotImplementedError else: raise ValueError("Wrong label_reps choice ") diff --git a/models/modules/similarity_scorer_base.py b/models/modules/similarity_scorer_base.py index 8135457..ca62d05 100644 --- a/models/modules/similarity_scorer_base.py +++ b/models/modules/similarity_scorer_base.py @@ -336,7 +336,7 @@ def forward( num_tags = support_targets.shape[-1] no_pad_num_tags = num_tags - 1 - if no_pad_num_tags > len(self.anchor_reps) and (not label_reps or self.random_init): + if no_pad_num_tags > len(self.anchor_reps) and (label_reps is None or self.random_init): raise RuntimeError("Too few anchors") if label_reps is None and not self.random_init: diff --git a/readme.md b/readme.md index 02decc6..7dc01b7 100644 --- a/readme.md +++ b/readme.md @@ -70,7 +70,7 @@ There are some params for you to control the generation process: ##### few-shot/meta-episode style -```json +``` { "domain_name": [ { // episode @@ -130,4 +130,4 @@ There are many parameters to control the train & test process, but there are som ## Information -The platform is developed by [HIT-SCIR](http://ir.hit.edu.cn/). If you have any question and advice for it, please contact us(Yutai Hou - [ythou@ir.hit.edu.cn]() or Yongkui Lai - [yklai@ir.hit.edu.cn]()). +The platform is developed by [HIT-SCIR](http://ir.hit.edu.cn/). If you have any question and advice for it, please contact us(Yutai Hou - [ythou@ir.hit.edu.cn](mailto:ythou@ir.hit.edu.cn) or Yongkui Lai - [yklai@ir.hit.edu.cn](mailto:yklai@ir.hit.edu.cn)). diff --git a/scripts/run_bert_sc+sl.sh b/scripts/run_bert_sc+sl.sh new file mode 100644 index 0000000..bb23192 --- /dev/null +++ b/scripts/run_bert_sc+sl.sh @@ -0,0 +1,284 @@ +#!/usr/bin/env bash +echo usage: pass gpu id list as param, split with , +echo eg: source run_bert_siamese.sh 3,4 stanford + +gpu_list=$1 + +# Comment one of follow 2 to switch debugging status +do_debug=--do_debug +#do_debug= + +#restore=--restore_cpt +restore= + +task=sc\ sl +#task=sl + +use_schema=--use_schema +#use_schema= + +#label_num_schema=--label_num_schema +label_num_schema= + + +# ======= dataset setting ====== +dataset_lst=($2 $3) +support_shots_lst=(3) + +query_shot=8 + +episode=100 + +cross_data_id=0 # for smp + +# ====== train & test setting ====== +seed_lst=(0) +#seed_lst=(6150 6151 6152) + +#lr_lst=(0.000001 0.000005 0.00005) +lr_lst=(0.00001) + +clip_grad=5 + +decay_lr_lst=(0.5) +#decay_lr_lst=(-1) + +#upper_lr_lst=( 0.5 0.01 0.005) +#upper_lr_lst=(0.01) +upper_lr_lst=(0.001) +#upper_lr_lst=(0.0001) +#upper_lr_lst=(0.0005) +#upper_lr_lst=(0.1) +#upper_lr_lst=(0.001 0.1) + +#fix_embd_epoch_lst=(1) +fix_embd_epoch_lst=(-1) +#fix_embd_epoch_lst=(1 2) + +warmup_epoch=2 + +train_batch_size_lst=(4) +test_batch_size=2 +#grad_acc=2 +grad_acc=4 +epoch=3 + +# ==== model setting ========= +# ---- encoder setting ----- + +#embedder=electra +embedder=bert +#embedder=sep_bert + + +# --------- emission setting -------- +emission=proto_with_label\ tapnet # match with task + + +#similarity=cosine +#similarity=l2 +similarity=dot + +emission_normalizer=none +#emission_normalizer=softmax +#emission_normalizer=norm + +#emission_scaler=none +#emission_scaler=fix +emission_scaler=learn +#emission_scaler=relu +#emission_scaler=exp + + +do_div_emission=-dbt +#do_div_emission= + +ems_scale_rate_lst=(0.01) +#ems_scale_rate_lst=(0.01 0.02 0.05 0.005) + +label_reps=sep +#label_reps=cat + +ple_normalizer=none +ple_scaler=fix +#ple_scale_r=0.5 +ple_scale_r_lst=(0.5) +#ple_scale_r=1 +#ple_scale_r=0.01 + +tap_random_init=--tap_random_init +tap_random_init_r=0.5 +tap_proto=--tap_proto +tap_proto_r=0.3 +tap_mlp= +emb_log= + +# ------ decoder setting ------- +#decoder_lst=(rule) +#decoder_lst=(sms) +decoder_lst=(crf) +#decoder_lst=(crf sms) + + +#trans_init_lst=(fix rand) +trans_init_lst=(rand) + +mask_trans=-mk_tr +#mask_trans= + +trans_scaler=fix +#trans_scale_rate_lst=(10) +trans_scale_rate_lst=(1) + +trans_rate=1 +#trans_rate=0.8 + +trans_normalizer=none +#trans_normalizer=softmax +#trans_normalizer=norm + +trans_scaler=none +#trans_scaler=fix +#trans_scaler=learn +#trans_scaler=relu +#trans_scaler=exp + +transition=learn + +# -------- SC decoder setting -------- + + +# ======= default path (for quick distribution) ========== +# bert base path +pretrained_model_path=/users4/yklai/corpus/BERT/pytorch/chinese_L-12_H-768_A-12 +pretrained_vocab_path=/users4/yklai/corpus/BERT/pytorch/chinese_L-12_H-768_A-12/vocab.txt + +# electra small path +#pretrained_model_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +#pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +#pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch +#pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch + +base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/SmpMetaData/ +#base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ + +echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] +# === Loop for all case and run === +for seed in ${seed_lst[@]} +do + for dataset in ${dataset_lst[@]} + do + for support_shots in ${support_shots_lst[@]} + do + for train_batch_size in ${train_batch_size_lst[@]} + do + for decay_lr in ${decay_lr_lst[@]} + do + for fix_embd_epoch in ${fix_embd_epoch_lst[@]} + do + for lr in ${lr_lst[@]} + do + for upper_lr in ${upper_lr_lst[@]} + do + + for trans_init in ${trans_init_lst[@]} + do + for ems_scale_r in ${ems_scale_rate_lst[@]} + do + for trans_scale_r in ${trans_scale_rate_lst[@]} + do + for decoder in ${decoder_lst[@]} + do + for ple_scale_r in ${ple_scale_r_lst[@]} + do + # model names + model_name=joint_sc_sl.ga_${grad_acc}_ple_${ple_scale_r}.bs_${train_batch_size}.electra.sim_${similarity}.ems_${emission_normalizer}.${use_schema}${label_num_schema}--fix_dev_spt${do_debug} + + data_dir=${base_data_dir}${dataset}.${cross_data_id}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode}${use_schema}--fix_dev_spt/ + file_mark=${dataset}.shots_${support_shots}.cross_id_${cross_data_id}.m_seed_${seed} + train_file_name=train.json + dev_file_name=dev.json + test_file_name=test.json + + echo [CLI] + echo Model: ${model_name} + echo Task: ${file_mark} + echo [CLI] + export OMP_NUM_THREADS=2 # threads num for each task + CUDA_VISIBLE_DEVICES=${gpu_list} python3 main.py ${do_debug} \ + --task ${task} \ + --seed ${seed} \ + --do_train \ + --do_predict \ + --train_path ${data_dir}${train_file_name} \ + --dev_path ${data_dir}${dev_file_name} \ + --test_path ${data_dir}${test_file_name} \ + --output_dir ${data_dir}${model_name}.DATA.${file_mark} \ + --bert_path ${pretrained_model_path} \ + --bert_vocab ${pretrained_vocab_path} \ + --train_batch_size ${train_batch_size} \ + --cpt_per_epoch 4 \ + --delete_checkpoint \ + --gradient_accumulation_steps ${grad_acc} \ + --num_train_epochs ${epoch} \ + --learning_rate ${lr} \ + --decay_lr ${decay_lr} \ + --upper_lr ${upper_lr} \ + --clip_grad ${clip_grad} \ + --fix_embed_epoch ${fix_embd_epoch} \ + --warmup_epoch ${warmup_epoch} \ + --test_batch_size ${test_batch_size} \ + --context_emb ${embedder} \ + ${use_schema} \ + ${label_num_schema} \ + --label_reps ${label_reps} \ + --projection_layer none \ + --emission ${emission} \ + --similarity ${similarity} \ + -e_nm ${emission_normalizer} \ + -e_scl ${emission_scaler} \ + --ems_scale_r ${ems_scale_r} \ + -ple_nm ${ple_normalizer} \ + -ple_scl ${ple_scaler} \ + --ple_scale_r ${ple_scale_r} \ + ${tap_random_init} \ + --tap_random_init_r ${tap_random_init_r} \ + --tap_proto ${tap_proto} \ + --tap_proto_r ${tap_proto_r} \ + ${tap_mlp} \ + ${emb_log} \ + ${do_div_emission} \ + --transition ${transition} \ + --backoff_init ${trans_init} \ + --trans_r ${trans_rate} \ + -t_nm ${trans_normalizer} \ + -t_scl ${trans_scaler} \ + --trans_scale_r ${trans_scale_r} \ + ${mask_trans} \ + --load_feature > ./joint/${model_name}.DATA.${file_mark}.log + echo [CLI] + echo Model: ${model_name} + echo Task: ${file_mark} + echo [CLI] + done + done + done + done + done + done + done + done + done + done + done + done +done + +# Other candidate option: +# -doft \ +# --delete_checkpoint \ +# --fp16 \ + + + +echo [FINISH] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] diff --git a/scripts/run_bert_sc.sh b/scripts/run_bert_sc.sh index 2b69891..a5b7df6 100644 --- a/scripts/run_bert_sc.sh +++ b/scripts/run_bert_sc.sh @@ -11,7 +11,7 @@ do_debug=--do_debug #restore=--restore_cpt restore= -task=sc +task=sc\ sl #task=sl use_schema=--use_schema diff --git a/scripts/run_electra_sc+sl.sh b/scripts/run_electra_sc+sl.sh new file mode 100644 index 0000000..4c2ddb7 --- /dev/null +++ b/scripts/run_electra_sc+sl.sh @@ -0,0 +1,284 @@ +#!/usr/bin/env bash +echo usage: pass gpu id list as param, split with , +echo eg: source run_bert_siamese.sh 3,4 stanford + +gpu_list=$1 + +# Comment one of follow 2 to switch debugging status +do_debug=--do_debug +#do_debug= + +#restore=--restore_cpt +restore= + +task=sc\ sl +#task=sl + +use_schema=--use_schema +#use_schema= + +#label_num_schema=--label_num_schema +label_num_schema= + + +# ======= dataset setting ====== +dataset_lst=($2 $3) +support_shots_lst=(3) + +query_shot=8 + +episode=100 + +cross_data_id=0 # for smp + +# ====== train & test setting ====== +seed_lst=(0) +#seed_lst=(6150 6151 6152) + +#lr_lst=(0.000001 0.000005 0.00005) +lr_lst=(0.00001) + +clip_grad=5 + +decay_lr_lst=(0.5) +#decay_lr_lst=(-1) + +#upper_lr_lst=( 0.5 0.01 0.005) +#upper_lr_lst=(0.01) +upper_lr_lst=(0.001) +#upper_lr_lst=(0.0001) +#upper_lr_lst=(0.0005) +#upper_lr_lst=(0.1) +#upper_lr_lst=(0.001 0.1) + +#fix_embd_epoch_lst=(1) +fix_embd_epoch_lst=(-1) +#fix_embd_epoch_lst=(1 2) + +warmup_epoch=2 + +train_batch_size_lst=(4) +test_batch_size=2 +#grad_acc=2 +grad_acc=4 +epoch=3 + +# ==== model setting ========= +# ---- encoder setting ----- + +embedder=electra +#embedder=bert +#embedder=sep_bert + + +# --------- emission setting -------- +emission=proto_with_label\ tapnet # match with task + + +#similarity=cosine +#similarity=l2 +similarity=dot + +emission_normalizer=none +#emission_normalizer=softmax +#emission_normalizer=norm + +#emission_scaler=none +#emission_scaler=fix +emission_scaler=learn +#emission_scaler=relu +#emission_scaler=exp + + +do_div_emission=-dbt +#do_div_emission= + +ems_scale_rate_lst=(0.01) +#ems_scale_rate_lst=(0.01 0.02 0.05 0.005) + +label_reps=sep +#label_reps=cat + +ple_normalizer=none +ple_scaler=fix +#ple_scale_r=0.5 +ple_scale_r_lst=(0.5) +#ple_scale_r=1 +#ple_scale_r=0.01 + +tap_random_init=--tap_random_init +tap_random_init_r=0.5 +tap_proto=--tap_proto +tap_proto_r=0.3 +tap_mlp= +emb_log= + +# ------ decoder setting ------- +#decoder_lst=(rule) +#decoder_lst=(sms) +decoder_lst=(crf) +#decoder_lst=(crf sms) + + +#trans_init_lst=(fix rand) +trans_init_lst=(rand) + +mask_trans=-mk_tr +#mask_trans= + +trans_scaler=fix +#trans_scale_rate_lst=(10) +trans_scale_rate_lst=(1) + +trans_rate=1 +#trans_rate=0.8 + +trans_normalizer=none +#trans_normalizer=softmax +#trans_normalizer=norm + +trans_scaler=none +#trans_scaler=fix +#trans_scaler=learn +#trans_scaler=relu +#trans_scaler=exp + +transition=learn + +# -------- SC decoder setting -------- + + +# ======= default path (for quick distribution) ========== +# bert base path +#pretrained_model_path=/users4/yklai/corpus/BERT/pytorch/chinese_L-12_H-768_A-12 +#pretrained_vocab_path=/users4/yklai/corpus/BERT/pytorch/chinese_L-12_H-768_A-12/vocab.txt + +# electra small path +pretrained_model_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +#pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch +#pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch + +base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/SmpMetaData/ +#base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ + +echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] +# === Loop for all case and run === +for seed in ${seed_lst[@]} +do + for dataset in ${dataset_lst[@]} + do + for support_shots in ${support_shots_lst[@]} + do + for train_batch_size in ${train_batch_size_lst[@]} + do + for decay_lr in ${decay_lr_lst[@]} + do + for fix_embd_epoch in ${fix_embd_epoch_lst[@]} + do + for lr in ${lr_lst[@]} + do + for upper_lr in ${upper_lr_lst[@]} + do + + for trans_init in ${trans_init_lst[@]} + do + for ems_scale_r in ${ems_scale_rate_lst[@]} + do + for trans_scale_r in ${trans_scale_rate_lst[@]} + do + for decoder in ${decoder_lst[@]} + do + for ple_scale_r in ${ple_scale_r_lst[@]} + do + # model names + model_name=joint_sc_sl.ga_${grad_acc}_ple_${ple_scale_r}.bs_${train_batch_size}.electra.sim_${similarity}.ems_${emission_normalizer}.${use_schema}${label_num_schema}--fix_dev_spt${do_debug} + + data_dir=${base_data_dir}${dataset}.${cross_data_id}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode}${use_schema}--fix_dev_spt/ + file_mark=${dataset}.shots_${support_shots}.cross_id_${cross_data_id}.m_seed_${seed} + train_file_name=train.json + dev_file_name=dev.json + test_file_name=test.json + + echo [CLI] + echo Model: ${model_name} + echo Task: ${file_mark} + echo [CLI] + export OMP_NUM_THREADS=2 # threads num for each task + CUDA_VISIBLE_DEVICES=${gpu_list} python3 main.py ${do_debug} \ + --task ${task} \ + --seed ${seed} \ + --do_train \ + --do_predict \ + --train_path ${data_dir}${train_file_name} \ + --dev_path ${data_dir}${dev_file_name} \ + --test_path ${data_dir}${test_file_name} \ + --output_dir ${data_dir}${model_name}.DATA.${file_mark} \ + --bert_path ${pretrained_model_path} \ + --bert_vocab ${pretrained_vocab_path} \ + --train_batch_size ${train_batch_size} \ + --cpt_per_epoch 4 \ + --delete_checkpoint \ + --gradient_accumulation_steps ${grad_acc} \ + --num_train_epochs ${epoch} \ + --learning_rate ${lr} \ + --decay_lr ${decay_lr} \ + --upper_lr ${upper_lr} \ + --clip_grad ${clip_grad} \ + --fix_embed_epoch ${fix_embd_epoch} \ + --warmup_epoch ${warmup_epoch} \ + --test_batch_size ${test_batch_size} \ + --context_emb ${embedder} \ + ${use_schema} \ + ${label_num_schema} \ + --label_reps ${label_reps} \ + --projection_layer none \ + --emission ${emission} \ + --similarity ${similarity} \ + -e_nm ${emission_normalizer} \ + -e_scl ${emission_scaler} \ + --ems_scale_r ${ems_scale_r} \ + -ple_nm ${ple_normalizer} \ + -ple_scl ${ple_scaler} \ + --ple_scale_r ${ple_scale_r} \ + ${tap_random_init} \ + --tap_random_init_r ${tap_random_init_r} \ + --tap_proto ${tap_proto} \ + --tap_proto_r ${tap_proto_r} \ + ${tap_mlp} \ + ${emb_log} \ + ${do_div_emission} \ + --transition ${transition} \ + --backoff_init ${trans_init} \ + --trans_r ${trans_rate} \ + -t_nm ${trans_normalizer} \ + -t_scl ${trans_scaler} \ + --trans_scale_r ${trans_scale_r} \ + ${mask_trans} \ + --load_feature > ./joint/${model_name}.DATA.${file_mark}.log + echo [CLI] + echo Model: ${model_name} + echo Task: ${file_mark} + echo [CLI] + done + done + done + done + done + done + done + done + done + done + done + done +done + +# Other candidate option: +# -doft \ +# --delete_checkpoint \ +# --fp16 \ + + + +echo [FINISH] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] diff --git a/scripts/run_electra_sc.sh b/scripts/run_electra_sc.sh index c475c20..a700646 100644 --- a/scripts/run_electra_sc.sh +++ b/scripts/run_electra_sc.sh @@ -14,6 +14,7 @@ restore= task=sc #task=sl + use_schema=--use_schema #use_schema= @@ -147,8 +148,11 @@ emb_log= # electra small path pretrained_model_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +#pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch +#pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/SmpMetaData/ +#base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] # === Loop for all case and run === @@ -188,7 +192,7 @@ do echo Task: ${file_mark} echo [CLI] export OMP_NUM_THREADS=2 # threads num for each task - CUDA_VISIBLE_DEVICES=${gpu_list} python main.py ${do_debug} \ + CUDA_VISIBLE_DEVICES=${gpu_list} python3 main.py ${do_debug} \ --task ${task} \ --seed ${seed} \ --do_train \ @@ -229,7 +233,7 @@ do ${emb_log} \ ${do_div_emission} \ --transition learn \ - --load_feature > ./sclog/${model_name}.DATA.${file_mark}.log + --load_feature > ./sclog/${model_name}.DATA.${file_mark}.log echo [CLI] echo Model: ${model_name} echo Task: ${file_mark} diff --git a/utils/data_loader.py b/utils/data_loader.py index 0e43b1e..619ed01 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -110,17 +110,6 @@ def batch2data_items(self, batch: dict) -> (List[DataItem], List[DataItem]): def get_data_items(self, parts: dict) -> List[DataItem]: data_item_lst = [] for seq_in, seq_out, label in zip(parts['seq_ins'], parts['seq_outs'], parts['labels']): - # todo: move word-piecing into preprocessing module - # label = token_label if self.opt.task == 'ml' else sent_label # decide label type according to task data_item = DataItem(seq_in=seq_in, seq_out=seq_out, label=label) data_item_lst.append(data_item) return data_item_lst - - # def get_data_items(self, parts: dict) -> List[DataItem]: - # data_item_lst = [] - # for text, label, wp_text, wp_label, wp_mark in zip( - # parts['seq_ins'], parts['seq_outs'], - # parts['tokenized_texts'], parts['word_piece_labels'], parts['word_piece_marks']): - # data_item = DataItem(text=text, label=label, wp_text=wp_text, wp_label=wp_label, wp_mark=wp_mark) - # data_item_lst.append(data_item) - # return data_item_lst diff --git a/utils/device_helper.py b/utils/device_helper.py index b2f2248..484aae3 100644 --- a/utils/device_helper.py +++ b/utils/device_helper.py @@ -40,7 +40,10 @@ def prepare_model(args, model, device, n_gpu): """ Set device to use """ if args.fp16: model.half() + # TODO: smarter way model.to(device) + for task in model.model_map.keys(): + model.model_map[task].to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) diff --git a/utils/iter_helper.py b/utils/iter_helper.py index c186790..6548501 100644 --- a/utils/iter_helper.py +++ b/utils/iter_helper.py @@ -80,16 +80,25 @@ def pad_collate(self, batch: List[List[torch.Tensor]]) -> List[torch.Tensor]: ret = [] for item_idx in range(len(batch[0])): # pad each data item # find longest sequence - max_len = max(map(lambda x: x[item_idx].shape[self.get_dim(item_idx)], batch)) - # if item_idx in self.sp_item_idx: - # print('debug padding:', "item_idx", item_idx, "max_len", - # max_len,'batch items', [x[item_idx].shape for x in batch]) - # pad according to max_len - padded_item_lst = list(map( - lambda x: pad_tensor(x[item_idx], pad=max_len, dim=self.get_dim(item_idx)), batch)) - # stack all - padded_item_lst = torch.stack(padded_item_lst, dim=0) - ret.append(padded_item_lst) + if isinstance(batch[0][item_idx], dict): + padded_item_lst_map = {} + for key in batch[0][item_idx].keys(): + max_len = max(map(lambda x: x[item_idx][key].shape[self.get_dim(item_idx)], batch)) + # pad according to max_len + padded_item_lst = list(map( + lambda x: pad_tensor(x[item_idx][key], pad=max_len, dim=self.get_dim(item_idx)), batch)) + # stack all + padded_item_lst = torch.stack(padded_item_lst, dim=0) + padded_item_lst_map[key] = padded_item_lst + + else: + max_len = max(map(lambda x: x[item_idx].shape[self.get_dim(item_idx)], batch)) + # pad according to max_len + padded_item_lst = list(map( + lambda x: pad_tensor(x[item_idx], pad=max_len, dim=self.get_dim(item_idx)), batch)) + # stack all + padded_item_lst_map = torch.stack(padded_item_lst, dim=0) + ret.append(padded_item_lst_map) return ret def get_dim(self, item_idx): diff --git a/utils/model_helper.py b/utils/model_helper.py index 1ef9e8e..c8edf40 100644 --- a/utils/model_helper.py +++ b/utils/model_helper.py @@ -19,11 +19,10 @@ from models.modules.text_classifier import SingleLabelTextClassifier from models.modules.conditional_random_field import ConditionalRandomField, allowed_transitions +from models.few_shot_learner import FewShotLearner, SchemaFewShotLearner from models.few_shot_seq_labeler import FewShotSeqLabeler, SchemaFewShotSeqLabeler from models.few_shot_text_classifier import FewShotTextClassifier, SchemaFewShotTextClassifier - from models.modules.scale_controller import build_scale_controller, ScaleControllerBase - from utils.device_helper import prepare_model @@ -67,8 +66,8 @@ def make_model(opt, config): ''' Create log file to record testing data ''' if opt.emb_log: emb_log = open(os.path.join(opt.output_dir, 'emb.log'), 'w') - if 'id2label' in config: - emb_log.write('id2label\t' + '\t'.join([str(k) + ':' + str(v) for k, v in config['id2label'].items()]) + '\n') + if 'id2label_map' in config: + emb_log.write('id2label_map\t' + '\t'.join([str(k) + ':' + str(v) for k, v in config['id2label_map'].items()])+'\n') else: emb_log = None @@ -86,38 +85,45 @@ def make_model(opt, config): else: raise TypeError('wrong component type') - if opt.emission == 'mnet': - similarity_scorer = MatchingSimilarityScorer(sim_func=sim_func, emb_log=emb_log) - emission_scorer = MNetEmissionScorer(similarity_scorer, ems_scaler, opt.div_by_tag_num) - elif opt.emission == 'proto': - similarity_scorer = PrototypeSimilarityScorer(sim_func=sim_func, emb_log=emb_log) - emission_scorer = PrototypeEmissionScorer(similarity_scorer, ems_scaler) - elif opt.emission == 'proto_with_label': - similarity_scorer = ProtoWithLabelSimilarityScorer(sim_func=sim_func, scaler=opt.ple_scale_r, emb_log=emb_log) - emission_scorer = ProtoWithLabelEmissionScorer(similarity_scorer, ems_scaler) - elif opt.emission == 'tapnet': - # set num of anchors: - # (1) if provided in config, use it (usually in load model case.) - # (2) *3 is used to ensure enough anchors ( > num_tags of unseen domains ) - num_anchors = config['num_anchors'] if 'num_anchors' in config else config['num_tags'] * 3 - config['num_anchors'] = num_anchors - anchor_dim = 256 if opt.context_emb == 'electra' else 768 - similarity_scorer = TapNetSimilarityScorer( - sim_func=sim_func, num_anchors=num_anchors, mlp_out_dim=opt.tap_mlp_out_dim, - random_init=opt.tap_random_init, random_init_r=opt.tap_random_init_r, - mlp=opt.tap_mlp, emb_log=emb_log, tap_proto=opt.tap_proto, tap_proto_r=opt.tap_proto_r, - anchor_dim=anchor_dim) - emission_scorer = TapNetEmissionScorer(similarity_scorer, ems_scaler) - else: - raise TypeError('wrong component type') + assert len(opt.emission) == len(opt.task), "the emission list should match with task list" + + emission_scorer_map = {} + for task, emission in zip(opt.task, opt.emission): + if emission == 'mnet': + similarity_scorer = MatchingSimilarityScorer(sim_func=sim_func, emb_log=emb_log) + emission_scorer = MNetEmissionScorer(similarity_scorer, ems_scaler, opt.div_by_tag_num) + elif emission == 'proto': + similarity_scorer = PrototypeSimilarityScorer(sim_func=sim_func, emb_log=emb_log) + emission_scorer = PrototypeEmissionScorer(similarity_scorer, ems_scaler) + elif emission == 'proto_with_label': + similarity_scorer = ProtoWithLabelSimilarityScorer(sim_func=sim_func, scaler=opt.ple_scale_r, emb_log=emb_log) + emission_scorer = ProtoWithLabelEmissionScorer(similarity_scorer, ems_scaler) + elif emission == 'tapnet': + # set num of anchors: + # (1) if provided in config, use it (usually in load model case.) + # (2) *3 is used to ensure enough anchors ( > num_tags of unseen domains ) + print('config: {}'.format(config)) + num_anchors = config['num_anchors'] if 'num_anchors' in config else config['num_tags'] * 3 + config['num_anchors'] = num_anchors + anchor_dim = 256 if opt.context_emb == 'electra' else 768 + similarity_scorer = TapNetSimilarityScorer( + sim_func=sim_func, num_anchors=num_anchors, mlp_out_dim=opt.tap_mlp_out_dim, + random_init=opt.tap_random_init, random_init_r=opt.tap_random_init_r, + mlp=opt.tap_mlp, emb_log=emb_log, tap_proto=opt.tap_proto, tap_proto_r=opt.tap_proto_r, + anchor_dim=anchor_dim) + emission_scorer = TapNetEmissionScorer(similarity_scorer, ems_scaler) + else: + raise TypeError('wrong component type') + emission_scorer_map[task] = emission_scorer ''' Build decoder ''' - if opt.task == 'sl': # for sequence labeling + model_map = {} + decoder_map = {} + transition_scorer = None + if 'sl' in opt.task: # for sequence labeling if opt.decoder == 'sms': - transition_scorer = None decoder = SequenceLabeler() elif opt.decoder == 'rule': - transition_scorer = None decoder = RuleSequenceLabeler(config['id2label']) elif opt.decoder == 'crf': # logger.info('We only support back-off trans training now!') @@ -132,7 +138,7 @@ def make_model(opt, config): elif opt.transition == 'learn_with_label': label_trans_normalizer = build_scale_controller(name=opt.label_trans_normalizer) label_trans_scaler = build_scale_controller(name=opt.label_trans_scaler, kwargs=make_scaler_args( - opt.label_trans_scaler, label_trans_normalizer, opt.label_trans_scale_r)) + opt.label_trans_scaler, label_trans_normalizer, opt.label_trans_scale_r)) transition_scorer = FewShotTransitionScorerFromLabel( num_tags=config['num_tags'], normalizer=trans_normalizer, scaler=trans_scaler, r=opt.trans_r, backoff_init=opt.backoff_init, label_scaler=label_trans_scaler) @@ -149,35 +155,41 @@ def make_model(opt, config): num_tags=transition_scorer.num_tags, constraints=constraints) # accurate tags else: raise TypeError('wrong component type') - elif opt.task == 'sc': # for single-label text classification task + + # decoder_map['sl'] = decoder + seq_laber = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler + model_map['sl'] = seq_laber(opt=opt, + context_embedder=context_embedder, + emission_scorer=emission_scorer_map['sl'], + decoder=decoder, + transition_scorer=transition_scorer, + config=config, + emb_log=emb_log) + + if 'sc' in opt.task: # for single-label text classification task decoder = SingleLabelTextClassifier() - else: - raise TypeError('wrong task type') + # decoder_map['sc'] = decoder + text_classifier = SchemaFewShotTextClassifier if opt.use_schema else FewShotTextClassifier + model_map['sc'] = text_classifier(opt=opt, + context_embedder=context_embedder, + emission_scorer=emission_scorer_map['sc'], + decoder=decoder, + config=config, + emb_log=emb_log) ''' Build the whole model ''' - if opt.task == 'sl': - seq_labeler = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler - model = seq_labeler( - opt=opt, - context_embedder=context_embedder, - emission_scorer=emission_scorer, - decoder=decoder, - transition_scorer=transition_scorer, - config=config, - emb_log=emb_log - ) - elif opt.task == 'sc': - text_classifier = SchemaFewShotTextClassifier if opt.use_schema else FewShotTextClassifier - model = text_classifier( - opt=opt, - context_embedder=context_embedder, - emission_scorer=emission_scorer, - decoder=decoder, - config=config, - emb_log=emb_log - ) - else: - raise TypeError('wrong task type') + few_shot_learner = SchemaFewShotLearner if opt.use_schema else FewShotLearner + model = few_shot_learner( + opt=opt, + context_embedder=context_embedder, + # emission_scorer_map=emission_scorer_map, + # decoder_map=decoder_map, + # transition_scorer=transition_scorer, + model_map=model_map, + config=config, + emb_log=emb_log + ) + return model diff --git a/utils/opt.py b/utils/opt.py index 816e005..96fed9b 100644 --- a/utils/opt.py +++ b/utils/opt.py @@ -39,8 +39,10 @@ def basic_args(parser): help="path to embedding cache dir. if use pytorch nlp, use this path to avoid downloading") group = parser.add_argument_group('Function') - parser.add_argument("--task", default='sc', choices=['sl', 'sc'], - help="Task: sl:sequence labeling, sc:single label sent classify") + # parser.add_argument("--task", default='sc', choices=['sl', 'sc'], + # help="Task: sl:sequence labeling, sc:single label sent classify") + parser.add_argument("--task", nargs='+', type=str, + help="Task: sl:sequence labeling, sc:single label sent classify. We can specify multiple tasks") group.add_argument('--allow_override', default=False, action='store_true', help='allow override experiment file') group.add_argument('--load_feature', default=False, action='store_true', help='load feature from file') group.add_argument('--save_feature', default=False, action='store_true', help='save feature to file') @@ -154,9 +156,9 @@ def model_args(parser): help="decode method") # ===== emission layer setting ========= - group.add_argument("--emission", default='mnet', type=str, - choices=['mnet', 'rank', 'proto', 'proto_with_label', 'tapnet'], - help="Method for calculate emission score") + group.add_argument("--emission", nargs="+", type=str, + # choices=['mnet', 'rank', 'proto', 'proto_with_label', 'tapnet'], + help="Method for calculate emission score, match with task list") group.add_argument("-e_nm", "--emission_normalizer", type=str, default='', choices=['softmax', 'norm', 'none'], help="normalize emission into 1-0") group.add_argument("-e_scl", "--emission_scaler", type=str, default=None, diff --git a/utils/preprocessor.py b/utils/preprocessor.py index 25bc4ff..1547a65 100644 --- a/utils/preprocessor.py +++ b/utils/preprocessor.py @@ -15,13 +15,13 @@ "FeatureItem", [ "tokens", # tokens corresponding to input token ids, eg: word_piece tokens with [CLS], [SEP] - "labels", # labels for all input position, eg; label for word_piece tokens + "label_map", # labels for all input position, eg; label for word_piece tokens "data_item", "token_ids", "segment_ids", "nwp_index", "input_mask", - "output_mask" + "output_mask_map" ] ) @@ -32,7 +32,7 @@ "segment_ids", # bert [SEP] ids "nwp_index", # non-word-piece word index to extract non-word-piece tokens' reps (only useful for bert). "input_mask", # [1] * len(sent), 1 for valid (tokens, cls, sep, word piece), 0 is padding in batch construction - "output_mask", # [1] * len(sent), 1 for valid output, 0 for padding, eg: 1 for original tokens in sl task + "output_mask_map", # [1] * len(sent), 1 for valid output, 0 for padding, eg: 1 for original tokens in sl task ] ) @@ -49,10 +49,10 @@ def __init__( test_feature_item: FeatureItem, support_input: ModelInput, support_feature_items: List[FeatureItem], - test_target: torch.Tensor, # 1) CRF, shape: (1, test_len) 2)SMS, shape: (support_size, t_len, s_len) - support_target: torch.Tensor, # 1) shape: (support_len, label_size) - label_input=None, - label_items=None, + test_target_map: Dict[str, List[torch.Tensor]], + support_target_map: Dict[str, List[torch.Tensor]], + label_input_map=None, + label_item_map=None, ): self.gid = gid self.test_gid = test_gid @@ -61,15 +61,13 @@ def __init__( self.test_input = test_input # shape: (1, test_len) self.support_input = support_input # shape: (support_size, support_len) # output: - # 1)CRF, shape: (1, test_len) - # 2)SMS, shape: (support_size, test_len, support_len) - self.test_target = test_target - self.support_target = support_target + self.test_target_map = test_target_map + self.support_target_map = support_target_map ''' raw feature ''' self.test_feature_item = test_feature_item self.support_feature_items = support_feature_items - self.label_input = label_input - self.label_items = label_items + self.label_input_map = label_input_map + self.label_item_map = label_item_map def __str__(self): return self.__repr__() @@ -82,20 +80,24 @@ class InputBuilderBase: def __init__(self, tokenizer): self.tokenizer = tokenizer - def __call__(self, example, max_support_size, label2id - ) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): + def __call__(self, example, max_support_size, label2id_map + ) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): raise NotImplementedError def data_item2feature_item(self, data_item: DataItem, seg_id: int) -> FeatureItem: raise NotImplementedError def get_test_model_input(self, feature_item: FeatureItem) -> ModelInput: + if isinstance(feature_item.output_mask_map, dict): + output_mask_map = {key: torch.LongTensor(map_item) for key, map_item in feature_item.output_mask_map.items()} + else: + output_mask_map = torch.LongTensor(feature_item.output_mask_map) ret = ModelInput( token_ids=torch.LongTensor(feature_item.token_ids), segment_ids=torch.LongTensor(feature_item.segment_ids), nwp_index=torch.LongTensor(feature_item.nwp_index), input_mask=torch.LongTensor(feature_item.input_mask), - output_mask=torch.LongTensor(feature_item.output_mask) + output_mask_map=output_mask_map ) return ret @@ -105,13 +107,20 @@ def get_support_model_input(self, feature_items: List[FeatureItem], max_support_ segment_ids = self.pad_support_set([f.segment_ids for f in feature_items], 0, max_support_size) nwp_index = self.pad_support_set([f.nwp_index for f in feature_items], [0], max_support_size) input_mask = self.pad_support_set([f.input_mask for f in feature_items], 0, max_support_size) - output_mask = self.pad_support_set([f.output_mask for f in feature_items], 0, max_support_size) + if isinstance(feature_items[0].output_mask_map, dict): + output_mask_map = {key: self.pad_support_set([f.output_mask_map[key] for f in feature_items], + 0, max_support_size) + for key in feature_items[0].output_mask_map.keys()} + output_mask_map = {key: torch.LongTensor(output_mask) for key, output_mask in output_mask_map.items()} + else: + output_mask_map = self.pad_support_set([f.output_mask_map for f in feature_items], 0, max_support_size) + output_mask_map = torch.LongTensor(output_mask_map) ret = ModelInput( token_ids=torch.LongTensor(token_ids), segment_ids=torch.LongTensor(segment_ids), nwp_index=torch.LongTensor(nwp_index), input_mask=torch.LongTensor(input_mask), - output_mask=torch.LongTensor(output_mask) + output_mask_map=output_mask_map ) return ret @@ -156,7 +165,7 @@ def __init__(self, tokenizer, opt): self.support_seg_id = 0 if opt.context_emb == 'sep_bert' else 1 # 1 to cat support and query to get reps self.seq_ins = {} - def __call__(self, example, max_support_size, label2id) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): + def __call__(self, example, max_support_size, label2id_map) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): test_feature_item, test_input = self.prepare_test(example) support_feature_items, support_input = self.prepare_support(example, max_support_size) return test_feature_item, test_input, support_feature_items, support_input @@ -172,33 +181,38 @@ def prepare_support(self, example, max_support_size): support_input = self.get_support_model_input(support_feature_items, max_support_size) return support_feature_items, support_input - def data_item2feature_item(self, data_item: DataItem, seg_id: int) -> FeatureItem: + def data_item2feature_item(self, data_item: DataItem, seg_id: int, share_feature: bool = True) -> FeatureItem: """ get feature_item for bert, steps: 1. do digitalizing 2. make mask """ wp_mark, wp_text = self.tokenizing(data_item) - if self.opt.task == 'sl': # use word-level labels [opt.label_wp is supported by model now.] - labels = self.get_wp_label(data_item.seq_out, wp_text, wp_mark) if self.opt.label_wp else data_item.seq_out - else: # use sentence level labels - labels = data_item.label - if 'None' not in labels: - # transfer label to index type such as `label_1` - if self.opt.index_label: - labels = [self.opt.label2index_type[label] for label in labels] - if self.opt.unused_label: - labels = [self.opt.label2unused_type[label] for label in labels] + if share_feature: + label_map = {} + output_mask_map = {} + if 'sl' in self.opt.task: # use word-level labels [opt.label_wp is supported by model now.] + labels = self.get_wp_label(data_item.seq_out, wp_text, wp_mark) if self.opt.label_wp else data_item.seq_out + label_map['sl'] = labels + output_mask_map['sl'] = [1] * len(labels) # For sl: it is original tokens; + + if 'sc' in self.opt.task: # use sentence level labels + labels = data_item.label + label_map['sc'] = labels + output_mask_map['sc'] = [1] * len(labels) # For sc: it is labels + else: + label_map = data_item.label + output_mask_map = [1] * len(label_map) + tokens = ['[CLS]'] + wp_text + ['[SEP]'] if seg_id == 0 else wp_text + ['[SEP]'] token_ids, segment_ids = self.digitizing_input(tokens=tokens, seg_id=seg_id) nwp_index = self.get_nwp_index(wp_mark) input_mask = [1] * len(token_ids) - output_mask = [1] * len(labels) # For sl: it is original tokens; For sc: it is labels ret = FeatureItem( tokens=tokens, - labels=labels, + label_map=label_map, data_item=data_item, token_ids=token_ids, segment_ids=segment_ids, nwp_index=nwp_index, input_mask=input_mask, - output_mask=output_mask, + output_mask_map=output_mask_map, ) return ret @@ -234,45 +248,54 @@ class SchemaInputBuilder(BertInputBuilder): def __init__(self, tokenizer, opt): super(SchemaInputBuilder, self).__init__(tokenizer, opt) - def __call__(self, example, max_support_size, label2id) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): + def __call__(self, example, max_support_size, label2id_map) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): test_feature_item, test_input = self.prepare_test(example) support_feature_items, support_input = self.prepare_support(example, max_support_size) if self.opt.label_reps in ['cat']: # represent labels by concat all all labels - label_input, label_items = self.prepare_label_feature(label2id) + label_input, label_items = self.prepare_label_feature(label2id_map) elif self.opt.label_reps in ['sep', 'sep_sum']: # represent each label independently - label_input, label_items = self.prepare_sep_label_feature(label2id) + label_input, label_items = self.prepare_sep_label_feature(label2id_map) + else: + raise TypeError('the label_reps should be one of cat & set & set_num') return test_feature_item, test_input, support_feature_items, support_input, label_items, label_input, - def prepare_label_feature(self, label2id: dict): + def prepare_label_feature(self, label2id_map: Dict[str, Dict[str, int]]): """ prepare digital input for label feature in concatenate style """ text, wp_text, label, wp_label, wp_mark = [], [], [], [], [] - sorted_labels = sorted(label2id.items(), key=lambda x: x[1]) - for label_name, label_id in sorted_labels: - if label_name == '[PAD]': - continue - tmp_text = self.convert_label_name(label_name) - tmp_wp_text = self.tokenizer.tokenize(' '.join(tmp_text)) - text.extend(tmp_text) - wp_text.extend(tmp_wp_text) - label.extend(['O'] * len(tmp_text)) - wp_label.extend(['O'] * len(tmp_wp_text)) - wp_mark.extend([0] + [1] * (len(tmp_wp_text) - 1)) - label_item = self.data_item2feature_item(DataItem(text, label, wp_text, wp_label, wp_mark), 0) - label_input = self.get_test_model_input(label_item) - return label_input, label_item - - def prepare_sep_label_feature(self, label2id): + sorted_label_map = {task: sorted(label2id.items(), key=lambda x: x[1]) + for task, label2id in label2id_map.items()} + label_item_map, label_input_map = {}, {} + for task, sorted_labels in sorted_label_map.items(): + for label_name, label_id in sorted_labels: + if label_name == '[PAD]': + continue + tmp_text = self.convert_label_name(label_name) + tmp_wp_text = self.tokenizer.tokenize(' '.join(tmp_text)) + text.extend(tmp_text) + wp_text.extend(tmp_wp_text) + label.extend(['O'] * len(tmp_text)) + wp_label.extend(['O'] * len(tmp_wp_text)) + wp_mark.extend([0] + [1] * (len(tmp_wp_text) - 1)) + label_item_map[task] = self.data_item2feature_item(DataItem(text, label, wp_text, wp_label, wp_mark), 0, + share_feature=False) + label_input_map[task] = self.get_test_model_input(label_item_map[task]) + return label_input_map, label_item_map + + def prepare_sep_label_feature(self, label2id_map): """ prepare digital input for label feature separately """ - label_items = [] - for label_name in label2id: - if label_name == '[PAD]': - continue - seq_in = self.convert_label_name(label_name) - seq_out = ['None'] * len(seq_in) - label = ['None'] - label_items.append(self.data_item2feature_item(DataItem(seq_in, seq_out, label), 0)) - label_input = self.get_support_model_input(label_items, len(label2id) - 1) # no pad, so - 1 - return label_input, label_items + label_item_map = {task: [] for task in label2id_map.keys()} + for task, label2id in label2id_map.items(): + for label_name in label2id: + if label_name == '[PAD]': + continue + seq_in = self.convert_label_name(label_name) + seq_out = ['None'] * len(seq_in) + label = ['None'] + label_item_map[task].append(self.data_item2feature_item(DataItem(seq_in, seq_out, label), + 0, share_feature=False)) + label_input_map = {task: self.get_support_model_input(label_items, len(label2id_map[task]) - 1) + for task, label_items in label_item_map.items()} # no pad, so - 1 + return label_input_map, label_item_map def convert_label_name(self, name): text = [] @@ -326,8 +349,9 @@ def convert_label_name(self, name): class NormalInputBuilder(InputBuilderBase): - def __init__(self, tokenizer): + def __init__(self, tokenizer, opt): super(NormalInputBuilder, self).__init__(tokenizer) + self.opt = opt def __call__(self, example, max_support_size, label2id) -> (FeatureItem, ModelInput, List[FeatureItem], ModelInput): test_feature_item = self.data_item2feature_item(data_item=example.test_data_item, seg_id=0) @@ -339,20 +363,31 @@ def __call__(self, example, max_support_size, label2id) -> (FeatureItem, ModelIn def data_item2feature_item(self, data_item: DataItem, seg_id: int) -> FeatureItem: """ get feature_item for bert, steps: 1. do padding 2. do digitalizing 3. make mask """ - tokens, labels = data_item.seq_in, data_item.seq_out + tokens = data_item.seq_in + label_map = {} + output_mask_map = {} + if 'sl' in self.opt.task: + labels = data_item.seq_out + label_map['sl'] = labels + output_mask_map['sl'] = [1] * len(labels) + + if 'sc' in self.opt.task: + labels = data_item.label + label_map['sc'] = labels + output_mask_map['sc'] = [1] * len(labels) + token_ids, segment_ids = self.digitizing_input(tokens=tokens, seg_id=seg_id) nwp_index = [[i] for i in range(len(token_ids))] input_mask = [1] * len(token_ids) - output_mask = [1] * len(data_item.seq_in) ret = FeatureItem( tokens=tokens, - labels=labels, + label_map=label_map, data_item=data_item, token_ids=token_ids, segment_ids=segment_ids, nwp_index=nwp_index, input_mask=input_mask, - output_mask=output_mask, + output_mask_map=output_mask_map, ) return ret @@ -396,18 +431,23 @@ def __init__(self): super(FewShotOutputBuilder, self).__init__() def __call__(self, test_feature_item: FeatureItem, support_feature_items: FeatureItem, - label2id: dict, max_support_size: int): - test_target = self.item2label_ids(test_feature_item, label2id) - # to estimate emission, the support target is one-hot here - support_target = [self.item2label_onehot(f_item, label2id) for f_item in support_feature_items] - support_target = self.pad_support_set(support_target, self.label2onehot('[PAD]', label2id), max_support_size) - return torch.LongTensor(test_target), torch.LongTensor(support_target) - - def item2label_ids(self, f_item: FeatureItem, label2id: dict): - return [label2id[lb] for lb in f_item.labels] - - def item2label_onehot(self, f_item: FeatureItem, label2id: dict): - return [self.label2onehot(lb, label2id) for lb in f_item.labels] + label2id_map: Dict[str, Dict[str, int]], max_support_size: int): + test_target_map, support_target_map = {}, {} + for task, label2id in label2id_map.items(): + test_target_map[task] = self.item2label_ids(test_feature_item, label2id, task) + # to estimate emission, the support target is one-hot here + support_target_map[task] = [self.item2label_onehot(f_item, label2id, task) + for f_item in support_feature_items] + support_target_map[task] = self.pad_support_set(support_target_map[task], + self.label2onehot('[PAD]', label2id), max_support_size) + return {task: torch.LongTensor(item_test_target) for task, item_test_target in test_target_map.items()}, \ + {task: torch.LongTensor(item_support_target) for task, item_support_target in support_target_map.items()} + + def item2label_ids(self, f_item: FeatureItem, label2id: dict, task: str): + return [label2id[lb] for lb in f_item.label_map[task]] + + def item2label_onehot(self, f_item: FeatureItem, label2id: dict, task): + return [self.label2onehot(lb, label2id) for lb in f_item.label_map[task]] def label2onehot(self, label: str, label2id: dict): onehot = [0 for _ in range(len(label2id))] @@ -440,12 +480,12 @@ def construct_feature( self, examples: List[FewShotExample], max_support_size: int, - label2id: dict, - id2label: dict, + label2id_map: Dict[str, Dict[str, int]], + id2label_map: Dict[str, Dict[int, str]] ) -> List[FewShotFeature]: all_features = [] for example in examples: - feature = self.example2feature(example, max_support_size, label2id, id2label) + feature = self.example2feature(example, max_support_size, label2id_map, id2label_map) all_features.append(feature) return all_features @@ -453,13 +493,13 @@ def example2feature( self, example: FewShotExample, max_support_size: int, - label2id: dict, - id2label: dict + label2id_map: Dict[str, Dict[str, int]], + id2label_map: Dict[str, Dict[int, str]] ) -> FewShotFeature: test_feature_item, test_input, support_feature_items, support_input = self.input_builder( - example, max_support_size, label2id) - test_target, support_target = self.output_builder( - test_feature_item, support_feature_items, label2id, max_support_size) + example, max_support_size, label2id_map) + test_target_map, support_target_map = self.output_builder( + test_feature_item, support_feature_items, label2id_map, max_support_size) ret = FewShotFeature( gid=example.gid, test_gid=example.test_id, @@ -468,8 +508,8 @@ def example2feature( test_feature_item=test_feature_item, support_input=support_input, support_feature_items=support_feature_items, - test_target=test_target, - support_target=support_target, + test_target_map=test_target_map, + support_target_map=support_target_map, ) return ret @@ -489,9 +529,9 @@ def example2feature( label2id: dict, id2label: dict ) -> FewShotFeature: - test_feature_item, test_input, support_feature_items, support_input, label_items, label_input = \ + test_feature_item, test_input, support_feature_items, support_input, label_item_map, label_input_map = \ self.input_builder(example, max_support_size, label2id) - test_target, support_target = self.output_builder( + test_target_map, support_target_map = self.output_builder( test_feature_item, support_feature_items, label2id, max_support_size) ret = FewShotFeature( gid=example.gid, @@ -501,10 +541,10 @@ def example2feature( test_feature_item=test_feature_item, support_input=support_input, support_feature_items=support_feature_items, - test_target=test_target, - support_target=support_target, - label_input=label_input, - label_items=label_items, + test_target_map=test_target_map, + support_target_map=support_target_map, + label_input_map=label_input_map, + label_item_map=label_item_map, ) return ret @@ -528,59 +568,35 @@ def purify(l): return set([item.replace('B-', '').replace('I-', '') for item in l]) ''' collect all label from: all test set & all support set ''' - all_labels = [] - label2id = {} + all_label_map = {key: [] for key in opt.task} + label2id_map = {key: {'[PAD]': 0} for key in opt.task} # '[PAD]' in first position and id is 0 for example in examples: - if opt.task == 'sl': - all_labels.append(example.test_data_item.seq_out) - all_labels.extend([data_item.seq_out for data_item in example.support_data_items]) - else: - all_labels.append(example.test_data_item.label) - all_labels.extend([data_item.label for data_item in example.support_data_items]) + if 'sl' in opt.task: + all_label_map['sl'].append(example.test_data_item.seq_out) + all_label_map['sl'].extend([data_item.seq_out for data_item in example.support_data_items]) + + if 'sc' in opt.task: + all_label_map['sc'].append(example.test_data_item.label) + all_label_map['sc'].extend([data_item.label for data_item in example.support_data_items]) ''' collect label word set ''' - label_set = sorted(list(purify(set(flatten(all_labels))))) # sort to make embedding id fixed - # transfer label to index type such as `label_1` - if opt.index_label: - if 'label2index_type' not in opt: - opt.label2index_type = {} - for idx, label in enumerate(label_set): - opt.label2index_type[label] = 'label_' + str(idx) - else: - max_label_idx = max([int(value.replace('label_', '')) for value in opt.label2index_type.values()]) - for label in label_set: - if label not in opt.label2index_type: - max_label_idx += 1 - opt.label2index_type[label] = 'label_' + str(max_label_idx) - label_set = [opt.label2index_type[label] for label in label_set] - elif opt.unused_label: - if 'label2unused_type' not in opt: - opt.label2unused_type = {} - for idx, label in enumerate(label_set): - opt.label2unused_type[label] = '[unused' + str(idx) + ']' - else: - max_label_idx = max([int(value.replace('[unused', '').replace(']', '')) for value in opt.label2unused_type.values()]) - for label in label_set: - if label not in opt.label2unused_type: - max_label_idx += 1 - opt.label2unused_type[label] = '[unused' + str(max_label_idx) + ']' - label_set = [opt.label2unused_type[label] for label in label_set] - else: - pass + # sort to make embedding id fixed + label_set_map = {key: sorted(list(purify(set(flatten(item_label))))) for key, item_label in all_label_map.items()} ''' build dict ''' - label2id['[PAD]'] = len(label2id) # '[PAD]' in first position and id is 0 - if opt.task == 'sl': - label2id['O'] = len(label2id) - for label in label_set: + if 'sl' in opt.task: + label2id_map['sl']['O'] = len(label2id_map['sl']) + for label in label_set_map['sl']: if label == 'O': continue - label2id['B-' + label] = len(label2id) - label2id['I-' + label] = len(label2id) - else: # sc - for label in label_set: - label2id[label] = len(label2id) + label2id_map['sl']['B-' + label] = len(label2id_map['sl']) + label2id_map['sl']['I-' + label] = len(label2id_map['sl']) + + if 'sc' in opt.task: + for label in label_set_map['sc']: + label2id_map['sc'][label] = len(label2id_map['sc']) + ''' reverse the label2id ''' - id2label = dict([(idx, label) for label, idx in label2id.items()]) - return label2id, id2label + id2label_map = {key: dict([(idx, label) for label, idx in label2id_map[key].items()]) for key in opt.task} + return label2id_map, id2label_map def make_word_dict(all_files: List[str]) -> (Dict[str, int], Dict[int, str]): @@ -607,12 +623,12 @@ def make_mask(token_ids: torch.Tensor, label_ids: torch.Tensor) -> (torch.Tensor return input_mask, output_mask -def save_feature(path, features, label2id, id2label): +def save_feature(path, features, label2id_map, id2label_map): with open(path, 'wb') as writer: saved_features = { 'features': features, - 'label2id': label2id, - 'id2label': id2label, + 'label2id': label2id_map, + 'id2label': id2label_map, } pickle.dump(saved_features, writer) diff --git a/utils/tester.py b/utils/tester.py index 5ce5e4b..18df5e2 100644 --- a/utils/tester.py +++ b/utils/tester.py @@ -16,8 +16,7 @@ from utils.preprocessor import FewShotFeature, ModelInput from utils.device_helper import prepare_model from utils.model_helper import make_model, load_model -from models.modules.transition_scorer import FewShotTransitionScorer -from models.few_shot_seq_labeler import FewShotSeqLabeler +from models.few_shot_learner import FewShotLearner logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -48,25 +47,32 @@ def __init__(self, opt, device, n_gpu): self.device = device self.n_gpu = n_gpu - def do_test(self, model: torch.nn.Module, test_features: List[FewShotFeature], id2label: dict, - log_mark: str = 'test_pred'): + def do_test(self, model: torch.nn.Module, test_features: List[FewShotFeature], id2label_map: dict, + log_mark: str = 'test_pred')->Dict[str, float]: logger.info("***** Running eval *****") # print("***** Running eval *****") logger.info(" Num features = %d", len(test_features)) logger.info(" Batch size = %d", self.batch_size) - all_results = [] + task_lst = id2label_map.keys() + all_results = [] # {task: [] for task in task_lst} model.eval() data_loader = self.get_data_loader(test_features) for batch in tqdm(data_loader, desc="Eval-Batch Progress"): - batch = tuple(t.to(self.device) for t in batch) # multi-gpu does scattering it-self + if self.n_gpu == 1: + # multi-gpu does scattering it-self + batch = tuple(t.to(self.device) if not isinstance(t, dict) + else {task: item.to(self.device) for task, item in t.items()} for t in batch) with torch.no_grad(): predictions = self.do_forward(batch, model) for i, feature_gid in enumerate(batch[0]): # iter over feature global id - prediction = predictions[i] - feature = test_features[feature_gid.item()] - all_results.append(RawResult(feature=feature, prediction=prediction)) + tmp_dict = {} + for task in task_lst: + prediction = predictions[task][i] + feature = test_features[feature_gid.item()] + tmp_dict[task] = RawResult(feature=feature, prediction=prediction) + all_results.append(tmp_dict) if model.emb_log: model.emb_log.write('text_' + str(feature_gid.item()) + '\t' + '\t'.join(feature.test_feature_item.data_item.seq_in) + '\n') @@ -75,8 +81,8 @@ def do_test(self, model: torch.nn.Module, test_features: List[FewShotFeature], i if model.emb_log: model.emb_log.close() - scores = self.eval_predictions(all_results, id2label, log_mark) - return scores + scores_map = self.eval_predictions(all_results, id2label_map, log_mark) + return scores_map def get_data_loader(self, features): dataset = TensorDataset([self.unpack_feature(f) for f in features]) @@ -122,26 +128,35 @@ def get_data_loader(self, features): data_loader = DataLoader(dataset, sampler=sampler, batch_size=self.batch_size, collate_fn=pad_collate) return data_loader - def eval_predictions(self, all_results: List[RawResult], id2label: dict, log_mark: str) -> float: + def eval_predictions(self, all_results: List[Dict[str, RawResult]], id2label_map: Dict[str, Dict[int, str]], + log_mark: str) -> Dict[str, float]: """ Our result score is average score of all few-shot batches. """ all_batches = self.reform_few_shot_batch(all_results) all_scores = [] for b_id, fs_batch in all_batches: - f1 = self.eval_one_few_shot_batch(b_id, fs_batch, id2label, log_mark) - all_scores.append(f1) - return sum(all_scores) * 1.0 / len(all_scores) - - def eval_one_few_shot_batch(self, b_id, fs_batch: List[RawResult], id2label: dict, log_mark: str) -> float: - pred_file_name = '{}.{}.txt'.format(log_mark, b_id) - output_prediction_file = os.path.join(self.opt.output_dir, pred_file_name) - if self.opt.task == 'sl': - self.writing_sl_prediction(fs_batch, output_prediction_file, id2label) + f1_map = self.eval_one_few_shot_batch(b_id, fs_batch, id2label_map, log_mark) + all_scores.append(f1_map) + return {task: sum([item[task] for item in all_scores]) * 1.0 / len(all_scores) for task in id2label_map.keys()} + + def eval_one_few_shot_batch(self, b_id, fs_batch: List[Dict[str, RawResult]], id2label_map: Dict[str, Dict[int, str]], + log_mark: str) -> Dict[str, float]: + f1_map = {} + if 'sl' in self.opt.task: + pred_file_name = 'sl.{}.{}.txt'.format(log_mark, b_id) + output_prediction_file = os.path.join(self.opt.output_dir, pred_file_name) + sl_fs_batch = [item['sl'] for item in fs_batch] + self.writing_sl_prediction(sl_fs_batch, output_prediction_file, id2label_map['sl']) precision, recall, f1 = self.eval_with_script(output_prediction_file) - elif self.opt.task == 'sc': - precision, recall, f1 = self.writing_sc_prediction(fs_batch, output_prediction_file, id2label) - else: - raise ValueError("Wrong task.") - return f1 + f1_map['sl'] = f1 + + if 'sc' in self.opt.task: + pred_file_name = 'sc.{}.{}.txt'.format(log_mark, b_id) + output_prediction_file = os.path.join(self.opt.output_dir, pred_file_name) + sc_fs_batch = [item['sc'] for item in fs_batch] + precision, recall, f1 = self.writing_sc_prediction(sc_fs_batch, output_prediction_file, id2label_map['sc']) + f1_map['sc'] = f1 + + return f1_map def writing_sc_prediction(self, fs_batch: List[RawResult], output_prediction_file: str, id2label: dict): tp, fp, fn = 0, 0, 0 @@ -212,14 +227,17 @@ def eval_with_script(self, output_prediction_file): f1 = float(std_results[7].replace('%;', '').replace("\\n'", '')) return precision, recall, f1 - def reform_few_shot_batch(self, all_results: List[RawResult]) -> List[List[Tuple[int, RawResult]]]: + def reform_few_shot_batch(self, all_results: List[Dict[str, RawResult]] + ) -> List[List[Tuple[int, Dict[str, RawResult]]]]: """ Our result score is average score of all few-shot batches. So here, we classify all result according to few-shot batch id. """ all_batches = {} + task_lst = all_results[0].keys() + has_task = list(task_lst)[0] for result in all_results: - b_id = result.feature.batch_gid + b_id = result[has_task].feature.batch_gid if b_id not in all_batches: all_batches[b_id] = [result] else: @@ -242,8 +260,8 @@ def unpack_feature(self, feature: FewShotFeature) -> List[torch.Tensor]: feature.support_input.input_mask, feature.support_input.output_mask, # target - feature.test_target, - feature.support_target, + feature.test_target_map, + feature.support_target_map, # Special torch.LongTensor([len(feature.support_feature_items)]), # support num ] @@ -268,7 +286,6 @@ def do_forward(self, batch, model): ) = batch prediction = model( - # loss, prediction = model( test_token_ids, test_segment_ids, test_nwp_index, @@ -292,19 +309,19 @@ def get_value_from_order_dict(self, order_dict, key): return v return [] - def clone_model(self, model, id2label): + def clone_model(self, model, id2label_map): """ clone only part of params """ # deal with data parallel model - new_model: FewShotSeqLabeler - old_model: FewShotSeqLabeler + new_model: FewShotLearner + old_model: FewShotLearner if self.opt.local_rank != -1 or self.n_gpu > 1 and hasattr(model, 'module'): # the model is parallel class here old_model = model.module else: old_model = model - emission_dict = old_model.emission_scorer.state_dict() - old_num_tags = len(self.get_value_from_order_dict(emission_dict, 'label_reps')) + # emission_dict = old_model.emission_scorer.state_dict() + # old_num_tags = len(self.get_value_from_order_dict(emission_dict, 'label_reps')) - config = {'num_tags': len(id2label), 'id2label': id2label} + config = {'num_tags': len(id2label_map['sl']) if 'sl' in id2label_map else 0, 'id2label_map': id2label_map} if 'num_anchors' in old_model.config: config['num_anchors'] = old_model.config['num_anchors'] # Use previous model's random anchors. # get a new instance for different domain @@ -315,13 +332,14 @@ def clone_model(self, model, id2label): else: sub_new_model = new_model ''' copy weights and stuff ''' - if old_model.opt.task == 'sl' and old_model.transition_scorer: + if 'sl' in old_model.opt.task and old_model.model_map['sl'].transition_scorer: # copy one-by-one because target transition and decoder will be left un-assigned sub_new_model.context_embedder.load_state_dict(old_model.context_embedder.state_dict()) - sub_new_model.emission_scorer.load_state_dict(old_model.emission_scorer.state_dict()) + sub_new_model.model_map['sl'].emission_scorer.load_state_dict( + old_model.model_map['sl'].emission_scorer.state_dict()) for param_name in ['backoff_trans_mat', 'backoff_start_trans_mat', 'backoff_end_trans_mat']: - sub_new_model.transition_scorer.state_dict()[param_name].copy_( - old_model.transition_scorer.state_dict()[param_name].data) + sub_new_model.model_map['sl'].transition_scorer.state_dict()[param_name].copy_( + old_model.model_map['sl'].transition_scorer.state_dict()[param_name].data) else: sub_new_model.load_state_dict(old_model.state_dict()) @@ -345,30 +363,30 @@ def get_data_loader(self, features): def unpack_feature(self, feature: FewShotFeature) -> List[torch.Tensor]: ret = [ - torch.LongTensor([feature.gid]), + torch.LongTensor([feature.gid]), # 1 # test - feature.test_input.token_ids, - feature.test_input.segment_ids, - feature.test_input.nwp_index, - feature.test_input.input_mask, - feature.test_input.output_mask, + feature.test_input.token_ids, # 2 + feature.test_input.segment_ids, # 3 + feature.test_input.nwp_index, # 4 + feature.test_input.input_mask, # 5 + feature.test_input.output_mask_map, # 6 # support - feature.support_input.token_ids, - feature.support_input.segment_ids, - feature.support_input.nwp_index, - feature.support_input.input_mask, - feature.support_input.output_mask, + feature.support_input.token_ids, # 7 + feature.support_input.segment_ids, # 8 + feature.support_input.nwp_index, # 9 + feature.support_input.input_mask, # 10 + feature.support_input.output_mask_map, # 11 # target - feature.test_target, - feature.support_target, + feature.test_target_map, # 11 + feature.support_target_map, # 12 # Special - torch.LongTensor([len(feature.support_feature_items)]), # support num + torch.LongTensor([len(feature.support_feature_items)]), # 13, support num # label feature - feature.label_input.token_ids, - feature.label_input.segment_ids, - feature.label_input.nwp_index, - feature.label_input.input_mask, - feature.label_input.output_mask, + {key: label_input.token_ids for key, label_input in feature.label_input_map.items()}, # 14 + {key: label_input.segment_ids for key, label_input in feature.label_input_map.items()}, # 15 + {key: label_input.nwp_index for key, label_input in feature.label_input_map.items()}, # 16 + {key: label_input.input_mask for key, label_input in feature.label_input_map.items()}, # 17 + {key: label_input.output_mask_map for key, label_input in feature.label_input_map.items()}, # 18 ] return ret @@ -420,17 +438,17 @@ def do_forward(self, batch, model): return prediction -def eval_check_points(opt, tester, test_features, test_id2label, device): +def eval_check_points(opt, tester, test_features, test_id2label_map, device): all_cpt_file = list(filter(lambda x: '.cpt.pl' in x, os.listdir(opt.saved_model_path))) all_cpt_file = sorted(all_cpt_file, key=lambda x: int(x.replace('model.step', '').replace('.cpt.pl', ''))) max_score = 0 for cpt_file in all_cpt_file: cpt_model = load_model(os.path.join(opt.saved_model_path, cpt_file)) - testing_model = tester.clone_model(cpt_model, test_id2label) - if opt.mask_transition and opt.task == 'sl': + testing_model = tester.clone_model(cpt_model, test_id2label_map) + if opt.mask_transition and 'sl' in opt.task: testing_model.label_mask = opt.test_label_mask.to(device) - test_score = tester.do_test(testing_model, test_features, test_id2label, log_mark='test_pred') + test_score = tester.do_test(testing_model, test_features, test_id2label_map, log_mark='test_pred') if test_score > max_score: max_score = test_score logger.info('cpt_file:{} - test:{}'.format(cpt_file, test_score)) diff --git a/utils/trainer.py b/utils/trainer.py index e355139..3f26150 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -7,24 +7,20 @@ import os import copy from transformers import AdamW, get_linear_schedule_with_warmup -from pytorch_pretrained_bert.optimization import BertAdam from tqdm import tqdm, trange from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler # My Staff from utils.iter_helper import PadCollate, FewShotDataset, SimilarLengthSampler from utils.preprocessor import FewShotFeature, ModelInput -from utils.device_helper import prepare_model from utils.model_helper import make_model, load_model -from models.few_shot_seq_labeler import FewShotSeqLabeler +from models.few_shot_learner import FewShotLearner logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO, - # stream=sys.stderr - # stream=sys.stdout ) logger = logging.getLogger(__name__) @@ -75,8 +71,8 @@ def __init__(self, opt, optimizer, scheduler, param_to_optimize, device, n_gpu, self.n_gpu = n_gpu def do_train(self, model, train_features, num_train_epochs, - dev_features=None, dev_id2label=None, - test_features=None, test_id2label=None, + dev_features=None, dev_id2label_map=None, + test_features=None, test_id2label_map=None, best_dev_score_now=0): """ do training and dev model selection @@ -100,7 +96,7 @@ def do_train(self, model, train_features, num_train_epochs, total_step = 0 best_dev_score_now = best_dev_score_now best_model_now = model - test_score = None + test_score_map = None min_loss = 100000000000000 loss_now = 0 no_new_best_dev_num = 0 @@ -115,7 +111,9 @@ def do_train(self, model, train_features, num_train_epochs, for epoch_id in trange(int(num_train_epochs), desc="Epoch"): for step, batch in enumerate(tqdm(data_loader, desc="Train-Batch Progress")): if self.n_gpu == 1: - batch = tuple(t.to(self.device) for t in batch) # multi-gpu does scattering it-self + # multi-gpu does scattering it-self + batch = tuple(t.to(self.device) if not isinstance(t, dict) + else {task: item.to(self.device) for task, item in t.items()} for t in batch) ''' loss ''' loss = self.do_forward(batch, model, epoch_id, step) loss = self.process_special_loss(loss) # for parallel process, split batch and so on @@ -131,11 +129,11 @@ def do_train(self, model, train_features, num_train_epochs, if self.time_to_make_check_point(total_step, data_loader): if self.tester and self.opt.eval_when_train: # this is not suit for training big model print("Start dev eval.") - dev_score, test_score, copied_best_model = self.model_selection( - model, best_dev_score_now, dev_features, dev_id2label, test_features, test_id2label) + dev_score_map, test_score_map, copied_best_model = self.model_selection( + model, best_dev_score_now, dev_features, dev_id2label_map, test_features, test_id2label_map) - if dev_score > best_dev_score_now: - best_dev_score_now = dev_score + if self.is_bigger(dev_score_map, best_dev_score_now): + best_dev_score_now = dev_score_map best_model_now = copied_best_model no_new_best_dev_num = 0 else: @@ -168,7 +166,7 @@ def do_train(self, model, train_features, num_train_epochs, break print(" --- The {} epoch Finish --- ".format(epoch_id)) - return best_model_now, best_dev_score_now, test_score + return best_model_now, best_dev_score_now, test_score_map def time_to_make_check_point(self, step, data_loader): interval_size = int(len(data_loader) / self.opt.cpt_per_epoch) @@ -248,20 +246,40 @@ def make_check_point_(self, model, step): time.sleep(300) self.make_check_point_(model, step) - def model_selection(self, model, best_score, dev_features, dev_id2label, test_features=None, test_id2label=None): + def is_bigger(self, first_score: Dict[str, float], second_score: Dict[str, float], + task_weight: Dict[str, float] = None)->bool: + task_lst = first_score.keys() + if isinstance(second_score, dict): + assert task_lst == second_score.keys(), "the two scores should have same keys" + elif isinstance(second_score, int) or isinstance(second_score, float): + second_score = {task: second_score for task in task_lst} + else: + raise TypeError('the second_score should be int or a dict to assign every task a int or float value') + + if task_weight: + assert sum(task_weight.values()) == 1.0, "the weight for all tasks should sum to 1" + else: + task_weight = {task: 1.0 / len(task_lst) for task in task_lst} + + first_avg_score = [score * task_weight[task] for task, score in first_score.items()] + second_avg_score = [score * task_weight[task] for task, score in second_score.items()] + + return first_avg_score > second_avg_score + + def model_selection(self, model, best_score, dev_features, dev_id2label_map, test_features=None, test_id2label_map=None): """ do model selection during training""" print("Start dev model selection.") # do dev eval at every dev_interval point and every end of epoch - dev_model = self.tester.clone_model(model, dev_id2label) # copy reusable params, for a different domain - if self.opt.mask_transition and self.opt.task == 'sl': - dev_model.label_mask = self.opt.dev_label_mask.to(self.device) + dev_model = self.tester.clone_model(model, dev_id2label_map) # copy reusable params, for a different domain + if self.opt.mask_transition and 'sl' in self.opt.task: + dev_model.model_map['sl'].label_mask = self.opt.dev_label_mask.to(self.device) - dev_score = self.tester.do_test(dev_model, dev_features, dev_id2label, log_mark='dev_pred') - logger.info(" dev score(F1) = {}".format(dev_score)) - print(" dev score(F1) = {}".format(dev_score)) + dev_score_map = self.tester.do_test(dev_model, dev_features, dev_id2label_map, log_mark='dev_pred') + logger.info(" dev score(F1) = {}".format(dev_score_map)) + print(" dev score(F1) = {}".format(dev_score_map)) best_model = None - test_score = None - if dev_score > best_score: + test_score_map = None + if self.is_bigger(dev_score_map, best_score): logger.info(" === Found new best!! === ") ''' store new best model ''' best_model = self.clone_model(model) # copy model to avoid writen by latter training @@ -271,16 +289,17 @@ def model_selection(self, model, best_score, dev_features, dev_id2label, test_fe ''' get current best model's test score ''' if test_features: - test_model = self.tester.clone_model(model, test_id2label) # copy reusable params for different domain - if self.opt.mask_transition and self.opt.task == 'sl': - test_model.label_mask = self.opt.test_label_mask.to(self.device) - - test_score = self.tester.do_test(test_model, test_features, test_id2label, log_mark='test_pred') - logger.info(" test score(F1) = {}".format(test_score)) - print(" test score(F1) = {}".format(test_score)) + # copy reusable params for different domain + test_model = self.tester.clone_model(model, test_id2label_map) + if self.opt.mask_transition and 'sl' in self.opt.task: + test_model.model_map['sl'].label_mask = self.opt.test_label_mask.to(self.device) + + test_score_map = self.tester.do_test(test_model, test_features, test_id2label_map, log_mark='test_pred') + logger.info(" test score(F1) = {}".format(test_score_map)) + print(" test score(F1) = {}".format(test_score_map)) # reset the model status model.train() - return dev_score, test_score, best_model + return dev_score_map, test_score_map, best_model def check_point_content(self, model): """ necessary staff for rebuild the model """ @@ -288,7 +307,8 @@ def check_point_content(self, model): return model.state_dict() def select_model_from_check_point( - self, train_id2label, dev_features, dev_id2label, test_features=None, test_id2label=None, rm_cpt=True): + self, train_id2label_map, dev_features, dev_id2label_map, + test_features=None, test_id2label_map=None, rm_cpt=True): all_cpt_file = list(filter(lambda x: '.cpt.pl' in x, os.listdir(self.opt.output_dir))) best_score = 0 test_score_then = 0 @@ -297,11 +317,11 @@ def select_model_from_check_point( for cpt_file in all_cpt_file: logger.info('testing check point: {}'.format(cpt_file)) model = load_model(os.path.join(self.opt.output_dir, cpt_file)) - dev_score, test_score, copied_model = self.model_selection( - model, best_score, dev_features, dev_id2label, test_features, test_id2label) - if dev_score > best_score: - best_score = dev_score - test_score_then = test_score + dev_score_map, test_score_map, copied_model = self.model_selection( + model, best_score, dev_features, dev_id2label_map, test_features, test_id2label_map) + if self.is_bigger(dev_score_map, best_score): + best_score = dev_score_map + test_score_then = test_score_map best_model = copied_model if rm_cpt: # delete all check point for cpt_file in all_cpt_file: @@ -400,8 +420,8 @@ def unpack_feature(self, feature: FewShotFeature) -> List[torch.Tensor]: feature.support_input.input_mask, feature.support_input.output_mask, # target - feature.test_target, - feature.support_target, + feature.test_target_map, + feature.support_target_map, # Special torch.LongTensor([len(feature.support_feature_items)]), # support num ] @@ -462,8 +482,8 @@ def get_value_from_order_dict(self, order_dict, key): def clone_model(self, model): # deal with data parallel model - best_model: FewShotSeqLabeler - old_model: FewShotSeqLabeler + best_model: FewShotLearner + old_model: FewShotLearner if self.opt.local_rank != -1 or self.n_gpu > 1: # the model is parallel class here old_model = model.module else: @@ -490,30 +510,30 @@ def get_data_loader(self, dataset, sampler): def unpack_feature(self, feature: FewShotFeature) -> List[torch.Tensor]: ret = [ - torch.LongTensor([feature.gid]), + torch.LongTensor([feature.gid]), # 0 # test - feature.test_input.token_ids, - feature.test_input.segment_ids, - feature.test_input.nwp_index, - feature.test_input.input_mask, - feature.test_input.output_mask, + feature.test_input.token_ids, # 1 + feature.test_input.segment_ids, # 2 + feature.test_input.nwp_index, # 3 + feature.test_input.input_mask, # 4 + feature.test_input.output_mask_map, # 5 # support - feature.support_input.token_ids, - feature.support_input.segment_ids, - feature.support_input.nwp_index, - feature.support_input.input_mask, - feature.support_input.output_mask, + feature.support_input.token_ids, # 6 + feature.support_input.segment_ids, # 7 + feature.support_input.nwp_index, # 8 + feature.support_input.input_mask, # 9 + feature.support_input.output_mask_map, # 10 # target - feature.test_target, - feature.support_target, + feature.test_target_map, # 11 + feature.support_target_map, # 12 # Special - torch.LongTensor([len(feature.support_feature_items)]), # support num + torch.LongTensor([len(feature.support_feature_items)]), # 13, support num # label feature - feature.label_input.token_ids, - feature.label_input.segment_ids, - feature.label_input.nwp_index, - feature.label_input.input_mask, - feature.label_input.output_mask, + {key: label_input.token_ids for key, label_input in feature.label_input_map.items()}, # 14 + {key: label_input.segment_ids for key, label_input in feature.label_input_map.items()}, # 15 + {key: label_input.nwp_index for key, label_input in feature.label_input_map.items()}, # 16 + {key: label_input.input_mask for key, label_input in feature.label_input_map.items()}, # 17 + {key: label_input.output_mask_map for key, label_input in feature.label_input_map.items()}, # 18 ] return ret From 64d1fb4275f9fd72338a3078817936debb8800cc Mon Sep 17 00:00:00 2001 From: laiyongkui Date: Fri, 26 Jun 2020 22:14:25 +0800 Subject: [PATCH 2/6] del some notes, complement readme --- models/few_shot_learner.py | 48 -------------------------------------- readme.md | 10 +++++--- utils/model_helper.py | 6 ----- 3 files changed, 7 insertions(+), 57 deletions(-) diff --git a/models/few_shot_learner.py b/models/few_shot_learner.py index 9ec0e87..f299534 100644 --- a/models/few_shot_learner.py +++ b/models/few_shot_learner.py @@ -13,42 +13,18 @@ class FewShotLearner(torch.nn.Module): def __init__(self, opt, context_embedder: ContextEmbedderBase, - # emission_scorer_map: Dict[str, EmissionScorerBase], - # decoder_map: Dict[str, torch.nn.Module], model_map: Dict[str, torch.nn.Module], - # transition_scorer: TransitionScorerBase = None, config: dict = None, # store necessary setting or none-torch params emb_log: str = None): super(FewShotLearner, self).__init__() self.opt = opt self.context_embedder = context_embedder - # self.emission_scorer_map = emission_scorer_map - # self.transition_scorer = transition_scorer - # self.decoder_map = decoder_map self.no_embedder_grad = opt.no_embedder_grad self.label_mask = None self.config = config self.emb_log = emb_log - # self.task_lst = decoder_map.keys() self.model_map = model_map - # for task in self.task_lst: - # if task == 'sl': - # self.model_map[task] = FewShotSeqLabeler(opt=opt, - # context_embedder=context_embedder, - # emission_scorer=emission_scorer_map[task], - # decoder=decoder_map[task], - # transition_scorer=transition_scorer, - # label_mask=self.label_mask, - # config=config, - # emb_log=emb_log) - # elif task == 'sc': - # self.model_map[task] = FewShotTextClassifier(opt=opt, - # context_embedder=context_embedder, - # emission_scorer=emission_scorer_map[task], - # decoder=decoder_map[task], - # config=config, - # emb_log=emb_log) def forward( self, @@ -144,36 +120,12 @@ def __init__( self, opt, context_embedder: ContextEmbedderBase, - # emission_scorer_map: Dict[str, EmissionScorerBase], - # decoder_map: Dict[str, torch.nn.Module], model_map: Dict[str, torch.nn.Module], - # transition_scorer: TransitionScorerBase = None, config: dict = None, # store necessary setting or none-torch params emb_log: str = None ): - # super(SchemaFewShotLearner, self).__init__( - # opt, context_embedder, emission_scorer_map, decoder_map, transition_scorer, config, emb_log) super(SchemaFewShotLearner, self).__init__(opt, context_embedder, model_map, config, emb_log) - - # self.task_lst = decoder_map.keys() self.model_map = model_map - # for task in self.task_lst: - # if task == 'sl': - # self.model_map[task] = SchemaFewShotSeqLabeler(opt=opt, - # context_embedder=context_embedder, - # emission_scorer=emission_scorer_map[task], - # decoder=decoder_map[task], - # transition_scorer=transition_scorer, - # label_mask=self.label_mask, - # config=config, - # emb_log=emb_log) - # elif task == 'sc': - # self.model_map[task] = SchemaFewShotTextClassifier(opt=opt, - # context_embedder=context_embedder, - # emission_scorer=emission_scorer_map[task], - # decoder=decoder_map[task], - # config=config, - # emb_log=emb_log) def forward( self, diff --git a/readme.md b/readme.md index 7dc01b7..94da181 100644 --- a/readme.md +++ b/readme.md @@ -1,6 +1,6 @@ # Meta Dialog Platform (MDP) -Meta Dialog Platform, is a tool for Few-Shot Learning. The platform now can be used for `classification` and `sequence labeling`. +Meta Dialog Platform, is a tool for Few-Shot Learning. The platform now can be used for `classification` and `sequence labeling` and `joint learning`(which only shares embedding). As [Electra](https://openreview.net/forum?id=r1xMH1BtvB) proposed, we can use `Electra small model` to speed up our exploitation because it's so small. And if you want to get the chinese version, you can go [Chinese-Electra](https://github.com/ymcui/Chinese-ELECTRA). @@ -115,18 +115,22 @@ We provide some bash scripts for convenience. As mentioned, we provide two scrip - sequence labeling - `run_electra_sl.sh` - `run_bert_sl.sh` +- joint learning + - `run_electra_sc+sl.sh` + - `run_bert_sc+sl.sh` #### parameters There are many parameters to control the train & test process, but there are some main parameters you should change and know. - `do_debug`: set for debug model -- `task`: specify the target task +- `task`: specify the target task which is a list split with space(should with backslash in bash script) +- `emission`: specify the emission choice which is a list match with task list, others is same with `task` - dataset name: `support_shots_lst` & `query_shot` & `episode` & `cross_id` are used for specify the dataset path - `embedder`: specify the embedder type, the choices are `bert` & `electra` & `sep_bert` & `glove`, in which `sep_bert` are not pair-wise embedding while others do. - `pretrained_model_path`: the pre-trained model path - `pretrained_vocab_path`: the vocabulary of pre-trained model, you can specify the file or its parent folder (default find the `vocab.txt` in it) - `base_data_dir`: the path of your dataset, the `base` mean that in which there are sub-folders (the sub-folders is named with `dataset name` mentioned at second item). If not, you can adjust the script. - + ## Information diff --git a/utils/model_helper.py b/utils/model_helper.py index c8edf40..64f0796 100644 --- a/utils/model_helper.py +++ b/utils/model_helper.py @@ -118,7 +118,6 @@ def make_model(opt, config): ''' Build decoder ''' model_map = {} - decoder_map = {} transition_scorer = None if 'sl' in opt.task: # for sequence labeling if opt.decoder == 'sms': @@ -156,7 +155,6 @@ def make_model(opt, config): else: raise TypeError('wrong component type') - # decoder_map['sl'] = decoder seq_laber = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler model_map['sl'] = seq_laber(opt=opt, context_embedder=context_embedder, @@ -168,7 +166,6 @@ def make_model(opt, config): if 'sc' in opt.task: # for single-label text classification task decoder = SingleLabelTextClassifier() - # decoder_map['sc'] = decoder text_classifier = SchemaFewShotTextClassifier if opt.use_schema else FewShotTextClassifier model_map['sc'] = text_classifier(opt=opt, context_embedder=context_embedder, @@ -182,9 +179,6 @@ def make_model(opt, config): model = few_shot_learner( opt=opt, context_embedder=context_embedder, - # emission_scorer_map=emission_scorer_map, - # decoder_map=decoder_map, - # transition_scorer=transition_scorer, model_map=model_map, config=config, emb_log=emb_log From ec87b255cc2d214fc14c73cab26561b49401af48 Mon Sep 17 00:00:00 2001 From: laiyongkui Date: Wed, 1 Jul 2020 16:17:54 +0800 Subject: [PATCH 3/6] debug v0.2 multi-task learning, normal for single-task learning when using multi-task framwork --- main.py | 1 - models/few_shot_learner.py | 79 ++++++++--- models/few_shot_seq_labeler.py | 221 ++++++++++++++++++----------- models/few_shot_text_classifier.py | 109 +++++++++----- models/modules/text_classifier.py | 2 +- scripts/run_bert_sc+sl.sh | 27 ++-- scripts/run_bert_sc.sh | 15 +- scripts/run_electra_sc+sl.sh | 30 ++-- scripts/run_electra_sc.sh | 12 +- scripts/run_electra_sl.sh | 18 +-- utils/device_helper.py | 3 - utils/opt.py | 2 + utils/preprocessor.py | 49 ++++--- utils/tester.py | 59 ++++++-- utils/trainer.py | 32 +++-- 15 files changed, 413 insertions(+), 246 deletions(-) diff --git a/main.py b/main.py index 9fb4df5..841e023 100644 --- a/main.py +++ b/main.py @@ -48,7 +48,6 @@ def get_training_data_and_feature(opt, data_loader, preprocessor): dev_features = preprocessor.construct_feature(dev_examples, dev_max_support_size, dev_label2id_map, dev_id2label_map) logger.info(' Finish prepare train dev features ') - if opt.save_feature: save_feature(opt.train_path.replace('.json', '.saved.pk'), train_features, train_label2id_map, train_id2label_map) diff --git a/models/few_shot_learner.py b/models/few_shot_learner.py index f299534..c3963d4 100644 --- a/models/few_shot_learner.py +++ b/models/few_shot_learner.py @@ -2,10 +2,6 @@ import torch from typing import Tuple, Dict, List from models.modules.context_embedder_base import ContextEmbedderBase -from models.modules.emission_scorer_base import EmissionScorerBase -from models.modules.transition_scorer import TransitionScorerBase -from models.few_shot_seq_labeler import FewShotSeqLabeler, SchemaFewShotSeqLabeler -from models.few_shot_text_classifier import FewShotTextClassifier, SchemaFewShotTextClassifier class FewShotLearner(torch.nn.Module): @@ -20,11 +16,17 @@ def __init__(self, self.opt = opt self.context_embedder = context_embedder self.no_embedder_grad = opt.no_embedder_grad - self.label_mask = None self.config = config self.emb_log = emb_log - self.model_map = model_map + self.label_mask = None + + # self.model_map = model_map + if 'sl' in model_map: + self.seq_labeler_model = model_map['sl'] + + if 'sc' in model_map: + self.classifier_model = model_map['sc'] def forward( self, @@ -70,18 +72,34 @@ def forward( if self.training: loss = 0. for task in self.opt.task: - loss += self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], - reps_map[task]['support'], support_output_mask_map[task], - test_target_map[task], support_target_map[task], support_num, - self.training) + if self.opt.do_debug: + print('task: {} - test_target: {} - {}'.format(task, test_target_map[task].size(), test_target_map[task])) + if task == 'sl': + loss += self.seq_labeler_model(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], self.label_mask) + elif task == 'sc': + loss += self.classifier_model(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], self.label_mask) + else: + raise ValueError return loss else: prediction_map = {} for task in self.opt.task: - prediction = self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], - reps_map[task]['support'], support_output_mask_map[task], - test_target_map[task], support_target_map[task], support_num, - self.training) + if task == 'sl': + prediction = self.seq_labeler_model.decode(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + self.label_mask) + elif task == 'sc': + prediction = self.classifier_model.decode(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + self.label_mask) + else: + raise ValueError prediction_map[task] = prediction return prediction_map @@ -125,7 +143,6 @@ def __init__( emb_log: str = None ): super(SchemaFewShotLearner, self).__init__(opt, context_embedder, model_map, config, emb_log) - self.model_map = model_map def forward( self, @@ -194,18 +211,34 @@ def forward( if self.training: loss = 0. for task in self.opt.task: - loss += self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], - reps_map[task]['support'], support_output_mask_map[task], - test_target_map[task], support_target_map[task], - support_num, label_reps_map[task], self.training) + if task == 'sl': + loss += self.seq_labeler_model(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + label_reps_map[task], self.label_mask) + elif task == 'sc': + loss += self.classifier_model(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + label_reps_map[task], self.label_mask) + else: + raise ValueError return loss else: prediction_map = {} for task in self.opt.task: - prediction = self.model_map[task](reps_map[task]['test'], test_output_mask_map[task], - reps_map[task]['support'], support_output_mask_map[task], - test_target_map[task], support_target_map[task], - support_num, label_reps_map[task], self.training) + if task == 'sl': + prediction = self.seq_labeler_model.decode(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + label_reps_map[task], self.label_mask) + elif task == 'sc': + prediction = self.classifier_model.decode(reps_map[task]['test'], test_output_mask_map[task], + reps_map[task]['support'], support_output_mask_map[task], + test_target_map[task], support_target_map[task], + label_reps_map[task], self.label_mask) + else: + raise ValueError prediction_map[task] = prediction return prediction_map diff --git a/models/few_shot_seq_labeler.py b/models/few_shot_seq_labeler.py index 9a6b84b..9824321 100644 --- a/models/few_shot_seq_labeler.py +++ b/models/few_shot_seq_labeler.py @@ -15,7 +15,6 @@ def __init__(self, emission_scorer: EmissionScorerBase, decoder: torch.nn.Module, transition_scorer: TransitionScorerBase = None, - label_mask: torch.Tensor = None, config: dict = None, # store necessary setting or none-torch params emb_log: str = None): super(FewShotSeqLabeler, self).__init__() @@ -24,8 +23,6 @@ def __init__(self, self.emission_scorer = emission_scorer self.transition_scorer = transition_scorer self.decoder = decoder - self.no_embedder_grad = opt.no_embedder_grad - self.label_mask = label_mask self.config = config self.emb_log = emb_log @@ -37,8 +34,7 @@ def forward( support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, - support_num: torch.Tensor, - is_training: bool = True, + label_mask: torch.Tensor = None, ): """ :param test_reps: (batch_size, test_len, emb_dim) @@ -47,60 +43,87 @@ def forward( :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, test_len) :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) - :param support_num: (batch_size, 1) - :param is_training: the training mode + :param label_mask: the output label mask :return: """ # calculate emission: shape(batch_size, test_len, no_pad_num_tag) emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target) - logits = emission # as we remove pad label (id = 0), so all label id sub 1. And relu is used to avoid -1 index test_target = torch.nn.functional.relu(test_target - 1) - loss, prediction = torch.FloatTensor(0).to(test_target.device), None if self.transition_scorer: transitions, start_transitions, end_transitions = self.transition_scorer(test_reps, support_target) - if self.label_mask is not None: - transitions = self.mask_transition(transitions, self.label_mask) + if label_mask is not None: + transitions = self.mask_transition(transitions, label_mask) self.decoder: ConditionalRandomField - if is_training: - # the CRF staff - llh = self.decoder.forward( - inputs=logits, - transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, - tags=test_target, - mask=test_output_mask) - loss = -1 * llh - else: - best_paths = self.decoder.viterbi_tags(logits=logits, - transitions_without_constrain=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, - mask=test_output_mask) - # split path and score - prediction, path_score = zip(*best_paths) - # we block pad label(id=0) before by - 1, here, we add 1 back - prediction = self.add_back_pad_label(prediction) + # the CRF staff + llh = self.decoder.forward( + inputs=logits, + transitions=transitions, + start_transitions=start_transitions, + end_transitions=end_transitions, + tags=test_target, + mask=test_output_mask) + loss = -1 * llh + else: self.decoder: SequenceLabeler - if is_training: - loss = self.decoder.forward(logits=logits, - tags=test_target, - mask=test_output_mask) - else: - prediction = self.decoder.decode(logits=logits, masks=test_output_mask) - # we block pad label(id=0) before by - 1, here, we add 1 back - prediction = self.add_back_pad_label(prediction) - if is_training: - return loss + loss = self.decoder.forward(logits=logits, + tags=test_target, + mask=test_output_mask) + + return loss + + def decode( + self, + test_reps: torch.Tensor, + test_output_mask: torch.Tensor, + support_reps: torch.Tensor, + support_output_mask: torch.Tensor, + test_target: torch.Tensor, + support_target: torch.Tensor, + label_mask: torch.Tensor = None, + ): + """ + :param test_reps: (batch_size, test_len, emb_dim) + :param test_output_mask: (batch_size, test_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) + :param support_output_mask: (batch_size, support_size, support_len) + :param test_target: index targets (batch_size, test_len) + :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) + :param label_mask: the output label mask + :return: + """ + # calculate emission: shape(batch_size, test_len, no_pad_num_tag) + emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target) + logits = emission + + if self.transition_scorer: + transitions, start_transitions, end_transitions = self.transition_scorer(test_reps, support_target) + + if label_mask is not None: + transitions = self.mask_transition(transitions, label_mask) + + self.decoder: ConditionalRandomField + best_paths = self.decoder.viterbi_tags(logits=logits, + transitions_without_constrain=transitions, + start_transitions=start_transitions, + end_transitions=end_transitions, + mask=test_output_mask) + # split path and score + prediction, path_score = zip(*best_paths) + # we block pad label(id=0) before by - 1, here, we add 1 back + prediction = self.add_back_pad_label(prediction) else: - return prediction + self.decoder: SequenceLabeler + prediction = self.decoder.decode(logits=logits, masks=test_output_mask) + # we block pad label(id=0) before by - 1, here, we add 1 back + prediction = self.add_back_pad_label(prediction) + return prediction def add_back_pad_label(self, predictions: List[List[int]]): for pred in predictions: @@ -122,12 +145,11 @@ def __init__( emission_scorer: EmissionScorerBase, decoder: torch.nn.Module, transition_scorer: TransitionScorerBase = None, - label_mask: torch.Tensor = None, config: dict = None, # store necessary setting or none-torch params emb_log: str = None ): super(SchemaFewShotSeqLabeler, self).__init__( - opt, context_embedder, emission_scorer, decoder, transition_scorer, label_mask, config, emb_log) + opt, context_embedder, emission_scorer, decoder, transition_scorer, config, emb_log) def forward( self, @@ -137,9 +159,8 @@ def forward( support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, - support_num: torch.Tensor, label_reps: torch.Tensor = None, - is_training: bool = True, + label_mask: torch.Tensor = None, ): """ few-shot sequence labeler using schema information @@ -149,19 +170,14 @@ def forward( :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, test_len) :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) - :param support_num: (batch_size, 1) :param label_reps: (batch_size, label_num, emb_dim) - :param is_training: the training mode + :param label_mask: the output label mask :return: """ # calculate emission: shape(batch_size, test_len, no_pad_num_tag) emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target, label_reps) - if not is_training and self.emb_log: - self.emb_log.write('\n'.join(['test_target\t' + '\t'.join(map(str, one_target)) - for one_target in test_target.tolist()]) + '\n') - logits = emission # block pad of label_id = 0, so all label id sub 1. And relu is used to avoid -1 index @@ -171,44 +187,79 @@ def forward( if self.transition_scorer: transitions, start_transitions, end_transitions = self.transition_scorer(test_reps, support_target, label_reps[0]) - if self.label_mask is not None: - transitions = self.mask_transition(transitions, self.label_mask) + if label_mask is not None: + transitions = self.mask_transition(transitions, label_mask) self.decoder: ConditionalRandomField - if is_training: - # the CRF staff - llh = self.decoder.forward( - inputs=logits, - transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, - tags=test_target, - mask=test_output_mask) - loss = -1 * llh - else: - best_paths = self.decoder.viterbi_tags(logits=logits, - transitions_without_constrain=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, - mask=test_output_mask) - # split path and score - prediction, path_score = zip(*best_paths) - # we block pad label(id=0) before by - 1, here, we add 1 back - prediction = self.add_back_pad_label(prediction) + # the CRF staff + llh = self.decoder.forward( + inputs=logits, + transitions=transitions, + start_transitions=start_transitions, + end_transitions=end_transitions, + tags=test_target, + mask=test_output_mask) + loss = -1 * llh else: self.decoder: SequenceLabeler - if is_training: - loss = self.decoder.forward(logits=logits, - tags=test_target, - mask=test_output_mask) - else: - prediction = self.decoder.decode(logits=logits, masks=test_output_mask) - # we block pad label(id=0) before by - 1, here, we add 1 back - prediction = self.add_back_pad_label(prediction) - if is_training: - return loss + loss = self.decoder.forward(logits=logits, + tags=test_target, + mask=test_output_mask) + return loss + + def decode( + self, + test_reps: torch.Tensor, + test_output_mask: torch.Tensor, + support_reps: torch.Tensor, + support_output_mask: torch.Tensor, + test_target: torch.Tensor, + support_target: torch.Tensor, + label_reps: torch.Tensor = None, + label_mask: torch.Tensor = None, + ): + """ + few-shot sequence labeler using schema information + :param test_reps: (batch_size, test_len, emb_dim) + :param test_output_mask: (batch_size, test_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) + :param support_output_mask: (batch_size, support_size, support_len) + :param test_target: index targets (batch_size, test_len) + :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) + :param label_reps: (batch_size, label_num, emb_dim) + :param label_mask: the output label mask + :return: + """ + + # calculate emission: shape(batch_size, test_len, no_pad_num_tag) + emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target, + label_reps) + logits = emission + + if self.transition_scorer: + transitions, start_transitions, end_transitions = self.transition_scorer(test_reps, support_target, label_reps[0]) + + if label_mask is not None: + transitions = self.mask_transition(transitions, label_mask) + + self.decoder: ConditionalRandomField + + best_paths = self.decoder.viterbi_tags(logits=logits, + transitions_without_constrain=transitions, + start_transitions=start_transitions, + end_transitions=end_transitions, + mask=test_output_mask) + # split path and score + prediction, path_score = zip(*best_paths) + # we block pad label(id=0) before by - 1, here, we add 1 back + prediction = self.add_back_pad_label(prediction) else: - return prediction + self.decoder: SequenceLabeler + + prediction = self.decoder.decode(logits=logits, masks=test_output_mask) + # we block pad label(id=0) before by - 1, here, we add 1 back + prediction = self.add_back_pad_label(prediction) + return prediction def main(): diff --git a/models/few_shot_text_classifier.py b/models/few_shot_text_classifier.py index 6fd7689..868c038 100644 --- a/models/few_shot_text_classifier.py +++ b/models/few_shot_text_classifier.py @@ -19,7 +19,6 @@ def __init__(self, self.context_embedder = context_embedder self.emission_scorer = emission_scorer self.decoder = decoder - self.no_embedder_grad = opt.no_embedder_grad self.config = config self.emb_log = emb_log @@ -31,8 +30,7 @@ def forward( support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, - support_num: torch.Tensor, - is_training: bool = True, + label_mask: torch.Tensor = None, ): """ :param test_reps: (batch_size, test_len, emb_dim) @@ -41,8 +39,7 @@ def forward( :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, multi-label_num) :param support_target: one-hot targets (batch_size, support_size, multi-label_num, num_tags) - :param support_num: (batch_size, 1) - :param is_training: the training mode + :param label_mask: the output label mask :return: """ # calculate emission: shape(batch_size, 1, no_pad_num_tag) @@ -52,19 +49,41 @@ def forward( # as we remove pad label (id = 0), so all label id sub 1. And relu is used to avoid -1 index test_target = torch.nn.functional.relu(test_target - 1) - loss, prediction = torch.FloatTensor(0).to(test_target.device), None + loss = self.decoder.forward(logits=logits, mask=test_output_mask, tags=test_target) + return loss - if is_training: - loss = self.decoder.forward(logits=logits, mask=test_output_mask, tags=test_target) - else: + def decode( + self, + test_reps: torch.Tensor, + test_output_mask: torch.Tensor, + support_reps: torch.Tensor, + support_output_mask: torch.Tensor, + test_target: torch.Tensor, + support_target: torch.Tensor, + label_mask: torch.Tensor = None, + ): + """ + :param test_reps: (batch_size, test_len, emb_dim) + :param test_output_mask: (batch_size, test_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) + :param support_output_mask: (batch_size, support_size, support_len) + :param test_target: index targets (batch_size, multi-label_num) + :param support_target: one-hot targets (batch_size, support_size, multi-label_num, num_tags) + :param label_mask: the output label mask + :return: + """ + # calculate emission: shape(batch_size, 1, no_pad_num_tag) + test_output_mask = torch.ones(test_output_mask.shape[0], 1).to(test_output_mask.device) # for sc, each test has only 1 output + emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target) + logits = emission + + # as we remove pad label (id = 0), so all label id sub 1. And relu is used to avoid -1 index + test_target = torch.nn.functional.relu(test_target - 1) - prediction = self.decoder.decode(logits=logits) - # we block pad label(id=0) before by - 1, here, we add 1 back - prediction = self.add_back_pad_label(prediction) - if is_training: - return loss - else: - return prediction + prediction = self.decoder.decode(logits=logits) + # we block pad label(id=0) before by - 1, here, we add 1 back + prediction = self.add_back_pad_label(prediction) + return prediction def add_back_pad_label(self, predictions: List[List[int]]): for pred in predictions: @@ -93,9 +112,8 @@ def forward( support_output_mask: torch.Tensor, test_target: torch.Tensor, support_target: torch.Tensor, - support_num: torch.Tensor, label_reps: torch.Tensor = None, - is_training: bool = True, + label_mask: torch.Tensor = None, ): """ few-shot sequence labeler using schema information @@ -105,35 +123,56 @@ def forward( :param support_output_mask: (batch_size, support_size, support_len) :param test_target: index targets (batch_size, test_len) :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) - :param support_num: (batch_size, 1) :param label_reps: (batch_size, label_num, emb_dim) - :param is_training: the training mode + :param label_mask: the output label mask :return: """ # calculate emission: shape(batch_size, test_len, no_pad_num_tag) emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target, label_reps) - if not is_training and self.emb_log: - self.emb_log.write('\n'.join(['test_target\t' + '\t'.join(map(str, one_target)) - for one_target in test_target.tolist()]) + '\n') + logits = emission + + # block pad of label_id = 0, so all label id sub 1. And relu is used to avoid -1 index + test_target = torch.nn.functional.relu(test_target - 1) + loss = self.decoder.forward(logits=logits, mask=test_output_mask, tags=test_target) + return loss + + def decode( + self, + test_reps: torch.Tensor, + test_output_mask: torch.Tensor, + support_reps: torch.Tensor, + support_output_mask: torch.Tensor, + test_target: torch.Tensor, + support_target: torch.Tensor, + label_reps: torch.Tensor = None, + label_mask: torch.Tensor = None, + ): + """ + few-shot sequence labeler using schema information + :param test_reps: (batch_size, test_len, emb_dim) + :param test_output_mask: (batch_size, test_len) + :param support_reps: (batch_size, support_size, support_len, emb_dim) + :param support_output_mask: (batch_size, support_size, support_len) + :param test_target: index targets (batch_size, test_len) + :param support_target: one-hot targets (batch_size, support_size, support_len, num_tags) + :param label_reps: (batch_size, label_num, emb_dim) + :param label_mask: the output label mask + :return: + """ + # calculate emission: shape(batch_size, test_len, no_pad_num_tag) + emission = self.emission_scorer(test_reps, support_reps, test_output_mask, support_output_mask, support_target, + label_reps) logits = emission # block pad of label_id = 0, so all label id sub 1. And relu is used to avoid -1 index test_target = torch.nn.functional.relu(test_target - 1) - loss, prediction = torch.FloatTensor([0]).to(test_target.device), None - - if is_training: - loss = self.decoder.forward(logits=logits, mask=test_output_mask, tags=test_target) - else: - prediction = self.decoder.decode(logits=logits) - # we block pad label(id=0) before by - 1, here, we add 1 back - prediction = self.add_back_pad_label(prediction) - if is_training: - return loss - else: - return prediction + prediction = self.decoder.decode(logits=logits) + # we block pad label(id=0) before by - 1, here, we add 1 back + prediction = self.add_back_pad_label(prediction) + return prediction def main(): diff --git a/models/modules/text_classifier.py b/models/modules/text_classifier.py index 0c14379..fad1d78 100644 --- a/models/modules/text_classifier.py +++ b/models/modules/text_classifier.py @@ -21,7 +21,7 @@ def forward(self, :return: """ logits = logits.squeeze(-2) - tags = tags.squeeze(-2) + tags = tags.squeeze(-1) loss = self.criterion(logits, tags) return loss diff --git a/scripts/run_bert_sc+sl.sh b/scripts/run_bert_sc+sl.sh index bb23192..b1eb25d 100644 --- a/scripts/run_bert_sc+sl.sh +++ b/scripts/run_bert_sc+sl.sh @@ -5,8 +5,8 @@ echo eg: source run_bert_siamese.sh 3,4 stanford gpu_list=$1 # Comment one of follow 2 to switch debugging status -do_debug=--do_debug -#do_debug= +#do_debug=--do_debug +do_debug= #restore=--restore_cpt restore= @@ -17,23 +17,19 @@ task=sc\ sl use_schema=--use_schema #use_schema= -#label_num_schema=--label_num_schema -label_num_schema= - # ======= dataset setting ====== dataset_lst=($2 $3) support_shots_lst=(3) -query_shot=8 - -episode=100 +query_shot=4 +episode=20 cross_data_id=0 # for smp # ====== train & test setting ====== -seed_lst=(0) -#seed_lst=(6150 6151 6152) +#seed_lst=(0) +seed_lst=(6150 6151 6152) #lr_lst=(0.000001 0.000005 0.00005) lr_lst=(0.00001) @@ -63,6 +59,9 @@ test_batch_size=2 grad_acc=4 epoch=3 +judge_joint_success=--judge_joint_success +#judge_joint_success= + # ==== model setting ========= # ---- encoder setting ----- @@ -159,7 +158,7 @@ pretrained_vocab_path=/users4/yklai/corpus/BERT/pytorch/chinese_L-12_H-768_A-12/ #pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch #pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch -base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/SmpMetaData/ +base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/smp/ #base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] @@ -192,7 +191,7 @@ do for ple_scale_r in ${ple_scale_r_lst[@]} do # model names - model_name=joint_sc_sl.ga_${grad_acc}_ple_${ple_scale_r}.bs_${train_batch_size}.electra.sim_${similarity}.ems_${emission_normalizer}.${use_schema}${label_num_schema}--fix_dev_spt${do_debug} + model_name=joint_sc_sl.ga_${grad_acc}_ple_${ple_scale_r}.bs_${train_batch_size}.bert.sim_${similarity}.ems_${emission_normalizer}.${use_schema}${judge_joint_success}--fix_dev_spt${do_debug} data_dir=${base_data_dir}${dataset}.${cross_data_id}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode}${use_schema}--fix_dev_spt/ file_mark=${dataset}.shots_${support_shots}.cross_id_${cross_data_id}.m_seed_${seed} @@ -230,7 +229,6 @@ do --test_batch_size ${test_batch_size} \ --context_emb ${embedder} \ ${use_schema} \ - ${label_num_schema} \ --label_reps ${label_reps} \ --projection_layer none \ --emission ${emission} \ @@ -255,7 +253,8 @@ do -t_scl ${trans_scaler} \ --trans_scale_r ${trans_scale_r} \ ${mask_trans} \ - --load_feature > ./joint/${model_name}.DATA.${file_mark}.log + ${judge_joint_success} \ + --load_feature > ./joint/${model_name}.DATA.${file_mark}.log echo [CLI] echo Model: ${model_name} echo Task: ${file_mark} diff --git a/scripts/run_bert_sc.sh b/scripts/run_bert_sc.sh index a5b7df6..3fe4740 100644 --- a/scripts/run_bert_sc.sh +++ b/scripts/run_bert_sc.sh @@ -11,7 +11,7 @@ do_debug=--do_debug #restore=--restore_cpt restore= -task=sc\ sl +task=sc #task=sl use_schema=--use_schema @@ -22,9 +22,8 @@ use_schema=--use_schema dataset_lst=($2 $3) support_shots_lst=(3) -query_shot=8 - -episode=100 +query_shot=4 +episode=50 cross_data_id=0 # for smp @@ -127,10 +126,6 @@ emb_log= #decoder_lst=(sms) #decoder_lst=(crf) #decoder_lst=(crf sms) -#decoder_lst=(mlc) -#decoder_lst=(eamlc) -#decoder_lst=(msmlc) -#decoder_lst=(krnmsmlc) # -------- SC decoder setting -------- @@ -223,8 +218,8 @@ do ${tap_mlp} \ ${emb_log} \ ${do_div_emission} \ - --transition learn \ - --load_feature > ./sclog/${model_name}.DATA.${file_mark}.log + --transition learn > ./sclog/${model_name}.DATA.${file_mark}.log + # --load_feature > ./sclog/${model_name}.DATA.${file_mark}.log echo [CLI] echo Model: ${model_name} echo Task: ${file_mark} diff --git a/scripts/run_electra_sc+sl.sh b/scripts/run_electra_sc+sl.sh index 4c2ddb7..346a249 100644 --- a/scripts/run_electra_sc+sl.sh +++ b/scripts/run_electra_sc+sl.sh @@ -17,17 +17,12 @@ task=sc\ sl use_schema=--use_schema #use_schema= -#label_num_schema=--label_num_schema -label_num_schema= - - # ======= dataset setting ====== dataset_lst=($2 $3) support_shots_lst=(3) -query_shot=8 - -episode=100 +query_shot=4 +episode=50 cross_data_id=0 # for smp @@ -63,6 +58,9 @@ test_batch_size=2 grad_acc=4 epoch=3 +judge_joint_success=--judge_joint_success +#judge_joint_success= + # ==== model setting ========= # ---- encoder setting ----- @@ -154,13 +152,13 @@ transition=learn #pretrained_vocab_path=/users4/yklai/corpus/BERT/pytorch/chinese_L-12_H-768_A-12/vocab.txt # electra small path -pretrained_model_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator -pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator -#pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch -#pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch +#pretrained_model_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +#pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator +pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch +pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch -base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/SmpMetaData/ -#base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ +#base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/smp/ +base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp/ echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] # === Loop for all case and run === @@ -192,7 +190,7 @@ do for ple_scale_r in ${ple_scale_r_lst[@]} do # model names - model_name=joint_sc_sl.ga_${grad_acc}_ple_${ple_scale_r}.bs_${train_batch_size}.electra.sim_${similarity}.ems_${emission_normalizer}.${use_schema}${label_num_schema}--fix_dev_spt${do_debug} + model_name=joint_sc_sl.ga_${grad_acc}_ple_${ple_scale_r}.bs_${train_batch_size}.electra.sim_${similarity}.ems_${emission_normalizer}.${use_schema}${judge_joint_success}--fix_dev_spt${do_debug} data_dir=${base_data_dir}${dataset}.${cross_data_id}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode}${use_schema}--fix_dev_spt/ file_mark=${dataset}.shots_${support_shots}.cross_id_${cross_data_id}.m_seed_${seed} @@ -230,7 +228,6 @@ do --test_batch_size ${test_batch_size} \ --context_emb ${embedder} \ ${use_schema} \ - ${label_num_schema} \ --label_reps ${label_reps} \ --projection_layer none \ --emission ${emission} \ @@ -255,7 +252,8 @@ do -t_scl ${trans_scaler} \ --trans_scale_r ${trans_scale_r} \ ${mask_trans} \ - --load_feature > ./joint/${model_name}.DATA.${file_mark}.log + ${judge_joint_success} \ + --load_feature # > ./joint/${model_name}.DATA.${file_mark}.log echo [CLI] echo Model: ${model_name} echo Task: ${file_mark} diff --git a/scripts/run_electra_sc.sh b/scripts/run_electra_sc.sh index a700646..c9e79df 100644 --- a/scripts/run_electra_sc.sh +++ b/scripts/run_electra_sc.sh @@ -26,9 +26,8 @@ label_num_schema= dataset_lst=($2 $3) support_shots_lst=(3) -query_shot=8 - -episode=100 +query_shot=4 +episode=50 cross_data_id=0 # for smp @@ -131,11 +130,6 @@ emb_log= #decoder_lst=(sms) #decoder_lst=(crf) #decoder_lst=(crf sms) -#decoder_lst=(mlc) -#decoder_lst=(eamlc) -#decoder_lst=(msmlc) -#decoder_lst=(krnmsmlc) - # -------- SC decoder setting -------- @@ -151,7 +145,7 @@ pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discrim #pretrained_model_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch #pretrained_vocab_path=/Users/lyk/Code/model/chinese_electra_small_discriminator_pytorch -base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/SmpMetaData/ +base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/smp_true/ #base_data_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] diff --git a/scripts/run_electra_sl.sh b/scripts/run_electra_sl.sh index 3a0af0c..4bad247 100644 --- a/scripts/run_electra_sl.sh +++ b/scripts/run_electra_sl.sh @@ -8,8 +8,8 @@ echo log file path ../sllog/ gpu_list=$1 # Comment one of follow 2 to switch debugging status -do_debug=--do_debug -#do_debug= +#do_debug=--do_debug +do_debug= task=sl @@ -41,14 +41,15 @@ upper_lr_lst=(0.001) fix_embd_epoch_lst=(-1) -warmup_epoch=1 +#warmup_epoch=1 +warmup_epoch=-1 train_batch_size_lst=(4) test_batch_size=4 grad_acc=2 #grad_acc=4 # if the GPU-memory is not enough, use bigger gradient accumulate -epoch=3 +epoch=1 # ==== model setting ========= # ---- encoder setting ----- @@ -139,7 +140,8 @@ transition=learn pretrained_model_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator pretrained_vocab_path=/users4/yklai/corpus/electra/chinese_electra_small_discriminator -base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/smp/ +#base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/smp_true/ +base_data_dir=/users4/yklai/code/Dialogue/FewShot/MetaDial/data/smp_sl/ echo [START] set jobs on dataset [ ${dataset_lst[@]} ] on gpu [ ${gpu_list} ] @@ -173,7 +175,7 @@ do for cross_data_id in ${cross_data_id_lst[@]} do # model names - model_name=sl.electra.dec_${decoder}.enc_${embedder}.ems_${emission}${do_div_emission}.mlp_${tap_mlp}_random_${tap_random_init_r}.e_scl_${emission_scaler}${ems_scale_r}_${emission_normalizer}.lb_${label_reps}_scl_${ple_scaler}${ple_scale_r}.t_scl_${trans_scaler}${trans_scale_r}_${trans_normalizer}.t_i_${trans_init}.${mask_trans}_.sim_${similarity}.lr_${lr}.up_lr_${upper_lr}.bs_${train_batch_size}_${test_batch_size}.sp_b_${grad_acc}.w_ep_${warmup_epoch}.ep_${epoch}--fix_dev_spt${do_debug} + model_name=2r_sl.electra.dec_${decoder}.enc_${embedder}.ems_${emission}${do_div_emission}.e_scl_${emission_scaler}${ems_scale_r}_${emission_normalizer}.lb_${label_reps}_scl_${ple_scaler}${ple_scale_r}.t_scl_${trans_scaler}${trans_scale_r}_${trans_normalizer}.t_i_${trans_init}.${mask_trans}_.sim_${similarity}.lr_${lr}.up_lr_${upper_lr}.bs_${train_batch_size}_${test_batch_size}.sp_b_${grad_acc}.w_ep_${warmup_epoch}.ep_${epoch}--fix_dev_spt${do_debug} data_dir=${base_data_dir}${dataset}.${cross_data_id}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode}${use_schema}--fix_dev_spt/ file_mark=${dataset}.shots_${support_shots}.cross_id_${cross_data_id}.m_seed_${seed} @@ -236,8 +238,8 @@ do -t_nm ${trans_normalizer} \ -t_scl ${trans_scaler} \ --trans_scale_r ${trans_scale_r} \ - ${mask_trans} \ - --load_feature > ./sllog/${model_name}.DATA.${file_mark}.log + ${mask_trans} > ./sllog/${model_name}.DATA.${file_mark}.log + # --load_feature > ./sllog/${model_name}.DATA.${file_mark}.log echo [CLI] echo Model: ${model_name} echo Task: ${file_mark} diff --git a/utils/device_helper.py b/utils/device_helper.py index 484aae3..b2f2248 100644 --- a/utils/device_helper.py +++ b/utils/device_helper.py @@ -40,10 +40,7 @@ def prepare_model(args, model, device, n_gpu): """ Set device to use """ if args.fp16: model.half() - # TODO: smarter way model.to(device) - for task in model.model_map.keys(): - model.model_map[task].to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) diff --git a/utils/opt.py b/utils/opt.py index 96fed9b..633aa44 100644 --- a/utils/opt.py +++ b/utils/opt.py @@ -131,6 +131,8 @@ def test_args(parser): group = parser.add_argument_group('Test') group.add_argument("--test_batch_size", default=2, type=int, help="Must same to few-shot batch size now") group.add_argument("--test_on_cpu", default=False, action='store_true', help="eval on cpu") + group.add_argument("--judge_joint_success", default=False, action='store_true', + help="set the eval method to judge joint score with joint success") return parser diff --git a/utils/preprocessor.py b/utils/preprocessor.py index 1547a65..ae4b4a7 100644 --- a/utils/preprocessor.py +++ b/utils/preprocessor.py @@ -89,7 +89,7 @@ def data_item2feature_item(self, data_item: DataItem, seg_id: int) -> FeatureIte def get_test_model_input(self, feature_item: FeatureItem) -> ModelInput: if isinstance(feature_item.output_mask_map, dict): - output_mask_map = {key: torch.LongTensor(map_item) for key, map_item in feature_item.output_mask_map.items()} + output_mask_map = {task: torch.LongTensor(map_item) for task, map_item in feature_item.output_mask_map.items()} else: output_mask_map = torch.LongTensor(feature_item.output_mask_map) ret = ModelInput( @@ -108,10 +108,10 @@ def get_support_model_input(self, feature_items: List[FeatureItem], max_support_ nwp_index = self.pad_support_set([f.nwp_index for f in feature_items], [0], max_support_size) input_mask = self.pad_support_set([f.input_mask for f in feature_items], 0, max_support_size) if isinstance(feature_items[0].output_mask_map, dict): - output_mask_map = {key: self.pad_support_set([f.output_mask_map[key] for f in feature_items], - 0, max_support_size) - for key in feature_items[0].output_mask_map.keys()} - output_mask_map = {key: torch.LongTensor(output_mask) for key, output_mask in output_mask_map.items()} + output_mask_map = {task: self.pad_support_set([f.output_mask_map[task] for f in feature_items], + 0, max_support_size) + for task in feature_items[0].output_mask_map.keys()} + output_mask_map = {task: torch.LongTensor(output_mask) for task, output_mask in output_mask_map.items()} else: output_mask_map = self.pad_support_set([f.output_mask_map for f in feature_items], 0, max_support_size) output_mask_map = torch.LongTensor(output_mask_map) @@ -181,10 +181,19 @@ def prepare_support(self, example, max_support_size): support_input = self.get_support_model_input(support_feature_items, max_support_size) return support_feature_items, support_input - def data_item2feature_item(self, data_item: DataItem, seg_id: int, share_feature: bool = True) -> FeatureItem: + def data_item2feature_item(self, data_item: DataItem, seg_id: int, specify_task: str = None) -> FeatureItem: """ get feature_item for bert, steps: 1. do digitalizing 2. make mask """ wp_mark, wp_text = self.tokenizing(data_item) - if share_feature: + if specify_task is not None: + if specify_task == 'sl': # use word-level labels [opt.label_wp is supported by model now.] + label_map = self.get_wp_label(data_item.seq_out, wp_text, wp_mark) if self.opt.label_wp else data_item.seq_out + output_mask_map = [1] * len(label_map) # For sl: it is original tokens; + elif specify_task == 'sc': # use sentence level labels + label_map = data_item.label + output_mask_map = [1] * len(label_map) # For sc: it is labels + else: + raise TypeError('the specify task should be: `sl` or `sc`') + else: label_map = {} output_mask_map = {} if 'sl' in self.opt.task: # use word-level labels [opt.label_wp is supported by model now.] @@ -196,9 +205,6 @@ def data_item2feature_item(self, data_item: DataItem, seg_id: int, share_feature labels = data_item.label label_map['sc'] = labels output_mask_map['sc'] = [1] * len(labels) # For sc: it is labels - else: - label_map = data_item.label - output_mask_map = [1] * len(label_map) tokens = ['[CLS]'] + wp_text + ['[SEP]'] if seg_id == 0 else wp_text + ['[SEP]'] token_ids, segment_ids = self.digitizing_input(tokens=tokens, seg_id=seg_id) @@ -256,7 +262,7 @@ def __call__(self, example, max_support_size, label2id_map) -> (FeatureItem, Mod elif self.opt.label_reps in ['sep', 'sep_sum']: # represent each label independently label_input, label_items = self.prepare_sep_label_feature(label2id_map) else: - raise TypeError('the label_reps should be one of cat & set & set_num') + raise TypeError('the label_reps should be one of cat & set & sep_num') return test_feature_item, test_input, support_feature_items, support_input, label_items, label_input, def prepare_label_feature(self, label2id_map: Dict[str, Dict[str, int]]): @@ -277,7 +283,7 @@ def prepare_label_feature(self, label2id_map: Dict[str, Dict[str, int]]): wp_label.extend(['O'] * len(tmp_wp_text)) wp_mark.extend([0] + [1] * (len(tmp_wp_text) - 1)) label_item_map[task] = self.data_item2feature_item(DataItem(text, label, wp_text, wp_label, wp_mark), 0, - share_feature=False) + task) label_input_map[task] = self.get_test_model_input(label_item_map[task]) return label_input_map, label_item_map @@ -291,8 +297,7 @@ def prepare_sep_label_feature(self, label2id_map): seq_in = self.convert_label_name(label_name) seq_out = ['None'] * len(seq_in) label = ['None'] - label_item_map[task].append(self.data_item2feature_item(DataItem(seq_in, seq_out, label), - 0, share_feature=False)) + label_item_map[task].append(self.data_item2feature_item(DataItem(seq_in, seq_out, label), 0, task)) label_input_map = {task: self.get_support_model_input(label_items, len(label2id_map[task]) - 1) for task, label_items in label_item_map.items()} # no pad, so - 1 return label_input_map, label_item_map @@ -526,13 +531,13 @@ def example2feature( self, example: FewShotExample, max_support_size: int, - label2id: dict, - id2label: dict + label2id_map: Dict[str, Dict[str, int]], + id2label_map: Dict[str, Dict[str, int]] ) -> FewShotFeature: test_feature_item, test_input, support_feature_items, support_input, label_item_map, label_input_map = \ - self.input_builder(example, max_support_size, label2id) + self.input_builder(example, max_support_size, label2id_map) test_target_map, support_target_map = self.output_builder( - test_feature_item, support_feature_items, label2id, max_support_size) + test_feature_item, support_feature_items, label2id_map, max_support_size) ret = FewShotFeature( gid=example.gid, test_gid=example.test_id, @@ -568,8 +573,8 @@ def purify(l): return set([item.replace('B-', '').replace('I-', '') for item in l]) ''' collect all label from: all test set & all support set ''' - all_label_map = {key: [] for key in opt.task} - label2id_map = {key: {'[PAD]': 0} for key in opt.task} # '[PAD]' in first position and id is 0 + all_label_map = {task: [] for task in opt.task} + label2id_map = {task: {'[PAD]': 0} for task in opt.task} # '[PAD]' in first position and id is 0 for example in examples: if 'sl' in opt.task: all_label_map['sl'].append(example.test_data_item.seq_out) @@ -580,7 +585,7 @@ def purify(l): all_label_map['sc'].extend([data_item.label for data_item in example.support_data_items]) ''' collect label word set ''' # sort to make embedding id fixed - label_set_map = {key: sorted(list(purify(set(flatten(item_label))))) for key, item_label in all_label_map.items()} + label_set_map = {task: sorted(list(purify(set(flatten(item_label))))) for task, item_label in all_label_map.items()} ''' build dict ''' if 'sl' in opt.task: label2id_map['sl']['O'] = len(label2id_map['sl']) @@ -595,7 +600,7 @@ def purify(l): label2id_map['sc'][label] = len(label2id_map['sc']) ''' reverse the label2id ''' - id2label_map = {key: dict([(idx, label) for label, idx in label2id_map[key].items()]) for key in opt.task} + id2label_map = {task: dict([(idx, label) for label, idx in label2id_map[task].items()]) for task in opt.task} return label2id_map, id2label_map diff --git a/utils/tester.py b/utils/tester.py index 18df5e2..909c3be 100644 --- a/utils/tester.py +++ b/utils/tester.py @@ -8,6 +8,7 @@ import json import collections import subprocess +import numpy as np from tqdm import tqdm, trange from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -136,7 +137,48 @@ def eval_predictions(self, all_results: List[Dict[str, RawResult]], id2label_map for b_id, fs_batch in all_batches: f1_map = self.eval_one_few_shot_batch(b_id, fs_batch, id2label_map, log_mark) all_scores.append(f1_map) - return {task: sum([item[task] for item in all_scores]) * 1.0 / len(all_scores) for task in id2label_map.keys()} + + res = {task: sum([item[task] for item in all_scores]) * 1.0 / len(all_scores) for task in id2label_map.keys()} + + # TODO: change to more general + if 'sc' in self.opt.task and 'sl' in self.opt.task: + all_intents = [[item['sc'] for item in fs_batch] for b_id, fs_batch in all_batches] + all_slots = [[item['sl'] for item in fs_batch] for b_id, fs_batch in all_batches] + + # prediction is directly the predict ids [pad is removed in decoder] + all_intent_pred_ids = [[result.prediction for result in batch_intents] for batch_intents in all_intents] + all_intent_features = [[result.feature for result in batch_intents] for batch_intents in all_intents] + all_intent_pred_labels = [[id2label_map['sc'][pred_ids[0]] for pred_ids in batch_pred_ids] + for batch_pred_ids in all_intent_pred_ids] + all_intent_target_labels = [[feature.test_feature_item.data_item.label[0] for feature in batch_features] + for batch_features in all_intent_features] + all_intent_pred_labels = np.array(all_intent_pred_labels) + all_intent_target_labels = np.array(all_intent_target_labels) + success = (all_intent_pred_labels == all_intent_target_labels) + # print('success: {} - {}'.format(np.mean(success), success)) + + all_slot_pred_ids = [[result.prediction for result in batch_slots] for batch_slots in all_slots] + all_slot_features = [[result.feature for result in batch_slots] for batch_slots in all_slots] + all_slot_pred_labels = [[[id2label_map['sl'][pred_id] for pred_id in pred_ids] + for pred_ids in batch_pred_ids] for batch_pred_ids in all_slot_pred_ids] + all_slot_target_labels = [[feature.test_feature_item.data_item.seq_out for feature in batch_features] + for batch_features in all_slot_features] + + for b_idx, (b_pred_labels, b_target_labels) in enumerate(zip(all_slot_pred_labels, all_slot_target_labels)): + for i_idx, (pred_labels, target_labels) in enumerate(zip(b_pred_labels, b_target_labels)): + for p_label, t_label in zip(pred_labels, target_labels): + if p_label != t_label: + success[b_idx][i_idx] = False + break + + success = success.astype(float) + success = np.mean(success) + + res['success'] = success + + # print('res: {}'.format(res)) + + return res def eval_one_few_shot_batch(self, b_id, fs_batch: List[Dict[str, RawResult]], id2label_map: Dict[str, Dict[int, str]], log_mark: str) -> Dict[str, float]: @@ -225,6 +267,7 @@ def eval_with_script(self, output_prediction_file): precision = float(std_results[3].replace('%;', '')) recall = float(std_results[5].replace('%;', '')) f1 = float(std_results[7].replace('%;', '').replace("\\n'", '')) + f1 = f1 / 100 # normalize to [0, 1] return precision, recall, f1 def reform_few_shot_batch(self, all_results: List[Dict[str, RawResult]] @@ -252,13 +295,13 @@ def unpack_feature(self, feature: FewShotFeature) -> List[torch.Tensor]: feature.test_input.segment_ids, feature.test_input.nwp_index, feature.test_input.input_mask, - feature.test_input.output_mask, + feature.test_input.output_mask_map, # support feature.support_input.token_ids, feature.support_input.segment_ids, feature.support_input.nwp_index, feature.support_input.input_mask, - feature.support_input.output_mask, + feature.support_input.output_mask_map, # target feature.test_target_map, feature.support_target_map, @@ -332,14 +375,14 @@ def clone_model(self, model, id2label_map): else: sub_new_model = new_model ''' copy weights and stuff ''' - if 'sl' in old_model.opt.task and old_model.model_map['sl'].transition_scorer: + if 'sl' in old_model.opt.task and old_model.seq_labeler_model.transition_scorer: # copy one-by-one because target transition and decoder will be left un-assigned sub_new_model.context_embedder.load_state_dict(old_model.context_embedder.state_dict()) - sub_new_model.model_map['sl'].emission_scorer.load_state_dict( - old_model.model_map['sl'].emission_scorer.state_dict()) + sub_new_model.seq_labeler_model.emission_scorer.load_state_dict( + old_model.seq_labeler_model.emission_scorer.state_dict()) for param_name in ['backoff_trans_mat', 'backoff_start_trans_mat', 'backoff_end_trans_mat']: - sub_new_model.model_map['sl'].transition_scorer.state_dict()[param_name].copy_( - old_model.model_map['sl'].transition_scorer.state_dict()[param_name].data) + sub_new_model.seq_labeler_model.transition_scorer.state_dict()[param_name].copy_( + old_model.seq_labeler_model.transition_scorer.state_dict()[param_name].data) else: sub_new_model.load_state_dict(old_model.state_dict()) diff --git a/utils/trainer.py b/utils/trainer.py index 3f26150..7e2c644 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -78,10 +78,11 @@ def do_train(self, model, train_features, num_train_epochs, do training and dev model selection :param model: :param train_features: + :param num_train_epochs: :param dev_features: - :param dev_id2label: + :param dev_id2label_map: :param test_features: - :param test_id2label: + :param test_id2label_map: :param best_dev_score_now: :return: """ @@ -248,11 +249,13 @@ def make_check_point_(self, model, step): def is_bigger(self, first_score: Dict[str, float], second_score: Dict[str, float], task_weight: Dict[str, float] = None)->bool: - task_lst = first_score.keys() + task_lst = [task_name for task_name in first_score.keys() if task_name != 'success'] if isinstance(second_score, dict): - assert task_lst == second_score.keys(), "the two scores should have same keys" + assert first_score.keys() == second_score.keys(), "the two scores should have same keys" elif isinstance(second_score, int) or isinstance(second_score, float): second_score = {task: second_score for task in task_lst} + if 'success' in first_score: + second_score['success'] = 0 else: raise TypeError('the second_score should be int or a dict to assign every task a int or float value') @@ -261,10 +264,17 @@ def is_bigger(self, first_score: Dict[str, float], second_score: Dict[str, float else: task_weight = {task: 1.0 / len(task_lst) for task in task_lst} - first_avg_score = [score * task_weight[task] for task, score in first_score.items()] - second_avg_score = [score * task_weight[task] for task, score in second_score.items()] + print('task_weight: {}'.format(task_weight)) - return first_avg_score > second_avg_score + first_avg_score = sum([score * task_weight[task] for task, score in first_score.items() if task in task_lst]) + second_avg_score = sum([score * task_weight[task] for task, score in second_score.items() if task in task_lst]) + + if self.opt.judge_joint_success and len(task_lst) > 1: + bigger = first_score['success'] > second_score['success'] + else: + bigger = first_avg_score > second_avg_score + + return bigger def model_selection(self, model, best_score, dev_features, dev_id2label_map, test_features=None, test_id2label_map=None): """ do model selection during training""" @@ -272,7 +282,7 @@ def model_selection(self, model, best_score, dev_features, dev_id2label_map, tes # do dev eval at every dev_interval point and every end of epoch dev_model = self.tester.clone_model(model, dev_id2label_map) # copy reusable params, for a different domain if self.opt.mask_transition and 'sl' in self.opt.task: - dev_model.model_map['sl'].label_mask = self.opt.dev_label_mask.to(self.device) + dev_model.label_mask = self.opt.dev_label_mask.to(self.device) dev_score_map = self.tester.do_test(dev_model, dev_features, dev_id2label_map, log_mark='dev_pred') logger.info(" dev score(F1) = {}".format(dev_score_map)) @@ -292,7 +302,7 @@ def model_selection(self, model, best_score, dev_features, dev_id2label_map, tes # copy reusable params for different domain test_model = self.tester.clone_model(model, test_id2label_map) if self.opt.mask_transition and 'sl' in self.opt.task: - test_model.model_map['sl'].label_mask = self.opt.test_label_mask.to(self.device) + test_model.label_mask = self.opt.test_label_mask.to(self.device) test_score_map = self.tester.do_test(test_model, test_features, test_id2label_map, log_mark='test_pred') logger.info(" test score(F1) = {}".format(test_score_map)) @@ -412,13 +422,13 @@ def unpack_feature(self, feature: FewShotFeature) -> List[torch.Tensor]: feature.test_input.segment_ids, feature.test_input.nwp_index, feature.test_input.input_mask, - feature.test_input.output_mask, + feature.test_input.output_mask_map, # support feature.support_input.token_ids, feature.support_input.segment_ids, feature.support_input.nwp_index, feature.support_input.input_mask, - feature.support_input.output_mask, + feature.support_input.output_mask_map, # target feature.test_target_map, feature.support_target_map, From 22ab669fd6f3097890dd60688519991836fe1e9a Mon Sep 17 00:00:00 2001 From: laiyongkui Date: Wed, 1 Jul 2020 18:41:43 +0800 Subject: [PATCH 4/6] change smp data generator to more general --- scripts/gen_meta_data.sh | 4 +- .../generate_meta_dataset.py | 21 ++- .../meta_dataset_generator/raw_data_loader.py | 124 ++++++++++++------ 3 files changed, 100 insertions(+), 49 deletions(-) diff --git a/scripts/gen_meta_data.sh b/scripts/gen_meta_data.sh index 5350f2b..34182a1 100644 --- a/scripts/gen_meta_data.sh +++ b/scripts/gen_meta_data.sh @@ -44,8 +44,8 @@ use_fix_support=--use_fix_support # ======= default path (for quick distribution) ========== #input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP正式数据集/ #output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP正式数据集/SmpMetaData/ -input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/ -output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp2/ +input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp_data/ +output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp_data/smp_episode_data/ echo \[START\] set jobs on dataset \[ ${dataset_lst[@]} \] # === Loop for all case and run === diff --git a/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py b/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py index e5106b2..ddccb2b 100644 --- a/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py +++ b/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py @@ -229,18 +229,25 @@ def main(): print('Train: Few_shot_data gathered and start to dump data') few_shot_data_statistic(opt, train_meta_data) - """dev & test""" + """dev""" + dev_meta_data = generator.gen_data(raw_data['dev']) + print('Dev: Few_shot_data gathered and start to dump data') + few_shot_data_statistic(opt, dev_meta_data) + + """test""" if opt.use_fix_support: domains = raw_data['support'].keys() - dev_meta_data = {} + test_meta_data = {} for domain in domains: - dev_meta_data[domain] = [{'support': raw_data['support'][domain], 'query': raw_data['dev'][domain]}] + test_meta_data[domain] = [{'support': raw_data['support'][domain], + 'query': raw_data['test'][domain]}] else: - dev_meta_data = generator.gen_data(raw_data['dev']) - + test_meta_data = generator.gen_data(raw_data['test']) print('Dev: Few_shot_data gathered and start to dump data') - few_shot_data_statistic(opt, dev_meta_data) - meta_data = {'train': train_meta_data, 'dev': dev_meta_data, 'test': dev_meta_data} + few_shot_data_statistic(opt, test_meta_data) + + """meta data""" + meta_data = {'train': train_meta_data, 'dev': dev_meta_data, 'test': test_meta_data} else: meta_data = generator.gen_data(raw_data) print('Few_shot_data gathered and start to dump data') diff --git a/scripts/other_tool/meta_dataset_generator/raw_data_loader.py b/scripts/other_tool/meta_dataset_generator/raw_data_loader.py index 1bc977d..76ccb1b 100644 --- a/scripts/other_tool/meta_dataset_generator/raw_data_loader.py +++ b/scripts/other_tool/meta_dataset_generator/raw_data_loader.py @@ -343,11 +343,52 @@ def load_data(self, path: str, with_dev: bool = True): a dict store support data: {"partition/domain name" : { 'seq_ins':[], 'labels'[]:, 'seq_outs':[]}} """ print('Start loading SMP (Chinese) data from: ', path) + train_data = self.load_normal_data(os.path.join(path, 'train')) + dev_data = self.load_normal_data(os.path.join(path, 'dev')) + support_data, test_data = self.load_support_test_data(os.path.join(path, 'test')) + return {'train': train_data, 'dev': dev_data, 'support': support_data, 'test': test_data} + + # all_data = {} + # + # all_files = [os.path.join(path, 'train', filename) for filename in path if filename.endswith('.json')] + # + # for one_file in all_files: + # part_data = self.unpack_train_data(one_file) + # + # for domain, data in part_data.items(): + # if domain not in all_data: + # all_data[domain] = {"seq_ins": [], "seq_outs": [], "labels": []} + # all_data[domain]['seq_ins'].extend(part_data[domain]['seq_ins']) + # all_data[domain]['seq_outs'].extend(part_data[domain]['seq_outs']) + # all_data[domain]['labels'].extend(part_data[domain]['labels']) + # return all_data + + # dev_data, support_data = {}, {} + # if with_dev: + # dev_support_files = [os.path.join(path, 'dev/support', filename) + # for filename in os.listdir(os.path.join(path, 'dev/support')) + # if filename.endswith('.json')] + # support_data, support_text_set = self.unpack_support_data(dev_support_files) + # + # dev_all_files = [os.path.join(path, 'dev/correct', filename) + # for filename in os.listdir(os.path.join(path, 'dev/correct')) + # if filename.endswith('.json')] + # for one_file in dev_all_files: + # part_data = self.unpack_train_data(one_file, support_text_set) + # + # for domain, data in part_data.items(): + # if domain not in all_data: + # dev_data[domain] = {"seq_ins": [], "seq_outs": [], "labels": []} + # dev_data[domain]['seq_ins'].extend(part_data[domain]['seq_ins']) + # dev_data[domain]['seq_outs'].extend(part_data[domain]['seq_outs']) + # dev_data[domain]['labels'].extend(part_data[domain]['labels']) + # + # return {'train': all_data, 'dev': dev_data, 'support': support_data} + + def load_normal_data(self, path: str): all_data = {} - - all_files = [os.path.join(path, 'train', filename) - for filename in os.listdir(os.path.join(path, 'train')) if filename.endswith('.json')] - + all_files = [os.path.join(path, filename) for filename in os.listdir(path) if filename.endswith('.json')] + print('all_files: {} - {}'.format(path, all_files)) for one_file in all_files: part_data = self.unpack_train_data(one_file) @@ -357,28 +398,30 @@ def load_data(self, path: str, with_dev: bool = True): all_data[domain]['seq_ins'].extend(part_data[domain]['seq_ins']) all_data[domain]['seq_outs'].extend(part_data[domain]['seq_outs']) all_data[domain]['labels'].extend(part_data[domain]['labels']) + return all_data + + def load_support_test_data(self, path, support_folder_name='support', test_folder_name='test'): + support_files = [os.path.join(path, support_folder_name, filename) + for filename in os.listdir(os.path.join(path, support_folder_name)) + if filename.endswith('.json')] + support_data, support_text_set = self.unpack_support_data(support_files) - dev_data, support_data = {}, {} - if with_dev: - dev_support_files = [os.path.join(path, 'dev/support', filename) - for filename in os.listdir(os.path.join(path, 'dev/support')) - if filename.endswith('.json')] - support_data, support_text_set = self.unpack_support_data(dev_support_files) + test_files = [os.path.join(path, test_folder_name, filename) + for filename in os.listdir(os.path.join(path, test_folder_name)) + if filename.endswith('.json')] - dev_all_files = [os.path.join(path, 'dev/correct', filename) - for filename in os.listdir(os.path.join(path, 'dev/correct')) - if filename.endswith('.json')] - for one_file in dev_all_files: - part_data = self.unpack_train_data(one_file, support_text_set) + test_data = {} + for one_file in test_files: + part_data = self.unpack_train_data(one_file) - for domain, data in part_data.items(): - if domain not in all_data: - dev_data[domain] = {"seq_ins": [], "seq_outs": [], "labels": []} - dev_data[domain]['seq_ins'].extend(part_data[domain]['seq_ins']) - dev_data[domain]['seq_outs'].extend(part_data[domain]['seq_outs']) - dev_data[domain]['labels'].extend(part_data[domain]['labels']) + for domain, data in part_data.items(): + if domain not in test_data: + test_data[domain] = {"seq_ins": [], "seq_outs": [], "labels": []} + test_data[domain]['seq_ins'].extend(part_data[domain]['seq_ins']) + test_data[domain]['seq_outs'].extend(part_data[domain]['seq_outs']) + test_data[domain]['labels'].extend(part_data[domain]['labels']) - return {'train': all_data, 'dev': dev_data, 'support': support_data} + return support_data, test_data def unpack_support_data(self, all_data_path): support_data = {} @@ -428,24 +471,25 @@ def handle_one_utterance(self, item): text = item['text'].replace(' ', '') seq_in = list(text) - slots = item['slots'] seq_out = ['O'] * len(seq_in) - for slot_key, slot_value in slots.items(): - if not isinstance(slot_value, list): - slot_value = [slot_value] - for s_val in slot_value: - s_val = s_val.replace(' ', '') - if s_val in text: - s_idx = text.index(s_val) - s_end = s_idx + len(s_val) - seq_out[s_idx] = 'B-' + slot_key - for idx in range(s_idx + 1, s_end): - seq_out[idx] = 'I-' + slot_key - else: - print('text: {}'.format(text)) - print(' slot_key: {} - slot_value: {}'.format(slot_key, s_val)) + if 'slots' in item: + slots = item['slots'] + for slot_key, slot_value in slots.items(): + if not isinstance(slot_value, list): + slot_value = [slot_value] + for s_val in slot_value: + s_val = s_val.replace(' ', '') + if s_val in text: + s_idx = text.index(s_val) + s_end = s_idx + len(s_val) + seq_out[s_idx] = 'B-' + slot_key + for idx in range(s_idx + 1, s_end): + seq_out[idx] = 'I-' + slot_key + else: + print('text: {}'.format(text)) + print(' slot_key: {} - slot_value: {}'.format(slot_key, s_val)) - label = item['intent'] + label = item['intent'] if 'intent' in item else 'O' return seq_in, seq_out, label @@ -460,11 +504,11 @@ def handle_one_utterance(self, item): opt.dataset = 'smp' opt.label_type = 'intent' - smp_path = '/Users/lyk/Work/Dialogue/FewShot/SMP/' + smp_path = '/Users/lyk/Work/Dialogue/FewShot/SMP/smp_data/' smp_loader = SMPDataLoader(opt) smp_data = smp_loader.load_data(path=smp_path) - train_data, dev_data, support_data = smp_data['train'], smp_data['dev'], smp_data['support'] + train_data, dev_data, support_data, test_data = smp_data['train'], smp_data['dev'], smp_data['support'], smp_data['test'] print("train: smp domain number: {}".format(len(train_data))) print("train: all smp domain: {}".format(train_data.keys())) From cd2bdd6a459d59b82c015addc2f98abf6b2cb627 Mon Sep 17 00:00:00 2001 From: laiyongkui Date: Wed, 11 Nov 2020 19:39:38 +0800 Subject: [PATCH 5/6] change to load new smp data --- main.py | 5 +++++ models/few_shot_learner.py | 2 -- scripts/other_tool/meta_dataset_generator/raw_data_loader.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 841e023..2f6f3dd 100644 --- a/main.py +++ b/main.py @@ -48,6 +48,9 @@ def get_training_data_and_feature(opt, data_loader, preprocessor): dev_features = preprocessor.construct_feature(dev_examples, dev_max_support_size, dev_label2id_map, dev_id2label_map) logger.info(' Finish prepare train dev features ') + if opt.do_debug: + print('train_label2id_map: {}'.format(train_label2id_map)) + print('dev_label2id_map: {}'.format(dev_label2id_map)) if opt.save_feature: save_feature(opt.train_path.replace('.json', '.saved.pk'), train_features, train_label2id_map, train_id2label_map) @@ -72,6 +75,8 @@ def get_testing_data_feature(opt, data_loader, preprocessor): test_features = preprocessor.construct_feature( test_examples, test_max_support_size, test_label2id_map, test_id2label_map) logger.info(' Finish prepare test feature') + if opt.do_debug: + print('test_label2id_map: {}'.format(test_label2id_map)) if opt.save_feature: save_feature(opt.test_path.replace('.json', '.saved.pk'), test_features, test_label2id_map, test_id2label_map) diff --git a/models/few_shot_learner.py b/models/few_shot_learner.py index c3963d4..899ec58 100644 --- a/models/few_shot_learner.py +++ b/models/few_shot_learner.py @@ -72,8 +72,6 @@ def forward( if self.training: loss = 0. for task in self.opt.task: - if self.opt.do_debug: - print('task: {} - test_target: {} - {}'.format(task, test_target_map[task].size(), test_target_map[task])) if task == 'sl': loss += self.seq_labeler_model(reps_map[task]['test'], test_output_mask_map[task], reps_map[task]['support'], support_output_mask_map[task], diff --git a/scripts/other_tool/meta_dataset_generator/raw_data_loader.py b/scripts/other_tool/meta_dataset_generator/raw_data_loader.py index 76ccb1b..8d26a2a 100644 --- a/scripts/other_tool/meta_dataset_generator/raw_data_loader.py +++ b/scripts/other_tool/meta_dataset_generator/raw_data_loader.py @@ -400,7 +400,7 @@ def load_normal_data(self, path: str): all_data[domain]['labels'].extend(part_data[domain]['labels']) return all_data - def load_support_test_data(self, path, support_folder_name='support', test_folder_name='test'): + def load_support_test_data(self, path, support_folder_name='support', test_folder_name='correct'): support_files = [os.path.join(path, support_folder_name, filename) for filename in os.listdir(os.path.join(path, support_folder_name)) if filename.endswith('.json')] From 4a8f8cb592dc46f50caf6316471d92f207cbe4e0 Mon Sep 17 00:00:00 2001 From: laiyongkui Date: Fri, 20 Nov 2020 11:30:21 +0800 Subject: [PATCH 6/6] fix bug: select true base vector in TapNet --- models/modules/similarity_scorer_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/modules/similarity_scorer_base.py b/models/modules/similarity_scorer_base.py index ca62d05..94207b8 100644 --- a/models/modules/similarity_scorer_base.py +++ b/models/modules/similarity_scorer_base.py @@ -403,9 +403,9 @@ def forward( s.append(s_) vh.append(vh_) s, vh = torch.stack(s, dim=0), torch.stack(vh, dim=0) - s_sum = (s >= 1e-13).sum(dim=1) + s_sum = max((s >= 1e-13).sum(dim=1)) # shape (batch_size, emb_dim, D) - M = torch.stack([torch.transpose(vh[i][s_sum[i]:].clone(), 0, 1) for i in range(batch_size)], dim=0) + M = torch.stack([vh[i][:, s_sum:].clone() for i in range(batch_size)], dim=0) '''get final test data reps''' # project query data embedding with a MLP