113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
"""API网关路由,处理算法调用的统一入口"""
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||
from typing import Dict, Any
|
||
import time
|
||
import logging
|
||
|
||
from app.gateway import api_gateway, call_algorithm_gateway
|
||
from app.models.database import get_db
|
||
from app.services.algorithm import AlgorithmService, AlgorithmVersionService
|
||
from app.schemas.algorithm import AlgorithmCallCreate, AlgorithmCallResult
|
||
|
||
router = APIRouter(prefix="/gateway", tags=["gateway"])
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@router.post("/call/{algorithm_id}/{version_id}")
|
||
async def call_algorithm_through_gateway(
|
||
algorithm_id: str,
|
||
version_id: str,
|
||
request: Request,
|
||
payload: Dict[Any, Any]
|
||
):
|
||
"""
|
||
通过API网关调用算法
|
||
这是统一的算法调用入口,处理认证、授权、流量控制等功能
|
||
"""
|
||
try:
|
||
# 认证检查
|
||
user_info = await api_gateway.authenticate_request(request)
|
||
if not user_info:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||
|
||
# 验证算法和版本是否存在
|
||
db = next(get_db())
|
||
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||
if not algorithm:
|
||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||
|
||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||
if not version or version.algorithm_id != algorithm_id:
|
||
raise HTTPException(status_code=404, detail="Algorithm version not found")
|
||
|
||
# 检查用户是否有权限调用此算法
|
||
# 这里可以根据用户角色和算法权限配置进行检查
|
||
# 为了简化,我们假设所有用户都可以调用公开算法
|
||
|
||
# 检查速率限制
|
||
rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id)
|
||
if not rate_limited:
|
||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||
|
||
# 记录调用开始时间
|
||
start_time = time.time()
|
||
|
||
# 路由请求到算法服务
|
||
result = await api_gateway.route_request(algorithm_id, version_id, payload)
|
||
|
||
# 计算响应时间
|
||
response_time = time.time() - start_time
|
||
|
||
# 记录调用日志
|
||
logger.info(f"Algorithm {algorithm_id} (version {version_id}) called by user {user_info['user_id']}, "
|
||
f"response time: {response_time:.2f}s")
|
||
|
||
# 这里可以添加调用记录到数据库的逻辑
|
||
# AlgorithmCallService.create_call_record(...)
|
||
|
||
return {
|
||
"success": True,
|
||
"result": result,
|
||
"algorithm_id": algorithm_id,
|
||
"version_id": version_id,
|
||
"response_time": response_time,
|
||
"timestamp": time.time()
|
||
}
|
||
|
||
except HTTPException:
|
||
# 重新抛出HTTP异常
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Gateway error when calling algorithm {algorithm_id}: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Gateway error: {str(e)}")
|
||
|
||
|
||
@router.get("/health")
|
||
async def gateway_health():
|
||
"""API网关健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"service": "api-gateway",
|
||
"timestamp": time.time()
|
||
}
|
||
|
||
|
||
@router.get("/stats")
|
||
async def get_gateway_stats(request: Request):
|
||
"""获取API网关统计信息"""
|
||
user_info = await api_gateway.authenticate_request(request)
|
||
if not user_info or user_info.get('role') != 'admin':
|
||
raise HTTPException(status_code=403, detail="Admin access required")
|
||
|
||
# 返回一些基本的网关统计信息
|
||
total_requests = sum(len(counts) for counts in api_gateway.request_counts.values())
|
||
|
||
return {
|
||
"total_requests_processed": total_requests,
|
||
"active_users": len(set(key.split(':')[0] for key in api_gateway.request_counts.keys())),
|
||
"algorithms_accessed": len(set(key.split(':')[1] for key in api_gateway.request_counts.keys())),
|
||
"rate_limit_blocks": 0, # 在实际实现中,这里应该跟踪被阻止的请求数
|
||
"uptime": "N/A" # 在实际实现中,这里应该是自启动以来的运行时间
|
||
} |