Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 51 additions & 38 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,58 @@ def get_training_data_and_feature(opt, data_loader, preprocessor):
""" prepare feature and data """
if opt.load_feature:
try:
train_features, train_label2id, train_id2label = load_feature(opt.train_path.replace('.json', '.saved.pk'))
dev_features, dev_label2id, dev_id2label = load_feature(opt.dev_path.replace('.json', '.saved.pk'))
train_features, train_label2id_map, train_id2label_map = \
load_feature(opt.train_path.replace('.json', '.saved.pk'))
dev_features, dev_label2id_map, dev_id2label_map = load_feature(opt.dev_path.replace('.json', '.saved.pk'))
except FileNotFoundError:
opt.load_feature, opt.save_feature = False, True # Not a saved feature file yet, make it
train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label =\
train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map =\
get_training_data_and_feature(opt, data_loader, preprocessor)
opt.load_feature, opt.save_feature = True, False # restore option
else:
train_examples, train_max_len, train_max_support_size = data_loader.load_data(path=opt.train_path)
dev_examples, dev_max_len, dev_max_support_size = data_loader.load_data(path=opt.dev_path)
train_label2id, train_id2label = make_dict(opt, train_examples)
dev_label2id, dev_id2label = make_dict(opt, dev_examples)
train_label2id_map, train_id2label_map = make_dict(opt, train_examples)
dev_label2id_map, dev_id2label_map = make_dict(opt, dev_examples)
logger.info(' Finish train dev prepare dict ')
train_features = preprocessor.construct_feature(
train_examples, train_max_support_size, train_label2id, train_id2label)
dev_features = preprocessor.construct_feature(dev_examples, dev_max_support_size, dev_label2id, dev_id2label)
train_examples, train_max_support_size, train_label2id_map, train_id2label_map)
dev_features = preprocessor.construct_feature(dev_examples, dev_max_support_size,
dev_label2id_map, dev_id2label_map)
logger.info(' Finish prepare train dev features ')

if opt.do_debug:
print('train_label2id_map: {}'.format(train_label2id_map))
print('dev_label2id_map: {}'.format(dev_label2id_map))
if opt.save_feature:
save_feature(opt.train_path.replace('.json', '.saved.pk'), train_features, train_label2id, train_id2label)
save_feature(opt.dev_path.replace('.json', '.saved.pk'), dev_features, dev_label2id, dev_id2label)
return train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label
save_feature(opt.train_path.replace('.json', '.saved.pk'),
train_features, train_label2id_map, train_id2label_map)
save_feature(opt.dev_path.replace('.json', '.saved.pk'), dev_features, dev_label2id_map, dev_id2label_map)
return train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map


def get_testing_data_feature(opt, data_loader, preprocessor):
""" prepare feature and data """
if opt.load_feature:
try:
test_features, test_label2id, test_id2label = load_feature(opt.test_path.replace('.json', '.saved.pk'))
test_features, test_label2id_map, test_id2label_map = \
load_feature(opt.test_path.replace('.json', '.saved.pk'))
except FileNotFoundError:
opt.load_feature, opt.save_feature = False, True # Not a saved feature file yet, make it
test_features, test_label2id, test_id2label = get_testing_data_feature(opt, data_loader, preprocessor)
test_features, test_label2id_map, test_id2label_map = get_testing_data_feature(opt, data_loader, preprocessor)
opt.load_feature, opt.save_feature = True, False # restore option
else:
test_examples, test_max_len, test_max_support_size = data_loader.load_data(path=opt.test_path)
test_label2id, test_id2label = make_dict(opt, test_examples)
test_label2id_map, test_id2label_map = make_dict(opt, test_examples)
logger.info(' Finish prepare test dict')
test_features = preprocessor.construct_feature(
test_examples, test_max_support_size, test_label2id, test_id2label)
test_examples, test_max_support_size, test_label2id_map, test_id2label_map)
logger.info(' Finish prepare test feature')
if opt.do_debug:
print('test_label2id_map: {}'.format(test_label2id_map))
if opt.save_feature:
save_feature(opt.test_path.replace('.json', '.saved.pk'), test_features, test_label2id, test_id2label)
return test_features, test_label2id, test_id2label
save_feature(opt.test_path.replace('.json', '.saved.pk'),
test_features, test_label2id_map, test_id2label_map)
return test_features, test_label2id_map, test_id2label_map


def main():
Expand All @@ -92,30 +101,32 @@ def main():
data_loader = FewShotRawDataLoader(opt)
preprocessor = make_preprocessor(opt)
if opt.do_train:
train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label = \
train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map = \
get_training_data_and_feature(opt, data_loader, preprocessor)

if opt.mask_transition and opt.task == 'sl':
opt.train_label_mask = make_label_mask(opt, opt.train_path, train_label2id)
opt.dev_label_mask = make_label_mask(opt, opt.dev_path, dev_label2id)
if opt.mask_transition and 'sl' in opt.task:
opt.train_label_mask = make_label_mask(opt, opt.train_path, train_label2id_map['sl'])
opt.dev_label_mask = make_label_mask(opt, opt.dev_path, dev_label2id_map['sl'])
else:
train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label = [None] * 6
if opt.mask_transition and opt.task == 'sl':
train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map = \
[None] * 6
if opt.mask_transition and 'sl' in opt.task:
opt.train_label_mask = None
opt.dev_label_mask = None

if opt.do_predict:
test_features, test_label2id, test_id2label = get_testing_data_feature(opt, data_loader, preprocessor)
if opt.mask_transition and opt.task == 'sl':
opt.test_label_mask = make_label_mask(opt, opt.test_path, test_label2id)
test_features, test_label2id_map, test_id2label_map = get_testing_data_feature(opt, data_loader, preprocessor)
if opt.mask_transition and 'sl' in opt.task:
opt.test_label_mask = make_label_mask(opt, opt.test_path, test_label2id_map['sl'])
else:
test_features, test_label2id, test_id2label = [None] * 3
if opt.mask_transition and opt.task == 'sl':
test_features, test_label2id_map, test_id2label_map = [None] * 3
if opt.mask_transition and 'sl' in opt.task:
opt.test_label_mask = None

''' over fitting test '''
if opt.do_overfit_test:
test_features, test_label2id, test_id2label = train_features, train_label2id, train_id2label
dev_features, dev_label2id, dev_id2label = train_features, train_label2id, train_id2label
test_features, test_label2id_map, test_id2label_map = train_features, train_label2id_map, train_id2label_map
dev_features, dev_label2id_map, dev_id2label_map = train_features, train_label2id_map, train_id2label_map

''' select training & testing mode '''
trainer_class = SchemaFewShotTrainer if opt.use_schema else FewShotTrainer
Expand All @@ -130,9 +141,10 @@ def main():
opt = training_model.opt
opt.warmup_epoch = -1
else:
training_model = make_model(opt, config={'num_tags': len(train_label2id)})
training_model = make_model(opt, config={
'num_tags': len(train_label2id_map['sl']) if 'sl' in train_label2id_map else 0})
training_model = prepare_model(opt, training_model, device, n_gpu)
if opt.mask_transition and opt.task == 'sl':
if opt.mask_transition and 'sl' in opt.task:
training_model.label_mask = opt.train_label_mask.to(device)
# prepare a set of name subseuqence/mark to use different learning rate for part of params
upper_structures = [
Expand All @@ -153,12 +165,13 @@ def main():
print('========== Warmup training finished! ==========')
trained_model, best_dev_score, test_score = trainer.do_train(
training_model, train_features, opt.num_train_epochs,
dev_features, dev_id2label, test_features, test_id2label, best_dev_score_now=0)
dev_features, dev_id2label_map, test_features, test_id2label_map, best_dev_score_now=0)

# decide the best model
if not opt.eval_when_train: # select best among check points
best_model, best_score, test_score_then = trainer.select_model_from_check_point(
train_id2label, dev_features, dev_id2label, test_features, test_id2label, rm_cpt=opt.delete_checkpoint)
train_id2label_map, dev_features, dev_id2label_map, test_features, test_id2label_map,
rm_cpt=opt.delete_checkpoint)
else: # best model is selected during training
best_model = trained_model
logger.info('dev:{}, test:{}'.format(best_dev_score, test_score))
Expand All @@ -173,17 +186,17 @@ def main():
if not opt.saved_model_path or not os.path.exists(opt.saved_model_path):
raise ValueError("No model trained and no trained model file given (or not exist)")
if os.path.isdir(opt.saved_model_path): # eval a list of checkpoints
max_score = eval_check_points(opt, tester, test_features, test_id2label, device)
max_score = eval_check_points(opt, tester, test_features, test_id2label_map, device)
print('best check points scores:{}'.format(max_score))
exit(0)
else:
best_model = load_model(opt.saved_model_path)

''' test the best model '''
testing_model = tester.clone_model(best_model, test_id2label) # copy reusable params
if opt.mask_transition and opt.task == 'sl':
testing_model = tester.clone_model(best_model, test_id2label_map) # copy reusable params
if opt.mask_transition and 'sl' in opt.task:
testing_model.label_mask = opt.test_label_mask.to(device)
test_score = tester.do_test(testing_model, test_features, test_id2label, log_mark='test_pred')
test_score = tester.do_test(testing_model, test_features, test_id2label_map, log_mark='test_pred')
logger.info('test:{}'.format(test_score))
print('test:{}'.format(test_score))

Expand Down
Loading