393 lines
13 KiB
Python
393 lines
13 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status, Body, UploadFile, File
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
import os
|
|
import uuid
|
|
from app.utils.file import file_storage
|
|
|
|
from app.models.database import get_db
|
|
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmResponse, AlgorithmListResponse, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmVersionResponse, AlgorithmCallCreate, AlgorithmCallResult
|
|
from app.models.models import AlgorithmCall
|
|
from app.services.algorithm import AlgorithmService, AlgorithmVersionService, AlgorithmCallService
|
|
from app.dependencies import get_current_active_user
|
|
|
|
# 创建路由器
|
|
router = APIRouter(prefix="/algorithms", tags=["algorithms"])
|
|
|
|
|
|
@router.post("", response_model=AlgorithmResponse)
|
|
async def create_algorithm(
|
|
algorithm: AlgorithmCreate,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""创建算法"""
|
|
# 只有管理员可以创建算法
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
# 创建算法
|
|
db_algorithm = AlgorithmService.create_algorithm(db, algorithm)
|
|
|
|
return db_algorithm
|
|
|
|
|
|
@router.get("", response_model=AlgorithmListResponse)
|
|
async def get_algorithms(
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
type: Optional[str] = None,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取算法列表"""
|
|
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
|
|
return {"algorithms": algorithms, "total": len(algorithms)}
|
|
|
|
|
|
@router.get("/{algorithm_id}", response_model=AlgorithmResponse)
|
|
async def get_algorithm(
|
|
algorithm_id: str,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取算法详情"""
|
|
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
|
if not algorithm:
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
return algorithm
|
|
|
|
|
|
@router.put("/{algorithm_id}", response_model=AlgorithmResponse)
|
|
async def update_algorithm(
|
|
algorithm_id: str,
|
|
algorithm_update: AlgorithmUpdate,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""更新算法"""
|
|
# 只有管理员可以更新算法
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
algorithm = AlgorithmService.update_algorithm(db, algorithm_id, algorithm_update)
|
|
if not algorithm:
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
return algorithm
|
|
|
|
|
|
@router.delete("/{algorithm_id}", response_model=dict)
|
|
async def delete_algorithm(
|
|
algorithm_id: str,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""删除算法"""
|
|
# 只有管理员可以删除算法
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
success = AlgorithmService.delete_algorithm(db, algorithm_id)
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
return {"message": "Algorithm deleted successfully"}
|
|
|
|
|
|
# 算法版本相关路由
|
|
@router.post("/{algorithm_id}/versions", response_model=AlgorithmVersionResponse)
|
|
async def create_version(
|
|
algorithm_id: str,
|
|
version: AlgorithmVersionCreate,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""创建算法版本"""
|
|
# 只有管理员可以创建算法版本
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
# 验证算法是否存在
|
|
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
# 确保版本的算法ID与路径一致
|
|
version.algorithm_id = algorithm_id
|
|
|
|
# 创建版本
|
|
db_version = AlgorithmVersionService.create_version(db, version)
|
|
|
|
return db_version
|
|
|
|
|
|
@router.get("/{algorithm_id}/versions", response_model=List[AlgorithmVersionResponse])
|
|
async def get_versions(
|
|
algorithm_id: str,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取算法版本列表"""
|
|
# 验证算法是否存在
|
|
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
# 获取版本列表
|
|
versions = AlgorithmVersionService.get_versions_by_algorithm_id(db, algorithm_id)
|
|
|
|
return versions
|
|
|
|
|
|
@router.get("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
|
|
async def get_version(
|
|
algorithm_id: str,
|
|
version_id: str,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取算法版本详情"""
|
|
# 验证算法是否存在
|
|
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
# 获取版本
|
|
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
|
if not version:
|
|
raise HTTPException(status_code=404, detail="Version not found")
|
|
|
|
# 验证版本是否属于该算法
|
|
if version.algorithm_id != algorithm_id:
|
|
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
|
|
|
return version
|
|
|
|
|
|
@router.put("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
|
|
async def update_version(
|
|
algorithm_id: str,
|
|
version_id: str,
|
|
version_update: AlgorithmVersionUpdate,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""更新算法版本"""
|
|
# 只有管理员可以更新算法版本
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
# 验证算法是否存在
|
|
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
# 获取版本
|
|
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
|
if not version:
|
|
raise HTTPException(status_code=404, detail="Version not found")
|
|
|
|
# 验证版本是否属于该算法
|
|
if version.algorithm_id != algorithm_id:
|
|
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
|
|
|
# 更新版本
|
|
updated_version = AlgorithmVersionService.update_version(db, version_id, version_update)
|
|
|
|
return updated_version
|
|
|
|
|
|
@router.delete("/{algorithm_id}/versions/{version_id}", response_model=dict)
|
|
async def delete_version(
|
|
algorithm_id: str,
|
|
version_id: str,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""删除算法版本"""
|
|
# 只有管理员可以删除算法版本
|
|
if current_user.role != "admin":
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
# 验证算法是否存在
|
|
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
|
raise HTTPException(status_code=404, detail="Algorithm not found")
|
|
|
|
# 获取版本
|
|
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
|
if not version:
|
|
raise HTTPException(status_code=404, detail="Version not found")
|
|
|
|
# 验证版本是否属于该算法
|
|
if version.algorithm_id != algorithm_id:
|
|
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
|
|
|
# 删除版本
|
|
success = AlgorithmVersionService.delete_version(db, version_id)
|
|
if not success:
|
|
raise HTTPException(status_code=400, detail="Failed to delete version")
|
|
|
|
return {"message": "Version deleted successfully"}
|
|
|
|
|
|
# 算法调用相关路由
|
|
@router.post("/call", response_model=AlgorithmCallResult)
|
|
async def call_algorithm(
|
|
call: AlgorithmCallCreate,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""调用算法"""
|
|
# 执行算法
|
|
result = AlgorithmCallService.execute_algorithm(db, current_user.id, call)
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/calls/{call_id}", response_model=AlgorithmCallResult)
|
|
async def get_call_result(
|
|
call_id: str,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取算法调用结果"""
|
|
# 获取调用记录
|
|
call = AlgorithmCallService.get_call_by_id(db, call_id)
|
|
if not call:
|
|
raise HTTPException(status_code=404, detail="Call not found")
|
|
|
|
# 管理员可以查看所有调用记录,普通用户只能查看自己的
|
|
if current_user.role != "admin" and current_user.id != call.user_id:
|
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
|
|
return call
|
|
|
|
|
|
@router.get("/calls", response_model=List[AlgorithmCallResult])
|
|
async def get_call_history(
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取调用历史"""
|
|
# 管理员可以查看所有调用记录,普通用户只能查看自己的
|
|
if current_user.role == "admin":
|
|
# 这里可以添加分页和过滤,暂时返回所有
|
|
calls = db.query(AlgorithmCall).offset(skip).limit(limit).all()
|
|
else:
|
|
calls = AlgorithmCallService.get_calls_by_user_id(db, current_user.id, skip=skip, limit=limit)
|
|
|
|
return calls
|
|
|
|
|
|
# 代码执行相关路由
|
|
@router.post("/execute-code")
|
|
async def execute_code(
|
|
code: str = Body(..., description="Python代码"),
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""执行Python代码"""
|
|
# 执行代码
|
|
result = AlgorithmCallService.execute_python_code(code)
|
|
|
|
return result
|
|
|
|
|
|
# 模型文件上传路由
|
|
@router.post("/upload-model")
|
|
async def upload_model(
|
|
file: UploadFile = File(..., description="模型文件"),
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""上传模型文件"""
|
|
# 支持的文件类型
|
|
allowed_extensions = {
|
|
".pt", ".pth", ".h5", ".hdf5", ".onnx", ".pb", ".tflite",
|
|
".joblib", ".pkl", ".zip", ".tar.gz"
|
|
}
|
|
|
|
# 验证文件类型
|
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
|
if file_extension not in allowed_extensions:
|
|
return {
|
|
"success": False,
|
|
"message": f"不支持的文件类型,支持的类型:{', '.join(allowed_extensions)}"
|
|
}
|
|
|
|
try:
|
|
# 生成唯一的文件名
|
|
unique_filename = f"models/{uuid.uuid4().hex[:8]}{file_extension}"
|
|
|
|
# 读取文件内容
|
|
file_content = await file.read()
|
|
|
|
# 上传文件到MinIO
|
|
import io
|
|
file_obj = io.BytesIO(file_content)
|
|
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
|
|
|
|
if success:
|
|
return {
|
|
"success": True,
|
|
"file_path": unique_filename,
|
|
"message": "模型文件上传成功"
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"message": "模型文件上传失败"
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"message": f"模型文件上传失败:{str(e)}"
|
|
}
|
|
|
|
|
|
# 视频文件上传路由
|
|
@router.post("/upload-video")
|
|
async def upload_video(
|
|
file: UploadFile = File(..., description="视频文件"),
|
|
current_user: dict = Depends(get_current_active_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""上传视频文件"""
|
|
# 支持的视频文件类型
|
|
allowed_extensions = {
|
|
".mp4", ".avi", ".mov", ".wmv", ".flv", ".mkv", ".webm"
|
|
}
|
|
|
|
# 验证文件类型
|
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
|
if file_extension not in allowed_extensions:
|
|
return {
|
|
"success": False,
|
|
"message": f"不支持的视频文件类型,支持的类型:{', '.join(allowed_extensions)}"
|
|
}
|
|
|
|
try:
|
|
# 生成唯一的文件名
|
|
unique_filename = f"videos/{current_user['id']}/{uuid.uuid4().hex[:12]}{file_extension}"
|
|
|
|
# 读取文件内容
|
|
file_content = await file.read()
|
|
|
|
# 上传文件到MinIO
|
|
import io
|
|
file_obj = io.BytesIO(file_content)
|
|
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
|
|
|
|
if success:
|
|
return {
|
|
"success": True,
|
|
"file_path": unique_filename,
|
|
"message": "视频文件上传成功"
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"message": "视频文件上传失败"
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"message": f"视频文件上传失败:{str(e)}"
|
|
}
|