good version for 算法注册

This commit is contained in:
2026-02-15 21:23:28 +08:00
parent 3c03777b97
commit 62ea5d36a5
115 changed files with 9566 additions and 1576 deletions

View File

View File

@@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings
from typing import Optional
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
class Settings(BaseSettings):
@@ -46,11 +47,56 @@ class Settings(BaseSettings):
GITEA_DEFAULT_OWNER: str = ""
GITEA_REPO_PREFIX: str = "AI"
# 服务管理配置
SERVICE_MANAGEMENT: Dict[str, Any] = {
"mode": "supervisor", # 服务管理模式local, docker, supervisor
"service_root_dir": "/opt/ai-services",
"supervisor_config_dir": "/etc/supervisor/conf.d",
}
class Config:
env_file = ".env"
case_sensitive = True
extra = "allow" # 允许额外的环境变量
def get_config(self, config_key: str, default: Any = None) -> Any:
"""获取配置,优先级:环境变量 > 数据库 > 文件默认值
Args:
config_key: 配置键
default: 默认值
Returns:
配置值
"""
# 1. 先从环境变量获取
env_key = config_key.upper().replace('.', '_')
env_value = getattr(self, env_key, None)
if env_value is not None:
return env_value
# 2. 从数据库获取
try:
from app.models.database import SessionLocal
from app.models.models import ServiceConfig
db: Session = SessionLocal()
try:
config = db.query(ServiceConfig).filter_by(
config_key=config_key,
status="active"
).first()
if config:
return config.config_value
finally:
db.close()
except Exception as e:
# 数据库连接失败时,返回默认值
print(f"Failed to load config from database: {str(e)}")
# 3. 返回默认值
return default
# 创建全局配置实例
settings = Settings()

View File

@@ -72,9 +72,24 @@ class APIGateway:
if not version_info:
raise HTTPException(status_code=404, detail="Algorithm version not found")
# 在实际实现中这里会根据version_info.url将请求转发到对应的算法服务
# 现在我们模拟调用过程
algorithm_url = version_info.url if hasattr(version_info, 'url') else f"http://localhost:8001/algorithms/{algorithm_id}/execute"
# 尝试从算法版本获取URL如果没有则尝试从服务表获取
algorithm_url = None
# 首先检查版本信息中是否有URL
if hasattr(version_info, 'url') and version_info.url:
algorithm_url = version_info.url
else:
# 如果版本信息中没有URL尝试从服务表获取
from app.models.models import AlgorithmService
service = db.query(AlgorithmService).filter(
AlgorithmService.algorithm_name == algorithm_id
).first()
if service and service.api_url:
algorithm_url = service.api_url
else:
# 如果都没有,使用默认的本地端点
algorithm_url = f"http://localhost:8001/algorithms/{algorithm_id}/execute"
# 使用httpx调用算法服务
async with httpx.AsyncClient(timeout=30.0) as client:

View File

@@ -4,7 +4,8 @@ from fastapi.middleware.trustedhost import TrustedHostMiddleware
from app.config.settings import settings
from app.models.database import engine, Base
from app.routes import user, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories
from app.models import models, api # 导入所有模型以确保表被创建
from app.routes import user, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories, config, comparison, api_management
# 创建数据库表
Base.metadata.create_all(bind=engine)
@@ -48,6 +49,9 @@ app.include_router(history.router, prefix=settings.API_V1_STR)
app.include_router(deployment.router)
app.include_router(gitea.router, prefix=settings.API_V1_STR)
app.include_router(repositories.router, prefix=settings.API_V1_STR)
app.include_router(config.router, prefix=settings.API_V1_STR)
app.include_router(comparison.router, prefix=settings.API_V1_STR)
app.include_router(api_management.router, prefix=settings.API_V1_STR)
@app.get("/")

Binary file not shown.

74
backend/app/models/api.py Normal file
View File

@@ -0,0 +1,74 @@
"""API封装模型"""
from sqlalchemy import Column, String, DateTime, Text, Boolean, ForeignKey
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from app.models.database import Base
import uuid
class ApiEndpoint(Base):
"""API端点模型"""
__tablename__ = "api_endpoints"
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
name = Column(String, nullable=False, index=True) # API名称
description = Column(Text, default="") # API描述
path = Column(String, nullable=False, unique=True, index=True) # API路径如 /api/image-classification
method = Column(String, default="POST") # HTTP方法GET, POST, PUT, DELETE
algorithm_id = Column(String, ForeignKey("algorithms.id"), nullable=False, index=True) # 关联的算法ID
version_id = Column(String, ForeignKey("algorithm_versions.id"), nullable=False, index=True) # 关联的算法版本ID
service_id = Column(String, ForeignKey("algorithm_services.service_id"), nullable=True, index=True) # 关联的服务ID
# API配置
config = Column(JSON, nullable=False, default={}) # API配置超时、重试等
request_schema = Column(JSON, nullable=True) # 请求参数schema
response_schema = Column(JSON, nullable=True) # 响应参数schema
# 权限配置
requires_auth = Column(Boolean, default=True) # 是否需要认证
allowed_roles = Column(JSON, default=[]) # 允许的角色列表
rate_limit = Column(JSON, nullable=True) # 限流配置,如 {"max_requests": 100, "window": 60}
# 状态
status = Column(String, default="active", index=True) # 状态active, inactive, deprecated
is_public = Column(Boolean, default=False) # 是否公开
# 统计信息
call_count = Column(String, default="0") # 调用次数
success_count = Column(String, default="0") # 成功次数
error_count = Column(String, default="0") # 错误次数
avg_response_time = Column(String, default="0.0") # 平均响应时间
# 时间戳
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), onupdate=datetime.utcnow)
last_called_at = Column(DateTime(timezone=True), nullable=True) # 最后调用时间
class ApiCallLog(Base):
"""API调用日志模型"""
__tablename__ = "api_call_logs"
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
api_endpoint_id = Column(String, ForeignKey("api_endpoints.id"), nullable=False, index=True) # API端点ID
user_id = Column(String, ForeignKey("users.id"), nullable=True, index=True) # 调用用户ID
# 请求信息
request_method = Column(String, nullable=False) # 请求方法
request_path = Column(String, nullable=False) # 请求路径
request_headers = Column(JSON, nullable=True) # 请求头
request_body = Column(JSON, nullable=True) # 请求体
# 响应信息
response_status = Column(String, nullable=False) # 响应状态码
response_body = Column(JSON, nullable=True) # 响应体
response_time = Column(String, nullable=False) # 响应时间(秒)
# 错误信息
error_message = Column(Text, nullable=True) # 错误信息
error_type = Column(String, nullable=True) # 错误类型
# 时间戳
created_at = Column(DateTime(timezone=True), default=datetime.utcnow, index=True) # 调用时间

View File

@@ -13,6 +13,8 @@ class Algorithm(Base):
name = Column(String, nullable=False, index=True)
description = Column(Text, nullable=False)
type = Column(String, nullable=False, index=True) # computer_vision, nlp, ml, edge_computing, medical, autonomous_driving等
tech_category = Column(String, nullable=False, default="computer_vision") # 技术分类:计算机视觉、视频处理、自然语言处理等
output_type = Column(String, nullable=False, default="image") # 输出类型图片、视频、文本、JSON等
status = Column(String, default="active", index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -159,6 +161,8 @@ class AlgorithmService(Base):
name = Column(String, nullable=False, index=True) # 服务名称
algorithm_name = Column(String, nullable=False) # 算法名称
version = Column(String, nullable=False) # 版本
tech_category = Column(String, nullable=False, default="computer_vision") # 技术分类:计算机视觉、视频处理、自然语言处理等
output_type = Column(String, nullable=False, default="image") # 输出类型图片、视频、文本、JSON等
host = Column(String, nullable=False) # 主机地址
port = Column(Integer, nullable=False) # 端口
api_url = Column(String, nullable=False) # API地址
@@ -172,3 +176,18 @@ class AlgorithmService(Base):
# 添加Algorithm模型的repository关系
Algorithm.repository = relationship("AlgorithmRepository", back_populates="algorithm", uselist=False)
class ServiceConfig(Base):
"""服务配置模型"""
__tablename__ = "service_configs"
id = Column(String, primary_key=True, index=True)
config_key = Column(String, nullable=False, unique=True, index=True) # 配置键
config_value = Column(JSON, nullable=False) # 配置值JSON格式
config_type = Column(String, nullable=False) # 配置类型system, service, user
service_id = Column(String, nullable=True, index=True) # 服务ID可为空系统配置
description = Column(Text, default="") # 配置描述
status = Column(String, default="active") # 状态
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment, config, comparison
api_router = APIRouter()
@@ -11,3 +11,5 @@ api_router.include_router(gateway.router, prefix="/gateway", tags=["gateway"])
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
api_router.include_router(openai.router, prefix="/openai", tags=["openai"])
api_router.include_router(deployment.router, tags=["deployment"])
api_router.include_router(config.router, tags=["config"])
api_router.include_router(comparison.router, tags=["comparison"])

Binary file not shown.

View File

@@ -33,6 +33,18 @@ async def create_algorithm(
@router.get("", response_model=AlgorithmListResponse)
async def get_algorithms_no_slash(
skip: int = 0,
limit: int = 100,
type: Optional[str] = None,
db: Session = Depends(get_db)
):
"""获取算法列表(不带末尾斜杠)"""
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
return {"algorithms": algorithms, "total": len(algorithms)}
@router.get("/", response_model=AlgorithmListResponse)
async def get_algorithms(
skip: int = 0,
limit: int = 100,

View File

@@ -0,0 +1,510 @@
"""API管理路由处理API端点的封装和管理"""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
import logging
from datetime import datetime
from app.models.database import get_db
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
from app.models.api import ApiEndpoint, ApiCallLog
from app.schemas.user import UserResponse
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/api-management", tags=["api-management"])
logger = logging.getLogger(__name__)
class ApiEndpointCreate(BaseModel):
"""创建API端点请求模型"""
name: str
description: str = ""
path: str
method: str = "POST"
algorithm_id: str
version_id: str
service_id: Optional[str] = None
requires_auth: bool = True
allowed_roles: List[str] = []
rate_limit: Optional[Dict[str, Any]] = None
is_public: bool = False
config: Dict[str, Any] = {}
class ApiEndpointUpdate(BaseModel):
"""更新API端点请求模型"""
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None
method: Optional[str] = None
requires_auth: Optional[bool] = None
allowed_roles: Optional[List[str]] = None
rate_limit: Optional[Dict[str, Any]] = None
is_public: Optional[bool] = None
config: Optional[Dict[str, Any]] = None
status: Optional[str] = None
class ApiEndpointResponse(BaseModel):
"""API端点响应模型"""
id: str
name: str
description: str
path: str
method: str
algorithm_id: str
algorithm_name: str
version_id: str
version: str
service_id: Optional[str]
status: str
is_public: bool
call_count: str
success_count: str
error_count: str
avg_response_time: str
created_at: datetime
updated_at: Optional[datetime]
last_called_at: Optional[datetime]
class ApiEndpointListResponse(BaseModel):
"""API端点列表响应模型"""
endpoints: List[ApiEndpointResponse]
total: int
class ApiStatsResponse(BaseModel):
"""API统计响应模型"""
total_endpoints: int
active_endpoints: int
total_calls: str
total_success: str
total_errors: str
avg_response_time: str
@router.get("/endpoints", response_model=ApiEndpointListResponse)
async def get_api_endpoints(
skip: int = 0,
limit: int = 100,
algorithm_id: Optional[str] = None,
status: Optional[str] = None,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点列表"""
try:
# 构建查询
query = db.query(ApiEndpoint)
# 筛选条件
if algorithm_id:
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
if status:
query = query.filter(ApiEndpoint.status == status)
# 分页
endpoints = query.offset(skip).limit(limit).all()
total = query.count()
# 构建响应
endpoint_responses = []
for endpoint in endpoints:
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
endpoint_responses.append({
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
})
return {
"endpoints": endpoint_responses,
"total": total
}
except Exception as e:
logger.error(f"获取API端点列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def get_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API端点详情"""
try:
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API端点详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
async def create_api_endpoint(
request: ApiEndpointCreate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""创建API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 验证算法和版本是否存在
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
if not algorithm:
raise HTTPException(status_code=404, detail="算法不存在")
version = db.query(AlgorithmVersion).filter(
AlgorithmVersion.id == request.version_id
).first()
if not version or version.algorithm_id != request.algorithm_id:
raise HTTPException(status_code=404, detail="算法版本不存在")
# 如果指定了服务ID验证服务是否存在
if request.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查API路径是否已存在
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
# 创建API端点
new_endpoint = ApiEndpoint(
name=request.name,
description=request.description,
path=request.path,
method=request.method,
algorithm_id=request.algorithm_id,
version_id=request.version_id,
service_id=request.service_id,
requires_auth=request.requires_auth,
allowed_roles=request.allowed_roles,
rate_limit=request.rate_limit,
is_public=request.is_public,
config=request.config,
status="active",
call_count="0",
success_count="0",
error_count="0",
avg_response_time="0.0"
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
# 返回创建的API端点
return {
"id": new_endpoint.id,
"name": new_endpoint.name,
"description": new_endpoint.description,
"path": new_endpoint.path,
"method": new_endpoint.method,
"algorithm_id": new_endpoint.algorithm_id,
"algorithm_name": algorithm.name,
"version_id": new_endpoint.version_id,
"version": version.version,
"service_id": new_endpoint.service_id,
"status": new_endpoint.status,
"is_public": new_endpoint.is_public,
"call_count": new_endpoint.call_count,
"success_count": new_endpoint.success_count,
"error_count": new_endpoint.error_count,
"avg_response_time": new_endpoint.avg_response_time,
"created_at": new_endpoint.created_at,
"updated_at": new_endpoint.updated_at,
"last_called_at": new_endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
async def update_api_endpoint(
endpoint_id: str,
request: ApiEndpointUpdate,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""更新API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 更新字段
if request.name is not None:
endpoint.name = request.name
if request.description is not None:
endpoint.description = request.description
if request.path is not None:
# 检查新路径是否已被其他端点使用
existing_endpoint = db.query(ApiEndpoint).filter(
ApiEndpoint.path == request.path,
ApiEndpoint.id != endpoint_id
).first()
if existing_endpoint:
raise HTTPException(status_code=400, detail="API路径已存在")
endpoint.path = request.path
if request.method is not None:
endpoint.method = request.method
if request.requires_auth is not None:
endpoint.requires_auth = request.requires_auth
if request.allowed_roles is not None:
endpoint.allowed_roles = request.allowed_roles
if request.rate_limit is not None:
endpoint.rate_limit = request.rate_limit
if request.is_public is not None:
endpoint.is_public = request.is_public
if request.config is not None:
endpoint.config = request.config
if request.status is not None:
endpoint.status = request.status
db.commit()
db.refresh(endpoint)
# 获取关联的算法和版本信息
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
return {
"id": endpoint.id,
"name": endpoint.name,
"description": endpoint.description,
"path": endpoint.path,
"method": endpoint.method,
"algorithm_id": endpoint.algorithm_id,
"algorithm_name": algorithm.name if algorithm else "",
"version_id": endpoint.version_id,
"version": version.version if version else "",
"service_id": endpoint.service_id,
"status": endpoint.status,
"is_public": endpoint.is_public,
"call_count": endpoint.call_count,
"success_count": endpoint.success_count,
"error_count": endpoint.error_count,
"avg_response_time": endpoint.avg_response_time,
"created_at": endpoint.created_at,
"updated_at": endpoint.updated_at,
"last_called_at": endpoint.last_called_at
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
@router.delete("/endpoints/{endpoint_id}")
async def delete_api_endpoint(
endpoint_id: str,
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""删除API端点"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 删除API端点
db.delete(endpoint)
db.commit()
return {
"success": True,
"message": "API端点删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除API端点失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
@router.get("/stats", response_model=ApiStatsResponse)
async def get_api_stats(
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""获取API统计信息"""
try:
# 检查用户权限
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
raise HTTPException(status_code=403, detail="Insufficient permissions")
# 统计API端点
total_endpoints = db.query(ApiEndpoint).count()
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
# 统计调用次数
endpoints = db.query(ApiEndpoint).all()
total_calls = sum(int(e.call_count or 0) for e in endpoints)
total_success = sum(int(e.success_count or 0) for e in endpoints)
total_errors = sum(int(e.error_count or 0) for e in endpoints)
# 计算平均响应时间
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
return {
"total_endpoints": total_endpoints,
"active_endpoints": active_endpoints,
"total_calls": str(total_calls),
"total_success": str(total_success),
"total_errors": str(total_errors),
"avg_response_time": f"{avg_response_time:.2f}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取API统计信息失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
@router.post("/endpoints/{endpoint_id}/test")
async def test_api_endpoint(
endpoint_id: str,
payload: Dict[str, Any],
db: Session = Depends(get_db),
current_user: UserResponse = Depends(get_current_active_user)
):
"""测试API端点"""
try:
# 查询API端点
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
if not endpoint:
raise HTTPException(status_code=404, detail="API端点不存在")
# 检查API端点状态
if endpoint.status != "active":
raise HTTPException(status_code=400, detail="API端点未激活")
# 查询关联的服务
if endpoint.service_id:
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == endpoint.service_id
).first()
if not service or service.status != "running":
raise HTTPException(status_code=400, detail="关联服务未运行")
# 调用服务
import httpx
import time
service_url = service.api_url
if not service_url.endswith("/"):
service_url += "/"
call_url = f"{service_url}predict"
start_time = time.time()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=payload,
headers={"Content-Type": "application/json"}
)
response_time = time.time() - start_time
if response.status_code == 200:
return {
"success": True,
"result": response.json(),
"response_time": response_time,
"message": "API调用成功"
}
else:
return {
"success": False,
"error": f"服务返回错误: HTTP {response.status_code}",
"response_time": response_time
}
else:
raise HTTPException(status_code=400, detail="API端点未关联服务")
except HTTPException:
raise
except Exception as e:
logger.error(f"测试API端点失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")

View File

@@ -0,0 +1,64 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any, List
from app.services.comparison_service import ComparisonService
from app.routes.user import get_current_active_user
# 创建路由器
router = APIRouter(prefix="/comparison", tags=["comparison"])
# 创建对比服务实例
comparison_service = ComparisonService()
@router.post("/compare-algorithms", response_model=dict)
async def compare_algorithms(
request_data: Dict[str, Any],
current_user: dict = Depends(get_current_active_user)
):
"""比较多个算法的效果
Args:
request_data: 请求数据包含input_data和algorithm_configs
current_user: 当前活跃用户
Returns:
对比结果
"""
input_data = request_data.get("input_data")
algorithm_configs = request_data.get("algorithm_configs")
if not input_data:
raise HTTPException(status_code=400, detail="缺少 input_data 参数")
if not algorithm_configs or not isinstance(algorithm_configs, list):
raise HTTPException(status_code=400, detail="缺少 algorithm_configs 参数或格式错误")
result = await comparison_service.compare_algorithms(input_data, algorithm_configs)
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("error", "对比失败"))
return result
@router.post("/generate-report", response_model=dict)
async def generate_comparison_report(
comparison_results: Dict[str, Any],
current_user: dict = Depends(get_current_active_user)
):
"""生成对比报告
Args:
comparison_results: 对比结果
current_user: 当前活跃用户
Returns:
对比报告
"""
report = comparison_service.generate_comparison_report(comparison_results)
if not report["success"]:
raise HTTPException(status_code=500, detail=report.get("error", "生成报告失败"))
return report

View File

@@ -0,0 +1,124 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, List, Optional
from app.models.database import get_db
from app.services.config_service import ConfigService
from app.routes.user import get_current_active_user
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/{config_key}")
async def get_config(
config_key: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取配置
Args:
config_key: 配置键
db: 数据库会话
current_user: 当前活跃用户
Returns:
配置信息
"""
config = ConfigService.get_config(db, config_key)
if not config:
raise HTTPException(status_code=404, detail="配置不存在")
return {"key": config_key, "value": config}
@router.post("/{config_key}")
async def set_config(
config_key: str,
config_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""设置配置
Args:
config_key: 配置键
config_data: 配置数据包含value、type、service_id、description等字段
db: 数据库会话
current_user: 当前活跃用户
Returns:
设置结果
"""
success = ConfigService.set_config(
db=db,
config_key=config_key,
config_value=config_data.get("value"),
config_type=config_data.get("type", "system"),
service_id=config_data.get("service_id"),
description=config_data.get("description", "")
)
if not success:
raise HTTPException(status_code=400, detail="设置配置失败")
return {"message": "设置配置成功"}
@router.get("/service/{service_id}")
async def get_service_configs(
service_id: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取服务配置
Args:
service_id: 服务ID
db: 数据库会话
current_user: 当前活跃用户
Returns:
服务配置列表
"""
configs = ConfigService.get_service_configs(db, service_id)
return {"service_id": service_id, "configs": configs}
@router.delete("/{config_key}")
async def delete_config(
config_key: str,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""删除配置
Args:
config_key: 配置键
db: 数据库会话
current_user: 当前活跃用户
Returns:
删除结果
"""
success = ConfigService.delete_config(db, config_key)
if not success:
raise HTTPException(status_code=400, detail="删除配置失败")
return {"message": "删除配置成功"}
@router.get("/")
async def get_all_configs(
config_type: Optional[str] = None,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_active_user)
):
"""获取所有配置
Args:
config_type: 配置类型,可选
db: 数据库会话
current_user: 当前活跃用户
Returns:
配置列表
"""
configs = ConfigService.get_all_configs(db, config_type)
return {"configs": configs}

View File

@@ -14,6 +14,7 @@ from app.schemas.user import UserResponse
from app.services.project_analyzer import ProjectAnalyzer
from app.services.service_generator import ServiceGenerator
from app.services.service_orchestrator import ServiceOrchestrator
from app.gitea.service import gitea_service
router = APIRouter(prefix="/services", tags=["services"])
@@ -23,6 +24,9 @@ class RegisterServiceRequest(BaseModel):
repository_id: str
name: str
version: str = "1.0.0"
description: Optional[str] = ""
tech_category: str = "computer_vision"
output_type: str = "image"
service_type: str = "http"
host: str = "0.0.0.0"
port: int = 8000
@@ -154,31 +158,24 @@ async def register_service(
# 记录仓库信息
print(f"仓库信息: {repo.name}, {repo.description}, {repo.repo_url}")
# 2. 分析项目
repo_path = f"/tmp/repository_{request.repository_id}"
# 注意:在实际实现中,应该从算法仓库中获取项目文件
# 这里简化处理,创建一个模拟的项目结构
os.makedirs(repo_path, exist_ok=True)
# 2. 从Gitea仓库克隆代码到本地
repo_path = f"/tmp/algorithms/{request.repository_id}"
# 创建模拟的算法文件
with open(os.path.join(repo_path, "algorithm.py"), "w") as f:
f.write("""
def predict(data):
return {"result": "Prediction result", "input": data}
def run(data):
return {"result": "Run result", "input": data}
def main(data):
return {"result": "Main result", "input": data}
""")
# 使用Gitea服务克隆仓库
clone_success = gitea_service.clone_repository(repo.repo_url, request.repository_id, repo.branch or "main")
if not clone_success:
raise HTTPException(status_code=400, detail=f"克隆仓库失败: {repo.repo_url}")
# 分析项目
print(f"仓库克隆成功: {repo_path}")
# 3. 分析项目
project_info = project_analyzer.analyze_project(repo_path)
if not project_info["success"]:
raise HTTPException(status_code=400, detail=f"项目分析失败: {project_info['error']}")
# 3. 生成服务包装器
print(f"项目分析成功: {project_info}")
# 4. 生成服务包装器
service_config = {
"name": request.name,
"version": request.version,
@@ -194,24 +191,31 @@ def main(data):
if not generate_result["success"]:
raise HTTPException(status_code=400, detail=f"服务生成失败: {generate_result['error']}")
# 4. 部署服务
print(f"服务生成成功: {generate_result}")
# 5. 部署服务
service_id = str(uuid.uuid4())
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info)
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info, repo_path)
if not deploy_result["success"]:
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
# 5. 保存服务信息到数据库
print(f"服务部署成功: {deploy_result}")
# 6. 保存服务信息到数据库
new_service = AlgorithmService(
id=str(uuid.uuid4()),
service_id=service_id,
name=request.name,
algorithm_name=repo.name, # 使用仓库名称作为算法名称
version=request.version,
tech_category=request.tech_category,
output_type=request.output_type,
host=request.host,
port=request.port,
api_url=deploy_result["api_url"],
status=deploy_result["status"],
config={
"repository_id": request.repository_id, # 保存仓库ID
"service_type": request.service_type,
"timeout": request.timeout,
"health_check_path": request.health_check_path,
@@ -352,8 +356,64 @@ async def start_service(
# 启动服务
start_result = service_orchestrator.start_service(service_id, container_id)
# 如果启动失败,尝试从数据库重新注册服务
if not start_result["success"]:
raise HTTPException(status_code=400, detail=f"服务启动失败: {start_result['error']}")
print(f"服务启动失败: {start_result['error']},尝试从数据库重新注册服务")
# 获取仓库信息
repository_id = service.config.get("repository_id")
if not repository_id:
raise HTTPException(status_code=400, detail="Repository ID not found in service config")
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
if not repository:
raise HTTPException(status_code=404, detail="Repository not found")
# 从Gitea克隆仓库
clone_success = gitea_service.clone_repository(
repository.repo_url,
service_id,
repository.branch or "main"
)
if not clone_success:
raise HTTPException(status_code=400, detail="Failed to clone repository")
# 仓库路径
repo_path = f"/tmp/algorithms/{service_id}"
# 分析项目
project_info = project_analyzer.analyze_project(repo_path)
if not project_info:
raise HTTPException(status_code=400, detail="Failed to analyze project")
# 生成服务
service_config = {
"name": service.name,
"version": service.version,
"host": service.host,
"port": service.port,
"timeout": service.config.get("timeout", 30),
"health_check_path": service.config.get("health_check_path", "/health"),
"environment": service.config.get("environment", {})
}
# 部署服务
deploy_result = service_orchestrator.deploy_service(service_id, project_info, service_config, repo_path)
if not deploy_result["success"]:
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
# 更新服务配置
service.config["container_id"] = deploy_result["container_id"]
service.api_url = deploy_result["api_url"]
db.commit()
start_result = {
"success": True,
"service_id": service_id,
"status": "running",
"error": None
}
# 更新服务状态
service.status = start_result["status"]
@@ -1065,3 +1125,108 @@ async def batch_delete_services(
)
finally:
db.close()
class ServiceCallRequest(BaseModel):
"""服务调用请求"""
service_id: str
payload: Dict[str, Any]
class ServiceCallResponse(BaseModel):
"""服务调用响应"""
success: bool
result: Dict[str, Any]
service_id: str
execution_time: float
error: Optional[str] = None
@router.post("/call")
async def call_service(
request: ServiceCallRequest,
current_user: UserResponse = Depends(get_current_active_user)
):
"""直接调用注册的服务"""
import time
import httpx
# 创建数据库会话
db = SessionLocal()
try:
# 查询服务
service = db.query(AlgorithmService).filter(
AlgorithmService.service_id == request.service_id
).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
# 检查服务状态
if service.status != "running":
raise HTTPException(
status_code=503,
detail=f"服务未运行,当前状态: {service.status}"
)
# 调用服务
start_time = time.time()
try:
# 构建服务URL
service_url = service.api_url
# 如果URL没有路径添加默认路径
if not service_url.endswith("/"):
service_url += "/"
# 添加调用端点
call_url = f"{service_url}predict"
# 使用httpx调用服务
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
call_url,
json=request.payload,
headers={"Content-Type": "application/json"}
)
execution_time = time.time() - start_time
if response.status_code == 200:
return ServiceCallResponse(
success=True,
result=response.json(),
service_id=request.service_id,
execution_time=execution_time
)
else:
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"服务返回错误: HTTP {response.status_code} - {response.text}"
)
except httpx.RequestError as e:
execution_time = time.time() - start_time
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"无法连接到服务: {str(e)}"
)
except Exception as e:
execution_time = time.time() - start_time
return ServiceCallResponse(
success=False,
result={},
service_id=request.service_id,
execution_time=execution_time,
error=f"服务调用异常: {str(e)}"
)
finally:
db.close()

View File

@@ -167,6 +167,12 @@ async def read_users_me(current_user: UserResponse = Depends(get_current_active_
return current_user
@router.get("/me/", response_model=UserResponse)
async def read_users_me_with_slash(current_user: UserResponse = Depends(get_current_active_user)):
"""获取当前用户信息(带末尾斜杠)"""
return current_user
@router.get("/", response_model=UserListResponse)
async def get_users(
skip: int = 0,

View File

@@ -9,6 +9,8 @@ class AlgorithmBase(BaseModel):
name: str = Field(..., description="算法名称")
description: str = Field(..., description="算法描述")
type: str = Field(..., description="算法类型")
tech_category: str = Field(default="computer_vision", description="技术分类")
output_type: str = Field(default="image", description="输出类型")
class AlgorithmCreate(AlgorithmBase):

View File

@@ -0,0 +1,165 @@
from typing import Dict, Any, List
import asyncio
import httpx
import logging
logger = logging.getLogger(__name__)
class ComparisonService:
"""效果对比服务"""
async def compare_algorithms(
self,
input_data: Dict[str, Any],
algorithm_configs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""比较多个算法的效果
Args:
input_data: 输入数据
algorithm_configs: 算法配置列表每个配置包含服务URL、参数等
Returns:
对比结果
"""
try:
# 异步执行所有算法
tasks = []
for config in algorithm_configs:
task = self._execute_algorithm(config, input_data)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
comparison_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
comparison_results.append({
"algorithm_id": algorithm_configs[i].get("id"),
"algorithm_name": algorithm_configs[i].get("name"),
"success": False,
"error": str(result),
"output": None,
"execution_time": 0
})
else:
comparison_results.append({
"algorithm_id": algorithm_configs[i].get("id"),
"algorithm_name": algorithm_configs[i].get("name"),
"success": True,
"error": None,
"output": result.get("output"),
"execution_time": result.get("execution_time", 0)
})
return {
"success": True,
"results": comparison_results,
"input_data": input_data
}
except Exception as e:
logger.error(f"Comparison error: {str(e)}")
return {
"success": False,
"error": str(e),
"results": []
}
async def _execute_algorithm(
self,
config: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""执行单个算法
Args:
config: 算法配置
input_data: 输入数据
Returns:
执行结果
"""
import time
start_time = time.time()
try:
url = config.get("url")
params = config.get("params", {})
if not url:
raise ValueError("缺少算法服务URL")
# 构建请求数据
request_data = {
"input_data": input_data.get("input_data", input_data),
"params": params
}
# 发送请求
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(f"{url}/predict", json=request_data)
response.raise_for_status()
result = response.json()
execution_time = time.time() - start_time
return {
"output": result,
"execution_time": execution_time
}
except Exception as e:
logger.error(f"Algorithm execution error: {str(e)}")
raise e
def generate_comparison_report(self, comparison_results: Dict[str, Any]) -> Dict[str, Any]:
"""生成对比报告
Args:
comparison_results: 对比结果
Returns:
对比报告
"""
try:
if not comparison_results.get("success"):
return {
"success": False,
"error": comparison_results.get("error", "对比失败")
}
results = comparison_results.get("results", [])
# 分析结果
successful_algorithms = [r for r in results if r.get("success")]
failed_algorithms = [r for r in results if not r.get("success")]
# 计算平均执行时间
if successful_algorithms:
avg_execution_time = sum(r.get("execution_time", 0) for r in successful_algorithms) / len(successful_algorithms)
else:
avg_execution_time = 0
# 生成报告
report = {
"summary": {
"total_algorithms": len(results),
"successful_algorithms": len(successful_algorithms),
"failed_algorithms": len(failed_algorithms),
"average_execution_time": round(avg_execution_time, 2)
},
"details": results,
"input_data": comparison_results.get("input_data")
}
return {
"success": True,
"report": report
}
except Exception as e:
logger.error(f"Report generation error: {str(e)}")
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,165 @@
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from app.models.models import ServiceConfig
import uuid
import logging
logger = logging.getLogger(__name__)
class ConfigService:
"""配置服务"""
@staticmethod
def get_config(db: Session, config_key: str) -> Optional[Dict[str, Any]]:
"""获取配置
Args:
db: 数据库会话
config_key: 配置键
Returns:
配置值如果不存在返回None
"""
config = db.query(ServiceConfig).filter_by(
config_key=config_key,
status="active"
).first()
if config:
return config.config_value
return None
@staticmethod
def set_config(db: Session, config_key: str, config_value: Dict[str, Any],
config_type: str = "system", service_id: Optional[str] = None,
description: str = "") -> bool:
"""设置配置
Args:
db: 数据库会话
config_key: 配置键
config_value: 配置值
config_type: 配置类型,默认为"system"
service_id: 服务ID系统配置可为None
description: 配置描述
Returns:
是否设置成功
"""
try:
# 检查是否存在
existing_config = db.query(ServiceConfig).filter_by(
config_key=config_key
).first()
if existing_config:
# 更新现有配置
existing_config.config_value = config_value
existing_config.config_type = config_type
existing_config.service_id = service_id
existing_config.description = description
existing_config.status = "active"
else:
# 创建新配置
new_config = ServiceConfig(
id=f"config-{uuid.uuid4()}",
config_key=config_key,
config_value=config_value,
config_type=config_type,
service_id=service_id,
description=description,
status="active"
)
db.add(new_config)
db.commit()
return True
except Exception as e:
logger.error(f"Failed to set config: {str(e)}")
db.rollback()
return False
@staticmethod
def get_service_configs(db: Session, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的所有配置
Args:
db: 数据库会话
service_id: 服务ID
Returns:
服务配置列表
"""
configs = db.query(ServiceConfig).filter_by(
service_id=service_id,
status="active"
).all()
return [
{
"key": config.config_key,
"value": config.config_value,
"type": config.config_type,
"description": config.description
}
for config in configs
]
@staticmethod
def delete_config(db: Session, config_key: str) -> bool:
"""删除配置
Args:
db: 数据库会话
config_key: 配置键
Returns:
是否删除成功
"""
try:
config = db.query(ServiceConfig).filter_by(
config_key=config_key
).first()
if config:
config.status = "inactive"
db.commit()
return True
except Exception as e:
logger.error(f"Failed to delete config: {str(e)}")
db.rollback()
return False
@staticmethod
def get_all_configs(db: Session, config_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取所有配置
Args:
db: 数据库会话
config_type: 配置类型,可选
Returns:
配置列表
"""
query = db.query(ServiceConfig).filter_by(status="active")
if config_type:
query = query.filter_by(config_type=config_type)
configs = query.all()
return [
{
"id": config.id,
"key": config.config_key,
"value": config.config_value,
"type": config.config_type,
"service_id": config.service_id,
"description": config.description,
"created_at": config.created_at,
"updated_at": config.updated_at
}
for config in configs
]

View File

@@ -63,22 +63,41 @@ class ProjectAnalyzer:
Returns:
项目类型,如 "python", "java", "nodejs"
"""
# 检查Python项目
# 检查Python项目 - 先检查根目录
if os.path.exists(os.path.join(repo_path, "requirements.txt")) or \
os.path.exists(os.path.join(repo_path, "pyproject.toml")) or \
any(file.endswith(".py") for file in os.listdir(repo_path)):
return "python"
# 检查Java项目
# 检查Python项目 - 递归检查子目录
for root, dirs, files in os.walk(repo_path):
if "requirements.txt" in files or "pyproject.toml" in files:
return "python"
if any(file.endswith(".py") for file in files):
return "python"
# 检查Java项目 - 先检查根目录
if os.path.exists(os.path.join(repo_path, "pom.xml")) or \
os.path.exists(os.path.join(repo_path, "build.gradle")) or \
os.path.exists(os.path.join(repo_path, "src")):
return "java"
# 检查Node.js项目
# 检查Java项目 - 递归检查子目录
for root, dirs, files in os.walk(repo_path):
if "pom.xml" in files or "build.gradle" in files:
return "java"
if "src" in dirs:
return "java"
# 检查Node.js项目 - 先检查根目录
if os.path.exists(os.path.join(repo_path, "package.json")):
return "nodejs"
# 检查Node.js项目 - 递归检查子目录
for root, dirs, files in os.walk(repo_path):
if "package.json" in files:
return "nodejs"
# 检查其他项目类型
if os.path.exists(os.path.join(repo_path, "CMakeLists.txt")):
return "c++"

View File

@@ -38,13 +38,14 @@ class ServiceOrchestrator:
self.client = None
print("使用本地进程部署模式")
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any]) -> Dict[str, Any]:
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any], repo_path: str = None) -> Dict[str, Any]:
"""部署服务
Args:
service_id: 服务ID
service_config: 服务配置
project_info: 项目信息
repo_path: 仓库路径(用于复制真实的算法文件)
Returns:
部署结果
@@ -95,7 +96,7 @@ class ServiceOrchestrator:
service_dir = self._create_service_directory(service_id)
# 2. 生成服务包装器
self._generate_local_service_wrapper(service_dir, project_info, service_config)
self._generate_local_service_wrapper(service_dir, project_info, service_config, repo_path)
# 3. 启动服务进程
process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config)
@@ -176,9 +177,13 @@ class ServiceOrchestrator:
else:
# 本地进程启动
if service_id not in self.processes:
# 服务不在进程列表中,可能是服务重启导致的
# 这种情况下,需要从外部重新注册服务
# 暂时返回错误,建议用户重新注册服务
print(f"服务 {service_id} 不在进程列表中,无法启动")
return {
"success": False,
"error": "服务不存在",
"error": "服务不存在,请重新注册服务",
"service_id": service_id,
"status": "error"
}
@@ -272,11 +277,18 @@ class ServiceOrchestrator:
else:
# 本地进程停止
if service_id not in self.processes:
# 服务不在进程列表中,可能是服务重启导致的
# 尝试通过端口查找并停止进程
print(f"服务 {service_id} 不在进程列表中,尝试通过端口查找进程")
# 从服务配置中获取端口信息
# 这里需要从外部传入服务配置,或者从数据库查询
# 暂时返回成功,因为服务可能已经停止了
return {
"success": False,
"error": "服务不存在",
"success": True,
"service_id": service_id,
"status": "error"
"status": "stopped",
"error": None
}
process_info = self.processes[service_id]
@@ -1271,13 +1283,14 @@ json
os.makedirs(service_dir, exist_ok=True)
return service_dir
def _generate_local_service_wrapper(self, service_dir: str, project_info: Dict[str, Any], service_config: Dict[str, Any]):
def _generate_local_service_wrapper(self, service_dir: str, project_info: Dict[str, Any], service_config: Dict[str, Any], repo_path: str = None):
"""生成本地服务包装器
Args:
service_dir: 服务目录
project_info: 项目信息
service_config: 服务配置
repo_path: 仓库路径(用于复制真实的算法文件)
"""
# 生成服务包装器
service_wrapper_content = self._generate_service_wrapper(project_info, service_config)
@@ -1285,7 +1298,44 @@ json
with open(os.path.join(service_dir, f"service_wrapper{wrapper_extension}"), "w") as f:
f.write(service_wrapper_content)
# 创建模拟的算法文件
# 复制真实的算法文件
if repo_path and project_info["project_type"] == "python":
# 尝试找到并复制主要的算法文件
entry_point = project_info.get("entry_point")
if entry_point:
source_file = os.path.join(repo_path, entry_point)
if os.path.exists(source_file):
# 复制算法文件到服务目录
import shutil
shutil.copy2(source_file, os.path.join(service_dir, "algorithm.py"))
print(f"已复制算法文件: {source_file} -> {os.path.join(service_dir, 'algorithm.py')}")
return
# 如果没有找到入口点尝试复制所有Python文件
if os.path.exists(repo_path):
import shutil
for root, dirs, files in os.walk(repo_path):
for file in files:
if file.endswith(".py") and not file.startswith("_"):
source_file = os.path.join(root, file)
dest_file = os.path.join(service_dir, file)
shutil.copy2(source_file, dest_file)
print(f"已复制Python文件: {source_file} -> {dest_file}")
# 如果有algorithm.py就使用它否则创建一个模拟的
if not os.path.exists(os.path.join(service_dir, "algorithm.py")):
print("未找到algorithm.py创建模拟算法文件")
self._create_mock_algorithm(service_dir)
else:
# 创建模拟的算法文件
self._create_mock_algorithm(service_dir)
def _create_mock_algorithm(self, service_dir: str):
"""创建模拟的算法文件
Args:
service_dir: 服务目录
"""
algorithm_content = """
def predict(data):
return {"result": "Prediction result", "input": data}
@@ -1316,9 +1366,9 @@ def main(data):
# 构建启动命令
if project_info["project_type"] == "python":
cmd = ["python", f"service_wrapper.py"]
cmd = ["python", "service_wrapper.py"]
else:
cmd = ["node", f"service_wrapper.js"]
cmd = ["node", "service_wrapper.js"]
# 设置环境变量
env = os.environ.copy()

2
backend/backend.log Normal file
View File

@@ -0,0 +1,2 @@
INFO: Will watch for changes in these directories: ['/Users/duguoyou/MLFlow/algorithm-showcase/backend']
ERROR: [Errno 48] Address already in use

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""检查算法数据"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app.models.database import SessionLocal
from app.models.models import Algorithm
def check_algorithms():
"""检查算法数据"""
db = SessionLocal()
try:
algorithms = db.query(Algorithm).all()
print(f"数据库中共有 {len(algorithms)} 个算法:\n")
for algo in algorithms:
print(f"算法名称: {algo.name}")
print(f" ID: {algo.id}")
print(f" 类型: {algo.type}")
print(f" 技术分类: {algo.tech_category}")
print(f" 输出类型: {algo.output_type}")
print(f" 描述: {algo.description}")
print(f" 状态: {algo.status}")
print(f" 版本数: {len(algo.versions)}")
print()
except Exception as e:
print(f"检查算法数据失败: {e}")
sys.exit(1)
finally:
db.close()
if __name__ == "__main__":
check_algorithms()

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""检查用户角色信息"""
import requests
def check_user_role():
"""检查用户角色"""
base_url = "http://localhost:8001/api/v1"
# 登录
print("步骤1: 登录")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
if response.status_code != 200:
print(f"登录失败: {response.text}")
return
data = response.json()
access_token = data.get('access_token')
print(f"登录成功!")
# 获取用户信息
print("\n步骤2: 获取用户信息")
headers = {"Authorization": f"Bearer {access_token}"}
user_response = requests.get(f"{base_url}/users/me", headers=headers)
print(f"状态码: {user_response.status_code}")
if user_response.status_code == 200:
user_data = user_response.json()
print(f"\n用户信息:")
print(f" 用户名: {user_data.get('username', 'N/A')}")
print(f" 邮箱: {user_data.get('email', 'N/A')}")
print(f" 角色ID: {user_data.get('role_id', 'N/A')}")
print(f" 角色名称: {user_data.get('role_name', 'N/A')}")
print(f" 角色对象: {user_data.get('role', 'N/A')}")
# 检查是否是管理员
role_name = user_data.get('role_name')
if role_name == 'admin':
print(f"\n✅ 用户是管理员,应该显示后台管理页面")
else:
print(f"\n❌ 用户不是管理员,角色名称是: {role_name}")
else:
print(f"获取用户信息失败: {user_response.text}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
check_user_role()

View File

@@ -1,39 +1,54 @@
#!/usr/bin/env python3
"""
检查数据库中的用户信息
"""
"""检查数据库中的用户信息"""
import sys
sys.path.insert(0, '/Users/duguoyou/MLFlow/algorithm-showcase/backend')
from sqlalchemy.orm import Session
from app.models.database import SessionLocal
from app.models.models import User
from app.services.user import UserService
def check_users():
"""检查数据库中的用户信息"""
"""检查用户"""
db = SessionLocal()
try:
# 获取所有用户
users = db.query(User).all()
if users:
print("数据库中的用户信息:")
print("-" * 50)
print(f"数据库中的用户数量: {len(users)}")
for user in users:
print(f"\n用户ID: {user.id}")
print(f"用户名: {user.username}")
print(f"邮箱: {user.email}")
print(f"状态: {user.status}")
print(f"角色ID: {user.role_id}")
print(f"密码哈希: {user.password_hash[:50]}...")
# 测试admin用户认证
print("\n\n测试admin用户认证:")
admin_user = UserService.get_user_by_username(db, 'admin')
if admin_user:
print(f"找到admin用户: {admin_user.id}")
print(f"密码哈希: {admin_user.password_hash[:50]}...")
for user in users:
print(f"用户ID: {user.id}")
print(f"用户名: {user.username}")
print(f"邮箱: {user.email}")
print(f"角色: {user.role}")
print(f"状态: {user.status}")
print(f"创建时间: {user.created_at}")
print("-" * 50)
# 测试密码验证
test_password = 'admin123'
is_valid = UserService.verify_password(test_password, admin_user.password_hash)
print(f"密码 '{test_password}' 验证结果: {is_valid}")
# 尝试认证
authenticated_user = UserService.authenticate_user(db, 'admin', test_password)
if authenticated_user:
print(f"认证成功: {authenticated_user.id}")
else:
print("认证失败")
else:
print("数据库中没有用户信息")
print("未找到admin用户")
finally:
db.close()
if __name__ == "__main__":
check_users()
check_users()

View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""创建示例算法数据"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app.models.database import SessionLocal
from app.models.models import Algorithm, AlgorithmVersion
from datetime import datetime
import uuid
def create_sample_algorithms():
"""创建示例算法"""
db = SessionLocal()
try:
# 示例算法数据
algorithms_data = [
{
"name": "目标检测",
"description": "识别图像中的物体位置和类别,支持人脸、车辆、物品等多种目标检测",
"type": "computer_vision",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8001",
"is_default": True
}
]
},
{
"name": "视频分析",
"description": "分析视频内容,提取关键帧、识别动作、追踪物体等",
"type": "computer_vision",
"tech_category": "video_processing",
"output_type": "video",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8002",
"is_default": True
}
]
},
{
"name": "图像增强",
"description": "提升图像质量,包括去噪、超分辨率、色彩校正等功能",
"type": "computer_vision",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8003",
"is_default": True
}
]
},
{
"name": "文本分类",
"description": "对文本内容进行分类,支持新闻分类、情感分析、垃圾邮件识别等",
"type": "nlp",
"tech_category": "nlp",
"output_type": "text",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8004",
"is_default": True
}
]
},
{
"name": "异常检测",
"description": "检测数据中的异常模式,适用于工业监控、金融风控等场景",
"type": "ml",
"tech_category": "ml",
"output_type": "json",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8005",
"is_default": True
}
]
},
{
"name": "医学影像分析",
"description": "分析医学影像辅助医生进行疾病诊断支持CT、MRI等多种影像格式",
"type": "medical",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8006",
"is_default": True
}
]
}
]
# 创建算法
for algo_data in algorithms_data:
# 检查算法是否已存在
existing_algo = db.query(Algorithm).filter(Algorithm.name == algo_data["name"]).first()
if existing_algo:
print(f"✓ 算法 '{algo_data['name']}' 已存在,跳过")
continue
# 创建算法
algorithm = Algorithm(
id=str(uuid.uuid4()),
name=algo_data["name"],
description=algo_data["description"],
type=algo_data["type"],
tech_category=algo_data["tech_category"],
output_type=algo_data["output_type"],
status="active"
)
db.add(algorithm)
db.flush() # 获取算法ID
# 创建版本
for version_data in algo_data["versions"]:
version = AlgorithmVersion(
id=str(uuid.uuid4()),
algorithm_id=algorithm.id,
version=version_data["version"],
url=version_data["url"],
is_default=version_data["is_default"]
)
db.add(version)
print(f"✓ 已创建算法: {algo_data['name']}")
db.commit()
print("\n示例算法创建完成!")
except Exception as e:
db.rollback()
print(f"创建示例算法失败: {e}")
sys.exit(1)
finally:
db.close()
if __name__ == "__main__":
create_sample_algorithms()

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""数据库迁移脚本:添加技术分类和输出类型字段"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sqlalchemy import text
from app.models.database import engine
def migrate():
"""执行数据库迁移"""
try:
with engine.connect() as conn:
# 检查字段是否已存在PostgreSQL语法
result = conn.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'algorithms'
"""))
columns = [row[0] for row in result.fetchall()]
# 添加 tech_category 字段
if 'tech_category' not in columns:
conn.execute(text("ALTER TABLE algorithms ADD COLUMN tech_category VARCHAR(50) DEFAULT 'computer_vision'"))
print("✓ 已添加 tech_category 字段")
else:
print("✓ tech_category 字段已存在")
# 添加 output_type 字段
if 'output_type' not in columns:
conn.execute(text("ALTER TABLE algorithms ADD COLUMN output_type VARCHAR(50) DEFAULT 'image'"))
print("✓ 已添加 output_type 字段")
else:
print("✓ output_type 字段已存在")
conn.commit()
print("\n数据库迁移完成!")
except Exception as e:
print(f"数据库迁移失败: {e}")
sys.exit(1)
if __name__ == "__main__":
migrate()

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""数据库迁移脚本为algorithm_services表添加技术分类和输出类型字段"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sqlalchemy import text
from app.models.database import engine
def migrate():
"""执行数据库迁移"""
try:
with engine.connect() as conn:
# 检查字段是否已存在PostgreSQL语法
result = conn.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'algorithm_services'
"""))
columns = [row[0] for row in result.fetchall()]
# 添加 tech_category 字段
if 'tech_category' not in columns:
conn.execute(text("ALTER TABLE algorithm_services ADD COLUMN tech_category VARCHAR(50) DEFAULT 'computer_vision'"))
print("✓ 已添加 tech_category 字段到 algorithm_services 表")
else:
print("✓ tech_category 字段已存在于 algorithm_services 表")
# 添加 output_type 字段
if 'output_type' not in columns:
conn.execute(text("ALTER TABLE algorithm_services ADD COLUMN output_type VARCHAR(50) DEFAULT 'image'"))
print("✓ 已添加 output_type 字段到 algorithm_services 表")
else:
print("✓ output_type 字段已存在于 algorithm_services 表")
conn.commit()
print("\n数据库迁移完成!")
except Exception as e:
print(f"数据库迁移失败: {e}")
sys.exit(1)
if __name__ == "__main__":
migrate()

48
backend/test_all_apis.py Normal file
View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""测试所有API端点"""
import requests
def test_apis():
"""测试API端点"""
base_url = "http://localhost:8001/api/v1"
# 测试算法列表(不需要认证)
print("1. 测试算法列表(不需要认证):")
try:
response = requests.get(f"{base_url}/algorithms/")
print(f" 状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f" 成功获取 {len(data.get('algorithms', []))} 个算法")
else:
print(f" 失败: {response.text}")
except Exception as e:
print(f" 错误: {e}")
# 测试用户信息(需要认证)
print("\n2. 测试用户信息(需要认证):")
try:
response = requests.get(f"{base_url}/users/me")
print(f" 状态码: {response.status_code}")
if response.status_code == 401:
print(f" 需要认证(正常)")
else:
print(f" 响应: {response.text}")
except Exception as e:
print(f" 错误: {e}")
# 测试服务列表(需要认证)
print("\n3. 测试服务列表(需要认证):")
try:
response = requests.get(f"{base_url}/services")
print(f" 状态码: {response.status_code}")
if response.status_code == 401:
print(f" 需要认证(正常)")
else:
print(f" 响应: {response.text[:200]}")
except Exception as e:
print(f" 错误: {e}")
if __name__ == "__main__":
test_apis()

53
backend/test_api.py Normal file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
"""测试前端API调用"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import requests
def test_api():
"""测试API"""
try:
# 调用算法列表API
response = requests.get('http://localhost:8001/api/v1/algorithms/')
if response.status_code == 200:
data = response.json()
algorithms = data.get('algorithms', [])
print(f"成功获取 {len(algorithms)} 个算法\n")
# 检查每个算法的字段
for algo in algorithms:
print(f"算法: {algo['name']}")
print(f" 技术分类: {algo.get('tech_category', 'N/A')}")
print(f" 输出类型: {algo.get('output_type', 'N/A')}")
print()
# 测试筛选
print("测试筛选功能:")
# 按技术分类筛选
cv_algorithms = [a for a in algorithms if a.get('tech_category') == 'computer_vision']
print(f" 计算机视觉算法: {len(cv_algorithms)}")
# 按输出类型筛选
image_algorithms = [a for a in algorithms if a.get('output_type') == 'image']
print(f" 图片输出算法: {len(image_algorithms)}")
# 按名称搜索
search_results = [a for a in algorithms if '视频' in a.get('name', '')]
print(f" 包含'视频'的算法: {len(search_results)}")
else:
print(f"API调用失败: {response.status_code}")
print(response.text)
except Exception as e:
print(f"测试失败: {e}")
sys.exit(1)
if __name__ == "__main__":
test_api()

View File

@@ -0,0 +1,32 @@
#!/usr/bin/env python3
"""测试前端代理配置"""
import requests
def test_frontend_proxy():
"""测试前端代理"""
try:
# 测试前端代理
response = requests.get('http://localhost:3000/api/algorithms')
print(f"状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"成功获取 {len(data.get('algorithms', []))} 个算法")
# 检查第一个算法的字段
if data.get('algorithms'):
first_algo = data['algorithms'][0]
print(f"\n第一个算法:")
print(f" 名称: {first_algo.get('name')}")
print(f" 技术分类: {first_algo.get('tech_category')}")
print(f" 输出类型: {first_algo.get('output_type')}")
else:
print(f"请求失败: {response.text}")
except Exception as e:
print(f"测试失败: {e}")
if __name__ == "__main__":
test_frontend_proxy()

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""测试完整的登录流程"""
import requests
def test_full_login_flow():
"""测试完整的登录流程"""
base_url = "http://localhost:8001/api/v1"
# 步骤1: 登录
print("步骤1: 登录")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
if response.status_code != 200:
print(f"登录失败: {response.text}")
return
data = response.json()
access_token = data.get('access_token')
print(f"登录成功!")
print(f"Token: {access_token[:50]}...")
# 步骤2: 使用token获取用户信息
print("\n步骤2: 获取用户信息")
headers = {"Authorization": f"Bearer {access_token}"}
user_response = requests.get(f"{base_url}/users/me", headers=headers)
print(f"状态码: {user_response.status_code}")
if user_response.status_code == 200:
user_data = user_response.json()
print(f"用户名: {user_data.get('username', 'N/A')}")
print(f"邮箱: {user_data.get('email', 'N/A')}")
print(f"角色: {user_data.get('role_name', 'N/A')}")
else:
print(f"获取用户信息失败: {user_response.text}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
test_full_login_flow()

45
backend/test_login.py Normal file
View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""测试登录功能"""
import requests
def test_login():
"""测试登录"""
base_url = "http://localhost:8001/api/v1"
# 测试登录
print("测试登录功能:")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"登录成功!")
print(f"访问令牌: {data.get('access_token', 'N/A')[:50]}...")
print(f"令牌类型: {data.get('token_type', 'N/A')}")
# 测试使用令牌访问受保护的API
if data.get('access_token'):
headers = {"Authorization": f"Bearer {data['access_token']}"}
user_response = requests.get(f"{base_url}/users/me", headers=headers)
print(f"\n测试用户信息API:")
print(f"状态码: {user_response.status_code}")
if user_response.status_code == 200:
user_data = user_response.json()
print(f"用户名: {user_data.get('username', 'N/A')}")
print(f"邮箱: {user_data.get('email', 'N/A')}")
else:
print(f"失败: {user_response.text}")
else:
print(f"登录失败: {response.text}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
test_login()

53
backend/test_login_api.py Normal file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
"""直接测试登录API"""
import requests
def test_login_api():
"""测试登录API"""
base_url = "http://localhost:8001/api/v1"
# 测试1: 使用JSON格式
print("测试1: 使用JSON格式")
login_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", json=login_data)
print(f"状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
print(f"响应内容: {response.text[:500]}")
if response.status_code == 200:
data = response.json()
print(f"✅ 登录成功!")
print(f"Token: {data.get('access_token', 'N/A')[:50]}...")
else:
print(f"❌ 登录失败")
except Exception as e:
print(f"错误: {e}")
# 测试2: 使用form-data格式
print("\n\n测试2: 使用form-data格式")
form_data = {
"username": "admin",
"password": "admin123"
}
try:
response = requests.post(f"{base_url}/users/login", data=form_data)
print(f"状态码: {response.status_code}")
print(f"响应内容: {response.text[:500]}")
if response.status_code == 200:
data = response.json()
print(f"✅ 登录成功!")
else:
print(f"❌ 登录失败")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
test_login_api()

232
backend/test_system.py Normal file
View File

@@ -0,0 +1,232 @@
import requests
import json
import time
from typing import Dict, Any, List
class SystemTester:
def __init__(self, base_url: str = "http://localhost:8001/api/v1"):
self.base_url = base_url
self.session = requests.Session()
self.token = None
self.user_id = None
def login(self, username: str = "admin", password: str = "admin123") -> bool:
"""登录系统"""
try:
response = self.session.post(
f"{self.base_url}/users/login",
json={"username": username, "password": password}
)
if response.status_code == 200:
data = response.json()
self.token = data.get("access_token")
self.user_id = data.get("user_id")
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
print(f"✓ 登录成功: {username}")
return True
else:
print(f"✗ 登录失败: {response.status_code} - {response.text}")
return False
except Exception as e:
print(f"✗ 登录异常: {str(e)}")
return False
def test_config_endpoints(self) -> bool:
"""测试配置管理API"""
print("\n=== 测试配置管理API ===")
success = True
try:
# 测试获取所有配置
response = self.session.get(f"{self.base_url}/config/")
if response.status_code == 200:
print("✓ 获取所有配置成功")
configs = response.json().get("configs", [])
print(f" 当前配置数量: {len(configs)}")
else:
print(f"✗ 获取所有配置失败: {response.status_code}")
success = False
# 测试添加配置
test_config = {
"value": "test_value_123",
"type": "system",
"service_id": None,
"description": "测试配置"
}
response = self.session.post(f"{self.base_url}/config/test_config_key", json=test_config)
if response.status_code == 200:
print("✓ 添加配置成功")
else:
print(f"✗ 添加配置失败: {response.status_code} - {response.text}")
success = False
# 测试获取单个配置
response = self.session.get(f"{self.base_url}/config/test_config_key")
if response.status_code == 200:
print("✓ 获取单个配置成功")
config_data = response.json()
print(f" 配置值: {config_data.get('value')}")
else:
print(f"✗ 获取单个配置失败: {response.status_code}")
success = False
# 测试删除配置
response = self.session.delete(f"{self.base_url}/config/test_config_key")
if response.status_code == 200:
print("✓ 删除配置成功")
else:
print(f"✗ 删除配置失败: {response.status_code}")
success = False
return success
except Exception as e:
print(f"✗ 配置管理API测试异常: {str(e)}")
return False
def test_comparison_endpoints(self) -> bool:
"""测试算法比较API"""
print("\n=== 测试算法比较API ===")
success = True
try:
# 测试算法比较(使用模拟数据)
test_data = {
"input_data": {"text": "这是一段测试文本"},
"algorithm_configs": [
{
"algorithm_id": "test_algo_1",
"algorithm_name": "测试算法1",
"version": "1.0.0",
"config": "{}"
},
{
"algorithm_id": "test_algo_2",
"algorithm_name": "测试算法2",
"version": "1.0.0",
"config": "{}"
}
]
}
response = self.session.post(f"{self.base_url}/comparison/compare-algorithms", json=test_data)
if response.status_code == 200:
print("✓ 算法比较API调用成功")
result = response.json()
print(f" 比较状态: {result.get('success')}")
if result.get('results'):
print(f" 结果数量: {len(result.get('results'))}")
else:
print(f"✗ 算法比较失败: {response.status_code} - {response.text}")
success = False
return success
except Exception as e:
print(f"✗ 算法比较API测试异常: {str(e)}")
return False
def test_existing_endpoints(self) -> bool:
"""测试现有API端点"""
print("\n=== 测试现有API端点 ===")
success = True
try:
# 测试健康检查
response = self.session.get(f"{self.base_url.replace('/api/v1', '')}/health")
if response.status_code == 200:
print("✓ 健康检查通过")
else:
print(f"✗ 健康检查失败: {response.status_code}")
success = False
# 测试获取当前用户
response = self.session.get(f"{self.base_url}/users/me")
if response.status_code == 200:
print("✓ 获取当前用户成功")
user_data = response.json()
print(f" 用户名: {user_data.get('username')}")
else:
print(f"✗ 获取当前用户失败: {response.status_code}")
success = False
# 测试获取算法列表
response = self.session.get(f"{self.base_url}/algorithms/")
if response.status_code == 200:
print("✓ 获取算法列表成功")
algorithms = response.json()
print(f" 算法数量: {len(algorithms) if isinstance(algorithms, list) else 0}")
else:
print(f"✗ 获取算法列表失败: {response.status_code}")
success = False
# 测试获取服务列表
response = self.session.get(f"{self.base_url}/services")
if response.status_code == 200:
print("✓ 获取服务列表成功")
services = response.json()
print(f" 服务数量: {len(services) if isinstance(services, list) else 0}")
else:
print(f"✗ 获取服务列表失败: {response.status_code}")
success = False
return success
except Exception as e:
print(f"✗ 现有API端点测试异常: {str(e)}")
return False
def run_all_tests(self) -> Dict[str, bool]:
"""运行所有测试"""
print("=" * 50)
print("开始系统自动化测试")
print("=" * 50)
results = {}
# 登录
if not self.login():
print("\n✗ 登录失败,无法继续测试")
return {"login": False}
results["login"] = True
# 测试现有端点
results["existing_endpoints"] = self.test_existing_endpoints()
# 测试配置管理API
results["config_endpoints"] = self.test_config_endpoints()
# 测试算法比较API
results["comparison_endpoints"] = self.test_comparison_endpoints()
# 输出测试结果
print("\n" + "=" * 50)
print("测试结果汇总")
print("=" * 50)
for test_name, result in results.items():
status = "✓ 通过" if result else "✗ 失败"
print(f"{test_name}: {status}")
total_tests = len(results)
passed_tests = sum(1 for result in results.values() if result)
print(f"\n总计: {passed_tests}/{total_tests} 测试通过")
if passed_tests == total_tests:
print("🎉 所有测试通过!")
else:
print("⚠️ 部分测试失败,请检查日志")
return results
def main():
"""主函数"""
tester = SystemTester()
results = tester.run_all_tests()
# 返回退出码
exit_code = 0 if all(results.values()) else 1
return exit_code
if __name__ == "__main__":
exit(main())