""" 人脸特征数据仓库 数据访问层,处理所有数据库操作 """ import logging from datetime import datetime, timedelta from typing import Optional, List, Dict, Any, Tuple from contextlib import contextmanager from sqlalchemy import select, update, delete, func, and_, or_, desc, asc from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError, IntegrityError from models.face_feature import SurFaceFeature from schemas.face_feature import ( FaceFeatureCreate, FaceFeatureUpdate, FaceFeatureQuery, FeatureStatus ) from models.face_feature import FeatureStatusEnum from utils.logger import setup_logger logger = setup_logger(__name__) class FaceFeatureRepository: """人脸特征数据仓库""" def __init__(self, session: Session): """ 初始化仓库 Args: session: SQLAlchemy会话对象 """ self.session = session # ===== 创建操作 ===== def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature: """ 创建特征记录 Args: feature_data: 特征数据 Returns: 创建的SurFaceFeature对象 Raises: IntegrityError: 违反唯一约束时抛出 """ try: # 转换为模型字典 feature_dict = feature_data.model_dump(exclude_unset=True) # 创建模型实例 feature = SurFaceFeature(**feature_dict) # 添加到会话 self.session.add(feature) self.session.flush() # 立即执行,但不提交 logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}") return feature except IntegrityError as e: logger.error(f"Integrity error creating face feature: {e}") self.session.rollback() raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, " f"feature_type={feature_data.feature_type}") except SQLAlchemyError as e: logger.error(f"Database error creating face feature: {e}") self.session.rollback() raise def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]: """ 批量创建特征记录 Args: features_data: 特征数据列表 Returns: 创建的SurFaceFeature对象列表 """ try: features = [] for feature_data in features_data: feature_dict = feature_data.model_dump(exclude_unset=True) feature = SurFaceFeature(**feature_dict) features.append(feature) # 批量添加 self.session.add_all(features) self.session.flush() logger.info(f"Created {len(features)} face feature records in batch") return features except IntegrityError as e: logger.error(f"Integrity error creating batch face features: {e}") self.session.rollback() raise ValueError("Duplicate feature record in batch") except SQLAlchemyError as e: logger.error(f"Database error creating batch face features: {e}") self.session.rollback() raise # ===== 查询操作 ===== def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]: """ 根据ID获取特征记录 Args: feature_id: 特征记录ID Returns: SurFaceFeature对象或None """ try: stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id) result = self.session.execute(stmt) feature = result.scalar_one_or_none() if feature: logger.debug(f"Retrieved face feature by id: {feature_id}") else: logger.debug(f"Face feature not found by id: {feature_id}") return feature except SQLAlchemyError as e: logger.error(f"Database error getting face feature by id: {e}") raise def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]: """ 根据人员ID和特征类型获取特征记录 Args: person_id: 人员ID feature_type: 特征类型 Returns: SurFaceFeature对象或None """ try: stmt = select(SurFaceFeature).where( and_( SurFaceFeature.person_id == person_id, SurFaceFeature.feature_type == feature_type ) ) result = self.session.execute(stmt) feature = result.scalar_one_or_none() if feature: logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}") else: logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}") return feature except SQLAlchemyError as e: logger.error(f"Database error getting face feature by person and type: {e}") raise def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]: """ 根据人员ID获取特征记录列表 Args: person_id: 人员ID limit: 返回数量限制 Returns: SurFaceFeature对象列表 """ try: stmt = ( select(SurFaceFeature) .where(SurFaceFeature.person_id == person_id) .order_by(desc(SurFaceFeature.created_time)) .limit(limit) ) result = self.session.execute(stmt) features = list(result.scalars().all()) logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}") return features except SQLAlchemyError as e: logger.error(f"Database error getting face features by person: {e}") raise def query_features( self, query: FaceFeatureQuery, page: int = 1, page_size: int = 20, order_by: str = "created_time", desc_order: bool = True ) -> Tuple[List[SurFaceFeature], int]: """ 查询特征记录(带分页) Args: query: 查询条件 page: 页码(从1开始) page_size: 每页数量 order_by: 排序字段 desc_order: 是否降序 Returns: (特征记录列表, 总记录数) """ try: # 构建查询条件 conditions = [] query_dict = query.model_dump(exclude_unset=True, exclude_none=True) # 处理查询条件 if "person_id" in query_dict: conditions.append(SurFaceFeature.person_id == query_dict["person_id"]) if "feature_type" in query_dict: conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"]) if "status" in query_dict: conditions.append(SurFaceFeature.status == query_dict["status"]) if "start_date" in query_dict: conditions.append(SurFaceFeature.created_time >= query_dict["start_date"]) if "end_date" in query_dict: conditions.append(SurFaceFeature.created_time <= query_dict["end_date"]) if "has_feature_data" in query_dict: if query_dict["has_feature_data"]: conditions.append(SurFaceFeature.feature_data.isnot(None)) else: conditions.append(SurFaceFeature.feature_data.is_(None)) # 基础查询 stmt = select(SurFaceFeature) if conditions: stmt = stmt.where(and_(*conditions)) # 获取总数 count_stmt = select(func.count()).select_from(stmt.subquery()) total_result = self.session.execute(count_stmt) total = total_result.scalar_one() # 排序 order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time) if desc_order: stmt = stmt.order_by(desc(order_column)) else: stmt = stmt.order_by(asc(order_column)) # 分页 offset = (page - 1) * page_size stmt = stmt.offset(offset).limit(page_size) # 执行查询 result = self.session.execute(stmt) features = list(result.scalars().all()) logger.debug(f"Query returned {len(features)} features (total: {total})") return features, total except SQLAlchemyError as e: logger.error(f"Database error querying face features: {e}") raise # ===== 更新操作 ===== def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]: """ 更新特征记录 Args: feature_id: 特征记录ID update_data: 更新数据 Returns: 更新后的SurFaceFeature对象或None(如果不存在) """ try: # 先检查是否存在 feature = self.get_by_id(feature_id) if not feature: logger.warning(f"Cannot update non-existent face feature: id={feature_id}") return None # 转换为字典 update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True) # 更新字段 for key, value in update_dict.items(): setattr(feature, key, value) # 刷新到数据库 self.session.flush() logger.info(f"Updated face feature: id={feature_id}") return feature except IntegrityError as e: logger.error(f"Integrity error updating face feature: {e}") self.session.rollback() raise ValueError("Update would create duplicate record") except SQLAlchemyError as e: logger.error(f"Database error updating face feature: {e}") self.session.rollback() raise def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool: """ 更新特征数据 Args: feature_id: 特征记录ID feature_data: 特征数据(二进制) Returns: 是否成功更新 """ try: stmt = ( update(SurFaceFeature) .where(SurFaceFeature.id == feature_id) .values(feature_data=feature_data) .returning(SurFaceFeature.id) ) result = self.session.execute(stmt) updated_id = result.scalar_one_or_none() if updated_id: logger.info(f"Updated feature data for face feature: id={feature_id}") return True else: logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}") return False except SQLAlchemyError as e: logger.error(f"Database error updating feature data: {e}") self.session.rollback() raise def update_status( self, feature_id: int, status: FeatureStatus, start_time: Optional[datetime] = None, finish_time: Optional[datetime] = None ) -> bool: """ 更新计算状态 Args: feature_id: 特征记录ID status: 新状态 start_time: 开始时间(可选) finish_time: 结束时间(可选) Returns: 是否成功更新 """ try: update_values = {"status": status.value if isinstance(status, FeatureStatus) else status} if start_time: update_values["start_time"] = start_time if finish_time: update_values["finish_time"] = finish_time stmt = ( update(SurFaceFeature) .where(SurFaceFeature.id == feature_id) .values(**update_values) .returning(SurFaceFeature.id) ) result = self.session.execute(stmt) updated_id = result.scalar_one_or_none() if updated_id: logger.info(f"Updated status to {status} for face feature: id={feature_id}") return True else: logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}") return False except SQLAlchemyError as e: logger.error(f"Database error updating status: {e}") self.session.rollback() raise # ===== 删除操作 ===== def delete(self, feature_id: int) -> bool: """ 删除特征记录 Args: feature_id: 特征记录ID Returns: 是否成功删除 """ try: stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id) result = self.session.execute(stmt) deleted_count = result.rowcount if deleted_count > 0: logger.info(f"Deleted face feature: id={feature_id}") return True else: logger.warning(f"Cannot delete non-existent face feature: id={feature_id}") return False except SQLAlchemyError as e: logger.error(f"Database error deleting face feature: {e}") self.session.rollback() raise def delete_by_person(self, person_id: int) -> int: """ 删除指定人员的所有特征记录 Args: person_id: 人员ID Returns: 删除的记录数 """ try: stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id) result = self.session.execute(stmt) deleted_count = result.rowcount logger.info(f"Deleted {deleted_count} face features for person_id={person_id}") return deleted_count except SQLAlchemyError as e: logger.error(f"Database error deleting face features by person: {e}") self.session.rollback() raise # ===== 统计操作 ===== def get_stats(self) -> Dict[str, Any]: """ 获取特征记录统计信息 Returns: 统计信息字典 """ try: # 总记录数 total_stmt = select(func.count()).select_from(SurFaceFeature) total_result = self.session.execute(total_stmt) total_count = total_result.scalar_one() # 按状态统计 status_stmt = ( select(SurFaceFeature.status, func.count()) .group_by(SurFaceFeature.status) ) status_result = self.session.execute(status_stmt) status_stats = {str(status): count for status, count in status_result} # 按特征类型统计 type_stmt = ( select(SurFaceFeature.feature_type, func.count()) .where(SurFaceFeature.feature_type.isnot(None)) .group_by(SurFaceFeature.feature_type) ) type_result = self.session.execute(type_stmt) type_stats = {str(feature_type): count for feature_type, count in type_result} # 平均处理时间(仅计算成功和失败的) time_stmt = ( select( func.avg( func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time) ) ) .where( and_( SurFaceFeature.start_time.isnot(None), SurFaceFeature.finish_time.isnot(None), SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED]) ) ) ) time_result = self.session.execute(time_stmt) avg_time = time_result.scalar_one() stats = { "total_count": total_count, "by_status": status_stats, "by_feature_type": type_stats, "avg_processing_time": float(avg_time) if avg_time else None } logger.debug(f"Generated face feature statistics") return stats except SQLAlchemyError as e: logger.error(f"Database error getting statistics: {e}") raise # ===== 批量操作 ===== def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]: """ 标记待处理的特征记录为计算中 Args: limit: 最大处理数量 Returns: 标记为处理中的特征记录列表 """ try: # 查找待处理的记录 pending_stmt = ( select(SurFaceFeature) .where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED) .order_by(SurFaceFeature.created_time) .limit(limit) .with_for_update(skip_locked=True) # 跳过被锁定的行 ) result = self.session.execute(pending_stmt) pending_features = list(result.scalars().all()) # 更新状态 feature_ids = [f.id for f in pending_features] if feature_ids: update_stmt = ( update(SurFaceFeature) .where(SurFaceFeature.id.in_(feature_ids)) .values( status=FeatureStatusEnum.PROCESSING, start_time=datetime.now() ) ) self.session.execute(update_stmt) logger.info(f"Marked {len(pending_features)} face features for processing") return pending_features except SQLAlchemyError as e: logger.error(f"Database error marking features for processing: {e}") self.session.rollback() raise def cleanup_old_features(self, days: int = 30) -> int: """ 清理旧的特征记录 Args: days: 保留天数 Returns: 删除的记录数 """ try: cutoff_date = datetime.now() - timedelta(days=days) # 只删除已完成(成功或失败)的旧记录 stmt = delete(SurFaceFeature).where( and_( SurFaceFeature.created_time < cutoff_date, SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED]) ) ) result = self.session.execute(stmt) deleted_count = result.rowcount logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)") return deleted_count except SQLAlchemyError as e: logger.error(f"Database error cleaning up old features: {e}") self.session.rollback() raise # ===== 新增方法:支持算法路由 ===== def get_features_by_type_and_status( self, feature_type: int, status: Optional[int] = None ) -> List[SurFaceFeature]: """ 根据特征类型和状态获取特征记录 Args: feature_type: 特征类型 status: 状态(可选) Returns: 特征记录列表 """ try: conditions = [SurFaceFeature.feature_type == feature_type] if status is not None: conditions.append(SurFaceFeature.status == status) stmt = ( select(SurFaceFeature) .where(and_(*conditions)) .order_by(asc(SurFaceFeature.created_time)) ) result = self.session.execute(stmt) features = list(result.scalars().all()) logger.debug(f"Retrieved {len(features)} features by type={feature_type}, status={status}") return features except SQLAlchemyError as e: logger.error(f"Database error getting features by type and status: {e}") raise def count_by_type_and_status( self, feature_type: int, status: Optional[int] = None ) -> int: """ 统计指定特征类型和状态的记录数量 Args: feature_type: 特征类型 status: 状态(可选) Returns: 记录数量 """ try: conditions = [SurFaceFeature.feature_type == feature_type] if status is not None: conditions.append(SurFaceFeature.status == status) stmt = select(func.count()).select_from(SurFaceFeature).where(and_(*conditions)) result = self.session.execute(stmt) count = result.scalar_one() return count except SQLAlchemyError as e: logger.error(f"Database error counting features by type and status: {e}") raise def get_statistics(self) -> Dict[str, Any]: """ 获取统计信息(兼容性方法) Returns: 统计信息字典 """ return self.get_stats()