Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions scripts/gen_meta_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
echo usage: pass dataset list as param, split with space
echo eg: source gen_mate_data.sh atis

dataset_lst=(smp)
dataset_lst=(snips)
#dataset_lst=(toursg)
#dataset_lst=(stanford)
#dataset_lst=(atis stanford toursg)

# ======= size setting ======
support_shots_lst=(1)
#support_shots_lst=(5)
episode_num=50 # We could over generation and select part of for each epoch
query_shot=4
#support_shots_lst=(1)
support_shots_lst=(5)
episode_num=100 # We could over generation and select part of for each epoch
query_shot=20
word_piece_data=True
way=-1

Expand All @@ -34,9 +34,9 @@ allow_override=--allow_override
check=--check

# ====== train & test setting ======
split_basis=domain
split_basis=sent_label

#eval_confif_id_lst=(1) # for snips
eval_config_id_lst=(0 1 2 3 4 5 6) # for snips
#eval_config_id_lst=(0 1 2 3 4 5) # for toursg
label_type_lst=(attribute)

Expand All @@ -46,8 +46,9 @@ use_fix_support=--use_fix_support
# ======= default path (for quick distribution) ==========
#input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP正式数据集/
#output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP正式数据集/SmpMetaData/
input_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/SMP_Final_Origin2
output_dir=/Users/lyk/Work/Dialogue/FewShot/SMP/smp_final/
input_dir=/Users/niloofarasadi/PycharmProjects/MetaDialog/data/original/snips/
output_dir=/Users/niloofarasadi/PycharmProjects/MetaDialog/data/k-shot/
config_dir=/Users/niloofarasadi/PycharmProjects/MetaDialog/data/config/

now=$(date +%s)

Expand All @@ -59,30 +60,30 @@ do
do
for support_shots in ${support_shots_lst[@]}
do
for eval_config_id in ${eval_config_id_lst[@]}
do
echo \[CLI\] generate with \[ ${use_fix_support} \]
input_path=${input_dir}
mark=try

export OMP_NUM_THREADS=2 # threads num for each task
python3 ./other_tool/meta_dataset_generator/generate_meta_dataset.py \
--input_path ${input_path} \
--output_dir ${output_dir} \
--dataset ${dataset} \
--episode_num ${episode_num} \
--support_shots ${support_shots} \
--query_shot ${query_shot} \
--way ${way} \
--task ${task} \
--seed ${seed} \
--split_basis ${split_basis} \
--remove_rate ${remove_rate} \
${use_fix_support} \
--mark ${mark} ${dup_query} ${allow_override} ${check} > ${output_dir}${dataset}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode_num}${use_fix_support}-${now}.log
export OMP_NUM_THREADS=2 # threads num for each task
python3 ./other_tool/meta_dataset_generator/generate_meta_datase.py \
--input_path ${input_path} \
--output_dir ${output_dir} \
--dataset ${dataset} \
--episode_num ${episode_num} \
--support_shots ${support_shots} \
--query_shot ${query_shot} \
--way ${way} \
--task ${task} \
--seed ${seed} \
--split_basis ${split_basis} \
--remove_rate ${remove_rate} \
${use_fix_support} \
--eval_config ${config_dir}config${eval_config_id}.json\
--eval_config_id ${eval_config_id}\
--mark ${mark} ${dup_query} ${allow_override} ${check} > ${output_dir}${dataset}.spt_s_${support_shots}.q_s_${query_shot}.ep_${episode_num}${use_fix_support}-${now}.log
done
done
done
done





Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def split_eval_set_with_label(opt, raw_data):
""" Both input and output are raw data format """
try:
# load config:
with open(opt.eval_labels, 'r') as reader:
with open(opt.eval_config, 'r') as reader:
config = json.load(reader)
# split labels
train_data = {'seq_ins': [], 'labels': [], 'seq_outs': []}
dev_data = {'seq_ins': [], 'labels': [], 'seq_outs': []}
test_data = {'seq_ins': [], 'labels': [], 'seq_outs': []}
for old_domain_name, old_domain in raw_data:
for old_domain_name, old_domain in raw_data.items():
for ind in range(len(old_domain['label'])):
seq_ins, seq_outs, labels = old_domain['seq_ins'][ind], old_domain['seq_outs'][ind], old_domain['labels'][ind]
# for label in labels:
Expand Down