159 lines
6.2 KiB
Python
159 lines
6.2 KiB
Python
"""API网关模块,负责请求路由、认证授权、流量控制等功能"""
|
||
|
||
from fastapi import Request, HTTPException, status, Depends
|
||
from fastapi.responses import JSONResponse
|
||
from typing import Dict, Any, Optional
|
||
import time
|
||
import asyncio
|
||
import logging
|
||
from urllib.parse import urljoin
|
||
import httpx
|
||
|
||
from app.config.settings import settings
|
||
# Note: get_current_active_user is not used in this module, removing the import
|
||
from app.models.database import get_db
|
||
from app.services.algorithm import AlgorithmVersionService
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class APIGateway:
|
||
"""API网关类,处理请求路由、认证授权、流量控制等功能"""
|
||
|
||
def __init__(self):
|
||
self.request_counts = {} # 存储请求计数,实际生产环境应使用Redis
|
||
|
||
async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
|
||
"""验证请求的认证信息"""
|
||
try:
|
||
# 从请求头获取token
|
||
token = request.headers.get("Authorization", "").replace("Bearer ", "")
|
||
if not token:
|
||
return None
|
||
|
||
# 验证token有效性(这里只是示例,实际应该调用认证服务)
|
||
# 在实际实现中,可能需要连接数据库验证API密钥
|
||
# 或者解析JWT token
|
||
return {"user_id": "temp_user_id", "role": "user"} # 临时返回值
|
||
except Exception as e:
|
||
logger.error(f"Authentication error: {str(e)}")
|
||
return None
|
||
|
||
async def check_rate_limit(self, user_id: str, algorithm_id: str) -> bool:
|
||
"""检查用户对特定算法的请求频率限制"""
|
||
key = f"{user_id}:{algorithm_id}"
|
||
current_time = time.time()
|
||
|
||
if key not in self.request_counts:
|
||
self.request_counts[key] = []
|
||
|
||
# 清除超过时间窗口的请求记录
|
||
self.request_counts[key] = [
|
||
req_time for req_time in self.request_counts[key]
|
||
if current_time - req_time < settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 30分钟窗口
|
||
]
|
||
|
||
# 检查是否超过限制(例如每分钟最多10次请求)
|
||
if len(self.request_counts[key]) >= 10:
|
||
return False
|
||
|
||
# 添加当前请求记录
|
||
self.request_counts[key].append(current_time)
|
||
return True
|
||
|
||
async def route_request(self, algorithm_id: str, version_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""路由请求到相应的算法服务"""
|
||
try:
|
||
# 获取算法版本信息
|
||
db = next(get_db())
|
||
version_info = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||
|
||
if not version_info:
|
||
raise HTTPException(status_code=404, detail="Algorithm version not found")
|
||
|
||
# 尝试从算法版本获取URL,如果没有则尝试从服务表获取
|
||
algorithm_url = None
|
||
|
||
# 首先检查版本信息中是否有URL
|
||
if hasattr(version_info, 'url') and version_info.url:
|
||
algorithm_url = version_info.url
|
||
else:
|
||
# 如果版本信息中没有URL,尝试从服务表获取
|
||
from app.models.models import AlgorithmService
|
||
service = db.query(AlgorithmService).filter(
|
||
AlgorithmService.algorithm_name == algorithm_id
|
||
).first()
|
||
|
||
if service and service.api_url:
|
||
algorithm_url = service.api_url
|
||
else:
|
||
# 如果都没有,使用默认的本地端点
|
||
algorithm_url = f"http://localhost:8001/algorithms/{algorithm_id}/execute"
|
||
|
||
# 使用httpx调用算法服务
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(
|
||
algorithm_url,
|
||
json=payload,
|
||
headers={"Content-Type": "application/json"}
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Algorithm service error: {response.text}"
|
||
)
|
||
|
||
return response.json()
|
||
|
||
except httpx.RequestError as e:
|
||
logger.error(f"Request to algorithm service failed: {str(e)}")
|
||
raise HTTPException(status_code=502, detail="Algorithm service unavailable")
|
||
except Exception as e:
|
||
logger.error(f"Routing error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail="Internal server error")
|
||
|
||
|
||
# 全局API网关实例
|
||
api_gateway = APIGateway()
|
||
|
||
|
||
async def gateway_middleware(request: Request):
|
||
"""网关中间件,处理认证和路由"""
|
||
# 认证检查
|
||
user_info = await api_gateway.authenticate_request(request)
|
||
if not user_info:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||
|
||
# 这里可以添加其他中间件逻辑,如权限检查、流量控制等
|
||
return user_info
|
||
|
||
|
||
async def call_algorithm_gateway(
|
||
algorithm_id: str,
|
||
version_id: str,
|
||
payload: Dict[str, Any],
|
||
user_info: Dict[str, Any] = Depends(gateway_middleware)
|
||
):
|
||
"""通过网关调用算法的主函数"""
|
||
# 检查速率限制
|
||
rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id)
|
||
if not rate_limited:
|
||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||
|
||
# 路由请求到算法服务
|
||
result = await api_gateway.route_request(algorithm_id, version_id, payload)
|
||
|
||
return result
|
||
|
||
|
||
# 辅助函数:获取用户信息而不强制要求认证
|
||
async def get_optional_user(request: Request):
|
||
"""获取用户信息,如果未认证则返回None"""
|
||
auth_header = request.headers.get("Authorization", "")
|
||
if auth_header.startswith("Bearer "):
|
||
token = auth_header.replace("Bearer ", "")
|
||
# 这里应该是实际的token验证逻辑
|
||
return {"user_id": "temp_user_id", "role": "user"} # 临时返回值
|
||
return None |