Files
SupervisorAI/biz/prison/indoor_biz.py

487 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import time
# import requests
from collections import deque
from biz.base_detector import BaseDetector
from biz.base_frame_processor import BaseFrameProcessorWorker
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX
from yolox.tracker.byte_tracker import BYTETracker
from common.constants import MODEL_ROOT_PATH
# ========================= 走廊场景专属配置 =========================
DETECT_MODEL_PATH = 'YOLO_Weight/kanshousuo.onnx' # 犯人检测onnx模型路径
INPUT_SIZE = 640 # 模型输入尺寸
RTSP_FPS = 10 # 视频流目标FPS
ALERT_PUSH_INTERVAL = 5 # 相同报警5秒内仅推送1次
# ALERT_PUSH_URL = "http://123.57.151.210:10000/picenter/websocket/test/process"
# 消失判定中心点在ROI内消失后持续无检测的帧数1.0秒,可微调)
ROI_LOST_FRAMES_THRESH = int(0.5 * RTSP_FPS) # todo: 从frame改为时间
# ========================= 默认ROI区域配置当config.yaml未配置时使用 =========================
DEFAULT_DOOR_ROIS = {
"left": {
"points": [[0.195, 0.245], [0.42, 0], [0.421, 0.185], [0.248, 0.8]],
"color": [255, 0, 0]
},
"right": {
"points": [[0.575, 0.], [0.81, 0.22], [0.78, 0.8], [0.575, 0.185]],
"color": [255, 0, 0]
}
}
# ==================================================================================
class PrisonerDoorDetector(BaseDetector):
def __init__(self, params=None):
super().__init__()
self.params = params or {}
# 0. 从params解析ROI配置无则使用默认值
door_rois_config = self.params.get('door_rois', DEFAULT_DOOR_ROIS)
self.roi_config = {}
self.roi_colors = {}
for door_name, door_cfg in door_rois_config.items():
self.roi_config[door_name] = door_cfg['points']
self.roi_colors[door_name] = tuple(door_cfg['color'])
model_path = self.params.get('model_path')
if model_path:
full_model_path = f"{MODEL_ROOT_PATH}/{model_path}"
else:
full_model_path = DETECT_MODEL_PATH
self.detector = YOLOv8_ONNX(
full_model_path,
conf_threshold=0.7, # 置信度阈值,可根据模型精度调整
iou_threshold=0.4, # IOU阈值
input_size=INPUT_SIZE
)
# 2. 初始化ByteTracker跟踪器适配走廊单/多犯人跟踪)
class TrackerArgs:
track_thresh = 0.65
track_buffer = 60 # 减小缓冲避免跟踪漂移
match_thresh = 0.5
mot20 = False
self.tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_FPS)
# 3. 状态变量初始化
# self.last_alert_time = 0.0 # 最后报警时间(防重复推送)
# 犯人跟踪信息:{track_id: {'is_cx_in_roi': 中心点是否在ROI, 'lost_frames': 消失帧数, 'lost_roi': 消失的ROI名称, 'last_cxcy': 最后中心点坐标}}
self.prisoner_track_info = {}
self.frame_width = 0 # 帧宽度(动态获取)
self.frame_height = 0 # 帧高度(动态获取)
self.roi_abs_cache = {} # ROI绝对坐标缓存{roi_name: np.int32数组}
self.entry_frame_cache = {}
# 基于位置的跟踪状态管理
self.active_targets = {} # {target_id: {...}}
self.next_target_id = 0
self.position_history = {} # {target_id: deque of positions}
# 距离阈值(用于匹配检测框和已有目标)
self.distance_threshold = 100 # 像素距离
buffer_seconds = 3 # 最大回溯3秒
self.init_frame_buffer(buffer_seconds, RTSP_FPS)
self.detect_rollback_time = 0.9 # 警报帧回溯时间(秒)
def compute_center_distance(self, box1, box2):
"""计算两个框中心点的欧氏距离"""
cx1 = (box1[0] + box1[2]) / 2
cy1 = (box1[1] + box1[3]) / 2
cx2 = (box2[0] + box2[2]) / 2
cy2 = (box2[1] + box2[3]) / 2
return np.sqrt((cx1 - cx2) ** 2 + (cy1 - cy2) ** 2)
def compute_iou(self, boxA, boxB):
"""IOU计算匹配跟踪框与犯人检测框过滤非犯人目标"""
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
return interArea / unionArea if unionArea > 0 else 0.0
def _get_roi_abs(self, roi_name):
"""相对坐标转绝对像素坐标适配当前帧分辨率OpenCV要求int32"""
if roi_name not in self.roi_config:
return None
roi_rel = np.array(self.roi_config[roi_name], dtype=np.float64)
roi_abs = roi_rel * np.array([self.frame_width, self.frame_height])
return roi_abs.astype(np.int32)
def is_cxcy_in_roi(self, cx, cy):
"""判断犯人框**中心点(cx,cy)** 是否在任意ROI内返回(是否在ROI, 所在ROI名称)"""
for roi_name, roi_abs in self.roi_abs_cache.items():
# OpenCV点在多边形内判定>=0 表示在内部/边上
if cv2.pointPolygonTest(roi_abs, (cx, cy), False) >= 0:
return (True, roi_name)
return (False, "outside")
def match_detection_to_target(self, detection_box, detection_conf):
"""
【核心】将检测框匹配到已有目标
返回: (matched_target_id, match_score)
"""
best_match_id = None
best_match_score = 0
det_center = np.array([(detection_box[0] + detection_box[2]) / 2,
(detection_box[1] + detection_box[3]) / 2])
for target_id, target_info in self.active_targets.items():
# 计算与目标最后已知位置的距离
last_box = target_info['last_box']
last_center = np.array([(last_box[0] + last_box[2]) / 2,
(last_box[1] + last_box[3]) / 2])
distance = np.linalg.norm(det_center - last_center)
# 计算IOU如果目标最近刚更新
time_since_update = time.time() - target_info['last_update_time']
iou_score = self.compute_iou(detection_box, last_box) if time_since_update < 1.0 else 0
# 综合评分:距离近 + IOU高
distance_score = max(0, 1 - distance / self.distance_threshold)
match_score = 0.3 * distance_score + 0.7 * iou_score
# 考虑位置预测(如果目标在移动中)
if target_id in self.position_history and len(self.position_history[target_id]) >= 2:
# 简单的线性预测
hist = list(self.position_history[target_id])
if len(hist) >= 2:
velocity = hist[-1] - hist[-2]
predicted_pos = last_center + velocity
pred_distance = np.linalg.norm(det_center - predicted_pos)
pred_score = max(0, 1 - pred_distance / self.distance_threshold)
match_score = 0.7 * match_score + 0.3 * pred_score
if match_score > best_match_score and match_score > 0.3: # 阈值可调
best_match_score = match_score
best_match_id = target_id
return best_match_id, best_match_score
# def push_alert(self, camera_id, target_id, lost_roi, last_cxcy, timestamp, entry_frame):
# """报警推送"""
# current_time = time.time()
# if current_time - self.last_alert_time < ALERT_PUSH_INTERVAL:
# return False
#
# _, frame_encoded = cv2.imencode('.jpg', entry_frame)
# frame_base64 = frame_encoded.tobytes()
#
# alert_info = {
# "camera_id": camera_id,
# "alert_type": "prisoner_cx_disappear_in_roi",
# "prisoner_track_id": target_id,
# "disappear_roi": lost_roi,
# "last_cx": round(last_cxcy[0], 2),
# "last_cy": round(last_cxcy[1], 2),
# "timestamp": timestamp,
# "entry_frame_base64": frame_base64,
# "details": f"犯人框中心点在{lost_roi}区域内消失"
# }
#
# try:
# requests.post(ALERT_PUSH_URL, json=alert_info, timeout=3)
# print(f"[报警成功] target_id={target_id}, roi={lost_roi}")
# self.last_alert_time = current_time
# return True
# except Exception as e:
# print(f"[报警失败] {str(e)}")
# return False
def process_frame(self, frame, camera_id: int, timestamp: float) -> dict:
"""
核心帧处理:
1. 绘制5个ROI区域 2. 检测+跟踪犯人 3. 判定中心点是否在ROI内
4. 中心点在ROI内消失则累计帧数达到阈值触发报警
"""
self.frame_height, self.frame_width = frame.shape[:2]
current_frame_alerts = [] # 本帧报警信息
frame_copy = frame.copy()
current_time = time.time()
# ========================= 1. 初始化ROI绝对坐标并绘制ROI =========================
self.roi_abs_cache.clear()
for roi_name in self.roi_config:
roi_abs = self._get_roi_abs(roi_name)
if roi_abs is None:
continue
self.roi_abs_cache[roi_name] = roi_abs
# 绘制ROI多边形闭合+ ROI名称标签
roi_draw = roi_abs.reshape((-1, 1, 2)) # OpenCV绘制要求形状 (n,1,2)
color = self.roi_colors.get(roi_name, (255, 255, 255))
cv2.polylines(frame, [roi_draw], isClosed=True, color=color, thickness=2)
cv2.putText(frame, roi_name, (roi_abs[0][0], roi_abs[0][1] - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# ========================= 2. 模型推理:仅提取犯人检测框 =========================
detect_results = self.detector(frame)
prisoner_detections = []
if detect_results:
for det in detect_results:
x1, y1, x2, y2, conf, cls_id = det
# 确保坐标在图像范围内
x1 = max(0, min(x1, self.frame_width - 1))
y1 = max(0, min(y1, self.frame_height - 1))
x2 = max(0, min(x2, self.frame_width - 1))
y2 = max(0, min(y2, self.frame_height - 1))
if cls_id == 1 and x2 > x1 and y2 > y1 and (x2 - x1) * (y2 - y1) > 100: # 过滤太小的框
prisoner_detections.append([x1, y1, x2, y2, conf, cls_id])
# ========================= 3. ByteTracker跟踪 =========================
prisoner_det_boxes = np.array(
[[x1, y1, x2, y2, conf] for x1, y1, x2, y2, conf, cls_id in prisoner_detections],
dtype=np.float32) if prisoner_detections else np.empty((0, 5))
if len(prisoner_det_boxes) > 0:
track_results = self.tracker.update(
prisoner_det_boxes,
[self.frame_height, self.frame_width],
[self.frame_height, self.frame_width]
)
else:
track_results = []
# ========================= 4. 【核心改进】融合跟踪和检测 =========================
# 4.1 先处理跟踪结果
tracked_detections = {} # {track_id: detection_box}
used_det_indices = set()
for track in track_results:
track_id = track.track_id
t_box = [float(x) for x in track.tlbr]
# 寻找匹配的检测框
best_iou = 0.0 # 最低阈值
best_det_idx = -1
for det_idx, det in enumerate(prisoner_detections):
if det_idx in used_det_indices:
continue
iou = self.compute_iou(t_box, det[:4])
if iou > best_iou:
best_iou = iou
best_det_idx = det_idx
if best_det_idx != -1:
# 跟踪框有对应的检测框,使用检测框(更准确)
tracked_detections[f"track_{track_id}"] = {
'box': prisoner_detections[best_det_idx][:4],
'conf': prisoner_detections[best_det_idx][4],
'source': 'tracked'
}
used_det_indices.add(best_det_idx)
else:
# 跟踪框没有对应的检测框,但仍保留跟踪框
tracked_detections[f"track_{track_id}"] = {
'box': t_box,
'conf': 0.5, # 给个中等置信度
'source': 'track_only'
}
# 4.2 处理未被跟踪的检测框
for det_idx, det in enumerate(prisoner_detections):
if det_idx not in used_det_indices:
tracked_detections[f"det_{det_idx}"] = {
'box': det[:4],
'conf': det[4],
'source': 'det_only'
}
# ========================= 5. 匹配到已有目标 =========================
current_target_ids = set()
matched_det_keys = set()
for det_key, det_info in tracked_detections.items():
det_box = det_info['box']
det_conf = det_info['conf']
# 计算中心点
cx = (det_box[0] + det_box[2]) / 2
cy = (det_box[1] + det_box[3]) / 2
# 匹配到已有目标
matched_target_id, match_score = self.match_detection_to_target(det_box, det_conf)
if matched_target_id is not None and match_score > 0.3:
# 更新已有目标
target_id = matched_target_id
target_info = self.active_targets[target_id]
# 更新位置历史
if target_id not in self.position_history:
self.position_history[target_id] = deque(maxlen=10)
self.position_history[target_id].append(np.array([cx, cy]))
# 判断是否在ROI内
is_cx_in_roi, current_roi = self.is_cxcy_in_roi(cx, cy)
# 首次进入ROI缓存帧
if not target_info.get('in_roi', False) and is_cx_in_roi:
self.entry_frame_cache[target_id] = frame_copy.copy()
target_info['lost_frames'] = 0
# 更新目标信息
target_info.update({
'last_box': det_box,
'last_cxcy': (cx, cy),
'last_conf': det_conf,
'last_update_time': current_time,
'in_roi': is_cx_in_roi,
'current_roi': current_roi if is_cx_in_roi else target_info.get('current_roi', 'outside'),
'detection_source': det_info['source']
})
current_target_ids.add(target_id)
matched_det_keys.add(det_key)
else:
# 创建新目标
target_id = self.next_target_id
self.next_target_id += 1
is_cx_in_roi, current_roi = self.is_cxcy_in_roi(cx, cy)
self.active_targets[target_id] = {
'first_seen': current_time,
'last_box': det_box,
'last_cxcy': (cx, cy),
'last_conf': det_conf,
'last_update_time': current_time,
'in_roi': is_cx_in_roi,
'current_roi': current_roi if is_cx_in_roi else 'outside',
'lost_frames': 0,
'detection_source': det_info['source']
}
self.position_history[target_id] = deque(maxlen=10)
self.position_history[target_id].append(np.array([cx, cy]))
if is_cx_in_roi:
self.entry_frame_cache[target_id] = frame_copy.copy()
current_target_ids.add(target_id)
matched_det_keys.add(det_key)
# ========================= 6. 处理消失和报警 =========================
for target_id in list(self.active_targets.keys()):
target_info = self.active_targets[target_id]
if target_id not in current_target_ids:
# 目标在当前帧未出现
if target_info['in_roi']:
# 在ROI内消失
target_info['lost_frames'] += 1
if target_info['lost_frames'] >= ROI_LOST_FRAMES_THRESH:
# 触发报警
# entry_frame = self.entry_frame_cache.get(target_id, frame_copy)
# self.push_alert(
# camera_id=camera_id,
# target_id=target_id,
# lost_roi=target_info['current_roi'],
# last_cxcy=target_info['last_cxcy'],
# timestamp=timestamp,
# entry_frame=entry_frame
# )
alert_frame = self.find_target_frame(timestamp - self.detect_rollback_time)
current_frame_alerts.append({
"time": timestamp,
"camera_id": camera_id,
"action": "Indoor Violation",
'image': alert_frame,
"prisoner_track_id": target_id,
"disappear_roi": target_info['current_roi'],
"last_cx": round(target_info['last_cxcy'][0], 2),
"last_cy": round(target_info['last_cxcy'][1], 2)
})
# 清理
del self.active_targets[target_id]
if target_id in self.position_history:
del self.position_history[target_id]
if target_id in self.entry_frame_cache:
del self.entry_frame_cache[target_id]
else:
# 不在ROI内消失直接清理
del self.active_targets[target_id]
if target_id in self.position_history:
del self.position_history[target_id]
if target_id in self.entry_frame_cache:
del self.entry_frame_cache[target_id]
else:
# 目标仍在但可能已离开ROI
if not target_info['in_roi']:
target_info['lost_frames'] = 0
# ========================= 7. 清理超时目标 =========================
timeout_threshold = 5.0 # 5秒无更新就清理
for target_id in list(self.active_targets.keys()):
if current_time - self.active_targets[target_id]['last_update_time'] > timeout_threshold:
del self.active_targets[target_id]
if target_id in self.position_history:
del self.position_history[target_id]
if target_id in self.entry_frame_cache:
del self.entry_frame_cache[target_id]
# ========================= 8. 绘制可视化 =========================
for target_id, target_info in self.active_targets.items():
box = target_info['last_box']
cx, cy = target_info['last_cxcy']
in_roi = target_info['in_roi']
current_roi = target_info['current_roi']
source = target_info.get('detection_source', 'unknown')
# 根据状态选择颜色
if in_roi:
color = (0, 0, 255) # 红色在ROI内
else:
color = (0, 255, 0) # 绿色不在ROI内
# 根据来源选择线型
thickness = 3 if source == 'tracked' else 2
cv2.rectangle(frame, (int(box[0]), int(box[1])),
(int(box[2]), int(box[3])), color, thickness)
cv2.circle(frame, (int(cx), int(cy)), 5, color, -1)
status = f"T{target_id}_{current_roi[:2]}"
if source == 'det_only':
status += "_DET"
cv2.putText(frame, status, (int(box[0]), int(box[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# ========================= 9. 统计信息 =========================
cv2.putText(frame, f"Camera: {camera_id}", (20, self.frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, f"Active Targets: {len(self.active_targets)}",
(20, self.frame_height - 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
self.append_frame(frame, timestamp)
return {"image": frame, "alerts": current_frame_alerts}
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(BaseFrameProcessorWorker):
"""看守所走廊犯人检测 - 增强跟踪版"""
DETECTOR_FACTORY = lambda params: PrisonerDoorDetector(params)
POST_TYPE = 3
TARGET_FPS = RTSP_FPS