good version for 算法注册
This commit is contained in:
165
backend/app/services/comparison_service.py
Normal file
165
backend/app/services/comparison_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from typing import Dict, Any, List
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComparisonService:
|
||||
"""效果对比服务"""
|
||||
|
||||
async def compare_algorithms(
|
||||
self,
|
||||
input_data: Dict[str, Any],
|
||||
algorithm_configs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""比较多个算法的效果
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
algorithm_configs: 算法配置列表,每个配置包含服务URL、参数等
|
||||
|
||||
Returns:
|
||||
对比结果
|
||||
"""
|
||||
try:
|
||||
# 异步执行所有算法
|
||||
tasks = []
|
||||
for config in algorithm_configs:
|
||||
task = self._execute_algorithm(config, input_data)
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
comparison_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
comparison_results.append({
|
||||
"algorithm_id": algorithm_configs[i].get("id"),
|
||||
"algorithm_name": algorithm_configs[i].get("name"),
|
||||
"success": False,
|
||||
"error": str(result),
|
||||
"output": None,
|
||||
"execution_time": 0
|
||||
})
|
||||
else:
|
||||
comparison_results.append({
|
||||
"algorithm_id": algorithm_configs[i].get("id"),
|
||||
"algorithm_name": algorithm_configs[i].get("name"),
|
||||
"success": True,
|
||||
"error": None,
|
||||
"output": result.get("output"),
|
||||
"execution_time": result.get("execution_time", 0)
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": comparison_results,
|
||||
"input_data": input_data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Comparison error: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"results": []
|
||||
}
|
||||
|
||||
async def _execute_algorithm(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
input_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""执行单个算法
|
||||
|
||||
Args:
|
||||
config: 算法配置
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
url = config.get("url")
|
||||
params = config.get("params", {})
|
||||
|
||||
if not url:
|
||||
raise ValueError("缺少算法服务URL")
|
||||
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"input_data": input_data.get("input_data", input_data),
|
||||
"params": params
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(f"{url}/predict", json=request_data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"output": result,
|
||||
"execution_time": execution_time
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Algorithm execution error: {str(e)}")
|
||||
raise e
|
||||
|
||||
def generate_comparison_report(self, comparison_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""生成对比报告
|
||||
|
||||
Args:
|
||||
comparison_results: 对比结果
|
||||
|
||||
Returns:
|
||||
对比报告
|
||||
"""
|
||||
try:
|
||||
if not comparison_results.get("success"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": comparison_results.get("error", "对比失败")
|
||||
}
|
||||
|
||||
results = comparison_results.get("results", [])
|
||||
|
||||
# 分析结果
|
||||
successful_algorithms = [r for r in results if r.get("success")]
|
||||
failed_algorithms = [r for r in results if not r.get("success")]
|
||||
|
||||
# 计算平均执行时间
|
||||
if successful_algorithms:
|
||||
avg_execution_time = sum(r.get("execution_time", 0) for r in successful_algorithms) / len(successful_algorithms)
|
||||
else:
|
||||
avg_execution_time = 0
|
||||
|
||||
# 生成报告
|
||||
report = {
|
||||
"summary": {
|
||||
"total_algorithms": len(results),
|
||||
"successful_algorithms": len(successful_algorithms),
|
||||
"failed_algorithms": len(failed_algorithms),
|
||||
"average_execution_time": round(avg_execution_time, 2)
|
||||
},
|
||||
"details": results,
|
||||
"input_data": comparison_results.get("input_data")
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"report": report
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Report generation error: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
Reference in New Issue
Block a user