345 lines
11 KiB
Python
345 lines
11 KiB
Python
"""数据管理路由,提供输入数据、输出结果和元数据的管理功能"""
|
||
|
||
from fastapi import APIRouter, HTTPException, status, Depends, UploadFile, File
|
||
from typing import List, Dict, Any, Optional
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.orm import Session
|
||
import json
|
||
|
||
from app.services.data_manager import data_manager
|
||
from app.models.database import get_db
|
||
from app.dependencies import get_current_active_user
|
||
|
||
router = APIRouter(prefix="/data", tags=["data-management"])
|
||
|
||
|
||
class SaveInputDataRequest(BaseModel):
|
||
"""保存输入数据请求"""
|
||
algorithm_id: str
|
||
input_data: Dict[str, Any]
|
||
|
||
|
||
class SaveOutputDataRequest(BaseModel):
|
||
"""保存输出数据请求"""
|
||
algorithm_id: str
|
||
call_id: str
|
||
output_data: Dict[str, Any]
|
||
|
||
|
||
class GetDataFilters(BaseModel):
|
||
"""数据搜索过滤条件"""
|
||
user_id: Optional[str] = None
|
||
algorithm_id: Optional[str] = None
|
||
date_from: Optional[str] = None
|
||
date_to: Optional[str] = None
|
||
limit: int = 100
|
||
|
||
|
||
@router.post("/input")
|
||
async def save_input_data(
|
||
request: SaveInputDataRequest,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""保存输入数据"""
|
||
# 检查用户权限
|
||
if current_user.get("role") not in ["admin", "user"] or current_user.get("id") != request.user_id:
|
||
if current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
data_id = data_manager.save_input_data(
|
||
user_id=current_user.get("id"),
|
||
algorithm_id=request.algorithm_id,
|
||
input_data=request.input_data
|
||
)
|
||
|
||
if data_id:
|
||
return {
|
||
"success": True,
|
||
"data_id": data_id,
|
||
"message": "Input data saved successfully"
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="Failed to save input data")
|
||
|
||
|
||
@router.post("/output")
|
||
async def save_output_data(
|
||
request: SaveOutputDataRequest,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""保存输出结果数据"""
|
||
# 检查用户权限
|
||
if current_user.get("role") not in ["admin", "user"]:
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
data_id = data_manager.save_output_data(
|
||
user_id=current_user.get("id"),
|
||
algorithm_id=request.algorithm_id,
|
||
call_id=request.call_id,
|
||
output_data=request.output_data
|
||
)
|
||
|
||
if data_id:
|
||
return {
|
||
"success": True,
|
||
"data_id": data_id,
|
||
"message": "Output data saved successfully"
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="Failed to save output data")
|
||
|
||
|
||
@router.get("/input/{data_id}")
|
||
async def get_input_data(
|
||
data_id: str,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""获取输入数据"""
|
||
data = data_manager.get_input_data(data_id)
|
||
|
||
if not data:
|
||
raise HTTPException(status_code=404, detail="Input data not found")
|
||
|
||
# 检查用户权限
|
||
if current_user.get("role") != "admin" and data.get("user_id") != current_user.get("id"):
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
return data
|
||
|
||
|
||
@router.get("/output/{data_id}")
|
||
async def get_output_data(
|
||
data_id: str,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""获取输出结果数据"""
|
||
data = data_manager.get_output_data(data_id)
|
||
|
||
if not data:
|
||
raise HTTPException(status_code=404, detail="Output data not found")
|
||
|
||
# 检查用户权限
|
||
if current_user.get("role") != "admin" and data.get("user_id") != current_user.get("id"):
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
return data
|
||
|
||
|
||
@router.get("/inputs/user")
|
||
async def get_user_inputs(
|
||
algorithm_id: Optional[str] = None,
|
||
limit: int = 100,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""获取用户的历史输入数据"""
|
||
# 检查用户权限
|
||
if current_user.get("role") != "admin" and current_user.get("id") != current_user.get("id"):
|
||
if current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
inputs = data_manager.get_user_inputs(
|
||
user_id=current_user.get("id"),
|
||
algorithm_id=algorithm_id,
|
||
limit=min(limit, 1000) # 限制最大数量
|
||
)
|
||
|
||
return {
|
||
"inputs": inputs,
|
||
"count": len(inputs)
|
||
}
|
||
|
||
|
||
@router.get("/outputs/user")
|
||
async def get_user_outputs(
|
||
algorithm_id: Optional[str] = None,
|
||
limit: int = 100,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""获取用户的历史输出数据"""
|
||
# 检查用户权限
|
||
if current_user.get("role") != "admin" and current_user.get("id") != current_user.get("id"):
|
||
if current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
outputs = data_manager.get_user_outputs(
|
||
user_id=current_user.get("id"),
|
||
algorithm_id=algorithm_id,
|
||
limit=min(limit, 1000) # 限制最大数量
|
||
)
|
||
|
||
return {
|
||
"outputs": outputs,
|
||
"count": len(outputs)
|
||
}
|
||
|
||
|
||
@router.post("/media/upload")
|
||
async def upload_media_file(
|
||
file: UploadFile = File(...),
|
||
algorithm_id: str = None,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""上传媒体文件(如图片、视频等)"""
|
||
if not algorithm_id:
|
||
raise HTTPException(status_code=400, detail="algorithm_id is required")
|
||
|
||
# 读取文件内容
|
||
file_content = await file.read()
|
||
|
||
# 保存到数据管理器
|
||
file_path = data_manager.save_media_file(
|
||
user_id=current_user.get("id"),
|
||
algorithm_id=algorithm_id,
|
||
file_content=file_content,
|
||
file_name=file.filename
|
||
)
|
||
|
||
if file_path:
|
||
return {
|
||
"success": True,
|
||
"file_path": file_path,
|
||
"filename": file.filename,
|
||
"size": len(file_content),
|
||
"message": "Media file uploaded successfully"
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="Failed to upload media file")
|
||
|
||
|
||
@router.get("/media/{file_path:path}")
|
||
async def get_media_file(
|
||
file_path: str,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""获取媒体文件"""
|
||
# 检查用户权限 - 确保用户只能访问自己的文件或公共文件
|
||
if current_user.get("role") != "admin" and not file_path.startswith(f"media/{current_user.get('id')}/"):
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
content = data_manager.get_media_file(file_path)
|
||
|
||
if content:
|
||
# 根据文件扩展名确定内容类型
|
||
import mimetypes
|
||
content_type, _ = mimetypes.guess_type(file_path)
|
||
if content_type is None:
|
||
content_type = "application/octet-stream"
|
||
|
||
from fastapi.responses import Response
|
||
return Response(content=content, media_type=content_type)
|
||
else:
|
||
raise HTTPException(status_code=404, detail="Media file not found")
|
||
|
||
|
||
@router.post("/snapshots/create")
|
||
async def create_data_snapshot(
|
||
call_id: str,
|
||
current_user: dict = Depends(get_current_active_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""创建数据快照"""
|
||
from app.models.models import AlgorithmCall
|
||
|
||
# 获取调用记录
|
||
call_record = db.query(AlgorithmCall).filter(AlgorithmCall.id == call_id).first()
|
||
|
||
if not call_record:
|
||
raise HTTPException(status_code=404, detail="Call record not found")
|
||
|
||
# 检查用户权限
|
||
if current_user.get("role") != "admin" and call_record.user_id != current_user.get("id"):
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
# 创建快照
|
||
snapshot = data_manager.create_data_snapshot(call_record)
|
||
|
||
if snapshot:
|
||
return {
|
||
"success": True,
|
||
"snapshot": snapshot,
|
||
"message": "Data snapshot created successfully"
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="Failed to create data snapshot")
|
||
|
||
|
||
@router.post("/search")
|
||
async def search_data_by_metadata(
|
||
filters: GetDataFilters,
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""根据元数据搜索数据"""
|
||
# 检查用户权限 - 用户只能搜索自己的数据,管理员可以搜索所有数据
|
||
if current_user.get("role") != "admin":
|
||
if filters.user_id and filters.user_id != current_user.get("id"):
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
# 如果没有指定用户ID,则默认搜索当前用户的数据
|
||
if not filters.user_id:
|
||
filters.user_id = current_user.get("id")
|
||
|
||
results = data_manager.search_data_by_metadata(filters.dict())
|
||
|
||
return {
|
||
"results": results,
|
||
"count": len(results)
|
||
}
|
||
|
||
|
||
@router.delete("/user-data")
|
||
async def delete_user_data(
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""删除用户的所有数据"""
|
||
# 检查用户权限
|
||
if current_user.get("role") != "admin" and current_user.get("id") != current_user.get("id"):
|
||
if current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||
|
||
success = data_manager.delete_user_data(current_user.get("id"))
|
||
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": "User data deleted successfully"
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="Failed to delete user data")
|
||
|
||
|
||
@router.get("/statistics")
|
||
async def get_data_statistics(
|
||
current_user: dict = Depends(get_current_active_user)
|
||
):
|
||
"""获取数据统计信息"""
|
||
# 这里返回基本的数据统计信息
|
||
# 在实际实现中,可能会从数据库和存储系统中收集更详细的统计信息
|
||
from sqlalchemy import func
|
||
from app.models.models import AlgorithmCall
|
||
|
||
db = next(get_db())
|
||
|
||
# 统计调用次数
|
||
total_calls = db.query(func.count(AlgorithmCall.id)).scalar()
|
||
|
||
# 统计当前用户调用次数
|
||
user_calls = db.query(func.count(AlgorithmCall.id)).filter(
|
||
AlgorithmCall.user_id == current_user.get("id")
|
||
).scalar()
|
||
|
||
# 管理员可以看到全部统计,普通用户只能看到自己的统计
|
||
if current_user.get("role") == "admin":
|
||
stats = {
|
||
"total_calls": total_calls,
|
||
"user_calls": user_calls,
|
||
"total_users": 0, # 在实际实现中,从用户表统计
|
||
"storage_used": "N/A", # 在实际实现中,从存储系统获取
|
||
"timestamp": "now"
|
||
}
|
||
else:
|
||
stats = {
|
||
"user_calls": user_calls,
|
||
"storage_used_by_user": "N/A", # 在实际实现中,从存储系统获取
|
||
"timestamp": "now"
|
||
}
|
||
|
||
return stats |