final version
This commit is contained in:
@@ -1,867 +0,0 @@
|
||||
"""Gitea服务,处理与Gitea相关的业务逻辑"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, List
|
||||
import uuid
|
||||
|
||||
from app.gitea.client import GiteaClient
|
||||
from app.config.settings import settings
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.models import GiteaConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteaService:
|
||||
"""Gitea服务类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Gitea服务"""
|
||||
self.config = self._load_config()
|
||||
self.client = None
|
||||
if self.config:
|
||||
self.client = GiteaClient(
|
||||
self.config['server_url'],
|
||||
self.config['access_token']
|
||||
)
|
||||
|
||||
def _load_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""加载Gitea配置
|
||||
|
||||
Returns:
|
||||
Gitea配置信息
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
# 从数据库中获取配置(只取第一个配置)
|
||||
config = db.query(GiteaConfig).filter_by(status="active").first()
|
||||
db.close()
|
||||
|
||||
if config:
|
||||
return {
|
||||
'id': config.id,
|
||||
'server_url': config.server_url,
|
||||
'access_token': config.access_token,
|
||||
'default_owner': config.default_owner,
|
||||
'repo_prefix': config.repo_prefix,
|
||||
'status': config.status
|
||||
}
|
||||
|
||||
# 配置不存在时返回默认值
|
||||
return {
|
||||
'server_url': getattr(settings, 'GITEA_SERVER_URL', ''),
|
||||
'access_token': getattr(settings, 'GITEA_ACCESS_TOKEN', ''),
|
||||
'default_owner': getattr(settings, 'GITEA_DEFAULT_OWNER', ''),
|
||||
'repo_prefix': getattr(settings, 'GITEA_REPO_PREFIX', '')
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Gitea config from database: {str(e)}")
|
||||
# 出错时返回默认配置
|
||||
return {
|
||||
'server_url': getattr(settings, 'GITEA_SERVER_URL', ''),
|
||||
'access_token': getattr(settings, 'GITEA_ACCESS_TOKEN', ''),
|
||||
'default_owner': getattr(settings, 'GITEA_DEFAULT_OWNER', ''),
|
||||
'repo_prefix': getattr(settings, 'GITEA_REPO_PREFIX', '')
|
||||
}
|
||||
|
||||
def save_config(self, config: Dict[str, Any]) -> bool:
|
||||
"""保存Gitea配置
|
||||
|
||||
Args:
|
||||
config: Gitea配置信息
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
|
||||
# 将所有现有配置设置为非活动状态
|
||||
db.query(GiteaConfig).update({GiteaConfig.status: "inactive"})
|
||||
|
||||
# 检查是否已有配置
|
||||
existing_config = db.query(GiteaConfig).first()
|
||||
|
||||
if existing_config:
|
||||
# 更新现有配置
|
||||
existing_config.server_url = config['server_url']
|
||||
existing_config.access_token = config['access_token']
|
||||
existing_config.default_owner = config['default_owner']
|
||||
existing_config.repo_prefix = config.get('repo_prefix', '')
|
||||
existing_config.status = "active"
|
||||
else:
|
||||
# 创建新配置
|
||||
new_config = GiteaConfig(
|
||||
id=f"gitea-config-{uuid.uuid4()}",
|
||||
server_url=config['server_url'],
|
||||
access_token=config['access_token'],
|
||||
default_owner=config['default_owner'],
|
||||
repo_prefix=config.get('repo_prefix', ''),
|
||||
status="active"
|
||||
)
|
||||
db.add(new_config)
|
||||
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
# 更新内存中的配置
|
||||
self.config = config
|
||||
self.client = GiteaClient(
|
||||
config['server_url'],
|
||||
config['access_token']
|
||||
)
|
||||
|
||||
logger.info("Gitea config saved to database successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Gitea config to database: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取Gitea配置
|
||||
|
||||
Returns:
|
||||
Gitea配置信息
|
||||
"""
|
||||
return self.config
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试Gitea连接
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if not self.client:
|
||||
return False
|
||||
return self.client.check_connection()
|
||||
|
||||
def create_repository(self, algorithm_id: str, algorithm_name: str, description: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""为算法创建Gitea仓库
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
algorithm_name: 算法名称
|
||||
description: 仓库描述
|
||||
|
||||
Returns:
|
||||
创建的仓库信息
|
||||
"""
|
||||
try:
|
||||
if not self.client:
|
||||
logger.error("Gitea client not initialized. Please check your Gitea configuration.")
|
||||
return None
|
||||
|
||||
if not self.config.get('default_owner'):
|
||||
logger.error("Default owner not set in Gitea configuration.")
|
||||
return None
|
||||
|
||||
# 记录传入的algorithm_id
|
||||
logger.info(f"Received algorithm_id: {algorithm_id}")
|
||||
|
||||
# 检查是否已经包含前缀
|
||||
repo_prefix = self.config.get('repo_prefix', '')
|
||||
if repo_prefix and algorithm_id.startswith(repo_prefix):
|
||||
logger.info(f"Algorithm ID already contains prefix: {repo_prefix}")
|
||||
repo_name = algorithm_id
|
||||
else:
|
||||
# 生成仓库名称,添加前缀
|
||||
repo_name = f"{repo_prefix}{algorithm_id}" if repo_prefix else algorithm_id
|
||||
logger.info(f"Generated repository name: {repo_name}")
|
||||
|
||||
logger.info(f"Creating repository: {repo_name} for owner: {self.config['default_owner']}")
|
||||
|
||||
# 创建仓库
|
||||
repo = self.client.create_repository(
|
||||
self.config['default_owner'],
|
||||
repo_name,
|
||||
description or f"Algorithm repository for {algorithm_name}",
|
||||
False
|
||||
)
|
||||
|
||||
if repo:
|
||||
logger.info(f"Repository created successfully: {repo}")
|
||||
# 验证仓库是否真的存在
|
||||
verify_repo = self.client.get_repository(self.config['default_owner'], repo_name)
|
||||
if not verify_repo:
|
||||
logger.error(f"Repository creation verified failed: {repo_name}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"Failed to create repository: {repo_name}")
|
||||
|
||||
return repo
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create repository: {str(e)}")
|
||||
return None
|
||||
|
||||
def clone_repository(self, repo_url: str, algorithm_id: str, branch: str = "main") -> bool:
|
||||
"""克隆Gitea仓库
|
||||
|
||||
Args:
|
||||
repo_url: 仓库URL
|
||||
algorithm_id: 算法ID
|
||||
branch: 分支名称
|
||||
|
||||
Returns:
|
||||
是否克隆成功
|
||||
"""
|
||||
try:
|
||||
# 创建本地目录
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
|
||||
logger.info(f"Cloning repository to: {repo_dir}")
|
||||
|
||||
# 导入需要的模块
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
# 如果目录已存在,先清理它
|
||||
if os.path.exists(repo_dir):
|
||||
logger.info(f"Cleaning existing repository directory: {repo_dir}")
|
||||
try:
|
||||
shutil.rmtree(repo_dir)
|
||||
logger.info(f"Successfully cleaned directory: {repo_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean directory: {str(e)}")
|
||||
# 尝试使用sudo删除(如果有权限)
|
||||
try:
|
||||
subprocess.run(["sudo", "rm", "-rf", repo_dir], check=True)
|
||||
logger.info(f"Successfully cleaned directory with sudo: {repo_dir}")
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to clean directory with sudo: {str(e2)}")
|
||||
return False
|
||||
|
||||
# 重新创建目录
|
||||
logger.info(f"Creating directory: {repo_dir}")
|
||||
os.makedirs(repo_dir, exist_ok=True)
|
||||
logger.info(f"Directory created successfully: {repo_dir}")
|
||||
|
||||
# 克隆仓库
|
||||
cmd = ["git", "clone", "-b", branch, repo_url, repo_dir]
|
||||
logger.info(f"Running clone command: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"Repository cloned successfully: {repo_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to clone repository: {result.stderr}")
|
||||
|
||||
# 尝试初始化仓库
|
||||
logger.info(f"Trying to initialize repository in {repo_dir}")
|
||||
|
||||
# 初始化git仓库
|
||||
init_result = subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if init_result.returncode != 0:
|
||||
logger.error(f"Failed to initialize git repository: {init_result.stderr}")
|
||||
return False
|
||||
|
||||
# 添加远程仓库
|
||||
remote_result = subprocess.run(["git", "remote", "add", "origin", repo_url], cwd=repo_dir, capture_output=True, text=True)
|
||||
if remote_result.returncode != 0:
|
||||
logger.error(f"Failed to add remote repository: {remote_result.stderr}")
|
||||
# 如果远程仓库已存在,尝试更新它
|
||||
logger.info("Trying to update existing remote repository")
|
||||
update_result = subprocess.run(["git", "remote", "set-url", "origin", repo_url], cwd=repo_dir, capture_output=True, text=True)
|
||||
if update_result.returncode != 0:
|
||||
logger.error(f"Failed to update remote repository: {update_result.stderr}")
|
||||
return False
|
||||
logger.info("Successfully updated remote repository")
|
||||
|
||||
# 创建初始文件
|
||||
readme_path = os.path.join(repo_dir, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write("# Algorithm Repository\n\nThis is an algorithm repository.\n")
|
||||
|
||||
# 添加文件并提交
|
||||
add_result = subprocess.run(["git", "add", "README.md"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.returncode != 0:
|
||||
logger.error(f"Failed to add README.md: {add_result.stderr}")
|
||||
return False
|
||||
|
||||
commit_result = subprocess.run(["git", "commit", "-m", "Initial commit"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if commit_result.returncode != 0:
|
||||
logger.error(f"Failed to commit initial file: {commit_result.stderr}")
|
||||
return False
|
||||
|
||||
# 推送代码到远程仓库
|
||||
push_result = subprocess.run(["git", "push", "-u", "origin", branch], cwd=repo_dir, capture_output=True, text=True)
|
||||
if push_result.returncode != 0:
|
||||
logger.error(f"Failed to push initial commit: {push_result.stderr}")
|
||||
# 即使推送失败,初始化仓库也算成功
|
||||
logger.info(f"Repository initialized successfully, but push failed: {push_result.stderr}")
|
||||
return True
|
||||
|
||||
logger.info(f"Repository initialized and pushed successfully: {repo_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clone repository: {str(e)}")
|
||||
return False
|
||||
|
||||
def push_to_repository(self, algorithm_id: str, message: str = "Update code") -> bool:
|
||||
"""推送代码到Gitea仓库
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
message: 提交消息
|
||||
|
||||
Returns:
|
||||
是否推送成功
|
||||
"""
|
||||
try:
|
||||
logger.info("=== 开始推送代码到Gitea仓库 ===")
|
||||
logger.info(f"Algorithm ID: {algorithm_id}")
|
||||
logger.info(f"Commit message: {message}")
|
||||
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
logger.info(f"Repository directory: {repo_dir}")
|
||||
|
||||
if not os.path.exists(repo_dir):
|
||||
logger.error(f"❌ Repository directory not found: {repo_dir}")
|
||||
return False
|
||||
|
||||
# 首先尝试使用API上传(推荐方法,避免Git推送限制)
|
||||
logger.info("Attempting to upload files via Gitea API...")
|
||||
api_upload_success = self.upload_files_via_api(algorithm_id, message)
|
||||
|
||||
if api_upload_success:
|
||||
logger.info(f"✅ Code uploaded successfully via API for algorithm: {algorithm_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning("❌ API upload failed, falling back to Git push...")
|
||||
|
||||
# 如果API上传失败,回退到原来的Git推送方法
|
||||
import subprocess
|
||||
|
||||
# 检查是否是git仓库
|
||||
git_dir = os.path.join(repo_dir, ".git")
|
||||
if not os.path.exists(git_dir):
|
||||
logger.info(f"⚠️ Git repository not initialized, initializing...")
|
||||
# 初始化git仓库
|
||||
logger.info(f"Executing: git init in {repo_dir}")
|
||||
init_result = subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git init output: {init_result.stdout}")
|
||||
if init_result.stderr:
|
||||
logger.warning(f"Git init stderr: {init_result.stderr}")
|
||||
if init_result.returncode != 0:
|
||||
logger.error(f"❌ Failed to initialize git repository: {init_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Git repository initialized successfully")
|
||||
|
||||
# 添加远程仓库(从配置中获取,包含访问令牌以确保认证)
|
||||
if self.config.get('default_owner'):
|
||||
# 使用访问令牌构建认证URL
|
||||
auth_repo_url = f"https://{self.config['access_token']}@{self.config['server_url'].replace('https://', '').replace('http://', '')}/{self.config['default_owner']}/{algorithm_id}.git"
|
||||
logger.info(f"Adding remote repository: {auth_repo_url}")
|
||||
remote_result = subprocess.run(["git", "remote", "add", "origin", auth_repo_url], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git remote add output: {remote_result.stdout}")
|
||||
if remote_result.stderr:
|
||||
logger.warning(f"Git remote add stderr: {remote_result.stderr}")
|
||||
if remote_result.returncode != 0:
|
||||
logger.error(f"❌ Failed to add remote repository: {remote_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Remote repository added successfully")
|
||||
else:
|
||||
logger.info("✅ Git repository already initialized")
|
||||
|
||||
# 执行git命令 - 分批添加文件以处理大量文件
|
||||
logger.info("=== 执行Git操作 ===")
|
||||
|
||||
# 获取所有需要添加的文件
|
||||
all_files = []
|
||||
for root, dirs, files in os.walk(repo_dir):
|
||||
if '.git' in root:
|
||||
continue
|
||||
for file in files:
|
||||
if not file.endswith('.git'):
|
||||
file_path = os.path.relpath(os.path.join(root, file), repo_dir)
|
||||
all_files.append(file_path)
|
||||
|
||||
logger.info(f"Total files to add: {len(all_files)}")
|
||||
|
||||
# 分批添加文件,避免命令行参数过长
|
||||
batch_size = 100 # 每次添加100个文件
|
||||
for i in range(0, len(all_files), batch_size):
|
||||
batch = all_files[i:i + batch_size]
|
||||
logger.info(f"Adding batch {i//batch_size + 1}: {len(batch)} files")
|
||||
|
||||
add_result = subprocess.run(["git", "add"] + batch, cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.stderr and add_result.returncode != 0:
|
||||
logger.error(f"❌ Git add batch {i//batch_size + 1} failed: {add_result.stderr}")
|
||||
return False
|
||||
elif add_result.stderr:
|
||||
logger.warning(f"Git add batch {i//batch_size + 1} warning: {add_result.stderr}")
|
||||
|
||||
logger.info("✅ Git add completed successfully")
|
||||
|
||||
# 检查是否有更改需要提交
|
||||
logger.info("Executing: git status --porcelain")
|
||||
status_result = subprocess.run(["git", "status", "--porcelain"], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git status output: {status_result.stdout}")
|
||||
if status_result.stderr:
|
||||
logger.warning(f"Git status stderr: {status_result.stderr}")
|
||||
if status_result.returncode != 0:
|
||||
logger.error(f"❌ Git status failed: {status_result.stderr}")
|
||||
return False
|
||||
|
||||
# 如果有更改,执行commit和push
|
||||
if status_result.stdout.strip():
|
||||
logger.info("✅ Changes detected, proceeding with commit and push")
|
||||
# 执行git commit
|
||||
logger.info(f"Executing: git commit -m '{message}'")
|
||||
commit_result = subprocess.run(["git", "commit", "-m", message], cwd=repo_dir, capture_output=True, text=True)
|
||||
logger.info(f"Git commit output: {commit_result.stdout}")
|
||||
if commit_result.stderr:
|
||||
logger.warning(f"Git commit stderr: {commit_result.stderr}")
|
||||
if commit_result.returncode != 0:
|
||||
logger.error(f"❌ Git commit failed: {commit_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Git commit completed successfully")
|
||||
|
||||
# 检查仓库大小
|
||||
logger.info("Checking repository size before push")
|
||||
total_size = 0
|
||||
for dirpath, dirnames, filenames in os.walk(repo_dir):
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
if not filepath.startswith(os.path.join(repo_dir, '.git')):
|
||||
total_size += os.path.getsize(filepath)
|
||||
logger.info(f"Repository size (excluding .git): {total_size / (1024 * 1024):.2f} MB")
|
||||
|
||||
if total_size > 100 * 1024 * 1024: # 100MB
|
||||
logger.warning(f"Repository is large: {total_size / (1024 * 1024):.2f} MB")
|
||||
logger.warning("This may cause HTTP 413 errors on push")
|
||||
|
||||
# 设置Git推送缓冲区大小(增加到1GB)
|
||||
logger.info("Setting Git http.postBuffer to 1GB")
|
||||
buffer_result = subprocess.run(["git", "config", "http.postBuffer", "1073741824"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if buffer_result.returncode != 0:
|
||||
logger.warning(f"Failed to set http.postBuffer: {buffer_result.stderr}")
|
||||
else:
|
||||
logger.info("✅ Git http.postBuffer set successfully")
|
||||
|
||||
# 禁用Git压缩
|
||||
logger.info("Disabling Git compression")
|
||||
compression_result = subprocess.run(["git", "config", "core.compression", "0"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if compression_result.returncode != 0:
|
||||
logger.warning(f"Failed to set core.compression: {compression_result.stderr}")
|
||||
else:
|
||||
logger.info("✅ Git core.compression disabled successfully")
|
||||
|
||||
# 针对大仓库优化的推送命令
|
||||
logger.info("Setting additional Git configs for large repositories...")
|
||||
subprocess.run(["git", "config", "http.postBuffer", "524288000"], cwd=repo_dir) # 500MB buffer
|
||||
subprocess.run(["git", "config", "pack.windowMemory", "128m"], cwd=repo_dir) # Limit memory usage
|
||||
subprocess.run(["git", "config", "pack.packSizeLimit", "128m"], cwd=repo_dir) # Limit pack size
|
||||
|
||||
# 执行git push,添加更多优化参数
|
||||
logger.info("Executing: git push with optimizations for large repositories")
|
||||
push_result = subprocess.run([
|
||||
"git", "push",
|
||||
"--verbose",
|
||||
"-u", "origin", "main",
|
||||
"--receive-pack='git receive-pack'", # Ensure proper receive pack
|
||||
"--progress" # Show progress for large pushes
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=300) # 5 minute timeout
|
||||
logger.info(f"Git push output: {push_result.stdout}")
|
||||
if push_result.stderr:
|
||||
logger.warning(f"Git push stderr: {push_result.stderr}")
|
||||
if push_result.returncode != 0:
|
||||
# 检查是否是常见的大文件错误
|
||||
error_msg = push_result.stderr.lower()
|
||||
is_large_file_error = (
|
||||
"http 413" in error_msg or
|
||||
"payload too large" in error_msg or
|
||||
"unpack failed" in error_msg or
|
||||
"remote: fatal" in error_msg or
|
||||
"cannot spawn" in error_msg or
|
||||
"timeout" in error_msg
|
||||
)
|
||||
|
||||
if is_large_file_error:
|
||||
logger.error(f"❌ Git push failed likely due to repository size: {total_size / (1024 * 1024):.2f} MB")
|
||||
logger.error(f"Error details: {push_result.stderr}")
|
||||
logger.error("\n📋 解决方案建议:")
|
||||
logger.error("1. 检查Gitea服务器配置,增加MAX_UPLOAD_SIZE限制")
|
||||
logger.error("2. 尝试使用SSH协议进行推送(如果服务器支持)")
|
||||
logger.error("3. 优化仓库大小,移除不必要的大文件")
|
||||
logger.error("4. 考虑使用Git LFS(Large File Storage)管理大文件")
|
||||
|
||||
# 尝试使用SSH协议进行推送(如果URL是HTTPS格式)
|
||||
logger.info("\n🔄 尝试使用SSH协议进行推送...")
|
||||
try:
|
||||
# 获取当前远程URL
|
||||
remote_result = subprocess.run(["git", "remote", "get-url", "origin"], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if remote_result.returncode == 0:
|
||||
https_url = remote_result.stdout.strip()
|
||||
# 将HTTPS URL转换为SSH URL
|
||||
if https_url.startswith("https://"):
|
||||
ssh_url = https_url.replace("https://", "git@").replace(":", "/")
|
||||
logger.info(f"Converting HTTPS URL to SSH URL: {ssh_url}")
|
||||
# 更新远程URL
|
||||
set_url_result = subprocess.run(["git", "remote", "set-url", "origin", ssh_url], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if set_url_result.returncode == 0:
|
||||
logger.info("✅ Remote URL updated to SSH format")
|
||||
# 再次尝试推送,使用更保守的参数
|
||||
logger.info("Executing: git push with SSH and conservative parameters")
|
||||
ssh_push_result = subprocess.run([
|
||||
"git", "push",
|
||||
"--verbose",
|
||||
"-u", "origin", "main"
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=600) # 10 minute timeout for SSH
|
||||
|
||||
if ssh_push_result.returncode == 0:
|
||||
logger.info("✅ Git push completed successfully with SSH")
|
||||
# 改回HTTPS URL
|
||||
reset_url_result = subprocess.run(["git", "remote", "set-url", "origin", https_url], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if reset_url_result.returncode != 0:
|
||||
logger.warning(f"Failed to reset remote URL to HTTPS: {reset_url_result.stderr}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"SSH push failed: {ssh_push_result.stderr}")
|
||||
# 改回HTTPS URL
|
||||
reset_url_result = subprocess.run(["git", "remote", "set-url", "origin", https_url], cwd=repo_dir, capture_output=True, text=True, timeout=30)
|
||||
if reset_url_result.returncode != 0:
|
||||
logger.warning(f"Failed to reset remote URL to HTTPS: {reset_url_result.stderr}")
|
||||
|
||||
# 如果SSH也失败,尝试分阶段推送
|
||||
logger.info("\n🔄 尝试分阶段推送...")
|
||||
return self.push_repository_staged(repo_dir, https_url)
|
||||
else:
|
||||
logger.warning(f"Could not get remote URL: {remote_result.stderr}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Remote URL command timed out")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to try SSH push: {str(e)}")
|
||||
else:
|
||||
logger.error(f"❌ Git push failed: {push_result.stderr}")
|
||||
return False
|
||||
logger.info("✅ Git push completed successfully")
|
||||
else:
|
||||
logger.info("ℹ️ No changes to commit")
|
||||
|
||||
logger.info(f"✅ Code pushed successfully for algorithm: {algorithm_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"=== 推送代码失败 ===")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def pull_from_repository(self, algorithm_id: str) -> bool:
|
||||
"""从Gitea仓库拉取代码
|
||||
|
||||
Args:
|
||||
algorithm_id: 算法ID
|
||||
|
||||
Returns:
|
||||
是否拉取成功
|
||||
"""
|
||||
try:
|
||||
repo_dir = f"/tmp/algorithms/{algorithm_id}"
|
||||
|
||||
if not os.path.exists(repo_dir):
|
||||
logger.error(f"Repository directory not found: {repo_dir}")
|
||||
return False
|
||||
|
||||
# 执行git pull命令
|
||||
result = subprocess.run(
|
||||
["git", "pull"],
|
||||
cwd=repo_dir,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"Code pulled successfully for algorithm: {algorithm_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to pull code: {result.stderr}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull code: {str(e)}")
|
||||
return False
|
||||
|
||||
def push_repository_staged(self, repo_dir: str, origin_url: str) -> bool:
|
||||
"""
|
||||
分阶段推送仓库,用于处理超大仓库
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
logger.info("=== 开始分阶段推送仓库 ===")
|
||||
logger.info(f"Repository directory: {repo_dir}")
|
||||
|
||||
# 获取所有文件并按类型分组
|
||||
all_files = []
|
||||
for root, dirs, files in os.walk(repo_dir):
|
||||
# 跳过 .git 目录
|
||||
if '.git' in root:
|
||||
continue
|
||||
for file in files:
|
||||
file_path = os.path.relpath(os.path.join(root, file), repo_dir)
|
||||
if file_path.startswith('.git'):
|
||||
continue
|
||||
all_files.append(file_path)
|
||||
|
||||
logger.info(f"Total files to stage: {len(all_files)}")
|
||||
|
||||
# 按扩展名分类文件,优先推送小文件
|
||||
def get_file_size(file_path):
|
||||
try:
|
||||
return os.path.getsize(os.path.join(repo_dir, file_path))
|
||||
except:
|
||||
return 0
|
||||
|
||||
# 按文件大小排序(从小到大)
|
||||
sorted_files = sorted(all_files, key=get_file_size)
|
||||
|
||||
# 分批处理,每批最多50个文件或不超过50MB
|
||||
batch_size_limit = 50
|
||||
batch_size_bytes = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
current_batch = []
|
||||
current_batch_size = 0
|
||||
batch_number = 1
|
||||
|
||||
for file_path in sorted_files:
|
||||
file_full_path = os.path.join(repo_dir, file_path)
|
||||
file_size = get_file_size(file_path)
|
||||
|
||||
# 如果单个文件太大,单独处理
|
||||
if file_size > batch_size_bytes:
|
||||
logger.info(f"Handling large file separately: {file_path} ({file_size / (1024*1024):.2f}MB)")
|
||||
# 单独添加和推送这个大文件
|
||||
add_result = subprocess.run(["git", "add", file_path], cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.returncode != 0:
|
||||
logger.error(f"Failed to add large file {file_path}: {add_result.stderr}")
|
||||
continue
|
||||
|
||||
# 检查是否有暂存的更改
|
||||
status_result = subprocess.run(["git", "status", "--porcelain"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if status_result.stdout.strip():
|
||||
# 创建专门的提交
|
||||
commit_msg = f"Add large file: {file_path}"
|
||||
commit_result = subprocess.run(["git", "commit", "-m", commit_msg], cwd=repo_dir, capture_output=True, text=True)
|
||||
if commit_result.returncode == 0:
|
||||
logger.info(f"Committed large file: {file_path}")
|
||||
|
||||
# 推送这个提交
|
||||
push_result = subprocess.run([
|
||||
"git", "push", "--verbose", "origin", "main"
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if push_result.returncode != 0:
|
||||
logger.warning(f"Push failed for large file {file_path}: {push_result.stderr}")
|
||||
# 如果推送失败,尝试重置这个文件的暂存状态
|
||||
subprocess.run(["git", "reset", "HEAD", file_path], cwd=repo_dir, capture_output=True, text=True)
|
||||
else:
|
||||
logger.info(f"Successfully pushed large file: {file_path}")
|
||||
else:
|
||||
logger.error(f"Failed to commit large file {file_path}: {commit_result.stderr}")
|
||||
else:
|
||||
# 尝试添加到当前批次
|
||||
if (len(current_batch) >= batch_size_limit or
|
||||
current_batch_size + file_size > batch_size_bytes):
|
||||
# 推送当前批次
|
||||
if current_batch:
|
||||
logger.info(f"Pushing batch {batch_number} with {len(current_batch)} files...")
|
||||
success = self.push_batch(repo_dir, current_batch, batch_number, origin_url)
|
||||
if not success:
|
||||
logger.error(f"Failed to push batch {batch_number}")
|
||||
return False
|
||||
batch_number += 1
|
||||
current_batch = []
|
||||
current_batch_size = 0
|
||||
|
||||
current_batch.append(file_path)
|
||||
current_batch_size += file_size
|
||||
|
||||
# 推送最后一批
|
||||
if current_batch:
|
||||
logger.info(f"Pushing final batch {batch_number} with {len(current_batch)} files...")
|
||||
success = self.push_batch(repo_dir, current_batch, batch_number, origin_url)
|
||||
if not success:
|
||||
logger.error(f"Failed to push final batch {batch_number}")
|
||||
return False
|
||||
|
||||
logger.info("✅ 分阶段推送完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 分阶段推送失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def push_batch(self, repo_dir: str, file_batch: list, batch_num: int, origin_url: str) -> bool:
|
||||
"""
|
||||
推送文件批次
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
logger.info(f"Processing batch {batch_num}: {len(file_batch)} files")
|
||||
|
||||
# 添加批次中的文件
|
||||
for file_path in file_batch:
|
||||
add_result = subprocess.run(["git", "add", file_path], cwd=repo_dir, capture_output=True, text=True)
|
||||
if add_result.returncode != 0:
|
||||
logger.error(f"Failed to add file {file_path}: {add_result.stderr}")
|
||||
return False
|
||||
|
||||
# 检查是否有更改需要提交
|
||||
status_result = subprocess.run(["git", "status", "--porcelain"], cwd=repo_dir, capture_output=True, text=True)
|
||||
if not status_result.stdout.strip():
|
||||
logger.info(f"No changes in batch {batch_num}")
|
||||
return True
|
||||
|
||||
# 提交批次
|
||||
commit_result = subprocess.run([
|
||||
"git", "commit", "-m", f"Batch {batch_num}: Add {len(file_batch)} files"
|
||||
], cwd=repo_dir, capture_output=True, text=True)
|
||||
|
||||
if commit_result.returncode != 0:
|
||||
logger.warning(f"Commit failed or no changes for batch {batch_num}: {commit_result.stderr}")
|
||||
# 即使没有更改,也可能正常(比如文件没变)
|
||||
|
||||
# 推送批次
|
||||
push_result = subprocess.run([
|
||||
"git", "push", "--verbose", "origin", "main"
|
||||
], cwd=repo_dir, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if push_result.returncode == 0:
|
||||
logger.info(f"✅ Batch {batch_num} pushed successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ Batch {batch_num} push failed: {push_result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"❌ Batch {batch_num} push timed out")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Batch {batch_num} push failed with error: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_repository_info(self, repo_owner: str, repo_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
repo_owner: 仓库所有者
|
||||
repo_name: 仓库名称
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
return self.client.get_repository(repo_owner, repo_name)
|
||||
|
||||
def list_repositories(self, owner: Optional[str] = None) -> Optional[List[Dict[str, Any]]]:
|
||||
"""列出仓库
|
||||
|
||||
Args:
|
||||
owner: 所有者(用户或组织)
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
target_owner = owner or self.config.get('default_owner')
|
||||
if not target_owner:
|
||||
return None
|
||||
|
||||
return self.client.list_repositories(target_owner)
|
||||
|
||||
def register_algorithm_from_repo(self, repo_owner: str, repo_name: str, algorithm_id: str) -> bool:
|
||||
"""从仓库注册算法服务
|
||||
|
||||
Args:
|
||||
repo_owner: 仓库所有者
|
||||
repo_name: 仓库名称
|
||||
algorithm_id: 算法ID
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
try:
|
||||
# 这里应该实现从仓库注册算法服务的逻辑
|
||||
# 1. 克隆仓库
|
||||
# 2. 扫描仓库中的算法代码
|
||||
# 3. 注册算法服务
|
||||
|
||||
logger.info(f"Algorithm registered from repo: {repo_owner}/{repo_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register algorithm from repo: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# 递归遍历目录中的所有文件
|
||||
for root, dirs, files in os.walk(repo_dir):
|
||||
# 跳过 .git 目录
|
||||
if '.git' in root:
|
||||
continue
|
||||
|
||||
for file in files:
|
||||
file_path = os.path.relpath(os.path.join(root, file), repo_dir)
|
||||
if file_path.startswith('.git'):
|
||||
continue
|
||||
|
||||
full_file_path = os.path.join(root, file)
|
||||
|
||||
# 读取文件内容并进行base64编码
|
||||
try:
|
||||
with open(full_file_path, 'rb') as f:
|
||||
file_content = f.read()
|
||||
encoded_content = base64.b64encode(file_content).decode('utf-8')
|
||||
|
||||
# 使用Gitea API创建或更新文件
|
||||
if self.client:
|
||||
# 移除开头的./,如果有的话
|
||||
clean_path = file_path.lstrip('./\\')
|
||||
result = self.client.create_file(
|
||||
self.config["default_owner"],
|
||||
algorithm_id,
|
||||
clean_path,
|
||||
encoded_content,
|
||||
f"{message} - Upload {clean_path}"
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(f"✅ File uploaded via API: {clean_path}")
|
||||
else:
|
||||
logger.error(f"❌ Failed to upload file via API: {clean_path}")
|
||||
return False
|
||||
else:
|
||||
logger.error("❌ Gitea client not initialized")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error processing file {file_path}: {str(e)}")
|
||||
return False
|
||||
|
||||
logger.info(f"✅ All files uploaded successfully via API for algorithm: {algorithm_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to upload files via API: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局Gitea服务实例
|
||||
gitea_service = GiteaService()
|
||||
@@ -140,18 +140,6 @@ class AlgorithmRepository(Base):
|
||||
algorithm = relationship("Algorithm", back_populates="repository", uselist=False)
|
||||
|
||||
|
||||
class ServiceGroup(Base):
|
||||
"""服务分组模型"""
|
||||
__tablename__ = "service_groups"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, unique=True, index=True) # 分组名称
|
||||
description = Column(Text, default="") # 分组描述
|
||||
status = Column(String, default="active", index=True) # 状态
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class AlgorithmService(Base):
|
||||
"""算法服务模型"""
|
||||
__tablename__ = "algorithm_services"
|
||||
|
||||
@@ -5,10 +5,12 @@ from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
import os
|
||||
import logging
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.models import AlgorithmService, ServiceGroup, AlgorithmRepository
|
||||
from app.models.models import AlgorithmService, AlgorithmRepository, Algorithm, AlgorithmVersion
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.api import ApiEndpoint
|
||||
from app.routes.user import get_current_active_user
|
||||
from app.schemas.user import UserResponse
|
||||
from app.services.project_analyzer import ProjectAnalyzer
|
||||
@@ -17,6 +19,7 @@ from app.services.service_orchestrator import ServiceOrchestrator
|
||||
from app.gitea.service import gitea_service
|
||||
|
||||
router = APIRouter(prefix="/services", tags=["services"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterServiceRequest(BaseModel):
|
||||
@@ -28,7 +31,7 @@ class RegisterServiceRequest(BaseModel):
|
||||
tech_category: str = "computer_vision"
|
||||
output_type: str = "image"
|
||||
service_type: str = "http"
|
||||
host: str = "0.0.0.0"
|
||||
host: str = "localhost"
|
||||
port: int = 8000
|
||||
timeout: int = 30
|
||||
health_check_path: str = "/health"
|
||||
@@ -89,34 +92,6 @@ class RepositoryAlgorithmsResponse(BaseModel):
|
||||
algorithms: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ServiceGroupRequest(BaseModel):
|
||||
"""服务分组请求"""
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class ServiceGroupResponse(BaseModel):
|
||||
"""服务分组响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
status: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class ServiceGroupListResponse(BaseModel):
|
||||
"""服务分组列表响应"""
|
||||
success: bool
|
||||
groups: List[ServiceGroupResponse]
|
||||
|
||||
|
||||
class ServiceGroupDetailResponse(BaseModel):
|
||||
"""服务分组详情响应"""
|
||||
success: bool
|
||||
group: ServiceGroupResponse
|
||||
|
||||
|
||||
class BatchOperationRequest(BaseModel):
|
||||
"""批量操作请求"""
|
||||
service_ids: List[str]
|
||||
@@ -228,7 +203,62 @@ async def register_service(
|
||||
db.commit()
|
||||
db.refresh(new_service)
|
||||
|
||||
# 6. 返回响应
|
||||
# 7. 自动创建API端点
|
||||
try:
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.name == repo.name).first()
|
||||
if not algorithm:
|
||||
algorithm = Algorithm(
|
||||
id=str(uuid.uuid4()),
|
||||
name=repo.name,
|
||||
description=request.description or f"算法服务: {request.name}",
|
||||
type=request.tech_category,
|
||||
tech_category=request.tech_category,
|
||||
output_type=request.output_type
|
||||
)
|
||||
db.add(algorithm)
|
||||
db.commit()
|
||||
db.refresh(algorithm)
|
||||
|
||||
version = db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == algorithm.id,
|
||||
AlgorithmVersion.version == request.version
|
||||
).first()
|
||||
if not version:
|
||||
version = AlgorithmVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
algorithm_id=algorithm.id,
|
||||
version=request.version,
|
||||
url=request.service_url if hasattr(request, 'service_url') else ""
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
api_endpoint = ApiEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
name=request.name,
|
||||
description=request.description or f"{request.name} API端点",
|
||||
path=f"/api/v1/algorithms/{algorithm.id}/call",
|
||||
method="POST",
|
||||
algorithm_id=algorithm.id,
|
||||
version_id=version.id,
|
||||
service_id=service_id,
|
||||
requires_auth=False,
|
||||
is_public=True,
|
||||
status="active",
|
||||
config={
|
||||
"service_url": deploy_result["api_url"],
|
||||
"timeout": request.timeout,
|
||||
"health_check_path": request.health_check_path
|
||||
}
|
||||
)
|
||||
db.add(api_endpoint)
|
||||
db.commit()
|
||||
logger.info(f"API端点创建成功: {api_endpoint.name}, 路径: {api_endpoint.path}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建API端点失败: {str(e)}")
|
||||
|
||||
# 8. 返回响应
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务注册成功",
|
||||
@@ -537,6 +567,12 @@ async def delete_service(
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="Service not found")
|
||||
|
||||
# 先删除关联的API端点
|
||||
db.query(ApiEndpoint).filter(ApiEndpoint.service_id == service_id).delete()
|
||||
|
||||
# 获取算法名称,用于后续删除算法记录
|
||||
algorithm_name = service.algorithm_name
|
||||
|
||||
# 获取容器ID和镜像名称
|
||||
container_id = service.config.get("container_id")
|
||||
image_name = f"algorithm-service-{service_id}:{service.version}"
|
||||
@@ -549,6 +585,17 @@ async def delete_service(
|
||||
|
||||
# 删除数据库记录
|
||||
db.delete(service)
|
||||
|
||||
# 删除关联的算法记录(通过算法名称匹配)
|
||||
if algorithm_name:
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.name == algorithm_name).first()
|
||||
if algorithm:
|
||||
# 先删除关联的算法版本
|
||||
db.query(AlgorithmVersion).filter(AlgorithmVersion.algorithm_id == algorithm.id).delete()
|
||||
# 再删除算法记录
|
||||
db.query(AlgorithmCall).filter(AlgorithmCall.algorithm_id == algorithm.id).delete()
|
||||
db.delete(algorithm)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 返回响应
|
||||
@@ -677,202 +724,6 @@ async def get_repository_algorithms(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# 服务分组管理API
|
||||
|
||||
@router.post("/groups", status_code=status.HTTP_201_CREATED)
|
||||
async def create_service_group(
|
||||
request: ServiceGroupRequest,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""创建服务分组"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 生成唯一ID
|
||||
group_id = str(uuid.uuid4())
|
||||
|
||||
# 创建分组实例
|
||||
group = ServiceGroup(
|
||||
id=group_id,
|
||||
name=request.name,
|
||||
description=request.description
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(group)
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务分组创建成功",
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"status": group.status,
|
||||
"created_at": group.created_at.isoformat(),
|
||||
"updated_at": group.updated_at.isoformat()
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/groups", response_model=ServiceGroupListResponse)
|
||||
async def list_service_groups(
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务分组列表"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组列表
|
||||
groups = db.query(ServiceGroup).all()
|
||||
|
||||
# 转换为响应格式
|
||||
group_list = []
|
||||
for group in groups:
|
||||
group_list.append(ServiceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
status=group.status,
|
||||
created_at=group.created_at.isoformat(),
|
||||
updated_at=group.updated_at.isoformat()
|
||||
))
|
||||
|
||||
return ServiceGroupListResponse(
|
||||
success=True,
|
||||
groups=group_list
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/groups/{group_id}", response_model=ServiceGroupDetailResponse)
|
||||
async def get_service_group(
|
||||
group_id: str,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""获取服务分组详情"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组
|
||||
group = db.query(ServiceGroup).filter(ServiceGroup.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Service group not found")
|
||||
|
||||
return ServiceGroupDetailResponse(
|
||||
success=True,
|
||||
group=ServiceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
status=group.status,
|
||||
created_at=group.created_at.isoformat(),
|
||||
updated_at=group.updated_at.isoformat()
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.put("/groups/{group_id}")
|
||||
async def update_service_group(
|
||||
group_id: str,
|
||||
request: ServiceGroupRequest,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""更新服务分组"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组
|
||||
group = db.query(ServiceGroup).filter(ServiceGroup.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Service group not found")
|
||||
|
||||
# 更新分组信息
|
||||
group.name = request.name
|
||||
group.description = request.description
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务分组更新成功",
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"status": group.status,
|
||||
"created_at": group.created_at.isoformat(),
|
||||
"updated_at": group.updated_at.isoformat()
|
||||
}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/groups/{group_id}")
|
||||
async def delete_service_group(
|
||||
group_id: str,
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""删除服务分组"""
|
||||
# 检查用户权限
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# 创建数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 查询分组
|
||||
group = db.query(ServiceGroup).filter(ServiceGroup.id == group_id).first()
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Service group not found")
|
||||
|
||||
# 检查分组是否有服务
|
||||
services_count = db.query(AlgorithmService).filter(AlgorithmService.group_id == group_id).count()
|
||||
if services_count > 0:
|
||||
raise HTTPException(status_code=400, detail=f"该分组下还有{services_count}个服务,无法删除")
|
||||
|
||||
# 删除分组
|
||||
db.delete(group)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "服务分组删除成功",
|
||||
"group_id": group_id
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 批量服务操作API
|
||||
|
||||
@router.post("/batch/start")
|
||||
@@ -1230,3 +1081,85 @@ async def call_service(
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/sync-api-endpoints")
|
||||
async def sync_api_endpoints(
|
||||
current_user: UserResponse = Depends(get_current_active_user)
|
||||
):
|
||||
"""同步所有服务到API端点"""
|
||||
if current_user.role_name != "admin":
|
||||
raise HTTPException(status_code=403, detail="权限不足")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
services = db.query(AlgorithmService).all()
|
||||
synced_count = 0
|
||||
|
||||
for service in services:
|
||||
existing_endpoint = db.query(ApiEndpoint).filter(
|
||||
(ApiEndpoint.service_id == service.service_id) |
|
||||
(ApiEndpoint.path == f"/api/v1/algorithms/{service.algorithm_name}/call")
|
||||
).first()
|
||||
|
||||
if existing_endpoint:
|
||||
continue
|
||||
|
||||
algorithm = db.query(Algorithm).filter(Algorithm.name == service.algorithm_name).first()
|
||||
if not algorithm:
|
||||
algorithm = Algorithm(
|
||||
id=str(uuid.uuid4()),
|
||||
name=service.algorithm_name,
|
||||
description=f"算法服务: {service.name}",
|
||||
type=service.tech_category or "computer_vision",
|
||||
tech_category=service.tech_category or "computer_vision",
|
||||
output_type=service.output_type or "image"
|
||||
)
|
||||
db.add(algorithm)
|
||||
db.commit()
|
||||
db.refresh(algorithm)
|
||||
|
||||
version = db.query(AlgorithmVersion).filter(
|
||||
AlgorithmVersion.algorithm_id == algorithm.id
|
||||
).first()
|
||||
if not version:
|
||||
version = AlgorithmVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
algorithm_id=algorithm.id,
|
||||
version=service.version or "1.0.0",
|
||||
url=service.api_url
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
api_endpoint = ApiEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
name=service.name,
|
||||
description=f"{service.name} API端点",
|
||||
path=f"/api/v1/algorithms/{algorithm.id}/call/{service.service_id[:8]}",
|
||||
method="POST",
|
||||
algorithm_id=algorithm.id,
|
||||
version_id=version.id,
|
||||
service_id=service.service_id,
|
||||
requires_auth=False,
|
||||
is_public=True,
|
||||
status=service.status or "active",
|
||||
config={
|
||||
"service_url": service.api_url,
|
||||
"timeout": service.config.get("timeout") if service.config else 30
|
||||
}
|
||||
)
|
||||
db.add(api_endpoint)
|
||||
synced_count += 1
|
||||
|
||||
db.commit()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"同步完成,共同步 {synced_count} 个API端点"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"同步API端点失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"同步失败: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -4,15 +4,19 @@ from typing import Optional, Tuple
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class MinioClient:
|
||||
"""MinIO客户端类"""
|
||||
"""MinIO客户端类,支持本地存储作为备选"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化MinIO客户端"""
|
||||
self.local_storage_path = "data_storage/local_uploads"
|
||||
os.makedirs(self.local_storage_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
self.client = Minio(
|
||||
settings.MINIO_ENDPOINT,
|
||||
@@ -21,16 +25,24 @@ class MinioClient:
|
||||
secure=settings.MINIO_SECURE
|
||||
)
|
||||
self.bucket_name = settings.MINIO_BUCKET_NAME
|
||||
self.is_connected = True # 先设置为True,这样在调用其他方法时不会报错
|
||||
|
||||
# 测试真实连接
|
||||
self.client.list_buckets()
|
||||
self.is_connected = True
|
||||
logging.info("MinIO connected successfully")
|
||||
|
||||
# 确保存储桶存在
|
||||
self._ensure_bucket_exists()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to connect to MinIO: {e}. Running in offline mode.")
|
||||
logging.warning(f"Failed to connect to MinIO: {e}. Using local storage.")
|
||||
self.client = None
|
||||
self.bucket_name = settings.MINIO_BUCKET_NAME
|
||||
self.is_connected = False
|
||||
|
||||
def _get_local_path(self, object_name: str) -> str:
|
||||
"""获取本地存储路径"""
|
||||
return os.path.join(self.local_storage_path, object_name)
|
||||
|
||||
def _ensure_bucket_exists(self):
|
||||
"""确保存储桶存在"""
|
||||
if not self.is_connected:
|
||||
@@ -60,24 +72,32 @@ class MinioClient:
|
||||
return False
|
||||
|
||||
def upload_from_bytes(self, data: bytes, object_name: str) -> bool:
|
||||
"""从字节数据上传文件"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Upload skipped.")
|
||||
return False
|
||||
"""从字节数据上传文件,优先使用MinIO,失败则使用本地存储"""
|
||||
if self.is_connected:
|
||||
try:
|
||||
import io
|
||||
file_obj = io.BytesIO(data)
|
||||
self.client.put_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_obj,
|
||||
length=len(data),
|
||||
part_size=10*1024*1024
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO upload error: {e}, falling back to local storage")
|
||||
|
||||
# 使用本地存储作为备选
|
||||
try:
|
||||
import io
|
||||
file_obj = io.BytesIO(data)
|
||||
self.client.put_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_obj,
|
||||
length=len(data),
|
||||
part_size=10*1024*1024
|
||||
)
|
||||
local_path = self._get_local_path(object_name)
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(data)
|
||||
logging.info(f"File saved to local storage: {local_path}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO upload error: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Local storage save error: {e}")
|
||||
return False
|
||||
|
||||
def upload_fileobj(self, file_obj: io.BytesIO, object_name: str, content_type: str = "application/octet-stream") -> bool:
|
||||
@@ -118,38 +138,54 @@ class MinioClient:
|
||||
return False
|
||||
|
||||
def get_object(self, object_name: str) -> Optional[bytes]:
|
||||
"""获取对象内容"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Get object skipped.")
|
||||
return None
|
||||
"""获取对象内容,优先使用MinIO,失败则使用本地存储"""
|
||||
if self.is_connected:
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
self.bucket_name,
|
||||
object_name
|
||||
)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO get object error: {e}, falling back to local storage")
|
||||
|
||||
# 使用本地存储作为备选
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
self.bucket_name,
|
||||
object_name
|
||||
)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO get object error: {e}")
|
||||
local_path = self._get_local_path(object_name)
|
||||
if os.path.exists(local_path):
|
||||
with open(local_path, 'rb') as f:
|
||||
return f.read()
|
||||
else:
|
||||
logging.warning(f"File not found in local storage: {local_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Local storage get error: {e}")
|
||||
return None
|
||||
|
||||
def delete_object(self, object_name: str) -> bool:
|
||||
"""删除对象"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Delete object skipped.")
|
||||
return False
|
||||
"""删除对象,优先使用MinIO,失败则使用本地存储"""
|
||||
if self.is_connected:
|
||||
try:
|
||||
self.client.remove_object(
|
||||
self.bucket_name,
|
||||
object_name
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO delete error: {e}")
|
||||
|
||||
# 使用本地存储作为备选
|
||||
try:
|
||||
self.client.remove_object(
|
||||
self.bucket_name,
|
||||
object_name
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO delete error: {e}")
|
||||
local_path = self._get_local_path(object_name)
|
||||
if os.path.exists(local_path):
|
||||
os.remove(local_path)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"Local storage delete error: {e}")
|
||||
return False
|
||||
|
||||
def list_objects(self, prefix: str = "") -> list:
|
||||
|
||||
Reference in New Issue
Block a user