引入http服务,引入数据库

This commit is contained in:
zqc
2025-12-20 18:09:18 +08:00
parent 713ad3f3e4
commit 66890ff094
14 changed files with 3183 additions and 0 deletions

67
src/api/dependencies.py Normal file
View File

@@ -0,0 +1,67 @@
"""
FastAPI依赖注入模块
"""
from typing import Generator, Optional
from fastapi import Depends, HTTPException, status
from sqlalchemy.orm import Session
from src.database.connection import db_manager, get_db
from src.repositories.face_feature_repository import FaceFeatureRepository
from src.services.face_feature_service import FaceFeatureService
def get_face_feature_repository(
db: Session = Depends(get_db)
) -> FaceFeatureRepository:
"""
获取人脸特征仓库依赖
Args:
db: 数据库会话
Returns:
FaceFeatureRepository实例
"""
return FaceFeatureRepository(db)
def get_face_feature_service(
repository: FaceFeatureRepository = Depends(get_face_feature_repository)
) -> FaceFeatureService:
"""
获取人脸特征服务依赖
Args:
repository: 人脸特征仓库
Returns:
FaceFeatureService实例
"""
return FaceFeatureService(repository)
def get_face_feature_by_id(
feature_id: int,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
根据ID获取人脸特征记录的依赖
Args:
feature_id: 特征记录ID
service: 人脸特征服务
Returns:
人脸特征记录
Raises:
HTTPException: 如果记录不存在
"""
feature = service.get_feature(feature_id)
if not feature:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"特征记录不存在 (ID: {feature_id})"
)
return feature

152
src/api/errors.py Normal file
View File

@@ -0,0 +1,152 @@
"""
API错误处理模块
"""
from typing import Any, Dict, Optional
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.requests import Request
class APIError(HTTPException):
"""API自定义错误"""
def __init__(
self,
status_code: int = status.HTTP_400_BAD_REQUEST,
detail: Any = None,
headers: Optional[Dict[str, str]] = None,
error_code: Optional[str] = None,
):
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.error_code = error_code or f"ERR_{status_code}"
class FaceFeatureProcessingError(APIError):
"""人脸特征处理错误"""
def __init__(
self,
detail: str = "人脸特征处理失败",
feature_id: Optional[int] = None,
):
if feature_id:
detail = f"人脸特征处理失败 (特征ID: {feature_id}): {detail}"
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail,
error_code="FACE_FEATURE_PROCESSING_ERROR"
)
class FeatureNotFoundError(APIError):
"""特征记录不存在错误"""
def __init__(self, feature_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"特征记录不存在 (ID: {feature_id})",
error_code="FEATURE_NOT_FOUND"
)
class DuplicateFeatureError(APIError):
"""重复特征记录错误"""
def __init__(self, person_id: int, feature_type: int):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=f"特征记录已存在 (人员ID: {person_id}, 特征类型: {feature_type})",
error_code="DUPLICATE_FEATURE"
)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError
) -> JSONResponse:
"""
请求验证异常处理器
Args:
request: 请求对象
exc: 验证异常
Returns:
JSON响应
"""
errors = []
for error in exc.errors():
field = ".".join(str(loc) for loc in error["loc"] if loc != "body")
errors.append({
"field": field or "body",
"message": error["msg"],
"type": error["type"]
})
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"error": {
"code": "VALIDATION_ERROR",
"message": "请求参数验证失败",
"details": errors
}
}
)
async def api_error_handler(
request: Request,
exc: APIError
) -> JSONResponse:
"""
API错误处理器
Args:
request: 请求对象
exc: API错误
Returns:
JSON响应
"""
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.error_code,
"message": exc.detail
}
}
)
async def generic_exception_handler(
request: Request,
exc: Exception
) -> JSONResponse:
"""
通用异常处理器
Args:
request: 请求对象
exc: 异常
Returns:
JSON响应
"""
# 记录异常到日志
import logging
logger = logging.getLogger(__name__)
logger.error(f"未处理的异常: {exc}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": {
"code": "INTERNAL_SERVER_ERROR",
"message": "服务器内部错误"
}
}
)

View File

@@ -0,0 +1,417 @@
"""
人脸特征API路由
"""
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
from fastapi.responses import JSONResponse
from src.schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FaceFeatureResponse,
FaceFeatureListResponse,
FaceFeatureStatsResponse,
BatchFaceFeatureCreate,
FeatureStatus
)
from src.api.dependencies import (
get_face_feature_service,
get_face_feature_by_id
)
from src.services.face_feature_service import FaceFeatureService
from src.api.errors import (
FaceFeatureProcessingError,
FeatureNotFoundError,
DuplicateFeatureError
)
from src.config import settings
# 创建路由器
router = APIRouter(
prefix="/face-features",
tags=["人脸特征管理"],
responses={
404: {"description": "资源不存在"},
400: {"description": "请求参数错误"},
500: {"description": "服务器内部错误"}
}
)
@router.post(
"/",
response_model=FaceFeatureResponse,
status_code=status.HTTP_201_CREATED,
summary="创建人脸特征记录",
description="创建新的人脸特征记录"
)
async def create_face_feature(
feature_data: FaceFeatureCreate,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
创建人脸特征记录
- **person_id**: 人员ID (必须大于0)
- **feature_type**: 特征类型 (可选大于等于0)
- **pic_id**: 图片ID (可选)
- **status**: 计算状态 (默认: NOT_STARTED)
- **feature_data**: 特征数据 (可选,二进制)
"""
try:
return service.create_feature(feature_data)
except ValueError as e:
if "already exists" in str(e):
# 解析错误信息中的person_id和feature_type
raise DuplicateFeatureError(
person_id=feature_data.person_id,
feature_type=feature_data.feature_type or 0
)
raise FaceFeatureProcessingError(detail=str(e))
@router.get(
"/{feature_id}",
response_model=FaceFeatureResponse,
summary="获取人脸特征记录",
description="根据ID获取人脸特征记录"
)
async def get_face_feature(
feature: FaceFeatureResponse = Depends(get_face_feature_by_id)
):
"""
根据ID获取人脸特征记录
- **feature_id**: 特征记录ID (路径参数)
"""
return feature
@router.get(
"/",
response_model=FaceFeatureListResponse,
summary="查询人脸特征记录",
description="查询人脸特征记录列表,支持分页和过滤"
)
async def list_face_features(
person_id: Optional[int] = Query(None, description="人员ID", gt=0),
feature_type: Optional[int] = Query(None, description="特征类型", ge=0),
status: Optional[FeatureStatus] = Query(None, description="计算状态"),
start_date: Optional[datetime] = Query(None, description="开始时间"),
end_date: Optional[datetime] = Query(None, description="结束时间"),
has_feature_data: Optional[bool] = Query(None, description="是否有特征数据"),
page: int = Query(1, description="页码", ge=1),
page_size: int = Query(20, description="每页数量", ge=1, le=100),
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
查询人脸特征记录
- **person_id**: 按人员ID过滤 (可选)
- **feature_type**: 按特征类型过滤 (可选)
- **status**: 按计算状态过滤 (可选)
- **start_date**: 开始时间过滤 (可选)
- **end_date**: 结束时间过滤 (可选)
- **has_feature_data**: 是否有特征数据过滤 (可选)
- **page**: 页码 (默认: 1)
- **page_size**: 每页数量 (默认: 20, 最大: 100)
"""
# 构建查询参数
query = FaceFeatureQuery(
person_id=person_id,
feature_type=feature_type,
status=status,
start_date=start_date,
end_date=end_date,
has_feature_data=has_feature_data
)
return service.query_features(
query=query,
page=page,
page_size=page_size,
order_by="created_time",
desc_order=True
)
@router.put(
"/{feature_id}",
response_model=FaceFeatureResponse,
summary="更新人脸特征记录",
description="更新指定ID的人脸特征记录"
)
async def update_face_feature(
feature_id: int,
update_data: FaceFeatureUpdate,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
更新人脸特征记录
- **feature_id**: 特征记录ID (路径参数)
- **update_data**: 更新数据 (请求体)
"""
try:
result = service.update_feature(feature_id, update_data)
if not result:
raise FeatureNotFoundError(feature_id)
return result
except ValueError as e:
raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id)
@router.delete(
"/{feature_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="删除人脸特征记录",
description="删除指定ID的人脸特征记录"
)
async def delete_face_feature(
feature_id: int,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
删除人脸特征记录
- **feature_id**: 特征记录ID (路径参数)
"""
success = service.delete_feature(feature_id)
if not success:
raise FeatureNotFoundError(feature_id)
return JSONResponse(
status_code=status.HTTP_204_NO_CONTENT,
content=None
)
@router.post(
"/{feature_id}/start-processing",
response_model=FaceFeatureResponse,
summary="开始处理人脸特征",
description="开始计算指定ID的人脸特征值"
)
async def start_face_feature_processing(
feature_id: int,
background_tasks: BackgroundTasks,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
开始处理人脸特征计算
- **feature_id**: 特征记录ID (路径参数)
注意:这是一个异步处理接口,会立即返回开始状态,
实际特征计算可能在后台进行。
"""
try:
# 先获取特征记录
feature = service.get_feature(feature_id)
if not feature:
raise FeatureNotFoundError(feature_id)
# 检查是否可以开始处理
if feature.status != FeatureStatus.NOT_STARTED:
raise FaceFeatureProcessingError(
detail=f"特征记录状态为 {feature.status_name},无法开始处理",
feature_id=feature_id
)
# 开始处理
success = service.start_processing(feature_id)
if not success:
raise FaceFeatureProcessingError(
detail="开始处理失败",
feature_id=feature_id
)
# 异步任务:模拟特征计算过程
# 在实际应用中,这里应该调用实际的特征计算服务
background_tasks.add_task(
simulate_feature_processing,
feature_id=feature_id,
service=service
)
# 返回更新后的特征记录
return service.get_feature(feature_id)
except ValueError as e:
raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id)
@router.post(
"/{feature_id}/finish-processing",
response_model=FaceFeatureResponse,
summary="完成人脸特征处理",
description="完成指定ID的人脸特征值计算"
)
async def finish_face_feature_processing(
feature_id: int,
success: bool = Query(True, description="是否成功完成"),
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
完成人脸特征计算
- **feature_id**: 特征记录ID (路径参数)
- **success**: 是否成功完成 (查询参数,默认: true)
"""
try:
# 检查特征记录
feature = service.get_feature(feature_id)
if not feature:
raise FeatureNotFoundError(feature_id)
# 检查是否可以完成处理
if feature.status != FeatureStatus.PROCESSING:
raise FaceFeatureProcessingError(
detail=f"特征记录状态为 {feature.status_name},无法完成处理",
feature_id=feature_id
)
# 完成处理
finish_success = service.finish_processing(feature_id, success)
if not finish_success:
raise FaceFeatureProcessingError(
detail="完成处理失败",
feature_id=feature_id
)
return service.get_feature(feature_id)
except ValueError as e:
raise FaceFeatureProcessingError(detail=str(e), feature_id=feature_id)
@router.post(
"/batch",
response_model=List[FaceFeatureResponse],
status_code=status.HTTP_201_CREATED,
summary="批量创建人脸特征记录",
description="批量创建多个人脸特征记录"
)
async def batch_create_face_features(
batch_data: BatchFaceFeatureCreate,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
批量创建人脸特征记录
- **items**: 特征记录列表 (必须1-1000条)
"""
try:
return service.create_features_batch(batch_data)
except ValueError as e:
raise FaceFeatureProcessingError(detail=str(e))
@router.get(
"/person/{person_id}",
response_model=List[FaceFeatureResponse],
summary="获取人员的人脸特征记录",
description="根据人员ID获取所有相关的人脸特征记录"
)
async def get_face_features_by_person(
person_id: int,
limit: int = Query(100, description="返回数量限制", ge=1, le=1000),
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
获取人员的人脸特征记录
- **person_id**: 人员ID (路径参数)
- **limit**: 返回数量限制 (查询参数,默认: 100, 最大: 1000)
"""
return service.list_features_by_person(person_id, limit)
@router.get(
"/stats/summary",
response_model=FaceFeatureStatsResponse,
summary="获取特征记录统计信息",
description="获取人脸特征记录的统计摘要"
)
async def get_face_features_stats(
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
获取特征记录统计信息
"""
return service.get_statistics()
@router.get(
"/person/{person_id}/stats",
summary="获取人员特征统计信息",
description="获取指定人员的特征记录统计信息"
)
async def get_person_face_features_stats(
person_id: int,
service: FaceFeatureService = Depends(get_face_feature_service)
):
"""
获取人员特征统计信息
- **person_id**: 人员ID (路径参数)
"""
try:
stats = service.get_person_statistics(person_id)
return {
"person_id": person_id,
"total_features": stats["total_features"],
"status_summary": stats["status_summary"],
"feature_types": stats["feature_types"],
"avg_processing_time": stats["avg_processing_time"],
"successful_count": stats["successful_count"]
}
except Exception as e:
raise FaceFeatureProcessingError(detail=str(e))
async def simulate_feature_processing(
feature_id: int,
service: FaceFeatureService
):
"""
模拟人脸特征计算过程
在实际应用中,这里应该调用实际的特征计算算法
例如使用InsightFace、OpenCV等库进行人脸特征提取
Args:
feature_id: 特征记录ID
service: 人脸特征服务
"""
import asyncio
import random
try:
# 模拟计算延迟 (3-10秒)
delay = random.uniform(3, 10)
await asyncio.sleep(delay)
# 模拟成功或失败 (90%成功率)
success = random.random() < 0.9
# 完成处理
service.finish_processing(feature_id, success)
# 如果成功,添加模拟的特征数据
if success:
# 生成模拟的512维特征向量 (float32)
import numpy as np
feature_data = np.random.randn(512).astype(np.float32).tobytes()
service.update_feature_data(feature_id, feature_data)
except Exception as e:
# 如果发生异常,标记为失败
service.finish_processing(feature_id, False)
# 记录日志
import logging
logger = logging.getLogger(__name__)
logger.error(f"特征计算失败 (ID: {feature_id}): {e}")

188
src/app.py Normal file
View File

@@ -0,0 +1,188 @@
"""
FastAPI主应用
将原来的main.py重命名为app.py
"""
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from fastapi.openapi.docs import (
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
get_redoc_html,
)
from fastapi.staticfiles import StaticFiles
from src.api.routes import face_features
from src.api.errors import (
APIError,
validation_exception_handler,
api_error_handler,
generic_exception_handler
)
from src.config import settings
from src.database.connection import init_database
from src.database.connection import db_manager
# 生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
- 启动时:初始化数据库
- 关闭时:清理资源
"""
# 启动时
print("🚀 start algorithm service...")
print(f"📊 db: {settings.DATABASE_NAME}")
print(f"🔧 debug mode: {settings.DEBUG}")
# 初始化数据库
init_database()
# 数据库健康检查
if db_manager.health_check():
print("✅ 数据库连接正常")
else:
print("❌ 数据库连接失败")
yield
# 关闭时
print("🛑 algorithm service stopped...")
db_manager.close()
# 创建FastAPI应用
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.PROJECT_VERSION,
description=settings.PROJECT_DESCRIPTION,
openapi_url=f"{settings.API_V1_PREFIX}/openapi.json",
docs_url=None, # 自定义docs
redoc_url=None, # 自定义redoc
lifespan=lifespan
)
# 自定义API文档页面
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=f"{app.title} - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js",
swagger_css_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css",
)
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect():
return get_swagger_ui_oauth2_redirect_html()
@app.get("/redoc", include_in_schema=False)
async def redoc_html():
return get_redoc_html(
openapi_url=app.openapi_url,
title=f"{app.title} - ReDoc",
redoc_js_url="https://unpkg.com/redoc@next/bundles/redoc.standalone.js",
)
# 中间件配置
app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["*"] if settings.DEBUG else ["localhost", "127.0.0.1"]
)
# 请求计时中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
"""
添加请求处理时间头
"""
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 异常处理器
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(APIError, api_error_handler)
app.add_exception_handler(Exception, generic_exception_handler)
# 根路由
@app.get("/")
async def root():
"""
根路径
"""
return {
"message": "algorithm service",
"version": settings.PROJECT_VERSION,
"docs": "/docs",
"api_prefix": settings.API_V1_PREFIX
}
@app.get("/health")
async def health_check():
"""
健康检查端点
"""
# 检查数据库连接
db_healthy = db_manager.health_check()
return {
"status": "healthy" if db_healthy else "unhealthy",
"database": "connected" if db_healthy else "disconnected",
"timestamp": time.time()
}
# 注册API路由
app.include_router(
face_features.router,
prefix=settings.API_V1_PREFIX
)
# 自定义404处理器
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
"""
自定义404错误处理器
"""
return JSONResponse(
status_code=404,
content={
"error": {
"code": "NOT_FOUND",
"message": f"请求的资源不存在: {request.url.path}"
}
}
)
# 导出应用实例
__all__ = ["app"]

91
src/config.py Normal file
View File

@@ -0,0 +1,91 @@
"""
数据库配置模块
使用pydantic进行配置验证和管理
"""
from typing import Optional, List
from pydantic_settings import BaseSettings
from functools import lru_cache
from pydantic import PostgresDsn, field_validator
from pydantic_core.core_schema import FieldValidationInfo
class Settings(BaseSettings):
"""应用配置类"""
# API配置
API_V1_PREFIX: str = "/api/v1"
PROJECT_NAME: str = "algorithm-service"
PROJECT_VERSION: str = "1.0.0"
PROJECT_DESCRIPTION: str = "algorithm-service"
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
# 数据库配置
DATABASE_HOST: str = "localhost"
DATABASE_PORT: int = 5432
DATABASE_USER: str = "postgres"
DATABASE_PASSWORD: str = "yipai123"
DATABASE_NAME: str = "pmms"
DATABASE_SCHEMA: str = "public"
# 连接池配置
DATABASE_POOL_SIZE: int = 10
DATABASE_MAX_OVERFLOW: int = 20
DATABASE_POOL_RECYCLE: int = 3600 # 连接回收时间(秒)
DATABASE_ECHO: bool = False # SQL日志生产环境设为False
# 应用配置
APP_NAME: str = "SurFaceFeature API"
APP_VERSION: str = "1.0.0"
DEBUG: bool = False
# 日志配置
LOG_LEVEL: str = "INFO"
LOG_FILE: Optional[str] = None
# 异步配置
ASYNC_MODE: bool = False
# JWT配置预留
SECRET_KEY: str = "your-secret-key-here-change-in-production"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
@property
def DATABASE_URL(self) -> str:
"""构建数据库连接URL"""
return f"postgresql://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
@property
def ASYNC_DATABASE_URL(self) -> str:
"""构建异步数据库连接URL"""
return f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
@field_validator("DATABASE_POOL_SIZE")
def validate_pool_size(cls, v):
"""验证连接池大小"""
if v < 1:
raise ValueError("DATABASE_POOL_SIZE must be at least 1")
if v > 100:
raise ValueError("DATABASE_POOL_SIZE cannot exceed 100")
return v
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
extra = "ignore"
@lru_cache()
def get_settings() -> Settings:
"""
获取配置单例
使用lru_cache避免重复加载.env文件
"""
return Settings()
# 导出配置实例
settings = get_settings()

66
src/database/base.py Normal file
View File

@@ -0,0 +1,66 @@
"""
数据库模型基类
"""
from datetime import datetime
from typing import Any, Dict
from sqlalchemy import Column, DateTime, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
class BaseModel(Base):
"""抽象基类,为所有模型提供通用字段"""
__abstract__ = True
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
created_time = Column(DateTime(timezone=True),
server_default=func.now(),
nullable=False,
comment="创建时间")
def to_dict(self, exclude: list = None) -> Dict[str, Any]:
"""
将模型实例转换为字典
Args:
exclude: 要排除的字段列表
Returns:
包含模型字段的字典
"""
exclude = exclude or []
result = {}
for column in self.__table__.columns:
if column.name in exclude:
continue
value = getattr(self, column.name)
# 处理特殊类型
if isinstance(value, datetime):
value = value.isoformat()
elif isinstance(value, bytes):
value = value.hex() if value else None
result[column.name] = value
return result
def update_from_dict(self, data: Dict[str, Any]) -> None:
"""
从字典更新模型字段
Args:
data: 包含要更新字段的字典
"""
for key, value in data.items():
if hasattr(self, key) and key != 'id':
setattr(self, key, value)
def __repr__(self) -> str:
"""模型表示"""
return f"<{self.__class__.__name__}(id={self.id})>"

252
src/database/connection.py Normal file
View File

@@ -0,0 +1,252 @@
"""
数据库连接管理模块
支持同步和异步两种模式
"""
import logging
from typing import Optional, Generator, AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
async_sessionmaker
)
from src.config import settings
from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class DatabaseManager:
"""数据库管理器(同步)"""
def __init__(self):
self._engine: Optional[Engine] = None
self._session_factory: Optional[sessionmaker] = None
def init_engine(self) -> None:
"""初始化数据库引擎"""
if self._engine is not None:
logger.warning("Database engine already initialized")
return
try:
self._engine = create_engine(
settings.DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO,
pool_pre_ping=True, # 连接前进行ping检查
connect_args={
"options": f"-csearch_path={settings.DATABASE_SCHEMA}"
}
)
self._session_factory = sessionmaker(
autocommit=False,
autoflush=False,
bind=self._engine,
expire_on_commit=False # 避免延迟加载问题
)
logger.info("Database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database engine: {e}")
raise
@property
def engine(self) -> Engine:
"""获取数据库引擎"""
if self._engine is None:
self.init_engine()
return self._engine
@property
def session_factory(self) -> sessionmaker:
"""获取会话工厂"""
if self._session_factory is None:
self.init_engine()
return self._session_factory
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""
获取数据库会话的上下文管理器
使用示例:
with db_manager.get_session() as session:
result = session.query(User).all()
"""
session = self.session_factory()
try:
yield session
session.commit()
logger.debug("Session committed successfully")
except Exception as e:
session.rollback()
logger.error(f"Session rollback due to error: {e}")
raise
finally:
session.close()
logger.debug("Session closed")
def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list:
"""执行原始SQL查询"""
with self.get_session() as session:
result = session.execute(text(sql), params or {})
return [dict(row._mapping) for row in result]
def health_check(self) -> bool:
"""数据库健康检查"""
try:
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Database health check passed")
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
def close(self) -> None:
"""关闭数据库连接"""
if self._engine:
self._engine.dispose()
self._engine = None
self._session_factory = None
logger.info("Database connections closed")
class AsyncDatabaseManager:
"""异步数据库管理器"""
def __init__(self):
self._engine: Optional[AsyncEngine] = None
self._async_session_factory: Optional[async_sessionmaker] = None
async def init_engine(self) -> None:
"""初始化异步数据库引擎"""
if self._engine is not None:
logger.warning("Async database engine already initialized")
return
try:
self._engine = create_async_engine(
settings.ASYNC_DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO,
pool_pre_ping=True,
connect_args={
"server_settings": {
"search_path": settings.DATABASE_SCHEMA
}
}
)
self._async_session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False
)
logger.info("Async database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize async database engine: {e}")
raise
@property
async def engine(self) -> AsyncEngine:
"""获取异步数据库引擎"""
if self._engine is None:
await self.init_engine()
return self._engine
@property
async def async_session_factory(self) -> async_sessionmaker:
"""获取异步会话工厂"""
if self._async_session_factory is None:
await self.init_engine()
return self._async_session_factory
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
获取异步数据库会话的上下文管理器
使用示例:
async with async_db_manager.get_session() as session:
result = await session.execute(query)
"""
if self._async_session_factory is None:
await self.init_engine()
session = self._async_session_factory()
try:
yield session
await session.commit()
logger.debug("Async session committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async session rollback due to error: {e}")
raise
finally:
await session.close()
logger.debug("Async session closed")
async def health_check(self) -> bool:
"""异步数据库健康检查"""
try:
async with self._engine.connect() as conn:
await conn.execute(text("SELECT 1"))
logger.info("Async database health check passed")
return True
except Exception as e:
logger.error(f"Async database health check failed: {e}")
return False
async def close(self) -> None:
"""关闭异步数据库连接"""
if self._engine:
await self._engine.dispose()
self._engine = None
self._async_session_factory = None
logger.info("Async database connections closed")
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
async_db_manager = AsyncDatabaseManager()
def init_database() -> None:
"""初始化数据库(同步)"""
db_manager.init_engine()
async def init_async_database() -> None:
"""初始化异步数据库"""
await async_db_manager.init_engine()
def get_db() -> Generator[Session, None, None]:
"""依赖注入获取数据库会话用于FastAPI等框架"""
with db_manager.get_session() as session:
yield session
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""依赖注入获取异步数据库会话用于FastAPI等框架"""
async with async_db_manager.get_session() as session:
yield session

207
src/main.py Normal file
View File

@@ -0,0 +1,207 @@
"""
主程序示例
演示如何使用各个模块
"""
import asyncio
from datetime import datetime, timedelta
from src.config import settings
from src.database.connection import db_manager, init_database
from src.models.face_feature import SurFaceFeature
from src.repositories.face_feature_repository import FaceFeatureRepository
from src.services.face_feature_service import FaceFeatureService
from src.schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FeatureStatus,
BatchFaceFeatureCreate
)
def demo_sync_operations():
"""演示同步操作"""
print("=== 同步操作演示 ===")
# 初始化数据库
init_database()
# 创建仓库和服务
with db_manager.get_session() as session:
repository = FaceFeatureRepository(session)
service = FaceFeatureService(repository)
try:
# 1. 创建特征记录(先检查是否存在)
print("\n1. 创建特征记录")
# 固定的测试ID
test_person_id = 1001
test_feature_type = 1
# 检查是否已存在
existing = service.get_feature_by_person_and_type(test_person_id, test_feature_type)
if existing:
print(f"记录已存在: ID={existing.id}, 状态={existing.status_name}")
feature_id = existing.id
else:
feature_data = FaceFeatureCreate(
person_id=test_person_id,
feature_type=test_feature_type,
pic_id="test_image_001.jpg",
status=FeatureStatus.NOT_STARTED
)
feature_response = service.create_feature(feature_data)
print(f"创建成功: ID={feature_response.id}, 人员ID={feature_response.person_id}")
feature_id = feature_response.id
# 2. 开始处理(如果未开始)
print("\n2. 检查并开始处理特征计算")
feature = service.get_feature(feature_id)
if feature and feature.status == FeatureStatus.NOT_STARTED:
if service.start_processing(feature_id):
print(f"已开始处理: ID={feature_id}")
else:
print(f"特征计算已开始或已完成: 状态={feature.status_name if feature else '未知'}")
# 3. 完成处理(如果还在处理中)
print("\n3. 检查并完成处理特征计算")
feature = service.get_feature(feature_id)
if feature and feature.status == FeatureStatus.PROCESSING:
if service.finish_processing(feature_id, success=True):
print(f"已完成处理: ID={feature_id}")
else:
print(f"特征计算已完成或未开始: 状态={feature.status_name if feature else '未知'}")
# 4. 查询特征
print("\n4. 查询特征记录")
retrieved = service.get_feature(feature_id)
if retrieved:
print(f"查询成功: ID={retrieved.id}, 状态={retrieved.status_name}")
print(f"处理时间: {retrieved.processing_time}")
print(f"是否有特征数据: {retrieved.has_feature_data}")
# 5. 批量创建 - 跳过已存在的
print("\n5. 批量创建特征记录")
batch_items = []
for i in range(3):
person_id = 2000 + i
# 检查是否已存在
existing = service.get_feature_by_person_and_type(person_id, 1)
if not existing:
batch_items.append(FaceFeatureCreate(person_id=person_id, feature_type=1))
if batch_items:
batch_data = BatchFaceFeatureCreate(items=batch_items)
try:
batch_result = service.create_features_batch(batch_data)
print(f"批量创建成功: {len(batch_result)}条记录")
except ValueError as e:
print(f"批量创建失败: {e}")
else:
print("所有记录已存在,跳过批量创建")
# 6. 查询列表
print("\n6. 查询特征记录列表")
query = FaceFeatureQuery(
feature_type=1,
start_date=datetime.now() - timedelta(days=1)
)
result = service.query_features(query, page=1, page_size=10)
print(f"查询结果: 共{result.total}条记录,本页{len(result.items)}")
if result.items:
print(f"第一笔记录: ID={result.items[0].id}, 人员ID={result.items[0].person_id}")
# 7. 获取统计信息
print("\n7. 获取统计信息")
stats = service.get_statistics()
print(f"总记录数: {stats.total_count}")
print(f"状态分布: {stats.by_status}")
print(f"特征类型分布: {stats.by_feature_type}")
# 8. 更新特征数据
print("\n8. 更新特征数据")
test_feature_data = b"test_feature_data_12345"
success = service.update_feature_data(feature_id, test_feature_data)
print(f"更新特征数据: {'成功' if success else '失败'}")
# 重新查询查看更新后的数据
updated_feature = service.get_feature(feature_id)
if updated_feature:
print(f"更新后是否有特征数据: {updated_feature.has_feature_data}")
# 9. 演示删除操作
print("\n9. 演示删除操作")
# 先创建一个要删除的记录
delete_person_id = 9999 # 使用一个不存在的ID
delete_feature_data = FaceFeatureCreate(
person_id=delete_person_id,
feature_type=3,
pic_id="to_delete.jpg"
)
delete_feature = service.create_feature(delete_feature_data)
print(f"创建待删除记录: ID={delete_feature.id}")
# 删除记录
delete_success = service.delete_feature(delete_feature.id)
print(f"删除记录: {'成功' if delete_success else '失败'}")
# 验证删除
deleted_check = service.get_feature(delete_feature.id)
print(f"验证删除: {'记录不存在' if deleted_check is None else '记录还存在'}")
except Exception as e:
print(f"操作失败: {e}")
import traceback
traceback.print_exc()
# 关闭数据库连接
db_manager.close()
print("\n数据库连接已关闭")
async def demo_async_operations():
"""演示异步操作"""
print("\n=== 异步操作演示 ===")
# 注意:异步操作需要异步数据库管理器
# 这里仅展示结构,实际使用时需要配置异步数据库
print("异步操作示例代码已准备,需要配置异步数据库连接")
print("请使用 async_db_manager 和异步版本的repository")
def main():
"""主函数"""
print("人脸特征管理系统演示")
print("=" * 50)
# 显示配置信息
print(f"应用名称: {settings.APP_NAME}")
print(f"数据库: {settings.DATABASE_NAME}")
print(f"连接池大小: {settings.DATABASE_POOL_SIZE}")
try:
# 执行同步演示
demo_sync_operations()
# 如果需要异步演示,可以取消下面的注释
# asyncio.run(demo_async_operations())
except KeyboardInterrupt:
print("\n程序被用户中断")
except Exception as e:
print(f"程序执行出错: {e}")
import traceback
traceback.print_exc()
print("\n演示完成!")
if __name__ == "__main__":
main()

197
src/models/face_feature.py Normal file
View File

@@ -0,0 +1,197 @@
"""
人脸特征数据模型
对应数据库表sur_face_feature
"""
from typing import Optional
from datetime import datetime
from enum import IntEnum
from sqlalchemy import (
Column,
Integer,
SmallInteger,
LargeBinary,
DateTime,
Text,
Index,
UniqueConstraint
)
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.sql import func
from src.database.base import BaseModel
class FeatureStatus(IntEnum):
"""人脸特征值计算状态枚举"""
NOT_STARTED = 0 # 未开始
PROCESSING = 1 # 计算中
SUCCESS = 2 # 计算成功
FAILED = 3 # 计算失败
# 导出别名以保持向后兼容性
FeatureStatusEnum = FeatureStatus
class SurFaceFeature(BaseModel):
"""
人脸特征值表模型
对应表结构:
- id: 主键
- person_id: 人员ID
- feature_type: 模型版本
- feature_data: 特征值(二进制)
- created_time: 创建时间
- pic_id: 图片ID
- status: 计算状态
- start_time: 计算开始时间
- finish_time: 计算结束时间
"""
__tablename__ = "sur_face_feature"
__table_args__ = (
# 唯一约束person_id + feature_type
UniqueConstraint("person_id", "feature_type", name="sur_face_feature_unique"),
# 为常用查询字段创建索引
Index("ix_sur_face_feature_person_id", "person_id"),
Index("ix_sur_face_feature_feature_type", "feature_type"),
Index("ix_sur_face_feature_status", "status"),
Index("ix_sur_face_feature_created_time", "created_time"),
{"schema": "public", "comment": "人脸特征值表"}
)
# 主键(自增序列)
id = Column(
Integer,
primary_key=True,
index=True,
comment="主键"
)
# 人员ID必填
person_id = Column(
Integer,
nullable=False,
comment="人员id"
)
# 模型版本(特征类型)
feature_type = Column(
SmallInteger,
nullable=True, # 根据SQL允许NULL
comment="模型版本"
)
# 特征值(二进制数据)
feature_data = Column(
BYTEA, # PostgreSQL的二进制类型
nullable=True,
comment="特征值"
)
# 创建时间(自动设置)
created_time = Column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
comment="创建时间"
)
# 图片ID
pic_id = Column(
Text,
nullable=True,
comment="图片id"
)
# 计算状态
status = Column(
SmallInteger,
default=FeatureStatusEnum.NOT_STARTED,
nullable=False,
comment="人脸特征值计算状态0=未开始1=计算中2=计算成功3=计算失败"
)
# 计算开始时间
start_time = Column(
DateTime(timezone=True),
nullable=True,
comment="特征计算开始时间"
)
# 计算结束时间
finish_time = Column(
DateTime(timezone=True),
nullable=True,
comment="特征计算结束时间"
)
# 属性方法
@property
def status_name(self) -> str:
"""获取状态名称"""
try:
return FeatureStatusEnum(self.status).name
except ValueError:
return f"未知状态({self.status})"
@property
def is_completed(self) -> bool:
"""是否完成计算"""
return self.status in [FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED]
@property
def processing_time(self) -> Optional[float]:
"""计算处理时间(秒)"""
if self.start_time and self.finish_time:
return (self.finish_time - self.start_time).total_seconds()
return None
def start_processing(self) -> None:
"""开始处理"""
self.status = FeatureStatusEnum.PROCESSING
self.start_time = datetime.now()
self.finish_time = None
def finish_processing(self, success: bool = True) -> None:
"""结束处理"""
self.status = FeatureStatusEnum.SUCCESS if success else FeatureStatusEnum.FAILED
self.finish_time = datetime.now()
def to_dict(self, exclude: list = None) -> dict:
"""
重写to_dict方法处理二进制数据
Args:
exclude: 要排除的字段列表
Returns:
转换后的字典
"""
exclude = exclude or []
# 默认排除二进制数据(太大)
default_exclude = ["feature_data"]
final_exclude = list(set(exclude + default_exclude))
result = super().to_dict(final_exclude)
# 添加额外属性
result["status_name"] = self.status_name
result["is_completed"] = self.is_completed
result["processing_time"] = self.processing_time
# 如果有feature_data添加一个标识
if self.feature_data and "feature_data" not in exclude:
result["has_feature_data"] = True
result["feature_data_length"] = len(self.feature_data)
else:
result["has_feature_data"] = False
return result
def __repr__(self) -> str:
return (f"<SurFaceFeature(id={self.id}, person_id={self.person_id}, "
f"feature_type={self.feature_type}, status={self.status_name})>")

View File

@@ -0,0 +1,597 @@
"""
人脸特征数据仓库
数据访问层,处理所有数据库操作
"""
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
from contextlib import contextmanager
from sqlalchemy import select, update, delete, func, and_, or_, desc, asc
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from src.models.face_feature import SurFaceFeature
from src.schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FeatureStatus
)
from src.models.face_feature import FeatureStatusEnum
from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class FaceFeatureRepository:
"""人脸特征数据仓库"""
def __init__(self, session: Session):
"""
初始化仓库
Args:
session: SQLAlchemy会话对象
"""
self.session = session
# ===== 创建操作 =====
def create(self, feature_data: FaceFeatureCreate) -> SurFaceFeature:
"""
创建特征记录
Args:
feature_data: 特征数据
Returns:
创建的SurFaceFeature对象
Raises:
IntegrityError: 违反唯一约束时抛出
"""
try:
# 转换为模型字典
feature_dict = feature_data.model_dump(exclude_unset=True)
# 创建模型实例
feature = SurFaceFeature(**feature_dict)
# 添加到会话
self.session.add(feature)
self.session.flush() # 立即执行,但不提交
logger.info(f"Created face feature record: id={feature.id}, person_id={feature.person_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error creating face feature: {e}")
self.session.rollback()
raise ValueError(f"Duplicate feature record for person_id={feature_data.person_id}, "
f"feature_type={feature_data.feature_type}")
except SQLAlchemyError as e:
logger.error(f"Database error creating face feature: {e}")
self.session.rollback()
raise
def create_batch(self, features_data: List[FaceFeatureCreate]) -> List[SurFaceFeature]:
"""
批量创建特征记录
Args:
features_data: 特征数据列表
Returns:
创建的SurFaceFeature对象列表
"""
try:
features = []
for feature_data in features_data:
feature_dict = feature_data.model_dump(exclude_unset=True)
feature = SurFaceFeature(**feature_dict)
features.append(feature)
# 批量添加
self.session.add_all(features)
self.session.flush()
logger.info(f"Created {len(features)} face feature records in batch")
return features
except IntegrityError as e:
logger.error(f"Integrity error creating batch face features: {e}")
self.session.rollback()
raise ValueError("Duplicate feature record in batch")
except SQLAlchemyError as e:
logger.error(f"Database error creating batch face features: {e}")
self.session.rollback()
raise
# ===== 查询操作 =====
def get_by_id(self, feature_id: int) -> Optional[SurFaceFeature]:
"""
根据ID获取特征记录
Args:
feature_id: 特征记录ID
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature by id: {feature_id}")
else:
logger.debug(f"Face feature not found by id: {feature_id}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by id: {e}")
raise
def get_by_person_and_type(self, person_id: int, feature_type: int) -> Optional[SurFaceFeature]:
"""
根据人员ID和特征类型获取特征记录
Args:
person_id: 人员ID
feature_type: 特征类型
Returns:
SurFaceFeature对象或None
"""
try:
stmt = select(SurFaceFeature).where(
and_(
SurFaceFeature.person_id == person_id,
SurFaceFeature.feature_type == feature_type
)
)
result = self.session.execute(stmt)
feature = result.scalar_one_or_none()
if feature:
logger.debug(f"Retrieved face feature: person_id={person_id}, feature_type={feature_type}")
else:
logger.debug(f"Face feature not found: person_id={person_id}, feature_type={feature_type}")
return feature
except SQLAlchemyError as e:
logger.error(f"Database error getting face feature by person and type: {e}")
raise
def get_by_person(self, person_id: int, limit: int = 100) -> List[SurFaceFeature]:
"""
根据人员ID获取特征记录列表
Args:
person_id: 人员ID
limit: 返回数量限制
Returns:
SurFaceFeature对象列表
"""
try:
stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.person_id == person_id)
.order_by(desc(SurFaceFeature.created_time))
.limit(limit)
)
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Retrieved {len(features)} face features for person_id={person_id}")
return features
except SQLAlchemyError as e:
logger.error(f"Database error getting face features by person: {e}")
raise
def query_features(
self,
query: FaceFeatureQuery,
page: int = 1,
page_size: int = 20,
order_by: str = "created_time",
desc_order: bool = True
) -> Tuple[List[SurFaceFeature], int]:
"""
查询特征记录(带分页)
Args:
query: 查询条件
page: 页码从1开始
page_size: 每页数量
order_by: 排序字段
desc_order: 是否降序
Returns:
(特征记录列表, 总记录数)
"""
try:
# 构建查询条件
conditions = []
query_dict = query.model_dump(exclude_unset=True, exclude_none=True)
# 处理查询条件
if "person_id" in query_dict:
conditions.append(SurFaceFeature.person_id == query_dict["person_id"])
if "feature_type" in query_dict:
conditions.append(SurFaceFeature.feature_type == query_dict["feature_type"])
if "status" in query_dict:
conditions.append(SurFaceFeature.status == query_dict["status"])
if "start_date" in query_dict:
conditions.append(SurFaceFeature.created_time >= query_dict["start_date"])
if "end_date" in query_dict:
conditions.append(SurFaceFeature.created_time <= query_dict["end_date"])
if "has_feature_data" in query_dict:
if query_dict["has_feature_data"]:
conditions.append(SurFaceFeature.feature_data.isnot(None))
else:
conditions.append(SurFaceFeature.feature_data.is_(None))
# 基础查询
stmt = select(SurFaceFeature)
if conditions:
stmt = stmt.where(and_(*conditions))
# 获取总数
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = self.session.execute(count_stmt)
total = total_result.scalar_one()
# 排序
order_column = getattr(SurFaceFeature, order_by, SurFaceFeature.created_time)
if desc_order:
stmt = stmt.order_by(desc(order_column))
else:
stmt = stmt.order_by(asc(order_column))
# 分页
offset = (page - 1) * page_size
stmt = stmt.offset(offset).limit(page_size)
# 执行查询
result = self.session.execute(stmt)
features = list(result.scalars().all())
logger.debug(f"Query returned {len(features)} features (total: {total})")
return features, total
except SQLAlchemyError as e:
logger.error(f"Database error querying face features: {e}")
raise
# ===== 更新操作 =====
def update(self, feature_id: int, update_data: FaceFeatureUpdate) -> Optional[SurFaceFeature]:
"""
更新特征记录
Args:
feature_id: 特征记录ID
update_data: 更新数据
Returns:
更新后的SurFaceFeature对象或None如果不存在
"""
try:
# 先检查是否存在
feature = self.get_by_id(feature_id)
if not feature:
logger.warning(f"Cannot update non-existent face feature: id={feature_id}")
return None
# 转换为字典
update_dict = update_data.model_dump(exclude_unset=True, exclude_none=True)
# 更新字段
for key, value in update_dict.items():
setattr(feature, key, value)
# 刷新到数据库
self.session.flush()
logger.info(f"Updated face feature: id={feature_id}")
return feature
except IntegrityError as e:
logger.error(f"Integrity error updating face feature: {e}")
self.session.rollback()
raise ValueError("Update would create duplicate record")
except SQLAlchemyError as e:
logger.error(f"Database error updating face feature: {e}")
self.session.rollback()
raise
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
"""
更新特征数据
Args:
feature_id: 特征记录ID
feature_data: 特征数据(二进制)
Returns:
是否成功更新
"""
try:
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(feature_data=feature_data)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated feature data for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update feature data for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating feature data: {e}")
self.session.rollback()
raise
def update_status(
self,
feature_id: int,
status: FeatureStatus,
start_time: Optional[datetime] = None,
finish_time: Optional[datetime] = None
) -> bool:
"""
更新计算状态
Args:
feature_id: 特征记录ID
status: 新状态
start_time: 开始时间(可选)
finish_time: 结束时间(可选)
Returns:
是否成功更新
"""
try:
update_values = {"status": status.value if isinstance(status, FeatureStatus) else status}
if start_time:
update_values["start_time"] = start_time
if finish_time:
update_values["finish_time"] = finish_time
stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id == feature_id)
.values(**update_values)
.returning(SurFaceFeature.id)
)
result = self.session.execute(stmt)
updated_id = result.scalar_one_or_none()
if updated_id:
logger.info(f"Updated status to {status} for face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot update status for non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error updating status: {e}")
self.session.rollback()
raise
# ===== 删除操作 =====
def delete(self, feature_id: int) -> bool:
"""
删除特征记录
Args:
feature_id: 特征记录ID
Returns:
是否成功删除
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.id == feature_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
if deleted_count > 0:
logger.info(f"Deleted face feature: id={feature_id}")
return True
else:
logger.warning(f"Cannot delete non-existent face feature: id={feature_id}")
return False
except SQLAlchemyError as e:
logger.error(f"Database error deleting face feature: {e}")
self.session.rollback()
raise
def delete_by_person(self, person_id: int) -> int:
"""
删除指定人员的所有特征记录
Args:
person_id: 人员ID
Returns:
删除的记录数
"""
try:
stmt = delete(SurFaceFeature).where(SurFaceFeature.person_id == person_id)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Deleted {deleted_count} face features for person_id={person_id}")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error deleting face features by person: {e}")
self.session.rollback()
raise
# ===== 统计操作 =====
def get_stats(self) -> Dict[str, Any]:
"""
获取特征记录统计信息
Returns:
统计信息字典
"""
try:
# 总记录数
total_stmt = select(func.count()).select_from(SurFaceFeature)
total_result = self.session.execute(total_stmt)
total_count = total_result.scalar_one()
# 按状态统计
status_stmt = (
select(SurFaceFeature.status, func.count())
.group_by(SurFaceFeature.status)
)
status_result = self.session.execute(status_stmt)
status_stats = {str(status): count for status, count in status_result}
# 按特征类型统计
type_stmt = (
select(SurFaceFeature.feature_type, func.count())
.where(SurFaceFeature.feature_type.isnot(None))
.group_by(SurFaceFeature.feature_type)
)
type_result = self.session.execute(type_stmt)
type_stats = {str(feature_type): count for feature_type, count in type_result}
# 平均处理时间(仅计算成功和失败的)
time_stmt = (
select(
func.avg(
func.extract('epoch', SurFaceFeature.finish_time - SurFaceFeature.start_time)
)
)
.where(
and_(
SurFaceFeature.start_time.isnot(None),
SurFaceFeature.finish_time.isnot(None),
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
)
time_result = self.session.execute(time_stmt)
avg_time = time_result.scalar_one()
stats = {
"total_count": total_count,
"by_status": status_stats,
"by_feature_type": type_stats,
"avg_processing_time": float(avg_time) if avg_time else None
}
logger.debug(f"Generated face feature statistics")
return stats
except SQLAlchemyError as e:
logger.error(f"Database error getting statistics: {e}")
raise
# ===== 批量操作 =====
def mark_for_processing(self, limit: int = 100) -> List[SurFaceFeature]:
"""
标记待处理的特征记录为计算中
Args:
limit: 最大处理数量
Returns:
标记为处理中的特征记录列表
"""
try:
# 查找待处理的记录
pending_stmt = (
select(SurFaceFeature)
.where(SurFaceFeature.status == FeatureStatusEnum.NOT_STARTED)
.order_by(SurFaceFeature.created_time)
.limit(limit)
.with_for_update(skip_locked=True) # 跳过被锁定的行
)
result = self.session.execute(pending_stmt)
pending_features = list(result.scalars().all())
# 更新状态
feature_ids = [f.id for f in pending_features]
if feature_ids:
update_stmt = (
update(SurFaceFeature)
.where(SurFaceFeature.id.in_(feature_ids))
.values(
status=FeatureStatusEnum.PROCESSING,
start_time=datetime.now()
)
)
self.session.execute(update_stmt)
logger.info(f"Marked {len(pending_features)} face features for processing")
return pending_features
except SQLAlchemyError as e:
logger.error(f"Database error marking features for processing: {e}")
self.session.rollback()
raise
def cleanup_old_features(self, days: int = 30) -> int:
"""
清理旧的特征记录
Args:
days: 保留天数
Returns:
删除的记录数
"""
try:
cutoff_date = datetime.now() - timedelta(days=days)
# 只删除已完成(成功或失败)的旧记录
stmt = delete(SurFaceFeature).where(
and_(
SurFaceFeature.created_time < cutoff_date,
SurFaceFeature.status.in_([FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED])
)
)
result = self.session.execute(stmt)
deleted_count = result.rowcount
logger.info(f"Cleaned up {deleted_count} old face features (older than {days} days)")
return deleted_count
except SQLAlchemyError as e:
logger.error(f"Database error cleaning up old features: {e}")
self.session.rollback()
raise

86
src/run.py Normal file
View File

@@ -0,0 +1,86 @@
#!/usr/bin/env python3
"""
FastAPI应用启动脚本
"""
import uvicorn
import argparse
from src.config import settings
def main():
"""
主函数:解析命令行参数并启动服务器
"""
parser = argparse.ArgumentParser(description="algorithm service")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="监听主机 (默认: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="监听端口 (默认: 8000)"
)
parser.add_argument(
"--reload",
action="store_true",
help="启用热重载 (开发模式)"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="工作进程数 (生产模式)"
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["debug", "info", "warning", "error", "critical"],
help="日志级别"
)
args = parser.parse_args()
# 根据环境选择配置
if settings.DEBUG:
print("🔧 开发模式")
uvicorn_config = {
"host": args.host,
"port": args.port,
"reload": True,
"log_level": "debug",
"workers": 1
}
else:
print("🚀 生产模式")
uvicorn_config = {
"host": args.host,
"port": args.port,
"reload": False,
"log_level": args.log_level,
"workers": args.workers
}
# 启动服务器
print(f"🌐 启动服务器: http://{args.host}:{args.port}")
print(f"📚 API文档: http://{args.host}:{args.port}/docs")
print(f"📊 健康检查: http://{args.host}:{args.port}/health")
uvicorn.run(
"src.app:app",
**uvicorn_config
)
if __name__ == "__main__":
main()

213
src/schemas/face_feature.py Normal file
View File

@@ -0,0 +1,213 @@
"""
人脸特征值的Pydantic模型
用于数据验证和序列化
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, field_validator, ConfigDict, model_validator
from enum import IntEnum
# 枚举定义(与数据库模型一致)
class FeatureStatus(IntEnum):
NOT_STARTED = 0
PROCESSING = 1
SUCCESS = 2
FAILED = 3
# 基础模型
class FaceFeatureBase(BaseModel):
"""基础模型,包含所有字段"""
person_id: int = Field(..., description="人员ID", gt=0)
feature_type: Optional[int] = Field(None, description="模型版本", ge=0)
feature_data: Optional[bytes] = Field(None, description="特征值(二进制)")
pic_id: Optional[str] = Field(None, description="图片ID", max_length=255)
status: FeatureStatus = Field(
default=FeatureStatus.NOT_STARTED,
description="计算状态"
)
start_time: Optional[datetime] = Field(None, description="计算开始时间")
finish_time: Optional[datetime] = Field(None, description="计算结束时间")
@field_validator('feature_data', mode='before')
@classmethod
def validate_feature_data(cls, v):
"""验证特征数据"""
if v is not None and not isinstance(v, bytes):
if isinstance(v, str):
# 尝试从hex字符串转换
try:
return bytes.fromhex(v)
except ValueError:
raise ValueError("feature_data must be valid hex string or bytes")
else:
raise ValueError("feature_data must be bytes or hex string")
return v
# 创建模型
class FaceFeatureCreate(FaceFeatureBase):
"""创建特征记录模型"""
# 创建时不指定ID和时间
model_config = ConfigDict(
json_schema_extra={
"example": {
"person_id": 1001,
"feature_type": 1,
"pic_id": "img_20250101_001",
"status": 0
}
}
)
class FaceFeatureUpdate(BaseModel):
"""更新特征记录模型"""
feature_type: Optional[int] = Field(None, description="模型版本", ge=0)
feature_data: Optional[bytes] = Field(None, description="特征值(二进制)")
pic_id: Optional[str] = Field(None, description="图片ID", max_length=255)
status: Optional[FeatureStatus] = Field(None, description="计算状态")
start_time: Optional[datetime] = Field(None, description="计算开始时间")
finish_time: Optional[datetime] = Field(None, description="计算结束时间")
model_config = ConfigDict(
json_schema_extra={
"example": {
"status": 2,
"finish_time": "2024-01-01T12:00:00Z"
}
}
)
# 查询参数模型
class FaceFeatureQuery(BaseModel):
"""特征记录查询参数"""
person_id: Optional[int] = Field(None, description="人员ID", gt=0)
feature_type: Optional[int] = Field(None, description="模型版本", ge=0)
status: Optional[FeatureStatus] = Field(None, description="计算状态")
start_date: Optional[datetime] = Field(None, description="开始时间")
end_date: Optional[datetime] = Field(None, description="结束时间")
has_feature_data: Optional[bool] = Field(None, description="是否有特征数据")
model_config = ConfigDict(
json_schema_extra={
"example": {
"person_id": 1001,
"status": 2,
"start_date": "2024-01-01T00:00:00Z"
}
}
)
# 响应模型
class FaceFeatureResponse(FaceFeatureBase):
"""特征记录响应模型"""
id: int
created_time: datetime
# 计算字段(将在验证后设置)
status_name: Optional[str] = None
is_completed: Optional[bool] = None
processing_time: Optional[float] = None
has_feature_data: Optional[bool] = None
@model_validator(mode='after')
def set_computed_fields(self):
"""设置所有计算字段"""
# 状态名称
try:
self.status_name = FeatureStatus(self.status).name
except ValueError:
self.status_name = f"未知状态({self.status})"
# 是否完成
self.is_completed = self.status in [FeatureStatus.SUCCESS, FeatureStatus.FAILED]
# 处理时间
if self.start_time and self.finish_time:
self.processing_time = (self.finish_time - self.start_time).total_seconds()
# 是否有特征数据
self.has_feature_data = self.feature_data is not None and len(self.feature_data) > 0
return self
model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={
"example": {
"id": 1,
"person_id": 1001,
"feature_type": 1,
"status": 2,
"status_name": "SUCCESS",
"created_time": "2024-01-01T10:00:00Z",
"pic_id": "img_20250101_001",
"start_time": "2024-01-01T10:00:00Z",
"finish_time": "2024-01-01T10:00:05Z",
"is_completed": True,
"processing_time": 5.0,
"has_feature_data": True
}
}
)
class FaceFeatureListResponse(BaseModel):
"""特征记录列表响应"""
total: int = Field(..., description="总记录数")
items: List[FaceFeatureResponse] = Field(..., description="记录列表")
page: Optional[int] = Field(None, description="当前页码")
page_size: Optional[int] = Field(None, description="每页数量")
model_config = ConfigDict(
json_schema_extra={
"example": {
"total": 100,
"page": 1,
"page_size": 20,
"items": []
}
}
)
# 批量操作模型
class BatchFaceFeatureCreate(BaseModel):
"""批量创建特征记录"""
items: List[FaceFeatureCreate] = Field(..., description="特征记录列表", min_items=1, max_items=1000)
model_config = ConfigDict(
json_schema_extra={
"example": {
"items": [
{"person_id": 1001, "feature_type": 1},
{"person_id": 1002, "feature_type": 1}
]
}
}
)
class FaceFeatureStatsResponse(BaseModel):
"""特征记录统计响应"""
total_count: int = Field(..., description="总记录数")
by_status: Dict[str, int] = Field(..., description="按状态统计")
by_feature_type: Dict[str, int] = Field(..., description="按特征类型统计")
avg_processing_time: Optional[float] = Field(None, description="平均处理时间(秒)")
model_config = ConfigDict(
json_schema_extra={
"example": {
"total_count": 1000,
"by_status": {"SUCCESS": 800, "PROCESSING": 100, "FAILED": 100},
"by_feature_type": {"1": 500, "2": 500},
"avg_processing_time": 5.2
}
}
)

View File

@@ -0,0 +1,562 @@
"""
人脸特征业务逻辑服务层
"""
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any
from contextlib import contextmanager
from src.repositories.face_feature_repository import FaceFeatureRepository
from src.schemas.face_feature import (
FaceFeatureCreate,
FaceFeatureUpdate,
FaceFeatureQuery,
FaceFeatureResponse,
FaceFeatureListResponse,
FaceFeatureStatsResponse,
BatchFaceFeatureCreate,
FeatureStatus
)
from src.models.face_feature import FeatureStatusEnum
from src.utils.logger import setup_logger
logger = setup_logger(__name__)
class FaceFeatureService:
"""人脸特征业务服务"""
def __init__(self, repository: FaceFeatureRepository):
"""
初始化服务
Args:
repository: 特征仓库实例
"""
self.repository = repository
# ===== CRUD操作 =====
def create_feature(self, feature_data: FaceFeatureCreate) -> FaceFeatureResponse:
"""
创建特征记录
Args:
feature_data: 特征数据
Returns:
创建的特征记录响应
"""
logger.info(f"Creating face feature for person_id={feature_data.person_id}")
# 业务逻辑验证
self._validate_feature_data(feature_data)
# 检查是否已存在相同记录
if feature_data.feature_type is not None:
existing = self.repository.get_by_person_and_type(
feature_data.person_id,
feature_data.feature_type
)
if existing:
raise ValueError(
f"Feature record already exists for person_id={feature_data.person_id}, "
f"feature_type={feature_data.feature_type}"
)
# 创建记录
feature = self.repository.create(feature_data)
# 转换为响应模型
return FaceFeatureResponse.model_validate(feature)
def create_features_batch(self, batch_data: BatchFaceFeatureCreate) -> List[FaceFeatureResponse]:
"""
批量创建特征记录
Args:
batch_data: 批量特征数据
Returns:
创建的特征记录响应列表
"""
logger.info(f"Creating {len(batch_data.items)} face features in batch")
# 验证所有数据
for feature_data in batch_data.items:
self._validate_feature_data(feature_data)
# 批量创建
features = self.repository.create_batch(batch_data.items)
# 转换为响应模型
return [FaceFeatureResponse.model_validate(f) for f in features]
def get_feature(self, feature_id: int) -> Optional[FaceFeatureResponse]:
"""
获取特征记录
Args:
feature_id: 特征记录ID
Returns:
特征记录响应或None
"""
logger.debug(f"Getting face feature: id={feature_id}")
feature = self.repository.get_by_id(feature_id)
if not feature:
return None
return FaceFeatureResponse.model_validate(feature)
def get_feature_by_person_and_type(
self,
person_id: int,
feature_type: int
) -> Optional[FaceFeatureResponse]:
"""
根据人员ID和特征类型获取特征记录
Args:
person_id: 人员ID
feature_type: 特征类型
Returns:
特征记录响应或None
"""
logger.debug(f"Getting face feature: person_id={person_id}, feature_type={feature_type}")
feature = self.repository.get_by_person_and_type(person_id, feature_type)
if not feature:
return None
return FaceFeatureResponse.model_validate(feature)
def list_features_by_person(
self,
person_id: int,
limit: int = 100
) -> List[FaceFeatureResponse]:
"""
获取人员的特征记录列表
Args:
person_id: 人员ID
limit: 返回数量限制
Returns:
特征记录响应列表
"""
logger.debug(f"Listing face features for person_id={person_id}")
features = self.repository.get_by_person(person_id, limit)
return [FaceFeatureResponse.model_validate(f) for f in features]
def query_features(
self,
query: FaceFeatureQuery,
page: int = 1,
page_size: int = 20,
order_by: str = "created_time",
desc_order: bool = True
) -> FaceFeatureListResponse:
"""
查询特征记录
Args:
query: 查询条件
page: 页码
page_size: 每页数量
order_by: 排序字段
desc_order: 是否降序
Returns:
特征记录列表响应
"""
logger.debug(f"Querying face features with filters: {query.model_dump(exclude_unset=True)}")
features, total = self.repository.query_features(
query, page, page_size, order_by, desc_order
)
items = [FaceFeatureResponse.model_validate(f) for f in features]
return FaceFeatureListResponse(
total=total,
items=items,
page=page,
page_size=page_size
)
def update_feature(
self,
feature_id: int,
update_data: FaceFeatureUpdate
) -> Optional[FaceFeatureResponse]:
"""
更新特征记录
Args:
feature_id: 特征记录ID
update_data: 更新数据
Returns:
更新后的特征记录响应或None
"""
logger.info(f"Updating face feature: id={feature_id}")
# 业务逻辑验证
if update_data.status is not None:
self._validate_status_transition(feature_id, update_data.status)
# 更新记录
feature = self.repository.update(feature_id, update_data)
if not feature:
return None
return FaceFeatureResponse.model_validate(feature)
def update_feature_data(self, feature_id: int, feature_data: bytes) -> bool:
"""
更新特征数据
Args:
feature_id: 特征记录ID
feature_data: 特征数据
Returns:
是否成功更新
"""
logger.info(f"Updating feature data for face feature: id={feature_id}")
# 验证特征数据
if not feature_data or len(feature_data) == 0:
raise ValueError("Feature data cannot be empty")
if len(feature_data) > 1024 * 1024: # 1MB限制
raise ValueError("Feature data is too large (max 1MB)")
return self.repository.update_feature_data(feature_id, feature_data)
def update_status(
self,
feature_id: int,
status: FeatureStatus,
start_time: Optional[datetime] = None,
finish_time: Optional[datetime] = None
) -> bool:
"""
更新计算状态
Args:
feature_id: 特征记录ID
status: 新状态
start_time: 开始时间
finish_time: 结束时间
Returns:
是否成功更新
"""
logger.info(f"Updating status to {status} for face feature: id={feature_id}")
# 验证状态转换
self._validate_status_transition(feature_id, status)
return self.repository.update_status(feature_id, status, start_time, finish_time)
def delete_feature(self, feature_id: int) -> bool:
"""
删除特征记录
Args:
feature_id: 特征记录ID
Returns:
是否成功删除
"""
logger.info(f"Deleting face feature: id={feature_id}")
return self.repository.delete(feature_id)
def delete_features_by_person(self, person_id: int) -> int:
"""
删除指定人员的所有特征记录
Args:
person_id: 人员ID
Returns:
删除的记录数
"""
logger.info(f"Deleting all face features for person_id={person_id}")
return self.repository.delete_by_person(person_id)
# ===== 业务操作 =====
def start_processing(self, feature_id: int) -> bool:
"""
开始处理特征计算
Args:
feature_id: 特征记录ID
Returns:
是否成功开始
"""
logger.info(f"Starting processing for face feature: id={feature_id}")
# 获取当前特征
feature = self.repository.get_by_id(feature_id)
if not feature:
return False
# 验证状态
if feature.status != FeatureStatusEnum.NOT_STARTED:
raise ValueError(
f"Cannot start processing for feature with status {feature.status_name}"
)
# 更新状态
return self.repository.update_status(
feature_id,
FeatureStatus.PROCESSING,
start_time=datetime.now()
)
def finish_processing(self, feature_id: int, success: bool = True) -> bool:
"""
完成特征计算
Args:
feature_id: 特征记录ID
success: 是否成功
Returns:
是否成功完成
"""
logger.info(f"Finishing processing for face feature: id={feature_id}, success={success}")
# 获取当前特征
feature = self.repository.get_by_id(feature_id)
if not feature:
return False
# 验证状态
if feature.status != FeatureStatusEnum.PROCESSING:
raise ValueError(
f"Cannot finish processing for feature with status {feature.status_name}"
)
# 更新状态
status = FeatureStatus.SUCCESS if success else FeatureStatus.FAILED
return self.repository.update_status(
feature_id,
status,
finish_time=datetime.now()
)
def process_pending_features(self, limit: int = 100) -> List[FaceFeatureResponse]:
"""
处理待计算的特征记录
Args:
limit: 最大处理数量
Returns:
处理中的特征记录列表
"""
logger.info(f"Processing up to {limit} pending face features")
features = self.repository.mark_for_processing(limit)
return [FaceFeatureResponse.model_validate(f) for f in features]
# ===== 统计和分析 =====
def get_statistics(self) -> FaceFeatureStatsResponse:
"""
获取特征记录统计信息
Returns:
统计信息响应
"""
logger.debug("Getting face feature statistics")
stats = self.repository.get_stats()
# 转换状态枚举名称
by_status = {}
for status_value, count in stats["by_status"].items():
try:
status_name = FeatureStatusEnum(int(status_value)).name
by_status[status_name] = count
except ValueError:
by_status[f"未知({status_value})"] = count
return FaceFeatureStatsResponse(
total_count=stats["total_count"],
by_status=by_status,
by_feature_type=stats["by_feature_type"],
avg_processing_time=stats["avg_processing_time"]
)
def get_person_statistics(self, person_id: int) -> Dict[str, Any]:
"""
获取人员的特征统计信息
Args:
person_id: 人员ID
Returns:
人员统计信息
"""
logger.debug(f"Getting statistics for person_id={person_id}")
# 获取该人员的所有特征
features = self.repository.get_by_person(person_id, limit=1000)
if not features:
return {
"person_id": person_id,
"total_features": 0,
"status_summary": {},
"feature_types": []
}
# 统计信息
status_summary = {}
feature_types = set()
successful_features = []
for feature in features:
# 状态统计
status_name = feature.status_name
status_summary[status_name] = status_summary.get(status_name, 0) + 1
# 特征类型
if feature.feature_type is not None:
feature_types.add(feature.feature_type)
# 成功完成的特征
if feature.status == FeatureStatusEnum.SUCCESS:
successful_features.append(feature)
# 计算平均处理时间(仅成功记录)
total_time = 0
valid_count = 0
for feature in successful_features:
if feature.processing_time:
total_time += feature.processing_time
valid_count += 1
avg_time = total_time / valid_count if valid_count > 0 else None
return {
"person_id": person_id,
"total_features": len(features),
"status_summary": status_summary,
"feature_types": sorted(list(feature_types)),
"avg_processing_time": avg_time,
"successful_count": len(successful_features)
}
def cleanup_old_records(self, days: int = 30) -> Dict[str, Any]:
"""
清理旧的特征记录
Args:
days: 保留天数
Returns:
清理结果
"""
logger.info(f"Cleaning up face features older than {days} days")
# 先获取清理前的统计
before_stats = self.repository.get_stats()
# 执行清理
deleted_count = self.repository.cleanup_old_features(days)
# 获取清理后的统计
after_stats = self.repository.get_stats()
return {
"days_retained": days,
"deleted_count": deleted_count,
"before_total": before_stats["total_count"],
"after_total": after_stats["total_count"],
"reduction_percentage": (
(before_stats["total_count"] - after_stats["total_count"]) /
before_stats["total_count"] * 100
if before_stats["total_count"] > 0 else 0
)
}
# ===== 私有方法 =====
def _validate_feature_data(self, feature_data: FaceFeatureCreate) -> None:
"""
验证特征数据
Args:
feature_data: 特征数据
Raises:
ValueError: 如果数据无效
"""
# 验证人员ID
if feature_data.person_id <= 0:
raise ValueError("person_id must be greater than 0")
# 验证特征类型
if feature_data.feature_type is not None and feature_data.feature_type < 0:
raise ValueError("feature_type must be non-negative")
# 验证特征数据大小
if feature_data.feature_data and len(feature_data.feature_data) > 1024 * 1024:
raise ValueError("feature_data is too large (max 1MB)")
# 验证状态
if feature_data.status:
try:
FeatureStatus(feature_data.status)
except ValueError:
raise ValueError(f"Invalid status value: {feature_data.status}")
def _validate_status_transition(self, feature_id: int, new_status: FeatureStatus) -> None:
"""
验证状态转换是否有效
Args:
feature_id: 特征记录ID
new_status: 新状态
Raises:
ValueError: 如果状态转换无效
"""
# 获取当前特征
feature = self.repository.get_by_id(feature_id)
if not feature:
return
current_status = FeatureStatusEnum(feature.status)
# 定义允许的状态转换
allowed_transitions = {
FeatureStatusEnum.NOT_STARTED: [FeatureStatusEnum.PROCESSING],
FeatureStatusEnum.PROCESSING: [FeatureStatusEnum.SUCCESS, FeatureStatusEnum.FAILED],
FeatureStatusEnum.SUCCESS: [],
FeatureStatusEnum.FAILED: [FeatureStatusEnum.PROCESSING] # 允许重试
}
new_status_enum = FeatureStatusEnum(new_status.value if isinstance(new_status, FeatureStatus) else new_status)
# 检查转换是否允许
if new_status_enum not in allowed_transitions.get(current_status, []):
raise ValueError(
f"Cannot transition from {current_status.name} to {new_status_enum.name}"
)

88
src/utils/logger.py Normal file
View File

@@ -0,0 +1,88 @@
"""
日志配置模块
"""
import logging
import sys
from typing import Optional
from logging.handlers import RotatingFileHandler
from src.config import settings
def setup_logger(
name: str,
level: Optional[str] = None,
log_file: Optional[str] = None
) -> logging.Logger:
"""
配置和获取logger
Args:
name: logger名称
level: 日志级别
log_file: 日志文件路径
Returns:
配置好的logger实例
"""
# 获取日志级别
if level is None:
level = settings.LOG_LEVEL
log_level = getattr(logging, level.upper(), logging.INFO)
# 创建logger
logger = logging.getLogger(name)
logger.setLevel(log_level)
# 避免重复添加handler
if logger.handlers:
return logger
# 创建formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 控制台handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 文件handler如果配置了日志文件
if log_file or settings.LOG_FILE:
file_path = log_file or settings.LOG_FILE
try:
file_handler = RotatingFileHandler(
file_path,
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
except Exception as e:
logger.warning(f"Failed to create log file handler: {e}")
return logger
# 创建根logger
root_logger = setup_logger("sur_face_feature")
def get_logger(name: str) -> logging.Logger:
"""
获取指定名称的logger
Args:
name: logger名称
Returns:
logger实例
"""
return setup_logger(name)