修改路径,从src放到根目录
This commit is contained in:
676
repositories/face_feature_repository.py
Normal file
676
repositories/face_feature_repository.py
Normal file
@@ -0,0 +1,676 @@
|
||||
"""
|
||||
人脸特征数据仓库
|
||||
数据访问层,处理所有数据库操作
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
|
||||
|
||||
from models.face_feature import SurFaceFeature
|
||||
from schemas.face_feature import (
|
||||
FaceFeatureCreate,
|
||||
FaceFeatureUpdate,
|
||||
FaceFeatureQuery,
|
||||
FeatureStatus
|
||||
)
|
||||
from models.face_feature import FeatureStatusEnum
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class FaceFeatureRepository:
|
||||
"""人脸特征数据仓库"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
"""
|
||||
初始化仓库
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy会话对象
|
||||
"""
|
||||
self.session = session
|
||||
|
||||
# ===== 创建操作 =====
|
||||
|
||||
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
|
||||
"""
|
||||
创建特征记录
|
||||
|
||||
Args:
|
||||
feature_data: 特征数据
|
||||
|
||||
Returns:
|
||||
创建的SurFaceFeature对象
|
||||
|
||||
Raises:
|
||||
IntegrityError: 违反唯一约束时抛出
|
||||
"""
|
||||
try:
|
||||
# 转换为模型字典
|
||||
feature_dict = feature_data.model_dump(exclude_unset=True)
|
||||
|
||||
# 创建模型实例
|
||||
feature = SurFaceFeature(**feature_dict)
|
||||
|
||||
# 添加到会话
|
||||
self.session.add(feature)
|
||||
self.session.flush() # 立即执行,但不提交
|
||||
|
||||
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
|
||||
return feature
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(f"Integrity error creating face feature: {e}")
|
||||
self.session.rollback()
|
||||
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
|
||||
f"feature_type={feature_data.feature_type}")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error creating face feature: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
|
||||
"""
|
||||
批量创建特征记录
|
||||
|
||||
Args:
|
||||
features_data: 特征数据列表
|
||||
|
||||
Returns:
|
||||
创建的SurFaceFeature对象列表
|
||||
"""
|
||||
try:
|
||||
features = []
|
||||
for feature_data in features_data:
|
||||
feature_dict = feature_data.model_dump(exclude_unset=True)
|
||||
feature = SurFaceFeature(**feature_dict)
|
||||
features.append(feature)
|
||||
|
||||
# 批量添加
|
||||
self.session.add_all(features)
|
||||
self.session.flush()
|
||||
|
||||
logger.info(f"Created {len(features)} face feature records in batch")
|
||||
return features
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(f"Integrity error creating batch face features: {e}")
|
||||
self.session.rollback()
|
||||
raise ValueError("Duplicate feature record in batch")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error creating batch face features: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
# ===== 查询操作 =====
|
||||
|
||||
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
|
||||
"""
|
||||
根据ID获取特征记录
|
||||
|
||||
Args:
|
||||
feature_id: 特征记录ID
|
||||
|
||||
Returns:
|
||||
SurFaceFeature对象或None
|
||||
"""
|
||||
try:
|
||||
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
|
||||
result = self.session.execute(stmt)
|
||||
feature = result.scalar_one_or_none()
|
||||
|
||||
if feature:
|
||||
logger.debug(f"Retrieved face feature by id: {feature_id}")
|
||||
else:
|
||||
logger.debug(f"Face feature not found by id: {feature_id}")
|
||||
|
||||
return feature
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting face feature by id: {e}")
|
||||
raise
|
||||
|
||||
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
|
||||
"""
|
||||
根据人员ID和特征类型获取特征记录
|
||||
|
||||
Args:
|
||||
person_id: 人员ID
|
||||
feature_type: 特征类型
|
||||
|
||||
Returns:
|
||||
SurFaceFeature对象或None
|
||||
"""
|
||||
try:
|
||||
stmt = select(SurFaceFeature).where(
|
||||
and_(
|
||||
SurFaceFeature.person_id == person_id,
|
||||
SurFaceFeature.feature_type == feature_type
|
||||
)
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
feature = result.scalar_one_or_none()
|
||||
|
||||
if feature:
|
||||
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
|
||||
else:
|
||||
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
|
||||
|
||||
return feature
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting face feature by person and type: {e}")
|
||||
raise
|
||||
|
||||
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
|
||||
"""
|
||||
根据人员ID获取特征记录列表
|
||||
|
||||
Args:
|
||||
person_id: 人员ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
SurFaceFeature对象列表
|
||||
"""
|
||||
try:
|
||||
stmt = (
|
||||
select(SurFaceFeature)
|
||||
.where(SurFaceFeature.person_id == person_id)
|
||||
.order_by(desc(SurFaceFeature.created_time))
|
||||
.limit(limit)
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
features = list(result.scalars().all())
|
||||
|
||||
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
|
||||
return features
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting face features by person: {e}")
|
||||
raise
|
||||
|
||||
def query_features(
|
||||
self,
|
||||
query: FaceFeatureQuery,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
order_by: str = "created_time",
|
||||
desc_order: bool = True
|
||||
) -> Tuple[List[SurFaceFeature], int]:
|
||||
"""
|
||||
查询特征记录(带分页)
|
||||
|
||||
Args:
|
||||
query: 查询条件
|
||||
page: 页码(从1开始)
|
||||
page_size: 每页数量
|
||||
order_by: 排序字段
|
||||
desc_order: 是否降序
|
||||
|
||||
Returns:
|
||||
(特征记录列表, 总记录数)
|
||||
"""
|
||||
try:
|
||||
# 构建查询条件
|
||||
conditions = []
|
||||
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
# 处理查询条件
|
||||
if "person_id" in query_dict:
|
||||
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
|
||||
|
||||
if "feature_type" in query_dict:
|
||||
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
|
||||
|
||||
if "status" in query_dict:
|
||||
conditions.append(SurFaceFeature.status == query_dict["status"])
|
||||
|
||||
if "start_date" in query_dict:
|
||||
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
|
||||
|
||||
if "end_date" in query_dict:
|
||||
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
|
||||
|
||||
if "has_feature_data" in query_dict:
|
||||
if query_dict["has_feature_data"]:
|
||||
conditions.append(SurFaceFeature.feature_data.isnot(None))
|
||||
else:
|
||||
conditions.append(SurFaceFeature.feature_data.is_(None))
|
||||
|
||||
# 基础查询
|
||||
stmt = select(SurFaceFeature)
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
# 获取总数
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total_result = self.session.execute(count_stmt)
|
||||
total = total_result.scalar_one()
|
||||
|
||||
# 排序
|
||||
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
|
||||
if desc_order:
|
||||
stmt = stmt.order_by(desc(order_column))
|
||||
else:
|
||||
stmt = stmt.order_by(asc(order_column))
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
stmt = stmt.offset(offset).limit(page_size)
|
||||
|
||||
# 执行查询
|
||||
result = self.session.execute(stmt)
|
||||
features = list(result.scalars().all())
|
||||
|
||||
logger.debug(f"Query returned {len(features)} features (total: {total})")
|
||||
return features, total
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error querying face features: {e}")
|
||||
raise
|
||||
|
||||
# ===== 更新操作 =====
|
||||
|
||||
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
|
||||
"""
|
||||
更新特征记录
|
||||
|
||||
Args:
|
||||
feature_id: 特征记录ID
|
||||
update_data: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的SurFaceFeature对象或None(如果不存在)
|
||||
"""
|
||||
try:
|
||||
# 先检查是否存在
|
||||
feature = self.get_by_id(feature_id)
|
||||
if not feature:
|
||||
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
|
||||
return None
|
||||
|
||||
# 转换为字典
|
||||
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
# 更新字段
|
||||
for key, value in update_dict.items():
|
||||
setattr(feature, key, value)
|
||||
|
||||
# 刷新到数据库
|
||||
self.session.flush()
|
||||
|
||||
logger.info(f"Updated face feature: id={feature_id}")
|
||||
return feature
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(f"Integrity error updating face feature: {e}")
|
||||
self.session.rollback()
|
||||
raise ValueError("Update would create duplicate record")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error updating face feature: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
|
||||
"""
|
||||
更新特征数据
|
||||
|
||||
Args:
|
||||
feature_id: 特征记录ID
|
||||
feature_data: 特征数据(二进制)
|
||||
|
||||
Returns:
|
||||
是否成功更新
|
||||
"""
|
||||
try:
|
||||
stmt = (
|
||||
update(SurFaceFeature)
|
||||
.where(SurFaceFeature.id == feature_id)
|
||||
.values(feature_data=feature_data)
|
||||
.returning(SurFaceFeature.id)
|
||||
)
|
||||
|
||||
result = self.session.execute(stmt)
|
||||
updated_id = result.scalar_one_or_none()
|
||||
|
||||
if updated_id:
|
||||
logger.info(f"Updated feature data for face feature: id={feature_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
|
||||
return False
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error updating feature data: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
feature_id: int,
|
||||
status: FeatureStatus,
|
||||
start_time: Optional[datetime] = None,
|
||||
finish_time: Optional[datetime] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新计算状态
|
||||
|
||||
Args:
|
||||
feature_id: 特征记录ID
|
||||
status: 新状态
|
||||
start_time: 开始时间(可选)
|
||||
finish_time: 结束时间(可选)
|
||||
|
||||
Returns:
|
||||
是否成功更新
|
||||
"""
|
||||
try:
|
||||
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
|
||||
|
||||
if start_time:
|
||||
update_values["start_time"] = start_time
|
||||
if finish_time:
|
||||
update_values["finish_time"] = finish_time
|
||||
|
||||
stmt = (
|
||||
update(SurFaceFeature)
|
||||
.where(SurFaceFeature.id == feature_id)
|
||||
.values(**update_values)
|
||||
.returning(SurFaceFeature.id)
|
||||
)
|
||||
|
||||
result = self.session.execute(stmt)
|
||||
updated_id = result.scalar_one_or_none()
|
||||
|
||||
if updated_id:
|
||||
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
|
||||
return False
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error updating status: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
# ===== 删除操作 =====
|
||||
|
||||
def delete(self, feature_id: int) -> bool:
|
||||
"""
|
||||
删除特征记录
|
||||
|
||||
Args:
|
||||
feature_id: 特征记录ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
try:
|
||||
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
|
||||
result = self.session.execute(stmt)
|
||||
|
||||
deleted_count = result.rowcount
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Deleted face feature: id={feature_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
|
||||
return False
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error deleting face feature: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
def delete_by_person(self, person_id: int) -> int:
|
||||
"""
|
||||
删除指定人员的所有特征记录
|
||||
|
||||
Args:
|
||||
person_id: 人员ID
|
||||
|
||||
Returns:
|
||||
删除的记录数
|
||||
"""
|
||||
try:
|
||||
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
|
||||
result = self.session.execute(stmt)
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
|
||||
return deleted_count
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error deleting face features by person: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
# ===== 统计操作 =====
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取特征记录统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
# 总记录数
|
||||
total_stmt = select(func.count()).select_from(SurFaceFeature)
|
||||
total_result = self.session.execute(total_stmt)
|
||||
total_count = total_result.scalar_one()
|
||||
|
||||
# 按状态统计
|
||||
status_stmt = (
|
||||
select(SurFaceFeature.status, func.count())
|
||||
.group_by(SurFaceFeature.status)
|
||||
)
|
||||
status_result = self.session.execute(status_stmt)
|
||||
status_stats = {str(status): count for status, count in status_result}
|
||||
|
||||
# 按特征类型统计
|
||||
type_stmt = (
|
||||
select(SurFaceFeature.feature_type, func.count())
|
||||
.where(SurFaceFeature.feature_type.isnot(None))
|
||||
.group_by(SurFaceFeature.feature_type)
|
||||
)
|
||||
type_result = self.session.execute(type_stmt)
|
||||
type_stats = {str(feature_type): count for feature_type, count in type_result}
|
||||
|
||||
# 平均处理时间(仅计算成功和失败的)
|
||||
time_stmt = (
|
||||
select(
|
||||
func.avg(
|
||||
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
SurFaceFeature.start_time.isnot(None),
|
||||
SurFaceFeature.finish_time.isnot(None),
|
||||
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
|
||||
)
|
||||
)
|
||||
)
|
||||
time_result = self.session.execute(time_stmt)
|
||||
avg_time = time_result.scalar_one()
|
||||
|
||||
stats = {
|
||||
"total_count": total_count,
|
||||
"by_status": status_stats,
|
||||
"by_feature_type": type_stats,
|
||||
"avg_processing_time": float(avg_time) if avg_time else None
|
||||
}
|
||||
|
||||
logger.debug(f"Generated face feature statistics")
|
||||
return stats
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting statistics: {e}")
|
||||
raise
|
||||
|
||||
# ===== 批量操作 =====
|
||||
|
||||
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
|
||||
"""
|
||||
标记待处理的特征记录为计算中
|
||||
|
||||
Args:
|
||||
limit: 最大处理数量
|
||||
|
||||
Returns:
|
||||
标记为处理中的特征记录列表
|
||||
"""
|
||||
try:
|
||||
# 查找待处理的记录
|
||||
pending_stmt = (
|
||||
select(SurFaceFeature)
|
||||
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
|
||||
.order_by(SurFaceFeature.created_time)
|
||||
.limit(limit)
|
||||
.with_for_update(skip_locked=True) # 跳过被锁定的行
|
||||
)
|
||||
|
||||
result = self.session.execute(pending_stmt)
|
||||
pending_features = list(result.scalars().all())
|
||||
|
||||
# 更新状态
|
||||
feature_ids = [f.id for f in pending_features]
|
||||
if feature_ids:
|
||||
update_stmt = (
|
||||
update(SurFaceFeature)
|
||||
.where(SurFaceFeature.id.in_(feature_ids))
|
||||
.values(
|
||||
status=FeatureStatusEnum.PROCESSING,
|
||||
start_time=datetime.now()
|
||||
)
|
||||
)
|
||||
self.session.execute(update_stmt)
|
||||
|
||||
logger.info(f"Marked {len(pending_features)} face features for processing")
|
||||
return pending_features
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error marking features for processing: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
def cleanup_old_features(self, days: int = 30) -> int:
|
||||
"""
|
||||
清理旧的特征记录
|
||||
|
||||
Args:
|
||||
days: 保留天数
|
||||
|
||||
Returns:
|
||||
删除的记录数
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
# 只删除已完成(成功或失败)的旧记录
|
||||
stmt = delete(SurFaceFeature).where(
|
||||
and_(
|
||||
SurFaceFeature.created_time < cutoff_date,
|
||||
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
|
||||
)
|
||||
)
|
||||
|
||||
result = self.session.execute(stmt)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
|
||||
return deleted_count
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error cleaning up old features: {e}")
|
||||
self.session.rollback()
|
||||
raise
|
||||
|
||||
# ===== 新增方法:支持算法路由 =====
|
||||
|
||||
def get_features_by_type_and_status(
|
||||
self,
|
||||
feature_type: int,
|
||||
status: Optional[int] = None
|
||||
) -> List[SurFaceFeature]:
|
||||
"""
|
||||
根据特征类型和状态获取特征记录
|
||||
|
||||
Args:
|
||||
feature_type: 特征类型
|
||||
status: 状态(可选)
|
||||
|
||||
Returns:
|
||||
特征记录列表
|
||||
"""
|
||||
try:
|
||||
conditions = [SurFaceFeature.feature_type == feature_type]
|
||||
|
||||
if status is not None:
|
||||
conditions.append(SurFaceFeature.status == status)
|
||||
|
||||
stmt = (
|
||||
select(SurFaceFeature)
|
||||
.where(and_(*conditions))
|
||||
.order_by(asc(SurFaceFeature.created_time))
|
||||
)
|
||||
|
||||
result = self.session.execute(stmt)
|
||||
features = list(result.scalars().all())
|
||||
|
||||
logger.debug(f"Retrieved {len(features)} features by type={feature_type}, status={status}")
|
||||
return features
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting features by type and status: {e}")
|
||||
raise
|
||||
|
||||
def count_by_type_and_status(
|
||||
self,
|
||||
feature_type: int,
|
||||
status: Optional[int] = None
|
||||
) -> int:
|
||||
"""
|
||||
统计指定特征类型和状态的记录数量
|
||||
|
||||
Args:
|
||||
feature_type: 特征类型
|
||||
status: 状态(可选)
|
||||
|
||||
Returns:
|
||||
记录数量
|
||||
"""
|
||||
try:
|
||||
conditions = [SurFaceFeature.feature_type == feature_type]
|
||||
|
||||
if status is not None:
|
||||
conditions.append(SurFaceFeature.status == status)
|
||||
|
||||
stmt = select(func.count()).select_from(SurFaceFeature).where(and_(*conditions))
|
||||
result = self.session.execute(stmt)
|
||||
count = result.scalar_one()
|
||||
|
||||
return count
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error counting features by type and status: {e}")
|
||||
raise
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息(兼容性方法)
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return self.get_stats()
|
||||
74
repositories/sur_config_repository.py
Normal file
74
repositories/sur_config_repository.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
配置相关数据访问层
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from models.sur_config import SurConfig, SurConfigBase
|
||||
from config import settings
|
||||
|
||||
|
||||
class SurConfigRepository:
|
||||
"""配置数据访问类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_face_config_params(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取人脸识别配置参数
|
||||
|
||||
返回:
|
||||
配置参数字典 {config_key: config_value}
|
||||
"""
|
||||
try:
|
||||
# 查询全局人脸配置
|
||||
config_records = self.db.query(SurConfig).filter(
|
||||
SurConfig.scope == settings.SUR_CONFIG_SCOPE_GLOBAL,
|
||||
SurConfig.config_type == settings.SUR_CONFIG_TYPE_FACE
|
||||
).all()
|
||||
|
||||
# 查询配置组对应的基础配置
|
||||
config_group_ids = [record.config_group_id for record in config_records if record.config_group_id]
|
||||
|
||||
if config_group_ids:
|
||||
base_configs = self.db.query(SurConfigBase).filter(
|
||||
SurConfigBase.group_id.in_(config_group_ids)
|
||||
).all()
|
||||
|
||||
# 合并配置
|
||||
config_dict = {}
|
||||
for record in config_records:
|
||||
if record.config_key and record.config_value:
|
||||
config_dict[record.config_key] = record.config_value
|
||||
|
||||
return config_dict
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取人脸配置参数失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_face_config_value(self, config_key: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定配置键的值
|
||||
|
||||
参数:
|
||||
config_key: 配置键
|
||||
|
||||
返回:
|
||||
配置值,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
config_record = self.db.query(SurConfig).filter(
|
||||
SurConfig.scope == settings.SUR_CONFIG_SCOPE_GLOBAL,
|
||||
SurConfig.config_type == settings.SUR_CONFIG_TYPE_FACE,
|
||||
SurConfig.config_key == config_key
|
||||
).first()
|
||||
|
||||
return config_record.config_value if config_record else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取配置值失败 {config_key}: {e}")
|
||||
return None
|
||||
89
repositories/sur_person_repository.py
Normal file
89
repositories/sur_person_repository.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
人员相关数据访问层
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from models.sur_person import SurPersonBlacklist, SurFaceFeature
|
||||
from config import settings
|
||||
|
||||
|
||||
class SurPersonRepository:
|
||||
"""人员数据访问类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_blacklist_face_features(self) -> Dict[int, str]:
|
||||
"""
|
||||
获取黑名单人员的人脸特征
|
||||
|
||||
返回:
|
||||
字典 {person_id: feature_data}
|
||||
"""
|
||||
try:
|
||||
# 查询启用的黑名单人员
|
||||
blacklist_persons = self.db.query(SurPersonBlacklist).filter(
|
||||
SurPersonBlacklist.status == 1
|
||||
).all()
|
||||
|
||||
if not blacklist_persons:
|
||||
return {}
|
||||
|
||||
person_ids = [person.person_id for person in blacklist_persons]
|
||||
|
||||
# 查询对应的人脸特征
|
||||
face_features = self.db.query(SurFaceFeature).filter(
|
||||
SurFaceFeature.person_id.in_(person_ids),
|
||||
SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION,
|
||||
SurFaceFeature.status == 2 # 计算成功的特征
|
||||
).all()
|
||||
|
||||
# 构建特征字典
|
||||
feature_dict = {}
|
||||
for feature in face_features:
|
||||
if feature.feature_data:
|
||||
feature_dict[feature.person_id] = feature.feature_data
|
||||
|
||||
return feature_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取黑名单人脸特征失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_blacklist_person_count(self) -> int:
|
||||
"""获取黑名单人员数量"""
|
||||
try:
|
||||
count = self.db.query(SurPersonBlacklist).filter(
|
||||
SurPersonBlacklist.status == 1
|
||||
).count()
|
||||
return count
|
||||
except Exception as e:
|
||||
print(f"获取黑名单人员数量失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_blacklist_face_feature_count(self) -> int:
|
||||
"""获取有特征数据的黑名单人员数量"""
|
||||
try:
|
||||
# 查询启用的黑名单人员
|
||||
blacklist_persons = self.db.query(SurPersonBlacklist).filter(
|
||||
SurPersonBlacklist.status == 1
|
||||
).all()
|
||||
|
||||
if not blacklist_persons:
|
||||
return 0
|
||||
|
||||
person_ids = [person.person_id for person in blacklist_persons]
|
||||
|
||||
# 查询对应的人脸特征数量
|
||||
count = self.db.query(SurFaceFeature).filter(
|
||||
SurFaceFeature.person_id.in_(person_ids),
|
||||
SurFaceFeature.feature_type == settings.FACE_MODEL_VERSION,
|
||||
SurFaceFeature.status == 2 # 计算成功的特征
|
||||
).count()
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取黑名单人脸特征数量失败: {e}")
|
||||
return 0
|
||||
105
repositories/video_check_repository.py
Normal file
105
repositories/video_check_repository.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
视频检查任务数据访问层
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
|
||||
from models.video_check_task import SurVideoCheckTask, SurVideo, SurConfigBase
|
||||
from config import settings
|
||||
|
||||
|
||||
class VideoCheckTaskRepository:
|
||||
"""视频检查任务数据访问类"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_pending_tasks(self) -> List[SurVideoCheckTask]:
|
||||
"""获取待处理的任务(status=0)"""
|
||||
return self.session.query(SurVideoCheckTask).filter(
|
||||
SurVideoCheckTask.status == 0
|
||||
).all()
|
||||
|
||||
def get_by_id(self, task_id: int) -> Optional[SurVideoCheckTask]:
|
||||
"""根据ID获取任务"""
|
||||
return self.session.query(SurVideoCheckTask).filter(
|
||||
SurVideoCheckTask.id == task_id
|
||||
).first()
|
||||
|
||||
def update_task_status(self, task_id: int, status: int,
|
||||
start_time=None, finish_time=None,
|
||||
progress=None, result=None, result_data=None,
|
||||
feature_data=None) -> bool:
|
||||
"""更新任务状态"""
|
||||
try:
|
||||
task = self.get_by_id(task_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.status = status
|
||||
if start_time:
|
||||
task.start_time = start_time
|
||||
if finish_time:
|
||||
task.finish_time = finish_time
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if result_data is not None:
|
||||
task.result_data = result_data
|
||||
if feature_data is not None:
|
||||
task.feature_data = feature_data
|
||||
|
||||
self.session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
self.session.rollback()
|
||||
return False
|
||||
|
||||
def get_config_by_group_id(self, group_id: int) -> Dict[str, str]:
|
||||
"""根据组ID获取配置"""
|
||||
configs = self.session.query(SurConfigBase).filter(
|
||||
and_(
|
||||
SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE,
|
||||
SurConfigBase.group_id == group_id
|
||||
)
|
||||
).all()
|
||||
|
||||
config_dict = {}
|
||||
for config in configs:
|
||||
config_dict[config.config_key] = config.config_value
|
||||
|
||||
return config_dict
|
||||
|
||||
def get_video_by_id(self, video_id: int) -> Optional[SurVideo]:
|
||||
"""根据ID获取视频信息"""
|
||||
return self.session.query(SurVideo).filter(
|
||||
SurVideo.id == video_id
|
||||
).first()
|
||||
|
||||
def get_videos_by_ids(self, video_ids: List[int]) -> List[SurVideo]:
|
||||
"""根据ID列表获取视频信息"""
|
||||
return self.session.query(SurVideo).filter(
|
||||
SurVideo.id.in_(video_ids)
|
||||
).all()
|
||||
|
||||
def get_processing_tasks(self) -> List[SurVideoCheckTask]:
|
||||
"""获取正在处理的任务(status=1)"""
|
||||
return self.session.query(SurVideoCheckTask).filter(
|
||||
SurVideoCheckTask.status == 1
|
||||
).all()
|
||||
|
||||
def get_timeout_tasks(self, timeout_hours: int) -> List[SurVideoCheckTask]:
|
||||
"""获取超时的任务"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
timeout_time = datetime.now() - timedelta(hours=timeout_hours)
|
||||
|
||||
return self.session.query(SurVideoCheckTask).filter(
|
||||
and_(
|
||||
SurVideoCheckTask.status == 1,
|
||||
SurVideoCheckTask.start_time < timeout_time
|
||||
)
|
||||
).all()
|
||||
Reference in New Issue
Block a user