from typing import Optional, List, Dict, Any from datetime import datetime from sqlalchemy.orm import Session import uuid import requests import time import logging import os logger = logging.getLogger(__name__) from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate from app.services.deployment import deployment_service class AlgorithmService: """算法服务类""" @staticmethod def create_algorithm(db: Session, algorithm: AlgorithmCreate) -> Algorithm: """创建算法""" # 生成唯一ID algorithm_id = f"algorithm-{uuid.uuid4().hex[:8]}" # 创建算法实例 db_algorithm = Algorithm( id=algorithm_id, name=algorithm.name, description=algorithm.description, type=algorithm.type ) # 保存到数据库 db.add(db_algorithm) db.commit() db.refresh(db_algorithm) # 自动部署(如果有代码) deployed_url = algorithm.url deployment_logs = [] if algorithm.code and not deployed_url: try: # 构建镜像 build_result = deployment_service.build_algorithm_image( algorithm.name, algorithm.code ) if build_result['success']: image_name = build_result['image_name'] deployment_logs.extend(build_result['logs']) # 部署容器 deployment_info = deployment_service.deploy_algorithm( algorithm_id, image_name ) deployed_url = deployment_info['api_url'] deployment_logs.append(f"部署成功: {deployed_url}") else: deployment_logs.extend(build_result['logs']) print(f"镜像构建失败: {build_result['logs'][-1]}") except Exception as e: error_message = f"自动部署失败: {str(e)}" deployment_logs.append(error_message) print(error_message) # 创建默认版本 version_id = f"version-{uuid.uuid4().hex[:8]}" db_version = AlgorithmVersion( id=version_id, algorithm_id=algorithm_id, version=algorithm.version, url=deployed_url, params=algorithm.params, input_schema=algorithm.input_schema, output_schema=algorithm.output_schema, code=algorithm.code, model_name=algorithm.model_name, model_file=algorithm.model_file, api_doc=algorithm.api_doc, is_default=True ) # 保存版本到数据库 db.add(db_version) db.commit() db.refresh(db_version) # 加载版本关系 db.refresh(db_algorithm, ['versions']) return db_algorithm @staticmethod def get_algorithm_by_id(db: Session, algorithm_id: str) -> Optional[Algorithm]: """通过ID获取算法""" return db.query(Algorithm).filter(Algorithm.id == algorithm_id).first() @staticmethod def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]: """获取算法列表,优先显示已注册的服务""" from app.models.models import AlgorithmService as Service # 获取所有已注册的服务 services = db.query(Service).all() # 获取所有算法 query = db.query(Algorithm) # 如果指定了算法类型,进行过滤 if algorithm_type: query = query.filter(Algorithm.type == algorithm_type) algorithms = query.all() # 创建服务名称集合,用于快速查找 service_names = {service.name for service in services} # 将算法分为两类:已注册的服务和普通算法 registered_algorithms = [] normal_algorithms = [] for algo in algorithms: if algo.name in service_names: registered_algorithms.append(algo) else: normal_algorithms.append(algo) # 合并结果:已注册的服务在前,普通算法在后 merged_algorithms = registered_algorithms + normal_algorithms # 应用分页 return merged_algorithms[skip:skip+limit] @staticmethod def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]: """更新算法""" # 获取算法 db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id) if not db_algorithm: return None # 更新算法信息 update_data = algorithm_update.dict(exclude_unset=True) # 应用更新 for field, value in update_data.items(): setattr(db_algorithm, field, value) # 保存到数据库 db.commit() db.refresh(db_algorithm) return db_algorithm @staticmethod def delete_algorithm(db: Session, algorithm_id: str) -> bool: """删除算法""" # 获取算法 db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id) if not db_algorithm: return False # 从数据库中删除 db.delete(db_algorithm) db.commit() return True class AlgorithmVersionService: """算法版本服务类""" @staticmethod def create_version(db: Session, version: AlgorithmVersionCreate) -> AlgorithmVersion: """创建算法版本""" # 生成唯一ID version_id = f"version-{uuid.uuid4().hex[:8]}" # 创建版本实例 db_version = AlgorithmVersion( id=version_id, algorithm_id=version.algorithm_id, version=version.version, url=version.url, params=version.params, input_schema=version.input_schema, output_schema=version.output_schema, code=version.code, model_name=version.model_name, model_file=version.model_file, api_doc=version.api_doc, is_default=version.is_default ) # 如果设置为默认版本,需要将其他版本设置为非默认 if version.is_default: db.query(AlgorithmVersion).filter( AlgorithmVersion.algorithm_id == version.algorithm_id, AlgorithmVersion.is_default == True ).update({"is_default": False}) # 保存到数据库 db.add(db_version) db.commit() db.refresh(db_version) return db_version @staticmethod def get_version_by_id(db: Session, version_id: str) -> Optional[AlgorithmVersion]: """通过ID获取版本""" return db.query(AlgorithmVersion).filter(AlgorithmVersion.id == version_id).first() @staticmethod def get_versions_by_algorithm_id(db: Session, algorithm_id: str) -> List[AlgorithmVersion]: """通过算法ID获取版本列表""" return db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm_id).all() @staticmethod def get_default_version(db: Session, algorithm_id: str) -> Optional[AlgorithmVersion]: """获取算法的默认版本""" return db.query(AlgorithmVersion).filter( AlgorithmVersion.algorithm_id == algorithm_id, AlgorithmVersion.is_default == True ).first() @staticmethod def update_version(db: Session, version_id: str, version_update: AlgorithmVersionUpdate) -> Optional[AlgorithmVersion]: """更新算法版本""" # 获取版本 db_version = AlgorithmVersionService.get_version_by_id(db, version_id) if not db_version: return None # 更新版本信息 update_data = version_update.dict(exclude_unset=True) # 如果设置为默认版本,需要将其他版本设置为非默认 if "is_default" in update_data and update_data["is_default"]: db.query(AlgorithmVersion).filter( AlgorithmVersion.algorithm_id == db_version.algorithm_id, AlgorithmVersion.is_default == True, AlgorithmVersion.id != version_id ).update({"is_default": False}) # 应用更新 for field, value in update_data.items(): setattr(db_version, field, value) # 保存到数据库 db.commit() db.refresh(db_version) return db_version @staticmethod def delete_version(db: Session, version_id: str) -> bool: """删除算法版本""" # 获取版本 db_version = AlgorithmVersionService.get_version_by_id(db, version_id) if not db_version: return False # 从数据库中删除 db.delete(db_version) db.commit() return True class AlgorithmCallService: """算法调用服务类""" @staticmethod def create_call(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall: """创建算法调用记录""" # 生成唯一ID call_id = f"call-{uuid.uuid4().hex[:8]}" # 创建调用实例 db_call = AlgorithmCall( id=call_id, user_id=user_id, algorithm_id=call.algorithm_id, version_id=call.version_id, input_data=call.input_data, params=call.params ) # 保存到数据库 db.add(db_call) db.commit() db.refresh(db_call) return db_call @staticmethod def execute_algorithm(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall: """执行算法""" # 创建调用记录 db_call = AlgorithmCallService.create_call(db, user_id, call) # 更新状态为运行中 db_call.status = "running" db.commit() db.refresh(db_call) try: # 获取算法版本信息 version = AlgorithmVersionService.get_version_by_id(db, call.version_id) if not version: db_call.status = "failed" db_call.error_message = "算法版本不存在" db.commit() return db_call # 处理视频输入数据 processed_input_data = call.input_data.copy() if 'video' in processed_input_data and processed_input_data['video']: from app.utils.file import file_storage import io import base64 import uuid # 从base64字符串解码视频数据 video_data = processed_input_data['video'] if video_data.startswith('data:'): # 移除data URL前缀 header, encoded = video_data.split(',', 1) video_bytes = base64.b64decode(encoded) # 提取文件扩展名 import re match = re.search(r'data:video/(\w+);', header) ext = match.group(1) if match else 'mp4' # 生成唯一文件名 video_filename = f"videos/{user_id}/{uuid.uuid4().hex[:12]}.{ext}" # 上传到MinIO file_obj = io.BytesIO(video_bytes) success = file_storage.upload_fileobj(file_obj, video_filename, f'video/{ext}') if success: # 替换为MinIO文件路径 processed_input_data['video'] = video_filename else: db_call.status = "failed" db_call.error_message = "视频文件上传失败" db.commit() return db_call # 处理PLY文件输入数据 if 'ply' in processed_input_data and processed_input_data['ply']: from app.utils.file import file_storage import io import base64 import uuid # 从base64字符串解码PLY数据 ply_data = processed_input_data['ply'] if ply_data.startswith('data:'): # 移除data URL前缀 header, encoded = ply_data.split(',', 1) ply_bytes = base64.b64decode(encoded) # 生成唯一文件名 ply_filename = f"ply/{user_id}/{uuid.uuid4().hex[:12]}.ply" # 上传到MinIO file_obj = io.BytesIO(ply_bytes) success = file_storage.upload_fileobj(file_obj, ply_filename, 'application/octet-stream') if success: # 替换为MinIO文件路径 processed_input_data['ply'] = ply_filename else: db_call.status = "failed" db_call.error_message = "PLY文件上传失败" db.commit() return db_call # 记录开始时间 start_time = time.time() # 处理视频路径 - 如果是MinIO路径,下载到本地 from app.utils.file import file_storage video_path = processed_input_data.get('video', '') if video_path and video_path.startswith('media/'): # 从MinIO下载视频到本地 video_content = file_storage.get_object(video_path) if video_content: # 保存到临时文件 import tempfile import uuid suffix = '.' + video_path.split('.')[-1] if '.' in video_path else '.mp4' with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: tmp_file.write(video_content) local_video_path = tmp_file.name # 使用本地路径 processed_input_data['video'] = local_video_path else: db_call.status = "failed" db_call.error_message = "无法下载视频文件" db.commit() return db_call # 调用算法API print(f"[DEBUG] 调用算法API: {version.url}, 输入: {processed_input_data}") response = requests.post( version.url, json={ "input_data": processed_input_data, "params": call.params }, timeout=120 ) # 计算响应时间 response_time = time.time() - start_time # 处理响应 if response.status_code == 200: output_data = response.json() print(f"[DEBUG] 算法响应: {output_data}") # 处理算法返回的视频文件 result = output_data.get('result', {}) if result and isinstance(result, dict): # 如果返回了本地视频路径,需要处理 if 'video' in result and result['video'] and result['video'].startswith('/'): # 这是本地路径,需要转换为可访问的URL local_video = result['video'] if os.path.exists(local_video): # 上传到MinIO并替换路径 with open(local_video, 'rb') as f: video_content = f.read() video_filename = f"results/{user_id}/{uuid.uuid4().hex[:12]}.mp4" from app.utils.file import file_storage success = file_storage.upload_from_bytes(video_content, video_filename) if success: result['video'] = video_filename result['video_url'] = f"/api/v1/data/media/{video_filename}" # 删除临时文件 try: os.remove(local_video) except: pass db_call.status = "success" print(f"[DEBUG] 保存output_data: {output_data}") db_call.output_data = output_data db_call.response_time = response_time else: db_call.status = "failed" db_call.error_message = f"算法执行失败: {response.text}" db_call.response_time = response_time except Exception as e: # 记录错误 db_call.status = "failed" db_call.error_message = f"算法执行异常: {str(e)}" db_call.response_time = time.time() - start_time if 'start_time' in locals() else None # 保存到数据库 db.commit() db.refresh(db_call) return db_call @staticmethod def get_call_by_id(db: Session, call_id: str) -> Optional[AlgorithmCall]: """通过ID获取调用记录""" return db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first() @staticmethod def get_calls_by_user_id(db: Session, user_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]: """通过用户ID获取调用记录列表""" return db.query(AlgorithmCall).filter( AlgorithmCall.user_id == user_id ).offset(skip).limit(limit).all() @staticmethod def get_calls_by_algorithm_id(db: Session, algorithm_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]: """通过算法ID获取调用记录列表""" return db.query(AlgorithmCall).filter( AlgorithmCall.algorithm_id == algorithm_id ).offset(skip).limit(limit).all() @staticmethod def execute_python_code(code: str) -> dict: """执行Python代码""" import subprocess import sys import io import contextlib # 准备执行环境 result = { "success": False, "output": "", "error": "" } try: # 创建一个安全的执行环境 # 使用subprocess创建一个独立的进程,限制执行时间 # 注意:这只是一个基本的安全措施,生产环境中需要更严格的沙箱 import tempfile import os # 创建临时文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: f.write(code) temp_file_name = f.name try: # 执行代码,限制时间为5秒 output = subprocess.check_output( [sys.executable, temp_file_name], stderr=subprocess.STDOUT, timeout=5, universal_newlines=True ) result["success"] = True result["output"] = output except subprocess.TimeoutExpired: result["error"] = "代码执行超时(超过5秒)" except subprocess.CalledProcessError as e: result["error"] = e.output finally: # 清理临时文件 if os.path.exists(temp_file_name): os.unlink(temp_file_name) except Exception as e: result["error"] = f"执行环境错误: {str(e)}" return result