Files
algorithm/backend/app/gateway.py
2026-02-08 14:42:58 +08:00

144 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""API网关模块负责请求路由、认证授权、流量控制等功能"""
from fastapi import Request, HTTPException, status, Depends
from fastapi.responses import JSONResponse
from typing import Dict, Any, Optional
import time
import asyncio
import logging
from urllib.parse import urljoin
import httpx
from app.config.settings import settings
# Note: get_current_active_user is not used in this module, removing the import
from app.models.database import get_db
from app.services.algorithm import AlgorithmVersionService
logger = logging.getLogger(__name__)
class APIGateway:
"""API网关类处理请求路由、认证授权、流量控制等功能"""
def __init__(self):
self.request_counts = {} # 存储请求计数实际生产环境应使用Redis
async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
"""验证请求的认证信息"""
try:
# 从请求头获取token
token = request.headers.get("Authorization", "").replace("Bearer ", "")
if not token:
return None
# 验证token有效性这里只是示例实际应该调用认证服务
# 在实际实现中可能需要连接数据库验证API密钥
# 或者解析JWT token
return {"user_id": "temp_user_id", "role": "user"} # 临时返回值
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return None
async def check_rate_limit(self, user_id: str, algorithm_id: str) -> bool:
"""检查用户对特定算法的请求频率限制"""
key = f"{user_id}:{algorithm_id}"
current_time = time.time()
if key not in self.request_counts:
self.request_counts[key] = []
# 清除超过时间窗口的请求记录
self.request_counts[key] = [
req_time for req_time in self.request_counts[key]
if current_time - req_time < settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 30分钟窗口
]
# 检查是否超过限制例如每分钟最多10次请求
if len(self.request_counts[key]) >= 10:
return False
# 添加当前请求记录
self.request_counts[key].append(current_time)
return True
async def route_request(self, algorithm_id: str, version_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""路由请求到相应的算法服务"""
try:
# 获取算法版本信息
db = next(get_db())
version_info = AlgorithmVersionService.get_version_by_id(db, version_id)
if not version_info:
raise HTTPException(status_code=404, detail="Algorithm version not found")
# 在实际实现中这里会根据version_info.url将请求转发到对应的算法服务
# 现在我们模拟调用过程
algorithm_url = version_info.url if hasattr(version_info, 'url') else f"http://localhost:8001/algorithms/{algorithm_id}/execute"
# 使用httpx调用算法服务
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
algorithm_url,
json=payload,
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"Algorithm service error: {response.text}"
)
return response.json()
except httpx.RequestError as e:
logger.error(f"Request to algorithm service failed: {str(e)}")
raise HTTPException(status_code=502, detail="Algorithm service unavailable")
except Exception as e:
logger.error(f"Routing error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
# 全局API网关实例
api_gateway = APIGateway()
async def gateway_middleware(request: Request):
"""网关中间件,处理认证和路由"""
# 认证检查
user_info = await api_gateway.authenticate_request(request)
if not user_info:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
# 这里可以添加其他中间件逻辑,如权限检查、流量控制等
return user_info
async def call_algorithm_gateway(
algorithm_id: str,
version_id: str,
payload: Dict[str, Any],
user_info: Dict[str, Any] = Depends(gateway_middleware)
):
"""通过网关调用算法的主函数"""
# 检查速率限制
rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id)
if not rate_limited:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# 路由请求到算法服务
result = await api_gateway.route_request(algorithm_id, version_id, payload)
return result
# 辅助函数:获取用户信息而不强制要求认证
async def get_optional_user(request: Request):
"""获取用户信息如果未认证则返回None"""
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header.replace("Bearer ", "")
# 这里应该是实际的token验证逻辑
return {"user_id": "temp_user_id", "role": "user"} # 临时返回值
return None