first commit
This commit is contained in:
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
BIN
backend/app/utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/app/utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
backend/app/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/cache.cpython-312.pyc
Normal file
BIN
backend/app/utils/__pycache__/cache.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/cache.cpython-39.pyc
Normal file
BIN
backend/app/utils/__pycache__/cache.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/file.cpython-312.pyc
Normal file
BIN
backend/app/utils/__pycache__/file.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/file.cpython-39.pyc
Normal file
BIN
backend/app/utils/__pycache__/file.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/logger.cpython-312.pyc
Normal file
BIN
backend/app/utils/__pycache__/logger.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/logger.cpython-39.pyc
Normal file
BIN
backend/app/utils/__pycache__/logger.cpython-39.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/openai.cpython-312.pyc
Normal file
BIN
backend/app/utils/__pycache__/openai.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/app/utils/__pycache__/openai.cpython-39.pyc
Normal file
BIN
backend/app/utils/__pycache__/openai.cpython-39.pyc
Normal file
Binary file not shown.
80
backend/app/utils/cache.py
Normal file
80
backend/app/utils/cache.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import redis
|
||||
import json
|
||||
from typing import Optional, Any
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""Redis缓存类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Redis连接"""
|
||||
self.redis_client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30,
|
||||
max_connections=50
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存值"""
|
||||
try:
|
||||
value = self.redis_client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Redis get error: {e}")
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
|
||||
"""设置缓存值"""
|
||||
try:
|
||||
value_str = json.dumps(value)
|
||||
if expire:
|
||||
return self.redis_client.setex(key, expire, value_str)
|
||||
else:
|
||||
return self.redis_client.set(key, value_str)
|
||||
except Exception as e:
|
||||
print(f"Redis set error: {e}")
|
||||
return False
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""删除缓存值"""
|
||||
try:
|
||||
return bool(self.redis_client.delete(key))
|
||||
except Exception as e:
|
||||
print(f"Redis delete error: {e}")
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""检查缓存是否存在"""
|
||||
try:
|
||||
return bool(self.redis_client.exists(key))
|
||||
except Exception as e:
|
||||
print(f"Redis exists error: {e}")
|
||||
return False
|
||||
|
||||
def increment(self, key: str, amount: int = 1) -> Optional[int]:
|
||||
"""递增计数器"""
|
||||
try:
|
||||
return self.redis_client.incrby(key, amount)
|
||||
except Exception as e:
|
||||
print(f"Redis increment error: {e}")
|
||||
return None
|
||||
|
||||
def decrement(self, key: str, amount: int = 1) -> Optional[int]:
|
||||
"""递减计数器"""
|
||||
try:
|
||||
return self.redis_client.decrby(key, amount)
|
||||
except Exception as e:
|
||||
print(f"Redis decrement error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局缓存实例
|
||||
cache = RedisCache()
|
||||
184
backend/app/utils/file.py
Normal file
184
backend/app/utils/file.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
from typing import Optional, Tuple
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class MinioClient:
|
||||
"""MinIO客户端类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化MinIO客户端"""
|
||||
try:
|
||||
self.client = Minio(
|
||||
settings.MINIO_ENDPOINT,
|
||||
access_key=settings.MINIO_ACCESS_KEY,
|
||||
secret_key=settings.MINIO_SECRET_KEY,
|
||||
secure=settings.MINIO_SECURE
|
||||
)
|
||||
self.bucket_name = settings.MINIO_BUCKET_NAME
|
||||
self.is_connected = True # 先设置为True,这样在调用其他方法时不会报错
|
||||
|
||||
# 确保存储桶存在
|
||||
self._ensure_bucket_exists()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to connect to MinIO: {e}. Running in offline mode.")
|
||||
self.client = None
|
||||
self.bucket_name = settings.MINIO_BUCKET_NAME
|
||||
self.is_connected = False
|
||||
|
||||
def _ensure_bucket_exists(self):
|
||||
"""确保存储桶存在"""
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
try:
|
||||
if not self.client.bucket_exists(self.bucket_name):
|
||||
self.client.make_bucket(self.bucket_name)
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO bucket error: {e}")
|
||||
|
||||
def upload_file(self, file_path: str, object_name: str) -> bool:
|
||||
"""上传文件"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Upload skipped.")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.fput_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_path
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO upload error: {e}")
|
||||
return False
|
||||
|
||||
def upload_fileobj(self, file_obj: io.BytesIO, object_name: str, content_type: str = "application/octet-stream") -> bool:
|
||||
"""上传文件对象"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Upload skipped.")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_obj,
|
||||
length=-1,
|
||||
part_size=10*1024*1024,
|
||||
content_type=content_type
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO upload error: {e}")
|
||||
return False
|
||||
|
||||
def download_file(self, object_name: str, file_path: str) -> bool:
|
||||
"""下载文件"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Download skipped.")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.fget_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_path
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO download error: {e}")
|
||||
return False
|
||||
|
||||
def get_object(self, object_name: str) -> Optional[bytes]:
|
||||
"""获取对象内容"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Get object skipped.")
|
||||
return None
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
self.bucket_name,
|
||||
object_name
|
||||
)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO get object error: {e}")
|
||||
return None
|
||||
|
||||
def delete_object(self, object_name: str) -> bool:
|
||||
"""删除对象"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Delete object skipped.")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.remove_object(
|
||||
self.bucket_name,
|
||||
object_name
|
||||
)
|
||||
return True
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO delete error: {e}")
|
||||
return False
|
||||
|
||||
def list_objects(self, prefix: str = "") -> list:
|
||||
"""列出对象"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. List objects skipped.")
|
||||
return []
|
||||
|
||||
try:
|
||||
objects = []
|
||||
for obj in self.client.list_objects(
|
||||
self.bucket_name,
|
||||
prefix=prefix,
|
||||
recursive=True
|
||||
):
|
||||
objects.append(obj.object_name)
|
||||
return objects
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO list objects error: {e}")
|
||||
return []
|
||||
|
||||
def get_presigned_url(self, object_name: str, expires: int = 604800) -> Optional[str]:
|
||||
"""获取预签名URL"""
|
||||
if not self.is_connected:
|
||||
logging.warning("MinIO is not connected. Get presigned URL skipped.")
|
||||
return None
|
||||
|
||||
try:
|
||||
url = self.client.presigned_get_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
expires=expires
|
||||
)
|
||||
return url
|
||||
except S3Error as e:
|
||||
logging.warning(f"MinIO presigned url error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局文件存储实例
|
||||
try:
|
||||
file_storage = MinioClient()
|
||||
except Exception as e:
|
||||
# 如果初始化失败,创建一个模拟实例
|
||||
class MockFileStorage:
|
||||
def __getattr__(self, name):
|
||||
def mock_method(*args, **kwargs):
|
||||
logging.warning(f"MinIO is not available. Method '{name}' will not execute.")
|
||||
return None if name.startswith('get_') or name == 'list_objects' else False
|
||||
return mock_method
|
||||
|
||||
file_storage = MockFileStorage()
|
||||
logging.warning(f"Failed to initialize MinIO client: {e}. Using mock instance.")
|
||||
308
backend/app/utils/logger.py
Normal file
308
backend/app/utils/logger.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""日志管理工具,提供结构化日志记录和日志查询功能"""
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class StructuredLogger:
|
||||
"""结构化日志记录器"""
|
||||
|
||||
def __init__(self, name: str = "algorithm_showcase", log_dir: str = "logs"):
|
||||
self.logger = logging.getLogger(name)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# 确保日志目录存在
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(exist_ok=True)
|
||||
|
||||
# 创建轮转文件处理器
|
||||
log_file = log_path / f"{name}.log"
|
||||
handler = RotatingFileHandler(
|
||||
str(log_file),
|
||||
maxBytes=10*1024*1024, # 10MB
|
||||
backupCount=5
|
||||
)
|
||||
|
||||
# 设置格式化器
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# 添加处理器
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(handler)
|
||||
|
||||
def log_event(self, event_type: str, user_id: str = None, algorithm_id: str = None,
|
||||
extra_data: Dict[str, Any] = None, level: int = logging.INFO):
|
||||
"""记录事件日志"""
|
||||
log_data = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"event_type": event_type,
|
||||
"user_id": user_id,
|
||||
"algorithm_id": algorithm_id,
|
||||
"extra_data": extra_data or {}
|
||||
}
|
||||
|
||||
message = json.dumps(log_data, ensure_ascii=False, default=str)
|
||||
self.logger.log(level, message)
|
||||
|
||||
def log_api_call(self, user_id: str, algorithm_id: str, version_id: str,
|
||||
input_size: int, response_time: float, success: bool,
|
||||
error_msg: str = None):
|
||||
"""记录API调用日志"""
|
||||
extra_data = {
|
||||
"version_id": version_id,
|
||||
"input_size": input_size,
|
||||
"response_time": response_time,
|
||||
"success": success
|
||||
}
|
||||
|
||||
if error_msg:
|
||||
extra_data["error"] = error_msg
|
||||
|
||||
self.log_event(
|
||||
event_type="api_call",
|
||||
user_id=user_id,
|
||||
algorithm_id=algorithm_id,
|
||||
extra_data=extra_data,
|
||||
level=logging.INFO if success else logging.ERROR
|
||||
)
|
||||
|
||||
def log_algorithm_execution(self, user_id: str, algorithm_id: str, version_id: str,
|
||||
input_data: Dict[str, Any], output_data: Dict[str, Any],
|
||||
execution_time: float, success: bool, error_msg: str = None):
|
||||
"""记录算法执行日志"""
|
||||
extra_data = {
|
||||
"version_id": version_id,
|
||||
"execution_time": execution_time,
|
||||
"input_summary": self._summarize_data(input_data),
|
||||
"output_summary": self._summarize_data(output_data),
|
||||
"success": success
|
||||
}
|
||||
|
||||
if error_msg:
|
||||
extra_data["error"] = error_msg
|
||||
|
||||
self.log_event(
|
||||
event_type="algorithm_execution",
|
||||
user_id=user_id,
|
||||
algorithm_id=algorithm_id,
|
||||
extra_data=extra_data,
|
||||
level=logging.INFO if success else logging.ERROR
|
||||
)
|
||||
|
||||
def log_system_event(self, event_subtype: str, severity: str, message: str,
|
||||
extra_data: Dict[str, Any] = None):
|
||||
"""记录系统事件日志"""
|
||||
extra_data = extra_data or {}
|
||||
extra_data["subtype"] = event_subtype
|
||||
extra_data["severity"] = severity
|
||||
|
||||
self.log_event(
|
||||
event_type="system_event",
|
||||
extra_data=extra_data,
|
||||
level=self._severity_to_level(severity)
|
||||
)
|
||||
|
||||
def _summarize_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""摘要数据以保护隐私"""
|
||||
if not isinstance(data, dict):
|
||||
return {"type": type(data).__name__, "size": len(str(data)) if hasattr(str(data), '__len__') else 0}
|
||||
|
||||
summary = {}
|
||||
for key, value in list(data.items())[:5]: # 只摘要前5个项目
|
||||
if isinstance(value, (dict, list)):
|
||||
summary[key] = f"<{type(value).__name__}>"
|
||||
else:
|
||||
summary[key] = str(value)[:100] # 限制长度
|
||||
return summary
|
||||
|
||||
def _severity_to_level(self, severity: str) -> int:
|
||||
"""将严重程度转换为日志级别"""
|
||||
severity_map = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL
|
||||
}
|
||||
return severity_map.get(severity.lower(), logging.INFO)
|
||||
|
||||
|
||||
class LogQuery:
|
||||
"""日志查询工具"""
|
||||
|
||||
def __init__(self, log_dir: str = "logs"):
|
||||
self.log_dir = Path(log_dir)
|
||||
|
||||
def search_logs(self,
|
||||
start_date: datetime = None,
|
||||
end_date: datetime = None,
|
||||
event_types: List[str] = None,
|
||||
user_ids: List[str] = None,
|
||||
algorithm_ids: List[str] = None,
|
||||
log_levels: List[str] = None,
|
||||
limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""搜索日志"""
|
||||
results = []
|
||||
|
||||
# 确定搜索的日志文件
|
||||
log_files = list(self.log_dir.glob("*.log"))
|
||||
log_files.sort(reverse=True) # 最新的文件优先
|
||||
|
||||
# 转换日志级别
|
||||
if log_levels:
|
||||
level_map = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
log_levels = [level_map.get(level.upper()) for level in log_levels if level.upper() in level_map]
|
||||
|
||||
for log_file in log_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
# 解析日志行
|
||||
parsed_line = self._parse_log_line(line.strip())
|
||||
if parsed_line and self._matches_filters(
|
||||
parsed_line, start_date, end_date, event_types,
|
||||
user_ids, algorithm_ids, log_levels
|
||||
):
|
||||
results.append(parsed_line)
|
||||
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
except Exception:
|
||||
# 跳过无法解析的行
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading log file {log_file}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def _parse_log_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析日志行"""
|
||||
try:
|
||||
# 日志行通常是 "时间 - 名称 - 级别 - JSON数据" 的格式
|
||||
parts = line.split(" - ", 3)
|
||||
if len(parts) >= 4:
|
||||
timestamp_str = parts[0]
|
||||
logger_name = parts[1]
|
||||
level = parts[2]
|
||||
message = parts[3]
|
||||
|
||||
# 尝试解析JSON消息
|
||||
if message.startswith('{') and message.endswith('}'):
|
||||
log_data = json.loads(message)
|
||||
log_data["timestamp"] = timestamp_str
|
||||
log_data["logger"] = logger_name
|
||||
log_data["level"] = level
|
||||
return log_data
|
||||
else:
|
||||
# 如果不是JSON,创建基本结构
|
||||
return {
|
||||
"timestamp": timestamp_str,
|
||||
"logger": logger_name,
|
||||
"level": level,
|
||||
"message": message,
|
||||
"raw_line": line
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _matches_filters(self, log_entry: Dict[str, Any],
|
||||
start_date: datetime, end_date: datetime,
|
||||
event_types: List[str], user_ids: List[str],
|
||||
algorithm_ids: List[str], log_levels: List[int]) -> bool:
|
||||
"""检查日志条目是否匹配过滤器"""
|
||||
# 时间范围检查
|
||||
if start_date or end_date:
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(log_entry.get("timestamp", "").replace("Z", "+00:00"))
|
||||
if start_date and entry_time < start_date:
|
||||
return False
|
||||
if end_date and entry_time > end_date:
|
||||
return False
|
||||
except ValueError:
|
||||
pass # 如果时间格式不正确,跳过时间检查
|
||||
|
||||
# 事件类型检查
|
||||
if event_types and log_entry.get("event_type") not in event_types:
|
||||
return False
|
||||
|
||||
# 用户ID检查
|
||||
if user_ids and log_entry.get("user_id") not in user_ids:
|
||||
return False
|
||||
|
||||
# 算法ID检查
|
||||
if algorithm_ids and log_entry.get("algorithm_id") not in algorithm_ids:
|
||||
return False
|
||||
|
||||
# 日志级别检查
|
||||
if log_levels:
|
||||
level_map = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
entry_level = level_map.get(log_entry.get("level", "").upper())
|
||||
if entry_level not in log_levels:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_log_stats(self, days: int = 7) -> Dict[str, Any]:
|
||||
"""获取日志统计信息"""
|
||||
start_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
logs = self.search_logs(start_date=start_date)
|
||||
|
||||
# 统计信息
|
||||
stats = {
|
||||
"total_logs": len(logs),
|
||||
"by_event_type": {},
|
||||
"by_level": {},
|
||||
"by_day": {},
|
||||
"error_count": 0
|
||||
}
|
||||
|
||||
for log in logs:
|
||||
# 按事件类型统计
|
||||
event_type = log.get("event_type", "unknown")
|
||||
stats["by_event_type"][event_type] = stats["by_event_type"].get(event_type, 0) + 1
|
||||
|
||||
# 按级别统计
|
||||
level = log.get("level", "UNKNOWN")
|
||||
stats["by_level"][level] = stats["by_level"].get(level, 0) + 1
|
||||
|
||||
# 按天统计
|
||||
day = log.get("timestamp", "")[:10] # YYYY-MM-DD
|
||||
if day:
|
||||
stats["by_day"][day] = stats["by_day"].get(day, 0) + 1
|
||||
|
||||
# 错误计数
|
||||
if log.get("level") in ["ERROR", "CRITICAL"]:
|
||||
stats["error_count"] += 1
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# 全局日志记录器实例
|
||||
structured_logger = StructuredLogger()
|
||||
log_query = LogQuery()
|
||||
91
backend/app/utils/openai.py
Normal file
91
backend/app/utils/openai.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import openai
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class OpenAIClient:
|
||||
"""OpenAI客户端类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化OpenAI客户端"""
|
||||
if settings.OPENAI_API_KEY:
|
||||
openai.api_key = settings.OPENAI_API_KEY
|
||||
self.model = settings.OPENAI_MODEL
|
||||
|
||||
def generate_simulation_data(self, prompt: str, data_type: str = "text") -> Optional[Dict[str, Any]]:
|
||||
"""生成仿真输入数据"""
|
||||
try:
|
||||
# 构建系统提示
|
||||
system_prompt = f"你是一个数据生成助手,根据用户的描述生成{data_type}类型的仿真输入数据。"
|
||||
|
||||
# 构建用户提示
|
||||
if data_type == "image":
|
||||
user_prompt = f"根据以下描述生成一张图片的详细描述,包括颜色、形状、大小等特征:{prompt}"
|
||||
elif data_type == "text":
|
||||
user_prompt = f"根据以下描述生成一段文本数据:{prompt}"
|
||||
elif data_type == "structured":
|
||||
user_prompt = f"根据以下描述生成结构化数据,返回JSON格式:{prompt}"
|
||||
else:
|
||||
user_prompt = prompt
|
||||
|
||||
# 调用OpenAI API
|
||||
response = openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 根据数据类型返回不同格式的结果
|
||||
if data_type == "structured":
|
||||
import json
|
||||
try:
|
||||
return {"data": json.loads(content), "type": data_type}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": content, "type": data_type}
|
||||
else:
|
||||
return {"data": content, "type": data_type}
|
||||
|
||||
except Exception as e:
|
||||
print(f"OpenAI API error: {e}")
|
||||
return None
|
||||
|
||||
def generate_image_description(self, image_url: str) -> Optional[str]:
|
||||
"""生成图片描述"""
|
||||
try:
|
||||
response = openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请描述这张图片的内容,包括颜色、形状、物体等特征。"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
print(f"OpenAI API error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局OpenAI客户端实例
|
||||
openai_client = OpenAIClient()
|
||||
Reference in New Issue
Block a user