92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
import openai
|
||
from typing import Optional, Dict, Any
|
||
|
||
from app.config.settings import settings
|
||
|
||
|
||
class OpenAIClient:
|
||
"""OpenAI客户端类"""
|
||
|
||
def __init__(self):
|
||
"""初始化OpenAI客户端"""
|
||
if settings.OPENAI_API_KEY:
|
||
openai.api_key = settings.OPENAI_API_KEY
|
||
self.model = settings.OPENAI_MODEL
|
||
|
||
def generate_simulation_data(self, prompt: str, data_type: str = "text") -> Optional[Dict[str, Any]]:
|
||
"""生成仿真输入数据"""
|
||
try:
|
||
# 构建系统提示
|
||
system_prompt = f"你是一个数据生成助手,根据用户的描述生成{data_type}类型的仿真输入数据。"
|
||
|
||
# 构建用户提示
|
||
if data_type == "image":
|
||
user_prompt = f"根据以下描述生成一张图片的详细描述,包括颜色、形状、大小等特征:{prompt}"
|
||
elif data_type == "text":
|
||
user_prompt = f"根据以下描述生成一段文本数据:{prompt}"
|
||
elif data_type == "structured":
|
||
user_prompt = f"根据以下描述生成结构化数据,返回JSON格式:{prompt}"
|
||
else:
|
||
user_prompt = prompt
|
||
|
||
# 调用OpenAI API
|
||
response = openai.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
temperature=0.7,
|
||
max_tokens=1000
|
||
)
|
||
|
||
# 处理响应
|
||
content = response.choices[0].message.content
|
||
|
||
# 根据数据类型返回不同格式的结果
|
||
if data_type == "structured":
|
||
import json
|
||
try:
|
||
return {"data": json.loads(content), "type": data_type}
|
||
except json.JSONDecodeError:
|
||
return {"data": content, "type": data_type}
|
||
else:
|
||
return {"data": content, "type": data_type}
|
||
|
||
except Exception as e:
|
||
print(f"OpenAI API error: {e}")
|
||
return None
|
||
|
||
def generate_image_description(self, image_url: str) -> Optional[str]:
|
||
"""生成图片描述"""
|
||
try:
|
||
response = openai.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "请描述这张图片的内容,包括颜色、形状、物体等特征。"},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": image_url
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
temperature=0.7,
|
||
max_tokens=500
|
||
)
|
||
|
||
return response.choices[0].message.content
|
||
|
||
except Exception as e:
|
||
print(f"OpenAI API error: {e}")
|
||
return None
|
||
|
||
|
||
# 创建全局OpenAI客户端实例
|
||
openai_client = OpenAIClient()
|