from typing import Optional, List, Dict, Any from datetime import datetime from sqlalchemy.orm import Session import uuid import requests import time from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate from app.services.deployment import deployment_service class AlgorithmService: """算法服务类""" @staticmethod def create_algorithm(db: Session, algorithm: AlgorithmCreate) -> Algorithm: """创建算法""" # 生成唯一ID algorithm_id = f"algorithm-{uuid.uuid4().hex[:8]}" # 创建算法实例 db_algorithm = Algorithm( id=algorithm_id, name=algorithm.name, description=algorithm.description, type=algorithm.type ) # 保存到数据库 db.add(db_algorithm) db.commit() db.refresh(db_algorithm) # 自动部署(如果有代码) deployed_url = algorithm.url deployment_logs = [] if algorithm.code and not deployed_url: try: # 构建镜像 build_result = deployment_service.build_algorithm_image( algorithm.name, algorithm.code ) if build_result['success']: image_name = build_result['image_name'] deployment_logs.extend(build_result['logs']) # 部署容器 deployment_info = deployment_service.deploy_algorithm( algorithm_id, image_name ) deployed_url = deployment_info['api_url'] deployment_logs.append(f"部署成功: {deployed_url}") else: deployment_logs.extend(build_result['logs']) print(f"镜像构建失败: {build_result['logs'][-1]}") except Exception as e: error_message = f"自动部署失败: {str(e)}" deployment_logs.append(error_message) print(error_message) # 创建默认版本 version_id = f"version-{uuid.uuid4().hex[:8]}" db_version = AlgorithmVersion( id=version_id, algorithm_id=algorithm_id, version=algorithm.version, url=deployed_url, params=algorithm.params, input_schema=algorithm.input_schema, output_schema=algorithm.output_schema, code=algorithm.code, model_name=algorithm.model_name, model_file=algorithm.model_file, api_doc=algorithm.api_doc, is_default=True ) # 保存版本到数据库 db.add(db_version) db.commit() db.refresh(db_version) # 加载版本关系 db.refresh(db_algorithm, ['versions']) return db_algorithm @staticmethod def get_algorithm_by_id(db: Session, algorithm_id: str) -> Optional[Algorithm]: """通过ID获取算法""" return db.query(Algorithm).filter(Algorithm.id == algorithm_id).first() @staticmethod def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]: """获取算法列表""" query = db.query(Algorithm) # 如果指定了算法类型,进行过滤 if algorithm_type: query = query.filter(Algorithm.type == algorithm_type) return query.offset(skip).limit(limit).all() @staticmethod def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]: """更新算法""" # 获取算法 db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id) if not db_algorithm: return None # 更新算法信息 update_data = algorithm_update.dict(exclude_unset=True) # 应用更新 for field, value in update_data.items(): setattr(db_algorithm, field, value) # 保存到数据库 db.commit() db.refresh(db_algorithm) return db_algorithm @staticmethod def delete_algorithm(db: Session, algorithm_id: str) -> bool: """删除算法""" # 获取算法 db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id) if not db_algorithm: return False # 从数据库中删除 db.delete(db_algorithm) db.commit() return True class AlgorithmVersionService: """算法版本服务类""" @staticmethod def create_version(db: Session, version: AlgorithmVersionCreate) -> AlgorithmVersion: """创建算法版本""" # 生成唯一ID version_id = f"version-{uuid.uuid4().hex[:8]}" # 创建版本实例 db_version = AlgorithmVersion( id=version_id, algorithm_id=version.algorithm_id, version=version.version, url=version.url, params=version.params, input_schema=version.input_schema, output_schema=version.output_schema, code=version.code, model_name=version.model_name, model_file=version.model_file, api_doc=version.api_doc, is_default=version.is_default ) # 如果设置为默认版本,需要将其他版本设置为非默认 if version.is_default: db.query(AlgorithmVersion).filter( AlgorithmVersion.algorithm_id == version.algorithm_id, AlgorithmVersion.is_default == True ).update({"is_default": False}) # 保存到数据库 db.add(db_version) db.commit() db.refresh(db_version) return db_version @staticmethod def get_version_by_id(db: Session, version_id: str) -> Optional[AlgorithmVersion]: """通过ID获取版本""" return db.query(AlgorithmVersion).filter(AlgorithmVersion.id == version_id).first() @staticmethod def get_versions_by_algorithm_id(db: Session, algorithm_id: str) -> List[AlgorithmVersion]: """通过算法ID获取版本列表""" return db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm_id).all() @staticmethod def get_default_version(db: Session, algorithm_id: str) -> Optional[AlgorithmVersion]: """获取算法的默认版本""" return db.query(AlgorithmVersion).filter( AlgorithmVersion.algorithm_id == algorithm_id, AlgorithmVersion.is_default == True ).first() @staticmethod def update_version(db: Session, version_id: str, version_update: AlgorithmVersionUpdate) -> Optional[AlgorithmVersion]: """更新算法版本""" # 获取版本 db_version = AlgorithmVersionService.get_version_by_id(db, version_id) if not db_version: return None # 更新版本信息 update_data = version_update.dict(exclude_unset=True) # 如果设置为默认版本,需要将其他版本设置为非默认 if "is_default" in update_data and update_data["is_default"]: db.query(AlgorithmVersion).filter( AlgorithmVersion.algorithm_id == db_version.algorithm_id, AlgorithmVersion.is_default == True, AlgorithmVersion.id != version_id ).update({"is_default": False}) # 应用更新 for field, value in update_data.items(): setattr(db_version, field, value) # 保存到数据库 db.commit() db.refresh(db_version) return db_version @staticmethod def delete_version(db: Session, version_id: str) -> bool: """删除算法版本""" # 获取版本 db_version = AlgorithmVersionService.get_version_by_id(db, version_id) if not db_version: return False # 从数据库中删除 db.delete(db_version) db.commit() return True class AlgorithmCallService: """算法调用服务类""" @staticmethod def create_call(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall: """创建算法调用记录""" # 生成唯一ID call_id = f"call-{uuid.uuid4().hex[:8]}" # 创建调用实例 db_call = AlgorithmCall( id=call_id, user_id=user_id, algorithm_id=call.algorithm_id, version_id=call.version_id, input_data=call.input_data, params=call.params ) # 保存到数据库 db.add(db_call) db.commit() db.refresh(db_call) return db_call @staticmethod def execute_algorithm(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall: """执行算法""" # 创建调用记录 db_call = AlgorithmCallService.create_call(db, user_id, call) # 更新状态为运行中 db_call.status = "running" db.commit() db.refresh(db_call) try: # 获取算法版本信息 version = AlgorithmVersionService.get_version_by_id(db, call.version_id) if not version: db_call.status = "failed" db_call.error_message = "算法版本不存在" db.commit() return db_call # 处理视频输入数据 processed_input_data = call.input_data.copy() if 'video' in processed_input_data and processed_input_data['video']: from app.utils.file import file_storage import io import base64 import uuid # 从base64字符串解码视频数据 video_data = processed_input_data['video'] if video_data.startswith('data:'): # 移除data URL前缀 header, encoded = video_data.split(',', 1) video_bytes = base64.b64decode(encoded) # 提取文件扩展名 import re match = re.search(r'data:video/(\w+);', header) ext = match.group(1) if match else 'mp4' # 生成唯一文件名 video_filename = f"videos/{user_id}/{uuid.uuid4().hex[:12]}.{ext}" # 上传到MinIO file_obj = io.BytesIO(video_bytes) success = file_storage.upload_fileobj(file_obj, video_filename, f'video/{ext}') if success: # 替换为MinIO文件路径 processed_input_data['video'] = video_filename else: db_call.status = "failed" db_call.error_message = "视频文件上传失败" db.commit() return db_call # 处理PLY文件输入数据 if 'ply' in processed_input_data and processed_input_data['ply']: from app.utils.file import file_storage import io import base64 import uuid # 从base64字符串解码PLY数据 ply_data = processed_input_data['ply'] if ply_data.startswith('data:'): # 移除data URL前缀 header, encoded = ply_data.split(',', 1) ply_bytes = base64.b64decode(encoded) # 生成唯一文件名 ply_filename = f"ply/{user_id}/{uuid.uuid4().hex[:12]}.ply" # 上传到MinIO file_obj = io.BytesIO(ply_bytes) success = file_storage.upload_fileobj(file_obj, ply_filename, 'application/octet-stream') if success: # 替换为MinIO文件路径 processed_input_data['ply'] = ply_filename else: db_call.status = "failed" db_call.error_message = "PLY文件上传失败" db.commit() return db_call # 记录开始时间 start_time = time.time() # 调用算法API response = requests.post( version.url, json={ "input_data": processed_input_data, "params": call.params }, timeout=30 ) # 计算响应时间 response_time = time.time() - start_time # 处理响应 if response.status_code == 200: output_data = response.json() db_call.status = "success" db_call.output_data = output_data db_call.response_time = response_time else: db_call.status = "failed" db_call.error_message = f"算法执行失败: {response.text}" db_call.response_time = response_time except Exception as e: # 记录错误 db_call.status = "failed" db_call.error_message = f"算法执行异常: {str(e)}" db_call.response_time = time.time() - start_time if 'start_time' in locals() else None # 保存到数据库 db.commit() db.refresh(db_call) return db_call @staticmethod def get_call_by_id(db: Session, call_id: str) -> Optional[AlgorithmCall]: """通过ID获取调用记录""" return db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first() @staticmethod def get_calls_by_user_id(db: Session, user_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]: """通过用户ID获取调用记录列表""" return db.query(AlgorithmCall).filter( AlgorithmCall.user_id == user_id ).offset(skip).limit(limit).all() @staticmethod def get_calls_by_algorithm_id(db: Session, algorithm_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]: """通过算法ID获取调用记录列表""" return db.query(AlgorithmCall).filter( AlgorithmCall.algorithm_id == algorithm_id ).offset(skip).limit(limit).all() @staticmethod def execute_python_code(code: str) -> dict: """执行Python代码""" import subprocess import sys import io import contextlib # 准备执行环境 result = { "success": False, "output": "", "error": "" } try: # 创建一个安全的执行环境 # 使用subprocess创建一个独立的进程,限制执行时间 # 注意:这只是一个基本的安全措施,生产环境中需要更严格的沙箱 import tempfile import os # 创建临时文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: f.write(code) temp_file_name = f.name try: # 执行代码,限制时间为5秒 output = subprocess.check_output( [sys.executable, temp_file_name], stderr=subprocess.STDOUT, timeout=5, universal_newlines=True ) result["success"] = True result["output"] = output except subprocess.TimeoutExpired: result["error"] = "代码执行超时(超过5秒)" except subprocess.CalledProcessError as e: result["error"] = e.output finally: # 清理临时文件 if os.path.exists(temp_file_name): os.unlink(temp_file_name) except Exception as e: result["error"] = f"执行环境错误: {str(e)}" return result