Files
SupervisorAI/biz/prison/supervision_room_biz.py

229 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rtsp_service_kadian.py
# 融合 Kadian_Detect_1221.py + rtsp_service_ws.py
# 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警
import cv2
import numpy as np
import os
import threading
import queue
import yaml
import json
from typing import Dict, Any, Tuple, List
from openxlab.model.commands import Model
from biz.base_frame_processor import BaseFrameProcessorWorker
# -------------------------- Kadian 检测相关导入 --------------------------
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from yolox.tracker.byte_tracker import BYTETracker
from rtsp_service_ws_prison import TrackerArgs
from utils.logger import get_logger
logger = get_logger(__name__)
# ========================= 配置区 =========================
# RTSP 服务配置
RTSP_TARGET_FPS = 10.0
# 新增:告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0 # 相同action 5秒内仅推送一次
Model_Path = 'YOLO_Weight/zhihuishi.onnx'
Model_size = 640 # yolo模型尺寸
Label_Map = {
-1: 'Unknown',
0: 'Police'
}
Color_Map = {
-1: (255, 255, 255), # 白
0: (0, 255, 0) # 绿
}
NOBODY_THRESHOLD = 5.0 * RTSP_TARGET_FPS # 当屏幕中的人消失了开始计数如果累计够时长判定Nobody如果中间又检测到了人则立即清空计数
class ZhihuishiDetector:
def __init__(self, params=None):
# person检测模型
self.detector = YOLOv8_ONNX(Model_Path, conf_threshold=0.6, iou_threshold=0.6, input_size=Model_size)
# ByteTracker
class TrackerArgs:
track_thresh = 0.61
track_buffer = RTSP_TARGET_FPS * 3 # 3 秒未见该目标,则判定该目标消失,在字典中删除
match_thresh = 0.8
mot20 = False
self.fps = RTSP_TARGET_FPS
# 当前帧的ID
self.current_frame_idx = 0
# 设置 ByteTrack跟踪器
self.ByteTracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
# 用来保存历史跟踪目标的字典,当目标消失了之后,就在该字典中清除该目标
self.track_role = {}
# 记录无人的帧数
self.nobody_frames = 0
print(f"\n超参数设置:")
print(f" FPS: {self.fps:.2f}")
print(f" 判定 'Nobody' 需连续: {NOBODY_THRESHOLD}")
def compute_iou(self, boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
x = self.width - text_w - 20
y = 50 + text_h + offset_y # 增加 Y 轴偏移
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
# =================================== 收集检测结果 ===================================
h, w = frame.shape[:2]
self.width, self.height = w, h
self.current_frame_idx += 1
current_time_sec = timestamp # 当前时间戳
# yolo 的检测结果
detect_results = self.detector(frame)
detect_xyxy = [] # 存储 yolo检测出来的所有检测框的角点坐标x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
detect_roles = [] # 存储 yolo检测出来的所有检测框的标签类别用id的形式保存
detect_bytetrack = [] # 从 yolo的检测结果中提取出来用于ByteTrack追踪检测框所需的信息保存在这里面
# 累计在当前帧里每个标签类别被检测到的次数,存储格式为 类别id:次数
current_labels_count = {id: 0 for id in Label_Map}
# ========= 存储当前帧所有警告 ==========
current_frame_alerts = [] # 每帧清空,重新收集
# 遍历 yolo 的检测结果,对 detect_xyxy detect_roles detect_bytetrack 进行填充
if detect_results:
for result in detect_results: # yolo检测结果返回 x1, y1, x2, y2, conf, cls_id
detect_xyxy.append(result[:-2])
detect_roles.append(result[-1])
detect_bytetrack.append(result[:-1])
# 根据收集到的 detect_bytetrack 确定追踪的检测框目标
tracks = self.ByteTracker.update(
np.array(detect_bytetrack, dtype=np.float32) if len(detect_bytetrack) else np.empty((0, 5)),
# np.empty((0,5)) 表示一个 0 行、5 列 的二维空数组
[self.height, self.width],
[self.height, self.width]
)
# 匹配每个跟踪目标的正确类别
# 为什么要用track的结果来统计标签类别的出现次数以及绘制检测框而不是仅用yolo的检测结果来统计及绘制是因为yolo的检测结果是针对单帧而bytetrack可以实现跨帧处理bytetrack的track_id会给每个目标设置一个唯一的id
current_track_ids = []
for track in tracks:
track_id = track.track_id
current_track_ids.append(track_id)
reIdentify_frame_interval = 10 # 重新匹配每个跟踪目标的类别的帧间隔
if (current_time_sec % reIdentify_frame_interval == 0) or track_id not in self.track_role:
best_iou = 0.0
best_role = -1
track_box = list(map(float, track.tlbr))
for i, box in enumerate(detect_xyxy):
iou = self.compute_iou(track_box, box)
if iou > best_iou:
best_iou = iou
best_role = detect_roles[i]
self.track_role[track_id] = best_role
role = self.track_role[track_id]
current_labels_count[role] += 1
# 当 role 不等于 unknown 的时候,绘制检测框
if role != -1:
x1, y1, x2, y2 = map(int, track.tlbr)
cv2.rectangle(frame, (x1, y1), (x2, y2), Color_Map[role], 2)
cv2.putText(frame, Label_Map[role], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Color_Map[role], 2)
# 处理过期的 track_role 里的role,如果 track_role 里包含 tracks 里没有的 role ,直接删了即可
for role in list(self.track_role.keys()): # 遍历字典的时候不能直接删元素,用 list() 先复制一份 key再遍历删除才安全
if role not in current_track_ids:
del self.track_role[role]
# ========================= 业务逻辑判断 ===========================
nobody_alter_flag = False # 无人在场业务逻辑是否成立
# Nobody 业务逻辑判断
if current_labels_count[0] == 0:
self.nobody_frames += 1
if self.nobody_frames >= NOBODY_THRESHOLD:
nobody_alter_flag = True
else:
self.nobody_frames = 0
if nobody_alter_flag:
action_text = 'Nobody Checking'
current_frame_alerts.append(
{
'time': current_time_sec,
'action': action_text,
}
)
self.draw_alert(frame, action_text, offset_y=0)
return {
"image": frame,
"alerts": current_frame_alerts
}
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(BaseFrameProcessorWorker):
"""监控室检测帧处理线程"""
# 子类配置
DETECTOR_FACTORY = lambda params: ZhihuishiDetector(params)
POST_TYPE = 2
TARGET_FPS = RTSP_TARGET_FPS