81 lines
2.0 KiB
Python
81 lines
2.0 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
import json
|
|
import logging
|
|
from .ai_algorithm import TextClassifier
|
|
from .config import settings
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 初始化FastAPI应用
|
|
app = FastAPI(
|
|
title="文本分类服务",
|
|
description="提供文本分类功能的AI服务",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# 初始化分类器
|
|
classifier = TextClassifier()
|
|
|
|
# 定义请求模型
|
|
class PredictRequest(BaseModel):
|
|
input_data: list
|
|
params: dict = {}
|
|
|
|
# 定义响应模型
|
|
class PredictResponse(BaseModel):
|
|
predictions: list
|
|
status: str
|
|
|
|
@app.post("/predict", response_model=PredictResponse)
|
|
async def predict(request: PredictRequest):
|
|
"""算法预测接口"""
|
|
try:
|
|
logger.info(f"Received prediction request: {request.input_data}")
|
|
predictions = classifier.classify(request.input_data, request.params)
|
|
logger.info(f"Prediction completed: {predictions}")
|
|
return PredictResponse(
|
|
predictions=predictions,
|
|
status="success"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Prediction error: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""健康检查接口"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "text-classification",
|
|
"version": "1.0.0"
|
|
}
|
|
|
|
@app.get("/info")
|
|
async def service_info():
|
|
"""服务信息接口"""
|
|
return {
|
|
"name": "文本分类服务",
|
|
"description": "提供文本分类功能的AI服务",
|
|
"version": "1.0.0",
|
|
"endpoints": {
|
|
"/predict": "POST - 文本分类预测",
|
|
"/health": "GET - 健康检查",
|
|
"/info": "GET - 服务信息"
|
|
}
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=settings.HOST,
|
|
port=settings.PORT,
|
|
reload=settings.DEBUG
|
|
)
|