466 lines
17 KiB
Python
466 lines
17 KiB
Python
from typing import Optional, List, Dict, Any
|
||
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
import uuid
|
||
import requests
|
||
import time
|
||
|
||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
|
||
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
|
||
from app.services.deployment import deployment_service
|
||
|
||
|
||
class AlgorithmService:
|
||
"""算法服务类"""
|
||
|
||
@staticmethod
|
||
def create_algorithm(db: Session, algorithm: AlgorithmCreate) -> Algorithm:
|
||
"""创建算法"""
|
||
# 生成唯一ID
|
||
algorithm_id = f"algorithm-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建算法实例
|
||
db_algorithm = Algorithm(
|
||
id=algorithm_id,
|
||
name=algorithm.name,
|
||
description=algorithm.description,
|
||
type=algorithm.type
|
||
)
|
||
|
||
# 保存到数据库
|
||
db.add(db_algorithm)
|
||
db.commit()
|
||
db.refresh(db_algorithm)
|
||
|
||
# 自动部署(如果有代码)
|
||
deployed_url = algorithm.url
|
||
deployment_logs = []
|
||
if algorithm.code and not deployed_url:
|
||
try:
|
||
# 构建镜像
|
||
build_result = deployment_service.build_algorithm_image(
|
||
algorithm.name,
|
||
algorithm.code
|
||
)
|
||
|
||
if build_result['success']:
|
||
image_name = build_result['image_name']
|
||
deployment_logs.extend(build_result['logs'])
|
||
|
||
# 部署容器
|
||
deployment_info = deployment_service.deploy_algorithm(
|
||
algorithm_id,
|
||
image_name
|
||
)
|
||
|
||
deployed_url = deployment_info['api_url']
|
||
deployment_logs.append(f"部署成功: {deployed_url}")
|
||
else:
|
||
deployment_logs.extend(build_result['logs'])
|
||
print(f"镜像构建失败: {build_result['logs'][-1]}")
|
||
except Exception as e:
|
||
error_message = f"自动部署失败: {str(e)}"
|
||
deployment_logs.append(error_message)
|
||
print(error_message)
|
||
|
||
# 创建默认版本
|
||
version_id = f"version-{uuid.uuid4().hex[:8]}"
|
||
db_version = AlgorithmVersion(
|
||
id=version_id,
|
||
algorithm_id=algorithm_id,
|
||
version=algorithm.version,
|
||
url=deployed_url,
|
||
params=algorithm.params,
|
||
input_schema=algorithm.input_schema,
|
||
output_schema=algorithm.output_schema,
|
||
code=algorithm.code,
|
||
model_name=algorithm.model_name,
|
||
model_file=algorithm.model_file,
|
||
api_doc=algorithm.api_doc,
|
||
is_default=True
|
||
)
|
||
|
||
# 保存版本到数据库
|
||
db.add(db_version)
|
||
db.commit()
|
||
db.refresh(db_version)
|
||
|
||
# 加载版本关系
|
||
db.refresh(db_algorithm, ['versions'])
|
||
|
||
return db_algorithm
|
||
|
||
@staticmethod
|
||
def get_algorithm_by_id(db: Session, algorithm_id: str) -> Optional[Algorithm]:
|
||
"""通过ID获取算法"""
|
||
return db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
|
||
|
||
@staticmethod
|
||
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
|
||
"""获取算法列表"""
|
||
query = db.query(Algorithm)
|
||
|
||
# 如果指定了算法类型,进行过滤
|
||
if algorithm_type:
|
||
query = query.filter(Algorithm.type == algorithm_type)
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
|
||
"""更新算法"""
|
||
# 获取算法
|
||
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||
if not db_algorithm:
|
||
return None
|
||
|
||
# 更新算法信息
|
||
update_data = algorithm_update.dict(exclude_unset=True)
|
||
|
||
# 应用更新
|
||
for field, value in update_data.items():
|
||
setattr(db_algorithm, field, value)
|
||
|
||
# 保存到数据库
|
||
db.commit()
|
||
db.refresh(db_algorithm)
|
||
|
||
return db_algorithm
|
||
|
||
@staticmethod
|
||
def delete_algorithm(db: Session, algorithm_id: str) -> bool:
|
||
"""删除算法"""
|
||
# 获取算法
|
||
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||
if not db_algorithm:
|
||
return False
|
||
|
||
# 从数据库中删除
|
||
db.delete(db_algorithm)
|
||
db.commit()
|
||
|
||
return True
|
||
|
||
|
||
class AlgorithmVersionService:
|
||
"""算法版本服务类"""
|
||
|
||
@staticmethod
|
||
def create_version(db: Session, version: AlgorithmVersionCreate) -> AlgorithmVersion:
|
||
"""创建算法版本"""
|
||
# 生成唯一ID
|
||
version_id = f"version-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建版本实例
|
||
db_version = AlgorithmVersion(
|
||
id=version_id,
|
||
algorithm_id=version.algorithm_id,
|
||
version=version.version,
|
||
url=version.url,
|
||
params=version.params,
|
||
input_schema=version.input_schema,
|
||
output_schema=version.output_schema,
|
||
code=version.code,
|
||
model_name=version.model_name,
|
||
model_file=version.model_file,
|
||
api_doc=version.api_doc,
|
||
is_default=version.is_default
|
||
)
|
||
|
||
# 如果设置为默认版本,需要将其他版本设置为非默认
|
||
if version.is_default:
|
||
db.query(AlgorithmVersion).filter(
|
||
AlgorithmVersion.algorithm_id == version.algorithm_id,
|
||
AlgorithmVersion.is_default == True
|
||
).update({"is_default": False})
|
||
|
||
# 保存到数据库
|
||
db.add(db_version)
|
||
db.commit()
|
||
db.refresh(db_version)
|
||
|
||
return db_version
|
||
|
||
@staticmethod
|
||
def get_version_by_id(db: Session, version_id: str) -> Optional[AlgorithmVersion]:
|
||
"""通过ID获取版本"""
|
||
return db.query(AlgorithmVersion).filter(AlgorithmVersion.id == version_id).first()
|
||
|
||
@staticmethod
|
||
def get_versions_by_algorithm_id(db: Session, algorithm_id: str) -> List[AlgorithmVersion]:
|
||
"""通过算法ID获取版本列表"""
|
||
return db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm_id).all()
|
||
|
||
@staticmethod
|
||
def get_default_version(db: Session, algorithm_id: str) -> Optional[AlgorithmVersion]:
|
||
"""获取算法的默认版本"""
|
||
return db.query(AlgorithmVersion).filter(
|
||
AlgorithmVersion.algorithm_id == algorithm_id,
|
||
AlgorithmVersion.is_default == True
|
||
).first()
|
||
|
||
@staticmethod
|
||
def update_version(db: Session, version_id: str, version_update: AlgorithmVersionUpdate) -> Optional[AlgorithmVersion]:
|
||
"""更新算法版本"""
|
||
# 获取版本
|
||
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||
if not db_version:
|
||
return None
|
||
|
||
# 更新版本信息
|
||
update_data = version_update.dict(exclude_unset=True)
|
||
|
||
# 如果设置为默认版本,需要将其他版本设置为非默认
|
||
if "is_default" in update_data and update_data["is_default"]:
|
||
db.query(AlgorithmVersion).filter(
|
||
AlgorithmVersion.algorithm_id == db_version.algorithm_id,
|
||
AlgorithmVersion.is_default == True,
|
||
AlgorithmVersion.id != version_id
|
||
).update({"is_default": False})
|
||
|
||
# 应用更新
|
||
for field, value in update_data.items():
|
||
setattr(db_version, field, value)
|
||
|
||
# 保存到数据库
|
||
db.commit()
|
||
db.refresh(db_version)
|
||
|
||
return db_version
|
||
|
||
@staticmethod
|
||
def delete_version(db: Session, version_id: str) -> bool:
|
||
"""删除算法版本"""
|
||
# 获取版本
|
||
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||
if not db_version:
|
||
return False
|
||
|
||
# 从数据库中删除
|
||
db.delete(db_version)
|
||
db.commit()
|
||
|
||
return True
|
||
|
||
|
||
class AlgorithmCallService:
|
||
"""算法调用服务类"""
|
||
|
||
@staticmethod
|
||
def create_call(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
|
||
"""创建算法调用记录"""
|
||
# 生成唯一ID
|
||
call_id = f"call-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建调用实例
|
||
db_call = AlgorithmCall(
|
||
id=call_id,
|
||
user_id=user_id,
|
||
algorithm_id=call.algorithm_id,
|
||
version_id=call.version_id,
|
||
input_data=call.input_data,
|
||
params=call.params
|
||
)
|
||
|
||
# 保存到数据库
|
||
db.add(db_call)
|
||
db.commit()
|
||
db.refresh(db_call)
|
||
|
||
return db_call
|
||
|
||
@staticmethod
|
||
def execute_algorithm(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
|
||
"""执行算法"""
|
||
# 创建调用记录
|
||
db_call = AlgorithmCallService.create_call(db, user_id, call)
|
||
|
||
# 更新状态为运行中
|
||
db_call.status = "running"
|
||
db.commit()
|
||
db.refresh(db_call)
|
||
|
||
try:
|
||
# 获取算法版本信息
|
||
version = AlgorithmVersionService.get_version_by_id(db, call.version_id)
|
||
if not version:
|
||
db_call.status = "failed"
|
||
db_call.error_message = "算法版本不存在"
|
||
db.commit()
|
||
return db_call
|
||
|
||
# 处理视频输入数据
|
||
processed_input_data = call.input_data.copy()
|
||
if 'video' in processed_input_data and processed_input_data['video']:
|
||
from app.utils.file import file_storage
|
||
import io
|
||
import base64
|
||
import uuid
|
||
|
||
# 从base64字符串解码视频数据
|
||
video_data = processed_input_data['video']
|
||
if video_data.startswith('data:'):
|
||
# 移除data URL前缀
|
||
header, encoded = video_data.split(',', 1)
|
||
video_bytes = base64.b64decode(encoded)
|
||
|
||
# 提取文件扩展名
|
||
import re
|
||
match = re.search(r'data:video/(\w+);', header)
|
||
ext = match.group(1) if match else 'mp4'
|
||
|
||
# 生成唯一文件名
|
||
video_filename = f"videos/{user_id}/{uuid.uuid4().hex[:12]}.{ext}"
|
||
|
||
# 上传到MinIO
|
||
file_obj = io.BytesIO(video_bytes)
|
||
success = file_storage.upload_fileobj(file_obj, video_filename, f'video/{ext}')
|
||
|
||
if success:
|
||
# 替换为MinIO文件路径
|
||
processed_input_data['video'] = video_filename
|
||
else:
|
||
db_call.status = "failed"
|
||
db_call.error_message = "视频文件上传失败"
|
||
db.commit()
|
||
return db_call
|
||
|
||
# 处理PLY文件输入数据
|
||
if 'ply' in processed_input_data and processed_input_data['ply']:
|
||
from app.utils.file import file_storage
|
||
import io
|
||
import base64
|
||
import uuid
|
||
|
||
# 从base64字符串解码PLY数据
|
||
ply_data = processed_input_data['ply']
|
||
if ply_data.startswith('data:'):
|
||
# 移除data URL前缀
|
||
header, encoded = ply_data.split(',', 1)
|
||
ply_bytes = base64.b64decode(encoded)
|
||
|
||
# 生成唯一文件名
|
||
ply_filename = f"ply/{user_id}/{uuid.uuid4().hex[:12]}.ply"
|
||
|
||
# 上传到MinIO
|
||
file_obj = io.BytesIO(ply_bytes)
|
||
success = file_storage.upload_fileobj(file_obj, ply_filename, 'application/octet-stream')
|
||
|
||
if success:
|
||
# 替换为MinIO文件路径
|
||
processed_input_data['ply'] = ply_filename
|
||
else:
|
||
db_call.status = "failed"
|
||
db_call.error_message = "PLY文件上传失败"
|
||
db.commit()
|
||
return db_call
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 调用算法API
|
||
response = requests.post(
|
||
version.url,
|
||
json={
|
||
"input_data": processed_input_data,
|
||
"params": call.params
|
||
},
|
||
timeout=30
|
||
)
|
||
|
||
# 计算响应时间
|
||
response_time = time.time() - start_time
|
||
|
||
# 处理响应
|
||
if response.status_code == 200:
|
||
output_data = response.json()
|
||
db_call.status = "success"
|
||
db_call.output_data = output_data
|
||
db_call.response_time = response_time
|
||
else:
|
||
db_call.status = "failed"
|
||
db_call.error_message = f"算法执行失败: {response.text}"
|
||
db_call.response_time = response_time
|
||
|
||
except Exception as e:
|
||
# 记录错误
|
||
db_call.status = "failed"
|
||
db_call.error_message = f"算法执行异常: {str(e)}"
|
||
db_call.response_time = time.time() - start_time if 'start_time' in locals() else None
|
||
|
||
# 保存到数据库
|
||
db.commit()
|
||
db.refresh(db_call)
|
||
|
||
return db_call
|
||
|
||
@staticmethod
|
||
def get_call_by_id(db: Session, call_id: str) -> Optional[AlgorithmCall]:
|
||
"""通过ID获取调用记录"""
|
||
return db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first()
|
||
|
||
@staticmethod
|
||
def get_calls_by_user_id(db: Session, user_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
|
||
"""通过用户ID获取调用记录列表"""
|
||
return db.query(AlgorithmCall).filter(
|
||
AlgorithmCall.user_id == user_id
|
||
).offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def get_calls_by_algorithm_id(db: Session, algorithm_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
|
||
"""通过算法ID获取调用记录列表"""
|
||
return db.query(AlgorithmCall).filter(
|
||
AlgorithmCall.algorithm_id == algorithm_id
|
||
).offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def execute_python_code(code: str) -> dict:
|
||
"""执行Python代码"""
|
||
import subprocess
|
||
import sys
|
||
import io
|
||
import contextlib
|
||
|
||
# 准备执行环境
|
||
result = {
|
||
"success": False,
|
||
"output": "",
|
||
"error": ""
|
||
}
|
||
|
||
try:
|
||
# 创建一个安全的执行环境
|
||
# 使用subprocess创建一个独立的进程,限制执行时间
|
||
# 注意:这只是一个基本的安全措施,生产环境中需要更严格的沙箱
|
||
import tempfile
|
||
import os
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||
f.write(code)
|
||
temp_file_name = f.name
|
||
|
||
try:
|
||
# 执行代码,限制时间为5秒
|
||
output = subprocess.check_output(
|
||
[sys.executable, temp_file_name],
|
||
stderr=subprocess.STDOUT,
|
||
timeout=5,
|
||
universal_newlines=True
|
||
)
|
||
result["success"] = True
|
||
result["output"] = output
|
||
except subprocess.TimeoutExpired:
|
||
result["error"] = "代码执行超时(超过5秒)"
|
||
except subprocess.CalledProcessError as e:
|
||
result["error"] = e.output
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_file_name):
|
||
os.unlink(temp_file_name)
|
||
|
||
except Exception as e:
|
||
result["error"] = f"执行环境错误: {str(e)}"
|
||
|
||
return result
|