from fastapi import APIRouter, Depends, HTTPException, status, Body, UploadFile, File from sqlalchemy.orm import Session from typing import List, Optional import os import uuid from app.utils.file import file_storage from app.models.database import get_db from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmResponse, AlgorithmListResponse, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmVersionResponse, AlgorithmCallCreate, AlgorithmCallResult from app.models.models import AlgorithmCall from app.services.algorithm import AlgorithmService, AlgorithmVersionService, AlgorithmCallService from app.dependencies import get_current_active_user # 创建路由器 router = APIRouter(prefix="/algorithms", tags=["algorithms"]) @router.post("", response_model=AlgorithmResponse) async def create_algorithm( algorithm: AlgorithmCreate, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """创建算法""" # 只有管理员可以创建算法 if current_user.role != "admin": raise HTTPException(status_code=403, detail="Not enough permissions") # 创建算法 db_algorithm = AlgorithmService.create_algorithm(db, algorithm) return db_algorithm @router.get("", response_model=AlgorithmListResponse) async def get_algorithms_no_slash( skip: int = 0, limit: int = 100, type: Optional[str] = None, db: Session = Depends(get_db) ): """获取算法列表(不带末尾斜杠)""" algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type) return {"algorithms": algorithms, "total": len(algorithms)} @router.get("/", response_model=AlgorithmListResponse) async def get_algorithms( skip: int = 0, limit: int = 100, type: Optional[str] = None, db: Session = Depends(get_db) ): """获取算法列表""" algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type) return {"algorithms": algorithms, "total": len(algorithms)} @router.get("/{algorithm_id}", response_model=AlgorithmResponse) async def get_algorithm( algorithm_id: str, db: Session = Depends(get_db) ): """获取算法详情""" algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id) if not algorithm: raise HTTPException(status_code=404, detail="Algorithm not found") return algorithm @router.put("/{algorithm_id}", response_model=AlgorithmResponse) async def update_algorithm( algorithm_id: str, algorithm_update: AlgorithmUpdate, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """更新算法""" # 只有管理员可以更新算法 if current_user.role != "admin": raise HTTPException(status_code=403, detail="Not enough permissions") algorithm = AlgorithmService.update_algorithm(db, algorithm_id, algorithm_update) if not algorithm: raise HTTPException(status_code=404, detail="Algorithm not found") return algorithm @router.delete("/{algorithm_id}", response_model=dict) async def delete_algorithm( algorithm_id: str, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """删除算法""" # 只有管理员可以删除算法 if current_user.role != "admin": raise HTTPException(status_code=403, detail="Not enough permissions") success = AlgorithmService.delete_algorithm(db, algorithm_id) if not success: raise HTTPException(status_code=404, detail="Algorithm not found") return {"message": "Algorithm deleted successfully"} # 算法版本相关路由 @router.post("/{algorithm_id}/versions", response_model=AlgorithmVersionResponse) async def create_version( algorithm_id: str, version: AlgorithmVersionCreate, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """创建算法版本""" # 只有管理员可以创建算法版本 if current_user.role != "admin": raise HTTPException(status_code=403, detail="Not enough permissions") # 验证算法是否存在 if not AlgorithmService.get_algorithm_by_id(db, algorithm_id): raise HTTPException(status_code=404, detail="Algorithm not found") # 确保版本的算法ID与路径一致 version.algorithm_id = algorithm_id # 创建版本 db_version = AlgorithmVersionService.create_version(db, version) return db_version @router.get("/{algorithm_id}/versions", response_model=List[AlgorithmVersionResponse]) async def get_versions( algorithm_id: str, db: Session = Depends(get_db) ): """获取算法版本列表""" # 验证算法是否存在 if not AlgorithmService.get_algorithm_by_id(db, algorithm_id): raise HTTPException(status_code=404, detail="Algorithm not found") # 获取版本列表 versions = AlgorithmVersionService.get_versions_by_algorithm_id(db, algorithm_id) return versions @router.get("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse) async def get_version( algorithm_id: str, version_id: str, db: Session = Depends(get_db) ): """获取算法版本详情""" # 验证算法是否存在 if not AlgorithmService.get_algorithm_by_id(db, algorithm_id): raise HTTPException(status_code=404, detail="Algorithm not found") # 获取版本 version = AlgorithmVersionService.get_version_by_id(db, version_id) if not version: raise HTTPException(status_code=404, detail="Version not found") # 验证版本是否属于该算法 if version.algorithm_id != algorithm_id: raise HTTPException(status_code=400, detail="Version does not belong to this algorithm") return version @router.put("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse) async def update_version( algorithm_id: str, version_id: str, version_update: AlgorithmVersionUpdate, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """更新算法版本""" # 只有管理员可以更新算法版本 if current_user.role != "admin": raise HTTPException(status_code=403, detail="Not enough permissions") # 验证算法是否存在 if not AlgorithmService.get_algorithm_by_id(db, algorithm_id): raise HTTPException(status_code=404, detail="Algorithm not found") # 获取版本 version = AlgorithmVersionService.get_version_by_id(db, version_id) if not version: raise HTTPException(status_code=404, detail="Version not found") # 验证版本是否属于该算法 if version.algorithm_id != algorithm_id: raise HTTPException(status_code=400, detail="Version does not belong to this algorithm") # 更新版本 updated_version = AlgorithmVersionService.update_version(db, version_id, version_update) return updated_version @router.delete("/{algorithm_id}/versions/{version_id}", response_model=dict) async def delete_version( algorithm_id: str, version_id: str, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """删除算法版本""" # 只有管理员可以删除算法版本 if current_user.role != "admin": raise HTTPException(status_code=403, detail="Not enough permissions") # 验证算法是否存在 if not AlgorithmService.get_algorithm_by_id(db, algorithm_id): raise HTTPException(status_code=404, detail="Algorithm not found") # 获取版本 version = AlgorithmVersionService.get_version_by_id(db, version_id) if not version: raise HTTPException(status_code=404, detail="Version not found") # 验证版本是否属于该算法 if version.algorithm_id != algorithm_id: raise HTTPException(status_code=400, detail="Version does not belong to this algorithm") # 删除版本 success = AlgorithmVersionService.delete_version(db, version_id) if not success: raise HTTPException(status_code=400, detail="Failed to delete version") return {"message": "Version deleted successfully"} # 算法调用相关路由 @router.post("/call/public", response_model=AlgorithmCallResult) async def call_algorithm_public( call: AlgorithmCallCreate, db: Session = Depends(get_db) ): """公开调用算法(不需要认证,用于演示页面)""" # 使用匿名用户ID进行调用 anonymous_user_id = "anonymous" # 执行算法 result = AlgorithmCallService.execute_algorithm(db, anonymous_user_id, call) return result @router.post("/call", response_model=AlgorithmCallResult) async def call_algorithm( call: AlgorithmCallCreate, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """调用算法""" # 执行算法 result = AlgorithmCallService.execute_algorithm(db, current_user.id, call) return result @router.get("/calls/{call_id}", response_model=AlgorithmCallResult) async def get_call_result( call_id: str, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """获取算法调用结果""" # 获取调用记录 call = AlgorithmCallService.get_call_by_id(db, call_id) if not call: raise HTTPException(status_code=404, detail="Call not found") # 管理员可以查看所有调用记录,普通用户只能查看自己的 if current_user.role != "admin" and current_user.id != call.user_id: raise HTTPException(status_code=403, detail="Not enough permissions") return call @router.get("/calls", response_model=List[AlgorithmCallResult]) async def get_call_history( skip: int = 0, limit: int = 100, current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """获取调用历史""" # 管理员可以查看所有调用记录,普通用户只能查看自己的 if current_user.role == "admin": # 这里可以添加分页和过滤,暂时返回所有 calls = db.query(AlgorithmCall).offset(skip).limit(limit).all() else: calls = AlgorithmCallService.get_calls_by_user_id(db, current_user.id, skip=skip, limit=limit) return calls # 代码执行相关路由 @router.post("/execute-code") async def execute_code( code: str = Body(..., description="Python代码"), current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """执行Python代码""" # 执行代码 result = AlgorithmCallService.execute_python_code(code) return result # 模型文件上传路由 @router.post("/upload-model") async def upload_model( file: UploadFile = File(..., description="模型文件"), current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """上传模型文件""" # 支持的文件类型 allowed_extensions = { ".pt", ".pth", ".h5", ".hdf5", ".onnx", ".pb", ".tflite", ".joblib", ".pkl", ".zip", ".tar.gz" } # 验证文件类型 file_extension = os.path.splitext(file.filename)[1].lower() if file_extension not in allowed_extensions: return { "success": False, "message": f"不支持的文件类型,支持的类型:{', '.join(allowed_extensions)}" } try: # 生成唯一的文件名 unique_filename = f"models/{uuid.uuid4().hex[:8]}{file_extension}" # 读取文件内容 file_content = await file.read() # 上传文件到MinIO import io file_obj = io.BytesIO(file_content) success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type) if success: return { "success": True, "file_path": unique_filename, "message": "模型文件上传成功" } else: return { "success": False, "message": "模型文件上传失败" } except Exception as e: return { "success": False, "message": f"模型文件上传失败:{str(e)}" } # 视频文件上传路由 @router.post("/upload-video") async def upload_video( file: UploadFile = File(..., description="视频文件"), current_user: dict = Depends(get_current_active_user), db: Session = Depends(get_db) ): """上传视频文件""" # 支持的视频文件类型 allowed_extensions = { ".mp4", ".avi", ".mov", ".wmv", ".flv", ".mkv", ".webm" } # 验证文件类型 file_extension = os.path.splitext(file.filename)[1].lower() if file_extension not in allowed_extensions: return { "success": False, "message": f"不支持的视频文件类型,支持的类型:{', '.join(allowed_extensions)}" } try: # 生成唯一的文件名 unique_filename = f"videos/{current_user['id']}/{uuid.uuid4().hex[:12]}{file_extension}" # 读取文件内容 file_content = await file.read() # 上传文件到MinIO import io file_obj = io.BytesIO(file_content) success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type) if success: return { "success": True, "file_path": unique_filename, "message": "视频文件上传成功" } else: return { "success": False, "message": "视频文件上传失败" } except Exception as e: return { "success": False, "message": f"视频文件上传失败:{str(e)}" }