510 lines
19 KiB
Python
510 lines
19 KiB
Python
"""API管理路由,处理API端点的封装和管理"""
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional, Dict, Any
|
||
from sqlalchemy.orm import Session
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
from app.models.database import get_db
|
||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
|
||
from app.models.api import ApiEndpoint, ApiCallLog
|
||
from app.schemas.user import UserResponse
|
||
from app.routes.user import get_current_active_user
|
||
|
||
router = APIRouter(prefix="/api-management", tags=["api-management"])
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ApiEndpointCreate(BaseModel):
|
||
"""创建API端点请求模型"""
|
||
name: str
|
||
description: str = ""
|
||
path: str
|
||
method: str = "POST"
|
||
algorithm_id: str
|
||
version_id: str
|
||
service_id: Optional[str] = None
|
||
requires_auth: bool = True
|
||
allowed_roles: List[str] = []
|
||
rate_limit: Optional[Dict[str, Any]] = None
|
||
is_public: bool = False
|
||
config: Dict[str, Any] = {}
|
||
|
||
|
||
class ApiEndpointUpdate(BaseModel):
|
||
"""更新API端点请求模型"""
|
||
name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
path: Optional[str] = None
|
||
method: Optional[str] = None
|
||
requires_auth: Optional[bool] = None
|
||
allowed_roles: Optional[List[str]] = None
|
||
rate_limit: Optional[Dict[str, Any]] = None
|
||
is_public: Optional[bool] = None
|
||
config: Optional[Dict[str, Any]] = None
|
||
status: Optional[str] = None
|
||
|
||
|
||
class ApiEndpointResponse(BaseModel):
|
||
"""API端点响应模型"""
|
||
id: str
|
||
name: str
|
||
description: str
|
||
path: str
|
||
method: str
|
||
algorithm_id: str
|
||
algorithm_name: str
|
||
version_id: str
|
||
version: str
|
||
service_id: Optional[str]
|
||
status: str
|
||
is_public: bool
|
||
call_count: str
|
||
success_count: str
|
||
error_count: str
|
||
avg_response_time: str
|
||
created_at: datetime
|
||
updated_at: Optional[datetime]
|
||
last_called_at: Optional[datetime]
|
||
|
||
|
||
class ApiEndpointListResponse(BaseModel):
|
||
"""API端点列表响应模型"""
|
||
endpoints: List[ApiEndpointResponse]
|
||
total: int
|
||
|
||
|
||
class ApiStatsResponse(BaseModel):
|
||
"""API统计响应模型"""
|
||
total_endpoints: int
|
||
active_endpoints: int
|
||
total_calls: str
|
||
total_success: str
|
||
total_errors: str
|
||
avg_response_time: str
|
||
|
||
|
||
@router.get("/endpoints", response_model=ApiEndpointListResponse)
|
||
async def get_api_endpoints(
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
algorithm_id: Optional[str] = None,
|
||
status: Optional[str] = None,
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""获取API端点列表"""
|
||
try:
|
||
# 构建查询
|
||
query = db.query(ApiEndpoint)
|
||
|
||
# 筛选条件
|
||
if algorithm_id:
|
||
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
|
||
if status:
|
||
query = query.filter(ApiEndpoint.status == status)
|
||
|
||
# 分页
|
||
endpoints = query.offset(skip).limit(limit).all()
|
||
total = query.count()
|
||
|
||
# 构建响应
|
||
endpoint_responses = []
|
||
for endpoint in endpoints:
|
||
# 获取关联的算法和版本信息
|
||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||
|
||
endpoint_responses.append({
|
||
"id": endpoint.id,
|
||
"name": endpoint.name,
|
||
"description": endpoint.description,
|
||
"path": endpoint.path,
|
||
"method": endpoint.method,
|
||
"algorithm_id": endpoint.algorithm_id,
|
||
"algorithm_name": algorithm.name if algorithm else "",
|
||
"version_id": endpoint.version_id,
|
||
"version": version.version if version else "",
|
||
"service_id": endpoint.service_id,
|
||
"status": endpoint.status,
|
||
"is_public": endpoint.is_public,
|
||
"call_count": endpoint.call_count,
|
||
"success_count": endpoint.success_count,
|
||
"error_count": endpoint.error_count,
|
||
"avg_response_time": endpoint.avg_response_time,
|
||
"created_at": endpoint.created_at,
|
||
"updated_at": endpoint.updated_at,
|
||
"last_called_at": endpoint.last_called_at
|
||
})
|
||
|
||
return {
|
||
"endpoints": endpoint_responses,
|
||
"total": total
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取API端点列表失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
|
||
|
||
|
||
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
|
||
async def get_api_endpoint(
|
||
endpoint_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""获取API端点详情"""
|
||
try:
|
||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||
|
||
if not endpoint:
|
||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||
|
||
# 获取关联的算法和版本信息
|
||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||
|
||
return {
|
||
"id": endpoint.id,
|
||
"name": endpoint.name,
|
||
"description": endpoint.description,
|
||
"path": endpoint.path,
|
||
"method": endpoint.method,
|
||
"algorithm_id": endpoint.algorithm_id,
|
||
"algorithm_name": algorithm.name if algorithm else "",
|
||
"version_id": endpoint.version_id,
|
||
"version": version.version if version else "",
|
||
"service_id": endpoint.service_id,
|
||
"status": endpoint.status,
|
||
"is_public": endpoint.is_public,
|
||
"call_count": endpoint.call_count,
|
||
"success_count": endpoint.success_count,
|
||
"error_count": endpoint.error_count,
|
||
"avg_response_time": endpoint.avg_response_time,
|
||
"created_at": endpoint.created_at,
|
||
"updated_at": endpoint.updated_at,
|
||
"last_called_at": endpoint.last_called_at
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取API端点详情失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
|
||
|
||
|
||
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_api_endpoint(
|
||
request: ApiEndpointCreate,
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""创建API端点"""
|
||
try:
|
||
# 检查用户权限
|
||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
# 验证算法和版本是否存在
|
||
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
|
||
if not algorithm:
|
||
raise HTTPException(status_code=404, detail="算法不存在")
|
||
|
||
version = db.query(AlgorithmVersion).filter(
|
||
AlgorithmVersion.id == request.version_id
|
||
).first()
|
||
if not version or version.algorithm_id != request.algorithm_id:
|
||
raise HTTPException(status_code=404, detail="算法版本不存在")
|
||
|
||
# 如果指定了服务ID,验证服务是否存在
|
||
if request.service_id:
|
||
service = db.query(AlgorithmService).filter(
|
||
AlgorithmService.service_id == request.service_id
|
||
).first()
|
||
if not service:
|
||
raise HTTPException(status_code=404, detail="服务不存在")
|
||
|
||
# 检查API路径是否已存在
|
||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||
ApiEndpoint.path == request.path
|
||
).first()
|
||
if existing_endpoint:
|
||
raise HTTPException(status_code=400, detail="API路径已存在")
|
||
|
||
# 创建API端点
|
||
new_endpoint = ApiEndpoint(
|
||
name=request.name,
|
||
description=request.description,
|
||
path=request.path,
|
||
method=request.method,
|
||
algorithm_id=request.algorithm_id,
|
||
version_id=request.version_id,
|
||
service_id=request.service_id,
|
||
requires_auth=request.requires_auth,
|
||
allowed_roles=request.allowed_roles,
|
||
rate_limit=request.rate_limit,
|
||
is_public=request.is_public,
|
||
config=request.config,
|
||
status="active",
|
||
call_count="0",
|
||
success_count="0",
|
||
error_count="0",
|
||
avg_response_time="0.0"
|
||
)
|
||
|
||
db.add(new_endpoint)
|
||
db.commit()
|
||
db.refresh(new_endpoint)
|
||
|
||
# 返回创建的API端点
|
||
return {
|
||
"id": new_endpoint.id,
|
||
"name": new_endpoint.name,
|
||
"description": new_endpoint.description,
|
||
"path": new_endpoint.path,
|
||
"method": new_endpoint.method,
|
||
"algorithm_id": new_endpoint.algorithm_id,
|
||
"algorithm_name": algorithm.name,
|
||
"version_id": new_endpoint.version_id,
|
||
"version": version.version,
|
||
"service_id": new_endpoint.service_id,
|
||
"status": new_endpoint.status,
|
||
"is_public": new_endpoint.is_public,
|
||
"call_count": new_endpoint.call_count,
|
||
"success_count": new_endpoint.success_count,
|
||
"error_count": new_endpoint.error_count,
|
||
"avg_response_time": new_endpoint.avg_response_time,
|
||
"created_at": new_endpoint.created_at,
|
||
"updated_at": new_endpoint.updated_at,
|
||
"last_called_at": new_endpoint.last_called_at
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"创建API端点失败: {str(e)}")
|
||
db.rollback()
|
||
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
|
||
|
||
|
||
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
|
||
async def update_api_endpoint(
|
||
endpoint_id: str,
|
||
request: ApiEndpointUpdate,
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""更新API端点"""
|
||
try:
|
||
# 检查用户权限
|
||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
# 查询API端点
|
||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||
if not endpoint:
|
||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||
|
||
# 更新字段
|
||
if request.name is not None:
|
||
endpoint.name = request.name
|
||
if request.description is not None:
|
||
endpoint.description = request.description
|
||
if request.path is not None:
|
||
# 检查新路径是否已被其他端点使用
|
||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||
ApiEndpoint.path == request.path,
|
||
ApiEndpoint.id != endpoint_id
|
||
).first()
|
||
if existing_endpoint:
|
||
raise HTTPException(status_code=400, detail="API路径已存在")
|
||
endpoint.path = request.path
|
||
if request.method is not None:
|
||
endpoint.method = request.method
|
||
if request.requires_auth is not None:
|
||
endpoint.requires_auth = request.requires_auth
|
||
if request.allowed_roles is not None:
|
||
endpoint.allowed_roles = request.allowed_roles
|
||
if request.rate_limit is not None:
|
||
endpoint.rate_limit = request.rate_limit
|
||
if request.is_public is not None:
|
||
endpoint.is_public = request.is_public
|
||
if request.config is not None:
|
||
endpoint.config = request.config
|
||
if request.status is not None:
|
||
endpoint.status = request.status
|
||
|
||
db.commit()
|
||
db.refresh(endpoint)
|
||
|
||
# 获取关联的算法和版本信息
|
||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||
|
||
return {
|
||
"id": endpoint.id,
|
||
"name": endpoint.name,
|
||
"description": endpoint.description,
|
||
"path": endpoint.path,
|
||
"method": endpoint.method,
|
||
"algorithm_id": endpoint.algorithm_id,
|
||
"algorithm_name": algorithm.name if algorithm else "",
|
||
"version_id": endpoint.version_id,
|
||
"version": version.version if version else "",
|
||
"service_id": endpoint.service_id,
|
||
"status": endpoint.status,
|
||
"is_public": endpoint.is_public,
|
||
"call_count": endpoint.call_count,
|
||
"success_count": endpoint.success_count,
|
||
"error_count": endpoint.error_count,
|
||
"avg_response_time": endpoint.avg_response_time,
|
||
"created_at": endpoint.created_at,
|
||
"updated_at": endpoint.updated_at,
|
||
"last_called_at": endpoint.last_called_at
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"更新API端点失败: {str(e)}")
|
||
db.rollback()
|
||
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
|
||
|
||
|
||
@router.delete("/endpoints/{endpoint_id}")
|
||
async def delete_api_endpoint(
|
||
endpoint_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""删除API端点"""
|
||
try:
|
||
# 检查用户权限
|
||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
# 查询API端点
|
||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||
if not endpoint:
|
||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||
|
||
# 删除API端点
|
||
db.delete(endpoint)
|
||
db.commit()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "API端点删除成功"
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"删除API端点失败: {str(e)}")
|
||
db.rollback()
|
||
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
|
||
|
||
|
||
@router.get("/stats", response_model=ApiStatsResponse)
|
||
async def get_api_stats(
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""获取API统计信息"""
|
||
try:
|
||
# 检查用户权限
|
||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
# 统计API端点
|
||
total_endpoints = db.query(ApiEndpoint).count()
|
||
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
|
||
|
||
# 统计调用次数
|
||
endpoints = db.query(ApiEndpoint).all()
|
||
total_calls = sum(int(e.call_count or 0) for e in endpoints)
|
||
total_success = sum(int(e.success_count or 0) for e in endpoints)
|
||
total_errors = sum(int(e.error_count or 0) for e in endpoints)
|
||
|
||
# 计算平均响应时间
|
||
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
|
||
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
|
||
|
||
return {
|
||
"total_endpoints": total_endpoints,
|
||
"active_endpoints": active_endpoints,
|
||
"total_calls": str(total_calls),
|
||
"total_success": str(total_success),
|
||
"total_errors": str(total_errors),
|
||
"avg_response_time": f"{avg_response_time:.2f}"
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取API统计信息失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
|
||
|
||
|
||
@router.post("/endpoints/{endpoint_id}/test")
|
||
async def test_api_endpoint(
|
||
endpoint_id: str,
|
||
payload: Dict[str, Any],
|
||
db: Session = Depends(get_db),
|
||
current_user: UserResponse = Depends(get_current_active_user)
|
||
):
|
||
"""测试API端点"""
|
||
try:
|
||
# 查询API端点
|
||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||
if not endpoint:
|
||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||
|
||
# 检查API端点状态
|
||
if endpoint.status != "active":
|
||
raise HTTPException(status_code=400, detail="API端点未激活")
|
||
|
||
# 查询关联的服务
|
||
if endpoint.service_id:
|
||
service = db.query(AlgorithmService).filter(
|
||
AlgorithmService.service_id == endpoint.service_id
|
||
).first()
|
||
if not service or service.status != "running":
|
||
raise HTTPException(status_code=400, detail="关联服务未运行")
|
||
|
||
# 调用服务
|
||
import httpx
|
||
import time
|
||
|
||
service_url = service.api_url
|
||
if not service_url.endswith("/"):
|
||
service_url += "/"
|
||
|
||
call_url = f"{service_url}predict"
|
||
|
||
start_time = time.time()
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(
|
||
call_url,
|
||
json=payload,
|
||
headers={"Content-Type": "application/json"}
|
||
)
|
||
response_time = time.time() - start_time
|
||
|
||
if response.status_code == 200:
|
||
return {
|
||
"success": True,
|
||
"result": response.json(),
|
||
"response_time": response_time,
|
||
"message": "API调用成功"
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"error": f"服务返回错误: HTTP {response.status_code}",
|
||
"response_time": response_time
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=400, detail="API端点未关联服务")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"测试API端点失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}") |