diff --git a/biz/base_detector.py b/biz/base_detector.py new file mode 100644 index 0000000..3c65010 --- /dev/null +++ b/biz/base_detector.py @@ -0,0 +1,62 @@ +from collections import deque +from typing import Optional +import numpy as np + + +class BaseDetector: + """ + 检测器基类 + 提供通用的帧回溯缓存功能,子类可按需使用 + """ + + def __init__(self): + # 帧回溯缓存(子类需要时调用 init_frame_buffer 初始化) + self._frame_buffer: Optional[deque] = None + + def init_frame_buffer(self, buffer_seconds: float, fps: float): + """ + 初始化帧回溯缓存队列 + + Args: + buffer_seconds: 需要缓存的时间长度(秒) + fps: 视频帧率 + """ + maxlen = int(buffer_seconds * fps) + self._frame_buffer = deque(maxlen=maxlen) + + def append_frame(self, frame: np.ndarray, timestamp: float): + """ + 将当前帧入队缓存 + + Args: + frame: 当前帧图像 + timestamp: 当前帧的时间戳 + """ + if self._frame_buffer is not None: + self._frame_buffer.append({ + 'timestamp': timestamp, + 'frame': frame.copy(), + }) + + def find_target_frame(self, target_time_sec: float) -> Optional[np.ndarray]: + """ + 在帧缓存中找到最接近目标时间的帧 + + Args: + target_time_sec: 目标时间戳 + + Returns: + 最接近目标时间的帧图像,缓存为空则返回 None + """ + if self._frame_buffer is None or len(self._frame_buffer) == 0: + return None + + target_frame = None + min_time_diff = float('inf') + for buffered in self._frame_buffer: + time_diff = abs(buffered['timestamp'] - target_time_sec) + if time_diff < min_time_diff: + min_time_diff = time_diff + target_frame = buffered['frame'] + + return target_frame diff --git a/biz/checkpoint/checkpoint_biz.py b/biz/checkpoint/checkpoint_biz.py index 6447757..7aa6f02 100644 --- a/biz/checkpoint/checkpoint_biz.py +++ b/biz/checkpoint/checkpoint_biz.py @@ -1,16 +1,14 @@ import cv2 import numpy as np from typing import Dict, Any -import threading -import queue -from collections import deque from biz.base_frame_processor import BaseFrameProcessorWorker +from biz.base_detector import BaseDetector # -------------------------- Kadian 检测相关导入 -------------------------- from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机) -from algorithm.common.npu_yolo_pose_onnx import YOLOv8_Pose_ONNX # Pose 专用模型 +# from algorithm.common.npu_yolo_pose_onnx import YOLOv8_Pose_ONNX # Pose 专用模型 from yolox.tracker.byte_tracker import BYTETracker from utils.logger import get_logger @@ -55,8 +53,9 @@ PERSON_CAR_INPUT_SIZE = 640 RTSP_TARGET_FPS = 10.0 # ========================= Kadian TrafficMonitor(精简版,专为服务设计) ========================= -class KadianDetector: +class KadianDetector(BaseDetector): def __init__(self, params=None): + super().__init__() # 摄像头额外参数 self.params = params if params is not None else {} @@ -147,8 +146,8 @@ class KadianDetector: self.nobody_frames = 0 # 累计无人在场帧数 self.only_one_frames = 0 # 累计单人在场帧数 - self.max_car_frames = int((15.0 + self.TIME_TOLERANCE_CAR) * self.fps) # - self.frame_buffer_ignore_untrunk = deque(maxlen=self.max_car_frames) + buffer_seconds = 15.0 + self.TIME_TOLERANCE_CAR + self.init_frame_buffer(buffer_seconds, self.fps) self.untrunk_rollback_time = 12.0 # 未检查后备箱需要回溯的时间 self.ignored_rollback_time = 12.0 # 漏检需要回溯的时间 @@ -219,21 +218,6 @@ class KadianDetector: x1, y1, x2, y2 = box return x1 < px < x2 and y1 < py < y2 - def find_target_frame(self, target_time_sec): - - target_frame = None - min_time_diff = float('inf') - for buffered in self.frame_buffer_ignore_untrunk: - time_diff = abs(buffered['timestamp'] - target_time_sec) - if time_diff < min_time_diff: - min_time_diff = time_diff - target_frame = buffered['frame'] - # 如果没找到,返回最早的帧 - if target_frame is None and len(self.frame_buffer_ignore_untrunk) > 0: - target_frame = self.frame_buffer_ignore_untrunk[0]['frame'] - - return target_frame - def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]: h, w = frame.shape[:2] self.width, self.height = w, h @@ -402,11 +386,7 @@ class KadianDetector: cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) # 每帧保存到缓存(移到循环外,确保每帧只写入一次) - self.frame_buffer_ignore_untrunk.append({ - 'frame_idx': self.current_frame_idx, - 'timestamp': current_time_sec, - 'frame': frame.copy(), - }) + self.append_frame(frame, current_time_sec) # ========================================== # 关联分析: 哪个后备箱属于哪辆车?