Files
SupervisorAI/biz/prison/ab_biz.py
2026-03-06 11:11:09 +08:00

260 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import base64
from typing import Dict, Any
import threading
import time
import queue
import requests
from biz.base_frame_processor import BaseFrameProcessorWorker
# -------------------------- Kadian 检测相关导入 --------------------------
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from common.constants import ALERT_PUSH_URL
from yolox.tracker.byte_tracker import BYTETracker
# ========================= 配置区 =========================
# Kadian 模型路径与ROI可根据实际情况修改
detector_model_path = 'YOLO_Weight/bag_model.onnx'
# 输入尺寸
input_size = 640
RTSP_TARGET_FPS = 10.0
# 新增:告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0 # 相同action 5秒内仅推送一次
class AbDetector:
def __init__(self, params=None):
# 摄像头额外参数
self.params = params if params is not None else {}
# 模型加载
self.detector = YOLOv8_ONNX(detector_model_path, conf_threshold=0.5, iou_threshold=0.45,
input_size=input_size)
# ByteTracker
class TrackerArgs:
track_thresh = 0.25
track_buffer = 30
match_thresh = 0.8
mot20 = False
self.track_role = {}
self.fps = RTSP_TARGET_FPS
self.tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
# ==========================================
# 超参数设置 (Hyperparameters)
# ==========================================
self.TIME_THRESHOLD_BLACKBAG = 1.0 # 黑包判定时长(秒)
self.TIME_TOLERANCE_BLACKBAG = 0.5 # 黑包丢失缓冲时间
# 转换为帧数阈值
self.frame_thresh_blackbag = int(self.TIME_THRESHOLD_BLACKBAG * self.fps)
self.frame_buffer_blackbag = int(self.TIME_TOLERANCE_BLACKBAG * self.fps)
print(f"\n超参数设置:")
print(f" FPS: {self.fps:.2f}")
print(f" 判定 'BlackBag Detected' 需累计检测: {self.frame_thresh_blackbag}")
print(f" 黑包丢失缓冲帧数: {self.frame_buffer_blackbag}")
# ==========================================
# 状态变量初始化
# ==========================================
self.current_frame_idx = 0
# 黑包检测状态
self.blackbag_detection_frames = 0
self.blackbag_missing_frames = 0
self.blackbag_alert_active = False
# 人员统计变量
self.current_person_count = 0
def compute_iou(self,boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
x = self.width - text_w - 20
y = 50 + text_h + offset_y # 增加 Y 轴偏移
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
if sub_text:
cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2)
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
h, w = frame.shape[:2]
self.width, self.height = w, h
self.current_frame_idx += 1
current_time_sec = timestamp
# ========= 检测推理(黑包+人)=========
detect_results = self.detector(frame)
# 初始化检测结果存储
dets_xyxy = []
dets_roles = []
dets_for_tracker = []
current_frame_alerts = []
# 解析检测结果黑包cls_id=0人员cls_id=1
if detect_results:
for det in detect_results:
x1, y1, x2, y2, conf, cls_id = det
dets_xyxy.append([x1, y1, x2, y2])
dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
dets_roles.append("black_bag")
elif cls_id == 1:
dets_roles.append("person")
# 跟踪器更新
dets = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) else np.empty((0, 5))
tracks = self.tracker.update(
dets,
[self.height, self.width],
[self.height, self.width]
)
# ========= 单帧统计初始化 =========
self.current_person_count = 0
current_blackbag_count = 0
# ========= 跟踪结果绘制与统计 =========
for t in tracks:
tid = t.track_id
# IoU匹配跟踪ID和类别
REVALIDATE_FRAME_INTERVAL = 10
#if tid not in self.track_role:
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (tid not in self.track_role):
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr))
for i, box in enumerate(dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = dets_roles[i]
self.track_role[tid] = best_role if best_iou > 0.1 else "unknown"
role = self.track_role.get(tid, "unknown")
x1, y1, x2, y2 = map(int, t.tlbr)
color = (255, 255, 255)
label = "Unknown"
# 人员检测cls_id=1
if role == "person":
self.current_person_count += 1
color = (255, 0, 255) # 紫色框
label = "Person"
# 黑包检测cls_id=0
elif role == "black_bag":
current_blackbag_count += 1
color = (0, 128, 0) # 绿色框
label = "Black Bag"
# 绘制检测框和标签
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ==========================================
# 黑包状态更新
# ==========================================
if current_blackbag_count > 0:
self.blackbag_detection_frames += 1
self.blackbag_missing_frames = 0
if self.blackbag_detection_frames >= self.frame_thresh_blackbag:
self.blackbag_alert_active = True
else:
self.blackbag_missing_frames += 1
if self.blackbag_missing_frames >= self.frame_buffer_blackbag:
self.blackbag_detection_frames = 0
self.blackbag_alert_active = False
# ==========================================
# 警告信息收集
# ==========================================
if self.blackbag_alert_active:
duration_seconds = self.blackbag_detection_frames / self.fps
current_frame_alerts.append(
{
'time': current_time_sec,
'action': 'Black Bag',
'details': f"Detected for {duration_seconds:.1f}s"
}
)
self.draw_alert(frame, "Black Bag Alert", (0, 0, 255), sub_text=f"Detected for {duration_seconds:.1f}s")
# ==========================================
# 绘制信息
# ==========================================
# 实时统计
debug_info = f"Person: {self.current_person_count} | BlackBag: {current_blackbag_count}"
cv2.putText(frame, debug_info, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 警告信息
alert_y_start = 150
for i, alert in enumerate(current_frame_alerts):
action = alert['action']
details = alert.get('details', '')
color = (0, 0, 255) # 红色警告
main_text = f"{action} ({details})"
y_pos = alert_y_start + i * 50
cv2.rectangle(frame, (20, y_pos - 40), (900, y_pos + 10), (0, 0, 0), -1)
cv2.putText(frame, main_text, (30, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
return {
"image": frame,
"alerts":current_frame_alerts
}
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(BaseFrameProcessorWorker):
"""轨迹检测帧处理线程"""
# 子类配置
DETECTOR_FACTORY = lambda params: AbDetector(params)
POST_TYPE = 2
TARGET_FPS = RTSP_TARGET_FPS