新增监狱人脸识别的biz
This commit is contained in:
@@ -20,6 +20,7 @@ from algorithm.face_recognition_algorithm import FaceRecognitionAlgorithm
|
|||||||
from biz.base_face_biz import BaseFaceBiz
|
from biz.base_face_biz import BaseFaceBiz
|
||||||
from biz.video_check_biz import VideoCheckBiz
|
from biz.video_check_biz import VideoCheckBiz
|
||||||
from biz.video_face_biz import VideoFaceBiz
|
from biz.video_face_biz import VideoFaceBiz
|
||||||
|
from biz.video_face_prison_biz import VideoFacePrisonBiz
|
||||||
from repositories.video_check_repository import VideoCheckTaskRepository
|
from repositories.video_check_repository import VideoCheckTaskRepository
|
||||||
|
|
||||||
# 创建路由器
|
# 创建路由器
|
||||||
@@ -34,6 +35,9 @@ face_algorithm_for_rtsp = FaceRecognitionAlgorithm(use_gpu=settings.FACE_USE_GPU
|
|||||||
# 初始化RTSP专用VideoFaceBiz实例
|
# 初始化RTSP专用VideoFaceBiz实例
|
||||||
video_face_biz = VideoFaceBiz(face_algorithm_for_rtsp.get_app())
|
video_face_biz = VideoFaceBiz(face_algorithm_for_rtsp.get_app())
|
||||||
|
|
||||||
|
# 初始化RTSP专用VideoFacePrisonBiz实例
|
||||||
|
video_face_prison_biz = VideoFacePrisonBiz(face_algorithm_for_rtsp.get_app())
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -827,5 +831,210 @@ async def get_videofacebiz_status():
|
|||||||
logger.error(f"获取VideoFaceBiz状态失败: {e}")
|
logger.error(f"获取VideoFaceBiz状态失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def sync_videofaceprisonbiz_params():
|
||||||
|
"""
|
||||||
|
同步VideoFacePrisonBiz的参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with db_manager.get_session() as session:
|
||||||
|
# 查询人脸识别配置(根据实际表结构)
|
||||||
|
config_records = session.query(SurConfigBase).filter(
|
||||||
|
SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# 构建配置参数字典
|
||||||
|
config_params = {}
|
||||||
|
for record in config_records:
|
||||||
|
if record.config_key and record.config_value:
|
||||||
|
config_params[record.config_key] = record.config_value
|
||||||
|
|
||||||
|
# 配置键映射关系
|
||||||
|
config_mapping = {
|
||||||
|
"face.list_mode": "list_mode",
|
||||||
|
"face.clarity_threshold": "clarity_threshold",
|
||||||
|
"face.min_face_size": "min_face_size",
|
||||||
|
"face.pitch_threshold": "pitch_threshold",
|
||||||
|
"face.yaw_threshold": "yaw_threshold",
|
||||||
|
"face.similarity_threshold": "similarity_threshold"
|
||||||
|
}
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
for config_key, param_name in config_mapping.items():
|
||||||
|
if config_key in config_params:
|
||||||
|
config_value = config_params[config_key]
|
||||||
|
|
||||||
|
# 根据参数类型进行转换和设置
|
||||||
|
if param_name == "list_mode":
|
||||||
|
if config_value in ["0", "1"]:
|
||||||
|
video_face_prison_biz.set_list_mode(config_value)
|
||||||
|
updated_count += 1
|
||||||
|
elif param_name == "clarity_threshold":
|
||||||
|
try:
|
||||||
|
threshold = float(config_value)
|
||||||
|
video_face_prison_biz.set_clarity_threshold(threshold)
|
||||||
|
updated_count += 1
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的清晰度阈值: {config_value}")
|
||||||
|
elif param_name == "min_face_size":
|
||||||
|
try:
|
||||||
|
size = int(config_value)
|
||||||
|
video_face_prison_biz.set_min_face_size(size)
|
||||||
|
updated_count += 1
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的最小人脸尺寸: {config_value}")
|
||||||
|
elif param_name == "pitch_threshold":
|
||||||
|
try:
|
||||||
|
threshold = float(config_value)
|
||||||
|
video_face_prison_biz.set_pitch_threshold(threshold)
|
||||||
|
updated_count += 1
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的俯仰角阈值: {config_value}")
|
||||||
|
elif param_name == "yaw_threshold":
|
||||||
|
try:
|
||||||
|
threshold = float(config_value)
|
||||||
|
video_face_prison_biz.set_yaw_threshold(threshold)
|
||||||
|
updated_count += 1
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的偏航角阈值: {config_value}")
|
||||||
|
elif param_name == "similarity_threshold":
|
||||||
|
try:
|
||||||
|
threshold = float(config_value)
|
||||||
|
video_face_prison_biz.set_similarity_threshold(threshold)
|
||||||
|
updated_count += 1
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的相似度阈值: {config_value}")
|
||||||
|
|
||||||
|
logger.info(f"✅ 同步VideoFacePrisonBiz参数完成,更新了 {updated_count} 个参数")
|
||||||
|
return updated_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 同步VideoFacePrisonBiz参数失败: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def sync_videofaceprisonbiz_blacklist():
|
||||||
|
"""
|
||||||
|
同步VideoFacePrisonBiz的黑名单
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with db_manager.get_session() as session:
|
||||||
|
# 查询启用的黑名单人员
|
||||||
|
blacklist_persons = session.query(SurPersonBlacklist).filter(
|
||||||
|
SurPersonBlacklist.status == 1
|
||||||
|
).all()
|
||||||
|
|
||||||
|
if not blacklist_persons:
|
||||||
|
logger.info("⚠️ 黑名单为空,清空当前黑名单")
|
||||||
|
video_face_prison_biz.set_registered_faces({})
|
||||||
|
return 0
|
||||||
|
|
||||||
|
person_ids = [person.person_id for person in blacklist_persons]
|
||||||
|
|
||||||
|
# 查询对应的人脸特征
|
||||||
|
face_features = session.query(SurFaceFeature).filter(
|
||||||
|
SurFaceFeature.person_id.in_(person_ids),
|
||||||
|
SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION,
|
||||||
|
SurFaceFeature.status == 2 # 计算成功的特征
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# 构建特征字典
|
||||||
|
registered_faces = {}
|
||||||
|
loaded_count = 0
|
||||||
|
|
||||||
|
for feature in face_features:
|
||||||
|
if feature.feature_data:
|
||||||
|
try:
|
||||||
|
# 将bytea转换为numpy数组
|
||||||
|
import numpy as np
|
||||||
|
feature_array = np.frombuffer(feature.feature_data, dtype=np.float32)
|
||||||
|
|
||||||
|
# 使用person_id作为标识符
|
||||||
|
person_name = f"blacklist_{feature.person_id}"
|
||||||
|
registered_faces[person_name] = feature_array
|
||||||
|
loaded_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 解析黑名单人员 {feature.person_id} 的特征数据失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 设置黑名单
|
||||||
|
success = video_face_prison_biz.set_registered_faces(registered_faces)
|
||||||
|
if success:
|
||||||
|
logger.info(f"✅ 同步黑名单完成,加载了 {loaded_count} 个黑名单人员")
|
||||||
|
else:
|
||||||
|
logger.error("❌ 设置黑名单失败")
|
||||||
|
|
||||||
|
return loaded_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 同步黑名单失败: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@router.post("/sync-videofaceprisonbiz-params", summary="同步VideoFacePrisonBiz参数")
|
||||||
|
async def sync_videofaceprisonbiz_params_endpoint():
|
||||||
|
"""
|
||||||
|
同步VideoFacePrisonBiz的参数
|
||||||
|
|
||||||
|
从sur_config表同步参数到VideoFacePrisonBiz实例
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
updated_count = sync_videofaceprisonbiz_params()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"同步参数完成,更新了 {updated_count} 个参数",
|
||||||
|
"updated_count": updated_count
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"同步VideoFacePrisonBiz参数失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"同步参数失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.post("/sync-videofaceprisonbiz-blacklist", summary="同步VideoFacePrisonBiz黑名单")
|
||||||
|
async def sync_videofaceprisonbiz_blacklist_endpoint():
|
||||||
|
"""
|
||||||
|
同步VideoFacePrisonBiz的黑名单
|
||||||
|
|
||||||
|
从sur_person_blacklist表同步黑名单到VideoFacePrisonBiz实例
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loaded_count = sync_videofaceprisonbiz_blacklist()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"同步黑名单完成,加载了 {loaded_count} 个黑名单人员",
|
||||||
|
"loaded_count": loaded_count
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"同步VideoFacePrisonBiz黑名单失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"同步黑名单失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/videofaceprisonbiz-status", summary="获取VideoFacePrisonBiz状态")
|
||||||
|
async def get_videofaceprisonbiz_status():
|
||||||
|
"""
|
||||||
|
获取VideoFacePrisonBiz的当前状态
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
status = {
|
||||||
|
"list_mode": video_face_prison_biz.get_list_mode(),
|
||||||
|
"clarity_threshold": video_face_prison_biz.get_clarity_threshold(),
|
||||||
|
"min_face_size": video_face_prison_biz.get_min_face_size(),
|
||||||
|
"pitch_threshold": video_face_prison_biz.get_pitch_threshold(),
|
||||||
|
"yaw_threshold": video_face_prison_biz.get_yaw_threshold(),
|
||||||
|
"similarity_threshold": video_face_prison_biz.get_similarity_threshold(),
|
||||||
|
"blacklist_count": video_face_prison_biz.get_registered_face_count()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": status
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取VideoFacePrisonBiz状态失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||||||
|
|
||||||
# 导出路由器
|
# 导出路由器
|
||||||
__all__ = ["router"]
|
__all__ = ["router", "sync_videofacebiz_params", "sync_videofacebiz_blacklist", "sync_videofaceprisonbiz_params", "sync_videofaceprisonbiz_blacklist"]
|
||||||
12
app.py
12
app.py
@@ -18,7 +18,8 @@ from fastapi.openapi.docs import (
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from api.routes import face_features
|
from api.routes import face_features
|
||||||
from api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, sync_videofacebiz_blacklist
|
from api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, \
|
||||||
|
sync_videofacebiz_blacklist, sync_videofaceprisonbiz_params, sync_videofaceprisonbiz_blacklist
|
||||||
from api.errors import (
|
from api.errors import (
|
||||||
APIError,
|
APIError,
|
||||||
validation_exception_handler,
|
validation_exception_handler,
|
||||||
@@ -72,6 +73,15 @@ async def lifespan(app: FastAPI):
|
|||||||
else:
|
else:
|
||||||
print("⚠️ RTSP 服务未启用")
|
print("⚠️ RTSP 服务未启用")
|
||||||
|
|
||||||
|
# 自动同步VideoFacePrisonBiz参数和黑名单
|
||||||
|
print("🔄 自动同步VideoFacePrisonBiz参数和黑名单...")
|
||||||
|
try:
|
||||||
|
params_updated = sync_videofaceprisonbiz_params()
|
||||||
|
blacklist_loaded = sync_videofaceprisonbiz_blacklist()
|
||||||
|
print(f"✅ 自动同步完成 - 参数更新: {params_updated}个, 黑名单加载: {blacklist_loaded}个")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 自动同步失败: {e}")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 关闭时
|
# 关闭时
|
||||||
|
|||||||
67
biz/video_face_prison_biz.py
Normal file
67
biz/video_face_prison_biz.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
视频检查业务类 - RTSP专用
|
||||||
|
专门处理RTSP视频流中的人脸识别和检测
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
import time
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
from biz.base_face_biz import BaseFaceBiz
|
||||||
|
|
||||||
|
class VideoFacePrisonBiz(BaseFaceBiz):
|
||||||
|
"""
|
||||||
|
视频检查业务类 - RTSP专用
|
||||||
|
专门处理RTSP视频流中的人脸识别和检测
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, face_analysis: FaceAnalysis):
|
||||||
|
"""
|
||||||
|
初始化视频检查业务类
|
||||||
|
|
||||||
|
参数:
|
||||||
|
face_analysis: 已初始化好的FaceAnalysis实例
|
||||||
|
"""
|
||||||
|
super().__init__(face_analysis)
|
||||||
|
|
||||||
|
def draw_detections(self, frame: np.ndarray, results: List[Dict]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
重写绘制检测结果方法
|
||||||
|
只在检测到黑名单匹配时用红色绘制人脸框
|
||||||
|
|
||||||
|
参数:
|
||||||
|
frame: 原始帧图像
|
||||||
|
results: 检测结果列表
|
||||||
|
|
||||||
|
返回:
|
||||||
|
绘制后的帧图像
|
||||||
|
"""
|
||||||
|
for result in results:
|
||||||
|
# 只在黑名单匹配时绘制
|
||||||
|
if result['is_match']:
|
||||||
|
bbox = result['bbox']
|
||||||
|
|
||||||
|
# 使用红色绘制人脸框
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
||||||
|
|
||||||
|
# 添加简单的匹配信息
|
||||||
|
best_match = result['best_match']
|
||||||
|
similarity = result['similarity']
|
||||||
|
|
||||||
|
# 绘制匹配信息
|
||||||
|
text = f"{best_match}: {similarity:.3f}"
|
||||||
|
text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||||||
|
|
||||||
|
# 绘制文本背景
|
||||||
|
cv2.rectangle(frame, (x1, y1 - text_size[1] - 5),
|
||||||
|
(x1 + text_size[0], y1), (0, 0, 0), -1)
|
||||||
|
|
||||||
|
# 绘制文本
|
||||||
|
cv2.putText(frame, text, (x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
@@ -8,6 +8,22 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class SurConfig(Base):
|
||||||
|
"""配置表"""
|
||||||
|
__tablename__ = "sur_config"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
scope = Column(SmallInteger, nullable=False, comment="作用域:0=全局,1=房间,2=摄像头")
|
||||||
|
target_id = Column(Integer, comment="根据作用域,摄像头id或房间id或其他")
|
||||||
|
description = Column(Text, comment="描述")
|
||||||
|
created_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||||
|
updated_time = Column(DateTime, default=func.now(), comment="更新时间")
|
||||||
|
created_by = Column(Integer, comment="创建人")
|
||||||
|
updated_by = Column(Integer, comment="更新人")
|
||||||
|
config_type = Column(SmallInteger, nullable=False, comment="配置类型:0=人脸识别")
|
||||||
|
config_group_id = Column(Integer, comment="配置组id")
|
||||||
|
|
||||||
|
|
||||||
class SurConfigBase(Base):
|
class SurConfigBase(Base):
|
||||||
"""配置基础表"""
|
"""配置基础表"""
|
||||||
__tablename__ = "sur_config_base"
|
__tablename__ = "sur_config_base"
|
||||||
|
|||||||
Reference in New Issue
Block a user