first commit
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
BIN
backend/app/services/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/algorithm.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/algorithm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/algorithm.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/algorithm.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/data_manager.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/data_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/data_manager.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/data_manager.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/deployment.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/deployment.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/deployment.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/deployment.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/history_manager.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/history_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/history_manager.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/history_manager.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/monitoring.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/monitoring.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/monitoring.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/monitoring.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/permission.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/permission.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/permission.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/permission.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend/app/services/__pycache__/service_manager.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/service_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/service_manager.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/service_manager.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/app/services/__pycache__/user.cpython-312.pyc
Normal file
BIN
backend/app/services/__pycache__/user.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/services/__pycache__/user.cpython-39.pyc
Normal file
BIN
backend/app/services/__pycache__/user.cpython-39.pyc
Normal file
Binary file not shown.
465
backend/app/services/algorithm.py
Normal file
465
backend/app/services/algorithm.py
Normal file
@@ -0,0 +1,465 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
import requests
|
||||
import time
|
||||
|
||||
from app.models.models import Algorithm, AlgorithmVersion, AlgorithmCall
|
||||
from app.schemas.algorithm import AlgorithmCreate, AlgorithmUpdate, AlgorithmVersionCreate, AlgorithmVersionUpdate, AlgorithmCallCreate
|
||||
from app.services.deployment import deployment_service
|
||||
|
||||
|
||||
class AlgorithmService:
|
||||
"""算法服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_algorithm(db: Session, algorithm: AlgorithmCreate) -> Algorithm:
|
||||
"""创建算法"""
|
||||
# 生成唯一ID
|
||||
algorithm_id = f"algorithm-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建算法实例
|
||||
db_algorithm = Algorithm(
|
||||
id=algorithm_id,
|
||||
name=algorithm.name,
|
||||
description=algorithm.description,
|
||||
type=algorithm.type
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_algorithm)
|
||||
db.commit()
|
||||
db.refresh(db_algorithm)
|
||||
|
||||
# 自动部署(如果有代码)
|
||||
deployed_url = algorithm.url
|
||||
deployment_logs = []
|
||||
if algorithm.code and not deployed_url:
|
||||
try:
|
||||
# 构建镜像
|
||||
build_result = deployment_service.build_algorithm_image(
|
||||
algorithm.name,
|
||||
algorithm.code
|
||||
)
|
||||
|
||||
if build_result['success']:
|
||||
image_name = build_result['image_name']
|
||||
deployment_logs.extend(build_result['logs'])
|
||||
|
||||
# 部署容器
|
||||
deployment_info = deployment_service.deploy_algorithm(
|
||||
algorithm_id,
|
||||
image_name
|
||||
)
|
||||
|
||||
deployed_url = deployment_info['api_url']
|
||||
deployment_logs.append(f"部署成功: {deployed_url}")
|
||||
else:
|
||||
deployment_logs.extend(build_result['logs'])
|
||||
print(f"镜像构建失败: {build_result['logs'][-1]}")
|
||||
except Exception as e:
|
||||
error_message = f"自动部署失败: {str(e)}"
|
||||
deployment_logs.append(error_message)
|
||||
print(error_message)
|
||||
|
||||
# 创建默认版本
|
||||
version_id = f"version-{uuid.uuid4().hex[:8]}"
|
||||
db_version = AlgorithmVersion(
|
||||
id=version_id,
|
||||
algorithm_id=algorithm_id,
|
||||
version=algorithm.version,
|
||||
url=deployed_url,
|
||||
params=algorithm.params,
|
||||
input_schema=algorithm.input_schema,
|
||||
output_schema=algorithm.output_schema,
|
||||
code=algorithm.code,
|
||||
model_name=algorithm.model_name,
|
||||
model_file=algorithm.model_file,
|
||||
api_doc=algorithm.api_doc,
|
||||
is_default=True
|
||||
)
|
||||
|
||||
# 保存版本到数据库
|
||||
db.add(db_version)
|
||||
db.commit()
|
||||
db.refresh(db_version)
|
||||
|
||||
# 加载版本关系
|
||||
db.refresh(db_algorithm, ['versions'])
|
||||
|
||||
return db_algorithm
|
||||
|
||||
@staticmethod
|
||||
def get_algorithm_by_id(db: Session, algorithm_id: str) -> Optional[Algorithm]:
|
||||
"""通过ID获取算法"""
|
||||
return db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_algorithms(db: Session, skip: int = 0, limit: int = 100, algorithm_type: Optional[str] = None) -> List[Algorithm]:
|
||||
"""获取算法列表"""
|
||||
query = db.query(Algorithm)
|
||||
|
||||
# 如果指定了算法类型,进行过滤
|
||||
if algorithm_type:
|
||||
query = query.filter(Algorithm.type == algorithm_type)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def update_algorithm(db: Session, algorithm_id: str, algorithm_update: AlgorithmUpdate) -> Optional[Algorithm]:
|
||||
"""更新算法"""
|
||||
# 获取算法
|
||||
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not db_algorithm:
|
||||
return None
|
||||
|
||||
# 更新算法信息
|
||||
update_data = algorithm_update.dict(exclude_unset=True)
|
||||
|
||||
# 应用更新
|
||||
for field, value in update_data.items():
|
||||
setattr(db_algorithm, field, value)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_algorithm)
|
||||
|
||||
return db_algorithm
|
||||
|
||||
@staticmethod
|
||||
def delete_algorithm(db: Session, algorithm_id: str) -> bool:
|
||||
"""删除算法"""
|
||||
# 获取算法
|
||||
db_algorithm = AlgorithmService.get_algorithm_by_id(db, algorithm_id)
|
||||
if not db_algorithm:
|
||||
return False
|
||||
|
||||
# 从数据库中删除
|
||||
db.delete(db_algorithm)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class AlgorithmVersionService:
|
||||
"""算法版本服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_version(db: Session, version: AlgorithmVersionCreate) -> AlgorithmVersion:
|
||||
"""创建算法版本"""
|
||||
# 生成唯一ID
|
||||
version_id = f"version-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建版本实例
|
||||
db_version = AlgorithmVersion(
|
||||
id=version_id,
|
||||
algorithm_id=version.algorithm_id,
|
||||
version=version.version,
|
||||
url=version.url,
|
||||
params=version.params,
|
||||
input_schema=version.input_schema,
|
||||
output_schema=version.output_schema,
|
||||
code=version.code,
|
||||
model_name=version.model_name,
|
||||
model_file=version.model_file,
|
||||
api_doc=version.api_doc,
|
||||
is_default=version.is_default
|
||||
)
|
||||
|
||||
# 如果设置为默认版本,需要将其他版本设置为非默认
|
||||
if version.is_default:
|
||||
db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == version.algorithm_id,
|
||||
AlgorithmVersion.is_default == True
|
||||
).update({"is_default": False})
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_version)
|
||||
db.commit()
|
||||
db.refresh(db_version)
|
||||
|
||||
return db_version
|
||||
|
||||
@staticmethod
|
||||
def get_version_by_id(db: Session, version_id: str) -> Optional[AlgorithmVersion]:
|
||||
"""通过ID获取版本"""
|
||||
return db.query(AlgorithmVersion).filter(AlgorithmVersion.id == version_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_versions_by_algorithm_id(db: Session, algorithm_id: str) -> List[AlgorithmVersion]:
|
||||
"""通过算法ID获取版本列表"""
|
||||
return db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_default_version(db: Session, algorithm_id: str) -> Optional[AlgorithmVersion]:
|
||||
"""获取算法的默认版本"""
|
||||
return db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == algorithm_id,
|
||||
AlgorithmVersion.is_default == True
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def update_version(db: Session, version_id: str, version_update: AlgorithmVersionUpdate) -> Optional[AlgorithmVersion]:
|
||||
"""更新算法版本"""
|
||||
# 获取版本
|
||||
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not db_version:
|
||||
return None
|
||||
|
||||
# 更新版本信息
|
||||
update_data = version_update.dict(exclude_unset=True)
|
||||
|
||||
# 如果设置为默认版本,需要将其他版本设置为非默认
|
||||
if "is_default" in update_data and update_data["is_default"]:
|
||||
db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == db_version.algorithm_id,
|
||||
AlgorithmVersion.is_default == True,
|
||||
AlgorithmVersion.id != version_id
|
||||
).update({"is_default": False})
|
||||
|
||||
# 应用更新
|
||||
for field, value in update_data.items():
|
||||
setattr(db_version, field, value)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_version)
|
||||
|
||||
return db_version
|
||||
|
||||
@staticmethod
|
||||
def delete_version(db: Session, version_id: str) -> bool:
|
||||
"""删除算法版本"""
|
||||
# 获取版本
|
||||
db_version = AlgorithmVersionService.get_version_by_id(db, version_id)
|
||||
if not db_version:
|
||||
return False
|
||||
|
||||
# 从数据库中删除
|
||||
db.delete(db_version)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class AlgorithmCallService:
|
||||
"""算法调用服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_call(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
|
||||
"""创建算法调用记录"""
|
||||
# 生成唯一ID
|
||||
call_id = f"call-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建调用实例
|
||||
db_call = AlgorithmCall(
|
||||
id=call_id,
|
||||
user_id=user_id,
|
||||
algorithm_id=call.algorithm_id,
|
||||
version_id=call.version_id,
|
||||
input_data=call.input_data,
|
||||
params=call.params
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_call)
|
||||
db.commit()
|
||||
db.refresh(db_call)
|
||||
|
||||
return db_call
|
||||
|
||||
@staticmethod
|
||||
def execute_algorithm(db: Session, user_id: str, call: AlgorithmCallCreate) -> AlgorithmCall:
|
||||
"""执行算法"""
|
||||
# 创建调用记录
|
||||
db_call = AlgorithmCallService.create_call(db, user_id, call)
|
||||
|
||||
# 更新状态为运行中
|
||||
db_call.status = "running"
|
||||
db.commit()
|
||||
db.refresh(db_call)
|
||||
|
||||
try:
|
||||
# 获取算法版本信息
|
||||
version = AlgorithmVersionService.get_version_by_id(db, call.version_id)
|
||||
if not version:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "算法版本不存在"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 处理视频输入数据
|
||||
processed_input_data = call.input_data.copy()
|
||||
if 'video' in processed_input_data and processed_input_data['video']:
|
||||
from app.utils.file import file_storage
|
||||
import io
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
# 从base64字符串解码视频数据
|
||||
video_data = processed_input_data['video']
|
||||
if video_data.startswith('data:'):
|
||||
# 移除data URL前缀
|
||||
header, encoded = video_data.split(',', 1)
|
||||
video_bytes = base64.b64decode(encoded)
|
||||
|
||||
# 提取文件扩展名
|
||||
import re
|
||||
match = re.search(r'data:video/(\w+);', header)
|
||||
ext = match.group(1) if match else 'mp4'
|
||||
|
||||
# 生成唯一文件名
|
||||
video_filename = f"videos/{user_id}/{uuid.uuid4().hex[:12]}.{ext}"
|
||||
|
||||
# 上传到MinIO
|
||||
file_obj = io.BytesIO(video_bytes)
|
||||
success = file_storage.upload_fileobj(file_obj, video_filename, f'video/{ext}')
|
||||
|
||||
if success:
|
||||
# 替换为MinIO文件路径
|
||||
processed_input_data['video'] = video_filename
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "视频文件上传失败"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 处理PLY文件输入数据
|
||||
if 'ply' in processed_input_data and processed_input_data['ply']:
|
||||
from app.utils.file import file_storage
|
||||
import io
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
# 从base64字符串解码PLY数据
|
||||
ply_data = processed_input_data['ply']
|
||||
if ply_data.startswith('data:'):
|
||||
# 移除data URL前缀
|
||||
header, encoded = ply_data.split(',', 1)
|
||||
ply_bytes = base64.b64decode(encoded)
|
||||
|
||||
# 生成唯一文件名
|
||||
ply_filename = f"ply/{user_id}/{uuid.uuid4().hex[:12]}.ply"
|
||||
|
||||
# 上传到MinIO
|
||||
file_obj = io.BytesIO(ply_bytes)
|
||||
success = file_storage.upload_fileobj(file_obj, ply_filename, 'application/octet-stream')
|
||||
|
||||
if success:
|
||||
# 替换为MinIO文件路径
|
||||
processed_input_data['ply'] = ply_filename
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = "PLY文件上传失败"
|
||||
db.commit()
|
||||
return db_call
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 调用算法API
|
||||
response = requests.post(
|
||||
version.url,
|
||||
json={
|
||||
"input_data": processed_input_data,
|
||||
"params": call.params
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# 处理响应
|
||||
if response.status_code == 200:
|
||||
output_data = response.json()
|
||||
db_call.status = "success"
|
||||
db_call.output_data = output_data
|
||||
db_call.response_time = response_time
|
||||
else:
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = f"算法执行失败: {response.text}"
|
||||
db_call.response_time = response_time
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
db_call.status = "failed"
|
||||
db_call.error_message = f"算法执行异常: {str(e)}"
|
||||
db_call.response_time = time.time() - start_time if 'start_time' in locals() else None
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_call)
|
||||
|
||||
return db_call
|
||||
|
||||
@staticmethod
|
||||
def get_call_by_id(db: Session, call_id: str) -> Optional[AlgorithmCall]:
|
||||
"""通过ID获取调用记录"""
|
||||
return db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_calls_by_user_id(db: Session, user_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
|
||||
"""通过用户ID获取调用记录列表"""
|
||||
return db.query(AlgorithmCall).filter(
|
||||
AlgorithmCall.user_id == user_id
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_calls_by_algorithm_id(db: Session, algorithm_id: str, skip: int = 0, limit: int = 100) -> List[AlgorithmCall]:
|
||||
"""通过算法ID获取调用记录列表"""
|
||||
return db.query(AlgorithmCall).filter(
|
||||
AlgorithmCall.algorithm_id == algorithm_id
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def execute_python_code(code: str) -> dict:
|
||||
"""执行Python代码"""
|
||||
import subprocess
|
||||
import sys
|
||||
import io
|
||||
import contextlib
|
||||
|
||||
# 准备执行环境
|
||||
result = {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": ""
|
||||
}
|
||||
|
||||
try:
|
||||
# 创建一个安全的执行环境
|
||||
# 使用subprocess创建一个独立的进程,限制执行时间
|
||||
# 注意:这只是一个基本的安全措施,生产环境中需要更严格的沙箱
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write(code)
|
||||
temp_file_name = f.name
|
||||
|
||||
try:
|
||||
# 执行代码,限制时间为5秒
|
||||
output = subprocess.check_output(
|
||||
[sys.executable, temp_file_name],
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=5,
|
||||
universal_newlines=True
|
||||
)
|
||||
result["success"] = True
|
||||
result["output"] = output
|
||||
except subprocess.TimeoutExpired:
|
||||
result["error"] = "代码执行超时(超过5秒)"
|
||||
except subprocess.CalledProcessError as e:
|
||||
result["error"] = e.output
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_file_name):
|
||||
os.unlink(temp_file_name)
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = f"执行环境错误: {str(e)}"
|
||||
|
||||
return result
|
||||
310
backend/app/services/data_manager.py
Normal file
310
backend/app/services/data_manager.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""数据管理服务,负责管理输入数据、输出结果和元数据"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models import AlgorithmCall
|
||||
from app.utils.file import file_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataManager:
|
||||
"""数据管理器,处理输入数据、输出结果和元数据的存储与检索"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage_path = "data_storage" # 数据存储路径
|
||||
|
||||
def save_input_data(self, user_id: str, algorithm_id: str, input_data: Dict[str, Any]) -> str:
|
||||
"""保存输入数据"""
|
||||
try:
|
||||
# 生成唯一ID
|
||||
data_id = f"input_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# 准备存储数据
|
||||
storage_data = {
|
||||
"data_id": data_id,
|
||||
"user_id": user_id,
|
||||
"algorithm_id": algorithm_id,
|
||||
"input_data": input_data,
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# 存储到数据库或文件系统
|
||||
# 这里我们使用MinIO存储大文件,数据库存储引用
|
||||
file_path = f"inputs/{user_id}/{algorithm_id}/{data_id}.json"
|
||||
|
||||
# 转换为JSON字符串
|
||||
json_str = json.dumps(storage_data, ensure_ascii=False, default=str)
|
||||
|
||||
# 上传到MinIO
|
||||
success = file_storage.upload_from_bytes(json_str.encode('utf-8'), file_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Input data saved: {data_id}")
|
||||
return data_id
|
||||
else:
|
||||
logger.error(f"Failed to save input data: {data_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving input data: {str(e)}")
|
||||
return None
|
||||
|
||||
def save_output_data(self, user_id: str, algorithm_id: str, call_id: str, output_data: Dict[str, Any]) -> str:
|
||||
"""保存输出结果数据"""
|
||||
try:
|
||||
# 生成唯一ID
|
||||
data_id = f"output_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# 准备存储数据
|
||||
storage_data = {
|
||||
"data_id": data_id,
|
||||
"user_id": user_id,
|
||||
"algorithm_id": algorithm_id,
|
||||
"call_id": call_id,
|
||||
"output_data": output_data,
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# 存储到MinIO
|
||||
file_path = f"outputs/{user_id}/{algorithm_id}/{call_id}_{data_id}.json"
|
||||
|
||||
# 转换为JSON字符串
|
||||
json_str = json.dumps(storage_data, ensure_ascii=False, default=str)
|
||||
|
||||
# 上传到MinIO
|
||||
success = file_storage.upload_from_bytes(json_str.encode('utf-8'), file_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Output data saved: {data_id}")
|
||||
return data_id
|
||||
else:
|
||||
logger.error(f"Failed to save output data: {data_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving output data: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_input_data(self, data_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取输入数据"""
|
||||
try:
|
||||
# 从MinIO获取数据
|
||||
file_path = f"inputs/{data_id}.json"
|
||||
|
||||
# 下载文件内容
|
||||
content = file_storage.download_file(file_path)
|
||||
if content:
|
||||
data = json.loads(content.decode('utf-8'))
|
||||
return data
|
||||
else:
|
||||
logger.warning(f"Input data not found: {data_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting input data: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_output_data(self, data_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取输出结果数据"""
|
||||
try:
|
||||
# 从MinIO获取数据
|
||||
file_path = f"outputs/{data_id}.json"
|
||||
|
||||
# 下载文件内容
|
||||
content = file_storage.download_file(file_path)
|
||||
if content:
|
||||
data = json.loads(content.decode('utf-8'))
|
||||
return data
|
||||
else:
|
||||
logger.warning(f"Output data not found: {data_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting output data: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_user_inputs(self, user_id: str, algorithm_id: str = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""获取用户的历史输入数据"""
|
||||
try:
|
||||
# 构建搜索路径
|
||||
if algorithm_id:
|
||||
search_prefix = f"inputs/{user_id}/{algorithm_id}/"
|
||||
else:
|
||||
search_prefix = f"inputs/{user_id}/"
|
||||
|
||||
# 列出MinIO中的文件
|
||||
files = file_storage.list_files(search_prefix)
|
||||
|
||||
inputs = []
|
||||
for file_info in files[:limit]: # 限制返回数量
|
||||
try:
|
||||
content = file_storage.download_file(file_info.object_name)
|
||||
if content:
|
||||
data = json.loads(content.decode('utf-8'))
|
||||
inputs.append(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input file {file_info.object_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
return inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user inputs: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_user_outputs(self, user_id: str, algorithm_id: str = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""获取用户的历史输出数据"""
|
||||
try:
|
||||
# 构建搜索路径
|
||||
if algorithm_id:
|
||||
search_prefix = f"outputs/{user_id}/{algorithm_id}/"
|
||||
else:
|
||||
search_prefix = f"outputs/{user_id}/"
|
||||
|
||||
# 列出MinIO中的文件
|
||||
files = file_storage.list_files(search_prefix)
|
||||
|
||||
outputs = []
|
||||
for file_info in files[:limit]: # 限制返回数量
|
||||
try:
|
||||
content = file_storage.download_file(file_info.object_name)
|
||||
if content:
|
||||
data = json.loads(content.decode('utf-8'))
|
||||
outputs.append(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing output file {file_info.object_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
return outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user outputs: {str(e)}")
|
||||
return []
|
||||
|
||||
def delete_user_data(self, user_id: str) -> bool:
|
||||
"""删除用户的所有数据"""
|
||||
try:
|
||||
# 删除用户相关的所有输入数据
|
||||
input_prefix = f"inputs/{user_id}/"
|
||||
input_files = file_storage.list_files(input_prefix)
|
||||
for file_info in input_files:
|
||||
file_storage.remove_file(file_info.object_name)
|
||||
|
||||
# 删除用户相关的所有输出数据
|
||||
output_prefix = f"outputs/{user_id}/"
|
||||
output_files = file_storage.list_files(output_prefix)
|
||||
for file_info in output_files:
|
||||
file_storage.remove_file(file_info.object_name)
|
||||
|
||||
logger.info(f"User data deleted: {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user data: {str(e)}")
|
||||
return False
|
||||
|
||||
def save_media_file(self, user_id: str, algorithm_id: str, file_content: bytes, file_name: str) -> Optional[str]:
|
||||
"""保存媒体文件(如图片、视频等)"""
|
||||
try:
|
||||
# 生成唯一文件名
|
||||
unique_name = f"{uuid.uuid4().hex[:12]}_{file_name}"
|
||||
file_path = f"media/{user_id}/{algorithm_id}/{unique_name}"
|
||||
|
||||
# 上传到MinIO
|
||||
success = file_storage.upload_from_bytes(file_content, file_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Media file saved: {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
logger.error(f"Failed to save media file: {file_path}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving media file: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_media_file(self, file_path: str) -> Optional[bytes]:
|
||||
"""获取媒体文件内容"""
|
||||
try:
|
||||
content = file_storage.download_file(file_path)
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting media file: {str(e)}")
|
||||
return None
|
||||
|
||||
def create_data_snapshot(self, call_record: AlgorithmCall) -> Dict[str, Any]:
|
||||
"""创建数据快照,保存调用时的完整数据状态"""
|
||||
try:
|
||||
snapshot = {
|
||||
"snapshot_id": f"snapshot_{uuid.uuid4().hex[:12]}",
|
||||
"call_id": call_record.id,
|
||||
"user_id": call_record.user_id,
|
||||
"algorithm_id": call_record.algorithm_id,
|
||||
"version_id": call_record.version_id,
|
||||
"input_data": call_record.input_data,
|
||||
"output_data": call_record.output_data,
|
||||
"params": call_record.params,
|
||||
"status": call_record.status,
|
||||
"response_time": call_record.response_time,
|
||||
"created_at": call_record.created_at.isoformat() if call_record.created_at else None,
|
||||
"updated_at": call_record.updated_at.isoformat() if call_record.updated_at else None
|
||||
}
|
||||
|
||||
# 保存快照到MinIO
|
||||
file_path = f"snapshots/{call_record.user_id}/{call_record.algorithm_id}/{call_record.id}.json"
|
||||
json_str = json.dumps(snapshot, ensure_ascii=False, default=str)
|
||||
|
||||
success = file_storage.upload_from_bytes(json_str.encode('utf-8'), file_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Data snapshot created: {snapshot['snapshot_id']}")
|
||||
return snapshot
|
||||
else:
|
||||
logger.error(f"Failed to create data snapshot: {snapshot['snapshot_id']}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating data snapshot: {str(e)}")
|
||||
return None
|
||||
|
||||
def search_data_by_metadata(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""根据元数据搜索数据"""
|
||||
try:
|
||||
# 从数据库中获取匹配的调用记录
|
||||
# 注意:这里仅作示例,实际实现可能需要更复杂的索引和搜索机制
|
||||
results = []
|
||||
|
||||
# 如果需要按用户搜索
|
||||
if filters.get('user_id'):
|
||||
# 这里应该查询数据库中的相关记录
|
||||
pass
|
||||
|
||||
# 如果需要按算法搜索
|
||||
if filters.get('algorithm_id'):
|
||||
# 这里应该查询数据库中的相关记录
|
||||
pass
|
||||
|
||||
# 如果需要按日期范围搜索
|
||||
if filters.get('date_range'):
|
||||
# 这里应该查询数据库中的相关记录
|
||||
pass
|
||||
|
||||
# 返回匹配的结果
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching data by metadata: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
# 全局数据管理器实例
|
||||
data_manager = DataManager()
|
||||
582
backend/app/services/deployment.py
Normal file
582
backend/app/services/deployment.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""部署服务模块,负责Docker容器管理、镜像构建和自动部署"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import json
|
||||
import tempfile
|
||||
import shutil
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
import docker
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
|
||||
from app.models.models import Algorithm, AlgorithmVersion
|
||||
from app.schemas.algorithm import AlgorithmCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeploymentService:
|
||||
"""部署服务类,负责Docker容器管理和自动部署"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
# 连接Docker守护进程
|
||||
self.client = docker.from_env()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Docker: {str(e)}")
|
||||
self.client = None
|
||||
|
||||
def detect_dependencies(self, code: str) -> List[str]:
|
||||
"""
|
||||
自动检测Python代码中的依赖
|
||||
|
||||
Args:
|
||||
code: Python代码字符串
|
||||
|
||||
Returns:
|
||||
依赖包列表
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
# 常见的导入语句检测
|
||||
import_patterns = [
|
||||
'import ',
|
||||
'from ',
|
||||
'pip install '
|
||||
]
|
||||
|
||||
# 常见的Python包
|
||||
common_packages = {
|
||||
'numpy', 'pandas', 'scikit-learn', 'tensorflow', 'torch', 'keras',
|
||||
'opencv-python', 'pillow', 'matplotlib', 'seaborn', 'nltk', 'spacy',
|
||||
'transformers', 'fastapi', 'flask', 'requests', 'urllib3', 'beautifulsoup4',
|
||||
'sqlalchemy', 'pymongo', 'redis', 'psycopg2', 'pymysql', 'onnxruntime'
|
||||
}
|
||||
|
||||
for line in code.split('\n'):
|
||||
line = line.strip()
|
||||
for pattern in import_patterns:
|
||||
if line.startswith(pattern):
|
||||
# 提取包名
|
||||
if pattern == 'import ':
|
||||
parts = line.split()[1].split('.')
|
||||
if parts[0] in common_packages:
|
||||
dependencies.append(parts[0])
|
||||
elif pattern == 'from ':
|
||||
parts = line.split()[1].split('.')
|
||||
if parts[0] in common_packages:
|
||||
dependencies.append(parts[0])
|
||||
elif pattern == 'pip install ':
|
||||
package = line.split(' ', 2)[2]
|
||||
dependencies.append(package)
|
||||
|
||||
# 去重
|
||||
return list(set(dependencies))
|
||||
|
||||
def build_algorithm_image(self, algorithm_name: str, code: str, dependencies: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
构建算法Docker镜像
|
||||
|
||||
Args:
|
||||
algorithm_name: 算法名称
|
||||
code: Python代码
|
||||
dependencies: 依赖包列表(如果为None则自动检测)
|
||||
|
||||
Returns:
|
||||
包含镜像名称和构建日志的字典
|
||||
"""
|
||||
try:
|
||||
build_logs = []
|
||||
|
||||
# 自动检测依赖
|
||||
build_logs.append(f"开始构建镜像: {algorithm_name}")
|
||||
if dependencies is None:
|
||||
build_logs.append("自动检测依赖...")
|
||||
dependencies = self.detect_dependencies(code)
|
||||
build_logs.append(f"检测到依赖: {', '.join(dependencies)}")
|
||||
else:
|
||||
build_logs.append(f"使用指定依赖: {', '.join(dependencies)}")
|
||||
|
||||
# 创建临时目录
|
||||
build_logs.append("创建构建环境...")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# 生成Dockerfile
|
||||
dockerfile_content = self._generate_dockerfile(dependencies)
|
||||
dockerfile_path = os.path.join(temp_dir, 'Dockerfile')
|
||||
|
||||
with open(dockerfile_path, 'w') as f:
|
||||
f.write(dockerfile_content)
|
||||
build_logs.append("生成Dockerfile完成")
|
||||
|
||||
# 生成算法代码文件
|
||||
algorithm_path = os.path.join(temp_dir, 'algorithm.py')
|
||||
with open(algorithm_path, 'w') as f:
|
||||
f.write(code)
|
||||
build_logs.append("生成算法代码文件完成")
|
||||
|
||||
# 生成API服务文件
|
||||
api_path = os.path.join(temp_dir, 'app.py')
|
||||
with open(api_path, 'w') as f:
|
||||
f.write(self._generate_api_service())
|
||||
build_logs.append("生成API服务文件完成")
|
||||
|
||||
# 构建镜像
|
||||
image_name = f"algorithm-{algorithm_name.lower().replace(' ', '-')}:latest"
|
||||
build_logs.append(f"开始构建镜像: {image_name}")
|
||||
|
||||
logger.info(f"Building Docker image: {image_name}")
|
||||
logger.info(f"Dependencies: {dependencies}")
|
||||
|
||||
# 使用Docker SDK构建镜像
|
||||
if self.client:
|
||||
image, logs = self.client.images.build(
|
||||
path=temp_dir,
|
||||
tag=image_name,
|
||||
rm=True
|
||||
)
|
||||
|
||||
# 打印构建日志
|
||||
for log in logs:
|
||||
if 'stream' in log:
|
||||
log_message = log['stream'].strip()
|
||||
if log_message:
|
||||
build_logs.append(log_message)
|
||||
logger.info(log_message)
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
build_logs.append("使用subprocess构建镜像...")
|
||||
result = subprocess.run(
|
||||
['docker', 'build', '-t', image_name, temp_dir],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_message = f"Docker build failed: {result.stderr}"
|
||||
build_logs.append(error_message)
|
||||
logger.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if line:
|
||||
build_logs.append(line)
|
||||
logger.info(line)
|
||||
|
||||
success_message = f"Successfully built image: {image_name}"
|
||||
build_logs.append(success_message)
|
||||
logger.info(success_message)
|
||||
|
||||
return {
|
||||
'image_name': image_name,
|
||||
'logs': build_logs,
|
||||
'success': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Failed to build algorithm image: {str(e)}"
|
||||
logger.error(error_message)
|
||||
return {
|
||||
'image_name': None,
|
||||
'logs': [error_message],
|
||||
'success': False
|
||||
}
|
||||
|
||||
def deploy_algorithm(self, algorithm_id: str, image_name: str, port: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
部署算法容器
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
image_name: 镜像名称
|
||||
port: 端口号(如果为None则自动分配)
|
||||
|
||||
Returns:
|
||||
部署信息
|
||||
"""
|
||||
try:
|
||||
# 自动分配端口
|
||||
if port is None:
|
||||
port = self._get_available_port()
|
||||
|
||||
# 容器名称
|
||||
container_name = f"algorithm-{algorithm_id}"
|
||||
|
||||
logger.info(f"Deploying container: {container_name} on port {port}")
|
||||
|
||||
# 停止并移除同名容器
|
||||
self._stop_and_remove_container(container_name)
|
||||
|
||||
# 启动容器
|
||||
if self.client:
|
||||
container = self.client.containers.run(
|
||||
image_name,
|
||||
name=container_name,
|
||||
ports={'8000/tcp': port},
|
||||
detach=True,
|
||||
restart_policy={'Name': 'unless-stopped'}
|
||||
)
|
||||
|
||||
container_id = container.id
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
result = subprocess.run(
|
||||
[
|
||||
'docker', 'run',
|
||||
'--name', container_name,
|
||||
'-p', f'{port}:8000',
|
||||
'-d',
|
||||
'--restart', 'unless-stopped',
|
||||
image_name
|
||||
],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to run container: {result.stderr}")
|
||||
raise Exception(f"Failed to run container: {result.stderr}")
|
||||
|
||||
container_id = result.stdout.strip()
|
||||
|
||||
# 等待容器启动
|
||||
time.sleep(2)
|
||||
|
||||
# 验证容器状态
|
||||
container_status = self.get_container_status(container_name)
|
||||
|
||||
deployment_info = {
|
||||
'container_name': container_name,
|
||||
'container_id': container_id,
|
||||
'image_name': image_name,
|
||||
'port': port,
|
||||
'status': container_status,
|
||||
'api_url': f"http://localhost:{port}"
|
||||
}
|
||||
|
||||
logger.info(f"Successfully deployed algorithm: {deployment_info}")
|
||||
return deployment_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deploy algorithm: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_container_status(self, container_name: str) -> str:
|
||||
"""
|
||||
获取容器状态
|
||||
|
||||
Args:
|
||||
container_name: 容器名称
|
||||
|
||||
Returns:
|
||||
容器状态
|
||||
"""
|
||||
try:
|
||||
if self.client:
|
||||
container = self.client.containers.get(container_name)
|
||||
return container.status
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
result = subprocess.run(
|
||||
['docker', 'inspect', '--format', '{{.State.Status}}', container_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get container status: {str(e)}")
|
||||
return 'error'
|
||||
|
||||
def stop_container(self, container_name: str) -> bool:
|
||||
"""
|
||||
停止容器
|
||||
|
||||
Args:
|
||||
container_name: 容器名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
if self.client:
|
||||
container = self.client.containers.get(container_name)
|
||||
container.stop()
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
result = subprocess.run(
|
||||
['docker', 'stop', container_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to stop container: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info(f"Successfully stopped container: {container_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop container: {str(e)}")
|
||||
return False
|
||||
|
||||
def remove_container(self, container_name: str) -> bool:
|
||||
"""
|
||||
移除容器
|
||||
|
||||
Args:
|
||||
container_name: 容器名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
if self.client:
|
||||
container = self.client.containers.get(container_name)
|
||||
container.remove(force=True)
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
result = subprocess.run(
|
||||
['docker', 'rm', '-f', container_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to remove container: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info(f"Successfully removed container: {container_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove container: {str(e)}")
|
||||
return False
|
||||
|
||||
def list_containers(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出所有算法容器
|
||||
|
||||
Returns:
|
||||
容器列表
|
||||
"""
|
||||
containers = []
|
||||
|
||||
try:
|
||||
if self.client:
|
||||
for container in self.client.containers.list(all=True):
|
||||
if container.name.startswith('algorithm-'):
|
||||
containers.append({
|
||||
'name': container.name,
|
||||
'id': container.id,
|
||||
'image': container.image.tags[0] if container.image.tags else 'unknown',
|
||||
'status': container.status,
|
||||
'ports': container.ports,
|
||||
'created': container.attrs['Created']
|
||||
})
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
result = subprocess.run(
|
||||
['docker', 'ps', '-a', '--format', '{{json .}}'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if line:
|
||||
container_info = json.loads(line)
|
||||
if container_info['Names'].startswith('/algorithm-'):
|
||||
containers.append({
|
||||
'name': container_info['Names'].lstrip('/'),
|
||||
'id': container_info['ID'],
|
||||
'image': container_info['Image'],
|
||||
'status': container_info['Status'],
|
||||
'ports': container_info['Ports'],
|
||||
'created': container_info['CreatedAt']
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list containers: {str(e)}")
|
||||
|
||||
return containers
|
||||
|
||||
def _generate_dockerfile(self, dependencies: List[str]) -> str:
|
||||
"""
|
||||
生成Dockerfile
|
||||
|
||||
Args:
|
||||
dependencies: 依赖包列表
|
||||
|
||||
Returns:
|
||||
Dockerfile内容
|
||||
"""
|
||||
dockerfile = """
|
||||
FROM python:3.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 安装Python依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 启动服务
|
||||
CMD ["python", "app.py"]
|
||||
"""
|
||||
|
||||
# 生成requirements.txt内容
|
||||
requirements = """
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
python-multipart
|
||||
numpy
|
||||
"""
|
||||
|
||||
# 为模型部署添加必要的依赖
|
||||
model_deps = {'onnxruntime', 'scikit-learn'}
|
||||
for dep in model_deps:
|
||||
if dep not in dependencies:
|
||||
requirements += f"{dep}\n"
|
||||
|
||||
for dep in dependencies:
|
||||
if dep not in model_deps:
|
||||
requirements += f"{dep}\n"
|
||||
|
||||
# 将requirements.txt内容添加到Dockerfile
|
||||
dockerfile = dockerfile.replace('COPY requirements.txt .', f"RUN echo '{requirements}' > requirements.txt")
|
||||
|
||||
return dockerfile
|
||||
|
||||
def _generate_api_service(self) -> str:
|
||||
"""
|
||||
生成API服务代码
|
||||
|
||||
Returns:
|
||||
API服务代码
|
||||
"""
|
||||
api_code = """
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import importlib.util
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 导入算法模块
|
||||
spec = importlib.util.spec_from_file_location("algorithm", "algorithm.py")
|
||||
algorithm = importlib.util.module_from_spec(spec)
|
||||
sys.modules["algorithm"] = algorithm
|
||||
spec.loader.exec_module(algorithm)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class AlgorithmInput(BaseModel):
|
||||
input_data: dict
|
||||
params: dict = {}
|
||||
|
||||
class AlgorithmOutput(BaseModel):
|
||||
success: bool
|
||||
result: dict
|
||||
error: str = ""
|
||||
|
||||
@app.post("/execute", response_model=AlgorithmOutput)
|
||||
def execute_algorithm(input_data: AlgorithmInput):
|
||||
# 执行算法
|
||||
try:
|
||||
# 调用算法的execute函数
|
||||
if hasattr(algorithm, 'execute'):
|
||||
result = algorithm.execute(input_data.input_data, input_data.params)
|
||||
return AlgorithmOutput(
|
||||
success=True,
|
||||
result=result
|
||||
)
|
||||
else:
|
||||
raise Exception("Algorithm module does not have execute function")
|
||||
|
||||
except Exception as e:
|
||||
return AlgorithmOutput(
|
||||
success=False,
|
||||
result={},
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
# 健康检查
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
# 根路径
|
||||
return {"message": "Algorithm API Service"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
"""
|
||||
|
||||
return api_code
|
||||
|
||||
def _get_available_port(self) -> int:
|
||||
"""
|
||||
获取可用端口
|
||||
|
||||
Returns:
|
||||
可用端口号
|
||||
"""
|
||||
import socket
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
def _stop_and_remove_container(self, container_name: str):
|
||||
"""
|
||||
停止并移除容器
|
||||
|
||||
Args:
|
||||
container_name: 容器名称
|
||||
"""
|
||||
try:
|
||||
if self.client:
|
||||
# 尝试获取容器
|
||||
try:
|
||||
container = self.client.containers.get(container_name)
|
||||
# 停止容器
|
||||
container.stop()
|
||||
# 移除容器
|
||||
container.remove()
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
else:
|
||||
# 备用方案:使用subprocess
|
||||
# 停止容器
|
||||
subprocess.run(['docker', 'stop', container_name], capture_output=True)
|
||||
# 移除容器
|
||||
subprocess.run(['docker', 'rm', '-f', container_name], capture_output=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop and remove container: {str(e)}")
|
||||
|
||||
|
||||
# 全局部署服务实例
|
||||
deployment_service = DeploymentService()
|
||||
257
backend/app/services/history_manager.py
Normal file
257
backend/app/services/history_manager.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""历史记录管理服务,负责管理算法调用历史和用户操作历史"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, desc
|
||||
|
||||
from app.models.models import AlgorithmCall, User, Algorithm
|
||||
from app.utils.file import file_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HistoryManager:
|
||||
"""历史记录管理器,处理算法调用历史和其他用户操作历史"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_user_call_history(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: str,
|
||||
algorithm_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
status: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[AlgorithmCall]:
|
||||
"""获取用户的调用历史"""
|
||||
query = db.query(AlgorithmCall).filter(AlgorithmCall.user_id == user_id)
|
||||
|
||||
# 添加过滤条件
|
||||
if algorithm_id:
|
||||
query = query.filter(AlgorithmCall.algorithm_id == algorithm_id)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(AlgorithmCall.created_at >= start_date)
|
||||
|
||||
if end_date:
|
||||
query = query.filter(AlgorithmCall.created_at <= end_date)
|
||||
|
||||
if status:
|
||||
query = query.filter(AlgorithmCall.status == status)
|
||||
|
||||
# 按创建时间倒序排列
|
||||
query = query.order_by(desc(AlgorithmCall.created_at))
|
||||
|
||||
# 分页
|
||||
history = query.offset(skip).limit(limit).all()
|
||||
|
||||
return history
|
||||
|
||||
def get_algorithm_call_history(
|
||||
self,
|
||||
db: Session,
|
||||
algorithm_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
status: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[AlgorithmCall]:
|
||||
"""获取特定算法的调用历史"""
|
||||
query = db.query(AlgorithmCall).filter(AlgorithmCall.algorithm_id == algorithm_id)
|
||||
|
||||
# 添加过滤条件
|
||||
if user_id:
|
||||
query = query.filter(AlgorithmCall.user_id == user_id)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(AlgorithmCall.created_at >= start_date)
|
||||
|
||||
if end_date:
|
||||
query = query.filter(AlgorithmCall.created_at <= end_date)
|
||||
|
||||
if status:
|
||||
query = query.filter(AlgorithmCall.status == status)
|
||||
|
||||
# 按创建时间倒序排列
|
||||
query = query.order_by(desc(AlgorithmCall.created_at))
|
||||
|
||||
# 分页
|
||||
history = query.offset(skip).limit(limit).all()
|
||||
|
||||
return history
|
||||
|
||||
def get_call_statistics(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: Optional[str] = None,
|
||||
algorithm_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""获取调用统计信息"""
|
||||
query = db.query(AlgorithmCall)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(AlgorithmCall.user_id == user_id)
|
||||
|
||||
if algorithm_id:
|
||||
query = query.filter(AlgorithmCall.algorithm_id == algorithm_id)
|
||||
|
||||
# 总调用次数
|
||||
total_calls = query.count()
|
||||
|
||||
# 按状态统计
|
||||
status_counts = db.query(
|
||||
AlgorithmCall.status,
|
||||
db.func.count(AlgorithmCall.id)
|
||||
).filter(
|
||||
AlgorithmCall.user_id == user_id if user_id else AlgorithmCall.algorithm_id == algorithm_id
|
||||
).group_by(AlgorithmCall.status).all()
|
||||
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
# 成功率
|
||||
success_count = status_dict.get('success', 0)
|
||||
success_rate = (success_count / total_calls * 100) if total_calls > 0 else 0
|
||||
|
||||
# 平均响应时间
|
||||
avg_response_time = db.query(
|
||||
db.func.avg(AlgorithmCall.response_time)
|
||||
).filter(
|
||||
AlgorithmCall.response_time.isnot(None),
|
||||
AlgorithmCall.user_id == user_id if user_id else AlgorithmCall.algorithm_id == algorithm_id
|
||||
).scalar()
|
||||
|
||||
# 今日调用次数
|
||||
today_start = datetime.combine(datetime.today().date(), datetime.min.time())
|
||||
today_calls = query.filter(AlgorithmCall.created_at >= today_start).count()
|
||||
|
||||
return {
|
||||
"total_calls": total_calls,
|
||||
"status_counts": status_dict,
|
||||
"success_rate": round(success_rate, 2),
|
||||
"avg_response_time": round(avg_response_time, 3) if avg_response_time else None,
|
||||
"today_calls": today_calls
|
||||
}
|
||||
|
||||
def get_comparison_data(
|
||||
self,
|
||||
db: Session,
|
||||
call_ids: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取用于对比的历史数据"""
|
||||
calls = db.query(AlgorithmCall).filter(AlgorithmCall.id.in_(call_ids)).all()
|
||||
|
||||
comparison_data = []
|
||||
for call in calls:
|
||||
comparison_data.append({
|
||||
"id": call.id,
|
||||
"algorithm_id": call.algorithm_id,
|
||||
"algorithm_name": getattr(call.algorithm, 'name', 'Unknown'),
|
||||
"version_id": call.version_id,
|
||||
"input_data": call.input_data,
|
||||
"output_data": call.output_data,
|
||||
"params": call.params,
|
||||
"status": call.status,
|
||||
"response_time": call.response_time,
|
||||
"created_at": call.created_at.isoformat() if call.created_at else None,
|
||||
"error_message": call.error_message
|
||||
})
|
||||
|
||||
return comparison_data
|
||||
|
||||
def export_history(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: str,
|
||||
algorithm_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
format_type: str = "json"
|
||||
) -> Optional[str]:
|
||||
"""导出历史记录"""
|
||||
history = self.get_user_call_history(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
algorithm_id=algorithm_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
skip=0,
|
||||
limit=10000 # 限制导出数量
|
||||
)
|
||||
|
||||
# 转换为可序列化的格式
|
||||
export_data = []
|
||||
for call in history:
|
||||
export_data.append({
|
||||
"id": call.id,
|
||||
"user_id": call.user_id,
|
||||
"algorithm_id": call.algorithm_id,
|
||||
"algorithm_name": getattr(call.algorithm, 'name', 'Unknown'),
|
||||
"version_id": call.version_id,
|
||||
"version_number": getattr(call.version, 'version', 'Unknown'),
|
||||
"input_data": call.input_data,
|
||||
"output_data": call.output_data,
|
||||
"params": call.params,
|
||||
"status": call.status,
|
||||
"response_time": call.response_time,
|
||||
"created_at": call.created_at.isoformat() if call.created_at else None,
|
||||
"updated_at": call.updated_at.isoformat() if call.updated_at else None,
|
||||
"error_message": call.error_message
|
||||
})
|
||||
|
||||
try:
|
||||
if format_type.lower() == "json":
|
||||
json_str = json.dumps(export_data, ensure_ascii=False, default=str)
|
||||
file_path = f"exports/history_{user_id}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
# 上传到存储
|
||||
success = file_storage.upload_from_bytes(json_str.encode('utf-8'), file_path)
|
||||
|
||||
if success:
|
||||
return file_path
|
||||
else:
|
||||
logger.error("Failed to upload exported history")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"Unsupported export format: {format_type}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting history: {str(e)}")
|
||||
return None
|
||||
|
||||
def delete_old_history(
|
||||
self,
|
||||
db: Session,
|
||||
days_old: int,
|
||||
user_id: Optional[str] = None,
|
||||
algorithm_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""删除旧的历史记录"""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query = db.query(AlgorithmCall).filter(AlgorithmCall.created_at < cutoff_date)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(AlgorithmCall.user_id == user_id)
|
||||
|
||||
if algorithm_id:
|
||||
query = query.filter(AlgorithmCall.algorithm_id == algorithm_id)
|
||||
|
||||
deleted_count = query.delete()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deleted {deleted_count} old history records")
|
||||
return deleted_count
|
||||
|
||||
|
||||
# 全局历史记录管理器实例
|
||||
history_manager = HistoryManager()
|
||||
284
backend/app/services/monitoring.py
Normal file
284
backend/app/services/monitoring.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""监控服务,负责系统监控、性能指标收集和告警"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import psutil
|
||||
import os
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.models.models import AlgorithmCall, User, Algorithm
|
||||
from app.services.service_manager import service_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""指标收集器,收集系统和业务指标"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics_history = defaultdict(lambda: deque(maxlen=1000)) # 保留最近1000个指标
|
||||
self.start_time = datetime.utcnow()
|
||||
|
||||
def collect_system_metrics(self) -> Dict[str, Any]:
|
||||
"""收集系统指标"""
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory_info = psutil.virtual_memory()
|
||||
disk_usage = psutil.disk_usage('/')
|
||||
|
||||
metrics = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_percent": memory_info.percent,
|
||||
"memory_available": memory_info.available,
|
||||
"memory_total": memory_info.total,
|
||||
"disk_percent": disk_usage.percent,
|
||||
"disk_free": disk_usage.free,
|
||||
"disk_total": disk_usage.total,
|
||||
"process_count": len(psutil.pids()),
|
||||
"uptime": (datetime.utcnow() - self.start_time).total_seconds()
|
||||
}
|
||||
|
||||
# 存储指标历史
|
||||
self.metrics_history['system'].append(metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def collect_business_metrics(self, db: Session) -> Dict[str, Any]:
|
||||
"""收集业务指标"""
|
||||
# 算法调用统计
|
||||
total_calls = db.query(func.count(AlgorithmCall.id)).scalar()
|
||||
today_calls = db.query(func.count(AlgorithmCall.id)).filter(
|
||||
AlgorithmCall.created_at >= datetime.utcnow().date()
|
||||
).scalar()
|
||||
|
||||
# 用户统计
|
||||
total_users = db.query(func.count(User.id)).scalar()
|
||||
active_users = db.query(func.count(User.id)).filter(User.status == 'active').scalar()
|
||||
|
||||
# 算法统计
|
||||
total_algorithms = db.query(func.count(Algorithm.id)).scalar()
|
||||
active_algorithms = db.query(func.count(Algorithm.id)).filter(Algorithm.status == 'active').scalar()
|
||||
|
||||
# 按状态统计调用
|
||||
status_counts = db.query(AlgorithmCall.status, func.count(AlgorithmCall.id)).group_by(AlgorithmCall.status).all()
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
# 平均响应时间(最近1小时)
|
||||
recent_calls = db.query(AlgorithmCall.response_time).filter(
|
||||
AlgorithmCall.response_time.isnot(None),
|
||||
AlgorithmCall.created_at >= datetime.utcnow() - timedelta(hours=1)
|
||||
).all()
|
||||
|
||||
avg_response_time = None
|
||||
if recent_calls:
|
||||
response_times = [call.response_time for call in recent_calls if call.response_time is not None]
|
||||
if response_times:
|
||||
avg_response_time = sum(response_times) / len(response_times)
|
||||
|
||||
metrics = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"business": {
|
||||
"total_calls": total_calls,
|
||||
"today_calls": today_calls,
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"total_algorithms": total_algorithms,
|
||||
"active_algorithms": active_algorithms,
|
||||
"call_status_counts": status_dict,
|
||||
"avg_response_time_recent_hour": avg_response_time
|
||||
}
|
||||
}
|
||||
|
||||
# 存储指标历史
|
||||
self.metrics_history['business'].append(metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def get_metric_history(self, metric_type: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""获取指标历史"""
|
||||
history = list(self.metrics_history[metric_type])
|
||||
return history[-limit:] if len(history) > limit else history
|
||||
|
||||
def get_current_metrics(self, db: Session) -> Dict[str, Any]:
|
||||
"""获取当前所有指标"""
|
||||
return {
|
||||
"system": self.collect_system_metrics(),
|
||||
"business": self.collect_business_metrics(db)
|
||||
}
|
||||
|
||||
|
||||
class AlertManager:
|
||||
"""告警管理器,处理阈值告警"""
|
||||
|
||||
def __init__(self):
|
||||
self.alert_rules = []
|
||||
self.active_alerts = {}
|
||||
self.alert_history = deque(maxlen=1000)
|
||||
|
||||
def add_alert_rule(self, name: str, condition_func, severity: str = "warning"):
|
||||
"""添加告警规则"""
|
||||
rule = {
|
||||
"name": name,
|
||||
"condition": condition_func,
|
||||
"severity": severity,
|
||||
"triggered": False
|
||||
}
|
||||
self.alert_rules.append(rule)
|
||||
|
||||
def check_alerts(self, metrics: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""检查告警条件"""
|
||||
triggered_alerts = []
|
||||
|
||||
for rule in self.alert_rules:
|
||||
try:
|
||||
is_triggered = rule["condition"](metrics)
|
||||
|
||||
if is_triggered and not rule["triggered"]:
|
||||
# 告警首次触发
|
||||
alert = {
|
||||
"name": rule["name"],
|
||||
"severity": rule["severity"],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"metrics": metrics
|
||||
}
|
||||
|
||||
self.active_alerts[rule["name"]] = alert
|
||||
self.alert_history.append(alert)
|
||||
triggered_alerts.append(alert)
|
||||
rule["triggered"] = True
|
||||
|
||||
logger.warning(f"Alert triggered: {rule['name']} - {alert}")
|
||||
|
||||
elif not is_triggered and rule["triggered"]:
|
||||
# 告警解除
|
||||
logger.info(f"Alert cleared: {rule['name']}")
|
||||
rule["triggered"] = False
|
||||
if rule["name"] in self.active_alerts:
|
||||
del self.active_alerts[rule["name"]]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking alert rule {rule['name']}: {str(e)}")
|
||||
|
||||
return triggered_alerts
|
||||
|
||||
def get_active_alerts(self) -> List[Dict[str, Any]]:
|
||||
"""获取当前激活的告警"""
|
||||
return list(self.active_alerts.values())
|
||||
|
||||
def get_alert_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""获取告警历史"""
|
||||
history = list(self.alert_history)
|
||||
return history[-limit:] if len(history) > limit else history
|
||||
|
||||
|
||||
class MonitoringService:
|
||||
"""监控服务主类"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics_collector = MetricsCollector()
|
||||
self.alert_manager = AlertManager()
|
||||
self.monitoring_task = None
|
||||
self.is_monitoring = False
|
||||
|
||||
# 添加默认告警规则
|
||||
self._setup_default_alerts()
|
||||
|
||||
def _setup_default_alerts(self):
|
||||
"""设置默认告警规则"""
|
||||
# CPU使用率过高
|
||||
def cpu_high_condition(metrics):
|
||||
cpu_percent = metrics.get("system", {}).get("cpu_percent", 0)
|
||||
return cpu_percent > 80
|
||||
|
||||
# 内存使用率过高
|
||||
def memory_high_condition(metrics):
|
||||
memory_percent = metrics.get("system", {}).get("memory_percent", 0)
|
||||
return memory_percent > 85
|
||||
|
||||
# 调用失败率过高
|
||||
def high_failure_rate_condition(metrics):
|
||||
business = metrics.get("business", {})
|
||||
status_counts = business.get("call_status_counts", {})
|
||||
total = sum(status_counts.values()) if status_counts else 1
|
||||
failed = status_counts.get("failed", 0)
|
||||
failure_rate = failed / total if total > 0 else 0
|
||||
return failure_rate > 0.1 # 失败率超过10%
|
||||
|
||||
self.alert_manager.add_alert_rule("High CPU Usage", cpu_high_condition, "warning")
|
||||
self.alert_manager.add_alert_rule("High Memory Usage", memory_high_condition, "warning")
|
||||
self.alert_manager.add_alert_rule("High Failure Rate", high_failure_rate_condition, "critical")
|
||||
|
||||
async def start_monitoring(self, db: Session, interval: int = 60):
|
||||
"""启动监控"""
|
||||
if self.is_monitoring:
|
||||
logger.warning("Monitoring already started")
|
||||
return
|
||||
|
||||
self.is_monitoring = True
|
||||
logger.info("Starting monitoring...")
|
||||
|
||||
while self.is_monitoring:
|
||||
try:
|
||||
# 收集指标
|
||||
metrics = self.metrics_collector.get_current_metrics(db)
|
||||
|
||||
# 检查告警
|
||||
triggered_alerts = self.alert_manager.check_alerts(metrics)
|
||||
|
||||
# 记录指标到日志
|
||||
logger.info(f"Collected metrics - CPU: {metrics['system']['cpu_percent']:.1f}%, "
|
||||
f"Memory: {metrics['system']['memory_percent']:.1f}%, "
|
||||
f"Total calls: {metrics['business']['business']['total_calls']}")
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring loop: {str(e)}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止监控"""
|
||||
self.is_monitoring = False
|
||||
logger.info("Monitoring stopped")
|
||||
|
||||
def get_system_health(self) -> Dict[str, Any]:
|
||||
"""获取系统健康状况"""
|
||||
system_metrics = self.metrics_collector.collect_system_metrics()
|
||||
|
||||
health_status = "healthy"
|
||||
if system_metrics["cpu_percent"] > 90 or system_metrics["memory_percent"] > 95:
|
||||
health_status = "critical"
|
||||
elif system_metrics["cpu_percent"] > 80 or system_metrics["memory_percent"] > 85:
|
||||
health_status = "warning"
|
||||
|
||||
return {
|
||||
"status": health_status,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"system_metrics": system_metrics,
|
||||
"active_alerts": len(self.alert_manager.active_alerts),
|
||||
"uptime": system_metrics["uptime"]
|
||||
}
|
||||
|
||||
def get_dashboard_data(self, db: Session) -> Dict[str, Any]:
|
||||
"""获取仪表板数据"""
|
||||
current_metrics = self.metrics_collector.get_current_metrics(db)
|
||||
active_alerts = self.alert_manager.get_active_alerts()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"system": current_metrics["system"],
|
||||
"business": current_metrics["business"]["business"],
|
||||
"active_alerts_count": len(active_alerts),
|
||||
"recent_alerts": active_alerts[-5:], # 最近5个告警
|
||||
"system_health": self.get_system_health()
|
||||
}
|
||||
|
||||
|
||||
# 全局监控服务实例
|
||||
monitoring_service = MonitoringService()
|
||||
289
backend/app/services/permission.py
Normal file
289
backend/app/services/permission.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""权限管理服务,负责算法访问权限和用户权限管理"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.models import User, Algorithm, APIKey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PermissionType(Enum):
|
||||
"""权限类型枚举"""
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class AccessLevel(Enum):
|
||||
"""访问级别枚举"""
|
||||
NONE = "none"
|
||||
READ = "read"
|
||||
READ_WRITE = "read_write"
|
||||
EXECUTE = "execute"
|
||||
FULL = "full"
|
||||
|
||||
|
||||
class PermissionManager:
|
||||
"""权限管理器,处理用户对算法的访问权限"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def check_algorithm_access(self, db: Session, user_id: str, algorithm_id: str,
|
||||
permission_type: PermissionType) -> bool:
|
||||
"""检查用户对算法的访问权限"""
|
||||
try:
|
||||
# 获取用户和算法信息
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
|
||||
|
||||
if not user or not algorithm:
|
||||
return False
|
||||
|
||||
# 管理员拥有所有权限
|
||||
if user.role == "admin":
|
||||
return True
|
||||
|
||||
# 算法所有者拥有完全权限
|
||||
# 注意:在这个模型中,我们没有直接的算法所有者字段
|
||||
# 实际应用中可能需要添加算法创建者的关联
|
||||
|
||||
# 根据算法状态决定权限
|
||||
if algorithm.status == "inactive":
|
||||
# 非活跃算法只有管理员可以访问
|
||||
return False
|
||||
|
||||
# 根据权限类型检查权限
|
||||
if permission_type in [PermissionType.READ, PermissionType.EXECUTE]:
|
||||
# 读取和执行权限:活跃算法对所有认证用户开放
|
||||
return algorithm.status == "active"
|
||||
elif permission_type == PermissionType.WRITE:
|
||||
# 写入权限:只有特定用户或管理员可以
|
||||
return user.role in ["admin", "manager"]
|
||||
elif permission_type == PermissionType.ADMIN:
|
||||
# 管理权限:只有管理员可以
|
||||
return user.role == "admin"
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking algorithm access: {str(e)}")
|
||||
return False
|
||||
|
||||
def grant_permission(self, db: Session, admin_user_id: str, user_id: str,
|
||||
algorithm_id: str, access_level: AccessLevel) -> bool:
|
||||
"""授予用户对算法的权限"""
|
||||
try:
|
||||
# 验证管理员权限
|
||||
admin_user = db.query(User).filter(
|
||||
User.id == admin_user_id,
|
||||
User.role.in_(["admin", "manager"])
|
||||
).first()
|
||||
|
||||
if not admin_user:
|
||||
logger.warning(f"User {admin_user_id} is not authorized to grant permissions")
|
||||
return False
|
||||
|
||||
# 验证目标用户和算法存在
|
||||
target_user = db.query(User).filter(User.id == user_id).first()
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
|
||||
|
||||
if not target_user or not algorithm:
|
||||
logger.warning("Target user or algorithm not found")
|
||||
return False
|
||||
|
||||
# 在实际实现中,这里应该创建权限记录
|
||||
# 由于当前数据模型中没有专门的权限表,我们记录日志
|
||||
logger.info(f"Permission granted: user {user_id} -> algorithm {algorithm_id}, "
|
||||
f"level: {access_level.value}, by admin: {admin_user_id}")
|
||||
|
||||
# 在实际实现中,应该创建权限记录到数据库
|
||||
# db.add(PermissionRecord(user_id=user_id, algorithm_id=algorithm_id, access_level=access_level))
|
||||
# db.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error granting permission: {str(e)}")
|
||||
return False
|
||||
|
||||
def revoke_permission(self, db: Session, admin_user_id: str, user_id: str,
|
||||
algorithm_id: str) -> bool:
|
||||
"""撤销用户对算法的权限"""
|
||||
try:
|
||||
# 验证管理员权限
|
||||
admin_user = db.query(User).filter(
|
||||
User.id == admin_user_id,
|
||||
User.role.in_(["admin", "manager"])
|
||||
).first()
|
||||
|
||||
if not admin_user:
|
||||
logger.warning(f"User {admin_user_id} is not authorized to revoke permissions")
|
||||
return False
|
||||
|
||||
# 在实际实现中,这里应该删除权限记录
|
||||
logger.info(f"Permission revoked: user {user_id} -> algorithm {algorithm_id}, "
|
||||
f"by admin: {admin_user_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking permission: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_user_permissions(self, db: Session, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取用户的所有权限"""
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return []
|
||||
|
||||
permissions = []
|
||||
|
||||
if user.role == "admin":
|
||||
# 管理员拥有所有算法的完全权限
|
||||
algorithms = db.query(Algorithm).all()
|
||||
for algorithm in algorithms:
|
||||
permissions.append({
|
||||
"algorithm_id": algorithm.id,
|
||||
"algorithm_name": algorithm.name,
|
||||
"access_level": AccessLevel.FULL.value,
|
||||
"granted_at": algorithm.created_at,
|
||||
"granted_by": "system",
|
||||
"permission_type": "administrative"
|
||||
})
|
||||
else:
|
||||
# 普通用户只能访问活跃算法
|
||||
algorithms = db.query(Algorithm).filter(Algorithm.status == "active").all()
|
||||
for algorithm in algorithms:
|
||||
permissions.append({
|
||||
"algorithm_id": algorithm.id,
|
||||
"algorithm_name": algorithm.name,
|
||||
"access_level": AccessLevel.READ.value, # 默认只读权限
|
||||
"granted_at": algorithm.created_at,
|
||||
"granted_by": "system",
|
||||
"permission_type": "public"
|
||||
})
|
||||
|
||||
return permissions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user permissions: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_algorithm_permissions(self, db: Session, algorithm_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取算法的所有权限分配"""
|
||||
try:
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
|
||||
if not algorithm:
|
||||
return []
|
||||
|
||||
permissions = []
|
||||
|
||||
# 如果是活跃算法,所有用户都有读取权限
|
||||
if algorithm.status == "active":
|
||||
# 添加公共访问权限信息
|
||||
permissions.append({
|
||||
"user_id": "*",
|
||||
"user_name": "All Users",
|
||||
"access_level": AccessLevel.READ.value,
|
||||
"granted_at": algorithm.created_at,
|
||||
"granted_by": "system",
|
||||
"permission_type": "public"
|
||||
})
|
||||
|
||||
# 获取特定用户的权限(如果有专门的权限表的话)
|
||||
# 这里我们模拟返回一些示例数据
|
||||
# 在实际实现中,应该查询权限表
|
||||
|
||||
return permissions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting algorithm permissions: {str(e)}")
|
||||
return []
|
||||
|
||||
def check_api_key_access(self, db: Session, api_key_value: str, algorithm_id: str) -> bool:
|
||||
"""检查API密钥对算法的访问权限"""
|
||||
try:
|
||||
# 通过API密钥查找用户
|
||||
api_key = db.query(APIKey).filter(
|
||||
APIKey.key == api_key_value,
|
||||
APIKey.status == "active"
|
||||
).first()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
# 检查用户对算法的访问权限
|
||||
return self.check_algorithm_access(
|
||||
db, api_key.user_id, algorithm_id, PermissionType.EXECUTE
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key access: {str(e)}")
|
||||
return False
|
||||
|
||||
def validate_user_algorithm_operation(self, db: Session, user_id: str, algorithm_id: str,
|
||||
operation: str) -> bool:
|
||||
"""验证用户对算法的操作权限"""
|
||||
try:
|
||||
perm_type = {
|
||||
'read': PermissionType.READ,
|
||||
'execute': PermissionType.EXECUTE,
|
||||
'write': PermissionType.WRITE,
|
||||
'admin': PermissionType.ADMIN
|
||||
}.get(operation.lower(), PermissionType.READ)
|
||||
|
||||
return self.check_algorithm_access(db, user_id, algorithm_id, perm_type)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating user operation: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RBACManager:
|
||||
"""基于角色的访问控制管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 定义角色权限映射
|
||||
self.role_permissions = {
|
||||
"admin": [
|
||||
PermissionType.READ, PermissionType.WRITE,
|
||||
PermissionType.EXECUTE, PermissionType.ADMIN
|
||||
],
|
||||
"manager": [
|
||||
PermissionType.READ, PermissionType.WRITE, PermissionType.EXECUTE
|
||||
],
|
||||
"user": [PermissionType.READ, PermissionType.EXECUTE],
|
||||
"guest": [PermissionType.READ]
|
||||
}
|
||||
|
||||
def get_role_permissions(self, role: str) -> List[PermissionType]:
|
||||
"""获取角色的权限列表"""
|
||||
return self.role_permissions.get(role, [])
|
||||
|
||||
def user_has_permission(self, db: Session, user_id: str, permission_type: PermissionType) -> bool:
|
||||
"""检查用户是否具有某种权限"""
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return False
|
||||
|
||||
user_perms = self.get_role_permissions(user.role)
|
||||
return permission_type in user_perms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking user permission: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局权限管理器实例
|
||||
permission_manager = PermissionManager()
|
||||
rbac_manager = RBACManager()
|
||||
307
backend/app/services/project_analyzer.py
Normal file
307
backend/app/services/project_analyzer.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""项目分析服务,用于分析算法仓库的结构和特性"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
|
||||
class ProjectAnalyzer:
|
||||
"""项目分析服务"""
|
||||
|
||||
def analyze_project(self, repo_path: str) -> Dict[str, Any]:
|
||||
"""分析项目结构和特性
|
||||
|
||||
Args:
|
||||
repo_path: 仓库路径
|
||||
|
||||
Returns:
|
||||
包含项目分析结果的字典
|
||||
"""
|
||||
try:
|
||||
# 1. 识别项目类型
|
||||
project_type = self._detect_project_type(repo_path)
|
||||
|
||||
# 2. 分析依赖
|
||||
dependencies = self._analyze_dependencies(repo_path, project_type)
|
||||
|
||||
# 3. 识别入口点
|
||||
entry_point = self._detect_entry_point(repo_path, project_type)
|
||||
|
||||
# 4. 分析API模式
|
||||
api_pattern = self._detect_api_pattern(repo_path, project_type)
|
||||
|
||||
# 5. 分析项目结构
|
||||
structure = self._analyze_structure(repo_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"project_type": project_type,
|
||||
"dependencies": dependencies,
|
||||
"entry_point": entry_point,
|
||||
"api_pattern": api_pattern,
|
||||
"structure": structure,
|
||||
"error": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"project_type": None,
|
||||
"dependencies": None,
|
||||
"entry_point": None,
|
||||
"api_pattern": None,
|
||||
"structure": None
|
||||
}
|
||||
|
||||
def _detect_project_type(self, repo_path: str) -> Optional[str]:
|
||||
"""检测项目类型
|
||||
|
||||
Args:
|
||||
repo_path: 仓库路径
|
||||
|
||||
Returns:
|
||||
项目类型,如 "python", "java", "nodejs" 等
|
||||
"""
|
||||
# 检查Python项目
|
||||
if os.path.exists(os.path.join(repo_path, "requirements.txt")) or \
|
||||
os.path.exists(os.path.join(repo_path, "pyproject.toml")) or \
|
||||
any(file.endswith(".py") for file in os.listdir(repo_path)):
|
||||
return "python"
|
||||
|
||||
# 检查Java项目
|
||||
if os.path.exists(os.path.join(repo_path, "pom.xml")) or \
|
||||
os.path.exists(os.path.join(repo_path, "build.gradle")) or \
|
||||
os.path.exists(os.path.join(repo_path, "src")):
|
||||
return "java"
|
||||
|
||||
# 检查Node.js项目
|
||||
if os.path.exists(os.path.join(repo_path, "package.json")):
|
||||
return "nodejs"
|
||||
|
||||
# 检查其他项目类型
|
||||
if os.path.exists(os.path.join(repo_path, "CMakeLists.txt")):
|
||||
return "c++"
|
||||
|
||||
return None
|
||||
|
||||
def _analyze_dependencies(self, repo_path: str, project_type: Optional[str]) -> List[str]:
|
||||
"""分析项目依赖
|
||||
|
||||
Args:
|
||||
repo_path: 仓库路径
|
||||
project_type: 项目类型
|
||||
|
||||
Returns:
|
||||
依赖列表
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
if project_type == "python":
|
||||
# 分析requirements.txt
|
||||
req_file = os.path.join(repo_path, "requirements.txt")
|
||||
if os.path.exists(req_file):
|
||||
with open(req_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
dependencies.append(line)
|
||||
|
||||
# 分析pyproject.toml
|
||||
pyproject_file = os.path.join(repo_path, "pyproject.toml")
|
||||
if os.path.exists(pyproject_file):
|
||||
with open(pyproject_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
# 简单解析依赖部分
|
||||
if "[dependencies]" in content:
|
||||
dep_section = content.split("[dependencies]")[1].split("[")[0]
|
||||
for line in dep_section.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
dependencies.append(line)
|
||||
|
||||
elif project_type == "java":
|
||||
# 分析pom.xml
|
||||
pom_file = os.path.join(repo_path, "pom.xml")
|
||||
if os.path.exists(pom_file):
|
||||
with open(pom_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
# 简单解析依赖
|
||||
for match in re.finditer(r'<dependency>.*?</dependency>', content, re.DOTALL):
|
||||
dep = match.group(0)
|
||||
group_id = re.search(r'<groupId>(.*?)</groupId>', dep)
|
||||
artifact_id = re.search(r'<artifactId>(.*?)</artifactId>', dep)
|
||||
version = re.search(r'<version>(.*?)</version>', dep)
|
||||
if group_id and artifact_id:
|
||||
dep_str = f"{group_id.group(1)}:{artifact_id.group(1)}"
|
||||
if version:
|
||||
dep_str += f":{version.group(1)}"
|
||||
dependencies.append(dep_str)
|
||||
|
||||
elif project_type == "nodejs":
|
||||
# 分析package.json
|
||||
package_file = os.path.join(repo_path, "package.json")
|
||||
if os.path.exists(package_file):
|
||||
with open(package_file, "r", encoding="utf-8") as f:
|
||||
package_data = json.load(f)
|
||||
if "dependencies" in package_data:
|
||||
for dep, version in package_data["dependencies"].items():
|
||||
dependencies.append(f"{dep}@{version}")
|
||||
if "devDependencies" in package_data:
|
||||
for dep, version in package_data["devDependencies"].items():
|
||||
dependencies.append(f"{dep}@{version} (dev)")
|
||||
|
||||
return dependencies
|
||||
|
||||
def _detect_entry_point(self, repo_path: str, project_type: Optional[str]) -> Optional[str]:
|
||||
"""检测项目入口点
|
||||
|
||||
Args:
|
||||
repo_path: 仓库路径
|
||||
project_type: 项目类型
|
||||
|
||||
Returns:
|
||||
入口点路径或函数名
|
||||
"""
|
||||
if project_type == "python":
|
||||
# 查找主要的Python文件
|
||||
main_files = ["main.py", "app.py", "run.py", "server.py"]
|
||||
for file in main_files:
|
||||
file_path = os.path.join(repo_path, file)
|
||||
if os.path.exists(file_path):
|
||||
return file
|
||||
|
||||
# 查找包含__main__.py的包
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
if "__main__.py" in files:
|
||||
return os.path.relpath(os.path.join(root, "__main__.py"), repo_path)
|
||||
|
||||
# 查找包含main函数的文件
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
if "def main(" in content:
|
||||
return os.path.relpath(file_path, repo_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
elif project_type == "java":
|
||||
# 查找包含main方法的Java文件
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
for file in files:
|
||||
if file.endswith(".java"):
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
if "public static void main(String[] args)" in content:
|
||||
return os.path.relpath(file_path, repo_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
elif project_type == "nodejs":
|
||||
# 检查package.json中的main字段
|
||||
package_file = os.path.join(repo_path, "package.json")
|
||||
if os.path.exists(package_file):
|
||||
with open(package_file, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
package_data = json.load(f)
|
||||
if "main" in package_data:
|
||||
return package_data["main"]
|
||||
elif "scripts" in package_data and "start" in package_data["scripts"]:
|
||||
return f"package.json (start: {package_data['scripts']['start']})"
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _detect_api_pattern(self, repo_path: str, project_type: Optional[str]) -> Optional[str]:
|
||||
"""检测API模式
|
||||
|
||||
Args:
|
||||
repo_path: 仓库路径
|
||||
project_type: 项目类型
|
||||
|
||||
Returns:
|
||||
API模式,如 "fastapi", "flask", "express" 等
|
||||
"""
|
||||
if project_type == "python":
|
||||
# 检查FastAPI
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
if "from fastapi import" in content or "import fastapi" in content:
|
||||
return "fastapi"
|
||||
elif "from flask import" in content or "import flask" in content:
|
||||
return "flask"
|
||||
elif "from django import" in content or "import django" in content:
|
||||
return "django"
|
||||
except:
|
||||
pass
|
||||
|
||||
elif project_type == "nodejs":
|
||||
# 检查Express
|
||||
package_file = os.path.join(repo_path, "package.json")
|
||||
if os.path.exists(package_file):
|
||||
with open(package_file, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
package_data = json.load(f)
|
||||
dependencies = package_data.get("dependencies", {})
|
||||
if "express" in dependencies:
|
||||
return "express"
|
||||
elif "koa" in dependencies:
|
||||
return "koa"
|
||||
elif "nestjs" in dependencies:
|
||||
return "nestjs"
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _analyze_structure(self, repo_path: str) -> Dict[str, Any]:
|
||||
"""分析项目结构
|
||||
|
||||
Args:
|
||||
repo_path: 仓库路径
|
||||
|
||||
Returns:
|
||||
项目结构字典
|
||||
"""
|
||||
structure = {
|
||||
"files": [],
|
||||
"directories": [],
|
||||
"size": 0
|
||||
}
|
||||
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
# 排除隐藏目录和文件
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
files = [f for f in files if not f.startswith(".")]
|
||||
|
||||
# 添加目录
|
||||
for dir_name in dirs:
|
||||
dir_path = os.path.join(root, dir_name)
|
||||
structure["directories"].append(os.path.relpath(dir_path, repo_path))
|
||||
|
||||
# 添加文件
|
||||
for file_name in files:
|
||||
file_path = os.path.join(root, file_name)
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
structure["files"].append({
|
||||
"path": os.path.relpath(file_path, repo_path),
|
||||
"size": file_size
|
||||
})
|
||||
structure["size"] += file_size
|
||||
except:
|
||||
pass
|
||||
|
||||
return structure
|
||||
977
backend/app/services/service_generator.py
Normal file
977
backend/app/services/service_generator.py
Normal file
@@ -0,0 +1,977 @@
|
||||
"""服务生成器,用于生成算法服务包装器"""
|
||||
|
||||
import os
|
||||
import jinja2
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class ServiceGenerator:
|
||||
"""服务生成器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务生成器"""
|
||||
# 初始化Jinja2模板引擎
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(os.path.join(os.path.dirname(__file__), "templates")),
|
||||
autoescape=False
|
||||
)
|
||||
|
||||
# 确保模板目录存在
|
||||
template_dir = os.path.join(os.path.dirname(__file__), "templates")
|
||||
if not os.path.exists(template_dir):
|
||||
os.makedirs(template_dir)
|
||||
# 创建默认模板
|
||||
self._create_default_templates(template_dir)
|
||||
|
||||
def generate_service(self, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""生成服务包装器
|
||||
|
||||
Args:
|
||||
project_info: 项目分析信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
包含生成结果的字典
|
||||
"""
|
||||
try:
|
||||
# 1. 根据项目类型选择模板
|
||||
template_name = self._select_template(project_info["project_type"], service_config["service_type"])
|
||||
|
||||
# 2. 生成服务代码
|
||||
service_code = self._generate_service_code(template_name, project_info, service_config)
|
||||
|
||||
# 3. 生成Dockerfile
|
||||
dockerfile = self._generate_dockerfile(project_info, service_config)
|
||||
|
||||
# 4. 生成配置文件
|
||||
config_files = self._generate_config_files(service_config)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service_code": service_code,
|
||||
"dockerfile": dockerfile,
|
||||
"config_files": config_files,
|
||||
"error": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"service_code": None,
|
||||
"dockerfile": None,
|
||||
"config_files": None
|
||||
}
|
||||
|
||||
def _select_template(self, project_type: str, service_type: str) -> str:
|
||||
"""选择服务模板
|
||||
|
||||
Args:
|
||||
project_type: 项目类型
|
||||
service_type: 服务类型
|
||||
|
||||
Returns:
|
||||
模板名称
|
||||
"""
|
||||
# 根据项目类型和服务类型选择模板
|
||||
template_map = {
|
||||
"python": {
|
||||
"http": "python_http_service.py.j2",
|
||||
"grpc": "python_grpc_service.py.j2",
|
||||
"mq": "python_mq_service.py.j2"
|
||||
},
|
||||
"nodejs": {
|
||||
"http": "nodejs_http_service.js.j2",
|
||||
"grpc": "nodejs_grpc_service.js.j2",
|
||||
"mq": "nodejs_mq_service.js.j2"
|
||||
}
|
||||
}
|
||||
|
||||
return template_map.get(project_type, {}).get(service_type, "python_http_service.py.j2")
|
||||
|
||||
def _generate_service_code(self, template_name: str, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""生成服务代码
|
||||
|
||||
Args:
|
||||
template_name: 模板名称
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
生成的服务代码
|
||||
"""
|
||||
try:
|
||||
# 加载模板
|
||||
template = self.template_env.get_template(template_name)
|
||||
|
||||
# 准备模板数据
|
||||
template_data = {
|
||||
"project_name": service_config.get("name", "algorithm-service"),
|
||||
"project_type": project_info["project_type"],
|
||||
"entry_point": project_info["entry_point"],
|
||||
"api_pattern": project_info["api_pattern"],
|
||||
"dependencies": project_info["dependencies"],
|
||||
"service_config": service_config,
|
||||
"host": service_config.get("host", "0.0.0.0"),
|
||||
"port": service_config.get("port", 8000),
|
||||
"timeout": service_config.get("timeout", 30),
|
||||
"health_check_path": service_config.get("health_check_path", "/health")
|
||||
}
|
||||
|
||||
# 渲染模板
|
||||
return template.render(**template_data)
|
||||
except Exception as e:
|
||||
# 如果模板不存在,生成默认的Python HTTP服务
|
||||
if project_info["project_type"] == "python":
|
||||
return self._generate_default_python_http_service(project_info, service_config)
|
||||
elif project_info["project_type"] == "nodejs":
|
||||
return self._generate_default_nodejs_http_service(project_info, service_config)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _generate_dockerfile(self, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""生成Dockerfile
|
||||
|
||||
Args:
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
Dockerfile内容
|
||||
"""
|
||||
dockerfile_templates = {
|
||||
"python": """
|
||||
FROM python:3.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 复制生成的服务包装器
|
||||
COPY service_wrapper.py .
|
||||
|
||||
# 设置环境变量
|
||||
ENV HOST={{host}}
|
||||
ENV PORT={{port}}
|
||||
ENV TIMEOUT={{timeout}}
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE {{port}}
|
||||
|
||||
# 启动服务
|
||||
CMD ["python", "service_wrapper.py"]
|
||||
""",
|
||||
"nodejs": """
|
||||
FROM node:16-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制依赖文件
|
||||
COPY package.json package-lock.json* .
|
||||
|
||||
# 安装依赖
|
||||
RUN npm install --production
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 复制生成的服务包装器
|
||||
COPY service_wrapper.js .
|
||||
|
||||
# 设置环境变量
|
||||
ENV HOST={{host}}
|
||||
ENV PORT={{port}}
|
||||
ENV TIMEOUT={{timeout}}
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE {{port}}
|
||||
|
||||
# 启动服务
|
||||
CMD ["node", "service_wrapper.js"]
|
||||
"""
|
||||
}
|
||||
|
||||
template = dockerfile_templates.get(project_info["project_type"], dockerfile_templates["python"])
|
||||
|
||||
# 替换模板变量
|
||||
template = template.replace("{{host}}", service_config.get("host", "0.0.0.0"))
|
||||
template = template.replace("{{port}}", str(service_config.get("port", 8000)))
|
||||
template = template.replace("{{timeout}}", str(service_config.get("timeout", 30)))
|
||||
|
||||
return template
|
||||
|
||||
def _generate_config_files(self, service_config: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""生成配置文件
|
||||
|
||||
Args:
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
配置文件字典
|
||||
"""
|
||||
config_files = {}
|
||||
|
||||
# 生成环境变量文件
|
||||
env_content = """
|
||||
# Service Configuration
|
||||
HOST={{host}}
|
||||
PORT={{port}}
|
||||
TIMEOUT={{timeout}}
|
||||
|
||||
# Service Metadata
|
||||
SERVICE_NAME={{name}}
|
||||
SERVICE_VERSION={{version}}
|
||||
"""
|
||||
|
||||
env_content = env_content.replace("{{host}}", service_config.get("host", "0.0.0.0"))
|
||||
env_content = env_content.replace("{{port}}", str(service_config.get("port", 8000)))
|
||||
env_content = env_content.replace("{{timeout}}", str(service_config.get("timeout", 30)))
|
||||
env_content = env_content.replace("{{name}}", service_config.get("name", "algorithm-service"))
|
||||
env_content = env_content.replace("{{version}}", service_config.get("version", "1.0.0"))
|
||||
|
||||
config_files[".env"] = env_content
|
||||
|
||||
# 生成docker-compose.yml
|
||||
docker_compose_content = """
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
{{name}}:
|
||||
build: .
|
||||
ports:
|
||||
- "{{port}}:{{port}}"
|
||||
environment:
|
||||
- HOST={{host}}
|
||||
- PORT={{port}}
|
||||
- TIMEOUT={{timeout}}
|
||||
restart: unless-stopped
|
||||
"""
|
||||
|
||||
docker_compose_content = docker_compose_content.replace("{{name}}", service_config.get("name", "algorithm-service"))
|
||||
docker_compose_content = docker_compose_content.replace("{{port}}", str(service_config.get("port", 8000)))
|
||||
docker_compose_content = docker_compose_content.replace("{{host}}", service_config.get("host", "0.0.0.0"))
|
||||
docker_compose_content = docker_compose_content.replace("{{timeout}}", str(service_config.get("timeout", 30)))
|
||||
|
||||
config_files["docker-compose.yml"] = docker_compose_content
|
||||
|
||||
return config_files
|
||||
|
||||
def _generate_default_python_http_service(self, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""生成默认的Python HTTP服务
|
||||
|
||||
Args:
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
生成的服务代码
|
||||
"""
|
||||
# 使用简单的字符串拼接
|
||||
service_code = "# Python HTTP服务包装器\n"
|
||||
service_code += "\n"
|
||||
service_code += "import os\n"
|
||||
service_code += "import sys\n"
|
||||
service_code += "import json\n"
|
||||
service_code += "import time\n"
|
||||
service_code += "from http.server import HTTPServer, BaseHTTPRequestHandler\n"
|
||||
service_code += "\n"
|
||||
service_code += "# 添加项目路径到Python路径\n"
|
||||
service_code += "sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))\n"
|
||||
service_code += "\n"
|
||||
service_code += "# 尝试导入算法模块\n"
|
||||
service_code += "try:\n"
|
||||
service_code += " # 根据入口点导入算法\n"
|
||||
service_code += " if '{{entry_point}}' == '':\n"
|
||||
service_code += " # 尝试导入主要模块\n"
|
||||
service_code += " import algorithm\n"
|
||||
service_code += " algorithm_module = algorithm\n"
|
||||
service_code += " else:\n"
|
||||
service_code += " # 动态导入入口点\n"
|
||||
service_code += " import importlib.util\n"
|
||||
service_code += " spec = importlib.util.spec_from_file_location('algorithm_module', '{{entry_point}}')\n"
|
||||
service_code += " algorithm_module = importlib.util.module_from_spec(spec)\n"
|
||||
service_code += " spec.loader.exec_module(algorithm_module)\n"
|
||||
service_code += " print('算法模块导入成功')\n"
|
||||
service_code += "except Exception as e:\n"
|
||||
service_code += " print(f'算法模块导入失败: {e}')\n"
|
||||
service_code += " algorithm_module = None\n"
|
||||
service_code += "\n"
|
||||
service_code += "# 服务配置\n"
|
||||
service_code += "HOST = os.environ.get('HOST', '0.0.0.0')\n"
|
||||
service_code += "PORT = int(os.environ.get('PORT', '{{port}}'))\n"
|
||||
service_code += "TIMEOUT = int(os.environ.get('TIMEOUT', '{{timeout}}'))\n"
|
||||
service_code += "\n"
|
||||
service_code += "class AlgorithmRequestHandler(BaseHTTPRequestHandler):\n"
|
||||
service_code += " '''算法请求处理器'''\n"
|
||||
service_code += " \n"
|
||||
service_code += " def do_POST(self):\n"
|
||||
service_code += " '''处理POST请求'''\n"
|
||||
service_code += " try:\n"
|
||||
service_code += " # 读取请求体\n"
|
||||
service_code += " content_length = int(self.headers['Content-Length'])\n"
|
||||
service_code += " post_data = self.rfile.read(content_length)\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 解析请求数据\n"
|
||||
service_code += " request_data = json.loads(post_data.decode('utf-8'))\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 记录请求开始时间\n"
|
||||
service_code += " start_time = time.time()\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 调用算法\n"
|
||||
service_code += " result = self._call_algorithm(request_data)\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 计算响应时间\n"
|
||||
service_code += " response_time = time.time() - start_time\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 构建响应\n"
|
||||
service_code += " response = {\n"
|
||||
service_code += " 'success': True,\n"
|
||||
service_code += " 'result': result,\n"
|
||||
service_code += " 'response_time': round(response_time, 4),\n"
|
||||
service_code += " 'message': '算法执行成功'\n"
|
||||
service_code += " }\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 发送响应\n"
|
||||
service_code += " self.send_response(200)\n"
|
||||
service_code += " self.send_header('Content-type', 'application/json')\n"
|
||||
service_code += " self.end_headers()\n"
|
||||
service_code += " self.wfile.write(json.dumps(response).encode('utf-8'))\n"
|
||||
service_code += " \n"
|
||||
service_code += " except Exception as e:\n"
|
||||
service_code += " # 构建错误响应\n"
|
||||
service_code += " error_response = {\n"
|
||||
service_code += " 'success': False,\n"
|
||||
service_code += " 'error': str(e),\n"
|
||||
service_code += " 'message': '算法执行失败'\n"
|
||||
service_code += " }\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 发送错误响应\n"
|
||||
service_code += " self.send_response(400)\n"
|
||||
service_code += " self.send_header('Content-type', 'application/json')\n"
|
||||
service_code += " self.end_headers()\n"
|
||||
service_code += " self.wfile.write(json.dumps(error_response).encode('utf-8'))\n"
|
||||
service_code += " \n"
|
||||
service_code += " def do_GET(self):\n"
|
||||
service_code += " '''处理GET请求'''\n"
|
||||
service_code += " if self.path == '/health':\n"
|
||||
service_code += " # 健康检查\n"
|
||||
service_code += " self._handle_health_check()\n"
|
||||
service_code += " elif self.path == '/info':\n"
|
||||
service_code += " # 服务信息\n"
|
||||
service_code += " self._handle_info()\n"
|
||||
service_code += " else:\n"
|
||||
service_code += " # 404响应\n"
|
||||
service_code += " self.send_response(404)\n"
|
||||
service_code += " self.send_header('Content-type', 'application/json')\n"
|
||||
service_code += " self.end_headers()\n"
|
||||
service_code += " self.wfile.write(json.dumps({'error': 'Not Found'}).encode('utf-8'))\n"
|
||||
service_code += " \n"
|
||||
service_code += " def _call_algorithm(self, request_data):\n"
|
||||
service_code += " '''调用算法\n"
|
||||
service_code += " \n"
|
||||
service_code += " Args:\n"
|
||||
service_code += " request_data: 请求数据\n"
|
||||
service_code += " \n"
|
||||
service_code += " Returns:\n"
|
||||
service_code += " 算法执行结果\n"
|
||||
service_code += " '''\n"
|
||||
service_code += " if algorithm_module is None:\n"
|
||||
service_code += " raise Exception('算法模块未加载')\n"
|
||||
service_code += " \n"
|
||||
service_code += " # 尝试调用算法的主要函数\n"
|
||||
service_code += " try:\n"
|
||||
service_code += " # 检查是否有predict函数\n"
|
||||
service_code += " if hasattr(algorithm_module, 'predict'):\n"
|
||||
service_code += " return algorithm_module.predict(request_data)\n"
|
||||
service_code += " # 检查是否有run函数\n"
|
||||
service_code += " elif hasattr(algorithm_module, 'run'):\n"
|
||||
service_code += " return algorithm_module.run(request_data)\n"
|
||||
service_code += " # 检查是否有main函数\n"
|
||||
service_code += " elif hasattr(algorithm_module, 'main'):\n"
|
||||
service_code += " return algorithm_module.main(request_data)\n"
|
||||
service_code += " else:\n"
|
||||
service_code += " raise Exception('未找到算法执行函数')\n"
|
||||
service_code += " except Exception as e:\n"
|
||||
service_code += " raise Exception(f'算法执行失败: {e}')\n"
|
||||
service_code += " \n"
|
||||
service_code += " def _handle_health_check(self):\n"
|
||||
service_code += " '''处理健康检查'''\n"
|
||||
service_code += " self.send_response(200)\n"
|
||||
service_code += " self.send_header('Content-type', 'application/json')\n"
|
||||
service_code += " self.end_headers()\n"
|
||||
service_code += " self.wfile.write(json.dumps({'status': 'healthy', 'service': '{{name}}'}).encode('utf-8'))\n"
|
||||
service_code += " \n"
|
||||
service_code += " def _handle_info(self):\n"
|
||||
service_code += " '''处理服务信息请求'''\n"
|
||||
service_code += " info = {\n"
|
||||
service_code += " 'service': '{{name}}',\n"
|
||||
service_code += " 'version': '{{version}}',\n"
|
||||
service_code += " 'host': HOST,\n"
|
||||
service_code += " 'port': PORT,\n"
|
||||
service_code += " 'timeout': TIMEOUT,\n"
|
||||
service_code += " 'algorithm_loaded': algorithm_module is not None\n"
|
||||
service_code += " }\n"
|
||||
service_code += " \n"
|
||||
service_code += " self.send_response(200)\n"
|
||||
service_code += " self.send_header('Content-type', 'application/json')\n"
|
||||
service_code += " self.end_headers()\n"
|
||||
service_code += " self.wfile.write(json.dumps(info).encode('utf-8'))\n"
|
||||
service_code += "\n"
|
||||
service_code += "def run_server():\n"
|
||||
service_code += " '''启动服务'''\n"
|
||||
service_code += " server = HTTPServer((HOST, PORT), AlgorithmRequestHandler)\n"
|
||||
service_code += " print(f'服务启动成功,监听地址: {HOST}:{PORT}')\n"
|
||||
service_code += " print(f'健康检查地址: http://{HOST}:{PORT}/health')\n"
|
||||
service_code += " print(f'服务信息地址: http://{HOST}:{PORT}/info')\n"
|
||||
service_code += " \n"
|
||||
service_code += " try:\n"
|
||||
service_code += " server.serve_forever()\n"
|
||||
service_code += " except KeyboardInterrupt:\n"
|
||||
service_code += " print('服务停止')\n"
|
||||
service_code += " server.shutdown()\n"
|
||||
service_code += "\n"
|
||||
service_code += "if __name__ == '__main__':\n"
|
||||
service_code += " run_server()\n"
|
||||
|
||||
# 替换模板变量
|
||||
service_code = service_code.replace("{{entry_point}}", project_info.get("entry_point", ""))
|
||||
service_code = service_code.replace("{{port}}", str(service_config.get("port", 8000)))
|
||||
service_code = service_code.replace("{{timeout}}", str(service_config.get("timeout", 30)))
|
||||
service_code = service_code.replace("{{name}}", service_config.get("name", "algorithm-service"))
|
||||
service_code = service_code.replace("{{version}}", service_config.get("version", "1.0.0"))
|
||||
|
||||
return service_code
|
||||
|
||||
def _generate_default_nodejs_http_service(self, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""生成默认的Node.js HTTP服务
|
||||
|
||||
Args:
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
生成的服务代码
|
||||
"""
|
||||
# 使用简单的字符串拼接
|
||||
service_code = "// Node.js HTTP服务包装器\n"
|
||||
service_code += "\n"
|
||||
service_code += "const http = require('http');\n"
|
||||
service_code += "const url = require('url');\n"
|
||||
service_code += "const fs = require('fs');\n"
|
||||
service_code += "const path = require('path');\n"
|
||||
service_code += "\n"
|
||||
service_code += "// 服务配置\n"
|
||||
service_code += "const HOST = process.env.HOST || '0.0.0.0';\n"
|
||||
service_code += "const PORT = process.env.PORT || {{port}};\n"
|
||||
service_code += "const TIMEOUT = process.env.TIMEOUT || {{timeout}};\n"
|
||||
service_code += "\n"
|
||||
service_code += "// 尝试导入算法模块\n"
|
||||
service_code += "let algorithmModule = null;\n"
|
||||
service_code += "try {\n"
|
||||
service_code += " // 根据入口点导入算法\n"
|
||||
service_code += " if ('{{entry_point}}' === '') {\n"
|
||||
service_code += " // 尝试导入主要模块\n"
|
||||
service_code += " algorithmModule = require('./algorithm');\n"
|
||||
service_code += " } else {\n"
|
||||
service_code += " // 导入入口点\n"
|
||||
service_code += " algorithmModule = require('./{{entry_point}}');\n"
|
||||
service_code += " }\n"
|
||||
service_code += " console.log('算法模块导入成功');\n"
|
||||
service_code += "} catch (e) {\n"
|
||||
service_code += " console.error('算法模块导入失败:', e);\n"
|
||||
service_code += " algorithmModule = null;\n"
|
||||
service_code += "}\n"
|
||||
service_code += "\n"
|
||||
service_code += "/**\n"
|
||||
service_code += " * 调用算法\n"
|
||||
service_code += " * @param {Object} requestData 请求数据\n"
|
||||
service_code += " * @returns {Object} 算法执行结果\n"
|
||||
service_code += " */\n"
|
||||
service_code += "function callAlgorithm(requestData) {\n"
|
||||
service_code += " if (!algorithmModule) {\n"
|
||||
service_code += " throw new Error('算法模块未加载');\n"
|
||||
service_code += " }\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 尝试调用算法的主要函数\n"
|
||||
service_code += " try {\n"
|
||||
service_code += " if (algorithmModule.predict) {\n"
|
||||
service_code += " return algorithmModule.predict(requestData);\n"
|
||||
service_code += " } else if (algorithmModule.run) {\n"
|
||||
service_code += " return algorithmModule.run(requestData);\n"
|
||||
service_code += " } else if (algorithmModule.main) {\n"
|
||||
service_code += " return algorithmModule.main(requestData);\n"
|
||||
service_code += " } else {\n"
|
||||
service_code += " throw new Error('未找到算法执行函数');\n"
|
||||
service_code += " }\n"
|
||||
service_code += " } catch (e) {\n"
|
||||
service_code += " throw new Error(`算法执行失败: ${e.message}`);\n"
|
||||
service_code += " }\n"
|
||||
service_code += "}\n"
|
||||
service_code += "\n"
|
||||
service_code += "/**\n"
|
||||
service_code += " * 处理请求\n"
|
||||
service_code += " * @param {http.IncomingMessage} req 请求对象\n"
|
||||
service_code += " * @param {http.ServerResponse} res 响应对象\n"
|
||||
service_code += " */\n"
|
||||
service_code += "function handleRequest(req, res) {\n"
|
||||
service_code += " const parsedUrl = url.parse(req.url, true);\n"
|
||||
service_code += " const pathname = parsedUrl.pathname;\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 设置响应头\n"
|
||||
service_code += " res.setHeader('Content-Type', 'application/json');\n"
|
||||
service_code += " \n"
|
||||
service_code += " if (req.method === 'GET') {\n"
|
||||
service_code += " if (pathname === '/health') {\n"
|
||||
service_code += " // 健康检查\n"
|
||||
service_code += " res.writeHead(200);\n"
|
||||
service_code += " res.end(JSON.stringify({ status: 'healthy', service: '{{name}}' }));\n"
|
||||
service_code += " } else if (pathname === '/info') {\n"
|
||||
service_code += " // 服务信息\n"
|
||||
service_code += " const info = {\n"
|
||||
service_code += " service: '{{name}}',\n"
|
||||
service_code += " version: '{{version}}',\n"
|
||||
service_code += " host: HOST,\n"
|
||||
service_code += " port: PORT,\n"
|
||||
service_code += " timeout: TIMEOUT,\n"
|
||||
service_code += " algorithm_loaded: algorithmModule !== null\n"
|
||||
service_code += " };\n"
|
||||
service_code += " res.writeHead(200);\n"
|
||||
service_code += " res.end(JSON.stringify(info));\n"
|
||||
service_code += " } else {\n"
|
||||
service_code += " // 404响应\n"
|
||||
service_code += " res.writeHead(404);\n"
|
||||
service_code += " res.end(JSON.stringify({ error: 'Not Found' }));\n"
|
||||
service_code += " }\n"
|
||||
service_code += " } else if (req.method === 'POST') {\n"
|
||||
service_code += " let body = '';\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 读取请求体\n"
|
||||
service_code += " req.on('data', chunk => {\n"
|
||||
service_code += " body += chunk.toString();\n"
|
||||
service_code += " });\n"
|
||||
service_code += " \n"
|
||||
service_code += " req.on('end', () => {\n"
|
||||
service_code += " try {\n"
|
||||
service_code += " // 解析请求数据\n"
|
||||
service_code += " const requestData = JSON.parse(body);\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 记录请求开始时间\n"
|
||||
service_code += " const startTime = Date.now();\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 调用算法\n"
|
||||
service_code += " const result = callAlgorithm(requestData);\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 计算响应时间\n"
|
||||
service_code += " const responseTime = (Date.now() - startTime) / 1000;\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 构建响应\n"
|
||||
service_code += " const response = {\n"
|
||||
service_code += " success: true,\n"
|
||||
service_code += " result: result,\n"
|
||||
service_code += " response_time: responseTime.toFixed(4),\n"
|
||||
service_code += " message: '算法执行成功'\n"
|
||||
service_code += " };\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 发送响应\n"
|
||||
service_code += " res.writeHead(200);\n"
|
||||
service_code += " res.end(JSON.stringify(response));\n"
|
||||
service_code += " } catch (e) {\n"
|
||||
service_code += " // 构建错误响应\n"
|
||||
service_code += " const errorResponse = {\n"
|
||||
service_code += " success: false,\n"
|
||||
service_code += " error: e.message,\n"
|
||||
service_code += " message: '算法执行失败'\n"
|
||||
service_code += " };\n"
|
||||
service_code += " \n"
|
||||
service_code += " // 发送错误响应\n"
|
||||
service_code += " res.writeHead(400);\n"
|
||||
service_code += " res.end(JSON.stringify(errorResponse));\n"
|
||||
service_code += " }\n"
|
||||
service_code += " });\n"
|
||||
service_code += " } else {\n"
|
||||
service_code += " // 不支持的方法\n"
|
||||
service_code += " res.writeHead(405);\n"
|
||||
service_code += " res.end(JSON.stringify({ error: 'Method Not Allowed' }));\n"
|
||||
service_code += " }\n"
|
||||
service_code += "}\n"
|
||||
service_code += "\n"
|
||||
service_code += "/**\n"
|
||||
service_code += " * 启动服务\n"
|
||||
service_code += " */\n"
|
||||
service_code += "function startServer() {\n"
|
||||
service_code += " const server = http.createServer(handleRequest);\n"
|
||||
service_code += " \n"
|
||||
service_code += " server.listen(PORT, HOST, () => {\n"
|
||||
service_code += " console.log(`服务启动成功,监听地址: ${HOST}:${PORT}`);\n"
|
||||
service_code += " console.log(`健康检查地址: http://${HOST}:${PORT}/health`);\n"
|
||||
service_code += " console.log(`服务信息地址: http://${HOST}:${PORT}/info`);\n"
|
||||
service_code += " });\n"
|
||||
service_code += " \n"
|
||||
service_code += " server.on('error', (error) => {\n"
|
||||
service_code += " console.error('服务启动失败:', error);\n"
|
||||
service_code += " });\n"
|
||||
service_code += "}\n"
|
||||
service_code += "\n"
|
||||
service_code += "// 启动服务\n"
|
||||
service_code += "startServer();\n"
|
||||
|
||||
# 替换模板变量
|
||||
service_code = service_code.replace("{{entry_point}}", project_info.get("entry_point", ""))
|
||||
service_code = service_code.replace("{{port}}", str(service_config.get("port", 8000)))
|
||||
service_code = service_code.replace("{{timeout}}", str(service_config.get("timeout", 30)))
|
||||
service_code = service_code.replace("{{name}}", service_config.get("name", "algorithm-service"))
|
||||
service_code = service_code.replace("{{version}}", service_config.get("version", "1.0.0"))
|
||||
|
||||
return service_code
|
||||
|
||||
def _create_default_templates(self, template_dir: str):
|
||||
"""创建默认模板
|
||||
|
||||
Args:
|
||||
template_dir: 模板目录
|
||||
"""
|
||||
# 创建Python HTTP服务模板
|
||||
python_http_template = '''
|
||||
# Python HTTP服务包装器
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
# 添加项目路径到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# 尝试导入算法模块
|
||||
try:
|
||||
# 根据入口点导入算法
|
||||
if "{{entry_point}}" == "":
|
||||
# 尝试导入主要模块
|
||||
import algorithm
|
||||
algorithm_module = algorithm
|
||||
else:
|
||||
# 动态导入入口点
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("algorithm_module", "{{entry_point}}")
|
||||
algorithm_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(algorithm_module)
|
||||
print("算法模块导入成功")
|
||||
except Exception as e:
|
||||
print(f"算法模块导入失败: {e}")
|
||||
algorithm_module = None
|
||||
|
||||
# 服务配置
|
||||
HOST = os.environ.get("HOST", "{{host}}")
|
||||
PORT = int(os.environ.get("PORT", "{{port}}"))
|
||||
TIMEOUT = int(os.environ.get("TIMEOUT", "{{timeout}}"))
|
||||
|
||||
class AlgorithmRequestHandler(BaseHTTPRequestHandler):
|
||||
"""算法请求处理器"""
|
||||
|
||||
def do_POST(self):
|
||||
"""处理POST请求"""
|
||||
try:
|
||||
# 读取请求体
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
|
||||
# 解析请求数据
|
||||
request_data = json.loads(post_data.decode('utf-8'))
|
||||
|
||||
# 记录请求开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 调用算法
|
||||
result = self._call_algorithm(request_data)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# 构建响应
|
||||
response = {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"response_time": round(response_time, 4),
|
||||
"message": "算法执行成功"
|
||||
}
|
||||
|
||||
# 发送响应
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
# 构建错误响应
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "算法执行失败"
|
||||
}
|
||||
|
||||
# 发送错误响应
|
||||
self.send_response(400)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(error_response).encode('utf-8'))
|
||||
|
||||
def do_GET(self):
|
||||
"""处理GET请求"""
|
||||
if self.path == "{{health_check_path}}":
|
||||
# 健康检查
|
||||
self._handle_health_check()
|
||||
elif self.path == "/info":
|
||||
# 服务信息
|
||||
self._handle_info()
|
||||
else:
|
||||
# 404响应
|
||||
self.send_response(404)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"error": "Not Found"}).encode('utf-8'))
|
||||
|
||||
def _call_algorithm(self, request_data):
|
||||
"""调用算法
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
|
||||
Returns:
|
||||
算法执行结果
|
||||
"""
|
||||
if algorithm_module is None:
|
||||
raise Exception("算法模块未加载")
|
||||
|
||||
# 尝试调用算法的主要函数
|
||||
try:
|
||||
# 检查是否有predict函数
|
||||
if hasattr(algorithm_module, 'predict'):
|
||||
return algorithm_module.predict(request_data)
|
||||
# 检查是否有run函数
|
||||
elif hasattr(algorithm_module, 'run'):
|
||||
return algorithm_module.run(request_data)
|
||||
# 检查是否有main函数
|
||||
elif hasattr(algorithm_module, 'main'):
|
||||
return algorithm_module.main(request_data)
|
||||
else:
|
||||
raise Exception("未找到算法执行函数")
|
||||
except Exception as e:
|
||||
raise Exception(f"算法执行失败: {e}")
|
||||
|
||||
def _handle_health_check(self):
|
||||
"""处理健康检查"""
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "healthy", "service": "{{project_name}}"}).encode('utf-8'))
|
||||
|
||||
def _handle_info(self):
|
||||
"""处理服务信息请求"""
|
||||
info = {
|
||||
"service": "{{project_name}}",
|
||||
"version": "1.0.0",
|
||||
"host": HOST,
|
||||
"port": PORT,
|
||||
"timeout": TIMEOUT,
|
||||
"algorithm_loaded": algorithm_module is not None
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(info).encode('utf-8'))
|
||||
|
||||
def run_server():
|
||||
"""启动服务"""
|
||||
server = HTTPServer((HOST, PORT), AlgorithmRequestHandler)
|
||||
print(f"服务启动成功,监听地址: {HOST}:{PORT}")
|
||||
print(f"健康检查地址: http://{HOST}:{PORT}{{health_check_path}}")
|
||||
print(f"服务信息地址: http://{HOST}:{PORT}/info")
|
||||
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("服务停止")
|
||||
server.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
'''
|
||||
|
||||
with open(os.path.join(template_dir, "python_http_service.py.j2"), "w") as f:
|
||||
f.write(python_http_template)
|
||||
|
||||
# 创建Node.js HTTP服务模板
|
||||
nodejs_http_template = '''
|
||||
// Node.js HTTP服务包装器
|
||||
|
||||
const http = require('http');
|
||||
const url = require('url');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
// 服务配置
|
||||
const HOST = process.env.HOST || '{{host}}';
|
||||
const PORT = process.env.PORT || {{port}};
|
||||
const TIMEOUT = process.env.TIMEOUT || {{timeout}};
|
||||
|
||||
// 尝试导入算法模块
|
||||
let algorithmModule = null;
|
||||
try {
|
||||
// 根据入口点导入算法
|
||||
if ("{{entry_point}}" === "") {
|
||||
// 尝试导入主要模块
|
||||
algorithmModule = require('./algorithm');
|
||||
} else {
|
||||
// 导入入口点
|
||||
algorithmModule = require('./{{entry_point}}');
|
||||
}
|
||||
console.log('算法模块导入成功');
|
||||
} catch (e) {
|
||||
console.error('算法模块导入失败:', e);
|
||||
algorithmModule = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用算法
|
||||
* @param {Object} requestData 请求数据
|
||||
* @returns {Object} 算法执行结果
|
||||
*/
|
||||
function callAlgorithm(requestData) {
|
||||
if (!algorithmModule) {
|
||||
throw new Error('算法模块未加载');
|
||||
}
|
||||
|
||||
// 尝试调用算法的主要函数
|
||||
try {
|
||||
if (algorithmModule.predict) {
|
||||
return algorithmModule.predict(requestData);
|
||||
} else if (algorithmModule.run) {
|
||||
return algorithmModule.run(requestData);
|
||||
} else if (algorithmModule.main) {
|
||||
return algorithmModule.main(requestData);
|
||||
} else {
|
||||
throw new Error('未找到算法执行函数');
|
||||
}
|
||||
} catch (e) {
|
||||
throw new Error(`算法执行失败: ${e.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理请求
|
||||
* @param {http.IncomingMessage} req 请求对象
|
||||
* @param {http.ServerResponse} res 响应对象
|
||||
*/
|
||||
function handleRequest(req, res) {
|
||||
const parsedUrl = url.parse(req.url, true);
|
||||
const pathname = parsedUrl.pathname;
|
||||
|
||||
// 设置响应头
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
|
||||
if (req.method === 'GET') {
|
||||
if (pathname === "{{health_check_path}}") {
|
||||
// 健康检查
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify({ status: 'healthy', service: '{{project_name}}' }));
|
||||
} else if (pathname === '/info') {
|
||||
// 服务信息
|
||||
const info = {
|
||||
service: '{{project_name}}',
|
||||
version: '1.0.0',
|
||||
host: HOST,
|
||||
port: PORT,
|
||||
timeout: TIMEOUT,
|
||||
algorithm_loaded: algorithmModule !== null
|
||||
};
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(info));
|
||||
} else {
|
||||
// 404响应
|
||||
res.writeHead(404);
|
||||
res.end(JSON.stringify({ error: 'Not Found' }));
|
||||
}
|
||||
} else if (req.method === 'POST') {
|
||||
let body = '';
|
||||
|
||||
// 读取请求体
|
||||
req.on('data', chunk => {
|
||||
body += chunk.toString();
|
||||
});
|
||||
|
||||
req.on('end', () => {
|
||||
try {
|
||||
// 解析请求数据
|
||||
const requestData = JSON.parse(body);
|
||||
|
||||
// 记录请求开始时间
|
||||
const startTime = Date.now();
|
||||
|
||||
// 调用算法
|
||||
const result = callAlgorithm(requestData);
|
||||
|
||||
// 计算响应时间
|
||||
const responseTime = (Date.now() - startTime) / 1000;
|
||||
|
||||
// 构建响应
|
||||
const response = {
|
||||
success: true,
|
||||
result: result,
|
||||
response_time: responseTime.toFixed(4),
|
||||
message: '算法执行成功'
|
||||
};
|
||||
|
||||
// 发送响应
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(response));
|
||||
} catch (e) {
|
||||
// 构建错误响应
|
||||
const errorResponse = {
|
||||
success: false,
|
||||
error: e.message,
|
||||
message: '算法执行失败'
|
||||
};
|
||||
|
||||
// 发送错误响应
|
||||
res.writeHead(400);
|
||||
res.end(JSON.stringify(errorResponse));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 不支持的方法
|
||||
res.writeHead(405);
|
||||
res.end(JSON.stringify({ error: 'Method Not Allowed' }));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 启动服务
|
||||
*/
|
||||
function startServer() {
|
||||
const server = http.createServer(handleRequest);
|
||||
|
||||
server.listen(PORT, HOST, () => {
|
||||
console.log(`服务启动成功,监听地址: ${HOST}:${PORT}`);
|
||||
console.log(`健康检查地址: http://${HOST}:${PORT}{{health_check_path}}`);
|
||||
console.log(`服务信息地址: http://${HOST}:${PORT}/info`);
|
||||
});
|
||||
|
||||
server.on('error', (error) => {
|
||||
console.error('服务启动失败:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// 启动服务
|
||||
startServer();
|
||||
'''
|
||||
|
||||
with open(os.path.join(template_dir, "nodejs_http_service.js.j2"), "w") as f:
|
||||
f.write(nodejs_http_template)
|
||||
213
backend/app/services/service_manager.py
Normal file
213
backend/app/services/service_manager.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""服务管理模块,负责管理算法服务的状态和配置"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceStatus:
|
||||
"""服务状态枚举"""
|
||||
HEALTHY = "healthy"
|
||||
UNHEALTHY = "unhealthy"
|
||||
UNKNOWN = "unknown"
|
||||
MAINTENANCE = "maintenance"
|
||||
|
||||
|
||||
class ServiceInfo:
|
||||
"""服务信息类"""
|
||||
def __init__(
|
||||
self,
|
||||
service_id: str,
|
||||
name: str,
|
||||
url: str,
|
||||
status: str = ServiceStatus.UNKNOWN,
|
||||
last_heartbeat: Optional[datetime] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.service_id = service_id
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.status = status
|
||||
self.last_heartbeat = last_heartbeat or datetime.utcnow()
|
||||
self.metadata = metadata or {}
|
||||
self.created_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"service_id": self.service_id,
|
||||
"name": self.name,
|
||||
"url": self.url,
|
||||
"status": self.status,
|
||||
"last_heartbeat": self.last_heartbeat.isoformat() if self.last_heartbeat else None,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class ServiceManager:
|
||||
"""服务管理器,管理所有注册的服务"""
|
||||
|
||||
def __init__(self):
|
||||
self._services: Dict[str, ServiceInfo] = {}
|
||||
self._health_check_interval = 30 # 健康检查间隔(秒)
|
||||
self._monitoring_task = None
|
||||
|
||||
def register_service(
|
||||
self,
|
||||
service_id: str,
|
||||
name: str,
|
||||
url: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""注册服务"""
|
||||
try:
|
||||
service_info = ServiceInfo(
|
||||
service_id=service_id,
|
||||
name=name,
|
||||
url=url,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self._services[service_id] = service_info
|
||||
logger.info(f"Service registered: {service_id} at {url}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register service {service_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def unregister_service(self, service_id: str) -> bool:
|
||||
"""注销服务"""
|
||||
if service_id in self._services:
|
||||
del self._services[service_id]
|
||||
logger.info(f"Service unregistered: {service_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_service(self, service_id: str) -> Optional[ServiceInfo]:
|
||||
"""获取服务信息"""
|
||||
return self._services.get(service_id)
|
||||
|
||||
def get_all_services(self) -> List[ServiceInfo]:
|
||||
"""获取所有服务信息"""
|
||||
return list(self._services.values())
|
||||
|
||||
def update_service_status(self, service_id: str, status: str) -> bool:
|
||||
"""更新服务状态"""
|
||||
if service_id in self._services:
|
||||
service = self._services[service_id]
|
||||
service.status = status
|
||||
service.updated_at = datetime.utcnow()
|
||||
service.last_heartbeat = datetime.utcnow()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_service_metadata(self, service_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""更新服务元数据"""
|
||||
if service_id in self._services:
|
||||
service = self._services[service_id]
|
||||
service.metadata.update(metadata)
|
||||
service.updated_at = datetime.utcnow()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def health_check_single(self, service_info: ServiceInfo) -> str:
|
||||
"""对单个服务进行健康检查"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 尝试访问服务的健康检查端点
|
||||
health_url = f"{service_info.url.rstrip('/')}/health"
|
||||
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
# 检查响应中的状态字段
|
||||
if isinstance(data, dict) and data.get("status") == "healthy":
|
||||
return ServiceStatus.HEALTHY
|
||||
return ServiceStatus.HEALTHY
|
||||
else:
|
||||
return ServiceStatus.UNHEALTHY
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Health check timeout for service {service_info.service_id}")
|
||||
return ServiceStatus.UNHEALTHY
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for service {service_info.service_id}: {str(e)}")
|
||||
return ServiceStatus.UNHEALTHY
|
||||
|
||||
async def health_check_all(self):
|
||||
"""对所有服务进行健康检查"""
|
||||
tasks = []
|
||||
for service_info in self._services.values():
|
||||
task = self.health_check_single(service_info)
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for service_info, result in zip(self._services.values(), results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Health check error for {service_info.service_id}: {result}")
|
||||
new_status = ServiceStatus.UNHEALTHY
|
||||
else:
|
||||
new_status = result
|
||||
|
||||
# 如果状态发生变化,则更新
|
||||
if service_info.status != new_status:
|
||||
logger.info(f"Service {service_info.service_id} status changed from {service_info.status} to {new_status}")
|
||||
self.update_service_status(service_info.service_id, new_status)
|
||||
|
||||
async def start_monitoring(self):
|
||||
"""启动服务监控"""
|
||||
if self._monitoring_task and not self._monitoring_task.done():
|
||||
logger.warning("Monitoring task already running")
|
||||
return
|
||||
|
||||
self._monitoring_task = asyncio.create_task(self._monitor_loop())
|
||||
logger.info("Service monitoring started")
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""停止服务监控"""
|
||||
if self._monitoring_task and not self._monitoring_task.done():
|
||||
self._monitoring_task.cancel()
|
||||
try:
|
||||
await self._monitoring_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Service monitoring stopped")
|
||||
|
||||
async def _monitor_loop(self):
|
||||
"""监控循环"""
|
||||
while True:
|
||||
try:
|
||||
await self.health_check_all()
|
||||
await asyncio.sleep(self._health_check_interval)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Monitor loop cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitor loop: {str(e)}")
|
||||
await asyncio.sleep(self._health_check_interval)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取服务统计信息"""
|
||||
total_services = len(self._services)
|
||||
healthy_count = sum(1 for s in self._services.values() if s.status == ServiceStatus.HEALTHY)
|
||||
unhealthy_count = sum(1 for s in self._services.values() if s.status == ServiceStatus.UNHEALTHY)
|
||||
unknown_count = sum(1 for s in self._services.values() if s.status == ServiceStatus.UNKNOWN)
|
||||
|
||||
return {
|
||||
"total_services": total_services,
|
||||
"healthy_services": healthy_count,
|
||||
"unhealthy_services": unhealthy_count,
|
||||
"unknown_services": unknown_count,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# 全局服务管理器实例
|
||||
service_manager = ServiceManager()
|
||||
962
backend/app/services/service_orchestrator.py
Normal file
962
backend/app/services/service_orchestrator.py
Normal file
@@ -0,0 +1,962 @@
|
||||
"""服务编排服务,用于管理算法服务的生命周期"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import docker
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from docker.errors import DockerException, NotFound
|
||||
|
||||
|
||||
class ServiceOrchestrator:
|
||||
"""服务编排服务"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务编排器"""
|
||||
try:
|
||||
# 连接Docker客户端
|
||||
self.client = docker.from_env()
|
||||
# 测试连接
|
||||
self.client.ping()
|
||||
print("Docker连接成功")
|
||||
except DockerException as e:
|
||||
print(f"Docker连接失败: {e}")
|
||||
self.client = None
|
||||
|
||||
def deploy_service(self, service_id: str, service_config: Dict[str, Any], project_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""部署服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
service_config: 服务配置
|
||||
project_info: 项目信息
|
||||
|
||||
Returns:
|
||||
部署结果
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"service_id": service_id,
|
||||
"container_id": None,
|
||||
"status": "error",
|
||||
"api_url": None
|
||||
}
|
||||
|
||||
# 1. 构建Docker镜像
|
||||
image_name = self._build_docker_image(service_id, project_info, service_config)
|
||||
|
||||
# 2. 启动服务容器
|
||||
container_id = self._start_service_container(service_id, image_name, service_config)
|
||||
|
||||
# 3. 验证服务启动
|
||||
if not self._verify_service_startup(container_id, service_config):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务启动验证失败",
|
||||
"service_id": service_id,
|
||||
"container_id": container_id,
|
||||
"status": "error",
|
||||
"api_url": None
|
||||
}
|
||||
|
||||
# 4. 构建API URL
|
||||
api_url = f"http://{service_config.get('host', 'localhost')}:{service_config.get('port', 8000)}"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"container_id": container_id,
|
||||
"status": "running",
|
||||
"api_url": api_url,
|
||||
"error": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"service_id": service_id,
|
||||
"container_id": None,
|
||||
"status": "error",
|
||||
"api_url": None
|
||||
}
|
||||
|
||||
def start_service(self, service_id: str, container_id: str) -> Dict[str, Any]:
|
||||
"""启动服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
container_id: 容器ID
|
||||
|
||||
Returns:
|
||||
启动结果
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
# 获取容器
|
||||
container = self.client.containers.get(container_id)
|
||||
|
||||
# 启动容器
|
||||
container.start()
|
||||
|
||||
# 验证服务启动
|
||||
if not self._verify_service_health(container_id):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务健康检查失败",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"status": "running",
|
||||
"error": None
|
||||
}
|
||||
except NotFound:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "容器不存在",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
def stop_service(self, service_id: str, container_id: str) -> Dict[str, Any]:
|
||||
"""停止服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
container_id: 容器ID
|
||||
|
||||
Returns:
|
||||
停止结果
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
# 获取容器
|
||||
container = self.client.containers.get(container_id)
|
||||
|
||||
# 停止容器
|
||||
container.stop(timeout=30)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"status": "stopped",
|
||||
"error": None
|
||||
}
|
||||
except NotFound:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "容器不存在",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
def restart_service(self, service_id: str, container_id: str) -> Dict[str, Any]:
|
||||
"""重启服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
container_id: 容器ID
|
||||
|
||||
Returns:
|
||||
重启结果
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
# 获取容器
|
||||
container = self.client.containers.get(container_id)
|
||||
|
||||
# 重启容器
|
||||
container.restart(timeout=30)
|
||||
|
||||
# 验证服务启动
|
||||
if not self._verify_service_health(container_id):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "服务健康检查失败",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"status": "running",
|
||||
"error": None
|
||||
}
|
||||
except NotFound:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "容器不存在",
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"service_id": service_id,
|
||||
"status": "error"
|
||||
}
|
||||
|
||||
def delete_service(self, service_id: str, container_id: str, image_name: str) -> Dict[str, Any]:
|
||||
"""删除服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
container_id: 容器ID
|
||||
image_name: 镜像名称
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"service_id": service_id
|
||||
}
|
||||
|
||||
# 停止并删除容器
|
||||
if container_id:
|
||||
try:
|
||||
container = self.client.containers.get(container_id)
|
||||
container.stop(timeout=10)
|
||||
container.remove(force=True)
|
||||
except NotFound:
|
||||
pass
|
||||
|
||||
# 删除镜像
|
||||
if image_name:
|
||||
try:
|
||||
self.client.images.remove(image_name, force=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service_id": service_id,
|
||||
"error": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"service_id": service_id
|
||||
}
|
||||
|
||||
def get_service_status(self, container_id: str) -> Dict[str, Any]:
|
||||
"""获取服务状态
|
||||
|
||||
Args:
|
||||
container_id: 容器ID
|
||||
|
||||
Returns:
|
||||
服务状态
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"status": "unknown",
|
||||
"health": "unknown"
|
||||
}
|
||||
|
||||
# 获取容器
|
||||
container = self.client.containers.get(container_id)
|
||||
|
||||
# 获取容器状态
|
||||
status = container.status
|
||||
|
||||
# 检查服务健康状态
|
||||
health = "unknown"
|
||||
if status == "running":
|
||||
if self._verify_service_health(container_id):
|
||||
health = "healthy"
|
||||
else:
|
||||
health = "unhealthy"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status": status,
|
||||
"health": health,
|
||||
"error": None
|
||||
}
|
||||
except NotFound:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "容器不存在",
|
||||
"status": "not_found",
|
||||
"health": "unknown"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"status": "unknown",
|
||||
"health": "unknown"
|
||||
}
|
||||
|
||||
def get_service_logs(self, container_id: str, lines: int = 100) -> Dict[str, Any]:
|
||||
"""获取服务日志
|
||||
|
||||
Args:
|
||||
container_id: 容器ID
|
||||
lines: 日志行数
|
||||
|
||||
Returns:
|
||||
服务日志
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Docker连接失败",
|
||||
"logs": []
|
||||
}
|
||||
|
||||
# 获取容器
|
||||
container = self.client.containers.get(container_id)
|
||||
|
||||
# 获取日志
|
||||
logs = container.logs(tail=lines).decode('utf-8').split('\n')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"logs": logs,
|
||||
"error": None
|
||||
}
|
||||
except NotFound:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "容器不存在",
|
||||
"logs": []
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"logs": []
|
||||
}
|
||||
|
||||
def _build_docker_image(self, service_id: str, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""构建Docker镜像
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
镜像名称
|
||||
"""
|
||||
# 生成镜像名称
|
||||
image_name = f"algorithm-service-{service_id}:{service_config.get('version', '1.0.0')}"
|
||||
|
||||
# 构建上下文目录
|
||||
build_context = os.path.join("/tmp", f"service-build-{service_id}")
|
||||
os.makedirs(build_context, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 创建Dockerfile
|
||||
dockerfile_content = self._generate_dockerfile(project_info, service_config)
|
||||
with open(os.path.join(build_context, "Dockerfile"), "w") as f:
|
||||
f.write(dockerfile_content)
|
||||
|
||||
# 复制项目文件(这里简化处理,实际应该复制完整项目)
|
||||
# 注意:在实际实现中,应该从算法仓库复制项目文件到构建上下文
|
||||
|
||||
# 创建服务包装器
|
||||
service_wrapper_content = self._generate_service_wrapper(project_info, service_config)
|
||||
wrapper_extension = ".py" if project_info["project_type"] == "python" else ".js"
|
||||
with open(os.path.join(build_context, f"service_wrapper{wrapper_extension}"), "w") as f:
|
||||
f.write(service_wrapper_content)
|
||||
|
||||
# 创建依赖文件
|
||||
self._create_dependency_file(build_context, project_info)
|
||||
|
||||
# 构建镜像
|
||||
print(f"开始构建Docker镜像: {image_name}")
|
||||
image, logs = self.client.images.build(
|
||||
path=build_context,
|
||||
tag=image_name,
|
||||
rm=True,
|
||||
pull=False
|
||||
)
|
||||
|
||||
# 打印构建日志
|
||||
for log in logs:
|
||||
if 'stream' in log:
|
||||
print(log['stream'], end='')
|
||||
|
||||
print(f"Docker镜像构建成功: {image_name}")
|
||||
return image_name
|
||||
finally:
|
||||
# 清理构建上下文
|
||||
import shutil
|
||||
try:
|
||||
shutil.rmtree(build_context)
|
||||
except:
|
||||
pass
|
||||
|
||||
def _start_service_container(self, service_id: str, image_name: str, service_config: Dict[str, Any]) -> str:
|
||||
"""启动服务容器
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
image_name: 镜像名称
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
容器ID
|
||||
"""
|
||||
# 容器名称
|
||||
container_name = f"algorithm-service-{service_id}"
|
||||
|
||||
# 端口映射
|
||||
ports = {
|
||||
f"{service_config.get('port', 8000)}/tcp": service_config.get('port', 8000)
|
||||
}
|
||||
|
||||
# 环境变量
|
||||
environment = {
|
||||
"HOST": service_config.get('host', '0.0.0.0'),
|
||||
"PORT": str(service_config.get('port', 8000)),
|
||||
"TIMEOUT": str(service_config.get('timeout', 30))
|
||||
}
|
||||
|
||||
# 启动容器
|
||||
container = self.client.containers.run(
|
||||
image_name,
|
||||
name=container_name,
|
||||
ports=ports,
|
||||
environment=environment,
|
||||
detach=True,
|
||||
restart_policy={"Name": "unless-stopped"}
|
||||
)
|
||||
|
||||
print(f"容器启动成功: {container.id}")
|
||||
return container.id
|
||||
|
||||
def _verify_service_startup(self, container_id: str, service_config: Dict[str, Any]) -> bool:
|
||||
"""验证服务启动
|
||||
|
||||
Args:
|
||||
container_id: 容器ID
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
是否启动成功
|
||||
"""
|
||||
# 等待服务启动
|
||||
time.sleep(5)
|
||||
|
||||
# 验证服务健康状态
|
||||
return self._verify_service_health(container_id)
|
||||
|
||||
def _verify_service_health(self, container_id: str) -> bool:
|
||||
"""验证服务健康状态
|
||||
|
||||
Args:
|
||||
container_id: 容器ID
|
||||
|
||||
Returns:
|
||||
是否健康
|
||||
"""
|
||||
try:
|
||||
container = self.client.containers.get(container_id)
|
||||
|
||||
# 检查容器状态
|
||||
if container.status != "running":
|
||||
return False
|
||||
|
||||
# 这里可以添加更详细的健康检查,例如发送HTTP请求到/health端点
|
||||
# 简化处理,只检查容器状态
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def _generate_dockerfile(self, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""生成Dockerfile
|
||||
|
||||
Args:
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
Dockerfile内容
|
||||
"""
|
||||
dockerfile_templates = {
|
||||
"python": """
|
||||
FROM python:3.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 复制生成的服务包装器
|
||||
COPY service_wrapper.py .
|
||||
|
||||
# 设置环境变量
|
||||
ENV HOST={{host}}
|
||||
ENV PORT={{port}}
|
||||
ENV TIMEOUT={{timeout}}
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE {{port}}
|
||||
|
||||
# 启动服务
|
||||
CMD ["python", "service_wrapper.py"]
|
||||
""",
|
||||
"nodejs": """
|
||||
FROM node:16-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制依赖文件
|
||||
COPY package.json package-lock.json* .
|
||||
|
||||
# 安装依赖
|
||||
RUN npm install --production
|
||||
|
||||
# 复制项目文件
|
||||
COPY . .
|
||||
|
||||
# 复制生成的服务包装器
|
||||
COPY service_wrapper.js .
|
||||
|
||||
# 设置环境变量
|
||||
ENV HOST={{host}}
|
||||
ENV PORT={{port}}
|
||||
ENV TIMEOUT={{timeout}}
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE {{port}}
|
||||
|
||||
# 启动服务
|
||||
CMD ["node", "service_wrapper.js"]
|
||||
"""
|
||||
}
|
||||
|
||||
template = dockerfile_templates.get(project_info["project_type"], dockerfile_templates["python"])
|
||||
|
||||
# 替换模板变量
|
||||
template = template.replace("{{host}}", service_config.get("host", "0.0.0.0"))
|
||||
template = template.replace("{{port}}", str(service_config.get("port", 8000)))
|
||||
template = template.replace("{{timeout}}", str(service_config.get("timeout", 30)))
|
||||
|
||||
return template
|
||||
|
||||
def _generate_service_wrapper(self, project_info: Dict[str, Any], service_config: Dict[str, Any]) -> str:
|
||||
"""生成服务包装器
|
||||
|
||||
Args:
|
||||
project_info: 项目信息
|
||||
service_config: 服务配置
|
||||
|
||||
Returns:
|
||||
服务包装器代码
|
||||
"""
|
||||
if project_info["project_type"] == "python":
|
||||
return '''
|
||||
# Python HTTP服务包装器
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
# 添加项目路径到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# 尝试导入算法模块
|
||||
try:
|
||||
# 尝试导入主要模块
|
||||
import algorithm
|
||||
algorithm_module = algorithm
|
||||
print("算法模块导入成功")
|
||||
except Exception as e:
|
||||
print(f"算法模块导入失败: {e}")
|
||||
algorithm_module = None
|
||||
|
||||
# 服务配置
|
||||
HOST = os.environ.get("HOST", "0.0.0.0")
|
||||
PORT = int(os.environ.get("PORT", "8000"))
|
||||
TIMEOUT = int(os.environ.get("TIMEOUT", "30"))
|
||||
|
||||
class AlgorithmRequestHandler(BaseHTTPRequestHandler):
|
||||
"""算法请求处理器"""
|
||||
|
||||
def do_POST(self):
|
||||
"""处理POST请求"""
|
||||
try:
|
||||
# 读取请求体
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
|
||||
# 解析请求数据
|
||||
request_data = json.loads(post_data.decode('utf-8'))
|
||||
|
||||
# 记录请求开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 调用算法
|
||||
result = self._call_algorithm(request_data)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# 构建响应
|
||||
response = {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"response_time": round(response_time, 4),
|
||||
"message": "算法执行成功"
|
||||
}
|
||||
|
||||
# 发送响应
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
# 构建错误响应
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "算法执行失败"
|
||||
}
|
||||
|
||||
# 发送错误响应
|
||||
self.send_response(400)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(error_response).encode('utf-8'))
|
||||
|
||||
def do_GET(self):
|
||||
"""处理GET请求"""
|
||||
if self.path == "/health":
|
||||
# 健康检查
|
||||
self._handle_health_check()
|
||||
elif self.path == "/info":
|
||||
# 服务信息
|
||||
self._handle_info()
|
||||
else:
|
||||
# 404响应
|
||||
self.send_response(404)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"error": "Not Found"}).encode('utf-8'))
|
||||
|
||||
def _call_algorithm(self, request_data):
|
||||
"""调用算法
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
|
||||
Returns:
|
||||
算法执行结果
|
||||
"""
|
||||
if algorithm_module is None:
|
||||
raise Exception("算法模块未加载")
|
||||
|
||||
# 尝试调用算法的主要函数
|
||||
try:
|
||||
# 检查是否有predict函数
|
||||
if hasattr(algorithm_module, 'predict'):
|
||||
return algorithm_module.predict(request_data)
|
||||
# 检查是否有run函数
|
||||
elif hasattr(algorithm_module, 'run'):
|
||||
return algorithm_module.run(request_data)
|
||||
# 检查是否有main函数
|
||||
elif hasattr(algorithm_module, 'main'):
|
||||
return algorithm_module.main(request_data)
|
||||
else:
|
||||
raise Exception("未找到算法执行函数")
|
||||
except Exception as e:
|
||||
raise Exception(f"算法执行失败: {e}")
|
||||
|
||||
def _handle_health_check(self):
|
||||
"""处理健康检查"""
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "healthy", "service": "{{service_name}}"}).encode('utf-8'))
|
||||
|
||||
def _handle_info(self):
|
||||
"""处理服务信息请求"""
|
||||
info = {
|
||||
"service": "{{service_name}}",
|
||||
"version": "{{service_version}}",
|
||||
"host": HOST,
|
||||
"port": PORT,
|
||||
"timeout": TIMEOUT,
|
||||
"algorithm_loaded": algorithm_module is not None
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(info).encode('utf-8'))
|
||||
|
||||
def run_server():
|
||||
"""启动服务"""
|
||||
server = HTTPServer((HOST, PORT), AlgorithmRequestHandler)
|
||||
print(f"服务启动成功,监听地址: {HOST}:{PORT}")
|
||||
print(f"健康检查地址: http://{HOST}:{PORT}/health")
|
||||
print(f"服务信息地址: http://{HOST}:{PORT}/info")
|
||||
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("服务停止")
|
||||
server.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
'''.replace("{{service_name}}", service_config.get("name", "algorithm-service")).replace("{{service_version}}", service_config.get("version", "1.0.0"))
|
||||
else:
|
||||
return '''
|
||||
// Node.js HTTP服务包装器
|
||||
|
||||
const http = require('http');
|
||||
const url = require('url');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
// 服务配置
|
||||
const HOST = process.env.HOST || '0.0.0.0';
|
||||
const PORT = process.env.PORT || 8000;
|
||||
const TIMEOUT = process.env.TIMEOUT || 30;
|
||||
|
||||
// 尝试导入算法模块
|
||||
let algorithmModule = null;
|
||||
try {
|
||||
// 尝试导入主要模块
|
||||
algorithmModule = require('./algorithm');
|
||||
console.log('算法模块导入成功');
|
||||
} catch (e) {
|
||||
console.error('算法模块导入失败:', e);
|
||||
algorithmModule = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用算法
|
||||
* @param {Object} requestData 请求数据
|
||||
* @returns {Object} 算法执行结果
|
||||
*/
|
||||
function callAlgorithm(requestData) {
|
||||
if (!algorithmModule) {
|
||||
throw new Error('算法模块未加载');
|
||||
}
|
||||
|
||||
// 尝试调用算法的主要函数
|
||||
try {
|
||||
if (algorithmModule.predict) {
|
||||
return algorithmModule.predict(requestData);
|
||||
} else if (algorithmModule.run) {
|
||||
return algorithmModule.run(requestData);
|
||||
} else if (algorithmModule.main) {
|
||||
return algorithmModule.main(requestData);
|
||||
} else {
|
||||
throw new Error('未找到算法执行函数');
|
||||
}
|
||||
} catch (e) {
|
||||
throw new Error(`算法执行失败: ${e.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理请求
|
||||
* @param {http.IncomingMessage} req 请求对象
|
||||
* @param {http.ServerResponse} res 响应对象
|
||||
*/
|
||||
function handleRequest(req, res) {
|
||||
const parsedUrl = url.parse(req.url, true);
|
||||
const pathname = parsedUrl.pathname;
|
||||
|
||||
// 设置响应头
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
|
||||
if (req.method === 'GET') {
|
||||
if (pathname === '/health') {
|
||||
// 健康检查
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify({ status: 'healthy', service: '{{service_name}}' }));
|
||||
} else if (pathname === '/info') {
|
||||
// 服务信息
|
||||
const info = {
|
||||
service: '{{service_name}}',
|
||||
version: '{{service_version}}',
|
||||
host: HOST,
|
||||
port: PORT,
|
||||
timeout: TIMEOUT,
|
||||
algorithm_loaded: algorithmModule !== null
|
||||
};
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(info));
|
||||
} else {
|
||||
// 404响应
|
||||
res.writeHead(404);
|
||||
res.end(JSON.stringify({ error: 'Not Found' }));
|
||||
}
|
||||
} else if (req.method === 'POST') {
|
||||
let body = '';
|
||||
|
||||
// 读取请求体
|
||||
req.on('data', chunk => {
|
||||
body += chunk.toString();
|
||||
});
|
||||
|
||||
req.on('end', () => {
|
||||
try {
|
||||
// 解析请求数据
|
||||
const requestData = JSON.parse(body);
|
||||
|
||||
// 记录请求开始时间
|
||||
const startTime = Date.now();
|
||||
|
||||
// 调用算法
|
||||
const result = callAlgorithm(requestData);
|
||||
|
||||
// 计算响应时间
|
||||
const responseTime = (Date.now() - startTime) / 1000;
|
||||
|
||||
// 构建响应
|
||||
const response = {
|
||||
success: true,
|
||||
result: result,
|
||||
response_time: responseTime.toFixed(4),
|
||||
message: '算法执行成功'
|
||||
};
|
||||
|
||||
// 发送响应
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(response));
|
||||
} catch (e) {
|
||||
// 构建错误响应
|
||||
const errorResponse = {
|
||||
success: false,
|
||||
error: e.message,
|
||||
message: '算法执行失败'
|
||||
};
|
||||
|
||||
// 发送错误响应
|
||||
res.writeHead(400);
|
||||
res.end(JSON.stringify(errorResponse));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 不支持的方法
|
||||
res.writeHead(405);
|
||||
res.end(JSON.stringify({ error: 'Method Not Allowed' }));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 启动服务
|
||||
*/
|
||||
function startServer() {
|
||||
const server = http.createServer(handleRequest);
|
||||
|
||||
server.listen(PORT, HOST, () => {
|
||||
console.log(`服务启动成功,监听地址: ${HOST}:${PORT}`);
|
||||
console.log(`健康检查地址: http://${HOST}:${PORT}/health`);
|
||||
console.log(`服务信息地址: http://${HOST}:${PORT}/info`);
|
||||
});
|
||||
|
||||
server.on('error', (error) => {
|
||||
console.error('服务启动失败:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// 启动服务
|
||||
startServer();
|
||||
'''.replace("{{service_name}}", service_config.get("name", "algorithm-service")).replace("{{service_version}}", service_config.get("version", "1.0.0"))
|
||||
|
||||
def _create_dependency_file(self, build_context: str, project_info: Dict[str, Any]):
|
||||
"""创建依赖文件
|
||||
|
||||
Args:
|
||||
build_context: 构建上下文目录
|
||||
project_info: 项目信息
|
||||
"""
|
||||
if project_info["project_type"] == "python":
|
||||
# 创建requirements.txt
|
||||
with open(os.path.join(build_context, "requirements.txt"), "w") as f:
|
||||
f.write("""
|
||||
# 基础依赖
|
||||
http.server
|
||||
json
|
||||
|
||||
# 算法依赖
|
||||
# 注意:在实际实现中,应该从项目的requirements.txt复制依赖
|
||||
""")
|
||||
elif project_info["project_type"] == "nodejs":
|
||||
# 创建package.json
|
||||
package_data = {
|
||||
"name": "algorithm-service",
|
||||
"version": "1.0.0",
|
||||
"description": "Algorithm service wrapper",
|
||||
"main": "service_wrapper.js",
|
||||
"scripts": {
|
||||
"start": "node service_wrapper.js"
|
||||
},
|
||||
"dependencies": {
|
||||
# 基础依赖
|
||||
}
|
||||
}
|
||||
with open(os.path.join(build_context, "package.json"), "w") as f:
|
||||
json.dump(package_data, f, indent=2)
|
||||
162
backend/app/services/templates/nodejs_http_service.js.j2
Normal file
162
backend/app/services/templates/nodejs_http_service.js.j2
Normal file
@@ -0,0 +1,162 @@
|
||||
|
||||
// Node.js HTTP服务包装器
|
||||
|
||||
const http = require('http');
|
||||
const url = require('url');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
// 服务配置
|
||||
const HOST = process.env.HOST || '{{host}}';
|
||||
const PORT = process.env.PORT || {{port}};
|
||||
const TIMEOUT = process.env.TIMEOUT || {{timeout}};
|
||||
|
||||
// 尝试导入算法模块
|
||||
let algorithmModule = null;
|
||||
try {
|
||||
// 根据入口点导入算法
|
||||
if ("{{entry_point}}" === "") {
|
||||
// 尝试导入主要模块
|
||||
algorithmModule = require('./algorithm');
|
||||
} else {
|
||||
// 导入入口点
|
||||
algorithmModule = require('./{{entry_point}}');
|
||||
}
|
||||
console.log('算法模块导入成功');
|
||||
} catch (e) {
|
||||
console.error('算法模块导入失败:', e);
|
||||
algorithmModule = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用算法
|
||||
* @param {Object} requestData 请求数据
|
||||
* @returns {Object} 算法执行结果
|
||||
*/
|
||||
function callAlgorithm(requestData) {
|
||||
if (!algorithmModule) {
|
||||
throw new Error('算法模块未加载');
|
||||
}
|
||||
|
||||
// 尝试调用算法的主要函数
|
||||
try {
|
||||
if (algorithmModule.predict) {
|
||||
return algorithmModule.predict(requestData);
|
||||
} else if (algorithmModule.run) {
|
||||
return algorithmModule.run(requestData);
|
||||
} else if (algorithmModule.main) {
|
||||
return algorithmModule.main(requestData);
|
||||
} else {
|
||||
throw new Error('未找到算法执行函数');
|
||||
}
|
||||
} catch (e) {
|
||||
throw new Error(`算法执行失败: ${e.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理请求
|
||||
* @param {http.IncomingMessage} req 请求对象
|
||||
* @param {http.ServerResponse} res 响应对象
|
||||
*/
|
||||
function handleRequest(req, res) {
|
||||
const parsedUrl = url.parse(req.url, true);
|
||||
const pathname = parsedUrl.pathname;
|
||||
|
||||
// 设置响应头
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
|
||||
if (req.method === 'GET') {
|
||||
if (pathname === "{{health_check_path}}") {
|
||||
// 健康检查
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify({ status: 'healthy', service: '{{project_name}}' }));
|
||||
} else if (pathname === '/info') {
|
||||
// 服务信息
|
||||
const info = {
|
||||
service: '{{project_name}}',
|
||||
version: '1.0.0',
|
||||
host: HOST,
|
||||
port: PORT,
|
||||
timeout: TIMEOUT,
|
||||
algorithm_loaded: algorithmModule !== null
|
||||
};
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(info));
|
||||
} else {
|
||||
// 404响应
|
||||
res.writeHead(404);
|
||||
res.end(JSON.stringify({ error: 'Not Found' }));
|
||||
}
|
||||
} else if (req.method === 'POST') {
|
||||
let body = '';
|
||||
|
||||
// 读取请求体
|
||||
req.on('data', chunk => {
|
||||
body += chunk.toString();
|
||||
});
|
||||
|
||||
req.on('end', () => {
|
||||
try {
|
||||
// 解析请求数据
|
||||
const requestData = JSON.parse(body);
|
||||
|
||||
// 记录请求开始时间
|
||||
const startTime = Date.now();
|
||||
|
||||
// 调用算法
|
||||
const result = callAlgorithm(requestData);
|
||||
|
||||
// 计算响应时间
|
||||
const responseTime = (Date.now() - startTime) / 1000;
|
||||
|
||||
// 构建响应
|
||||
const response = {
|
||||
success: true,
|
||||
result: result,
|
||||
response_time: responseTime.toFixed(4),
|
||||
message: '算法执行成功'
|
||||
};
|
||||
|
||||
// 发送响应
|
||||
res.writeHead(200);
|
||||
res.end(JSON.stringify(response));
|
||||
} catch (e) {
|
||||
// 构建错误响应
|
||||
const errorResponse = {
|
||||
success: false,
|
||||
error: e.message,
|
||||
message: '算法执行失败'
|
||||
};
|
||||
|
||||
// 发送错误响应
|
||||
res.writeHead(400);
|
||||
res.end(JSON.stringify(errorResponse));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 不支持的方法
|
||||
res.writeHead(405);
|
||||
res.end(JSON.stringify({ error: 'Method Not Allowed' }));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 启动服务
|
||||
*/
|
||||
function startServer() {
|
||||
const server = http.createServer(handleRequest);
|
||||
|
||||
server.listen(PORT, HOST, () => {
|
||||
console.log(`服务启动成功,监听地址: ${HOST}:${PORT}`);
|
||||
console.log(`健康检查地址: http://${HOST}:${PORT}{{health_check_path}}`);
|
||||
console.log(`服务信息地址: http://${HOST}:${PORT}/info`);
|
||||
});
|
||||
|
||||
server.on('error', (error) => {
|
||||
console.error('服务启动失败:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// 启动服务
|
||||
startServer();
|
||||
166
backend/app/services/templates/python_http_service.py.j2
Normal file
166
backend/app/services/templates/python_http_service.py.j2
Normal file
@@ -0,0 +1,166 @@
|
||||
|
||||
# Python HTTP服务包装器
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
# 添加项目路径到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# 尝试导入算法模块
|
||||
try:
|
||||
# 根据入口点导入算法
|
||||
if "{{entry_point}}" == "":
|
||||
# 尝试导入主要模块
|
||||
import algorithm
|
||||
algorithm_module = algorithm
|
||||
else:
|
||||
# 动态导入入口点
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("algorithm_module", "{{entry_point}}")
|
||||
algorithm_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(algorithm_module)
|
||||
print("算法模块导入成功")
|
||||
except Exception as e:
|
||||
print(f"算法模块导入失败: {e}")
|
||||
algorithm_module = None
|
||||
|
||||
# 服务配置
|
||||
HOST = os.environ.get("HOST", "{{host}}")
|
||||
PORT = int(os.environ.get("PORT", "{{port}}"))
|
||||
TIMEOUT = int(os.environ.get("TIMEOUT", "{{timeout}}"))
|
||||
|
||||
class AlgorithmRequestHandler(BaseHTTPRequestHandler):
|
||||
"""算法请求处理器"""
|
||||
|
||||
def do_POST(self):
|
||||
"""处理POST请求"""
|
||||
try:
|
||||
# 读取请求体
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
|
||||
# 解析请求数据
|
||||
request_data = json.loads(post_data.decode('utf-8'))
|
||||
|
||||
# 记录请求开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 调用算法
|
||||
result = self._call_algorithm(request_data)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# 构建响应
|
||||
response = {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"response_time": round(response_time, 4),
|
||||
"message": "算法执行成功"
|
||||
}
|
||||
|
||||
# 发送响应
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
# 构建错误响应
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "算法执行失败"
|
||||
}
|
||||
|
||||
# 发送错误响应
|
||||
self.send_response(400)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(error_response).encode('utf-8'))
|
||||
|
||||
def do_GET(self):
|
||||
"""处理GET请求"""
|
||||
if self.path == "{{health_check_path}}":
|
||||
# 健康检查
|
||||
self._handle_health_check()
|
||||
elif self.path == "/info":
|
||||
# 服务信息
|
||||
self._handle_info()
|
||||
else:
|
||||
# 404响应
|
||||
self.send_response(404)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"error": "Not Found"}).encode('utf-8'))
|
||||
|
||||
def _call_algorithm(self, request_data):
|
||||
"""调用算法
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
|
||||
Returns:
|
||||
算法执行结果
|
||||
"""
|
||||
if algorithm_module is None:
|
||||
raise Exception("算法模块未加载")
|
||||
|
||||
# 尝试调用算法的主要函数
|
||||
try:
|
||||
# 检查是否有predict函数
|
||||
if hasattr(algorithm_module, 'predict'):
|
||||
return algorithm_module.predict(request_data)
|
||||
# 检查是否有run函数
|
||||
elif hasattr(algorithm_module, 'run'):
|
||||
return algorithm_module.run(request_data)
|
||||
# 检查是否有main函数
|
||||
elif hasattr(algorithm_module, 'main'):
|
||||
return algorithm_module.main(request_data)
|
||||
else:
|
||||
raise Exception("未找到算法执行函数")
|
||||
except Exception as e:
|
||||
raise Exception(f"算法执行失败: {e}")
|
||||
|
||||
def _handle_health_check(self):
|
||||
"""处理健康检查"""
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "healthy", "service": "{{project_name}}"}).encode('utf-8'))
|
||||
|
||||
def _handle_info(self):
|
||||
"""处理服务信息请求"""
|
||||
info = {
|
||||
"service": "{{project_name}}",
|
||||
"version": "1.0.0",
|
||||
"host": HOST,
|
||||
"port": PORT,
|
||||
"timeout": TIMEOUT,
|
||||
"algorithm_loaded": algorithm_module is not None
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(info).encode('utf-8'))
|
||||
|
||||
def run_server():
|
||||
"""启动服务"""
|
||||
server = HTTPServer((HOST, PORT), AlgorithmRequestHandler)
|
||||
print(f"服务启动成功,监听地址: {HOST}:{PORT}")
|
||||
print(f"健康检查地址: http://{HOST}:{PORT}{{health_check_path}}")
|
||||
print(f"服务信息地址: http://{HOST}:{PORT}/info")
|
||||
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("服务停止")
|
||||
server.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
238
backend/app/services/user.py
Normal file
238
backend/app/services/user.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.models import User, APIKey
|
||||
from app.schemas.user import UserCreate, UserUpdate, TokenData, APIKeyCreate
|
||||
from app.utils.cache import cache
|
||||
|
||||
# 密码加密上下文,使用pbkdf2_sha256方案,避免bcrypt的密码长度限制
|
||||
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户服务类"""
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
@staticmethod
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""获取密码哈希值"""
|
||||
# 确保密码长度不超过72个字节
|
||||
password = password[:72]
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建访问令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def logout_user(token: str) -> None:
|
||||
"""用户登出,将令牌加入黑名单"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
exp = payload.get("exp")
|
||||
if exp:
|
||||
# 计算令牌剩余有效期
|
||||
remaining_time = exp - int(datetime.utcnow().timestamp())
|
||||
if remaining_time > 0:
|
||||
# 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间
|
||||
cache.set(f"blacklist:{token}", "1", expire=remaining_time)
|
||||
except JWTError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_token_blacklisted(token: str) -> bool:
|
||||
"""检查令牌是否在黑名单中"""
|
||||
return cache.exists(f"blacklist:{token}")
|
||||
|
||||
@staticmethod
|
||||
def get_current_user(db: Session, token: str) -> Optional[User]:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 检查令牌是否在黑名单中
|
||||
if UserService.is_token_blacklisted(token):
|
||||
return None
|
||||
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
return None
|
||||
token_data = TokenData(username=username, user_id=payload.get("user_id"))
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
user = UserService.get_user_by_username(db, username=token_data.username)
|
||||
if user is None:
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||||
"""通过用户名获取用户"""
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
|
||||
"""通过ID获取用户"""
|
||||
return db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""通过邮箱获取用户"""
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
@staticmethod
|
||||
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""获取用户列表"""
|
||||
return db.query(User).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user: UserCreate) -> User:
|
||||
"""创建用户"""
|
||||
# 生成唯一ID
|
||||
user_id = f"user-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建用户实例
|
||||
db_user = User(
|
||||
id=user_id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
password_hash=UserService.get_password_hash(user.password),
|
||||
role=user.role
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
return db_user
|
||||
|
||||
@staticmethod
|
||||
def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]:
|
||||
"""更新用户"""
|
||||
# 获取用户
|
||||
db_user = UserService.get_user_by_id(db, user_id)
|
||||
if not db_user:
|
||||
return None
|
||||
|
||||
# 更新用户信息
|
||||
update_data = user_update.dict(exclude_unset=True)
|
||||
|
||||
# 如果更新密码,需要重新哈希
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password"))
|
||||
|
||||
# 应用更新
|
||||
for field, value in update_data.items():
|
||||
setattr(db_user, field, value)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
return db_user
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
|
||||
"""认证用户"""
|
||||
user = UserService.get_user_by_username(db, username)
|
||||
if not user:
|
||||
return None
|
||||
if not UserService.verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
class APIKeyService:
|
||||
"""API密钥服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(db: Session, api_key_create: APIKeyCreate) -> APIKey:
|
||||
"""创建API密钥"""
|
||||
# 生成唯一ID和密钥
|
||||
api_key_id = f"key-{uuid.uuid4().hex[:8]}"
|
||||
api_key_value = f"sk_{uuid.uuid4().hex}"
|
||||
|
||||
# 创建API密钥实例
|
||||
db_api_key = APIKey(
|
||||
id=api_key_id,
|
||||
user_id=api_key_create.user_id,
|
||||
name=api_key_create.name,
|
||||
key=api_key_value,
|
||||
expires_at=api_key_create.expires_at
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_api_key)
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
|
||||
return db_api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[APIKey]:
|
||||
"""通过ID获取API密钥"""
|
||||
return db.query(APIKey).filter(APIKey.id == api_key_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_value(db: Session, api_key_value: str) -> Optional[APIKey]:
|
||||
"""通过值获取API密钥"""
|
||||
return db.query(APIKey).filter(APIKey.key == api_key_value).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_keys_by_user_id(db: Session, user_id: str) -> List[APIKey]:
|
||||
"""通过用户ID获取API密钥列表"""
|
||||
return db.query(APIKey).filter(APIKey.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def revoke_api_key(db: Session, api_key_id: str) -> Optional[APIKey]:
|
||||
"""撤销API密钥"""
|
||||
# 获取API密钥
|
||||
db_api_key = APIKeyService.get_api_key_by_id(db, api_key_id)
|
||||
if not db_api_key:
|
||||
return None
|
||||
|
||||
# 更新状态为已撤销
|
||||
db_api_key.status = "revoked"
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
|
||||
return db_api_key
|
||||
|
||||
@staticmethod
|
||||
def validate_api_key(db: Session, api_key_value: str) -> Optional[APIKey]:
|
||||
"""验证API密钥"""
|
||||
# 获取API密钥
|
||||
api_key = APIKeyService.get_api_key_by_value(db, api_key_value)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# 检查状态
|
||||
if api_key.status != "active":
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if api_key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
return api_key
|
||||
Reference in New Issue
Block a user