"""API管理路由,处理API端点的封装和管理""" from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel from typing import List, Optional, Dict, Any from sqlalchemy.orm import Session import logging from datetime import datetime from app.models.database import get_db from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User from app.models.api import ApiEndpoint, ApiCallLog from app.schemas.user import UserResponse from app.routes.user import get_current_active_user router = APIRouter(prefix="/api-management", tags=["api-management"]) logger = logging.getLogger(__name__) class ApiEndpointCreate(BaseModel): """创建API端点请求模型""" name: str description: str = "" path: str method: str = "POST" algorithm_id: str version_id: str service_id: Optional[str] = None requires_auth: bool = True allowed_roles: List[str] = [] rate_limit: Optional[Dict[str, Any]] = None is_public: bool = False config: Dict[str, Any] = {} class ApiEndpointUpdate(BaseModel): """更新API端点请求模型""" name: Optional[str] = None description: Optional[str] = None path: Optional[str] = None method: Optional[str] = None requires_auth: Optional[bool] = None allowed_roles: Optional[List[str]] = None rate_limit: Optional[Dict[str, Any]] = None is_public: Optional[bool] = None config: Optional[Dict[str, Any]] = None status: Optional[str] = None class ApiEndpointResponse(BaseModel): """API端点响应模型""" id: str name: str description: str path: str method: str algorithm_id: str algorithm_name: str version_id: str version: str service_id: Optional[str] status: str is_public: bool call_count: str success_count: str error_count: str avg_response_time: str created_at: datetime updated_at: Optional[datetime] last_called_at: Optional[datetime] class ApiEndpointListResponse(BaseModel): """API端点列表响应模型""" endpoints: List[ApiEndpointResponse] total: int class ApiStatsResponse(BaseModel): """API统计响应模型""" total_endpoints: int active_endpoints: int total_calls: str total_success: str total_errors: str avg_response_time: str @router.get("/endpoints", response_model=ApiEndpointListResponse) async def get_api_endpoints( skip: int = 0, limit: int = 100, algorithm_id: Optional[str] = None, status: Optional[str] = None, db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """获取API端点列表""" try: # 构建查询 query = db.query(ApiEndpoint) # 筛选条件 if algorithm_id: query = query.filter(ApiEndpoint.algorithm_id == algorithm_id) if status: query = query.filter(ApiEndpoint.status == status) # 分页 endpoints = query.offset(skip).limit(limit).all() total = query.count() # 构建响应 endpoint_responses = [] for endpoint in endpoints: # 获取关联的算法和版本信息 algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first() version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first() endpoint_responses.append({ "id": endpoint.id, "name": endpoint.name, "description": endpoint.description, "path": endpoint.path, "method": endpoint.method, "algorithm_id": endpoint.algorithm_id, "algorithm_name": algorithm.name if algorithm else "", "version_id": endpoint.version_id, "version": version.version if version else "", "service_id": endpoint.service_id, "status": endpoint.status, "is_public": endpoint.is_public, "call_count": endpoint.call_count, "success_count": endpoint.success_count, "error_count": endpoint.error_count, "avg_response_time": endpoint.avg_response_time, "created_at": endpoint.created_at, "updated_at": endpoint.updated_at, "last_called_at": endpoint.last_called_at }) return { "endpoints": endpoint_responses, "total": total } except Exception as e: logger.error(f"获取API端点列表失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}") @router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse) async def get_api_endpoint( endpoint_id: str, db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """获取API端点详情""" try: endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first() if not endpoint: raise HTTPException(status_code=404, detail="API端点不存在") # 获取关联的算法和版本信息 algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first() version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first() return { "id": endpoint.id, "name": endpoint.name, "description": endpoint.description, "path": endpoint.path, "method": endpoint.method, "algorithm_id": endpoint.algorithm_id, "algorithm_name": algorithm.name if algorithm else "", "version_id": endpoint.version_id, "version": version.version if version else "", "service_id": endpoint.service_id, "status": endpoint.status, "is_public": endpoint.is_public, "call_count": endpoint.call_count, "success_count": endpoint.success_count, "error_count": endpoint.error_count, "avg_response_time": endpoint.avg_response_time, "created_at": endpoint.created_at, "updated_at": endpoint.updated_at, "last_called_at": endpoint.last_called_at } except HTTPException: raise except Exception as e: logger.error(f"获取API端点详情失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}") @router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED) async def create_api_endpoint( request: ApiEndpointCreate, db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """创建API端点""" try: # 检查用户权限 if not hasattr(current_user, 'role_name') or current_user.role_name != "admin": raise HTTPException(status_code=403, detail="Insufficient permissions") # 验证算法和版本是否存在 algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first() if not algorithm: raise HTTPException(status_code=404, detail="算法不存在") version = db.query(AlgorithmVersion).filter( AlgorithmVersion.id == request.version_id ).first() if not version or version.algorithm_id != request.algorithm_id: raise HTTPException(status_code=404, detail="算法版本不存在") # 如果指定了服务ID,验证服务是否存在 if request.service_id: service = db.query(AlgorithmService).filter( AlgorithmService.service_id == request.service_id ).first() if not service: raise HTTPException(status_code=404, detail="服务不存在") # 检查API路径是否已存在 existing_endpoint = db.query(ApiEndpoint).filter( ApiEndpoint.path == request.path ).first() if existing_endpoint: raise HTTPException(status_code=400, detail="API路径已存在") # 创建API端点 new_endpoint = ApiEndpoint( name=request.name, description=request.description, path=request.path, method=request.method, algorithm_id=request.algorithm_id, version_id=request.version_id, service_id=request.service_id, requires_auth=request.requires_auth, allowed_roles=request.allowed_roles, rate_limit=request.rate_limit, is_public=request.is_public, config=request.config, status="active", call_count="0", success_count="0", error_count="0", avg_response_time="0.0" ) db.add(new_endpoint) db.commit() db.refresh(new_endpoint) # 返回创建的API端点 return { "id": new_endpoint.id, "name": new_endpoint.name, "description": new_endpoint.description, "path": new_endpoint.path, "method": new_endpoint.method, "algorithm_id": new_endpoint.algorithm_id, "algorithm_name": algorithm.name, "version_id": new_endpoint.version_id, "version": version.version, "service_id": new_endpoint.service_id, "status": new_endpoint.status, "is_public": new_endpoint.is_public, "call_count": new_endpoint.call_count, "success_count": new_endpoint.success_count, "error_count": new_endpoint.error_count, "avg_response_time": new_endpoint.avg_response_time, "created_at": new_endpoint.created_at, "updated_at": new_endpoint.updated_at, "last_called_at": new_endpoint.last_called_at } except HTTPException: raise except Exception as e: logger.error(f"创建API端点失败: {str(e)}") db.rollback() raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}") @router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse) async def update_api_endpoint( endpoint_id: str, request: ApiEndpointUpdate, db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """更新API端点""" try: # 检查用户权限 if not hasattr(current_user, 'role_name') or current_user.role_name != "admin": raise HTTPException(status_code=403, detail="Insufficient permissions") # 查询API端点 endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first() if not endpoint: raise HTTPException(status_code=404, detail="API端点不存在") # 更新字段 if request.name is not None: endpoint.name = request.name if request.description is not None: endpoint.description = request.description if request.path is not None: # 检查新路径是否已被其他端点使用 existing_endpoint = db.query(ApiEndpoint).filter( ApiEndpoint.path == request.path, ApiEndpoint.id != endpoint_id ).first() if existing_endpoint: raise HTTPException(status_code=400, detail="API路径已存在") endpoint.path = request.path if request.method is not None: endpoint.method = request.method if request.requires_auth is not None: endpoint.requires_auth = request.requires_auth if request.allowed_roles is not None: endpoint.allowed_roles = request.allowed_roles if request.rate_limit is not None: endpoint.rate_limit = request.rate_limit if request.is_public is not None: endpoint.is_public = request.is_public if request.config is not None: endpoint.config = request.config if request.status is not None: endpoint.status = request.status db.commit() db.refresh(endpoint) # 获取关联的算法和版本信息 algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first() version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first() return { "id": endpoint.id, "name": endpoint.name, "description": endpoint.description, "path": endpoint.path, "method": endpoint.method, "algorithm_id": endpoint.algorithm_id, "algorithm_name": algorithm.name if algorithm else "", "version_id": endpoint.version_id, "version": version.version if version else "", "service_id": endpoint.service_id, "status": endpoint.status, "is_public": endpoint.is_public, "call_count": endpoint.call_count, "success_count": endpoint.success_count, "error_count": endpoint.error_count, "avg_response_time": endpoint.avg_response_time, "created_at": endpoint.created_at, "updated_at": endpoint.updated_at, "last_called_at": endpoint.last_called_at } except HTTPException: raise except Exception as e: logger.error(f"更新API端点失败: {str(e)}") db.rollback() raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}") @router.delete("/endpoints/{endpoint_id}") async def delete_api_endpoint( endpoint_id: str, db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """删除API端点""" try: # 检查用户权限 if not hasattr(current_user, 'role_name') or current_user.role_name != "admin": raise HTTPException(status_code=403, detail="Insufficient permissions") # 查询API端点 endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first() if not endpoint: raise HTTPException(status_code=404, detail="API端点不存在") # 删除API端点 db.delete(endpoint) db.commit() return { "success": True, "message": "API端点删除成功" } except HTTPException: raise except Exception as e: logger.error(f"删除API端点失败: {str(e)}") db.rollback() raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}") @router.get("/stats", response_model=ApiStatsResponse) async def get_api_stats( db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """获取API统计信息""" try: # 检查用户权限 if not hasattr(current_user, 'role_name') or current_user.role_name != "admin": raise HTTPException(status_code=403, detail="Insufficient permissions") # 统计API端点 total_endpoints = db.query(ApiEndpoint).count() active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count() # 统计调用次数 endpoints = db.query(ApiEndpoint).all() total_calls = sum(int(e.call_count or 0) for e in endpoints) total_success = sum(int(e.success_count or 0) for e in endpoints) total_errors = sum(int(e.error_count or 0) for e in endpoints) # 计算平均响应时间 avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0] avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0 return { "total_endpoints": total_endpoints, "active_endpoints": active_endpoints, "total_calls": str(total_calls), "total_success": str(total_success), "total_errors": str(total_errors), "avg_response_time": f"{avg_response_time:.2f}" } except HTTPException: raise except Exception as e: logger.error(f"获取API统计信息失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}") @router.post("/endpoints/{endpoint_id}/test") async def test_api_endpoint( endpoint_id: str, payload: Dict[str, Any], db: Session = Depends(get_db), current_user: UserResponse = Depends(get_current_active_user) ): """测试API端点""" try: # 查询API端点 endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first() if not endpoint: raise HTTPException(status_code=404, detail="API端点不存在") # 检查API端点状态 if endpoint.status != "active": raise HTTPException(status_code=400, detail="API端点未激活") # 查询关联的服务 if endpoint.service_id: service = db.query(AlgorithmService).filter( AlgorithmService.service_id == endpoint.service_id ).first() if not service or service.status != "running": raise HTTPException(status_code=400, detail="关联服务未运行") # 调用服务 import httpx import time service_url = service.api_url if not service_url.endswith("/"): service_url += "/" call_url = f"{service_url}predict" start_time = time.time() async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( call_url, json=payload, headers={"Content-Type": "application/json"} ) response_time = time.time() - start_time if response.status_code == 200: return { "success": True, "result": response.json(), "response_time": response_time, "message": "API调用成功" } else: return { "success": False, "error": f"服务返回错误: HTTP {response.status_code}", "response_time": response_time } else: raise HTTPException(status_code=400, detail="API端点未关联服务") except HTTPException: raise except Exception as e: logger.error(f"测试API端点失败: {str(e)}") raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")