Files
algorithm/backend/app/services/comparison_service.py

166 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, Any, List
import asyncio
import httpx
import logging
logger = logging.getLogger(__name__)
class ComparisonService:
"""效果对比服务"""
async def compare_algorithms(
self,
input_data: Dict[str, Any],
algorithm_configs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""比较多个算法的效果
Args:
input_data: 输入数据
algorithm_configs: 算法配置列表每个配置包含服务URL、参数等
Returns:
对比结果
"""
try:
# 异步执行所有算法
tasks = []
for config in algorithm_configs:
task = self._execute_algorithm(config, input_data)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
comparison_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
comparison_results.append({
"algorithm_id": algorithm_configs[i].get("id"),
"algorithm_name": algorithm_configs[i].get("name"),
"success": False,
"error": str(result),
"output": None,
"execution_time": 0
})
else:
comparison_results.append({
"algorithm_id": algorithm_configs[i].get("id"),
"algorithm_name": algorithm_configs[i].get("name"),
"success": True,
"error": None,
"output": result.get("output"),
"execution_time": result.get("execution_time", 0)
})
return {
"success": True,
"results": comparison_results,
"input_data": input_data
}
except Exception as e:
logger.error(f"Comparison error: {str(e)}")
return {
"success": False,
"error": str(e),
"results": []
}
async def _execute_algorithm(
self,
config: Dict[str, Any],
input_data: Dict[str, Any]
) -> Dict[str, Any]:
"""执行单个算法
Args:
config: 算法配置
input_data: 输入数据
Returns:
执行结果
"""
import time
start_time = time.time()
try:
url = config.get("url")
params = config.get("params", {})
if not url:
raise ValueError("缺少算法服务URL")
# 构建请求数据
request_data = {
"input_data": input_data.get("input_data", input_data),
"params": params
}
# 发送请求
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(f"{url}/predict", json=request_data)
response.raise_for_status()
result = response.json()
execution_time = time.time() - start_time
return {
"output": result,
"execution_time": execution_time
}
except Exception as e:
logger.error(f"Algorithm execution error: {str(e)}")
raise e
def generate_comparison_report(self, comparison_results: Dict[str, Any]) -> Dict[str, Any]:
"""生成对比报告
Args:
comparison_results: 对比结果
Returns:
对比报告
"""
try:
if not comparison_results.get("success"):
return {
"success": False,
"error": comparison_results.get("error", "对比失败")
}
results = comparison_results.get("results", [])
# 分析结果
successful_algorithms = [r for r in results if r.get("success")]
failed_algorithms = [r for r in results if not r.get("success")]
# 计算平均执行时间
if successful_algorithms:
avg_execution_time = sum(r.get("execution_time", 0) for r in successful_algorithms) / len(successful_algorithms)
else:
avg_execution_time = 0
# 生成报告
report = {
"summary": {
"total_algorithms": len(results),
"successful_algorithms": len(successful_algorithms),
"failed_algorithms": len(failed_algorithms),
"average_execution_time": round(avg_execution_time, 2)
},
"details": results,
"input_data": comparison_results.get("input_data")
}
return {
"success": True,
"report": report
}
except Exception as e:
logger.error(f"Report generation error: {str(e)}")
return {
"success": False,
"error": str(e)
}