Files
SupervisorAI/biz/prison/indoor_biz.py

447 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import time
import requests
from collections import deque
from biz.base_frame_processor import BaseFrameProcessorWorker
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX
from yolox.tracker.byte_tracker import BYTETracker
# ========================= 走廊场景专属配置 =========================
MODEL_PATH = 'YOLO_Weight/kanshousuo.onnx'
INPUT_SIZE = 640
RTSP_FPS = 10
ALERT_PUSH_INTERVAL = 10
ALERT_PUSH_URL = "http://123.57.151.210:10000/picenter/websocket/test/process"
ROI_LOST_FRAMES_THRESH = int(1 * RTSP_FPS)
# ========================= ROI区域配置 =========================
ROI_CONFIG = {
"left": [[0.195, 0.245], [0.42, 0], [0.421, 0.185], [0.248, 0.8]],
"right": [[0.575, 0.], [0.81, 0.22], [0.78, 0.8], [0.575, 0.185]],
}
class PrisonerDoorDetector:
def __init__(self, params=None):
self.params = params or {}
# 1. 加载YOLO模型 - 降低阈值提高检测率
self.detector = YOLOv8_ONNX(
MODEL_PATH,
conf_threshold=0.57, # 进一步降低,捕获更多检测
iou_threshold=0.4,
input_size=INPUT_SIZE
)
# 2. ByteTracker参数优化
class TrackerArgs:
track_thresh = 0.65 # 更低的跟踪阈值
track_buffer = 60 # 更大的缓冲,应对短暂消失
match_thresh = 0.5 # 更宽松的匹配
mot20 = False
self.tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_FPS)
# 3. 状态变量
self.last_alert_time = 0.0
self.frame_width = 0
self.frame_height = 0
self.roi_abs_cache = {}
self.entry_frame_cache = {}
# 【核心改进】基于位置的跟踪状态管理
self.active_targets = {} # {target_id: {...}}
self.next_target_id = 0
self.position_history = {} # {target_id: deque of positions}
# 距离阈值(用于匹配检测框和已有目标)
self.distance_threshold = 100 # 像素距离
def compute_center_distance(self, box1, box2):
"""计算两个框中心点的欧氏距离"""
cx1 = (box1[0] + box1[2]) / 2
cy1 = (box1[1] + box1[3]) / 2
cx2 = (box2[0] + box2[2]) / 2
cy2 = (box2[1] + box2[3]) / 2
return np.sqrt((cx1 - cx2) ** 2 + (cy1 - cy2) ** 2)
def compute_iou(self, boxA, boxB):
"""IOU计算"""
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = max(0, (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
boxBArea = max(0, (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
unionArea = boxAArea + boxBArea - interArea
return interArea / unionArea if unionArea > 0 else 0.0
def _get_roi_abs(self, roi_name):
"""相对坐标转绝对像素坐标"""
if roi_name not in ROI_CONFIG:
return None
roi_rel = np.array(ROI_CONFIG[roi_name], dtype=np.float64)
roi_abs = roi_rel * np.array([self.frame_width, self.frame_height])
return roi_abs.astype(np.int32)
def is_cxcy_in_roi(self, cx, cy):
"""判断中心点是否在ROI内"""
for roi_name, roi_abs in self.roi_abs_cache.items():
if cv2.pointPolygonTest(roi_abs, (cx, cy), False) >= 0:
return (True, roi_name)
return (False, "outside")
def match_detection_to_target(self, detection_box, detection_conf):
"""
【核心】将检测框匹配到已有目标
返回: (matched_target_id, match_score)
"""
best_match_id = None
best_match_score = 0
det_center = np.array([(detection_box[0] + detection_box[2]) / 2,
(detection_box[1] + detection_box[3]) / 2])
for target_id, target_info in self.active_targets.items():
# 计算与目标最后已知位置的距离
last_box = target_info['last_box']
last_center = np.array([(last_box[0] + last_box[2]) / 2,
(last_box[1] + last_box[3]) / 2])
distance = np.linalg.norm(det_center - last_center)
# 计算IOU如果目标最近刚更新
time_since_update = time.time() - target_info['last_update_time']
iou_score = self.compute_iou(detection_box, last_box) if time_since_update < 1.0 else 0
# 综合评分:距离近 + IOU高
distance_score = max(0, 1 - distance / self.distance_threshold)
match_score = 0.3 * distance_score + 0.7 * iou_score
# 考虑位置预测(如果目标在移动中)
if target_id in self.position_history and len(self.position_history[target_id]) >= 2:
# 简单的线性预测
hist = list(self.position_history[target_id])
if len(hist) >= 2:
velocity = hist[-1] - hist[-2]
predicted_pos = last_center + velocity
pred_distance = np.linalg.norm(det_center - predicted_pos)
pred_score = max(0, 1 - pred_distance / self.distance_threshold)
match_score = 0.7 * match_score + 0.3 * pred_score
if match_score > best_match_score and match_score > 0.3: # 阈值可调
best_match_score = match_score
best_match_id = target_id
return best_match_id, best_match_score
def push_alert(self, camera_id, target_id, lost_roi, last_cxcy, timestamp, entry_frame):
"""报警推送"""
current_time = time.time()
if current_time - self.last_alert_time < ALERT_PUSH_INTERVAL:
return False
_, frame_encoded = cv2.imencode('.jpg', entry_frame)
frame_base64 = frame_encoded.tobytes()
alert_info = {
"camera_id": camera_id,
"alert_type": "prisoner_cx_disappear_in_roi",
"prisoner_track_id": target_id,
"disappear_roi": lost_roi,
"last_cx": round(last_cxcy[0], 2),
"last_cy": round(last_cxcy[1], 2),
"timestamp": timestamp,
"entry_frame_base64": frame_base64,
"details": f"犯人框中心点在{lost_roi}区域内消失"
}
try:
requests.post(ALERT_PUSH_URL, json=alert_info, timeout=3)
print(f"[报警成功] target_id={target_id}, roi={lost_roi}")
self.last_alert_time = current_time
return True
except Exception as e:
print(f"[报警失败] {str(e)}")
return False
def process_frame(self, frame, camera_id: int, timestamp: float) -> dict:
"""核心帧处理 - 增强检测版"""
self.frame_height, self.frame_width = frame.shape[:2]
current_frame_alerts = []
frame_copy = frame.copy()
current_time = time.time()
# ========================= 1. 绘制ROI区域 =========================
roi_colors = {"left": (255, 0, 0), "right": (255, 0, 0)}
self.roi_abs_cache.clear()
for roi_name, _ in ROI_CONFIG.items():
roi_abs = self._get_roi_abs(roi_name)
if roi_abs is None:
continue
self.roi_abs_cache[roi_name] = roi_abs
roi_draw = roi_abs.reshape((-1, 1, 2))
cv2.polylines(frame, [roi_draw], isClosed=True, color=roi_colors[roi_name], thickness=2)
cv2.putText(frame, roi_name, (roi_abs[0][0], roi_abs[0][1] - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, roi_colors[roi_name], 2)
# ========================= 2. 模型推理 =========================
detect_results = self.detector(frame)
prisoner_detections = []
if detect_results:
for det in detect_results:
x1, y1, x2, y2, conf, cls_id = det
# 确保坐标在图像范围内
x1 = max(0, min(x1, self.frame_width - 1))
y1 = max(0, min(y1, self.frame_height - 1))
x2 = max(0, min(x2, self.frame_width - 1))
y2 = max(0, min(y2, self.frame_height - 1))
if cls_id == 1 and x2 > x1 and y2 > y1 and (x2 - x1) * (y2 - y1) > 100: # 过滤太小的框
prisoner_detections.append([x1, y1, x2, y2, conf, cls_id])
# ========================= 3. ByteTracker跟踪 =========================
prisoner_det_boxes = np.array(
[[x1, y1, x2, y2, conf] for x1, y1, x2, y2, conf, cls_id in prisoner_detections],
dtype=np.float32) if prisoner_detections else np.empty((0, 5))
if len(prisoner_det_boxes) > 0:
track_results = self.tracker.update(
prisoner_det_boxes,
[self.frame_height, self.frame_width],
[self.frame_height, self.frame_width]
)
else:
track_results = []
# ========================= 4. 【核心改进】融合跟踪和检测 =========================
# 4.1 先处理跟踪结果
tracked_detections = {} # {track_id: detection_box}
used_det_indices = set()
for track in track_results:
track_id = track.track_id
t_box = [float(x) for x in track.tlbr]
# 寻找匹配的检测框
best_iou = 0.0 # 最低阈值
best_det_idx = -1
for det_idx, det in enumerate(prisoner_detections):
if det_idx in used_det_indices:
continue
iou = self.compute_iou(t_box, det[:4])
if iou > best_iou:
best_iou = iou
best_det_idx = det_idx
if best_det_idx != -1:
# 跟踪框有对应的检测框,使用检测框(更准确)
tracked_detections[f"track_{track_id}"] = {
'box': prisoner_detections[best_det_idx][:4],
'conf': prisoner_detections[best_det_idx][4],
'source': 'tracked'
}
used_det_indices.add(best_det_idx)
else:
# 跟踪框没有对应的检测框,但仍保留跟踪框
tracked_detections[f"track_{track_id}"] = {
'box': t_box,
'conf': 0.5, # 给个中等置信度
'source': 'track_only'
}
# 4.2 处理未被跟踪的检测框
for det_idx, det in enumerate(prisoner_detections):
if det_idx not in used_det_indices:
tracked_detections[f"det_{det_idx}"] = {
'box': det[:4],
'conf': det[4],
'source': 'det_only'
}
# ========================= 5. 匹配到已有目标 =========================
current_target_ids = set()
matched_det_keys = set()
for det_key, det_info in tracked_detections.items():
det_box = det_info['box']
det_conf = det_info['conf']
# 计算中心点
cx = (det_box[0] + det_box[2]) / 2
cy = (det_box[1] + det_box[3]) / 2
# 匹配到已有目标
matched_target_id, match_score = self.match_detection_to_target(det_box, det_conf)
if matched_target_id is not None and match_score > 0.3:
# 更新已有目标
target_id = matched_target_id
target_info = self.active_targets[target_id]
# 更新位置历史
if target_id not in self.position_history:
self.position_history[target_id] = deque(maxlen=10)
self.position_history[target_id].append(np.array([cx, cy]))
# 判断是否在ROI内
is_cx_in_roi, current_roi = self.is_cxcy_in_roi(cx, cy)
# 首次进入ROI缓存帧
if not target_info.get('in_roi', False) and is_cx_in_roi:
self.entry_frame_cache[target_id] = frame_copy.copy()
target_info['lost_frames'] = 0
# 更新目标信息
target_info.update({
'last_box': det_box,
'last_cxcy': (cx, cy),
'last_conf': det_conf,
'last_update_time': current_time,
'in_roi': is_cx_in_roi,
'current_roi': current_roi if is_cx_in_roi else target_info.get('current_roi', 'outside'),
'detection_source': det_info['source']
})
current_target_ids.add(target_id)
matched_det_keys.add(det_key)
else:
# 创建新目标
target_id = self.next_target_id
self.next_target_id += 1
is_cx_in_roi, current_roi = self.is_cxcy_in_roi(cx, cy)
self.active_targets[target_id] = {
'first_seen': current_time,
'last_box': det_box,
'last_cxcy': (cx, cy),
'last_conf': det_conf,
'last_update_time': current_time,
'in_roi': is_cx_in_roi,
'current_roi': current_roi if is_cx_in_roi else 'outside',
'lost_frames': 0,
'detection_source': det_info['source']
}
self.position_history[target_id] = deque(maxlen=10)
self.position_history[target_id].append(np.array([cx, cy]))
if is_cx_in_roi:
self.entry_frame_cache[target_id] = frame_copy.copy()
current_target_ids.add(target_id)
matched_det_keys.add(det_key)
# ========================= 6. 处理消失和报警 =========================
for target_id in list(self.active_targets.keys()):
target_info = self.active_targets[target_id]
if target_id not in current_target_ids:
# 目标在当前帧未出现
if target_info['in_roi']:
# 在ROI内消失
target_info['lost_frames'] += 1
if target_info['lost_frames'] >= ROI_LOST_FRAMES_THRESH:
# 触发报警
entry_frame = self.entry_frame_cache.get(target_id, frame_copy)
self.push_alert(
camera_id=camera_id,
target_id=target_id,
lost_roi=target_info['current_roi'],
last_cxcy=target_info['last_cxcy'],
timestamp=timestamp,
entry_frame=entry_frame
)
current_frame_alerts.append({
"time": timestamp,
"camera_id": camera_id,
"action": "prisoner_cx_disappear_in_door",
"prisoner_track_id": target_id,
"disappear_roi": target_info['current_roi'],
"last_cx": round(target_info['last_cxcy'][0], 2),
"last_cy": round(target_info['last_cxcy'][1], 2)
})
# 清理
del self.active_targets[target_id]
if target_id in self.position_history:
del self.position_history[target_id]
if target_id in self.entry_frame_cache:
del self.entry_frame_cache[target_id]
else:
# 不在ROI内消失直接清理
del self.active_targets[target_id]
if target_id in self.position_history:
del self.position_history[target_id]
if target_id in self.entry_frame_cache:
del self.entry_frame_cache[target_id]
else:
# 目标仍在但可能已离开ROI
if not target_info['in_roi']:
target_info['lost_frames'] = 0
# ========================= 7. 清理超时目标 =========================
timeout_threshold = 5.0 # 5秒无更新就清理
for target_id in list(self.active_targets.keys()):
if current_time - self.active_targets[target_id]['last_update_time'] > timeout_threshold:
del self.active_targets[target_id]
if target_id in self.position_history:
del self.position_history[target_id]
if target_id in self.entry_frame_cache:
del self.entry_frame_cache[target_id]
# ========================= 8. 绘制可视化 =========================
for target_id, target_info in self.active_targets.items():
box = target_info['last_box']
cx, cy = target_info['last_cxcy']
in_roi = target_info['in_roi']
current_roi = target_info['current_roi']
source = target_info.get('detection_source', 'unknown')
# 根据状态选择颜色
if in_roi:
color = (0, 0, 255) # 绿色在ROI内
else:
color = (0, 255, 0) # 橙色不在ROI内
# 根据来源选择线型
thickness = 3 if source == 'tracked' else 2
cv2.rectangle(frame, (int(box[0]), int(box[1])),
(int(box[2]), int(box[3])), color, thickness)
cv2.circle(frame, (int(cx), int(cy)), 5, color, -1)
status = f"T{target_id}_{current_roi[:2]}"
if source == 'det_only':
status += "_DET"
cv2.putText(frame, status, (int(box[0]), int(box[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# ========================= 9. 统计信息 =========================
cv2.putText(frame, f"Camera: {camera_id}", (20, self.frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, f"Active Targets: {len(self.active_targets)}",
(20, self.frame_height - 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
return {"image": frame, "alerts": current_frame_alerts}
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(BaseFrameProcessorWorker):
DETECTOR_FACTORY = lambda params: PrisonerDoorDetector(params)
POST_TYPE = 3
TARGET_FPS = RTSP_FPS