Files
algorithm/backend/app/routes/api_management.py

510 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""API管理路由处理API端点的封装和管理"""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
import logging
from datetime import datetime
from app.models.database import get_db
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
from app.models.api import ApiEndpoint, ApiCallLog
from app.schemas.user import UserResponse
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/api-management", tags=["api-management"])
logger = logging.getLogger(__name__)
class ApiEndpointCreate(BaseModel):
"""创建API端点请求模型"""
name: str
description: str = ""
path: str
method: str = "POST"
algorithm_id: str
version_id: str
service_id: Optional[str] = None
requires_auth: bool = True
allowed_roles: List[str] = []
rate_limit: Optional[Dict[str, Any]] = None
is_public: bool = False
config: Dict[str, Any] = {}
class ApiEndpointUpdate(BaseModel):
"""更新API端点请求模型"""
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None
method: Optional[str] = None
requires_auth: Optional[bool] = None
allowed_roles: Optional[List[str]] = None
rate_limit: Optional[Dict[str, Any]] = None
is_public: Optional[bool] = None
config: Optional[Dict[str, Any]] = None
status: Optional[str] = None
class ApiEndpointResponse(BaseModel):
"""API端点响应模型"""
id: str
name: str
description: str
path: str
method: str
algorithm_id: str
algorithm_name: str
version_id: str
version: str
service_id: Optional[str]
status: str
is_public: bool
call_count: str
success_count: str
error_count: str
avg_response_time: str
created_at: datetime
updated_at: Optional[datetime]
last_called_at: Optional[datetime]
class ApiEndpointListResponse(BaseModel):
"""API端点列表响应模型"""
endpoints: List[ApiEndpointResponse]
total: int
class ApiStatsResponse(BaseModel):
"""API统计响应模型"""
total_endpoints: int
active_endpoints: int
total_calls: str
total_success: str
total_errors: str
avg_response_time: str
@router.get("/endpoints", response_model=ApiEndpointListResponse)
async def get_api_endpoints(
skip: int = 0,
limit: int = 100,
algorithm_id: Optional[str] = None,
status: Optional[str] = None,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点列表"""
try:
# 构建查询
query = db.query(ApiEndpoint)
# 筛选条件
if algorithm_id:
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
if status:
query = query.filter(ApiEndpoint.status == status)
# 分页
endpoints = query.offset(skip).limit(limit).all()
total = query.count()
# 构建响应
endpoint_responses = []
for endpoint in endpoints:
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
endpoint_responses.append({
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
})
return {
"endpoints": endpoint_responses,
"total": total
}
except Exception as e:
logger.error(f"获取API端点列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def get_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点详情"""
try:
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API端点详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
async def create_api_endpoint(
request: ApiEndpointCreate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""创建API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 验证算法和版本是否存在
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
if not algorithm:
raise HTTPException(status_code=404, detail="算法不存在")
version = db.query(AlgorithmVersion).filter(
AlgorithmVersion.id == request.version_id
).first()
if not version or version.algorithm_id != request.algorithm_id:
raise HTTPException(status_code=404, detail="算法版本不存在")
# 如果指定了服务ID验证服务是否存在
if request.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查API路径是否已存在
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
# 创建API端点
new_endpoint = ApiEndpoint(
name=request.name,
description=request.description,
path=request.path,
method=request.method,
algorithm_id=request.algorithm_id,
version_id=request.version_id,
service_id=request.service_id,
requires_auth=request.requires_auth,
allowed_roles=request.allowed_roles,
rate_limit=request.rate_limit,
is_public=request.is_public,
config=request.config,
status="active",
call_count="0",
success_count="0",
error_count="0",
avg_response_time="0.0"
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
# 返回创建的API端点
return {
"id": new_endpoint.id,
"name": new_endpoint.name,
"description": new_endpoint.description,
"path": new_endpoint.path,
"method": new_endpoint.method,
"algorithm_id": new_endpoint.algorithm_id,
"algorithm_name": algorithm.name,
"version_id": new_endpoint.version_id,
"version": version.version,
"service_id": new_endpoint.service_id,
"status": new_endpoint.status,
"is_public": new_endpoint.is_public,
"call_count": new_endpoint.call_count,
"success_count": new_endpoint.success_count,
"error_count": new_endpoint.error_count,
"avg_response_time": new_endpoint.avg_response_time,
"created_at": new_endpoint.created_at,
"updated_at": new_endpoint.updated_at,
"last_called_at": new_endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def update_api_endpoint(
endpoint_id: str,
request: ApiEndpointUpdate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""更新API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 更新字段
if request.name is not None:
endpoint.name = request.name
if request.description is not None:
endpoint.description = request.description
if request.path is not None:
# 检查新路径是否已被其他端点使用
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path,
ApiEndpoint.id != endpoint_id
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
endpoint.path = request.path
if request.method is not None:
endpoint.method = request.method
if request.requires_auth is not None:
endpoint.requires_auth = request.requires_auth
if request.allowed_roles is not None:
endpoint.allowed_roles = request.allowed_roles
if request.rate_limit is not None:
endpoint.rate_limit = request.rate_limit
if request.is_public is not None:
endpoint.is_public = request.is_public
if request.config is not None:
endpoint.config = request.config
if request.status is not None:
endpoint.status = request.status
db.commit()
db.refresh(endpoint)
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
@router.delete("/endpoints/{endpoint_id}")
async def delete_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""删除API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 删除API端点
db.delete(endpoint)
db.commit()
return {
"success": True,
"message": "API端点删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
@router.get("/stats", response_model=ApiStatsResponse)
async def get_api_stats(
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API统计信息"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 统计API端点
total_endpoints = db.query(ApiEndpoint).count()
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
# 统计调用次数
endpoints = db.query(ApiEndpoint).all()
total_calls = sum(int(e.call_count or 0) for e in endpoints)
total_success = sum(int(e.success_count or 0) for e in endpoints)
total_errors = sum(int(e.error_count or 0) for e in endpoints)
# 计算平均响应时间
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
return {
"total_endpoints": total_endpoints,
"active_endpoints": active_endpoints,
"total_calls": str(total_calls),
"total_success": str(total_success),
"total_errors": str(total_errors),
"avg_response_time": f"{avg_response_time:.2f}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API统计信息失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
@router.post("/endpoints/{endpoint_id}/test")
async def test_api_endpoint(
endpoint_id: str,
payload: Dict[str, Any],
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""测试API端点"""
try:
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 检查API端点状态
if endpoint.status != "active":
raise HTTPException(status_code=400, detail="API端点未激活")
# 查询关联的服务
if endpoint.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == endpoint.service_id
).first()
if not service or service.status != "running":
raise HTTPException(status_code=400, detail="关联服务未运行")
# 调用服务
import httpx
import time
service_url = service.api_url
if not service_url.endswith("/"):
service_url += "/"
call_url = f"{service_url}predict"
start_time = time.time()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=payload,
headers={"Content-Type": "application/json"}
)
response_time = time.time() - start_time
if response.status_code == 200:
return {
"success": True,
"result": response.json(),
"response_time": response_time,
"message": "API调用成功"
}
else:
return {
"success": False,
"error": f"服务返回错误: HTTP {response.status_code}",
"response_time": response_time
}
else:
raise HTTPException(status_code=400, detail="API端点未关联服务")
except HTTPException:
raise
except Exception as e:
logger.error(f"测试API端点失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")