good version for web
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -5,13 +5,14 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.models.database import get_db
|
||||
from app.services.user import UserService
|
||||
from app.schemas.user import UserResponse
|
||||
|
||||
# OAuth2密码Bearer
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/users/login", auto_error=False)
|
||||
|
||||
|
||||
async def get_current_active_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
|
||||
@@ -25,4 +26,19 @@ async def get_current_active_user(db: Session = Depends(get_db), token: str = De
|
||||
)
|
||||
if hasattr(user, 'status') and user.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return user
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user_optional(db: Session = Depends(get_db), token: Optional[str] = Depends(oauth2_scheme)):
|
||||
"""获取当前活跃用户(可选,未登录返回None)"""
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
user = UserService.get_current_user(db, token)
|
||||
if not user:
|
||||
return None
|
||||
if hasattr(user, 'status') and user.status != "active":
|
||||
return None
|
||||
return user
|
||||
except:
|
||||
return None
|
||||
BIN
backend/app/models/__pycache__/api.cpython-39.pyc
Normal file
BIN
backend/app/models/__pycache__/api.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -237,6 +237,21 @@ async def delete_version(
|
||||
|
||||
|
||||
# 算法调用相关路由
|
||||
@router.post("/call/public", response_model=AlgorithmCallResult)
|
||||
async def call_algorithm_public(
|
||||
call: AlgorithmCallCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""公开调用算法(不需要认证,用于演示页面)"""
|
||||
# 使用匿名用户ID进行调用
|
||||
anonymous_user_id = "anonymous"
|
||||
|
||||
# 执行算法
|
||||
result = AlgorithmCallService.execute_algorithm(db, anonymous_user_id, call)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/call", response_model=AlgorithmCallResult)
|
||||
async def call_algorithm(
|
||||
call: AlgorithmCallCreate,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""数据管理路由,提供输入数据、输出结果和元数据的管理功能"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File, Form
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
|
||||
from app.services.data_manager import data_manager
|
||||
from app.models.database import get_db
|
||||
from app.dependencies import get_current_active_user
|
||||
from app.dependencies import get_current_active_user, get_current_active_user_optional
|
||||
|
||||
router = APIRouter(prefix="/data", tags=["data-management"])
|
||||
|
||||
@@ -176,19 +176,18 @@ async def get_user_outputs(
|
||||
@router.post("/media/upload")
|
||||
async def upload_media_file(
|
||||
file: UploadFile = File(...),
|
||||
algorithm_id: str = None,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
algorithm_id: str = Form(...)
|
||||
):
|
||||
"""上传媒体文件(如图片、视频等)"""
|
||||
"""上传媒体文件(如图片、视频等)- 公开API,不需要认证"""
|
||||
if not algorithm_id:
|
||||
raise HTTPException(status_code=400, detail="algorithm_id is required")
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 保存到数据管理器
|
||||
# 保存到数据管理器(使用匿名用户)
|
||||
file_path = data_manager.save_media_file(
|
||||
user_id=current_user.get("id"),
|
||||
user_id="anonymous",
|
||||
algorithm_id=algorithm_id,
|
||||
file_content=file_content,
|
||||
file_name=file.filename
|
||||
@@ -209,10 +208,28 @@ async def upload_media_file(
|
||||
@router.get("/media/{file_path:path}")
|
||||
async def get_media_file(
|
||||
file_path: str,
|
||||
current_user: dict = Depends(get_current_active_user)
|
||||
current_user: Optional[dict] = Depends(get_current_active_user_optional)
|
||||
):
|
||||
"""获取媒体文件"""
|
||||
# 检查用户权限 - 确保用户只能访问自己的文件或公共文件
|
||||
# results 目录下的文件公开访问
|
||||
if file_path.startswith("results/"):
|
||||
content = data_manager.get_media_file(file_path)
|
||||
if content:
|
||||
# 从完整路径中提取文件名来获取正确的MIME类型
|
||||
import mimetypes
|
||||
filename = file_path.split('/')[-1]
|
||||
content_type, _ = mimetypes.guess_type(filename)
|
||||
if content_type is None:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
from fastapi.responses import Response
|
||||
return Response(content=content, media_type=content_type)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Media file not found")
|
||||
|
||||
# 其他文件需要用户权限
|
||||
if current_user is None:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
if current_user.get("role") != "admin" and not file_path.startswith(f"media/{current_user.get('id')}/"):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class BatchOperationResponse(BaseModel):
|
||||
# 初始化服务组件
|
||||
project_analyzer = ProjectAnalyzer()
|
||||
service_generator = ServiceGenerator()
|
||||
service_orchestrator = ServiceOrchestrator(deployment_mode=settings.DEPLOYMENT_MODE)
|
||||
service_orchestrator = ServiceOrchestrator(deployment_mode=settings.DEPLOYMENT_MODE, db_url=settings.DATABASE_URL)
|
||||
|
||||
|
||||
@router.post("/register", status_code=status.HTTP_201_CREATED)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,6 +4,10 @@ from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
import requests
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
|
||||
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
|
||||
@@ -97,14 +101,39 @@ class AlgorithmService:
|
||||
|
||||
@staticmethod
|
||||
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
|
||||
"""获取算法列表"""
|
||||
"""获取算法列表,优先显示已注册的服务"""
|
||||
from app.models.models import AlgorithmService as Service
|
||||
|
||||
# 获取所有已注册的服务
|
||||
services = db.query(Service).all()
|
||||
|
||||
# 获取所有算法
|
||||
query = db.query(Algorithm)
|
||||
|
||||
# 如果指定了算法类型,进行过滤
|
||||
if algorithm_type:
|
||||
query = query.filter(Algorithm.type == algorithm_type)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
algorithms = query.all()
|
||||
|
||||
# 创建服务名称集合,用于快速查找
|
||||
service_names = {service.name for service in services}
|
||||
|
||||
# 将算法分为两类:已注册的服务和普通算法
|
||||
registered_algorithms = []
|
||||
normal_algorithms = []
|
||||
|
||||
for algo in algorithms:
|
||||
if algo.name in service_names:
|
||||
registered_algorithms.append(algo)
|
||||
else:
|
||||
normal_algorithms.append(algo)
|
||||
|
||||
# 合并结果:已注册的服务在前,普通算法在后
|
||||
merged_algorithms = registered_algorithms + normal_algorithms
|
||||
|
||||
# 应用分页
|
||||
return merged_algorithms[skip:skip+limit]
|
||||
|
||||
@staticmethod
|
||||
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
|
||||
@@ -358,14 +387,37 @@ class AlgorithmCallService:
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 处理视频路径 - 如果是MinIO路径,下载到本地
|
||||
from app.utils.file import file_storage
|
||||
video_path = processed_input_data.get('video', '')
|
||||
if video_path and video_path.startswith('media/'):
|
||||
# 从MinIO下载视频到本地
|
||||
video_content = file_storage.get_object(video_path)
|
||||
if video_content:
|
||||
# 保存到临时文件
|
||||
import tempfile
|
||||
import uuid
|
||||
suffix = '.' + video_path.split('.')[-1] if '.' in video_path else '.mp4'
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
||||
tmp_file.write(video_content)
|
||||
local_video_path = tmp_file.name
|
||||
# 使用本地路径
|
||||
processed_input_data['video'] = local_video_path
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "无法下载视频文件"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 调用算法API
|
||||
print(f"[DEBUG] 调用算法API: {version.url}, 输入: {processed_input_data}")
|
||||
response = requests.post(
|
||||
version.url,
|
||||
json={
|
||||
"input_data": processed_input_data,
|
||||
"params": call.params
|
||||
},
|
||||
timeout=30
|
||||
timeout=120
|
||||
)
|
||||
|
||||
# 计算响应时间
|
||||
@@ -374,7 +426,35 @@ class AlgorithmCallService:
|
||||
# 处理响应
|
||||
if response.status_code == 200:
|
||||
output_data = response.json()
|
||||
print(f"[DEBUG] 算法响应: {output_data}")
|
||||
|
||||
# 处理算法返回的视频文件
|
||||
result = output_data.get('result', {})
|
||||
if result and isinstance(result, dict):
|
||||
# 如果返回了本地视频路径,需要处理
|
||||
if 'video' in result and result['video'] and result['video'].startswith('/'):
|
||||
# 这是本地路径,需要转换为可访问的URL
|
||||
local_video = result['video']
|
||||
if os.path.exists(local_video):
|
||||
# 上传到MinIO并替换路径
|
||||
with open(local_video, 'rb') as f:
|
||||
video_content = f.read()
|
||||
|
||||
video_filename = f"results/{user_id}/{uuid.uuid4().hex[:12]}.mp4"
|
||||
from app.utils.file import file_storage
|
||||
success = file_storage.upload_from_bytes(video_content, video_filename)
|
||||
if success:
|
||||
result['video'] = video_filename
|
||||
result['video_url'] = f"/api/v1/data/media/{video_filename}"
|
||||
|
||||
# 删除临时文件
|
||||
try:
|
||||
os.remove(local_video)
|
||||
except:
|
||||
pass
|
||||
|
||||
db_call.status = "success"
|
||||
print(f"[DEBUG] 保存output_data: {output_data}")
|
||||
db_call.output_data = output_data
|
||||
db_call.response_time = response_time
|
||||
else:
|
||||
|
||||
@@ -235,7 +235,7 @@ class DataManager:
|
||||
def get_media_file(self, file_path: str) -> Optional[bytes]:
|
||||
"""获取媒体文件内容"""
|
||||
try:
|
||||
content = file_storage.download_file(file_path)
|
||||
content = file_storage.get_object(file_path)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting media file: {str(e)}")
|
||||
|
||||
@@ -11,18 +11,41 @@ import psutil
|
||||
from typing import Dict, Any, Optional
|
||||
from docker.errors import DockerException, NotFound
|
||||
|
||||
# 数据库相关导入
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
DATABASE_AVAILABLE = True
|
||||
except ImportError:
|
||||
DATABASE_AVAILABLE = False
|
||||
|
||||
|
||||
class ServiceOrchestrator:
|
||||
"""服务编排服务"""
|
||||
|
||||
def __init__(self, deployment_mode="local"):
|
||||
def __init__(self, deployment_mode="local", db_url=None):
|
||||
"""初始化服务编排器
|
||||
|
||||
Args:
|
||||
deployment_mode: 部署模式,支持"docker"和"local"
|
||||
db_url: 数据库连接URL,用于重新加载服务信息
|
||||
"""
|
||||
self.deployment_mode = deployment_mode
|
||||
self.processes = {} # 存储本地进程信息
|
||||
self.db_url = db_url
|
||||
self.db_engine = None
|
||||
self.db_session = None
|
||||
|
||||
# 初始化数据库连接
|
||||
if db_url and DATABASE_AVAILABLE:
|
||||
try:
|
||||
self.db_engine = create_engine(db_url)
|
||||
self.db_session = sessionmaker(bind=self.db_engine)()
|
||||
print("数据库连接成功")
|
||||
except Exception as e:
|
||||
print(f"数据库连接失败: {e}")
|
||||
self.db_engine = None
|
||||
self.db_session = None
|
||||
|
||||
if deployment_mode == "docker":
|
||||
try:
|
||||
@@ -178,15 +201,17 @@ class ServiceOrchestrator:
|
||||
# 本地进程启动
|
||||
if service_id not in self.processes:
|
||||
# 服务不在进程列表中,可能是服务重启导致的
|
||||
# 这种情况下,需要从外部重新注册服务
|
||||
# 暂时返回错误,建议用户重新注册服务
|
||||
print(f"服务 {service_id} 不在进程列表中,无法启动")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务不存在,请重新注册服务",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
# 尝试从数据库重新加载服务信息
|
||||
print(f"服务 {service_id} 不在进程列表中,尝试从数据库重新加载")
|
||||
service_info = self.reload_service_from_db(service_id)
|
||||
|
||||
if not service_info:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务不存在,请重新注册服务",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
process_info = self.processes[service_id]
|
||||
|
||||
@@ -209,11 +234,76 @@ class ServiceOrchestrator:
|
||||
project_info = process_info["project_info"]
|
||||
service_config = process_info["service_config"]
|
||||
|
||||
print(f"准备启动服务 {service_id}")
|
||||
print(f"服务目录: {service_dir}")
|
||||
print(f"服务配置: {service_config}")
|
||||
|
||||
# 检查服务目录是否存在,如果不存在则从Gitea克隆
|
||||
if not os.path.exists(service_dir):
|
||||
print(f"服务目录不存在: {service_dir},尝试从Gitea克隆代码")
|
||||
repository_id = service_config.get("repository_id")
|
||||
if not repository_id:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "无法获取仓库ID,无法克隆代码",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
# 从数据库获取仓库信息
|
||||
try:
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import AlgorithmRepository
|
||||
|
||||
db = SessionLocal()
|
||||
repository = db.query(AlgorithmRepository).filter(AlgorithmRepository.id == repository_id).first()
|
||||
db.close()
|
||||
|
||||
if not repository:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"仓库不存在: {repository_id}",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
# 克隆仓库
|
||||
from app.gitea.service import GiteaService
|
||||
gitea_service = GiteaService()
|
||||
clone_success = gitea_service.clone_repository(
|
||||
repository.repo_url,
|
||||
service_id,
|
||||
repository.branch or "main"
|
||||
)
|
||||
|
||||
if not clone_success:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"克隆仓库失败: {repository.repo_url}",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
print(f"成功从Gitea克隆仓库到: {service_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"克隆仓库时出错: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"克隆仓库时出错: {str(e)}",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
# 启动服务进程
|
||||
print(f"开始启动服务进程...")
|
||||
new_process_info = self._start_local_service_process(service_id, service_dir, project_info, service_config)
|
||||
print(f"服务进程启动完成: {new_process_info}")
|
||||
|
||||
# 验证服务启动
|
||||
print(f"开始验证服务启动...")
|
||||
if not self._verify_local_service_startup(service_id, service_config):
|
||||
print(f"服务启动验证失败")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务启动验证失败",
|
||||
@@ -221,6 +311,7 @@ class ServiceOrchestrator:
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
print(f"服务启动成功!")
|
||||
return {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
@@ -610,6 +701,76 @@ class ServiceOrchestrator:
|
||||
"health": "unknown"
|
||||
}
|
||||
|
||||
def reload_service_from_db(self, service_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""从数据库重新加载服务信息
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务信息字典,如果未找到则返回None
|
||||
"""
|
||||
if not self.db_session:
|
||||
print("数据库连接不可用,无法重新加载服务信息")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 从数据库查询服务信息
|
||||
query = text("""
|
||||
SELECT service_id, api_url, status, config, host, port
|
||||
FROM algorithm_services
|
||||
WHERE service_id = :service_id
|
||||
""")
|
||||
result = self.db_session.execute(query, {"service_id": service_id}).fetchone()
|
||||
|
||||
if not result:
|
||||
print(f"数据库中未找到服务 {service_id}")
|
||||
return None
|
||||
|
||||
# 构建服务信息字典
|
||||
# config字段已经是JSON类型,不需要再解析
|
||||
config_data = result[3] if result[3] else {}
|
||||
|
||||
# 将host和port添加到config中
|
||||
config_data["host"] = result[4] if result[4] else "localhost"
|
||||
config_data["port"] = result[5] if result[5] else 8000
|
||||
|
||||
service_info = {
|
||||
"service_id": result[0],
|
||||
"api_url": result[1],
|
||||
"status": result[2],
|
||||
"config": config_data
|
||||
}
|
||||
|
||||
# 从config中提取容器ID
|
||||
container_id = service_info["config"].get("container_id")
|
||||
if container_id:
|
||||
service_info["container_id"] = container_id
|
||||
|
||||
# 更新本地进程缓存(仅用于本地模式)
|
||||
if self.deployment_mode == "local":
|
||||
# 从config中获取项目类型,如果没有则默认为python
|
||||
project_type = service_info["config"].get("project_type", "python")
|
||||
|
||||
self.processes[service_id] = {
|
||||
"service_dir": f"/tmp/algorithms/{service_id}",
|
||||
"project_info": {
|
||||
"project_type": project_type,
|
||||
"name": service_info["config"].get("name", ""),
|
||||
"version": service_info["config"].get("version", "1.0.0"),
|
||||
"description": service_info["config"].get("description", "")
|
||||
},
|
||||
"service_config": service_info["config"],
|
||||
"pid": None # PID需要重新获取
|
||||
}
|
||||
|
||||
print(f"成功从数据库重新加载服务 {service_id} 的信息")
|
||||
return service_info
|
||||
|
||||
except Exception as e:
|
||||
print(f"从数据库重新加载服务信息失败: {e}")
|
||||
return None
|
||||
|
||||
def get_service_logs(self, container_id: str, lines: int = 100) -> Dict[str, Any]:
|
||||
"""获取服务日志
|
||||
|
||||
@@ -1429,7 +1590,10 @@ def main(data):
|
||||
import requests
|
||||
|
||||
# 构建健康检查URL
|
||||
# 使用localhost而不是0.0.0.0,因为健康检查是在本地执行的
|
||||
host = service_config.get("host", "localhost")
|
||||
if host == "0.0.0.0":
|
||||
host = "localhost"
|
||||
port = service_config.get("port", 8000)
|
||||
health_check_url = f"http://{host}:{port}/health"
|
||||
|
||||
@@ -1437,5 +1601,6 @@ def main(data):
|
||||
response = requests.get(health_check_url, timeout=10)
|
||||
|
||||
return response.status_code == 200
|
||||
except:
|
||||
except Exception as e:
|
||||
print(f"健康检查失败: {e}")
|
||||
return False
|
||||
|
||||
Binary file not shown.
@@ -59,6 +59,27 @@ class MinioClient:
|
||||
logging.warning(f"MinIO upload error: {e}")
|
||||
return False
|
||||
|
||||
def upload_from_bytes(self, data: bytes, object_name: str) -> bool:
|
||||
"""从字节数据上传文件"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Upload skipped.")
|
||||
return False
|
||||
|
||||
try:
|
||||
import io
|
||||
file_obj = io.BytesIO(data)
|
||||
self.client.put_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_obj,
|
||||
length=len(data),
|
||||
part_size=10*1024*1024
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO upload error: {e}")
|
||||
return False
|
||||
|
||||
def upload_fileobj(self, file_obj: io.BytesIO, object_name: str, content_type: str = "application/octet-stream") -> bool:
|
||||
"""上传文件对象"""
|
||||
if not self.is_connected:
|
||||
|
||||
Reference in New Issue
Block a user