252 lines
7.9 KiB
Python
252 lines
7.9 KiB
Python
"""
|
||
数据库连接管理模块
|
||
支持同步和异步两种模式
|
||
"""
|
||
|
||
import logging
|
||
from typing import Optional, Generator, AsyncGenerator
|
||
from contextlib import asynccontextmanager, contextmanager
|
||
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.engine import Engine
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
from sqlalchemy.ext.asyncio import (
|
||
AsyncEngine,
|
||
AsyncSession,
|
||
create_async_engine,
|
||
async_sessionmaker
|
||
)
|
||
|
||
from config import settings
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
class DatabaseManager:
|
||
"""数据库管理器(同步)"""
|
||
|
||
def __init__(self):
|
||
self._engine: Optional[Engine] = None
|
||
self._session_factory: Optional[sessionmaker] = None
|
||
|
||
def init_engine(self) -> None:
|
||
"""初始化数据库引擎"""
|
||
if self._engine is not None:
|
||
logger.warning("Database engine already initialized")
|
||
return
|
||
|
||
try:
|
||
self._engine = create_engine(
|
||
settings.DATABASE_URL,
|
||
pool_size=settings.DATABASE_POOL_SIZE,
|
||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||
echo=settings.DATABASE_ECHO,
|
||
pool_pre_ping=True, # 连接前进行ping检查
|
||
connect_args={
|
||
"options": f"-csearch_path={settings.DATABASE_SCHEMA}"
|
||
}
|
||
)
|
||
|
||
self._session_factory = sessionmaker(
|
||
autocommit=False,
|
||
autoflush=False,
|
||
bind=self._engine,
|
||
expire_on_commit=False # 避免延迟加载问题
|
||
)
|
||
|
||
logger.info("Database engine initialized successfully")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize database engine: {e}")
|
||
raise
|
||
|
||
@property
|
||
def engine(self) -> Engine:
|
||
"""获取数据库引擎"""
|
||
if self._engine is None:
|
||
self.init_engine()
|
||
return self._engine
|
||
|
||
@property
|
||
def session_factory(self) -> sessionmaker:
|
||
"""获取会话工厂"""
|
||
if self._session_factory is None:
|
||
self.init_engine()
|
||
return self._session_factory
|
||
|
||
@contextmanager
|
||
def get_session(self) -> Generator[Session, None, None]:
|
||
"""
|
||
获取数据库会话的上下文管理器
|
||
|
||
使用示例:
|
||
with db_manager.get_session() as session:
|
||
result = session.query(User).all()
|
||
"""
|
||
session = self.session_factory()
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
logger.debug("Session committed successfully")
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Session rollback due to error: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
logger.debug("Session closed")
|
||
|
||
def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list:
|
||
"""执行原始SQL查询"""
|
||
with self.get_session() as session:
|
||
result = session.execute(text(sql), params or {})
|
||
return [dict(row._mapping) for row in result]
|
||
|
||
def health_check(self) -> bool:
|
||
"""数据库健康检查"""
|
||
try:
|
||
with self.engine.connect() as conn:
|
||
conn.execute(text("SELECT 1"))
|
||
logger.info("Database health check passed")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Database health check failed: {e}")
|
||
return False
|
||
|
||
def close(self) -> None:
|
||
"""关闭数据库连接"""
|
||
if self._engine:
|
||
self._engine.dispose()
|
||
self._engine = None
|
||
self._session_factory = None
|
||
logger.info("Database connections closed")
|
||
|
||
|
||
class AsyncDatabaseManager:
|
||
"""异步数据库管理器"""
|
||
|
||
def __init__(self):
|
||
self._engine: Optional[AsyncEngine] = None
|
||
self._async_session_factory: Optional[async_sessionmaker] = None
|
||
|
||
async def init_engine(self) -> None:
|
||
"""初始化异步数据库引擎"""
|
||
if self._engine is not None:
|
||
logger.warning("Async database engine already initialized")
|
||
return
|
||
|
||
try:
|
||
self._engine = create_async_engine(
|
||
settings.ASYNC_DATABASE_URL,
|
||
pool_size=settings.DATABASE_POOL_SIZE,
|
||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||
echo=settings.DATABASE_ECHO,
|
||
pool_pre_ping=True,
|
||
connect_args={
|
||
"server_settings": {
|
||
"search_path": settings.DATABASE_SCHEMA
|
||
}
|
||
}
|
||
)
|
||
|
||
self._async_session_factory = async_sessionmaker(
|
||
self._engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
autocommit=False
|
||
)
|
||
|
||
logger.info("Async database engine initialized successfully")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize async database engine: {e}")
|
||
raise
|
||
|
||
@property
|
||
async def engine(self) -> AsyncEngine:
|
||
"""获取异步数据库引擎"""
|
||
if self._engine is None:
|
||
await self.init_engine()
|
||
return self._engine
|
||
|
||
@property
|
||
async def async_session_factory(self) -> async_sessionmaker:
|
||
"""获取异步会话工厂"""
|
||
if self._async_session_factory is None:
|
||
await self.init_engine()
|
||
return self._async_session_factory
|
||
|
||
@asynccontextmanager
|
||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||
"""
|
||
获取异步数据库会话的上下文管理器
|
||
|
||
使用示例:
|
||
async with async_db_manager.get_session() as session:
|
||
result = await session.execute(query)
|
||
"""
|
||
if self._async_session_factory is None:
|
||
await self.init_engine()
|
||
|
||
session = self._async_session_factory()
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
logger.debug("Async session committed successfully")
|
||
except Exception as e:
|
||
await session.rollback()
|
||
logger.error(f"Async session rollback due to error: {e}")
|
||
raise
|
||
finally:
|
||
await session.close()
|
||
logger.debug("Async session closed")
|
||
|
||
async def health_check(self) -> bool:
|
||
"""异步数据库健康检查"""
|
||
try:
|
||
async with self._engine.connect() as conn:
|
||
await conn.execute(text("SELECT 1"))
|
||
logger.info("Async database health check passed")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Async database health check failed: {e}")
|
||
return False
|
||
|
||
async def close(self) -> None:
|
||
"""关闭异步数据库连接"""
|
||
if self._engine:
|
||
await self._engine.dispose()
|
||
self._engine = None
|
||
self._async_session_factory = None
|
||
logger.info("Async database connections closed")
|
||
|
||
|
||
# 创建全局数据库管理器实例
|
||
db_manager = DatabaseManager()
|
||
async_db_manager = AsyncDatabaseManager()
|
||
|
||
|
||
def init_database() -> None:
|
||
"""初始化数据库(同步)"""
|
||
db_manager.init_engine()
|
||
|
||
|
||
async def init_async_database() -> None:
|
||
"""初始化异步数据库"""
|
||
await async_db_manager.init_engine()
|
||
|
||
|
||
def get_db() -> Generator[Session, None, None]:
|
||
"""依赖注入:获取数据库会话(用于FastAPI等框架)"""
|
||
with db_manager.get_session() as session:
|
||
yield session
|
||
|
||
|
||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""依赖注入:获取异步数据库会话(用于FastAPI等框架)"""
|
||
async with async_db_manager.get_session() as session:
|
||
yield session |