完成黑名单、参数同步
This commit is contained in:
@@ -13,6 +13,8 @@ from src.config import settings
|
||||
from src.database.connection import db_manager
|
||||
from src.models.face_feature import FeatureStatus
|
||||
from src.models.video_check_task import SurVideoCheckTask
|
||||
from src.models.sur_config import SurConfigBase
|
||||
from src.models.sur_person import SurPersonBlacklist, SurFaceFeature
|
||||
from src.repositories.face_feature_repository import FaceFeatureRepository
|
||||
from src.algorithm.face_recognition_algorithm import FaceRecognitionAlgorithm
|
||||
from src.biz.base_face_biz import BaseFaceBiz
|
||||
@@ -621,5 +623,209 @@ async def get_video_check_status():
|
||||
raise HTTPException(status_code=500, detail=f"获取视频检查状态失败: {str(e)}")
|
||||
|
||||
|
||||
def sync_videofacebiz_params():
|
||||
"""
|
||||
同步VideoFaceBiz的参数
|
||||
"""
|
||||
try:
|
||||
with db_manager.get_session() as session:
|
||||
# 查询人脸识别配置(根据实际表结构)
|
||||
config_records = session.query(SurConfigBase).filter(
|
||||
SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE
|
||||
).all()
|
||||
|
||||
# 构建配置参数字典
|
||||
config_params = {}
|
||||
for record in config_records:
|
||||
if record.config_key and record.config_value:
|
||||
config_params[record.config_key] = record.config_value
|
||||
|
||||
# 配置键映射关系
|
||||
config_mapping = {
|
||||
"face.list_mode": "list_mode",
|
||||
"face.clarity_threshold": "clarity_threshold",
|
||||
"face.min_face_size": "min_face_size",
|
||||
"face.pitch_threshold": "pitch_threshold",
|
||||
"face.yaw_threshold": "yaw_threshold",
|
||||
"face.similarity_threshold": "similarity_threshold"
|
||||
}
|
||||
|
||||
updated_count = 0
|
||||
|
||||
for config_key, param_name in config_mapping.items():
|
||||
if config_key in config_params:
|
||||
config_value = config_params[config_key]
|
||||
|
||||
# 根据参数类型进行转换和设置
|
||||
if param_name == "list_mode":
|
||||
if config_value in ["0", "1"]:
|
||||
video_face_biz.set_list_mode(config_value)
|
||||
updated_count += 1
|
||||
elif param_name == "clarity_threshold":
|
||||
try:
|
||||
threshold = float(config_value)
|
||||
video_face_biz.set_clarity_threshold(threshold)
|
||||
updated_count += 1
|
||||
except ValueError:
|
||||
logger.error(f"无效的清晰度阈值: {config_value}")
|
||||
elif param_name == "min_face_size":
|
||||
try:
|
||||
size = int(config_value)
|
||||
video_face_biz.set_min_face_size(size)
|
||||
updated_count += 1
|
||||
except ValueError:
|
||||
logger.error(f"无效的最小人脸尺寸: {config_value}")
|
||||
elif param_name == "pitch_threshold":
|
||||
try:
|
||||
threshold = float(config_value)
|
||||
video_face_biz.set_pitch_threshold(threshold)
|
||||
updated_count += 1
|
||||
except ValueError:
|
||||
logger.error(f"无效的俯仰角阈值: {config_value}")
|
||||
elif param_name == "yaw_threshold":
|
||||
try:
|
||||
threshold = float(config_value)
|
||||
video_face_biz.set_yaw_threshold(threshold)
|
||||
updated_count += 1
|
||||
except ValueError:
|
||||
logger.error(f"无效的偏航角阈值: {config_value}")
|
||||
elif param_name == "similarity_threshold":
|
||||
try:
|
||||
threshold = float(config_value)
|
||||
video_face_biz.set_similarity_threshold(threshold)
|
||||
updated_count += 1
|
||||
except ValueError:
|
||||
logger.error(f"无效的相似度阈值: {config_value}")
|
||||
|
||||
logger.info(f"✅ 同步VideoFaceBiz参数完成,更新了 {updated_count} 个参数")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 同步VideoFaceBiz参数失败: {e}")
|
||||
return 0
|
||||
|
||||
def sync_videofacebiz_blacklist():
|
||||
"""
|
||||
同步VideoFaceBiz的黑名单
|
||||
"""
|
||||
try:
|
||||
with db_manager.get_session() as session:
|
||||
# 查询启用的黑名单人员
|
||||
blacklist_persons = session.query(SurPersonBlacklist).filter(
|
||||
SurPersonBlacklist.status == 1
|
||||
).all()
|
||||
|
||||
if not blacklist_persons:
|
||||
logger.info("⚠️ 黑名单为空,清空当前黑名单")
|
||||
video_face_biz.set_registered_faces({})
|
||||
return 0
|
||||
|
||||
person_ids = [person.person_id for person in blacklist_persons]
|
||||
|
||||
# 查询对应的人脸特征
|
||||
face_features = session.query(SurFaceFeature).filter(
|
||||
SurFaceFeature.person_id.in_(person_ids),
|
||||
SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION,
|
||||
SurFaceFeature.status == 2 # 计算成功的特征
|
||||
).all()
|
||||
|
||||
# 构建特征字典
|
||||
registered_faces = {}
|
||||
loaded_count = 0
|
||||
|
||||
for feature in face_features:
|
||||
if feature.feature_data:
|
||||
try:
|
||||
# 将bytea转换为numpy数组
|
||||
import numpy as np
|
||||
feature_array = np.frombuffer(feature.feature_data, dtype=np.float32)
|
||||
|
||||
# 使用person_id作为标识符
|
||||
person_name = f"blacklist_{feature.person_id}"
|
||||
registered_faces[person_name] = feature_array
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析黑名单人员 {feature.person_id} 的特征数据失败: {e}")
|
||||
continue
|
||||
|
||||
# 设置黑名单
|
||||
success = video_face_biz.set_registered_faces(registered_faces)
|
||||
if success:
|
||||
logger.info(f"✅ 同步黑名单完成,加载了 {loaded_count} 个黑名单人员")
|
||||
else:
|
||||
logger.error("❌ 设置黑名单失败")
|
||||
|
||||
return loaded_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 同步黑名单失败: {e}")
|
||||
return 0
|
||||
|
||||
@router.post("/sync-videofacebiz-params", summary="同步VideoFaceBiz参数")
|
||||
async def sync_videofacebiz_params_endpoint():
|
||||
"""
|
||||
同步VideoFaceBiz的参数
|
||||
|
||||
从sur_config表同步参数到VideoFaceBiz实例
|
||||
"""
|
||||
try:
|
||||
updated_count = sync_videofacebiz_params()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"同步参数完成,更新了 {updated_count} 个参数",
|
||||
"updated_count": updated_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步VideoFaceBiz参数失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"同步参数失败: {str(e)}")
|
||||
|
||||
@router.post("/sync-videofacebiz-blacklist", summary="同步VideoFaceBiz黑名单")
|
||||
async def sync_videofacebiz_blacklist_endpoint():
|
||||
"""
|
||||
同步VideoFaceBiz的黑名单
|
||||
|
||||
从sur_person_blacklist表同步黑名单到VideoFaceBiz实例
|
||||
"""
|
||||
try:
|
||||
loaded_count = sync_videofacebiz_blacklist()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"同步黑名单完成,加载了 {loaded_count} 个黑名单人员",
|
||||
"loaded_count": loaded_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步VideoFaceBiz黑名单失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"同步黑名单失败: {str(e)}")
|
||||
|
||||
@router.get("/videofacebiz-status", summary="获取VideoFaceBiz状态")
|
||||
async def get_videofacebiz_status():
|
||||
"""
|
||||
获取VideoFaceBiz的当前状态
|
||||
"""
|
||||
try:
|
||||
status = {
|
||||
"list_mode": video_face_biz.get_list_mode(),
|
||||
"clarity_threshold": video_face_biz.get_clarity_threshold(),
|
||||
"min_face_size": video_face_biz.get_min_face_size(),
|
||||
"pitch_threshold": video_face_biz.get_pitch_threshold(),
|
||||
"yaw_threshold": video_face_biz.get_yaw_threshold(),
|
||||
"similarity_threshold": video_face_biz.get_similarity_threshold(),
|
||||
"blacklist_count": video_face_biz.get_registered_face_count()
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": status
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取VideoFaceBiz状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||||
|
||||
# 导出路由器
|
||||
__all__ = ["router"]
|
||||
11
src/app.py
11
src/app.py
@@ -18,7 +18,7 @@ from fastapi.openapi.docs import (
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from src.api.routes import face_features
|
||||
from src.api.routes.algorithm_router import router as algorithm_router
|
||||
from src.api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, sync_videofacebiz_blacklist
|
||||
from src.api.errors import (
|
||||
APIError,
|
||||
validation_exception_handler,
|
||||
@@ -60,6 +60,15 @@ async def lifespan(app: FastAPI):
|
||||
rtsp_server.start()
|
||||
# 将 RTSP 服务实例保存到应用状态
|
||||
app.state.rtsp_server = rtsp_server
|
||||
|
||||
# 自动同步VideoFaceBiz参数和黑名单
|
||||
print("🔄 自动同步VideoFaceBiz参数和黑名单...")
|
||||
try:
|
||||
params_updated = sync_videofacebiz_params()
|
||||
blacklist_loaded = sync_videofacebiz_blacklist()
|
||||
print(f"✅ 自动同步完成 - 参数更新: {params_updated}个, 黑名单加载: {blacklist_loaded}个")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 自动同步失败: {e}")
|
||||
else:
|
||||
print("⚠️ RTSP 服务未启用")
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class Settings(BaseSettings):
|
||||
FACE_USE_GPU: bool = True
|
||||
FACE_USE_NPU: bool = False
|
||||
SUR_CONFIG_TYPE_FACE: int = 0
|
||||
SUR_CONFIG_SCOPE_GLOBAL: int = 0
|
||||
|
||||
# JWT配置(预留)
|
||||
SECRET_KEY: str = "your-secret-key-here-change-in-production"
|
||||
|
||||
24
src/models/sur_config.py
Normal file
24
src/models/sur_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
配置表模型
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, func, SmallInteger
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class SurConfigBase(Base):
|
||||
"""配置基础表"""
|
||||
__tablename__ = "sur_config_base"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
config_type = Column(SmallInteger, nullable=False, comment="配置类型:0=人脸识别")
|
||||
group_id = Column(Integer, nullable=False, comment="组id")
|
||||
config_key = Column(Text, nullable=False, comment="键")
|
||||
config_value = Column(Text, comment="值")
|
||||
description = Column(Text, comment="备注")
|
||||
created_time = Column(DateTime, comment="创建时间")
|
||||
updated_time = Column(DateTime, comment="修改时间")
|
||||
created_by = Column(Integer, comment="创建人")
|
||||
updated_by = Column(Integer, comment="修改人")
|
||||
34
src/models/sur_person.py
Normal file
34
src/models/sur_person.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
人员相关表模型
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class SurPersonBlacklist(Base):
|
||||
"""人员黑名单表"""
|
||||
__tablename__ = "sur_person_blacklist"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
person_id = Column(Integer, nullable=False, comment="人员ID")
|
||||
status = Column(Integer, nullable=False, default=1, comment="状态:0=禁用,1=启用")
|
||||
created_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
updated_time = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
|
||||
class SurFaceFeature(Base):
|
||||
"""人脸特征表"""
|
||||
__tablename__ = "sur_face_feature"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
person_id = Column(Integer, nullable=False, comment="人员ID")
|
||||
feature_type = Column(Integer, comment="模型版本")
|
||||
feature_data = Column(Text, comment="特征值")
|
||||
created_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
pic_id = Column(String(255), comment="图片ID")
|
||||
status = Column(Integer, default=0, comment="人脸特征值计算状态:0=未开始,1=计算中,2=计算成功,3=计算失败")
|
||||
start_time = Column(DateTime, comment="特征计算开始时间")
|
||||
finish_time = Column(DateTime, comment="特征计算结束时间")
|
||||
Reference in New Issue
Block a user