diff --git a/src/api/routes/algorithm_router.py b/src/api/routes/algorithm_router.py index 2b6eb8a..951ff21 100644 --- a/src/api/routes/algorithm_router.py +++ b/src/api/routes/algorithm_router.py @@ -17,6 +17,7 @@ from src.repositories.face_feature_repository import FaceFeatureRepository from src.algorithm.face_recognition_algorithm import FaceRecognitionAlgorithm from src.biz.base_face_biz import BaseFaceBiz from src.biz.video_check_biz import VideoCheckBiz +from src.biz.video_face_biz import VideoFaceBiz from src.repositories.video_check_repository import VideoCheckTaskRepository # 创建路由器 @@ -25,6 +26,13 @@ router = APIRouter(prefix="/algorithm", tags=["algorithm"]) # 初始化人脸识别算法 face_algorithm = FaceRecognitionAlgorithm(use_gpu=settings.FACE_USE_GPU, use_npu=settings.FACE_USE_NPU) +# 初始化RTSP专用人脸识别算法 +face_algorithm_for_rtsp = FaceRecognitionAlgorithm(use_gpu=settings.FACE_USE_GPU, use_npu=settings.FACE_USE_NPU) + +# 初始化RTSP专用VideoFaceBiz实例 +video_face_biz = VideoFaceBiz(face_algorithm_for_rtsp.get_app()) + + logger = logging.getLogger(__name__) diff --git a/src/biz/video_face_biz.py b/src/biz/video_face_biz.py new file mode 100644 index 0000000..c1d73b4 --- /dev/null +++ b/src/biz/video_face_biz.py @@ -0,0 +1,67 @@ +""" +视频检查业务类 - RTSP专用 +专门处理RTSP视频流中的人脸识别和检测 +""" + +import cv2 +import numpy as np +from typing import Optional, List, Dict +import time +from insightface.app import FaceAnalysis +from src.biz.base_face_biz import BaseFaceBiz + +class VideoFaceBiz(BaseFaceBiz): + """ + 视频检查业务类 - RTSP专用 + 专门处理RTSP视频流中的人脸识别和检测 + """ + + def __init__(self, face_analysis: FaceAnalysis): + """ + 初始化视频检查业务类 + + 参数: + face_analysis: 已初始化好的FaceAnalysis实例 + """ + super().__init__(face_analysis) + + def draw_detections(self, frame: np.ndarray, results: List[Dict]) -> np.ndarray: + """ + 重写绘制检测结果方法 + 只在检测到黑名单匹配时用红色绘制人脸框 + + 参数: + frame: 原始帧图像 + results: 检测结果列表 + + 返回: + 绘制后的帧图像 + """ + for result in results: + # 只在黑名单匹配时绘制 + if result['is_match']: + bbox = result['bbox'] + + # 使用红色绘制人脸框 + x1, y1, x2, y2 = bbox + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2) + + # 添加简单的匹配信息 + best_match = result['best_match'] + similarity = result['similarity'] + + # 绘制匹配信息 + text = f"{best_match}: {similarity:.3f}" + text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] + + # 绘制文本背景 + cv2.rectangle(frame, (x1, y1 - text_size[1] - 5), + (x1 + text_size[0], y1), (0, 0, 0), -1) + + # 绘制文本 + cv2.putText(frame, text, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + + return frame + + diff --git a/src/rtsp_service_ws_1217.py b/src/rtsp_service_ws_1217.py index 23ba7ae..8f0cb9e 100644 --- a/src/rtsp_service_ws_1217.py +++ b/src/rtsp_service_ws_1217.py @@ -21,11 +21,10 @@ import torch # 导入人脸识别算法 try: - from algorithm.face_recognition_algorithm import FaceRecognitionAlgorithm + from src.api.routes.algorithm_router import video_face_biz print("[INFO] 成功导入人脸识别算法") except Exception as e: - FaceRecognitionAlgorithm = None print(f"[WARN] 无法导入人脸识别算法: {e}") from yolox.tracker.byte_tracker import BYTETracker @@ -124,8 +123,6 @@ sess_suspect = None input_name_sup = None input_name_sus = None -# 人脸识别算法实例 -face_algorithm = None # ========================= @@ -133,7 +130,7 @@ face_algorithm = None # ========================= def init_models_once(): - global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus, face_algorithm + global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus # YOLO if YOLOv8_ONNX is None: @@ -191,30 +188,6 @@ def init_models_once(): input_name_sus = sess_suspect.get_inputs()[0].name print(f"[INFO] 被监护人模型输入: {input_name_sus}") - # ----------------------------- - # 人脸识别算法初始化 - # ----------------------------- - if FACE_RECOGNITION_ENABLED and FaceRecognitionAlgorithm is not None: - try: - # 根据实际情况选择设备(这里使用GPU,如果需要NPU可修改) - face_algorithm = FaceRecognitionAlgorithm(use_gpu=True, use_npu=False) - - # 设置黑名单模式 - face_algorithm.set_list_mode("blacklist") - - # 加载注册人脸 - if os.path.exists(FACE_REGISTER_DIR): - face_algorithm.load_registered_faces(FACE_REGISTER_DIR) - print(f"[INFO] 人脸识别算法初始化完成,加载了 {face_algorithm.get_registered_face_count()} 张注册人脸") - else: - print(f"[WARN] 人脸注册目录不存在: {FACE_REGISTER_DIR},人脸识别功能可能受限") - - except Exception as e: - print(f"[ERROR] 人脸识别算法初始化失败: {e}") - face_algorithm = None - else: - print("[INFO] 人脸识别功能未启用或不可用") - # 只初始化一次 init_models_once() @@ -475,7 +448,6 @@ class FrameProcessorWorker(threading.Thread): self.last_process_ts: Dict[int, float] = {} # 人脸识别相关初始化 - self.face_algorithm = face_algorithm self.face_last_alert = face_last_alert self.current_face_alert = None # 存储当前帧的人脸告警信息 @@ -548,10 +520,10 @@ class FrameProcessorWorker(threading.Thread): # 2) 进行人脸识别(如果启用) face_results = [] face_processing_time = 0 - if self.face_algorithm is not None and FACE_RECOGNITION_ENABLED: + if video_face_biz is not None and FACE_RECOGNITION_ENABLED: try: # 处理当前帧 - 获取人脸识别结果 - processed_frame_for_face, face_results, face_processing_time = self.face_algorithm.process_frame( + processed_frame_for_face, face_results, face_processing_time = video_face_biz.process_frame( frame) # 检查是否有黑名单匹配 @@ -664,8 +636,8 @@ class FrameProcessorWorker(threading.Thread): # 绘制人脸识别结果 - if self.face_algorithm is not None and face_results: - result_img = self.face_algorithm.draw_detections(result_img, face_results) + if video_face_biz is not None and face_results: + result_img = video_face_biz.draw_detections(result_img, face_results) # 添加人脸识别统计信息 match_count = sum(1 for r in face_results if r['is_match'])