修改路径,从src放到根目录

This commit is contained in:
zqc
2026-01-08 10:32:36 +08:00
parent 96589ebdbd
commit f86effd63c
37 changed files with 51 additions and 410 deletions

View File

@@ -0,0 +1,676 @@
"""
人脸特征数据仓库
数据访问层,处理所有数据库操作
"""
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
from contextlib import contextmanager
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from models.face_feature import SurFaceFeature
from schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FeatureStatus
)
from models.face_feature import FeatureStatusEnum
from utils.logger import setup_logger
logger = setup_logger(__name__)
class FaceFeatureRepository:
"""人脸特征数据仓库"""
def __init__(self, session: Session):
"""
初始化仓库
Args:
session: SQLAlchemy会话对象
"""
self.session = session
# ===== 创建操作 =====
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
"""
创建特征记录
Args:
feature_data: 特征数据
Returns:
创建的SurFaceFeature对象
Raises:
IntegrityError: 违反唯一约束时抛出
"""
try:
# 转换为模型字典
feature_dict = feature_data.model_dump(exclude_unset=True)
# 创建模型实例
feature = SurFaceFeature(**feature_dict)
# 添加到会话
self.session.add(feature)
self.session.flush() # 立即执行,但不提交
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error creating face feature: {e}")
self.session.rollback()
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
f"feature_type={feature_data.feature_type}")
except SQLAlchemyError as e:
logger.error(f"Database error creating face feature: {e}")
self.session.rollback()
raise
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
"""
批量创建特征记录
Args:
features_data: 特征数据列表
Returns:
创建的SurFaceFeature对象列表
"""
try:
features = []
for feature_data in features_data:
feature_dict = feature_data.model_dump(exclude_unset=True)
feature = SurFaceFeature(**feature_dict)
features.append(feature)
# 批量添加
self.session.add_all(features)
self.session.flush()
logger.info(f"Created {len(features)} face feature records in batch")
return features
except IntegrityError as e:
logger.error(f"Integrity error creating batch face features: {e}")
self.session.rollback()
raise ValueError("Duplicate feature record in batch")
except SQLAlchemyError as e:
logger.error(f"Database error creating batch face features: {e}")
self.session.rollback()
raise
# ===== 查询操作 =====
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
"""
根据ID获取特征记录
Args:
feature_id: 特征记录ID
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature by id: {feature_id}")
else:
logger.debug(f"Face feature not found by id: {feature_id}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by id: {e}")
raise
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
"""
根据人员ID和特征类型获取特征记录
Args:
person_id: 人员ID
feature_type: 特征类型
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(
and_(
SurFaceFeature.person_id == person_id,
SurFaceFeature.feature_type == feature_type
)
)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
else:
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by person and type: {e}")
raise
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
"""
根据人员ID获取特征记录列表
Args:
person_id: 人员ID
limit: 返回数量限制
Returns:
SurFaceFeature对象列表
"""
try:
stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.person_id == person_id)
.order_by(desc(SurFaceFeature.created_time))
.limit(limit)
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting face features by person: {e}")
raise
def query_features(
self,
query: FaceFeatureQuery,
page: int = 1,
page_size: int = 20,
order_by: str = "created_time",
desc_order: bool = True
) -> Tuple[List[SurFaceFeature], int]:
"""
查询特征记录(带分页)
Args:
query: 查询条件
page: 页码从1开始
page_size: 每页数量
order_by: 排序字段
desc_order: 是否降序
Returns:
(特征记录列表, 总记录数)
"""
try:
# 构建查询条件
conditions = []
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
# 处理查询条件
if "person_id" in query_dict:
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
if "feature_type" in query_dict:
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
if "status" in query_dict:
conditions.append(SurFaceFeature.status == query_dict["status"])
if "start_date" in query_dict:
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
if "end_date" in query_dict:
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
if "has_feature_data" in query_dict:
if query_dict["has_feature_data"]:
conditions.append(SurFaceFeature.feature_data.isnot(None))
else:
conditions.append(SurFaceFeature.feature_data.is_(None))
# 基础查询
stmt = select(SurFaceFeature)
if conditions:
stmt = stmt.where(and_(*conditions))
# 获取总数
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = self.session.execute(count_stmt)
total = total_result.scalar_one()
# 排序
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
if desc_order:
stmt = stmt.order_by(desc(order_column))
else:
stmt = stmt.order_by(asc(order_column))
# 分页
offset = (page - 1) * page_size
stmt = stmt.offset(offset).limit(page_size)
# 执行查询
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Query returned {len(features)} features (total: {total})")
return features, total
except SQLAlchemyError as e:
logger.error(f"Database error querying face features: {e}")
raise
# ===== 更新操作 =====
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
"""
更新特征记录
Args:
feature_id: 特征记录ID
update_data: 更新数据
Returns:
更新后的SurFaceFeature对象或None如果不存在
"""
try:
# 先检查是否存在
feature = self.get_by_id(feature_id)
if not feature:
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
return None
# 转换为字典
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
# 更新字段
for key, value in update_dict.items():
setattr(feature, key, value)
# 刷新到数据库
self.session.flush()
logger.info(f"Updated face feature: id={feature_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error updating face feature: {e}")
self.session.rollback()
raise ValueError("Update would create duplicate record")
except SQLAlchemyError as e:
logger.error(f"Database error updating face feature: {e}")
self.session.rollback()
raise
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
"""
更新特征数据
Args:
feature_id: 特征记录ID
feature_data: 特征数据(二进制)
Returns:
是否成功更新
"""
try:
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(feature_data=feature_data)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated feature data for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating feature data: {e}")
self.session.rollback()
raise
def update_status(
self,
feature_id: int,
status: FeatureStatus,
start_time: Optional[datetime] = None,
finish_time: Optional[datetime] = None
) -> bool:
"""
更新计算状态
Args:
feature_id: 特征记录ID
status: 新状态
start_time: 开始时间(可选)
finish_time: 结束时间(可选)
Returns:
是否成功更新
"""
try:
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
if start_time:
update_values["start_time"] = start_time
if finish_time:
update_values["finish_time"] = finish_time
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(**update_values)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating status: {e}")
self.session.rollback()
raise
# ===== 删除操作 =====
def delete(self, feature_id: int) -> bool:
"""
删除特征记录
Args:
feature_id: 特征记录ID
Returns:
是否成功删除
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
if deleted_count > 0:
logger.info(f"Deleted face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error deleting face feature: {e}")
self.session.rollback()
raise
def delete_by_person(self, person_id: int) -> int:
"""
删除指定人员的所有特征记录
Args:
person_id: 人员ID
Returns:
删除的记录数
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error deleting face features by person: {e}")
self.session.rollback()
raise
# ===== 统计操作 =====
def get_stats(self) -> Dict[str, Any]:
"""
获取特征记录统计信息
Returns:
统计信息字典
"""
try:
# 总记录数
total_stmt = select(func.count()).select_from(SurFaceFeature)
total_result = self.session.execute(total_stmt)
total_count = total_result.scalar_one()
# 按状态统计
status_stmt = (
select(SurFaceFeature.status, func.count())
.group_by(SurFaceFeature.status)
)
status_result = self.session.execute(status_stmt)
status_stats = {str(status): count for status, count in status_result}
# 按特征类型统计
type_stmt = (
select(SurFaceFeature.feature_type, func.count())
.where(SurFaceFeature.feature_type.isnot(None))
.group_by(SurFaceFeature.feature_type)
)
type_result = self.session.execute(type_stmt)
type_stats = {str(feature_type): count for feature_type, count in type_result}
# 平均处理时间(仅计算成功和失败的)
time_stmt = (
select(
func.avg(
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
)
)
.where(
and_(
SurFaceFeature.start_time.isnot(None),
SurFaceFeature.finish_time.isnot(None),
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
)
time_result = self.session.execute(time_stmt)
avg_time = time_result.scalar_one()
stats = {
"total_count": total_count,
"by_status": status_stats,
"by_feature_type": type_stats,
"avg_processing_time": float(avg_time) if avg_time else None
}
logger.debug(f"Generated face feature statistics")
return stats
except SQLAlchemyError as e:
logger.error(f"Database error getting statistics: {e}")
raise
# ===== 批量操作 =====
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
"""
标记待处理的特征记录为计算中
Args:
limit: 最大处理数量
Returns:
标记为处理中的特征记录列表
"""
try:
# 查找待处理的记录
pending_stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
.order_by(SurFaceFeature.created_time)
.limit(limit)
.with_for_update(skip_locked=True) # 跳过被锁定的行
)
result = self.session.execute(pending_stmt)
pending_features = list(result.scalars().all())
# 更新状态
feature_ids = [f.id for f in pending_features]
if feature_ids:
update_stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id.in_(feature_ids))
.values(
status=FeatureStatusEnum.PROCESSING,
start_time=datetime.now()
)
)
self.session.execute(update_stmt)
logger.info(f"Marked {len(pending_features)} face features for processing")
return pending_features
except SQLAlchemyError as e:
logger.error(f"Database error marking features for processing: {e}")
self.session.rollback()
raise
def cleanup_old_features(self, days: int = 30) -> int:
"""
清理旧的特征记录
Args:
days: 保留天数
Returns:
删除的记录数
"""
try:
cutoff_date = datetime.now() - timedelta(days=days)
# 只删除已完成(成功或失败)的旧记录
stmt = delete(SurFaceFeature).where(
and_(
SurFaceFeature.created_time < cutoff_date,
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error cleaning up old features: {e}")
self.session.rollback()
raise
# ===== 新增方法:支持算法路由 =====
def get_features_by_type_and_status(
self,
feature_type: int,
status: Optional[int] = None
) -> List[SurFaceFeature]:
"""
根据特征类型和状态获取特征记录
Args:
feature_type: 特征类型
status: 状态(可选)
Returns:
特征记录列表
"""
try:
conditions = [SurFaceFeature.feature_type == feature_type]
if status is not None:
conditions.append(SurFaceFeature.status == status)
stmt = (
select(SurFaceFeature)
.where(and_(*conditions))
.order_by(asc(SurFaceFeature.created_time))
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} features by type={feature_type}, status={status}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting features by type and status: {e}")
raise
def count_by_type_and_status(
self,
feature_type: int,
status: Optional[int] = None
) -> int:
"""
统计指定特征类型和状态的记录数量
Args:
feature_type: 特征类型
status: 状态(可选)
Returns:
记录数量
"""
try:
conditions = [SurFaceFeature.feature_type == feature_type]
if status is not None:
conditions.append(SurFaceFeature.status == status)
stmt = select(func.count()).select_from(SurFaceFeature).where(and_(*conditions))
result = self.session.execute(stmt)
count = result.scalar_one()
return count
except SQLAlchemyError as e:
logger.error(f"Database error counting features by type and status: {e}")
raise
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息(兼容性方法)
Returns:
统计信息字典
"""
return self.get_stats()

View File

@@ -0,0 +1,74 @@
"""
配置相关数据访问层
"""
from typing import List, Dict, Optional
from sqlalchemy.orm import Session
from models.sur_config import SurConfig, SurConfigBase
from config import settings
class SurConfigRepository:
"""配置数据访问类"""
def __init__(self, db: Session):
self.db = db
def get_face_config_params(self) -> Dict[str, str]:
"""
获取人脸识别配置参数
返回:
配置参数字典 {config_key: config_value}
"""
try:
# 查询全局人脸配置
config_records = self.db.query(SurConfig).filter(
SurConfig.scope == settings.SUR_CONFIG_SCOPE_GLOBAL,
SurConfig.config_type == settings.SUR_CONFIG_TYPE_FACE
).all()
# 查询配置组对应的基础配置
config_group_ids = [record.config_group_id for record in config_records if record.config_group_id]
if config_group_ids:
base_configs = self.db.query(SurConfigBase).filter(
SurConfigBase.group_id.in_(config_group_ids)
).all()
# 合并配置
config_dict = {}
for record in config_records:
if record.config_key and record.config_value:
config_dict[record.config_key] = record.config_value
return config_dict
return {}
except Exception as e:
print(f"获取人脸配置参数失败: {e}")
return {}
def get_face_config_value(self, config_key: str) -> Optional[str]:
"""
获取指定配置键的值
参数:
config_key: 配置键
返回:
配置值如果不存在返回None
"""
try:
config_record = self.db.query(SurConfig).filter(
SurConfig.scope == settings.SUR_CONFIG_SCOPE_GLOBAL,
SurConfig.config_type == settings.SUR_CONFIG_TYPE_FACE,
SurConfig.config_key == config_key
).first()
return config_record.config_value if config_record else None
except Exception as e:
print(f"获取配置值失败 {config_key}: {e}")
return None

View File

@@ -0,0 +1,89 @@
"""
人员相关数据访问层
"""
from typing import List, Dict, Optional
from sqlalchemy.orm import Session
from models.sur_person import SurPersonBlacklist, SurFaceFeature
from config import settings
class SurPersonRepository:
"""人员数据访问类"""
def __init__(self, db: Session):
self.db = db
def get_blacklist_face_features(self) -> Dict[int, str]:
"""
获取黑名单人员的人脸特征
返回:
字典 {person_id: feature_data}
"""
try:
# 查询启用的黑名单人员
blacklist_persons = self.db.query(SurPersonBlacklist).filter(
SurPersonBlacklist.status == 1
).all()
if not blacklist_persons:
return {}
person_ids = [person.person_id for person in blacklist_persons]
# 查询对应的人脸特征
face_features = self.db.query(SurFaceFeature).filter(
SurFaceFeature.person_id.in_(person_ids),
SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION,
SurFaceFeature.status == 2 # 计算成功的特征
).all()
# 构建特征字典
feature_dict = {}
for feature in face_features:
if feature.feature_data:
feature_dict[feature.person_id] = feature.feature_data
return feature_dict
except Exception as e:
print(f"获取黑名单人脸特征失败: {e}")
return {}
def get_blacklist_person_count(self) -> int:
"""获取黑名单人员数量"""
try:
count = self.db.query(SurPersonBlacklist).filter(
SurPersonBlacklist.status == 1
).count()
return count
except Exception as e:
print(f"获取黑名单人员数量失败: {e}")
return 0
def get_blacklist_face_feature_count(self) -> int:
"""获取有特征数据的黑名单人员数量"""
try:
# 查询启用的黑名单人员
blacklist_persons = self.db.query(SurPersonBlacklist).filter(
SurPersonBlacklist.status == 1
).all()
if not blacklist_persons:
return 0
person_ids = [person.person_id for person in blacklist_persons]
# 查询对应的人脸特征数量
count = self.db.query(SurFaceFeature).filter(
SurFaceFeature.person_id.in_(person_ids),
SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION,
SurFaceFeature.status == 2 # 计算成功的特征
).count()
return count
except Exception as e:
print(f"获取黑名单人脸特征数量失败: {e}")
return 0

View File

@@ -0,0 +1,105 @@
"""
视频检查任务数据访问层
"""
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_
from models.video_check_task import SurVideoCheckTask, SurVideo, SurConfigBase
from config import settings
class VideoCheckTaskRepository:
"""视频检查任务数据访问类"""
def __init__(self, session: Session):
self.session = session
def get_pending_tasks(self) -> List[SurVideoCheckTask]:
"""获取待处理的任务status=0"""
return self.session.query(SurVideoCheckTask).filter(
SurVideoCheckTask.status == 0
).all()
def get_by_id(self, task_id: int) -> Optional[SurVideoCheckTask]:
"""根据ID获取任务"""
return self.session.query(SurVideoCheckTask).filter(
SurVideoCheckTask.id == task_id
).first()
def update_task_status(self, task_id: int, status: int,
start_time=None, finish_time=None,
progress=None, result=None, result_data=None,
feature_data=None) -> bool:
"""更新任务状态"""
try:
task = self.get_by_id(task_id)
if not task:
return False
task.status = status
if start_time:
task.start_time = start_time
if finish_time:
task.finish_time = finish_time
if progress is not None:
task.progress = progress
if result is not None:
task.result = result
if result_data is not None:
task.result_data = result_data
if feature_data is not None:
task.feature_data = feature_data
self.session.commit()
return True
except Exception:
self.session.rollback()
return False
def get_config_by_group_id(self, group_id: int) -> Dict[str, str]:
"""根据组ID获取配置"""
configs = self.session.query(SurConfigBase).filter(
and_(
SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE,
SurConfigBase.group_id == group_id
)
).all()
config_dict = {}
for config in configs:
config_dict[config.config_key] = config.config_value
return config_dict
def get_video_by_id(self, video_id: int) -> Optional[SurVideo]:
"""根据ID获取视频信息"""
return self.session.query(SurVideo).filter(
SurVideo.id == video_id
).first()
def get_videos_by_ids(self, video_ids: List[int]) -> List[SurVideo]:
"""根据ID列表获取视频信息"""
return self.session.query(SurVideo).filter(
SurVideo.id.in_(video_ids)
).all()
def get_processing_tasks(self) -> List[SurVideoCheckTask]:
"""获取正在处理的任务status=1"""
return self.session.query(SurVideoCheckTask).filter(
SurVideoCheckTask.status == 1
).all()
def get_timeout_tasks(self, timeout_hours: int) -> List[SurVideoCheckTask]:
"""获取超时的任务"""
from datetime import datetime, timedelta
timeout_time = datetime.now() - timedelta(hours=timeout_hours)
return self.session.query(SurVideoCheckTask).filter(
and_(
SurVideoCheckTask.status == 1,
SurVideoCheckTask.start_time < timeout_time
)
).all()