Files
algorithm/backend/app/routes/gateway.py
2026-02-08 14:42:58 +08:00

113 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""API网关路由处理算法调用的统一入口"""
from fastapi import APIRouter, Depends, HTTPException, status, Request
from typing import Dict, Any
import time
import logging
from app.gateway import api_gateway, call_algorithm_gateway
from app.models.database import get_db
from app.services.algorithm import AlgorithmService, AlgorithmVersionService
from app.schemas.algorithm import AlgorithmCallCreate, AlgorithmCallResult
router = APIRouter(prefix="/gateway", tags=["gateway"])
logger = logging.getLogger(__name__)
@router.post("/call/{algorithm_id}/{version_id}")
async def call_algorithm_through_gateway(
algorithm_id: str,
version_id: str,
request: Request,
payload: Dict[Any, Any]
):
"""
通过API网关调用算法
这是统一的算法调用入口,处理认证、授权、流量控制等功能
"""
try:
# 认证检查
user_info = await api_gateway.authenticate_request(request)
if not user_info:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
# 验证算法和版本是否存在
db = next(get_db())
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
if not algorithm:
raise HTTPException(status_code=404, detail="Algorithm not found")
version = AlgorithmVersionService.get_version_by_id(db, version_id)
if not version or version.algorithm_id != algorithm_id:
raise HTTPException(status_code=404, detail="Algorithm version not found")
# 检查用户是否有权限调用此算法
# 这里可以根据用户角色和算法权限配置进行检查
# 为了简化,我们假设所有用户都可以调用公开算法
# 检查速率限制
rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id)
if not rate_limited:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# 记录调用开始时间
start_time = time.time()
# 路由请求到算法服务
result = await api_gateway.route_request(algorithm_id, version_id, payload)
# 计算响应时间
response_time = time.time() - start_time
# 记录调用日志
logger.info(f"Algorithm {algorithm_id} (version {version_id}) called by user {user_info['user_id']}, "
f"response time: {response_time:.2f}s")
# 这里可以添加调用记录到数据库的逻辑
# AlgorithmCallService.create_call_record(...)
return {
"success": True,
"result": result,
"algorithm_id": algorithm_id,
"version_id": version_id,
"response_time": response_time,
"timestamp": time.time()
}
except HTTPException:
# 重新抛出HTTP异常
raise
except Exception as e:
logger.error(f"Gateway error when calling algorithm {algorithm_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Gateway error: {str(e)}")
@router.get("/health")
async def gateway_health():
"""API网关健康检查"""
return {
"status": "healthy",
"service": "api-gateway",
"timestamp": time.time()
}
@router.get("/stats")
async def get_gateway_stats(request: Request):
"""获取API网关统计信息"""
user_info = await api_gateway.authenticate_request(request)
if not user_info or user_info.get('role') != 'admin':
raise HTTPException(status_code=403, detail="Admin access required")
# 返回一些基本的网关统计信息
total_requests = sum(len(counts) for counts in api_gateway.request_counts.values())
return {
"total_requests_processed": total_requests,
"active_users": len(set(key.split(':')[0] for key in api_gateway.request_counts.keys())),
"algorithms_accessed": len(set(key.split(':')[1] for key in api_gateway.request_counts.keys())),
"rate_limit_blocks": 0, # 在实际实现中,这里应该跟踪被阻止的请求数
"uptime": "N/A" # 在实际实现中,这里应该是自启动以来的运行时间
}