good version for 算法注册
This commit is contained in:
108
services/image-recognition/main.py
Normal file
108
services/image-recognition/main.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import json
|
||||
import logging
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from .ai_algorithm import ImageRecognizer
|
||||
from .config import settings
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FastAPI应用
|
||||
app = FastAPI(
|
||||
title="图像识别服务",
|
||||
description="提供图像识别功能的AI服务",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = ImageRecognizer()
|
||||
|
||||
# 定义请求模型
|
||||
class PredictRequest(BaseModel):
|
||||
input_data: list
|
||||
params: dict = {}
|
||||
|
||||
# 定义响应模型
|
||||
class PredictResponse(BaseModel):
|
||||
predictions: list
|
||||
status: str
|
||||
|
||||
@app.post("/predict", response_model=PredictResponse)
|
||||
async def predict(request: PredictRequest):
|
||||
"""算法预测接口"""
|
||||
try:
|
||||
logger.info(f"Received prediction request for {len(request.input_data)} images")
|
||||
predictions = recognizer.recognize(request.input_data, request.params)
|
||||
logger.info(f"Prediction completed: {predictions}")
|
||||
return PredictResponse(
|
||||
predictions=predictions,
|
||||
status="success"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/predict/file")
|
||||
async def predict_file(file: UploadFile = File(...)):
|
||||
"""通过文件上传进行预测"""
|
||||
try:
|
||||
logger.info(f"Received file upload: {file.filename}")
|
||||
|
||||
# 读取文件内容
|
||||
contents = await file.read()
|
||||
|
||||
# 转换为base64
|
||||
image_base64 = base64.b64encode(contents).decode('utf-8')
|
||||
|
||||
# 调用识别器
|
||||
predictions = recognizer.recognize([image_base64])
|
||||
|
||||
logger.info(f"File prediction completed: {predictions}")
|
||||
return {
|
||||
"predictions": predictions,
|
||||
"status": "success",
|
||||
"filename": file.filename
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"File prediction error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "image-recognition",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/info")
|
||||
async def service_info():
|
||||
"""服务信息接口"""
|
||||
return {
|
||||
"name": "图像识别服务",
|
||||
"description": "提供图像识别功能的AI服务",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"/predict": "POST - 图像识别预测",
|
||||
"/predict/file": "POST - 通过文件上传进行预测",
|
||||
"/health": "GET - 健康检查",
|
||||
"/info": "GET - 服务信息"
|
||||
}
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=settings.DEBUG
|
||||
)
|
||||
Reference in New Issue
Block a user