引入http服务,引入数据库
This commit is contained in:
67
src/api/dependencies.py
Normal file
67
src/api/dependencies.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
FastAPI依赖注入模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Generator, Optional
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from src.database.connection import db_manager, get_db
|
||||||
|
from src.repositories.face_feature_repository import FaceFeatureRepository
|
||||||
|
from src.services.face_feature_service import FaceFeatureService
|
||||||
|
|
||||||
|
|
||||||
|
def get_face_feature_repository(
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> FaceFeatureRepository:
|
||||||
|
"""
|
||||||
|
获取人脸特征仓库依赖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FaceFeatureRepository实例
|
||||||
|
"""
|
||||||
|
return FaceFeatureRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
def get_face_feature_service(
|
||||||
|
repository: FaceFeatureRepository = Depends(get_face_feature_repository)
|
||||||
|
) -> FaceFeatureService:
|
||||||
|
"""
|
||||||
|
获取人脸特征服务依赖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: 人脸特征仓库
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FaceFeatureService实例
|
||||||
|
"""
|
||||||
|
return FaceFeatureService(repository)
|
||||||
|
|
||||||
|
|
||||||
|
def get_face_feature_by_id(
|
||||||
|
feature_id: int,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据ID获取人脸特征记录的依赖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
service: 人脸特征服务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
人脸特征记录
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 如果记录不存在
|
||||||
|
"""
|
||||||
|
feature = service.get_feature(feature_id)
|
||||||
|
if not feature:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"特征记录不存在 (ID: {feature_id})"
|
||||||
|
)
|
||||||
|
return feature
|
||||||
152
src/api/errors.py
Normal file
152
src/api/errors.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
API错误处理模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.requests import Request
|
||||||
|
|
||||||
|
|
||||||
|
class APIError(HTTPException):
|
||||||
|
"""API自定义错误"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail: Any = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(status_code=status_code, detail=detail, headers=headers)
|
||||||
|
self.error_code = error_code or f"ERR_{status_code}"
|
||||||
|
|
||||||
|
|
||||||
|
class FaceFeatureProcessingError(APIError):
|
||||||
|
"""人脸特征处理错误"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
detail: str = "人脸特征处理失败",
|
||||||
|
feature_id: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if feature_id:
|
||||||
|
detail = f"人脸特征处理失败 (特征ID: {feature_id}): {detail}"
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=detail,
|
||||||
|
error_code="FACE_FEATURE_PROCESSING_ERROR"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureNotFoundError(APIError):
|
||||||
|
"""特征记录不存在错误"""
|
||||||
|
|
||||||
|
def __init__(self, feature_id: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"特征记录不存在 (ID: {feature_id})",
|
||||||
|
error_code="FEATURE_NOT_FOUND"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateFeatureError(APIError):
|
||||||
|
"""重复特征记录错误"""
|
||||||
|
|
||||||
|
def __init__(self, person_id: int, feature_type: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"特征记录已存在 (人员ID: {person_id}, 特征类型: {feature_type})",
|
||||||
|
error_code="DUPLICATE_FEATURE"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def validation_exception_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: RequestValidationError
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
请求验证异常处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
exc: 验证异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON响应
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
for error in exc.errors():
|
||||||
|
field = ".".join(str(loc) for loc in error["loc"] if loc != "body")
|
||||||
|
errors.append({
|
||||||
|
"field": field or "body",
|
||||||
|
"message": error["msg"],
|
||||||
|
"type": error["type"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "VALIDATION_ERROR",
|
||||||
|
"message": "请求参数验证失败",
|
||||||
|
"details": errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def api_error_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: APIError
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
API错误处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
exc: API错误
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON响应
|
||||||
|
"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": exc.error_code,
|
||||||
|
"message": exc.detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generic_exception_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: Exception
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
通用异常处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
exc: 异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON响应
|
||||||
|
"""
|
||||||
|
# 记录异常到日志
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.error(f"未处理的异常: {exc}", exc_info=True)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "INTERNAL_SERVER_ERROR",
|
||||||
|
"message": "服务器内部错误"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
417
src/api/routes/face_features.py
Normal file
417
src/api/routes/face_features.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
"""
|
||||||
|
人脸特征API路由
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from src.schemas.face_feature import (
|
||||||
|
FaceFeatureCreate,
|
||||||
|
FaceFeatureUpdate,
|
||||||
|
FaceFeatureQuery,
|
||||||
|
FaceFeatureResponse,
|
||||||
|
FaceFeatureListResponse,
|
||||||
|
FaceFeatureStatsResponse,
|
||||||
|
BatchFaceFeatureCreate,
|
||||||
|
FeatureStatus
|
||||||
|
)
|
||||||
|
from src.api.dependencies import (
|
||||||
|
get_face_feature_service,
|
||||||
|
get_face_feature_by_id
|
||||||
|
)
|
||||||
|
from src.services.face_feature_service import FaceFeatureService
|
||||||
|
from src.api.errors import (
|
||||||
|
FaceFeatureProcessingError,
|
||||||
|
FeatureNotFoundError,
|
||||||
|
DuplicateFeatureError
|
||||||
|
)
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
# 创建路由器
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/face-features",
|
||||||
|
tags=["人脸特征管理"],
|
||||||
|
responses={
|
||||||
|
404: {"description": "资源不存在"},
|
||||||
|
400: {"description": "请求参数错误"},
|
||||||
|
500: {"description": "服务器内部错误"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/",
|
||||||
|
response_model=FaceFeatureResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="创建人脸特征记录",
|
||||||
|
description="创建新的人脸特征记录"
|
||||||
|
)
|
||||||
|
async def create_face_feature(
|
||||||
|
feature_data: FaceFeatureCreate,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建人脸特征记录
|
||||||
|
|
||||||
|
- **person_id**: 人员ID (必须,大于0)
|
||||||
|
- **feature_type**: 特征类型 (可选,大于等于0)
|
||||||
|
- **pic_id**: 图片ID (可选)
|
||||||
|
- **status**: 计算状态 (默认: NOT_STARTED)
|
||||||
|
- **feature_data**: 特征数据 (可选,二进制)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return service.create_feature(feature_data)
|
||||||
|
except ValueError as e:
|
||||||
|
if "already exists" in str(e):
|
||||||
|
# 解析错误信息中的person_id和feature_type
|
||||||
|
raise DuplicateFeatureError(
|
||||||
|
person_id=feature_data.person_id,
|
||||||
|
feature_type=feature_data.feature_type or 0
|
||||||
|
)
|
||||||
|
raise FaceFeatureProcessingError(detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{feature_id}",
|
||||||
|
response_model=FaceFeatureResponse,
|
||||||
|
summary="获取人脸特征记录",
|
||||||
|
description="根据ID获取人脸特征记录"
|
||||||
|
)
|
||||||
|
async def get_face_feature(
|
||||||
|
feature: FaceFeatureResponse = Depends(get_face_feature_by_id)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据ID获取人脸特征记录
|
||||||
|
|
||||||
|
- **feature_id**: 特征记录ID (路径参数)
|
||||||
|
"""
|
||||||
|
return feature
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/",
|
||||||
|
response_model=FaceFeatureListResponse,
|
||||||
|
summary="查询人脸特征记录",
|
||||||
|
description="查询人脸特征记录列表,支持分页和过滤"
|
||||||
|
)
|
||||||
|
async def list_face_features(
|
||||||
|
person_id: Optional[int] = Query(None, description="人员ID", gt=0),
|
||||||
|
feature_type: Optional[int] = Query(None, description="特征类型", ge=0),
|
||||||
|
status: Optional[FeatureStatus] = Query(None, description="计算状态"),
|
||||||
|
start_date: Optional[datetime] = Query(None, description="开始时间"),
|
||||||
|
end_date: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
|
has_feature_data: Optional[bool] = Query(None, description="是否有特征数据"),
|
||||||
|
page: int = Query(1, description="页码", ge=1),
|
||||||
|
page_size: int = Query(20, description="每页数量", ge=1, le=100),
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
查询人脸特征记录
|
||||||
|
|
||||||
|
- **person_id**: 按人员ID过滤 (可选)
|
||||||
|
- **feature_type**: 按特征类型过滤 (可选)
|
||||||
|
- **status**: 按计算状态过滤 (可选)
|
||||||
|
- **start_date**: 开始时间过滤 (可选)
|
||||||
|
- **end_date**: 结束时间过滤 (可选)
|
||||||
|
- **has_feature_data**: 是否有特征数据过滤 (可选)
|
||||||
|
- **page**: 页码 (默认: 1)
|
||||||
|
- **page_size**: 每页数量 (默认: 20, 最大: 100)
|
||||||
|
"""
|
||||||
|
# 构建查询参数
|
||||||
|
query = FaceFeatureQuery(
|
||||||
|
person_id=person_id,
|
||||||
|
feature_type=feature_type,
|
||||||
|
status=status,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
has_feature_data=has_feature_data
|
||||||
|
)
|
||||||
|
|
||||||
|
return service.query_features(
|
||||||
|
query=query,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
order_by="created_time",
|
||||||
|
desc_order=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/{feature_id}",
|
||||||
|
response_model=FaceFeatureResponse,
|
||||||
|
summary="更新人脸特征记录",
|
||||||
|
description="更新指定ID的人脸特征记录"
|
||||||
|
)
|
||||||
|
async def update_face_feature(
|
||||||
|
feature_id: int,
|
||||||
|
update_data: FaceFeatureUpdate,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新人脸特征记录
|
||||||
|
|
||||||
|
- **feature_id**: 特征记录ID (路径参数)
|
||||||
|
- **update_data**: 更新数据 (请求体)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = service.update_feature(feature_id, update_data)
|
||||||
|
if not result:
|
||||||
|
raise FeatureNotFoundError(feature_id)
|
||||||
|
return result
|
||||||
|
except ValueError as e:
|
||||||
|
raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{feature_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="删除人脸特征记录",
|
||||||
|
description="删除指定ID的人脸特征记录"
|
||||||
|
)
|
||||||
|
async def delete_face_feature(
|
||||||
|
feature_id: int,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除人脸特征记录
|
||||||
|
|
||||||
|
- **feature_id**: 特征记录ID (路径参数)
|
||||||
|
"""
|
||||||
|
success = service.delete_feature(feature_id)
|
||||||
|
if not success:
|
||||||
|
raise FeatureNotFoundError(feature_id)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
content=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{feature_id}/start-processing",
|
||||||
|
response_model=FaceFeatureResponse,
|
||||||
|
summary="开始处理人脸特征",
|
||||||
|
description="开始计算指定ID的人脸特征值"
|
||||||
|
)
|
||||||
|
async def start_face_feature_processing(
|
||||||
|
feature_id: int,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
开始处理人脸特征计算
|
||||||
|
|
||||||
|
- **feature_id**: 特征记录ID (路径参数)
|
||||||
|
|
||||||
|
注意:这是一个异步处理接口,会立即返回开始状态,
|
||||||
|
实际特征计算可能在后台进行。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先获取特征记录
|
||||||
|
feature = service.get_feature(feature_id)
|
||||||
|
if not feature:
|
||||||
|
raise FeatureNotFoundError(feature_id)
|
||||||
|
|
||||||
|
# 检查是否可以开始处理
|
||||||
|
if feature.status != FeatureStatus.NOT_STARTED:
|
||||||
|
raise FaceFeatureProcessingError(
|
||||||
|
detail=f"特征记录状态为 {feature.status_name},无法开始处理",
|
||||||
|
feature_id=feature_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 开始处理
|
||||||
|
success = service.start_processing(feature_id)
|
||||||
|
if not success:
|
||||||
|
raise FaceFeatureProcessingError(
|
||||||
|
detail="开始处理失败",
|
||||||
|
feature_id=feature_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 异步任务:模拟特征计算过程
|
||||||
|
# 在实际应用中,这里应该调用实际的特征计算服务
|
||||||
|
background_tasks.add_task(
|
||||||
|
simulate_feature_processing,
|
||||||
|
feature_id=feature_id,
|
||||||
|
service=service
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回更新后的特征记录
|
||||||
|
return service.get_feature(feature_id)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{feature_id}/finish-processing",
|
||||||
|
response_model=FaceFeatureResponse,
|
||||||
|
summary="完成人脸特征处理",
|
||||||
|
description="完成指定ID的人脸特征值计算"
|
||||||
|
)
|
||||||
|
async def finish_face_feature_processing(
|
||||||
|
feature_id: int,
|
||||||
|
success: bool = Query(True, description="是否成功完成"),
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
完成人脸特征计算
|
||||||
|
|
||||||
|
- **feature_id**: 特征记录ID (路径参数)
|
||||||
|
- **success**: 是否成功完成 (查询参数,默认: true)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查特征记录
|
||||||
|
feature = service.get_feature(feature_id)
|
||||||
|
if not feature:
|
||||||
|
raise FeatureNotFoundError(feature_id)
|
||||||
|
|
||||||
|
# 检查是否可以完成处理
|
||||||
|
if feature.status != FeatureStatus.PROCESSING:
|
||||||
|
raise FaceFeatureProcessingError(
|
||||||
|
detail=f"特征记录状态为 {feature.status_name},无法完成处理",
|
||||||
|
feature_id=feature_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 完成处理
|
||||||
|
finish_success = service.finish_processing(feature_id, success)
|
||||||
|
if not finish_success:
|
||||||
|
raise FaceFeatureProcessingError(
|
||||||
|
detail="完成处理失败",
|
||||||
|
feature_id=feature_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return service.get_feature(feature_id)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/batch",
|
||||||
|
response_model=List[FaceFeatureResponse],
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="批量创建人脸特征记录",
|
||||||
|
description="批量创建多个人脸特征记录"
|
||||||
|
)
|
||||||
|
async def batch_create_face_features(
|
||||||
|
batch_data: BatchFaceFeatureCreate,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
批量创建人脸特征记录
|
||||||
|
|
||||||
|
- **items**: 特征记录列表 (必须,1-1000条)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return service.create_features_batch(batch_data)
|
||||||
|
except ValueError as e:
|
||||||
|
raise FaceFeatureProcessingError(detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/person/{person_id}",
|
||||||
|
response_model=List[FaceFeatureResponse],
|
||||||
|
summary="获取人员的人脸特征记录",
|
||||||
|
description="根据人员ID获取所有相关的人脸特征记录"
|
||||||
|
)
|
||||||
|
async def get_face_features_by_person(
|
||||||
|
person_id: int,
|
||||||
|
limit: int = Query(100, description="返回数量限制", ge=1, le=1000),
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取人员的人脸特征记录
|
||||||
|
|
||||||
|
- **person_id**: 人员ID (路径参数)
|
||||||
|
- **limit**: 返回数量限制 (查询参数,默认: 100, 最大: 1000)
|
||||||
|
"""
|
||||||
|
return service.list_features_by_person(person_id, limit)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/stats/summary",
|
||||||
|
response_model=FaceFeatureStatsResponse,
|
||||||
|
summary="获取特征记录统计信息",
|
||||||
|
description="获取人脸特征记录的统计摘要"
|
||||||
|
)
|
||||||
|
async def get_face_features_stats(
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取特征记录统计信息
|
||||||
|
"""
|
||||||
|
return service.get_statistics()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/person/{person_id}/stats",
|
||||||
|
summary="获取人员特征统计信息",
|
||||||
|
description="获取指定人员的特征记录统计信息"
|
||||||
|
)
|
||||||
|
async def get_person_face_features_stats(
|
||||||
|
person_id: int,
|
||||||
|
service: FaceFeatureService = Depends(get_face_feature_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取人员特征统计信息
|
||||||
|
|
||||||
|
- **person_id**: 人员ID (路径参数)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stats = service.get_person_statistics(person_id)
|
||||||
|
return {
|
||||||
|
"person_id": person_id,
|
||||||
|
"total_features": stats["total_features"],
|
||||||
|
"status_summary": stats["status_summary"],
|
||||||
|
"feature_types": stats["feature_types"],
|
||||||
|
"avg_processing_time": stats["avg_processing_time"],
|
||||||
|
"successful_count": stats["successful_count"]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise FaceFeatureProcessingError(detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
async def simulate_feature_processing(
|
||||||
|
feature_id: int,
|
||||||
|
service: FaceFeatureService
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
模拟人脸特征计算过程
|
||||||
|
|
||||||
|
在实际应用中,这里应该调用实际的特征计算算法
|
||||||
|
例如:使用InsightFace、OpenCV等库进行人脸特征提取
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
service: 人脸特征服务
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 模拟计算延迟 (3-10秒)
|
||||||
|
delay = random.uniform(3, 10)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# 模拟成功或失败 (90%成功率)
|
||||||
|
success = random.random() < 0.9
|
||||||
|
|
||||||
|
# 完成处理
|
||||||
|
service.finish_processing(feature_id, success)
|
||||||
|
|
||||||
|
# 如果成功,添加模拟的特征数据
|
||||||
|
if success:
|
||||||
|
# 生成模拟的512维特征向量 (float32)
|
||||||
|
import numpy as np
|
||||||
|
feature_data = np.random.randn(512).astype(np.float32).tobytes()
|
||||||
|
service.update_feature_data(feature_id, feature_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 如果发生异常,标记为失败
|
||||||
|
service.finish_processing(feature_id, False)
|
||||||
|
# 记录日志
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.error(f"特征计算失败 (ID: {feature_id}): {e}")
|
||||||
188
src/app.py
Normal file
188
src/app.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
FastAPI主应用
|
||||||
|
将原来的main.py重命名为app.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.openapi.docs import (
|
||||||
|
get_swagger_ui_html,
|
||||||
|
get_swagger_ui_oauth2_redirect_html,
|
||||||
|
get_redoc_html,
|
||||||
|
)
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from src.api.routes import face_features
|
||||||
|
from src.api.errors import (
|
||||||
|
APIError,
|
||||||
|
validation_exception_handler,
|
||||||
|
api_error_handler,
|
||||||
|
generic_exception_handler
|
||||||
|
)
|
||||||
|
from src.config import settings
|
||||||
|
from src.database.connection import init_database
|
||||||
|
from src.database.connection import db_manager
|
||||||
|
|
||||||
|
|
||||||
|
# 生命周期管理
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
应用生命周期管理
|
||||||
|
|
||||||
|
- 启动时:初始化数据库
|
||||||
|
- 关闭时:清理资源
|
||||||
|
"""
|
||||||
|
# 启动时
|
||||||
|
print("🚀 start algorithm service...")
|
||||||
|
print(f"📊 db: {settings.DATABASE_NAME}")
|
||||||
|
print(f"🔧 debug mode: {settings.DEBUG}")
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
init_database()
|
||||||
|
|
||||||
|
# 数据库健康检查
|
||||||
|
if db_manager.health_check():
|
||||||
|
print("✅ 数据库连接正常")
|
||||||
|
else:
|
||||||
|
print("❌ 数据库连接失败")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# 关闭时
|
||||||
|
print("🛑 algorithm service stopped...")
|
||||||
|
db_manager.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.PROJECT_NAME,
|
||||||
|
version=settings.PROJECT_VERSION,
|
||||||
|
description=settings.PROJECT_DESCRIPTION,
|
||||||
|
openapi_url=f"{settings.API_V1_PREFIX}/openapi.json",
|
||||||
|
docs_url=None, # 自定义docs
|
||||||
|
redoc_url=None, # 自定义redoc
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义API文档页面
|
||||||
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
async def custom_swagger_ui_html():
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
openapi_url=app.openapi_url,
|
||||||
|
title=f"{app.title} - Swagger UI",
|
||||||
|
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||||
|
swagger_js_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js",
|
||||||
|
swagger_css_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||||||
|
async def swagger_ui_redirect():
|
||||||
|
return get_swagger_ui_oauth2_redirect_html()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/redoc", include_in_schema=False)
|
||||||
|
async def redoc_html():
|
||||||
|
return get_redoc_html(
|
||||||
|
openapi_url=app.openapi_url,
|
||||||
|
title=f"{app.title} - ReDoc",
|
||||||
|
redoc_js_url="https://unpkg.com/redoc@next/bundles/redoc.standalone.js",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 中间件配置
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
TrustedHostMiddleware,
|
||||||
|
allowed_hosts=["*"] if settings.DEBUG else ["localhost", "127.0.0.1"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 请求计时中间件
|
||||||
|
@app.middleware("http")
|
||||||
|
async def add_process_time_header(request: Request, call_next):
|
||||||
|
"""
|
||||||
|
添加请求处理时间头
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
response = await call_next(request)
|
||||||
|
process_time = time.time() - start_time
|
||||||
|
response.headers["X-Process-Time"] = str(process_time)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# 异常处理器
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
app.add_exception_handler(APIError, api_error_handler)
|
||||||
|
app.add_exception_handler(Exception, generic_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
# 根路由
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""
|
||||||
|
根路径
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"message": "algorithm service",
|
||||||
|
"version": settings.PROJECT_VERSION,
|
||||||
|
"docs": "/docs",
|
||||||
|
"api_prefix": settings.API_V1_PREFIX
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""
|
||||||
|
健康检查端点
|
||||||
|
"""
|
||||||
|
# 检查数据库连接
|
||||||
|
db_healthy = db_manager.health_check()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "healthy" if db_healthy else "unhealthy",
|
||||||
|
"database": "connected" if db_healthy else "disconnected",
|
||||||
|
"timestamp": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 注册API路由
|
||||||
|
app.include_router(
|
||||||
|
face_features.router,
|
||||||
|
prefix=settings.API_V1_PREFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义404处理器
|
||||||
|
@app.exception_handler(404)
|
||||||
|
async def not_found_handler(request: Request, exc):
|
||||||
|
"""
|
||||||
|
自定义404错误处理器
|
||||||
|
"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "NOT_FOUND",
|
||||||
|
"message": f"请求的资源不存在: {request.url.path}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 导出应用实例
|
||||||
|
__all__ = ["app"]
|
||||||
91
src/config.py
Normal file
91
src/config.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
数据库配置模块
|
||||||
|
使用pydantic进行配置验证和管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from functools import lru_cache
|
||||||
|
from pydantic import PostgresDsn, field_validator
|
||||||
|
from pydantic_core.core_schema import FieldValidationInfo
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""应用配置类"""
|
||||||
|
|
||||||
|
# API配置
|
||||||
|
API_V1_PREFIX: str = "/api/v1"
|
||||||
|
PROJECT_NAME: str = "algorithm-service"
|
||||||
|
PROJECT_VERSION: str = "1.0.0"
|
||||||
|
PROJECT_DESCRIPTION: str = "algorithm-service"
|
||||||
|
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
|
||||||
|
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
DATABASE_HOST: str = "localhost"
|
||||||
|
DATABASE_PORT: int = 5432
|
||||||
|
DATABASE_USER: str = "postgres"
|
||||||
|
DATABASE_PASSWORD: str = "yipai123"
|
||||||
|
DATABASE_NAME: str = "pmms"
|
||||||
|
DATABASE_SCHEMA: str = "public"
|
||||||
|
|
||||||
|
# 连接池配置
|
||||||
|
DATABASE_POOL_SIZE: int = 10
|
||||||
|
DATABASE_MAX_OVERFLOW: int = 20
|
||||||
|
DATABASE_POOL_RECYCLE: int = 3600 # 连接回收时间(秒)
|
||||||
|
DATABASE_ECHO: bool = False # SQL日志,生产环境设为False
|
||||||
|
|
||||||
|
# 应用配置
|
||||||
|
APP_NAME: str = "SurFaceFeature API"
|
||||||
|
APP_VERSION: str = "1.0.0"
|
||||||
|
DEBUG: bool = False
|
||||||
|
|
||||||
|
# 日志配置
|
||||||
|
LOG_LEVEL: str = "INFO"
|
||||||
|
LOG_FILE: Optional[str] = None
|
||||||
|
|
||||||
|
# 异步配置
|
||||||
|
ASYNC_MODE: bool = False
|
||||||
|
|
||||||
|
# JWT配置(预留)
|
||||||
|
SECRET_KEY: str = "your-secret-key-here-change-in-production"
|
||||||
|
ALGORITHM: str = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
|
|
||||||
|
@property
|
||||||
|
def DATABASE_URL(self) -> str:
|
||||||
|
"""构建数据库连接URL"""
|
||||||
|
return f"postgresql://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ASYNC_DATABASE_URL(self) -> str:
|
||||||
|
"""构建异步数据库连接URL"""
|
||||||
|
return f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
|
||||||
|
|
||||||
|
@field_validator("DATABASE_POOL_SIZE")
|
||||||
|
def validate_pool_size(cls, v):
|
||||||
|
"""验证连接池大小"""
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError("DATABASE_POOL_SIZE must be at least 1")
|
||||||
|
if v > 100:
|
||||||
|
raise ValueError("DATABASE_POOL_SIZE cannot exceed 100")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
case_sensitive = False
|
||||||
|
extra = "ignore"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
获取配置单例
|
||||||
|
使用lru_cache避免重复加载.env文件
|
||||||
|
"""
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
|
# 导出配置实例
|
||||||
|
settings = get_settings()
|
||||||
66
src/database/base.py
Normal file
66
src/database/base.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
数据库模型基类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict
|
||||||
|
from sqlalchemy import Column, DateTime, Integer
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(Base):
|
||||||
|
"""抽象基类,为所有模型提供通用字段"""
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||||
|
created_time = Column(DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
comment="创建时间")
|
||||||
|
|
||||||
|
def to_dict(self, exclude: list = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将模型实例转换为字典
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude: 要排除的字段列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含模型字段的字典
|
||||||
|
"""
|
||||||
|
exclude = exclude or []
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for column in self.__table__.columns:
|
||||||
|
if column.name in exclude:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = getattr(self, column.name)
|
||||||
|
|
||||||
|
# 处理特殊类型
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
value = value.isoformat()
|
||||||
|
elif isinstance(value, bytes):
|
||||||
|
value = value.hex() if value else None
|
||||||
|
|
||||||
|
result[column.name] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def update_from_dict(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
从字典更新模型字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含要更新字段的字典
|
||||||
|
"""
|
||||||
|
for key, value in data.items():
|
||||||
|
if hasattr(self, key) and key != 'id':
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""模型表示"""
|
||||||
|
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||||
252
src/database/connection.py
Normal file
252
src/database/connection.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
数据库连接管理模块
|
||||||
|
支持同步和异步两种模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Generator, AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
create_async_engine,
|
||||||
|
async_sessionmaker
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
"""数据库管理器(同步)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._engine: Optional[Engine] = None
|
||||||
|
self._session_factory: Optional[sessionmaker] = None
|
||||||
|
|
||||||
|
def init_engine(self) -> None:
|
||||||
|
"""初始化数据库引擎"""
|
||||||
|
if self._engine is not None:
|
||||||
|
logger.warning("Database engine already initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._engine = create_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
pool_size=settings.DATABASE_POOL_SIZE,
|
||||||
|
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||||
|
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||||
|
echo=settings.DATABASE_ECHO,
|
||||||
|
pool_pre_ping=True, # 连接前进行ping检查
|
||||||
|
connect_args={
|
||||||
|
"options": f"-csearch_path={settings.DATABASE_SCHEMA}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._session_factory = sessionmaker(
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
bind=self._engine,
|
||||||
|
expire_on_commit=False # 避免延迟加载问题
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Database engine initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize database engine: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self) -> Engine:
|
||||||
|
"""获取数据库引擎"""
|
||||||
|
if self._engine is None:
|
||||||
|
self.init_engine()
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_factory(self) -> sessionmaker:
|
||||||
|
"""获取会话工厂"""
|
||||||
|
if self._session_factory is None:
|
||||||
|
self.init_engine()
|
||||||
|
return self._session_factory
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_session(self) -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
获取数据库会话的上下文管理器
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
with db_manager.get_session() as session:
|
||||||
|
result = session.query(User).all()
|
||||||
|
"""
|
||||||
|
session = self.session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
logger.debug("Session committed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Session rollback due to error: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
logger.debug("Session closed")
|
||||||
|
|
||||||
|
def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list:
|
||||||
|
"""执行原始SQL查询"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
result = session.execute(text(sql), params or {})
|
||||||
|
return [dict(row._mapping) for row in result]
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
"""数据库健康检查"""
|
||||||
|
try:
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
logger.info("Database health check passed")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭数据库连接"""
|
||||||
|
if self._engine:
|
||||||
|
self._engine.dispose()
|
||||||
|
self._engine = None
|
||||||
|
self._session_factory = None
|
||||||
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDatabaseManager:
|
||||||
|
"""异步数据库管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._engine: Optional[AsyncEngine] = None
|
||||||
|
self._async_session_factory: Optional[async_sessionmaker] = None
|
||||||
|
|
||||||
|
async def init_engine(self) -> None:
|
||||||
|
"""初始化异步数据库引擎"""
|
||||||
|
if self._engine is not None:
|
||||||
|
logger.warning("Async database engine already initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._engine = create_async_engine(
|
||||||
|
settings.ASYNC_DATABASE_URL,
|
||||||
|
pool_size=settings.DATABASE_POOL_SIZE,
|
||||||
|
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||||
|
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||||
|
echo=settings.DATABASE_ECHO,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
connect_args={
|
||||||
|
"server_settings": {
|
||||||
|
"search_path": settings.DATABASE_SCHEMA
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._async_session_factory = async_sessionmaker(
|
||||||
|
self._engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Async database engine initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize async database engine: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def engine(self) -> AsyncEngine:
|
||||||
|
"""获取异步数据库引擎"""
|
||||||
|
if self._engine is None:
|
||||||
|
await self.init_engine()
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def async_session_factory(self) -> async_sessionmaker:
|
||||||
|
"""获取异步会话工厂"""
|
||||||
|
if self._async_session_factory is None:
|
||||||
|
await self.init_engine()
|
||||||
|
return self._async_session_factory
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""
|
||||||
|
获取异步数据库会话的上下文管理器
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
async with async_db_manager.get_session() as session:
|
||||||
|
result = await session.execute(query)
|
||||||
|
"""
|
||||||
|
if self._async_session_factory is None:
|
||||||
|
await self.init_engine()
|
||||||
|
|
||||||
|
session = self._async_session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
logger.debug("Async session committed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
logger.error(f"Async session rollback due to error: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
logger.debug("Async session closed")
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""异步数据库健康检查"""
|
||||||
|
try:
|
||||||
|
async with self._engine.connect() as conn:
|
||||||
|
await conn.execute(text("SELECT 1"))
|
||||||
|
logger.info("Async database health check passed")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Async database health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭异步数据库连接"""
|
||||||
|
if self._engine:
|
||||||
|
await self._engine.dispose()
|
||||||
|
self._engine = None
|
||||||
|
self._async_session_factory = None
|
||||||
|
logger.info("Async database connections closed")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局数据库管理器实例
|
||||||
|
db_manager = DatabaseManager()
|
||||||
|
async_db_manager = AsyncDatabaseManager()
|
||||||
|
|
||||||
|
|
||||||
|
def init_database() -> None:
|
||||||
|
"""初始化数据库(同步)"""
|
||||||
|
db_manager.init_engine()
|
||||||
|
|
||||||
|
|
||||||
|
async def init_async_database() -> None:
|
||||||
|
"""初始化异步数据库"""
|
||||||
|
await async_db_manager.init_engine()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
"""依赖注入:获取数据库会话(用于FastAPI等框架)"""
|
||||||
|
with db_manager.get_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""依赖注入:获取异步数据库会话(用于FastAPI等框架)"""
|
||||||
|
async with async_db_manager.get_session() as session:
|
||||||
|
yield session
|
||||||
207
src/main.py
Normal file
207
src/main.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
主程序示例
|
||||||
|
演示如何使用各个模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
from src.database.connection import db_manager, init_database
|
||||||
|
from src.models.face_feature import SurFaceFeature
|
||||||
|
from src.repositories.face_feature_repository import FaceFeatureRepository
|
||||||
|
from src.services.face_feature_service import FaceFeatureService
|
||||||
|
from src.schemas.face_feature import (
|
||||||
|
FaceFeatureCreate,
|
||||||
|
FaceFeatureUpdate,
|
||||||
|
FaceFeatureQuery,
|
||||||
|
FeatureStatus,
|
||||||
|
BatchFaceFeatureCreate
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def demo_sync_operations():
|
||||||
|
"""演示同步操作"""
|
||||||
|
print("=== 同步操作演示 ===")
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
init_database()
|
||||||
|
|
||||||
|
# 创建仓库和服务
|
||||||
|
with db_manager.get_session() as session:
|
||||||
|
repository = FaceFeatureRepository(session)
|
||||||
|
service = FaceFeatureService(repository)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 创建特征记录(先检查是否存在)
|
||||||
|
print("\n1. 创建特征记录")
|
||||||
|
|
||||||
|
# 固定的测试ID
|
||||||
|
test_person_id = 1001
|
||||||
|
test_feature_type = 1
|
||||||
|
|
||||||
|
# 检查是否已存在
|
||||||
|
existing = service.get_feature_by_person_and_type(test_person_id, test_feature_type)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
print(f"记录已存在: ID={existing.id}, 状态={existing.status_name}")
|
||||||
|
feature_id = existing.id
|
||||||
|
else:
|
||||||
|
feature_data = FaceFeatureCreate(
|
||||||
|
person_id=test_person_id,
|
||||||
|
feature_type=test_feature_type,
|
||||||
|
pic_id="test_image_001.jpg",
|
||||||
|
status=FeatureStatus.NOT_STARTED
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_response = service.create_feature(feature_data)
|
||||||
|
print(f"创建成功: ID={feature_response.id}, 人员ID={feature_response.person_id}")
|
||||||
|
feature_id = feature_response.id
|
||||||
|
|
||||||
|
# 2. 开始处理(如果未开始)
|
||||||
|
print("\n2. 检查并开始处理特征计算")
|
||||||
|
feature = service.get_feature(feature_id)
|
||||||
|
if feature and feature.status == FeatureStatus.NOT_STARTED:
|
||||||
|
if service.start_processing(feature_id):
|
||||||
|
print(f"已开始处理: ID={feature_id}")
|
||||||
|
else:
|
||||||
|
print(f"特征计算已开始或已完成: 状态={feature.status_name if feature else '未知'}")
|
||||||
|
|
||||||
|
# 3. 完成处理(如果还在处理中)
|
||||||
|
print("\n3. 检查并完成处理特征计算")
|
||||||
|
feature = service.get_feature(feature_id)
|
||||||
|
if feature and feature.status == FeatureStatus.PROCESSING:
|
||||||
|
if service.finish_processing(feature_id, success=True):
|
||||||
|
print(f"已完成处理: ID={feature_id}")
|
||||||
|
else:
|
||||||
|
print(f"特征计算已完成或未开始: 状态={feature.status_name if feature else '未知'}")
|
||||||
|
|
||||||
|
# 4. 查询特征
|
||||||
|
print("\n4. 查询特征记录")
|
||||||
|
retrieved = service.get_feature(feature_id)
|
||||||
|
if retrieved:
|
||||||
|
print(f"查询成功: ID={retrieved.id}, 状态={retrieved.status_name}")
|
||||||
|
print(f"处理时间: {retrieved.processing_time}秒")
|
||||||
|
print(f"是否有特征数据: {retrieved.has_feature_data}")
|
||||||
|
|
||||||
|
# 5. 批量创建 - 跳过已存在的
|
||||||
|
print("\n5. 批量创建特征记录")
|
||||||
|
batch_items = []
|
||||||
|
for i in range(3):
|
||||||
|
person_id = 2000 + i
|
||||||
|
# 检查是否已存在
|
||||||
|
existing = service.get_feature_by_person_and_type(person_id, 1)
|
||||||
|
if not existing:
|
||||||
|
batch_items.append(FaceFeatureCreate(person_id=person_id, feature_type=1))
|
||||||
|
|
||||||
|
if batch_items:
|
||||||
|
batch_data = BatchFaceFeatureCreate(items=batch_items)
|
||||||
|
try:
|
||||||
|
batch_result = service.create_features_batch(batch_data)
|
||||||
|
print(f"批量创建成功: {len(batch_result)}条记录")
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"批量创建失败: {e}")
|
||||||
|
else:
|
||||||
|
print("所有记录已存在,跳过批量创建")
|
||||||
|
|
||||||
|
# 6. 查询列表
|
||||||
|
print("\n6. 查询特征记录列表")
|
||||||
|
query = FaceFeatureQuery(
|
||||||
|
feature_type=1,
|
||||||
|
start_date=datetime.now() - timedelta(days=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.query_features(query, page=1, page_size=10)
|
||||||
|
print(f"查询结果: 共{result.total}条记录,本页{len(result.items)}条")
|
||||||
|
|
||||||
|
if result.items:
|
||||||
|
print(f"第一笔记录: ID={result.items[0].id}, 人员ID={result.items[0].person_id}")
|
||||||
|
|
||||||
|
# 7. 获取统计信息
|
||||||
|
print("\n7. 获取统计信息")
|
||||||
|
stats = service.get_statistics()
|
||||||
|
print(f"总记录数: {stats.total_count}")
|
||||||
|
print(f"状态分布: {stats.by_status}")
|
||||||
|
print(f"特征类型分布: {stats.by_feature_type}")
|
||||||
|
|
||||||
|
# 8. 更新特征数据
|
||||||
|
print("\n8. 更新特征数据")
|
||||||
|
test_feature_data = b"test_feature_data_12345"
|
||||||
|
success = service.update_feature_data(feature_id, test_feature_data)
|
||||||
|
print(f"更新特征数据: {'成功' if success else '失败'}")
|
||||||
|
|
||||||
|
# 重新查询查看更新后的数据
|
||||||
|
updated_feature = service.get_feature(feature_id)
|
||||||
|
if updated_feature:
|
||||||
|
print(f"更新后是否有特征数据: {updated_feature.has_feature_data}")
|
||||||
|
|
||||||
|
# 9. 演示删除操作
|
||||||
|
print("\n9. 演示删除操作")
|
||||||
|
# 先创建一个要删除的记录
|
||||||
|
delete_person_id = 9999 # 使用一个不存在的ID
|
||||||
|
delete_feature_data = FaceFeatureCreate(
|
||||||
|
person_id=delete_person_id,
|
||||||
|
feature_type=3,
|
||||||
|
pic_id="to_delete.jpg"
|
||||||
|
)
|
||||||
|
delete_feature = service.create_feature(delete_feature_data)
|
||||||
|
print(f"创建待删除记录: ID={delete_feature.id}")
|
||||||
|
|
||||||
|
# 删除记录
|
||||||
|
delete_success = service.delete_feature(delete_feature.id)
|
||||||
|
print(f"删除记录: {'成功' if delete_success else '失败'}")
|
||||||
|
|
||||||
|
# 验证删除
|
||||||
|
deleted_check = service.get_feature(delete_feature.id)
|
||||||
|
print(f"验证删除: {'记录不存在' if deleted_check is None else '记录还存在'}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"操作失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 关闭数据库连接
|
||||||
|
db_manager.close()
|
||||||
|
print("\n数据库连接已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
async def demo_async_operations():
|
||||||
|
"""演示异步操作"""
|
||||||
|
print("\n=== 异步操作演示 ===")
|
||||||
|
|
||||||
|
# 注意:异步操作需要异步数据库管理器
|
||||||
|
# 这里仅展示结构,实际使用时需要配置异步数据库
|
||||||
|
|
||||||
|
print("异步操作示例代码已准备,需要配置异步数据库连接")
|
||||||
|
print("请使用 async_db_manager 和异步版本的repository")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("人脸特征管理系统演示")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 显示配置信息
|
||||||
|
print(f"应用名称: {settings.APP_NAME}")
|
||||||
|
print(f"数据库: {settings.DATABASE_NAME}")
|
||||||
|
print(f"连接池大小: {settings.DATABASE_POOL_SIZE}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行同步演示
|
||||||
|
demo_sync_operations()
|
||||||
|
|
||||||
|
# 如果需要异步演示,可以取消下面的注释
|
||||||
|
# asyncio.run(demo_async_operations())
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n程序被用户中断")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"程序执行出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print("\n演示完成!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
197
src/models/face_feature.py
Normal file
197
src/models/face_feature.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
人脸特征数据模型
|
||||||
|
对应数据库表:sur_face_feature
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import IntEnum
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
SmallInteger,
|
||||||
|
LargeBinary,
|
||||||
|
DateTime,
|
||||||
|
Text,
|
||||||
|
Index,
|
||||||
|
UniqueConstraint
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import BYTEA
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
from src.database.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureStatus(IntEnum):
|
||||||
|
"""人脸特征值计算状态枚举"""
|
||||||
|
NOT_STARTED = 0 # 未开始
|
||||||
|
PROCESSING = 1 # 计算中
|
||||||
|
SUCCESS = 2 # 计算成功
|
||||||
|
FAILED = 3 # 计算失败
|
||||||
|
|
||||||
|
|
||||||
|
# 导出别名以保持向后兼容性
|
||||||
|
FeatureStatusEnum = FeatureStatus
|
||||||
|
|
||||||
|
|
||||||
|
class SurFaceFeature(BaseModel):
|
||||||
|
"""
|
||||||
|
人脸特征值表模型
|
||||||
|
|
||||||
|
对应表结构:
|
||||||
|
- id: 主键
|
||||||
|
- person_id: 人员ID
|
||||||
|
- feature_type: 模型版本
|
||||||
|
- feature_data: 特征值(二进制)
|
||||||
|
- created_time: 创建时间
|
||||||
|
- pic_id: 图片ID
|
||||||
|
- status: 计算状态
|
||||||
|
- start_time: 计算开始时间
|
||||||
|
- finish_time: 计算结束时间
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "sur_face_feature"
|
||||||
|
__table_args__ = (
|
||||||
|
# 唯一约束:person_id + feature_type
|
||||||
|
UniqueConstraint("person_id", "feature_type", name="sur_face_feature_unique"),
|
||||||
|
# 为常用查询字段创建索引
|
||||||
|
Index("ix_sur_face_feature_person_id", "person_id"),
|
||||||
|
Index("ix_sur_face_feature_feature_type", "feature_type"),
|
||||||
|
Index("ix_sur_face_feature_status", "status"),
|
||||||
|
Index("ix_sur_face_feature_created_time", "created_time"),
|
||||||
|
{"schema": "public", "comment": "人脸特征值表"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 主键(自增序列)
|
||||||
|
id = Column(
|
||||||
|
Integer,
|
||||||
|
primary_key=True,
|
||||||
|
index=True,
|
||||||
|
comment="主键"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 人员ID(必填)
|
||||||
|
person_id = Column(
|
||||||
|
Integer,
|
||||||
|
nullable=False,
|
||||||
|
comment="人员id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模型版本(特征类型)
|
||||||
|
feature_type = Column(
|
||||||
|
SmallInteger,
|
||||||
|
nullable=True, # 根据SQL,允许NULL
|
||||||
|
comment="模型版本"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 特征值(二进制数据)
|
||||||
|
feature_data = Column(
|
||||||
|
BYTEA, # PostgreSQL的二进制类型
|
||||||
|
nullable=True,
|
||||||
|
comment="特征值"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建时间(自动设置)
|
||||||
|
created_time = Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
comment="创建时间"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 图片ID
|
||||||
|
pic_id = Column(
|
||||||
|
Text,
|
||||||
|
nullable=True,
|
||||||
|
comment="图片id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算状态
|
||||||
|
status = Column(
|
||||||
|
SmallInteger,
|
||||||
|
default=FeatureStatusEnum.NOT_STARTED,
|
||||||
|
nullable=False,
|
||||||
|
comment="人脸特征值计算状态:0=未开始,1=计算中,2=计算成功,3=计算失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算开始时间
|
||||||
|
start_time = Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=True,
|
||||||
|
comment="特征计算开始时间"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算结束时间
|
||||||
|
finish_time = Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=True,
|
||||||
|
comment="特征计算结束时间"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 属性方法
|
||||||
|
@property
|
||||||
|
def status_name(self) -> str:
|
||||||
|
"""获取状态名称"""
|
||||||
|
try:
|
||||||
|
return FeatureStatusEnum(self.status).name
|
||||||
|
except ValueError:
|
||||||
|
return f"未知状态({self.status})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_completed(self) -> bool:
|
||||||
|
"""是否完成计算"""
|
||||||
|
return self.status in [FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processing_time(self) -> Optional[float]:
|
||||||
|
"""计算处理时间(秒)"""
|
||||||
|
if self.start_time and self.finish_time:
|
||||||
|
return (self.finish_time - self.start_time).total_seconds()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def start_processing(self) -> None:
|
||||||
|
"""开始处理"""
|
||||||
|
self.status = FeatureStatusEnum.PROCESSING
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
self.finish_time = None
|
||||||
|
|
||||||
|
def finish_processing(self, success: bool = True) -> None:
|
||||||
|
"""结束处理"""
|
||||||
|
self.status = FeatureStatusEnum.SUCCESS if success else FeatureStatusEnum.FAILED
|
||||||
|
self.finish_time = datetime.now()
|
||||||
|
|
||||||
|
def to_dict(self, exclude: list = None) -> dict:
|
||||||
|
"""
|
||||||
|
重写to_dict方法,处理二进制数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude: 要排除的字段列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转换后的字典
|
||||||
|
"""
|
||||||
|
exclude = exclude or []
|
||||||
|
|
||||||
|
# 默认排除二进制数据(太大)
|
||||||
|
default_exclude = ["feature_data"]
|
||||||
|
final_exclude = list(set(exclude + default_exclude))
|
||||||
|
|
||||||
|
result = super().to_dict(final_exclude)
|
||||||
|
|
||||||
|
# 添加额外属性
|
||||||
|
result["status_name"] = self.status_name
|
||||||
|
result["is_completed"] = self.is_completed
|
||||||
|
result["processing_time"] = self.processing_time
|
||||||
|
|
||||||
|
# 如果有feature_data,添加一个标识
|
||||||
|
if self.feature_data and "feature_data" not in exclude:
|
||||||
|
result["has_feature_data"] = True
|
||||||
|
result["feature_data_length"] = len(self.feature_data)
|
||||||
|
else:
|
||||||
|
result["has_feature_data"] = False
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"<SurFaceFeature(id={self.id}, person_id={self.person_id}, "
|
||||||
|
f"feature_type={self.feature_type}, status={self.status_name})>")
|
||||||
597
src/repositories/face_feature_repository.py
Normal file
597
src/repositories/face_feature_repository.py
Normal file
@@ -0,0 +1,597 @@
|
|||||||
|
"""
|
||||||
|
人脸特征数据仓库
|
||||||
|
数据访问层,处理所有数据库操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
|
||||||
|
|
||||||
|
from src.models.face_feature import SurFaceFeature
|
||||||
|
from src.schemas.face_feature import (
|
||||||
|
FaceFeatureCreate,
|
||||||
|
FaceFeatureUpdate,
|
||||||
|
FaceFeatureQuery,
|
||||||
|
FeatureStatus
|
||||||
|
)
|
||||||
|
from src.models.face_feature import FeatureStatusEnum
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceFeatureRepository:
|
||||||
|
"""人脸特征数据仓库"""
|
||||||
|
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
"""
|
||||||
|
初始化仓库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: SQLAlchemy会话对象
|
||||||
|
"""
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
# ===== 创建操作 =====
|
||||||
|
|
||||||
|
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
|
||||||
|
"""
|
||||||
|
创建特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_data: 特征数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的SurFaceFeature对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IntegrityError: 违反唯一约束时抛出
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 转换为模型字典
|
||||||
|
feature_dict = feature_data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# 创建模型实例
|
||||||
|
feature = SurFaceFeature(**feature_dict)
|
||||||
|
|
||||||
|
# 添加到会话
|
||||||
|
self.session.add(feature)
|
||||||
|
self.session.flush() # 立即执行,但不提交
|
||||||
|
|
||||||
|
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
|
||||||
|
return feature
|
||||||
|
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Integrity error creating face feature: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
|
||||||
|
f"feature_type={feature_data.feature_type}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error creating face feature: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
|
||||||
|
"""
|
||||||
|
批量创建特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features_data: 特征数据列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的SurFaceFeature对象列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
features = []
|
||||||
|
for feature_data in features_data:
|
||||||
|
feature_dict = feature_data.model_dump(exclude_unset=True)
|
||||||
|
feature = SurFaceFeature(**feature_dict)
|
||||||
|
features.append(feature)
|
||||||
|
|
||||||
|
# 批量添加
|
||||||
|
self.session.add_all(features)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
logger.info(f"Created {len(features)} face feature records in batch")
|
||||||
|
return features
|
||||||
|
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Integrity error creating batch face features: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise ValueError("Duplicate feature record in batch")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error creating batch face features: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ===== 查询操作 =====
|
||||||
|
|
||||||
|
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
|
||||||
|
"""
|
||||||
|
根据ID获取特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SurFaceFeature对象或None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
feature = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if feature:
|
||||||
|
logger.debug(f"Retrieved face feature by id: {feature_id}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Face feature not found by id: {feature_id}")
|
||||||
|
|
||||||
|
return feature
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error getting face feature by id: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
|
||||||
|
"""
|
||||||
|
根据人员ID和特征类型获取特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
feature_type: 特征类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SurFaceFeature对象或None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = select(SurFaceFeature).where(
|
||||||
|
and_(
|
||||||
|
SurFaceFeature.person_id == person_id,
|
||||||
|
SurFaceFeature.feature_type == feature_type
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
feature = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if feature:
|
||||||
|
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
|
||||||
|
|
||||||
|
return feature
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error getting face feature by person and type: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
|
||||||
|
"""
|
||||||
|
根据人员ID获取特征记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SurFaceFeature对象列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
select(SurFaceFeature)
|
||||||
|
.where(SurFaceFeature.person_id == person_id)
|
||||||
|
.order_by(desc(SurFaceFeature.created_time))
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
features = list(result.scalars().all())
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
|
||||||
|
return features
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error getting face features by person: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def query_features(
|
||||||
|
self,
|
||||||
|
query: FaceFeatureQuery,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
order_by: str = "created_time",
|
||||||
|
desc_order: bool = True
|
||||||
|
) -> Tuple[List[SurFaceFeature], int]:
|
||||||
|
"""
|
||||||
|
查询特征记录(带分页)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询条件
|
||||||
|
page: 页码(从1开始)
|
||||||
|
page_size: 每页数量
|
||||||
|
order_by: 排序字段
|
||||||
|
desc_order: 是否降序
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(特征记录列表, 总记录数)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建查询条件
|
||||||
|
conditions = []
|
||||||
|
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
|
|
||||||
|
# 处理查询条件
|
||||||
|
if "person_id" in query_dict:
|
||||||
|
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
|
||||||
|
|
||||||
|
if "feature_type" in query_dict:
|
||||||
|
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
|
||||||
|
|
||||||
|
if "status" in query_dict:
|
||||||
|
conditions.append(SurFaceFeature.status == query_dict["status"])
|
||||||
|
|
||||||
|
if "start_date" in query_dict:
|
||||||
|
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
|
||||||
|
|
||||||
|
if "end_date" in query_dict:
|
||||||
|
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
|
||||||
|
|
||||||
|
if "has_feature_data" in query_dict:
|
||||||
|
if query_dict["has_feature_data"]:
|
||||||
|
conditions.append(SurFaceFeature.feature_data.isnot(None))
|
||||||
|
else:
|
||||||
|
conditions.append(SurFaceFeature.feature_data.is_(None))
|
||||||
|
|
||||||
|
# 基础查询
|
||||||
|
stmt = select(SurFaceFeature)
|
||||||
|
if conditions:
|
||||||
|
stmt = stmt.where(and_(*conditions))
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||||
|
total_result = self.session.execute(count_stmt)
|
||||||
|
total = total_result.scalar_one()
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
|
||||||
|
if desc_order:
|
||||||
|
stmt = stmt.order_by(desc(order_column))
|
||||||
|
else:
|
||||||
|
stmt = stmt.order_by(asc(order_column))
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
stmt = stmt.offset(offset).limit(page_size)
|
||||||
|
|
||||||
|
# 执行查询
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
features = list(result.scalars().all())
|
||||||
|
|
||||||
|
logger.debug(f"Query returned {len(features)} features (total: {total})")
|
||||||
|
return features, total
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error querying face features: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ===== 更新操作 =====
|
||||||
|
|
||||||
|
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
|
||||||
|
"""
|
||||||
|
更新特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
update_data: 更新数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的SurFaceFeature对象或None(如果不存在)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先检查是否存在
|
||||||
|
feature = self.get_by_id(feature_id)
|
||||||
|
if not feature:
|
||||||
|
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 转换为字典
|
||||||
|
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
for key, value in update_dict.items():
|
||||||
|
setattr(feature, key, value)
|
||||||
|
|
||||||
|
# 刷新到数据库
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
logger.info(f"Updated face feature: id={feature_id}")
|
||||||
|
return feature
|
||||||
|
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Integrity error updating face feature: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise ValueError("Update would create duplicate record")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error updating face feature: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
更新特征数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
feature_data: 特征数据(二进制)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功更新
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
update(SurFaceFeature)
|
||||||
|
.where(SurFaceFeature.id == feature_id)
|
||||||
|
.values(feature_data=feature_data)
|
||||||
|
.returning(SurFaceFeature.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
updated_id = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if updated_id:
|
||||||
|
logger.info(f"Updated feature data for face feature: id={feature_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error updating feature data: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update_status(
|
||||||
|
self,
|
||||||
|
feature_id: int,
|
||||||
|
status: FeatureStatus,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
finish_time: Optional[datetime] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
更新计算状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
status: 新状态
|
||||||
|
start_time: 开始时间(可选)
|
||||||
|
finish_time: 结束时间(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功更新
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
|
||||||
|
|
||||||
|
if start_time:
|
||||||
|
update_values["start_time"] = start_time
|
||||||
|
if finish_time:
|
||||||
|
update_values["finish_time"] = finish_time
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
update(SurFaceFeature)
|
||||||
|
.where(SurFaceFeature.id == feature_id)
|
||||||
|
.values(**update_values)
|
||||||
|
.returning(SurFaceFeature.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
updated_id = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if updated_id:
|
||||||
|
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error updating status: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ===== 删除操作 =====
|
||||||
|
|
||||||
|
def delete(self, feature_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
删除特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功删除
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.info(f"Deleted face feature: id={feature_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error deleting face feature: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_by_person(self, person_id: int) -> int:
|
||||||
|
"""
|
||||||
|
删除指定人员的所有特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error deleting face features by person: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ===== 统计操作 =====
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取特征记录统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 总记录数
|
||||||
|
total_stmt = select(func.count()).select_from(SurFaceFeature)
|
||||||
|
total_result = self.session.execute(total_stmt)
|
||||||
|
total_count = total_result.scalar_one()
|
||||||
|
|
||||||
|
# 按状态统计
|
||||||
|
status_stmt = (
|
||||||
|
select(SurFaceFeature.status, func.count())
|
||||||
|
.group_by(SurFaceFeature.status)
|
||||||
|
)
|
||||||
|
status_result = self.session.execute(status_stmt)
|
||||||
|
status_stats = {str(status): count for status, count in status_result}
|
||||||
|
|
||||||
|
# 按特征类型统计
|
||||||
|
type_stmt = (
|
||||||
|
select(SurFaceFeature.feature_type, func.count())
|
||||||
|
.where(SurFaceFeature.feature_type.isnot(None))
|
||||||
|
.group_by(SurFaceFeature.feature_type)
|
||||||
|
)
|
||||||
|
type_result = self.session.execute(type_stmt)
|
||||||
|
type_stats = {str(feature_type): count for feature_type, count in type_result}
|
||||||
|
|
||||||
|
# 平均处理时间(仅计算成功和失败的)
|
||||||
|
time_stmt = (
|
||||||
|
select(
|
||||||
|
func.avg(
|
||||||
|
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
SurFaceFeature.start_time.isnot(None),
|
||||||
|
SurFaceFeature.finish_time.isnot(None),
|
||||||
|
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
time_result = self.session.execute(time_stmt)
|
||||||
|
avg_time = time_result.scalar_one()
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total_count": total_count,
|
||||||
|
"by_status": status_stats,
|
||||||
|
"by_feature_type": type_stats,
|
||||||
|
"avg_processing_time": float(avg_time) if avg_time else None
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"Generated face feature statistics")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error getting statistics: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ===== 批量操作 =====
|
||||||
|
|
||||||
|
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
|
||||||
|
"""
|
||||||
|
标记待处理的特征记录为计算中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: 最大处理数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标记为处理中的特征记录列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找待处理的记录
|
||||||
|
pending_stmt = (
|
||||||
|
select(SurFaceFeature)
|
||||||
|
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
|
||||||
|
.order_by(SurFaceFeature.created_time)
|
||||||
|
.limit(limit)
|
||||||
|
.with_for_update(skip_locked=True) # 跳过被锁定的行
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.session.execute(pending_stmt)
|
||||||
|
pending_features = list(result.scalars().all())
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
feature_ids = [f.id for f in pending_features]
|
||||||
|
if feature_ids:
|
||||||
|
update_stmt = (
|
||||||
|
update(SurFaceFeature)
|
||||||
|
.where(SurFaceFeature.id.in_(feature_ids))
|
||||||
|
.values(
|
||||||
|
status=FeatureStatusEnum.PROCESSING,
|
||||||
|
start_time=datetime.now()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.session.execute(update_stmt)
|
||||||
|
|
||||||
|
logger.info(f"Marked {len(pending_features)} face features for processing")
|
||||||
|
return pending_features
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error marking features for processing: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def cleanup_old_features(self, days: int = 30) -> int:
|
||||||
|
"""
|
||||||
|
清理旧的特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 保留天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cutoff_date = datetime.now() - timedelta(days=days)
|
||||||
|
|
||||||
|
# 只删除已完成(成功或失败)的旧记录
|
||||||
|
stmt = delete(SurFaceFeature).where(
|
||||||
|
and_(
|
||||||
|
SurFaceFeature.created_time < cutoff_date,
|
||||||
|
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
|
||||||
|
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Database error cleaning up old features: {e}")
|
||||||
|
self.session.rollback()
|
||||||
|
raise
|
||||||
86
src/run.py
Normal file
86
src/run.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FastAPI应用启动脚本
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import argparse
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
主函数:解析命令行参数并启动服务器
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="algorithm service")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="0.0.0.0",
|
||||||
|
help="监听主机 (默认: 0.0.0.0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="监听端口 (默认: 8000)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--reload",
|
||||||
|
action="store_true",
|
||||||
|
help="启用热重载 (开发模式)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="工作进程数 (生产模式)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=["debug", "info", "warning", "error", "critical"],
|
||||||
|
help="日志级别"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 根据环境选择配置
|
||||||
|
if settings.DEBUG:
|
||||||
|
print("🔧 开发模式")
|
||||||
|
uvicorn_config = {
|
||||||
|
"host": args.host,
|
||||||
|
"port": args.port,
|
||||||
|
"reload": True,
|
||||||
|
"log_level": "debug",
|
||||||
|
"workers": 1
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("🚀 生产模式")
|
||||||
|
uvicorn_config = {
|
||||||
|
"host": args.host,
|
||||||
|
"port": args.port,
|
||||||
|
"reload": False,
|
||||||
|
"log_level": args.log_level,
|
||||||
|
"workers": args.workers
|
||||||
|
}
|
||||||
|
|
||||||
|
# 启动服务器
|
||||||
|
print(f"🌐 启动服务器: http://{args.host}:{args.port}")
|
||||||
|
print(f"📚 API文档: http://{args.host}:{args.port}/docs")
|
||||||
|
print(f"📊 健康检查: http://{args.host}:{args.port}/health")
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"src.app:app",
|
||||||
|
**uvicorn_config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
213
src/schemas/face_feature.py
Normal file
213
src/schemas/face_feature.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
人脸特征值的Pydantic模型
|
||||||
|
用于数据验证和序列化
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field, field_validator, ConfigDict, model_validator
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
# 枚举定义(与数据库模型一致)
|
||||||
|
class FeatureStatus(IntEnum):
|
||||||
|
NOT_STARTED = 0
|
||||||
|
PROCESSING = 1
|
||||||
|
SUCCESS = 2
|
||||||
|
FAILED = 3
|
||||||
|
|
||||||
|
|
||||||
|
# 基础模型
|
||||||
|
class FaceFeatureBase(BaseModel):
|
||||||
|
"""基础模型,包含所有字段"""
|
||||||
|
person_id: int = Field(..., description="人员ID", gt=0)
|
||||||
|
feature_type: Optional[int] = Field(None, description="模型版本", ge=0)
|
||||||
|
feature_data: Optional[bytes] = Field(None, description="特征值(二进制)")
|
||||||
|
pic_id: Optional[str] = Field(None, description="图片ID", max_length=255)
|
||||||
|
status: FeatureStatus = Field(
|
||||||
|
default=FeatureStatus.NOT_STARTED,
|
||||||
|
description="计算状态"
|
||||||
|
)
|
||||||
|
start_time: Optional[datetime] = Field(None, description="计算开始时间")
|
||||||
|
finish_time: Optional[datetime] = Field(None, description="计算结束时间")
|
||||||
|
|
||||||
|
@field_validator('feature_data', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_feature_data(cls, v):
|
||||||
|
"""验证特征数据"""
|
||||||
|
if v is not None and not isinstance(v, bytes):
|
||||||
|
if isinstance(v, str):
|
||||||
|
# 尝试从hex字符串转换
|
||||||
|
try:
|
||||||
|
return bytes.fromhex(v)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("feature_data must be valid hex string or bytes")
|
||||||
|
else:
|
||||||
|
raise ValueError("feature_data must be bytes or hex string")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
class FaceFeatureCreate(FaceFeatureBase):
|
||||||
|
"""创建特征记录模型"""
|
||||||
|
# 创建时不指定ID和时间
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"person_id": 1001,
|
||||||
|
"feature_type": 1,
|
||||||
|
"pic_id": "img_20250101_001",
|
||||||
|
"status": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceFeatureUpdate(BaseModel):
|
||||||
|
"""更新特征记录模型"""
|
||||||
|
feature_type: Optional[int] = Field(None, description="模型版本", ge=0)
|
||||||
|
feature_data: Optional[bytes] = Field(None, description="特征值(二进制)")
|
||||||
|
pic_id: Optional[str] = Field(None, description="图片ID", max_length=255)
|
||||||
|
status: Optional[FeatureStatus] = Field(None, description="计算状态")
|
||||||
|
start_time: Optional[datetime] = Field(None, description="计算开始时间")
|
||||||
|
finish_time: Optional[datetime] = Field(None, description="计算结束时间")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"status": 2,
|
||||||
|
"finish_time": "2024-01-01T12:00:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 查询参数模型
|
||||||
|
class FaceFeatureQuery(BaseModel):
|
||||||
|
"""特征记录查询参数"""
|
||||||
|
person_id: Optional[int] = Field(None, description="人员ID", gt=0)
|
||||||
|
feature_type: Optional[int] = Field(None, description="模型版本", ge=0)
|
||||||
|
status: Optional[FeatureStatus] = Field(None, description="计算状态")
|
||||||
|
start_date: Optional[datetime] = Field(None, description="开始时间")
|
||||||
|
end_date: Optional[datetime] = Field(None, description="结束时间")
|
||||||
|
has_feature_data: Optional[bool] = Field(None, description="是否有特征数据")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"person_id": 1001,
|
||||||
|
"status": 2,
|
||||||
|
"start_date": "2024-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 响应模型
|
||||||
|
class FaceFeatureResponse(FaceFeatureBase):
|
||||||
|
"""特征记录响应模型"""
|
||||||
|
id: int
|
||||||
|
created_time: datetime
|
||||||
|
|
||||||
|
# 计算字段(将在验证后设置)
|
||||||
|
status_name: Optional[str] = None
|
||||||
|
is_completed: Optional[bool] = None
|
||||||
|
processing_time: Optional[float] = None
|
||||||
|
has_feature_data: Optional[bool] = None
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def set_computed_fields(self):
|
||||||
|
"""设置所有计算字段"""
|
||||||
|
# 状态名称
|
||||||
|
try:
|
||||||
|
self.status_name = FeatureStatus(self.status).name
|
||||||
|
except ValueError:
|
||||||
|
self.status_name = f"未知状态({self.status})"
|
||||||
|
|
||||||
|
# 是否完成
|
||||||
|
self.is_completed = self.status in [FeatureStatus.SUCCESS, FeatureStatus.FAILED]
|
||||||
|
|
||||||
|
# 处理时间
|
||||||
|
if self.start_time and self.finish_time:
|
||||||
|
self.processing_time = (self.finish_time - self.start_time).total_seconds()
|
||||||
|
|
||||||
|
# 是否有特征数据
|
||||||
|
self.has_feature_data = self.feature_data is not None and len(self.feature_data) > 0
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
from_attributes=True,
|
||||||
|
populate_by_name=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"id": 1,
|
||||||
|
"person_id": 1001,
|
||||||
|
"feature_type": 1,
|
||||||
|
"status": 2,
|
||||||
|
"status_name": "SUCCESS",
|
||||||
|
"created_time": "2024-01-01T10:00:00Z",
|
||||||
|
"pic_id": "img_20250101_001",
|
||||||
|
"start_time": "2024-01-01T10:00:00Z",
|
||||||
|
"finish_time": "2024-01-01T10:00:05Z",
|
||||||
|
"is_completed": True,
|
||||||
|
"processing_time": 5.0,
|
||||||
|
"has_feature_data": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceFeatureListResponse(BaseModel):
|
||||||
|
"""特征记录列表响应"""
|
||||||
|
total: int = Field(..., description="总记录数")
|
||||||
|
items: List[FaceFeatureResponse] = Field(..., description="记录列表")
|
||||||
|
page: Optional[int] = Field(None, description="当前页码")
|
||||||
|
page_size: Optional[int] = Field(None, description="每页数量")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"total": 100,
|
||||||
|
"page": 1,
|
||||||
|
"page_size": 20,
|
||||||
|
"items": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 批量操作模型
|
||||||
|
class BatchFaceFeatureCreate(BaseModel):
|
||||||
|
"""批量创建特征记录"""
|
||||||
|
items: List[FaceFeatureCreate] = Field(..., description="特征记录列表", min_items=1, max_items=1000)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"items": [
|
||||||
|
{"person_id": 1001, "feature_type": 1},
|
||||||
|
{"person_id": 1002, "feature_type": 1}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceFeatureStatsResponse(BaseModel):
|
||||||
|
"""特征记录统计响应"""
|
||||||
|
total_count: int = Field(..., description="总记录数")
|
||||||
|
by_status: Dict[str, int] = Field(..., description="按状态统计")
|
||||||
|
by_feature_type: Dict[str, int] = Field(..., description="按特征类型统计")
|
||||||
|
avg_processing_time: Optional[float] = Field(None, description="平均处理时间(秒)")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"total_count": 1000,
|
||||||
|
"by_status": {"SUCCESS": 800, "PROCESSING": 100, "FAILED": 100},
|
||||||
|
"by_feature_type": {"1": 500, "2": 500},
|
||||||
|
"avg_processing_time": 5.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
562
src/services/face_feature_service.py
Normal file
562
src/services/face_feature_service.py
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
"""
|
||||||
|
人脸特征业务逻辑服务层
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from src.repositories.face_feature_repository import FaceFeatureRepository
|
||||||
|
from src.schemas.face_feature import (
|
||||||
|
FaceFeatureCreate,
|
||||||
|
FaceFeatureUpdate,
|
||||||
|
FaceFeatureQuery,
|
||||||
|
FaceFeatureResponse,
|
||||||
|
FaceFeatureListResponse,
|
||||||
|
FaceFeatureStatsResponse,
|
||||||
|
BatchFaceFeatureCreate,
|
||||||
|
FeatureStatus
|
||||||
|
)
|
||||||
|
from src.models.face_feature import FeatureStatusEnum
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceFeatureService:
|
||||||
|
"""人脸特征业务服务"""
|
||||||
|
|
||||||
|
def __init__(self, repository: FaceFeatureRepository):
|
||||||
|
"""
|
||||||
|
初始化服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: 特征仓库实例
|
||||||
|
"""
|
||||||
|
self.repository = repository
|
||||||
|
|
||||||
|
# ===== CRUD操作 =====
|
||||||
|
|
||||||
|
def create_feature(self, feature_data: FaceFeatureCreate) -> FaceFeatureResponse:
|
||||||
|
"""
|
||||||
|
创建特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_data: 特征数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的特征记录响应
|
||||||
|
"""
|
||||||
|
logger.info(f"Creating face feature for person_id={feature_data.person_id}")
|
||||||
|
|
||||||
|
# 业务逻辑验证
|
||||||
|
self._validate_feature_data(feature_data)
|
||||||
|
|
||||||
|
# 检查是否已存在相同记录
|
||||||
|
if feature_data.feature_type is not None:
|
||||||
|
existing = self.repository.get_by_person_and_type(
|
||||||
|
feature_data.person_id,
|
||||||
|
feature_data.feature_type
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature record already exists for person_id={feature_data.person_id}, "
|
||||||
|
f"feature_type={feature_data.feature_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建记录
|
||||||
|
feature = self.repository.create(feature_data)
|
||||||
|
|
||||||
|
# 转换为响应模型
|
||||||
|
return FaceFeatureResponse.model_validate(feature)
|
||||||
|
|
||||||
|
def create_features_batch(self, batch_data: BatchFaceFeatureCreate) -> List[FaceFeatureResponse]:
|
||||||
|
"""
|
||||||
|
批量创建特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_data: 批量特征数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的特征记录响应列表
|
||||||
|
"""
|
||||||
|
logger.info(f"Creating {len(batch_data.items)} face features in batch")
|
||||||
|
|
||||||
|
# 验证所有数据
|
||||||
|
for feature_data in batch_data.items:
|
||||||
|
self._validate_feature_data(feature_data)
|
||||||
|
|
||||||
|
# 批量创建
|
||||||
|
features = self.repository.create_batch(batch_data.items)
|
||||||
|
|
||||||
|
# 转换为响应模型
|
||||||
|
return [FaceFeatureResponse.model_validate(f) for f in features]
|
||||||
|
|
||||||
|
def get_feature(self, feature_id: int) -> Optional[FaceFeatureResponse]:
|
||||||
|
"""
|
||||||
|
获取特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征记录响应或None
|
||||||
|
"""
|
||||||
|
logger.debug(f"Getting face feature: id={feature_id}")
|
||||||
|
|
||||||
|
feature = self.repository.get_by_id(feature_id)
|
||||||
|
if not feature:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return FaceFeatureResponse.model_validate(feature)
|
||||||
|
|
||||||
|
def get_feature_by_person_and_type(
|
||||||
|
self,
|
||||||
|
person_id: int,
|
||||||
|
feature_type: int
|
||||||
|
) -> Optional[FaceFeatureResponse]:
|
||||||
|
"""
|
||||||
|
根据人员ID和特征类型获取特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
feature_type: 特征类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征记录响应或None
|
||||||
|
"""
|
||||||
|
logger.debug(f"Getting face feature: person_id={person_id}, feature_type={feature_type}")
|
||||||
|
|
||||||
|
feature = self.repository.get_by_person_and_type(person_id, feature_type)
|
||||||
|
if not feature:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return FaceFeatureResponse.model_validate(feature)
|
||||||
|
|
||||||
|
def list_features_by_person(
|
||||||
|
self,
|
||||||
|
person_id: int,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[FaceFeatureResponse]:
|
||||||
|
"""
|
||||||
|
获取人员的特征记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征记录响应列表
|
||||||
|
"""
|
||||||
|
logger.debug(f"Listing face features for person_id={person_id}")
|
||||||
|
|
||||||
|
features = self.repository.get_by_person(person_id, limit)
|
||||||
|
return [FaceFeatureResponse.model_validate(f) for f in features]
|
||||||
|
|
||||||
|
def query_features(
|
||||||
|
self,
|
||||||
|
query: FaceFeatureQuery,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
order_by: str = "created_time",
|
||||||
|
desc_order: bool = True
|
||||||
|
) -> FaceFeatureListResponse:
|
||||||
|
"""
|
||||||
|
查询特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询条件
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
order_by: 排序字段
|
||||||
|
desc_order: 是否降序
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征记录列表响应
|
||||||
|
"""
|
||||||
|
logger.debug(f"Querying face features with filters: {query.model_dump(exclude_unset=True)}")
|
||||||
|
|
||||||
|
features, total = self.repository.query_features(
|
||||||
|
query, page, page_size, order_by, desc_order
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [FaceFeatureResponse.model_validate(f) for f in features]
|
||||||
|
|
||||||
|
return FaceFeatureListResponse(
|
||||||
|
total=total,
|
||||||
|
items=items,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_feature(
|
||||||
|
self,
|
||||||
|
feature_id: int,
|
||||||
|
update_data: FaceFeatureUpdate
|
||||||
|
) -> Optional[FaceFeatureResponse]:
|
||||||
|
"""
|
||||||
|
更新特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
update_data: 更新数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的特征记录响应或None
|
||||||
|
"""
|
||||||
|
logger.info(f"Updating face feature: id={feature_id}")
|
||||||
|
|
||||||
|
# 业务逻辑验证
|
||||||
|
if update_data.status is not None:
|
||||||
|
self._validate_status_transition(feature_id, update_data.status)
|
||||||
|
|
||||||
|
# 更新记录
|
||||||
|
feature = self.repository.update(feature_id, update_data)
|
||||||
|
if not feature:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return FaceFeatureResponse.model_validate(feature)
|
||||||
|
|
||||||
|
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
更新特征数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
feature_data: 特征数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功更新
|
||||||
|
"""
|
||||||
|
logger.info(f"Updating feature data for face feature: id={feature_id}")
|
||||||
|
|
||||||
|
# 验证特征数据
|
||||||
|
if not feature_data or len(feature_data) == 0:
|
||||||
|
raise ValueError("Feature data cannot be empty")
|
||||||
|
|
||||||
|
if len(feature_data) > 1024 * 1024: # 1MB限制
|
||||||
|
raise ValueError("Feature data is too large (max 1MB)")
|
||||||
|
|
||||||
|
return self.repository.update_feature_data(feature_id, feature_data)
|
||||||
|
|
||||||
|
def update_status(
|
||||||
|
self,
|
||||||
|
feature_id: int,
|
||||||
|
status: FeatureStatus,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
finish_time: Optional[datetime] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
更新计算状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
status: 新状态
|
||||||
|
start_time: 开始时间
|
||||||
|
finish_time: 结束时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功更新
|
||||||
|
"""
|
||||||
|
logger.info(f"Updating status to {status} for face feature: id={feature_id}")
|
||||||
|
|
||||||
|
# 验证状态转换
|
||||||
|
self._validate_status_transition(feature_id, status)
|
||||||
|
|
||||||
|
return self.repository.update_status(feature_id, status, start_time, finish_time)
|
||||||
|
|
||||||
|
def delete_feature(self, feature_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
删除特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功删除
|
||||||
|
"""
|
||||||
|
logger.info(f"Deleting face feature: id={feature_id}")
|
||||||
|
|
||||||
|
return self.repository.delete(feature_id)
|
||||||
|
|
||||||
|
def delete_features_by_person(self, person_id: int) -> int:
|
||||||
|
"""
|
||||||
|
删除指定人员的所有特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的记录数
|
||||||
|
"""
|
||||||
|
logger.info(f"Deleting all face features for person_id={person_id}")
|
||||||
|
|
||||||
|
return self.repository.delete_by_person(person_id)
|
||||||
|
|
||||||
|
# ===== 业务操作 =====
|
||||||
|
|
||||||
|
def start_processing(self, feature_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
开始处理特征计算
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功开始
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting processing for face feature: id={feature_id}")
|
||||||
|
|
||||||
|
# 获取当前特征
|
||||||
|
feature = self.repository.get_by_id(feature_id)
|
||||||
|
if not feature:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证状态
|
||||||
|
if feature.status != FeatureStatusEnum.NOT_STARTED:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot start processing for feature with status {feature.status_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
return self.repository.update_status(
|
||||||
|
feature_id,
|
||||||
|
FeatureStatus.PROCESSING,
|
||||||
|
start_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
def finish_processing(self, feature_id: int, success: bool = True) -> bool:
|
||||||
|
"""
|
||||||
|
完成特征计算
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
success: 是否成功
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功完成
|
||||||
|
"""
|
||||||
|
logger.info(f"Finishing processing for face feature: id={feature_id}, success={success}")
|
||||||
|
|
||||||
|
# 获取当前特征
|
||||||
|
feature = self.repository.get_by_id(feature_id)
|
||||||
|
if not feature:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证状态
|
||||||
|
if feature.status != FeatureStatusEnum.PROCESSING:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot finish processing for feature with status {feature.status_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
status = FeatureStatus.SUCCESS if success else FeatureStatus.FAILED
|
||||||
|
return self.repository.update_status(
|
||||||
|
feature_id,
|
||||||
|
status,
|
||||||
|
finish_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_pending_features(self, limit: int = 100) -> List[FaceFeatureResponse]:
|
||||||
|
"""
|
||||||
|
处理待计算的特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: 最大处理数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理中的特征记录列表
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing up to {limit} pending face features")
|
||||||
|
|
||||||
|
features = self.repository.mark_for_processing(limit)
|
||||||
|
return [FaceFeatureResponse.model_validate(f) for f in features]
|
||||||
|
|
||||||
|
# ===== 统计和分析 =====
|
||||||
|
|
||||||
|
def get_statistics(self) -> FaceFeatureStatsResponse:
|
||||||
|
"""
|
||||||
|
获取特征记录统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息响应
|
||||||
|
"""
|
||||||
|
logger.debug("Getting face feature statistics")
|
||||||
|
|
||||||
|
stats = self.repository.get_stats()
|
||||||
|
|
||||||
|
# 转换状态枚举名称
|
||||||
|
by_status = {}
|
||||||
|
for status_value, count in stats["by_status"].items():
|
||||||
|
try:
|
||||||
|
status_name = FeatureStatusEnum(int(status_value)).name
|
||||||
|
by_status[status_name] = count
|
||||||
|
except ValueError:
|
||||||
|
by_status[f"未知({status_value})"] = count
|
||||||
|
|
||||||
|
return FaceFeatureStatsResponse(
|
||||||
|
total_count=stats["total_count"],
|
||||||
|
by_status=by_status,
|
||||||
|
by_feature_type=stats["by_feature_type"],
|
||||||
|
avg_processing_time=stats["avg_processing_time"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_person_statistics(self, person_id: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取人员的特征统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: 人员ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
人员统计信息
|
||||||
|
"""
|
||||||
|
logger.debug(f"Getting statistics for person_id={person_id}")
|
||||||
|
|
||||||
|
# 获取该人员的所有特征
|
||||||
|
features = self.repository.get_by_person(person_id, limit=1000)
|
||||||
|
|
||||||
|
if not features:
|
||||||
|
return {
|
||||||
|
"person_id": person_id,
|
||||||
|
"total_features": 0,
|
||||||
|
"status_summary": {},
|
||||||
|
"feature_types": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
status_summary = {}
|
||||||
|
feature_types = set()
|
||||||
|
successful_features = []
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
# 状态统计
|
||||||
|
status_name = feature.status_name
|
||||||
|
status_summary[status_name] = status_summary.get(status_name, 0) + 1
|
||||||
|
|
||||||
|
# 特征类型
|
||||||
|
if feature.feature_type is not None:
|
||||||
|
feature_types.add(feature.feature_type)
|
||||||
|
|
||||||
|
# 成功完成的特征
|
||||||
|
if feature.status == FeatureStatusEnum.SUCCESS:
|
||||||
|
successful_features.append(feature)
|
||||||
|
|
||||||
|
# 计算平均处理时间(仅成功记录)
|
||||||
|
total_time = 0
|
||||||
|
valid_count = 0
|
||||||
|
|
||||||
|
for feature in successful_features:
|
||||||
|
if feature.processing_time:
|
||||||
|
total_time += feature.processing_time
|
||||||
|
valid_count += 1
|
||||||
|
|
||||||
|
avg_time = total_time / valid_count if valid_count > 0 else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"person_id": person_id,
|
||||||
|
"total_features": len(features),
|
||||||
|
"status_summary": status_summary,
|
||||||
|
"feature_types": sorted(list(feature_types)),
|
||||||
|
"avg_processing_time": avg_time,
|
||||||
|
"successful_count": len(successful_features)
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup_old_records(self, days: int = 30) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
清理旧的特征记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 保留天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理结果
|
||||||
|
"""
|
||||||
|
logger.info(f"Cleaning up face features older than {days} days")
|
||||||
|
|
||||||
|
# 先获取清理前的统计
|
||||||
|
before_stats = self.repository.get_stats()
|
||||||
|
|
||||||
|
# 执行清理
|
||||||
|
deleted_count = self.repository.cleanup_old_features(days)
|
||||||
|
|
||||||
|
# 获取清理后的统计
|
||||||
|
after_stats = self.repository.get_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"days_retained": days,
|
||||||
|
"deleted_count": deleted_count,
|
||||||
|
"before_total": before_stats["total_count"],
|
||||||
|
"after_total": after_stats["total_count"],
|
||||||
|
"reduction_percentage": (
|
||||||
|
(before_stats["total_count"] - after_stats["total_count"]) /
|
||||||
|
before_stats["total_count"] * 100
|
||||||
|
if before_stats["total_count"] > 0 else 0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# ===== 私有方法 =====
|
||||||
|
|
||||||
|
def _validate_feature_data(self, feature_data: FaceFeatureCreate) -> None:
|
||||||
|
"""
|
||||||
|
验证特征数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_data: 特征数据
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果数据无效
|
||||||
|
"""
|
||||||
|
# 验证人员ID
|
||||||
|
if feature_data.person_id <= 0:
|
||||||
|
raise ValueError("person_id must be greater than 0")
|
||||||
|
|
||||||
|
# 验证特征类型
|
||||||
|
if feature_data.feature_type is not None and feature_data.feature_type < 0:
|
||||||
|
raise ValueError("feature_type must be non-negative")
|
||||||
|
|
||||||
|
# 验证特征数据大小
|
||||||
|
if feature_data.feature_data and len(feature_data.feature_data) > 1024 * 1024:
|
||||||
|
raise ValueError("feature_data is too large (max 1MB)")
|
||||||
|
|
||||||
|
# 验证状态
|
||||||
|
if feature_data.status:
|
||||||
|
try:
|
||||||
|
FeatureStatus(feature_data.status)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid status value: {feature_data.status}")
|
||||||
|
|
||||||
|
def _validate_status_transition(self, feature_id: int, new_status: FeatureStatus) -> None:
|
||||||
|
"""
|
||||||
|
验证状态转换是否有效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_id: 特征记录ID
|
||||||
|
new_status: 新状态
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果状态转换无效
|
||||||
|
"""
|
||||||
|
# 获取当前特征
|
||||||
|
feature = self.repository.get_by_id(feature_id)
|
||||||
|
if not feature:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_status = FeatureStatusEnum(feature.status)
|
||||||
|
|
||||||
|
# 定义允许的状态转换
|
||||||
|
allowed_transitions = {
|
||||||
|
FeatureStatusEnum.NOT_STARTED: [FeatureStatusEnum.PROCESSING],
|
||||||
|
FeatureStatusEnum.PROCESSING: [FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED],
|
||||||
|
FeatureStatusEnum.SUCCESS: [],
|
||||||
|
FeatureStatusEnum.FAILED: [FeatureStatusEnum.PROCESSING] # 允许重试
|
||||||
|
}
|
||||||
|
|
||||||
|
new_status_enum = FeatureStatusEnum(new_status.value if isinstance(new_status, FeatureStatus) else new_status)
|
||||||
|
|
||||||
|
# 检查转换是否允许
|
||||||
|
if new_status_enum not in allowed_transitions.get(current_status, []):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot transition from {current_status.name} to {new_status_enum.name}"
|
||||||
|
)
|
||||||
88
src/utils/logger.py
Normal file
88
src/utils/logger.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
日志配置模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
name: str,
|
||||||
|
level: Optional[str] = None,
|
||||||
|
log_file: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
配置和获取logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: logger名称
|
||||||
|
level: 日志级别
|
||||||
|
log_file: 日志文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的logger实例
|
||||||
|
"""
|
||||||
|
# 获取日志级别
|
||||||
|
if level is None:
|
||||||
|
level = settings.LOG_LEVEL
|
||||||
|
|
||||||
|
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||||
|
|
||||||
|
# 创建logger
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# 避免重复添加handler
|
||||||
|
if logger.handlers:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# 创建formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 控制台handler
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(log_level)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 文件handler(如果配置了日志文件)
|
||||||
|
if log_file or settings.LOG_FILE:
|
||||||
|
file_path = log_file or settings.LOG_FILE
|
||||||
|
try:
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
file_path,
|
||||||
|
maxBytes=10 * 1024 * 1024, # 10MB
|
||||||
|
backupCount=5,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setLevel(log_level)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create log file handler: {e}")
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
# 创建根logger
|
||||||
|
root_logger = setup_logger("sur_face_feature")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
获取指定名称的logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: logger名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logger实例
|
||||||
|
"""
|
||||||
|
return setup_logger(name)
|
||||||
Reference in New Issue
Block a user