Files
algorithm/backend/app/services/algorithm.py
2026-02-08 14:42:58 +08:00

466 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional, List, Dict, Any
from datetime import datetime
from sqlalchemy.orm import Session
import uuid
import requests
import time
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
from app.services.deployment import deployment_service
class AlgorithmService:
"""算法服务类"""
@staticmethod
def create_algorithm(db: Session, algorithm: AlgorithmCreate) -> Algorithm:
"""创建算法"""
# 生成唯一ID
algorithm_id = f"algorithm-{uuid.uuid4().hex[:8]}"
# 创建算法实例
db_algorithm = Algorithm(
id=algorithm_id,
name=algorithm.name,
description=algorithm.description,
type=algorithm.type
)
# 保存到数据库
db.add(db_algorithm)
db.commit()
db.refresh(db_algorithm)
# 自动部署(如果有代码)
deployed_url = algorithm.url
deployment_logs = []
if algorithm.code and not deployed_url:
try:
# 构建镜像
build_result = deployment_service.build_algorithm_image(
algorithm.name,
algorithm.code
)
if build_result['success']:
image_name = build_result['image_name']
deployment_logs.extend(build_result['logs'])
# 部署容器
deployment_info = deployment_service.deploy_algorithm(
algorithm_id,
image_name
)
deployed_url = deployment_info['api_url']
deployment_logs.append(f"部署成功: {deployed_url}")
else:
deployment_logs.extend(build_result['logs'])
print(f"镜像构建失败: {build_result['logs'][-1]}")
except Exception as e:
error_message = f"自动部署失败: {str(e)}"
deployment_logs.append(error_message)
print(error_message)
# 创建默认版本
version_id = f"version-{uuid.uuid4().hex[:8]}"
db_version = AlgorithmVersion(
id=version_id,
algorithm_id=algorithm_id,
version=algorithm.version,
url=deployed_url,
params=algorithm.params,
input_schema=algorithm.input_schema,
output_schema=algorithm.output_schema,
code=algorithm.code,
model_name=algorithm.model_name,
model_file=algorithm.model_file,
api_doc=algorithm.api_doc,
is_default=True
)
# 保存版本到数据库
db.add(db_version)
db.commit()
db.refresh(db_version)
# 加载版本关系
db.refresh(db_algorithm, ['versions'])
return db_algorithm
@staticmethod
def get_algorithm_by_id(db: Session, algorithm_id: str) -> Optional[Algorithm]:
"""通过ID获取算法"""
return db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
@staticmethod
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
"""获取算法列表"""
query = db.query(Algorithm)
# 如果指定了算法类型,进行过滤
if algorithm_type:
query = query.filter(Algorithm.type == algorithm_type)
return query.offset(skip).limit(limit).all()
@staticmethod
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
"""更新算法"""
# 获取算法
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
if not db_algorithm:
return None
# 更新算法信息
update_data = algorithm_update.dict(exclude_unset=True)
# 应用更新
for field, value in update_data.items():
setattr(db_algorithm, field, value)
# 保存到数据库
db.commit()
db.refresh(db_algorithm)
return db_algorithm
@staticmethod
def delete_algorithm(db: Session, algorithm_id: str) -> bool:
"""删除算法"""
# 获取算法
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
if not db_algorithm:
return False
# 从数据库中删除
db.delete(db_algorithm)
db.commit()
return True
class AlgorithmVersionService:
"""算法版本服务类"""
@staticmethod
def create_version(db: Session, version: AlgorithmVersionCreate) -> AlgorithmVersion:
"""创建算法版本"""
# 生成唯一ID
version_id = f"version-{uuid.uuid4().hex[:8]}"
# 创建版本实例
db_version = AlgorithmVersion(
id=version_id,
algorithm_id=version.algorithm_id,
version=version.version,
url=version.url,
params=version.params,
input_schema=version.input_schema,
output_schema=version.output_schema,
code=version.code,
model_name=version.model_name,
model_file=version.model_file,
api_doc=version.api_doc,
is_default=version.is_default
)
# 如果设置为默认版本,需要将其他版本设置为非默认
if version.is_default:
db.query(AlgorithmVersion).filter(
AlgorithmVersion.algorithm_id == version.algorithm_id,
AlgorithmVersion.is_default == True
).update({"is_default": False})
# 保存到数据库
db.add(db_version)
db.commit()
db.refresh(db_version)
return db_version
@staticmethod
def get_version_by_id(db: Session, version_id: str) -> Optional[AlgorithmVersion]:
"""通过ID获取版本"""
return db.query(AlgorithmVersion).filter(AlgorithmVersion.id == version_id).first()
@staticmethod
def get_versions_by_algorithm_id(db: Session, algorithm_id: str) -> List[AlgorithmVersion]:
"""通过算法ID获取版本列表"""
return db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm_id).all()
@staticmethod
def get_default_version(db: Session, algorithm_id: str) -> Optional[AlgorithmVersion]:
"""获取算法的默认版本"""
return db.query(AlgorithmVersion).filter(
AlgorithmVersion.algorithm_id == algorithm_id,
AlgorithmVersion.is_default == True
).first()
@staticmethod
def update_version(db: Session, version_id: str, version_update: AlgorithmVersionUpdate) -> Optional[AlgorithmVersion]:
"""更新算法版本"""
# 获取版本
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
if not db_version:
return None
# 更新版本信息
update_data = version_update.dict(exclude_unset=True)
# 如果设置为默认版本,需要将其他版本设置为非默认
if "is_default" in update_data and update_data["is_default"]:
db.query(AlgorithmVersion).filter(
AlgorithmVersion.algorithm_id == db_version.algorithm_id,
AlgorithmVersion.is_default == True,
AlgorithmVersion.id != version_id
).update({"is_default": False})
# 应用更新
for field, value in update_data.items():
setattr(db_version, field, value)
# 保存到数据库
db.commit()
db.refresh(db_version)
return db_version
@staticmethod
def delete_version(db: Session, version_id: str) -> bool:
"""删除算法版本"""
# 获取版本
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
if not db_version:
return False
# 从数据库中删除
db.delete(db_version)
db.commit()
return True
class AlgorithmCallService:
"""算法调用服务类"""
@staticmethod
def create_call(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
"""创建算法调用记录"""
# 生成唯一ID
call_id = f"call-{uuid.uuid4().hex[:8]}"
# 创建调用实例
db_call = AlgorithmCall(
id=call_id,
user_id=user_id,
algorithm_id=call.algorithm_id,
version_id=call.version_id,
input_data=call.input_data,
params=call.params
)
# 保存到数据库
db.add(db_call)
db.commit()
db.refresh(db_call)
return db_call
@staticmethod
def execute_algorithm(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
"""执行算法"""
# 创建调用记录
db_call = AlgorithmCallService.create_call(db, user_id, call)
# 更新状态为运行中
db_call.status = "running"
db.commit()
db.refresh(db_call)
try:
# 获取算法版本信息
version = AlgorithmVersionService.get_version_by_id(db, call.version_id)
if not version:
db_call.status = "failed"
db_call.error_message = "算法版本不存在"
db.commit()
return db_call
# 处理视频输入数据
processed_input_data = call.input_data.copy()
if 'video' in processed_input_data and processed_input_data['video']:
from app.utils.file import file_storage
import io
import base64
import uuid
# 从base64字符串解码视频数据
video_data = processed_input_data['video']
if video_data.startswith('data:'):
# 移除data URL前缀
header, encoded = video_data.split(',', 1)
video_bytes = base64.b64decode(encoded)
# 提取文件扩展名
import re
match = re.search(r'data:video/(\w+);', header)
ext = match.group(1) if match else 'mp4'
# 生成唯一文件名
video_filename = f"videos/{user_id}/{uuid.uuid4().hex[:12]}.{ext}"
# 上传到MinIO
file_obj = io.BytesIO(video_bytes)
success = file_storage.upload_fileobj(file_obj, video_filename, f'video/{ext}')
if success:
# 替换为MinIO文件路径
processed_input_data['video'] = video_filename
else:
db_call.status = "failed"
db_call.error_message = "视频文件上传失败"
db.commit()
return db_call
# 处理PLY文件输入数据
if 'ply' in processed_input_data and processed_input_data['ply']:
from app.utils.file import file_storage
import io
import base64
import uuid
# 从base64字符串解码PLY数据
ply_data = processed_input_data['ply']
if ply_data.startswith('data:'):
# 移除data URL前缀
header, encoded = ply_data.split(',', 1)
ply_bytes = base64.b64decode(encoded)
# 生成唯一文件名
ply_filename = f"ply/{user_id}/{uuid.uuid4().hex[:12]}.ply"
# 上传到MinIO
file_obj = io.BytesIO(ply_bytes)
success = file_storage.upload_fileobj(file_obj, ply_filename, 'application/octet-stream')
if success:
# 替换为MinIO文件路径
processed_input_data['ply'] = ply_filename
else:
db_call.status = "failed"
db_call.error_message = "PLY文件上传失败"
db.commit()
return db_call
# 记录开始时间
start_time = time.time()
# 调用算法API
response = requests.post(
version.url,
json={
"input_data": processed_input_data,
"params": call.params
},
timeout=30
)
# 计算响应时间
response_time = time.time() - start_time
# 处理响应
if response.status_code == 200:
output_data = response.json()
db_call.status = "success"
db_call.output_data = output_data
db_call.response_time = response_time
else:
db_call.status = "failed"
db_call.error_message = f"算法执行失败: {response.text}"
db_call.response_time = response_time
except Exception as e:
# 记录错误
db_call.status = "failed"
db_call.error_message = f"算法执行异常: {str(e)}"
db_call.response_time = time.time() - start_time if 'start_time' in locals() else None
# 保存到数据库
db.commit()
db.refresh(db_call)
return db_call
@staticmethod
def get_call_by_id(db: Session, call_id: str) -> Optional[AlgorithmCall]:
"""通过ID获取调用记录"""
return db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first()
@staticmethod
def get_calls_by_user_id(db: Session, user_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
"""通过用户ID获取调用记录列表"""
return db.query(AlgorithmCall).filter(
AlgorithmCall.user_id == user_id
).offset(skip).limit(limit).all()
@staticmethod
def get_calls_by_algorithm_id(db: Session, algorithm_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
"""通过算法ID获取调用记录列表"""
return db.query(AlgorithmCall).filter(
AlgorithmCall.algorithm_id == algorithm_id
).offset(skip).limit(limit).all()
@staticmethod
def execute_python_code(code: str) -> dict:
"""执行Python代码"""
import subprocess
import sys
import io
import contextlib
# 准备执行环境
result = {
"success": False,
"output": "",
"error": ""
}
try:
# 创建一个安全的执行环境
# 使用subprocess创建一个独立的进程限制执行时间
# 注意:这只是一个基本的安全措施,生产环境中需要更严格的沙箱
import tempfile
import os
# 创建临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
temp_file_name = f.name
try:
# 执行代码限制时间为5秒
output = subprocess.check_output(
[sys.executable, temp_file_name],
stderr=subprocess.STDOUT,
timeout=5,
universal_newlines=True
)
result["success"] = True
result["output"] = output
except subprocess.TimeoutExpired:
result["error"] = "代码执行超时超过5秒"
except subprocess.CalledProcessError as e:
result["error"] = e.output
finally:
# 清理临时文件
if os.path.exists(temp_file_name):
os.unlink(temp_file_name)
except Exception as e:
result["error"] = f"执行环境错误: {str(e)}"
return result