补充提交

This commit is contained in:
zqc
2026-01-08 10:33:53 +08:00
parent f86effd63c
commit a183d650b3
4 changed files with 410 additions and 0 deletions

0
__init__.py Normal file
View File

219
app.py Normal file
View File

@@ -0,0 +1,219 @@
"""
FastAPI主应用
将原来的main.py重命名为app.py
"""
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from fastapi.openapi.docs import (
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
get_redoc_html,
)
from fastapi.staticfiles import StaticFiles
from api.routes import face_features
from api.routes.algorithm_router import router as algorithm_router, sync_videofacebiz_params, sync_videofacebiz_blacklist
from api.errors import (
APIError,
validation_exception_handler,
api_error_handler,
generic_exception_handler
)
from config import settings
from database.connection import init_database
from database.connection import db_manager
from rtsp.service import rtsp_server
# 生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
- 启动时:初始化数据库
- 关闭时:清理资源
"""
# 启动时
print("🚀 start algorithm service...")
print(f"📊 db: {settings.DATABASE_NAME}")
print(f"🔧 debug mode: {settings.DEBUG}")
# 初始化数据库
init_database()
# 数据库健康检查
if db_manager.health_check():
print("✅ 数据库连接正常")
else:
print("❌ 数据库连接失败")
# 启动 RTSP 服务(如果启用)
if settings.RTSP_ENABLED:
print("📹 启动 RTSP 服务...")
rtsp_server.start()
# 将 RTSP 服务实例保存到应用状态
app.state.rtsp_server = rtsp_server
# 自动同步VideoFaceBiz参数和黑名单
print("🔄 自动同步VideoFaceBiz参数和黑名单...")
try:
params_updated = sync_videofacebiz_params()
blacklist_loaded = sync_videofacebiz_blacklist()
print(f"✅ 自动同步完成 - 参数更新: {params_updated}个, 黑名单加载: {blacklist_loaded}")
except Exception as e:
print(f"⚠️ 自动同步失败: {e}")
else:
print("⚠️ RTSP 服务未启用")
yield
# 关闭时
print("🛑 algorithm service stopped...")
# 停止 RTSP 服务
if settings.RTSP_ENABLED:
print("🛑 停止 RTSP 服务...")
rtsp_server.stop()
db_manager.close()
# 创建FastAPI应用
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.PROJECT_VERSION,
description=settings.PROJECT_DESCRIPTION,
openapi_url=f"{settings.API_V1_PREFIX}/openapi.json",
docs_url=None, # 自定义docs
redoc_url=None, # 自定义redoc
lifespan=lifespan
)
# 自定义API文档页面
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=f"{app.title} - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js",
swagger_css_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css",
)
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect():
return get_swagger_ui_oauth2_redirect_html()
@app.get("/redoc", include_in_schema=False)
async def redoc_html():
return get_redoc_html(
openapi_url=app.openapi_url,
title=f"{app.title} - ReDoc",
redoc_js_url="https://unpkg.com/redoc@next/bundles/redoc.standalone.js",
)
# 中间件配置
app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["*"] if settings.DEBUG else ["localhost", "127.0.0.1", "0.0.0.0"]
)
# 请求计时中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
"""
添加请求处理时间头
"""
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 异常处理器
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(APIError, api_error_handler)
app.add_exception_handler(Exception, generic_exception_handler)
# 根路由
@app.get("/")
async def root():
"""
根路径
"""
return {
"message": "algorithm service",
"version": settings.PROJECT_VERSION,
"docs": "/docs",
"api_prefix": settings.API_V1_PREFIX
}
@app.get("/health")
async def health_check():
"""
健康检查端点
"""
# 检查数据库连接
db_healthy = db_manager.health_check()
return {
"status": "healthy" if db_healthy else "unhealthy",
"database": "connected" if db_healthy else "disconnected",
"timestamp": time.time()
}
# 注册API路由
app.include_router(
face_features.router,
prefix=settings.API_V1_PREFIX
)
app.include_router(
algorithm_router,
prefix=settings.API_V1_PREFIX
)
# 自定义404处理器
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
"""
自定义404错误处理器
"""
return JSONResponse(
status_code=404,
content={
"error": {
"code": "NOT_FOUND",
"message": f"请求的资源不存在: {request.url.path}"
}
}
)
# 导出应用实例
__all__ = ["app"]

103
config.py Normal file
View File

@@ -0,0 +1,103 @@
"""
数据库配置模块
使用pydantic进行配置验证和管理
"""
from typing import Optional, List
from pydantic_settings import BaseSettings
from functools import lru_cache
from pydantic import PostgresDsn, field_validator
from pydantic_core.core_schema import FieldValidationInfo
class Settings(BaseSettings):
"""应用配置类"""
RTSP_ENABLED: bool = True
# API配置
API_V1_PREFIX: str = "/api/v1"
PROJECT_NAME: str = "algorithm-service"
PROJECT_VERSION: str = "1.0.0"
PROJECT_DESCRIPTION: str = "algorithm-service"
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
# 数据库配置
DATABASE_HOST: str = "localhost"
DATABASE_PORT: int = 5432
DATABASE_USER: str = "postgres"
DATABASE_PASSWORD: str = "yipai123"
DATABASE_NAME: str = "pmms"
DATABASE_SCHEMA: str = "public"
# 连接池配置
DATABASE_POOL_SIZE: int = 10
DATABASE_MAX_OVERFLOW: int = 20
DATABASE_POOL_RECYCLE: int = 3600 # 连接回收时间(秒)
DATABASE_ECHO: bool = False # SQL日志生产环境设为False
# 应用配置
APP_NAME: str = "SurFaceFeature API"
APP_VERSION: str = "1.0.0"
DEBUG: bool = False
# 日志配置
LOG_LEVEL: str = "INFO"
LOG_FILE: Optional[str] = None
# 异步配置
ASYNC_MODE: bool = False
# 资源文件夹配置
FACE_REGISTER_IMAGE_RESOURCE_DIR: str = "D:/ruoyi/uploadPath/face"
VIDEO_RESOURCE_DIR: str = "D:/ruoyi/uploadPath/video"
FACE_CAL_FEATURE_TIMEOUT_HOURS: int = 10
FACE_MODEL_VERSION: int = 0 #insight_face_buffalo_l
FACE_USE_GPU: bool = True
FACE_USE_NPU: bool = False
SUR_CONFIG_TYPE_FACE: int = 0
SUR_CONFIG_SCOPE_GLOBAL: int = 0
# JWT配置预留
SECRET_KEY: str = "your-secret-key-here-change-in-production"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
@property
def DATABASE_URL(self) -> str:
"""构建数据库连接URL"""
return f"postgresql://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
@property
def ASYNC_DATABASE_URL(self) -> str:
"""构建异步数据库连接URL"""
return f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
@field_validator("DATABASE_POOL_SIZE")
def validate_pool_size(cls, v):
"""验证连接池大小"""
if v < 1:
raise ValueError("DATABASE_POOL_SIZE must be at least 1")
if v > 100:
raise ValueError("DATABASE_POOL_SIZE cannot exceed 100")
return v
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
extra = "ignore"
@lru_cache()
def get_settings() -> Settings:
"""
获取配置单例
使用lru_cache避免重复加载.env文件
"""
return Settings()
# 导出配置实例
settings = get_settings()

88
utils/logger.py Normal file
View File

@@ -0,0 +1,88 @@
"""
日志配置模块
"""
import logging
import sys
from typing import Optional
from logging.handlers import RotatingFileHandler
from config import settings
def setup_logger(
name: str,
level: Optional[str] = None,
log_file: Optional[str] = None
) -> logging.Logger:
"""
配置和获取logger
Args:
name: logger名称
level: 日志级别
log_file: 日志文件路径
Returns:
配置好的logger实例
"""
# 获取日志级别
if level is None:
level = settings.LOG_LEVEL
log_level = getattr(logging, level.upper(), logging.INFO)
# 创建logger
logger = logging.getLogger(name)
logger.setLevel(log_level)
# 避免重复添加handler
if logger.handlers:
return logger
# 创建formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 控制台handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 文件handler如果配置了日志文件
if log_file or settings.LOG_FILE:
file_path = log_file or settings.LOG_FILE
try:
file_handler = RotatingFileHandler(
file_path,
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
except Exception as e:
logger.warning(f"Failed to create log file handler: {e}")
return logger
# 创建根logger
root_logger = setup_logger("sur_face_feature")
def get_logger(name: str) -> logging.Logger:
"""
获取指定名称的logger
Args:
name: logger名称
Returns:
logger实例
"""
return setup_logger(name)