good version for 算法注册

This commit is contained in:
2026-02-15 21:23:28 +08:00
parent 3c03777b97
commit 62ea5d36a5
115 changed files with 9566 additions and 1576 deletions

View File

@@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings
from typing import Optional
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
class Settings(BaseSettings):
@@ -46,11 +47,56 @@ class Settings(BaseSettings):
GITEA_DEFAULT_OWNER: str = ""
GITEA_REPO_PREFIX: str = "AI"
# 服务管理配置
SERVICE_MANAGEMENT: Dict[str, Any] = {
"mode": "supervisor", # 服务管理模式local, docker, supervisor
"service_root_dir": "/opt/ai-services",
"supervisor_config_dir": "/etc/supervisor/conf.d",
}
class Config:
env_file = ".env"
case_sensitive = True
extra = "allow" # 允许额外的环境变量
def get_config(self, config_key: str, default: Any = None) -> Any:
"""获取配置,优先级:环境变量 > 数据库 > 文件默认值
Args:
config_key: 配置键
default: 默认值
Returns:
配置值
"""
# 1. 先从环境变量获取
env_key = config_key.upper().replace('.', '_')
env_value = getattr(self, env_key, None)
if env_value is not None:
return env_value
# 2. 从数据库获取
try:
from app.models.database import SessionLocal
from app.models.models import ServiceConfig
db: Session = SessionLocal()
try:
config = db.query(ServiceConfig).filter_by(
config_key=config_key,
status="active"
).first()
if config:
return config.config_value
finally:
db.close()
except Exception as e:
# 数据库连接失败时,返回默认值
print(f"Failed to load config from database: {str(e)}")
# 3. 返回默认值
return default
# 创建全局配置实例
settings = Settings()

View File

@@ -72,9 +72,24 @@ class APIGateway:
if not version_info:
raise HTTPException(status_code=404, detail="Algorithm version not found")
# 在实际实现中这里会根据version_info.url将请求转发到对应的算法服务
# 现在我们模拟调用过程
algorithm_url = version_info.url if hasattr(version_info, 'url') else f"http://localhost:8001/algorithms/{algorithm_id}/execute"
# 尝试从算法版本获取URL如果没有则尝试从服务表获取
algorithm_url = None
# 首先检查版本信息中是否有URL
if hasattr(version_info, 'url') and version_info.url:
algorithm_url = version_info.url
else:
# 如果版本信息中没有URL尝试从服务表获取
from app.models.models import AlgorithmService
service = db.query(AlgorithmService).filter(
AlgorithmService.algorithm_name == algorithm_id
).first()
if service and service.api_url:
algorithm_url = service.api_url
else:
# 如果都没有,使用默认的本地端点
algorithm_url = f"http://localhost:8001/algorithms/{algorithm_id}/execute"
# 使用httpx调用算法服务
async with httpx.AsyncClient(timeout=30.0) as client:

View File

@@ -4,7 +4,8 @@ from fastapi.middleware.trustedhost import TrustedHostMiddleware
from app.config.settings import settings
from app.models.database import engine, Base
from app.routes import user, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories
from app.models import models, api # 导入所有模型以确保表被创建
from app.routes import user, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories, config, comparison, api_management
# 创建数据库表
Base.metadata.create_all(bind=engine)
@@ -48,6 +49,9 @@ app.include_router(history.router, prefix=settings.API_V1_STR)
app.include_router(deployment.router)
app.include_router(gitea.router, prefix=settings.API_V1_STR)
app.include_router(repositories.router, prefix=settings.API_V1_STR)
app.include_router(config.router, prefix=settings.API_V1_STR)
app.include_router(comparison.router, prefix=settings.API_V1_STR)
app.include_router(api_management.router, prefix=settings.API_V1_STR)
@app.get("/")

Binary file not shown.

74
backend/app/models/api.py Normal file
View File

@@ -0,0 +1,74 @@
"""API封装模型"""
from sqlalchemy import Column, String, DateTime, Text, Boolean, ForeignKey
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from app.models.database import Base
import uuid
class ApiEndpoint(Base):
"""API端点模型"""
__tablename__ = "api_endpoints"
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
name = Column(String, nullable=False, index=True) # API名称
description = Column(Text, default="") # API描述
path = Column(String, nullable=False, unique=True, index=True) # API路径如 /api/image-classification
method = Column(String, default="POST") # HTTP方法GET, POST, PUT, DELETE
algorithm_id = Column(String, ForeignKey("algorithms.id"), nullable=False, index=True) # 关联的算法ID
version_id = Column(String, ForeignKey("algorithm_versions.id"), nullable=False, index=True) # 关联的算法版本ID
service_id = Column(String, ForeignKey("algorithm_services.service_id"), nullable=True, index=True) # 关联的服务ID
# API配置
config = Column(JSON, nullable=False, default={}) # API配置超时、重试等
request_schema = Column(JSON, nullable=True) # 请求参数schema
response_schema = Column(JSON, nullable=True) # 响应参数schema
# 权限配置
requires_auth = Column(Boolean, default=True) # 是否需要认证
allowed_roles = Column(JSON, default=[]) # 允许的角色列表
rate_limit = Column(JSON, nullable=True) # 限流配置,如 {"max_requests": 100, "window": 60}
# 状态
status = Column(String, default="active", index=True) # 状态active, inactive, deprecated
is_public = Column(Boolean, default=False) # 是否公开
# 统计信息
call_count = Column(String, default="0") # 调用次数
success_count = Column(String, default="0") # 成功次数
error_count = Column(String, default="0") # 错误次数
avg_response_time = Column(String, default="0.0") # 平均响应时间
# 时间戳
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), onupdate=datetime.utcnow)
last_called_at = Column(DateTime(timezone=True), nullable=True) # 最后调用时间
class ApiCallLog(Base):
"""API调用日志模型"""
__tablename__ = "api_call_logs"
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
api_endpoint_id = Column(String, ForeignKey("api_endpoints.id"), nullable=False, index=True) # API端点ID
user_id = Column(String, ForeignKey("users.id"), nullable=True, index=True) # 调用用户ID
# 请求信息
request_method = Column(String, nullable=False) # 请求方法
request_path = Column(String, nullable=False) # 请求路径
request_headers = Column(JSON, nullable=True) # 请求头
request_body = Column(JSON, nullable=True) # 请求体
# 响应信息
response_status = Column(String, nullable=False) # 响应状态码
response_body = Column(JSON, nullable=True) # 响应体
response_time = Column(String, nullable=False) # 响应时间(秒)
# 错误信息
error_message = Column(Text, nullable=True) # 错误信息
error_type = Column(String, nullable=True) # 错误类型
# 时间戳
created_at = Column(DateTime(timezone=True), default=datetime.utcnow, index=True) # 调用时间

View File

@@ -13,6 +13,8 @@ class Algorithm(Base):
name = Column(String, nullable=False, index=True)
description = Column(Text, nullable=False)
type = Column(String, nullable=False, index=True) # computer_vision, nlp, ml, edge_computing, medical, autonomous_driving等
tech_category = Column(String, nullable=False, default="computer_vision") # 技术分类:计算机视觉、视频处理、自然语言处理等
output_type = Column(String, nullable=False, default="image") # 输出类型图片、视频、文本、JSON等
status = Column(String, default="active", index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -159,6 +161,8 @@ class AlgorithmService(Base):
name = Column(String, nullable=False, index=True) # 服务名称
algorithm_name = Column(String, nullable=False) # 算法名称
version = Column(String, nullable=False) # 版本
tech_category = Column(String, nullable=False, default="computer_vision") # 技术分类:计算机视觉、视频处理、自然语言处理等
output_type = Column(String, nullable=False, default="image") # 输出类型图片、视频、文本、JSON等
host = Column(String, nullable=False) # 主机地址
port = Column(Integer, nullable=False) # 端口
api_url = Column(String, nullable=False) # API地址
@@ -172,3 +176,18 @@ class AlgorithmService(Base):
# 添加Algorithm模型的repository关系
Algorithm.repository = relationship("AlgorithmRepository", back_populates="algorithm", uselist=False)
class ServiceConfig(Base):
"""服务配置模型"""
__tablename__ = "service_configs"
id = Column(String, primary_key=True, index=True)
config_key = Column(String, nullable=False, unique=True, index=True) # 配置键
config_value = Column(JSON, nullable=False) # 配置值JSON格式
config_type = Column(String, nullable=False) # 配置类型system, service, user
service_id = Column(String, nullable=True, index=True) # 服务ID可为空系统配置
description = Column(Text, default="") # 配置描述
status = Column(String, default="active") # 状态
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment, config, comparison
api_router = APIRouter()
@@ -11,3 +11,5 @@ api_router.include_router(gateway.router, prefix="/gateway", tags=["gateway"])
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
api_router.include_router(openai.router, prefix="/openai", tags=["openai"])
api_router.include_router(deployment.router, tags=["deployment"])
api_router.include_router(config.router, tags=["config"])
api_router.include_router(comparison.router, tags=["comparison"])

Binary file not shown.

View File

@@ -33,6 +33,18 @@ async def create_algorithm(
@router.get("", response_model=AlgorithmListResponse)
async def get_algorithms_no_slash(
skip: int = 0,
limit: int = 100,
type: Optional[str] = None,
db: Session = Depends(get_db)
):
"""获取算法列表(不带末尾斜杠)"""
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
return {"algorithms": algorithms, "total": len(algorithms)}
@router.get("/", response_model=AlgorithmListResponse)
async def get_algorithms(
skip: int = 0,
limit: int = 100,

View File

@@ -0,0 +1,510 @@
"""API管理路由处理API端点的封装和管理"""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
import logging
from datetime import datetime
from app.models.database import get_db
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
from app.models.api import ApiEndpoint, ApiCallLog
from app.schemas.user import UserResponse
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/api-management", tags=["api-management"])
logger = logging.getLogger(__name__)
class ApiEndpointCreate(BaseModel):
"""创建API端点请求模型"""
name: str
description: str = ""
path: str
method: str = "POST"
algorithm_id: str
version_id: str
service_id: Optional[str] = None
requires_auth: bool = True
allowed_roles: List[str] = []
rate_limit: Optional[Dict[str, Any]] = None
is_public: bool = False
config: Dict[str, Any] = {}
class ApiEndpointUpdate(BaseModel):
"""更新API端点请求模型"""
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None
method: Optional[str] = None
requires_auth: Optional[bool] = None
allowed_roles: Optional[List[str]] = None
rate_limit: Optional[Dict[str, Any]] = None
is_public: Optional[bool] = None
config: Optional[Dict[str, Any]] = None
status: Optional[str] = None
class ApiEndpointResponse(BaseModel):
"""API端点响应模型"""
id: str
name: str
description: str
path: str
method: str
algorithm_id: str
algorithm_name: str
version_id: str
version: str
service_id: Optional[str]
status: str
is_public: bool
call_count: str
success_count: str
error_count: str
avg_response_time: str
created_at: datetime
updated_at: Optional[datetime]
last_called_at: Optional[datetime]
class ApiEndpointListResponse(BaseModel):
"""API端点列表响应模型"""
endpoints: List[ApiEndpointResponse]
total: int
class ApiStatsResponse(BaseModel):
"""API统计响应模型"""
total_endpoints: int
active_endpoints: int
total_calls: str
total_success: str
total_errors: str
avg_response_time: str
@router.get("/endpoints", response_model=ApiEndpointListResponse)
async def get_api_endpoints(
skip: int = 0,
limit: int = 100,
algorithm_id: Optional[str] = None,
status: Optional[str] = None,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点列表"""
try:
# 构建查询
query = db.query(ApiEndpoint)
# 筛选条件
if algorithm_id:
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
if status:
query = query.filter(ApiEndpoint.status == status)
# 分页
endpoints = query.offset(skip).limit(limit).all()
total = query.count()
# 构建响应
endpoint_responses = []
for endpoint in endpoints:
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
endpoint_responses.append({
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
})
return {
"endpoints": endpoint_responses,
"total": total
}
except Exception as e:
logger.error(f"获取API端点列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def get_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点详情"""
try:
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API端点详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
async def create_api_endpoint(
request: ApiEndpointCreate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""创建API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 验证算法和版本是否存在
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
if not algorithm:
raise HTTPException(status_code=404, detail="算法不存在")
version = db.query(AlgorithmVersion).filter(
AlgorithmVersion.id == request.version_id
).first()
if not version or version.algorithm_id != request.algorithm_id:
raise HTTPException(status_code=404, detail="算法版本不存在")
# 如果指定了服务ID验证服务是否存在
if request.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查API路径是否已存在
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
# 创建API端点
new_endpoint = ApiEndpoint(
name=request.name,
description=request.description,
path=request.path,
method=request.method,
algorithm_id=request.algorithm_id,
version_id=request.version_id,
service_id=request.service_id,
requires_auth=request.requires_auth,
allowed_roles=request.allowed_roles,
rate_limit=request.rate_limit,
is_public=request.is_public,
config=request.config,
status="active",
call_count="0",
success_count="0",
error_count="0",
avg_response_time="0.0"
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
# 返回创建的API端点
return {
"id": new_endpoint.id,
"name": new_endpoint.name,
"description": new_endpoint.description,
"path": new_endpoint.path,
"method": new_endpoint.method,
"algorithm_id": new_endpoint.algorithm_id,
"algorithm_name": algorithm.name,
"version_id": new_endpoint.version_id,
"version": version.version,
"service_id": new_endpoint.service_id,
"status": new_endpoint.status,
"is_public": new_endpoint.is_public,
"call_count": new_endpoint.call_count,
"success_count": new_endpoint.success_count,
"error_count": new_endpoint.error_count,
"avg_response_time": new_endpoint.avg_response_time,
"created_at": new_endpoint.created_at,
"updated_at": new_endpoint.updated_at,
"last_called_at": new_endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def update_api_endpoint(
endpoint_id: str,
request: ApiEndpointUpdate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""更新API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 更新字段
if request.name is not None:
endpoint.name = request.name
if request.description is not None:
endpoint.description = request.description
if request.path is not None:
# 检查新路径是否已被其他端点使用
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path,
ApiEndpoint.id != endpoint_id
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
endpoint.path = request.path
if request.method is not None:
endpoint.method = request.method
if request.requires_auth is not None:
endpoint.requires_auth = request.requires_auth
if request.allowed_roles is not None:
endpoint.allowed_roles = request.allowed_roles
if request.rate_limit is not None:
endpoint.rate_limit = request.rate_limit
if request.is_public is not None:
endpoint.is_public = request.is_public
if request.config is not None:
endpoint.config = request.config
if request.status is not None:
endpoint.status = request.status
db.commit()
db.refresh(endpoint)
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
@router.delete("/endpoints/{endpoint_id}")
async def delete_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""删除API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 删除API端点
db.delete(endpoint)
db.commit()
return {
"success": True,
"message": "API端点删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
@router.get("/stats", response_model=ApiStatsResponse)
async def get_api_stats(
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API统计信息"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 统计API端点
total_endpoints = db.query(ApiEndpoint).count()
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
# 统计调用次数
endpoints = db.query(ApiEndpoint).all()
total_calls = sum(int(e.call_count or 0) for e in endpoints)
total_success = sum(int(e.success_count or 0) for e in endpoints)
total_errors = sum(int(e.error_count or 0) for e in endpoints)
# 计算平均响应时间
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
return {
"total_endpoints": total_endpoints,
"active_endpoints": active_endpoints,
"total_calls": str(total_calls),
"total_success": str(total_success),
"total_errors": str(total_errors),
"avg_response_time": f"{avg_response_time:.2f}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API统计信息失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
@router.post("/endpoints/{endpoint_id}/test")
async def test_api_endpoint(
endpoint_id: str,
payload: Dict[str, Any],
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""测试API端点"""
try:
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 检查API端点状态
if endpoint.status != "active":
raise HTTPException(status_code=400, detail="API端点未激活")
# 查询关联的服务
if endpoint.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == endpoint.service_id
).first()
if not service or service.status != "running":
raise HTTPException(status_code=400, detail="关联服务未运行")
# 调用服务
import httpx
import time
service_url = service.api_url
if not service_url.endswith("/"):
service_url += "/"
call_url = f"{service_url}predict"
start_time = time.time()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=payload,
headers={"Content-Type": "application/json"}
)
response_time = time.time() - start_time
if response.status_code == 200:
return {
"success": True,
"result": response.json(),
"response_time": response_time,
"message": "API调用成功"
}
else:
return {
"success": False,
"error": f"服务返回错误: HTTP {response.status_code}",
"response_time": response_time
}
else:
raise HTTPException(status_code=400, detail="API端点未关联服务")
except HTTPException:
raise
except Exception as e:
logger.error(f"测试API端点失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")

View File

@@ -0,0 +1,64 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any, List
from app.services.comparison_service import ComparisonService
from app.routes.user import get_current_active_user
# 创建路由器
router = APIRouter(prefix="/comparison", tags=["comparison"])
# 创建对比服务实例
comparison_service = ComparisonService()
@router.post("/compare-algorithms", response_model=dict)
async def compare_algorithms(
request_data: Dict[str, Any],
current_user: dict = Depends(get_current_active_user)
):
"""比较多个算法的效果
Args:
request_data: 请求数据包含input_data和algorithm_configs
current_user: 当前活跃用户
Returns:
对比结果
"""
input_data = request_data.get("input_data")
algorithm_configs = request_data.get("algorithm_configs")
if not input_data:
raise HTTPException(status_code=400, detail="缺少 input_data 参数")
if not algorithm_configs or not isinstance(algorithm_configs, list):
raise HTTPException(status_code=400, detail="缺少 algorithm_configs 参数或格式错误")
result = await comparison_service.compare_algorithms(input_data, algorithm_configs)
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("error", "对比失败"))
return result
@router.post("/generate-report", response_model=dict)
async def generate_comparison_report(
comparison_results: Dict[str, Any],
current_user: dict = Depends(get_current_active_user)
):
"""生成对比报告
Args:
comparison_results: 对比结果
current_user: 当前活跃用户
Returns:
对比报告
"""
report = comparison_service.generate_comparison_report(comparison_results)
if not report["success"]:
raise HTTPException(status_code=500, detail=report.get("error", "生成报告失败"))
return report

View File

@@ -0,0 +1,124 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, List, Optional
from app.models.database import get_db
from app.services.config_service import ConfigService
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/{config_key}")
async def get_config(
config_key: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取配置
Args:
config_key: 配置键
db: 数据库会话
current_user: 当前活跃用户
Returns:
配置信息
"""
config = ConfigService.get_config(db, config_key)
if not config:
raise HTTPException(status_code=404, detail="配置不存在")
return {"key": config_key, "value": config}
@router.post("/{config_key}")
async def set_config(
config_key: str,
config_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""设置配置
Args:
config_key: 配置键
config_data: 配置数据包含value、type、service_id、description等字段
db: 数据库会话
current_user: 当前活跃用户
Returns:
设置结果
"""
success = ConfigService.set_config(
db=db,
config_key=config_key,
config_value=config_data.get("value"),
config_type=config_data.get("type", "system"),
service_id=config_data.get("service_id"),
description=config_data.get("description", "")
)
if not success:
raise HTTPException(status_code=400, detail="设置配置失败")
return {"message": "设置配置成功"}
@router.get("/service/{service_id}")
async def get_service_configs(
service_id: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取服务配置
Args:
service_id: 服务ID
db: 数据库会话
current_user: 当前活跃用户
Returns:
服务配置列表
"""
configs = ConfigService.get_service_configs(db, service_id)
return {"service_id": service_id, "configs": configs}
@router.delete("/{config_key}")
async def delete_config(
config_key: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""删除配置
Args:
config_key: 配置键
db: 数据库会话
current_user: 当前活跃用户
Returns:
删除结果
"""
success = ConfigService.delete_config(db, config_key)
if not success:
raise HTTPException(status_code=400, detail="删除配置失败")
return {"message": "删除配置成功"}
@router.get("/")
async def get_all_configs(
config_type: Optional[str] = None,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取所有配置
Args:
config_type: 配置类型,可选
db: 数据库会话
current_user: 当前活跃用户
Returns:
配置列表
"""
configs = ConfigService.get_all_configs(db, config_type)
return {"configs": configs}

View File

@@ -14,6 +14,7 @@ from app.schemas.user import UserResponse
from app.services.project_analyzer import ProjectAnalyzer
from app.services.service_generator import ServiceGenerator
from app.services.service_orchestrator import ServiceOrchestrator
from app.gitea.service import gitea_service
router = APIRouter(prefix="/services", tags=["services"])
@@ -23,6 +24,9 @@ class RegisterServiceRequest(BaseModel):
repository_id: str
name: str
version: str = "1.0.0"
description: Optional[str] = ""
tech_category: str = "computer_vision"
output_type: str = "image"
service_type: str = "http"
host: str = "0.0.0.0"
port: int = 8000
@@ -154,31 +158,24 @@ async def register_service(
# 记录仓库信息
print(f"仓库信息: {repo.name}, {repo.description}, {repo.repo_url}")
# 2. 分析项目
repo_path = f"/tmp/repository_{request.repository_id}"
# 注意:在实际实现中,应该从算法仓库中获取项目文件
# 这里简化处理,创建一个模拟的项目结构
os.makedirs(repo_path, exist_ok=True)
# 2. 从Gitea仓库克隆代码到本地
repo_path = f"/tmp/algorithms/{request.repository_id}"
# 创建模拟的算法文件
with open(os.path.join(repo_path, "algorithm.py"), "w") as f:
f.write("""
def predict(data):
return {"result": "Prediction result", "input": data}
def run(data):
return {"result": "Run result", "input": data}
def main(data):
return {"result": "Main result", "input": data}
""")
# 使用Gitea服务克隆仓库
clone_success = gitea_service.clone_repository(repo.repo_url, request.repository_id, repo.branch or "main")
if not clone_success:
raise HTTPException(status_code=400, detail=f"克隆仓库失败: {repo.repo_url}")
# 分析项目
print(f"仓库克隆成功: {repo_path}")
# 3. 分析项目
project_info = project_analyzer.analyze_project(repo_path)
if not project_info["success"]:
raise HTTPException(status_code=400, detail=f"项目分析失败: {project_info['error']}")
# 3. 生成服务包装器
print(f"项目分析成功: {project_info}")
# 4. 生成服务包装器
service_config = {
"name": request.name,
"version": request.version,
@@ -194,24 +191,31 @@ def main(data):
if not generate_result["success"]:
raise HTTPException(status_code=400, detail=f"服务生成失败: {generate_result['error']}")
# 4. 部署服务
print(f"服务生成成功: {generate_result}")
# 5. 部署服务
service_id = str(uuid.uuid4())
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info)
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info, repo_path)
if not deploy_result["success"]:
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
# 5. 保存服务信息到数据库
print(f"服务部署成功: {deploy_result}")
# 6. 保存服务信息到数据库
new_service = AlgorithmService(
id=str(uuid.uuid4()),
service_id=service_id,
name=request.name,
algorithm_name=repo.name, # 使用仓库名称作为算法名称
version=request.version,
tech_category=request.tech_category,
output_type=request.output_type,
host=request.host,
port=request.port,
api_url=deploy_result["api_url"],
status=deploy_result["status"],
config={
"repository_id": request.repository_id, # 保存仓库ID
"service_type": request.service_type,
"timeout": request.timeout,
"health_check_path": request.health_check_path,
@@ -352,8 +356,64 @@ async def start_service(
# 启动服务
start_result = service_orchestrator.start_service(service_id, container_id)
# 如果启动失败,尝试从数据库重新注册服务
if not start_result["success"]:
raise HTTPException(status_code=400, detail=f"服务启动失败: {start_result['error']}")
print(f"服务启动失败: {start_result['error']},尝试从数据库重新注册服务")
# 获取仓库信息
repository_id = service.config.get("repository_id")
if not repository_id:
raise HTTPException(status_code=400, detail="Repository ID not found in service config")
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
if not repository:
raise HTTPException(status_code=404, detail="Repository not found")
# 从Gitea克隆仓库
clone_success = gitea_service.clone_repository(
repository.repo_url,
service_id,
repository.branch or "main"
)
if not clone_success:
raise HTTPException(status_code=400, detail="Failed to clone repository")
# 仓库路径
repo_path = f"/tmp/algorithms/{service_id}"
# 分析项目
project_info = project_analyzer.analyze_project(repo_path)
if not project_info:
raise HTTPException(status_code=400, detail="Failed to analyze project")
# 生成服务
service_config = {
"name": service.name,
"version": service.version,
"host": service.host,
"port": service.port,
"timeout": service.config.get("timeout", 30),
"health_check_path": service.config.get("health_check_path", "/health"),
"environment": service.config.get("environment", {})
}
# 部署服务
deploy_result = service_orchestrator.deploy_service(service_id, project_info, service_config, repo_path)
if not deploy_result["success"]:
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
# 更新服务配置
service.config["container_id"] = deploy_result["container_id"]
service.api_url = deploy_result["api_url"]
db.commit()
start_result = {
"success": True,
"service_id": service_id,
"status": "running",
"error": None
}
# 更新服务状态
service.status = start_result["status"]
@@ -1065,3 +1125,108 @@ async def batch_delete_services(
)
finally:
db.close()
class ServiceCallRequest(BaseModel):
"""服务调用请求"""
service_id: str
payload: Dict[str, Any]
class ServiceCallResponse(BaseModel):
"""服务调用响应"""
success: bool
result: Dict[str, Any]
service_id: str
execution_time: float
error: Optional[str] = None
@router.post("/call")
async def call_service(
request: ServiceCallRequest,
current_user: UserResponse = Depends(get_current_active_user)
):
"""直接调用注册的服务"""
import time
import httpx
# 创建数据库会话
db = SessionLocal()
try:
# 查询服务
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查服务状态
if service.status != "running":
raise HTTPException(
status_code=503,
detail=f"服务未运行,当前状态: {service.status}"
)
# 调用服务
start_time = time.time()
try:
# 构建服务URL
service_url = service.api_url
# 如果URL没有路径添加默认路径
if not service_url.endswith("/"):
service_url += "/"
# 添加调用端点
call_url = f"{service_url}predict"
# 使用httpx调用服务
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=request.payload,
headers={"Content-Type": "application/json"}
)
execution_time = time.time() - start_time
if response.status_code == 200:
return ServiceCallResponse(
success=True,
result=response.json(),
service_id=request.service_id,
execution_time=execution_time
)
else:
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"服务返回错误: HTTP {response.status_code} - {response.text}"
)
except httpx.RequestError as e:
execution_time = time.time() - start_time
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"无法连接到服务: {str(e)}"
)
except Exception as e:
execution_time = time.time() - start_time
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"服务调用异常: {str(e)}"
)
finally:
db.close()

View File

@@ -167,6 +167,12 @@ async def read_users_me(current_user: UserResponse = Depends(get_current_active_
return current_user
@router.get("/me/", response_model=UserResponse)
async def read_users_me_with_slash(current_user: UserResponse = Depends(get_current_active_user)):
"""获取当前用户信息(带末尾斜杠)"""
return current_user
@router.get("/", response_model=UserListResponse)
async def get_users(
skip: int = 0,

View File

@@ -9,6 +9,8 @@ class AlgorithmBase(BaseModel):
name: str = Field(..., description="算法名称")
description: str = Field(..., description="算法描述")
type: str = Field(..., description="算法类型")
tech_category: str = Field(default="computer_vision", description="技术分类")
output_type: str = Field(default="image", description="输出类型")
class AlgorithmCreate(AlgorithmBase):

View File

@@ -0,0 +1,165 @@
from typing import Dict, Any, List
import asyncio
import httpx
import logging
logger = logging.getLogger(__name__)
class ComparisonService:
"""效果对比服务"""
async def compare_algorithms(
self,
input_data: Dict[str, Any],
algorithm_configs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""比较多个算法的效果
Args:
input_data: 输入数据
algorithm_configs: 算法配置列表每个配置包含服务URL、参数等
Returns:
对比结果
"""
try:
# 异步执行所有算法
tasks = []
for config in algorithm_configs:
task = self._execute_algorithm(config, input_data)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
comparison_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
comparison_results.append({
"algorithm_id": algorithm_configs[i].get("id"),
"algorithm_name": algorithm_configs[i].get("name"),
"success": False,
"error": str(result),
"output": None,
"execution_time": 0
})
else:
comparison_results.append({
"algorithm_id": algorithm_configs[i].get("id"),
"algorithm_name": algorithm_configs[i].get("name"),
"success": True,
"error": None,
"output": result.get("output"),
"execution_time": result.get("execution_time", 0)
})
return {
"success": True,
"results": comparison_results,
"input_data": input_data
}
except Exception as e:
logger.error(f"Comparison error: {str(e)}")
return {
"success": False,
"error": str(e),
"results": []
}
async def _execute_algorithm(
self,
config: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""执行单个算法
Args:
config: 算法配置
input_data: 输入数据
Returns:
执行结果
"""
import time
start_time = time.time()
try:
url = config.get("url")
params = config.get("params", {})
if not url:
raise ValueError("缺少算法服务URL")
# 构建请求数据
request_data = {
"input_data": input_data.get("input_data", input_data),
"params": params
}
# 发送请求
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(f"{url}/predict", json=request_data)
response.raise_for_status()
result = response.json()
execution_time = time.time() - start_time
return {
"output": result,
"execution_time": execution_time
}
except Exception as e:
logger.error(f"Algorithm execution error: {str(e)}")
raise e
def generate_comparison_report(self, comparison_results: Dict[str, Any]) -> Dict[str, Any]:
"""生成对比报告
Args:
comparison_results: 对比结果
Returns:
对比报告
"""
try:
if not comparison_results.get("success"):
return {
"success": False,
"error": comparison_results.get("error", "对比失败")
}
results = comparison_results.get("results", [])
# 分析结果
successful_algorithms = [r for r in results if r.get("success")]
failed_algorithms = [r for r in results if not r.get("success")]
# 计算平均执行时间
if successful_algorithms:
avg_execution_time = sum(r.get("execution_time", 0) for r in successful_algorithms) / len(successful_algorithms)
else:
avg_execution_time = 0
# 生成报告
report = {
"summary": {
"total_algorithms": len(results),
"successful_algorithms": len(successful_algorithms),
"failed_algorithms": len(failed_algorithms),
"average_execution_time": round(avg_execution_time, 2)
},
"details": results,
"input_data": comparison_results.get("input_data")
}
return {
"success": True,
"report": report
}
except Exception as e:
logger.error(f"Report generation error: {str(e)}")
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,165 @@
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from app.models.models import ServiceConfig
import uuid
import logging
logger = logging.getLogger(__name__)
class ConfigService:
"""配置服务"""
@staticmethod
def get_config(db: Session, config_key: str) -> Optional[Dict[str, Any]]:
"""获取配置
Args:
db: 数据库会话
config_key: 配置键
Returns:
配置值如果不存在返回None
"""
config = db.query(ServiceConfig).filter_by(
config_key=config_key,
status="active"
).first()
if config:
return config.config_value
return None
@staticmethod
def set_config(db: Session, config_key: str, config_value: Dict[str, Any],
config_type: str = "system", service_id: Optional[str] = None,
description: str = "") -> bool:
"""设置配置
Args:
db: 数据库会话
config_key: 配置键
config_value: 配置值
config_type: 配置类型,默认为"system"
service_id: 服务ID系统配置可为None
description: 配置描述
Returns:
是否设置成功
"""
try:
# 检查是否存在
existing_config = db.query(ServiceConfig).filter_by(
config_key=config_key
).first()
if existing_config:
# 更新现有配置
existing_config.config_value = config_value
existing_config.config_type = config_type
existing_config.service_id = service_id
existing_config.description = description
existing_config.status = "active"
else:
# 创建新配置
new_config = ServiceConfig(
id=f"config-{uuid.uuid4()}",
config_key=config_key,
config_value=config_value,
config_type=config_type,
service_id=service_id,
description=description,
status="active"
)
db.add(new_config)
db.commit()
return True
except Exception as e:
logger.error(f"Failed to set config: {str(e)}")
db.rollback()
return False
@staticmethod
def get_service_configs(db: Session, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的所有配置
Args:
db: 数据库会话
service_id: 服务ID
Returns:
服务配置列表
"""
configs = db.query(ServiceConfig).filter_by(
service_id=service_id,
status="active"
).all()
return [
{
"key": config.config_key,
"value": config.config_value,
"type": config.config_type,
"description": config.description
}
for config in configs
]
@staticmethod
def delete_config(db: Session, config_key: str) -> bool:
"""删除配置
Args:
db: 数据库会话
config_key: 配置键
Returns:
是否删除成功
"""
try:
config = db.query(ServiceConfig).filter_by(
config_key=config_key
).first()
if config:
config.status = "inactive"
db.commit()
return True
except Exception as e:
logger.error(f"Failed to delete config: {str(e)}")
db.rollback()
return False
@staticmethod
def get_all_configs(db: Session, config_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取所有配置
Args:
db: 数据库会话
config_type: 配置类型,可选
Returns:
配置列表
"""
query = db.query(ServiceConfig).filter_by(status="active")
if config_type:
query = query.filter_by(config_type=config_type)
configs = query.all()
return [
{
"id": config.id,
"key": config.config_key,
"value": config.config_value,
"type": config.config_type,
"service_id": config.service_id,
"description": config.description,
"created_at": config.created_at,
"updated_at": config.updated_at
}
for config in configs
]

View File

@@ -63,22 +63,41 @@ class ProjectAnalyzer:
Returns:
项目类型,如 "python", "java", "nodejs"
"""
# 检查Python项目
# 检查Python项目 - 先检查根目录
if os.path.exists(os.path.join(repo_path, "requirements.txt")) or \
os.path.exists(os.path.join(repo_path, "pyproject.toml")) or \
any(file.endswith(".py") for file in os.listdir(repo_path)):
return "python"
# 检查Java项目
# 检查Python项目 - 递归检查子目录
for root, dirs, files in os.walk(repo_path):
if "requirements.txt" in files or "pyproject.toml" in files:
return "python"
if any(file.endswith(".py") for file in files):
return "python"
# 检查Java项目 - 先检查根目录
if os.path.exists(os.path.join(repo_path, "pom.xml")) or \
os.path.exists(os.path.join(repo_path, "build.gradle")) or \
os.path.exists(os.path.join(repo_path, "src")):
return "java"
# 检查Node.js项目
# 检查Java项目 - 递归检查子目录
for root, dirs, files in os.walk(repo_path):
if "pom.xml" in files or "build.gradle" in files:
return "java"
if "src" in dirs:
return "java"
# 检查Node.js项目 - 先检查根目录
if os.path.exists(os.path.join(repo_path, "package.json")):
return "nodejs"
# 检查Node.js项目 - 递归检查子目录
for root, dirs, files in os.walk(repo_path):
if "package.json" in files:
return "nodejs"
# 检查其他项目类型
if os.path.exists(os.path.join(repo_path, "CMakeLists.txt")):
return "c++"

View File

@@ -38,13 +38,14 @@ class ServiceOrchestrator:
self.client = None
print("使用本地进程部署模式")
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any]) -> Dict[str, Any]:
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any], repo_path: str = None) -> Dict[str, Any]:
"""部署服务
Args:
service_id: 服务ID
service_config: 服务配置
project_info: 项目信息
repo_path: 仓库路径(用于复制真实的算法文件)
Returns:
部署结果
@@ -95,7 +96,7 @@ class ServiceOrchestrator:
service_dir = self._create_service_directory(service_id)
# 2. 生成服务包装器
self._generate_local_service_wrapper(service_dir, project_info, service_config)
self._generate_local_service_wrapper(service_dir, project_info, service_config, repo_path)
# 3. 启动服务进程
process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config)
@@ -176,9 +177,13 @@ class ServiceOrchestrator:
else:
# 本地进程启动
if service_id not in self.processes:
# 服务不在进程列表中,可能是服务重启导致的
# 这种情况下,需要从外部重新注册服务
# 暂时返回错误,建议用户重新注册服务
print(f"服务 {service_id} 不在进程列表中,无法启动")
return {
"success": False,
"error": "服务不存在",
"error": "服务不存在,请重新注册服务",
"service_id": service_id,
"status": "error"
}
@@ -272,11 +277,18 @@ class ServiceOrchestrator:
else:
# 本地进程停止
if service_id not in self.processes:
# 服务不在进程列表中,可能是服务重启导致的
# 尝试通过端口查找并停止进程
print(f"服务 {service_id} 不在进程列表中,尝试通过端口查找进程")
# 从服务配置中获取端口信息
# 这里需要从外部传入服务配置,或者从数据库查询
# 暂时返回成功,因为服务可能已经停止了
return {
"success": False,
"error": "服务不存在",
"success": True,
"service_id": service_id,
"status": "error"
"status": "stopped",
"error": None
}
process_info = self.processes[service_id]
@@ -1271,13 +1283,14 @@ json
os.makedirs(service_dir, exist_ok=True)
return service_dir
def _generate_local_service_wrapper(self, service_dir: str, project_info: Dict[str, Any], service_config: Dict[str, Any]):
def _generate_local_service_wrapper(self, service_dir: str, project_info: Dict[str, Any], service_config: Dict[str, Any], repo_path: str = None):
"""生成本地服务包装器
Args:
service_dir: 服务目录
project_info: 项目信息
service_config: 服务配置
repo_path: 仓库路径(用于复制真实的算法文件)
"""
# 生成服务包装器
service_wrapper_content = self._generate_service_wrapper(project_info, service_config)
@@ -1285,7 +1298,44 @@ json
with open(os.path.join(service_dir, f"service_wrapper{wrapper_extension}"), "w") as f:
f.write(service_wrapper_content)
# 创建模拟的算法文件
# 复制真实的算法文件
if repo_path and project_info["project_type"] == "python":
# 尝试找到并复制主要的算法文件
entry_point = project_info.get("entry_point")
if entry_point:
source_file = os.path.join(repo_path, entry_point)
if os.path.exists(source_file):
# 复制算法文件到服务目录
import shutil
shutil.copy2(source_file, os.path.join(service_dir, "algorithm.py"))
print(f"已复制算法文件: {source_file} -> {os.path.join(service_dir, 'algorithm.py')}")
return
# 如果没有找到入口点尝试复制所有Python文件
if os.path.exists(repo_path):
import shutil
for root, dirs, files in os.walk(repo_path):
for file in files:
if file.endswith(".py") and not file.startswith("_"):
source_file = os.path.join(root, file)
dest_file = os.path.join(service_dir, file)
shutil.copy2(source_file, dest_file)
print(f"已复制Python文件: {source_file} -> {dest_file}")
# 如果有algorithm.py就使用它否则创建一个模拟的
if not os.path.exists(os.path.join(service_dir, "algorithm.py")):
print("未找到algorithm.py创建模拟算法文件")
self._create_mock_algorithm(service_dir)
else:
# 创建模拟的算法文件
self._create_mock_algorithm(service_dir)
def _create_mock_algorithm(self, service_dir: str):
"""创建模拟的算法文件
Args:
service_dir: 服务目录
"""
algorithm_content = """
def predict(data):
return {"result": "Prediction result", "input": data}
@@ -1316,9 +1366,9 @@ def main(data):
# 构建启动命令
if project_info["project_type"] == "python":
cmd = ["python", f"service_wrapper.py"]
cmd = ["python", "service_wrapper.py"]
else:
cmd = ["node", f"service_wrapper.js"]
cmd = ["node", "service_wrapper.js"]
# 设置环境变量
env = os.environ.copy()