更新跟踪框匹配逻辑
This commit is contained in:
@@ -6,51 +6,73 @@ from collections import deque
|
|||||||
from biz.base_frame_processor import BaseFrameProcessorWorker
|
from biz.base_frame_processor import BaseFrameProcessorWorker
|
||||||
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX
|
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX
|
||||||
from yolox.tracker.byte_tracker import BYTETracker
|
from yolox.tracker.byte_tracker import BYTETracker
|
||||||
|
from common.constants import MODEL_ROOT_PATH
|
||||||
|
|
||||||
# ========================= 走廊场景专属配置 =========================
|
# ========================= 走廊场景专属配置 =========================
|
||||||
MODEL_PATH = 'YOLO_Weight/kanshousuo.onnx'
|
DETECT_MODEL_PATH = 'YOLO_Weight/kanshousuo.onnx' # 犯人检测onnx模型路径
|
||||||
INPUT_SIZE = 640
|
INPUT_SIZE = 640 # 模型输入尺寸
|
||||||
RTSP_FPS = 10
|
RTSP_FPS = 10 # 视频流目标FPS
|
||||||
ALERT_PUSH_INTERVAL = 10
|
ALERT_PUSH_INTERVAL = 5 # 相同报警5秒内仅推送1次
|
||||||
ALERT_PUSH_URL = "http://123.57.151.210:10000/picenter/websocket/test/process"
|
ALERT_PUSH_URL = "http://123.57.151.210:10000/picenter/websocket/test/process"
|
||||||
ROI_LOST_FRAMES_THRESH = int(1 * RTSP_FPS)
|
# 消失判定:中心点在ROI内消失后,持续无检测的帧数(1.0秒,可微调)
|
||||||
|
ROI_LOST_FRAMES_THRESH = int(0.5 * RTSP_FPS) # todo: 从frame改为时间
|
||||||
|
|
||||||
# ========================= ROI区域配置 =========================
|
# ========================= 默认ROI区域配置(当config.yaml未配置时使用) =========================
|
||||||
ROI_CONFIG = {
|
DEFAULT_DOOR_ROIS = {
|
||||||
"left": [[0.195, 0.245], [0.42, 0], [0.421, 0.185], [0.248, 0.8]],
|
"left": {
|
||||||
"right": [[0.575, 0.], [0.81, 0.22], [0.78, 0.8], [0.575, 0.185]],
|
"points": [[0.195, 0.245], [0.42, 0], [0.421, 0.185], [0.248, 0.8]],
|
||||||
|
"color": [255, 0, 0]
|
||||||
|
},
|
||||||
|
"right": {
|
||||||
|
"points": [[0.575, 0.], [0.81, 0.22], [0.78, 0.8], [0.575, 0.185]],
|
||||||
|
"color": [255, 0, 0]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ==================================================================================
|
||||||
class PrisonerDoorDetector:
|
class PrisonerDoorDetector:
|
||||||
def __init__(self, params=None):
|
def __init__(self, params=None):
|
||||||
self.params = params or {}
|
self.params = params or {}
|
||||||
|
|
||||||
# 1. 加载YOLO模型 - 降低阈值提高检测率
|
# 0. 从params解析ROI配置,无则使用默认值
|
||||||
|
door_rois_config = self.params.get('door_rois', DEFAULT_DOOR_ROIS)
|
||||||
|
self.roi_config = {}
|
||||||
|
self.roi_colors = {}
|
||||||
|
for door_name, door_cfg in door_rois_config.items():
|
||||||
|
self.roi_config[door_name] = door_cfg['points']
|
||||||
|
self.roi_colors[door_name] = tuple(door_cfg['color'])
|
||||||
|
|
||||||
|
model_path = self.params.get('model_path')
|
||||||
|
if model_path:
|
||||||
|
full_model_path = f"{MODEL_ROOT_PATH}/{model_path}"
|
||||||
|
else:
|
||||||
|
full_model_path = DETECT_MODEL_PATH
|
||||||
|
|
||||||
self.detector = YOLOv8_ONNX(
|
self.detector = YOLOv8_ONNX(
|
||||||
MODEL_PATH,
|
full_model_path,
|
||||||
conf_threshold=0.57, # 进一步降低,捕获更多检测
|
conf_threshold=0.7, # 置信度阈值,可根据模型精度调整
|
||||||
iou_threshold=0.4,
|
iou_threshold=0.4, # IOU阈值
|
||||||
input_size=INPUT_SIZE
|
input_size=INPUT_SIZE
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. ByteTracker参数优化
|
# 2. 初始化ByteTracker跟踪器(适配走廊单/多犯人跟踪)
|
||||||
class TrackerArgs:
|
class TrackerArgs:
|
||||||
track_thresh = 0.65 # 更低的跟踪阈值
|
track_thresh = 0.65
|
||||||
track_buffer = 60 # 更大的缓冲,应对短暂消失
|
track_buffer = 60 # 减小缓冲避免跟踪漂移
|
||||||
match_thresh = 0.5 # 更宽松的匹配
|
match_thresh = 0.5
|
||||||
mot20 = False
|
mot20 = False
|
||||||
|
|
||||||
self.tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_FPS)
|
self.tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_FPS)
|
||||||
|
|
||||||
# 3. 状态变量
|
# 3. 状态变量初始化
|
||||||
self.last_alert_time = 0.0
|
self.last_alert_time = 0.0 # 最后报警时间(防重复推送)
|
||||||
self.frame_width = 0
|
# 犯人跟踪信息:{track_id: {'is_cx_in_roi': 中心点是否在ROI, 'lost_frames': 消失帧数, 'lost_roi': 消失的ROI名称, 'last_cxcy': 最后中心点坐标}}
|
||||||
self.frame_height = 0
|
self.prisoner_track_info = {}
|
||||||
self.roi_abs_cache = {}
|
self.frame_width = 0 # 帧宽度(动态获取)
|
||||||
|
self.frame_height = 0 # 帧高度(动态获取)
|
||||||
|
self.roi_abs_cache = {} # ROI绝对坐标缓存:{roi_name: np.int32数组}
|
||||||
self.entry_frame_cache = {}
|
self.entry_frame_cache = {}
|
||||||
|
|
||||||
# 【核心改进】基于位置的跟踪状态管理
|
# 基于位置的跟踪状态管理
|
||||||
self.active_targets = {} # {target_id: {...}}
|
self.active_targets = {} # {target_id: {...}}
|
||||||
self.next_target_id = 0
|
self.next_target_id = 0
|
||||||
self.position_history = {} # {target_id: deque of positions}
|
self.position_history = {} # {target_id: deque of positions}
|
||||||
@@ -67,33 +89,31 @@ class PrisonerDoorDetector:
|
|||||||
return np.sqrt((cx1 - cx2) ** 2 + (cy1 - cy2) ** 2)
|
return np.sqrt((cx1 - cx2) ** 2 + (cy1 - cy2) ** 2)
|
||||||
|
|
||||||
def compute_iou(self, boxA, boxB):
|
def compute_iou(self, boxA, boxB):
|
||||||
"""IOU计算"""
|
"""IOU计算:匹配跟踪框与犯人检测框,过滤非犯人目标"""
|
||||||
xA = max(boxA[0], boxB[0])
|
xA = max(boxA[0], boxB[0])
|
||||||
yA = max(boxA[1], boxB[1])
|
yA = max(boxA[1], boxB[1])
|
||||||
xB = min(boxA[2], boxB[2])
|
xB = min(boxA[2], boxB[2])
|
||||||
yB = min(boxA[3], boxB[3])
|
yB = min(boxA[3], boxB[3])
|
||||||
|
|
||||||
interW = max(0, xB - xA)
|
interW = max(0, xB - xA)
|
||||||
interH = max(0, yB - yA)
|
interH = max(0, yB - yA)
|
||||||
interArea = interW * interH
|
interArea = interW * interH
|
||||||
|
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
||||||
boxAArea = max(0, (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
||||||
boxBArea = max(0, (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
|
||||||
|
|
||||||
unionArea = boxAArea + boxBArea - interArea
|
unionArea = boxAArea + boxBArea - interArea
|
||||||
return interArea / unionArea if unionArea > 0 else 0.0
|
return interArea / unionArea if unionArea > 0 else 0.0
|
||||||
|
|
||||||
def _get_roi_abs(self, roi_name):
|
def _get_roi_abs(self, roi_name):
|
||||||
"""相对坐标转绝对像素坐标"""
|
"""相对坐标转绝对像素坐标(适配当前帧分辨率,OpenCV要求int32)"""
|
||||||
if roi_name not in ROI_CONFIG:
|
if roi_name not in self.roi_config:
|
||||||
return None
|
return None
|
||||||
roi_rel = np.array(ROI_CONFIG[roi_name], dtype=np.float64)
|
roi_rel = np.array(self.roi_config[roi_name], dtype=np.float64)
|
||||||
roi_abs = roi_rel * np.array([self.frame_width, self.frame_height])
|
roi_abs = roi_rel * np.array([self.frame_width, self.frame_height])
|
||||||
return roi_abs.astype(np.int32)
|
return roi_abs.astype(np.int32)
|
||||||
|
|
||||||
def is_cxcy_in_roi(self, cx, cy):
|
def is_cxcy_in_roi(self, cx, cy):
|
||||||
"""判断中心点是否在ROI内"""
|
"""判断犯人框**中心点(cx,cy)** 是否在任意ROI内,返回:(是否在ROI, 所在ROI名称)"""
|
||||||
for roi_name, roi_abs in self.roi_abs_cache.items():
|
for roi_name, roi_abs in self.roi_abs_cache.items():
|
||||||
|
# OpenCV点在多边形内判定:>=0 表示在内部/边上
|
||||||
if cv2.pointPolygonTest(roi_abs, (cx, cy), False) >= 0:
|
if cv2.pointPolygonTest(roi_abs, (cx, cy), False) >= 0:
|
||||||
return (True, roi_name)
|
return (True, roi_name)
|
||||||
return (False, "outside")
|
return (False, "outside")
|
||||||
@@ -142,57 +162,63 @@ class PrisonerDoorDetector:
|
|||||||
|
|
||||||
return best_match_id, best_match_score
|
return best_match_id, best_match_score
|
||||||
|
|
||||||
def push_alert(self, camera_id, target_id, lost_roi, last_cxcy, timestamp, entry_frame):
|
# def push_alert(self, camera_id, target_id, lost_roi, last_cxcy, timestamp, entry_frame):
|
||||||
"""报警推送"""
|
# """报警推送"""
|
||||||
current_time = time.time()
|
# current_time = time.time()
|
||||||
if current_time - self.last_alert_time < ALERT_PUSH_INTERVAL:
|
# if current_time - self.last_alert_time < ALERT_PUSH_INTERVAL:
|
||||||
return False
|
# return False
|
||||||
|
#
|
||||||
|
# _, frame_encoded = cv2.imencode('.jpg', entry_frame)
|
||||||
|
# frame_base64 = frame_encoded.tobytes()
|
||||||
|
#
|
||||||
|
# alert_info = {
|
||||||
|
# "camera_id": camera_id,
|
||||||
|
# "alert_type": "prisoner_cx_disappear_in_roi",
|
||||||
|
# "prisoner_track_id": target_id,
|
||||||
|
# "disappear_roi": lost_roi,
|
||||||
|
# "last_cx": round(last_cxcy[0], 2),
|
||||||
|
# "last_cy": round(last_cxcy[1], 2),
|
||||||
|
# "timestamp": timestamp,
|
||||||
|
# "entry_frame_base64": frame_base64,
|
||||||
|
# "details": f"犯人框中心点在{lost_roi}区域内消失"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# requests.post(ALERT_PUSH_URL, json=alert_info, timeout=3)
|
||||||
|
# print(f"[报警成功] target_id={target_id}, roi={lost_roi}")
|
||||||
|
# self.last_alert_time = current_time
|
||||||
|
# return True
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"[报警失败] {str(e)}")
|
||||||
|
# return False
|
||||||
|
|
||||||
_, frame_encoded = cv2.imencode('.jpg', entry_frame)
|
|
||||||
frame_base64 = frame_encoded.tobytes()
|
|
||||||
|
|
||||||
alert_info = {
|
|
||||||
"camera_id": camera_id,
|
|
||||||
"alert_type": "prisoner_cx_disappear_in_roi",
|
|
||||||
"prisoner_track_id": target_id,
|
|
||||||
"disappear_roi": lost_roi,
|
|
||||||
"last_cx": round(last_cxcy[0], 2),
|
|
||||||
"last_cy": round(last_cxcy[1], 2),
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"entry_frame_base64": frame_base64,
|
|
||||||
"details": f"犯人框中心点在{lost_roi}区域内消失"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
requests.post(ALERT_PUSH_URL, json=alert_info, timeout=3)
|
|
||||||
print(f"[报警成功] target_id={target_id}, roi={lost_roi}")
|
|
||||||
self.last_alert_time = current_time
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[报警失败] {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def process_frame(self, frame, camera_id: int, timestamp: float) -> dict:
|
def process_frame(self, frame, camera_id: int, timestamp: float) -> dict:
|
||||||
"""核心帧处理 - 增强检测版"""
|
"""
|
||||||
|
核心帧处理:
|
||||||
|
1. 绘制5个ROI区域 2. 检测+跟踪犯人 3. 判定中心点是否在ROI内
|
||||||
|
4. 中心点在ROI内消失则累计帧数,达到阈值触发报警
|
||||||
|
"""
|
||||||
self.frame_height, self.frame_width = frame.shape[:2]
|
self.frame_height, self.frame_width = frame.shape[:2]
|
||||||
current_frame_alerts = []
|
current_frame_alerts = [] # 本帧报警信息
|
||||||
frame_copy = frame.copy()
|
frame_copy = frame.copy()
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# ========================= 1. 绘制ROI区域 =========================
|
# ========================= 1. 初始化ROI绝对坐标并绘制ROI =========================
|
||||||
roi_colors = {"left": (255, 0, 0), "right": (255, 0, 0)}
|
|
||||||
self.roi_abs_cache.clear()
|
self.roi_abs_cache.clear()
|
||||||
for roi_name, _ in ROI_CONFIG.items():
|
for roi_name in self.roi_config:
|
||||||
roi_abs = self._get_roi_abs(roi_name)
|
roi_abs = self._get_roi_abs(roi_name)
|
||||||
if roi_abs is None:
|
if roi_abs is None:
|
||||||
continue
|
continue
|
||||||
self.roi_abs_cache[roi_name] = roi_abs
|
self.roi_abs_cache[roi_name] = roi_abs
|
||||||
roi_draw = roi_abs.reshape((-1, 1, 2))
|
# 绘制ROI多边形(闭合)+ ROI名称标签
|
||||||
cv2.polylines(frame, [roi_draw], isClosed=True, color=roi_colors[roi_name], thickness=2)
|
roi_draw = roi_abs.reshape((-1, 1, 2)) # OpenCV绘制要求形状 (n,1,2)
|
||||||
|
color = self.roi_colors.get(roi_name, (255, 255, 255))
|
||||||
|
cv2.polylines(frame, [roi_draw], isClosed=True, color=color, thickness=2)
|
||||||
cv2.putText(frame, roi_name, (roi_abs[0][0], roi_abs[0][1] - 5),
|
cv2.putText(frame, roi_name, (roi_abs[0][0], roi_abs[0][1] - 5),
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, roi_colors[roi_name], 2)
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||||
|
|
||||||
# ========================= 2. 模型推理 =========================
|
# ========================= 2. 模型推理:仅提取犯人检测框 =========================
|
||||||
detect_results = self.detector(frame)
|
detect_results = self.detector(frame)
|
||||||
prisoner_detections = []
|
prisoner_detections = []
|
||||||
|
|
||||||
@@ -222,6 +248,7 @@ class PrisonerDoorDetector:
|
|||||||
else:
|
else:
|
||||||
track_results = []
|
track_results = []
|
||||||
|
|
||||||
|
|
||||||
# ========================= 4. 【核心改进】融合跟踪和检测 =========================
|
# ========================= 4. 【核心改进】融合跟踪和检测 =========================
|
||||||
# 4.1 先处理跟踪结果
|
# 4.1 先处理跟踪结果
|
||||||
tracked_detections = {} # {track_id: detection_box}
|
tracked_detections = {} # {track_id: detection_box}
|
||||||
@@ -368,7 +395,7 @@ class PrisonerDoorDetector:
|
|||||||
current_frame_alerts.append({
|
current_frame_alerts.append({
|
||||||
"time": timestamp,
|
"time": timestamp,
|
||||||
"camera_id": camera_id,
|
"camera_id": camera_id,
|
||||||
"action": "prisoner_cx_disappear_in_door",
|
"action": "Indoor Violation",
|
||||||
"prisoner_track_id": target_id,
|
"prisoner_track_id": target_id,
|
||||||
"disappear_roi": target_info['current_roi'],
|
"disappear_roi": target_info['current_roi'],
|
||||||
"last_cx": round(target_info['last_cxcy'][0], 2),
|
"last_cx": round(target_info['last_cxcy'][0], 2),
|
||||||
@@ -413,9 +440,9 @@ class PrisonerDoorDetector:
|
|||||||
|
|
||||||
# 根据状态选择颜色
|
# 根据状态选择颜色
|
||||||
if in_roi:
|
if in_roi:
|
||||||
color = (0, 0, 255) # 绿色:在ROI内
|
color = (0, 0, 255) # 红色:在ROI内
|
||||||
else:
|
else:
|
||||||
color = (0, 255, 0) # 橙色:不在ROI内
|
color = (0, 255, 0) # 绿色:不在ROI内
|
||||||
|
|
||||||
# 根据来源选择线型
|
# 根据来源选择线型
|
||||||
thickness = 3 if source == 'tracked' else 2
|
thickness = 3 if source == 'tracked' else 2
|
||||||
@@ -442,6 +469,7 @@ class PrisonerDoorDetector:
|
|||||||
|
|
||||||
# ========================= 帧处理线程 =========================
|
# ========================= 帧处理线程 =========================
|
||||||
class FrameProcessorWorker(BaseFrameProcessorWorker):
|
class FrameProcessorWorker(BaseFrameProcessorWorker):
|
||||||
|
"""看守所走廊犯人检测 - 增强跟踪版"""
|
||||||
DETECTOR_FACTORY = lambda params: PrisonerDoorDetector(params)
|
DETECTOR_FACTORY = lambda params: PrisonerDoorDetector(params)
|
||||||
POST_TYPE = 3
|
POST_TYPE = 3
|
||||||
TARGET_FPS = RTSP_FPS
|
TARGET_FPS = RTSP_FPS
|
||||||
Reference in New Issue
Block a user