good version for 算法注册
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment
|
||||
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment, config, comparison
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -11,3 +11,5 @@ api_router.include_router(gateway.router, prefix="/gateway", tags=["gateway"])
|
||||
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
|
||||
api_router.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||
api_router.include_router(deployment.router, tags=["deployment"])
|
||||
api_router.include_router(config.router, tags=["config"])
|
||||
api_router.include_router(comparison.router, tags=["comparison"])
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend/app/routes/__pycache__/api_management.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/api_management.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/comparison.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/comparison.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/config.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -33,6 +33,18 @@ async def create_algorithm(
|
||||
|
||||
|
||||
@router.get("", response_model=AlgorithmListResponse)
|
||||
async def get_algorithms_no_slash(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
type: Optional[str] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法列表(不带末尾斜杠)"""
|
||||
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
|
||||
return {"algorithms": algorithms, "total": len(algorithms)}
|
||||
|
||||
|
||||
@router.get("/", response_model=AlgorithmListResponse)
|
||||
async def get_algorithms(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
|
||||
510
backend/app/routes/api_management.py
Normal file
510
backend/app/routes/api_management.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""API管理路由,处理API端点的封装和管理"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
|
||||
from app.models.api import ApiEndpoint, ApiCallLog
|
||||
from app.schemas.user import UserResponse
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/api-management", tags=["api-management"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiEndpointCreate(BaseModel):
|
||||
"""创建API端点请求模型"""
|
||||
name: str
|
||||
description: str = ""
|
||||
path: str
|
||||
method: str = "POST"
|
||||
algorithm_id: str
|
||||
version_id: str
|
||||
service_id: Optional[str] = None
|
||||
requires_auth: bool = True
|
||||
allowed_roles: List[str] = []
|
||||
rate_limit: Optional[Dict[str, Any]] = None
|
||||
is_public: bool = False
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ApiEndpointUpdate(BaseModel):
|
||||
"""更新API端点请求模型"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
method: Optional[str] = None
|
||||
requires_auth: Optional[bool] = None
|
||||
allowed_roles: Optional[List[str]] = None
|
||||
rate_limit: Optional[Dict[str, Any]] = None
|
||||
is_public: Optional[bool] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class ApiEndpointResponse(BaseModel):
|
||||
"""API端点响应模型"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
method: str
|
||||
algorithm_id: str
|
||||
algorithm_name: str
|
||||
version_id: str
|
||||
version: str
|
||||
service_id: Optional[str]
|
||||
status: str
|
||||
is_public: bool
|
||||
call_count: str
|
||||
success_count: str
|
||||
error_count: str
|
||||
avg_response_time: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
last_called_at: Optional[datetime]
|
||||
|
||||
|
||||
class ApiEndpointListResponse(BaseModel):
|
||||
"""API端点列表响应模型"""
|
||||
endpoints: List[ApiEndpointResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class ApiStatsResponse(BaseModel):
|
||||
"""API统计响应模型"""
|
||||
total_endpoints: int
|
||||
active_endpoints: int
|
||||
total_calls: str
|
||||
total_success: str
|
||||
total_errors: str
|
||||
avg_response_time: str
|
||||
|
||||
|
||||
@router.get("/endpoints", response_model=ApiEndpointListResponse)
|
||||
async def get_api_endpoints(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
algorithm_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取API端点列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = db.query(ApiEndpoint)
|
||||
|
||||
# 筛选条件
|
||||
if algorithm_id:
|
||||
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
|
||||
if status:
|
||||
query = query.filter(ApiEndpoint.status == status)
|
||||
|
||||
# 分页
|
||||
endpoints = query.offset(skip).limit(limit).all()
|
||||
total = query.count()
|
||||
|
||||
# 构建响应
|
||||
endpoint_responses = []
|
||||
for endpoint in endpoints:
|
||||
# 获取关联的算法和版本信息
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||||
|
||||
endpoint_responses.append({
|
||||
"id": endpoint.id,
|
||||
"name": endpoint.name,
|
||||
"description": endpoint.description,
|
||||
"path": endpoint.path,
|
||||
"method": endpoint.method,
|
||||
"algorithm_id": endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name if algorithm else "",
|
||||
"version_id": endpoint.version_id,
|
||||
"version": version.version if version else "",
|
||||
"service_id": endpoint.service_id,
|
||||
"status": endpoint.status,
|
||||
"is_public": endpoint.is_public,
|
||||
"call_count": endpoint.call_count,
|
||||
"success_count": endpoint.success_count,
|
||||
"error_count": endpoint.error_count,
|
||||
"avg_response_time": endpoint.avg_response_time,
|
||||
"created_at": endpoint.created_at,
|
||||
"updated_at": endpoint.updated_at,
|
||||
"last_called_at": endpoint.last_called_at
|
||||
})
|
||||
|
||||
return {
|
||||
"endpoints": endpoint_responses,
|
||||
"total": total
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取API端点列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
|
||||
async def get_api_endpoint(
|
||||
endpoint_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取API端点详情"""
|
||||
try:
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 获取关联的算法和版本信息
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||||
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"name": endpoint.name,
|
||||
"description": endpoint.description,
|
||||
"path": endpoint.path,
|
||||
"method": endpoint.method,
|
||||
"algorithm_id": endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name if algorithm else "",
|
||||
"version_id": endpoint.version_id,
|
||||
"version": version.version if version else "",
|
||||
"service_id": endpoint.service_id,
|
||||
"status": endpoint.status,
|
||||
"is_public": endpoint.is_public,
|
||||
"call_count": endpoint.call_count,
|
||||
"success_count": endpoint.success_count,
|
||||
"error_count": endpoint.error_count,
|
||||
"avg_response_time": endpoint.avg_response_time,
|
||||
"created_at": endpoint.created_at,
|
||||
"updated_at": endpoint.updated_at,
|
||||
"last_called_at": endpoint.last_called_at
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取API端点详情失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_api_endpoint(
|
||||
request: ApiEndpointCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""创建API端点"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 验证算法和版本是否存在
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="算法不存在")
|
||||
|
||||
version = db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.id == request.version_id
|
||||
).first()
|
||||
if not version or version.algorithm_id != request.algorithm_id:
|
||||
raise HTTPException(status_code=404, detail="算法版本不存在")
|
||||
|
||||
# 如果指定了服务ID,验证服务是否存在
|
||||
if request.service_id:
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.service_id == request.service_id
|
||||
).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
# 检查API路径是否已存在
|
||||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||||
ApiEndpoint.path == request.path
|
||||
).first()
|
||||
if existing_endpoint:
|
||||
raise HTTPException(status_code=400, detail="API路径已存在")
|
||||
|
||||
# 创建API端点
|
||||
new_endpoint = ApiEndpoint(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
path=request.path,
|
||||
method=request.method,
|
||||
algorithm_id=request.algorithm_id,
|
||||
version_id=request.version_id,
|
||||
service_id=request.service_id,
|
||||
requires_auth=request.requires_auth,
|
||||
allowed_roles=request.allowed_roles,
|
||||
rate_limit=request.rate_limit,
|
||||
is_public=request.is_public,
|
||||
config=request.config,
|
||||
status="active",
|
||||
call_count="0",
|
||||
success_count="0",
|
||||
error_count="0",
|
||||
avg_response_time="0.0"
|
||||
)
|
||||
|
||||
db.add(new_endpoint)
|
||||
db.commit()
|
||||
db.refresh(new_endpoint)
|
||||
|
||||
# 返回创建的API端点
|
||||
return {
|
||||
"id": new_endpoint.id,
|
||||
"name": new_endpoint.name,
|
||||
"description": new_endpoint.description,
|
||||
"path": new_endpoint.path,
|
||||
"method": new_endpoint.method,
|
||||
"algorithm_id": new_endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name,
|
||||
"version_id": new_endpoint.version_id,
|
||||
"version": version.version,
|
||||
"service_id": new_endpoint.service_id,
|
||||
"status": new_endpoint.status,
|
||||
"is_public": new_endpoint.is_public,
|
||||
"call_count": new_endpoint.call_count,
|
||||
"success_count": new_endpoint.success_count,
|
||||
"error_count": new_endpoint.error_count,
|
||||
"avg_response_time": new_endpoint.avg_response_time,
|
||||
"created_at": new_endpoint.created_at,
|
||||
"updated_at": new_endpoint.updated_at,
|
||||
"last_called_at": new_endpoint.last_called_at
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建API端点失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
|
||||
async def update_api_endpoint(
|
||||
endpoint_id: str,
|
||||
request: ApiEndpointUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""更新API端点"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 查询API端点
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 更新字段
|
||||
if request.name is not None:
|
||||
endpoint.name = request.name
|
||||
if request.description is not None:
|
||||
endpoint.description = request.description
|
||||
if request.path is not None:
|
||||
# 检查新路径是否已被其他端点使用
|
||||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||||
ApiEndpoint.path == request.path,
|
||||
ApiEndpoint.id != endpoint_id
|
||||
).first()
|
||||
if existing_endpoint:
|
||||
raise HTTPException(status_code=400, detail="API路径已存在")
|
||||
endpoint.path = request.path
|
||||
if request.method is not None:
|
||||
endpoint.method = request.method
|
||||
if request.requires_auth is not None:
|
||||
endpoint.requires_auth = request.requires_auth
|
||||
if request.allowed_roles is not None:
|
||||
endpoint.allowed_roles = request.allowed_roles
|
||||
if request.rate_limit is not None:
|
||||
endpoint.rate_limit = request.rate_limit
|
||||
if request.is_public is not None:
|
||||
endpoint.is_public = request.is_public
|
||||
if request.config is not None:
|
||||
endpoint.config = request.config
|
||||
if request.status is not None:
|
||||
endpoint.status = request.status
|
||||
|
||||
db.commit()
|
||||
db.refresh(endpoint)
|
||||
|
||||
# 获取关联的算法和版本信息
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||||
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"name": endpoint.name,
|
||||
"description": endpoint.description,
|
||||
"path": endpoint.path,
|
||||
"method": endpoint.method,
|
||||
"algorithm_id": endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name if algorithm else "",
|
||||
"version_id": endpoint.version_id,
|
||||
"version": version.version if version else "",
|
||||
"service_id": endpoint.service_id,
|
||||
"status": endpoint.status,
|
||||
"is_public": endpoint.is_public,
|
||||
"call_count": endpoint.call_count,
|
||||
"success_count": endpoint.success_count,
|
||||
"error_count": endpoint.error_count,
|
||||
"avg_response_time": endpoint.avg_response_time,
|
||||
"created_at": endpoint.created_at,
|
||||
"updated_at": endpoint.updated_at,
|
||||
"last_called_at": endpoint.last_called_at
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新API端点失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/endpoints/{endpoint_id}")
|
||||
async def delete_api_endpoint(
|
||||
endpoint_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除API端点"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 查询API端点
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 删除API端点
|
||||
db.delete(endpoint)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "API端点删除成功"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除API端点失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ApiStatsResponse)
|
||||
async def get_api_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取API统计信息"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 统计API端点
|
||||
total_endpoints = db.query(ApiEndpoint).count()
|
||||
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
|
||||
|
||||
# 统计调用次数
|
||||
endpoints = db.query(ApiEndpoint).all()
|
||||
total_calls = sum(int(e.call_count or 0) for e in endpoints)
|
||||
total_success = sum(int(e.success_count or 0) for e in endpoints)
|
||||
total_errors = sum(int(e.error_count or 0) for e in endpoints)
|
||||
|
||||
# 计算平均响应时间
|
||||
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
|
||||
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
|
||||
|
||||
return {
|
||||
"total_endpoints": total_endpoints,
|
||||
"active_endpoints": active_endpoints,
|
||||
"total_calls": str(total_calls),
|
||||
"total_success": str(total_success),
|
||||
"total_errors": str(total_errors),
|
||||
"avg_response_time": f"{avg_response_time:.2f}"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取API统计信息失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/endpoints/{endpoint_id}/test")
|
||||
async def test_api_endpoint(
|
||||
endpoint_id: str,
|
||||
payload: Dict[str, Any],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""测试API端点"""
|
||||
try:
|
||||
# 查询API端点
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 检查API端点状态
|
||||
if endpoint.status != "active":
|
||||
raise HTTPException(status_code=400, detail="API端点未激活")
|
||||
|
||||
# 查询关联的服务
|
||||
if endpoint.service_id:
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.service_id == endpoint.service_id
|
||||
).first()
|
||||
if not service or service.status != "running":
|
||||
raise HTTPException(status_code=400, detail="关联服务未运行")
|
||||
|
||||
# 调用服务
|
||||
import httpx
|
||||
import time
|
||||
|
||||
service_url = service.api_url
|
||||
if not service_url.endswith("/"):
|
||||
service_url += "/"
|
||||
|
||||
call_url = f"{service_url}predict"
|
||||
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
call_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"result": response.json(),
|
||||
"response_time": response_time,
|
||||
"message": "API调用成功"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"服务返回错误: HTTP {response.status_code}",
|
||||
"response_time": response_time
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="API端点未关联服务")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"测试API端点失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")
|
||||
64
backend/app/routes/comparison.py
Normal file
64
backend/app/routes/comparison.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.services.comparison_service import ComparisonService
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/comparison", tags=["comparison"])
|
||||
|
||||
# 创建对比服务实例
|
||||
comparison_service = ComparisonService()
|
||||
|
||||
|
||||
@router.post("/compare-algorithms", response_model=dict)
|
||||
async def compare_algorithms(
|
||||
request_data: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""比较多个算法的效果
|
||||
|
||||
Args:
|
||||
request_data: 请求数据,包含input_data和algorithm_configs
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
对比结果
|
||||
"""
|
||||
input_data = request_data.get("input_data")
|
||||
algorithm_configs = request_data.get("algorithm_configs")
|
||||
|
||||
if not input_data:
|
||||
raise HTTPException(status_code=400, detail="缺少 input_data 参数")
|
||||
|
||||
if not algorithm_configs or not isinstance(algorithm_configs, list):
|
||||
raise HTTPException(status_code=400, detail="缺少 algorithm_configs 参数或格式错误")
|
||||
|
||||
result = await comparison_service.compare_algorithms(input_data, algorithm_configs)
|
||||
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "对比失败"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/generate-report", response_model=dict)
|
||||
async def generate_comparison_report(
|
||||
comparison_results: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""生成对比报告
|
||||
|
||||
Args:
|
||||
comparison_results: 对比结果
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
对比报告
|
||||
"""
|
||||
report = comparison_service.generate_comparison_report(comparison_results)
|
||||
|
||||
if not report["success"]:
|
||||
raise HTTPException(status_code=500, detail=report.get("error", "生成报告失败"))
|
||||
|
||||
return report
|
||||
124
backend/app/routes/config.py
Normal file
124
backend/app/routes/config.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.services.config_service import ConfigService
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/{config_key}")
|
||||
async def get_config(
|
||||
config_key: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取配置
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
配置信息
|
||||
"""
|
||||
config = ConfigService.get_config(db, config_key)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="配置不存在")
|
||||
return {"key": config_key, "value": config}
|
||||
|
||||
|
||||
@router.post("/{config_key}")
|
||||
async def set_config(
|
||||
config_key: str,
|
||||
config_data: Dict[str, Any],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""设置配置
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
config_data: 配置数据,包含value、type、service_id、description等字段
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
设置结果
|
||||
"""
|
||||
success = ConfigService.set_config(
|
||||
db=db,
|
||||
config_key=config_key,
|
||||
config_value=config_data.get("value"),
|
||||
config_type=config_data.get("type", "system"),
|
||||
service_id=config_data.get("service_id"),
|
||||
description=config_data.get("description", "")
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="设置配置失败")
|
||||
return {"message": "设置配置成功"}
|
||||
|
||||
|
||||
@router.get("/service/{service_id}")
|
||||
async def get_service_configs(
|
||||
service_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
服务配置列表
|
||||
"""
|
||||
configs = ConfigService.get_service_configs(db, service_id)
|
||||
return {"service_id": service_id, "configs": configs}
|
||||
|
||||
|
||||
@router.delete("/{config_key}")
|
||||
async def delete_config(
|
||||
config_key: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除配置
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
success = ConfigService.delete_config(db, config_key)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="删除配置失败")
|
||||
return {"message": "删除配置成功"}
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_all_configs(
|
||||
config_type: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取所有配置
|
||||
|
||||
Args:
|
||||
config_type: 配置类型,可选
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
配置列表
|
||||
"""
|
||||
configs = ConfigService.get_all_configs(db, config_type)
|
||||
return {"configs": configs}
|
||||
@@ -14,6 +14,7 @@ from app.schemas.user import UserResponse
|
||||
from app.services.project_analyzer import ProjectAnalyzer
|
||||
from app.services.service_generator import ServiceGenerator
|
||||
from app.services.service_orchestrator import ServiceOrchestrator
|
||||
from app.gitea.service import gitea_service
|
||||
|
||||
router = APIRouter(prefix="/services", tags=["services"])
|
||||
|
||||
@@ -23,6 +24,9 @@ class RegisterServiceRequest(BaseModel):
|
||||
repository_id: str
|
||||
name: str
|
||||
version: str = "1.0.0"
|
||||
description: Optional[str] = ""
|
||||
tech_category: str = "computer_vision"
|
||||
output_type: str = "image"
|
||||
service_type: str = "http"
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
@@ -154,31 +158,24 @@ async def register_service(
|
||||
# 记录仓库信息
|
||||
print(f"仓库信息: {repo.name}, {repo.description}, {repo.repo_url}")
|
||||
|
||||
# 2. 分析项目
|
||||
repo_path = f"/tmp/repository_{request.repository_id}"
|
||||
# 注意:在实际实现中,应该从算法仓库中获取项目文件
|
||||
# 这里简化处理,创建一个模拟的项目结构
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
# 2. 从Gitea仓库克隆代码到本地
|
||||
repo_path = f"/tmp/algorithms/{request.repository_id}"
|
||||
|
||||
# 创建模拟的算法文件
|
||||
with open(os.path.join(repo_path, "algorithm.py"), "w") as f:
|
||||
f.write("""
|
||||
def predict(data):
|
||||
return {"result": "Prediction result", "input": data}
|
||||
|
||||
def run(data):
|
||||
return {"result": "Run result", "input": data}
|
||||
|
||||
def main(data):
|
||||
return {"result": "Main result", "input": data}
|
||||
""")
|
||||
# 使用Gitea服务克隆仓库
|
||||
clone_success = gitea_service.clone_repository(repo.repo_url, request.repository_id, repo.branch or "main")
|
||||
if not clone_success:
|
||||
raise HTTPException(status_code=400, detail=f"克隆仓库失败: {repo.repo_url}")
|
||||
|
||||
# 分析项目
|
||||
print(f"仓库克隆成功: {repo_path}")
|
||||
|
||||
# 3. 分析项目
|
||||
project_info = project_analyzer.analyze_project(repo_path)
|
||||
if not project_info["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"项目分析失败: {project_info['error']}")
|
||||
|
||||
# 3. 生成服务包装器
|
||||
print(f"项目分析成功: {project_info}")
|
||||
|
||||
# 4. 生成服务包装器
|
||||
service_config = {
|
||||
"name": request.name,
|
||||
"version": request.version,
|
||||
@@ -194,24 +191,31 @@ def main(data):
|
||||
if not generate_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务生成失败: {generate_result['error']}")
|
||||
|
||||
# 4. 部署服务
|
||||
print(f"服务生成成功: {generate_result}")
|
||||
|
||||
# 5. 部署服务
|
||||
service_id = str(uuid.uuid4())
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info)
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info, repo_path)
|
||||
if not deploy_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
|
||||
|
||||
# 5. 保存服务信息到数据库
|
||||
print(f"服务部署成功: {deploy_result}")
|
||||
|
||||
# 6. 保存服务信息到数据库
|
||||
new_service = AlgorithmService(
|
||||
id=str(uuid.uuid4()),
|
||||
service_id=service_id,
|
||||
name=request.name,
|
||||
algorithm_name=repo.name, # 使用仓库名称作为算法名称
|
||||
version=request.version,
|
||||
tech_category=request.tech_category,
|
||||
output_type=request.output_type,
|
||||
host=request.host,
|
||||
port=request.port,
|
||||
api_url=deploy_result["api_url"],
|
||||
status=deploy_result["status"],
|
||||
config={
|
||||
"repository_id": request.repository_id, # 保存仓库ID
|
||||
"service_type": request.service_type,
|
||||
"timeout": request.timeout,
|
||||
"health_check_path": request.health_check_path,
|
||||
@@ -352,8 +356,64 @@ async def start_service(
|
||||
|
||||
# 启动服务
|
||||
start_result = service_orchestrator.start_service(service_id, container_id)
|
||||
|
||||
# 如果启动失败,尝试从数据库重新注册服务
|
||||
if not start_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务启动失败: {start_result['error']}")
|
||||
print(f"服务启动失败: {start_result['error']},尝试从数据库重新注册服务")
|
||||
|
||||
# 获取仓库信息
|
||||
repository_id = service.config.get("repository_id")
|
||||
if not repository_id:
|
||||
raise HTTPException(status_code=400, detail="Repository ID not found in service config")
|
||||
|
||||
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
|
||||
if not repository:
|
||||
raise HTTPException(status_code=404, detail="Repository not found")
|
||||
|
||||
# 从Gitea克隆仓库
|
||||
clone_success = gitea_service.clone_repository(
|
||||
repository.repo_url,
|
||||
service_id,
|
||||
repository.branch or "main"
|
||||
)
|
||||
if not clone_success:
|
||||
raise HTTPException(status_code=400, detail="Failed to clone repository")
|
||||
|
||||
# 仓库路径
|
||||
repo_path = f"/tmp/algorithms/{service_id}"
|
||||
|
||||
# 分析项目
|
||||
project_info = project_analyzer.analyze_project(repo_path)
|
||||
if not project_info:
|
||||
raise HTTPException(status_code=400, detail="Failed to analyze project")
|
||||
|
||||
# 生成服务
|
||||
service_config = {
|
||||
"name": service.name,
|
||||
"version": service.version,
|
||||
"host": service.host,
|
||||
"port": service.port,
|
||||
"timeout": service.config.get("timeout", 30),
|
||||
"health_check_path": service.config.get("health_check_path", "/health"),
|
||||
"environment": service.config.get("environment", {})
|
||||
}
|
||||
|
||||
# 部署服务
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, project_info, service_config, repo_path)
|
||||
if not deploy_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
|
||||
|
||||
# 更新服务配置
|
||||
service.config["container_id"] = deploy_result["container_id"]
|
||||
service.api_url = deploy_result["api_url"]
|
||||
db.commit()
|
||||
|
||||
start_result = {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"status": "running",
|
||||
"error": None
|
||||
}
|
||||
|
||||
# 更新服务状态
|
||||
service.status = start_result["status"]
|
||||
@@ -1065,3 +1125,108 @@ async def batch_delete_services(
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class ServiceCallRequest(BaseModel):
|
||||
"""服务调用请求"""
|
||||
service_id: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class ServiceCallResponse(BaseModel):
|
||||
"""服务调用响应"""
|
||||
success: bool
|
||||
result: Dict[str, Any]
|
||||
service_id: str
|
||||
execution_time: float
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/call")
|
||||
async def call_service(
|
||||
request: ServiceCallRequest,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""直接调用注册的服务"""
|
||||
import time
|
||||
import httpx
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.service_id == request.service_id
|
||||
).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
# 检查服务状态
|
||||
if service.status != "running":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"服务未运行,当前状态: {service.status}"
|
||||
)
|
||||
|
||||
# 调用服务
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 构建服务URL
|
||||
service_url = service.api_url
|
||||
|
||||
# 如果URL没有路径,添加默认路径
|
||||
if not service_url.endswith("/"):
|
||||
service_url += "/"
|
||||
|
||||
# 添加调用端点
|
||||
call_url = f"{service_url}predict"
|
||||
|
||||
# 使用httpx调用服务
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
call_url,
|
||||
json=request.payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
return ServiceCallResponse(
|
||||
success=True,
|
||||
result=response.json(),
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time
|
||||
)
|
||||
else:
|
||||
return ServiceCallResponse(
|
||||
success=False,
|
||||
result={},
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time,
|
||||
error=f"服务返回错误: HTTP {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ServiceCallResponse(
|
||||
success=False,
|
||||
result={},
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time,
|
||||
error=f"无法连接到服务: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ServiceCallResponse(
|
||||
success=False,
|
||||
result={},
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time,
|
||||
error=f"服务调用异常: {str(e)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -167,6 +167,12 @@ async def read_users_me(current_user: UserResponse = Depends(get_current_active_
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/me/", response_model=UserResponse)
|
||||
async def read_users_me_with_slash(current_user: UserResponse = Depends(get_current_active_user)):
|
||||
"""获取当前用户信息(带末尾斜杠)"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
skip: int = 0,
|
||||
|
||||
Reference in New Issue
Block a user