from fastapi import FastAPI, HTTPException, UploadFile, File from pydantic import BaseModel import uvicorn import json import logging import base64 from io import BytesIO from .ai_algorithm import ImageRecognizer from .config import settings # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 初始化FastAPI应用 app = FastAPI( title="图像识别服务", description="提供图像识别功能的AI服务", version="1.0.0" ) # 初始化识别器 recognizer = ImageRecognizer() # 定义请求模型 class PredictRequest(BaseModel): input_data: list params: dict = {} # 定义响应模型 class PredictResponse(BaseModel): predictions: list status: str @app.post("/predict", response_model=PredictResponse) async def predict(request: PredictRequest): """算法预测接口""" try: logger.info(f"Received prediction request for {len(request.input_data)} images") predictions = recognizer.recognize(request.input_data, request.params) logger.info(f"Prediction completed: {predictions}") return PredictResponse( predictions=predictions, status="success" ) except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/file") async def predict_file(file: UploadFile = File(...)): """通过文件上传进行预测""" try: logger.info(f"Received file upload: {file.filename}") # 读取文件内容 contents = await file.read() # 转换为base64 image_base64 = base64.b64encode(contents).decode('utf-8') # 调用识别器 predictions = recognizer.recognize([image_base64]) logger.info(f"File prediction completed: {predictions}") return { "predictions": predictions, "status": "success", "filename": file.filename } except Exception as e: logger.error(f"File prediction error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """健康检查接口""" return { "status": "healthy", "service": "image-recognition", "version": "1.0.0" } @app.get("/info") async def service_info(): """服务信息接口""" return { "name": "图像识别服务", "description": "提供图像识别功能的AI服务", "version": "1.0.0", "endpoints": { "/predict": "POST - 图像识别预测", "/predict/file": "POST - 通过文件上传进行预测", "/health": "GET - 健康检查", "/info": "GET - 服务信息" } } if __name__ == "__main__": uvicorn.run( "main:app", host=settings.HOST, port=settings.PORT, reload=settings.DEBUG )