good version for web

This commit is contained in:
2026-02-18 09:36:18 +08:00
parent 62ea5d36a5
commit 72ab0c0b56
42 changed files with 1305 additions and 1515 deletions

View File

@@ -5,13 +5,14 @@
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import Optional
from app.models.database import get_db from app.models.database import get_db
from app.services.user import UserService from app.services.user import UserService
from app.schemas.user import UserResponse from app.schemas.user import UserResponse
# OAuth2密码Bearer # OAuth2密码Bearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login", auto_error=False)
async def get_current_active_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)): async def get_current_active_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
@@ -25,4 +26,19 @@ async def get_current_active_user(db: Session = Depends(get_db), token: str = De
) )
if hasattr(user, 'status') and user.status != "active": if hasattr(user, 'status') and user.status != "active":
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return user return user
async def get_current_active_user_optional(db: Session = Depends(get_db), token: Optional[str] = Depends(oauth2_scheme)):
"""获取当前活跃用户可选未登录返回None"""
if not token:
return None
try:
user = UserService.get_current_user(db, token)
if not user:
return None
if hasattr(user, 'status') and user.status != "active":
return None
return user
except:
return None

Binary file not shown.

View File

@@ -237,6 +237,21 @@ async def delete_version(
# 算法调用相关路由 # 算法调用相关路由
@router.post("/call/public", response_model=AlgorithmCallResult)
async def call_algorithm_public(
call: AlgorithmCallCreate,
db: Session = Depends(get_db)
):
"""公开调用算法(不需要认证,用于演示页面)"""
# 使用匿名用户ID进行调用
anonymous_user_id = "anonymous"
# 执行算法
result = AlgorithmCallService.execute_algorithm(db, anonymous_user_id, call)
return result
@router.post("/call", response_model=AlgorithmCallResult) @router.post("/call", response_model=AlgorithmCallResult)
async def call_algorithm( async def call_algorithm(
call: AlgorithmCallCreate, call: AlgorithmCallCreate,

View File

@@ -1,6 +1,6 @@
"""数据管理路由,提供输入数据、输出结果和元数据的管理功能""" """数据管理路由,提供输入数据、输出结果和元数据的管理功能"""
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File, Form
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -8,7 +8,7 @@ import json
from app.services.data_manager import data_manager from app.services.data_manager import data_manager
from app.models.database import get_db from app.models.database import get_db
from app.dependencies import get_current_active_user from app.dependencies import get_current_active_user, get_current_active_user_optional
router = APIRouter(prefix="/data", tags=["data-management"]) router = APIRouter(prefix="/data", tags=["data-management"])
@@ -176,19 +176,18 @@ async def get_user_outputs(
@router.post("/media/upload") @router.post("/media/upload")
async def upload_media_file( async def upload_media_file(
file: UploadFile = File(...), file: UploadFile = File(...),
algorithm_id: str = None, algorithm_id: str = Form(...)
current_user: dict = Depends(get_current_active_user)
): ):
"""上传媒体文件(如图片、视频等)""" """上传媒体文件(如图片、视频等)- 公开API不需要认证"""
if not algorithm_id: if not algorithm_id:
raise HTTPException(status_code=400, detail="algorithm_id is required") raise HTTPException(status_code=400, detail="algorithm_id is required")
# 读取文件内容 # 读取文件内容
file_content = await file.read() file_content = await file.read()
# 保存到数据管理器 # 保存到数据管理器(使用匿名用户)
file_path = data_manager.save_media_file( file_path = data_manager.save_media_file(
user_id=current_user.get("id"), user_id="anonymous",
algorithm_id=algorithm_id, algorithm_id=algorithm_id,
file_content=file_content, file_content=file_content,
file_name=file.filename file_name=file.filename
@@ -209,10 +208,28 @@ async def upload_media_file(
@router.get("/media/{file_path:path}") @router.get("/media/{file_path:path}")
async def get_media_file( async def get_media_file(
file_path: str, file_path: str,
current_user: dict = Depends(get_current_active_user) current_user: Optional[dict] = Depends(get_current_active_user_optional)
): ):
"""获取媒体文件""" """获取媒体文件"""
# 检查用户权限 - 确保用户只能访问自己的文件或公共文件 # results 目录下的文件公开访问
if file_path.startswith("results/"):
content = data_manager.get_media_file(file_path)
if content:
# 从完整路径中提取文件名来获取正确的MIME类型
import mimetypes
filename = file_path.split('/')[-1]
content_type, _ = mimetypes.guess_type(filename)
if content_type is None:
content_type = "application/octet-stream"
from fastapi.responses import Response
return Response(content=content, media_type=content_type)
else:
raise HTTPException(status_code=404, detail="Media file not found")
# 其他文件需要用户权限
if current_user is None:
raise HTTPException(status_code=401, detail="Not authenticated")
if current_user.get("role") != "admin" and not file_path.startswith(f"media/{current_user.get('id')}/"): if current_user.get("role") != "admin" and not file_path.startswith(f"media/{current_user.get('id')}/"):
raise HTTPException(status_code=403, detail="Insufficient permissions") raise HTTPException(status_code=403, detail="Insufficient permissions")

View File

@@ -132,7 +132,7 @@ class BatchOperationResponse(BaseModel):
# 初始化服务组件 # 初始化服务组件
project_analyzer = ProjectAnalyzer() project_analyzer = ProjectAnalyzer()
service_generator = ServiceGenerator() service_generator = ServiceGenerator()
service_orchestrator = ServiceOrchestrator(deployment_mode=settings.DEPLOYMENT_MODE) service_orchestrator = ServiceOrchestrator(deployment_mode=settings.DEPLOYMENT_MODE, db_url=settings.DATABASE_URL)
@router.post("/register", status_code=status.HTTP_201_CREATED) @router.post("/register", status_code=status.HTTP_201_CREATED)

View File

@@ -4,6 +4,10 @@ from sqlalchemy.orm import Session
import uuid import uuid
import requests import requests
import time import time
import logging
import os
logger = logging.getLogger(__name__)
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
@@ -97,14 +101,39 @@ class AlgorithmService:
@staticmethod @staticmethod
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]: def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
"""获取算法列表""" """获取算法列表,优先显示已注册的服务"""
from app.models.models import AlgorithmService as Service
# 获取所有已注册的服务
services = db.query(Service).all()
# 获取所有算法
query = db.query(Algorithm) query = db.query(Algorithm)
# 如果指定了算法类型,进行过滤 # 如果指定了算法类型,进行过滤
if algorithm_type: if algorithm_type:
query = query.filter(Algorithm.type == algorithm_type) query = query.filter(Algorithm.type == algorithm_type)
return query.offset(skip).limit(limit).all() algorithms = query.all()
# 创建服务名称集合,用于快速查找
service_names = {service.name for service in services}
# 将算法分为两类:已注册的服务和普通算法
registered_algorithms = []
normal_algorithms = []
for algo in algorithms:
if algo.name in service_names:
registered_algorithms.append(algo)
else:
normal_algorithms.append(algo)
# 合并结果:已注册的服务在前,普通算法在后
merged_algorithms = registered_algorithms + normal_algorithms
# 应用分页
return merged_algorithms[skip:skip+limit]
@staticmethod @staticmethod
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]: def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
@@ -358,14 +387,37 @@ class AlgorithmCallService:
# 记录开始时间 # 记录开始时间
start_time = time.time() start_time = time.time()
# 处理视频路径 - 如果是MinIO路径下载到本地
from app.utils.file import file_storage
video_path = processed_input_data.get('video', '')
if video_path and video_path.startswith('media/'):
# 从MinIO下载视频到本地
video_content = file_storage.get_object(video_path)
if video_content:
# 保存到临时文件
import tempfile
import uuid
suffix = '.' + video_path.split('.')[-1] if '.' in video_path else '.mp4'
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(video_content)
local_video_path = tmp_file.name
# 使用本地路径
processed_input_data['video'] = local_video_path
else:
db_call.status = "failed"
db_call.error_message = "无法下载视频文件"
db.commit()
return db_call
# 调用算法API # 调用算法API
print(f"[DEBUG] 调用算法API: {version.url}, 输入: {processed_input_data}")
response = requests.post( response = requests.post(
version.url, version.url,
json={ json={
"input_data": processed_input_data, "input_data": processed_input_data,
"params": call.params "params": call.params
}, },
timeout=30 timeout=120
) )
# 计算响应时间 # 计算响应时间
@@ -374,7 +426,35 @@ class AlgorithmCallService:
# 处理响应 # 处理响应
if response.status_code == 200: if response.status_code == 200:
output_data = response.json() output_data = response.json()
print(f"[DEBUG] 算法响应: {output_data}")
# 处理算法返回的视频文件
result = output_data.get('result', {})
if result and isinstance(result, dict):
# 如果返回了本地视频路径,需要处理
if 'video' in result and result['video'] and result['video'].startswith('/'):
# 这是本地路径需要转换为可访问的URL
local_video = result['video']
if os.path.exists(local_video):
# 上传到MinIO并替换路径
with open(local_video, 'rb') as f:
video_content = f.read()
video_filename = f"results/{user_id}/{uuid.uuid4().hex[:12]}.mp4"
from app.utils.file import file_storage
success = file_storage.upload_from_bytes(video_content, video_filename)
if success:
result['video'] = video_filename
result['video_url'] = f"/api/v1/data/media/{video_filename}"
# 删除临时文件
try:
os.remove(local_video)
except:
pass
db_call.status = "success" db_call.status = "success"
print(f"[DEBUG] 保存output_data: {output_data}")
db_call.output_data = output_data db_call.output_data = output_data
db_call.response_time = response_time db_call.response_time = response_time
else: else:

View File

@@ -235,7 +235,7 @@ class DataManager:
def get_media_file(self, file_path: str) -> Optional[bytes]: def get_media_file(self, file_path: str) -> Optional[bytes]:
"""获取媒体文件内容""" """获取媒体文件内容"""
try: try:
content = file_storage.download_file(file_path) content = file_storage.get_object(file_path)
return content return content
except Exception as e: except Exception as e:
logger.error(f"Error getting media file: {str(e)}") logger.error(f"Error getting media file: {str(e)}")

View File

@@ -11,18 +11,41 @@ import psutil
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from docker.errors import DockerException, NotFound from docker.errors import DockerException, NotFound
# 数据库相关导入
try:
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
DATABASE_AVAILABLE = True
except ImportError:
DATABASE_AVAILABLE = False
class ServiceOrchestrator: class ServiceOrchestrator:
"""服务编排服务""" """服务编排服务"""
def __init__(self, deployment_mode="local"): def __init__(self, deployment_mode="local", db_url=None):
"""初始化服务编排器 """初始化服务编排器
Args: Args:
deployment_mode: 部署模式,支持"docker""local" deployment_mode: 部署模式,支持"docker""local"
db_url: 数据库连接URL用于重新加载服务信息
""" """
self.deployment_mode = deployment_mode self.deployment_mode = deployment_mode
self.processes = {} # 存储本地进程信息 self.processes = {} # 存储本地进程信息
self.db_url = db_url
self.db_engine = None
self.db_session = None
# 初始化数据库连接
if db_url and DATABASE_AVAILABLE:
try:
self.db_engine = create_engine(db_url)
self.db_session = sessionmaker(bind=self.db_engine)()
print("数据库连接成功")
except Exception as e:
print(f"数据库连接失败: {e}")
self.db_engine = None
self.db_session = None
if deployment_mode == "docker": if deployment_mode == "docker":
try: try:
@@ -178,15 +201,17 @@ class ServiceOrchestrator:
# 本地进程启动 # 本地进程启动
if service_id not in self.processes: if service_id not in self.processes:
# 服务不在进程列表中,可能是服务重启导致的 # 服务不在进程列表中,可能是服务重启导致的
# 这种情况下,需要从外部重新注册服务 # 尝试从数据库重新加载服务信息
# 暂时返回错误,建议用户重新注册服务 print(f"服务 {service_id} 不在进程列表中,尝试从数据库重新加载")
print(f"服务 {service_id} 不在进程列表中,无法启动") service_info = self.reload_service_from_db(service_id)
return {
"success": False, if not service_info:
"error": "服务不存在,请重新注册服务", return {
"service_id": service_id, "success": False,
"status": "error" "error": "服务不存在,请重新注册服务",
} "service_id": service_id,
"status": "error"
}
process_info = self.processes[service_id] process_info = self.processes[service_id]
@@ -209,11 +234,76 @@ class ServiceOrchestrator:
project_info = process_info["project_info"] project_info = process_info["project_info"]
service_config = process_info["service_config"] service_config = process_info["service_config"]
print(f"准备启动服务 {service_id}")
print(f"服务目录: {service_dir}")
print(f"服务配置: {service_config}")
# 检查服务目录是否存在如果不存在则从Gitea克隆
if not os.path.exists(service_dir):
print(f"服务目录不存在: {service_dir}尝试从Gitea克隆代码")
repository_id = service_config.get("repository_id")
if not repository_id:
return {
"success": False,
"error": "无法获取仓库ID无法克隆代码",
"service_id": service_id,
"status": "error"
}
# 从数据库获取仓库信息
try:
from app.models.database import SessionLocal
from app.models.models import AlgorithmRepository
db = SessionLocal()
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
db.close()
if not repository:
return {
"success": False,
"error": f"仓库不存在: {repository_id}",
"service_id": service_id,
"status": "error"
}
# 克隆仓库
from app.gitea.service import GiteaService
gitea_service = GiteaService()
clone_success = gitea_service.clone_repository(
repository.repo_url,
service_id,
repository.branch or "main"
)
if not clone_success:
return {
"success": False,
"error": f"克隆仓库失败: {repository.repo_url}",
"service_id": service_id,
"status": "error"
}
print(f"成功从Gitea克隆仓库到: {service_dir}")
except Exception as e:
print(f"克隆仓库时出错: {str(e)}")
return {
"success": False,
"error": f"克隆仓库时出错: {str(e)}",
"service_id": service_id,
"status": "error"
}
# 启动服务进程 # 启动服务进程
print(f"开始启动服务进程...")
new_process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config) new_process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config)
print(f"服务进程启动完成: {new_process_info}")
# 验证服务启动 # 验证服务启动
print(f"开始验证服务启动...")
if not self._verify_local_service_startup(service_id, service_config): if not self._verify_local_service_startup(service_id, service_config):
print(f"服务启动验证失败")
return { return {
"success": False, "success": False,
"error": "服务启动验证失败", "error": "服务启动验证失败",
@@ -221,6 +311,7 @@ class ServiceOrchestrator:
"status": "error" "status": "error"
} }
print(f"服务启动成功!")
return { return {
"success": True, "success": True,
"service_id": service_id, "service_id": service_id,
@@ -610,6 +701,76 @@ class ServiceOrchestrator:
"health": "unknown" "health": "unknown"
} }
def reload_service_from_db(self, service_id: str) -> Optional[Dict[str, Any]]:
"""从数据库重新加载服务信息
Args:
service_id: 服务ID
Returns:
服务信息字典如果未找到则返回None
"""
if not self.db_session:
print("数据库连接不可用,无法重新加载服务信息")
return None
try:
# 从数据库查询服务信息
query = text("""
SELECT service_id, api_url, status, config, host, port
FROM algorithm_services
WHERE service_id = :service_id
""")
result = self.db_session.execute(query, {"service_id": service_id}).fetchone()
if not result:
print(f"数据库中未找到服务 {service_id}")
return None
# 构建服务信息字典
# config字段已经是JSON类型不需要再解析
config_data = result[3] if result[3] else {}
# 将host和port添加到config中
config_data["host"] = result[4] if result[4] else "localhost"
config_data["port"] = result[5] if result[5] else 8000
service_info = {
"service_id": result[0],
"api_url": result[1],
"status": result[2],
"config": config_data
}
# 从config中提取容器ID
container_id = service_info["config"].get("container_id")
if container_id:
service_info["container_id"] = container_id
# 更新本地进程缓存(仅用于本地模式)
if self.deployment_mode == "local":
# 从config中获取项目类型如果没有则默认为python
project_type = service_info["config"].get("project_type", "python")
self.processes[service_id] = {
"service_dir": f"/tmp/algorithms/{service_id}",
"project_info": {
"project_type": project_type,
"name": service_info["config"].get("name", ""),
"version": service_info["config"].get("version", "1.0.0"),
"description": service_info["config"].get("description", "")
},
"service_config": service_info["config"],
"pid": None # PID需要重新获取
}
print(f"成功从数据库重新加载服务 {service_id} 的信息")
return service_info
except Exception as e:
print(f"从数据库重新加载服务信息失败: {e}")
return None
def get_service_logs(self, container_id: str, lines: int = 100) -> Dict[str, Any]: def get_service_logs(self, container_id: str, lines: int = 100) -> Dict[str, Any]:
"""获取服务日志 """获取服务日志
@@ -1429,7 +1590,10 @@ def main(data):
import requests import requests
# 构建健康检查URL # 构建健康检查URL
# 使用localhost而不是0.0.0.0,因为健康检查是在本地执行的
host = service_config.get("host", "localhost") host = service_config.get("host", "localhost")
if host == "0.0.0.0":
host = "localhost"
port = service_config.get("port", 8000) port = service_config.get("port", 8000)
health_check_url = f"http://{host}:{port}/health" health_check_url = f"http://{host}:{port}/health"
@@ -1437,5 +1601,6 @@ def main(data):
response = requests.get(health_check_url, timeout=10) response = requests.get(health_check_url, timeout=10)
return response.status_code == 200 return response.status_code == 200
except: except Exception as e:
print(f"健康检查失败: {e}")
return False return False

View File

@@ -59,6 +59,27 @@ class MinioClient:
logging.warning(f"MinIO upload error: {e}") logging.warning(f"MinIO upload error: {e}")
return False return False
def upload_from_bytes(self, data: bytes, object_name: str) -> bool:
"""从字节数据上传文件"""
if not self.is_connected:
logging.warning("MinIO is not connected. Upload skipped.")
return False
try:
import io
file_obj = io.BytesIO(data)
self.client.put_object(
self.bucket_name,
object_name,
file_obj,
length=len(data),
part_size=10*1024*1024
)
return True
except S3Error as e:
logging.warning(f"MinIO upload error: {e}")
return False
def upload_fileobj(self, file_obj: io.BytesIO, object_name: str, content_type: str = "application/octet-stream") -> bool: def upload_fileobj(self, file_obj: io.BytesIO, object_name: str, content_type: str = "application/octet-stream") -> bool:
"""上传文件对象""" """上传文件对象"""
if not self.is_connected: if not self.is_connected:

View File

@@ -1,2 +0,0 @@
INFO: Will watch for changes in these directories: ['/Users/duguoyou/MLFlow/algorithm-showcase/backend']
ERROR: [Errno 48] Address already in use

View File

@@ -1,79 +0,0 @@
#!/usr/bin/env python3
"""
检查并重置管理员账号
"""
from sqlalchemy.orm import Session
from app.models.database import engine, Base, SessionLocal
from app.models.models import User, Role
from app.services.user import UserService
def check_admin():
"""检查并重置管理员账号"""
db = SessionLocal()
try:
# 检查是否存在管理员账号
print("检查管理员账号...")
admin_user = db.query(User).filter(User.username == "admin").first()
if not admin_user:
print("⚠️ 管理员账号不存在,创建新的管理员账号...")
# 初始化默认角色
UserService.init_default_roles(db)
# 获取默认管理员角色
admin_role = UserService.get_role_by_name(db, "admin")
if not admin_role:
print("❌ 管理员角色不存在,创建失败")
return
# 创建默认管理员账号
admin_user = User(
id="user-admin",
username="admin",
email="admin@example.com",
password_hash=UserService.get_password_hash("admin123"),
role_id=admin_role.id,
status="active"
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
print("✅ 管理员账号创建成功")
else:
print("✅ 管理员账号存在,重置密码...")
# 重置管理员密码
admin_user.password_hash = UserService.get_password_hash("admin123")
db.commit()
print("✅ 管理员密码重置成功")
# 显示管理员账号信息
print("\n管理员账号信息:")
print(f"用户名: {admin_user.username}")
print(f"密码: admin123")
print(f"邮箱: {admin_user.email}")
print(f"状态: {admin_user.status}")
# 检查角色信息
if admin_user.role:
print(f"角色: {admin_user.role.name}")
else:
print("⚠️ 管理员角色信息缺失")
# 尝试修复角色关联
admin_role = UserService.get_role_by_name(db, "admin")
if admin_role:
admin_user.role_id = admin_role.id
db.commit()
print("✅ 管理员角色关联修复成功")
except Exception as e:
print(f"❌ 操作失败: {e}")
finally:
db.close()
if __name__ == "__main__":
check_admin()

View File

@@ -1,38 +0,0 @@
#!/usr/bin/env python3
"""检查算法数据"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app.models.database import SessionLocal
from app.models.models import Algorithm
def check_algorithms():
"""检查算法数据"""
db = SessionLocal()
try:
algorithms = db.query(Algorithm).all()
print(f"数据库中共有 {len(algorithms)} 个算法:\n")
for algo in algorithms:
print(f"算法名称: {algo.name}")
print(f" ID: {algo.id}")
print(f" 类型: {algo.type}")
print(f" 技术分类: {algo.tech_category}")
print(f" 输出类型: {algo.output_type}")
print(f" 描述: {algo.description}")
print(f" 状态: {algo.status}")
print(f" 版本数: {len(algo.versions)}")
print()
except Exception as e:
print(f"检查算法数据失败: {e}")
sys.exit(1)
finally:
db.close()
if __name__ == "__main__":
check_algorithms()

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python3
"""检查用户角色信息"""
import requests
def check_user_role():
"""检查用户角色"""
base_url = "http://localhost:8001/api/v1"
# 登录
print("步骤1: 登录")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
if response.status_code != 200:
print(f"登录失败: {response.text}")
return
data = response.json()
access_token = data.get('access_token')
print(f"登录成功!")
# 获取用户信息
print("\n步骤2: 获取用户信息")
headers = {"Authorization": f"Bearer {access_token}"}
user_response = requests.get(f"{base_url}/users/me", headers=headers)
print(f"状态码: {user_response.status_code}")
if user_response.status_code == 200:
user_data = user_response.json()
print(f"\n用户信息:")
print(f" 用户名: {user_data.get('username', 'N/A')}")
print(f" 邮箱: {user_data.get('email', 'N/A')}")
print(f" 角色ID: {user_data.get('role_id', 'N/A')}")
print(f" 角色名称: {user_data.get('role_name', 'N/A')}")
print(f" 角色对象: {user_data.get('role', 'N/A')}")
# 检查是否是管理员
role_name = user_data.get('role_name')
if role_name == 'admin':
print(f"\n✅ 用户是管理员,应该显示后台管理页面")
else:
print(f"\n❌ 用户不是管理员,角色名称是: {role_name}")
else:
print(f"获取用户信息失败: {user_response.text}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
check_user_role()

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env python3
"""检查数据库中的用户信息"""
import sys
sys.path.insert(0, '/Users/duguoyou/MLFlow/algorithm-showcase/backend')
from app.models.database import SessionLocal
from app.models.models import User
from app.services.user import UserService
def check_users():
"""检查用户"""
db = SessionLocal()
try:
# 获取所有用户
users = db.query(User).all()
print(f"数据库中的用户数量: {len(users)}")
for user in users:
print(f"\n用户ID: {user.id}")
print(f"用户名: {user.username}")
print(f"邮箱: {user.email}")
print(f"状态: {user.status}")
print(f"角色ID: {user.role_id}")
print(f"密码哈希: {user.password_hash[:50]}...")
# 测试admin用户认证
print("\n\n测试admin用户认证:")
admin_user = UserService.get_user_by_username(db, 'admin')
if admin_user:
print(f"找到admin用户: {admin_user.id}")
print(f"密码哈希: {admin_user.password_hash[:50]}...")
# 测试密码验证
test_password = 'admin123'
is_valid = UserService.verify_password(test_password, admin_user.password_hash)
print(f"密码 '{test_password}' 验证结果: {is_valid}")
# 尝试认证
authenticated_user = UserService.authenticate_user(db, 'admin', test_password)
if authenticated_user:
print(f"认证成功: {authenticated_user.id}")
else:
print("认证失败")
else:
print("未找到admin用户")
finally:
db.close()
if __name__ == "__main__":
check_users()

View File

@@ -1,151 +0,0 @@
#!/usr/bin/env python3
"""创建示例算法数据"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app.models.database import SessionLocal
from app.models.models import Algorithm, AlgorithmVersion
from datetime import datetime
import uuid
def create_sample_algorithms():
"""创建示例算法"""
db = SessionLocal()
try:
# 示例算法数据
algorithms_data = [
{
"name": "目标检测",
"description": "识别图像中的物体位置和类别,支持人脸、车辆、物品等多种目标检测",
"type": "computer_vision",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8001",
"is_default": True
}
]
},
{
"name": "视频分析",
"description": "分析视频内容,提取关键帧、识别动作、追踪物体等",
"type": "computer_vision",
"tech_category": "video_processing",
"output_type": "video",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8002",
"is_default": True
}
]
},
{
"name": "图像增强",
"description": "提升图像质量,包括去噪、超分辨率、色彩校正等功能",
"type": "computer_vision",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8003",
"is_default": True
}
]
},
{
"name": "文本分类",
"description": "对文本内容进行分类,支持新闻分类、情感分析、垃圾邮件识别等",
"type": "nlp",
"tech_category": "nlp",
"output_type": "text",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8004",
"is_default": True
}
]
},
{
"name": "异常检测",
"description": "检测数据中的异常模式,适用于工业监控、金融风控等场景",
"type": "ml",
"tech_category": "ml",
"output_type": "json",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8005",
"is_default": True
}
]
},
{
"name": "医学影像分析",
"description": "分析医学影像辅助医生进行疾病诊断支持CT、MRI等多种影像格式",
"type": "medical",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8006",
"is_default": True
}
]
}
]
# 创建算法
for algo_data in algorithms_data:
# 检查算法是否已存在
existing_algo = db.query(Algorithm).filter(Algorithm.name == algo_data["name"]).first()
if existing_algo:
print(f"✓ 算法 '{algo_data['name']}' 已存在,跳过")
continue
# 创建算法
algorithm = Algorithm(
id=str(uuid.uuid4()),
name=algo_data["name"],
description=algo_data["description"],
type=algo_data["type"],
tech_category=algo_data["tech_category"],
output_type=algo_data["output_type"],
status="active"
)
db.add(algorithm)
db.flush() # 获取算法ID
# 创建版本
for version_data in algo_data["versions"]:
version = AlgorithmVersion(
id=str(uuid.uuid4()),
algorithm_id=algorithm.id,
version=version_data["version"],
url=version_data["url"],
is_default=version_data["is_default"]
)
db.add(version)
print(f"✓ 已创建算法: {algo_data['name']}")
db.commit()
print("\n示例算法创建完成!")
except Exception as e:
db.rollback()
print(f"创建示例算法失败: {e}")
sys.exit(1)
finally:
db.close()
if __name__ == "__main__":
create_sample_algorithms()

View File

View File

@@ -1,45 +0,0 @@
#!/usr/bin/env python3
"""数据库迁移脚本:添加技术分类和输出类型字段"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sqlalchemy import text
from app.models.database import engine
def migrate():
"""执行数据库迁移"""
try:
with engine.connect() as conn:
# 检查字段是否已存在PostgreSQL语法
result = conn.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'algorithms'
"""))
columns = [row[0] for row in result.fetchall()]
# 添加 tech_category 字段
if 'tech_category' not in columns:
conn.execute(text("ALTER TABLE algorithms ADD COLUMN tech_category VARCHAR(50) DEFAULT 'computer_vision'"))
print("✓ 已添加 tech_category 字段")
else:
print("✓ tech_category 字段已存在")
# 添加 output_type 字段
if 'output_type' not in columns:
conn.execute(text("ALTER TABLE algorithms ADD COLUMN output_type VARCHAR(50) DEFAULT 'image'"))
print("✓ 已添加 output_type 字段")
else:
print("✓ output_type 字段已存在")
conn.commit()
print("\n数据库迁移完成!")
except Exception as e:
print(f"数据库迁移失败: {e}")
sys.exit(1)
if __name__ == "__main__":
migrate()

View File

@@ -1,45 +0,0 @@
#!/usr/bin/env python3
"""数据库迁移脚本为algorithm_services表添加技术分类和输出类型字段"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sqlalchemy import text
from app.models.database import engine
def migrate():
"""执行数据库迁移"""
try:
with engine.connect() as conn:
# 检查字段是否已存在PostgreSQL语法
result = conn.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'algorithm_services'
"""))
columns = [row[0] for row in result.fetchall()]
# 添加 tech_category 字段
if 'tech_category' not in columns:
conn.execute(text("ALTER TABLE algorithm_services ADD COLUMN tech_category VARCHAR(50) DEFAULT 'computer_vision'"))
print("✓ 已添加 tech_category 字段到 algorithm_services 表")
else:
print("✓ tech_category 字段已存在于 algorithm_services 表")
# 添加 output_type 字段
if 'output_type' not in columns:
conn.execute(text("ALTER TABLE algorithm_services ADD COLUMN output_type VARCHAR(50) DEFAULT 'image'"))
print("✓ 已添加 output_type 字段到 algorithm_services 表")
else:
print("✓ output_type 字段已存在于 algorithm_services 表")
conn.commit()
print("\n数据库迁移完成!")
except Exception as e:
print(f"数据库迁移失败: {e}")
sys.exit(1)
if __name__ == "__main__":
migrate()

View File

@@ -1,48 +0,0 @@
#!/usr/bin/env python3
"""测试所有API端点"""
import requests
def test_apis():
"""测试API端点"""
base_url = "http://localhost:8001/api/v1"
# 测试算法列表(不需要认证)
print("1. 测试算法列表(不需要认证):")
try:
response = requests.get(f"{base_url}/algorithms/")
print(f" 状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f" 成功获取 {len(data.get('algorithms', []))} 个算法")
else:
print(f" 失败: {response.text}")
except Exception as e:
print(f" 错误: {e}")
# 测试用户信息(需要认证)
print("\n2. 测试用户信息(需要认证):")
try:
response = requests.get(f"{base_url}/users/me")
print(f" 状态码: {response.status_code}")
if response.status_code == 401:
print(f" 需要认证(正常)")
else:
print(f" 响应: {response.text}")
except Exception as e:
print(f" 错误: {e}")
# 测试服务列表(需要认证)
print("\n3. 测试服务列表(需要认证):")
try:
response = requests.get(f"{base_url}/services")
print(f" 状态码: {response.status_code}")
if response.status_code == 401:
print(f" 需要认证(正常)")
else:
print(f" 响应: {response.text[:200]}")
except Exception as e:
print(f" 错误: {e}")
if __name__ == "__main__":
test_apis()

View File

@@ -1,53 +0,0 @@
#!/usr/bin/env python3
"""测试前端API调用"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import requests
def test_api():
"""测试API"""
try:
# 调用算法列表API
response = requests.get('http://localhost:8001/api/v1/algorithms/')
if response.status_code == 200:
data = response.json()
algorithms = data.get('algorithms', [])
print(f"成功获取 {len(algorithms)} 个算法\n")
# 检查每个算法的字段
for algo in algorithms:
print(f"算法: {algo['name']}")
print(f" 技术分类: {algo.get('tech_category', 'N/A')}")
print(f" 输出类型: {algo.get('output_type', 'N/A')}")
print()
# 测试筛选
print("测试筛选功能:")
# 按技术分类筛选
cv_algorithms = [a for a in algorithms if a.get('tech_category') == 'computer_vision']
print(f" 计算机视觉算法: {len(cv_algorithms)}")
# 按输出类型筛选
image_algorithms = [a for a in algorithms if a.get('output_type') == 'image']
print(f" 图片输出算法: {len(image_algorithms)}")
# 按名称搜索
search_results = [a for a in algorithms if '视频' in a.get('name', '')]
print(f" 包含'视频'的算法: {len(search_results)}")
else:
print(f"API调用失败: {response.status_code}")
print(response.text)
except Exception as e:
print(f"测试失败: {e}")
sys.exit(1)
if __name__ == "__main__":
test_api()

View File

@@ -1,32 +0,0 @@
#!/usr/bin/env python3
"""测试前端代理配置"""
import requests
def test_frontend_proxy():
"""测试前端代理"""
try:
# 测试前端代理
response = requests.get('http://localhost:3000/api/algorithms')
print(f"状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"成功获取 {len(data.get('algorithms', []))} 个算法")
# 检查第一个算法的字段
if data.get('algorithms'):
first_algo = data['algorithms'][0]
print(f"\n第一个算法:")
print(f" 名称: {first_algo.get('name')}")
print(f" 技术分类: {first_algo.get('tech_category')}")
print(f" 输出类型: {first_algo.get('output_type')}")
else:
print(f"请求失败: {response.text}")
except Exception as e:
print(f"测试失败: {e}")
if __name__ == "__main__":
test_frontend_proxy()

View File

@@ -1,48 +0,0 @@
#!/usr/bin/env python3
"""测试完整的登录流程"""
import requests
def test_full_login_flow():
"""测试完整的登录流程"""
base_url = "http://localhost:8001/api/v1"
# 步骤1: 登录
print("步骤1: 登录")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
if response.status_code != 200:
print(f"登录失败: {response.text}")
return
data = response.json()
access_token = data.get('access_token')
print(f"登录成功!")
print(f"Token: {access_token[:50]}...")
# 步骤2: 使用token获取用户信息
print("\n步骤2: 获取用户信息")
headers = {"Authorization": f"Bearer {access_token}"}
user_response = requests.get(f"{base_url}/users/me", headers=headers)
print(f"状态码: {user_response.status_code}")
if user_response.status_code == 200:
user_data = user_response.json()
print(f"用户名: {user_data.get('username', 'N/A')}")
print(f"邮箱: {user_data.get('email', 'N/A')}")
print(f"角色: {user_data.get('role_name', 'N/A')}")
else:
print(f"获取用户信息失败: {user_response.text}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
test_full_login_flow()

View File

@@ -1,45 +0,0 @@
#!/usr/bin/env python3
"""测试登录功能"""
import requests
def test_login():
"""测试登录"""
base_url = "http://localhost:8001/api/v1"
# 测试登录
print("测试登录功能:")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"登录成功!")
print(f"访问令牌: {data.get('access_token', 'N/A')[:50]}...")
print(f"令牌类型: {data.get('token_type', 'N/A')}")
# 测试使用令牌访问受保护的API
if data.get('access_token'):
headers = {"Authorization": f"Bearer {data['access_token']}"}
user_response = requests.get(f"{base_url}/users/me", headers=headers)
print(f"\n测试用户信息API:")
print(f"状态码: {user_response.status_code}")
if user_response.status_code == 200:
user_data = user_response.json()
print(f"用户名: {user_data.get('username', 'N/A')}")
print(f"邮箱: {user_data.get('email', 'N/A')}")
else:
print(f"失败: {user_response.text}")
else:
print(f"登录失败: {response.text}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
test_login()

View File

@@ -1,53 +0,0 @@
#!/usr/bin/env python3
"""直接测试登录API"""
import requests
def test_login_api():
"""测试登录API"""
base_url = "http://localhost:8001/api/v1"
# 测试1: 使用JSON格式
print("测试1: 使用JSON格式")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}")
if response.status_code == 200:
data = response.json()
print(f"✅ 登录成功!")
print(f"Token: {data.get('access_token', 'N/A')[:50]}...")
else:
print(f"❌ 登录失败")
except Exception as e:
print(f"错误: {e}")
# 测试2: 使用form-data格式
print("\n\n测试2: 使用form-data格式")
form_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", data=form_data)
print(f"状态码: {response.status_code}")
print(f"响应内容: {response.text[:500]}")
if response.status_code == 200:
data = response.json()
print(f"✅ 登录成功!")
else:
print(f"❌ 登录失败")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
test_login_api()

View File

@@ -1,232 +0,0 @@
import requests
import json
import time
from typing import Dict, Any, List
class SystemTester:
def __init__(self, base_url: str = "http://localhost:8001/api/v1"):
self.base_url = base_url
self.session = requests.Session()
self.token = None
self.user_id = None
def login(self, username: str = "admin", password: str = "admin123") -> bool:
"""登录系统"""
try:
response = self.session.post(
f"{self.base_url}/users/login",
json={"username": username, "password": password}
)
if response.status_code == 200:
data = response.json()
self.token = data.get("access_token")
self.user_id = data.get("user_id")
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
print(f"✓ 登录成功: {username}")
return True
else:
print(f"✗ 登录失败: {response.status_code} - {response.text}")
return False
except Exception as e:
print(f"✗ 登录异常: {str(e)}")
return False
def test_config_endpoints(self) -> bool:
"""测试配置管理API"""
print("\n=== 测试配置管理API ===")
success = True
try:
# 测试获取所有配置
response = self.session.get(f"{self.base_url}/config/")
if response.status_code == 200:
print("✓ 获取所有配置成功")
configs = response.json().get("configs", [])
print(f" 当前配置数量: {len(configs)}")
else:
print(f"✗ 获取所有配置失败: {response.status_code}")
success = False
# 测试添加配置
test_config = {
"value": "test_value_123",
"type": "system",
"service_id": None,
"description": "测试配置"
}
response = self.session.post(f"{self.base_url}/config/test_config_key", json=test_config)
if response.status_code == 200:
print("✓ 添加配置成功")
else:
print(f"✗ 添加配置失败: {response.status_code} - {response.text}")
success = False
# 测试获取单个配置
response = self.session.get(f"{self.base_url}/config/test_config_key")
if response.status_code == 200:
print("✓ 获取单个配置成功")
config_data = response.json()
print(f" 配置值: {config_data.get('value')}")
else:
print(f"✗ 获取单个配置失败: {response.status_code}")
success = False
# 测试删除配置
response = self.session.delete(f"{self.base_url}/config/test_config_key")
if response.status_code == 200:
print("✓ 删除配置成功")
else:
print(f"✗ 删除配置失败: {response.status_code}")
success = False
return success
except Exception as e:
print(f"✗ 配置管理API测试异常: {str(e)}")
return False
def test_comparison_endpoints(self) -> bool:
"""测试算法比较API"""
print("\n=== 测试算法比较API ===")
success = True
try:
# 测试算法比较(使用模拟数据)
test_data = {
"input_data": {"text": "这是一段测试文本"},
"algorithm_configs": [
{
"algorithm_id": "test_algo_1",
"algorithm_name": "测试算法1",
"version": "1.0.0",
"config": "{}"
},
{
"algorithm_id": "test_algo_2",
"algorithm_name": "测试算法2",
"version": "1.0.0",
"config": "{}"
}
]
}
response = self.session.post(f"{self.base_url}/comparison/compare-algorithms", json=test_data)
if response.status_code == 200:
print("✓ 算法比较API调用成功")
result = response.json()
print(f" 比较状态: {result.get('success')}")
if result.get('results'):
print(f" 结果数量: {len(result.get('results'))}")
else:
print(f"✗ 算法比较失败: {response.status_code} - {response.text}")
success = False
return success
except Exception as e:
print(f"✗ 算法比较API测试异常: {str(e)}")
return False
def test_existing_endpoints(self) -> bool:
"""测试现有API端点"""
print("\n=== 测试现有API端点 ===")
success = True
try:
# 测试健康检查
response = self.session.get(f"{self.base_url.replace('/api/v1', '')}/health")
if response.status_code == 200:
print("✓ 健康检查通过")
else:
print(f"✗ 健康检查失败: {response.status_code}")
success = False
# 测试获取当前用户
response = self.session.get(f"{self.base_url}/users/me")
if response.status_code == 200:
print("✓ 获取当前用户成功")
user_data = response.json()
print(f" 用户名: {user_data.get('username')}")
else:
print(f"✗ 获取当前用户失败: {response.status_code}")
success = False
# 测试获取算法列表
response = self.session.get(f"{self.base_url}/algorithms/")
if response.status_code == 200:
print("✓ 获取算法列表成功")
algorithms = response.json()
print(f" 算法数量: {len(algorithms) if isinstance(algorithms, list) else 0}")
else:
print(f"✗ 获取算法列表失败: {response.status_code}")
success = False
# 测试获取服务列表
response = self.session.get(f"{self.base_url}/services")
if response.status_code == 200:
print("✓ 获取服务列表成功")
services = response.json()
print(f" 服务数量: {len(services) if isinstance(services, list) else 0}")
else:
print(f"✗ 获取服务列表失败: {response.status_code}")
success = False
return success
except Exception as e:
print(f"✗ 现有API端点测试异常: {str(e)}")
return False
def run_all_tests(self) -> Dict[str, bool]:
"""运行所有测试"""
print("=" * 50)
print("开始系统自动化测试")
print("=" * 50)
results = {}
# 登录
if not self.login():
print("\n✗ 登录失败,无法继续测试")
return {"login": False}
results["login"] = True
# 测试现有端点
results["existing_endpoints"] = self.test_existing_endpoints()
# 测试配置管理API
results["config_endpoints"] = self.test_config_endpoints()
# 测试算法比较API
results["comparison_endpoints"] = self.test_comparison_endpoints()
# 输出测试结果
print("\n" + "=" * 50)
print("测试结果汇总")
print("=" * 50)
for test_name, result in results.items():
status = "✓ 通过" if result else "✗ 失败"
print(f"{test_name}: {status}")
total_tests = len(results)
passed_tests = sum(1 for result in results.values() if result)
print(f"\n总计: {passed_tests}/{total_tests} 测试通过")
if passed_tests == total_tests:
print("🎉 所有测试通过!")
else:
print("⚠️ 部分测试失败,请检查日志")
return results
def main():
"""主函数"""
tester = SystemTester()
results = tester.run_all_tests()
# 返回退出码
exit_code = 0 if all(results.values()) else 1
return exit_code
if __name__ == "__main__":
exit(main())

View File

@@ -1,127 +0,0 @@
#!/usr/bin/env python3
"""
更新数据库结构删除api_keys表添加roles表修改users表
"""
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.models.database import engine, Base, SessionLocal
from app.models.models import User, Role
from app.services.user import UserService
def update_db():
"""更新数据库结构"""
db = SessionLocal()
try:
# 1. 删除api_keys表
print("删除api_keys表...")
try:
db.execute(text("DROP TABLE IF EXISTS api_keys CASCADE"))
db.commit()
print("✅ api_keys表删除成功")
except Exception as e:
print(f"⚠️ 删除api_keys表时出错: {e}")
db.rollback()
# 2. 创建roles表
print("\n创建roles表...")
try:
# 直接执行SQL创建表避免依赖模型的顺序
db.execute(text("""
CREATE TABLE IF NOT EXISTS roles (
id VARCHAR PRIMARY KEY,
name VARCHAR UNIQUE NOT NULL,
description TEXT DEFAULT '',
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
"""))
db.commit()
print("✅ roles表创建成功")
except Exception as e:
print(f"⚠️ 创建roles表时出错: {e}")
db.rollback()
# 3. 修改users表添加role_id字段删除role字段
print("\n修改users表...")
try:
# 检查是否存在role_id字段
result = db.execute(text("SELECT column_name FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'role_id'")).fetchone()
if not result:
# 添加role_id字段
db.execute(text("ALTER TABLE users ADD COLUMN role_id VARCHAR"))
print("✅ 添加role_id字段成功")
# 初始化默认角色
UserService.init_default_roles(db)
print("✅ 默认角色初始化成功")
# 获取默认角色
admin_role = UserService.get_role_by_name(db, "admin")
user_role = UserService.get_role_by_name(db, "user")
if admin_role and user_role:
# 更新现有用户的role_id字段
db.execute(text(f"UPDATE users SET role_id = CASE WHEN role = 'admin' THEN '{admin_role.id}' ELSE '{user_role.id}' END"))
print("✅ 更新用户role_id字段成功")
# 删除role字段
result = db.execute(text("SELECT column_name FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'role'")).fetchone()
if result:
db.execute(text("ALTER TABLE users DROP COLUMN role"))
print("✅ 删除role字段成功")
# 添加外键约束
db.execute(text("ALTER TABLE users ADD CONSTRAINT fk_users_role FOREIGN KEY (role_id) REFERENCES roles(id)"))
print("✅ 添加外键约束成功")
db.commit()
except Exception as e:
print(f"⚠️ 修改users表时出错: {e}")
db.rollback()
# 4. 检查并创建默认管理员账号
print("\n检查默认管理员账号...")
try:
# 检查是否已存在管理员账号
admin_user = db.query(User).filter(User.username == "admin").first()
if not admin_user:
# 获取默认管理员角色
admin_role = UserService.get_role_by_name(db, "admin")
if admin_role:
# 创建默认管理员账号
admin_user = User(
id="user-admin",
username="admin",
email="admin@example.com",
password_hash=UserService.get_password_hash("admin123"),
role_id=admin_role.id,
status="active"
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
print("✅ 默认管理员账号创建成功")
print(f"用户名: admin")
print(f"密码: admin123")
else:
print("❌ 无法创建管理员账号因为admin角色不存在")
else:
print("⚠️ 管理员账号已存在")
except Exception as e:
print(f"⚠️ 检查管理员账号时出错: {e}")
db.rollback()
print("\n✅ 数据库结构更新完成")
finally:
db.close()
if __name__ == "__main__":
update_db()

File diff suppressed because it is too large Load Diff