first commit
This commit is contained in:
392
backend/app/routes/algorithm.py
Normal file
392
backend/app/routes/algorithm.py
Normal file
@@ -0,0 +1,392 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, UploadFile, File
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import uuid
|
||||
from app.utils.file import file_storage
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmResponse, AlgorithmListResponse, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmVersionResponse, AlgorithmCallCreate, AlgorithmCallResult
|
||||
from app.models.models import AlgorithmCall
|
||||
from app.services.algorithm import AlgorithmService, AlgorithmVersionService, AlgorithmCallService
|
||||
from app.dependencies import get_current_active_user
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/algorithms", tags=["algorithms"])
|
||||
|
||||
|
||||
@router.post("", response_model=AlgorithmResponse)
|
||||
async def create_algorithm(
|
||||
algorithm: AlgorithmCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建算法"""
|
||||
# 只有管理员可以创建算法
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 创建算法
|
||||
db_algorithm = AlgorithmService.create_algorithm(db, algorithm)
|
||||
|
||||
return db_algorithm
|
||||
|
||||
|
||||
@router.get("", response_model=AlgorithmListResponse)
|
||||
async def get_algorithms(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
type: Optional[str] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法列表"""
|
||||
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
|
||||
return {"algorithms": algorithms, "total": len(algorithms)}
|
||||
|
||||
|
||||
@router.get("/{algorithm_id}", response_model=AlgorithmResponse)
|
||||
async def get_algorithm(
|
||||
algorithm_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法详情"""
|
||||
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
return algorithm
|
||||
|
||||
|
||||
@router.put("/{algorithm_id}", response_model=AlgorithmResponse)
|
||||
async def update_algorithm(
|
||||
algorithm_id: str,
|
||||
algorithm_update: AlgorithmUpdate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新算法"""
|
||||
# 只有管理员可以更新算法
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
algorithm = AlgorithmService.update_algorithm(db, algorithm_id, algorithm_update)
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
return algorithm
|
||||
|
||||
|
||||
@router.delete("/{algorithm_id}", response_model=dict)
|
||||
async def delete_algorithm(
|
||||
algorithm_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除算法"""
|
||||
# 只有管理员可以删除算法
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
success = AlgorithmService.delete_algorithm(db, algorithm_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
return {"message": "Algorithm deleted successfully"}
|
||||
|
||||
|
||||
# 算法版本相关路由
|
||||
@router.post("/{algorithm_id}/versions", response_model=AlgorithmVersionResponse)
|
||||
async def create_version(
|
||||
algorithm_id: str,
|
||||
version: AlgorithmVersionCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建算法版本"""
|
||||
# 只有管理员可以创建算法版本
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 确保版本的算法ID与路径一致
|
||||
version.algorithm_id = algorithm_id
|
||||
|
||||
# 创建版本
|
||||
db_version = AlgorithmVersionService.create_version(db, version)
|
||||
|
||||
return db_version
|
||||
|
||||
|
||||
@router.get("/{algorithm_id}/versions", response_model=List[AlgorithmVersionResponse])
|
||||
async def get_versions(
|
||||
algorithm_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法版本列表"""
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本列表
|
||||
versions = AlgorithmVersionService.get_versions_by_algorithm_id(db, algorithm_id)
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
@router.get("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
|
||||
async def get_version(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法版本详情"""
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# 验证版本是否属于该算法
|
||||
if version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@router.put("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
|
||||
async def update_version(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
version_update: AlgorithmVersionUpdate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新算法版本"""
|
||||
# 只有管理员可以更新算法版本
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# 验证版本是否属于该算法
|
||||
if version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
||||
|
||||
# 更新版本
|
||||
updated_version = AlgorithmVersionService.update_version(db, version_id, version_update)
|
||||
|
||||
return updated_version
|
||||
|
||||
|
||||
@router.delete("/{algorithm_id}/versions/{version_id}", response_model=dict)
|
||||
async def delete_version(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除算法版本"""
|
||||
# 只有管理员可以删除算法版本
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# 验证版本是否属于该算法
|
||||
if version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
||||
|
||||
# 删除版本
|
||||
success = AlgorithmVersionService.delete_version(db, version_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to delete version")
|
||||
|
||||
return {"message": "Version deleted successfully"}
|
||||
|
||||
|
||||
# 算法调用相关路由
|
||||
@router.post("/call", response_model=AlgorithmCallResult)
|
||||
async def call_algorithm(
|
||||
call: AlgorithmCallCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""调用算法"""
|
||||
# 执行算法
|
||||
result = AlgorithmCallService.execute_algorithm(db, current_user.id, call)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/calls/{call_id}", response_model=AlgorithmCallResult)
|
||||
async def get_call_result(
|
||||
call_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法调用结果"""
|
||||
# 获取调用记录
|
||||
call = AlgorithmCallService.get_call_by_id(db, call_id)
|
||||
if not call:
|
||||
raise HTTPException(status_code=404, detail="Call not found")
|
||||
|
||||
# 管理员可以查看所有调用记录,普通用户只能查看自己的
|
||||
if current_user.role != "admin" and current_user.id != call.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
return call
|
||||
|
||||
|
||||
@router.get("/calls", response_model=List[AlgorithmCallResult])
|
||||
async def get_call_history(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取调用历史"""
|
||||
# 管理员可以查看所有调用记录,普通用户只能查看自己的
|
||||
if current_user.role == "admin":
|
||||
# 这里可以添加分页和过滤,暂时返回所有
|
||||
calls = db.query(AlgorithmCall).offset(skip).limit(limit).all()
|
||||
else:
|
||||
calls = AlgorithmCallService.get_calls_by_user_id(db, current_user.id, skip=skip, limit=limit)
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
# 代码执行相关路由
|
||||
@router.post("/execute-code")
|
||||
async def execute_code(
|
||||
code: str = Body(..., description="Python代码"),
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""执行Python代码"""
|
||||
# 执行代码
|
||||
result = AlgorithmCallService.execute_python_code(code)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 模型文件上传路由
|
||||
@router.post("/upload-model")
|
||||
async def upload_model(
|
||||
file: UploadFile = File(..., description="模型文件"),
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""上传模型文件"""
|
||||
# 支持的文件类型
|
||||
allowed_extensions = {
|
||||
".pt", ".pth", ".h5", ".hdf5", ".onnx", ".pb", ".tflite",
|
||||
".joblib", ".pkl", ".zip", ".tar.gz"
|
||||
}
|
||||
|
||||
# 验证文件类型
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
if file_extension not in allowed_extensions:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不支持的文件类型,支持的类型:{', '.join(allowed_extensions)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# 生成唯一的文件名
|
||||
unique_filename = f"models/{uuid.uuid4().hex[:8]}{file_extension}"
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 上传文件到MinIO
|
||||
import io
|
||||
file_obj = io.BytesIO(file_content)
|
||||
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": unique_filename,
|
||||
"message": "模型文件上传成功"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "模型文件上传失败"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"模型文件上传失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 视频文件上传路由
|
||||
@router.post("/upload-video")
|
||||
async def upload_video(
|
||||
file: UploadFile = File(..., description="视频文件"),
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""上传视频文件"""
|
||||
# 支持的视频文件类型
|
||||
allowed_extensions = {
|
||||
".mp4", ".avi", ".mov", ".wmv", ".flv", ".mkv", ".webm"
|
||||
}
|
||||
|
||||
# 验证文件类型
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
if file_extension not in allowed_extensions:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不支持的视频文件类型,支持的类型:{', '.join(allowed_extensions)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# 生成唯一的文件名
|
||||
unique_filename = f"videos/{current_user['id']}/{uuid.uuid4().hex[:12]}{file_extension}"
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 上传文件到MinIO
|
||||
import io
|
||||
file_obj = io.BytesIO(file_content)
|
||||
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": unique_filename,
|
||||
"message": "视频文件上传成功"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "视频文件上传失败"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"视频文件上传失败:{str(e)}"
|
||||
}
|
||||
Reference in New Issue
Block a user