good version for web

This commit is contained in:
2026-02-18 09:36:18 +08:00
parent 62ea5d36a5
commit 72ab0c0b56
42 changed files with 1305 additions and 1515 deletions

View File

@@ -5,13 +5,14 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from typing import Optional
from app.models.database import get_db
from app.services.user import UserService
from app.schemas.user import UserResponse
# OAuth2密码Bearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login", auto_error=False)
async def get_current_active_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
@@ -25,4 +26,19 @@ async def get_current_active_user(db: Session = Depends(get_db), token: str = De
)
if hasattr(user, 'status') and user.status != "active":
raise HTTPException(status_code=400, detail="Inactive user")
return user
return user
async def get_current_active_user_optional(db: Session = Depends(get_db), token: Optional[str] = Depends(oauth2_scheme)):
"""获取当前活跃用户可选未登录返回None"""
if not token:
return None
try:
user = UserService.get_current_user(db, token)
if not user:
return None
if hasattr(user, 'status') and user.status != "active":
return None
return user
except:
return None

Binary file not shown.

View File

@@ -237,6 +237,21 @@ async def delete_version(
# 算法调用相关路由
@router.post("/call/public", response_model=AlgorithmCallResult)
async def call_algorithm_public(
call: AlgorithmCallCreate,
db: Session = Depends(get_db)
):
"""公开调用算法(不需要认证,用于演示页面)"""
# 使用匿名用户ID进行调用
anonymous_user_id = "anonymous"
# 执行算法
result = AlgorithmCallService.execute_algorithm(db, anonymous_user_id, call)
return result
@router.post("/call", response_model=AlgorithmCallResult)
async def call_algorithm(
call: AlgorithmCallCreate,

View File

@@ -1,6 +1,6 @@
"""数据管理路由,提供输入数据、输出结果和元数据的管理功能"""
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File, Form
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -8,7 +8,7 @@ import json
from app.services.data_manager import data_manager
from app.models.database import get_db
from app.dependencies import get_current_active_user
from app.dependencies import get_current_active_user, get_current_active_user_optional
router = APIRouter(prefix="/data", tags=["data-management"])
@@ -176,19 +176,18 @@ async def get_user_outputs(
@router.post("/media/upload")
async def upload_media_file(
file: UploadFile = File(...),
algorithm_id: str = None,
current_user: dict = Depends(get_current_active_user)
algorithm_id: str = Form(...)
):
"""上传媒体文件(如图片、视频等)"""
"""上传媒体文件(如图片、视频等)- 公开API不需要认证"""
if not algorithm_id:
raise HTTPException(status_code=400, detail="algorithm_id is required")
# 读取文件内容
file_content = await file.read()
# 保存到数据管理器
# 保存到数据管理器(使用匿名用户)
file_path = data_manager.save_media_file(
user_id=current_user.get("id"),
user_id="anonymous",
algorithm_id=algorithm_id,
file_content=file_content,
file_name=file.filename
@@ -209,10 +208,28 @@ async def upload_media_file(
@router.get("/media/{file_path:path}")
async def get_media_file(
file_path: str,
current_user: dict = Depends(get_current_active_user)
current_user: Optional[dict] = Depends(get_current_active_user_optional)
):
"""获取媒体文件"""
# 检查用户权限 - 确保用户只能访问自己的文件或公共文件
# results 目录下的文件公开访问
if file_path.startswith("results/"):
content = data_manager.get_media_file(file_path)
if content:
# 从完整路径中提取文件名来获取正确的MIME类型
import mimetypes
filename = file_path.split('/')[-1]
content_type, _ = mimetypes.guess_type(filename)
if content_type is None:
content_type = "application/octet-stream"
from fastapi.responses import Response
return Response(content=content, media_type=content_type)
else:
raise HTTPException(status_code=404, detail="Media file not found")
# 其他文件需要用户权限
if current_user is None:
raise HTTPException(status_code=401, detail="Not authenticated")
if current_user.get("role") != "admin" and not file_path.startswith(f"media/{current_user.get('id')}/"):
raise HTTPException(status_code=403, detail="Insufficient permissions")

View File

@@ -132,7 +132,7 @@ class BatchOperationResponse(BaseModel):
# 初始化服务组件
project_analyzer = ProjectAnalyzer()
service_generator = ServiceGenerator()
service_orchestrator = ServiceOrchestrator(deployment_mode=settings.DEPLOYMENT_MODE)
service_orchestrator = ServiceOrchestrator(deployment_mode=settings.DEPLOYMENT_MODE, db_url=settings.DATABASE_URL)
@router.post("/register", status_code=status.HTTP_201_CREATED)

View File

@@ -4,6 +4,10 @@ from sqlalchemy.orm import Session
import uuid
import requests
import time
import logging
import os
logger = logging.getLogger(__name__)
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
@@ -97,14 +101,39 @@ class AlgorithmService:
@staticmethod
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
"""获取算法列表"""
"""获取算法列表,优先显示已注册的服务"""
from app.models.models import AlgorithmService as Service
# 获取所有已注册的服务
services = db.query(Service).all()
# 获取所有算法
query = db.query(Algorithm)
# 如果指定了算法类型,进行过滤
if algorithm_type:
query = query.filter(Algorithm.type == algorithm_type)
return query.offset(skip).limit(limit).all()
algorithms = query.all()
# 创建服务名称集合,用于快速查找
service_names = {service.name for service in services}
# 将算法分为两类:已注册的服务和普通算法
registered_algorithms = []
normal_algorithms = []
for algo in algorithms:
if algo.name in service_names:
registered_algorithms.append(algo)
else:
normal_algorithms.append(algo)
# 合并结果:已注册的服务在前,普通算法在后
merged_algorithms = registered_algorithms + normal_algorithms
# 应用分页
return merged_algorithms[skip:skip+limit]
@staticmethod
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
@@ -358,14 +387,37 @@ class AlgorithmCallService:
# 记录开始时间
start_time = time.time()
# 处理视频路径 - 如果是MinIO路径下载到本地
from app.utils.file import file_storage
video_path = processed_input_data.get('video', '')
if video_path and video_path.startswith('media/'):
# 从MinIO下载视频到本地
video_content = file_storage.get_object(video_path)
if video_content:
# 保存到临时文件
import tempfile
import uuid
suffix = '.' + video_path.split('.')[-1] if '.' in video_path else '.mp4'
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(video_content)
local_video_path = tmp_file.name
# 使用本地路径
processed_input_data['video'] = local_video_path
else:
db_call.status = "failed"
db_call.error_message = "无法下载视频文件"
db.commit()
return db_call
# 调用算法API
print(f"[DEBUG] 调用算法API: {version.url}, 输入: {processed_input_data}")
response = requests.post(
version.url,
json={
"input_data": processed_input_data,
"params": call.params
},
timeout=30
timeout=120
)
# 计算响应时间
@@ -374,7 +426,35 @@ class AlgorithmCallService:
# 处理响应
if response.status_code == 200:
output_data = response.json()
print(f"[DEBUG] 算法响应: {output_data}")
# 处理算法返回的视频文件
result = output_data.get('result', {})
if result and isinstance(result, dict):
# 如果返回了本地视频路径,需要处理
if 'video' in result and result['video'] and result['video'].startswith('/'):
# 这是本地路径需要转换为可访问的URL
local_video = result['video']
if os.path.exists(local_video):
# 上传到MinIO并替换路径
with open(local_video, 'rb') as f:
video_content = f.read()
video_filename = f"results/{user_id}/{uuid.uuid4().hex[:12]}.mp4"
from app.utils.file import file_storage
success = file_storage.upload_from_bytes(video_content, video_filename)
if success:
result['video'] = video_filename
result['video_url'] = f"/api/v1/data/media/{video_filename}"
# 删除临时文件
try:
os.remove(local_video)
except:
pass
db_call.status = "success"
print(f"[DEBUG] 保存output_data: {output_data}")
db_call.output_data = output_data
db_call.response_time = response_time
else:

View File

@@ -235,7 +235,7 @@ class DataManager:
def get_media_file(self, file_path: str) -> Optional[bytes]:
"""获取媒体文件内容"""
try:
content = file_storage.download_file(file_path)
content = file_storage.get_object(file_path)
return content
except Exception as e:
logger.error(f"Error getting media file: {str(e)}")

View File

@@ -11,18 +11,41 @@ import psutil
from typing import Dict, Any, Optional
from docker.errors import DockerException, NotFound
# 数据库相关导入
try:
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
DATABASE_AVAILABLE = True
except ImportError:
DATABASE_AVAILABLE = False
class ServiceOrchestrator:
"""服务编排服务"""
def __init__(self, deployment_mode="local"):
def __init__(self, deployment_mode="local", db_url=None):
"""初始化服务编排器
Args:
deployment_mode: 部署模式,支持"docker""local"
db_url: 数据库连接URL用于重新加载服务信息
"""
self.deployment_mode = deployment_mode
self.processes = {} # 存储本地进程信息
self.db_url = db_url
self.db_engine = None
self.db_session = None
# 初始化数据库连接
if db_url and DATABASE_AVAILABLE:
try:
self.db_engine = create_engine(db_url)
self.db_session = sessionmaker(bind=self.db_engine)()
print("数据库连接成功")
except Exception as e:
print(f"数据库连接失败: {e}")
self.db_engine = None
self.db_session = None
if deployment_mode == "docker":
try:
@@ -178,15 +201,17 @@ class ServiceOrchestrator:
# 本地进程启动
if service_id not in self.processes:
# 服务不在进程列表中,可能是服务重启导致的
# 这种情况下,需要从外部重新注册服务
# 暂时返回错误,建议用户重新注册服务
print(f"服务 {service_id} 不在进程列表中,无法启动")
return {
"success": False,
"error": "服务不存在,请重新注册服务",
"service_id": service_id,
"status": "error"
}
# 尝试从数据库重新加载服务信息
print(f"服务 {service_id} 不在进程列表中,尝试从数据库重新加载")
service_info = self.reload_service_from_db(service_id)
if not service_info:
return {
"success": False,
"error": "服务不存在,请重新注册服务",
"service_id": service_id,
"status": "error"
}
process_info = self.processes[service_id]
@@ -209,11 +234,76 @@ class ServiceOrchestrator:
project_info = process_info["project_info"]
service_config = process_info["service_config"]
print(f"准备启动服务 {service_id}")
print(f"服务目录: {service_dir}")
print(f"服务配置: {service_config}")
# 检查服务目录是否存在如果不存在则从Gitea克隆
if not os.path.exists(service_dir):
print(f"服务目录不存在: {service_dir}尝试从Gitea克隆代码")
repository_id = service_config.get("repository_id")
if not repository_id:
return {
"success": False,
"error": "无法获取仓库ID无法克隆代码",
"service_id": service_id,
"status": "error"
}
# 从数据库获取仓库信息
try:
from app.models.database import SessionLocal
from app.models.models import AlgorithmRepository
db = SessionLocal()
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
db.close()
if not repository:
return {
"success": False,
"error": f"仓库不存在: {repository_id}",
"service_id": service_id,
"status": "error"
}
# 克隆仓库
from app.gitea.service import GiteaService
gitea_service = GiteaService()
clone_success = gitea_service.clone_repository(
repository.repo_url,
service_id,
repository.branch or "main"
)
if not clone_success:
return {
"success": False,
"error": f"克隆仓库失败: {repository.repo_url}",
"service_id": service_id,
"status": "error"
}
print(f"成功从Gitea克隆仓库到: {service_dir}")
except Exception as e:
print(f"克隆仓库时出错: {str(e)}")
return {
"success": False,
"error": f"克隆仓库时出错: {str(e)}",
"service_id": service_id,
"status": "error"
}
# 启动服务进程
print(f"开始启动服务进程...")
new_process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config)
print(f"服务进程启动完成: {new_process_info}")
# 验证服务启动
print(f"开始验证服务启动...")
if not self._verify_local_service_startup(service_id, service_config):
print(f"服务启动验证失败")
return {
"success": False,
"error": "服务启动验证失败",
@@ -221,6 +311,7 @@ class ServiceOrchestrator:
"status": "error"
}
print(f"服务启动成功!")
return {
"success": True,
"service_id": service_id,
@@ -610,6 +701,76 @@ class ServiceOrchestrator:
"health": "unknown"
}
def reload_service_from_db(self, service_id: str) -> Optional[Dict[str, Any]]:
"""从数据库重新加载服务信息
Args:
service_id: 服务ID
Returns:
服务信息字典如果未找到则返回None
"""
if not self.db_session:
print("数据库连接不可用,无法重新加载服务信息")
return None
try:
# 从数据库查询服务信息
query = text("""
SELECT service_id, api_url, status, config, host, port
FROM algorithm_services
WHERE service_id = :service_id
""")
result = self.db_session.execute(query, {"service_id": service_id}).fetchone()
if not result:
print(f"数据库中未找到服务 {service_id}")
return None
# 构建服务信息字典
# config字段已经是JSON类型不需要再解析
config_data = result[3] if result[3] else {}
# 将host和port添加到config中
config_data["host"] = result[4] if result[4] else "localhost"
config_data["port"] = result[5] if result[5] else 8000
service_info = {
"service_id": result[0],
"api_url": result[1],
"status": result[2],
"config": config_data
}
# 从config中提取容器ID
container_id = service_info["config"].get("container_id")
if container_id:
service_info["container_id"] = container_id
# 更新本地进程缓存(仅用于本地模式)
if self.deployment_mode == "local":
# 从config中获取项目类型如果没有则默认为python
project_type = service_info["config"].get("project_type", "python")
self.processes[service_id] = {
"service_dir": f"/tmp/algorithms/{service_id}",
"project_info": {
"project_type": project_type,
"name": service_info["config"].get("name", ""),
"version": service_info["config"].get("version", "1.0.0"),
"description": service_info["config"].get("description", "")
},
"service_config": service_info["config"],
"pid": None # PID需要重新获取
}
print(f"成功从数据库重新加载服务 {service_id} 的信息")
return service_info
except Exception as e:
print(f"从数据库重新加载服务信息失败: {e}")
return None
def get_service_logs(self, container_id: str, lines: int = 100) -> Dict[str, Any]:
"""获取服务日志
@@ -1429,7 +1590,10 @@ def main(data):
import requests
# 构建健康检查URL
# 使用localhost而不是0.0.0.0,因为健康检查是在本地执行的
host = service_config.get("host", "localhost")
if host == "0.0.0.0":
host = "localhost"
port = service_config.get("port", 8000)
health_check_url = f"http://{host}:{port}/health"
@@ -1437,5 +1601,6 @@ def main(data):
response = requests.get(health_check_url, timeout=10)
return response.status_code == 200
except:
except Exception as e:
print(f"健康检查失败: {e}")
return False

View File

@@ -59,6 +59,27 @@ class MinioClient:
logging.warning(f"MinIO upload error: {e}")
return False
def upload_from_bytes(self, data: bytes, object_name: str) -> bool:
"""从字节数据上传文件"""
if not self.is_connected:
logging.warning("MinIO is not connected. Upload skipped.")
return False
try:
import io
file_obj = io.BytesIO(data)
self.client.put_object(
self.bucket_name,
object_name,
file_obj,
length=len(data),
part_size=10*1024*1024
)
return True
except S3Error as e:
logging.warning(f"MinIO upload error: {e}")
return False
def upload_fileobj(self, file_obj: io.BytesIO, object_name: str, content_type: str = "application/octet-stream") -> bool:
"""上传文件对象"""
if not self.is_connected: