Files
SupervisorAI/biz/prison/indoor_biz.py
2026-04-01 17:16:41 +08:00

244 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import time
import requests
from biz.base_frame_processor import BaseFrameProcessorWorker
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX
from yolox.tracker.byte_tracker import BYTETracker
from common.constants import MODEL_ROOT_PATH
# ========================= 走廊场景专属配置 =========================
DETECT_MODEL_PATH = 'YOLO_Weight/kanshousuo.onnx' # 犯人检测onnx模型路径
INPUT_SIZE = 640 # 模型输入尺寸
RTSP_FPS = 10 # 视频流目标FPS
ALERT_PUSH_INTERVAL = 5 # 相同报警5秒内仅推送1次
ALERT_PUSH_URL = "http://123.57.151.210:10000/picenter/websocket/test/process"
# 消失判定中心点在ROI内消失后持续无检测的帧数1.0秒,可微调)
ROI_LOST_FRAMES_THRESH = int(0.5 * RTSP_FPS)
# ========================= 默认ROI区域配置当config.yaml未配置时使用 =========================
DEFAULT_DOOR_ROIS = {
"left_door_1": {
"points": [[0.195, 0.242], [0.265, 0.17], [0.3, 0.63], [0.248, 0.8]],
"color": [255, 0, 0]
}
}
# ==================================================================================
class PrisonerDoorDetector:
def __init__(self, params=None):
self.params = params or {}
# 0. 从params解析ROI配置无则使用默认值
door_rois_config = self.params.get('door_rois', DEFAULT_DOOR_ROIS)
self.roi_config = {}
self.roi_colors = {}
for door_name, door_cfg in door_rois_config.items():
self.roi_config[door_name] = door_cfg['points']
self.roi_colors[door_name] = tuple(door_cfg['color'])
model_path = self.params.get('model_path')
if model_path:
full_model_path = f"{MODEL_ROOT_PATH}/{model_path}"
else:
full_model_path = DETECT_MODEL_PATH
self.detector = YOLOv8_ONNX(
full_model_path,
conf_threshold=0.5, # 置信度阈值,可根据模型精度调整
iou_threshold=0.45, # IOU阈值
input_size=INPUT_SIZE
)
# 2. 初始化ByteTracker跟踪器适配走廊单/多犯人跟踪)
class TrackerArgs:
track_thresh = 0.25
track_buffer = 20 # 减小缓冲避免跟踪漂移
match_thresh = 0.75
mot20 = False
self.tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_FPS)
# 3. 状态变量初始化
self.last_alert_time = 0.0 # 最后报警时间(防重复推送)
# 犯人跟踪信息:{track_id: {'is_cx_in_roi': 中心点是否在ROI, 'lost_frames': 消失帧数, 'lost_roi': 消失的ROI名称, 'last_cxcy': 最后中心点坐标}}
self.prisoner_track_info = {}
self.frame_width = 0 # 帧宽度(动态获取)
self.frame_height = 0 # 帧高度(动态获取)
self.roi_abs_cache = {} # ROI绝对坐标缓存{roi_name: np.int32数组}
def compute_iou(self, boxA, boxB):
"""IOU计算匹配跟踪框与犯人检测框过滤非犯人目标"""
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
return interArea / unionArea if unionArea > 0 else 0.0
def _get_roi_abs(self, roi_name):
"""相对坐标转绝对像素坐标适配当前帧分辨率OpenCV要求int32"""
if roi_name not in self.roi_config:
return None
roi_rel = np.array(self.roi_config[roi_name], dtype=np.float64)
roi_abs = roi_rel * np.array([self.frame_width, self.frame_height])
return roi_abs.astype(np.int32)
def is_cxcy_in_roi(self, cx, cy):
"""判断犯人框**中心点(cx,cy)** 是否在任意ROI内返回(是否在ROI, 所在ROI名称)"""
for roi_name, roi_abs in self.roi_abs_cache.items():
# OpenCV点在多边形内判定>=0 表示在内部/边上
if cv2.pointPolygonTest(roi_abs, (cx, cy), False) >= 0:
return (True, roi_name)
return (False, "outside")
# def push_alert(self, camera_id, track_id, lost_roi, last_cxcy, timestamp):
# """报警推送带频率限制携带消失ROI、最后中心点坐标"""
# current_time = time.time()
# if current_time - self.last_alert_time < ALERT_PUSH_INTERVAL:
# return False
# # 构造报警信息(可根据平台要求扩展字段)
# alert_info = {
# "camera_id": camera_id,
# "alert_type": "prisoner_cx_disappear_in_roi",
# "prisoner_track_id": track_id,
# "disappear_roi": lost_roi,
# "last_cx": round(last_cxcy[0], 2),
# "last_cy": round(last_cxcy[1], 2),
# "timestamp": timestamp,
# "details": f"犯人框中心点在{lost_roi}区域内消失,触发报警"
# }
# # 推送报警请求
# try:
# requests.post(ALERT_PUSH_URL, json=alert_info, timeout=3)
# print(f"[报警成功] {alert_info}")
# self.last_alert_time = current_time
# return True
# except Exception as e:
# print(f"[报警失败] 原因:{str(e)}")
# return False
def process_frame(self, frame, camera_id: int, timestamp: float) -> dict:
"""
核心帧处理:
1. 绘制5个ROI区域 2. 检测+跟踪犯人 3. 判定中心点是否在ROI内
4. 中心点在ROI内消失则累计帧数达到阈值触发报警
"""
self.frame_height, self.frame_width = frame.shape[:2]
current_frame_alerts = [] # 本帧报警信息
# ========================= 1. 初始化ROI绝对坐标并绘制ROI =========================
self.roi_abs_cache.clear()
for roi_name in self.roi_config:
roi_abs = self._get_roi_abs(roi_name)
if roi_abs is None:
continue
self.roi_abs_cache[roi_name] = roi_abs
# 绘制ROI多边形闭合+ ROI名称标签
roi_draw = roi_abs.reshape((-1, 1, 2)) # OpenCV绘制要求形状 (n,1,2)
color = self.roi_colors.get(roi_name, (255, 255, 255))
cv2.polylines(frame, [roi_draw], isClosed=True, color=color, thickness=2)
cv2.putText(frame, roi_name, (roi_abs[0][0], roi_abs[0][1] - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# ========================= 2. 模型推理:仅提取犯人检测框 =========================
detect_results = self.detector(frame)
prisoner_dets_xyxy = [] # 仅存犯人检测框 [x1,y1,x2,y2]
dets_for_tracker = [] # 跟踪器输入 [x1,y1,x2,y2,conf]
if detect_results:
for det in detect_results:
x1, y1, x2, y2, conf, cls_id = det
dets_for_tracker.append([x1, y1, x2, y2, conf])
# 替换为你模型中「犯人」的实际类别ID此处默认cls_id=1
if cls_id == 1:
prisoner_dets_xyxy.append([x1, y1, x2, y2])
# ========================= 3. 目标跟踪:更新犯人跟踪结果 =========================
dets_np = np.array(dets_for_tracker, dtype=np.float32) if dets_for_tracker else np.empty((0, 5))
track_results = self.tracker.update(dets_np, [self.frame_height, self.frame_width],
[self.frame_height, self.frame_width])
# ========================= 4. 遍历跟踪结果判定犯人中心点是否在ROI =========================
current_prisoner_tids = set() # 本帧存在的犯人track_id
for track in track_results:
track_id = track.track_id
track_box = list(map(float, track.tlbr)) # 跟踪框 [x1,y1,x2,y2]
# IOU匹配过滤非犯人目标仅保留真正的犯人
is_prisoner = False
for p_box in prisoner_dets_xyxy:
if self.compute_iou(track_box, p_box) > 0.3:
is_prisoner = True
break
if not is_prisoner:
continue
# 计算犯人框**中心点坐标**(核心判定依据)
cx = (track_box[0] + track_box[2]) / 2
cy = (track_box[1] + track_box[3]) / 2
# 判定中心点是否在ROI内返回(是否在ROI, 所在ROI名称)
is_cx_in_roi, current_roi = self.is_cxcy_in_roi(cx, cy)
# 更新犯人跟踪信息记录中心点状态、所在ROI、最后坐标重置消失帧数
self.prisoner_track_info[track_id] = {
"is_cx_in_roi": is_cx_in_roi,
"lost_frames": 0,
"lost_roi": current_roi,
"last_cxcy": (cx, cy)
}
current_prisoner_tids.add(track_id)
# 绘制犯人框+中心点+状态标签(可视化调试)
x1, y1, x2, y2 = map(int, track_box)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2) # 红色犯人框
cv2.circle(frame, (int(cx), int(cy)), 5, (0, 255, 255), -1) # 黄色中心点
cv2.putText(frame, f"Prisoner_{track_id}({current_roi})", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# ========================= 5. 核心判定中心点在ROI内消失则报警 =========================
for track_id in list(self.prisoner_track_info.keys()):
if track_id not in current_prisoner_tids:
# 犯人本帧消失,获取其最后状态
track_info = self.prisoner_track_info[track_id]
# 仅处理「**中心点原本在ROI内**」的消失情况
if track_info["is_cx_in_roi"]:
track_info["lost_frames"] += 1 # 累计消失帧数
# 消失帧数达到阈值,触发报警
if track_info["lost_frames"] >= ROI_LOST_FRAMES_THRESH:
# self.push_alert(
# camera_id=camera_id,
# track_id=track_id,
# lost_roi=track_info["lost_roi"],
# last_cxcy=track_info["last_cxcy"],
# timestamp=timestamp
# )
# 记录本帧报警信息
current_frame_alerts.append({
"time": timestamp,
"camera_id": camera_id,
"action": "Indoor Violation",
"prisoner_track_id": track_id,
"disappear_roi": track_info["lost_roi"],
"last_cx": round(track_info["last_cxcy"][0], 2),
"last_cy": round(track_info["last_cxcy"][1], 2)
})
del self.prisoner_track_info[track_id] # 报警后清除状态,避免重复触发
else:
del self.prisoner_track_info[track_id] # 中心点不在ROI的消失直接清除
# ========================= 6. 绘制辅助信息摄像头ID、在押犯人数 =========================
cv2.putText(frame, f"Camera: {camera_id}", (20, self.frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, f"Prisoners: {len(current_prisoner_tids)}", (20, self.frame_height - 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
return {"image": frame, "alerts": current_frame_alerts}
# ========================= 帧处理线程(对接原有框架,直接复用) =========================
class FrameProcessorWorker(BaseFrameProcessorWorker):
"""看守所走廊犯人检测 - 5ROI+中心点消失判定"""
DETECTOR_FACTORY = lambda params: PrisonerDoorDetector(params)
POST_TYPE = 3 # 与原有业务区分,自定义即可
TARGET_FPS = RTSP_FPS