QASystemOnMedicalKG/question_classifier.py
2018-10-04 23:28:23 +08:00

195 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# coding: utf-8
# File: question_classifier.py
# Author: lhy<lhy_in_blcu@126.com,https://huangyong.github.io>
# Date: 18-10-4
import os
import ahocorasick
class QuestionClassifier:
def __init__(self):
cur_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1])
# 特征词路径
self.disease_path = os.path.join(cur_dir, 'dict/disease.txt')
self.department_path = os.path.join(cur_dir, 'dict/department.txt')
self.check_path = os.path.join(cur_dir, 'dict/check.txt')
self.drug_path = os.path.join(cur_dir, 'dict/drug.txt')
self.food_path = os.path.join(cur_dir, 'dict/food.txt')
self.producer_path = os.path.join(cur_dir, 'dict/producer.txt')
self.symptom_path = os.path.join(cur_dir, 'dict/symptom.txt')
self.deny_path = os.path.join(cur_dir, 'dict/deny.txt')
# 加载特征词
self.disease_wds= [i.strip() for i in open(self.disease_path) if i.strip()]
self.department_wds= [i.strip() for i in open(self.department_path) if i.strip()]
self.check_wds= [i.strip() for i in open(self.check_path) if i.strip()]
self.drug_wds= [i.strip() for i in open(self.drug_path) if i.strip()]
self.food_wds= [i.strip() for i in open(self.food_path) if i.strip()]
self.producer_wds= [i.strip() for i in open(self.producer_path) if i.strip()]
self.symptom_wds= [i.strip() for i in open(self.symptom_path) if i.strip()]
self.region_words = set(self.department_wds + self.disease_wds + self.check_wds + self.drug_wds + self.food_wds + self.producer_wds + self.symptom_wds)
self.deny_words = [i.strip() for i in open(self.deny_path) if i.strip()]
# 构造领域actree
self.region_tree = self.build_actree(list(self.region_words))
# 构建词典
self.wdtype_dict = self.build_wdtype_dict()
# 问句疑问词
self.symptom_qwds = ['症状', '表征', '现象', '症候', '表现']
self.cause_qwds = ['原因', '为什么', '怎么会', '怎样才', '咋样才', '怎样会', '如何会', '为啥', '为何', '如何', '怎么才会', '会导致', '会造成']
self.acompany_qwds = ['并发症', '并发', '一起发生', '一并发生', '一起出现', '一并出现', '一同发生', '一同出现', '伴随发生', '伴随']
self.food_qwds = ['饮食', '饮用', '', '', '伙食', '膳食', '', '视频', '' ,'忌口', '补品', '保健品']
self.drug_qwds = ['', '药品', '用药']
self.prevent_qwds = ['预防', '防范', '抵制', '抵御', '防止',
'怎样才能不', '怎么才能不', '咋样才能不','咋才能不', '如何才能不',
'怎样才不', '怎么才不', '咋样才不','咋才不', '如何才不',
'怎样才可以不', '怎么才可以不', '咋样才可以不', '咋才可以不', '如何可以不',
'怎样才可不', '怎么才可不', '咋样才可不', '咋才可不', '如何可不']
self.lasttime_qwds = ['周期', '多久', '多长时间', '多少时间', '几天', '几年', '多少天', '多少小时', '几个小时', '多少年']
self.cureway_qwds = ['怎么治疗', '如何医治', '怎么医治', '怎么治', '怎么医', '如何治', '医治方式', '疗法', '咋治', '怎么办']
self.cureprob_qwds = ['多大概率能治好', '多大几率能治好', '治好希望大么']
self.easyget_qwds = ['易感人群', '容易感染']
self.check_qwds = ['检查', '检查项目']
self.belong_qwds = ['属于什么科', '属于', '什么科']
self.cure_qwds = ['治疗什么', '治啥', '治疗啥', '医治啥', '治愈啥', '主治啥', '主治什么', '有什么用', '有何用', '用处', '用途']
print('model init finished ......')
return
'''分类主函数'''
def classify(self, question):
data = {}
medical_dict = self.check_medical(question)
if not medical_dict:
return {}
data['args'] = medical_dict
#收集问句当中所涉及到的实体类型
types = []
for type_ in medical_dict.values():
types += type_
question_type = 'others'
# 症状
if self.check_words(self.symptom_qwds, question) and ('disease' in types):
question_type = 'disease_symptom'
# 原因
if self.check_words(self.cause_qwds, question) and ('disease' in types):
question_type = 'disease_cause'
# 并发症
if self.check_words(self.acompany_qwds, question) and ('disease' in types or 'symptom' in types):
question_type = 'disease_acompany'
# 推荐食品
if self.check_words(self.food_qwds, question) and ('disease' in types or 'symptom' in types):
deny_status = self.check_words(self.deny_words, question)
if deny_status:
question_type = 'disease_not_food'
else:
question_type = 'disease_do_food'
# 推荐药品
if self.check_words(self.drug_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_drug'
if 'symptom' in types:
question_type = 'symptom_disease_drug'
# 症状防御
if self.check_words(self.prevent_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_prevent'
if 'symptom' in types:
question_type = 'symptom_disease_prevent'
# 疾病医疗周期
if self.check_words(self.lasttime_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_lasttime'
if 'symptom' in types:
question_type = 'symptom_disease_lasttime'
# 疾病治疗方式
if self.check_words(self.cureway_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_cureway'
if 'symptom' in types:
question_type = 'symptom_disease_cureway'
# 疾病治愈可能性
if self.check_words(self.cureprob_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_cureprob'
if 'symptom' in types:
question_type = 'symptom_disease_cureprob'
# 疾病易感染人群
if self.check_words(self.easyget_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_easyget'
if 'symptom' in types:
question_type = 'symptom_disease_easyget'
# 疾病接受检查项目
if self.check_words(self.check_qwds, question) and ('disease' in types or 'symptom' in types):
if 'disease' in types:
question_type = 'disease_check'
if 'symptom' in types:
question_type = 'symptom_disease_check'
# 药品治啥病
if self.check_words(self.cure_qwds, question) and ('drug' in types or 'producer' in types):
question_type = 'drug_disease'
data['question_type'] = question_type
return data
'''构造词对应的类型'''
def build_wdtype_dict(self):
wd_dict = dict()
for wd in self.region_words:
wd_dict[wd] = []
if wd in self.disease_wds:
wd_dict[wd].append('disease')
if wd in self.department_wds:
wd_dict[wd].append('department')
if wd in self.check_wds:
wd_dict[wd].append('check')
if wd in self.drug_wds:
wd_dict[wd].append('drug')
if wd in self.food_wds:
wd_dict[wd].append('food')
if wd in self.symptom_wds:
wd_dict[wd].append('symptom')
if wd in self.producer_wds:
wd_dict[wd].append('producer')
return wd_dict
'''构造actree加速过滤'''
def build_actree(self, wordlist):
actree = ahocorasick.Automaton()
for index, word in enumerate(wordlist):
actree.add_word(word, (index, word))
actree.make_automaton()
return actree
'''问句过滤'''
def check_medical(self, question):
region_wds = []
for i in self.region_tree.iter(question):
wd = i[1][1]
region_wds.append(wd)
stop_wds = []
for wd1 in region_wds:
for wd2 in region_wds:
if wd1 in wd2 and wd1 != wd2:
stop_wds.append(wd1)
final_wds = [i for i in region_wds if i not in stop_wds]
final_dict = {i:self.wdtype_dict.get(i) for i in final_wds}
return final_dict
'''基于特征词进行分类'''
def check_words(self, wds, sent):
for wd in wds:
if wd in sent:
return True
return False
if __name__ == '__main__':
handler = QuestionClassifier()
while 1:
question = input('input an question:')
data = handler.classify(question)
print(data)