修改路径,从src放到根目录

This commit is contained in:
zqc
2026-01-08 10:32:36 +08:00
parent 96589ebdbd
commit f86effd63c
37 changed files with 51 additions and 410 deletions

252
database/connection.py Normal file
View File

@@ -0,0 +1,252 @@
"""
数据库连接管理模块
支持同步和异步两种模式
"""
import logging
from typing import Optional, Generator, AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
async_sessionmaker
)
from config import settings
from utils.logger import setup_logger
logger = setup_logger(__name__)
class DatabaseManager:
"""数据库管理器(同步)"""
def __init__(self):
self._engine: Optional[Engine] = None
self._session_factory: Optional[sessionmaker] = None
def init_engine(self) -> None:
"""初始化数据库引擎"""
if self._engine is not None:
logger.warning("Database engine already initialized")
return
try:
self._engine = create_engine(
settings.DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO,
pool_pre_ping=True, # 连接前进行ping检查
connect_args={
"options": f"-csearch_path={settings.DATABASE_SCHEMA}"
}
)
self._session_factory = sessionmaker(
autocommit=False,
autoflush=False,
bind=self._engine,
expire_on_commit=False # 避免延迟加载问题
)
logger.info("Database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database engine: {e}")
raise
@property
def engine(self) -> Engine:
"""获取数据库引擎"""
if self._engine is None:
self.init_engine()
return self._engine
@property
def session_factory(self) -> sessionmaker:
"""获取会话工厂"""
if self._session_factory is None:
self.init_engine()
return self._session_factory
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""
获取数据库会话的上下文管理器
使用示例:
with db_manager.get_session() as session:
result = session.query(User).all()
"""
session = self.session_factory()
try:
yield session
session.commit()
logger.debug("Session committed successfully")
except Exception as e:
session.rollback()
logger.error(f"Session rollback due to error: {e}")
raise
finally:
session.close()
logger.debug("Session closed")
def execute_raw_sql(self, sql: str, params: Optional[dict] = None) -> list:
"""执行原始SQL查询"""
with self.get_session() as session:
result = session.execute(text(sql), params or {})
return [dict(row._mapping) for row in result]
def health_check(self) -> bool:
"""数据库健康检查"""
try:
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("Database health check passed")
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
def close(self) -> None:
"""关闭数据库连接"""
if self._engine:
self._engine.dispose()
self._engine = None
self._session_factory = None
logger.info("Database connections closed")
class AsyncDatabaseManager:
"""异步数据库管理器"""
def __init__(self):
self._engine: Optional[AsyncEngine] = None
self._async_session_factory: Optional[async_sessionmaker] = None
async def init_engine(self) -> None:
"""初始化异步数据库引擎"""
if self._engine is not None:
logger.warning("Async database engine already initialized")
return
try:
self._engine = create_async_engine(
settings.ASYNC_DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO,
pool_pre_ping=True,
connect_args={
"server_settings": {
"search_path": settings.DATABASE_SCHEMA
}
}
)
self._async_session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False
)
logger.info("Async database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize async database engine: {e}")
raise
@property
async def engine(self) -> AsyncEngine:
"""获取异步数据库引擎"""
if self._engine is None:
await self.init_engine()
return self._engine
@property
async def async_session_factory(self) -> async_sessionmaker:
"""获取异步会话工厂"""
if self._async_session_factory is None:
await self.init_engine()
return self._async_session_factory
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
获取异步数据库会话的上下文管理器
使用示例:
async with async_db_manager.get_session() as session:
result = await session.execute(query)
"""
if self._async_session_factory is None:
await self.init_engine()
session = self._async_session_factory()
try:
yield session
await session.commit()
logger.debug("Async session committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async session rollback due to error: {e}")
raise
finally:
await session.close()
logger.debug("Async session closed")
async def health_check(self) -> bool:
"""异步数据库健康检查"""
try:
async with self._engine.connect() as conn:
await conn.execute(text("SELECT 1"))
logger.info("Async database health check passed")
return True
except Exception as e:
logger.error(f"Async database health check failed: {e}")
return False
async def close(self) -> None:
"""关闭异步数据库连接"""
if self._engine:
await self._engine.dispose()
self._engine = None
self._async_session_factory = None
logger.info("Async database connections closed")
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
async_db_manager = AsyncDatabaseManager()
def init_database() -> None:
"""初始化数据库(同步)"""
db_manager.init_engine()
async def init_async_database() -> None:
"""初始化异步数据库"""
await async_db_manager.init_engine()
def get_db() -> Generator[Session, None, None]:
"""依赖注入获取数据库会话用于FastAPI等框架"""
with db_manager.get_session() as session:
yield session
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""依赖注入获取异步数据库会话用于FastAPI等框架"""
async with async_db_manager.get_session() as session:
yield session