From 4a7cfc8ccfcb30186664f70015f171fb03c575a6 Mon Sep 17 00:00:00 2001 From: zqc <835569504@qq.com> Date: Sun, 21 Dec 2025 16:00:32 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=BC=80=E5=A7=8B=E5=A4=84?= =?UTF-8?q?=E7=90=86=E8=A7=86=E9=A2=91http=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/routes/algorithm_router.py | 318 +++++++++++++++++++++ src/config.py | 1 + src/models/video_check_task.py | 55 ++++ src/repositories/video_check_repository.py | 105 +++++++ 4 files changed, 479 insertions(+) create mode 100644 src/models/video_check_task.py create mode 100644 src/repositories/video_check_repository.py diff --git a/src/api/routes/algorithm_router.py b/src/api/routes/algorithm_router.py index b9eb220..e9507d1 100644 --- a/src/api/routes/algorithm_router.py +++ b/src/api/routes/algorithm_router.py @@ -14,9 +14,12 @@ from sqlalchemy.orm import Session from src.config import settings from src.database.connection import db_manager from src.models.face_feature import SurFaceFeature, FeatureStatus +from src.models.video_check_task import SurVideoCheckTask from src.repositories.face_feature_repository import FaceFeatureRepository from src.face_recognition_algorithm import FaceRecognitionAlgorithm from src.base_face_biz import BaseFaceBiz +from src.video_check_biz import VideoCheckBiz +from src.repositories.video_check_repository import VideoCheckTaskRepository # 创建路由器 router = APIRouter(prefix="/algorithm", tags=["algorithm"]) @@ -292,5 +295,320 @@ async def calculate_single_feature(feature_id: int): raise HTTPException(status_code=500, detail=f"计算单个特征失败: {str(e)}") +def process_video_check_task(task_id: int) -> bool: + """ + 处理单个视频检查任务 + + 参数: + task_id: 任务ID + + 返回: + 是否成功处理 + """ + try: + with db_manager.get_session() as session: + repository = VideoCheckTaskRepository(session) + + # 获取任务记录 + task = repository.get_by_id(task_id) + if not task: + logger.error(f"视频检查任务不存在: {task_id}") + return False + + # 检查是否已经处理完成 + if task.status in [2, 3, 5]: # 完成、取消、失败 + logger.info(f"视频检查任务已处理完成: {task_id}, 状态: {task.status}") + return True + + # 检查是否超时 + if task.status == 1: # 处理中 + if task.start_time: + timeout_duration = timedelta(hours=settings.FACE_CAL_FEATURE_TIMEOUT_HOURS) + if datetime.now() - task.start_time > timeout_duration: + # 超时处理 + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "任务超时"} + ) + logger.warning(f"视频检查任务超时: {task_id}") + return False + else: + # 没有开始时间,重置状态 + repository.update_task_status(task_id, 0) + + # 处理未开始的任务 + if task.status == 0: # 等待 + # 设置状态为处理中 + repository.update_task_status(task_id, 1, start_time=datetime.now()) + logger.info(f"开始处理视频检查任务: {task_id}") + + # 创建VideoCheckBiz实例 + video_biz = VideoCheckBiz(face_algorithm.get_app()) + + # 获取配置参数 + config_dict = repository.get_config_by_group_id(task.config_id) + + # 设置VideoCheckBiz参数 + if config_dict: + param_mapping = { + "face.list_mode": "list_mode", + "face.clarity_threshold": "clarity_threshold", + "face.min_face_size": "min_face_size", + "face.pitch_threshold": "pitch_threshold", + "face.yaw_threshold": "yaw_threshold", + "face.similarity_threshold": "similarity_threshold", + "face.skip_frame": "frame_skip" + } + + for config_key, biz_param in param_mapping.items(): + if config_key in config_dict: + try: + value = config_dict[config_key] + if biz_param == "list_mode": + video_biz.set_list_mode(value) + elif biz_param == "clarity_threshold": + video_biz.set_clarity_threshold(float(value)) + elif biz_param == "min_face_size": + video_biz.set_min_face_size(int(value)) + elif biz_param == "pitch_threshold": + video_biz.set_pitch_threshold(float(value)) + elif biz_param == "yaw_threshold": + video_biz.set_yaw_threshold(float(value)) + elif biz_param == "similarity_threshold": + video_biz.set_similarity_threshold(float(value)) + elif biz_param == "frame_skip": + # frame_skip作为参数传递给方法,不设置到实例 + pass + except (ValueError, TypeError) as e: + logger.warning(f"参数设置失败 {config_key}: {value}, 错误: {str(e)}") + + # 获取目标视频路径 + target_video = repository.get_video_by_id(int(task.target_video_id)) + if not target_video: + logger.error(f"目标视频不存在: {task.target_video_id}") + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "目标视频不存在"} + ) + return False + + target_video_path = os.path.join(settings.VIDEO_RESOURCE_DIR, target_video.video_name_on_server) + + if not os.path.exists(target_video_path): + logger.error(f"目标视频文件不存在: {target_video_path}") + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "目标视频文件不存在"} + ) + return False + + # 提取最佳人脸特征 + frame_skip = int(config_dict.get("face.skip_frame", 10)) + best_feature = video_biz.extract_best_face_from_video(target_video_path, frame_skip) + + if best_feature is None: + logger.error(f"无法从目标视频中提取人脸特征: {task_id}") + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "无法从目标视频中提取人脸特征"} + ) + return False + + # 将特征值保存到数据库 + feature_bytes = best_feature.tobytes() + + # 设置黑名单(使用提取的特征) + video_biz.set_registered_faces({"target_person": best_feature}) + + # 获取待检查的视频列表 + video_ids = [int(vid.strip()) for vid in task.video_id_list.split(",") if vid.strip()] + video_list = repository.get_videos_by_ids(video_ids) + + if not video_list: + logger.error(f"待检查视频列表为空: {task_id}") + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "待检查视频列表为空"} + ) + return False + + # 构建视频路径列表 + video_paths = [] + for video in video_list: + video_path = os.path.join(settings.VIDEO_RESOURCE_DIR, video.video_name_on_server) + if os.path.exists(video_path): + video_paths.append(video_path) + else: + logger.warning(f"视频文件不存在,跳过: {video_path}") + + if not video_paths: + logger.error(f"所有待检查视频文件都不存在: {task_id}") + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "所有待检查视频文件都不存在"} + ) + return False + + # 批量处理视频进行黑名单检测 + results = video_biz.batch_process_videos_with_blacklist_detection( + video_paths, frame_skip, "_checked" + ) + + # 分析结果 + has_blacklist_match = any(result.get('has_blacklist_match', False) for result in results) + total_detections = sum(result.get('detection_count', 0) for result in results) + + # 更新任务状态 + result_status = 1 if has_blacklist_match else 2 # 1=找到,2=未找到 + + repository.update_task_status( + task_id, 2, finish_time=datetime.now(), + result=result_status, result_data={ + "has_blacklist_match": has_blacklist_match, + "total_detections": total_detections, + "video_results": results, + "target_video": target_video_path, + "checked_videos": len(video_paths) + }, + feature_data=feature_bytes + ) + + logger.info(f"视频检查任务完成: {task_id}, 结果: {'找到' if has_blacklist_match else '未找到'}") + return True + + return True + + except Exception as e: + logger.error(f"处理视频检查任务时发生异常: {task_id}, 错误: {str(e)}") + try: + with db_manager.get_session() as session: + repository = VideoCheckTaskRepository(session) + repository.update_task_status( + task_id, 5, finish_time=datetime.now(), + result=0, result_data={"error": str(e)} + ) + except Exception: + pass + return False + + +async def process_pending_video_checks(): + """ + 异步处理所有待处理的视频检查任务 + """ + try: + with db_manager.get_session() as session: + repository = VideoCheckTaskRepository(session) + + # 查找需要处理的任务(status=0) + pending_tasks = repository.get_pending_tasks() + + # 查找可能超时的任务(status=1且超时) + timeout_tasks = repository.get_timeout_tasks(settings.FACE_CAL_FEATURE_TIMEOUT_HOURS) + + total_pending = len(pending_tasks) + total_timeout = len(timeout_tasks) + + logger.info(f"发现待处理视频检查任务: {total_pending}个, 超时任务: {total_timeout}个") + + # 处理超时任务 + for task in timeout_tasks: + repository.update_task_status( + task.id, 5, finish_time=datetime.now(), + result=0, result_data={"error": "任务超时"} + ) + + if timeout_tasks: + session.commit() + + # 处理待处理任务 + processed_count = 0 + + for task in pending_tasks: + processed_count += 1 + process_video_check_task(task.id) + + # 每处理5个任务输出一次进度 + if processed_count % 5 == 0: + logger.info(f"视频检查处理进度: {processed_count}/{total_pending}") + + logger.info(f"视频检查任务处理完成: 共处理 {processed_count} 个任务") + + except Exception as e: + logger.error(f"批量处理视频检查任务时发生异常: {str(e)}") + + +@router.post("/start-video-check", summary="开始视频检查") +async def start_video_check(background_tasks: BackgroundTasks): + """ + 开始处理视频检查任务 + + 此接口会: + 1. 查找所有status为0的视频检查任务 + 2. 将状态改为1,设置开始时间 + 3. 提取目标视频中最佳人脸特征 + 4. 进行黑名单检测 + 5. 对于status为1且超时的任务,标记为失败 + + 返回处理结果统计 + """ + try: + # 在后台任务中异步处理,避免阻塞请求 + background_tasks.add_task(process_pending_video_checks) + + return { + "success": True, + "message": "收到视频检查请求" + } + + except Exception as e: + logger.error(f"启动视频检查失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"启动视频检查失败: {str(e)}") + + +@router.get("/video-check-status", summary="获取视频检查状态") +async def get_video_check_status(): + """ + 获取当前视频检查任务的状态统计 + """ + try: + with db_manager.get_session() as session: + repository = VideoCheckTaskRepository(session) + + # 获取统计信息 + total_tasks = len(repository.session.query(SurVideoCheckTask).all()) + pending_tasks = len(repository.get_pending_tasks()) + processing_tasks = len(repository.get_processing_tasks()) + + # 获取已完成任务的统计 + completed_tasks = repository.session.query(SurVideoCheckTask).filter( + SurVideoCheckTask.status == 2 + ).all() + + found_count = sum(1 for task in completed_tasks if task.result == 1) + not_found_count = sum(1 for task in completed_tasks if task.result == 2) + failed_count = len(repository.session.query(SurVideoCheckTask).filter( + SurVideoCheckTask.status == 5 + ).all()) + + return { + "success": True, + "data": { + "total_tasks": total_tasks, + "pending_tasks": pending_tasks, + "processing_tasks": processing_tasks, + "completed_tasks": len(completed_tasks), + "found_count": found_count, + "not_found_count": not_found_count, + "failed_count": failed_count, + "timeout_hours": settings.FACE_CAL_FEATURE_TIMEOUT_HOURS + } + } + + except Exception as e: + logger.error(f"获取视频检查状态失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取视频检查状态失败: {str(e)}") + + # 导出路由器 __all__ = ["router"] \ No newline at end of file diff --git a/src/config.py b/src/config.py index 8c8a718..b722354 100644 --- a/src/config.py +++ b/src/config.py @@ -54,6 +54,7 @@ class Settings(BaseSettings): FACE_MODEL_VERSION: int = 0 #insight_face_buffalo_l FACE_USE_GPU: bool = True FACE_USE_NPU: bool = False + SUR_CONFIG_TYPE_FACE: int = 0 # JWT配置(预留) SECRET_KEY: str = "your-secret-key-here-change-in-production" diff --git a/src/models/video_check_task.py b/src/models/video_check_task.py new file mode 100644 index 0000000..2e67d7f --- /dev/null +++ b/src/models/video_check_task.py @@ -0,0 +1,55 @@ +""" +视频检查任务模型 +""" + +from datetime import datetime +from typing import Optional, Dict, Any +from sqlalchemy import Column, Integer, String, DateTime, Text, JSON +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class SurVideoCheckTask(Base): + """视频检查任务表""" + __tablename__ = "sur_video_check_task" + + id = Column(Integer, primary_key=True, comment="主键") + video_id_list = Column(Text, comment="视频id list,用,分隔") + target_video_id = Column(Text, comment="被查询人员所在video id") + config_id = Column(Integer, comment="配置id") + feature_data = Column(Text, comment="特征向量") + status = Column(Integer, comment="任务状态:0=等待,1=正在处理,2=完成,3=取消,5=失败") + progress = Column(Integer, default=0, comment="进度,1000满") + result = Column(Integer, comment="结果:0=未出结果,1=找到,2=未找到") + result_data = Column(JSON, comment="结果数据") + created_time = Column(DateTime, default=datetime.now, comment="创建时间") + created_by = Column(Integer, comment="创建人") + start_time = Column(DateTime, comment="任务开始时间") + finish_time = Column(DateTime, comment="任务结束时间") + + +class SurVideo(Base): + """视频表""" + __tablename__ = "sur_video" + + id = Column(Integer, primary_key=True, comment="主键") + video_name = Column(Text, comment="文件名") + created_time = Column(DateTime, default=datetime.now, comment="创建时间") + video_name_on_server = Column(Text, comment="服务器上的文件名") + + +class SurConfigBase(Base): + """配置基础表""" + __tablename__ = "sur_config_base" + + id = Column(Integer, primary_key=True, comment="主键") + config_type = Column(Integer, nullable=False, comment="配置类型:0=人脸识别") + group_id = Column(Integer, nullable=False, comment="组id") + config_key = Column(Text, nullable=False, comment="键") + config_value = Column(Text, comment="值") + description = Column(Text, comment="备注") + created_time = Column(DateTime, comment="创建时间") + updated_time = Column(DateTime, comment="修改时间") + created_by = Column(Integer, comment="创建人") + updated_by = Column(Integer, comment="修改人") \ No newline at end of file diff --git a/src/repositories/video_check_repository.py b/src/repositories/video_check_repository.py new file mode 100644 index 0000000..cecb5bf --- /dev/null +++ b/src/repositories/video_check_repository.py @@ -0,0 +1,105 @@ +""" +视频检查任务数据访问层 +""" + +from typing import List, Optional, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy import and_ + +from src.models.video_check_task import SurVideoCheckTask, SurVideo, SurConfigBase +from src.config import settings + + +class VideoCheckTaskRepository: + """视频检查任务数据访问类""" + + def __init__(self, session: Session): + self.session = session + + def get_pending_tasks(self) -> List[SurVideoCheckTask]: + """获取待处理的任务(status=0)""" + return self.session.query(SurVideoCheckTask).filter( + SurVideoCheckTask.status == 0 + ).all() + + def get_by_id(self, task_id: int) -> Optional[SurVideoCheckTask]: + """根据ID获取任务""" + return self.session.query(SurVideoCheckTask).filter( + SurVideoCheckTask.id == task_id + ).first() + + def update_task_status(self, task_id: int, status: int, + start_time=None, finish_time=None, + progress=None, result=None, result_data=None, + feature_data=None) -> bool: + """更新任务状态""" + try: + task = self.get_by_id(task_id) + if not task: + return False + + task.status = status + if start_time: + task.start_time = start_time + if finish_time: + task.finish_time = finish_time + if progress is not None: + task.progress = progress + if result is not None: + task.result = result + if result_data is not None: + task.result_data = result_data + if feature_data is not None: + task.feature_data = feature_data + + self.session.commit() + return True + except Exception: + self.session.rollback() + return False + + def get_config_by_group_id(self, group_id: int) -> Dict[str, str]: + """根据组ID获取配置""" + configs = self.session.query(SurConfigBase).filter( + and_( + SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE, + SurConfigBase.group_id == group_id + ) + ).all() + + config_dict = {} + for config in configs: + config_dict[config.config_key] = config.config_value + + return config_dict + + def get_video_by_id(self, video_id: int) -> Optional[SurVideo]: + """根据ID获取视频信息""" + return self.session.query(SurVideo).filter( + SurVideo.id == video_id + ).first() + + def get_videos_by_ids(self, video_ids: List[int]) -> List[SurVideo]: + """根据ID列表获取视频信息""" + return self.session.query(SurVideo).filter( + SurVideo.id.in_(video_ids) + ).all() + + def get_processing_tasks(self) -> List[SurVideoCheckTask]: + """获取正在处理的任务(status=1)""" + return self.session.query(SurVideoCheckTask).filter( + SurVideoCheckTask.status == 1 + ).all() + + def get_timeout_tasks(self, timeout_hours: int) -> List[SurVideoCheckTask]: + """获取超时的任务""" + from datetime import datetime, timedelta + + timeout_time = datetime.now() - timedelta(hours=timeout_hours) + + return self.session.query(SurVideoCheckTask).filter( + and_( + SurVideoCheckTask.status == 1, + SurVideoCheckTask.start_time < timeout_time + ) + ).all() \ No newline at end of file