修改路径,从src放到根目录
This commit is contained in:
252
database/connection.py
Normal file
252
database/connection.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
数据库连接管理模块
|
||||
支持同步和异步两种模式
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Generator, AsyncGenerator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
create_async_engine,
|
||||
async_sessionmaker
|
||||
)
|
||||
|
||||
from config import settings
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库管理器(同步)"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine: Optional[Engine] = None
|
||||
self._session_factory: Optional[sessionmaker] = None
|
||||
|
||||
def init_engine(self) -> None:
|
||||
"""初始化数据库引擎"""
|
||||
if self._engine is not None:
|
||||
logger.warning("Database engine already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
self._engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO,
|
||||
pool_pre_ping=True, # 连接前进行ping检查
|
||||
connect_args={
|
||||
"options": f"-csearch_path={settings.DATABASE_SCHEMA}"
|
||||
}
|
||||
)
|
||||
|
||||
self._session_factory = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=self._engine,
|
||||
expire_on_commit=False # 避免延迟加载问题
|
||||
)
|
||||
|
||||
logger.info("Database engine initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database engine: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def engine(self) -> Engine:
|
||||
"""获取数据库引擎"""
|
||||
if self._engine is None:
|
||||
self.init_engine()
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def session_factory(self) -> sessionmaker:
|
||||
"""获取会话工厂"""
|
||||
if self._session_factory is None:
|
||||
self.init_engine()
|
||||
return self._session_factory
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
获取数据库会话的上下文管理器
|
||||
|
||||
使用示例:
|
||||
with db_manager.get_session() as session:
|
||||
result = session.query(User).all()
|
||||
"""
|
||||
session = self.session_factory()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
logger.debug("Session committed successfully")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Session rollback due to error: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
logger.debug("Session closed")
|
||||
|
||||
def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list:
|
||||
"""执行原始SQL查询"""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), params or {})
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""数据库健康检查"""
|
||||
try:
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
logger.info("Database health check passed")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
if self._engine:
|
||||
self._engine.dispose()
|
||||
self._engine = None
|
||||
self._session_factory = None
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
class AsyncDatabaseManager:
|
||||
"""异步数据库管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine: Optional[AsyncEngine] = None
|
||||
self._async_session_factory: Optional[async_sessionmaker] = None
|
||||
|
||||
async def init_engine(self) -> None:
|
||||
"""初始化异步数据库引擎"""
|
||||
if self._engine is not None:
|
||||
logger.warning("Async database engine already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
self._engine = create_async_engine(
|
||||
settings.ASYNC_DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO,
|
||||
pool_pre_ping=True,
|
||||
connect_args={
|
||||
"server_settings": {
|
||||
"search_path": settings.DATABASE_SCHEMA
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self._async_session_factory = async_sessionmaker(
|
||||
self._engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False
|
||||
)
|
||||
|
||||
logger.info("Async database engine initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async database engine: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
async def engine(self) -> AsyncEngine:
|
||||
"""获取异步数据库引擎"""
|
||||
if self._engine is None:
|
||||
await self.init_engine()
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
async def async_session_factory(self) -> async_sessionmaker:
|
||||
"""获取异步会话工厂"""
|
||||
if self._async_session_factory is None:
|
||||
await self.init_engine()
|
||||
return self._async_session_factory
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
获取异步数据库会话的上下文管理器
|
||||
|
||||
使用示例:
|
||||
async with async_db_manager.get_session() as session:
|
||||
result = await session.execute(query)
|
||||
"""
|
||||
if self._async_session_factory is None:
|
||||
await self.init_engine()
|
||||
|
||||
session = self._async_session_factory()
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async session committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async session rollback due to error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Async session closed")
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""异步数据库健康检查"""
|
||||
try:
|
||||
async with self._engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.info("Async database health check passed")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {e}")
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭异步数据库连接"""
|
||||
if self._engine:
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
self._async_session_factory = None
|
||||
logger.info("Async database connections closed")
|
||||
|
||||
|
||||
# 创建全局数据库管理器实例
|
||||
db_manager = DatabaseManager()
|
||||
async_db_manager = AsyncDatabaseManager()
|
||||
|
||||
|
||||
def init_database() -> None:
|
||||
"""初始化数据库(同步)"""
|
||||
db_manager.init_engine()
|
||||
|
||||
|
||||
async def init_async_database() -> None:
|
||||
"""初始化异步数据库"""
|
||||
await async_db_manager.init_engine()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""依赖注入:获取数据库会话(用于FastAPI等框架)"""
|
||||
with db_manager.get_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""依赖注入:获取异步数据库会话(用于FastAPI等框架)"""
|
||||
async with async_db_manager.get_session() as session:
|
||||
yield session
|
||||
Reference in New Issue
Block a user