final version
This commit is contained in:
@@ -5,10 +5,12 @@ from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
import os
|
||||
import logging
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.models import AlgorithmService, ServiceGroup, AlgorithmRepository
|
||||
from app.models.models import AlgorithmService, AlgorithmRepository, Algorithm, AlgorithmVersion
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.api import ApiEndpoint
|
||||
from app.routes.user import get_current_active_user
|
||||
from app.schemas.user import UserResponse
|
||||
from app.services.project_analyzer import ProjectAnalyzer
|
||||
@@ -17,6 +19,7 @@ from app.services.service_orchestrator import ServiceOrchestrator
|
||||
from app.gitea.service import gitea_service
|
||||
|
||||
router = APIRouter(prefix="/services", tags=["services"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterServiceRequest(BaseModel):
|
||||
@@ -28,7 +31,7 @@ class RegisterServiceRequest(BaseModel):
|
||||
tech_category: str = "computer_vision"
|
||||
output_type: str = "image"
|
||||
service_type: str = "http"
|
||||
host: str = "0.0.0.0"
|
||||
host: str = "localhost"
|
||||
port: int = 8000
|
||||
timeout: int = 30
|
||||
health_check_path: str = "/health"
|
||||
@@ -89,34 +92,6 @@ class RepositoryAlgorithmsResponse(BaseModel):
|
||||
algorithms: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ServiceGroupRequest(BaseModel):
|
||||
"""服务分组请求"""
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class ServiceGroupResponse(BaseModel):
|
||||
"""服务分组响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
status: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class ServiceGroupListResponse(BaseModel):
|
||||
"""服务分组列表响应"""
|
||||
success: bool
|
||||
groups: List[ServiceGroupResponse]
|
||||
|
||||
|
||||
class ServiceGroupDetailResponse(BaseModel):
|
||||
"""服务分组详情响应"""
|
||||
success: bool
|
||||
group: ServiceGroupResponse
|
||||
|
||||
|
||||
class BatchOperationRequest(BaseModel):
|
||||
"""批量操作请求"""
|
||||
service_ids: List[str]
|
||||
@@ -228,7 +203,62 @@ async def register_service(
|
||||
db.commit()
|
||||
db.refresh(new_service)
|
||||
|
||||
# 6. 返回响应
|
||||
# 7. 自动创建API端点
|
||||
try:
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.name == repo.name).first()
|
||||
if not algorithm:
|
||||
algorithm = Algorithm(
|
||||
id=str(uuid.uuid4()),
|
||||
name=repo.name,
|
||||
description=request.description or f"算法服务: {request.name}",
|
||||
type=request.tech_category,
|
||||
tech_category=request.tech_category,
|
||||
output_type=request.output_type
|
||||
)
|
||||
db.add(algorithm)
|
||||
db.commit()
|
||||
db.refresh(algorithm)
|
||||
|
||||
version = db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == algorithm.id,
|
||||
AlgorithmVersion.version == request.version
|
||||
).first()
|
||||
if not version:
|
||||
version = AlgorithmVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
algorithm_id=algorithm.id,
|
||||
version=request.version,
|
||||
url=request.service_url if hasattr(request, 'service_url') else ""
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
api_endpoint = ApiEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
name=request.name,
|
||||
description=request.description or f"{request.name} API端点",
|
||||
path=f"/api/v1/algorithms/{algorithm.id}/call",
|
||||
method="POST",
|
||||
algorithm_id=algorithm.id,
|
||||
version_id=version.id,
|
||||
service_id=service_id,
|
||||
requires_auth=False,
|
||||
is_public=True,
|
||||
status="active",
|
||||
config={
|
||||
"service_url": deploy_result["api_url"],
|
||||
"timeout": request.timeout,
|
||||
"health_check_path": request.health_check_path
|
||||
}
|
||||
)
|
||||
db.add(api_endpoint)
|
||||
db.commit()
|
||||
logger.info(f"API端点创建成功: {api_endpoint.name}, 路径: {api_endpoint.path}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建API端点失败: {str(e)}")
|
||||
|
||||
# 8. 返回响应
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务注册成功",
|
||||
@@ -537,6 +567,12 @@ async def delete_service(
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 先删除关联的API端点
|
||||
db.query(ApiEndpoint).filter(ApiEndpoint.service_id == service_id).delete()
|
||||
|
||||
# 获取算法名称,用于后续删除算法记录
|
||||
algorithm_name = service.algorithm_name
|
||||
|
||||
# 获取容器ID和镜像名称
|
||||
container_id = service.config.get("container_id")
|
||||
image_name = f"algorithm-service-{service_id}:{service.version}"
|
||||
@@ -549,6 +585,17 @@ async def delete_service(
|
||||
|
||||
# 删除数据库记录
|
||||
db.delete(service)
|
||||
|
||||
# 删除关联的算法记录(通过算法名称匹配)
|
||||
if algorithm_name:
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.name == algorithm_name).first()
|
||||
if algorithm:
|
||||
# 先删除关联的算法版本
|
||||
db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm.id).delete()
|
||||
# 再删除算法记录
|
||||
db.query(AlgorithmCall).filter(AlgorithmCall.algorithm_id == algorithm.id).delete()
|
||||
db.delete(algorithm)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 返回响应
|
||||
@@ -677,202 +724,6 @@ async def get_repository_algorithms(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# 服务分组管理API
|
||||
|
||||
@router.post("/groups", status_code=status.HTTP_201_CREATED)
|
||||
async def create_service_group(
|
||||
request: ServiceGroupRequest,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""创建服务分组"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 生成唯一ID
|
||||
group_id = str(uuid.uuid4())
|
||||
|
||||
# 创建分组实例
|
||||
group = ServiceGroup(
|
||||
id=group_id,
|
||||
name=request.name,
|
||||
description=request.description
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(group)
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务分组创建成功",
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"status": group.status,
|
||||
"created_at": group.created_at.isoformat(),
|
||||
"updated_at": group.updated_at.isoformat()
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/groups", response_model=ServiceGroupListResponse)
|
||||
async def list_service_groups(
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务分组列表"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组列表
|
||||
groups = db.query(ServiceGroup).all()
|
||||
|
||||
# 转换为响应格式
|
||||
group_list = []
|
||||
for group in groups:
|
||||
group_list.append(ServiceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
status=group.status,
|
||||
created_at=group.created_at.isoformat(),
|
||||
updated_at=group.updated_at.isoformat()
|
||||
))
|
||||
|
||||
return ServiceGroupListResponse(
|
||||
success=True,
|
||||
groups=group_list
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/groups/{group_id}", response_model=ServiceGroupDetailResponse)
|
||||
async def get_service_group(
|
||||
group_id: str,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务分组详情"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组
|
||||
group = db.query(ServiceGroup).filter(ServiceGroup.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Service group not found")
|
||||
|
||||
return ServiceGroupDetailResponse(
|
||||
success=True,
|
||||
group=ServiceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
status=group.status,
|
||||
created_at=group.created_at.isoformat(),
|
||||
updated_at=group.updated_at.isoformat()
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.put("/groups/{group_id}")
|
||||
async def update_service_group(
|
||||
group_id: str,
|
||||
request: ServiceGroupRequest,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""更新服务分组"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组
|
||||
group = db.query(ServiceGroup).filter(ServiceGroup.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Service group not found")
|
||||
|
||||
# 更新分组信息
|
||||
group.name = request.name
|
||||
group.description = request.description
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务分组更新成功",
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"status": group.status,
|
||||
"created_at": group.created_at.isoformat(),
|
||||
"updated_at": group.updated_at.isoformat()
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/groups/{group_id}")
|
||||
async def delete_service_group(
|
||||
group_id: str,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除服务分组"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组
|
||||
group = db.query(ServiceGroup).filter(ServiceGroup.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Service group not found")
|
||||
|
||||
# 检查分组是否有服务
|
||||
services_count = db.query(AlgorithmService).filter(AlgorithmService.group_id == group_id).count()
|
||||
if services_count > 0:
|
||||
raise HTTPException(status_code=400, detail=f"该分组下还有{services_count}个服务,无法删除")
|
||||
|
||||
# 删除分组
|
||||
db.delete(group)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务分组删除成功",
|
||||
"group_id": group_id
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 批量服务操作API
|
||||
|
||||
@router.post("/batch/start")
|
||||
@@ -1230,3 +1081,85 @@ async def call_service(
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/sync-api-endpoints")
|
||||
async def sync_api_endpoints(
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""同步所有服务到API端点"""
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="权限不足")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
services = db.query(AlgorithmService).all()
|
||||
synced_count = 0
|
||||
|
||||
for service in services:
|
||||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||||
(ApiEndpoint.service_id == service.service_id) |
|
||||
(ApiEndpoint.path == f"/api/v1/algorithms/{service.algorithm_name}/call")
|
||||
).first()
|
||||
|
||||
if existing_endpoint:
|
||||
continue
|
||||
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.name == service.algorithm_name).first()
|
||||
if not algorithm:
|
||||
algorithm = Algorithm(
|
||||
id=str(uuid.uuid4()),
|
||||
name=service.algorithm_name,
|
||||
description=f"算法服务: {service.name}",
|
||||
type=service.tech_category or "computer_vision",
|
||||
tech_category=service.tech_category or "computer_vision",
|
||||
output_type=service.output_type or "image"
|
||||
)
|
||||
db.add(algorithm)
|
||||
db.commit()
|
||||
db.refresh(algorithm)
|
||||
|
||||
version = db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == algorithm.id
|
||||
).first()
|
||||
if not version:
|
||||
version = AlgorithmVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
algorithm_id=algorithm.id,
|
||||
version=service.version or "1.0.0",
|
||||
url=service.api_url
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
api_endpoint = ApiEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
name=service.name,
|
||||
description=f"{service.name} API端点",
|
||||
path=f"/api/v1/algorithms/{algorithm.id}/call/{service.service_id[:8]}",
|
||||
method="POST",
|
||||
algorithm_id=algorithm.id,
|
||||
version_id=version.id,
|
||||
service_id=service.service_id,
|
||||
requires_auth=False,
|
||||
is_public=True,
|
||||
status=service.status or "active",
|
||||
config={
|
||||
"service_url": service.api_url,
|
||||
"timeout": service.config.get("timeout") if service.config else 30
|
||||
}
|
||||
)
|
||||
db.add(api_endpoint)
|
||||
synced_count += 1
|
||||
|
||||
db.commit()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"同步完成,共同步 {synced_count} 个API端点"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"同步API端点失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"同步失败: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user