175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
import logging
|
||
import os
|
||
from typing import List, Dict, Any, Optional
|
||
import openai
|
||
from .config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class OpenAIProxy:
|
||
"""OpenAI代理"""
|
||
|
||
def __init__(self):
|
||
"""初始化OpenAI代理"""
|
||
logger.info("初始化OpenAI代理")
|
||
# 设置API密钥
|
||
openai.api_key = settings.API_KEY
|
||
if settings.API_BASE:
|
||
openai.api_base = settings.API_BASE
|
||
|
||
def complete(self, model: str, messages: list, temperature: float = 0.7,
|
||
max_tokens: int = 1000) -> Dict[str, Any]:
|
||
"""完成聊天请求
|
||
|
||
Args:
|
||
model: 模型名称
|
||
messages: 消息列表
|
||
temperature: 温度参数
|
||
max_tokens: 最大令牌数
|
||
|
||
Returns:
|
||
完成结果
|
||
"""
|
||
try:
|
||
response = openai.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
|
||
# 转换为字典格式
|
||
return {
|
||
"id": response.id,
|
||
"object": response.object,
|
||
"created": response.created,
|
||
"model": response.model,
|
||
"choices": [
|
||
{
|
||
"index": choice.index,
|
||
"message": {
|
||
"role": choice.message.role,
|
||
"content": choice.message.content
|
||
},
|
||
"finish_reason": choice.finish_reason
|
||
}
|
||
for choice in response.choices
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": response.usage.prompt_tokens,
|
||
"completion_tokens": response.usage.completion_tokens,
|
||
"total_tokens": response.usage.total_tokens
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"OpenAI completion error: {str(e)}")
|
||
# 返回模拟响应,用于演示
|
||
return self._mock_completion(messages, model)
|
||
|
||
def generate_simulation_input(self, prompt: str, input_type: str = "text") -> Dict[str, Any]:
|
||
"""生成仿真输入数据
|
||
|
||
Args:
|
||
prompt: 用户描述的场景
|
||
input_type: 输入类型,支持 "text", "image", "table"
|
||
|
||
Returns:
|
||
生成的仿真输入数据
|
||
"""
|
||
try:
|
||
# 根据输入类型构建不同的提示词
|
||
if input_type == "text":
|
||
system_prompt = "你是一个文本数据生成器,根据用户描述生成相应的文本数据"
|
||
user_prompt = f"请根据以下描述生成文本数据:{prompt}"
|
||
elif input_type == "image":
|
||
system_prompt = "你是一个图像描述生成器,根据用户描述生成详细的图像描述"
|
||
user_prompt = f"请根据以下描述生成详细的图像描述:{prompt}"
|
||
elif input_type == "table":
|
||
system_prompt = "你是一个表格数据生成器,根据用户描述生成结构化的表格数据"
|
||
user_prompt = f"请根据以下描述生成结构化的表格数据:{prompt}"
|
||
else:
|
||
system_prompt = "你是一个数据生成器,根据用户描述生成相应的数据"
|
||
user_prompt = f"请根据以下描述生成数据:{prompt}"
|
||
|
||
# 调用OpenAI API
|
||
response = openai.chat.completions.create(
|
||
model=settings.MODEL,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
temperature=settings.TEMPERATURE,
|
||
max_tokens=settings.MAX_TOKENS
|
||
)
|
||
|
||
# 处理响应
|
||
generated_content = response.choices[0].message.content
|
||
|
||
return {
|
||
"success": True,
|
||
"data": generated_content,
|
||
"input_type": input_type
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"OpenAI simulation input generation error: {str(e)}")
|
||
# 返回模拟响应,用于演示
|
||
return self._mock_simulation_input(prompt, input_type)
|
||
|
||
def _mock_completion(self, messages: list, model: str) -> Dict[str, Any]:
|
||
"""模拟完成响应,用于演示
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
|
||
Returns:
|
||
模拟的完成结果
|
||
"""
|
||
return {
|
||
"id": "chat-mock-123",
|
||
"object": "chat.completion",
|
||
"created": 1677825464,
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": "这是一个模拟的响应,用于演示OpenAI代理服务"
|
||
},
|
||
"finish_reason": "stop"
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": 10,
|
||
"completion_tokens": 20,
|
||
"total_tokens": 30
|
||
}
|
||
}
|
||
|
||
def _mock_simulation_input(self, prompt: str, input_type: str) -> Dict[str, Any]:
|
||
"""模拟生成仿真输入数据,用于演示
|
||
|
||
Args:
|
||
prompt: 用户描述的场景
|
||
input_type: 输入类型
|
||
|
||
Returns:
|
||
模拟的生成结果
|
||
"""
|
||
if input_type == "text":
|
||
data = f"这是根据描述生成的文本数据:{prompt}"
|
||
elif input_type == "image":
|
||
data = f"这是根据描述生成的图像描述:{prompt}"
|
||
elif input_type == "table":
|
||
data = f"这是根据描述生成的表格数据:{prompt}"
|
||
else:
|
||
data = f"这是根据描述生成的数据:{prompt}"
|
||
|
||
return {
|
||
"success": True,
|
||
"data": data,
|
||
"input_type": input_type
|
||
}
|