Files
algorithm/backend/create_sample_algorithms.py

151 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""创建示例算法数据"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app.models.database import SessionLocal
from app.models.models import Algorithm, AlgorithmVersion
from datetime import datetime
import uuid
def create_sample_algorithms():
"""创建示例算法"""
db = SessionLocal()
try:
# 示例算法数据
algorithms_data = [
{
"name": "目标检测",
"description": "识别图像中的物体位置和类别,支持人脸、车辆、物品等多种目标检测",
"type": "computer_vision",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8001",
"is_default": True
}
]
},
{
"name": "视频分析",
"description": "分析视频内容,提取关键帧、识别动作、追踪物体等",
"type": "computer_vision",
"tech_category": "video_processing",
"output_type": "video",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8002",
"is_default": True
}
]
},
{
"name": "图像增强",
"description": "提升图像质量,包括去噪、超分辨率、色彩校正等功能",
"type": "computer_vision",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8003",
"is_default": True
}
]
},
{
"name": "文本分类",
"description": "对文本内容进行分类,支持新闻分类、情感分析、垃圾邮件识别等",
"type": "nlp",
"tech_category": "nlp",
"output_type": "text",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8004",
"is_default": True
}
]
},
{
"name": "异常检测",
"description": "检测数据中的异常模式,适用于工业监控、金融风控等场景",
"type": "ml",
"tech_category": "ml",
"output_type": "json",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8005",
"is_default": True
}
]
},
{
"name": "医学影像分析",
"description": "分析医学影像辅助医生进行疾病诊断支持CT、MRI等多种影像格式",
"type": "medical",
"tech_category": "computer_vision",
"output_type": "image",
"versions": [
{
"version": "1.0.0",
"url": "http://0.0.0.0:8006",
"is_default": True
}
]
}
]
# 创建算法
for algo_data in algorithms_data:
# 检查算法是否已存在
existing_algo = db.query(Algorithm).filter(Algorithm.name == algo_data["name"]).first()
if existing_algo:
print(f"✓ 算法 '{algo_data['name']}' 已存在,跳过")
continue
# 创建算法
algorithm = Algorithm(
id=str(uuid.uuid4()),
name=algo_data["name"],
description=algo_data["description"],
type=algo_data["type"],
tech_category=algo_data["tech_category"],
output_type=algo_data["output_type"],
status="active"
)
db.add(algorithm)
db.flush() # 获取算法ID
# 创建版本
for version_data in algo_data["versions"]:
version = AlgorithmVersion(
id=str(uuid.uuid4()),
algorithm_id=algorithm.id,
version=version_data["version"],
url=version_data["url"],
is_default=version_data["is_default"]
)
db.add(version)
print(f"✓ 已创建算法: {algo_data['name']}")
db.commit()
print("\n示例算法创建完成!")
except Exception as e:
db.rollback()
print(f"创建示例算法失败: {e}")
sys.exit(1)
finally:
db.close()
if __name__ == "__main__":
create_sample_algorithms()