67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
import logging
|
|
from typing import List, Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TextClassifier:
|
|
"""文本分类器"""
|
|
|
|
def __init__(self):
|
|
"""初始化文本分类器"""
|
|
logger.info("初始化文本分类器")
|
|
# 这里可以加载预训练模型
|
|
# 示例中使用简单的规则分类
|
|
|
|
def classify(self, texts: List[str], params: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
|
"""分类文本
|
|
|
|
Args:
|
|
texts: 文本列表
|
|
params: 分类参数
|
|
|
|
Returns:
|
|
分类结果列表
|
|
"""
|
|
if params is None:
|
|
params = {}
|
|
|
|
threshold = params.get("threshold", 0.5)
|
|
|
|
results = []
|
|
for text in texts:
|
|
# 简单的规则分类示例
|
|
classification = self._simple_classify(text)
|
|
results.append({
|
|
"text": text,
|
|
"label": classification["label"],
|
|
"confidence": classification["confidence"]
|
|
})
|
|
|
|
return results
|
|
|
|
def _simple_classify(self, text: str) -> Dict[str, Any]:
|
|
"""简单的文本分类实现
|
|
|
|
Args:
|
|
text: 待分类的文本
|
|
|
|
Returns:
|
|
分类结果
|
|
"""
|
|
# 简单的规则分类
|
|
text_lower = text.lower()
|
|
|
|
if any(keyword in text_lower for keyword in ["技术", "科技", "编程", "代码"]):
|
|
return {"label": "技术", "confidence": 0.9}
|
|
elif any(keyword in text_lower for keyword in ["体育", "足球", "篮球", "运动"]):
|
|
return {"label": "体育", "confidence": 0.85}
|
|
elif any(keyword in text_lower for keyword in ["电影", "音乐", "娱乐", "游戏"]):
|
|
return {"label": "娱乐", "confidence": 0.8}
|
|
elif any(keyword in text_lower for keyword in ["美食", "餐厅", "烹饪", "食物"]):
|
|
return {"label": "美食", "confidence": 0.85}
|
|
elif any(keyword in text_lower for keyword in ["政治", "新闻", "政府", "政策"]):
|
|
return {"label": "政治", "confidence": 0.9}
|
|
else:
|
|
return {"label": "其他", "confidence": 0.7}
|