Compare commits

...

2 Commits

View File

@@ -12,117 +12,75 @@ import json
from typing import Dict, Any, Tuple, List
from openxlab.model.commands import Model
from biz.base_frame_processor import BaseFrameProcessorWorker
# -------------------------- Kadian 检测相关导入 --------------------------
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from yolox.tracker.byte_tracker import BYTETracker
from rtsp_service_ws_prison import TrackerArgs
from utils.logger import get_logger
logger = get_logger(__name__)
# ========================= 配置区 =========================
Person_Phone_Model = r'YOLO_Weight/person_phone_model.onnx' # 人和手机的检测模型
Smoke_Model = r'YOLO_Weight/smoke_model.onnx' # 抽烟检测模型
person_phone_input_size = 1280 # 模型输入尺寸,与训练时的模型一致
smoke_input_size = 1280 # 模型输入尺寸,与训练时的模型一致
# RTSP 服务配置
RTSP_TARGET_FPS = 5.0
RTSP_TARGET_FPS = 10.0
# 新增:告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0 # 相同action 5秒内仅推送一次
Model_Path = 'YOLO_Weight/zhihuishi.onnx'
Model_size = 640 # yolo模型尺寸
Label_Map = {
-1: 'Unknown',
0: 'Police'
}
Color_Map = {
-1: (255, 255, 255), # 白
0: (0, 255, 0) # 绿
}
NOBODY_THRESHOLD = 5.0 * RTSP_TARGET_FPS # 当屏幕中的人消失了开始计数如果累计够时长判定Nobody如果中间又检测到了人则立即清空计数
class ZhihuishiDetector:
def __init__(self, params=None):
# 模型加载
# 人和手机检测模型
print(f"加载人和手机检测模型: {Person_Phone_Model}")
self.person_phone_detector = YOLOv8_ONNX(Person_Phone_Model, conf_threshold=0.6, iou_threshold=0.45,
input_size=person_phone_input_size)
# 抽烟检测模型
print(f"加载抽烟检测模型: {Smoke_Model}")
self.smoke_detector = YOLOv8_ONNX(Smoke_Model, conf_threshold=0.4, iou_threshold=0.65,
input_size=smoke_input_size)
# person检测模型
self.detector = YOLOv8_ONNX(Model_Path, conf_threshold=0.6, iou_threshold=0.6, input_size=Model_size)
# ByteTracker
class TrackerArgs:
track_thresh = 0.25
track_buffer = 30
track_thresh = 0.61
track_buffer = RTSP_TARGET_FPS * 3 # 3 秒未见该目标,则判定该目标消失,在字典中删除
match_thresh = 0.8
mot20 = False
self.fps = RTSP_TARGET_FPS
self.person_phone_tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
self.smoke_tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
# 当前帧的ID
self.current_frame_idx = 0
self.person_phone_track_role = {}
self.smoke_track_role = {}
# 设置 ByteTrack跟踪器
self.ByteTracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
# 用来保存历史跟踪目标的字典,当目标消失了之后,就在该字典中清除该目标
self.track_role = {}
# ==========================================
# 超参数设置 (Hyperparameters)
# ==========================================
# 1. 业务判定时间阈值
self.TIME_THRESHOLD_NOBODY = 2.0 # 无人在场判定时长
self.TIME_TOLERANCE_NOBODY = 2.0 # 人丢失缓冲时间
self.TIME_THRESHOLD_SMOKE = 1.0 # 抽烟判定时长
self.TIME_TOLERANCE_SMOKE = 0.5 # 烟丢失缓冲时间(防抖动)
self.TIME_THRESHOLD_PHONE = 1.0 # 玩手机判定时长
self.TIME_TOLERANCE_PHONE = 0.5 # 手机丢失缓冲时间(防抖动)
# 无人在场帧数阈值
self.frame_thresh_nobody = int(self.TIME_THRESHOLD_NOBODY * self.fps)
self.frame_buffer_nobody = int(self.TIME_TOLERANCE_NOBODY * self.fps)
# 抽烟检测帧数阈值
self.frame_thresh_smoke = int(self.TIME_THRESHOLD_SMOKE * self.fps)
self.frame_buffer_smoke = int(self.TIME_TOLERANCE_SMOKE * self.fps)
# 手机检测帧数阈值
self.frame_thresh_phone = int(self.TIME_THRESHOLD_PHONE * self.fps)
self.frame_buffer_phone = int(self.TIME_TOLERANCE_PHONE * self.fps)
# 记录无人的帧数
self.nobody_frames = 0
print(f"\n超参数设置:")
print(f" FPS: {self.fps:.2f}")
print(f" 判定 'Nobody' 需连续: {self.frame_thresh_nobody}")
print(f" 判定 'Smoke Detected' 需累计检测: {self.frame_thresh_smoke}")
print(f" 抽烟丢失缓冲帧数: {self.frame_buffer_smoke}")
print(f" 判定 'Phone Detected' 需累计检测: {self.frame_thresh_phone}")
print(f" 手机丢失缓冲帧数: {self.frame_buffer_phone}")
print(f" 判定 'Nobody' 需连续: {NOBODY_THRESHOLD}")
# ==========================================
# 状态变量初始化
# ==========================================
self.current_frame_idx = 0
# 无人在场检测状态变量
self.nobody_detection_frames = 0
self.nobody_missing_frames = 0 # 连续未检测到手机的帧数
self.nobody_alert_active = False # 手机报警是否激活
# 手机检测状态变量
self.phone_detection_frames = 0 # 连续检测到手机的帧数
self.phone_missing_frames = 0 # 连续未检测到手机的帧数
self.phone_alert_active = False # 手机报警是否激活
# 抽烟检测状态变量
self.smoke_detection_frames = 0 # 连续检测到手机的帧数
self.smoke_missing_frames = 0 # 连续未检测到手机的帧数
self.smoke_alert_active = False # 手机报警是否激活
def compute_iou(self,boxA, boxB):
def compute_iou(self, boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
@@ -142,7 +100,7 @@ class ZhihuishiDetector:
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0):
def draw_alert(self, frame, text, color=(0, 0, 255), offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
@@ -155,340 +113,107 @@ class ZhihuishiDetector:
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
if sub_text:
cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2)
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
# =================================== 收集检测结果 ===================================
h, w = frame.shape[:2]
self.width, self.height = w, h
self.current_frame_idx += 1
current_time_sec = timestamp # 当前时间戳
current_time_sec = timestamp
# yolo 的检测结果
detect_results = self.detector(frame)
# ========= 人和手机检测 =========
person_phone_results = self.person_phone_detector(frame)
detect_xyxy = [] # 存储 yolo检测出来的所有检测框的角点坐标x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
detect_roles = [] # 存储 yolo检测出来的所有检测框的标签类别用id的形式保存
detect_bytetrack = [] # 从 yolo的检测结果中提取出来用于ByteTrack追踪检测框所需的信息保存在这里面
# ========= 抽烟检测 =========
smoke_results = self.smoke_detector(frame)
# 累计在当前帧里每个标签类别被检测到的次数,存储格式为 类别id:次数
current_labels_count = {id: 0 for id in Label_Map}
person_phone_dets_xyxy = []
person_phone_dets_roles = []
person_phone_dets_for_tracker = []
smoke_dets_xyxy = []
smoke_dets_roles = []
smoke_dets_for_tracker = []
# ========= 当前帧所有警告列表(关键改动)==========
# ========= 存储当前帧所有警告 ==========
current_frame_alerts = [] # 每帧清空,重新收集
# 收集 人和手机的检测结果
if person_phone_results:
for det in person_phone_results:
x1, y1, x2, y2, conf, cls_id = det # x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
person_phone_dets_xyxy.append([x1, y1, x2, y2])
person_phone_dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
person_phone_dets_roles.append("phone")
elif cls_id == 1:
person_phone_dets_roles.append("police")
# 遍历 yolo 的检测结果,对 detect_xyxy detect_roles detect_bytetrack 进行填充
if detect_results:
for result in detect_results: # yolo检测结果返回 x1, y1, x2, y2, conf, cls_id
detect_xyxy.append(result[:-2])
detect_roles.append(result[-1])
detect_bytetrack.append(result[:-1])
person_phone_dets = np.array(person_phone_dets_for_tracker, dtype=np.float32) if len(
person_phone_dets_for_tracker) else np.empty((0, 5))
person_phone_tracks = self.person_phone_tracker.update(
person_phone_dets,
# 根据收集到的 detect_bytetrack 确定追踪的检测框目标
tracks = self.ByteTracker.update(
np.array(detect_bytetrack, dtype=np.float32) if len(detect_bytetrack) else np.empty((0, 5)),
# np.empty((0,5)) 表示一个 0 行、5 列 的二维空数组
[self.height, self.width],
[self.height, self.width]
)
# 收集 抽烟的检测结果
if smoke_results:
for det in smoke_results:
x1, y1, x2, y2, conf, cls_id = det
smoke_dets_xyxy.append([x1, y1, x2, y2])
smoke_dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
smoke_dets_roles.append("smoke")
# 匹配每个跟踪目标的正确类别
# 为什么要用track的结果来统计标签类别的出现次数以及绘制检测框而不是仅用yolo的检测结果来统计及绘制是因为yolo的检测结果是针对单帧而bytetrack可以实现跨帧处理bytetrack的track_id会给每个目标设置一个唯一的id
current_track_ids = []
for track in tracks:
track_id = track.track_id
current_track_ids.append(track_id)
smoke_dets = np.array(smoke_dets_for_tracker, dtype=np.float32) if len(
smoke_dets_for_tracker) else np.empty((0, 5))
reIdentify_frame_interval = 10 # 重新匹配每个跟踪目标的类别的帧间隔
if (current_time_sec % reIdentify_frame_interval == 0) or track_id not in self.track_role:
best_iou = 0.0
best_role = -1
smoke_tracks = self.smoke_tracker.update(
smoke_dets,
[self.height, self.width],
[self.height, self.width]
)
track_box = list(map(float, track.tlbr))
# ========= 单帧统计变量 =========
current_person_count = 0
current_phone_count = 0
current_smoke_count = 0
for i, box in enumerate(detect_xyxy):
iou = self.compute_iou(track_box, box)
if iou > best_iou:
best_iou = iou
best_role = detect_roles[i]
# ========= 人和手机检测 =========
for t in person_phone_tracks:
# print("t: {}".format(t))
tid = t.track_id
# cls_id = -1
self.track_role[track_id] = best_role
# IoU 匹配角色
# IoU匹配跟踪ID和类别
REVALIDATE_FRAME_INTERVAL = 10
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (tid not in self.person_phone_track_role):
#if tid not in self.person_phone_track_role:
best_iou = 0
best_role = "unknown"
role = self.track_role[track_id]
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
current_labels_count[role] += 1
for i, box in enumerate(person_phone_dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = person_phone_dets_roles[i]
if best_iou > 0.1:
self.person_phone_track_role[tid] = best_role
else:
self.person_phone_track_role[tid] = "unknown"
# 当 role 不等于 unknown 的时候,绘制检测框
if role != -1:
x1, y1, x2, y2 = map(int, track.tlbr)
cv2.rectangle(frame, (x1, y1), (x2, y2), Color_Map[role], 2)
cv2.putText(frame, Label_Map[role], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Color_Map[role], 2)
role = self.person_phone_track_role.get(tid, "unknown")
cls_id = -1
if role == "phone":
cls_id = 0
elif role == "police":
cls_id = 1
# print("tid: {}, role: {}, cls: {}".format(tid, role,cls_id))
# 处理过期的 track_role 里的role,如果 track_role 里包含 tracks 里没有的 role ,直接删了即可
for role in list(self.track_role.keys()): # 遍历字典的时候不能直接删元素,用 list() 先复制一份 key再遍历删除才安全
if role not in current_track_ids:
del self.track_role[role]
x1, y1, x2, y2 = map(int, t.tlbr)
# ========================= 业务逻辑判断 ===========================
nobody_alter_flag = False # 无人在场业务逻辑是否成立
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
color = None
label = None
# Nobody 业务逻辑判断
if current_labels_count[0] == 0:
self.nobody_frames += 1
if cls_id == 0: # Person
current_phone_count += 1
color = (255, 0, 255)
label = "Phone"
elif cls_id == 1: # Phone主模型已支持
current_person_count += 1
color = (0, 0, 139)
label = "Person"
else:
color = (255, 255, 255)
label = "Unknown"
# label = f"ID:{tid} IN"
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ========= 抽烟检测 =========
for t in smoke_tracks:
# print("t: {}".format(t))
tid = t.track_id
# cls_id = -1
# IoU 匹配角色
# IoU匹配跟踪ID和类别
REVALIDATE_FRAME_INTERVAL = 10
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (tid not in self.smoke_track_role):
#if tid not in self.smoke_track_role:
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
for i, box in enumerate(smoke_dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = smoke_dets_roles[i]
# self.smoke_track_role[tid] = best_role
if best_iou > 0.1:
self.smoke_track_role[tid] = best_role
else:
self.smoke_track_role[tid] = "unknown"
role = self.smoke_track_role.get(tid, "unknown")
cls_id = -1
if role == "smoke":
cls_id = 0
x1, y1, x2, y2 = map(int, t.tlbr)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
color = None
label = None
if cls_id == 0: # 抽烟
current_smoke_count += 1
color = (255, 255, 0)
label = "Smoke"
else:
color = (255, 255, 255)
label = "Unknown"
# label = f"ID:{tid} IN"
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ==========================================
# 手机检测
# ==========================================
if current_phone_count > 0:
# 检测到手机框
self.phone_detection_frames += 1
self.phone_missing_frames = 0 # 重置丢失计数器
# 当检测累计达到阈值时,激活报警
if self.phone_detection_frames >= self.frame_thresh_phone:
self.phone_alert_active = True
if self.nobody_frames >= NOBODY_THRESHOLD:
nobody_alter_flag = True
else:
# 未检测到手机框
self.phone_missing_frames += 1
# 如果之前检测到手机,重置检测计数器
if self.phone_detection_frames > 0:
# 只有在连续丢失超过缓冲帧数时才重置
if self.phone_missing_frames >= self.frame_buffer_phone:
self.phone_detection_frames = 0
self.phone_alert_active = False
else:
# 从未检测到手机,保持状态
pass
# ==========================================
# 抽烟检测
# ==========================================
if current_smoke_count > 0:
# 检测到抽烟框
self.smoke_detection_frames += 1
self.smoke_missing_frames = 0 # 重置丢失计数器
# 当检测累计达到阈值时,激活报警
if self.smoke_detection_frames >= self.frame_thresh_smoke:
self.smoke_alert_active = True
else:
# 未检测到抽烟框
self.smoke_missing_frames += 1
# 如果之前检测到抽烟,重置检测计数器
if self.smoke_detection_frames > 0:
# 只有在连续丢失超过缓冲帧数时才重置
if self.smoke_missing_frames >= self.frame_buffer_smoke:
self.smoke_detection_frames = 0
self.smoke_alert_active = False
else:
# 从未检测到抽烟,保持状态
pass
# ==========================================
# 9. 业务逻辑判定 (Only One / Nobody)
# ==========================================
status_text = ""
if current_person_count == 0:
self.nobody_detection_frames += 1
self.nobody_missing_frames = 0
if self.nobody_detection_frames >= self.frame_thresh_nobody:
self.nobody_alert_active = True
else:
self.nobody_missing_frames += 1
if self.nobody_detection_frames > 0:
if self.nobody_missing_frames >= self.frame_buffer_nobody:
self.nobody_detection_frames = 0
self.nobody_alert_active = False
else:
pass
self.nobody_frames = 0
# if current_person_count == 0:
# self.cnt_frame_nobody += 1
# else:
# self.cnt_frame_nobody = 0
# ==========================================
# 10. 收集并生成结构化警告(核心改动)
# ==========================================
alert_offset = 0
# A. Playing Phone
if self.phone_alert_active:
duration_seconds = self.phone_detection_frames / self.fps
if nobody_alter_flag:
action_text = 'Nobody Checking'
current_frame_alerts.append(
{
'time': current_time_sec,
'action': 'Playing Phone',
'confidence': 1.0, # 固定为1.0(规则判定)
'details': f"Detected for {duration_seconds:.1f}s"
'action': action_text,
}
)
# A. Playing Phone
if self.smoke_alert_active:
duration_seconds = self.smoke_detection_frames / self.fps
current_frame_alerts.append(
{
'time': current_time_sec,
'action': 'Smoke',
'confidence': 1.0, # 固定为1.0(规则判定)
'details': f"Detected for {duration_seconds:.1f}s"
}
)
# D. Nobody Checking
if self.nobody_alert_active:
duration_seconds = self.nobody_detection_frames / self.fps
current_frame_alerts.append({
'time': current_time_sec,
'action': 'Nobody Checking',
'confidence': 1.0,
'details': f"Detected for {duration_seconds:.1f}s"
})
# ==========================================
# 11. 统一显示当前帧所有警告(可替换原分层显示)
# ==========================================
debug_info = f"Person: {current_person_count} | Phone: {current_phone_count} | Smoke: {current_smoke_count}"
cv2.putText(frame, debug_info, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 统一警告显示区
alert_y_start = 150
for i, alert in enumerate(current_frame_alerts):
action = alert['action']
details = alert.get('details', '')
color = (0, 0, 255) # 默认红色警告
if action == 'Nobody Checking':
color = (255, 255, 255)
elif action == 'Smoke':
color = (0, 0, 255)
elif action == 'Playing Phone':
color = (255, 0, 0)
main_text = action
if details:
main_text += f" ({details})"
y_pos = alert_y_start + i * 50
cv2.rectangle(frame, (20, y_pos - 40), (900, y_pos + 10), (0, 0, 0), -1)
cv2.putText(frame, main_text, (30, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
self.draw_alert(frame, action_text, offset_y=0)
return {
"image": frame,
"alerts":current_frame_alerts
"alerts": current_frame_alerts
}