good version for 算法注册

This commit is contained in:
2026-02-15 21:23:28 +08:00
parent 3c03777b97
commit 62ea5d36a5
115 changed files with 9566 additions and 1576 deletions

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment, config, comparison
api_router = APIRouter()
@@ -11,3 +11,5 @@ api_router.include_router(gateway.router, prefix="/gateway", tags=["gateway"])
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
api_router.include_router(openai.router, prefix="/openai", tags=["openai"])
api_router.include_router(deployment.router, tags=["deployment"])
api_router.include_router(config.router, tags=["config"])
api_router.include_router(comparison.router, tags=["comparison"])

Binary file not shown.

View File

@@ -33,6 +33,18 @@ async def create_algorithm(
@router.get("", response_model=AlgorithmListResponse)
async def get_algorithms_no_slash(
skip: int = 0,
limit: int = 100,
type: Optional[str] = None,
db: Session = Depends(get_db)
):
"""获取算法列表(不带末尾斜杠)"""
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
return {"algorithms": algorithms, "total": len(algorithms)}
@router.get("/", response_model=AlgorithmListResponse)
async def get_algorithms(
skip: int = 0,
limit: int = 100,

View File

@@ -0,0 +1,510 @@
"""API管理路由处理API端点的封装和管理"""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
import logging
from datetime import datetime
from app.models.database import get_db
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
from app.models.api import ApiEndpoint, ApiCallLog
from app.schemas.user import UserResponse
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/api-management", tags=["api-management"])
logger = logging.getLogger(__name__)
class ApiEndpointCreate(BaseModel):
"""创建API端点请求模型"""
name: str
description: str = ""
path: str
method: str = "POST"
algorithm_id: str
version_id: str
service_id: Optional[str] = None
requires_auth: bool = True
allowed_roles: List[str] = []
rate_limit: Optional[Dict[str, Any]] = None
is_public: bool = False
config: Dict[str, Any] = {}
class ApiEndpointUpdate(BaseModel):
"""更新API端点请求模型"""
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None
method: Optional[str] = None
requires_auth: Optional[bool] = None
allowed_roles: Optional[List[str]] = None
rate_limit: Optional[Dict[str, Any]] = None
is_public: Optional[bool] = None
config: Optional[Dict[str, Any]] = None
status: Optional[str] = None
class ApiEndpointResponse(BaseModel):
"""API端点响应模型"""
id: str
name: str
description: str
path: str
method: str
algorithm_id: str
algorithm_name: str
version_id: str
version: str
service_id: Optional[str]
status: str
is_public: bool
call_count: str
success_count: str
error_count: str
avg_response_time: str
created_at: datetime
updated_at: Optional[datetime]
last_called_at: Optional[datetime]
class ApiEndpointListResponse(BaseModel):
"""API端点列表响应模型"""
endpoints: List[ApiEndpointResponse]
total: int
class ApiStatsResponse(BaseModel):
"""API统计响应模型"""
total_endpoints: int
active_endpoints: int
total_calls: str
total_success: str
total_errors: str
avg_response_time: str
@router.get("/endpoints", response_model=ApiEndpointListResponse)
async def get_api_endpoints(
skip: int = 0,
limit: int = 100,
algorithm_id: Optional[str] = None,
status: Optional[str] = None,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点列表"""
try:
# 构建查询
query = db.query(ApiEndpoint)
# 筛选条件
if algorithm_id:
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
if status:
query = query.filter(ApiEndpoint.status == status)
# 分页
endpoints = query.offset(skip).limit(limit).all()
total = query.count()
# 构建响应
endpoint_responses = []
for endpoint in endpoints:
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
endpoint_responses.append({
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
})
return {
"endpoints": endpoint_responses,
"total": total
}
except Exception as e:
logger.error(f"获取API端点列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def get_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点详情"""
try:
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API端点详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
async def create_api_endpoint(
request: ApiEndpointCreate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""创建API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 验证算法和版本是否存在
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
if not algorithm:
raise HTTPException(status_code=404, detail="算法不存在")
version = db.query(AlgorithmVersion).filter(
AlgorithmVersion.id == request.version_id
).first()
if not version or version.algorithm_id != request.algorithm_id:
raise HTTPException(status_code=404, detail="算法版本不存在")
# 如果指定了服务ID验证服务是否存在
if request.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查API路径是否已存在
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
# 创建API端点
new_endpoint = ApiEndpoint(
name=request.name,
description=request.description,
path=request.path,
method=request.method,
algorithm_id=request.algorithm_id,
version_id=request.version_id,
service_id=request.service_id,
requires_auth=request.requires_auth,
allowed_roles=request.allowed_roles,
rate_limit=request.rate_limit,
is_public=request.is_public,
config=request.config,
status="active",
call_count="0",
success_count="0",
error_count="0",
avg_response_time="0.0"
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
# 返回创建的API端点
return {
"id": new_endpoint.id,
"name": new_endpoint.name,
"description": new_endpoint.description,
"path": new_endpoint.path,
"method": new_endpoint.method,
"algorithm_id": new_endpoint.algorithm_id,
"algorithm_name": algorithm.name,
"version_id": new_endpoint.version_id,
"version": version.version,
"service_id": new_endpoint.service_id,
"status": new_endpoint.status,
"is_public": new_endpoint.is_public,
"call_count": new_endpoint.call_count,
"success_count": new_endpoint.success_count,
"error_count": new_endpoint.error_count,
"avg_response_time": new_endpoint.avg_response_time,
"created_at": new_endpoint.created_at,
"updated_at": new_endpoint.updated_at,
"last_called_at": new_endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def update_api_endpoint(
endpoint_id: str,
request: ApiEndpointUpdate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""更新API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 更新字段
if request.name is not None:
endpoint.name = request.name
if request.description is not None:
endpoint.description = request.description
if request.path is not None:
# 检查新路径是否已被其他端点使用
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path,
ApiEndpoint.id != endpoint_id
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
endpoint.path = request.path
if request.method is not None:
endpoint.method = request.method
if request.requires_auth is not None:
endpoint.requires_auth = request.requires_auth
if request.allowed_roles is not None:
endpoint.allowed_roles = request.allowed_roles
if request.rate_limit is not None:
endpoint.rate_limit = request.rate_limit
if request.is_public is not None:
endpoint.is_public = request.is_public
if request.config is not None:
endpoint.config = request.config
if request.status is not None:
endpoint.status = request.status
db.commit()
db.refresh(endpoint)
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
@router.delete("/endpoints/{endpoint_id}")
async def delete_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""删除API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 删除API端点
db.delete(endpoint)
db.commit()
return {
"success": True,
"message": "API端点删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
@router.get("/stats", response_model=ApiStatsResponse)
async def get_api_stats(
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API统计信息"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 统计API端点
total_endpoints = db.query(ApiEndpoint).count()
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
# 统计调用次数
endpoints = db.query(ApiEndpoint).all()
total_calls = sum(int(e.call_count or 0) for e in endpoints)
total_success = sum(int(e.success_count or 0) for e in endpoints)
total_errors = sum(int(e.error_count or 0) for e in endpoints)
# 计算平均响应时间
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
return {
"total_endpoints": total_endpoints,
"active_endpoints": active_endpoints,
"total_calls": str(total_calls),
"total_success": str(total_success),
"total_errors": str(total_errors),
"avg_response_time": f"{avg_response_time:.2f}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API统计信息失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
@router.post("/endpoints/{endpoint_id}/test")
async def test_api_endpoint(
endpoint_id: str,
payload: Dict[str, Any],
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""测试API端点"""
try:
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 检查API端点状态
if endpoint.status != "active":
raise HTTPException(status_code=400, detail="API端点未激活")
# 查询关联的服务
if endpoint.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == endpoint.service_id
).first()
if not service or service.status != "running":
raise HTTPException(status_code=400, detail="关联服务未运行")
# 调用服务
import httpx
import time
service_url = service.api_url
if not service_url.endswith("/"):
service_url += "/"
call_url = f"{service_url}predict"
start_time = time.time()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=payload,
headers={"Content-Type": "application/json"}
)
response_time = time.time() - start_time
if response.status_code == 200:
return {
"success": True,
"result": response.json(),
"response_time": response_time,
"message": "API调用成功"
}
else:
return {
"success": False,
"error": f"服务返回错误: HTTP {response.status_code}",
"response_time": response_time
}
else:
raise HTTPException(status_code=400, detail="API端点未关联服务")
except HTTPException:
raise
except Exception as e:
logger.error(f"测试API端点失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")

View File

@@ -0,0 +1,64 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any, List
from app.services.comparison_service import ComparisonService
from app.routes.user import get_current_active_user
# 创建路由器
router = APIRouter(prefix="/comparison", tags=["comparison"])
# 创建对比服务实例
comparison_service = ComparisonService()
@router.post("/compare-algorithms", response_model=dict)
async def compare_algorithms(
request_data: Dict[str, Any],
current_user: dict = Depends(get_current_active_user)
):
"""比较多个算法的效果
Args:
request_data: 请求数据包含input_data和algorithm_configs
current_user: 当前活跃用户
Returns:
对比结果
"""
input_data = request_data.get("input_data")
algorithm_configs = request_data.get("algorithm_configs")
if not input_data:
raise HTTPException(status_code=400, detail="缺少 input_data 参数")
if not algorithm_configs or not isinstance(algorithm_configs, list):
raise HTTPException(status_code=400, detail="缺少 algorithm_configs 参数或格式错误")
result = await comparison_service.compare_algorithms(input_data, algorithm_configs)
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("error", "对比失败"))
return result
@router.post("/generate-report", response_model=dict)
async def generate_comparison_report(
comparison_results: Dict[str, Any],
current_user: dict = Depends(get_current_active_user)
):
"""生成对比报告
Args:
comparison_results: 对比结果
current_user: 当前活跃用户
Returns:
对比报告
"""
report = comparison_service.generate_comparison_report(comparison_results)
if not report["success"]:
raise HTTPException(status_code=500, detail=report.get("error", "生成报告失败"))
return report

View File

@@ -0,0 +1,124 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, List, Optional
from app.models.database import get_db
from app.services.config_service import ConfigService
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/{config_key}")
async def get_config(
config_key: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取配置
Args:
config_key: 配置键
db: 数据库会话
current_user: 当前活跃用户
Returns:
配置信息
"""
config = ConfigService.get_config(db, config_key)
if not config:
raise HTTPException(status_code=404, detail="配置不存在")
return {"key": config_key, "value": config}
@router.post("/{config_key}")
async def set_config(
config_key: str,
config_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""设置配置
Args:
config_key: 配置键
config_data: 配置数据包含value、type、service_id、description等字段
db: 数据库会话
current_user: 当前活跃用户
Returns:
设置结果
"""
success = ConfigService.set_config(
db=db,
config_key=config_key,
config_value=config_data.get("value"),
config_type=config_data.get("type", "system"),
service_id=config_data.get("service_id"),
description=config_data.get("description", "")
)
if not success:
raise HTTPException(status_code=400, detail="设置配置失败")
return {"message": "设置配置成功"}
@router.get("/service/{service_id}")
async def get_service_configs(
service_id: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取服务配置
Args:
service_id: 服务ID
db: 数据库会话
current_user: 当前活跃用户
Returns:
服务配置列表
"""
configs = ConfigService.get_service_configs(db, service_id)
return {"service_id": service_id, "configs": configs}
@router.delete("/{config_key}")
async def delete_config(
config_key: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""删除配置
Args:
config_key: 配置键
db: 数据库会话
current_user: 当前活跃用户
Returns:
删除结果
"""
success = ConfigService.delete_config(db, config_key)
if not success:
raise HTTPException(status_code=400, detail="删除配置失败")
return {"message": "删除配置成功"}
@router.get("/")
async def get_all_configs(
config_type: Optional[str] = None,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取所有配置
Args:
config_type: 配置类型,可选
db: 数据库会话
current_user: 当前活跃用户
Returns:
配置列表
"""
configs = ConfigService.get_all_configs(db, config_type)
return {"configs": configs}

View File

@@ -14,6 +14,7 @@ from app.schemas.user import UserResponse
from app.services.project_analyzer import ProjectAnalyzer
from app.services.service_generator import ServiceGenerator
from app.services.service_orchestrator import ServiceOrchestrator
from app.gitea.service import gitea_service
router = APIRouter(prefix="/services", tags=["services"])
@@ -23,6 +24,9 @@ class RegisterServiceRequest(BaseModel):
repository_id: str
name: str
version: str = "1.0.0"
description: Optional[str] = ""
tech_category: str = "computer_vision"
output_type: str = "image"
service_type: str = "http"
host: str = "0.0.0.0"
port: int = 8000
@@ -154,31 +158,24 @@ async def register_service(
# 记录仓库信息
print(f"仓库信息: {repo.name}, {repo.description}, {repo.repo_url}")
# 2. 分析项目
repo_path = f"/tmp/repository_{request.repository_id}"
# 注意:在实际实现中,应该从算法仓库中获取项目文件
# 这里简化处理,创建一个模拟的项目结构
os.makedirs(repo_path, exist_ok=True)
# 2. 从Gitea仓库克隆代码到本地
repo_path = f"/tmp/algorithms/{request.repository_id}"
# 创建模拟的算法文件
with open(os.path.join(repo_path, "algorithm.py"), "w") as f:
f.write("""
def predict(data):
return {"result": "Prediction result", "input": data}
def run(data):
return {"result": "Run result", "input": data}
def main(data):
return {"result": "Main result", "input": data}
""")
# 使用Gitea服务克隆仓库
clone_success = gitea_service.clone_repository(repo.repo_url, request.repository_id, repo.branch or "main")
if not clone_success:
raise HTTPException(status_code=400, detail=f"克隆仓库失败: {repo.repo_url}")
# 分析项目
print(f"仓库克隆成功: {repo_path}")
# 3. 分析项目
project_info = project_analyzer.analyze_project(repo_path)
if not project_info["success"]:
raise HTTPException(status_code=400, detail=f"项目分析失败: {project_info['error']}")
# 3. 生成服务包装器
print(f"项目分析成功: {project_info}")
# 4. 生成服务包装器
service_config = {
"name": request.name,
"version": request.version,
@@ -194,24 +191,31 @@ def main(data):
if not generate_result["success"]:
raise HTTPException(status_code=400, detail=f"服务生成失败: {generate_result['error']}")
# 4. 部署服务
print(f"服务生成成功: {generate_result}")
# 5. 部署服务
service_id = str(uuid.uuid4())
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info)
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info, repo_path)
if not deploy_result["success"]:
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
# 5. 保存服务信息到数据库
print(f"服务部署成功: {deploy_result}")
# 6. 保存服务信息到数据库
new_service = AlgorithmService(
id=str(uuid.uuid4()),
service_id=service_id,
name=request.name,
algorithm_name=repo.name, # 使用仓库名称作为算法名称
version=request.version,
tech_category=request.tech_category,
output_type=request.output_type,
host=request.host,
port=request.port,
api_url=deploy_result["api_url"],
status=deploy_result["status"],
config={
"repository_id": request.repository_id, # 保存仓库ID
"service_type": request.service_type,
"timeout": request.timeout,
"health_check_path": request.health_check_path,
@@ -352,8 +356,64 @@ async def start_service(
# 启动服务
start_result = service_orchestrator.start_service(service_id, container_id)
# 如果启动失败,尝试从数据库重新注册服务
if not start_result["success"]:
raise HTTPException(status_code=400, detail=f"服务启动失败: {start_result['error']}")
print(f"服务启动失败: {start_result['error']},尝试从数据库重新注册服务")
# 获取仓库信息
repository_id = service.config.get("repository_id")
if not repository_id:
raise HTTPException(status_code=400, detail="Repository ID not found in service config")
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
if not repository:
raise HTTPException(status_code=404, detail="Repository not found")
# 从Gitea克隆仓库
clone_success = gitea_service.clone_repository(
repository.repo_url,
service_id,
repository.branch or "main"
)
if not clone_success:
raise HTTPException(status_code=400, detail="Failed to clone repository")
# 仓库路径
repo_path = f"/tmp/algorithms/{service_id}"
# 分析项目
project_info = project_analyzer.analyze_project(repo_path)
if not project_info:
raise HTTPException(status_code=400, detail="Failed to analyze project")
# 生成服务
service_config = {
"name": service.name,
"version": service.version,
"host": service.host,
"port": service.port,
"timeout": service.config.get("timeout", 30),
"health_check_path": service.config.get("health_check_path", "/health"),
"environment": service.config.get("environment", {})
}
# 部署服务
deploy_result = service_orchestrator.deploy_service(service_id, project_info, service_config, repo_path)
if not deploy_result["success"]:
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
# 更新服务配置
service.config["container_id"] = deploy_result["container_id"]
service.api_url = deploy_result["api_url"]
db.commit()
start_result = {
"success": True,
"service_id": service_id,
"status": "running",
"error": None
}
# 更新服务状态
service.status = start_result["status"]
@@ -1065,3 +1125,108 @@ async def batch_delete_services(
)
finally:
db.close()
class ServiceCallRequest(BaseModel):
"""服务调用请求"""
service_id: str
payload: Dict[str, Any]
class ServiceCallResponse(BaseModel):
"""服务调用响应"""
success: bool
result: Dict[str, Any]
service_id: str
execution_time: float
error: Optional[str] = None
@router.post("/call")
async def call_service(
request: ServiceCallRequest,
current_user: UserResponse = Depends(get_current_active_user)
):
"""直接调用注册的服务"""
import time
import httpx
# 创建数据库会话
db = SessionLocal()
try:
# 查询服务
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查服务状态
if service.status != "running":
raise HTTPException(
status_code=503,
detail=f"服务未运行,当前状态: {service.status}"
)
# 调用服务
start_time = time.time()
try:
# 构建服务URL
service_url = service.api_url
# 如果URL没有路径添加默认路径
if not service_url.endswith("/"):
service_url += "/"
# 添加调用端点
call_url = f"{service_url}predict"
# 使用httpx调用服务
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=request.payload,
headers={"Content-Type": "application/json"}
)
execution_time = time.time() - start_time
if response.status_code == 200:
return ServiceCallResponse(
success=True,
result=response.json(),
service_id=request.service_id,
execution_time=execution_time
)
else:
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"服务返回错误: HTTP {response.status_code} - {response.text}"
)
except httpx.RequestError as e:
execution_time = time.time() - start_time
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"无法连接到服务: {str(e)}"
)
except Exception as e:
execution_time = time.time() - start_time
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"服务调用异常: {str(e)}"
)
finally:
db.close()

View File

@@ -167,6 +167,12 @@ async def read_users_me(current_user: UserResponse = Depends(get_current_active_
return current_user
@router.get("/me/", response_model=UserResponse)
async def read_users_me_with_slash(current_user: UserResponse = Depends(get_current_active_user)):
"""获取当前用户信息(带末尾斜杠)"""
return current_user
@router.get("/", response_model=UserListResponse)
async def get_users(
skip: int = 0,