151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
||
"""创建示例算法数据"""
|
||
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from app.models.database import SessionLocal
|
||
from app.models.models import Algorithm, AlgorithmVersion
|
||
from datetime import datetime
|
||
import uuid
|
||
|
||
def create_sample_algorithms():
|
||
"""创建示例算法"""
|
||
db = SessionLocal()
|
||
|
||
try:
|
||
# 示例算法数据
|
||
algorithms_data = [
|
||
{
|
||
"name": "目标检测",
|
||
"description": "识别图像中的物体位置和类别,支持人脸、车辆、物品等多种目标检测",
|
||
"type": "computer_vision",
|
||
"tech_category": "computer_vision",
|
||
"output_type": "image",
|
||
"versions": [
|
||
{
|
||
"version": "1.0.0",
|
||
"url": "http://0.0.0.0:8001",
|
||
"is_default": True
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"name": "视频分析",
|
||
"description": "分析视频内容,提取关键帧、识别动作、追踪物体等",
|
||
"type": "computer_vision",
|
||
"tech_category": "video_processing",
|
||
"output_type": "video",
|
||
"versions": [
|
||
{
|
||
"version": "1.0.0",
|
||
"url": "http://0.0.0.0:8002",
|
||
"is_default": True
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"name": "图像增强",
|
||
"description": "提升图像质量,包括去噪、超分辨率、色彩校正等功能",
|
||
"type": "computer_vision",
|
||
"tech_category": "computer_vision",
|
||
"output_type": "image",
|
||
"versions": [
|
||
{
|
||
"version": "1.0.0",
|
||
"url": "http://0.0.0.0:8003",
|
||
"is_default": True
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"name": "文本分类",
|
||
"description": "对文本内容进行分类,支持新闻分类、情感分析、垃圾邮件识别等",
|
||
"type": "nlp",
|
||
"tech_category": "nlp",
|
||
"output_type": "text",
|
||
"versions": [
|
||
{
|
||
"version": "1.0.0",
|
||
"url": "http://0.0.0.0:8004",
|
||
"is_default": True
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"name": "异常检测",
|
||
"description": "检测数据中的异常模式,适用于工业监控、金融风控等场景",
|
||
"type": "ml",
|
||
"tech_category": "ml",
|
||
"output_type": "json",
|
||
"versions": [
|
||
{
|
||
"version": "1.0.0",
|
||
"url": "http://0.0.0.0:8005",
|
||
"is_default": True
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"name": "医学影像分析",
|
||
"description": "分析医学影像,辅助医生进行疾病诊断,支持CT、MRI等多种影像格式",
|
||
"type": "medical",
|
||
"tech_category": "computer_vision",
|
||
"output_type": "image",
|
||
"versions": [
|
||
{
|
||
"version": "1.0.0",
|
||
"url": "http://0.0.0.0:8006",
|
||
"is_default": True
|
||
}
|
||
]
|
||
}
|
||
]
|
||
|
||
# 创建算法
|
||
for algo_data in algorithms_data:
|
||
# 检查算法是否已存在
|
||
existing_algo = db.query(Algorithm).filter(Algorithm.name == algo_data["name"]).first()
|
||
if existing_algo:
|
||
print(f"✓ 算法 '{algo_data['name']}' 已存在,跳过")
|
||
continue
|
||
|
||
# 创建算法
|
||
algorithm = Algorithm(
|
||
id=str(uuid.uuid4()),
|
||
name=algo_data["name"],
|
||
description=algo_data["description"],
|
||
type=algo_data["type"],
|
||
tech_category=algo_data["tech_category"],
|
||
output_type=algo_data["output_type"],
|
||
status="active"
|
||
)
|
||
db.add(algorithm)
|
||
db.flush() # 获取算法ID
|
||
|
||
# 创建版本
|
||
for version_data in algo_data["versions"]:
|
||
version = AlgorithmVersion(
|
||
id=str(uuid.uuid4()),
|
||
algorithm_id=algorithm.id,
|
||
version=version_data["version"],
|
||
url=version_data["url"],
|
||
is_default=version_data["is_default"]
|
||
)
|
||
db.add(version)
|
||
|
||
print(f"✓ 已创建算法: {algo_data['name']}")
|
||
|
||
db.commit()
|
||
print("\n示例算法创建完成!")
|
||
|
||
except Exception as e:
|
||
db.rollback()
|
||
print(f"创建示例算法失败: {e}")
|
||
sys.exit(1)
|
||
finally:
|
||
db.close()
|
||
|
||
if __name__ == "__main__":
|
||
create_sample_algorithms() |