"""API网关模块,负责请求路由、认证授权、流量控制等功能""" from fastapi import Request, HTTPException, status, Depends from fastapi.responses import JSONResponse from typing import Dict, Any, Optional import time import asyncio import logging from urllib.parse import urljoin import httpx from app.config.settings import settings # Note: get_current_active_user is not used in this module, removing the import from app.models.database import get_db from app.services.algorithm import AlgorithmVersionService logger = logging.getLogger(__name__) class APIGateway: """API网关类,处理请求路由、认证授权、流量控制等功能""" def __init__(self): self.request_counts = {} # 存储请求计数,实际生产环境应使用Redis async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]: """验证请求的认证信息""" try: # 从请求头获取token token = request.headers.get("Authorization", "").replace("Bearer ", "") if not token: return None # 验证token有效性(这里只是示例,实际应该调用认证服务) # 在实际实现中,可能需要连接数据库验证API密钥 # 或者解析JWT token return {"user_id": "temp_user_id", "role": "user"} # 临时返回值 except Exception as e: logger.error(f"Authentication error: {str(e)}") return None async def check_rate_limit(self, user_id: str, algorithm_id: str) -> bool: """检查用户对特定算法的请求频率限制""" key = f"{user_id}:{algorithm_id}" current_time = time.time() if key not in self.request_counts: self.request_counts[key] = [] # 清除超过时间窗口的请求记录 self.request_counts[key] = [ req_time for req_time in self.request_counts[key] if current_time - req_time < settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 30分钟窗口 ] # 检查是否超过限制(例如每分钟最多10次请求) if len(self.request_counts[key]) >= 10: return False # 添加当前请求记录 self.request_counts[key].append(current_time) return True async def route_request(self, algorithm_id: str, version_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: """路由请求到相应的算法服务""" try: # 获取算法版本信息 db = next(get_db()) version_info = AlgorithmVersionService.get_version_by_id(db, version_id) if not version_info: raise HTTPException(status_code=404, detail="Algorithm version not found") # 在实际实现中,这里会根据version_info.url将请求转发到对应的算法服务 # 现在我们模拟调用过程 algorithm_url = version_info.url if hasattr(version_info, 'url') else f"http://localhost:8001/algorithms/{algorithm_id}/execute" # 使用httpx调用算法服务 async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( algorithm_url, json=payload, headers={"Content-Type": "application/json"} ) if response.status_code != 200: raise HTTPException( status_code=response.status_code, detail=f"Algorithm service error: {response.text}" ) return response.json() except httpx.RequestError as e: logger.error(f"Request to algorithm service failed: {str(e)}") raise HTTPException(status_code=502, detail="Algorithm service unavailable") except Exception as e: logger.error(f"Routing error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") # 全局API网关实例 api_gateway = APIGateway() async def gateway_middleware(request: Request): """网关中间件,处理认证和路由""" # 认证检查 user_info = await api_gateway.authenticate_request(request) if not user_info: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") # 这里可以添加其他中间件逻辑,如权限检查、流量控制等 return user_info async def call_algorithm_gateway( algorithm_id: str, version_id: str, payload: Dict[str, Any], user_info: Dict[str, Any] = Depends(gateway_middleware) ): """通过网关调用算法的主函数""" # 检查速率限制 rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id) if not rate_limited: raise HTTPException(status_code=429, detail="Rate limit exceeded") # 路由请求到算法服务 result = await api_gateway.route_request(algorithm_id, version_id, payload) return result # 辅助函数:获取用户信息而不强制要求认证 async def get_optional_user(request: Request): """获取用户信息,如果未认证则返回None""" auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header.replace("Bearer ", "") # 这里应该是实际的token验证逻辑 return {"user_id": "temp_user_id", "role": "user"} # 临时返回值 return None