print改为logger打印
This commit is contained in:
@@ -5,6 +5,9 @@ import onnxruntime as ort
|
||||
import os
|
||||
import time
|
||||
|
||||
from utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
|
||||
shape = img.shape[:2] # h, w
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
@@ -35,23 +38,23 @@ class YOLOv8_ONNX:
|
||||
|
||||
self.session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
actual_providers = self.session.get_providers()
|
||||
print("YOLO Providers:", actual_providers)
|
||||
logger.info("YOLO Providers:", actual_providers)
|
||||
|
||||
if "CANNExecutionProvider" in actual_providers:
|
||||
print("[INFO] YOLO 使用 CANNExecutionProvider(昇腾 NPU)")
|
||||
logger.info("[INFO] YOLO 使用 CANNExecutionProvider(昇腾 NPU)")
|
||||
elif 'CUDAExecutionProvider' in actual_providers:
|
||||
print("[INFO] YOLO 使用 CUDAExecutionProvider(NVIDIA GPU)")
|
||||
logger.info("[INFO] YOLO 使用 CUDAExecutionProvider(NVIDIA GPU)")
|
||||
else:
|
||||
print("[INFO] YOLO 使用 CPUExecutionProvider")
|
||||
logger.info("[INFO] YOLO 使用 CPUExecutionProvider")
|
||||
|
||||
self.conf_threshold = conf_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.input_name = self.session.get_inputs()[0].name
|
||||
self.input_size = (input_size, input_size) if isinstance(input_size, int) else input_size
|
||||
|
||||
print(f"模型输入名称: {self.input_name}")
|
||||
print(f"模型输入形状: {self.session.get_inputs()[0].shape}")
|
||||
print(f"模型输出形状: {self.session.get_outputs()[0].shape}")
|
||||
logger.info(f"模型输入名称: {self.input_name}")
|
||||
logger.info(f"模型输入形状: {self.session.get_inputs()[0].shape}")
|
||||
logger.info(f"模型输出形状: {self.session.get_outputs()[0].shape}")
|
||||
|
||||
def preprocess(self, img):
|
||||
self.orig_shape = img.shape[:2]
|
||||
|
||||
@@ -9,7 +9,8 @@ import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
from utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# -------------------------------------------------
|
||||
@@ -89,14 +90,14 @@ class YOLOv8_Pose_ONNX:
|
||||
# 获取真实工作 provider
|
||||
actual_providers = self.session.get_providers()
|
||||
|
||||
print("YOLO Providers:", actual_providers)
|
||||
logger.info("YOLO Providers:", actual_providers)
|
||||
|
||||
if "CANNExecutionProvider" in actual_providers:
|
||||
print("[INFO] YOLO 使用 CANNExecutionProvider(昇腾)")
|
||||
logger.info("[INFO] YOLO 使用 CANNExecutionProvider(昇腾)")
|
||||
elif 'CUDAExecutionProvider' in actual_providers:
|
||||
print("[INFO] YOLO 使用 CUDAExecutionProvider(NVIDIA GPU)")
|
||||
logger.info("[INFO] YOLO 使用 CUDAExecutionProvider(NVIDIA GPU)")
|
||||
else:
|
||||
print("[INFO] YOLO 使用 CPUExecutionProvider(非昇腾环境)")
|
||||
logger.info("[INFO] YOLO 使用 CPUExecutionProvider(非昇腾环境)")
|
||||
|
||||
self.conf_threshold = conf_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
@@ -104,9 +105,9 @@ class YOLOv8_Pose_ONNX:
|
||||
|
||||
self.input_name = self.session.get_inputs()[0].name
|
||||
self.input_size = (input_size, input_size)
|
||||
print(f"模型输入名称: {self.input_name}")
|
||||
print(f"模型输入形状: {self.session.get_inputs()[0].shape}")
|
||||
print(f"模型输出形状: {self.session.get_outputs()[0].shape}")
|
||||
logger.info(f"模型输入名称: {self.input_name}")
|
||||
logger.info(f"模型输入形状: {self.session.get_inputs()[0].shape}")
|
||||
logger.info(f"模型输出形状: {self.session.get_outputs()[0].shape}")
|
||||
|
||||
|
||||
def nms(self, boxes, scores, iou_threshold=0.45):
|
||||
|
||||
Reference in New Issue
Block a user