diff --git a/biz/checkpoint/checkpoint_biz.py b/biz/checkpoint/checkpoint_biz.py index e65d648..4228b15 100644 --- a/biz/checkpoint/checkpoint_biz.py +++ b/biz/checkpoint/checkpoint_biz.py @@ -14,13 +14,10 @@ from algorithm.common.npu_yolo_pose_onnx import YOLOv8_Pose_ONNX from yolox.tracker.byte_tracker import BYTETracker from utils.logger import get_logger +from common.constants import MODEL_ROOT_PATH logger = get_logger(__name__) -# ========================= 配置区 ========================= -# Kadian 模型路径与ROI(可根据实际情况修改) -#DETECT_MODEL_PATH = 'YOLO_Weight/Kadian.onnx' -DETECT_MODEL_PATH = r'D:\Python_Save\PoliceProject\Yolo_Weight\Kadian\Kadian_wanjia_xinkailing.onnx' -#POSE_MODEL_PATH = 'YOLO_Weight/yolov8l-pose.onnx' +DETECT_MODEL_PATH = 'YOLO_Weight/Kadian.onnx' # 默认相对ROI(与原文件一致) #ROI_RELATIVE = np.array([ @@ -63,9 +60,17 @@ class KadianDetector: # 摄像头额外参数 self.params = params if params is not None else {} + # 模型路径:从 params 读取,未配置则使用默认值 DETECT_MODEL_PATH + model_path = self.params.get('model_path') + if model_path: + full_model_path = f"{MODEL_ROOT_PATH}/{model_path}" + else: + full_model_path = DETECT_MODEL_PATH + logger.info(f"Loading model from: {full_model_path}") + # 模型加载 self.detector = YOLOv8_ONNX( - DETECT_MODEL_PATH, + full_model_path, conf_threshold=0.6, iou_threshold=0.65, input_size=PERSON_CAR_INPUT_SIZE