Files
SupervisorAI/src/database/connection.py
2026-01-04 09:45:46 +08:00

252 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接管理模块
支持同步和异步两种模式
"""
import logging
from typing import Optional, Generator, AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
async_sessionmaker
)
from config import settings
from utils.logger import setup_logger
logger = setup_logger(__name__)
class DatabaseManager:
"""数据库管理器(同步)"""
def __init__(self):
self._engine: Optional[Engine] = None
self._session_factory: Optional[sessionmaker] = None
def init_engine(self) -> None:
"""初始化数据库引擎"""
if self._engine is not None:
logger.warning("Database engine already initialized")
return
try:
self._engine = create_engine(
settings.DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO,
pool_pre_ping=True, # 连接前进行ping检查
connect_args={
"options": f"-csearch_path={settings.DATABASE_SCHEMA}"
}
)
self._session_factory = sessionmaker(
autocommit=False,
autoflush=False,
bind=self._engine,
expire_on_commit=False # 避免延迟加载问题
)
logger.info("Database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database engine: {e}")
raise
@property
def engine(self) -> Engine:
"""获取数据库引擎"""
if self._engine is None:
self.init_engine()
return self._engine
@property
def session_factory(self) -> sessionmaker:
"""获取会话工厂"""
if self._session_factory is None:
self.init_engine()
return self._session_factory
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""
获取数据库会话的上下文管理器
使用示例:
with db_manager.get_session() as session:
result = session.query(User).all()
"""
session = self.session_factory()
try:
yield session
session.commit()
logger.debug("Session committed successfully")
except Exception as e:
session.rollback()
logger.error(f"Session rollback due to error: {e}")
raise
finally:
session.close()
logger.debug("Session closed")
def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list:
"""执行原始SQL查询"""
with self.get_session() as session:
result = session.execute(text(sql), params or {})
return [dict(row._mapping) for row in result]
def health_check(self) -> bool:
"""数据库健康检查"""
try:
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Database health check passed")
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
def close(self) -> None:
"""关闭数据库连接"""
if self._engine:
self._engine.dispose()
self._engine = None
self._session_factory = None
logger.info("Database connections closed")
class AsyncDatabaseManager:
"""异步数据库管理器"""
def __init__(self):
self._engine: Optional[AsyncEngine] = None
self._async_session_factory: Optional[async_sessionmaker] = None
async def init_engine(self) -> None:
"""初始化异步数据库引擎"""
if self._engine is not None:
logger.warning("Async database engine already initialized")
return
try:
self._engine = create_async_engine(
settings.ASYNC_DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO,
pool_pre_ping=True,
connect_args={
"server_settings": {
"search_path": settings.DATABASE_SCHEMA
}
}
)
self._async_session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False
)
logger.info("Async database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize async database engine: {e}")
raise
@property
async def engine(self) -> AsyncEngine:
"""获取异步数据库引擎"""
if self._engine is None:
await self.init_engine()
return self._engine
@property
async def async_session_factory(self) -> async_sessionmaker:
"""获取异步会话工厂"""
if self._async_session_factory is None:
await self.init_engine()
return self._async_session_factory
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
获取异步数据库会话的上下文管理器
使用示例:
async with async_db_manager.get_session() as session:
result = await session.execute(query)
"""
if self._async_session_factory is None:
await self.init_engine()
session = self._async_session_factory()
try:
yield session
await session.commit()
logger.debug("Async session committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async session rollback due to error: {e}")
raise
finally:
await session.close()
logger.debug("Async session closed")
async def health_check(self) -> bool:
"""异步数据库健康检查"""
try:
async with self._engine.connect() as conn:
await conn.execute(text("SELECT 1"))
logger.info("Async database health check passed")
return True
except Exception as e:
logger.error(f"Async database health check failed: {e}")
return False
async def close(self) -> None:
"""关闭异步数据库连接"""
if self._engine:
await self._engine.dispose()
self._engine = None
self._async_session_factory = None
logger.info("Async database connections closed")
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
async_db_manager = AsyncDatabaseManager()
def init_database() -> None:
"""初始化数据库(同步)"""
db_manager.init_engine()
async def init_async_database() -> None:
"""初始化异步数据库"""
await async_db_manager.init_engine()
def get_db() -> Generator[Session, None, None]:
"""依赖注入获取数据库会话用于FastAPI等框架"""
with db_manager.get_session() as session:
yield session
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""依赖注入获取异步数据库会话用于FastAPI等框架"""
async with async_db_manager.get_session() as session:
yield session