Files
SupervisorAI/biz/checkpoint/checkpoint_biz.py
2026-03-09 13:28:38 +08:00

523 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from typing import Dict, Any
import threading
import queue
from biz.base_frame_processor import BaseFrameProcessorWorker
# -------------------------- Kadian 检测相关导入 --------------------------
from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from algorithm.common.npu_yolo_pose_onnx import YOLOv8_Pose_ONNX # Pose 专用模型
from yolox.tracker.byte_tracker import BYTETracker
from utils.logger import get_logger
from common.constants import MODEL_ROOT_PATH
logger = get_logger(__name__)
DETECT_MODEL_PATH = 'YOLO_Weight/Kadian.onnx'
# 默认相对ROI与原文件一致
#ROI_RELATIVE = np.array([
# [0.10989583333333333, 0.006481481481481481],
# [0.421875, 0.005555555555555556],
# [0.9921875, 0.9888888888888889],
# [0.3411458333333333, 0.9861111111111112]
#])
# ROI_RELATIVE=np.array([
# [0.12,0.0],
# [0.3,0.0],
# [0.5,0.2],
# [1.0, 0.95],
# [1.0,1.0],
# [0.42,1.0]
# ])
ROI_RELATIVE=np.array([
[0.15,0.15],
[0.37,0.15],
[0.55,0.2],
[0.9,0.85],
[0.35,0.85]
])
ALERT_PUSH_INTERVAL = 5.0
# 输入尺寸
PERSON_CAR_INPUT_SIZE = 640
#POSE_INPUT_SIZE = 640
RTSP_TARGET_FPS = 10.0
# ========================= Kadian TrafficMonitor精简版专为服务设计 =========================
class KadianDetector:
def __init__(self, params=None):
# 摄像头额外参数
self.params = params if params is not None else {}
# 模型路径:从 params 读取,未配置则使用默认值 DETECT_MODEL_PATH
model_path = self.params.get('model_path')
if model_path:
full_model_path = f"{MODEL_ROOT_PATH}/{model_path}"
else:
full_model_path = DETECT_MODEL_PATH
logger.info(f"Loading model from: {full_model_path}")
# 模型加载
self.detector = YOLOv8_ONNX(
full_model_path,
conf_threshold=0.6,
iou_threshold=0.65,
input_size=PERSON_CAR_INPUT_SIZE
)
# 跟踪器配置
class TrackerArgs:
track_thresh = 0.61 # 必须大于等于yolo的conf_threshold
track_buffer = 40
match_thresh = 0.85
mot20 = True
self.fps = RTSP_TARGET_FPS
self.tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
self.track_role = {} # 跟踪ID到类别的映射
# ROI 处理:优先从 params 获取,否则使用默认值 ROI_RELATIVE
roi_points = self.params.get('roi_points', ROI_RELATIVE)
self.roi_points = np.array(roi_points, dtype=np.float64) if roi_points is not None else None
# ===================== 超参数设置 (仅保留车/后备箱相关) =====================
# 后备箱检查判定阈值
self.TIME_THRESHOLD_TRUNK_OPEN = 0.3
# 车辆最小停留时间阈值 (小于此时间视为无人检查/直接通过)
self.TIME_THRESHOLD_CAR_MIN_DURATION = 3.0
# Car 丢帧/ID维持缓冲
self.TIME_TOLERANCE_CAR = 2.0
# police丢失阈值
self.TIME_TOLERANCE_POLICE = 3.0
# police状态判定阈值 (累计秒数)
self.TIME_THRESHOLD_NOBODY = 10.0
self.TIME_THRESHOLD_ONLY_ONE = 10.0
# --- 计算对应的帧数阈值 ---
self.frame_thresh_trunk_valid = int(self.TIME_THRESHOLD_TRUNK_OPEN * self.fps)
self.frame_thresh_car_min_duration = int(self.TIME_THRESHOLD_CAR_MIN_DURATION * self.fps)
self.frame_buffer_limit_car = int(self.TIME_TOLERANCE_CAR * self.fps)
self.frame_buffer_limit_police = int(self.TIME_TOLERANCE_POLICE * self.fps)
self.frame_thresh_nobody = int(self.TIME_THRESHOLD_NOBODY * self.fps)
self.frame_thresh_only_one = int(self.TIME_THRESHOLD_ONLY_ONE * self.fps)
# 显示相关阈值
self.ignore_show_seconds = 0.2 # 未检测的警告显示时长
self.openTrunk_show_seconds = 0.2 # 打开后备箱的警告显示时长
self.police_show_seconds = 0.2 # 警察在场警告显示时长
# 状态变量初始化
self.current_frame_idx = 0
self.width = 0
self.height = 0
# 车辆注册表 (字典)
self.roi_car_registry = {}
# 违规车辆记录
self.unchecked_trunk_alerts = {} # 后备箱未检
self.fast_pass_alerts = {} # 通过过快
# 警察注册表 (字典)
self.roi_police_registry = {}
# 警察在场告警记录
self.nobody_alerts = {} # 无人在场
self.only_one_alerts = {} # 单人在场
# 累计帧数计数器
self.nobody_frames = 0 # 累计无人在场帧数
self.only_one_frames = 0 # 累计单人在场帧数
def _get_roi_points(self, frame_width: int, frame_height: int):
"""
每帧动态计算正确的 ROI 绝对坐标,并确保类型为 np.int32
用于 pointPolygonTest 和 polylines
"""
if self.roi_points is None:
raise ValueError("ROI points must be provided; cannot be None.")
if self.roi_points.max() <= 1.0:
# 相对坐标 → 转换为绝对
roi_abs = self.roi_points * np.array([frame_width, frame_height])
else:
# 绝对坐标,直接使用
roi_abs = self.roi_points.copy()
# 强制转为 int32关键解决 OpenCV 断言错误)
return roi_abs.astype(np.int32)
def check_point_in_roi(self, roi_points, point):
"""判断点是否在ROI内"""
return cv2.pointPolygonTest(roi_points, point, False) >= 0
def compute_iou(self, boxA, boxB):
"""计算两个框的IOU"""
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
x = self.width - text_w - 20
y = 50 + text_h + offset_y # 增加 Y 轴偏移
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
if sub_text:
cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2)
def is_point_in_box(self, point, box):
"""判断点是否在框内"""
px, py = point
x1, y1, x2, y2 = box
return x1 < px < x2 and y1 < py < y2
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
h, w = frame.shape[:2]
self.width, self.height = w, h
self.current_frame_idx += 1
# ========= 每帧动态获取正确的 ROIint32=========
roi_points_int32 = self._get_roi_points(w, h) # shape: (4, 2), dtype: int32
roi_points_draw = roi_points_int32.reshape((-1, 1, 2)) # shape: (4, 1, 2) 用于绘制
current_time_sec = timestamp
# ========= 主检测删除pose检测=========
detections = self.detector(frame)
dets_xyxy = []
dets_roles = []
dets_for_tracker = []
# ========= 当前帧所有警告列表 ==========
current_frame_alerts = [] # 每帧清空,重新收集
if detections:
for det in detections:
x1, y1, x2, y2, conf, cls_id = det # x1,y1:左上角x2,y2:右下角
dets_xyxy.append([x1, y1, x2, y2])
dets_for_tracker.append([x1, y1, x2, y2, conf])
# 更新类别映射0=Car,1=OpenTrunk,2=Passerby,3=Police
if cls_id == 0:
dets_roles.append("car")
elif cls_id == 1:
dets_roles.append("opentrunk")
elif cls_id == 2:
dets_roles.append("passerby") # 路人
elif cls_id == 3:
dets_roles.append("police") # 警察
dets = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) else np.empty((0, 5))
# 跟踪器更新
tracks = self.tracker.update(
dets,
[self.height, self.width],
[self.height, self.width]
)
# logger.debug("tracks: {}".format(tracks))
# ========= 绘制 ROI =========
cv2.polylines(frame, [roi_points_draw], isClosed=True, color=(255, 0, 0), thickness=3)
# ========= 单帧统计变量 =========
current_roi_trunk_count = 0 # 仅保留后备箱统计
current_roi_police_count = 0 # ROI内警察数量
# 临时存储本帧的目标,用于后续关联分析
current_cars = [] # {'id':, 'box':}
current_trunks = [] # (cx, cy)
# ========= 处理跟踪结果 =========
for t in tracks:
tid = t.track_id
REVALIDATE_FRAME_INTERVAL = 10
# 定期重新匹配跟踪ID的类别
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (tid not in self.track_role):
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
for i, box in enumerate(dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = dets_roles[i]
if best_iou > 0.1:
self.track_role[tid] = best_role
else:
self.track_role[tid] = "unknown"
role = self.track_role.get(tid, "unknown")
x1, y1, x2, y2 = map(int, t.tlbr)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# 定义不同类别的颜色(仅标框,不告警)
if role == "car":
color = (0, 255, 0) # 绿色
label = f"Car:{tid}"
# 仅处理ROI内的车辆
if self.check_point_in_roi(roi_points_int32, (cx, cy)):
current_cars.append({'id': tid, 'box': [x1, y1, x2, y2]})
# 车辆注册表初始化
if tid not in self.roi_car_registry:
self.roi_car_registry[tid] = {
'first_seen': self.current_frame_idx,
'last_seen': self.current_frame_idx,
'trunk_frames': 0,
'is_checked': False,
}
else:
self.roi_car_registry[tid]['last_seen'] = self.current_frame_idx
label += " IN"
elif role == "opentrunk":
color = (255, 165, 0) # 橙色
label = "OpenTrunk"
if self.check_point_in_roi(roi_points_int32, (cx, cy)):
current_roi_trunk_count += 1
current_trunks.append((cx, cy))
label += " IN"
elif role == "passerby":
color = (255, 255, 0) # 黄色(仅标框,不告警)
label = "Passerby"
elif role == "police":
color = (0, 255, 255) # 青色
label = "Police"
if self.check_point_in_roi(roi_points_int32, (cx, cy)):
current_roi_police_count += 1
# 警察注册表初始化
if tid not in self.roi_police_registry:
self.roi_police_registry[tid] = {
'first_seen': self.current_frame_idx,
'last_seen': self.current_frame_idx,
}
else:
self.roi_police_registry[tid]['last_seen'] = self.current_frame_idx
label += " IN"
else:
color = (255, 255, 255) # 白色
label = "Unknown"
# 绘制检测框和标签(所有类别都标框,仅车/后备箱有逻辑)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ==========================================
# 关联分析: 哪个后备箱属于哪辆车?
# ==========================================
for car_info in current_cars:
c_id = car_info['id']
c_box = car_info['box']
trunk_found_for_this_car = False
for t_pt in current_trunks:
if self.is_point_in_box(t_pt, c_box):
trunk_found_for_this_car = True
break
if trunk_found_for_this_car:
self.roi_car_registry[c_id]['trunk_frames'] += 1
if self.roi_car_registry[c_id]['trunk_frames'] >= self.frame_thresh_trunk_valid:
self.roi_car_registry[c_id]['is_checked'] = True
# ==========================================
# 维护车辆注册表 & 生成离场报警
# ==========================================
active_car_ids = []
cars_to_remove = []
for car_id, info in self.roi_car_registry.items():
last_seen = info['last_seen']
if (self.current_frame_idx - last_seen) <= self.frame_buffer_limit_car:
active_car_ids.append(car_id)
else:
cars_to_remove.append(car_id)
# 处理离场车辆,生成违规告警
for car_id in cars_to_remove:
car_info = self.roi_car_registry[car_id]
duration_frames = car_info['last_seen'] - car_info['first_seen']
# 情况1通过时间太短 -> Ignore (Too Fast)
if duration_frames < self.frame_thresh_car_min_duration:
logger.info(f"ALARM: Car {car_id} passed too fast -> Regarded as Ignore Checked!")
self.fast_pass_alerts[car_id] = self.current_frame_idx + int(self.ignore_show_seconds * self.fps)
# 情况2时间够长但没检查后备箱 -> Unchecked Trunk
elif not car_info['is_checked']:
logger.info(f"ALARM: Car {car_id} left without checking trunk!")
self.unchecked_trunk_alerts[car_id] = self.current_frame_idx + int(
self.openTrunk_show_seconds * self.fps)
del self.roi_car_registry[car_id]
effective_car_count = len(active_car_ids)
# ==========================================
# 维护警察注册表
# ==========================================
active_police_ids = []
polices_to_remove = []
for police_id, info in self.roi_police_registry.items():
last_seen = info['last_seen']
if (self.current_frame_idx - last_seen) <= self.frame_buffer_limit_police:
active_police_ids.append(police_id)
else:
polices_to_remove.append(police_id)
for police_id in polices_to_remove:
del self.roi_police_registry[police_id]
effective_police_count = len(active_police_ids)
# ==========================================
# 显示调试信息和报警 (仅保留车/后备箱相关)
# ==========================================
# 调试信息
debug_info = f"Cars: {len(active_car_ids)} | Trunk: {current_roi_trunk_count} | Police: {effective_police_count} | Nobody:{self.nobody_frames}/{self.frame_thresh_nobody} | OnlyOne:{self.only_one_frames}/{self.frame_thresh_only_one}"
cv2.putText(frame, debug_info, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 报警偏移量(防止重叠)
alert_offset = 0
# A. 显示 Trunk Checked (车辆已检查后备箱)
# for car_id in active_car_ids:
# if car_id in self.roi_car_registry and self.roi_car_registry[car_id]['is_checked']:
# current_frame_alerts.append({
# 'time': current_time_sec,
# 'action': "Trunk Checked",
# })
# self.draw_alert(frame, "Trunk Checked!!", (0, 255, 0), offset_y=alert_offset)
# alert_offset += 100
# break # 只显示一次
# B. 显示 Unchecked Trunk (离场未检查后备箱)
expired_alerts = [cid for cid, end_frame in self.unchecked_trunk_alerts.items() if
self.current_frame_idx > end_frame]
for cid in expired_alerts:
del self.unchecked_trunk_alerts[cid]
if len(self.unchecked_trunk_alerts) > 0:
alert_text = f"Unchecked Trunk! (ID:{list(self.unchecked_trunk_alerts.keys())})"
current_frame_alerts.append({
'time': current_time_sec,
'action': "Unchecked Trunk",
})
#self.draw_alert(frame, alert_text, (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
# C. 显示 Ignore (通过过快)
expired_fast_alerts = [cid for cid, end_frame in self.fast_pass_alerts.items() if
self.current_frame_idx > end_frame]
for cid in expired_fast_alerts:
del self.fast_pass_alerts[cid]
if len(self.fast_pass_alerts) > 0:
alert_text = f"Ignore: (ID:{list(self.fast_pass_alerts.keys())})"
current_frame_alerts.append({
'time': current_time_sec,
'action': "Ignore",
})
#self.draw_alert(frame, alert_text, (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
# D. 显示警察在场状态 (Nobody/Only One)
# 清理过期的 Nobody 告警
expired_nobody = [k for k, v in self.nobody_alerts.items() if self.current_frame_idx > v]
for k in expired_nobody:
del self.nobody_alerts[k]
# 清理过期的 Only One 告警
expired_only_one = [k for k, v in self.only_one_alerts.items() if self.current_frame_idx > v]
for k in expired_only_one:
del self.only_one_alerts[k]
if effective_car_count > 0:
# 更新累计帧数
if effective_police_count == 0:
self.nobody_frames += 1
self.only_one_frames = 0
elif effective_police_count == 1:
self.only_one_frames += 1
self.nobody_frames = 0
else:
self.nobody_frames = 0
self.only_one_frames = 0
else:
self.nobody_frames = 0
self.only_one_frames = 0
if effective_police_count == 0 and self.nobody_frames >= self.frame_thresh_nobody:
alert_text = "Nobody"
if "Nobody" not in self.nobody_alerts:
self.nobody_alerts["Nobody"] = self.current_frame_idx + int(self.police_show_seconds * self.fps)
current_frame_alerts.append({
'time': current_time_sec,
'action': "Nobody",
})
# self.draw_alert(frame, alert_text, (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
elif effective_police_count == 1 and self.only_one_frames >= self.frame_thresh_only_one:
alert_text = "Only One"
if "Only One" not in self.only_one_alerts:
self.only_one_alerts["Only One"] = self.current_frame_idx + int(self.police_show_seconds * self.fps)
current_frame_alerts.append({
'time': current_time_sec,
'action': "Only One",
})
# self.draw_alert(frame, alert_text, (255, 165, 0), offset_y=alert_offset)
alert_offset += 100
return {
"image": frame,
"alerts": current_frame_alerts,
}
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(BaseFrameProcessorWorker):
"""卡点检测帧处理线程"""
# 子类配置
DETECTOR_FACTORY = lambda params: KadianDetector(params)
POST_TYPE = 1
TARGET_FPS = RTSP_TARGET_FPS