229 lines
8.3 KiB
Python
229 lines
8.3 KiB
Python
# rtsp_service_kadian.py
|
||
# 融合 Kadian_Detect_1221.py + rtsp_service_ws.py
|
||
# 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import os
|
||
import threading
|
||
import queue
|
||
import yaml
|
||
import json
|
||
|
||
from typing import Dict, Any, Tuple, List
|
||
|
||
from openxlab.model.commands import Model
|
||
|
||
from biz.base_frame_processor import BaseFrameProcessorWorker
|
||
|
||
# -------------------------- Kadian 检测相关导入 --------------------------
|
||
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
|
||
|
||
from yolox.tracker.byte_tracker import BYTETracker
|
||
|
||
from rtsp_service_ws_prison import TrackerArgs
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# ========================= 配置区 =========================
|
||
# RTSP 服务配置
|
||
RTSP_TARGET_FPS = 10.0
|
||
|
||
# 新增:告警推送频率限制(秒)
|
||
ALERT_PUSH_INTERVAL = 5.0 # 相同action 5秒内仅推送一次
|
||
|
||
Model_Path = 'YOLO_Weight/zhihuishi.onnx'
|
||
|
||
Model_size = 640 # yolo模型尺寸
|
||
|
||
Label_Map = {
|
||
-1: 'Unknown',
|
||
0: 'Police'
|
||
}
|
||
|
||
Color_Map = {
|
||
-1: (255, 255, 255), # 白
|
||
0: (0, 255, 0) # 绿
|
||
}
|
||
|
||
NOBODY_THRESHOLD = 5.0 * RTSP_TARGET_FPS # 当屏幕中的人消失了开始计数,如果累计够时长判定Nobody,如果中间又检测到了人,则立即清空计数
|
||
|
||
|
||
class ZhihuishiDetector:
|
||
def __init__(self, params=None):
|
||
# person检测模型
|
||
self.detector = YOLOv8_ONNX(Model_Path, conf_threshold=0.6, iou_threshold=0.6, input_size=Model_size)
|
||
|
||
# ByteTracker
|
||
class TrackerArgs:
|
||
track_thresh = 0.61
|
||
track_buffer = RTSP_TARGET_FPS * 3 # 3 秒未见该目标,则判定该目标消失,在字典中删除
|
||
match_thresh = 0.8
|
||
mot20 = False
|
||
|
||
self.fps = RTSP_TARGET_FPS
|
||
|
||
# 当前帧的ID
|
||
self.current_frame_idx = 0
|
||
|
||
# 设置 ByteTrack跟踪器
|
||
self.ByteTracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
|
||
|
||
# 用来保存历史跟踪目标的字典,当目标消失了之后,就在该字典中清除该目标
|
||
self.track_role = {}
|
||
|
||
# 记录无人的帧数
|
||
self.nobody_frames = 0
|
||
|
||
print(f"\n超参数设置:")
|
||
print(f" FPS: {self.fps:.2f}")
|
||
print(f" 判定 'Nobody' 需连续: {NOBODY_THRESHOLD} 帧")
|
||
|
||
def compute_iou(self, boxA, boxB):
|
||
# box = [x1, y1, x2, y2]
|
||
xA = max(boxA[0], boxB[0])
|
||
yA = max(boxA[1], boxB[1])
|
||
xB = min(boxA[2], boxB[2])
|
||
yB = min(boxA[3], boxB[3])
|
||
|
||
interW = max(0, xB - xA)
|
||
interH = max(0, yB - yA)
|
||
interArea = interW * interH
|
||
|
||
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
||
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
||
|
||
unionArea = boxAArea + boxBArea - interArea
|
||
if unionArea == 0:
|
||
return 0.0
|
||
|
||
return interArea / unionArea
|
||
|
||
def draw_alert(self, frame, text, color=(0, 0, 255), offset_y=0):
|
||
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
|
||
font_scale = 1.5
|
||
thickness = 3
|
||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||
|
||
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
|
||
x = self.width - text_w - 20
|
||
y = 50 + text_h + offset_y # 增加 Y 轴偏移
|
||
|
||
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
|
||
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
|
||
|
||
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
|
||
|
||
# =================================== 收集检测结果 ===================================
|
||
h, w = frame.shape[:2]
|
||
self.width, self.height = w, h
|
||
self.current_frame_idx += 1
|
||
current_time_sec = timestamp # 当前时间戳
|
||
|
||
# yolo 的检测结果
|
||
detect_results = self.detector(frame)
|
||
|
||
detect_xyxy = [] # 存储 yolo检测出来的所有检测框的角点坐标,x1, y1, x2, y2为角点坐标,x1 y1为左上角,x2 y2为右下角
|
||
detect_roles = [] # 存储 yolo检测出来的所有检测框的标签类别,用id的形式保存
|
||
detect_bytetrack = [] # 从 yolo的检测结果中提取出来用于ByteTrack追踪检测框所需的信息,保存在这里面
|
||
|
||
# 累计在当前帧里每个标签类别被检测到的次数,存储格式为 类别id:次数
|
||
current_labels_count = {id: 0 for id in Label_Map}
|
||
|
||
# ========= 存储当前帧所有警告 ==========
|
||
current_frame_alerts = [] # 每帧清空,重新收集
|
||
|
||
# 遍历 yolo 的检测结果,对 detect_xyxy detect_roles detect_bytetrack 进行填充
|
||
if detect_results:
|
||
for result in detect_results: # yolo检测结果返回 x1, y1, x2, y2, conf, cls_id
|
||
detect_xyxy.append(result[:-2])
|
||
detect_roles.append(result[-1])
|
||
detect_bytetrack.append(result[:-1])
|
||
|
||
# 根据收集到的 detect_bytetrack 确定追踪的检测框目标
|
||
tracks = self.ByteTracker.update(
|
||
np.array(detect_bytetrack, dtype=np.float32) if len(detect_bytetrack) else np.empty((0, 5)),
|
||
# np.empty((0,5)) 表示一个 0 行、5 列 的二维空数组
|
||
[self.height, self.width],
|
||
[self.height, self.width]
|
||
)
|
||
|
||
# 匹配每个跟踪目标的正确类别
|
||
# 为什么要用track的结果来统计标签类别的出现次数以及绘制检测框,而不是仅用yolo的检测结果来统计及绘制,是因为yolo的检测结果是针对单帧,而bytetrack可以实现跨帧处理,bytetrack的track_id会给每个目标设置一个唯一的id
|
||
current_track_ids = []
|
||
for track in tracks:
|
||
track_id = track.track_id
|
||
current_track_ids.append(track_id)
|
||
|
||
reIdentify_frame_interval = 10 # 重新匹配每个跟踪目标的类别的帧间隔
|
||
if (current_time_sec % reIdentify_frame_interval == 0) or track_id not in self.track_role:
|
||
best_iou = 0.0
|
||
best_role = -1
|
||
|
||
track_box = list(map(float, track.tlbr))
|
||
|
||
for i, box in enumerate(detect_xyxy):
|
||
iou = self.compute_iou(track_box, box)
|
||
if iou > best_iou:
|
||
best_iou = iou
|
||
best_role = detect_roles[i]
|
||
|
||
self.track_role[track_id] = best_role
|
||
|
||
role = self.track_role[track_id]
|
||
|
||
current_labels_count[role] += 1
|
||
|
||
# 当 role 不等于 unknown 的时候,绘制检测框
|
||
if role != -1:
|
||
x1, y1, x2, y2 = map(int, track.tlbr)
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), Color_Map[role], 2)
|
||
cv2.putText(frame, Label_Map[role], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Color_Map[role], 2)
|
||
|
||
# 处理过期的 track_role 里的role,如果 track_role 里包含 tracks 里没有的 role ,直接删了即可
|
||
for role in list(self.track_role.keys()): # 遍历字典的时候不能直接删元素,用 list() 先复制一份 key,再遍历删除,才安全
|
||
if role not in current_track_ids:
|
||
del self.track_role[role]
|
||
|
||
# ========================= 业务逻辑判断 ===========================
|
||
nobody_alter_flag = False # 无人在场业务逻辑是否成立
|
||
|
||
|
||
# Nobody 业务逻辑判断
|
||
if current_labels_count[0] == 0:
|
||
self.nobody_frames += 1
|
||
|
||
if self.nobody_frames >= NOBODY_THRESHOLD:
|
||
nobody_alter_flag = True
|
||
else:
|
||
self.nobody_frames = 0
|
||
|
||
|
||
|
||
if nobody_alter_flag:
|
||
action_text = 'Nobody Checking'
|
||
current_frame_alerts.append(
|
||
{
|
||
'time': current_time_sec,
|
||
'action': action_text,
|
||
}
|
||
)
|
||
self.draw_alert(frame, action_text, offset_y=0)
|
||
|
||
return {
|
||
"image": frame,
|
||
"alerts": current_frame_alerts
|
||
}
|
||
|
||
|
||
# ========================= 帧处理线程 =========================
|
||
class FrameProcessorWorker(BaseFrameProcessorWorker):
|
||
"""监控室检测帧处理线程"""
|
||
|
||
# 子类配置
|
||
DETECTOR_FACTORY = lambda params: ZhihuishiDetector(params)
|
||
POST_TYPE = 2
|
||
TARGET_FPS = RTSP_TARGET_FPS
|
||
|