good version for 算法注册
This commit is contained in:
0
backend/algorithm_showcase.db
Normal file
0
backend/algorithm_showcase.db
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,6 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -46,11 +47,56 @@ class Settings(BaseSettings):
|
||||
GITEA_DEFAULT_OWNER: str = ""
|
||||
GITEA_REPO_PREFIX: str = "AI"
|
||||
|
||||
# 服务管理配置
|
||||
SERVICE_MANAGEMENT: Dict[str, Any] = {
|
||||
"mode": "supervisor", # 服务管理模式:local, docker, supervisor
|
||||
"service_root_dir": "/opt/ai-services",
|
||||
"supervisor_config_dir": "/etc/supervisor/conf.d",
|
||||
}
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
extra = "allow" # 允许额外的环境变量
|
||||
|
||||
def get_config(self, config_key: str, default: Any = None) -> Any:
|
||||
"""获取配置,优先级:环境变量 > 数据库 > 文件默认值
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
"""
|
||||
# 1. 先从环境变量获取
|
||||
env_key = config_key.upper().replace('.', '_')
|
||||
env_value = getattr(self, env_key, None)
|
||||
if env_value is not None:
|
||||
return env_value
|
||||
|
||||
# 2. 从数据库获取
|
||||
try:
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import ServiceConfig
|
||||
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
config = db.query(ServiceConfig).filter_by(
|
||||
config_key=config_key,
|
||||
status="active"
|
||||
).first()
|
||||
if config:
|
||||
return config.config_value
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# 数据库连接失败时,返回默认值
|
||||
print(f"Failed to load config from database: {str(e)}")
|
||||
|
||||
# 3. 返回默认值
|
||||
return default
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
|
||||
@@ -72,9 +72,24 @@ class APIGateway:
|
||||
if not version_info:
|
||||
raise HTTPException(status_code=404, detail="Algorithm version not found")
|
||||
|
||||
# 在实际实现中,这里会根据version_info.url将请求转发到对应的算法服务
|
||||
# 现在我们模拟调用过程
|
||||
algorithm_url = version_info.url if hasattr(version_info, 'url') else f"http://localhost:8001/algorithms/{algorithm_id}/execute"
|
||||
# 尝试从算法版本获取URL,如果没有则尝试从服务表获取
|
||||
algorithm_url = None
|
||||
|
||||
# 首先检查版本信息中是否有URL
|
||||
if hasattr(version_info, 'url') and version_info.url:
|
||||
algorithm_url = version_info.url
|
||||
else:
|
||||
# 如果版本信息中没有URL,尝试从服务表获取
|
||||
from app.models.models import AlgorithmService
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.algorithm_name == algorithm_id
|
||||
).first()
|
||||
|
||||
if service and service.api_url:
|
||||
algorithm_url = service.api_url
|
||||
else:
|
||||
# 如果都没有,使用默认的本地端点
|
||||
algorithm_url = f"http://localhost:8001/algorithms/{algorithm_id}/execute"
|
||||
|
||||
# 使用httpx调用算法服务
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
|
||||
@@ -4,7 +4,8 @@ from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.database import engine, Base
|
||||
from app.routes import user, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories
|
||||
from app.models import models, api # 导入所有模型以确保表被创建
|
||||
from app.routes import user, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories, config, comparison, api_management
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -48,6 +49,9 @@ app.include_router(history.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(deployment.router)
|
||||
app.include_router(gitea.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(repositories.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(config.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(comparison.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(api_management.router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
BIN
backend/app/models/__pycache__/api.cpython-312.pyc
Normal file
BIN
backend/app/models/__pycache__/api.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
74
backend/app/models/api.py
Normal file
74
backend/app/models/api.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""API封装模型"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Text, Boolean, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.models.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class ApiEndpoint(Base):
|
||||
"""API端点模型"""
|
||||
__tablename__ = "api_endpoints"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String, nullable=False, index=True) # API名称
|
||||
description = Column(Text, default="") # API描述
|
||||
path = Column(String, nullable=False, unique=True, index=True) # API路径,如 /api/image-classification
|
||||
method = Column(String, default="POST") # HTTP方法:GET, POST, PUT, DELETE
|
||||
algorithm_id = Column(String, ForeignKey("algorithms.id"), nullable=False, index=True) # 关联的算法ID
|
||||
version_id = Column(String, ForeignKey("algorithm_versions.id"), nullable=False, index=True) # 关联的算法版本ID
|
||||
service_id = Column(String, ForeignKey("algorithm_services.service_id"), nullable=True, index=True) # 关联的服务ID
|
||||
|
||||
# API配置
|
||||
config = Column(JSON, nullable=False, default={}) # API配置(超时、重试等)
|
||||
request_schema = Column(JSON, nullable=True) # 请求参数schema
|
||||
response_schema = Column(JSON, nullable=True) # 响应参数schema
|
||||
|
||||
# 权限配置
|
||||
requires_auth = Column(Boolean, default=True) # 是否需要认证
|
||||
allowed_roles = Column(JSON, default=[]) # 允许的角色列表
|
||||
rate_limit = Column(JSON, nullable=True) # 限流配置,如 {"max_requests": 100, "window": 60}
|
||||
|
||||
# 状态
|
||||
status = Column(String, default="active", index=True) # 状态:active, inactive, deprecated
|
||||
is_public = Column(Boolean, default=False) # 是否公开
|
||||
|
||||
# 统计信息
|
||||
call_count = Column(String, default="0") # 调用次数
|
||||
success_count = Column(String, default="0") # 成功次数
|
||||
error_count = Column(String, default="0") # 错误次数
|
||||
avg_response_time = Column(String, default="0.0") # 平均响应时间
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=datetime.utcnow)
|
||||
last_called_at = Column(DateTime(timezone=True), nullable=True) # 最后调用时间
|
||||
|
||||
|
||||
class ApiCallLog(Base):
|
||||
"""API调用日志模型"""
|
||||
__tablename__ = "api_call_logs"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, default=lambda: str(uuid.uuid4()))
|
||||
api_endpoint_id = Column(String, ForeignKey("api_endpoints.id"), nullable=False, index=True) # API端点ID
|
||||
user_id = Column(String, ForeignKey("users.id"), nullable=True, index=True) # 调用用户ID
|
||||
|
||||
# 请求信息
|
||||
request_method = Column(String, nullable=False) # 请求方法
|
||||
request_path = Column(String, nullable=False) # 请求路径
|
||||
request_headers = Column(JSON, nullable=True) # 请求头
|
||||
request_body = Column(JSON, nullable=True) # 请求体
|
||||
|
||||
# 响应信息
|
||||
response_status = Column(String, nullable=False) # 响应状态码
|
||||
response_body = Column(JSON, nullable=True) # 响应体
|
||||
response_time = Column(String, nullable=False) # 响应时间(秒)
|
||||
|
||||
# 错误信息
|
||||
error_message = Column(Text, nullable=True) # 错误信息
|
||||
error_type = Column(String, nullable=True) # 错误类型
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), default=datetime.utcnow, index=True) # 调用时间
|
||||
@@ -13,6 +13,8 @@ class Algorithm(Base):
|
||||
name = Column(String, nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
type = Column(String, nullable=False, index=True) # computer_vision, nlp, ml, edge_computing, medical, autonomous_driving等
|
||||
tech_category = Column(String, nullable=False, default="computer_vision") # 技术分类:计算机视觉、视频处理、自然语言处理等
|
||||
output_type = Column(String, nullable=False, default="image") # 输出类型:图片、视频、文本、JSON等
|
||||
status = Column(String, default="active", index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
@@ -159,6 +161,8 @@ class AlgorithmService(Base):
|
||||
name = Column(String, nullable=False, index=True) # 服务名称
|
||||
algorithm_name = Column(String, nullable=False) # 算法名称
|
||||
version = Column(String, nullable=False) # 版本
|
||||
tech_category = Column(String, nullable=False, default="computer_vision") # 技术分类:计算机视觉、视频处理、自然语言处理等
|
||||
output_type = Column(String, nullable=False, default="image") # 输出类型:图片、视频、文本、JSON等
|
||||
host = Column(String, nullable=False) # 主机地址
|
||||
port = Column(Integer, nullable=False) # 端口
|
||||
api_url = Column(String, nullable=False) # API地址
|
||||
@@ -172,3 +176,18 @@ class AlgorithmService(Base):
|
||||
|
||||
# 添加Algorithm模型的repository关系
|
||||
Algorithm.repository = relationship("AlgorithmRepository", back_populates="algorithm", uselist=False)
|
||||
|
||||
|
||||
class ServiceConfig(Base):
|
||||
"""服务配置模型"""
|
||||
__tablename__ = "service_configs"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
config_key = Column(String, nullable=False, unique=True, index=True) # 配置键
|
||||
config_value = Column(JSON, nullable=False) # 配置值(JSON格式)
|
||||
config_type = Column(String, nullable=False) # 配置类型:system, service, user
|
||||
service_id = Column(String, nullable=True, index=True) # 服务ID(可为空,系统配置)
|
||||
description = Column(Text, default="") # 配置描述
|
||||
status = Column(String, default="active") # 状态
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment
|
||||
from app.routes import user, algorithm, history, gateway, monitoring, openai, deployment, config, comparison
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -11,3 +11,5 @@ api_router.include_router(gateway.router, prefix="/gateway", tags=["gateway"])
|
||||
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
|
||||
api_router.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||
api_router.include_router(deployment.router, tags=["deployment"])
|
||||
api_router.include_router(config.router, tags=["config"])
|
||||
api_router.include_router(comparison.router, tags=["comparison"])
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend/app/routes/__pycache__/api_management.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/api_management.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/comparison.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/comparison.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/config.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -33,6 +33,18 @@ async def create_algorithm(
|
||||
|
||||
|
||||
@router.get("", response_model=AlgorithmListResponse)
|
||||
async def get_algorithms_no_slash(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
type: Optional[str] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法列表(不带末尾斜杠)"""
|
||||
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
|
||||
return {"algorithms": algorithms, "total": len(algorithms)}
|
||||
|
||||
|
||||
@router.get("/", response_model=AlgorithmListResponse)
|
||||
async def get_algorithms(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
|
||||
510
backend/app/routes/api_management.py
Normal file
510
backend/app/routes/api_management.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""API管理路由,处理API端点的封装和管理"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmService, User
|
||||
from app.models.api import ApiEndpoint, ApiCallLog
|
||||
from app.schemas.user import UserResponse
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/api-management", tags=["api-management"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiEndpointCreate(BaseModel):
|
||||
"""创建API端点请求模型"""
|
||||
name: str
|
||||
description: str = ""
|
||||
path: str
|
||||
method: str = "POST"
|
||||
algorithm_id: str
|
||||
version_id: str
|
||||
service_id: Optional[str] = None
|
||||
requires_auth: bool = True
|
||||
allowed_roles: List[str] = []
|
||||
rate_limit: Optional[Dict[str, Any]] = None
|
||||
is_public: bool = False
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ApiEndpointUpdate(BaseModel):
|
||||
"""更新API端点请求模型"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
method: Optional[str] = None
|
||||
requires_auth: Optional[bool] = None
|
||||
allowed_roles: Optional[List[str]] = None
|
||||
rate_limit: Optional[Dict[str, Any]] = None
|
||||
is_public: Optional[bool] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class ApiEndpointResponse(BaseModel):
|
||||
"""API端点响应模型"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
method: str
|
||||
algorithm_id: str
|
||||
algorithm_name: str
|
||||
version_id: str
|
||||
version: str
|
||||
service_id: Optional[str]
|
||||
status: str
|
||||
is_public: bool
|
||||
call_count: str
|
||||
success_count: str
|
||||
error_count: str
|
||||
avg_response_time: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
last_called_at: Optional[datetime]
|
||||
|
||||
|
||||
class ApiEndpointListResponse(BaseModel):
|
||||
"""API端点列表响应模型"""
|
||||
endpoints: List[ApiEndpointResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class ApiStatsResponse(BaseModel):
|
||||
"""API统计响应模型"""
|
||||
total_endpoints: int
|
||||
active_endpoints: int
|
||||
total_calls: str
|
||||
total_success: str
|
||||
total_errors: str
|
||||
avg_response_time: str
|
||||
|
||||
|
||||
@router.get("/endpoints", response_model=ApiEndpointListResponse)
|
||||
async def get_api_endpoints(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
algorithm_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取API端点列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = db.query(ApiEndpoint)
|
||||
|
||||
# 筛选条件
|
||||
if algorithm_id:
|
||||
query = query.filter(ApiEndpoint.algorithm_id == algorithm_id)
|
||||
if status:
|
||||
query = query.filter(ApiEndpoint.status == status)
|
||||
|
||||
# 分页
|
||||
endpoints = query.offset(skip).limit(limit).all()
|
||||
total = query.count()
|
||||
|
||||
# 构建响应
|
||||
endpoint_responses = []
|
||||
for endpoint in endpoints:
|
||||
# 获取关联的算法和版本信息
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||||
|
||||
endpoint_responses.append({
|
||||
"id": endpoint.id,
|
||||
"name": endpoint.name,
|
||||
"description": endpoint.description,
|
||||
"path": endpoint.path,
|
||||
"method": endpoint.method,
|
||||
"algorithm_id": endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name if algorithm else "",
|
||||
"version_id": endpoint.version_id,
|
||||
"version": version.version if version else "",
|
||||
"service_id": endpoint.service_id,
|
||||
"status": endpoint.status,
|
||||
"is_public": endpoint.is_public,
|
||||
"call_count": endpoint.call_count,
|
||||
"success_count": endpoint.success_count,
|
||||
"error_count": endpoint.error_count,
|
||||
"avg_response_time": endpoint.avg_response_time,
|
||||
"created_at": endpoint.created_at,
|
||||
"updated_at": endpoint.updated_at,
|
||||
"last_called_at": endpoint.last_called_at
|
||||
})
|
||||
|
||||
return {
|
||||
"endpoints": endpoint_responses,
|
||||
"total": total
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取API端点列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取API端点列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
|
||||
async def get_api_endpoint(
|
||||
endpoint_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取API端点详情"""
|
||||
try:
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 获取关联的算法和版本信息
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||||
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"name": endpoint.name,
|
||||
"description": endpoint.description,
|
||||
"path": endpoint.path,
|
||||
"method": endpoint.method,
|
||||
"algorithm_id": endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name if algorithm else "",
|
||||
"version_id": endpoint.version_id,
|
||||
"version": version.version if version else "",
|
||||
"service_id": endpoint.service_id,
|
||||
"status": endpoint.status,
|
||||
"is_public": endpoint.is_public,
|
||||
"call_count": endpoint.call_count,
|
||||
"success_count": endpoint.success_count,
|
||||
"error_count": endpoint.error_count,
|
||||
"avg_response_time": endpoint.avg_response_time,
|
||||
"created_at": endpoint.created_at,
|
||||
"updated_at": endpoint.updated_at,
|
||||
"last_called_at": endpoint.last_called_at
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取API端点详情失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取API端点详情失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/endpoints", response_model=ApiEndpointResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_api_endpoint(
|
||||
request: ApiEndpointCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""创建API端点"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 验证算法和版本是否存在
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == request.algorithm_id).first()
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="算法不存在")
|
||||
|
||||
version = db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.id == request.version_id
|
||||
).first()
|
||||
if not version or version.algorithm_id != request.algorithm_id:
|
||||
raise HTTPException(status_code=404, detail="算法版本不存在")
|
||||
|
||||
# 如果指定了服务ID,验证服务是否存在
|
||||
if request.service_id:
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.service_id == request.service_id
|
||||
).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
# 检查API路径是否已存在
|
||||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||||
ApiEndpoint.path == request.path
|
||||
).first()
|
||||
if existing_endpoint:
|
||||
raise HTTPException(status_code=400, detail="API路径已存在")
|
||||
|
||||
# 创建API端点
|
||||
new_endpoint = ApiEndpoint(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
path=request.path,
|
||||
method=request.method,
|
||||
algorithm_id=request.algorithm_id,
|
||||
version_id=request.version_id,
|
||||
service_id=request.service_id,
|
||||
requires_auth=request.requires_auth,
|
||||
allowed_roles=request.allowed_roles,
|
||||
rate_limit=request.rate_limit,
|
||||
is_public=request.is_public,
|
||||
config=request.config,
|
||||
status="active",
|
||||
call_count="0",
|
||||
success_count="0",
|
||||
error_count="0",
|
||||
avg_response_time="0.0"
|
||||
)
|
||||
|
||||
db.add(new_endpoint)
|
||||
db.commit()
|
||||
db.refresh(new_endpoint)
|
||||
|
||||
# 返回创建的API端点
|
||||
return {
|
||||
"id": new_endpoint.id,
|
||||
"name": new_endpoint.name,
|
||||
"description": new_endpoint.description,
|
||||
"path": new_endpoint.path,
|
||||
"method": new_endpoint.method,
|
||||
"algorithm_id": new_endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name,
|
||||
"version_id": new_endpoint.version_id,
|
||||
"version": version.version,
|
||||
"service_id": new_endpoint.service_id,
|
||||
"status": new_endpoint.status,
|
||||
"is_public": new_endpoint.is_public,
|
||||
"call_count": new_endpoint.call_count,
|
||||
"success_count": new_endpoint.success_count,
|
||||
"error_count": new_endpoint.error_count,
|
||||
"avg_response_time": new_endpoint.avg_response_time,
|
||||
"created_at": new_endpoint.created_at,
|
||||
"updated_at": new_endpoint.updated_at,
|
||||
"last_called_at": new_endpoint.last_called_at
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建API端点失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建API端点失败: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/endpoints/{endpoint_id}", response_model=ApiEndpointResponse)
|
||||
async def update_api_endpoint(
|
||||
endpoint_id: str,
|
||||
request: ApiEndpointUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""更新API端点"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 查询API端点
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 更新字段
|
||||
if request.name is not None:
|
||||
endpoint.name = request.name
|
||||
if request.description is not None:
|
||||
endpoint.description = request.description
|
||||
if request.path is not None:
|
||||
# 检查新路径是否已被其他端点使用
|
||||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||||
ApiEndpoint.path == request.path,
|
||||
ApiEndpoint.id != endpoint_id
|
||||
).first()
|
||||
if existing_endpoint:
|
||||
raise HTTPException(status_code=400, detail="API路径已存在")
|
||||
endpoint.path = request.path
|
||||
if request.method is not None:
|
||||
endpoint.method = request.method
|
||||
if request.requires_auth is not None:
|
||||
endpoint.requires_auth = request.requires_auth
|
||||
if request.allowed_roles is not None:
|
||||
endpoint.allowed_roles = request.allowed_roles
|
||||
if request.rate_limit is not None:
|
||||
endpoint.rate_limit = request.rate_limit
|
||||
if request.is_public is not None:
|
||||
endpoint.is_public = request.is_public
|
||||
if request.config is not None:
|
||||
endpoint.config = request.config
|
||||
if request.status is not None:
|
||||
endpoint.status = request.status
|
||||
|
||||
db.commit()
|
||||
db.refresh(endpoint)
|
||||
|
||||
# 获取关联的算法和版本信息
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == endpoint.algorithm_id).first()
|
||||
version = db.query(AlgorithmVersion).filter(AlgorithmVersion.id == endpoint.version_id).first()
|
||||
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"name": endpoint.name,
|
||||
"description": endpoint.description,
|
||||
"path": endpoint.path,
|
||||
"method": endpoint.method,
|
||||
"algorithm_id": endpoint.algorithm_id,
|
||||
"algorithm_name": algorithm.name if algorithm else "",
|
||||
"version_id": endpoint.version_id,
|
||||
"version": version.version if version else "",
|
||||
"service_id": endpoint.service_id,
|
||||
"status": endpoint.status,
|
||||
"is_public": endpoint.is_public,
|
||||
"call_count": endpoint.call_count,
|
||||
"success_count": endpoint.success_count,
|
||||
"error_count": endpoint.error_count,
|
||||
"avg_response_time": endpoint.avg_response_time,
|
||||
"created_at": endpoint.created_at,
|
||||
"updated_at": endpoint.updated_at,
|
||||
"last_called_at": endpoint.last_called_at
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新API端点失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"更新API端点失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/endpoints/{endpoint_id}")
|
||||
async def delete_api_endpoint(
|
||||
endpoint_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除API端点"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 查询API端点
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 删除API端点
|
||||
db.delete(endpoint)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "API端点删除成功"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除API端点失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"删除API端点失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ApiStatsResponse)
|
||||
async def get_api_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取API统计信息"""
|
||||
try:
|
||||
# 检查用户权限
|
||||
if not hasattr(current_user, 'role_name') or current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 统计API端点
|
||||
total_endpoints = db.query(ApiEndpoint).count()
|
||||
active_endpoints = db.query(ApiEndpoint).filter(ApiEndpoint.status == "active").count()
|
||||
|
||||
# 统计调用次数
|
||||
endpoints = db.query(ApiEndpoint).all()
|
||||
total_calls = sum(int(e.call_count or 0) for e in endpoints)
|
||||
total_success = sum(int(e.success_count or 0) for e in endpoints)
|
||||
total_errors = sum(int(e.error_count or 0) for e in endpoints)
|
||||
|
||||
# 计算平均响应时间
|
||||
avg_response_times = [float(e.avg_response_time or 0) for e in endpoints if float(e.avg_response_time or 0) > 0]
|
||||
avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0.0
|
||||
|
||||
return {
|
||||
"total_endpoints": total_endpoints,
|
||||
"active_endpoints": active_endpoints,
|
||||
"total_calls": str(total_calls),
|
||||
"total_success": str(total_success),
|
||||
"total_errors": str(total_errors),
|
||||
"avg_response_time": f"{avg_response_time:.2f}"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取API统计信息失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取API统计信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/endpoints/{endpoint_id}/test")
|
||||
async def test_api_endpoint(
|
||||
endpoint_id: str,
|
||||
payload: Dict[str, Any],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""测试API端点"""
|
||||
try:
|
||||
# 查询API端点
|
||||
endpoint = db.query(ApiEndpoint).filter(ApiEndpoint.id == endpoint_id).first()
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail="API端点不存在")
|
||||
|
||||
# 检查API端点状态
|
||||
if endpoint.status != "active":
|
||||
raise HTTPException(status_code=400, detail="API端点未激活")
|
||||
|
||||
# 查询关联的服务
|
||||
if endpoint.service_id:
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.service_id == endpoint.service_id
|
||||
).first()
|
||||
if not service or service.status != "running":
|
||||
raise HTTPException(status_code=400, detail="关联服务未运行")
|
||||
|
||||
# 调用服务
|
||||
import httpx
|
||||
import time
|
||||
|
||||
service_url = service.api_url
|
||||
if not service_url.endswith("/"):
|
||||
service_url += "/"
|
||||
|
||||
call_url = f"{service_url}predict"
|
||||
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
call_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"result": response.json(),
|
||||
"response_time": response_time,
|
||||
"message": "API调用成功"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"服务返回错误: HTTP {response.status_code}",
|
||||
"response_time": response_time
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="API端点未关联服务")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"测试API端点失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"测试API端点失败: {str(e)}")
|
||||
64
backend/app/routes/comparison.py
Normal file
64
backend/app/routes/comparison.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.services.comparison_service import ComparisonService
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/comparison", tags=["comparison"])
|
||||
|
||||
# 创建对比服务实例
|
||||
comparison_service = ComparisonService()
|
||||
|
||||
|
||||
@router.post("/compare-algorithms", response_model=dict)
|
||||
async def compare_algorithms(
|
||||
request_data: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""比较多个算法的效果
|
||||
|
||||
Args:
|
||||
request_data: 请求数据,包含input_data和algorithm_configs
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
对比结果
|
||||
"""
|
||||
input_data = request_data.get("input_data")
|
||||
algorithm_configs = request_data.get("algorithm_configs")
|
||||
|
||||
if not input_data:
|
||||
raise HTTPException(status_code=400, detail="缺少 input_data 参数")
|
||||
|
||||
if not algorithm_configs or not isinstance(algorithm_configs, list):
|
||||
raise HTTPException(status_code=400, detail="缺少 algorithm_configs 参数或格式错误")
|
||||
|
||||
result = await comparison_service.compare_algorithms(input_data, algorithm_configs)
|
||||
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "对比失败"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/generate-report", response_model=dict)
|
||||
async def generate_comparison_report(
|
||||
comparison_results: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""生成对比报告
|
||||
|
||||
Args:
|
||||
comparison_results: 对比结果
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
对比报告
|
||||
"""
|
||||
report = comparison_service.generate_comparison_report(comparison_results)
|
||||
|
||||
if not report["success"]:
|
||||
raise HTTPException(status_code=500, detail=report.get("error", "生成报告失败"))
|
||||
|
||||
return report
|
||||
124
backend/app/routes/config.py
Normal file
124
backend/app/routes/config.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.services.config_service import ConfigService
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/{config_key}")
|
||||
async def get_config(
|
||||
config_key: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取配置
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
配置信息
|
||||
"""
|
||||
config = ConfigService.get_config(db, config_key)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="配置不存在")
|
||||
return {"key": config_key, "value": config}
|
||||
|
||||
|
||||
@router.post("/{config_key}")
|
||||
async def set_config(
|
||||
config_key: str,
|
||||
config_data: Dict[str, Any],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""设置配置
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
config_data: 配置数据,包含value、type、service_id、description等字段
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
设置结果
|
||||
"""
|
||||
success = ConfigService.set_config(
|
||||
db=db,
|
||||
config_key=config_key,
|
||||
config_value=config_data.get("value"),
|
||||
config_type=config_data.get("type", "system"),
|
||||
service_id=config_data.get("service_id"),
|
||||
description=config_data.get("description", "")
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="设置配置失败")
|
||||
return {"message": "设置配置成功"}
|
||||
|
||||
|
||||
@router.get("/service/{service_id}")
|
||||
async def get_service_configs(
|
||||
service_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
服务配置列表
|
||||
"""
|
||||
configs = ConfigService.get_service_configs(db, service_id)
|
||||
return {"service_id": service_id, "configs": configs}
|
||||
|
||||
|
||||
@router.delete("/{config_key}")
|
||||
async def delete_config(
|
||||
config_key: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除配置
|
||||
|
||||
Args:
|
||||
config_key: 配置键
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
success = ConfigService.delete_config(db, config_key)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="删除配置失败")
|
||||
return {"message": "删除配置成功"}
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_all_configs(
|
||||
config_type: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取所有配置
|
||||
|
||||
Args:
|
||||
config_type: 配置类型,可选
|
||||
db: 数据库会话
|
||||
current_user: 当前活跃用户
|
||||
|
||||
Returns:
|
||||
配置列表
|
||||
"""
|
||||
configs = ConfigService.get_all_configs(db, config_type)
|
||||
return {"configs": configs}
|
||||
@@ -14,6 +14,7 @@ from app.schemas.user import UserResponse
|
||||
from app.services.project_analyzer import ProjectAnalyzer
|
||||
from app.services.service_generator import ServiceGenerator
|
||||
from app.services.service_orchestrator import ServiceOrchestrator
|
||||
from app.gitea.service import gitea_service
|
||||
|
||||
router = APIRouter(prefix="/services", tags=["services"])
|
||||
|
||||
@@ -23,6 +24,9 @@ class RegisterServiceRequest(BaseModel):
|
||||
repository_id: str
|
||||
name: str
|
||||
version: str = "1.0.0"
|
||||
description: Optional[str] = ""
|
||||
tech_category: str = "computer_vision"
|
||||
output_type: str = "image"
|
||||
service_type: str = "http"
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
@@ -154,31 +158,24 @@ async def register_service(
|
||||
# 记录仓库信息
|
||||
print(f"仓库信息: {repo.name}, {repo.description}, {repo.repo_url}")
|
||||
|
||||
# 2. 分析项目
|
||||
repo_path = f"/tmp/repository_{request.repository_id}"
|
||||
# 注意:在实际实现中,应该从算法仓库中获取项目文件
|
||||
# 这里简化处理,创建一个模拟的项目结构
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
# 2. 从Gitea仓库克隆代码到本地
|
||||
repo_path = f"/tmp/algorithms/{request.repository_id}"
|
||||
|
||||
# 创建模拟的算法文件
|
||||
with open(os.path.join(repo_path, "algorithm.py"), "w") as f:
|
||||
f.write("""
|
||||
def predict(data):
|
||||
return {"result": "Prediction result", "input": data}
|
||||
|
||||
def run(data):
|
||||
return {"result": "Run result", "input": data}
|
||||
|
||||
def main(data):
|
||||
return {"result": "Main result", "input": data}
|
||||
""")
|
||||
# 使用Gitea服务克隆仓库
|
||||
clone_success = gitea_service.clone_repository(repo.repo_url, request.repository_id, repo.branch or "main")
|
||||
if not clone_success:
|
||||
raise HTTPException(status_code=400, detail=f"克隆仓库失败: {repo.repo_url}")
|
||||
|
||||
# 分析项目
|
||||
print(f"仓库克隆成功: {repo_path}")
|
||||
|
||||
# 3. 分析项目
|
||||
project_info = project_analyzer.analyze_project(repo_path)
|
||||
if not project_info["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"项目分析失败: {project_info['error']}")
|
||||
|
||||
# 3. 生成服务包装器
|
||||
print(f"项目分析成功: {project_info}")
|
||||
|
||||
# 4. 生成服务包装器
|
||||
service_config = {
|
||||
"name": request.name,
|
||||
"version": request.version,
|
||||
@@ -194,24 +191,31 @@ def main(data):
|
||||
if not generate_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务生成失败: {generate_result['error']}")
|
||||
|
||||
# 4. 部署服务
|
||||
print(f"服务生成成功: {generate_result}")
|
||||
|
||||
# 5. 部署服务
|
||||
service_id = str(uuid.uuid4())
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info)
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info, repo_path)
|
||||
if not deploy_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
|
||||
|
||||
# 5. 保存服务信息到数据库
|
||||
print(f"服务部署成功: {deploy_result}")
|
||||
|
||||
# 6. 保存服务信息到数据库
|
||||
new_service = AlgorithmService(
|
||||
id=str(uuid.uuid4()),
|
||||
service_id=service_id,
|
||||
name=request.name,
|
||||
algorithm_name=repo.name, # 使用仓库名称作为算法名称
|
||||
version=request.version,
|
||||
tech_category=request.tech_category,
|
||||
output_type=request.output_type,
|
||||
host=request.host,
|
||||
port=request.port,
|
||||
api_url=deploy_result["api_url"],
|
||||
status=deploy_result["status"],
|
||||
config={
|
||||
"repository_id": request.repository_id, # 保存仓库ID
|
||||
"service_type": request.service_type,
|
||||
"timeout": request.timeout,
|
||||
"health_check_path": request.health_check_path,
|
||||
@@ -352,8 +356,64 @@ async def start_service(
|
||||
|
||||
# 启动服务
|
||||
start_result = service_orchestrator.start_service(service_id, container_id)
|
||||
|
||||
# 如果启动失败,尝试从数据库重新注册服务
|
||||
if not start_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务启动失败: {start_result['error']}")
|
||||
print(f"服务启动失败: {start_result['error']},尝试从数据库重新注册服务")
|
||||
|
||||
# 获取仓库信息
|
||||
repository_id = service.config.get("repository_id")
|
||||
if not repository_id:
|
||||
raise HTTPException(status_code=400, detail="Repository ID not found in service config")
|
||||
|
||||
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
|
||||
if not repository:
|
||||
raise HTTPException(status_code=404, detail="Repository not found")
|
||||
|
||||
# 从Gitea克隆仓库
|
||||
clone_success = gitea_service.clone_repository(
|
||||
repository.repo_url,
|
||||
service_id,
|
||||
repository.branch or "main"
|
||||
)
|
||||
if not clone_success:
|
||||
raise HTTPException(status_code=400, detail="Failed to clone repository")
|
||||
|
||||
# 仓库路径
|
||||
repo_path = f"/tmp/algorithms/{service_id}"
|
||||
|
||||
# 分析项目
|
||||
project_info = project_analyzer.analyze_project(repo_path)
|
||||
if not project_info:
|
||||
raise HTTPException(status_code=400, detail="Failed to analyze project")
|
||||
|
||||
# 生成服务
|
||||
service_config = {
|
||||
"name": service.name,
|
||||
"version": service.version,
|
||||
"host": service.host,
|
||||
"port": service.port,
|
||||
"timeout": service.config.get("timeout", 30),
|
||||
"health_check_path": service.config.get("health_check_path", "/health"),
|
||||
"environment": service.config.get("environment", {})
|
||||
}
|
||||
|
||||
# 部署服务
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, project_info, service_config, repo_path)
|
||||
if not deploy_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
|
||||
|
||||
# 更新服务配置
|
||||
service.config["container_id"] = deploy_result["container_id"]
|
||||
service.api_url = deploy_result["api_url"]
|
||||
db.commit()
|
||||
|
||||
start_result = {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"status": "running",
|
||||
"error": None
|
||||
}
|
||||
|
||||
# 更新服务状态
|
||||
service.status = start_result["status"]
|
||||
@@ -1065,3 +1125,108 @@ async def batch_delete_services(
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class ServiceCallRequest(BaseModel):
|
||||
"""服务调用请求"""
|
||||
service_id: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class ServiceCallResponse(BaseModel):
|
||||
"""服务调用响应"""
|
||||
success: bool
|
||||
result: Dict[str, Any]
|
||||
service_id: str
|
||||
execution_time: float
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/call")
|
||||
async def call_service(
|
||||
request: ServiceCallRequest,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""直接调用注册的服务"""
|
||||
import time
|
||||
import httpx
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(
|
||||
AlgorithmService.service_id == request.service_id
|
||||
).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
# 检查服务状态
|
||||
if service.status != "running":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"服务未运行,当前状态: {service.status}"
|
||||
)
|
||||
|
||||
# 调用服务
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 构建服务URL
|
||||
service_url = service.api_url
|
||||
|
||||
# 如果URL没有路径,添加默认路径
|
||||
if not service_url.endswith("/"):
|
||||
service_url += "/"
|
||||
|
||||
# 添加调用端点
|
||||
call_url = f"{service_url}predict"
|
||||
|
||||
# 使用httpx调用服务
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
call_url,
|
||||
json=request.payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
return ServiceCallResponse(
|
||||
success=True,
|
||||
result=response.json(),
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time
|
||||
)
|
||||
else:
|
||||
return ServiceCallResponse(
|
||||
success=False,
|
||||
result={},
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time,
|
||||
error=f"服务返回错误: HTTP {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ServiceCallResponse(
|
||||
success=False,
|
||||
result={},
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time,
|
||||
error=f"无法连接到服务: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ServiceCallResponse(
|
||||
success=False,
|
||||
result={},
|
||||
service_id=request.service_id,
|
||||
execution_time=execution_time,
|
||||
error=f"服务调用异常: {str(e)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -167,6 +167,12 @@ async def read_users_me(current_user: UserResponse = Depends(get_current_active_
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/me/", response_model=UserResponse)
|
||||
async def read_users_me_with_slash(current_user: UserResponse = Depends(get_current_active_user)):
|
||||
"""获取当前用户信息(带末尾斜杠)"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
skip: int = 0,
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -9,6 +9,8 @@ class AlgorithmBase(BaseModel):
|
||||
name: str = Field(..., description="算法名称")
|
||||
description: str = Field(..., description="算法描述")
|
||||
type: str = Field(..., description="算法类型")
|
||||
tech_category: str = Field(default="computer_vision", description="技术分类")
|
||||
output_type: str = Field(default="image", description="输出类型")
|
||||
|
||||
|
||||
class AlgorithmCreate(AlgorithmBase):
|
||||
|
||||
Binary file not shown.
BIN
backend/app/services/__pycache__/config_service.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/config_service.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
165
backend/app/services/comparison_service.py
Normal file
165
backend/app/services/comparison_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from typing import Dict, Any, List
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComparisonService:
|
||||
"""效果对比服务"""
|
||||
|
||||
async def compare_algorithms(
|
||||
self,
|
||||
input_data: Dict[str, Any],
|
||||
algorithm_configs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""比较多个算法的效果
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
algorithm_configs: 算法配置列表,每个配置包含服务URL、参数等
|
||||
|
||||
Returns:
|
||||
对比结果
|
||||
"""
|
||||
try:
|
||||
# 异步执行所有算法
|
||||
tasks = []
|
||||
for config in algorithm_configs:
|
||||
task = self._execute_algorithm(config, input_data)
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
comparison_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
comparison_results.append({
|
||||
"algorithm_id": algorithm_configs[i].get("id"),
|
||||
"algorithm_name": algorithm_configs[i].get("name"),
|
||||
"success": False,
|
||||
"error": str(result),
|
||||
"output": None,
|
||||
"execution_time": 0
|
||||
})
|
||||
else:
|
||||
comparison_results.append({
|
||||
"algorithm_id": algorithm_configs[i].get("id"),
|
||||
"algorithm_name": algorithm_configs[i].get("name"),
|
||||
"success": True,
|
||||
"error": None,
|
||||
"output": result.get("output"),
|
||||
"execution_time": result.get("execution_time", 0)
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": comparison_results,
|
||||
"input_data": input_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Comparison error: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"results": []
|
||||
}
|
||||
|
||||
async def _execute_algorithm(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""执行单个算法
|
||||
|
||||
Args:
|
||||
config: 算法配置
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
url = config.get("url")
|
||||
params = config.get("params", {})
|
||||
|
||||
if not url:
|
||||
raise ValueError("缺少算法服务URL")
|
||||
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"input_data": input_data.get("input_data", input_data),
|
||||
"params": params
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(f"{url}/predict", json=request_data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"output": result,
|
||||
"execution_time": execution_time
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Algorithm execution error: {str(e)}")
|
||||
raise e
|
||||
|
||||
def generate_comparison_report(self, comparison_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""生成对比报告
|
||||
|
||||
Args:
|
||||
comparison_results: 对比结果
|
||||
|
||||
Returns:
|
||||
对比报告
|
||||
"""
|
||||
try:
|
||||
if not comparison_results.get("success"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": comparison_results.get("error", "对比失败")
|
||||
}
|
||||
|
||||
results = comparison_results.get("results", [])
|
||||
|
||||
# 分析结果
|
||||
successful_algorithms = [r for r in results if r.get("success")]
|
||||
failed_algorithms = [r for r in results if not r.get("success")]
|
||||
|
||||
# 计算平均执行时间
|
||||
if successful_algorithms:
|
||||
avg_execution_time = sum(r.get("execution_time", 0) for r in successful_algorithms) / len(successful_algorithms)
|
||||
else:
|
||||
avg_execution_time = 0
|
||||
|
||||
# 生成报告
|
||||
report = {
|
||||
"summary": {
|
||||
"total_algorithms": len(results),
|
||||
"successful_algorithms": len(successful_algorithms),
|
||||
"failed_algorithms": len(failed_algorithms),
|
||||
"average_execution_time": round(avg_execution_time, 2)
|
||||
},
|
||||
"details": results,
|
||||
"input_data": comparison_results.get("input_data")
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"report": report
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Report generation error: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
165
backend/app/services/config_service.py
Normal file
165
backend/app/services/config_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.models import ServiceConfig
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""配置服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session, config_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_key: 配置键
|
||||
|
||||
Returns:
|
||||
配置值,如果不存在返回None
|
||||
"""
|
||||
config = db.query(ServiceConfig).filter_by(
|
||||
config_key=config_key,
|
||||
status="active"
|
||||
).first()
|
||||
|
||||
if config:
|
||||
return config.config_value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_config(db: Session, config_key: str, config_value: Dict[str, Any],
|
||||
config_type: str = "system", service_id: Optional[str] = None,
|
||||
description: str = "") -> bool:
|
||||
"""设置配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_key: 配置键
|
||||
config_value: 配置值
|
||||
config_type: 配置类型,默认为"system"
|
||||
service_id: 服务ID,系统配置可为None
|
||||
description: 配置描述
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
# 检查是否存在
|
||||
existing_config = db.query(ServiceConfig).filter_by(
|
||||
config_key=config_key
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
# 更新现有配置
|
||||
existing_config.config_value = config_value
|
||||
existing_config.config_type = config_type
|
||||
existing_config.service_id = service_id
|
||||
existing_config.description = description
|
||||
existing_config.status = "active"
|
||||
else:
|
||||
# 创建新配置
|
||||
new_config = ServiceConfig(
|
||||
id=f"config-{uuid.uuid4()}",
|
||||
config_key=config_key,
|
||||
config_value=config_value,
|
||||
config_type=config_type,
|
||||
service_id=service_id,
|
||||
description=description,
|
||||
status="active"
|
||||
)
|
||||
db.add(new_config)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set config: {str(e)}")
|
||||
db.rollback()
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_service_configs(db: Session, service_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取服务的所有配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务配置列表
|
||||
"""
|
||||
configs = db.query(ServiceConfig).filter_by(
|
||||
service_id=service_id,
|
||||
status="active"
|
||||
).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"key": config.config_key,
|
||||
"value": config.config_value,
|
||||
"type": config.config_type,
|
||||
"description": config.description
|
||||
}
|
||||
for config in configs
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def delete_config(db: Session, config_key: str) -> bool:
|
||||
"""删除配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_key: 配置键
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
config = db.query(ServiceConfig).filter_by(
|
||||
config_key=config_key
|
||||
).first()
|
||||
|
||||
if config:
|
||||
config.status = "inactive"
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete config: {str(e)}")
|
||||
db.rollback()
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_all_configs(db: Session, config_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""获取所有配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_type: 配置类型,可选
|
||||
|
||||
Returns:
|
||||
配置列表
|
||||
"""
|
||||
query = db.query(ServiceConfig).filter_by(status="active")
|
||||
|
||||
if config_type:
|
||||
query = query.filter_by(config_type=config_type)
|
||||
|
||||
configs = query.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": config.id,
|
||||
"key": config.config_key,
|
||||
"value": config.config_value,
|
||||
"type": config.config_type,
|
||||
"service_id": config.service_id,
|
||||
"description": config.description,
|
||||
"created_at": config.created_at,
|
||||
"updated_at": config.updated_at
|
||||
}
|
||||
for config in configs
|
||||
]
|
||||
@@ -63,22 +63,41 @@ class ProjectAnalyzer:
|
||||
Returns:
|
||||
项目类型,如 "python", "java", "nodejs" 等
|
||||
"""
|
||||
# 检查Python项目
|
||||
# 检查Python项目 - 先检查根目录
|
||||
if os.path.exists(os.path.join(repo_path, "requirements.txt")) or \
|
||||
os.path.exists(os.path.join(repo_path, "pyproject.toml")) or \
|
||||
any(file.endswith(".py") for file in os.listdir(repo_path)):
|
||||
return "python"
|
||||
|
||||
# 检查Java项目
|
||||
# 检查Python项目 - 递归检查子目录
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
if "requirements.txt" in files or "pyproject.toml" in files:
|
||||
return "python"
|
||||
if any(file.endswith(".py") for file in files):
|
||||
return "python"
|
||||
|
||||
# 检查Java项目 - 先检查根目录
|
||||
if os.path.exists(os.path.join(repo_path, "pom.xml")) or \
|
||||
os.path.exists(os.path.join(repo_path, "build.gradle")) or \
|
||||
os.path.exists(os.path.join(repo_path, "src")):
|
||||
return "java"
|
||||
|
||||
# 检查Node.js项目
|
||||
# 检查Java项目 - 递归检查子目录
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
if "pom.xml" in files or "build.gradle" in files:
|
||||
return "java"
|
||||
if "src" in dirs:
|
||||
return "java"
|
||||
|
||||
# 检查Node.js项目 - 先检查根目录
|
||||
if os.path.exists(os.path.join(repo_path, "package.json")):
|
||||
return "nodejs"
|
||||
|
||||
# 检查Node.js项目 - 递归检查子目录
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
if "package.json" in files:
|
||||
return "nodejs"
|
||||
|
||||
# 检查其他项目类型
|
||||
if os.path.exists(os.path.join(repo_path, "CMakeLists.txt")):
|
||||
return "c++"
|
||||
|
||||
@@ -38,13 +38,14 @@ class ServiceOrchestrator:
|
||||
self.client = None
|
||||
print("使用本地进程部署模式")
|
||||
|
||||
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any], repo_path: str = None) -> Dict[str, Any]:
|
||||
"""部署服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
service_config: 服务配置
|
||||
project_info: 项目信息
|
||||
repo_path: 仓库路径(用于复制真实的算法文件)
|
||||
|
||||
Returns:
|
||||
部署结果
|
||||
@@ -95,7 +96,7 @@ class ServiceOrchestrator:
|
||||
service_dir = self._create_service_directory(service_id)
|
||||
|
||||
# 2. 生成服务包装器
|
||||
self._generate_local_service_wrapper(service_dir, project_info, service_config)
|
||||
self._generate_local_service_wrapper(service_dir, project_info, service_config, repo_path)
|
||||
|
||||
# 3. 启动服务进程
|
||||
process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config)
|
||||
@@ -176,9 +177,13 @@ class ServiceOrchestrator:
|
||||
else:
|
||||
# 本地进程启动
|
||||
if service_id not in self.processes:
|
||||
# 服务不在进程列表中,可能是服务重启导致的
|
||||
# 这种情况下,需要从外部重新注册服务
|
||||
# 暂时返回错误,建议用户重新注册服务
|
||||
print(f"服务 {service_id} 不在进程列表中,无法启动")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务不存在",
|
||||
"error": "服务不存在,请重新注册服务",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
@@ -272,11 +277,18 @@ class ServiceOrchestrator:
|
||||
else:
|
||||
# 本地进程停止
|
||||
if service_id not in self.processes:
|
||||
# 服务不在进程列表中,可能是服务重启导致的
|
||||
# 尝试通过端口查找并停止进程
|
||||
print(f"服务 {service_id} 不在进程列表中,尝试通过端口查找进程")
|
||||
|
||||
# 从服务配置中获取端口信息
|
||||
# 这里需要从外部传入服务配置,或者从数据库查询
|
||||
# 暂时返回成功,因为服务可能已经停止了
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务不存在",
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
"status": "stopped",
|
||||
"error": None
|
||||
}
|
||||
|
||||
process_info = self.processes[service_id]
|
||||
@@ -1271,13 +1283,14 @@ json
|
||||
os.makedirs(service_dir, exist_ok=True)
|
||||
return service_dir
|
||||
|
||||
def _generate_local_service_wrapper(self, service_dir: str, project_info: Dict[str, Any], service_config: Dict[str, Any]):
|
||||
def _generate_local_service_wrapper(self, service_dir: str, project_info: Dict[str, Any], service_config: Dict[str, Any], repo_path: str = None):
|
||||
"""生成本地服务包装器
|
||||
|
||||
Args:
|
||||
service_dir: 服务目录
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
repo_path: 仓库路径(用于复制真实的算法文件)
|
||||
"""
|
||||
# 生成服务包装器
|
||||
service_wrapper_content = self._generate_service_wrapper(project_info, service_config)
|
||||
@@ -1285,7 +1298,44 @@ json
|
||||
with open(os.path.join(service_dir, f"service_wrapper{wrapper_extension}"), "w") as f:
|
||||
f.write(service_wrapper_content)
|
||||
|
||||
# 创建模拟的算法文件
|
||||
# 复制真实的算法文件
|
||||
if repo_path and project_info["project_type"] == "python":
|
||||
# 尝试找到并复制主要的算法文件
|
||||
entry_point = project_info.get("entry_point")
|
||||
if entry_point:
|
||||
source_file = os.path.join(repo_path, entry_point)
|
||||
if os.path.exists(source_file):
|
||||
# 复制算法文件到服务目录
|
||||
import shutil
|
||||
shutil.copy2(source_file, os.path.join(service_dir, "algorithm.py"))
|
||||
print(f"已复制算法文件: {source_file} -> {os.path.join(service_dir, 'algorithm.py')}")
|
||||
return
|
||||
|
||||
# 如果没有找到入口点,尝试复制所有Python文件
|
||||
if os.path.exists(repo_path):
|
||||
import shutil
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
for file in files:
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
source_file = os.path.join(root, file)
|
||||
dest_file = os.path.join(service_dir, file)
|
||||
shutil.copy2(source_file, dest_file)
|
||||
print(f"已复制Python文件: {source_file} -> {dest_file}")
|
||||
|
||||
# 如果有algorithm.py,就使用它,否则创建一个模拟的
|
||||
if not os.path.exists(os.path.join(service_dir, "algorithm.py")):
|
||||
print("未找到algorithm.py,创建模拟算法文件")
|
||||
self._create_mock_algorithm(service_dir)
|
||||
else:
|
||||
# 创建模拟的算法文件
|
||||
self._create_mock_algorithm(service_dir)
|
||||
|
||||
def _create_mock_algorithm(self, service_dir: str):
|
||||
"""创建模拟的算法文件
|
||||
|
||||
Args:
|
||||
service_dir: 服务目录
|
||||
"""
|
||||
algorithm_content = """
|
||||
def predict(data):
|
||||
return {"result": "Prediction result", "input": data}
|
||||
@@ -1316,9 +1366,9 @@ def main(data):
|
||||
|
||||
# 构建启动命令
|
||||
if project_info["project_type"] == "python":
|
||||
cmd = ["python", f"service_wrapper.py"]
|
||||
cmd = ["python", "service_wrapper.py"]
|
||||
else:
|
||||
cmd = ["node", f"service_wrapper.js"]
|
||||
cmd = ["node", "service_wrapper.js"]
|
||||
|
||||
# 设置环境变量
|
||||
env = os.environ.copy()
|
||||
|
||||
2
backend/backend.log
Normal file
2
backend/backend.log
Normal file
@@ -0,0 +1,2 @@
|
||||
INFO: Will watch for changes in these directories: ['/Users/duguoyou/MLFlow/algorithm-showcase/backend']
|
||||
ERROR: [Errno 48] Address already in use
|
||||
38
backend/check_algorithms.py
Normal file
38
backend/check_algorithms.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""检查算法数据"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import Algorithm
|
||||
|
||||
def check_algorithms():
|
||||
"""检查算法数据"""
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
algorithms = db.query(Algorithm).all()
|
||||
|
||||
print(f"数据库中共有 {len(algorithms)} 个算法:\n")
|
||||
|
||||
for algo in algorithms:
|
||||
print(f"算法名称: {algo.name}")
|
||||
print(f" ID: {algo.id}")
|
||||
print(f" 类型: {algo.type}")
|
||||
print(f" 技术分类: {algo.tech_category}")
|
||||
print(f" 输出类型: {algo.output_type}")
|
||||
print(f" 描述: {algo.description}")
|
||||
print(f" 状态: {algo.status}")
|
||||
print(f" 版本数: {len(algo.versions)}")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"检查算法数据失败: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_algorithms()
|
||||
57
backend/check_user_role.py
Normal file
57
backend/check_user_role.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
"""检查用户角色信息"""
|
||||
|
||||
import requests
|
||||
|
||||
def check_user_role():
|
||||
"""检查用户角色"""
|
||||
base_url = "http://localhost:8001/api/v1"
|
||||
|
||||
# 登录
|
||||
print("步骤1: 登录")
|
||||
login_data = {
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{base_url}/users/login", json=login_data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"登录失败: {response.text}")
|
||||
return
|
||||
|
||||
data = response.json()
|
||||
access_token = data.get('access_token')
|
||||
print(f"登录成功!")
|
||||
|
||||
# 获取用户信息
|
||||
print("\n步骤2: 获取用户信息")
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
user_response = requests.get(f"{base_url}/users/me", headers=headers)
|
||||
print(f"状态码: {user_response.status_code}")
|
||||
|
||||
if user_response.status_code == 200:
|
||||
user_data = user_response.json()
|
||||
print(f"\n用户信息:")
|
||||
print(f" 用户名: {user_data.get('username', 'N/A')}")
|
||||
print(f" 邮箱: {user_data.get('email', 'N/A')}")
|
||||
print(f" 角色ID: {user_data.get('role_id', 'N/A')}")
|
||||
print(f" 角色名称: {user_data.get('role_name', 'N/A')}")
|
||||
print(f" 角色对象: {user_data.get('role', 'N/A')}")
|
||||
|
||||
# 检查是否是管理员
|
||||
role_name = user_data.get('role_name')
|
||||
if role_name == 'admin':
|
||||
print(f"\n✅ 用户是管理员,应该显示后台管理页面")
|
||||
else:
|
||||
print(f"\n❌ 用户不是管理员,角色名称是: {role_name}")
|
||||
else:
|
||||
print(f"获取用户信息失败: {user_response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_user_role()
|
||||
@@ -1,39 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查数据库中的用户信息
|
||||
"""
|
||||
"""检查数据库中的用户信息"""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, '/Users/duguoyou/MLFlow/algorithm-showcase/backend')
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import User
|
||||
|
||||
from app.services.user import UserService
|
||||
|
||||
def check_users():
|
||||
"""检查数据库中的用户信息"""
|
||||
"""检查用户"""
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# 获取所有用户
|
||||
users = db.query(User).all()
|
||||
|
||||
if users:
|
||||
print("数据库中的用户信息:")
|
||||
print("-" * 50)
|
||||
print(f"数据库中的用户数量: {len(users)}")
|
||||
|
||||
for user in users:
|
||||
print(f"\n用户ID: {user.id}")
|
||||
print(f"用户名: {user.username}")
|
||||
print(f"邮箱: {user.email}")
|
||||
print(f"状态: {user.status}")
|
||||
print(f"角色ID: {user.role_id}")
|
||||
print(f"密码哈希: {user.password_hash[:50]}...")
|
||||
|
||||
# 测试admin用户认证
|
||||
print("\n\n测试admin用户认证:")
|
||||
admin_user = UserService.get_user_by_username(db, 'admin')
|
||||
if admin_user:
|
||||
print(f"找到admin用户: {admin_user.id}")
|
||||
print(f"密码哈希: {admin_user.password_hash[:50]}...")
|
||||
|
||||
for user in users:
|
||||
print(f"用户ID: {user.id}")
|
||||
print(f"用户名: {user.username}")
|
||||
print(f"邮箱: {user.email}")
|
||||
print(f"角色: {user.role}")
|
||||
print(f"状态: {user.status}")
|
||||
print(f"创建时间: {user.created_at}")
|
||||
print("-" * 50)
|
||||
# 测试密码验证
|
||||
test_password = 'admin123'
|
||||
is_valid = UserService.verify_password(test_password, admin_user.password_hash)
|
||||
print(f"密码 '{test_password}' 验证结果: {is_valid}")
|
||||
|
||||
# 尝试认证
|
||||
authenticated_user = UserService.authenticate_user(db, 'admin', test_password)
|
||||
if authenticated_user:
|
||||
print(f"认证成功: {authenticated_user.id}")
|
||||
else:
|
||||
print("认证失败")
|
||||
else:
|
||||
print("数据库中没有用户信息")
|
||||
print("未找到admin用户")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_users()
|
||||
check_users()
|
||||
151
backend/create_sample_algorithms.py
Normal file
151
backend/create_sample_algorithms.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
"""创建示例算法数据"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import Algorithm, AlgorithmVersion
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
def create_sample_algorithms():
|
||||
"""创建示例算法"""
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# 示例算法数据
|
||||
algorithms_data = [
|
||||
{
|
||||
"name": "目标检测",
|
||||
"description": "识别图像中的物体位置和类别,支持人脸、车辆、物品等多种目标检测",
|
||||
"type": "computer_vision",
|
||||
"tech_category": "computer_vision",
|
||||
"output_type": "image",
|
||||
"versions": [
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"url": "http://0.0.0.0:8001",
|
||||
"is_default": True
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "视频分析",
|
||||
"description": "分析视频内容,提取关键帧、识别动作、追踪物体等",
|
||||
"type": "computer_vision",
|
||||
"tech_category": "video_processing",
|
||||
"output_type": "video",
|
||||
"versions": [
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"url": "http://0.0.0.0:8002",
|
||||
"is_default": True
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "图像增强",
|
||||
"description": "提升图像质量,包括去噪、超分辨率、色彩校正等功能",
|
||||
"type": "computer_vision",
|
||||
"tech_category": "computer_vision",
|
||||
"output_type": "image",
|
||||
"versions": [
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"url": "http://0.0.0.0:8003",
|
||||
"is_default": True
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "文本分类",
|
||||
"description": "对文本内容进行分类,支持新闻分类、情感分析、垃圾邮件识别等",
|
||||
"type": "nlp",
|
||||
"tech_category": "nlp",
|
||||
"output_type": "text",
|
||||
"versions": [
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"url": "http://0.0.0.0:8004",
|
||||
"is_default": True
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "异常检测",
|
||||
"description": "检测数据中的异常模式,适用于工业监控、金融风控等场景",
|
||||
"type": "ml",
|
||||
"tech_category": "ml",
|
||||
"output_type": "json",
|
||||
"versions": [
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"url": "http://0.0.0.0:8005",
|
||||
"is_default": True
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "医学影像分析",
|
||||
"description": "分析医学影像,辅助医生进行疾病诊断,支持CT、MRI等多种影像格式",
|
||||
"type": "medical",
|
||||
"tech_category": "computer_vision",
|
||||
"output_type": "image",
|
||||
"versions": [
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"url": "http://0.0.0.0:8006",
|
||||
"is_default": True
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 创建算法
|
||||
for algo_data in algorithms_data:
|
||||
# 检查算法是否已存在
|
||||
existing_algo = db.query(Algorithm).filter(Algorithm.name == algo_data["name"]).first()
|
||||
if existing_algo:
|
||||
print(f"✓ 算法 '{algo_data['name']}' 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 创建算法
|
||||
algorithm = Algorithm(
|
||||
id=str(uuid.uuid4()),
|
||||
name=algo_data["name"],
|
||||
description=algo_data["description"],
|
||||
type=algo_data["type"],
|
||||
tech_category=algo_data["tech_category"],
|
||||
output_type=algo_data["output_type"],
|
||||
status="active"
|
||||
)
|
||||
db.add(algorithm)
|
||||
db.flush() # 获取算法ID
|
||||
|
||||
# 创建版本
|
||||
for version_data in algo_data["versions"]:
|
||||
version = AlgorithmVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
algorithm_id=algorithm.id,
|
||||
version=version_data["version"],
|
||||
url=version_data["url"],
|
||||
is_default=version_data["is_default"]
|
||||
)
|
||||
db.add(version)
|
||||
|
||||
print(f"✓ 已创建算法: {algo_data['name']}")
|
||||
|
||||
db.commit()
|
||||
print("\n示例算法创建完成!")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"创建示例算法失败: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_sample_algorithms()
|
||||
45
backend/migrate_add_algorithm_fields.py
Normal file
45
backend/migrate_add_algorithm_fields.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""数据库迁移脚本:添加技术分类和输出类型字段"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from sqlalchemy import text
|
||||
from app.models.database import engine
|
||||
|
||||
def migrate():
|
||||
"""执行数据库迁移"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# 检查字段是否已存在(PostgreSQL语法)
|
||||
result = conn.execute(text("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'algorithms'
|
||||
"""))
|
||||
columns = [row[0] for row in result.fetchall()]
|
||||
|
||||
# 添加 tech_category 字段
|
||||
if 'tech_category' not in columns:
|
||||
conn.execute(text("ALTER TABLE algorithms ADD COLUMN tech_category VARCHAR(50) DEFAULT 'computer_vision'"))
|
||||
print("✓ 已添加 tech_category 字段")
|
||||
else:
|
||||
print("✓ tech_category 字段已存在")
|
||||
|
||||
# 添加 output_type 字段
|
||||
if 'output_type' not in columns:
|
||||
conn.execute(text("ALTER TABLE algorithms ADD COLUMN output_type VARCHAR(50) DEFAULT 'image'"))
|
||||
print("✓ 已添加 output_type 字段")
|
||||
else:
|
||||
print("✓ output_type 字段已存在")
|
||||
|
||||
conn.commit()
|
||||
print("\n数据库迁移完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"数据库迁移失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate()
|
||||
45
backend/migrate_add_service_fields.py
Normal file
45
backend/migrate_add_service_fields.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""数据库迁移脚本:为algorithm_services表添加技术分类和输出类型字段"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from sqlalchemy import text
|
||||
from app.models.database import engine
|
||||
|
||||
def migrate():
|
||||
"""执行数据库迁移"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# 检查字段是否已存在(PostgreSQL语法)
|
||||
result = conn.execute(text("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'algorithm_services'
|
||||
"""))
|
||||
columns = [row[0] for row in result.fetchall()]
|
||||
|
||||
# 添加 tech_category 字段
|
||||
if 'tech_category' not in columns:
|
||||
conn.execute(text("ALTER TABLE algorithm_services ADD COLUMN tech_category VARCHAR(50) DEFAULT 'computer_vision'"))
|
||||
print("✓ 已添加 tech_category 字段到 algorithm_services 表")
|
||||
else:
|
||||
print("✓ tech_category 字段已存在于 algorithm_services 表")
|
||||
|
||||
# 添加 output_type 字段
|
||||
if 'output_type' not in columns:
|
||||
conn.execute(text("ALTER TABLE algorithm_services ADD COLUMN output_type VARCHAR(50) DEFAULT 'image'"))
|
||||
print("✓ 已添加 output_type 字段到 algorithm_services 表")
|
||||
else:
|
||||
print("✓ output_type 字段已存在于 algorithm_services 表")
|
||||
|
||||
conn.commit()
|
||||
print("\n数据库迁移完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"数据库迁移失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate()
|
||||
48
backend/test_all_apis.py
Normal file
48
backend/test_all_apis.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试所有API端点"""
|
||||
|
||||
import requests
|
||||
|
||||
def test_apis():
|
||||
"""测试API端点"""
|
||||
base_url = "http://localhost:8001/api/v1"
|
||||
|
||||
# 测试算法列表(不需要认证)
|
||||
print("1. 测试算法列表(不需要认证):")
|
||||
try:
|
||||
response = requests.get(f"{base_url}/algorithms/")
|
||||
print(f" 状态码: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f" 成功获取 {len(data.get('algorithms', []))} 个算法")
|
||||
else:
|
||||
print(f" 失败: {response.text}")
|
||||
except Exception as e:
|
||||
print(f" 错误: {e}")
|
||||
|
||||
# 测试用户信息(需要认证)
|
||||
print("\n2. 测试用户信息(需要认证):")
|
||||
try:
|
||||
response = requests.get(f"{base_url}/users/me")
|
||||
print(f" 状态码: {response.status_code}")
|
||||
if response.status_code == 401:
|
||||
print(f" 需要认证(正常)")
|
||||
else:
|
||||
print(f" 响应: {response.text}")
|
||||
except Exception as e:
|
||||
print(f" 错误: {e}")
|
||||
|
||||
# 测试服务列表(需要认证)
|
||||
print("\n3. 测试服务列表(需要认证):")
|
||||
try:
|
||||
response = requests.get(f"{base_url}/services")
|
||||
print(f" 状态码: {response.status_code}")
|
||||
if response.status_code == 401:
|
||||
print(f" 需要认证(正常)")
|
||||
else:
|
||||
print(f" 响应: {response.text[:200]}")
|
||||
except Exception as e:
|
||||
print(f" 错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_apis()
|
||||
53
backend/test_api.py
Normal file
53
backend/test_api.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试前端API调用"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import requests
|
||||
|
||||
def test_api():
|
||||
"""测试API"""
|
||||
try:
|
||||
# 调用算法列表API
|
||||
response = requests.get('http://localhost:8001/api/v1/algorithms/')
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
algorithms = data.get('algorithms', [])
|
||||
|
||||
print(f"成功获取 {len(algorithms)} 个算法\n")
|
||||
|
||||
# 检查每个算法的字段
|
||||
for algo in algorithms:
|
||||
print(f"算法: {algo['name']}")
|
||||
print(f" 技术分类: {algo.get('tech_category', 'N/A')}")
|
||||
print(f" 输出类型: {algo.get('output_type', 'N/A')}")
|
||||
print()
|
||||
|
||||
# 测试筛选
|
||||
print("测试筛选功能:")
|
||||
|
||||
# 按技术分类筛选
|
||||
cv_algorithms = [a for a in algorithms if a.get('tech_category') == 'computer_vision']
|
||||
print(f" 计算机视觉算法: {len(cv_algorithms)} 个")
|
||||
|
||||
# 按输出类型筛选
|
||||
image_algorithms = [a for a in algorithms if a.get('output_type') == 'image']
|
||||
print(f" 图片输出算法: {len(image_algorithms)} 个")
|
||||
|
||||
# 按名称搜索
|
||||
search_results = [a for a in algorithms if '视频' in a.get('name', '')]
|
||||
print(f" 包含'视频'的算法: {len(search_results)} 个")
|
||||
|
||||
else:
|
||||
print(f"API调用失败: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
except Exception as e:
|
||||
print(f"测试失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_api()
|
||||
32
backend/test_frontend_proxy.py
Normal file
32
backend/test_frontend_proxy.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试前端代理配置"""
|
||||
|
||||
import requests
|
||||
|
||||
def test_frontend_proxy():
|
||||
"""测试前端代理"""
|
||||
try:
|
||||
# 测试前端代理
|
||||
response = requests.get('http://localhost:3000/api/algorithms')
|
||||
|
||||
print(f"状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"成功获取 {len(data.get('algorithms', []))} 个算法")
|
||||
|
||||
# 检查第一个算法的字段
|
||||
if data.get('algorithms'):
|
||||
first_algo = data['algorithms'][0]
|
||||
print(f"\n第一个算法:")
|
||||
print(f" 名称: {first_algo.get('name')}")
|
||||
print(f" 技术分类: {first_algo.get('tech_category')}")
|
||||
print(f" 输出类型: {first_algo.get('output_type')}")
|
||||
else:
|
||||
print(f"请求失败: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"测试失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_frontend_proxy()
|
||||
48
backend/test_full_login.py
Normal file
48
backend/test_full_login.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试完整的登录流程"""
|
||||
|
||||
import requests
|
||||
|
||||
def test_full_login_flow():
|
||||
"""测试完整的登录流程"""
|
||||
base_url = "http://localhost:8001/api/v1"
|
||||
|
||||
# 步骤1: 登录
|
||||
print("步骤1: 登录")
|
||||
login_data = {
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{base_url}/users/login", json=login_data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"登录失败: {response.text}")
|
||||
return
|
||||
|
||||
data = response.json()
|
||||
access_token = data.get('access_token')
|
||||
print(f"登录成功!")
|
||||
print(f"Token: {access_token[:50]}...")
|
||||
|
||||
# 步骤2: 使用token获取用户信息
|
||||
print("\n步骤2: 获取用户信息")
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
user_response = requests.get(f"{base_url}/users/me", headers=headers)
|
||||
print(f"状态码: {user_response.status_code}")
|
||||
|
||||
if user_response.status_code == 200:
|
||||
user_data = user_response.json()
|
||||
print(f"用户名: {user_data.get('username', 'N/A')}")
|
||||
print(f"邮箱: {user_data.get('email', 'N/A')}")
|
||||
print(f"角色: {user_data.get('role_name', 'N/A')}")
|
||||
else:
|
||||
print(f"获取用户信息失败: {user_response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_full_login_flow()
|
||||
45
backend/test_login.py
Normal file
45
backend/test_login.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试登录功能"""
|
||||
|
||||
import requests
|
||||
|
||||
def test_login():
|
||||
"""测试登录"""
|
||||
base_url = "http://localhost:8001/api/v1"
|
||||
|
||||
# 测试登录
|
||||
print("测试登录功能:")
|
||||
login_data = {
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{base_url}/users/login", json=login_data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"登录成功!")
|
||||
print(f"访问令牌: {data.get('access_token', 'N/A')[:50]}...")
|
||||
print(f"令牌类型: {data.get('token_type', 'N/A')}")
|
||||
|
||||
# 测试使用令牌访问受保护的API
|
||||
if data.get('access_token'):
|
||||
headers = {"Authorization": f"Bearer {data['access_token']}"}
|
||||
user_response = requests.get(f"{base_url}/users/me", headers=headers)
|
||||
print(f"\n测试用户信息API:")
|
||||
print(f"状态码: {user_response.status_code}")
|
||||
if user_response.status_code == 200:
|
||||
user_data = user_response.json()
|
||||
print(f"用户名: {user_data.get('username', 'N/A')}")
|
||||
print(f"邮箱: {user_data.get('email', 'N/A')}")
|
||||
else:
|
||||
print(f"失败: {user_response.text}")
|
||||
else:
|
||||
print(f"登录失败: {response.text}")
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_login()
|
||||
53
backend/test_login_api.py
Normal file
53
backend/test_login_api.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
"""直接测试登录API"""
|
||||
|
||||
import requests
|
||||
|
||||
def test_login_api():
|
||||
"""测试登录API"""
|
||||
base_url = "http://localhost:8001/api/v1"
|
||||
|
||||
# 测试1: 使用JSON格式
|
||||
print("测试1: 使用JSON格式")
|
||||
login_data = {
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{base_url}/users/login", json=login_data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
print(f"响应内容: {response.text[:500]}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"✅ 登录成功!")
|
||||
print(f"Token: {data.get('access_token', 'N/A')[:50]}...")
|
||||
else:
|
||||
print(f"❌ 登录失败")
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
# 测试2: 使用form-data格式
|
||||
print("\n\n测试2: 使用form-data格式")
|
||||
form_data = {
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{base_url}/users/login", data=form_data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text[:500]}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"✅ 登录成功!")
|
||||
else:
|
||||
print(f"❌ 登录失败")
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_login_api()
|
||||
232
backend/test_system.py
Normal file
232
backend/test_system.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
|
||||
class SystemTester:
|
||||
def __init__(self, base_url: str = "http://localhost:8001/api/v1"):
|
||||
self.base_url = base_url
|
||||
self.session = requests.Session()
|
||||
self.token = None
|
||||
self.user_id = None
|
||||
|
||||
def login(self, username: str = "admin", password: str = "admin123") -> bool:
|
||||
"""登录系统"""
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.base_url}/users/login",
|
||||
json={"username": username, "password": password}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
self.token = data.get("access_token")
|
||||
self.user_id = data.get("user_id")
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
|
||||
print(f"✓ 登录成功: {username}")
|
||||
return True
|
||||
else:
|
||||
print(f"✗ 登录失败: {response.status_code} - {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✗ 登录异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_config_endpoints(self) -> bool:
|
||||
"""测试配置管理API"""
|
||||
print("\n=== 测试配置管理API ===")
|
||||
success = True
|
||||
|
||||
try:
|
||||
# 测试获取所有配置
|
||||
response = self.session.get(f"{self.base_url}/config/")
|
||||
if response.status_code == 200:
|
||||
print("✓ 获取所有配置成功")
|
||||
configs = response.json().get("configs", [])
|
||||
print(f" 当前配置数量: {len(configs)}")
|
||||
else:
|
||||
print(f"✗ 获取所有配置失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
# 测试添加配置
|
||||
test_config = {
|
||||
"value": "test_value_123",
|
||||
"type": "system",
|
||||
"service_id": None,
|
||||
"description": "测试配置"
|
||||
}
|
||||
response = self.session.post(f"{self.base_url}/config/test_config_key", json=test_config)
|
||||
if response.status_code == 200:
|
||||
print("✓ 添加配置成功")
|
||||
else:
|
||||
print(f"✗ 添加配置失败: {response.status_code} - {response.text}")
|
||||
success = False
|
||||
|
||||
# 测试获取单个配置
|
||||
response = self.session.get(f"{self.base_url}/config/test_config_key")
|
||||
if response.status_code == 200:
|
||||
print("✓ 获取单个配置成功")
|
||||
config_data = response.json()
|
||||
print(f" 配置值: {config_data.get('value')}")
|
||||
else:
|
||||
print(f"✗ 获取单个配置失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
# 测试删除配置
|
||||
response = self.session.delete(f"{self.base_url}/config/test_config_key")
|
||||
if response.status_code == 200:
|
||||
print("✓ 删除配置成功")
|
||||
else:
|
||||
print(f"✗ 删除配置失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"✗ 配置管理API测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_comparison_endpoints(self) -> bool:
|
||||
"""测试算法比较API"""
|
||||
print("\n=== 测试算法比较API ===")
|
||||
success = True
|
||||
|
||||
try:
|
||||
# 测试算法比较(使用模拟数据)
|
||||
test_data = {
|
||||
"input_data": {"text": "这是一段测试文本"},
|
||||
"algorithm_configs": [
|
||||
{
|
||||
"algorithm_id": "test_algo_1",
|
||||
"algorithm_name": "测试算法1",
|
||||
"version": "1.0.0",
|
||||
"config": "{}"
|
||||
},
|
||||
{
|
||||
"algorithm_id": "test_algo_2",
|
||||
"algorithm_name": "测试算法2",
|
||||
"version": "1.0.0",
|
||||
"config": "{}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = self.session.post(f"{self.base_url}/comparison/compare-algorithms", json=test_data)
|
||||
if response.status_code == 200:
|
||||
print("✓ 算法比较API调用成功")
|
||||
result = response.json()
|
||||
print(f" 比较状态: {result.get('success')}")
|
||||
if result.get('results'):
|
||||
print(f" 结果数量: {len(result.get('results'))}")
|
||||
else:
|
||||
print(f"✗ 算法比较失败: {response.status_code} - {response.text}")
|
||||
success = False
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"✗ 算法比较API测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_existing_endpoints(self) -> bool:
|
||||
"""测试现有API端点"""
|
||||
print("\n=== 测试现有API端点 ===")
|
||||
success = True
|
||||
|
||||
try:
|
||||
# 测试健康检查
|
||||
response = self.session.get(f"{self.base_url.replace('/api/v1', '')}/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ 健康检查通过")
|
||||
else:
|
||||
print(f"✗ 健康检查失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
# 测试获取当前用户
|
||||
response = self.session.get(f"{self.base_url}/users/me")
|
||||
if response.status_code == 200:
|
||||
print("✓ 获取当前用户成功")
|
||||
user_data = response.json()
|
||||
print(f" 用户名: {user_data.get('username')}")
|
||||
else:
|
||||
print(f"✗ 获取当前用户失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
# 测试获取算法列表
|
||||
response = self.session.get(f"{self.base_url}/algorithms/")
|
||||
if response.status_code == 200:
|
||||
print("✓ 获取算法列表成功")
|
||||
algorithms = response.json()
|
||||
print(f" 算法数量: {len(algorithms) if isinstance(algorithms, list) else 0}")
|
||||
else:
|
||||
print(f"✗ 获取算法列表失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
# 测试获取服务列表
|
||||
response = self.session.get(f"{self.base_url}/services")
|
||||
if response.status_code == 200:
|
||||
print("✓ 获取服务列表成功")
|
||||
services = response.json()
|
||||
print(f" 服务数量: {len(services) if isinstance(services, list) else 0}")
|
||||
else:
|
||||
print(f"✗ 获取服务列表失败: {response.status_code}")
|
||||
success = False
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"✗ 现有API端点测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def run_all_tests(self) -> Dict[str, bool]:
|
||||
"""运行所有测试"""
|
||||
print("=" * 50)
|
||||
print("开始系统自动化测试")
|
||||
print("=" * 50)
|
||||
|
||||
results = {}
|
||||
|
||||
# 登录
|
||||
if not self.login():
|
||||
print("\n✗ 登录失败,无法继续测试")
|
||||
return {"login": False}
|
||||
|
||||
results["login"] = True
|
||||
|
||||
# 测试现有端点
|
||||
results["existing_endpoints"] = self.test_existing_endpoints()
|
||||
|
||||
# 测试配置管理API
|
||||
results["config_endpoints"] = self.test_config_endpoints()
|
||||
|
||||
# 测试算法比较API
|
||||
results["comparison_endpoints"] = self.test_comparison_endpoints()
|
||||
|
||||
# 输出测试结果
|
||||
print("\n" + "=" * 50)
|
||||
print("测试结果汇总")
|
||||
print("=" * 50)
|
||||
|
||||
for test_name, result in results.items():
|
||||
status = "✓ 通过" if result else "✗ 失败"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
total_tests = len(results)
|
||||
passed_tests = sum(1 for result in results.values() if result)
|
||||
|
||||
print(f"\n总计: {passed_tests}/{total_tests} 测试通过")
|
||||
|
||||
if passed_tests == total_tests:
|
||||
print("🎉 所有测试通过!")
|
||||
else:
|
||||
print("⚠️ 部分测试失败,请检查日志")
|
||||
|
||||
return results
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
tester = SystemTester()
|
||||
results = tester.run_all_tests()
|
||||
|
||||
# 返回退出码
|
||||
exit_code = 0 if all(results.values()) else 1
|
||||
return exit_code
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user