diff --git a/src/api/routes/algorithm_router.py b/src/api/routes/algorithm_router.py index 951ff21..3804649 100644 --- a/src/api/routes/algorithm_router.py +++ b/src/api/routes/algorithm_router.py @@ -13,6 +13,8 @@ from src.config import settings from src.database.connection import db_manager from src.models.face_feature import FeatureStatus from src.models.video_check_task import SurVideoCheckTask +from src.models.sur_config import SurConfigBase +from src.models.sur_person import SurPersonBlacklist, SurFaceFeature from src.repositories.face_feature_repository import FaceFeatureRepository from src.algorithm.face_recognition_algorithm import FaceRecognitionAlgorithm from src.biz.base_face_biz import BaseFaceBiz @@ -621,5 +623,209 @@ async def get_video_check_status(): raise HTTPException(status_code=500, detail=f"获取视频检查状态失败: {str(e)}") +def sync_videofacebiz_params(): + """ + 同步VideoFaceBiz的参数 + """ + try: + with db_manager.get_session() as session: + # 查询人脸识别配置(根据实际表结构) + config_records = session.query(SurConfigBase).filter( + SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE + ).all() + + # 构建配置参数字典 + config_params = {} + for record in config_records: + if record.config_key and record.config_value: + config_params[record.config_key] = record.config_value + + # 配置键映射关系 + config_mapping = { + "face.list_mode": "list_mode", + "face.clarity_threshold": "clarity_threshold", + "face.min_face_size": "min_face_size", + "face.pitch_threshold": "pitch_threshold", + "face.yaw_threshold": "yaw_threshold", + "face.similarity_threshold": "similarity_threshold" + } + + updated_count = 0 + + for config_key, param_name in config_mapping.items(): + if config_key in config_params: + config_value = config_params[config_key] + + # 根据参数类型进行转换和设置 + if param_name == "list_mode": + if config_value in ["0", "1"]: + video_face_biz.set_list_mode(config_value) + updated_count += 1 + elif param_name == "clarity_threshold": + try: + threshold = float(config_value) + video_face_biz.set_clarity_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的清晰度阈值: {config_value}") + elif param_name == "min_face_size": + try: + size = int(config_value) + video_face_biz.set_min_face_size(size) + updated_count += 1 + except ValueError: + logger.error(f"无效的最小人脸尺寸: {config_value}") + elif param_name == "pitch_threshold": + try: + threshold = float(config_value) + video_face_biz.set_pitch_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的俯仰角阈值: {config_value}") + elif param_name == "yaw_threshold": + try: + threshold = float(config_value) + video_face_biz.set_yaw_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的偏航角阈值: {config_value}") + elif param_name == "similarity_threshold": + try: + threshold = float(config_value) + video_face_biz.set_similarity_threshold(threshold) + updated_count += 1 + except ValueError: + logger.error(f"无效的相似度阈值: {config_value}") + + logger.info(f"✅ 同步VideoFaceBiz参数完成,更新了 {updated_count} 个参数") + return updated_count + + except Exception as e: + logger.error(f"❌ 同步VideoFaceBiz参数失败: {e}") + return 0 + +def sync_videofacebiz_blacklist(): + """ + 同步VideoFaceBiz的黑名单 + """ + try: + with db_manager.get_session() as session: + # 查询启用的黑名单人员 + blacklist_persons = session.query(SurPersonBlacklist).filter( + SurPersonBlacklist.status == 1 + ).all() + + if not blacklist_persons: + logger.info("⚠️ 黑名单为空,清空当前黑名单") + video_face_biz.set_registered_faces({}) + return 0 + + person_ids = [person.person_id for person in blacklist_persons] + + # 查询对应的人脸特征 + face_features = session.query(SurFaceFeature).filter( + SurFaceFeature.person_id.in_(person_ids), + SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION, + SurFaceFeature.status == 2 # 计算成功的特征 + ).all() + + # 构建特征字典 + registered_faces = {} + loaded_count = 0 + + for feature in face_features: + if feature.feature_data: + try: + # 将bytea转换为numpy数组 + import numpy as np + feature_array = np.frombuffer(feature.feature_data, dtype=np.float32) + + # 使用person_id作为标识符 + person_name = f"blacklist_{feature.person_id}" + registered_faces[person_name] = feature_array + loaded_count += 1 + + except Exception as e: + logger.error(f"❌ 解析黑名单人员 {feature.person_id} 的特征数据失败: {e}") + continue + + # 设置黑名单 + success = video_face_biz.set_registered_faces(registered_faces) + if success: + logger.info(f"✅ 同步黑名单完成,加载了 {loaded_count} 个黑名单人员") + else: + logger.error("❌ 设置黑名单失败") + + return loaded_count + + except Exception as e: + logger.error(f"❌ 同步黑名单失败: {e}") + return 0 + +@router.post("/sync-videofacebiz-params", summary="同步VideoFaceBiz参数") +async def sync_videofacebiz_params_endpoint(): + """ + 同步VideoFaceBiz的参数 + + 从sur_config表同步参数到VideoFaceBiz实例 + """ + try: + updated_count = sync_videofacebiz_params() + + return { + "success": True, + "message": f"同步参数完成,更新了 {updated_count} 个参数", + "updated_count": updated_count + } + + except Exception as e: + logger.error(f"同步VideoFaceBiz参数失败: {e}") + raise HTTPException(status_code=500, detail=f"同步参数失败: {str(e)}") + +@router.post("/sync-videofacebiz-blacklist", summary="同步VideoFaceBiz黑名单") +async def sync_videofacebiz_blacklist_endpoint(): + """ + 同步VideoFaceBiz的黑名单 + + 从sur_person_blacklist表同步黑名单到VideoFaceBiz实例 + """ + try: + loaded_count = sync_videofacebiz_blacklist() + + return { + "success": True, + "message": f"同步黑名单完成,加载了 {loaded_count} 个黑名单人员", + "loaded_count": loaded_count + } + + except Exception as e: + logger.error(f"同步VideoFaceBiz黑名单失败: {e}") + raise HTTPException(status_code=500, detail=f"同步黑名单失败: {str(e)}") + +@router.get("/videofacebiz-status", summary="获取VideoFaceBiz状态") +async def get_videofacebiz_status(): + """ + 获取VideoFaceBiz的当前状态 + """ + try: + status = { + "list_mode": video_face_biz.get_list_mode(), + "clarity_threshold": video_face_biz.get_clarity_threshold(), + "min_face_size": video_face_biz.get_min_face_size(), + "pitch_threshold": video_face_biz.get_pitch_threshold(), + "yaw_threshold": video_face_biz.get_yaw_threshold(), + "similarity_threshold": video_face_biz.get_similarity_threshold(), + "blacklist_count": video_face_biz.get_registered_face_count() + } + + return { + "success": True, + "data": status + } + + except Exception as e: + logger.error(f"获取VideoFaceBiz状态失败: {e}") + raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") + # 导出路由器 __all__ = ["router"] \ No newline at end of file diff --git a/src/app.py b/src/app.py index c0e6aa8..c005491 100644 --- a/src/app.py +++ b/src/app.py @@ -18,7 +18,7 @@ from fastapi.openapi.docs import ( from fastapi.staticfiles import StaticFiles from src.api.routes import face_features -from src.api.routes.algorithm_router import router as algorithm_router +from src.api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, sync_videofacebiz_blacklist from src.api.errors import ( APIError, validation_exception_handler, @@ -60,6 +60,15 @@ async def lifespan(app: FastAPI): rtsp_server.start() # 将 RTSP 服务实例保存到应用状态 app.state.rtsp_server = rtsp_server + + # 自动同步VideoFaceBiz参数和黑名单 + print("🔄 自动同步VideoFaceBiz参数和黑名单...") + try: + params_updated = sync_videofacebiz_params() + blacklist_loaded = sync_videofacebiz_blacklist() + print(f"✅ 自动同步完成 - 参数更新: {params_updated}个, 黑名单加载: {blacklist_loaded}个") + except Exception as e: + print(f"⚠️ 自动同步失败: {e}") else: print("⚠️ RTSP 服务未启用") diff --git a/src/config.py b/src/config.py index dac011a..48a4bb6 100644 --- a/src/config.py +++ b/src/config.py @@ -57,6 +57,7 @@ class Settings(BaseSettings): FACE_USE_GPU: bool = True FACE_USE_NPU: bool = False SUR_CONFIG_TYPE_FACE: int = 0 + SUR_CONFIG_SCOPE_GLOBAL: int = 0 # JWT配置(预留) SECRET_KEY: str = "your-secret-key-here-change-in-production" diff --git a/src/models/sur_config.py b/src/models/sur_config.py new file mode 100644 index 0000000..6b6fc62 --- /dev/null +++ b/src/models/sur_config.py @@ -0,0 +1,24 @@ +""" +配置表模型 +""" + +from sqlalchemy import Column, Integer, String, Text, DateTime, func, SmallInteger +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class SurConfigBase(Base): + """配置基础表""" + __tablename__ = "sur_config_base" + + id = Column(Integer, primary_key=True, index=True) + config_type = Column(SmallInteger, nullable=False, comment="配置类型:0=人脸识别") + group_id = Column(Integer, nullable=False, comment="组id") + config_key = Column(Text, nullable=False, comment="键") + config_value = Column(Text, comment="值") + description = Column(Text, comment="备注") + created_time = Column(DateTime, comment="创建时间") + updated_time = Column(DateTime, comment="修改时间") + created_by = Column(Integer, comment="创建人") + updated_by = Column(Integer, comment="修改人") \ No newline at end of file diff --git a/src/models/sur_person.py b/src/models/sur_person.py new file mode 100644 index 0000000..bdf119b --- /dev/null +++ b/src/models/sur_person.py @@ -0,0 +1,34 @@ +""" +人员相关表模型 +""" + +from sqlalchemy import Column, Integer, String, DateTime, func, Text +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class SurPersonBlacklist(Base): + """人员黑名单表""" + __tablename__ = "sur_person_blacklist" + + id = Column(Integer, primary_key=True, index=True) + person_id = Column(Integer, nullable=False, comment="人员ID") + status = Column(Integer, nullable=False, default=1, comment="状态:0=禁用,1=启用") + created_time = Column(DateTime, default=func.now(), comment="创建时间") + updated_time = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间") + + +class SurFaceFeature(Base): + """人脸特征表""" + __tablename__ = "sur_face_feature" + + id = Column(Integer, primary_key=True, index=True) + person_id = Column(Integer, nullable=False, comment="人员ID") + feature_type = Column(Integer, comment="模型版本") + feature_data = Column(Text, comment="特征值") + created_time = Column(DateTime, default=func.now(), comment="创建时间") + pic_id = Column(String(255), comment="图片ID") + status = Column(Integer, default=0, comment="人脸特征值计算状态:0=未开始,1=计算中,2=计算成功,3=计算失败") + start_time = Column(DateTime, comment="特征计算开始时间") + finish_time = Column(DateTime, comment="特征计算结束时间") \ No newline at end of file