""" 数据库连接管理模块 支持同步和异步两种模式 """ import logging from typing import Optional, Generator, AsyncGenerator from contextlib import asynccontextmanager, contextmanager from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, create_async_engine, async_sessionmaker ) from config import settings from utils.logger import setup_logger logger = setup_logger(__name__) class DatabaseManager: """数据库管理器(同步)""" def __init__(self): self._engine: Optional[Engine] = None self._session_factory: Optional[sessionmaker] = None def init_engine(self) -> None: """初始化数据库引擎""" if self._engine is not None: logger.warning("Database engine already initialized") return try: self._engine = create_engine( settings.DATABASE_URL, pool_size=settings.DATABASE_POOL_SIZE, max_overflow=settings.DATABASE_MAX_OVERFLOW, pool_recycle=settings.DATABASE_POOL_RECYCLE, echo=settings.DATABASE_ECHO, pool_pre_ping=True, # 连接前进行ping检查 connect_args={ "options": f"-csearch_path={settings.DATABASE_SCHEMA}" } ) self._session_factory = sessionmaker( autocommit=False, autoflush=False, bind=self._engine, expire_on_commit=False # 避免延迟加载问题 ) logger.info("Database engine initialized successfully") except Exception as e: logger.error(f"Failed to initialize database engine: {e}") raise @property def engine(self) -> Engine: """获取数据库引擎""" if self._engine is None: self.init_engine() return self._engine @property def session_factory(self) -> sessionmaker: """获取会话工厂""" if self._session_factory is None: self.init_engine() return self._session_factory @contextmanager def get_session(self) -> Generator[Session, None, None]: """ 获取数据库会话的上下文管理器 使用示例: with db_manager.get_session() as session: result = session.query(User).all() """ session = self.session_factory() try: yield session session.commit() logger.debug("Session committed successfully") except Exception as e: session.rollback() logger.error(f"Session rollback due to error: {e}") raise finally: session.close() logger.debug("Session closed") def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list: """执行原始SQL查询""" with self.get_session() as session: result = session.execute(text(sql), params or {}) return [dict(row._mapping) for row in result] def health_check(self) -> bool: """数据库健康检查""" try: with self.engine.connect() as conn: conn.execute(text("SELECT 1")) logger.info("Database health check passed") return True except Exception as e: logger.error(f"Database health check failed: {e}") return False def close(self) -> None: """关闭数据库连接""" if self._engine: self._engine.dispose() self._engine = None self._session_factory = None logger.info("Database connections closed") class AsyncDatabaseManager: """异步数据库管理器""" def __init__(self): self._engine: Optional[AsyncEngine] = None self._async_session_factory: Optional[async_sessionmaker] = None async def init_engine(self) -> None: """初始化异步数据库引擎""" if self._engine is not None: logger.warning("Async database engine already initialized") return try: self._engine = create_async_engine( settings.ASYNC_DATABASE_URL, pool_size=settings.DATABASE_POOL_SIZE, max_overflow=settings.DATABASE_MAX_OVERFLOW, pool_recycle=settings.DATABASE_POOL_RECYCLE, echo=settings.DATABASE_ECHO, pool_pre_ping=True, connect_args={ "server_settings": { "search_path": settings.DATABASE_SCHEMA } } ) self._async_session_factory = async_sessionmaker( self._engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False ) logger.info("Async database engine initialized successfully") except Exception as e: logger.error(f"Failed to initialize async database engine: {e}") raise @property async def engine(self) -> AsyncEngine: """获取异步数据库引擎""" if self._engine is None: await self.init_engine() return self._engine @property async def async_session_factory(self) -> async_sessionmaker: """获取异步会话工厂""" if self._async_session_factory is None: await self.init_engine() return self._async_session_factory @asynccontextmanager async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """ 获取异步数据库会话的上下文管理器 使用示例: async with async_db_manager.get_session() as session: result = await session.execute(query) """ if self._async_session_factory is None: await self.init_engine() session = self._async_session_factory() try: yield session await session.commit() logger.debug("Async session committed successfully") except Exception as e: await session.rollback() logger.error(f"Async session rollback due to error: {e}") raise finally: await session.close() logger.debug("Async session closed") async def health_check(self) -> bool: """异步数据库健康检查""" try: async with self._engine.connect() as conn: await conn.execute(text("SELECT 1")) logger.info("Async database health check passed") return True except Exception as e: logger.error(f"Async database health check failed: {e}") return False async def close(self) -> None: """关闭异步数据库连接""" if self._engine: await self._engine.dispose() self._engine = None self._async_session_factory = None logger.info("Async database connections closed") # 创建全局数据库管理器实例 db_manager = DatabaseManager() async_db_manager = AsyncDatabaseManager() def init_database() -> None: """初始化数据库(同步)""" db_manager.init_engine() async def init_async_database() -> None: """初始化异步数据库""" await async_db_manager.init_engine() def get_db() -> Generator[Session, None, None]: """依赖注入:获取数据库会话(用于FastAPI等框架)""" with db_manager.get_session() as session: yield session async def get_async_db() -> AsyncGenerator[AsyncSession, None]: """依赖注入:获取异步数据库会话(用于FastAPI等框架)""" async with async_db_manager.get_session() as session: yield session