修改路径,从src放到根目录

This commit is contained in:
zqc
2026-01-08 10:32:36 +08:00
parent 96589ebdbd
commit f86effd63c
37 changed files with 51 additions and 410 deletions

View File

@@ -0,0 +1,676 @@
"""
人脸特征数据仓库
数据访问层,处理所有数据库操作
"""
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
from contextlib import contextmanager
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from models.face_feature import SurFaceFeature
from schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FeatureStatus
)
from models.face_feature import FeatureStatusEnum
from utils.logger import setup_logger
logger = setup_logger(__name__)
class FaceFeatureRepository:
"""人脸特征数据仓库"""
def __init__(self, session: Session):
"""
初始化仓库
Args:
session: SQLAlchemy会话对象
"""
self.session = session
# ===== 创建操作 =====
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
"""
创建特征记录
Args:
feature_data: 特征数据
Returns:
创建的SurFaceFeature对象
Raises:
IntegrityError: 违反唯一约束时抛出
"""
try:
# 转换为模型字典
feature_dict = feature_data.model_dump(exclude_unset=True)
# 创建模型实例
feature = SurFaceFeature(**feature_dict)
# 添加到会话
self.session.add(feature)
self.session.flush() # 立即执行,但不提交
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error creating face feature: {e}")
self.session.rollback()
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
f"feature_type={feature_data.feature_type}")
except SQLAlchemyError as e:
logger.error(f"Database error creating face feature: {e}")
self.session.rollback()
raise
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
"""
批量创建特征记录
Args:
features_data: 特征数据列表
Returns:
创建的SurFaceFeature对象列表
"""
try:
features = []
for feature_data in features_data:
feature_dict = feature_data.model_dump(exclude_unset=True)
feature = SurFaceFeature(**feature_dict)
features.append(feature)
# 批量添加
self.session.add_all(features)
self.session.flush()
logger.info(f"Created {len(features)} face feature records in batch")
return features
except IntegrityError as e:
logger.error(f"Integrity error creating batch face features: {e}")
self.session.rollback()
raise ValueError("Duplicate feature record in batch")
except SQLAlchemyError as e:
logger.error(f"Database error creating batch face features: {e}")
self.session.rollback()
raise
# ===== 查询操作 =====
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
"""
根据ID获取特征记录
Args:
feature_id: 特征记录ID
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature by id: {feature_id}")
else:
logger.debug(f"Face feature not found by id: {feature_id}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by id: {e}")
raise
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
"""
根据人员ID和特征类型获取特征记录
Args:
person_id: 人员ID
feature_type: 特征类型
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(
and_(
SurFaceFeature.person_id == person_id,
SurFaceFeature.feature_type == feature_type
)
)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
else:
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by person and type: {e}")
raise
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
"""
根据人员ID获取特征记录列表
Args:
person_id: 人员ID
limit: 返回数量限制
Returns:
SurFaceFeature对象列表
"""
try:
stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.person_id == person_id)
.order_by(desc(SurFaceFeature.created_time))
.limit(limit)
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting face features by person: {e}")
raise
def query_features(
self,
query: FaceFeatureQuery,
page: int = 1,
page_size: int = 20,
order_by: str = "created_time",
desc_order: bool = True
) -> Tuple[List[SurFaceFeature], int]:
"""
查询特征记录(带分页)
Args:
query: 查询条件
page: 页码从1开始
page_size: 每页数量
order_by: 排序字段
desc_order: 是否降序
Returns:
(特征记录列表, 总记录数)
"""
try:
# 构建查询条件
conditions = []
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
# 处理查询条件
if "person_id" in query_dict:
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
if "feature_type" in query_dict:
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
if "status" in query_dict:
conditions.append(SurFaceFeature.status == query_dict["status"])
if "start_date" in query_dict:
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
if "end_date" in query_dict:
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
if "has_feature_data" in query_dict:
if query_dict["has_feature_data"]:
conditions.append(SurFaceFeature.feature_data.isnot(None))
else:
conditions.append(SurFaceFeature.feature_data.is_(None))
# 基础查询
stmt = select(SurFaceFeature)
if conditions:
stmt = stmt.where(and_(*conditions))
# 获取总数
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = self.session.execute(count_stmt)
total = total_result.scalar_one()
# 排序
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
if desc_order:
stmt = stmt.order_by(desc(order_column))
else:
stmt = stmt.order_by(asc(order_column))
# 分页
offset = (page - 1) * page_size
stmt = stmt.offset(offset).limit(page_size)
# 执行查询
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Query returned {len(features)} features (total: {total})")
return features, total
except SQLAlchemyError as e:
logger.error(f"Database error querying face features: {e}")
raise
# ===== 更新操作 =====
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
"""
更新特征记录
Args:
feature_id: 特征记录ID
update_data: 更新数据
Returns:
更新后的SurFaceFeature对象或None如果不存在
"""
try:
# 先检查是否存在
feature = self.get_by_id(feature_id)
if not feature:
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
return None
# 转换为字典
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
# 更新字段
for key, value in update_dict.items():
setattr(feature, key, value)
# 刷新到数据库
self.session.flush()
logger.info(f"Updated face feature: id={feature_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error updating face feature: {e}")
self.session.rollback()
raise ValueError("Update would create duplicate record")
except SQLAlchemyError as e:
logger.error(f"Database error updating face feature: {e}")
self.session.rollback()
raise
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
"""
更新特征数据
Args:
feature_id: 特征记录ID
feature_data: 特征数据(二进制)
Returns:
是否成功更新
"""
try:
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(feature_data=feature_data)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated feature data for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating feature data: {e}")
self.session.rollback()
raise
def update_status(
self,
feature_id: int,
status: FeatureStatus,
start_time: Optional[datetime] = None,
finish_time: Optional[datetime] = None
) -> bool:
"""
更新计算状态
Args:
feature_id: 特征记录ID
status: 新状态
start_time: 开始时间(可选)
finish_time: 结束时间(可选)
Returns:
是否成功更新
"""
try:
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
if start_time:
update_values["start_time"] = start_time
if finish_time:
update_values["finish_time"] = finish_time
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(**update_values)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating status: {e}")
self.session.rollback()
raise
# ===== 删除操作 =====
def delete(self, feature_id: int) -> bool:
"""
删除特征记录
Args:
feature_id: 特征记录ID
Returns:
是否成功删除
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
if deleted_count > 0:
logger.info(f"Deleted face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error deleting face feature: {e}")
self.session.rollback()
raise
def delete_by_person(self, person_id: int) -> int:
"""
删除指定人员的所有特征记录
Args:
person_id: 人员ID
Returns:
删除的记录数
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error deleting face features by person: {e}")
self.session.rollback()
raise
# ===== 统计操作 =====
def get_stats(self) -> Dict[str, Any]:
"""
获取特征记录统计信息
Returns:
统计信息字典
"""
try:
# 总记录数
total_stmt = select(func.count()).select_from(SurFaceFeature)
total_result = self.session.execute(total_stmt)
total_count = total_result.scalar_one()
# 按状态统计
status_stmt = (
select(SurFaceFeature.status, func.count())
.group_by(SurFaceFeature.status)
)
status_result = self.session.execute(status_stmt)
status_stats = {str(status): count for status, count in status_result}
# 按特征类型统计
type_stmt = (
select(SurFaceFeature.feature_type, func.count())
.where(SurFaceFeature.feature_type.isnot(None))
.group_by(SurFaceFeature.feature_type)
)
type_result = self.session.execute(type_stmt)
type_stats = {str(feature_type): count for feature_type, count in type_result}
# 平均处理时间(仅计算成功和失败的)
time_stmt = (
select(
func.avg(
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
)
)
.where(
and_(
SurFaceFeature.start_time.isnot(None),
SurFaceFeature.finish_time.isnot(None),
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
)
time_result = self.session.execute(time_stmt)
avg_time = time_result.scalar_one()
stats = {
"total_count": total_count,
"by_status": status_stats,
"by_feature_type": type_stats,
"avg_processing_time": float(avg_time) if avg_time else None
}
logger.debug(f"Generated face feature statistics")
return stats
except SQLAlchemyError as e:
logger.error(f"Database error getting statistics: {e}")
raise
# ===== 批量操作 =====
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
"""
标记待处理的特征记录为计算中
Args:
limit: 最大处理数量
Returns:
标记为处理中的特征记录列表
"""
try:
# 查找待处理的记录
pending_stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
.order_by(SurFaceFeature.created_time)
.limit(limit)
.with_for_update(skip_locked=True) # 跳过被锁定的行
)
result = self.session.execute(pending_stmt)
pending_features = list(result.scalars().all())
# 更新状态
feature_ids = [f.id for f in pending_features]
if feature_ids:
update_stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id.in_(feature_ids))
.values(
status=FeatureStatusEnum.PROCESSING,
start_time=datetime.now()
)
)
self.session.execute(update_stmt)
logger.info(f"Marked {len(pending_features)} face features for processing")
return pending_features
except SQLAlchemyError as e:
logger.error(f"Database error marking features for processing: {e}")
self.session.rollback()
raise
def cleanup_old_features(self, days: int = 30) -> int:
"""
清理旧的特征记录
Args:
days: 保留天数
Returns:
删除的记录数
"""
try:
cutoff_date = datetime.now() - timedelta(days=days)
# 只删除已完成(成功或失败)的旧记录
stmt = delete(SurFaceFeature).where(
and_(
SurFaceFeature.created_time < cutoff_date,
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error cleaning up old features: {e}")
self.session.rollback()
raise
# ===== 新增方法:支持算法路由 =====
def get_features_by_type_and_status(
self,
feature_type: int,
status: Optional[int] = None
) -> List[SurFaceFeature]:
"""
根据特征类型和状态获取特征记录
Args:
feature_type: 特征类型
status: 状态(可选)
Returns:
特征记录列表
"""
try:
conditions = [SurFaceFeature.feature_type == feature_type]
if status is not None:
conditions.append(SurFaceFeature.status == status)
stmt = (
select(SurFaceFeature)
.where(and_(*conditions))
.order_by(asc(SurFaceFeature.created_time))
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} features by type={feature_type}, status={status}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting features by type and status: {e}")
raise
def count_by_type_and_status(
self,
feature_type: int,
status: Optional[int] = None
) -> int:
"""
统计指定特征类型和状态的记录数量
Args:
feature_type: 特征类型
status: 状态(可选)
Returns:
记录数量
"""
try:
conditions = [SurFaceFeature.feature_type == feature_type]
if status is not None:
conditions.append(SurFaceFeature.status == status)
stmt = select(func.count()).select_from(SurFaceFeature).where(and_(*conditions))
result = self.session.execute(stmt)
count = result.scalar_one()
return count
except SQLAlchemyError as e:
logger.error(f"Database error counting features by type and status: {e}")
raise
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息(兼容性方法)
Returns:
统计信息字典
"""
return self.get_stats()