-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathquery_processor.py
More file actions
308 lines (255 loc) · 11.5 KB
/
query_processor.py
File metadata and controls
308 lines (255 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
#!/usr/bin/env python3
"""
智能查询预处理器
解决中文分词、查询扩展和实体识别问题
"""
import re
import warnings
# 过滤 pkg_resources 弃用警告
warnings.filterwarnings("ignore", category=DeprecationWarning, module=".*pkg_resources.*")
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
import jieba
from typing import List, Dict, Any, Set, Optional
from dataclasses import dataclass
from loguru import logger
@dataclass
class ProcessedQuery:
"""处理后的查询结果"""
original: str
normalized: str
keywords: List[str]
entities: List[str]
expanded_terms: List[str]
query_type: str
confidence: float
class ChineseQueryProcessor:
"""中文查询预处理器"""
def __init__(self):
"""初始化处理器"""
# 初始化jieba分词
jieba.initialize()
# 常见的查询模式 - 🔧 增强版本,更好识别复杂查询
self.question_patterns = {
r'(.+?)是什么': 'definition',
r'(.+?)是啥': 'definition',
r'什么是(.+?)': 'definition',
r'(.+?)怎么样': 'evaluation',
r'(.+?)如何': 'method',
r'(.+?)的作用': 'function',
r'(.+?)的特点': 'feature',
# 🔧 增强:更全面的复杂查询模式
r'(.+?)包含(.+?)方面': 'specific_inquiry',
r'(.+?)具体包含(.+?)': 'specific_inquiry',
r'(.+?)有哪些(.+?)': 'enumeration',
r'(.+?)分为(.+?)': 'classification',
r'(.+?)承诺(.+?)': 'commitment_inquiry',
r'(.+?)哪几个(.+?)': 'enumeration', # 🔧 新增:哪几个
r'(.+?)几个方面(.+?)': 'specific_inquiry', # 🔧 新增:几个方面
r'(.+?)方面的(.+?)': 'specific_inquiry', # 🔧 新增:方面的
r'(.+?)包括(.+?)': 'enumeration', # 🔧 新增:包括
r'(.+?)涉及(.+?)': 'specific_inquiry', # 🔧 新增:涉及
r'(.+?)覆盖(.+?)': 'specific_inquiry', # 🔧 新增:覆盖
r'(.+?)服务(.+?)': 'service_inquiry', # 🔧 新增:服务相关
r'(.+?)保障(.+?)': 'service_inquiry', # 🔧 新增:保障相关
}
# 同义词词典
self.synonyms = {
'是啥': ['是什么', '是', '定义', '含义'],
'怎么样': ['如何', '怎样', '效果'],
'作用': ['功能', '用途', '目的'],
'特点': ['特征', '性质', '属性'],
}
# 实体类型词典
self.entity_types = {
'公司': ['公司', '企业', '集团', '有限公司', '股份有限公司'],
'技术': ['技术', '算法', '方法', '框架', '系统'],
'产品': ['产品', '服务', '平台', '工具'],
'概念': ['概念', '理论', '模型', '原理'],
}
# 停用词
self.stop_words = {
'的', '了', '在', '是', '我', '有', '和', '就', '不', '人',
'都', '一', '一个', '上', '也', '很', '到', '说', '要', '去',
'你', '会', '着', '没有', '看', '好', '自己', '这'
}
def process_query(self, query: str) -> ProcessedQuery:
"""处理查询"""
try:
# 1. 标准化查询
normalized = self._normalize_query(query)
# 2. 识别查询类型
query_type, confidence = self._identify_query_type(normalized)
# 3. 提取关键词
keywords = self._extract_keywords(normalized)
# 4. 实体识别
entities = self._extract_entities(normalized)
# 5. 查询扩展
expanded_terms = self._expand_query(keywords, entities)
return ProcessedQuery(
original=query,
normalized=normalized,
keywords=keywords,
entities=entities,
expanded_terms=expanded_terms,
query_type=query_type,
confidence=confidence
)
except Exception as e:
logger.error(f"查询处理失败: {e}")
# 返回基础处理结果
return ProcessedQuery(
original=query,
normalized=query.strip(),
keywords=[query.strip()],
entities=[],
expanded_terms=[query.strip()],
query_type='unknown',
confidence=0.5
)
def _normalize_query(self, query: str) -> str:
"""标准化查询"""
# 去除多余空格
normalized = re.sub(r'\s+', ' ', query.strip())
# 统一标点符号
normalized = normalized.replace('?', '?').replace('!', '!')
# 处理常见的口语化表达
replacements = {
'是啥': '是什么',
'咋样': '怎么样',
'咋办': '怎么办',
'啥意思': '什么意思',
}
for old, new in replacements.items():
normalized = normalized.replace(old, new)
return normalized
def _identify_query_type(self, query: str) -> tuple[str, float]:
"""识别查询类型"""
# 🔧 修复:添加问候语检测
greetings = ['你好', 'hello', 'hi', '您好', '早上好', '下午好', '晚上好', 'nihao']
if any(greeting in query.lower() for greeting in greetings):
return 'greeting', 0.95
# 🔧 修复:添加无意义查询检测
meaningless = ['测试', 'test', '试试', '看看', '随便', '没事']
if any(word in query.lower() for word in meaningless) and len(query) < 10:
return 'meaningless', 0.9
for pattern, query_type in self.question_patterns.items():
if re.search(pattern, query):
return query_type, 0.9
# 基于关键词判断
if any(word in query for word in ['什么', '是', '定义']):
return 'definition', 0.7
elif any(word in query for word in ['如何', '怎么', '方法']):
return 'method', 0.7
elif any(word in query for word in ['为什么', '原因']):
return 'reason', 0.7
elif any(word in query for word in ['包含', '具体', '哪些', '哪几个']):
return 'specific_inquiry', 0.8 # 🔧 新增:具体询问类型
elif any(word in query for word in ['承诺', '保证', '确保']):
return 'commitment_inquiry', 0.8 # 🔧 新增:承诺询问类型
elif '?' in query or '?' in query:
return 'question', 0.6 # 🔧 新增:问号表示疑问
else:
return 'general', 0.5
def _extract_keywords(self, query: str) -> List[str]:
"""提取关键词"""
# 使用jieba分词
words = list(jieba.cut(query))
# 过滤停用词和短词
keywords = []
for word in words:
word = word.strip()
if (len(word) > 1 and
word not in self.stop_words and
not re.match(r'^[?!。,、;:""''()【】 \t\n\r\f\v]+$', word)):
keywords.append(word)
return keywords
def _extract_entities(self, query: str) -> List[str]:
"""提取实体"""
entities = []
# 基于模式匹配提取实体
patterns = [
r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', # 英文实体
r'([\u4e00-\u9fff]{2,}(?:公司|企业|集团|技术|系统|平台))', # 中文机构/技术实体
r'([\u4e00-\u9fff]{2,})', # 一般中文实体
]
for pattern in patterns:
matches = re.findall(pattern, query)
entities.extend(matches)
# 去重并过滤
unique_entities = []
for entity in entities:
if entity not in unique_entities and len(entity) > 1:
unique_entities.append(entity)
return unique_entities
def _expand_query(self, keywords: List[str], entities: List[str]) -> List[str]:
"""扩展查询词汇"""
expanded = set(keywords + entities)
# 添加同义词
for keyword in keywords:
if keyword in self.synonyms:
expanded.update(self.synonyms[keyword])
# 添加相关词汇
for entity in entities:
# 如果是公司名,添加相关词汇
if any(suffix in entity for suffix in ['公司', '企业', '集团']):
expanded.update(['业务', '服务', '产品'])
# 如果是技术名,添加相关词汇
elif any(suffix in entity for suffix in ['技术', '系统', '平台']):
expanded.update(['应用', '功能', '特点'])
return list(expanded)
def generate_search_queries(self, processed_query: ProcessedQuery) -> List[str]:
"""生成多个搜索查询"""
queries = []
# 1. 原始查询
queries.append(processed_query.original)
# 2. 标准化查询
if processed_query.normalized != processed_query.original:
queries.append(processed_query.normalized)
# 3. 关键词组合
if len(processed_query.keywords) > 1:
queries.append(' '.join(processed_query.keywords))
# 4. 实体查询
for entity in processed_query.entities:
if len(entity) > 2: # 过滤太短的实体
queries.append(entity)
# 5. 扩展词查询
if processed_query.expanded_terms:
# 选择最重要的扩展词
important_terms = [term for term in processed_query.expanded_terms
if len(term) > 2 and term not in processed_query.keywords]
if important_terms:
queries.append(' '.join(important_terms[:3])) # 最多3个扩展词
# 去重
unique_queries = []
for query in queries:
if query not in unique_queries:
unique_queries.append(query)
return unique_queries
def should_use_fuzzy_search(self, processed_query: ProcessedQuery) -> bool:
"""判断是否应该使用模糊搜索"""
# 如果查询很短或者置信度很低,建议使用模糊搜索
return (len(processed_query.original) < 5 or
processed_query.confidence < 0.6 or
len(processed_query.keywords) < 2)
# 使用示例
if __name__ == "__main__":
processor = ChineseQueryProcessor()
# 测试查询
test_queries = [
"铁塔是啥",
"中国铁塔公司",
"nihao",
"什么是人工智能",
"AgenticX框架怎么样"
]
for query in test_queries:
print(f"\n原始查询: {query}")
result = processor.process_query(query)
print(f"标准化: {result.normalized}")
print(f"关键词: {result.keywords}")
print(f"实体: {result.entities}")
print(f"扩展词: {result.expanded_terms}")
print(f"查询类型: {result.query_type} (置信度: {result.confidence})")
search_queries = processor.generate_search_queries(result)
print(f"搜索查询: {search_queries}")
print(f"建议模糊搜索: {processor.should_use_fuzzy_search(result)}")