first commit
This commit is contained in:
5
backend/.env
Normal file
5
backend/.env
Normal file
@@ -0,0 +1,5 @@
|
||||
# Gitea 配置
|
||||
GITEA_SERVER_URL=https://gitea.swiftsnake.cn
|
||||
GITEA_ACCESS_TOKEN=26ccc228c6624f98d6dd629365be052e161b0da3
|
||||
GITEA_DEFAULT_OWNER=yipai-tech
|
||||
GITEA_REPO_PREFIX=AI
|
||||
28
backend/Dockerfile
Normal file
28
backend/Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
FROM python:3.9-slim
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 启动应用
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
32
backend/Dockerfile.custom
Normal file
32
backend/Dockerfile.custom
Normal file
@@ -0,0 +1,32 @@
|
||||
FROM crpi-x2l5uviq1k8hji3c.ap-northeast-1.personal.cr.aliyuncs.com/yipaidocker-images/linux_arm64_python:3.9-slim
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 清除代理设置
|
||||
ENV HTTP_PROXY=""
|
||||
ENV http_proxy=""
|
||||
ENV HTTPS_PROXY=""
|
||||
ENV https_proxy=""
|
||||
ENV NO_PROXY="localhost,127.0.0.1"
|
||||
ENV no_proxy="localhost,127.0.0.1"
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 升级pip并安装Python依赖
|
||||
RUN unset HTTP_PROXY && unset http_proxy && unset HTTPS_PROXY && unset https_proxy && \
|
||||
pip install --upgrade pip && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 启动应用
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
BIN
backend/app/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/app/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
backend/app/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/dependencies.cpython-312.pyc
Normal file
BIN
backend/app/__pycache__/dependencies.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/dependencies.cpython-39.pyc
Normal file
BIN
backend/app/__pycache__/dependencies.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/gateway.cpython-312.pyc
Normal file
BIN
backend/app/__pycache__/gateway.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/gateway.cpython-39.pyc
Normal file
BIN
backend/app/__pycache__/gateway.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/main.cpython-312.pyc
Normal file
BIN
backend/app/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/__pycache__/main.cpython-39.pyc
Normal file
BIN
backend/app/__pycache__/main.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/config/__pycache__/settings.cpython-312.pyc
Normal file
BIN
backend/app/config/__pycache__/settings.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/config/__pycache__/settings.cpython-39.pyc
Normal file
BIN
backend/app/config/__pycache__/settings.cpython-39.pyc
Normal file
Binary file not shown.
53
backend/app/config/settings.py
Normal file
53
backend/app/config/settings.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置类"""
|
||||
# 应用基本配置
|
||||
APP_NAME: str = "智能算法展示平台"
|
||||
APP_VERSION: str = "1.0.0"
|
||||
DEBUG: bool = True
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: str = "postgresql://admin:password@localhost:5432/algorithm_db"
|
||||
|
||||
# Redis配置
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# MinIO配置
|
||||
MINIO_ENDPOINT: str = "localhost:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET_NAME: str = "algorithm-data"
|
||||
MINIO_SECURE: bool = False
|
||||
|
||||
# JWT配置
|
||||
SECRET_KEY: str = "your-secret-key-here"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
# OpenAI配置
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
OPENAI_MODEL: str = "gpt-3.5-turbo"
|
||||
|
||||
# CORS配置
|
||||
CORS_ORIGINS: list = ["*"]
|
||||
|
||||
# API配置
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Gitea 配置
|
||||
GITEA_SERVER_URL: str = ""
|
||||
GITEA_ACCESS_TOKEN: str = ""
|
||||
GITEA_DEFAULT_OWNER: str = ""
|
||||
GITEA_REPO_PREFIX: str = "AI"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
extra = "allow" # 允许额外的环境变量
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
28
backend/app/dependencies.py
Normal file
28
backend/app/dependencies.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
通用依赖项模块,包含常用的FastAPI依赖项函数
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.services.user import UserService
|
||||
from app.schemas.user import UserResponse
|
||||
|
||||
# OAuth2密码Bearer
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login")
|
||||
|
||||
|
||||
async def get_current_active_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
|
||||
"""获取当前活跃用户"""
|
||||
user = UserService.get_current_user(db, token)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if hasattr(user, 'status') and user.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return user
|
||||
144
backend/app/gateway.py
Normal file
144
backend/app/gateway.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""API网关模块,负责请求路由、认证授权、流量控制等功能"""
|
||||
|
||||
from fastapi import Request, HTTPException, status, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any, Optional
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from urllib.parse import urljoin
|
||||
import httpx
|
||||
|
||||
from app.config.settings import settings
|
||||
# Note: get_current_active_user is not used in this module, removing the import
|
||||
from app.models.database import get_db
|
||||
from app.services.algorithm import AlgorithmVersionService
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIGateway:
|
||||
"""API网关类,处理请求路由、认证授权、流量控制等功能"""
|
||||
|
||||
def __init__(self):
|
||||
self.request_counts = {} # 存储请求计数,实际生产环境应使用Redis
|
||||
|
||||
async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""验证请求的认证信息"""
|
||||
try:
|
||||
# 从请求头获取token
|
||||
token = request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||
if not token:
|
||||
return None
|
||||
|
||||
# 验证token有效性(这里只是示例,实际应该调用认证服务)
|
||||
# 在实际实现中,可能需要连接数据库验证API密钥
|
||||
# 或者解析JWT token
|
||||
return {"user_id": "temp_user_id", "role": "user"} # 临时返回值
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {str(e)}")
|
||||
return None
|
||||
|
||||
async def check_rate_limit(self, user_id: str, algorithm_id: str) -> bool:
|
||||
"""检查用户对特定算法的请求频率限制"""
|
||||
key = f"{user_id}:{algorithm_id}"
|
||||
current_time = time.time()
|
||||
|
||||
if key not in self.request_counts:
|
||||
self.request_counts[key] = []
|
||||
|
||||
# 清除超过时间窗口的请求记录
|
||||
self.request_counts[key] = [
|
||||
req_time for req_time in self.request_counts[key]
|
||||
if current_time - req_time < settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 30分钟窗口
|
||||
]
|
||||
|
||||
# 检查是否超过限制(例如每分钟最多10次请求)
|
||||
if len(self.request_counts[key]) >= 10:
|
||||
return False
|
||||
|
||||
# 添加当前请求记录
|
||||
self.request_counts[key].append(current_time)
|
||||
return True
|
||||
|
||||
async def route_request(self, algorithm_id: str, version_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""路由请求到相应的算法服务"""
|
||||
try:
|
||||
# 获取算法版本信息
|
||||
db = next(get_db())
|
||||
version_info = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
|
||||
if not version_info:
|
||||
raise HTTPException(status_code=404, detail="Algorithm version not found")
|
||||
|
||||
# 在实际实现中,这里会根据version_info.url将请求转发到对应的算法服务
|
||||
# 现在我们模拟调用过程
|
||||
algorithm_url = version_info.url if hasattr(version_info, 'url') else f"http://localhost:8001/algorithms/{algorithm_id}/execute"
|
||||
|
||||
# 使用httpx调用算法服务
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
algorithm_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Algorithm service error: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request to algorithm service failed: {str(e)}")
|
||||
raise HTTPException(status_code=502, detail="Algorithm service unavailable")
|
||||
except Exception as e:
|
||||
logger.error(f"Routing error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
# 全局API网关实例
|
||||
api_gateway = APIGateway()
|
||||
|
||||
|
||||
async def gateway_middleware(request: Request):
|
||||
"""网关中间件,处理认证和路由"""
|
||||
# 认证检查
|
||||
user_info = await api_gateway.authenticate_request(request)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||||
|
||||
# 这里可以添加其他中间件逻辑,如权限检查、流量控制等
|
||||
return user_info
|
||||
|
||||
|
||||
async def call_algorithm_gateway(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
payload: Dict[str, Any],
|
||||
user_info: Dict[str, Any] = Depends(gateway_middleware)
|
||||
):
|
||||
"""通过网关调用算法的主函数"""
|
||||
# 检查速率限制
|
||||
rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id)
|
||||
if not rate_limited:
|
||||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||||
|
||||
# 路由请求到算法服务
|
||||
result = await api_gateway.route_request(algorithm_id, version_id, payload)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 辅助函数:获取用户信息而不强制要求认证
|
||||
async def get_optional_user(request: Request):
|
||||
"""获取用户信息,如果未认证则返回None"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header.replace("Bearer ", "")
|
||||
# 这里应该是实际的token验证逻辑
|
||||
return {"user_id": "temp_user_id", "role": "user"} # 临时返回值
|
||||
return None
|
||||
BIN
backend/app/gitea/__pycache__/client.cpython-312.pyc
Normal file
BIN
backend/app/gitea/__pycache__/client.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/gitea/__pycache__/service.cpython-312.pyc
Normal file
BIN
backend/app/gitea/__pycache__/service.cpython-312.pyc
Normal file
Binary file not shown.
217
backend/app/gitea/client.py
Normal file
217
backend/app/gitea/client.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Gitea API客户端,用于与Gitea服务器通信"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteaClient:
|
||||
"""Gitea API客户端类"""
|
||||
|
||||
def __init__(self, server_url: str, access_token: str):
|
||||
"""初始化Gitea客户端
|
||||
|
||||
Args:
|
||||
server_url: Gitea服务器URL
|
||||
access_token: Gitea访问令牌
|
||||
"""
|
||||
self.server_url = server_url.rstrip('/')
|
||||
self.access_token = access_token
|
||||
self.api_url = f"{self.server_url}/api/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"token {self.access_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def _request(self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""发送API请求
|
||||
|
||||
Args:
|
||||
method: HTTP方法(GET, POST, PUT, DELETE)
|
||||
endpoint: API端点
|
||||
data: 请求数据
|
||||
|
||||
Returns:
|
||||
API响应数据
|
||||
"""
|
||||
url = f"{self.api_url}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method == "GET":
|
||||
response = requests.get(url, headers=self.headers, params=data)
|
||||
elif method == "POST":
|
||||
response = requests.post(url, headers=self.headers, json=data)
|
||||
elif method == "PUT":
|
||||
response = requests.put(url, headers=self.headers, json=data)
|
||||
elif method == "PATCH":
|
||||
response = requests.patch(url, headers=self.headers, json=data)
|
||||
elif method == "DELETE":
|
||||
response = requests.delete(url, headers=self.headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if response.content:
|
||||
return response.json()
|
||||
return None
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Gitea API request failed: {str(e)}")
|
||||
return None
|
||||
|
||||
def create_repository(self, owner: str, name: str, description: str = "", private: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""创建仓库
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
name: 仓库名称
|
||||
description: 仓库描述
|
||||
private: 是否私有
|
||||
|
||||
Returns:
|
||||
创建的仓库信息
|
||||
"""
|
||||
data = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"private": private,
|
||||
"auto_init": True
|
||||
}
|
||||
|
||||
# 优先尝试在指定的 owner(组织)下创建仓库
|
||||
logger.info(f"Attempting to create repository for owner: {owner}")
|
||||
owner_response = self._request("POST", f"org/{owner}/repos", data)
|
||||
if owner_response:
|
||||
logger.info(f"Repository created successfully under owner: {owner}")
|
||||
return owner_response
|
||||
|
||||
# 如果组织创建失败,尝试在用户下创建仓库
|
||||
logger.info(f"Organization creation failed, trying user account")
|
||||
user_response = self._request("POST", f"user/repos", data)
|
||||
if user_response:
|
||||
logger.info(f"Repository created successfully in user account")
|
||||
return user_response
|
||||
|
||||
logger.error(f"Failed to create repository for owner {owner}")
|
||||
return None
|
||||
|
||||
def get_repository(self, owner: str, repo: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
repo: 仓库名称
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
return self._request("GET", f"repos/{owner}/{repo}")
|
||||
|
||||
def list_repositories(self, owner: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""列出用户或组织的仓库
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
return self._request("GET", f"users/{owner}/repos")
|
||||
|
||||
def delete_repository(self, owner: str, repo: str) -> bool:
|
||||
"""删除仓库
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
repo: 仓库名称
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
response = self._request("DELETE", f"repos/{owner}/{repo}")
|
||||
return response is not None
|
||||
|
||||
def create_file(self, owner: str, repo: str, path: str, content: str, message: str) -> Optional[Dict[str, Any]]:
|
||||
"""创建或更新文件
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
repo: 仓库名称
|
||||
path: 文件路径
|
||||
content: 文件内容(base64编码)
|
||||
message: 提交消息
|
||||
|
||||
Returns:
|
||||
操作结果
|
||||
"""
|
||||
data = {
|
||||
"content": content,
|
||||
"message": message
|
||||
}
|
||||
|
||||
return self._request("POST", f"repos/{owner}/{repo}/contents/{path}", data)
|
||||
|
||||
def get_file(self, owner: str, repo: str, path: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取文件内容
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
repo: 仓库名称
|
||||
path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件信息和内容
|
||||
"""
|
||||
return self._request("GET", f"repos/{owner}/{repo}/contents/{path}")
|
||||
|
||||
def get_repository_files(self, owner: str, repo: str, path: str = "") -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取仓库文件列表
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
repo: 仓库名称
|
||||
path: 目录路径(默认为根目录)
|
||||
|
||||
Returns:
|
||||
文件列表
|
||||
"""
|
||||
result = self._request("GET", f"repos/{owner}/{repo}/contents/{path}")
|
||||
return result if isinstance(result, list) else None
|
||||
|
||||
|
||||
|
||||
def check_connection(self) -> bool:
|
||||
"""检查与Gitea服务器的连接
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
response = self._request("GET", "user")
|
||||
return response is not None
|
||||
|
||||
def update_repository(self, owner: str, repo: str, name: Optional[str] = None, description: Optional[str] = None, private: Optional[bool] = None) -> Optional[Dict[str, Any]]:
|
||||
"""更新仓库信息
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
repo: 仓库名称
|
||||
name: 新的仓库名称(可选)
|
||||
description: 新的仓库描述(可选)
|
||||
private: 是否私有(可选)
|
||||
|
||||
Returns:
|
||||
更新后的仓库信息
|
||||
"""
|
||||
data = {}
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if private is not None:
|
||||
data["private"] = private
|
||||
|
||||
return self._request("PATCH", f"repos/{owner}/{repo}", data)
|
||||
1253
backend/app/gitea/service.py
Normal file
1253
backend/app/gitea/service.py
Normal file
File diff suppressed because it is too large
Load Diff
867
backend/app/gitea/service_backup.py
Normal file
867
backend/app/gitea/service_backup.py
Normal file
@@ -0,0 +1,867 @@
|
||||
"""Gitea服务,处理与Gitea相关的业务逻辑"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, List
|
||||
import uuid
|
||||
|
||||
from app.gitea.client import GiteaClient
|
||||
from app.config.settings import settings
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import GiteaConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteaService:
|
||||
"""Gitea服务类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Gitea服务"""
|
||||
self.config = self._load_config()
|
||||
self.client = None
|
||||
if self.config:
|
||||
self.client = GiteaClient(
|
||||
self.config['server_url'],
|
||||
self.config['access_token']
|
||||
)
|
||||
|
||||
def _load_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""加载Gitea配置
|
||||
|
||||
Returns:
|
||||
Gitea配置信息
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
# 从数据库中获取配置(只取第一个配置)
|
||||
config = db.query(GiteaConfig).filter_by(status="active").first()
|
||||
db.close()
|
||||
|
||||
if config:
|
||||
return {
|
||||
'id': config.id,
|
||||
'server_url': config.server_url,
|
||||
'access_token': config.access_token,
|
||||
'default_owner': config.default_owner,
|
||||
'repo_prefix': config.repo_prefix,
|
||||
'status': config.status
|
||||
}
|
||||
|
||||
# 配置不存在时返回默认值
|
||||
return {
|
||||
'server_url': getattr(settings, 'GITEA_SERVER_URL', ''),
|
||||
'access_token': getattr(settings, 'GITEA_ACCESS_TOKEN', ''),
|
||||
'default_owner': getattr(settings, 'GITEA_DEFAULT_OWNER', ''),
|
||||
'repo_prefix': getattr(settings, 'GITEA_REPO_PREFIX', '')
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Gitea config from database: {str(e)}")
|
||||
# 出错时返回默认配置
|
||||
return {
|
||||
'server_url': getattr(settings, 'GITEA_SERVER_URL', ''),
|
||||
'access_token': getattr(settings, 'GITEA_ACCESS_TOKEN', ''),
|
||||
'default_owner': getattr(settings, 'GITEA_DEFAULT_OWNER', ''),
|
||||
'repo_prefix': getattr(settings, 'GITEA_REPO_PREFIX', '')
|
||||
}
|
||||
|
||||
def save_config(self, config: Dict[str, Any]) -> bool:
|
||||
"""保存Gitea配置
|
||||
|
||||
Args:
|
||||
config: Gitea配置信息
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
|
||||
# 将所有现有配置设置为非活动状态
|
||||
db.query(GiteaConfig).update({GiteaConfig.status: "inactive"})
|
||||
|
||||
# 检查是否已有配置
|
||||
existing_config = db.query(GiteaConfig).first()
|
||||
|
||||
if existing_config:
|
||||
# 更新现有配置
|
||||
existing_config.server_url = config['server_url']
|
||||
existing_config.access_token = config['access_token']
|
||||
existing_config.default_owner = config['default_owner']
|
||||
existing_config.repo_prefix = config.get('repo_prefix', '')
|
||||
existing_config.status = "active"
|
||||
else:
|
||||
# 创建新配置
|
||||
new_config = GiteaConfig(
|
||||
id=f"gitea-config-{uuid.uuid4()}",
|
||||
server_url=config['server_url'],
|
||||
access_token=config['access_token'],
|
||||
default_owner=config['default_owner'],
|
||||
repo_prefix=config.get('repo_prefix', ''),
|
||||
status="active"
|
||||
)
|
||||
db.add(new_config)
|
||||
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
# 更新内存中的配置
|
||||
self.config = config
|
||||
self.client = GiteaClient(
|
||||
config['server_url'],
|
||||
config['access_token']
|
||||
)
|
||||
|
||||
logger.info("Gitea config saved to database successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Gitea config to database: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取Gitea配置
|
||||
|
||||
Returns:
|
||||
Gitea配置信息
|
||||
"""
|
||||
return self.config
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试Gitea连接
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if not self.client:
|
||||
return False
|
||||
return self.client.check_connection()
|
||||
|
||||
def create_repository(self, algorithm_id: str, algorithm_name: str, description: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""为算法创建Gitea仓库
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
algorithm_name: 算法名称
|
||||
description: 仓库描述
|
||||
|
||||
Returns:
|
||||
创建的仓库信息
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
logger.error("Gitea client not initialized. Please check your Gitea configuration.")
|
||||
return None
|
||||
|
||||
if not self.config.get('default_owner'):
|
||||
logger.error("Default owner not set in Gitea configuration.")
|
||||
return None
|
||||
|
||||
# 记录传入的algorithm_id
|
||||
logger.info(f"Received algorithm_id: {algorithm_id}")
|
||||
|
||||
# 检查是否已经包含前缀
|
||||
repo_prefix = self.config.get('repo_prefix', '')
|
||||
if repo_prefix and algorithm_id.startswith(repo_prefix):
|
||||
logger.info(f"Algorithm ID already contains prefix: {repo_prefix}")
|
||||
repo_name = algorithm_id
|
||||
else:
|
||||
# 生成仓库名称,添加前缀
|
||||
repo_name = f"{repo_prefix}{algorithm_id}" if repo_prefix else algorithm_id
|
||||
logger.info(f"Generated repository name: {repo_name}")
|
||||
|
||||
logger.info(f"Creating repository: {repo_name} for owner: {self.config['default_owner']}")
|
||||
|
||||
# 创建仓库
|
||||
repo = self.client.create_repository(
|
||||
self.config['default_owner'],
|
||||
repo_name,
|
||||
description or f"Algorithm repository for {algorithm_name}",
|
||||
False
|
||||
)
|
||||
|
||||
if repo:
|
||||
logger.info(f"Repository created successfully: {repo}")
|
||||
# 验证仓库是否真的存在
|
||||
verify_repo = self.client.get_repository(self.config['default_owner'], repo_name)
|
||||
if not verify_repo:
|
||||
logger.error(f"Repository creation verified failed: {repo_name}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"Failed to create repository: {repo_name}")
|
||||
|
||||
return repo
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create repository: {str(e)}")
|
||||
return None
|
||||
|
||||
def clone_repository(self, repo_url: str, algorithm_id: str, branch: str = "main") -> bool:
|
||||
"""克隆Gitea仓库
|
||||
|
||||
Args:
|
||||
repo_url: 仓库URL
|
||||
algorithm_id: 算法ID
|
||||
branch: 分支名称
|
||||
|
||||
Returns:
|
||||
是否克隆成功
|
||||
"""
|
||||
try:
|
||||
# 创建本地目录
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
|
||||
logger.info(f"Cloning repository to: {repo_dir}")
|
||||
|
||||
# 导入需要的模块
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
# 如果目录已存在,先清理它
|
||||
if os.path.exists(repo_dir):
|
||||
logger.info(f"Cleaning existing repository directory: {repo_dir}")
|
||||
try:
|
||||
shutil.rmtree(repo_dir)
|
||||
logger.info(f"Successfully cleaned directory: {repo_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean directory: {str(e)}")
|
||||
# 尝试使用sudo删除(如果有权限)
|
||||
try:
|
||||
subprocess.run(["sudo", "rm", "-rf", repo_dir], check=True)
|
||||
logger.info(f"Successfully cleaned directory with sudo: {repo_dir}")
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to clean directory with sudo: {str(e2)}")
|
||||
return False
|
||||
|
||||
# 重新创建目录
|
||||
logger.info(f"Creating directory: {repo_dir}")
|
||||
os.makedirs(repo_dir, exist_ok=True)
|
||||
logger.info(f"Directory created successfully: {repo_dir}")
|
||||
|
||||
# 克隆仓库
|
||||
cmd = ["git", "clone", "-b", branch, repo_url, repo_dir]
|
||||
logger.info(f"Running clone command: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"Repository cloned successfully: {repo_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to clone repository: {result.stderr}")
|
||||
|
||||
# 尝试初始化仓库
|
||||
logger.info(f"Trying to initialize repository in {repo_dir}")
|
||||
|
||||
# 初始化git仓库
|
||||
init_result = subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if init_result.returncode != 0:
|
||||
logger.error(f"Failed to initialize git repository: {init_result.stderr}")
|
||||
return False
|
||||
|
||||
# 添加远程仓库
|
||||
remote_result = subprocess.run(["git", "remote", "add", "origin", repo_url], cwd=repo_dir, capture_output=True, text=True)
|
||||
if remote_result.returncode != 0:
|
||||
logger.error(f"Failed to add remote repository: {remote_result.stderr}")
|
||||
# 如果远程仓库已存在,尝试更新它
|
||||
logger.info("Trying to update existing remote repository")
|
||||
update_result = subprocess.run(["git", "remote", "set-url", "origin", repo_url], cwd=repo_dir, capture_output=True, text=True)
|
||||
if update_result.returncode != 0:
|
||||
logger.error(f"Failed to update remote repository: {update_result.stderr}")
|
||||
return False
|
||||
logger.info("Successfully updated remote repository")
|
||||
|
||||
# 创建初始文件
|
||||
readme_path = os.path.join(repo_dir, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write("# Algorithm Repository\n\nThis is an algorithm repository.\n")
|
||||
|
||||
# 添加文件并提交
|
||||
add_result = subprocess.run(["git", "add", "README.md"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.returncode != 0:
|
||||
logger.error(f"Failed to add README.md: {add_result.stderr}")
|
||||
return False
|
||||
|
||||
commit_result = subprocess.run(["git", "commit", "-m", "Initial commit"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if commit_result.returncode != 0:
|
||||
logger.error(f"Failed to commit initial file: {commit_result.stderr}")
|
||||
return False
|
||||
|
||||
# 推送代码到远程仓库
|
||||
push_result = subprocess.run(["git", "push", "-u", "origin", branch], cwd=repo_dir, capture_output=True, text=True)
|
||||
if push_result.returncode != 0:
|
||||
logger.error(f"Failed to push initial commit: {push_result.stderr}")
|
||||
# 即使推送失败,初始化仓库也算成功
|
||||
logger.info(f"Repository initialized successfully, but push failed: {push_result.stderr}")
|
||||
return True
|
||||
|
||||
logger.info(f"Repository initialized and pushed successfully: {repo_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clone repository: {str(e)}")
|
||||
return False
|
||||
|
||||
def push_to_repository(self, algorithm_id: str, message: str = "Update code") -> bool:
|
||||
"""推送代码到Gitea仓库
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
message: 提交消息
|
||||
|
||||
Returns:
|
||||
是否推送成功
|
||||
"""
|
||||
try:
|
||||
logger.info("=== 开始推送代码到Gitea仓库 ===")
|
||||
logger.info(f"Algorithm ID: {algorithm_id}")
|
||||
logger.info(f"Commit message: {message}")
|
||||
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
logger.info(f"Repository directory: {repo_dir}")
|
||||
|
||||
if not os.path.exists(repo_dir):
|
||||
logger.error(f"❌ Repository directory not found: {repo_dir}")
|
||||
return False
|
||||
|
||||
# 首先尝试使用API上传(推荐方法,避免Git推送限制)
|
||||
logger.info("Attempting to upload files via Gitea API...")
|
||||
api_upload_success = self.upload_files_via_api(algorithm_id, message)
|
||||
|
||||
if api_upload_success:
|
||||
logger.info(f"✅ Code uploaded successfully via API for algorithm: {algorithm_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning("❌ API upload failed, falling back to Git push...")
|
||||
|
||||
# 如果API上传失败,回退到原来的Git推送方法
|
||||
import subprocess
|
||||
|
||||
# 检查是否是git仓库
|
||||
git_dir = os.path.join(repo_dir, ".git")
|
||||
if not os.path.exists(git_dir):
|
||||
logger.info(f"⚠️ Git repository not initialized, initializing...")
|
||||
# 初始化git仓库
|
||||
logger.info(f"Executing: git init in {repo_dir}")
|
||||
init_result = subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git init output: {init_result.stdout}")
|
||||
if init_result.stderr:
|
||||
logger.warning(f"Git init stderr: {init_result.stderr}")
|
||||
if init_result.returncode != 0:
|
||||
logger.error(f"❌ Failed to initialize git repository: {init_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Git repository initialized successfully")
|
||||
|
||||
# 添加远程仓库(从配置中获取,包含访问令牌以确保认证)
|
||||
if self.config.get('default_owner'):
|
||||
# 使用访问令牌构建认证URL
|
||||
auth_repo_url = f"https://{self.config['access_token']}@{self.config['server_url'].replace('https://', '').replace('http://', '')}/{self.config['default_owner']}/{algorithm_id}.git"
|
||||
logger.info(f"Adding remote repository: {auth_repo_url}")
|
||||
remote_result = subprocess.run(["git", "remote", "add", "origin", auth_repo_url], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git remote add output: {remote_result.stdout}")
|
||||
if remote_result.stderr:
|
||||
logger.warning(f"Git remote add stderr: {remote_result.stderr}")
|
||||
if remote_result.returncode != 0:
|
||||
logger.error(f"❌ Failed to add remote repository: {remote_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Remote repository added successfully")
|
||||
else:
|
||||
logger.info("✅ Git repository already initialized")
|
||||
|
||||
# 执行git命令 - 分批添加文件以处理大量文件
|
||||
logger.info("=== 执行Git操作 ===")
|
||||
|
||||
# 获取所有需要添加的文件
|
||||
all_files = []
|
||||
for root, dirs, files in os.walk(repo_dir):
|
||||
if '.git' in root:
|
||||
continue
|
||||
for file in files:
|
||||
if not file.endswith('.git'):
|
||||
file_path = os.path.relpath(os.path.join(root, file), repo_dir)
|
||||
all_files.append(file_path)
|
||||
|
||||
logger.info(f"Total files to add: {len(all_files)}")
|
||||
|
||||
# 分批添加文件,避免命令行参数过长
|
||||
batch_size = 100 # 每次添加100个文件
|
||||
for i in range(0, len(all_files), batch_size):
|
||||
batch = all_files[i:i + batch_size]
|
||||
logger.info(f"Adding batch {i//batch_size + 1}: {len(batch)} files")
|
||||
|
||||
add_result = subprocess.run(["git", "add"] + batch, cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.stderr and add_result.returncode != 0:
|
||||
logger.error(f"❌ Git add batch {i//batch_size + 1} failed: {add_result.stderr}")
|
||||
return False
|
||||
elif add_result.stderr:
|
||||
logger.warning(f"Git add batch {i//batch_size + 1} warning: {add_result.stderr}")
|
||||
|
||||
logger.info("✅ Git add completed successfully")
|
||||
|
||||
# 检查是否有更改需要提交
|
||||
logger.info("Executing: git status --porcelain")
|
||||
status_result = subprocess.run(["git", "status", "--porcelain"], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git status output: {status_result.stdout}")
|
||||
if status_result.stderr:
|
||||
logger.warning(f"Git status stderr: {status_result.stderr}")
|
||||
if status_result.returncode != 0:
|
||||
logger.error(f"❌ Git status failed: {status_result.stderr}")
|
||||
return False
|
||||
|
||||
# 如果有更改,执行commit和push
|
||||
if status_result.stdout.strip():
|
||||
logger.info("✅ Changes detected, proceeding with commit and push")
|
||||
# 执行git commit
|
||||
logger.info(f"Executing: git commit -m '{message}'")
|
||||
commit_result = subprocess.run(["git", "commit", "-m", message], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git commit output: {commit_result.stdout}")
|
||||
if commit_result.stderr:
|
||||
logger.warning(f"Git commit stderr: {commit_result.stderr}")
|
||||
if commit_result.returncode != 0:
|
||||
logger.error(f"❌ Git commit failed: {commit_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Git commit completed successfully")
|
||||
|
||||
# 检查仓库大小
|
||||
logger.info("Checking repository size before push")
|
||||
total_size = 0
|
||||
for dirpath, dirnames, filenames in os.walk(repo_dir):
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
if not filepath.startswith(os.path.join(repo_dir, '.git')):
|
||||
total_size += os.path.getsize(filepath)
|
||||
logger.info(f"Repository size (excluding .git): {total_size / (1024 * 1024):.2f} MB")
|
||||
|
||||
if total_size > 100 * 1024 * 1024: # 100MB
|
||||
logger.warning(f"Repository is large: {total_size / (1024 * 1024):.2f} MB")
|
||||
logger.warning("This may cause HTTP 413 errors on push")
|
||||
|
||||
# 设置Git推送缓冲区大小(增加到1GB)
|
||||
logger.info("Setting Git http.postBuffer to 1GB")
|
||||
buffer_result = subprocess.run(["git", "config", "http.postBuffer", "1073741824"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if buffer_result.returncode != 0:
|
||||
logger.warning(f"Failed to set http.postBuffer: {buffer_result.stderr}")
|
||||
else:
|
||||
logger.info("✅ Git http.postBuffer set successfully")
|
||||
|
||||
# 禁用Git压缩
|
||||
logger.info("Disabling Git compression")
|
||||
compression_result = subprocess.run(["git", "config", "core.compression", "0"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if compression_result.returncode != 0:
|
||||
logger.warning(f"Failed to set core.compression: {compression_result.stderr}")
|
||||
else:
|
||||
logger.info("✅ Git core.compression disabled successfully")
|
||||
|
||||
# 针对大仓库优化的推送命令
|
||||
logger.info("Setting additional Git configs for large repositories...")
|
||||
subprocess.run(["git", "config", "http.postBuffer", "524288000"], cwd=repo_dir) # 500MB buffer
|
||||
subprocess.run(["git", "config", "pack.windowMemory", "128m"], cwd=repo_dir) # Limit memory usage
|
||||
subprocess.run(["git", "config", "pack.packSizeLimit", "128m"], cwd=repo_dir) # Limit pack size
|
||||
|
||||
# 执行git push,添加更多优化参数
|
||||
logger.info("Executing: git push with optimizations for large repositories")
|
||||
push_result = subprocess.run([
|
||||
"git", "push",
|
||||
"--verbose",
|
||||
"-u", "origin", "main",
|
||||
"--receive-pack='git receive-pack'", # Ensure proper receive pack
|
||||
"--progress" # Show progress for large pushes
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=300) # 5 minute timeout
|
||||
logger.info(f"Git push output: {push_result.stdout}")
|
||||
if push_result.stderr:
|
||||
logger.warning(f"Git push stderr: {push_result.stderr}")
|
||||
if push_result.returncode != 0:
|
||||
# 检查是否是常见的大文件错误
|
||||
error_msg = push_result.stderr.lower()
|
||||
is_large_file_error = (
|
||||
"http 413" in error_msg or
|
||||
"payload too large" in error_msg or
|
||||
"unpack failed" in error_msg or
|
||||
"remote: fatal" in error_msg or
|
||||
"cannot spawn" in error_msg or
|
||||
"timeout" in error_msg
|
||||
)
|
||||
|
||||
if is_large_file_error:
|
||||
logger.error(f"❌ Git push failed likely due to repository size: {total_size / (1024 * 1024):.2f} MB")
|
||||
logger.error(f"Error details: {push_result.stderr}")
|
||||
logger.error("\n📋 解决方案建议:")
|
||||
logger.error("1. 检查Gitea服务器配置,增加MAX_UPLOAD_SIZE限制")
|
||||
logger.error("2. 尝试使用SSH协议进行推送(如果服务器支持)")
|
||||
logger.error("3. 优化仓库大小,移除不必要的大文件")
|
||||
logger.error("4. 考虑使用Git LFS(Large File Storage)管理大文件")
|
||||
|
||||
# 尝试使用SSH协议进行推送(如果URL是HTTPS格式)
|
||||
logger.info("\n🔄 尝试使用SSH协议进行推送...")
|
||||
try:
|
||||
# 获取当前远程URL
|
||||
remote_result = subprocess.run(["git", "remote", "get-url", "origin"], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if remote_result.returncode == 0:
|
||||
https_url = remote_result.stdout.strip()
|
||||
# 将HTTPS URL转换为SSH URL
|
||||
if https_url.startswith("https://"):
|
||||
ssh_url = https_url.replace("https://", "git@").replace(":", "/")
|
||||
logger.info(f"Converting HTTPS URL to SSH URL: {ssh_url}")
|
||||
# 更新远程URL
|
||||
set_url_result = subprocess.run(["git", "remote", "set-url", "origin", ssh_url], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if set_url_result.returncode == 0:
|
||||
logger.info("✅ Remote URL updated to SSH format")
|
||||
# 再次尝试推送,使用更保守的参数
|
||||
logger.info("Executing: git push with SSH and conservative parameters")
|
||||
ssh_push_result = subprocess.run([
|
||||
"git", "push",
|
||||
"--verbose",
|
||||
"-u", "origin", "main"
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=600) # 10 minute timeout for SSH
|
||||
|
||||
if ssh_push_result.returncode == 0:
|
||||
logger.info("✅ Git push completed successfully with SSH")
|
||||
# 改回HTTPS URL
|
||||
reset_url_result = subprocess.run(["git", "remote", "set-url", "origin", https_url], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if reset_url_result.returncode != 0:
|
||||
logger.warning(f"Failed to reset remote URL to HTTPS: {reset_url_result.stderr}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"SSH push failed: {ssh_push_result.stderr}")
|
||||
# 改回HTTPS URL
|
||||
reset_url_result = subprocess.run(["git", "remote", "set-url", "origin", https_url], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if reset_url_result.returncode != 0:
|
||||
logger.warning(f"Failed to reset remote URL to HTTPS: {reset_url_result.stderr}")
|
||||
|
||||
# 如果SSH也失败,尝试分阶段推送
|
||||
logger.info("\n🔄 尝试分阶段推送...")
|
||||
return self.push_repository_staged(repo_dir, https_url)
|
||||
else:
|
||||
logger.warning(f"Could not get remote URL: {remote_result.stderr}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Remote URL command timed out")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to try SSH push: {str(e)}")
|
||||
else:
|
||||
logger.error(f"❌ Git push failed: {push_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Git push completed successfully")
|
||||
else:
|
||||
logger.info("ℹ️ No changes to commit")
|
||||
|
||||
logger.info(f"✅ Code pushed successfully for algorithm: {algorithm_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"=== 推送代码失败 ===")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def pull_from_repository(self, algorithm_id: str) -> bool:
|
||||
"""从Gitea仓库拉取代码
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
|
||||
Returns:
|
||||
是否拉取成功
|
||||
"""
|
||||
try:
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
|
||||
if not os.path.exists(repo_dir):
|
||||
logger.error(f"Repository directory not found: {repo_dir}")
|
||||
return False
|
||||
|
||||
# 执行git pull命令
|
||||
result = subprocess.run(
|
||||
["git", "pull"],
|
||||
cwd=repo_dir,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"Code pulled successfully for algorithm: {algorithm_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to pull code: {result.stderr}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull code: {str(e)}")
|
||||
return False
|
||||
|
||||
def push_repository_staged(self, repo_dir: str, origin_url: str) -> bool:
|
||||
"""
|
||||
分阶段推送仓库,用于处理超大仓库
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
logger.info("=== 开始分阶段推送仓库 ===")
|
||||
logger.info(f"Repository directory: {repo_dir}")
|
||||
|
||||
# 获取所有文件并按类型分组
|
||||
all_files = []
|
||||
for root, dirs, files in os.walk(repo_dir):
|
||||
# 跳过 .git 目录
|
||||
if '.git' in root:
|
||||
continue
|
||||
for file in files:
|
||||
file_path = os.path.relpath(os.path.join(root, file), repo_dir)
|
||||
if file_path.startswith('.git'):
|
||||
continue
|
||||
all_files.append(file_path)
|
||||
|
||||
logger.info(f"Total files to stage: {len(all_files)}")
|
||||
|
||||
# 按扩展名分类文件,优先推送小文件
|
||||
def get_file_size(file_path):
|
||||
try:
|
||||
return os.path.getsize(os.path.join(repo_dir, file_path))
|
||||
except:
|
||||
return 0
|
||||
|
||||
# 按文件大小排序(从小到大)
|
||||
sorted_files = sorted(all_files, key=get_file_size)
|
||||
|
||||
# 分批处理,每批最多50个文件或不超过50MB
|
||||
batch_size_limit = 50
|
||||
batch_size_bytes = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
current_batch = []
|
||||
current_batch_size = 0
|
||||
batch_number = 1
|
||||
|
||||
for file_path in sorted_files:
|
||||
file_full_path = os.path.join(repo_dir, file_path)
|
||||
file_size = get_file_size(file_path)
|
||||
|
||||
# 如果单个文件太大,单独处理
|
||||
if file_size > batch_size_bytes:
|
||||
logger.info(f"Handling large file separately: {file_path} ({file_size / (1024*1024):.2f}MB)")
|
||||
# 单独添加和推送这个大文件
|
||||
add_result = subprocess.run(["git", "add", file_path], cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.returncode != 0:
|
||||
logger.error(f"Failed to add large file {file_path}: {add_result.stderr}")
|
||||
continue
|
||||
|
||||
# 检查是否有暂存的更改
|
||||
status_result = subprocess.run(["git", "status", "--porcelain"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if status_result.stdout.strip():
|
||||
# 创建专门的提交
|
||||
commit_msg = f"Add large file: {file_path}"
|
||||
commit_result = subprocess.run(["git", "commit", "-m", commit_msg], cwd=repo_dir, capture_output=True, text=True)
|
||||
if commit_result.returncode == 0:
|
||||
logger.info(f"Committed large file: {file_path}")
|
||||
|
||||
# 推送这个提交
|
||||
push_result = subprocess.run([
|
||||
"git", "push", "--verbose", "origin", "main"
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if push_result.returncode != 0:
|
||||
logger.warning(f"Push failed for large file {file_path}: {push_result.stderr}")
|
||||
# 如果推送失败,尝试重置这个文件的暂存状态
|
||||
subprocess.run(["git", "reset", "HEAD", file_path], cwd=repo_dir, capture_output=True, text=True)
|
||||
else:
|
||||
logger.info(f"Successfully pushed large file: {file_path}")
|
||||
else:
|
||||
logger.error(f"Failed to commit large file {file_path}: {commit_result.stderr}")
|
||||
else:
|
||||
# 尝试添加到当前批次
|
||||
if (len(current_batch) >= batch_size_limit or
|
||||
current_batch_size + file_size > batch_size_bytes):
|
||||
# 推送当前批次
|
||||
if current_batch:
|
||||
logger.info(f"Pushing batch {batch_number} with {len(current_batch)} files...")
|
||||
success = self.push_batch(repo_dir, current_batch, batch_number, origin_url)
|
||||
if not success:
|
||||
logger.error(f"Failed to push batch {batch_number}")
|
||||
return False
|
||||
batch_number += 1
|
||||
current_batch = []
|
||||
current_batch_size = 0
|
||||
|
||||
current_batch.append(file_path)
|
||||
current_batch_size += file_size
|
||||
|
||||
# 推送最后一批
|
||||
if current_batch:
|
||||
logger.info(f"Pushing final batch {batch_number} with {len(current_batch)} files...")
|
||||
success = self.push_batch(repo_dir, current_batch, batch_number, origin_url)
|
||||
if not success:
|
||||
logger.error(f"Failed to push final batch {batch_number}")
|
||||
return False
|
||||
|
||||
logger.info("✅ 分阶段推送完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 分阶段推送失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def push_batch(self, repo_dir: str, file_batch: list, batch_num: int, origin_url: str) -> bool:
|
||||
"""
|
||||
推送文件批次
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
logger.info(f"Processing batch {batch_num}: {len(file_batch)} files")
|
||||
|
||||
# 添加批次中的文件
|
||||
for file_path in file_batch:
|
||||
add_result = subprocess.run(["git", "add", file_path], cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.returncode != 0:
|
||||
logger.error(f"Failed to add file {file_path}: {add_result.stderr}")
|
||||
return False
|
||||
|
||||
# 检查是否有更改需要提交
|
||||
status_result = subprocess.run(["git", "status", "--porcelain"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if not status_result.stdout.strip():
|
||||
logger.info(f"No changes in batch {batch_num}")
|
||||
return True
|
||||
|
||||
# 提交批次
|
||||
commit_result = subprocess.run([
|
||||
"git", "commit", "-m", f"Batch {batch_num}: Add {len(file_batch)} files"
|
||||
], cwd=repo_dir, capture_output=True, text=True)
|
||||
|
||||
if commit_result.returncode != 0:
|
||||
logger.warning(f"Commit failed or no changes for batch {batch_num}: {commit_result.stderr}")
|
||||
# 即使没有更改,也可能正常(比如文件没变)
|
||||
|
||||
# 推送批次
|
||||
push_result = subprocess.run([
|
||||
"git", "push", "--verbose", "origin", "main"
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if push_result.returncode == 0:
|
||||
logger.info(f"✅ Batch {batch_num} pushed successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ Batch {batch_num} push failed: {push_result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"❌ Batch {batch_num} push timed out")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Batch {batch_num} push failed with error: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_repository_info(self, repo_owner: str, repo_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
repo_owner: 仓库所有者
|
||||
repo_name: 仓库名称
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
return self.client.get_repository(repo_owner, repo_name)
|
||||
|
||||
def list_repositories(self, owner: Optional[str] = None) -> Optional[List[Dict[str, Any]]]:
|
||||
"""列出仓库
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
target_owner = owner or self.config.get('default_owner')
|
||||
if not target_owner:
|
||||
return None
|
||||
|
||||
return self.client.list_repositories(target_owner)
|
||||
|
||||
def register_algorithm_from_repo(self, repo_owner: str, repo_name: str, algorithm_id: str) -> bool:
|
||||
"""从仓库注册算法服务
|
||||
|
||||
Args:
|
||||
repo_owner: 仓库所有者
|
||||
repo_name: 仓库名称
|
||||
algorithm_id: 算法ID
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
try:
|
||||
# 这里应该实现从仓库注册算法服务的逻辑
|
||||
# 1. 克隆仓库
|
||||
# 2. 扫描仓库中的算法代码
|
||||
# 3. 注册算法服务
|
||||
|
||||
logger.info(f"Algorithm registered from repo: {repo_owner}/{repo_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register algorithm from repo: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# 递归遍历目录中的所有文件
|
||||
for root, dirs, files in os.walk(repo_dir):
|
||||
# 跳过 .git 目录
|
||||
if '.git' in root:
|
||||
continue
|
||||
|
||||
for file in files:
|
||||
file_path = os.path.relpath(os.path.join(root, file), repo_dir)
|
||||
if file_path.startswith('.git'):
|
||||
continue
|
||||
|
||||
full_file_path = os.path.join(root, file)
|
||||
|
||||
# 读取文件内容并进行base64编码
|
||||
try:
|
||||
with open(full_file_path, 'rb') as f:
|
||||
file_content = f.read()
|
||||
encoded_content = base64.b64encode(file_content).decode('utf-8')
|
||||
|
||||
# 使用Gitea API创建或更新文件
|
||||
if self.client:
|
||||
# 移除开头的./,如果有的话
|
||||
clean_path = file_path.lstrip('./\\')
|
||||
result = self.client.create_file(
|
||||
self.config["default_owner"],
|
||||
algorithm_id,
|
||||
clean_path,
|
||||
encoded_content,
|
||||
f"{message} - Upload {clean_path}"
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(f"✅ File uploaded via API: {clean_path}")
|
||||
else:
|
||||
logger.error(f"❌ Failed to upload file via API: {clean_path}")
|
||||
return False
|
||||
else:
|
||||
logger.error("❌ Gitea client not initialized")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error processing file {file_path}: {str(e)}")
|
||||
return False
|
||||
|
||||
logger.info(f"✅ All files uploaded successfully via API for algorithm: {algorithm_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to upload files via API: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局Gitea服务实例
|
||||
gitea_service = GiteaService()
|
||||
74
backend/app/main.py
Normal file
74
backend/app/main.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.database import engine, Base
|
||||
from app.routes import user, api_key, algorithm, openai, gateway, services, data_management, monitoring, permissions, history, deployment, gitea, repositories
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version=settings.APP_VERSION,
|
||||
description="智能算法展示平台API",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
# 设置更大的上传限制
|
||||
max_form_data_parts=50000, # 允许最多50000个表单数据部分(包括文件)
|
||||
# 额外配置以绕过Starlette的默认限制
|
||||
)
|
||||
|
||||
# 手动配置Starlette的表单解析器限制
|
||||
from starlette.formparsers import MultiPartParser
|
||||
MultiPartParser.max_files = 50000
|
||||
MultiPartParser.max_fields = 50000
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(user.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(api_key.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(algorithm.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(openai.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(gateway.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(services.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(data_management.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(monitoring.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(permissions.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(history.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(deployment.router)
|
||||
app.include_router(gitea.router, prefix=settings.API_V1_STR)
|
||||
app.include_router(repositories.router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"message": "Welcome to 智能算法展示平台 API",
|
||||
"version": settings.APP_VERSION,
|
||||
"docs": "/docs"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/@vite/client")
|
||||
async def serve_vite_client():
|
||||
"""处理 @vite/client 请求,避免 404 错误"""
|
||||
from fastapi.responses import Response
|
||||
return Response("", media_type="application/javascript")
|
||||
BIN
backend/app/models/__pycache__/database.cpython-312.pyc
Normal file
BIN
backend/app/models/__pycache__/database.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/database.cpython-39.pyc
Normal file
BIN
backend/app/models/__pycache__/database.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/models.cpython-312.pyc
Normal file
BIN
backend/app/models/__pycache__/models.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/models/__pycache__/models.cpython-39.pyc
Normal file
BIN
backend/app/models/__pycache__/models.cpython-39.pyc
Normal file
Binary file not shown.
38
backend/app/models/database.py
Normal file
38
backend/app/models/database.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_recycle=3600, # 连接回收时间
|
||||
pool_timeout=30, # 连接超时时间
|
||||
echo=False # 关闭SQL日志
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False # 提交后不自动过期对象
|
||||
)
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库会话的依赖函数"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
162
backend/app/models/models.py
Normal file
162
backend/app/models/models.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from sqlalchemy import Column, Integer, String, Float, Text, Boolean, DateTime, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.models.database import Base
|
||||
|
||||
|
||||
class Algorithm(Base):
|
||||
"""算法模型"""
|
||||
__tablename__ = "algorithms"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
type = Column(String, nullable=False, index=True) # computer_vision, nlp, ml, edge_computing, medical, autonomous_driving等
|
||||
status = Column(String, default="active", index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关系
|
||||
versions = relationship("AlgorithmVersion", back_populates="algorithm", cascade="all, delete-orphan")
|
||||
calls = relationship("AlgorithmCall", back_populates="algorithm")
|
||||
|
||||
|
||||
class AlgorithmVersion(Base):
|
||||
"""算法版本模型"""
|
||||
__tablename__ = "algorithm_versions"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
algorithm_id = Column(String, ForeignKey("algorithms.id"), nullable=False, index=True)
|
||||
version = Column(String, nullable=False)
|
||||
url = Column(String, nullable=False) # 算法API地址
|
||||
params = Column(JSON, default=dict) # 算法参数配置
|
||||
input_schema = Column(JSON, default=dict) # 输入数据格式
|
||||
output_schema = Column(JSON, default=dict) # 输出数据格式
|
||||
code = Column(Text, default='') # Python算法代码
|
||||
model_name = Column(String, default='') # API训练后的模型名字
|
||||
model_file = Column(String, default='') # 模型文件路径
|
||||
api_doc = Column(Text, default='') # 模型的API用法文档
|
||||
is_default = Column(Boolean, default=False) # 是否为默认版本
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关系
|
||||
algorithm = relationship("Algorithm", back_populates="versions")
|
||||
calls = relationship("AlgorithmCall", back_populates="version")
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户模型"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
username = Column(String, unique=True, nullable=False, index=True)
|
||||
email = Column(String, unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
role = Column(String, default="user", index=True) # admin, user, customer
|
||||
status = Column(String, default="active", index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关系
|
||||
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
|
||||
calls = relationship("AlgorithmCall", back_populates="user")
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
"""API密钥模型"""
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True)
|
||||
key = Column(String, unique=True, nullable=False, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
status = Column(String, default="active", index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
|
||||
class AlgorithmCall(Base):
|
||||
"""算法调用记录模型"""
|
||||
__tablename__ = "algorithm_calls"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True)
|
||||
algorithm_id = Column(String, ForeignKey("algorithms.id"), nullable=False, index=True)
|
||||
version_id = Column(String, ForeignKey("algorithm_versions.id"), nullable=False, index=True)
|
||||
input_data = Column(JSON, nullable=False) # 输入数据
|
||||
params = Column(JSON, default=dict) # 调用参数
|
||||
output_data = Column(JSON, default=dict) # 输出数据
|
||||
status = Column(String, default="pending", index=True) # pending, running, success, failed
|
||||
response_time = Column(Float, nullable=True) # 响应时间(秒)
|
||||
error_message = Column(Text, nullable=True) # 错误信息
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="calls")
|
||||
algorithm = relationship("Algorithm", back_populates="calls")
|
||||
version = relationship("AlgorithmVersion", back_populates="calls")
|
||||
|
||||
|
||||
class GiteaConfig(Base):
|
||||
"""Gitea配置模型"""
|
||||
__tablename__ = "gitea_configs"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
server_url = Column(String, nullable=False) # Gitea服务器URL
|
||||
access_token = Column(String, nullable=False) # 访问令牌
|
||||
default_owner = Column(String, nullable=False) # 默认组织/用户
|
||||
repo_prefix = Column(String, default="") # 仓库前缀
|
||||
status = Column(String, default="active") # 状态
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class AlgorithmRepository(Base):
|
||||
"""算法仓库模型"""
|
||||
__tablename__ = "algorithm_repositories"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
algorithm_id = Column(String, ForeignKey("algorithms.id"), nullable=True, index=True) # 关联的算法ID
|
||||
name = Column(String, nullable=False, index=True) # 仓库名称
|
||||
description = Column(Text, default="") # 仓库描述
|
||||
type = Column(String, default="code") # 仓库类型:code, model, hybrid
|
||||
repo_url = Column(String, nullable=False) # Git仓库URL
|
||||
branch = Column(String, default="main") # 分支名称
|
||||
local_path = Column(String, default="") # 本地存储路径
|
||||
status = Column(String, default="active", index=True) # 状态
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关系
|
||||
algorithm = relationship("Algorithm", back_populates="repository", uselist=False)
|
||||
|
||||
|
||||
class AlgorithmService(Base):
|
||||
"""算法服务模型"""
|
||||
__tablename__ = "algorithm_services"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
service_id = Column(String, unique=True, nullable=False, index=True) # 服务ID
|
||||
name = Column(String, nullable=False, index=True) # 服务名称
|
||||
algorithm_name = Column(String, nullable=False) # 算法名称
|
||||
version = Column(String, nullable=False) # 版本
|
||||
host = Column(String, nullable=False) # 主机地址
|
||||
port = Column(Integer, nullable=False) # 端口
|
||||
api_url = Column(String, nullable=False) # API地址
|
||||
status = Column(String, default="stopped", index=True) # 状态:running, stopped, error, restarting
|
||||
config = Column(JSON, default=dict) # 服务配置
|
||||
start_time = Column(DateTime(timezone=True), nullable=True) # 启动时间
|
||||
last_heartbeat = Column(DateTime(timezone=True), nullable=True) # 最后心跳时间
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
# 添加Algorithm模型的repository关系
|
||||
Algorithm.repository = relationship("AlgorithmRepository", back_populates="algorithm", uselist=False)
|
||||
14
backend/app/routes/__init__.py
Normal file
14
backend/app/routes/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
from app.routes import user, algorithm, api_key, history, gateway, monitoring, openai, deployment
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册路由
|
||||
api_router.include_router(user.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(algorithm.router, prefix="/algorithms", tags=["algorithms"])
|
||||
api_router.include_router(api_key.router, prefix="/api-keys", tags=["api-keys"])
|
||||
api_router.include_router(history.router, prefix="/history", tags=["history"])
|
||||
api_router.include_router(gateway.router, prefix="/gateway", tags=["gateway"])
|
||||
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
|
||||
api_router.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||
api_router.include_router(deployment.router, tags=["deployment"])
|
||||
BIN
backend/app/routes/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/algorithm.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/algorithm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/algorithm.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/algorithm.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/api_key.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/api_key.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/api_key.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/api_key.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/data_management.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/data_management.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/data_management.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/data_management.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/deployment.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/deployment.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/gateway.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/gateway.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/gateway.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/gateway.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/gitea.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/gitea.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/history.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/history.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/history.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/history.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/monitoring.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/monitoring.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/monitoring.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/monitoring.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/openai.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/openai.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/openai.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/openai.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/permissions.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/permissions.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/permissions.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/permissions.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/repositories.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/repositories.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/services.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/services.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/services.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/services.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/user.cpython-312.pyc
Normal file
BIN
backend/app/routes/__pycache__/user.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/user.cpython-39.pyc
Normal file
BIN
backend/app/routes/__pycache__/user.cpython-39.pyc
Normal file
Binary file not shown.
392
backend/app/routes/algorithm.py
Normal file
392
backend/app/routes/algorithm.py
Normal file
@@ -0,0 +1,392 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, UploadFile, File
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import uuid
|
||||
from app.utils.file import file_storage
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmResponse, AlgorithmListResponse, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmVersionResponse, AlgorithmCallCreate, AlgorithmCallResult
|
||||
from app.models.models import AlgorithmCall
|
||||
from app.services.algorithm import AlgorithmService, AlgorithmVersionService, AlgorithmCallService
|
||||
from app.dependencies import get_current_active_user
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/algorithms", tags=["algorithms"])
|
||||
|
||||
|
||||
@router.post("", response_model=AlgorithmResponse)
|
||||
async def create_algorithm(
|
||||
algorithm: AlgorithmCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建算法"""
|
||||
# 只有管理员可以创建算法
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 创建算法
|
||||
db_algorithm = AlgorithmService.create_algorithm(db, algorithm)
|
||||
|
||||
return db_algorithm
|
||||
|
||||
|
||||
@router.get("", response_model=AlgorithmListResponse)
|
||||
async def get_algorithms(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
type: Optional[str] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法列表"""
|
||||
algorithms = AlgorithmService.get_algorithms(db, skip=skip, limit=limit, algorithm_type=type)
|
||||
return {"algorithms": algorithms, "total": len(algorithms)}
|
||||
|
||||
|
||||
@router.get("/{algorithm_id}", response_model=AlgorithmResponse)
|
||||
async def get_algorithm(
|
||||
algorithm_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法详情"""
|
||||
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
return algorithm
|
||||
|
||||
|
||||
@router.put("/{algorithm_id}", response_model=AlgorithmResponse)
|
||||
async def update_algorithm(
|
||||
algorithm_id: str,
|
||||
algorithm_update: AlgorithmUpdate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新算法"""
|
||||
# 只有管理员可以更新算法
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
algorithm = AlgorithmService.update_algorithm(db, algorithm_id, algorithm_update)
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
return algorithm
|
||||
|
||||
|
||||
@router.delete("/{algorithm_id}", response_model=dict)
|
||||
async def delete_algorithm(
|
||||
algorithm_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除算法"""
|
||||
# 只有管理员可以删除算法
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
success = AlgorithmService.delete_algorithm(db, algorithm_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
return {"message": "Algorithm deleted successfully"}
|
||||
|
||||
|
||||
# 算法版本相关路由
|
||||
@router.post("/{algorithm_id}/versions", response_model=AlgorithmVersionResponse)
|
||||
async def create_version(
|
||||
algorithm_id: str,
|
||||
version: AlgorithmVersionCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建算法版本"""
|
||||
# 只有管理员可以创建算法版本
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 确保版本的算法ID与路径一致
|
||||
version.algorithm_id = algorithm_id
|
||||
|
||||
# 创建版本
|
||||
db_version = AlgorithmVersionService.create_version(db, version)
|
||||
|
||||
return db_version
|
||||
|
||||
|
||||
@router.get("/{algorithm_id}/versions", response_model=List[AlgorithmVersionResponse])
|
||||
async def get_versions(
|
||||
algorithm_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法版本列表"""
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本列表
|
||||
versions = AlgorithmVersionService.get_versions_by_algorithm_id(db, algorithm_id)
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
@router.get("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
|
||||
async def get_version(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法版本详情"""
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# 验证版本是否属于该算法
|
||||
if version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@router.put("/{algorithm_id}/versions/{version_id}", response_model=AlgorithmVersionResponse)
|
||||
async def update_version(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
version_update: AlgorithmVersionUpdate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新算法版本"""
|
||||
# 只有管理员可以更新算法版本
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# 验证版本是否属于该算法
|
||||
if version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
||||
|
||||
# 更新版本
|
||||
updated_version = AlgorithmVersionService.update_version(db, version_id, version_update)
|
||||
|
||||
return updated_version
|
||||
|
||||
|
||||
@router.delete("/{algorithm_id}/versions/{version_id}", response_model=dict)
|
||||
async def delete_version(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除算法版本"""
|
||||
# 只有管理员可以删除算法版本
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 验证算法是否存在
|
||||
if not AlgorithmService.get_algorithm_by_id(db, algorithm_id):
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
# 获取版本
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# 验证版本是否属于该算法
|
||||
if version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="Version does not belong to this algorithm")
|
||||
|
||||
# 删除版本
|
||||
success = AlgorithmVersionService.delete_version(db, version_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to delete version")
|
||||
|
||||
return {"message": "Version deleted successfully"}
|
||||
|
||||
|
||||
# 算法调用相关路由
|
||||
@router.post("/call", response_model=AlgorithmCallResult)
|
||||
async def call_algorithm(
|
||||
call: AlgorithmCallCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""调用算法"""
|
||||
# 执行算法
|
||||
result = AlgorithmCallService.execute_algorithm(db, current_user.id, call)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/calls/{call_id}", response_model=AlgorithmCallResult)
|
||||
async def get_call_result(
|
||||
call_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取算法调用结果"""
|
||||
# 获取调用记录
|
||||
call = AlgorithmCallService.get_call_by_id(db, call_id)
|
||||
if not call:
|
||||
raise HTTPException(status_code=404, detail="Call not found")
|
||||
|
||||
# 管理员可以查看所有调用记录,普通用户只能查看自己的
|
||||
if current_user.role != "admin" and current_user.id != call.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
return call
|
||||
|
||||
|
||||
@router.get("/calls", response_model=List[AlgorithmCallResult])
|
||||
async def get_call_history(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取调用历史"""
|
||||
# 管理员可以查看所有调用记录,普通用户只能查看自己的
|
||||
if current_user.role == "admin":
|
||||
# 这里可以添加分页和过滤,暂时返回所有
|
||||
calls = db.query(AlgorithmCall).offset(skip).limit(limit).all()
|
||||
else:
|
||||
calls = AlgorithmCallService.get_calls_by_user_id(db, current_user.id, skip=skip, limit=limit)
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
# 代码执行相关路由
|
||||
@router.post("/execute-code")
|
||||
async def execute_code(
|
||||
code: str = Body(..., description="Python代码"),
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""执行Python代码"""
|
||||
# 执行代码
|
||||
result = AlgorithmCallService.execute_python_code(code)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 模型文件上传路由
|
||||
@router.post("/upload-model")
|
||||
async def upload_model(
|
||||
file: UploadFile = File(..., description="模型文件"),
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""上传模型文件"""
|
||||
# 支持的文件类型
|
||||
allowed_extensions = {
|
||||
".pt", ".pth", ".h5", ".hdf5", ".onnx", ".pb", ".tflite",
|
||||
".joblib", ".pkl", ".zip", ".tar.gz"
|
||||
}
|
||||
|
||||
# 验证文件类型
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
if file_extension not in allowed_extensions:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不支持的文件类型,支持的类型:{', '.join(allowed_extensions)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# 生成唯一的文件名
|
||||
unique_filename = f"models/{uuid.uuid4().hex[:8]}{file_extension}"
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 上传文件到MinIO
|
||||
import io
|
||||
file_obj = io.BytesIO(file_content)
|
||||
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": unique_filename,
|
||||
"message": "模型文件上传成功"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "模型文件上传失败"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"模型文件上传失败:{str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 视频文件上传路由
|
||||
@router.post("/upload-video")
|
||||
async def upload_video(
|
||||
file: UploadFile = File(..., description="视频文件"),
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""上传视频文件"""
|
||||
# 支持的视频文件类型
|
||||
allowed_extensions = {
|
||||
".mp4", ".avi", ".mov", ".wmv", ".flv", ".mkv", ".webm"
|
||||
}
|
||||
|
||||
# 验证文件类型
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
if file_extension not in allowed_extensions:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不支持的视频文件类型,支持的类型:{', '.join(allowed_extensions)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# 生成唯一的文件名
|
||||
unique_filename = f"videos/{current_user['id']}/{uuid.uuid4().hex[:12]}{file_extension}"
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 上传文件到MinIO
|
||||
import io
|
||||
file_obj = io.BytesIO(file_content)
|
||||
success = file_storage.upload_fileobj(file_obj, unique_filename, file.content_type)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": unique_filename,
|
||||
"message": "视频文件上传成功"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "视频文件上传失败"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"视频文件上传失败:{str(e)}"
|
||||
}
|
||||
88
backend/app/routes/api_key.py
Normal file
88
backend/app/routes/api_key.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.schemas.user import APIKeyCreate, APIKeyResponse, APIKeyListResponse
|
||||
from app.models.models import APIKey
|
||||
from app.services.user import APIKeyService
|
||||
from app.dependencies import get_current_active_user
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/api-keys", tags=["api-keys"])
|
||||
|
||||
|
||||
@router.post("", response_model=APIKeyResponse)
|
||||
async def create_api_key(
|
||||
api_key_create: APIKeyCreate,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建API密钥"""
|
||||
# 只有管理员或用户本人可以为自己创建API密钥
|
||||
if current_user.role != "admin" and current_user.id != api_key_create.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 创建API密钥
|
||||
api_key = APIKeyService.create_api_key(db, api_key_create)
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
@router.get("", response_model=APIKeyListResponse)
|
||||
async def get_api_keys(
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取API密钥列表"""
|
||||
# 管理员可以查看所有API密钥,普通用户只能查看自己的
|
||||
if current_user.role == "admin":
|
||||
# 这里可以添加分页和过滤,暂时返回所有
|
||||
api_keys = db.query(APIKey).all()
|
||||
else:
|
||||
api_keys = APIKeyService.get_api_keys_by_user_id(db, current_user.id)
|
||||
|
||||
return {"api_keys": api_keys, "total": len(api_keys)}
|
||||
|
||||
|
||||
@router.get("/{api_key_id}", response_model=APIKeyResponse)
|
||||
async def get_api_key(
|
||||
api_key_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取API密钥详情"""
|
||||
# 获取API密钥
|
||||
api_key = APIKeyService.get_api_key_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
# 管理员可以查看所有API密钥,普通用户只能查看自己的
|
||||
if current_user.role != "admin" and current_user.id != api_key.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
@router.delete("/{api_key_id}", response_model=dict)
|
||||
async def revoke_api_key(
|
||||
api_key_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""撤销API密钥"""
|
||||
# 获取API密钥
|
||||
api_key = APIKeyService.get_api_key_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
# 管理员可以撤销所有API密钥,普通用户只能撤销自己的
|
||||
if current_user.role != "admin" and current_user.id != api_key.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 撤销API密钥
|
||||
result = APIKeyService.revoke_api_key(db, api_key_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=400, detail="Failed to revoke API key")
|
||||
|
||||
return {"message": "API key revoked successfully"}
|
||||
345
backend/app/routes/data_management.py
Normal file
345
backend/app/routes/data_management.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""数据管理路由,提供输入数据、输出结果和元数据的管理功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
import json
|
||||
|
||||
from app.services.data_manager import data_manager
|
||||
from app.models.database import get_db
|
||||
from app.dependencies import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/data", tags=["data-management"])
|
||||
|
||||
|
||||
class SaveInputDataRequest(BaseModel):
|
||||
"""保存输入数据请求"""
|
||||
algorithm_id: str
|
||||
input_data: Dict[str, Any]
|
||||
|
||||
|
||||
class SaveOutputDataRequest(BaseModel):
|
||||
"""保存输出数据请求"""
|
||||
algorithm_id: str
|
||||
call_id: str
|
||||
output_data: Dict[str, Any]
|
||||
|
||||
|
||||
class GetDataFilters(BaseModel):
|
||||
"""数据搜索过滤条件"""
|
||||
user_id: Optional[str] = None
|
||||
algorithm_id: Optional[str] = None
|
||||
date_from: Optional[str] = None
|
||||
date_to: Optional[str] = None
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@router.post("/input")
|
||||
async def save_input_data(
|
||||
request: SaveInputDataRequest,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""保存输入数据"""
|
||||
# 检查用户权限
|
||||
if current_user.get("role") not in ["admin", "user"] or current_user.get("id") != request.user_id:
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
data_id = data_manager.save_input_data(
|
||||
user_id=current_user.get("id"),
|
||||
algorithm_id=request.algorithm_id,
|
||||
input_data=request.input_data
|
||||
)
|
||||
|
||||
if data_id:
|
||||
return {
|
||||
"success": True,
|
||||
"data_id": data_id,
|
||||
"message": "Input data saved successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to save input data")
|
||||
|
||||
|
||||
@router.post("/output")
|
||||
async def save_output_data(
|
||||
request: SaveOutputDataRequest,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""保存输出结果数据"""
|
||||
# 检查用户权限
|
||||
if current_user.get("role") not in ["admin", "user"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
data_id = data_manager.save_output_data(
|
||||
user_id=current_user.get("id"),
|
||||
algorithm_id=request.algorithm_id,
|
||||
call_id=request.call_id,
|
||||
output_data=request.output_data
|
||||
)
|
||||
|
||||
if data_id:
|
||||
return {
|
||||
"success": True,
|
||||
"data_id": data_id,
|
||||
"message": "Output data saved successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to save output data")
|
||||
|
||||
|
||||
@router.get("/input/{data_id}")
|
||||
async def get_input_data(
|
||||
data_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取输入数据"""
|
||||
data = data_manager.get_input_data(data_id)
|
||||
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail="Input data not found")
|
||||
|
||||
# 检查用户权限
|
||||
if current_user.get("role") != "admin" and data.get("user_id") != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/output/{data_id}")
|
||||
async def get_output_data(
|
||||
data_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取输出结果数据"""
|
||||
data = data_manager.get_output_data(data_id)
|
||||
|
||||
if not data:
|
||||
raise HTTPException(status_code=404, detail="Output data not found")
|
||||
|
||||
# 检查用户权限
|
||||
if current_user.get("role") != "admin" and data.get("user_id") != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/inputs/user")
|
||||
async def get_user_inputs(
|
||||
algorithm_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取用户的历史输入数据"""
|
||||
# 检查用户权限
|
||||
if current_user.get("role") != "admin" and current_user.get("id") != current_user.get("id"):
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
inputs = data_manager.get_user_inputs(
|
||||
user_id=current_user.get("id"),
|
||||
algorithm_id=algorithm_id,
|
||||
limit=min(limit, 1000) # 限制最大数量
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": inputs,
|
||||
"count": len(inputs)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/outputs/user")
|
||||
async def get_user_outputs(
|
||||
algorithm_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取用户的历史输出数据"""
|
||||
# 检查用户权限
|
||||
if current_user.get("role") != "admin" and current_user.get("id") != current_user.get("id"):
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
outputs = data_manager.get_user_outputs(
|
||||
user_id=current_user.get("id"),
|
||||
algorithm_id=algorithm_id,
|
||||
limit=min(limit, 1000) # 限制最大数量
|
||||
)
|
||||
|
||||
return {
|
||||
"outputs": outputs,
|
||||
"count": len(outputs)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/media/upload")
|
||||
async def upload_media_file(
|
||||
file: UploadFile = File(...),
|
||||
algorithm_id: str = None,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""上传媒体文件(如图片、视频等)"""
|
||||
if not algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="algorithm_id is required")
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 保存到数据管理器
|
||||
file_path = data_manager.save_media_file(
|
||||
user_id=current_user.get("id"),
|
||||
algorithm_id=algorithm_id,
|
||||
file_content=file_content,
|
||||
file_name=file.filename
|
||||
)
|
||||
|
||||
if file_path:
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": file_path,
|
||||
"filename": file.filename,
|
||||
"size": len(file_content),
|
||||
"message": "Media file uploaded successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to upload media file")
|
||||
|
||||
|
||||
@router.get("/media/{file_path:path}")
|
||||
async def get_media_file(
|
||||
file_path: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取媒体文件"""
|
||||
# 检查用户权限 - 确保用户只能访问自己的文件或公共文件
|
||||
if current_user.get("role") != "admin" and not file_path.startswith(f"media/{current_user.get('id')}/"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
content = data_manager.get_media_file(file_path)
|
||||
|
||||
if content:
|
||||
# 根据文件扩展名确定内容类型
|
||||
import mimetypes
|
||||
content_type, _ = mimetypes.guess_type(file_path)
|
||||
if content_type is None:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
from fastapi.responses import Response
|
||||
return Response(content=content, media_type=content_type)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Media file not found")
|
||||
|
||||
|
||||
@router.post("/snapshots/create")
|
||||
async def create_data_snapshot(
|
||||
call_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建数据快照"""
|
||||
from app.models.models import AlgorithmCall
|
||||
|
||||
# 获取调用记录
|
||||
call_record = db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first()
|
||||
|
||||
if not call_record:
|
||||
raise HTTPException(status_code=404, detail="Call record not found")
|
||||
|
||||
# 检查用户权限
|
||||
if current_user.get("role") != "admin" and call_record.user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建快照
|
||||
snapshot = data_manager.create_data_snapshot(call_record)
|
||||
|
||||
if snapshot:
|
||||
return {
|
||||
"success": True,
|
||||
"snapshot": snapshot,
|
||||
"message": "Data snapshot created successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to create data snapshot")
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def search_data_by_metadata(
|
||||
filters: GetDataFilters,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""根据元数据搜索数据"""
|
||||
# 检查用户权限 - 用户只能搜索自己的数据,管理员可以搜索所有数据
|
||||
if current_user.get("role") != "admin":
|
||||
if filters.user_id and filters.user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
# 如果没有指定用户ID,则默认搜索当前用户的数据
|
||||
if not filters.user_id:
|
||||
filters.user_id = current_user.get("id")
|
||||
|
||||
results = data_manager.search_data_by_metadata(filters.dict())
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/user-data")
|
||||
async def delete_user_data(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除用户的所有数据"""
|
||||
# 检查用户权限
|
||||
if current_user.get("role") != "admin" and current_user.get("id") != current_user.get("id"):
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
success = data_manager.delete_user_data(current_user.get("id"))
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "User data deleted successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete user data")
|
||||
|
||||
|
||||
@router.get("/statistics")
|
||||
async def get_data_statistics(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取数据统计信息"""
|
||||
# 这里返回基本的数据统计信息
|
||||
# 在实际实现中,可能会从数据库和存储系统中收集更详细的统计信息
|
||||
from sqlalchemy import func
|
||||
from app.models.models import AlgorithmCall
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
# 统计调用次数
|
||||
total_calls = db.query(func.count(AlgorithmCall.id)).scalar()
|
||||
|
||||
# 统计当前用户调用次数
|
||||
user_calls = db.query(func.count(AlgorithmCall.id)).filter(
|
||||
AlgorithmCall.user_id == current_user.get("id")
|
||||
).scalar()
|
||||
|
||||
# 管理员可以看到全部统计,普通用户只能看到自己的统计
|
||||
if current_user.get("role") == "admin":
|
||||
stats = {
|
||||
"total_calls": total_calls,
|
||||
"user_calls": user_calls,
|
||||
"total_users": 0, # 在实际实现中,从用户表统计
|
||||
"storage_used": "N/A", # 在实际实现中,从存储系统获取
|
||||
"timestamp": "now"
|
||||
}
|
||||
else:
|
||||
stats = {
|
||||
"user_calls": user_calls,
|
||||
"storage_used_by_user": "N/A", # 在实际实现中,从存储系统获取
|
||||
"timestamp": "now"
|
||||
}
|
||||
|
||||
return stats
|
||||
123
backend/app/routes/deployment.py
Normal file
123
backend/app/routes/deployment.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""部署管理API"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from app.services.deployment import deployment_service
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/deployment",
|
||||
tags=["deployment"]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/containers", response_model=List[Dict[str, Any]])
|
||||
def list_containers():
|
||||
"""
|
||||
列出所有算法容器
|
||||
"""
|
||||
try:
|
||||
containers = deployment_service.list_containers()
|
||||
return containers
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list containers: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/containers/{container_name}/status", response_model=Dict[str, Any])
|
||||
def get_container_status(container_name: str):
|
||||
"""
|
||||
获取容器状态
|
||||
"""
|
||||
try:
|
||||
status = deployment_service.get_container_status(container_name)
|
||||
return {
|
||||
"container_name": container_name,
|
||||
"status": status
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get container status: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/containers/{container_name}/stop", response_model=Dict[str, Any])
|
||||
def stop_container(container_name: str):
|
||||
"""
|
||||
停止容器
|
||||
"""
|
||||
try:
|
||||
success = deployment_service.stop_container(container_name)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to stop container: {container_name}")
|
||||
return {
|
||||
"container_name": container_name,
|
||||
"success": success,
|
||||
"message": "Container stopped successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to stop container: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/containers/{container_name}/remove", response_model=Dict[str, Any])
|
||||
def remove_container(container_name: str):
|
||||
"""
|
||||
移除容器
|
||||
"""
|
||||
try:
|
||||
success = deployment_service.remove_container(container_name)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to remove container: {container_name}")
|
||||
return {
|
||||
"container_name": container_name,
|
||||
"success": success,
|
||||
"message": "Container removed successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to remove container: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/containers/{container_name}/restart", response_model=Dict[str, Any])
|
||||
def restart_container(container_name: str):
|
||||
"""
|
||||
重启容器
|
||||
"""
|
||||
try:
|
||||
# 先停止容器
|
||||
stop_success = deployment_service.stop_container(container_name)
|
||||
if not stop_success:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to stop container for restart: {container_name}")
|
||||
|
||||
# 这里简化处理,实际应该重新启动容器
|
||||
# 由于我们没有保存镜像信息,这里返回操作成功
|
||||
return {
|
||||
"container_name": container_name,
|
||||
"success": True,
|
||||
"message": "Container restarted successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart container: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/health", response_model=Dict[str, Any])
|
||||
def deployment_health_check():
|
||||
"""
|
||||
部署服务健康检查
|
||||
"""
|
||||
try:
|
||||
# 检查Docker连接
|
||||
containers = deployment_service.list_containers()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"message": "Deployment service is running",
|
||||
"container_count": len(containers)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Deployment service error: {str(e)}",
|
||||
"container_count": 0
|
||||
}
|
||||
113
backend/app/routes/gateway.py
Normal file
113
backend/app/routes/gateway.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""API网关路由,处理算法调用的统一入口"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from typing import Dict, Any
|
||||
import time
|
||||
import logging
|
||||
|
||||
from app.gateway import api_gateway, call_algorithm_gateway
|
||||
from app.models.database import get_db
|
||||
from app.services.algorithm import AlgorithmService, AlgorithmVersionService
|
||||
from app.schemas.algorithm import AlgorithmCallCreate, AlgorithmCallResult
|
||||
|
||||
router = APIRouter(prefix="/gateway", tags=["gateway"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/call/{algorithm_id}/{version_id}")
|
||||
async def call_algorithm_through_gateway(
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
request: Request,
|
||||
payload: Dict[Any, Any]
|
||||
):
|
||||
"""
|
||||
通过API网关调用算法
|
||||
这是统一的算法调用入口,处理认证、授权、流量控制等功能
|
||||
"""
|
||||
try:
|
||||
# 认证检查
|
||||
user_info = await api_gateway.authenticate_request(request)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||||
|
||||
# 验证算法和版本是否存在
|
||||
db = next(get_db())
|
||||
algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not algorithm:
|
||||
raise HTTPException(status_code=404, detail="Algorithm not found")
|
||||
|
||||
version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not version or version.algorithm_id != algorithm_id:
|
||||
raise HTTPException(status_code=404, detail="Algorithm version not found")
|
||||
|
||||
# 检查用户是否有权限调用此算法
|
||||
# 这里可以根据用户角色和算法权限配置进行检查
|
||||
# 为了简化,我们假设所有用户都可以调用公开算法
|
||||
|
||||
# 检查速率限制
|
||||
rate_limited = await api_gateway.check_rate_limit(user_info['user_id'], algorithm_id)
|
||||
if not rate_limited:
|
||||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||||
|
||||
# 记录调用开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 路由请求到算法服务
|
||||
result = await api_gateway.route_request(algorithm_id, version_id, payload)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# 记录调用日志
|
||||
logger.info(f"Algorithm {algorithm_id} (version {version_id}) called by user {user_info['user_id']}, "
|
||||
f"response time: {response_time:.2f}s")
|
||||
|
||||
# 这里可以添加调用记录到数据库的逻辑
|
||||
# AlgorithmCallService.create_call_record(...)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"algorithm_id": algorithm_id,
|
||||
"version_id": version_id,
|
||||
"response_time": response_time,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Gateway error when calling algorithm {algorithm_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Gateway error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def gateway_health():
|
||||
"""API网关健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "api-gateway",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_gateway_stats(request: Request):
|
||||
"""获取API网关统计信息"""
|
||||
user_info = await api_gateway.authenticate_request(request)
|
||||
if not user_info or user_info.get('role') != 'admin':
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
# 返回一些基本的网关统计信息
|
||||
total_requests = sum(len(counts) for counts in api_gateway.request_counts.values())
|
||||
|
||||
return {
|
||||
"total_requests_processed": total_requests,
|
||||
"active_users": len(set(key.split(':')[0] for key in api_gateway.request_counts.keys())),
|
||||
"algorithms_accessed": len(set(key.split(':')[1] for key in api_gateway.request_counts.keys())),
|
||||
"rate_limit_blocks": 0, # 在实际实现中,这里应该跟踪被阻止的请求数
|
||||
"uptime": "N/A" # 在实际实现中,这里应该是自启动以来的运行时间
|
||||
}
|
||||
325
backend/app/routes/gitea.py
Normal file
325
backend/app/routes/gitea.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Gitea相关的路由"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, File, Form, UploadFile
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from app.gitea.service import gitea_service
|
||||
from app.dependencies import get_current_active_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/gitea", tags=["gitea"])
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_gitea_config(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
获取Gitea配置
|
||||
"""
|
||||
config = gitea_service.get_config()
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Gitea config not found")
|
||||
|
||||
# 隐藏敏感信息
|
||||
config_copy = config.copy()
|
||||
if 'access_token' in config_copy:
|
||||
config_copy['access_token'] = '***'
|
||||
|
||||
return config_copy
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
async def set_gitea_config(
|
||||
config: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
设置Gitea配置
|
||||
"""
|
||||
# 验证配置
|
||||
required_fields = ['server_url', 'access_token', 'default_owner']
|
||||
for field in required_fields:
|
||||
if field not in config or not config[field]:
|
||||
raise HTTPException(status_code=400, detail=f"Missing required field: {field}")
|
||||
|
||||
# 保存配置
|
||||
success = gitea_service.save_config(config)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to save Gitea config")
|
||||
|
||||
# 测试连接
|
||||
connection_success = gitea_service.test_connection()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Gitea config saved successfully",
|
||||
"connection_test": "success" if connection_success else "failed"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-connection")
|
||||
async def test_gitea_connection(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
测试Gitea连接
|
||||
"""
|
||||
success = gitea_service.test_connection()
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to connect to Gitea server")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Connected to Gitea server successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/repos")
|
||||
async def list_gitea_repositories(
|
||||
owner: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
列出Gitea仓库
|
||||
"""
|
||||
repos = gitea_service.list_repositories(owner)
|
||||
if repos is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to list repositories")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"repositories": repos
|
||||
}
|
||||
|
||||
|
||||
@router.post("/repos/create")
|
||||
async def create_gitea_repository(
|
||||
algorithm_id: str = Body(..., description="算法ID"),
|
||||
algorithm_name: str = Body(..., description="算法名称"),
|
||||
description: str = Body("", description="仓库描述"),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
创建Gitea仓库
|
||||
"""
|
||||
repo = gitea_service.create_repository(algorithm_id, algorithm_name, description)
|
||||
if not repo:
|
||||
raise HTTPException(status_code=500, detail="Failed to create repository")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"repository": repo
|
||||
}
|
||||
|
||||
|
||||
@router.post("/repos/clone")
|
||||
async def clone_gitea_repository(
|
||||
repo_url: str = Body(..., description="仓库URL"),
|
||||
algorithm_id: str = Body(..., description="算法ID"),
|
||||
branch: str = Body("main", description="分支名称"),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
克隆Gitea仓库
|
||||
"""
|
||||
success = gitea_service.clone_repository(repo_url, algorithm_id, branch)
|
||||
if not success:
|
||||
# 即使克隆失败,也尝试继续执行,因为我们可能已经初始化了仓库
|
||||
logger.info("Clone failed, but continuing with existing repository setup")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Repository cloned or initialized successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/repos/push")
|
||||
async def push_to_gitea_repository(
|
||||
algorithm_id: str = Body(..., description="算法ID"),
|
||||
message: str = Body("Update code", description="提交消息"),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
推送代码到Gitea仓库
|
||||
"""
|
||||
success = gitea_service.push_to_repository(algorithm_id, message)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to push code")
|
||||
|
||||
# 验证推送是否成功
|
||||
verify_success = gitea_service.verify_push(algorithm_id)
|
||||
if not verify_success:
|
||||
logger.warning(f"Push completed but verification failed for algorithm: {algorithm_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Code pushed but verification failed",
|
||||
"verified": False
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Code pushed successfully",
|
||||
"verified": True
|
||||
}
|
||||
|
||||
|
||||
@router.post("/repos/upload", dependencies=[Depends(get_current_active_user)])
|
||||
async def upload_files_to_repository(
|
||||
files: list[UploadFile] = File(..., description="上传的文件列表"),
|
||||
algorithm_id: str = Form(..., description="算法ID")
|
||||
):
|
||||
"""
|
||||
上传文件到仓库(支持大量文件)
|
||||
"""
|
||||
try:
|
||||
logger.info("=== 开始上传文件 ===")
|
||||
logger.info(f"Received {len(files)} files for algorithm: {algorithm_id}")
|
||||
|
||||
# 验证文件数量
|
||||
MAX_FILES = 50000
|
||||
if len(files) > MAX_FILES:
|
||||
logger.error(f"Too many files: {len(files)} (max: {MAX_FILES})")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Too many files. Maximum number of files is {MAX_FILES}."
|
||||
)
|
||||
|
||||
# 创建仓库目录
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
logger.info(f"Repository directory: {repo_dir}")
|
||||
os.makedirs(repo_dir, exist_ok=True)
|
||||
logger.info(f"Created repository directory: {repo_dir}")
|
||||
|
||||
# 保存上传的文件
|
||||
logger.info("=== 保存上传的文件 ===")
|
||||
saved_files = []
|
||||
|
||||
# 分批处理文件,避免内存问题
|
||||
batch_size = 100 # 每批处理100个文件
|
||||
for batch_start in range(0, len(files), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(files))
|
||||
batch = files[batch_start:batch_end]
|
||||
|
||||
logger.info(f"Processing batch {batch_start//batch_size + 1}: files {batch_start+1} to {batch_end}")
|
||||
|
||||
for i, file in enumerate(batch):
|
||||
# 为了获取文件内容,我们需要读取它
|
||||
file_content = await file.read()
|
||||
|
||||
# 获取文件路径(使用file.filename,它应该包含相对路径)
|
||||
file_path = os.path.join(repo_dir, file.filename)
|
||||
logger.info(f"Processing file {batch_start + i + 1}/{len(files)}: {file.filename}")
|
||||
logger.info(f" Target path: {file_path}")
|
||||
|
||||
# 确保文件所在目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
logger.info(f" Created directory: {os.path.dirname(file_path)}")
|
||||
|
||||
# 保存文件
|
||||
try:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
file_stats = os.stat(file_path)
|
||||
logger.info(f" File size: {file_stats.st_size} bytes")
|
||||
logger.info(f" ✅ File saved successfully: {file_path}")
|
||||
saved_files.append(file_path)
|
||||
except Exception as file_error:
|
||||
logger.error(f" ❌ Failed to save file {file.filename}: {str(file_error)}")
|
||||
raise
|
||||
|
||||
logger.info(f"=== 文件上传完成 ===")
|
||||
logger.info(f"Successfully saved {len(saved_files)} files to repository")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Files uploaded successfully",
|
||||
"saved_files": saved_files,
|
||||
"total_files": len(files)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"=== 上传文件失败 ===")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload files: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/repos/pull")
|
||||
async def pull_from_gitea_repository(
|
||||
algorithm_id: str = Body(..., description="算法ID"),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
从Gitea仓库拉取代码
|
||||
"""
|
||||
success = gitea_service.pull_from_repository(algorithm_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to pull code")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Code pulled successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/repos/update")
|
||||
async def update_gitea_repository(
|
||||
algorithm_id: str = Body(..., description="算法ID"),
|
||||
name: Optional[str] = Body(None, description="新的仓库名称"),
|
||||
description: Optional[str] = Body(None, description="新的仓库描述"),
|
||||
private: Optional[bool] = Body(None, description="是否私有"),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
更新Gitea仓库信息
|
||||
"""
|
||||
updated_repo = gitea_service.update_repository_info(algorithm_id, name, description, private)
|
||||
if not updated_repo:
|
||||
raise HTTPException(status_code=500, detail="Failed to update repository info")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Repository info updated successfully",
|
||||
"repository": updated_repo
|
||||
}
|
||||
|
||||
|
||||
@router.post("/repos/register")
|
||||
async def register_algorithm_from_repository(
|
||||
repo_owner: str = Body(..., description="仓库所有者"),
|
||||
repo_name: str = Body(..., description="仓库名称"),
|
||||
algorithm_id: str = Body(..., description="算法ID"),
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
从仓库注册算法服务
|
||||
"""
|
||||
success = gitea_service.register_algorithm_from_repo(repo_owner, repo_name, algorithm_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to register algorithm from repository")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Algorithm registered from repository successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/repos/{repo_owner}/{repo_name}")
|
||||
async def get_gitea_repository_info(
|
||||
repo_owner: str,
|
||||
repo_name: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
获取仓库信息
|
||||
"""
|
||||
repo = gitea_service.get_repository_info(repo_owner, repo_name)
|
||||
if not repo:
|
||||
raise HTTPException(status_code=404, detail="Repository not found")
|
||||
|
||||
return repo
|
||||
240
backend/app/routes/history.py
Normal file
240
backend/app/routes/history.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""历史记录管理路由,提供调用历史查询、统计和导出功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.services.history_manager import history_manager
|
||||
from app.models.database import get_db
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/history", tags=["history"])
|
||||
|
||||
|
||||
@router.get("/user-calls")
|
||||
async def get_user_call_history(
|
||||
algorithm_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取用户的调用历史"""
|
||||
# 解析日期参数
|
||||
start_dt = None
|
||||
end_dt = None
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid start_date format")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid end_date format")
|
||||
|
||||
# 普通用户只能查看自己的历史,管理员可以查看所有用户历史
|
||||
user_id = current_user.get("id")
|
||||
if current_user.get("role") == "admin":
|
||||
# 管理员可以指定用户ID,否则查看所有用户
|
||||
user_id = None # 这样会返回所有用户的记录
|
||||
|
||||
history = history_manager.get_user_call_history(
|
||||
db=db,
|
||||
user_id=user_id or current_user.get("id"),
|
||||
algorithm_id=algorithm_id,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
status=status,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"history": [call.__dict__ for call in history],
|
||||
"count": len(history),
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
@router.get("/algorithm-calls/{algorithm_id}")
|
||||
async def get_algorithm_call_history(
|
||||
algorithm_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取特定算法的调用历史"""
|
||||
# 验证权限:用户必须有权访问该算法
|
||||
# 在实际实现中,这里应该检查用户是否有权访问该算法
|
||||
# 为简化,我们只检查是否为管理员或查看自己的记录
|
||||
|
||||
# 解析日期参数
|
||||
start_dt = None
|
||||
end_dt = None
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid start_date format")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid end_date format")
|
||||
|
||||
history = history_manager.get_algorithm_call_history(
|
||||
db=db,
|
||||
algorithm_id=algorithm_id,
|
||||
user_id=user_id,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
status=status,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"history": [call.__dict__ for call in history],
|
||||
"count": len(history),
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
@router.get("/statistics")
|
||||
async def get_call_statistics(
|
||||
user_id: Optional[str] = None,
|
||||
algorithm_id: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取调用统计信息"""
|
||||
# 权限检查
|
||||
if current_user.get("role") != "admin":
|
||||
# 普通用户只能查看自己的统计
|
||||
if user_id and user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
user_id = current_user.get("id")
|
||||
|
||||
stats = history_manager.get_call_statistics(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
algorithm_id=algorithm_id
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/compare")
|
||||
async def get_comparison_data(
|
||||
call_ids: List[str],
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取用于对比的历史数据"""
|
||||
# 权限检查:用户只能对比自己的调用记录
|
||||
# 获取调用记录
|
||||
calls = db.query(AlgorithmCall).filter(AlgorithmCall.id.in_(call_ids)).all()
|
||||
|
||||
# 检查权限:用户只能对比自己的记录
|
||||
for call in calls:
|
||||
if call.user_id != current_user.get("id") and current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions to access call data")
|
||||
|
||||
comparison_data = history_manager.get_comparison_data(db, call_ids)
|
||||
|
||||
return {
|
||||
"comparison_data": comparison_data,
|
||||
"count": len(comparison_data)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
async def export_history(
|
||||
algorithm_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
format_type: str = "json",
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""导出历史记录"""
|
||||
# 解析日期参数
|
||||
start_dt = None
|
||||
end_dt = None
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid start_date format")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid end_date format")
|
||||
|
||||
file_path = history_manager.export_history(
|
||||
db=db,
|
||||
user_id=current_user.get("id"),
|
||||
algorithm_id=algorithm_id,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
format_type=format_type
|
||||
)
|
||||
|
||||
if file_path:
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": file_path,
|
||||
"download_url": f"/api/files/download/{file_path}",
|
||||
"message": "History exported successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to export history")
|
||||
|
||||
|
||||
@router.delete("/cleanup")
|
||||
async def cleanup_old_history(
|
||||
days_old: int,
|
||||
algorithm_id: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""清理旧的历史记录"""
|
||||
# 只有管理员可以清理历史记录
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only administrators can clean up history")
|
||||
|
||||
# 确保天数为正数
|
||||
if days_old <= 0:
|
||||
raise HTTPException(status_code=400, detail="days_old must be positive")
|
||||
|
||||
deleted_count = history_manager.delete_old_history(
|
||||
db=db,
|
||||
days_old=days_old,
|
||||
algorithm_id=algorithm_id
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Cleaned up {deleted_count} old history records",
|
||||
"deleted_count": deleted_count
|
||||
}
|
||||
|
||||
|
||||
# 导入需要的模型
|
||||
from app.models.models import AlgorithmCall
|
||||
345
backend/app/routes/monitoring.py
Normal file
345
backend/app/routes/monitoring.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""监控与日志路由,提供系统监控、指标收集和日志查询功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.services.monitoring import monitoring_service
|
||||
from app.utils.logger import structured_logger, log_query
|
||||
from app.models.database import get_db
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/monitoring", tags=["monitoring"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def get_system_health():
|
||||
"""获取系统健康状况"""
|
||||
health = monitoring_service.get_system_health()
|
||||
return health
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
async def get_dashboard_data(
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取仪表板数据"""
|
||||
# 只有管理员可以访问仪表板
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
dashboard_data = monitoring_service.get_dashboard_data(db)
|
||||
return dashboard_data
|
||||
|
||||
|
||||
@router.get("/metrics/system")
|
||||
async def get_system_metrics(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取系统指标"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
from app.services.monitoring import MetricsCollector
|
||||
collector = MetricsCollector()
|
||||
metrics = collector.collect_system_metrics()
|
||||
return metrics
|
||||
|
||||
|
||||
@router.get("/metrics/business")
|
||||
async def get_business_metrics(
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取业务指标"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
from app.services.monitoring import MetricsCollector
|
||||
collector = MetricsCollector()
|
||||
metrics = collector.collect_business_metrics(db)
|
||||
return metrics
|
||||
|
||||
|
||||
@router.get("/metrics/history")
|
||||
async def get_metrics_history(
|
||||
metric_type: str = "system",
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取指标历史"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
if metric_type not in ["system", "business"]:
|
||||
raise HTTPException(status_code=400, detail="Invalid metric type. Use 'system' or 'business'")
|
||||
|
||||
from app.services.monitoring import MetricsCollector
|
||||
collector = MetricsCollector()
|
||||
history = collector.get_metric_history(metric_type, limit)
|
||||
return {"history": history}
|
||||
|
||||
|
||||
@router.get("/alerts/active")
|
||||
async def get_active_alerts(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取当前激活的告警"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
active_alerts = monitoring_service.alert_manager.get_active_alerts()
|
||||
return {"active_alerts": active_alerts}
|
||||
|
||||
|
||||
@router.get("/alerts/history")
|
||||
async def get_alert_history(
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取告警历史"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
history = monitoring_service.alert_manager.get_alert_history(limit)
|
||||
return {"alert_history": history}
|
||||
|
||||
|
||||
@router.post("/monitoring/start")
|
||||
async def start_monitoring(
|
||||
interval: int = 60,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""启动监控"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 注意:在实际应用中,我们不会在这里启动一个长时间运行的协程
|
||||
# 这通常会在应用启动时完成
|
||||
# 这里仅作为示例返回确认信息
|
||||
return {
|
||||
"message": "Monitoring started",
|
||||
"interval": interval,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/monitoring/stop")
|
||||
async def stop_monitoring(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""停止监控"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
await monitoring_service.stop_monitoring()
|
||||
return {
|
||||
"message": "Monitoring stopped",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logs/event")
|
||||
async def log_custom_event(
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
algorithm_id: Optional[str] = None,
|
||||
extra_data: Dict[str, Any] = {},
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""记录自定义事件日志"""
|
||||
# 普通用户只能记录自己的事件
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
if user_id and user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Cannot log events for other users")
|
||||
user_id = current_user.get("id")
|
||||
|
||||
structured_logger.log_event(
|
||||
event_type=event_type,
|
||||
user_id=user_id,
|
||||
algorithm_id=algorithm_id,
|
||||
extra_data=extra_data
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Event logged successfully",
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logs/api-call")
|
||||
async def log_api_call(
|
||||
user_id: str,
|
||||
algorithm_id: str,
|
||||
version_id: str,
|
||||
input_size: int,
|
||||
response_time: float,
|
||||
success: bool,
|
||||
error_msg: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""记录API调用日志"""
|
||||
# 管理员或用户自己可以记录日志
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
if user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Cannot log API calls for other users")
|
||||
|
||||
structured_logger.log_api_call(
|
||||
user_id=user_id,
|
||||
algorithm_id=algorithm_id,
|
||||
version_id=version_id,
|
||||
input_size=input_size,
|
||||
response_time=response_time,
|
||||
success=success,
|
||||
error_msg=error_msg
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "API call logged successfully",
|
||||
"success": success,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs/search")
|
||||
async def search_logs(
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
event_types: Optional[str] = None, # 逗号分隔的事件类型
|
||||
user_ids: Optional[str] = None, # 逗号分隔的用户ID
|
||||
algorithm_ids: Optional[str] = None, # 逗号分隔的算法ID
|
||||
log_levels: Optional[str] = None, # 逗号分隔的日志级别
|
||||
limit: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""搜索日志"""
|
||||
# 普通用户只能搜索自己的日志
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
# 如果指定了其他用户ID,则只允许查看自己的
|
||||
if user_ids:
|
||||
user_id_list = user_ids.split(',')
|
||||
if current_user.get("id") not in user_id_list:
|
||||
raise HTTPException(status_code=403, detail="Cannot search logs for other users")
|
||||
else:
|
||||
user_ids = current_user.get("id")
|
||||
|
||||
# 解析日期
|
||||
start_dt = None
|
||||
end_dt = None
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid start_date format")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid end_date format")
|
||||
|
||||
# 解析数组参数
|
||||
event_type_list = event_types.split(',') if event_types else None
|
||||
user_id_list = user_ids.split(',') if user_ids else None
|
||||
algorithm_id_list = algorithm_ids.split(',') if algorithm_ids else None
|
||||
log_level_list = log_levels.split(',') if log_levels else None
|
||||
|
||||
# 执行搜索
|
||||
results = log_query.search_logs(
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
event_types=event_type_list,
|
||||
user_ids=user_id_list,
|
||||
algorithm_ids=algorithm_id_list,
|
||||
log_levels=log_level_list,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"logs": results,
|
||||
"count": len(results),
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs/stats")
|
||||
async def get_log_stats(
|
||||
days: int = 7,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取日志统计信息"""
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
stats = log_query.get_log_stats(days=days)
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/performance/algorithm/{algorithm_id}")
|
||||
async def get_algorithm_performance(
|
||||
algorithm_id: str,
|
||||
days: int = 7,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取算法性能指标"""
|
||||
# 用户只能查看自己有权访问的算法
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
# 这里应该检查用户是否有权访问该算法
|
||||
# 简单起见,我们假设用户可以查看任何算法
|
||||
pass
|
||||
|
||||
from sqlalchemy import func
|
||||
from app.models.models import AlgorithmCall
|
||||
|
||||
# 计算性能指标
|
||||
start_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# 总调用次数
|
||||
total_calls = db.query(func.count(AlgorithmCall.id)).filter(
|
||||
AlgorithmCall.algorithm_id == algorithm_id,
|
||||
AlgorithmCall.created_at >= start_date
|
||||
).scalar()
|
||||
|
||||
# 成功调用次数
|
||||
success_calls = db.query(func.count(AlgorithmCall.id)).filter(
|
||||
AlgorithmCall.algorithm_id == algorithm_id,
|
||||
AlgorithmCall.status == 'success',
|
||||
AlgorithmCall.created_at >= start_date
|
||||
).scalar()
|
||||
|
||||
# 平均响应时间
|
||||
avg_response_time = db.query(func.avg(AlgorithmCall.response_time)).filter(
|
||||
AlgorithmCall.algorithm_id == algorithm_id,
|
||||
AlgorithmCall.response_time.isnot(None),
|
||||
AlgorithmCall.created_at >= start_date
|
||||
).scalar()
|
||||
|
||||
# 按状态分组
|
||||
status_counts = db.query(
|
||||
AlgorithmCall.status,
|
||||
func.count(AlgorithmCall.id)
|
||||
).filter(
|
||||
AlgorithmCall.algorithm_id == algorithm_id,
|
||||
AlgorithmCall.created_at >= start_date
|
||||
).group_by(AlgorithmCall.status).all()
|
||||
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
success_rate = (success_calls / total_calls * 100) if total_calls > 0 else 0
|
||||
|
||||
return {
|
||||
"algorithm_id": algorithm_id,
|
||||
"period_days": days,
|
||||
"total_calls": total_calls,
|
||||
"success_calls": success_calls,
|
||||
"failed_calls": total_calls - success_calls,
|
||||
"success_rate": round(success_rate, 2),
|
||||
"average_response_time": round(avg_response_time, 3) if avg_response_time else None,
|
||||
"status_distribution": status_dict,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
42
backend/app/routes/openai.py
Normal file
42
backend/app/routes/openai.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Optional
|
||||
|
||||
from app.utils.openai import openai_client
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/openai", tags=["openai"])
|
||||
|
||||
|
||||
@router.post("/generate-data", response_model=dict)
|
||||
async def generate_simulation_data(
|
||||
prompt: str,
|
||||
data_type: str = "text",
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""生成仿真输入数据"""
|
||||
# 验证数据类型
|
||||
valid_types = ["text", "image", "structured"]
|
||||
if data_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid data type. Valid types are: {', '.join(valid_types)}")
|
||||
|
||||
# 生成数据
|
||||
result = openai_client.generate_simulation_data(prompt, data_type)
|
||||
if not result:
|
||||
raise HTTPException(status_code=500, detail="Failed to generate data from OpenAI")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/describe-image", response_model=dict)
|
||||
async def generate_image_description(
|
||||
image_url: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""生成图片描述"""
|
||||
# 生成图片描述
|
||||
description = openai_client.generate_image_description(image_url)
|
||||
if not description:
|
||||
raise HTTPException(status_code=500, detail="Failed to generate image description from OpenAI")
|
||||
|
||||
return {"description": description}
|
||||
264
backend/app/routes/permissions.py
Normal file
264
backend/app/routes/permissions.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""权限管理路由,提供算法访问权限和用户权限管理功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.permission import (
|
||||
permission_manager, rbac_manager,
|
||||
AccessLevel, PermissionType
|
||||
)
|
||||
from app.models.database import get_db
|
||||
from app.routes.user import get_current_active_user
|
||||
|
||||
router = APIRouter(prefix="/permissions", tags=["permissions"])
|
||||
|
||||
|
||||
class GrantPermissionRequest(BaseModel):
|
||||
"""授予权限请求"""
|
||||
user_id: str
|
||||
algorithm_id: str
|
||||
access_level: str # 使用字符串,稍后转换为AccessLevel
|
||||
|
||||
|
||||
class CheckPermissionRequest(BaseModel):
|
||||
"""检查权限请求"""
|
||||
algorithm_id: str
|
||||
permission_type: str # 使用字符串,稍后转换为PermissionType
|
||||
|
||||
|
||||
class RevokePermissionRequest(BaseModel):
|
||||
"""撤销权限请求"""
|
||||
user_id: str
|
||||
algorithm_id: str
|
||||
|
||||
|
||||
@router.post("/grant")
|
||||
async def grant_permission(
|
||||
request: GrantPermissionRequest,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""授予用户对算法的权限"""
|
||||
# 只有管理员和经理可以授予权限
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions to grant permissions")
|
||||
|
||||
# 验证访问级别
|
||||
try:
|
||||
access_level = AccessLevel(request.access_level)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid access level. Valid levels: {[level.value for level in AccessLevel]}")
|
||||
|
||||
success = permission_manager.grant_permission(
|
||||
db, current_user.get("id"), request.user_id,
|
||||
request.algorithm_id, access_level
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"message": "Permission granted successfully",
|
||||
"user_id": request.user_id,
|
||||
"algorithm_id": request.algorithm_id,
|
||||
"access_level": request.access_level
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to grant permission")
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke_permission(
|
||||
request: RevokePermissionRequest,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""撤销用户对算法的权限"""
|
||||
# 只有管理员和经理可以撤销权限
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions to revoke permissions")
|
||||
|
||||
success = permission_manager.revoke_permission(
|
||||
db, current_user.get("id"), request.user_id, request.algorithm_id
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"message": "Permission revoked successfully",
|
||||
"user_id": request.user_id,
|
||||
"algorithm_id": request.algorithm_id
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to revoke permission")
|
||||
|
||||
|
||||
@router.post("/check")
|
||||
async def check_permission(
|
||||
request: CheckPermissionRequest,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""检查用户对算法的权限"""
|
||||
# 验证权限类型
|
||||
try:
|
||||
permission_type = PermissionType(request.permission_type)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid permission type. Valid types: {[ptype.value for ptype in PermissionType]}")
|
||||
|
||||
has_permission = permission_manager.check_algorithm_access(
|
||||
db, current_user.get("id"), request.algorithm_id, permission_type
|
||||
)
|
||||
|
||||
return {
|
||||
"has_permission": has_permission,
|
||||
"user_id": current_user.get("id"),
|
||||
"algorithm_id": request.algorithm_id,
|
||||
"permission_type": request.permission_type
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user/{user_id}")
|
||||
async def get_user_permissions(
|
||||
user_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取用户的权限列表"""
|
||||
# 用户只能查看自己的权限,管理员可以查看任何用户权限
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
if user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Cannot view permissions for other users")
|
||||
|
||||
permissions = permission_manager.get_user_permissions(db, user_id)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"permissions": permissions,
|
||||
"count": len(permissions)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/algorithm/{algorithm_id}")
|
||||
async def get_algorithm_permissions(
|
||||
algorithm_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取算法的权限分配情况"""
|
||||
# 检查用户是否有权限查看算法权限
|
||||
can_read = permission_manager.check_algorithm_access(
|
||||
db, current_user.get("id"), algorithm_id, PermissionType.READ
|
||||
)
|
||||
|
||||
if not can_read and current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions to view algorithm permissions")
|
||||
|
||||
permissions = permission_manager.get_algorithm_permissions(db, algorithm_id)
|
||||
|
||||
return {
|
||||
"algorithm_id": algorithm_id,
|
||||
"permissions": permissions,
|
||||
"count": len(permissions)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/role/{role_name}")
|
||||
async def get_role_permissions(
|
||||
role_name: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取角色的权限列表"""
|
||||
# 所有用户都可以查看角色权限
|
||||
permissions = rbac_manager.get_role_permissions(role_name)
|
||||
|
||||
if not permissions:
|
||||
raise HTTPException(status_code=404, detail="Role not found")
|
||||
|
||||
return {
|
||||
"role": role_name,
|
||||
"permissions": [perm.value for perm in permissions]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/validate-operation")
|
||||
async def validate_user_algorithm_operation(
|
||||
algorithm_id: str,
|
||||
operation: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""验证用户对算法的操作权限"""
|
||||
is_valid = permission_manager.validate_user_algorithm_operation(
|
||||
db, current_user.get("id"), algorithm_id, operation
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": current_user.get("id"),
|
||||
"algorithm_id": algorithm_id,
|
||||
"operation": operation,
|
||||
"has_permission": is_valid
|
||||
}
|
||||
|
||||
|
||||
@router.get("/my-permissions")
|
||||
async def get_my_permissions(
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取当前用户的权限"""
|
||||
permissions = permission_manager.get_user_permissions(db, current_user.get("id"))
|
||||
|
||||
return {
|
||||
"user_id": current_user.get("id"),
|
||||
"username": current_user.get("username"),
|
||||
"role": current_user.get("role"),
|
||||
"permissions": permissions,
|
||||
"count": len(permissions)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user-role-permissions/{user_id}")
|
||||
async def get_user_role_based_permissions(
|
||||
user_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""获取用户的基于角色的权限(而非具体算法权限)"""
|
||||
# 用户只能查看自己的权限,管理员可以查看任何用户权限
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
if user_id != current_user.get("id"):
|
||||
raise HTTPException(status_code=403, detail="Cannot view permissions for other users")
|
||||
|
||||
# 获取用户角色
|
||||
from app.models.models import User
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
role_permissions = rbac_manager.get_role_permissions(user.role)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"role": user.role,
|
||||
"role_permissions": [perm.value for perm in role_permissions]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/check-api-key-access")
|
||||
async def check_api_key_access(
|
||||
api_key_value: str,
|
||||
algorithm_id: str,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""检查API密钥对算法的访问权限"""
|
||||
# 只有管理员可以检查任意API密钥的权限
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only admins can check API key access")
|
||||
|
||||
has_access = permission_manager.check_api_key_access(db, api_key_value, algorithm_id)
|
||||
|
||||
return {
|
||||
"api_key_valid": True, # 如果到达这里,说明API密钥存在且活跃
|
||||
"has_algorithm_access": has_access,
|
||||
"algorithm_id": algorithm_id
|
||||
}
|
||||
280
backend/app/routes/repositories.py
Normal file
280
backend/app/routes/repositories.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""算法仓库管理路由,提供仓库添加、列表、删除等功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
|
||||
from app.models.models import AlgorithmRepository
|
||||
from app.models.database import SessionLocal
|
||||
from app.routes.user import get_current_active_user
|
||||
from app.gitea.service import gitea_service
|
||||
|
||||
router = APIRouter(prefix="/repositories", tags=["repositories"])
|
||||
|
||||
|
||||
class CreateRepositoryRequest(BaseModel):
|
||||
"""创建仓库请求"""
|
||||
name: str
|
||||
description: str
|
||||
type: str = "code"
|
||||
repo_url: str
|
||||
branch: str = "main"
|
||||
local_path: str = ""
|
||||
algorithm_id: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateRepositoryRequest(BaseModel):
|
||||
"""更新仓库请求"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
repo_url: Optional[str] = None
|
||||
branch: Optional[str] = None
|
||||
local_path: Optional[str] = None
|
||||
algorithm_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_repository(
|
||||
request: CreateRepositoryRequest,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""创建算法仓库"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 生成唯一ID
|
||||
repo_id = str(uuid.uuid4())
|
||||
|
||||
# 创建仓库实例
|
||||
repo = AlgorithmRepository(
|
||||
id=repo_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
type=request.type,
|
||||
repo_url=request.repo_url,
|
||||
branch=request.branch,
|
||||
local_path=request.local_path,
|
||||
algorithm_id=request.algorithm_id
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(repo)
|
||||
db.commit()
|
||||
db.refresh(repo)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Repository created successfully",
|
||||
"repository": {
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"description": repo.description,
|
||||
"type": repo.type,
|
||||
"repo_url": repo.repo_url,
|
||||
"branch": repo.branch,
|
||||
"local_path": repo.local_path,
|
||||
"algorithm_id": repo.algorithm_id,
|
||||
"status": repo.status,
|
||||
"created_at": repo.created_at,
|
||||
"updated_at": repo.updated_at
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_repositories(
|
||||
algorithm_id: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取算法仓库列表"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询仓库列表
|
||||
query = db.query(AlgorithmRepository)
|
||||
|
||||
# 如果指定了算法ID,只返回该算法的仓库
|
||||
if algorithm_id:
|
||||
query = query.filter(AlgorithmRepository.algorithm_id == algorithm_id)
|
||||
|
||||
repos = query.all()
|
||||
|
||||
# 转换为字典列表
|
||||
repo_list = []
|
||||
for repo in repos:
|
||||
repo_list.append({
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"description": repo.description,
|
||||
"type": repo.type,
|
||||
"repo_url": repo.repo_url,
|
||||
"branch": repo.branch,
|
||||
"local_path": repo.local_path,
|
||||
"algorithm_id": repo.algorithm_id,
|
||||
"status": repo.status,
|
||||
"created_at": repo.created_at,
|
||||
"updated_at": repo.updated_at
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"repositories": repo_list
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/{repo_id}")
|
||||
async def get_repository(
|
||||
repo_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取单个算法仓库"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询仓库
|
||||
repo = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repo_id).first()
|
||||
|
||||
if not repo:
|
||||
raise HTTPException(status_code=404, detail="Repository not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"repository": {
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"description": repo.description,
|
||||
"type": repo.type,
|
||||
"repo_url": repo.repo_url,
|
||||
"branch": repo.branch,
|
||||
"local_path": repo.local_path,
|
||||
"algorithm_id": repo.algorithm_id,
|
||||
"status": repo.status,
|
||||
"created_at": repo.created_at,
|
||||
"updated_at": repo.updated_at
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.put("/{repo_id}")
|
||||
async def update_repository(
|
||||
repo_id: str,
|
||||
request: UpdateRepositoryRequest,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""更新算法仓库"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询仓库
|
||||
repo = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repo_id).first()
|
||||
|
||||
if not repo:
|
||||
raise HTTPException(status_code=404, detail="Repository not found")
|
||||
|
||||
# 更新仓库信息
|
||||
if request.name is not None:
|
||||
repo.name = request.name
|
||||
if request.description is not None:
|
||||
repo.description = request.description
|
||||
if request.type is not None:
|
||||
repo.type = request.type
|
||||
if request.repo_url is not None:
|
||||
repo.repo_url = request.repo_url
|
||||
if request.branch is not None:
|
||||
repo.branch = request.branch
|
||||
if request.local_path is not None:
|
||||
repo.local_path = request.local_path
|
||||
if request.algorithm_id is not None:
|
||||
repo.algorithm_id = request.algorithm_id
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(repo)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Repository updated successfully",
|
||||
"repository": {
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"description": repo.description,
|
||||
"type": repo.type,
|
||||
"repo_url": repo.repo_url,
|
||||
"branch": repo.branch,
|
||||
"local_path": repo.local_path,
|
||||
"algorithm_id": repo.algorithm_id,
|
||||
"status": repo.status,
|
||||
"created_at": repo.created_at,
|
||||
"updated_at": repo.updated_at
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/{repo_id}")
|
||||
async def delete_repository(
|
||||
repo_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除算法仓库"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询仓库
|
||||
repo = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repo_id).first()
|
||||
|
||||
if not repo:
|
||||
raise HTTPException(status_code=404, detail="Repository not found")
|
||||
|
||||
# 先删除Gitea仓库
|
||||
gitea_deleted = False
|
||||
if repo.repo_url:
|
||||
# 从repo_url中提取仓库名称
|
||||
import os
|
||||
repo_name = os.path.basename(repo.repo_url).replace('.git', '')
|
||||
gitea_deleted = gitea_service.delete_repository(repo_name)
|
||||
if gitea_deleted:
|
||||
print(f"Gitea repository deleted successfully: {repo_name}")
|
||||
else:
|
||||
print(f"Failed to delete Gitea repository: {repo_name}")
|
||||
|
||||
# 删除系统仓库数据
|
||||
db.delete(repo)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Repository deleted successfully",
|
||||
"gitea_deleted": gitea_deleted
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
569
backend/app/routes/services.py
Normal file
569
backend/app/routes/services.py
Normal file
@@ -0,0 +1,569 @@
|
||||
"""算法服务管理路由,提供服务注册、管理等功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from app.models.models import AlgorithmService
|
||||
from app.models.database import SessionLocal
|
||||
from app.routes.user import get_current_active_user
|
||||
from app.services.project_analyzer import ProjectAnalyzer
|
||||
from app.services.service_generator import ServiceGenerator
|
||||
from app.services.service_orchestrator import ServiceOrchestrator
|
||||
|
||||
router = APIRouter(prefix="/services", tags=["services"])
|
||||
|
||||
|
||||
class RegisterServiceRequest(BaseModel):
|
||||
"""注册服务请求"""
|
||||
repository_id: str
|
||||
name: str
|
||||
version: str = "1.0.0"
|
||||
service_type: str = "http"
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
timeout: int = 30
|
||||
health_check_path: str = "/health"
|
||||
environment: Dict[str, str] = {}
|
||||
|
||||
|
||||
class ServiceResponse(BaseModel):
|
||||
"""服务响应"""
|
||||
id: str
|
||||
service_id: str
|
||||
name: str
|
||||
algorithm_name: str
|
||||
version: str
|
||||
host: str
|
||||
port: int
|
||||
api_url: str
|
||||
status: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class ServiceListResponse(BaseModel):
|
||||
"""服务列表响应"""
|
||||
success: bool
|
||||
services: List[ServiceResponse]
|
||||
|
||||
|
||||
class ServiceDetailResponse(BaseModel):
|
||||
"""服务详情响应"""
|
||||
success: bool
|
||||
service: ServiceResponse
|
||||
|
||||
|
||||
class ServiceOperationResponse(BaseModel):
|
||||
"""服务操作响应"""
|
||||
success: bool
|
||||
message: str
|
||||
service_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class ServiceStatusResponse(BaseModel):
|
||||
"""服务状态响应"""
|
||||
success: bool
|
||||
status: str
|
||||
health: str
|
||||
|
||||
|
||||
class ServiceLogsResponse(BaseModel):
|
||||
"""服务日志响应"""
|
||||
success: bool
|
||||
logs: List[str]
|
||||
|
||||
|
||||
class RepositoryAlgorithmsResponse(BaseModel):
|
||||
"""仓库算法列表响应"""
|
||||
success: bool
|
||||
algorithms: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# 初始化服务组件
|
||||
project_analyzer = ProjectAnalyzer()
|
||||
service_generator = ServiceGenerator()
|
||||
service_orchestrator = ServiceOrchestrator()
|
||||
|
||||
|
||||
@router.post("/register", status_code=status.HTTP_201_CREATED)
|
||||
async def register_service(
|
||||
request: RegisterServiceRequest,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""注册新服务"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 1. 获取仓库信息
|
||||
# 注意:在实际实现中,应该从数据库中获取仓库信息
|
||||
# 这里简化处理,假设仓库存在
|
||||
|
||||
# 2. 分析项目
|
||||
repo_path = f"/tmp/repository_{request.repository_id}"
|
||||
# 注意:在实际实现中,应该从算法仓库中获取项目文件
|
||||
# 这里简化处理,创建一个模拟的项目结构
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
|
||||
# 创建模拟的算法文件
|
||||
with open(os.path.join(repo_path, "algorithm.py"), "w") as f:
|
||||
f.write("""
|
||||
def predict(data):
|
||||
return {"result": "Prediction result", "input": data}
|
||||
|
||||
def run(data):
|
||||
return {"result": "Run result", "input": data}
|
||||
|
||||
def main(data):
|
||||
return {"result": "Main result", "input": data}
|
||||
""")
|
||||
|
||||
# 分析项目
|
||||
project_info = project_analyzer.analyze_project(repo_path)
|
||||
if not project_info["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"项目分析失败: {project_info['error']}")
|
||||
|
||||
# 3. 生成服务包装器
|
||||
service_config = {
|
||||
"name": request.name,
|
||||
"version": request.version,
|
||||
"service_type": request.service_type,
|
||||
"host": request.host,
|
||||
"port": request.port,
|
||||
"timeout": request.timeout,
|
||||
"health_check_path": request.health_check_path,
|
||||
"environment": request.environment
|
||||
}
|
||||
|
||||
generate_result = service_generator.generate_service(project_info, service_config)
|
||||
if not generate_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务生成失败: {generate_result['error']}")
|
||||
|
||||
# 4. 部署服务
|
||||
service_id = str(uuid.uuid4())
|
||||
deploy_result = service_orchestrator.deploy_service(service_id, service_config, project_info)
|
||||
if not deploy_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务部署失败: {deploy_result['error']}")
|
||||
|
||||
# 5. 保存服务信息到数据库
|
||||
new_service = AlgorithmService(
|
||||
id=str(uuid.uuid4()),
|
||||
service_id=service_id,
|
||||
name=request.name,
|
||||
algorithm_name="algorithm", # 注意:在实际实现中,应该从仓库信息中获取
|
||||
version=request.version,
|
||||
host=request.host,
|
||||
port=request.port,
|
||||
api_url=deploy_result["api_url"],
|
||||
status=deploy_result["status"],
|
||||
config={
|
||||
"service_type": request.service_type,
|
||||
"timeout": request.timeout,
|
||||
"health_check_path": request.health_check_path,
|
||||
"environment": request.environment,
|
||||
"container_id": deploy_result["container_id"]
|
||||
}
|
||||
)
|
||||
|
||||
db.add(new_service)
|
||||
db.commit()
|
||||
db.refresh(new_service)
|
||||
|
||||
# 6. 返回响应
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务注册成功",
|
||||
"service": {
|
||||
"id": new_service.id,
|
||||
"service_id": new_service.service_id,
|
||||
"name": new_service.name,
|
||||
"algorithm_name": new_service.algorithm_name,
|
||||
"version": new_service.version,
|
||||
"host": new_service.host,
|
||||
"port": new_service.port,
|
||||
"api_url": new_service.api_url,
|
||||
"status": new_service.status,
|
||||
"created_at": new_service.created_at.isoformat(),
|
||||
"updated_at": new_service.updated_at.isoformat()
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("", response_model=ServiceListResponse)
|
||||
async def list_services(
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务列表"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务列表
|
||||
services = db.query(AlgorithmService).all()
|
||||
|
||||
# 转换为响应格式
|
||||
service_list = []
|
||||
for service in services:
|
||||
service_list.append(ServiceResponse(
|
||||
id=service.id,
|
||||
service_id=service.service_id,
|
||||
name=service.name,
|
||||
algorithm_name=service.algorithm_name,
|
||||
version=service.version,
|
||||
host=service.host,
|
||||
port=service.port,
|
||||
api_url=service.api_url,
|
||||
status=service.status,
|
||||
created_at=service.created_at.isoformat(),
|
||||
updated_at=service.updated_at.isoformat()
|
||||
))
|
||||
|
||||
return ServiceListResponse(
|
||||
success=True,
|
||||
services=service_list
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/{service_id}", response_model=ServiceDetailResponse)
|
||||
async def get_service(
|
||||
service_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务详情"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 返回响应
|
||||
return ServiceDetailResponse(
|
||||
success=True,
|
||||
service=ServiceResponse(
|
||||
id=service.id,
|
||||
service_id=service.service_id,
|
||||
name=service.name,
|
||||
algorithm_name=service.algorithm_name,
|
||||
version=service.version,
|
||||
host=service.host,
|
||||
port=service.port,
|
||||
api_url=service.api_url,
|
||||
status=service.status,
|
||||
created_at=service.created_at.isoformat(),
|
||||
updated_at=service.updated_at.isoformat()
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/{service_id}/start")
|
||||
async def start_service(
|
||||
service_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""启动服务"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 获取容器ID
|
||||
container_id = service.config.get("container_id")
|
||||
if not container_id:
|
||||
raise HTTPException(status_code=400, detail="Container ID not found")
|
||||
|
||||
# 启动服务
|
||||
start_result = service_orchestrator.start_service(service_id, container_id)
|
||||
if not start_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务启动失败: {start_result['error']}")
|
||||
|
||||
# 更新服务状态
|
||||
service.status = start_result["status"]
|
||||
db.commit()
|
||||
|
||||
# 返回响应
|
||||
return ServiceOperationResponse(
|
||||
success=True,
|
||||
message="服务启动成功",
|
||||
service_id=service_id,
|
||||
status=start_result["status"]
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/{service_id}/stop")
|
||||
async def stop_service(
|
||||
service_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""停止服务"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 获取容器ID
|
||||
container_id = service.config.get("container_id")
|
||||
if not container_id:
|
||||
raise HTTPException(status_code=400, detail="Container ID not found")
|
||||
|
||||
# 停止服务
|
||||
stop_result = service_orchestrator.stop_service(service_id, container_id)
|
||||
if not stop_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务停止失败: {stop_result['error']}")
|
||||
|
||||
# 更新服务状态
|
||||
service.status = stop_result["status"]
|
||||
db.commit()
|
||||
|
||||
# 返回响应
|
||||
return ServiceOperationResponse(
|
||||
success=True,
|
||||
message="服务停止成功",
|
||||
service_id=service_id,
|
||||
status=stop_result["status"]
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/{service_id}/restart")
|
||||
async def restart_service(
|
||||
service_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""重启服务"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 获取容器ID
|
||||
container_id = service.config.get("container_id")
|
||||
if not container_id:
|
||||
raise HTTPException(status_code=400, detail="Container ID not found")
|
||||
|
||||
# 重启服务
|
||||
restart_result = service_orchestrator.restart_service(service_id, container_id)
|
||||
if not restart_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"服务重启失败: {restart_result['error']}")
|
||||
|
||||
# 更新服务状态
|
||||
service.status = restart_result["status"]
|
||||
db.commit()
|
||||
|
||||
# 返回响应
|
||||
return ServiceOperationResponse(
|
||||
success=True,
|
||||
message="服务重启成功",
|
||||
service_id=service_id,
|
||||
status=restart_result["status"]
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/{service_id}")
|
||||
async def delete_service(
|
||||
service_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除服务"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 获取容器ID和镜像名称
|
||||
container_id = service.config.get("container_id")
|
||||
image_name = f"algorithm-service-{service_id}:{service.version}"
|
||||
|
||||
# 删除服务
|
||||
delete_result = service_orchestrator.delete_service(service_id, container_id, image_name)
|
||||
if not delete_result["success"]:
|
||||
# 继续执行,即使Docker操作失败
|
||||
pass
|
||||
|
||||
# 删除数据库记录
|
||||
db.delete(service)
|
||||
db.commit()
|
||||
|
||||
# 返回响应
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务删除成功",
|
||||
"service_id": service_id
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/{service_id}/status")
|
||||
async def get_service_status(
|
||||
service_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务状态"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 获取容器ID
|
||||
container_id = service.config.get("container_id")
|
||||
if not container_id:
|
||||
raise HTTPException(status_code=400, detail="Container ID not found")
|
||||
|
||||
# 获取服务状态
|
||||
status_result = service_orchestrator.get_service_status(container_id)
|
||||
if not status_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"获取服务状态失败: {status_result['error']}")
|
||||
|
||||
# 返回响应
|
||||
return ServiceStatusResponse(
|
||||
success=True,
|
||||
status=status_result["status"],
|
||||
health=status_result["health"]
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/{service_id}/logs")
|
||||
async def get_service_logs(
|
||||
service_id: str,
|
||||
lines: int = 100,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务日志"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询服务
|
||||
service = db.query(AlgorithmService).filter(AlgorithmService.service_id == service_id).first()
|
||||
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 获取容器ID
|
||||
container_id = service.config.get("container_id")
|
||||
if not container_id:
|
||||
raise HTTPException(status_code=400, detail="Container ID not found")
|
||||
|
||||
# 获取服务日志
|
||||
logs_result = service_orchestrator.get_service_logs(container_id, lines)
|
||||
if not logs_result["success"]:
|
||||
raise HTTPException(status_code=400, detail=f"获取服务日志失败: {logs_result['error']}")
|
||||
|
||||
# 返回响应
|
||||
return ServiceLogsResponse(
|
||||
success=True,
|
||||
logs=logs_result["logs"]
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/repository/algorithms")
|
||||
async def get_repository_algorithms(
|
||||
repository_id: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取仓库中的算法列表"""
|
||||
# 检查用户权限
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
try:
|
||||
# 模拟获取仓库中的算法列表
|
||||
# 注意:在实际实现中,应该从算法仓库中获取真实的算法列表
|
||||
algorithms = [
|
||||
{
|
||||
"id": "alg-001",
|
||||
"name": "图像分类算法",
|
||||
"description": "基于深度学习的图像分类算法",
|
||||
"type": "computer_vision",
|
||||
"entry_point": "algorithm.py"
|
||||
},
|
||||
{
|
||||
"id": "alg-002",
|
||||
"name": "文本分类算法",
|
||||
"description": "基于BERT的文本分类算法",
|
||||
"type": "nlp",
|
||||
"entry_point": "text_algorithm.py"
|
||||
}
|
||||
]
|
||||
|
||||
return RepositoryAlgorithmsResponse(
|
||||
success=True,
|
||||
algorithms=algorithms
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
129
backend/app/routes/user.py
Normal file
129
backend/app/routes/user.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserResponse, UserListResponse, Token, LoginRequest
|
||||
from app.services.user import UserService
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
# OAuth2密码Bearer
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
|
||||
|
||||
async def get_current_active_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
|
||||
"""获取当前活跃用户"""
|
||||
user = UserService.get_current_user(db, token)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if user.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return user
|
||||
|
||||
|
||||
from app.schemas.user import LoginRequest
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(login_request: LoginRequest, db: Session = Depends(get_db)):
|
||||
"""用户登录"""
|
||||
user = UserService.authenticate_user(db, login_request.username, login_request.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 创建访问令牌
|
||||
access_token = UserService.create_access_token(
|
||||
data={"sub": user.username, "user_id": user.id}
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse)
|
||||
async def register(user: UserCreate, db: Session = Depends(get_db)):
|
||||
"""用户注册"""
|
||||
# 检查用户名是否已存在
|
||||
if UserService.get_user_by_username(db, user.username):
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if UserService.get_user_by_email(db, user.email):
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
# 创建用户
|
||||
db_user = UserService.create_user(db, user)
|
||||
|
||||
return db_user
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def read_users_me(current_user: UserResponse = Depends(get_current_active_user)):
|
||||
"""获取当前用户信息"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: UserResponse = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取用户列表"""
|
||||
# 只有管理员可以查看用户列表
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
users = UserService.get_users(db, skip=skip, limit=limit)
|
||||
return {"users": users, "total": len(users)}
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
current_user: UserResponse = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取用户信息"""
|
||||
# 只有管理员或用户本人可以查看用户信息
|
||||
if current_user.role != "admin" and current_user.id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
user = UserService.get_user_by_id(db, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
user_update: UserUpdate,
|
||||
current_user: UserResponse = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新用户信息"""
|
||||
# 只有管理员或用户本人可以更新用户信息
|
||||
if current_user.role != "admin" and current_user.id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# 非管理员只能更新自己的信息,不能更新角色
|
||||
if current_user.role != "admin" and "role" in user_update.dict():
|
||||
raise HTTPException(status_code=403, detail="Cannot update role")
|
||||
|
||||
user = UserService.update_user(db, user_id, user_update)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return user
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
BIN
backend/app/schemas/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/app/schemas/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/schemas/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
backend/app/schemas/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/schemas/__pycache__/algorithm.cpython-312.pyc
Normal file
BIN
backend/app/schemas/__pycache__/algorithm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/schemas/__pycache__/algorithm.cpython-39.pyc
Normal file
BIN
backend/app/schemas/__pycache__/algorithm.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/schemas/__pycache__/user.cpython-312.pyc
Normal file
BIN
backend/app/schemas/__pycache__/user.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/schemas/__pycache__/user.cpython-39.pyc
Normal file
BIN
backend/app/schemas/__pycache__/user.cpython-39.pyc
Normal file
Binary file not shown.
116
backend/app/schemas/algorithm.py
Normal file
116
backend/app/schemas/algorithm.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class AlgorithmBase(BaseModel):
|
||||
"""算法基础模式"""
|
||||
id: Optional[str] = None
|
||||
name: str = Field(..., description="算法名称")
|
||||
description: str = Field(..., description="算法描述")
|
||||
type: str = Field(..., description="算法类型")
|
||||
|
||||
|
||||
class AlgorithmCreate(AlgorithmBase):
|
||||
"""创建算法模式"""
|
||||
# 版本相关字段
|
||||
version: str = Field(default="1.0.0", description="版本号")
|
||||
url: str = Field(default="", description="算法API地址")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="算法参数配置")
|
||||
input_schema: Dict[str, Any] = Field(default_factory=dict, description="输入数据格式")
|
||||
output_schema: Dict[str, Any] = Field(default_factory=dict, description="输出数据格式")
|
||||
code: str = Field(default="", description="Python算法代码")
|
||||
model_name: str = Field(default="", description="API训练后的模型名字")
|
||||
model_file: str = Field(default="", description="模型文件路径")
|
||||
api_doc: str = Field(default="", description="模型的API用法文档")
|
||||
|
||||
|
||||
class AlgorithmUpdate(BaseModel):
|
||||
"""更新算法模式"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class AlgorithmVersionBase(BaseModel):
|
||||
"""算法版本基础模式"""
|
||||
id: Optional[str] = None
|
||||
algorithm_id: str = Field(..., description="算法ID")
|
||||
version: str = Field(..., description="版本号")
|
||||
url: str = Field(..., description="算法API地址")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="算法参数配置")
|
||||
input_schema: Dict[str, Any] = Field(default_factory=dict, description="输入数据格式")
|
||||
output_schema: Dict[str, Any] = Field(default_factory=dict, description="输出数据格式")
|
||||
code: str = Field(default="", description="Python算法代码")
|
||||
model_name: str = Field(default="", description="API训练后的模型名字")
|
||||
model_file: str = Field(default="", description="模型文件路径")
|
||||
api_doc: str = Field(default="", description="模型的API用法文档")
|
||||
is_default: bool = Field(default=False, description="是否为默认版本")
|
||||
|
||||
|
||||
class AlgorithmVersionCreate(AlgorithmVersionBase):
|
||||
"""创建算法版本模式"""
|
||||
pass
|
||||
|
||||
|
||||
class AlgorithmVersionUpdate(BaseModel):
|
||||
"""更新算法版本模式"""
|
||||
version: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
input_schema: Optional[Dict[str, Any]] = None
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
code: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
model_file: Optional[str] = None
|
||||
api_doc: Optional[str] = None
|
||||
is_default: Optional[bool] = None
|
||||
|
||||
|
||||
class AlgorithmCallBase(BaseModel):
|
||||
"""算法调用基础模式"""
|
||||
id: Optional[str] = None
|
||||
algorithm_id: str = Field(..., description="算法ID")
|
||||
version_id: str = Field(..., description="版本ID")
|
||||
input_data: Dict[str, Any] = Field(..., description="输入数据")
|
||||
params: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
|
||||
|
||||
class AlgorithmCallCreate(AlgorithmCallBase):
|
||||
"""创建算法调用模式"""
|
||||
pass
|
||||
|
||||
|
||||
class AlgorithmCallResult(BaseModel):
|
||||
"""算法调用结果模式"""
|
||||
id: str
|
||||
status: str
|
||||
output_data: Dict[str, Any]
|
||||
response_time: Optional[float] = None
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AlgorithmResponse(AlgorithmBase):
|
||||
"""算法响应模式"""
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
versions: List["AlgorithmVersionResponse"] = []
|
||||
|
||||
|
||||
class AlgorithmVersionResponse(AlgorithmVersionBase):
|
||||
"""算法版本响应模式"""
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class AlgorithmListResponse(BaseModel):
|
||||
"""算法列表响应模式"""
|
||||
algorithms: List[AlgorithmResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# 更新前向引用
|
||||
AlgorithmResponse.model_rebuild()
|
||||
83
backend/app/schemas/user.py
Normal file
83
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""用户基础模式"""
|
||||
id: Optional[str] = None
|
||||
username: str = Field(..., description="用户名")
|
||||
email: EmailStr = Field(..., description="邮箱")
|
||||
role: str = Field(default="user", description="用户角色")
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""创建用户模式"""
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""更新用户模式"""
|
||||
username: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
role: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
"""用户响应模式"""
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""用户列表响应模式"""
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""令牌模式"""
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""令牌数据模式"""
|
||||
username: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求模式"""
|
||||
username: str = Field(..., description="用户名")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class APIKeyBase(BaseModel):
|
||||
"""API密钥基础模式"""
|
||||
id: Optional[str] = None
|
||||
user_id: str = Field(..., description="用户ID")
|
||||
name: str = Field(..., description="密钥名称")
|
||||
|
||||
|
||||
class APIKeyCreate(APIKeyBase):
|
||||
"""创建API密钥模式"""
|
||||
expires_at: datetime = Field(..., description="过期时间")
|
||||
|
||||
|
||||
class APIKeyResponse(APIKeyBase):
|
||||
"""API密钥响应模式"""
|
||||
key: str
|
||||
expires_at: datetime
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class APIKeyListResponse(BaseModel):
|
||||
"""API密钥列表响应模式"""
|
||||
api_keys: List[APIKeyResponse]
|
||||
total: int
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
BIN
backend/app/services/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/algorithm.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/algorithm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/algorithm.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/algorithm.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/data_manager.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/data_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/data_manager.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/data_manager.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/deployment.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/deployment.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/deployment.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/deployment.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/history_manager.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/history_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/history_manager.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/history_manager.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/monitoring.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/monitoring.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/monitoring.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/monitoring.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/permission.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/permission.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/permission.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/permission.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend/app/services/__pycache__/service_manager.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/service_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/service_manager.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/service_manager.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/app/services/__pycache__/user.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/user.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/user.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/user.cpython-39.pyc
Normal file
Binary file not shown.
465
backend/app/services/algorithm.py
Normal file
465
backend/app/services/algorithm.py
Normal file
@@ -0,0 +1,465 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
import requests
|
||||
import time
|
||||
|
||||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
|
||||
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
|
||||
from app.services.deployment import deployment_service
|
||||
|
||||
|
||||
class AlgorithmService:
|
||||
"""算法服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_algorithm(db: Session, algorithm: AlgorithmCreate) -> Algorithm:
|
||||
"""创建算法"""
|
||||
# 生成唯一ID
|
||||
algorithm_id = f"algorithm-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建算法实例
|
||||
db_algorithm = Algorithm(
|
||||
id=algorithm_id,
|
||||
name=algorithm.name,
|
||||
description=algorithm.description,
|
||||
type=algorithm.type
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_algorithm)
|
||||
db.commit()
|
||||
db.refresh(db_algorithm)
|
||||
|
||||
# 自动部署(如果有代码)
|
||||
deployed_url = algorithm.url
|
||||
deployment_logs = []
|
||||
if algorithm.code and not deployed_url:
|
||||
try:
|
||||
# 构建镜像
|
||||
build_result = deployment_service.build_algorithm_image(
|
||||
algorithm.name,
|
||||
algorithm.code
|
||||
)
|
||||
|
||||
if build_result['success']:
|
||||
image_name = build_result['image_name']
|
||||
deployment_logs.extend(build_result['logs'])
|
||||
|
||||
# 部署容器
|
||||
deployment_info = deployment_service.deploy_algorithm(
|
||||
algorithm_id,
|
||||
image_name
|
||||
)
|
||||
|
||||
deployed_url = deployment_info['api_url']
|
||||
deployment_logs.append(f"部署成功: {deployed_url}")
|
||||
else:
|
||||
deployment_logs.extend(build_result['logs'])
|
||||
print(f"镜像构建失败: {build_result['logs'][-1]}")
|
||||
except Exception as e:
|
||||
error_message = f"自动部署失败: {str(e)}"
|
||||
deployment_logs.append(error_message)
|
||||
print(error_message)
|
||||
|
||||
# 创建默认版本
|
||||
version_id = f"version-{uuid.uuid4().hex[:8]}"
|
||||
db_version = AlgorithmVersion(
|
||||
id=version_id,
|
||||
algorithm_id=algorithm_id,
|
||||
version=algorithm.version,
|
||||
url=deployed_url,
|
||||
params=algorithm.params,
|
||||
input_schema=algorithm.input_schema,
|
||||
output_schema=algorithm.output_schema,
|
||||
code=algorithm.code,
|
||||
model_name=algorithm.model_name,
|
||||
model_file=algorithm.model_file,
|
||||
api_doc=algorithm.api_doc,
|
||||
is_default=True
|
||||
)
|
||||
|
||||
# 保存版本到数据库
|
||||
db.add(db_version)
|
||||
db.commit()
|
||||
db.refresh(db_version)
|
||||
|
||||
# 加载版本关系
|
||||
db.refresh(db_algorithm, ['versions'])
|
||||
|
||||
return db_algorithm
|
||||
|
||||
@staticmethod
|
||||
def get_algorithm_by_id(db: Session, algorithm_id: str) -> Optional[Algorithm]:
|
||||
"""通过ID获取算法"""
|
||||
return db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
|
||||
"""获取算法列表"""
|
||||
query = db.query(Algorithm)
|
||||
|
||||
# 如果指定了算法类型,进行过滤
|
||||
if algorithm_type:
|
||||
query = query.filter(Algorithm.type == algorithm_type)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
|
||||
"""更新算法"""
|
||||
# 获取算法
|
||||
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not db_algorithm:
|
||||
return None
|
||||
|
||||
# 更新算法信息
|
||||
update_data = algorithm_update.dict(exclude_unset=True)
|
||||
|
||||
# 应用更新
|
||||
for field, value in update_data.items():
|
||||
setattr(db_algorithm, field, value)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_algorithm)
|
||||
|
||||
return db_algorithm
|
||||
|
||||
@staticmethod
|
||||
def delete_algorithm(db: Session, algorithm_id: str) -> bool:
|
||||
"""删除算法"""
|
||||
# 获取算法
|
||||
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not db_algorithm:
|
||||
return False
|
||||
|
||||
# 从数据库中删除
|
||||
db.delete(db_algorithm)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class AlgorithmVersionService:
|
||||
"""算法版本服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_version(db: Session, version: AlgorithmVersionCreate) -> AlgorithmVersion:
|
||||
"""创建算法版本"""
|
||||
# 生成唯一ID
|
||||
version_id = f"version-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建版本实例
|
||||
db_version = AlgorithmVersion(
|
||||
id=version_id,
|
||||
algorithm_id=version.algorithm_id,
|
||||
version=version.version,
|
||||
url=version.url,
|
||||
params=version.params,
|
||||
input_schema=version.input_schema,
|
||||
output_schema=version.output_schema,
|
||||
code=version.code,
|
||||
model_name=version.model_name,
|
||||
model_file=version.model_file,
|
||||
api_doc=version.api_doc,
|
||||
is_default=version.is_default
|
||||
)
|
||||
|
||||
# 如果设置为默认版本,需要将其他版本设置为非默认
|
||||
if version.is_default:
|
||||
db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == version.algorithm_id,
|
||||
AlgorithmVersion.is_default == True
|
||||
).update({"is_default": False})
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_version)
|
||||
db.commit()
|
||||
db.refresh(db_version)
|
||||
|
||||
return db_version
|
||||
|
||||
@staticmethod
|
||||
def get_version_by_id(db: Session, version_id: str) -> Optional[AlgorithmVersion]:
|
||||
"""通过ID获取版本"""
|
||||
return db.query(AlgorithmVersion).filter(AlgorithmVersion.id == version_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_versions_by_algorithm_id(db: Session, algorithm_id: str) -> List[AlgorithmVersion]:
|
||||
"""通过算法ID获取版本列表"""
|
||||
return db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_default_version(db: Session, algorithm_id: str) -> Optional[AlgorithmVersion]:
|
||||
"""获取算法的默认版本"""
|
||||
return db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == algorithm_id,
|
||||
AlgorithmVersion.is_default == True
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def update_version(db: Session, version_id: str, version_update: AlgorithmVersionUpdate) -> Optional[AlgorithmVersion]:
|
||||
"""更新算法版本"""
|
||||
# 获取版本
|
||||
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not db_version:
|
||||
return None
|
||||
|
||||
# 更新版本信息
|
||||
update_data = version_update.dict(exclude_unset=True)
|
||||
|
||||
# 如果设置为默认版本,需要将其他版本设置为非默认
|
||||
if "is_default" in update_data and update_data["is_default"]:
|
||||
db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == db_version.algorithm_id,
|
||||
AlgorithmVersion.is_default == True,
|
||||
AlgorithmVersion.id != version_id
|
||||
).update({"is_default": False})
|
||||
|
||||
# 应用更新
|
||||
for field, value in update_data.items():
|
||||
setattr(db_version, field, value)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_version)
|
||||
|
||||
return db_version
|
||||
|
||||
@staticmethod
|
||||
def delete_version(db: Session, version_id: str) -> bool:
|
||||
"""删除算法版本"""
|
||||
# 获取版本
|
||||
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not db_version:
|
||||
return False
|
||||
|
||||
# 从数据库中删除
|
||||
db.delete(db_version)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class AlgorithmCallService:
|
||||
"""算法调用服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_call(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
|
||||
"""创建算法调用记录"""
|
||||
# 生成唯一ID
|
||||
call_id = f"call-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建调用实例
|
||||
db_call = AlgorithmCall(
|
||||
id=call_id,
|
||||
user_id=user_id,
|
||||
algorithm_id=call.algorithm_id,
|
||||
version_id=call.version_id,
|
||||
input_data=call.input_data,
|
||||
params=call.params
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_call)
|
||||
db.commit()
|
||||
db.refresh(db_call)
|
||||
|
||||
return db_call
|
||||
|
||||
@staticmethod
|
||||
def execute_algorithm(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
|
||||
"""执行算法"""
|
||||
# 创建调用记录
|
||||
db_call = AlgorithmCallService.create_call(db, user_id, call)
|
||||
|
||||
# 更新状态为运行中
|
||||
db_call.status = "running"
|
||||
db.commit()
|
||||
db.refresh(db_call)
|
||||
|
||||
try:
|
||||
# 获取算法版本信息
|
||||
version = AlgorithmVersionService.get_version_by_id(db, call.version_id)
|
||||
if not version:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "算法版本不存在"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 处理视频输入数据
|
||||
processed_input_data = call.input_data.copy()
|
||||
if 'video' in processed_input_data and processed_input_data['video']:
|
||||
from app.utils.file import file_storage
|
||||
import io
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
# 从base64字符串解码视频数据
|
||||
video_data = processed_input_data['video']
|
||||
if video_data.startswith('data:'):
|
||||
# 移除data URL前缀
|
||||
header, encoded = video_data.split(',', 1)
|
||||
video_bytes = base64.b64decode(encoded)
|
||||
|
||||
# 提取文件扩展名
|
||||
import re
|
||||
match = re.search(r'data:video/(\w+);', header)
|
||||
ext = match.group(1) if match else 'mp4'
|
||||
|
||||
# 生成唯一文件名
|
||||
video_filename = f"videos/{user_id}/{uuid.uuid4().hex[:12]}.{ext}"
|
||||
|
||||
# 上传到MinIO
|
||||
file_obj = io.BytesIO(video_bytes)
|
||||
success = file_storage.upload_fileobj(file_obj, video_filename, f'video/{ext}')
|
||||
|
||||
if success:
|
||||
# 替换为MinIO文件路径
|
||||
processed_input_data['video'] = video_filename
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "视频文件上传失败"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 处理PLY文件输入数据
|
||||
if 'ply' in processed_input_data and processed_input_data['ply']:
|
||||
from app.utils.file import file_storage
|
||||
import io
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
# 从base64字符串解码PLY数据
|
||||
ply_data = processed_input_data['ply']
|
||||
if ply_data.startswith('data:'):
|
||||
# 移除data URL前缀
|
||||
header, encoded = ply_data.split(',', 1)
|
||||
ply_bytes = base64.b64decode(encoded)
|
||||
|
||||
# 生成唯一文件名
|
||||
ply_filename = f"ply/{user_id}/{uuid.uuid4().hex[:12]}.ply"
|
||||
|
||||
# 上传到MinIO
|
||||
file_obj = io.BytesIO(ply_bytes)
|
||||
success = file_storage.upload_fileobj(file_obj, ply_filename, 'application/octet-stream')
|
||||
|
||||
if success:
|
||||
# 替换为MinIO文件路径
|
||||
processed_input_data['ply'] = ply_filename
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "PLY文件上传失败"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 调用算法API
|
||||
response = requests.post(
|
||||
version.url,
|
||||
json={
|
||||
"input_data": processed_input_data,
|
||||
"params": call.params
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# 处理响应
|
||||
if response.status_code == 200:
|
||||
output_data = response.json()
|
||||
db_call.status = "success"
|
||||
db_call.output_data = output_data
|
||||
db_call.response_time = response_time
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = f"算法执行失败: {response.text}"
|
||||
db_call.response_time = response_time
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = f"算法执行异常: {str(e)}"
|
||||
db_call.response_time = time.time() - start_time if 'start_time' in locals() else None
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_call)
|
||||
|
||||
return db_call
|
||||
|
||||
@staticmethod
|
||||
def get_call_by_id(db: Session, call_id: str) -> Optional[AlgorithmCall]:
|
||||
"""通过ID获取调用记录"""
|
||||
return db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_calls_by_user_id(db: Session, user_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
|
||||
"""通过用户ID获取调用记录列表"""
|
||||
return db.query(AlgorithmCall).filter(
|
||||
AlgorithmCall.user_id == user_id
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_calls_by_algorithm_id(db: Session, algorithm_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
|
||||
"""通过算法ID获取调用记录列表"""
|
||||
return db.query(AlgorithmCall).filter(
|
||||
AlgorithmCall.algorithm_id == algorithm_id
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def execute_python_code(code: str) -> dict:
|
||||
"""执行Python代码"""
|
||||
import subprocess
|
||||
import sys
|
||||
import io
|
||||
import contextlib
|
||||
|
||||
# 准备执行环境
|
||||
result = {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": ""
|
||||
}
|
||||
|
||||
try:
|
||||
# 创建一个安全的执行环境
|
||||
# 使用subprocess创建一个独立的进程,限制执行时间
|
||||
# 注意:这只是一个基本的安全措施,生产环境中需要更严格的沙箱
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write(code)
|
||||
temp_file_name = f.name
|
||||
|
||||
try:
|
||||
# 执行代码,限制时间为5秒
|
||||
output = subprocess.check_output(
|
||||
[sys.executable, temp_file_name],
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=5,
|
||||
universal_newlines=True
|
||||
)
|
||||
result["success"] = True
|
||||
result["output"] = output
|
||||
except subprocess.TimeoutExpired:
|
||||
result["error"] = "代码执行超时(超过5秒)"
|
||||
except subprocess.CalledProcessError as e:
|
||||
result["error"] = e.output
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_file_name):
|
||||
os.unlink(temp_file_name)
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = f"执行环境错误: {str(e)}"
|
||||
|
||||
return result
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user