Files
algorithm/backend/app/routes/algorithm.py

405 lines
13 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status, Body, UploadFile, File
from sqlalchemy.orm import Session
from typing import List, Optional
import os
import uuid
from app.utils.file import file_storage
from app.models.database import get_db
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmResponse, AlgorithmListResponse, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmVersionResponse, AlgorithmCallCreate, AlgorithmCallResult
from app.models.models import AlgorithmCall
from app.services.algorithm import AlgorithmService, AlgorithmVersionService, AlgorithmCallService
from app.dependencies import get_current_active_user
# 创建路由器
router = APIRouter(prefix="/algorithms", tags=["algorithms"])
@router.post("", response_model=AlgorithmResponse)
async def create_algorithm(
algorithm: AlgorithmCreate,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""创建算法"""
# 只有管理员可以创建算法
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="Not enough permissions")
# 创建算法
db_algorithm = AlgorithmService.create_algorithm(db, algorithm)
return db_algorithm
@router.get("", response_model=AlgorithmListResponse)
async def get_algorithms_no_slash(
skip: int = 0,
limit: int = 100,
type: Optional[str] = None,
db: Session = Depends(get_db)
):
"""获取算法列表(不带末尾斜杠)"""
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
return {"algorithms": algorithms, "total": len(algorithms)}
@router.get("/", response_model=AlgorithmListResponse)
async def get_algorithms(
skip: int = 0,
limit: int = 100,
type: Optional[str] = None,
db: Session = Depends(get_db)
):
"""获取算法列表"""
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
return {"algorithms": algorithms, "total": len(algorithms)}
@router.get("/{algorithm_id}", response_model=AlgorithmResponse)
async def get_algorithm(
algorithm_id: str,
db: Session = Depends(get_db)
):
"""获取算法详情"""
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
if not algorithm:
raise HTTPException(status_code=404, detail="Algorithm not found")
return algorithm
@router.put("/{algorithm_id}", response_model=AlgorithmResponse)
async def update_algorithm(
algorithm_id: str,
algorithm_update: AlgorithmUpdate,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""更新算法"""
# 只有管理员可以更新算法
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="Not enough permissions")
algorithm = AlgorithmService.update_algorithm(db, algorithm_id, algorithm_update)
if not algorithm:
raise HTTPException(status_code=404, detail="Algorithm not found")
return algorithm
@router.delete("/{algorithm_id}", response_model=dict)
async def delete_algorithm(
algorithm_id: str,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""删除算法"""
# 只有管理员可以删除算法
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="Not enough permissions")
success = AlgorithmService.delete_algorithm(db, algorithm_id)
if not success:
raise HTTPException(status_code=404, detail="Algorithm not found")
return {"message": "Algorithm deleted successfully"}
# 算法版本相关路由
@router.post("/{algorithm_id}/versions", response_model=AlgorithmVersionResponse)
async def create_version(
algorithm_id: str,
version: AlgorithmVersionCreate,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""创建算法版本"""
# 只有管理员可以创建算法版本
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="Not enough permissions")
# 验证算法是否存在
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
raise HTTPException(status_code=404, detail="Algorithm not found")
# 确保版本的算法ID与路径一致
version.algorithm_id = algorithm_id
# 创建版本
db_version = AlgorithmVersionService.create_version(db, version)
return db_version
@router.get("/{algorithm_id}/versions", response_model=List[AlgorithmVersionResponse])
async def get_versions(
algorithm_id: str,
db: Session = Depends(get_db)
):
"""获取算法版本列表"""
# 验证算法是否存在
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
raise HTTPException(status_code=404, detail="Algorithm not found")
# 获取版本列表
versions = AlgorithmVersionService.get_versions_by_algorithm_id(db, algorithm_id)
return versions
@router.get("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
async def get_version(
algorithm_id: str,
version_id: str,
db: Session = Depends(get_db)
):
"""获取算法版本详情"""
# 验证算法是否存在
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
raise HTTPException(status_code=404, detail="Algorithm not found")
# 获取版本
version = AlgorithmVersionService.get_version_by_id(db, version_id)
if not version:
raise HTTPException(status_code=404, detail="Version not found")
# 验证版本是否属于该算法
if version.algorithm_id != algorithm_id:
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
return version
@router.put("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
async def update_version(
algorithm_id: str,
version_id: str,
version_update: AlgorithmVersionUpdate,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""更新算法版本"""
# 只有管理员可以更新算法版本
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="Not enough permissions")
# 验证算法是否存在
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
raise HTTPException(status_code=404, detail="Algorithm not found")
# 获取版本
version = AlgorithmVersionService.get_version_by_id(db, version_id)
if not version:
raise HTTPException(status_code=404, detail="Version not found")
# 验证版本是否属于该算法
if version.algorithm_id != algorithm_id:
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
# 更新版本
updated_version = AlgorithmVersionService.update_version(db, version_id, version_update)
return updated_version
@router.delete("/{algorithm_id}/versions/{version_id}", response_model=dict)
async def delete_version(
algorithm_id: str,
version_id: str,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""删除算法版本"""
# 只有管理员可以删除算法版本
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="Not enough permissions")
# 验证算法是否存在
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
raise HTTPException(status_code=404, detail="Algorithm not found")
# 获取版本
version = AlgorithmVersionService.get_version_by_id(db, version_id)
if not version:
raise HTTPException(status_code=404, detail="Version not found")
# 验证版本是否属于该算法
if version.algorithm_id != algorithm_id:
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
# 删除版本
success = AlgorithmVersionService.delete_version(db, version_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to delete version")
return {"message": "Version deleted successfully"}
# 算法调用相关路由
@router.post("/call", response_model=AlgorithmCallResult)
async def call_algorithm(
call: AlgorithmCallCreate,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""调用算法"""
# 执行算法
result = AlgorithmCallService.execute_algorithm(db, current_user.id, call)
return result
@router.get("/calls/{call_id}", response_model=AlgorithmCallResult)
async def get_call_result(
call_id: str,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""获取算法调用结果"""
# 获取调用记录
call = AlgorithmCallService.get_call_by_id(db, call_id)
if not call:
raise HTTPException(status_code=404, detail="Call not found")
# 管理员可以查看所有调用记录,普通用户只能查看自己的
if current_user.role != "admin" and current_user.id != call.user_id:
raise HTTPException(status_code=403, detail="Not enough permissions")
return call
@router.get("/calls", response_model=List[AlgorithmCallResult])
async def get_call_history(
skip: int = 0,
limit: int = 100,
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""获取调用历史"""
# 管理员可以查看所有调用记录,普通用户只能查看自己的
if current_user.role == "admin":
# 这里可以添加分页和过滤,暂时返回所有
calls = db.query(AlgorithmCall).offset(skip).limit(limit).all()
else:
calls = AlgorithmCallService.get_calls_by_user_id(db, current_user.id, skip=skip, limit=limit)
return calls
# 代码执行相关路由
@router.post("/execute-code")
async def execute_code(
code: str = Body(..., description="Python代码"),
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""执行Python代码"""
# 执行代码
result = AlgorithmCallService.execute_python_code(code)
return result
# 模型文件上传路由
@router.post("/upload-model")
async def upload_model(
file: UploadFile = File(..., description="模型文件"),
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""上传模型文件"""
# 支持的文件类型
allowed_extensions = {
".pt", ".pth", ".h5", ".hdf5", ".onnx", ".pb", ".tflite",
".joblib", ".pkl", ".zip", ".tar.gz"
}
# 验证文件类型
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in allowed_extensions:
return {
"success": False,
"message": f"不支持的文件类型,支持的类型:{', '.join(allowed_extensions)}"
}
try:
# 生成唯一的文件名
unique_filename = f"models/{uuid.uuid4().hex[:8]}{file_extension}"
# 读取文件内容
file_content = await file.read()
# 上传文件到MinIO
import io
file_obj = io.BytesIO(file_content)
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
if success:
return {
"success": True,
"file_path": unique_filename,
"message": "模型文件上传成功"
}
else:
return {
"success": False,
"message": "模型文件上传失败"
}
except Exception as e:
return {
"success": False,
"message": f"模型文件上传失败:{str(e)}"
}
# 视频文件上传路由
@router.post("/upload-video")
async def upload_video(
file: UploadFile = File(..., description="视频文件"),
current_user: dict = Depends(get_current_active_user),
db: Session = Depends(get_db)
):
"""上传视频文件"""
# 支持的视频文件类型
allowed_extensions = {
".mp4", ".avi", ".mov", ".wmv", ".flv", ".mkv", ".webm"
}
# 验证文件类型
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in allowed_extensions:
return {
"success": False,
"message": f"不支持的视频文件类型,支持的类型:{', '.join(allowed_extensions)}"
}
try:
# 生成唯一的文件名
unique_filename = f"videos/{current_user['id']}/{uuid.uuid4().hex[:12]}{file_extension}"
# 读取文件内容
file_content = await file.read()
# 上传文件到MinIO
import io
file_obj = io.BytesIO(file_content)
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
if success:
return {
"success": True,
"file_path": unique_filename,
"message": "视频文件上传成功"
}
else:
return {
"success": False,
"message": "视频文件上传失败"
}
except Exception as e:
return {
"success": False,
"message": f"视频文件上传失败:{str(e)}"
}