Files
SupervisorAI/repositories/face_feature_repository.py

676 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
人脸特征数据仓库
数据访问层,处理所有数据库操作
"""
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
from contextlib import contextmanager
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from models.face_feature import SurFaceFeature
from schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FeatureStatus
)
from models.face_feature import FeatureStatusEnum
from utils.logger import setup_logger
logger = setup_logger(__name__)
class FaceFeatureRepository:
"""人脸特征数据仓库"""
def __init__(self, session: Session):
"""
初始化仓库
Args:
session: SQLAlchemy会话对象
"""
self.session = session
# ===== 创建操作 =====
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
"""
创建特征记录
Args:
feature_data: 特征数据
Returns:
创建的SurFaceFeature对象
Raises:
IntegrityError: 违反唯一约束时抛出
"""
try:
# 转换为模型字典
feature_dict = feature_data.model_dump(exclude_unset=True)
# 创建模型实例
feature = SurFaceFeature(**feature_dict)
# 添加到会话
self.session.add(feature)
self.session.flush() # 立即执行,但不提交
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error creating face feature: {e}")
self.session.rollback()
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
f"feature_type={feature_data.feature_type}")
except SQLAlchemyError as e:
logger.error(f"Database error creating face feature: {e}")
self.session.rollback()
raise
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
"""
批量创建特征记录
Args:
features_data: 特征数据列表
Returns:
创建的SurFaceFeature对象列表
"""
try:
features = []
for feature_data in features_data:
feature_dict = feature_data.model_dump(exclude_unset=True)
feature = SurFaceFeature(**feature_dict)
features.append(feature)
# 批量添加
self.session.add_all(features)
self.session.flush()
logger.info(f"Created {len(features)} face feature records in batch")
return features
except IntegrityError as e:
logger.error(f"Integrity error creating batch face features: {e}")
self.session.rollback()
raise ValueError("Duplicate feature record in batch")
except SQLAlchemyError as e:
logger.error(f"Database error creating batch face features: {e}")
self.session.rollback()
raise
# ===== 查询操作 =====
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
"""
根据ID获取特征记录
Args:
feature_id: 特征记录ID
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature by id: {feature_id}")
else:
logger.debug(f"Face feature not found by id: {feature_id}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by id: {e}")
raise
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
"""
根据人员ID和特征类型获取特征记录
Args:
person_id: 人员ID
feature_type: 特征类型
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(
and_(
SurFaceFeature.person_id == person_id,
SurFaceFeature.feature_type == feature_type
)
)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
else:
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by person and type: {e}")
raise
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
"""
根据人员ID获取特征记录列表
Args:
person_id: 人员ID
limit: 返回数量限制
Returns:
SurFaceFeature对象列表
"""
try:
stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.person_id == person_id)
.order_by(desc(SurFaceFeature.created_time))
.limit(limit)
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting face features by person: {e}")
raise
def query_features(
self,
query: FaceFeatureQuery,
page: int = 1,
page_size: int = 20,
order_by: str = "created_time",
desc_order: bool = True
) -> Tuple[List[SurFaceFeature], int]:
"""
查询特征记录(带分页)
Args:
query: 查询条件
page: 页码从1开始
page_size: 每页数量
order_by: 排序字段
desc_order: 是否降序
Returns:
(特征记录列表, 总记录数)
"""
try:
# 构建查询条件
conditions = []
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
# 处理查询条件
if "person_id" in query_dict:
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
if "feature_type" in query_dict:
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
if "status" in query_dict:
conditions.append(SurFaceFeature.status == query_dict["status"])
if "start_date" in query_dict:
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
if "end_date" in query_dict:
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
if "has_feature_data" in query_dict:
if query_dict["has_feature_data"]:
conditions.append(SurFaceFeature.feature_data.isnot(None))
else:
conditions.append(SurFaceFeature.feature_data.is_(None))
# 基础查询
stmt = select(SurFaceFeature)
if conditions:
stmt = stmt.where(and_(*conditions))
# 获取总数
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = self.session.execute(count_stmt)
total = total_result.scalar_one()
# 排序
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
if desc_order:
stmt = stmt.order_by(desc(order_column))
else:
stmt = stmt.order_by(asc(order_column))
# 分页
offset = (page - 1) * page_size
stmt = stmt.offset(offset).limit(page_size)
# 执行查询
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Query returned {len(features)} features (total: {total})")
return features, total
except SQLAlchemyError as e:
logger.error(f"Database error querying face features: {e}")
raise
# ===== 更新操作 =====
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
"""
更新特征记录
Args:
feature_id: 特征记录ID
update_data: 更新数据
Returns:
更新后的SurFaceFeature对象或None如果不存在
"""
try:
# 先检查是否存在
feature = self.get_by_id(feature_id)
if not feature:
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
return None
# 转换为字典
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
# 更新字段
for key, value in update_dict.items():
setattr(feature, key, value)
# 刷新到数据库
self.session.flush()
logger.info(f"Updated face feature: id={feature_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error updating face feature: {e}")
self.session.rollback()
raise ValueError("Update would create duplicate record")
except SQLAlchemyError as e:
logger.error(f"Database error updating face feature: {e}")
self.session.rollback()
raise
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
"""
更新特征数据
Args:
feature_id: 特征记录ID
feature_data: 特征数据(二进制)
Returns:
是否成功更新
"""
try:
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(feature_data=feature_data)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated feature data for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating feature data: {e}")
self.session.rollback()
raise
def update_status(
self,
feature_id: int,
status: FeatureStatus,
start_time: Optional[datetime] = None,
finish_time: Optional[datetime] = None
) -> bool:
"""
更新计算状态
Args:
feature_id: 特征记录ID
status: 新状态
start_time: 开始时间(可选)
finish_time: 结束时间(可选)
Returns:
是否成功更新
"""
try:
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
if start_time:
update_values["start_time"] = start_time
if finish_time:
update_values["finish_time"] = finish_time
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(**update_values)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating status: {e}")
self.session.rollback()
raise
# ===== 删除操作 =====
def delete(self, feature_id: int) -> bool:
"""
删除特征记录
Args:
feature_id: 特征记录ID
Returns:
是否成功删除
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
if deleted_count > 0:
logger.info(f"Deleted face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error deleting face feature: {e}")
self.session.rollback()
raise
def delete_by_person(self, person_id: int) -> int:
"""
删除指定人员的所有特征记录
Args:
person_id: 人员ID
Returns:
删除的记录数
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error deleting face features by person: {e}")
self.session.rollback()
raise
# ===== 统计操作 =====
def get_stats(self) -> Dict[str, Any]:
"""
获取特征记录统计信息
Returns:
统计信息字典
"""
try:
# 总记录数
total_stmt = select(func.count()).select_from(SurFaceFeature)
total_result = self.session.execute(total_stmt)
total_count = total_result.scalar_one()
# 按状态统计
status_stmt = (
select(SurFaceFeature.status, func.count())
.group_by(SurFaceFeature.status)
)
status_result = self.session.execute(status_stmt)
status_stats = {str(status): count for status, count in status_result}
# 按特征类型统计
type_stmt = (
select(SurFaceFeature.feature_type, func.count())
.where(SurFaceFeature.feature_type.isnot(None))
.group_by(SurFaceFeature.feature_type)
)
type_result = self.session.execute(type_stmt)
type_stats = {str(feature_type): count for feature_type, count in type_result}
# 平均处理时间(仅计算成功和失败的)
time_stmt = (
select(
func.avg(
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
)
)
.where(
and_(
SurFaceFeature.start_time.isnot(None),
SurFaceFeature.finish_time.isnot(None),
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
)
time_result = self.session.execute(time_stmt)
avg_time = time_result.scalar_one()
stats = {
"total_count": total_count,
"by_status": status_stats,
"by_feature_type": type_stats,
"avg_processing_time": float(avg_time) if avg_time else None
}
logger.debug(f"Generated face feature statistics")
return stats
except SQLAlchemyError as e:
logger.error(f"Database error getting statistics: {e}")
raise
# ===== 批量操作 =====
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
"""
标记待处理的特征记录为计算中
Args:
limit: 最大处理数量
Returns:
标记为处理中的特征记录列表
"""
try:
# 查找待处理的记录
pending_stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
.order_by(SurFaceFeature.created_time)
.limit(limit)
.with_for_update(skip_locked=True) # 跳过被锁定的行
)
result = self.session.execute(pending_stmt)
pending_features = list(result.scalars().all())
# 更新状态
feature_ids = [f.id for f in pending_features]
if feature_ids:
update_stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id.in_(feature_ids))
.values(
status=FeatureStatusEnum.PROCESSING,
start_time=datetime.now()
)
)
self.session.execute(update_stmt)
logger.info(f"Marked {len(pending_features)} face features for processing")
return pending_features
except SQLAlchemyError as e:
logger.error(f"Database error marking features for processing: {e}")
self.session.rollback()
raise
def cleanup_old_features(self, days: int = 30) -> int:
"""
清理旧的特征记录
Args:
days: 保留天数
Returns:
删除的记录数
"""
try:
cutoff_date = datetime.now() - timedelta(days=days)
# 只删除已完成(成功或失败)的旧记录
stmt = delete(SurFaceFeature).where(
and_(
SurFaceFeature.created_time < cutoff_date,
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error cleaning up old features: {e}")
self.session.rollback()
raise
# ===== 新增方法:支持算法路由 =====
def get_features_by_type_and_status(
self,
feature_type: int,
status: Optional[int] = None
) -> List[SurFaceFeature]:
"""
根据特征类型和状态获取特征记录
Args:
feature_type: 特征类型
status: 状态(可选)
Returns:
特征记录列表
"""
try:
conditions = [SurFaceFeature.feature_type == feature_type]
if status is not None:
conditions.append(SurFaceFeature.status == status)
stmt = (
select(SurFaceFeature)
.where(and_(*conditions))
.order_by(asc(SurFaceFeature.created_time))
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} features by type={feature_type}, status={status}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting features by type and status: {e}")
raise
def count_by_type_and_status(
self,
feature_type: int,
status: Optional[int] = None
) -> int:
"""
统计指定特征类型和状态的记录数量
Args:
feature_type: 特征类型
status: 状态(可选)
Returns:
记录数量
"""
try:
conditions = [SurFaceFeature.feature_type == feature_type]
if status is not None:
conditions.append(SurFaceFeature.status == status)
stmt = select(func.count()).select_from(SurFaceFeature).where(and_(*conditions))
result = self.session.execute(stmt)
count = result.scalar_one()
return count
except SQLAlchemyError as e:
logger.error(f"Database error counting features by type and status: {e}")
raise
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息(兼容性方法)
Returns:
统计信息字典
"""
return self.get_stats()