diff --git a/scripts/gen_meta_data.sh b/scripts/gen_meta_data.sh index ac3bdb8..5cd477c 100644 --- a/scripts/gen_meta_data.sh +++ b/scripts/gen_meta_data.sh @@ -3,16 +3,16 @@ echo usage: pass dataset list as param, split with space echo eg: source gen_mate_data.sh atis -dataset_lst=(smp) +dataset_lst=(snips) #dataset_lst=(toursg) #dataset_lst=(stanford) #dataset_lst=(atis stanford toursg) # ======= size setting ====== -support_shots_lst=(1) -#support_shots_lst=(5) -episode_num=50 # We could over generation and select part of for each epoch -query_shot=4 +#support_shots_lst=(1) +support_shots_lst=(5) +episode_num=100 # We could over generation and select part of for each epoch +query_shot=20 word_piece_data=True way=-1 @@ -34,9 +34,9 @@ allow_override=--allow_override check=--check # ====== train & test setting ====== -split_basis=domain +split_basis=sent_label -#eval_confif_id_lst=(1) # for snips +eval_config_id_lst=(0 1 2 3 4 5 6) # for snips #eval_config_id_lst=(0 1 2 3 4 5) # for toursg label_type_lst=(attribute) @@ -46,8 +46,9 @@ use_fix_support=--use_fix_support # ======= default path (for quick distribution) ========== #input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP正式数据集/ #output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP正式数据集/SmpMetaData/ -input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/SMP_Final_Origin2 -output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp_final/ +input_dir=/Users/niloofarasadi/PycharmProjects/MetaDialog/data/original/snips/ +output_dir=/Users/niloofarasadi/PycharmProjects/MetaDialog/data/k-shot/ +config_dir=/Users/niloofarasadi/PycharmProjects/MetaDialog/data/config/ now=$(date +%s) @@ -59,30 +60,30 @@ do do for support_shots in ${support_shots_lst[@]} do + for eval_config_id in ${eval_config_id_lst[@]} + do echo \[CLI\] generate with \[ ${use_fix_support} \] input_path=${input_dir} mark=try - export OMP_NUM_THREADS=2 # threads num for each task - python3 ./other_tool/meta_dataset_generator/generate_meta_dataset.py \ - --input_path ${input_path} \ - --output_dir ${output_dir} \ - --dataset ${dataset} \ - --episode_num ${episode_num} \ - --support_shots ${support_shots} \ - --query_shot ${query_shot} \ - --way ${way} \ - --task ${task} \ - --seed ${seed} \ - --split_basis ${split_basis} \ - --remove_rate ${remove_rate} \ - ${use_fix_support} \ - --mark ${mark} ${dup_query} ${allow_override} ${check} > ${output_dir}${dataset}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode_num}${use_fix_support}-${now}.log + export OMP_NUM_THREADS=2 # threads num for each task + python3 ./other_tool/meta_dataset_generator/generate_meta_datase.py \ + --input_path ${input_path} \ + --output_dir ${output_dir} \ + --dataset ${dataset} \ + --episode_num ${episode_num} \ + --support_shots ${support_shots} \ + --query_shot ${query_shot} \ + --way ${way} \ + --task ${task} \ + --seed ${seed} \ + --split_basis ${split_basis} \ + --remove_rate ${remove_rate} \ + ${use_fix_support} \ + --eval_config ${config_dir}config${eval_config_id}.json\ + --eval_config_id ${eval_config_id}\ + --mark ${mark} ${dup_query} ${allow_override} ${check} > ${output_dir}${dataset}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode_num}${use_fix_support}-${now}.log + done done done done - - - - - diff --git a/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py b/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py index 0e34fd0..d8233fe 100644 --- a/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py +++ b/scripts/other_tool/meta_dataset_generator/generate_meta_dataset.py @@ -98,13 +98,13 @@ def split_eval_set_with_label(opt, raw_data): """ Both input and output are raw data format """ try: # load config: - with open(opt.eval_labels, 'r') as reader: + with open(opt.eval_config, 'r') as reader: config = json.load(reader) # split labels train_data = {'seq_ins': [], 'labels': [], 'seq_outs': []} dev_data = {'seq_ins': [], 'labels': [], 'seq_outs': []} test_data = {'seq_ins': [], 'labels': [], 'seq_outs': []} - for old_domain_name, old_domain in raw_data: + for old_domain_name, old_domain in raw_data.items(): for ind in range(len(old_domain['label'])): seq_ins, seq_outs, labels = old_domain['seq_ins'][ind], old_domain['seq_outs'][ind], old_domain['labels'][ind] # for label in labels: