676 lines
22 KiB
Python
676 lines
22 KiB
Python
"""
|
||
人脸特征数据仓库
|
||
数据访问层,处理所有数据库操作
|
||
"""
|
||
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, List, Dict, Any, Tuple
|
||
from contextlib import contextmanager
|
||
|
||
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
|
||
from sqlalchemy.orm import Session, joinedload
|
||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
|
||
|
||
from models.face_feature import SurFaceFeature
|
||
from schemas.face_feature import (
|
||
FaceFeatureCreate,
|
||
FaceFeatureUpdate,
|
||
FaceFeatureQuery,
|
||
FeatureStatus
|
||
)
|
||
from models.face_feature import FeatureStatusEnum
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
class FaceFeatureRepository:
|
||
"""人脸特征数据仓库"""
|
||
|
||
def __init__(self, session: Session):
|
||
"""
|
||
初始化仓库
|
||
|
||
Args:
|
||
session: SQLAlchemy会话对象
|
||
"""
|
||
self.session = session
|
||
|
||
# ===== 创建操作 =====
|
||
|
||
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
|
||
"""
|
||
创建特征记录
|
||
|
||
Args:
|
||
feature_data: 特征数据
|
||
|
||
Returns:
|
||
创建的SurFaceFeature对象
|
||
|
||
Raises:
|
||
IntegrityError: 违反唯一约束时抛出
|
||
"""
|
||
try:
|
||
# 转换为模型字典
|
||
feature_dict = feature_data.model_dump(exclude_unset=True)
|
||
|
||
# 创建模型实例
|
||
feature = SurFaceFeature(**feature_dict)
|
||
|
||
# 添加到会话
|
||
self.session.add(feature)
|
||
self.session.flush() # 立即执行,但不提交
|
||
|
||
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
|
||
return feature
|
||
|
||
except IntegrityError as e:
|
||
logger.error(f"Integrity error creating face feature: {e}")
|
||
self.session.rollback()
|
||
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
|
||
f"feature_type={feature_data.feature_type}")
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error creating face feature: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
|
||
"""
|
||
批量创建特征记录
|
||
|
||
Args:
|
||
features_data: 特征数据列表
|
||
|
||
Returns:
|
||
创建的SurFaceFeature对象列表
|
||
"""
|
||
try:
|
||
features = []
|
||
for feature_data in features_data:
|
||
feature_dict = feature_data.model_dump(exclude_unset=True)
|
||
feature = SurFaceFeature(**feature_dict)
|
||
features.append(feature)
|
||
|
||
# 批量添加
|
||
self.session.add_all(features)
|
||
self.session.flush()
|
||
|
||
logger.info(f"Created {len(features)} face feature records in batch")
|
||
return features
|
||
|
||
except IntegrityError as e:
|
||
logger.error(f"Integrity error creating batch face features: {e}")
|
||
self.session.rollback()
|
||
raise ValueError("Duplicate feature record in batch")
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error creating batch face features: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
# ===== 查询操作 =====
|
||
|
||
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
|
||
"""
|
||
根据ID获取特征记录
|
||
|
||
Args:
|
||
feature_id: 特征记录ID
|
||
|
||
Returns:
|
||
SurFaceFeature对象或None
|
||
"""
|
||
try:
|
||
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
|
||
result = self.session.execute(stmt)
|
||
feature = result.scalar_one_or_none()
|
||
|
||
if feature:
|
||
logger.debug(f"Retrieved face feature by id: {feature_id}")
|
||
else:
|
||
logger.debug(f"Face feature not found by id: {feature_id}")
|
||
|
||
return feature
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error getting face feature by id: {e}")
|
||
raise
|
||
|
||
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
|
||
"""
|
||
根据人员ID和特征类型获取特征记录
|
||
|
||
Args:
|
||
person_id: 人员ID
|
||
feature_type: 特征类型
|
||
|
||
Returns:
|
||
SurFaceFeature对象或None
|
||
"""
|
||
try:
|
||
stmt = select(SurFaceFeature).where(
|
||
and_(
|
||
SurFaceFeature.person_id == person_id,
|
||
SurFaceFeature.feature_type == feature_type
|
||
)
|
||
)
|
||
result = self.session.execute(stmt)
|
||
feature = result.scalar_one_or_none()
|
||
|
||
if feature:
|
||
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
|
||
else:
|
||
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
|
||
|
||
return feature
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error getting face feature by person and type: {e}")
|
||
raise
|
||
|
||
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
|
||
"""
|
||
根据人员ID获取特征记录列表
|
||
|
||
Args:
|
||
person_id: 人员ID
|
||
limit: 返回数量限制
|
||
|
||
Returns:
|
||
SurFaceFeature对象列表
|
||
"""
|
||
try:
|
||
stmt = (
|
||
select(SurFaceFeature)
|
||
.where(SurFaceFeature.person_id == person_id)
|
||
.order_by(desc(SurFaceFeature.created_time))
|
||
.limit(limit)
|
||
)
|
||
result = self.session.execute(stmt)
|
||
features = list(result.scalars().all())
|
||
|
||
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
|
||
return features
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error getting face features by person: {e}")
|
||
raise
|
||
|
||
def query_features(
|
||
self,
|
||
query: FaceFeatureQuery,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
order_by: str = "created_time",
|
||
desc_order: bool = True
|
||
) -> Tuple[List[SurFaceFeature], int]:
|
||
"""
|
||
查询特征记录(带分页)
|
||
|
||
Args:
|
||
query: 查询条件
|
||
page: 页码(从1开始)
|
||
page_size: 每页数量
|
||
order_by: 排序字段
|
||
desc_order: 是否降序
|
||
|
||
Returns:
|
||
(特征记录列表, 总记录数)
|
||
"""
|
||
try:
|
||
# 构建查询条件
|
||
conditions = []
|
||
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
|
||
|
||
# 处理查询条件
|
||
if "person_id" in query_dict:
|
||
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
|
||
|
||
if "feature_type" in query_dict:
|
||
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
|
||
|
||
if "status" in query_dict:
|
||
conditions.append(SurFaceFeature.status == query_dict["status"])
|
||
|
||
if "start_date" in query_dict:
|
||
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
|
||
|
||
if "end_date" in query_dict:
|
||
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
|
||
|
||
if "has_feature_data" in query_dict:
|
||
if query_dict["has_feature_data"]:
|
||
conditions.append(SurFaceFeature.feature_data.isnot(None))
|
||
else:
|
||
conditions.append(SurFaceFeature.feature_data.is_(None))
|
||
|
||
# 基础查询
|
||
stmt = select(SurFaceFeature)
|
||
if conditions:
|
||
stmt = stmt.where(and_(*conditions))
|
||
|
||
# 获取总数
|
||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||
total_result = self.session.execute(count_stmt)
|
||
total = total_result.scalar_one()
|
||
|
||
# 排序
|
||
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
|
||
if desc_order:
|
||
stmt = stmt.order_by(desc(order_column))
|
||
else:
|
||
stmt = stmt.order_by(asc(order_column))
|
||
|
||
# 分页
|
||
offset = (page - 1) * page_size
|
||
stmt = stmt.offset(offset).limit(page_size)
|
||
|
||
# 执行查询
|
||
result = self.session.execute(stmt)
|
||
features = list(result.scalars().all())
|
||
|
||
logger.debug(f"Query returned {len(features)} features (total: {total})")
|
||
return features, total
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error querying face features: {e}")
|
||
raise
|
||
|
||
# ===== 更新操作 =====
|
||
|
||
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
|
||
"""
|
||
更新特征记录
|
||
|
||
Args:
|
||
feature_id: 特征记录ID
|
||
update_data: 更新数据
|
||
|
||
Returns:
|
||
更新后的SurFaceFeature对象或None(如果不存在)
|
||
"""
|
||
try:
|
||
# 先检查是否存在
|
||
feature = self.get_by_id(feature_id)
|
||
if not feature:
|
||
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
|
||
return None
|
||
|
||
# 转换为字典
|
||
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
|
||
|
||
# 更新字段
|
||
for key, value in update_dict.items():
|
||
setattr(feature, key, value)
|
||
|
||
# 刷新到数据库
|
||
self.session.flush()
|
||
|
||
logger.info(f"Updated face feature: id={feature_id}")
|
||
return feature
|
||
|
||
except IntegrityError as e:
|
||
logger.error(f"Integrity error updating face feature: {e}")
|
||
self.session.rollback()
|
||
raise ValueError("Update would create duplicate record")
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error updating face feature: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
|
||
"""
|
||
更新特征数据
|
||
|
||
Args:
|
||
feature_id: 特征记录ID
|
||
feature_data: 特征数据(二进制)
|
||
|
||
Returns:
|
||
是否成功更新
|
||
"""
|
||
try:
|
||
stmt = (
|
||
update(SurFaceFeature)
|
||
.where(SurFaceFeature.id == feature_id)
|
||
.values(feature_data=feature_data)
|
||
.returning(SurFaceFeature.id)
|
||
)
|
||
|
||
result = self.session.execute(stmt)
|
||
updated_id = result.scalar_one_or_none()
|
||
|
||
if updated_id:
|
||
logger.info(f"Updated feature data for face feature: id={feature_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
|
||
return False
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error updating feature data: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
def update_status(
|
||
self,
|
||
feature_id: int,
|
||
status: FeatureStatus,
|
||
start_time: Optional[datetime] = None,
|
||
finish_time: Optional[datetime] = None
|
||
) -> bool:
|
||
"""
|
||
更新计算状态
|
||
|
||
Args:
|
||
feature_id: 特征记录ID
|
||
status: 新状态
|
||
start_time: 开始时间(可选)
|
||
finish_time: 结束时间(可选)
|
||
|
||
Returns:
|
||
是否成功更新
|
||
"""
|
||
try:
|
||
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
|
||
|
||
if start_time:
|
||
update_values["start_time"] = start_time
|
||
if finish_time:
|
||
update_values["finish_time"] = finish_time
|
||
|
||
stmt = (
|
||
update(SurFaceFeature)
|
||
.where(SurFaceFeature.id == feature_id)
|
||
.values(**update_values)
|
||
.returning(SurFaceFeature.id)
|
||
)
|
||
|
||
result = self.session.execute(stmt)
|
||
updated_id = result.scalar_one_or_none()
|
||
|
||
if updated_id:
|
||
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
|
||
return False
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error updating status: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
# ===== 删除操作 =====
|
||
|
||
def delete(self, feature_id: int) -> bool:
|
||
"""
|
||
删除特征记录
|
||
|
||
Args:
|
||
feature_id: 特征记录ID
|
||
|
||
Returns:
|
||
是否成功删除
|
||
"""
|
||
try:
|
||
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
|
||
result = self.session.execute(stmt)
|
||
|
||
deleted_count = result.rowcount
|
||
if deleted_count > 0:
|
||
logger.info(f"Deleted face feature: id={feature_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
|
||
return False
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error deleting face feature: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
def delete_by_person(self, person_id: int) -> int:
|
||
"""
|
||
删除指定人员的所有特征记录
|
||
|
||
Args:
|
||
person_id: 人员ID
|
||
|
||
Returns:
|
||
删除的记录数
|
||
"""
|
||
try:
|
||
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
|
||
result = self.session.execute(stmt)
|
||
|
||
deleted_count = result.rowcount
|
||
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
|
||
return deleted_count
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error deleting face features by person: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
# ===== 统计操作 =====
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""
|
||
获取特征记录统计信息
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
try:
|
||
# 总记录数
|
||
total_stmt = select(func.count()).select_from(SurFaceFeature)
|
||
total_result = self.session.execute(total_stmt)
|
||
total_count = total_result.scalar_one()
|
||
|
||
# 按状态统计
|
||
status_stmt = (
|
||
select(SurFaceFeature.status, func.count())
|
||
.group_by(SurFaceFeature.status)
|
||
)
|
||
status_result = self.session.execute(status_stmt)
|
||
status_stats = {str(status): count for status, count in status_result}
|
||
|
||
# 按特征类型统计
|
||
type_stmt = (
|
||
select(SurFaceFeature.feature_type, func.count())
|
||
.where(SurFaceFeature.feature_type.isnot(None))
|
||
.group_by(SurFaceFeature.feature_type)
|
||
)
|
||
type_result = self.session.execute(type_stmt)
|
||
type_stats = {str(feature_type): count for feature_type, count in type_result}
|
||
|
||
# 平均处理时间(仅计算成功和失败的)
|
||
time_stmt = (
|
||
select(
|
||
func.avg(
|
||
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
|
||
)
|
||
)
|
||
.where(
|
||
and_(
|
||
SurFaceFeature.start_time.isnot(None),
|
||
SurFaceFeature.finish_time.isnot(None),
|
||
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
|
||
)
|
||
)
|
||
)
|
||
time_result = self.session.execute(time_stmt)
|
||
avg_time = time_result.scalar_one()
|
||
|
||
stats = {
|
||
"total_count": total_count,
|
||
"by_status": status_stats,
|
||
"by_feature_type": type_stats,
|
||
"avg_processing_time": float(avg_time) if avg_time else None
|
||
}
|
||
|
||
logger.debug(f"Generated face feature statistics")
|
||
return stats
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error getting statistics: {e}")
|
||
raise
|
||
|
||
# ===== 批量操作 =====
|
||
|
||
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
|
||
"""
|
||
标记待处理的特征记录为计算中
|
||
|
||
Args:
|
||
limit: 最大处理数量
|
||
|
||
Returns:
|
||
标记为处理中的特征记录列表
|
||
"""
|
||
try:
|
||
# 查找待处理的记录
|
||
pending_stmt = (
|
||
select(SurFaceFeature)
|
||
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
|
||
.order_by(SurFaceFeature.created_time)
|
||
.limit(limit)
|
||
.with_for_update(skip_locked=True) # 跳过被锁定的行
|
||
)
|
||
|
||
result = self.session.execute(pending_stmt)
|
||
pending_features = list(result.scalars().all())
|
||
|
||
# 更新状态
|
||
feature_ids = [f.id for f in pending_features]
|
||
if feature_ids:
|
||
update_stmt = (
|
||
update(SurFaceFeature)
|
||
.where(SurFaceFeature.id.in_(feature_ids))
|
||
.values(
|
||
status=FeatureStatusEnum.PROCESSING,
|
||
start_time=datetime.now()
|
||
)
|
||
)
|
||
self.session.execute(update_stmt)
|
||
|
||
logger.info(f"Marked {len(pending_features)} face features for processing")
|
||
return pending_features
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error marking features for processing: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
def cleanup_old_features(self, days: int = 30) -> int:
|
||
"""
|
||
清理旧的特征记录
|
||
|
||
Args:
|
||
days: 保留天数
|
||
|
||
Returns:
|
||
删除的记录数
|
||
"""
|
||
try:
|
||
cutoff_date = datetime.now() - timedelta(days=days)
|
||
|
||
# 只删除已完成(成功或失败)的旧记录
|
||
stmt = delete(SurFaceFeature).where(
|
||
and_(
|
||
SurFaceFeature.created_time < cutoff_date,
|
||
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
|
||
)
|
||
)
|
||
|
||
result = self.session.execute(stmt)
|
||
deleted_count = result.rowcount
|
||
|
||
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
|
||
return deleted_count
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error cleaning up old features: {e}")
|
||
self.session.rollback()
|
||
raise
|
||
|
||
# ===== 新增方法:支持算法路由 =====
|
||
|
||
def get_features_by_type_and_status(
|
||
self,
|
||
feature_type: int,
|
||
status: Optional[int] = None
|
||
) -> List[SurFaceFeature]:
|
||
"""
|
||
根据特征类型和状态获取特征记录
|
||
|
||
Args:
|
||
feature_type: 特征类型
|
||
status: 状态(可选)
|
||
|
||
Returns:
|
||
特征记录列表
|
||
"""
|
||
try:
|
||
conditions = [SurFaceFeature.feature_type == feature_type]
|
||
|
||
if status is not None:
|
||
conditions.append(SurFaceFeature.status == status)
|
||
|
||
stmt = (
|
||
select(SurFaceFeature)
|
||
.where(and_(*conditions))
|
||
.order_by(asc(SurFaceFeature.created_time))
|
||
)
|
||
|
||
result = self.session.execute(stmt)
|
||
features = list(result.scalars().all())
|
||
|
||
logger.debug(f"Retrieved {len(features)} features by type={feature_type}, status={status}")
|
||
return features
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error getting features by type and status: {e}")
|
||
raise
|
||
|
||
def count_by_type_and_status(
|
||
self,
|
||
feature_type: int,
|
||
status: Optional[int] = None
|
||
) -> int:
|
||
"""
|
||
统计指定特征类型和状态的记录数量
|
||
|
||
Args:
|
||
feature_type: 特征类型
|
||
status: 状态(可选)
|
||
|
||
Returns:
|
||
记录数量
|
||
"""
|
||
try:
|
||
conditions = [SurFaceFeature.feature_type == feature_type]
|
||
|
||
if status is not None:
|
||
conditions.append(SurFaceFeature.status == status)
|
||
|
||
stmt = select(func.count()).select_from(SurFaceFeature).where(and_(*conditions))
|
||
result = self.session.execute(stmt)
|
||
count = result.scalar_one()
|
||
|
||
return count
|
||
|
||
except SQLAlchemyError as e:
|
||
logger.error(f"Database error counting features by type and status: {e}")
|
||
raise
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""
|
||
获取统计信息(兼容性方法)
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
return self.get_stats() |