diff --git a/src/api/dependencies.py b/src/api/dependencies.py new file mode 100644 index 0000000..b457472 --- /dev/null +++ b/src/api/dependencies.py @@ -0,0 +1,67 @@ +""" +FastAPI依赖注入模块 +""" + +from typing import Generator, Optional +from fastapi import Depends, HTTPException, status +from sqlalchemy.orm import Session + +from src.database.connection import db_manager, get_db +from src.repositories.face_feature_repository import FaceFeatureRepository +from src.services.face_feature_service import FaceFeatureService + + +def get_face_feature_repository( + db: Session = Depends(get_db) +) -> FaceFeatureRepository: + """ + 获取人脸特征仓库依赖 + + Args: + db: 数据库会话 + + Returns: + FaceFeatureRepository实例 + """ + return FaceFeatureRepository(db) + + +def get_face_feature_service( + repository: FaceFeatureRepository = Depends(get_face_feature_repository) +) -> FaceFeatureService: + """ + 获取人脸特征服务依赖 + + Args: + repository: 人脸特征仓库 + + Returns: + FaceFeatureService实例 + """ + return FaceFeatureService(repository) + + +def get_face_feature_by_id( + feature_id: int, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 根据ID获取人脸特征记录的依赖 + + Args: + feature_id: 特征记录ID + service: 人脸特征服务 + + Returns: + 人脸特征记录 + + Raises: + HTTPException: 如果记录不存在 + """ + feature = service.get_feature(feature_id) + if not feature: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"特征记录不存在 (ID: {feature_id})" + ) + return feature \ No newline at end of file diff --git a/src/api/errors.py b/src/api/errors.py new file mode 100644 index 0000000..ee131c4 --- /dev/null +++ b/src/api/errors.py @@ -0,0 +1,152 @@ +""" +API错误处理模块 +""" + +from typing import Any, Dict, Optional +from fastapi import HTTPException, status +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from fastapi.requests import Request + + +class APIError(HTTPException): + """API自定义错误""" + + def __init__( + self, + status_code: int = status.HTTP_400_BAD_REQUEST, + detail: Any = None, + headers: Optional[Dict[str, str]] = None, + error_code: Optional[str] = None, + ): + super().__init__(status_code=status_code, detail=detail, headers=headers) + self.error_code = error_code or f"ERR_{status_code}" + + +class FaceFeatureProcessingError(APIError): + """人脸特征处理错误""" + + def __init__( + self, + detail: str = "人脸特征处理失败", + feature_id: Optional[int] = None, + ): + if feature_id: + detail = f"人脸特征处理失败 (特征ID: {feature_id}): {detail}" + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=detail, + error_code="FACE_FEATURE_PROCESSING_ERROR" + ) + + +class FeatureNotFoundError(APIError): + """特征记录不存在错误""" + + def __init__(self, feature_id: int): + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"特征记录不存在 (ID: {feature_id})", + error_code="FEATURE_NOT_FOUND" + ) + + +class DuplicateFeatureError(APIError): + """重复特征记录错误""" + + def __init__(self, person_id: int, feature_type: int): + super().__init__( + status_code=status.HTTP_409_CONFLICT, + detail=f"特征记录已存在 (人员ID: {person_id}, 特征类型: {feature_type})", + error_code="DUPLICATE_FEATURE" + ) + + +async def validation_exception_handler( + request: Request, + exc: RequestValidationError +) -> JSONResponse: + """ + 请求验证异常处理器 + + Args: + request: 请求对象 + exc: 验证异常 + + Returns: + JSON响应 + """ + errors = [] + for error in exc.errors(): + field = ".".join(str(loc) for loc in error["loc"] if loc != "body") + errors.append({ + "field": field or "body", + "message": error["msg"], + "type": error["type"] + }) + + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "error": { + "code": "VALIDATION_ERROR", + "message": "请求参数验证失败", + "details": errors + } + } + ) + + +async def api_error_handler( + request: Request, + exc: APIError +) -> JSONResponse: + """ + API错误处理器 + + Args: + request: 请求对象 + exc: API错误 + + Returns: + JSON响应 + """ + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "code": exc.error_code, + "message": exc.detail + } + } + ) + + +async def generic_exception_handler( + request: Request, + exc: Exception +) -> JSONResponse: + """ + 通用异常处理器 + + Args: + request: 请求对象 + exc: 异常 + + Returns: + JSON响应 + """ + # 记录异常到日志 + import logging + logger = logging.getLogger(__name__) + logger.error(f"未处理的异常: {exc}", exc_info=True) + + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "error": { + "code": "INTERNAL_SERVER_ERROR", + "message": "服务器内部错误" + } + } + ) \ No newline at end of file diff --git a/src/api/routes/face_features.py b/src/api/routes/face_features.py new file mode 100644 index 0000000..f7f6401 --- /dev/null +++ b/src/api/routes/face_features.py @@ -0,0 +1,417 @@ +""" +人脸特征API路由 +""" + +from typing import List, Optional +from datetime import datetime +from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks +from fastapi.responses import JSONResponse + +from src.schemas.face_feature import ( + FaceFeatureCreate, + FaceFeatureUpdate, + FaceFeatureQuery, + FaceFeatureResponse, + FaceFeatureListResponse, + FaceFeatureStatsResponse, + BatchFaceFeatureCreate, + FeatureStatus +) +from src.api.dependencies import ( + get_face_feature_service, + get_face_feature_by_id +) +from src.services.face_feature_service import FaceFeatureService +from src.api.errors import ( + FaceFeatureProcessingError, + FeatureNotFoundError, + DuplicateFeatureError +) +from src.config import settings + +# 创建路由器 +router = APIRouter( + prefix="/face-features", + tags=["人脸特征管理"], + responses={ + 404: {"description": "资源不存在"}, + 400: {"description": "请求参数错误"}, + 500: {"description": "服务器内部错误"} + } +) + + +@router.post( + "/", + response_model=FaceFeatureResponse, + status_code=status.HTTP_201_CREATED, + summary="创建人脸特征记录", + description="创建新的人脸特征记录" +) +async def create_face_feature( + feature_data: FaceFeatureCreate, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 创建人脸特征记录 + + - **person_id**: 人员ID (必须,大于0) + - **feature_type**: 特征类型 (可选,大于等于0) + - **pic_id**: 图片ID (可选) + - **status**: 计算状态 (默认: NOT_STARTED) + - **feature_data**: 特征数据 (可选,二进制) + """ + try: + return service.create_feature(feature_data) + except ValueError as e: + if "already exists" in str(e): + # 解析错误信息中的person_id和feature_type + raise DuplicateFeatureError( + person_id=feature_data.person_id, + feature_type=feature_data.feature_type or 0 + ) + raise FaceFeatureProcessingError(detail=str(e)) + + +@router.get( + "/{feature_id}", + response_model=FaceFeatureResponse, + summary="获取人脸特征记录", + description="根据ID获取人脸特征记录" +) +async def get_face_feature( + feature: FaceFeatureResponse = Depends(get_face_feature_by_id) +): + """ + 根据ID获取人脸特征记录 + + - **feature_id**: 特征记录ID (路径参数) + """ + return feature + + +@router.get( + "/", + response_model=FaceFeatureListResponse, + summary="查询人脸特征记录", + description="查询人脸特征记录列表,支持分页和过滤" +) +async def list_face_features( + person_id: Optional[int] = Query(None, description="人员ID", gt=0), + feature_type: Optional[int] = Query(None, description="特征类型", ge=0), + status: Optional[FeatureStatus] = Query(None, description="计算状态"), + start_date: Optional[datetime] = Query(None, description="开始时间"), + end_date: Optional[datetime] = Query(None, description="结束时间"), + has_feature_data: Optional[bool] = Query(None, description="是否有特征数据"), + page: int = Query(1, description="页码", ge=1), + page_size: int = Query(20, description="每页数量", ge=1, le=100), + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 查询人脸特征记录 + + - **person_id**: 按人员ID过滤 (可选) + - **feature_type**: 按特征类型过滤 (可选) + - **status**: 按计算状态过滤 (可选) + - **start_date**: 开始时间过滤 (可选) + - **end_date**: 结束时间过滤 (可选) + - **has_feature_data**: 是否有特征数据过滤 (可选) + - **page**: 页码 (默认: 1) + - **page_size**: 每页数量 (默认: 20, 最大: 100) + """ + # 构建查询参数 + query = FaceFeatureQuery( + person_id=person_id, + feature_type=feature_type, + status=status, + start_date=start_date, + end_date=end_date, + has_feature_data=has_feature_data + ) + + return service.query_features( + query=query, + page=page, + page_size=page_size, + order_by="created_time", + desc_order=True + ) + + +@router.put( + "/{feature_id}", + response_model=FaceFeatureResponse, + summary="更新人脸特征记录", + description="更新指定ID的人脸特征记录" +) +async def update_face_feature( + feature_id: int, + update_data: FaceFeatureUpdate, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 更新人脸特征记录 + + - **feature_id**: 特征记录ID (路径参数) + - **update_data**: 更新数据 (请求体) + """ + try: + result = service.update_feature(feature_id, update_data) + if not result: + raise FeatureNotFoundError(feature_id) + return result + except ValueError as e: + raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id) + + +@router.delete( + "/{feature_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="删除人脸特征记录", + description="删除指定ID的人脸特征记录" +) +async def delete_face_feature( + feature_id: int, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 删除人脸特征记录 + + - **feature_id**: 特征记录ID (路径参数) + """ + success = service.delete_feature(feature_id) + if not success: + raise FeatureNotFoundError(feature_id) + + return JSONResponse( + status_code=status.HTTP_204_NO_CONTENT, + content=None + ) + + +@router.post( + "/{feature_id}/start-processing", + response_model=FaceFeatureResponse, + summary="开始处理人脸特征", + description="开始计算指定ID的人脸特征值" +) +async def start_face_feature_processing( + feature_id: int, + background_tasks: BackgroundTasks, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 开始处理人脸特征计算 + + - **feature_id**: 特征记录ID (路径参数) + + 注意:这是一个异步处理接口,会立即返回开始状态, + 实际特征计算可能在后台进行。 + """ + try: + # 先获取特征记录 + feature = service.get_feature(feature_id) + if not feature: + raise FeatureNotFoundError(feature_id) + + # 检查是否可以开始处理 + if feature.status != FeatureStatus.NOT_STARTED: + raise FaceFeatureProcessingError( + detail=f"特征记录状态为 {feature.status_name},无法开始处理", + feature_id=feature_id + ) + + # 开始处理 + success = service.start_processing(feature_id) + if not success: + raise FaceFeatureProcessingError( + detail="开始处理失败", + feature_id=feature_id + ) + + # 异步任务:模拟特征计算过程 + # 在实际应用中,这里应该调用实际的特征计算服务 + background_tasks.add_task( + simulate_feature_processing, + feature_id=feature_id, + service=service + ) + + # 返回更新后的特征记录 + return service.get_feature(feature_id) + + except ValueError as e: + raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id) + + +@router.post( + "/{feature_id}/finish-processing", + response_model=FaceFeatureResponse, + summary="完成人脸特征处理", + description="完成指定ID的人脸特征值计算" +) +async def finish_face_feature_processing( + feature_id: int, + success: bool = Query(True, description="是否成功完成"), + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 完成人脸特征计算 + + - **feature_id**: 特征记录ID (路径参数) + - **success**: 是否成功完成 (查询参数,默认: true) + """ + try: + # 检查特征记录 + feature = service.get_feature(feature_id) + if not feature: + raise FeatureNotFoundError(feature_id) + + # 检查是否可以完成处理 + if feature.status != FeatureStatus.PROCESSING: + raise FaceFeatureProcessingError( + detail=f"特征记录状态为 {feature.status_name},无法完成处理", + feature_id=feature_id + ) + + # 完成处理 + finish_success = service.finish_processing(feature_id, success) + if not finish_success: + raise FaceFeatureProcessingError( + detail="完成处理失败", + feature_id=feature_id + ) + + return service.get_feature(feature_id) + + except ValueError as e: + raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id) + + +@router.post( + "/batch", + response_model=List[FaceFeatureResponse], + status_code=status.HTTP_201_CREATED, + summary="批量创建人脸特征记录", + description="批量创建多个人脸特征记录" +) +async def batch_create_face_features( + batch_data: BatchFaceFeatureCreate, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 批量创建人脸特征记录 + + - **items**: 特征记录列表 (必须,1-1000条) + """ + try: + return service.create_features_batch(batch_data) + except ValueError as e: + raise FaceFeatureProcessingError(detail=str(e)) + + +@router.get( + "/person/{person_id}", + response_model=List[FaceFeatureResponse], + summary="获取人员的人脸特征记录", + description="根据人员ID获取所有相关的人脸特征记录" +) +async def get_face_features_by_person( + person_id: int, + limit: int = Query(100, description="返回数量限制", ge=1, le=1000), + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 获取人员的人脸特征记录 + + - **person_id**: 人员ID (路径参数) + - **limit**: 返回数量限制 (查询参数,默认: 100, 最大: 1000) + """ + return service.list_features_by_person(person_id, limit) + + +@router.get( + "/stats/summary", + response_model=FaceFeatureStatsResponse, + summary="获取特征记录统计信息", + description="获取人脸特征记录的统计摘要" +) +async def get_face_features_stats( + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 获取特征记录统计信息 + """ + return service.get_statistics() + + +@router.get( + "/person/{person_id}/stats", + summary="获取人员特征统计信息", + description="获取指定人员的特征记录统计信息" +) +async def get_person_face_features_stats( + person_id: int, + service: FaceFeatureService = Depends(get_face_feature_service) +): + """ + 获取人员特征统计信息 + + - **person_id**: 人员ID (路径参数) + """ + try: + stats = service.get_person_statistics(person_id) + return { + "person_id": person_id, + "total_features": stats["total_features"], + "status_summary": stats["status_summary"], + "feature_types": stats["feature_types"], + "avg_processing_time": stats["avg_processing_time"], + "successful_count": stats["successful_count"] + } + except Exception as e: + raise FaceFeatureProcessingError(detail=str(e)) + + +async def simulate_feature_processing( + feature_id: int, + service: FaceFeatureService +): + """ + 模拟人脸特征计算过程 + + 在实际应用中,这里应该调用实际的特征计算算法 + 例如:使用InsightFace、OpenCV等库进行人脸特征提取 + + Args: + feature_id: 特征记录ID + service: 人脸特征服务 + """ + import asyncio + import random + + try: + # 模拟计算延迟 (3-10秒) + delay = random.uniform(3, 10) + await asyncio.sleep(delay) + + # 模拟成功或失败 (90%成功率) + success = random.random() < 0.9 + + # 完成处理 + service.finish_processing(feature_id, success) + + # 如果成功,添加模拟的特征数据 + if success: + # 生成模拟的512维特征向量 (float32) + import numpy as np + feature_data = np.random.randn(512).astype(np.float32).tobytes() + service.update_feature_data(feature_id, feature_data) + + except Exception as e: + # 如果发生异常,标记为失败 + service.finish_processing(feature_id, False) + # 记录日志 + import logging + logger = logging.getLogger(__name__) + logger.error(f"特征计算失败 (ID: {feature_id}): {e}") \ No newline at end of file diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..9d79a13 --- /dev/null +++ b/src/app.py @@ -0,0 +1,188 @@ +""" +FastAPI主应用 +将原来的main.py重命名为app.py +""" + +import time +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.trustedhost import TrustedHostMiddleware +from fastapi.responses import JSONResponse +from fastapi.openapi.docs import ( + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, + get_redoc_html, +) +from fastapi.staticfiles import StaticFiles + +from src.api.routes import face_features +from src.api.errors import ( + APIError, + validation_exception_handler, + api_error_handler, + generic_exception_handler +) +from src.config import settings +from src.database.connection import init_database +from src.database.connection import db_manager + + +# 生命周期管理 +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + 应用生命周期管理 + + - 启动时:初始化数据库 + - 关闭时:清理资源 + """ + # 启动时 + print("🚀 start algorithm service...") + print(f"📊 db: {settings.DATABASE_NAME}") + print(f"🔧 debug mode: {settings.DEBUG}") + + # 初始化数据库 + init_database() + + # 数据库健康检查 + if db_manager.health_check(): + print("✅ 数据库连接正常") + else: + print("❌ 数据库连接失败") + + yield + + # 关闭时 + print("🛑 algorithm service stopped...") + db_manager.close() + + +# 创建FastAPI应用 +app = FastAPI( + title=settings.PROJECT_NAME, + version=settings.PROJECT_VERSION, + description=settings.PROJECT_DESCRIPTION, + openapi_url=f"{settings.API_V1_PREFIX}/openapi.json", + docs_url=None, # 自定义docs + redoc_url=None, # 自定义redoc + lifespan=lifespan +) + + +# 自定义API文档页面 +@app.get("/docs", include_in_schema=False) +async def custom_swagger_ui_html(): + return get_swagger_ui_html( + openapi_url=app.openapi_url, + title=f"{app.title} - Swagger UI", + oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, + swagger_js_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js", + swagger_css_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css", + ) + + +@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False) +async def swagger_ui_redirect(): + return get_swagger_ui_oauth2_redirect_html() + + +@app.get("/redoc", include_in_schema=False) +async def redoc_html(): + return get_redoc_html( + openapi_url=app.openapi_url, + title=f"{app.title} - ReDoc", + redoc_js_url="https://unpkg.com/redoc@next/bundles/redoc.standalone.js", + ) + + +# 中间件配置 +app.add_middleware( + CORSMiddleware, + allow_origins=settings.BACKEND_CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.add_middleware( + TrustedHostMiddleware, + allowed_hosts=["*"] if settings.DEBUG else ["localhost", "127.0.0.1"] +) + + +# 请求计时中间件 +@app.middleware("http") +async def add_process_time_header(request: Request, call_next): + """ + 添加请求处理时间头 + """ + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + + +# 异常处理器 +app.add_exception_handler(RequestValidationError, validation_exception_handler) +app.add_exception_handler(APIError, api_error_handler) +app.add_exception_handler(Exception, generic_exception_handler) + + +# 根路由 +@app.get("/") +async def root(): + """ + 根路径 + """ + return { + "message": "algorithm service", + "version": settings.PROJECT_VERSION, + "docs": "/docs", + "api_prefix": settings.API_V1_PREFIX + } + + +@app.get("/health") +async def health_check(): + """ + 健康检查端点 + """ + # 检查数据库连接 + db_healthy = db_manager.health_check() + + return { + "status": "healthy" if db_healthy else "unhealthy", + "database": "connected" if db_healthy else "disconnected", + "timestamp": time.time() + } + + +# 注册API路由 +app.include_router( + face_features.router, + prefix=settings.API_V1_PREFIX +) + + +# 自定义404处理器 +@app.exception_handler(404) +async def not_found_handler(request: Request, exc): + """ + 自定义404错误处理器 + """ + return JSONResponse( + status_code=404, + content={ + "error": { + "code": "NOT_FOUND", + "message": f"请求的资源不存在: {request.url.path}" + } + } + ) + + +# 导出应用实例 +__all__ = ["app"] \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..81d1cce --- /dev/null +++ b/src/config.py @@ -0,0 +1,91 @@ +""" +数据库配置模块 +使用pydantic进行配置验证和管理 +""" + +from typing import Optional, List +from pydantic_settings import BaseSettings +from functools import lru_cache +from pydantic import PostgresDsn, field_validator +from pydantic_core.core_schema import FieldValidationInfo + + +class Settings(BaseSettings): + """应用配置类""" + + # API配置 + API_V1_PREFIX: str = "/api/v1" + PROJECT_NAME: str = "algorithm-service" + PROJECT_VERSION: str = "1.0.0" + PROJECT_DESCRIPTION: str = "algorithm-service" + BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"] + + + # 数据库配置 + DATABASE_HOST: str = "localhost" + DATABASE_PORT: int = 5432 + DATABASE_USER: str = "postgres" + DATABASE_PASSWORD: str = "yipai123" + DATABASE_NAME: str = "pmms" + DATABASE_SCHEMA: str = "public" + + # 连接池配置 + DATABASE_POOL_SIZE: int = 10 + DATABASE_MAX_OVERFLOW: int = 20 + DATABASE_POOL_RECYCLE: int = 3600 # 连接回收时间(秒) + DATABASE_ECHO: bool = False # SQL日志,生产环境设为False + + # 应用配置 + APP_NAME: str = "SurFaceFeature API" + APP_VERSION: str = "1.0.0" + DEBUG: bool = False + + # 日志配置 + LOG_LEVEL: str = "INFO" + LOG_FILE: Optional[str] = None + + # 异步配置 + ASYNC_MODE: bool = False + + # JWT配置(预留) + SECRET_KEY: str = "your-secret-key-here-change-in-production" + ALGORITHM: str = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + + @property + def DATABASE_URL(self) -> str: + """构建数据库连接URL""" + return f"postgresql://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}" + + @property + def ASYNC_DATABASE_URL(self) -> str: + """构建异步数据库连接URL""" + return f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}" + + @field_validator("DATABASE_POOL_SIZE") + def validate_pool_size(cls, v): + """验证连接池大小""" + if v < 1: + raise ValueError("DATABASE_POOL_SIZE must be at least 1") + if v > 100: + raise ValueError("DATABASE_POOL_SIZE cannot exceed 100") + return v + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = False + extra = "ignore" + + +@lru_cache() +def get_settings() -> Settings: + """ + 获取配置单例 + 使用lru_cache避免重复加载.env文件 + """ + return Settings() + + +# 导出配置实例 +settings = get_settings() \ No newline at end of file diff --git a/src/database/base.py b/src/database/base.py new file mode 100644 index 0000000..8692abb --- /dev/null +++ b/src/database/base.py @@ -0,0 +1,66 @@ +""" +数据库模型基类 +""" + +from datetime import datetime +from typing import Any, Dict +from sqlalchemy import Column, DateTime, Integer +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import func + +Base = declarative_base() + + +class BaseModel(Base): + """抽象基类,为所有模型提供通用字段""" + __abstract__ = True + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + created_time = Column(DateTime(timezone=True), + server_default=func.now(), + nullable=False, + comment="创建时间") + + def to_dict(self, exclude: list = None) -> Dict[str, Any]: + """ + 将模型实例转换为字典 + + Args: + exclude: 要排除的字段列表 + + Returns: + 包含模型字段的字典 + """ + exclude = exclude or [] + result = {} + + for column in self.__table__.columns: + if column.name in exclude: + continue + + value = getattr(self, column.name) + + # 处理特殊类型 + if isinstance(value, datetime): + value = value.isoformat() + elif isinstance(value, bytes): + value = value.hex() if value else None + + result[column.name] = value + + return result + + def update_from_dict(self, data: Dict[str, Any]) -> None: + """ + 从字典更新模型字段 + + Args: + data: 包含要更新字段的字典 + """ + for key, value in data.items(): + if hasattr(self, key) and key != 'id': + setattr(self, key, value) + + def __repr__(self) -> str: + """模型表示""" + return f"<{self.__class__.__name__}(id={self.id})>" \ No newline at end of file diff --git a/src/database/connection.py b/src/database/connection.py new file mode 100644 index 0000000..a27f62a --- /dev/null +++ b/src/database/connection.py @@ -0,0 +1,252 @@ +""" +数据库连接管理模块 +支持同步和异步两种模式 +""" + +import logging +from typing import Optional, Generator, AsyncGenerator +from contextlib import asynccontextmanager, contextmanager + +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + create_async_engine, + async_sessionmaker +) + +from src.config import settings +from src.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class DatabaseManager: + """数据库管理器(同步)""" + + def __init__(self): + self._engine: Optional[Engine] = None + self._session_factory: Optional[sessionmaker] = None + + def init_engine(self) -> None: + """初始化数据库引擎""" + if self._engine is not None: + logger.warning("Database engine already initialized") + return + + try: + self._engine = create_engine( + settings.DATABASE_URL, + pool_size=settings.DATABASE_POOL_SIZE, + max_overflow=settings.DATABASE_MAX_OVERFLOW, + pool_recycle=settings.DATABASE_POOL_RECYCLE, + echo=settings.DATABASE_ECHO, + pool_pre_ping=True, # 连接前进行ping检查 + connect_args={ + "options": f"-csearch_path={settings.DATABASE_SCHEMA}" + } + ) + + self._session_factory = sessionmaker( + autocommit=False, + autoflush=False, + bind=self._engine, + expire_on_commit=False # 避免延迟加载问题 + ) + + logger.info("Database engine initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize database engine: {e}") + raise + + @property + def engine(self) -> Engine: + """获取数据库引擎""" + if self._engine is None: + self.init_engine() + return self._engine + + @property + def session_factory(self) -> sessionmaker: + """获取会话工厂""" + if self._session_factory is None: + self.init_engine() + return self._session_factory + + @contextmanager + def get_session(self) -> Generator[Session, None, None]: + """ + 获取数据库会话的上下文管理器 + + 使用示例: + with db_manager.get_session() as session: + result = session.query(User).all() + """ + session = self.session_factory() + try: + yield session + session.commit() + logger.debug("Session committed successfully") + except Exception as e: + session.rollback() + logger.error(f"Session rollback due to error: {e}") + raise + finally: + session.close() + logger.debug("Session closed") + + def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list: + """执行原始SQL查询""" + with self.get_session() as session: + result = session.execute(text(sql), params or {}) + return [dict(row._mapping) for row in result] + + def health_check(self) -> bool: + """数据库健康检查""" + try: + with self.engine.connect() as conn: + conn.execute(text("SELECT 1")) + logger.info("Database health check passed") + return True + except Exception as e: + logger.error(f"Database health check failed: {e}") + return False + + def close(self) -> None: + """关闭数据库连接""" + if self._engine: + self._engine.dispose() + self._engine = None + self._session_factory = None + logger.info("Database connections closed") + + +class AsyncDatabaseManager: + """异步数据库管理器""" + + def __init__(self): + self._engine: Optional[AsyncEngine] = None + self._async_session_factory: Optional[async_sessionmaker] = None + + async def init_engine(self) -> None: + """初始化异步数据库引擎""" + if self._engine is not None: + logger.warning("Async database engine already initialized") + return + + try: + self._engine = create_async_engine( + settings.ASYNC_DATABASE_URL, + pool_size=settings.DATABASE_POOL_SIZE, + max_overflow=settings.DATABASE_MAX_OVERFLOW, + pool_recycle=settings.DATABASE_POOL_RECYCLE, + echo=settings.DATABASE_ECHO, + pool_pre_ping=True, + connect_args={ + "server_settings": { + "search_path": settings.DATABASE_SCHEMA + } + } + ) + + self._async_session_factory = async_sessionmaker( + self._engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False + ) + + logger.info("Async database engine initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize async database engine: {e}") + raise + + @property + async def engine(self) -> AsyncEngine: + """获取异步数据库引擎""" + if self._engine is None: + await self.init_engine() + return self._engine + + @property + async def async_session_factory(self) -> async_sessionmaker: + """获取异步会话工厂""" + if self._async_session_factory is None: + await self.init_engine() + return self._async_session_factory + + @asynccontextmanager + async def get_session(self) -> AsyncGenerator[AsyncSession, None]: + """ + 获取异步数据库会话的上下文管理器 + + 使用示例: + async with async_db_manager.get_session() as session: + result = await session.execute(query) + """ + if self._async_session_factory is None: + await self.init_engine() + + session = self._async_session_factory() + try: + yield session + await session.commit() + logger.debug("Async session committed successfully") + except Exception as e: + await session.rollback() + logger.error(f"Async session rollback due to error: {e}") + raise + finally: + await session.close() + logger.debug("Async session closed") + + async def health_check(self) -> bool: + """异步数据库健康检查""" + try: + async with self._engine.connect() as conn: + await conn.execute(text("SELECT 1")) + logger.info("Async database health check passed") + return True + except Exception as e: + logger.error(f"Async database health check failed: {e}") + return False + + async def close(self) -> None: + """关闭异步数据库连接""" + if self._engine: + await self._engine.dispose() + self._engine = None + self._async_session_factory = None + logger.info("Async database connections closed") + + +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() +async_db_manager = AsyncDatabaseManager() + + +def init_database() -> None: + """初始化数据库(同步)""" + db_manager.init_engine() + + +async def init_async_database() -> None: + """初始化异步数据库""" + await async_db_manager.init_engine() + + +def get_db() -> Generator[Session, None, None]: + """依赖注入:获取数据库会话(用于FastAPI等框架)""" + with db_manager.get_session() as session: + yield session + + +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + """依赖注入:获取异步数据库会话(用于FastAPI等框架)""" + async with async_db_manager.get_session() as session: + yield session \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..96a4536 --- /dev/null +++ b/src/main.py @@ -0,0 +1,207 @@ +""" +主程序示例 +演示如何使用各个模块 +""" + +import asyncio +from datetime import datetime, timedelta + +from src.config import settings +from src.database.connection import db_manager, init_database +from src.models.face_feature import SurFaceFeature +from src.repositories.face_feature_repository import FaceFeatureRepository +from src.services.face_feature_service import FaceFeatureService +from src.schemas.face_feature import ( + FaceFeatureCreate, + FaceFeatureUpdate, + FaceFeatureQuery, + FeatureStatus, + BatchFaceFeatureCreate +) + + +def demo_sync_operations(): + """演示同步操作""" + print("=== 同步操作演示 ===") + + # 初始化数据库 + init_database() + + # 创建仓库和服务 + with db_manager.get_session() as session: + repository = FaceFeatureRepository(session) + service = FaceFeatureService(repository) + + try: + # 1. 创建特征记录(先检查是否存在) + print("\n1. 创建特征记录") + + # 固定的测试ID + test_person_id = 1001 + test_feature_type = 1 + + # 检查是否已存在 + existing = service.get_feature_by_person_and_type(test_person_id, test_feature_type) + + if existing: + print(f"记录已存在: ID={existing.id}, 状态={existing.status_name}") + feature_id = existing.id + else: + feature_data = FaceFeatureCreate( + person_id=test_person_id, + feature_type=test_feature_type, + pic_id="test_image_001.jpg", + status=FeatureStatus.NOT_STARTED + ) + + feature_response = service.create_feature(feature_data) + print(f"创建成功: ID={feature_response.id}, 人员ID={feature_response.person_id}") + feature_id = feature_response.id + + # 2. 开始处理(如果未开始) + print("\n2. 检查并开始处理特征计算") + feature = service.get_feature(feature_id) + if feature and feature.status == FeatureStatus.NOT_STARTED: + if service.start_processing(feature_id): + print(f"已开始处理: ID={feature_id}") + else: + print(f"特征计算已开始或已完成: 状态={feature.status_name if feature else '未知'}") + + # 3. 完成处理(如果还在处理中) + print("\n3. 检查并完成处理特征计算") + feature = service.get_feature(feature_id) + if feature and feature.status == FeatureStatus.PROCESSING: + if service.finish_processing(feature_id, success=True): + print(f"已完成处理: ID={feature_id}") + else: + print(f"特征计算已完成或未开始: 状态={feature.status_name if feature else '未知'}") + + # 4. 查询特征 + print("\n4. 查询特征记录") + retrieved = service.get_feature(feature_id) + if retrieved: + print(f"查询成功: ID={retrieved.id}, 状态={retrieved.status_name}") + print(f"处理时间: {retrieved.processing_time}秒") + print(f"是否有特征数据: {retrieved.has_feature_data}") + + # 5. 批量创建 - 跳过已存在的 + print("\n5. 批量创建特征记录") + batch_items = [] + for i in range(3): + person_id = 2000 + i + # 检查是否已存在 + existing = service.get_feature_by_person_and_type(person_id, 1) + if not existing: + batch_items.append(FaceFeatureCreate(person_id=person_id, feature_type=1)) + + if batch_items: + batch_data = BatchFaceFeatureCreate(items=batch_items) + try: + batch_result = service.create_features_batch(batch_data) + print(f"批量创建成功: {len(batch_result)}条记录") + except ValueError as e: + print(f"批量创建失败: {e}") + else: + print("所有记录已存在,跳过批量创建") + + # 6. 查询列表 + print("\n6. 查询特征记录列表") + query = FaceFeatureQuery( + feature_type=1, + start_date=datetime.now() - timedelta(days=1) + ) + + result = service.query_features(query, page=1, page_size=10) + print(f"查询结果: 共{result.total}条记录,本页{len(result.items)}条") + + if result.items: + print(f"第一笔记录: ID={result.items[0].id}, 人员ID={result.items[0].person_id}") + + # 7. 获取统计信息 + print("\n7. 获取统计信息") + stats = service.get_statistics() + print(f"总记录数: {stats.total_count}") + print(f"状态分布: {stats.by_status}") + print(f"特征类型分布: {stats.by_feature_type}") + + # 8. 更新特征数据 + print("\n8. 更新特征数据") + test_feature_data = b"test_feature_data_12345" + success = service.update_feature_data(feature_id, test_feature_data) + print(f"更新特征数据: {'成功' if success else '失败'}") + + # 重新查询查看更新后的数据 + updated_feature = service.get_feature(feature_id) + if updated_feature: + print(f"更新后是否有特征数据: {updated_feature.has_feature_data}") + + # 9. 演示删除操作 + print("\n9. 演示删除操作") + # 先创建一个要删除的记录 + delete_person_id = 9999 # 使用一个不存在的ID + delete_feature_data = FaceFeatureCreate( + person_id=delete_person_id, + feature_type=3, + pic_id="to_delete.jpg" + ) + delete_feature = service.create_feature(delete_feature_data) + print(f"创建待删除记录: ID={delete_feature.id}") + + # 删除记录 + delete_success = service.delete_feature(delete_feature.id) + print(f"删除记录: {'成功' if delete_success else '失败'}") + + # 验证删除 + deleted_check = service.get_feature(delete_feature.id) + print(f"验证删除: {'记录不存在' if deleted_check is None else '记录还存在'}") + + except Exception as e: + print(f"操作失败: {e}") + import traceback + traceback.print_exc() + + # 关闭数据库连接 + db_manager.close() + print("\n数据库连接已关闭") + + +async def demo_async_operations(): + """演示异步操作""" + print("\n=== 异步操作演示 ===") + + # 注意:异步操作需要异步数据库管理器 + # 这里仅展示结构,实际使用时需要配置异步数据库 + + print("异步操作示例代码已准备,需要配置异步数据库连接") + print("请使用 async_db_manager 和异步版本的repository") + + +def main(): + """主函数""" + print("人脸特征管理系统演示") + print("=" * 50) + + # 显示配置信息 + print(f"应用名称: {settings.APP_NAME}") + print(f"数据库: {settings.DATABASE_NAME}") + print(f"连接池大小: {settings.DATABASE_POOL_SIZE}") + + try: + # 执行同步演示 + demo_sync_operations() + + # 如果需要异步演示,可以取消下面的注释 + # asyncio.run(demo_async_operations()) + + except KeyboardInterrupt: + print("\n程序被用户中断") + except Exception as e: + print(f"程序执行出错: {e}") + import traceback + traceback.print_exc() + + print("\n演示完成!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/models/face_feature.py b/src/models/face_feature.py new file mode 100644 index 0000000..a35f51a --- /dev/null +++ b/src/models/face_feature.py @@ -0,0 +1,197 @@ +""" +人脸特征数据模型 +对应数据库表:sur_face_feature +""" + +from typing import Optional +from datetime import datetime +from enum import IntEnum +from sqlalchemy import ( + Column, + Integer, + SmallInteger, + LargeBinary, + DateTime, + Text, + Index, + UniqueConstraint +) +from sqlalchemy.dialects.postgresql import BYTEA +from sqlalchemy.sql import func + +from src.database.base import BaseModel + + +class FeatureStatus(IntEnum): + """人脸特征值计算状态枚举""" + NOT_STARTED = 0 # 未开始 + PROCESSING = 1 # 计算中 + SUCCESS = 2 # 计算成功 + FAILED = 3 # 计算失败 + + +# 导出别名以保持向后兼容性 +FeatureStatusEnum = FeatureStatus + + +class SurFaceFeature(BaseModel): + """ + 人脸特征值表模型 + + 对应表结构: + - id: 主键 + - person_id: 人员ID + - feature_type: 模型版本 + - feature_data: 特征值(二进制) + - created_time: 创建时间 + - pic_id: 图片ID + - status: 计算状态 + - start_time: 计算开始时间 + - finish_time: 计算结束时间 + """ + + __tablename__ = "sur_face_feature" + __table_args__ = ( + # 唯一约束:person_id + feature_type + UniqueConstraint("person_id", "feature_type", name="sur_face_feature_unique"), + # 为常用查询字段创建索引 + Index("ix_sur_face_feature_person_id", "person_id"), + Index("ix_sur_face_feature_feature_type", "feature_type"), + Index("ix_sur_face_feature_status", "status"), + Index("ix_sur_face_feature_created_time", "created_time"), + {"schema": "public", "comment": "人脸特征值表"} + ) + + # 主键(自增序列) + id = Column( + Integer, + primary_key=True, + index=True, + comment="主键" + ) + + # 人员ID(必填) + person_id = Column( + Integer, + nullable=False, + comment="人员id" + ) + + # 模型版本(特征类型) + feature_type = Column( + SmallInteger, + nullable=True, # 根据SQL,允许NULL + comment="模型版本" + ) + + # 特征值(二进制数据) + feature_data = Column( + BYTEA, # PostgreSQL的二进制类型 + nullable=True, + comment="特征值" + ) + + # 创建时间(自动设置) + created_time = Column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + comment="创建时间" + ) + + # 图片ID + pic_id = Column( + Text, + nullable=True, + comment="图片id" + ) + + # 计算状态 + status = Column( + SmallInteger, + default=FeatureStatusEnum.NOT_STARTED, + nullable=False, + comment="人脸特征值计算状态:0=未开始,1=计算中,2=计算成功,3=计算失败" + ) + + # 计算开始时间 + start_time = Column( + DateTime(timezone=True), + nullable=True, + comment="特征计算开始时间" + ) + + # 计算结束时间 + finish_time = Column( + DateTime(timezone=True), + nullable=True, + comment="特征计算结束时间" + ) + + # 属性方法 + @property + def status_name(self) -> str: + """获取状态名称""" + try: + return FeatureStatusEnum(self.status).name + except ValueError: + return f"未知状态({self.status})" + + @property + def is_completed(self) -> bool: + """是否完成计算""" + return self.status in [FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED] + + @property + def processing_time(self) -> Optional[float]: + """计算处理时间(秒)""" + if self.start_time and self.finish_time: + return (self.finish_time - self.start_time).total_seconds() + return None + + def start_processing(self) -> None: + """开始处理""" + self.status = FeatureStatusEnum.PROCESSING + self.start_time = datetime.now() + self.finish_time = None + + def finish_processing(self, success: bool = True) -> None: + """结束处理""" + self.status = FeatureStatusEnum.SUCCESS if success else FeatureStatusEnum.FAILED + self.finish_time = datetime.now() + + def to_dict(self, exclude: list = None) -> dict: + """ + 重写to_dict方法,处理二进制数据 + + Args: + exclude: 要排除的字段列表 + + Returns: + 转换后的字典 + """ + exclude = exclude or [] + + # 默认排除二进制数据(太大) + default_exclude = ["feature_data"] + final_exclude = list(set(exclude + default_exclude)) + + result = super().to_dict(final_exclude) + + # 添加额外属性 + result["status_name"] = self.status_name + result["is_completed"] = self.is_completed + result["processing_time"] = self.processing_time + + # 如果有feature_data,添加一个标识 + if self.feature_data and "feature_data" not in exclude: + result["has_feature_data"] = True + result["feature_data_length"] = len(self.feature_data) + else: + result["has_feature_data"] = False + + return result + + def __repr__(self) -> str: + return (f"") \ No newline at end of file diff --git a/src/repositories/face_feature_repository.py b/src/repositories/face_feature_repository.py new file mode 100644 index 0000000..0d7d1f1 --- /dev/null +++ b/src/repositories/face_feature_repository.py @@ -0,0 +1,597 @@ +""" +人脸特征数据仓库 +数据访问层,处理所有数据库操作 +""" + +import logging +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any, Tuple +from contextlib import contextmanager + +from sqlalchemy import select, update, delete, func, and_, or_, desc, asc +from sqlalchemy.orm import Session, joinedload +from sqlalchemy.exc import SQLAlchemyError, IntegrityError + +from src.models.face_feature import SurFaceFeature +from src.schemas.face_feature import ( + FaceFeatureCreate, + FaceFeatureUpdate, + FaceFeatureQuery, + FeatureStatus +) +from src.models.face_feature import FeatureStatusEnum +from src.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class FaceFeatureRepository: + """人脸特征数据仓库""" + + def __init__(self, session: Session): + """ + 初始化仓库 + + Args: + session: SQLAlchemy会话对象 + """ + self.session = session + + # ===== 创建操作 ===== + + def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature: + """ + 创建特征记录 + + Args: + feature_data: 特征数据 + + Returns: + 创建的SurFaceFeature对象 + + Raises: + IntegrityError: 违反唯一约束时抛出 + """ + try: + # 转换为模型字典 + feature_dict = feature_data.model_dump(exclude_unset=True) + + # 创建模型实例 + feature = SurFaceFeature(**feature_dict) + + # 添加到会话 + self.session.add(feature) + self.session.flush() # 立即执行,但不提交 + + logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}") + return feature + + except IntegrityError as e: + logger.error(f"Integrity error creating face feature: {e}") + self.session.rollback() + raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, " + f"feature_type={feature_data.feature_type}") + except SQLAlchemyError as e: + logger.error(f"Database error creating face feature: {e}") + self.session.rollback() + raise + + def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]: + """ + 批量创建特征记录 + + Args: + features_data: 特征数据列表 + + Returns: + 创建的SurFaceFeature对象列表 + """ + try: + features = [] + for feature_data in features_data: + feature_dict = feature_data.model_dump(exclude_unset=True) + feature = SurFaceFeature(**feature_dict) + features.append(feature) + + # 批量添加 + self.session.add_all(features) + self.session.flush() + + logger.info(f"Created {len(features)} face feature records in batch") + return features + + except IntegrityError as e: + logger.error(f"Integrity error creating batch face features: {e}") + self.session.rollback() + raise ValueError("Duplicate feature record in batch") + except SQLAlchemyError as e: + logger.error(f"Database error creating batch face features: {e}") + self.session.rollback() + raise + + # ===== 查询操作 ===== + + def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]: + """ + 根据ID获取特征记录 + + Args: + feature_id: 特征记录ID + + Returns: + SurFaceFeature对象或None + """ + try: + stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id) + result = self.session.execute(stmt) + feature = result.scalar_one_or_none() + + if feature: + logger.debug(f"Retrieved face feature by id: {feature_id}") + else: + logger.debug(f"Face feature not found by id: {feature_id}") + + return feature + + except SQLAlchemyError as e: + logger.error(f"Database error getting face feature by id: {e}") + raise + + def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]: + """ + 根据人员ID和特征类型获取特征记录 + + Args: + person_id: 人员ID + feature_type: 特征类型 + + Returns: + SurFaceFeature对象或None + """ + try: + stmt = select(SurFaceFeature).where( + and_( + SurFaceFeature.person_id == person_id, + SurFaceFeature.feature_type == feature_type + ) + ) + result = self.session.execute(stmt) + feature = result.scalar_one_or_none() + + if feature: + logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}") + else: + logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}") + + return feature + + except SQLAlchemyError as e: + logger.error(f"Database error getting face feature by person and type: {e}") + raise + + def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]: + """ + 根据人员ID获取特征记录列表 + + Args: + person_id: 人员ID + limit: 返回数量限制 + + Returns: + SurFaceFeature对象列表 + """ + try: + stmt = ( + select(SurFaceFeature) + .where(SurFaceFeature.person_id == person_id) + .order_by(desc(SurFaceFeature.created_time)) + .limit(limit) + ) + result = self.session.execute(stmt) + features = list(result.scalars().all()) + + logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}") + return features + + except SQLAlchemyError as e: + logger.error(f"Database error getting face features by person: {e}") + raise + + def query_features( + self, + query: FaceFeatureQuery, + page: int = 1, + page_size: int = 20, + order_by: str = "created_time", + desc_order: bool = True + ) -> Tuple[List[SurFaceFeature], int]: + """ + 查询特征记录(带分页) + + Args: + query: 查询条件 + page: 页码(从1开始) + page_size: 每页数量 + order_by: 排序字段 + desc_order: 是否降序 + + Returns: + (特征记录列表, 总记录数) + """ + try: + # 构建查询条件 + conditions = [] + query_dict = query.model_dump(exclude_unset=True, exclude_none=True) + + # 处理查询条件 + if "person_id" in query_dict: + conditions.append(SurFaceFeature.person_id == query_dict["person_id"]) + + if "feature_type" in query_dict: + conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"]) + + if "status" in query_dict: + conditions.append(SurFaceFeature.status == query_dict["status"]) + + if "start_date" in query_dict: + conditions.append(SurFaceFeature.created_time >= query_dict["start_date"]) + + if "end_date" in query_dict: + conditions.append(SurFaceFeature.created_time <= query_dict["end_date"]) + + if "has_feature_data" in query_dict: + if query_dict["has_feature_data"]: + conditions.append(SurFaceFeature.feature_data.isnot(None)) + else: + conditions.append(SurFaceFeature.feature_data.is_(None)) + + # 基础查询 + stmt = select(SurFaceFeature) + if conditions: + stmt = stmt.where(and_(*conditions)) + + # 获取总数 + count_stmt = select(func.count()).select_from(stmt.subquery()) + total_result = self.session.execute(count_stmt) + total = total_result.scalar_one() + + # 排序 + order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time) + if desc_order: + stmt = stmt.order_by(desc(order_column)) + else: + stmt = stmt.order_by(asc(order_column)) + + # 分页 + offset = (page - 1) * page_size + stmt = stmt.offset(offset).limit(page_size) + + # 执行查询 + result = self.session.execute(stmt) + features = list(result.scalars().all()) + + logger.debug(f"Query returned {len(features)} features (total: {total})") + return features, total + + except SQLAlchemyError as e: + logger.error(f"Database error querying face features: {e}") + raise + + # ===== 更新操作 ===== + + def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]: + """ + 更新特征记录 + + Args: + feature_id: 特征记录ID + update_data: 更新数据 + + Returns: + 更新后的SurFaceFeature对象或None(如果不存在) + """ + try: + # 先检查是否存在 + feature = self.get_by_id(feature_id) + if not feature: + logger.warning(f"Cannot update non-existent face feature: id={feature_id}") + return None + + # 转换为字典 + update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True) + + # 更新字段 + for key, value in update_dict.items(): + setattr(feature, key, value) + + # 刷新到数据库 + self.session.flush() + + logger.info(f"Updated face feature: id={feature_id}") + return feature + + except IntegrityError as e: + logger.error(f"Integrity error updating face feature: {e}") + self.session.rollback() + raise ValueError("Update would create duplicate record") + except SQLAlchemyError as e: + logger.error(f"Database error updating face feature: {e}") + self.session.rollback() + raise + + def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool: + """ + 更新特征数据 + + Args: + feature_id: 特征记录ID + feature_data: 特征数据(二进制) + + Returns: + 是否成功更新 + """ + try: + stmt = ( + update(SurFaceFeature) + .where(SurFaceFeature.id == feature_id) + .values(feature_data=feature_data) + .returning(SurFaceFeature.id) + ) + + result = self.session.execute(stmt) + updated_id = result.scalar_one_or_none() + + if updated_id: + logger.info(f"Updated feature data for face feature: id={feature_id}") + return True + else: + logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}") + return False + + except SQLAlchemyError as e: + logger.error(f"Database error updating feature data: {e}") + self.session.rollback() + raise + + def update_status( + self, + feature_id: int, + status: FeatureStatus, + start_time: Optional[datetime] = None, + finish_time: Optional[datetime] = None + ) -> bool: + """ + 更新计算状态 + + Args: + feature_id: 特征记录ID + status: 新状态 + start_time: 开始时间(可选) + finish_time: 结束时间(可选) + + Returns: + 是否成功更新 + """ + try: + update_values = {"status": status.value if isinstance(status, FeatureStatus) else status} + + if start_time: + update_values["start_time"] = start_time + if finish_time: + update_values["finish_time"] = finish_time + + stmt = ( + update(SurFaceFeature) + .where(SurFaceFeature.id == feature_id) + .values(**update_values) + .returning(SurFaceFeature.id) + ) + + result = self.session.execute(stmt) + updated_id = result.scalar_one_or_none() + + if updated_id: + logger.info(f"Updated status to {status} for face feature: id={feature_id}") + return True + else: + logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}") + return False + + except SQLAlchemyError as e: + logger.error(f"Database error updating status: {e}") + self.session.rollback() + raise + + # ===== 删除操作 ===== + + def delete(self, feature_id: int) -> bool: + """ + 删除特征记录 + + Args: + feature_id: 特征记录ID + + Returns: + 是否成功删除 + """ + try: + stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id) + result = self.session.execute(stmt) + + deleted_count = result.rowcount + if deleted_count > 0: + logger.info(f"Deleted face feature: id={feature_id}") + return True + else: + logger.warning(f"Cannot delete non-existent face feature: id={feature_id}") + return False + + except SQLAlchemyError as e: + logger.error(f"Database error deleting face feature: {e}") + self.session.rollback() + raise + + def delete_by_person(self, person_id: int) -> int: + """ + 删除指定人员的所有特征记录 + + Args: + person_id: 人员ID + + Returns: + 删除的记录数 + """ + try: + stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id) + result = self.session.execute(stmt) + + deleted_count = result.rowcount + logger.info(f"Deleted {deleted_count} face features for person_id={person_id}") + return deleted_count + + except SQLAlchemyError as e: + logger.error(f"Database error deleting face features by person: {e}") + self.session.rollback() + raise + + # ===== 统计操作 ===== + + def get_stats(self) -> Dict[str, Any]: + """ + 获取特征记录统计信息 + + Returns: + 统计信息字典 + """ + try: + # 总记录数 + total_stmt = select(func.count()).select_from(SurFaceFeature) + total_result = self.session.execute(total_stmt) + total_count = total_result.scalar_one() + + # 按状态统计 + status_stmt = ( + select(SurFaceFeature.status, func.count()) + .group_by(SurFaceFeature.status) + ) + status_result = self.session.execute(status_stmt) + status_stats = {str(status): count for status, count in status_result} + + # 按特征类型统计 + type_stmt = ( + select(SurFaceFeature.feature_type, func.count()) + .where(SurFaceFeature.feature_type.isnot(None)) + .group_by(SurFaceFeature.feature_type) + ) + type_result = self.session.execute(type_stmt) + type_stats = {str(feature_type): count for feature_type, count in type_result} + + # 平均处理时间(仅计算成功和失败的) + time_stmt = ( + select( + func.avg( + func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time) + ) + ) + .where( + and_( + SurFaceFeature.start_time.isnot(None), + SurFaceFeature.finish_time.isnot(None), + SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED]) + ) + ) + ) + time_result = self.session.execute(time_stmt) + avg_time = time_result.scalar_one() + + stats = { + "total_count": total_count, + "by_status": status_stats, + "by_feature_type": type_stats, + "avg_processing_time": float(avg_time) if avg_time else None + } + + logger.debug(f"Generated face feature statistics") + return stats + + except SQLAlchemyError as e: + logger.error(f"Database error getting statistics: {e}") + raise + + # ===== 批量操作 ===== + + def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]: + """ + 标记待处理的特征记录为计算中 + + Args: + limit: 最大处理数量 + + Returns: + 标记为处理中的特征记录列表 + """ + try: + # 查找待处理的记录 + pending_stmt = ( + select(SurFaceFeature) + .where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED) + .order_by(SurFaceFeature.created_time) + .limit(limit) + .with_for_update(skip_locked=True) # 跳过被锁定的行 + ) + + result = self.session.execute(pending_stmt) + pending_features = list(result.scalars().all()) + + # 更新状态 + feature_ids = [f.id for f in pending_features] + if feature_ids: + update_stmt = ( + update(SurFaceFeature) + .where(SurFaceFeature.id.in_(feature_ids)) + .values( + status=FeatureStatusEnum.PROCESSING, + start_time=datetime.now() + ) + ) + self.session.execute(update_stmt) + + logger.info(f"Marked {len(pending_features)} face features for processing") + return pending_features + + except SQLAlchemyError as e: + logger.error(f"Database error marking features for processing: {e}") + self.session.rollback() + raise + + def cleanup_old_features(self, days: int = 30) -> int: + """ + 清理旧的特征记录 + + Args: + days: 保留天数 + + Returns: + 删除的记录数 + """ + try: + cutoff_date = datetime.now() - timedelta(days=days) + + # 只删除已完成(成功或失败)的旧记录 + stmt = delete(SurFaceFeature).where( + and_( + SurFaceFeature.created_time < cutoff_date, + SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED]) + ) + ) + + result = self.session.execute(stmt) + deleted_count = result.rowcount + + logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)") + return deleted_count + + except SQLAlchemyError as e: + logger.error(f"Database error cleaning up old features: {e}") + self.session.rollback() + raise \ No newline at end of file diff --git a/src/run.py b/src/run.py new file mode 100644 index 0000000..67ce87b --- /dev/null +++ b/src/run.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +FastAPI应用启动脚本 +""" + +import uvicorn +import argparse +from src.config import settings + + +def main(): + """ + 主函数:解析命令行参数并启动服务器 + """ + parser = argparse.ArgumentParser(description="algorithm service") + + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="监听主机 (默认: 0.0.0.0)" + ) + + parser.add_argument( + "--port", + type=int, + default=8000, + help="监听端口 (默认: 8000)" + ) + + parser.add_argument( + "--reload", + action="store_true", + help="启用热重载 (开发模式)" + ) + + parser.add_argument( + "--workers", + type=int, + default=1, + help="工作进程数 (生产模式)" + ) + + parser.add_argument( + "--log-level", + type=str, + default="info", + choices=["debug", "info", "warning", "error", "critical"], + help="日志级别" + ) + + args = parser.parse_args() + + # 根据环境选择配置 + if settings.DEBUG: + print("🔧 开发模式") + uvicorn_config = { + "host": args.host, + "port": args.port, + "reload": True, + "log_level": "debug", + "workers": 1 + } + else: + print("🚀 生产模式") + uvicorn_config = { + "host": args.host, + "port": args.port, + "reload": False, + "log_level": args.log_level, + "workers": args.workers + } + + # 启动服务器 + print(f"🌐 启动服务器: http://{args.host}:{args.port}") + print(f"📚 API文档: http://{args.host}:{args.port}/docs") + print(f"📊 健康检查: http://{args.host}:{args.port}/health") + + uvicorn.run( + "src.app:app", + **uvicorn_config + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/schemas/face_feature.py b/src/schemas/face_feature.py new file mode 100644 index 0000000..9ec86f9 --- /dev/null +++ b/src/schemas/face_feature.py @@ -0,0 +1,213 @@ +""" +人脸特征值的Pydantic模型 +用于数据验证和序列化 +""" + +from datetime import datetime +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field, field_validator, ConfigDict, model_validator +from enum import IntEnum + + +# 枚举定义(与数据库模型一致) +class FeatureStatus(IntEnum): + NOT_STARTED = 0 + PROCESSING = 1 + SUCCESS = 2 + FAILED = 3 + + +# 基础模型 +class FaceFeatureBase(BaseModel): + """基础模型,包含所有字段""" + person_id: int = Field(..., description="人员ID", gt=0) + feature_type: Optional[int] = Field(None, description="模型版本", ge=0) + feature_data: Optional[bytes] = Field(None, description="特征值(二进制)") + pic_id: Optional[str] = Field(None, description="图片ID", max_length=255) + status: FeatureStatus = Field( + default=FeatureStatus.NOT_STARTED, + description="计算状态" + ) + start_time: Optional[datetime] = Field(None, description="计算开始时间") + finish_time: Optional[datetime] = Field(None, description="计算结束时间") + + @field_validator('feature_data', mode='before') + @classmethod + def validate_feature_data(cls, v): + """验证特征数据""" + if v is not None and not isinstance(v, bytes): + if isinstance(v, str): + # 尝试从hex字符串转换 + try: + return bytes.fromhex(v) + except ValueError: + raise ValueError("feature_data must be valid hex string or bytes") + else: + raise ValueError("feature_data must be bytes or hex string") + return v + + +# 创建模型 +class FaceFeatureCreate(FaceFeatureBase): + """创建特征记录模型""" + # 创建时不指定ID和时间 + model_config = ConfigDict( + json_schema_extra={ + "example": { + "person_id": 1001, + "feature_type": 1, + "pic_id": "img_20250101_001", + "status": 0 + } + } + ) + + +class FaceFeatureUpdate(BaseModel): + """更新特征记录模型""" + feature_type: Optional[int] = Field(None, description="模型版本", ge=0) + feature_data: Optional[bytes] = Field(None, description="特征值(二进制)") + pic_id: Optional[str] = Field(None, description="图片ID", max_length=255) + status: Optional[FeatureStatus] = Field(None, description="计算状态") + start_time: Optional[datetime] = Field(None, description="计算开始时间") + finish_time: Optional[datetime] = Field(None, description="计算结束时间") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "status": 2, + "finish_time": "2024-01-01T12:00:00Z" + } + } + ) + + +# 查询参数模型 +class FaceFeatureQuery(BaseModel): + """特征记录查询参数""" + person_id: Optional[int] = Field(None, description="人员ID", gt=0) + feature_type: Optional[int] = Field(None, description="模型版本", ge=0) + status: Optional[FeatureStatus] = Field(None, description="计算状态") + start_date: Optional[datetime] = Field(None, description="开始时间") + end_date: Optional[datetime] = Field(None, description="结束时间") + has_feature_data: Optional[bool] = Field(None, description="是否有特征数据") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "person_id": 1001, + "status": 2, + "start_date": "2024-01-01T00:00:00Z" + } + } + ) + + +# 响应模型 +class FaceFeatureResponse(FaceFeatureBase): + """特征记录响应模型""" + id: int + created_time: datetime + + # 计算字段(将在验证后设置) + status_name: Optional[str] = None + is_completed: Optional[bool] = None + processing_time: Optional[float] = None + has_feature_data: Optional[bool] = None + + @model_validator(mode='after') + def set_computed_fields(self): + """设置所有计算字段""" + # 状态名称 + try: + self.status_name = FeatureStatus(self.status).name + except ValueError: + self.status_name = f"未知状态({self.status})" + + # 是否完成 + self.is_completed = self.status in [FeatureStatus.SUCCESS, FeatureStatus.FAILED] + + # 处理时间 + if self.start_time and self.finish_time: + self.processing_time = (self.finish_time - self.start_time).total_seconds() + + # 是否有特征数据 + self.has_feature_data = self.feature_data is not None and len(self.feature_data) > 0 + + return self + + model_config = ConfigDict( + from_attributes=True, + populate_by_name=True, + json_schema_extra={ + "example": { + "id": 1, + "person_id": 1001, + "feature_type": 1, + "status": 2, + "status_name": "SUCCESS", + "created_time": "2024-01-01T10:00:00Z", + "pic_id": "img_20250101_001", + "start_time": "2024-01-01T10:00:00Z", + "finish_time": "2024-01-01T10:00:05Z", + "is_completed": True, + "processing_time": 5.0, + "has_feature_data": True + } + } + ) + + +class FaceFeatureListResponse(BaseModel): + """特征记录列表响应""" + total: int = Field(..., description="总记录数") + items: List[FaceFeatureResponse] = Field(..., description="记录列表") + page: Optional[int] = Field(None, description="当前页码") + page_size: Optional[int] = Field(None, description="每页数量") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "total": 100, + "page": 1, + "page_size": 20, + "items": [] + } + } + ) + + +# 批量操作模型 +class BatchFaceFeatureCreate(BaseModel): + """批量创建特征记录""" + items: List[FaceFeatureCreate] = Field(..., description="特征记录列表", min_items=1, max_items=1000) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "items": [ + {"person_id": 1001, "feature_type": 1}, + {"person_id": 1002, "feature_type": 1} + ] + } + } + ) + + +class FaceFeatureStatsResponse(BaseModel): + """特征记录统计响应""" + total_count: int = Field(..., description="总记录数") + by_status: Dict[str, int] = Field(..., description="按状态统计") + by_feature_type: Dict[str, int] = Field(..., description="按特征类型统计") + avg_processing_time: Optional[float] = Field(None, description="平均处理时间(秒)") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "total_count": 1000, + "by_status": {"SUCCESS": 800, "PROCESSING": 100, "FAILED": 100}, + "by_feature_type": {"1": 500, "2": 500}, + "avg_processing_time": 5.2 + } + } + ) \ No newline at end of file diff --git a/src/services/face_feature_service.py b/src/services/face_feature_service.py new file mode 100644 index 0000000..b464e29 --- /dev/null +++ b/src/services/face_feature_service.py @@ -0,0 +1,562 @@ +""" +人脸特征业务逻辑服务层 +""" + +import logging +from datetime import datetime +from typing import Optional, List, Dict, Any +from contextlib import contextmanager + +from src.repositories.face_feature_repository import FaceFeatureRepository +from src.schemas.face_feature import ( + FaceFeatureCreate, + FaceFeatureUpdate, + FaceFeatureQuery, + FaceFeatureResponse, + FaceFeatureListResponse, + FaceFeatureStatsResponse, + BatchFaceFeatureCreate, + FeatureStatus +) +from src.models.face_feature import FeatureStatusEnum +from src.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class FaceFeatureService: + """人脸特征业务服务""" + + def __init__(self, repository: FaceFeatureRepository): + """ + 初始化服务 + + Args: + repository: 特征仓库实例 + """ + self.repository = repository + + # ===== CRUD操作 ===== + + def create_feature(self, feature_data: FaceFeatureCreate) -> FaceFeatureResponse: + """ + 创建特征记录 + + Args: + feature_data: 特征数据 + + Returns: + 创建的特征记录响应 + """ + logger.info(f"Creating face feature for person_id={feature_data.person_id}") + + # 业务逻辑验证 + self._validate_feature_data(feature_data) + + # 检查是否已存在相同记录 + if feature_data.feature_type is not None: + existing = self.repository.get_by_person_and_type( + feature_data.person_id, + feature_data.feature_type + ) + if existing: + raise ValueError( + f"Feature record already exists for person_id={feature_data.person_id}, " + f"feature_type={feature_data.feature_type}" + ) + + # 创建记录 + feature = self.repository.create(feature_data) + + # 转换为响应模型 + return FaceFeatureResponse.model_validate(feature) + + def create_features_batch(self, batch_data: BatchFaceFeatureCreate) -> List[FaceFeatureResponse]: + """ + 批量创建特征记录 + + Args: + batch_data: 批量特征数据 + + Returns: + 创建的特征记录响应列表 + """ + logger.info(f"Creating {len(batch_data.items)} face features in batch") + + # 验证所有数据 + for feature_data in batch_data.items: + self._validate_feature_data(feature_data) + + # 批量创建 + features = self.repository.create_batch(batch_data.items) + + # 转换为响应模型 + return [FaceFeatureResponse.model_validate(f) for f in features] + + def get_feature(self, feature_id: int) -> Optional[FaceFeatureResponse]: + """ + 获取特征记录 + + Args: + feature_id: 特征记录ID + + Returns: + 特征记录响应或None + """ + logger.debug(f"Getting face feature: id={feature_id}") + + feature = self.repository.get_by_id(feature_id) + if not feature: + return None + + return FaceFeatureResponse.model_validate(feature) + + def get_feature_by_person_and_type( + self, + person_id: int, + feature_type: int + ) -> Optional[FaceFeatureResponse]: + """ + 根据人员ID和特征类型获取特征记录 + + Args: + person_id: 人员ID + feature_type: 特征类型 + + Returns: + 特征记录响应或None + """ + logger.debug(f"Getting face feature: person_id={person_id}, feature_type={feature_type}") + + feature = self.repository.get_by_person_and_type(person_id, feature_type) + if not feature: + return None + + return FaceFeatureResponse.model_validate(feature) + + def list_features_by_person( + self, + person_id: int, + limit: int = 100 + ) -> List[FaceFeatureResponse]: + """ + 获取人员的特征记录列表 + + Args: + person_id: 人员ID + limit: 返回数量限制 + + Returns: + 特征记录响应列表 + """ + logger.debug(f"Listing face features for person_id={person_id}") + + features = self.repository.get_by_person(person_id, limit) + return [FaceFeatureResponse.model_validate(f) for f in features] + + def query_features( + self, + query: FaceFeatureQuery, + page: int = 1, + page_size: int = 20, + order_by: str = "created_time", + desc_order: bool = True + ) -> FaceFeatureListResponse: + """ + 查询特征记录 + + Args: + query: 查询条件 + page: 页码 + page_size: 每页数量 + order_by: 排序字段 + desc_order: 是否降序 + + Returns: + 特征记录列表响应 + """ + logger.debug(f"Querying face features with filters: {query.model_dump(exclude_unset=True)}") + + features, total = self.repository.query_features( + query, page, page_size, order_by, desc_order + ) + + items = [FaceFeatureResponse.model_validate(f) for f in features] + + return FaceFeatureListResponse( + total=total, + items=items, + page=page, + page_size=page_size + ) + + def update_feature( + self, + feature_id: int, + update_data: FaceFeatureUpdate + ) -> Optional[FaceFeatureResponse]: + """ + 更新特征记录 + + Args: + feature_id: 特征记录ID + update_data: 更新数据 + + Returns: + 更新后的特征记录响应或None + """ + logger.info(f"Updating face feature: id={feature_id}") + + # 业务逻辑验证 + if update_data.status is not None: + self._validate_status_transition(feature_id, update_data.status) + + # 更新记录 + feature = self.repository.update(feature_id, update_data) + if not feature: + return None + + return FaceFeatureResponse.model_validate(feature) + + def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool: + """ + 更新特征数据 + + Args: + feature_id: 特征记录ID + feature_data: 特征数据 + + Returns: + 是否成功更新 + """ + logger.info(f"Updating feature data for face feature: id={feature_id}") + + # 验证特征数据 + if not feature_data or len(feature_data) == 0: + raise ValueError("Feature data cannot be empty") + + if len(feature_data) > 1024 * 1024: # 1MB限制 + raise ValueError("Feature data is too large (max 1MB)") + + return self.repository.update_feature_data(feature_id, feature_data) + + def update_status( + self, + feature_id: int, + status: FeatureStatus, + start_time: Optional[datetime] = None, + finish_time: Optional[datetime] = None + ) -> bool: + """ + 更新计算状态 + + Args: + feature_id: 特征记录ID + status: 新状态 + start_time: 开始时间 + finish_time: 结束时间 + + Returns: + 是否成功更新 + """ + logger.info(f"Updating status to {status} for face feature: id={feature_id}") + + # 验证状态转换 + self._validate_status_transition(feature_id, status) + + return self.repository.update_status(feature_id, status, start_time, finish_time) + + def delete_feature(self, feature_id: int) -> bool: + """ + 删除特征记录 + + Args: + feature_id: 特征记录ID + + Returns: + 是否成功删除 + """ + logger.info(f"Deleting face feature: id={feature_id}") + + return self.repository.delete(feature_id) + + def delete_features_by_person(self, person_id: int) -> int: + """ + 删除指定人员的所有特征记录 + + Args: + person_id: 人员ID + + Returns: + 删除的记录数 + """ + logger.info(f"Deleting all face features for person_id={person_id}") + + return self.repository.delete_by_person(person_id) + + # ===== 业务操作 ===== + + def start_processing(self, feature_id: int) -> bool: + """ + 开始处理特征计算 + + Args: + feature_id: 特征记录ID + + Returns: + 是否成功开始 + """ + logger.info(f"Starting processing for face feature: id={feature_id}") + + # 获取当前特征 + feature = self.repository.get_by_id(feature_id) + if not feature: + return False + + # 验证状态 + if feature.status != FeatureStatusEnum.NOT_STARTED: + raise ValueError( + f"Cannot start processing for feature with status {feature.status_name}" + ) + + # 更新状态 + return self.repository.update_status( + feature_id, + FeatureStatus.PROCESSING, + start_time=datetime.now() + ) + + def finish_processing(self, feature_id: int, success: bool = True) -> bool: + """ + 完成特征计算 + + Args: + feature_id: 特征记录ID + success: 是否成功 + + Returns: + 是否成功完成 + """ + logger.info(f"Finishing processing for face feature: id={feature_id}, success={success}") + + # 获取当前特征 + feature = self.repository.get_by_id(feature_id) + if not feature: + return False + + # 验证状态 + if feature.status != FeatureStatusEnum.PROCESSING: + raise ValueError( + f"Cannot finish processing for feature with status {feature.status_name}" + ) + + # 更新状态 + status = FeatureStatus.SUCCESS if success else FeatureStatus.FAILED + return self.repository.update_status( + feature_id, + status, + finish_time=datetime.now() + ) + + def process_pending_features(self, limit: int = 100) -> List[FaceFeatureResponse]: + """ + 处理待计算的特征记录 + + Args: + limit: 最大处理数量 + + Returns: + 处理中的特征记录列表 + """ + logger.info(f"Processing up to {limit} pending face features") + + features = self.repository.mark_for_processing(limit) + return [FaceFeatureResponse.model_validate(f) for f in features] + + # ===== 统计和分析 ===== + + def get_statistics(self) -> FaceFeatureStatsResponse: + """ + 获取特征记录统计信息 + + Returns: + 统计信息响应 + """ + logger.debug("Getting face feature statistics") + + stats = self.repository.get_stats() + + # 转换状态枚举名称 + by_status = {} + for status_value, count in stats["by_status"].items(): + try: + status_name = FeatureStatusEnum(int(status_value)).name + by_status[status_name] = count + except ValueError: + by_status[f"未知({status_value})"] = count + + return FaceFeatureStatsResponse( + total_count=stats["total_count"], + by_status=by_status, + by_feature_type=stats["by_feature_type"], + avg_processing_time=stats["avg_processing_time"] + ) + + def get_person_statistics(self, person_id: int) -> Dict[str, Any]: + """ + 获取人员的特征统计信息 + + Args: + person_id: 人员ID + + Returns: + 人员统计信息 + """ + logger.debug(f"Getting statistics for person_id={person_id}") + + # 获取该人员的所有特征 + features = self.repository.get_by_person(person_id, limit=1000) + + if not features: + return { + "person_id": person_id, + "total_features": 0, + "status_summary": {}, + "feature_types": [] + } + + # 统计信息 + status_summary = {} + feature_types = set() + successful_features = [] + + for feature in features: + # 状态统计 + status_name = feature.status_name + status_summary[status_name] = status_summary.get(status_name, 0) + 1 + + # 特征类型 + if feature.feature_type is not None: + feature_types.add(feature.feature_type) + + # 成功完成的特征 + if feature.status == FeatureStatusEnum.SUCCESS: + successful_features.append(feature) + + # 计算平均处理时间(仅成功记录) + total_time = 0 + valid_count = 0 + + for feature in successful_features: + if feature.processing_time: + total_time += feature.processing_time + valid_count += 1 + + avg_time = total_time / valid_count if valid_count > 0 else None + + return { + "person_id": person_id, + "total_features": len(features), + "status_summary": status_summary, + "feature_types": sorted(list(feature_types)), + "avg_processing_time": avg_time, + "successful_count": len(successful_features) + } + + def cleanup_old_records(self, days: int = 30) -> Dict[str, Any]: + """ + 清理旧的特征记录 + + Args: + days: 保留天数 + + Returns: + 清理结果 + """ + logger.info(f"Cleaning up face features older than {days} days") + + # 先获取清理前的统计 + before_stats = self.repository.get_stats() + + # 执行清理 + deleted_count = self.repository.cleanup_old_features(days) + + # 获取清理后的统计 + after_stats = self.repository.get_stats() + + return { + "days_retained": days, + "deleted_count": deleted_count, + "before_total": before_stats["total_count"], + "after_total": after_stats["total_count"], + "reduction_percentage": ( + (before_stats["total_count"] - after_stats["total_count"]) / + before_stats["total_count"] * 100 + if before_stats["total_count"] > 0 else 0 + ) + } + + # ===== 私有方法 ===== + + def _validate_feature_data(self, feature_data: FaceFeatureCreate) -> None: + """ + 验证特征数据 + + Args: + feature_data: 特征数据 + + Raises: + ValueError: 如果数据无效 + """ + # 验证人员ID + if feature_data.person_id <= 0: + raise ValueError("person_id must be greater than 0") + + # 验证特征类型 + if feature_data.feature_type is not None and feature_data.feature_type < 0: + raise ValueError("feature_type must be non-negative") + + # 验证特征数据大小 + if feature_data.feature_data and len(feature_data.feature_data) > 1024 * 1024: + raise ValueError("feature_data is too large (max 1MB)") + + # 验证状态 + if feature_data.status: + try: + FeatureStatus(feature_data.status) + except ValueError: + raise ValueError(f"Invalid status value: {feature_data.status}") + + def _validate_status_transition(self, feature_id: int, new_status: FeatureStatus) -> None: + """ + 验证状态转换是否有效 + + Args: + feature_id: 特征记录ID + new_status: 新状态 + + Raises: + ValueError: 如果状态转换无效 + """ + # 获取当前特征 + feature = self.repository.get_by_id(feature_id) + if not feature: + return + + current_status = FeatureStatusEnum(feature.status) + + # 定义允许的状态转换 + allowed_transitions = { + FeatureStatusEnum.NOT_STARTED: [FeatureStatusEnum.PROCESSING], + FeatureStatusEnum.PROCESSING: [FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED], + FeatureStatusEnum.SUCCESS: [], + FeatureStatusEnum.FAILED: [FeatureStatusEnum.PROCESSING] # 允许重试 + } + + new_status_enum = FeatureStatusEnum(new_status.value if isinstance(new_status, FeatureStatus) else new_status) + + # 检查转换是否允许 + if new_status_enum not in allowed_transitions.get(current_status, []): + raise ValueError( + f"Cannot transition from {current_status.name} to {new_status_enum.name}" + ) \ No newline at end of file diff --git a/src/utils/logger.py b/src/utils/logger.py new file mode 100644 index 0000000..67402b4 --- /dev/null +++ b/src/utils/logger.py @@ -0,0 +1,88 @@ +""" +日志配置模块 +""" + +import logging +import sys +from typing import Optional +from logging.handlers import RotatingFileHandler + +from src.config import settings + + +def setup_logger( + name: str, + level: Optional[str] = None, + log_file: Optional[str] = None +) -> logging.Logger: + """ + 配置和获取logger + + Args: + name: logger名称 + level: 日志级别 + log_file: 日志文件路径 + + Returns: + 配置好的logger实例 + """ + # 获取日志级别 + if level is None: + level = settings.LOG_LEVEL + + log_level = getattr(logging, level.upper(), logging.INFO) + + # 创建logger + logger = logging.getLogger(name) + logger.setLevel(log_level) + + # 避免重复添加handler + if logger.handlers: + return logger + + # 创建formatter + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 控制台handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # 文件handler(如果配置了日志文件) + if log_file or settings.LOG_FILE: + file_path = log_file or settings.LOG_FILE + try: + file_handler = RotatingFileHandler( + file_path, + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, + encoding='utf-8' + ) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + except Exception as e: + logger.warning(f"Failed to create log file handler: {e}") + + return logger + + +# 创建根logger +root_logger = setup_logger("sur_face_feature") + + +def get_logger(name: str) -> logging.Logger: + """ + 获取指定名称的logger + + Args: + name: logger名称 + + Returns: + logger实例 + """ + return setup_logger(name) \ No newline at end of file