-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdemo_relation.py
More file actions
71 lines (61 loc) · 2.4 KB
/
demo_relation.py
File metadata and controls
71 lines (61 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from Text_Annotation.annotate import annotate, find_relation
from Text_Annotation.train import train_relation
from keras.models import load_model
import numpy as np
import os
DIR = os.path.dirname(os.path.abspath(__file__))
params = {
'model': 'crf',
'num_units': 128,
'num_layers': 2,
'num_tags': 10,
}
# train_x = np.load(DIR + '/data/train_x.npy')
# train_y = np.load(DIR + '/data/train_y.npy')
#
# # model=train_relation(x=train_x,y=train_y,method='SVM',model_path=DIR+'/model/relation/SVM.model')
# model = train_relation(x=train_x, y=train_y, num_tag=3,
# batchsize=64, epoch=1,
# method='DL', model_path=DIR + '/model/relation/DL.h5')
method = 'DL'
model = load_model(DIR + '/model/relation/DL.h5')
# 不先使用一下keras的模型后续会报计算图错误...
m = model.predict(np.ones([1, 512]))
regulations = [['n', [1]],
['n', [2, 3, 4]],
['v', [5]],
['v', [6, 7, 8]],
['U', [9]]]
regular = [['v', 'n'], ['n', 'v']]
while True:
print('\n使用前请确保有模型。输入文本,quit=离开;\n请输入命令:')
text = input()
if text == 'quit':
print('\n再见!')
break
# text='医疗机构变更单位名称、法定代表人或负责人'
y_predict, output_fb = annotate(text=text,
batchsize=1,
data_process_path=DIR + '/model/%s/model_pos/data_process.pkl' % (params['model']),
model_path=DIR + '/model/%s/model_pos/' % (params['model']),
train=False,
**params)
result = find_relation(text=text,
sentence_vector=output_fb[0],
regulations=regulations,
regular=regular,
annotation=y_predict[0],
model=model,
method='DL',
tags=['null', '主谓', '动宾'])
print('\n实体识别:\n')
for i in result['entities']:
print(i)
print('\n关系抽取:\n')
for i in result['relations']:
print('实体1',i['entity1'])
print('实体2', i['entity2'])
print('关系', i['relation'])
print('\n')
# print('\n分析结果:\n',
# result)