-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_debug.py
More file actions
58 lines (52 loc) · 2.65 KB
/
Copy pathmain_debug.py
File metadata and controls
58 lines (52 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
main_train.py은 모델 학습을 진행하는 스크립트입니다.
e.g. https://github.com/wisdomify/wisdomify/blob/main/main_train.py
"""
import torch
import random
import argparse
import numpy as np
import pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from BERT.datamodules import AnmSourceNERDataModule, SourceNERDataModule, AnmNERDataModule
from BERT.labels import ANM_LABELS, SOURCE_LABELS
from BERT.loaders import load_config
from BERT.models import BiLabelNER, MonoLabelNER, BiLabelNERWithBiLSTM
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="mono_label_ner")
parser.add_argument("--ver", type=str, default="test_anm")
args = parser.parse_args()
config = load_config(args.model, args.ver)
config.update(vars(args)) # command-line arguments 도 기록하기!
# --- fix random seeds -- #
torch.manual_seed(config['seed'])
random.seed(config['seed'])
np.random.seed(config['seed'])
tokenizer = BertTokenizer.from_pretrained(config['bert'])
bert = BertModel.from_pretrained(config['bert'])
if config['model'] == BiLabelNER.name:
model = BiLabelNER(bert=bert, lr=float(config['lr']), num_labels_pair=(len(ANM_LABELS), len(SOURCE_LABELS)))
datamodule = AnmSourceNERDataModule(config, tokenizer)
elif config['model'] == BiLabelNERWithBiLSTM.name:
model = BiLabelNERWithBiLSTM(bert=bert, lr=float(config['lr']), num_labels_pair=(len(ANM_LABELS),
len(SOURCE_LABELS)))
datamodule = AnmSourceNERDataModule(config, tokenizer)
elif config['model'] == MonoLabelNER.name:
if config['label_type'] == "anm":
model = MonoLabelNER(bert=bert, lr=float(config['lr']), num_labels=len(ANM_LABELS),
hidden_size=bert.config.hidden_size)
datamodule = AnmNERDataModule(config, tokenizer)
elif config['label_type'] == "source":
model = MonoLabelNER(bert=bert, lr=float(config['lr']), num_labels=len(SOURCE_LABELS),
hidden_size=bert.config.hidden_size)
datamodule = SourceNERDataModule(config, tokenizer)
else:
raise ValueError(f"Invalid label_type: {config['label_type']}")
else:
raise ValueError(f"Invalid model: {config['model']}")
trainer = pl.Trainer(fast_dev_run=True, # 에폭을 한번만 돈다. 모델 저장도 안함. 디버깅으로 제격
gpus=torch.cuda.device_count())
trainer.fit(model=model, datamodule=datamodule)
if __name__ == '__main__':
main()