diff --git a/api/routes/algorithm_router.py b/api/routes/algorithm_router.py index 0cf2e13..2723050 100644 --- a/api/routes/algorithm_router.py +++ b/api/routes/algorithm_router.py @@ -20,6 +20,7 @@ from algorithm.face_recognition_algorithm import FaceRecognitionAlgorithm from biz.base_face_biz import BaseFaceBiz from biz.video_check_biz import VideoCheckBiz from biz.video_face_biz import VideoFaceBiz +from biz.video_face_prison_biz import VideoFacePrisonBiz from repositories.video_check_repository import VideoCheckTaskRepository # 创建路由器 @@ -34,6 +35,9 @@ face_algorithm_for_rtsp = FaceRecognitionAlgorithm(use_gpu=settings.FACE_USE_GPU # 初始化RTSP专用VideoFaceBiz实例 video_face_biz = VideoFaceBiz(face_algorithm_for_rtsp.get_app()) +# 初始化RTSP专用VideoFacePrisonBiz实例 +video_face_prison_biz = VideoFacePrisonBiz(face_algorithm_for_rtsp.get_app()) + logger = logging.getLogger(__name__) @@ -827,5 +831,210 @@ async def get_videofacebiz_status(): logger.error(f"获取VideoFaceBiz状态失败: {e}") raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") + +def sync_videofaceprisonbiz_params(): + """ + 同步VideoFacePrisonBiz的参数 + """ + try: + with db_manager.get_session() as session: + # 查询人脸识别配置(根据实际表结构) + config_records = session.query(SurConfigBase).filter( + SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE + ).all() + + # 构建配置参数字典 + config_params = {} + for record in config_records: + if record.config_key and record.config_value: + config_params[record.config_key] = record.config_value + + # 配置键映射关系 + config_mapping = { + "face.list_mode": "list_mode", + "face.clarity_threshold": "clarity_threshold", + "face.min_face_size": "min_face_size", + "face.pitch_threshold": "pitch_threshold", + "face.yaw_threshold": "yaw_threshold", + "face.similarity_threshold": "similarity_threshold" + } + + updated_count = 0 + + for config_key, param_name in config_mapping.items(): + if config_key in config_params: + config_value = config_params[config_key] + + # 根据参数类型进行转换和设置 + if param_name == "list_mode": + if config_value in ["0", "1"]: + video_face_prison_biz.set_list_mode(config_value) + updated_count += 1 + elif param_name == "clarity_threshold": + try: + threshold = float(config_value) + video_face_prison_biz.set_clarity_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的清晰度阈值: {config_value}") + elif param_name == "min_face_size": + try: + size = int(config_value) + video_face_prison_biz.set_min_face_size(size) + updated_count += 1 + except ValueError: + logger.error(f"无效的最小人脸尺寸: {config_value}") + elif param_name == "pitch_threshold": + try: + threshold = float(config_value) + video_face_prison_biz.set_pitch_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的俯仰角阈值: {config_value}") + elif param_name == "yaw_threshold": + try: + threshold = float(config_value) + video_face_prison_biz.set_yaw_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的偏航角阈值: {config_value}") + elif param_name == "similarity_threshold": + try: + threshold = float(config_value) + video_face_prison_biz.set_similarity_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的相似度阈值: {config_value}") + + logger.info(f"✅ 同步VideoFacePrisonBiz参数完成,更新了 {updated_count} 个参数") + return updated_count + + except Exception as e: + logger.error(f"❌ 同步VideoFacePrisonBiz参数失败: {e}") + return 0 + +def sync_videofaceprisonbiz_blacklist(): + """ + 同步VideoFacePrisonBiz的黑名单 + """ + try: + with db_manager.get_session() as session: + # 查询启用的黑名单人员 + blacklist_persons = session.query(SurPersonBlacklist).filter( + SurPersonBlacklist.status == 1 + ).all() + + if not blacklist_persons: + logger.info("⚠️ 黑名单为空,清空当前黑名单") + video_face_prison_biz.set_registered_faces({}) + return 0 + + person_ids = [person.person_id for person in blacklist_persons] + + # 查询对应的人脸特征 + face_features = session.query(SurFaceFeature).filter( + SurFaceFeature.person_id.in_(person_ids), + SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION, + SurFaceFeature.status == 2 # 计算成功的特征 + ).all() + + # 构建特征字典 + registered_faces = {} + loaded_count = 0 + + for feature in face_features: + if feature.feature_data: + try: + # 将bytea转换为numpy数组 + import numpy as np + feature_array = np.frombuffer(feature.feature_data, dtype=np.float32) + + # 使用person_id作为标识符 + person_name = f"blacklist_{feature.person_id}" + registered_faces[person_name] = feature_array + loaded_count += 1 + + except Exception as e: + logger.error(f"❌ 解析黑名单人员 {feature.person_id} 的特征数据失败: {e}") + continue + + # 设置黑名单 + success = video_face_prison_biz.set_registered_faces(registered_faces) + if success: + logger.info(f"✅ 同步黑名单完成,加载了 {loaded_count} 个黑名单人员") + else: + logger.error("❌ 设置黑名单失败") + + return loaded_count + + except Exception as e: + logger.error(f"❌ 同步黑名单失败: {e}") + return 0 + +@router.post("/sync-videofaceprisonbiz-params", summary="同步VideoFacePrisonBiz参数") +async def sync_videofaceprisonbiz_params_endpoint(): + """ + 同步VideoFacePrisonBiz的参数 + + 从sur_config表同步参数到VideoFacePrisonBiz实例 + """ + try: + updated_count = sync_videofaceprisonbiz_params() + + return { + "success": True, + "message": f"同步参数完成,更新了 {updated_count} 个参数", + "updated_count": updated_count + } + + except Exception as e: + logger.error(f"同步VideoFacePrisonBiz参数失败: {e}") + raise HTTPException(status_code=500, detail=f"同步参数失败: {str(e)}") + +@router.post("/sync-videofaceprisonbiz-blacklist", summary="同步VideoFacePrisonBiz黑名单") +async def sync_videofaceprisonbiz_blacklist_endpoint(): + """ + 同步VideoFacePrisonBiz的黑名单 + + 从sur_person_blacklist表同步黑名单到VideoFacePrisonBiz实例 + """ + try: + loaded_count = sync_videofaceprisonbiz_blacklist() + + return { + "success": True, + "message": f"同步黑名单完成,加载了 {loaded_count} 个黑名单人员", + "loaded_count": loaded_count + } + + except Exception as e: + logger.error(f"同步VideoFacePrisonBiz黑名单失败: {e}") + raise HTTPException(status_code=500, detail=f"同步黑名单失败: {str(e)}") + +@router.get("/videofaceprisonbiz-status", summary="获取VideoFacePrisonBiz状态") +async def get_videofaceprisonbiz_status(): + """ + 获取VideoFacePrisonBiz的当前状态 + """ + try: + status = { + "list_mode": video_face_prison_biz.get_list_mode(), + "clarity_threshold": video_face_prison_biz.get_clarity_threshold(), + "min_face_size": video_face_prison_biz.get_min_face_size(), + "pitch_threshold": video_face_prison_biz.get_pitch_threshold(), + "yaw_threshold": video_face_prison_biz.get_yaw_threshold(), + "similarity_threshold": video_face_prison_biz.get_similarity_threshold(), + "blacklist_count": video_face_prison_biz.get_registered_face_count() + } + + return { + "success": True, + "data": status + } + + except Exception as e: + logger.error(f"获取VideoFacePrisonBiz状态失败: {e}") + raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") + # 导出路由器 -__all__ = ["router"] \ No newline at end of file +__all__ = ["router", "sync_videofacebiz_params", "sync_videofacebiz_blacklist", "sync_videofaceprisonbiz_params", "sync_videofaceprisonbiz_blacklist"] \ No newline at end of file diff --git a/app.py b/app.py index 17377d3..040f573 100644 --- a/app.py +++ b/app.py @@ -18,7 +18,8 @@ from fastapi.openapi.docs import ( from fastapi.staticfiles import StaticFiles from api.routes import face_features -from api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, sync_videofacebiz_blacklist +from api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, \ + sync_videofacebiz_blacklist, sync_videofaceprisonbiz_params, sync_videofaceprisonbiz_blacklist from api.errors import ( APIError, validation_exception_handler, @@ -72,6 +73,15 @@ async def lifespan(app: FastAPI): else: print("⚠️ RTSP 服务未启用") + # 自动同步VideoFacePrisonBiz参数和黑名单 + print("🔄 自动同步VideoFacePrisonBiz参数和黑名单...") + try: + params_updated = sync_videofaceprisonbiz_params() + blacklist_loaded = sync_videofaceprisonbiz_blacklist() + print(f"✅ 自动同步完成 - 参数更新: {params_updated}个, 黑名单加载: {blacklist_loaded}个") + except Exception as e: + print(f"⚠️ 自动同步失败: {e}") + yield # 关闭时 diff --git a/biz/video_face_prison_biz.py b/biz/video_face_prison_biz.py new file mode 100644 index 0000000..bc345c6 --- /dev/null +++ b/biz/video_face_prison_biz.py @@ -0,0 +1,67 @@ +""" +视频检查业务类 - RTSP专用 +专门处理RTSP视频流中的人脸识别和检测 +""" + +import cv2 +import numpy as np +from typing import Optional, List, Dict +import time +from insightface.app import FaceAnalysis +from biz.base_face_biz import BaseFaceBiz + +class VideoFacePrisonBiz(BaseFaceBiz): + """ + 视频检查业务类 - RTSP专用 + 专门处理RTSP视频流中的人脸识别和检测 + """ + + def __init__(self, face_analysis: FaceAnalysis): + """ + 初始化视频检查业务类 + + 参数: + face_analysis: 已初始化好的FaceAnalysis实例 + """ + super().__init__(face_analysis) + + def draw_detections(self, frame: np.ndarray, results: List[Dict]) -> np.ndarray: + """ + 重写绘制检测结果方法 + 只在检测到黑名单匹配时用红色绘制人脸框 + + 参数: + frame: 原始帧图像 + results: 检测结果列表 + + 返回: + 绘制后的帧图像 + """ + for result in results: + # 只在黑名单匹配时绘制 + if result['is_match']: + bbox = result['bbox'] + + # 使用红色绘制人脸框 + x1, y1, x2, y2 = bbox + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2) + + # 添加简单的匹配信息 + best_match = result['best_match'] + similarity = result['similarity'] + + # 绘制匹配信息 + text = f"{best_match}: {similarity:.3f}" + text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] + + # 绘制文本背景 + cv2.rectangle(frame, (x1, y1 - text_size[1] - 5), + (x1 + text_size[0], y1), (0, 0, 0), -1) + + # 绘制文本 + cv2.putText(frame, text, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + + return frame + + diff --git a/models/sur_config.py b/models/sur_config.py index 6b6fc62..4b0d9a1 100644 --- a/models/sur_config.py +++ b/models/sur_config.py @@ -8,6 +8,22 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() +class SurConfig(Base): + """配置表""" + __tablename__ = "sur_config" + + id = Column(Integer, primary_key=True, index=True) + scope = Column(SmallInteger, nullable=False, comment="作用域:0=全局,1=房间,2=摄像头") + target_id = Column(Integer, comment="根据作用域,摄像头id或房间id或其他") + description = Column(Text, comment="描述") + created_time = Column(DateTime, default=func.now(), comment="创建时间") + updated_time = Column(DateTime, default=func.now(), comment="更新时间") + created_by = Column(Integer, comment="创建人") + updated_by = Column(Integer, comment="更新人") + config_type = Column(SmallInteger, nullable=False, comment="配置类型:0=人脸识别") + config_group_id = Column(Integer, comment="配置组id") + + class SurConfigBase(Base): """配置基础表""" __tablename__ = "sur_config_base"