Files
SupervisorAI/biz/trajectory/trajectory01_biz.py
2026-03-03 11:25:37 +08:00

579 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rtsp_service_kadian.py
# 融合 Kadian_Detect_1221.py + rtsp_service_ws.py
# 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警
import cv2
import numpy as np
import os
import time
import threading
import queue
import yaml
import json
import base64
import asyncio
import websockets
from dataclasses import dataclass
from typing import Dict, Any, Tuple, List
from datetime import datetime
# -------------------------- Kadian 检测相关导入 --------------------------
from algorithm.checkpoint.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
# from rtsp_service_ws_0108 import WS_PORT
from yolox.tracker.byte_tracker import BYTETracker
# ========================= 配置区 =========================
# Kadian 模型路径与ROI可根据实际情况修改
detector_model_path = 'YOLO_Weight/prisoner_model.onnx'
# 输入尺寸
input_size = 640
# RTSP 服务配置
RTSP_TARGET_FPS = 10.0
# 新增:告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0 # 相同action 5秒内仅推送一次
ALERT_PUSH_URL = "http://123.57.151.210:10000/picenter/websocket/test/process"
class TrajectoryDetector:
def __init__(self):
# 模型加载
self.police_prisoner_detector = YOLOv8_ONNX(detector_model_path, conf_threshold=0.5, iou_threshold=0.45,
input_size=input_size)
# ByteTracker
class TrackerArgs:
track_thresh = 0.25
track_buffer = 30
match_thresh = 0.8
mot20 = False
self.police_prisoner_track_role = {}
self.fps = RTSP_TARGET_FPS
self.tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
# ==========================================
# 超参数设置 (Hyperparameters)
# ==========================================
self.TIME_THRESHOLD_POLICE = 1.0 # 警察判定时长
self.TIME_TOLERANCE_POLICE = 0.5 # 警察失缓冲时间(防抖动)
self.TIME_THRESHOLD_PRISONER = 1.0 # 犯人判定时长
self.TIME_TOLERANCE_PRISONER = 1.0 # 犯人丢失缓冲时间(防抖动)
# 警察检测帧数阈值
self.frame_thresh_police = int(self.TIME_THRESHOLD_POLICE * self.fps)
self.frame_buffer_police = int(self.TIME_TOLERANCE_POLICE * self.fps)
# 犯人检测帧数阈值
self.frame_thresh_prisoner = int(self.TIME_THRESHOLD_PRISONER * self.fps)
self.frame_buffer_prisoner = int(self.TIME_TOLERANCE_PRISONER * self.fps)
print(f"\n超参数设置:")
print(f" FPS: {self.fps:.2f}")
print(f" 判定 'police Detected' 需累计检测: {self.frame_thresh_police}")
print(f" 警察丢失缓冲帧数: {self.frame_buffer_police}")
print(f" 判定 'prisoner Detected' 需累计检测: {self.frame_thresh_prisoner}")
print(f" 犯人丢失缓冲帧数: {self.frame_buffer_prisoner}")
# ==========================================
# 状态变量初始化
# ==========================================
self.current_frame_idx = 0
# 警察检测状态变量
self.police_detection_frames = 0 # 连续检测到警察的帧数
self.police_missing_frames = 0 # 连续未检测到警察的帧数
self.police_alert_active = False # 警察报警是否激活
# 犯人检测状态变量
self.prisoner_detection_frames = 0 # 连续检测到犯人的帧数
self.prisoner_missing_frames = 0 # 连续未检测到犯人的帧数
self.prisoner_alert_active = False # 犯人报警是否激活
# =========================
# 路线 ROI + 状态机初始化
# =========================
# ⚠️ 改为相对坐标0-1区间按 [x, y] 格式x/y 范围 0~1
# 示例:原 (50,100) 在 960x480 分辨率下 → x=50/960≈0.052, y=100/480≈0.208
self.route_rois = [
{
"name": "entry",
"polygon_rel": [(0.4, 0.05), (0.6, 0.05), (0.6, 0.35), (0.4, 0.35)] # 相对坐标
},
{
"name": "corridor",
"polygon_rel": [(0.4, 0.4), (0.6, 0.4), (0.6, 0.6), (0.4, 0.6)] # 相对坐标
},
{
"name": "exit", # finish区域
"polygon_rel": [(0.55, 0.3), (0.75, 0.3), (0.75, 0.7), (0.55, 0.7)] # 相对坐标
}
]
# 帧尺寸(动态更新)
self.width = 0
self.height = 0
print(f"相对坐标 ROI: {self.route_rois}")
# 每个犯人track_id一套路线状态
self.prisoner_route_state = {}
# 新增记录所有曾经出现过的犯人track_id及其状态
self.all_prisoner_tracks = {}
# 新增记录已触发违规的track_id避免重复告警
self.violated_tracks = set()
def _get_abs_polygon(self, rel_polygon):
"""将相对坐标(0-1)转换为绝对像素坐标"""
return [
(int(x * self.width), int(y * self.height))
for x, y in rel_polygon
]
def compute_iou(self, boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxB[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
x = self.width - text_w - 20
y = 50 + text_h + offset_y # 增加 Y 轴偏移
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
if sub_text:
cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2)
def _point_in_polygon(self, point, polygon):
"""
判断点是否在多边形内
polygon: 绝对像素坐标的多边形
"""
return cv2.pointPolygonTest(
np.array(polygon, dtype=np.int32),
point,
False
) >= 0
def _draw_route_rois(self, frame):
"""
在画面中绘制路线 ROI动态转换为绝对坐标
"""
for idx, roi in enumerate(self.route_rois):
# 相对坐标转绝对坐标
abs_polygon = self._get_abs_polygon(roi["polygon_rel"])
pts = np.array(abs_polygon, np.int32).reshape((-1, 1, 2))
# ROI 边框
cv2.polylines(
frame,
[pts],
isClosed=True,
color=(0, 255, 255),
thickness=2
)
# 标注名称
text_pos = abs_polygon[0]
cv2.putText(
frame,
f"{idx + 1}:{roi['name']}",
(text_pos[0], text_pos[1] - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 255),
2
)
def _update_prisoner_route(self, tid, point, timestamp):
"""
路线状态机:
必须按顺序进入 route_rois
"""
# 初始化状态
if tid not in self.prisoner_route_state:
self.prisoner_route_state[tid] = {
"stage": 0, # 当前应进入的 ROI 索引
"finished": False, # 是否完成路线
"violation": False, # 是否违规
"entered_entry": False, # 是否进入过entry区域
"last_seen": timestamp # 最后出现时间
}
# 记录所有犯人track
self.all_prisoner_tracks[tid] = self.prisoner_route_state[tid]
state = self.prisoner_route_state[tid]
state["last_seen"] = timestamp # 更新最后出现时间
# 已完成或已违规,不再处理
# 已完成或已违规不再处理并删除该tid的状态
if state["finished"] or state["violation"]:
# 关键修改删除当前tid的状态记录
if tid in self.prisoner_route_state:
del self.prisoner_route_state[tid]
# 可选同时清理all_prisoner_tracks和已标记的违规/完成记录,避免内存泄漏
if tid in self.all_prisoner_tracks:
del self.all_prisoner_tracks[tid]
self.violated_tracks.discard(tid) # 移除违规标记
return
current_stage = state["stage"]
# 所有阶段完成
if current_stage >= len(self.route_rois):
state["finished"] = True
return
# 当前应进入的 ROI转换为绝对坐标
current_roi_rel = self.route_rois[current_stage]["polygon_rel"]
current_roi_abs = self._get_abs_polygon(current_roi_rel)
# 是否进入当前 ROI
if self._point_in_polygon(point, current_roi_abs):
# 标记是否进入entry区域第一个ROI
if current_stage == 0:
state["entered_entry"] = True
state["stage"] += 1
# 如果刚好完成最后一个 ROI (exit/finish)
if state["stage"] == len(self.route_rois):
state["finished"] = True
def _check_prisoner_violation(self, current_time):
"""
检查消失的犯人是否违规:
1. 进入过entry区域
2. 未完成整个路线未进入exit/finish
3. 已经消失超过track buffer时间
"""
violations = []
# 遍历所有曾经出现过的犯人track
for tid, state in list(self.all_prisoner_tracks.items()):
# 跳过已完成、已违规或未进入entry的track
if state["finished"] or state["violation"] or not state["entered_entry"]:
continue
# 检查是否已消失超过track buffer时间这里用3秒作为消失判定
if current_time - state["last_seen"] > 2.0 and tid not in self.violated_tracks:
state["violation"] = True
self.violated_tracks.add(tid)
violations.append({
'time': current_time,
'action': 'violation',
'confidence': 1.0,
'details': f""
})
return violations
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
h, w = frame.shape[:2]
self.width, self.height = w, h # 更新帧尺寸
self.current_frame_idx += 1
current_time_sec = timestamp
# ========= 警察和犯人检测 =========
police_prisoner_results = self.police_prisoner_detector(frame)
police_prisoner_dets_xyxy = []
police_prisoner_dets_roles = []
police_prisoner_dets_for_tracker = []
# ========= 当前帧所有警告列表(关键改动)==========
current_frame_alerts = [] # 每帧清空,重新收集
if police_prisoner_results:
for det in police_prisoner_results:
x1, y1, x2, y2, conf, cls_id = det # x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
police_prisoner_dets_xyxy.append([x1, y1, x2, y2])
police_prisoner_dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
police_prisoner_dets_roles.append("police")
elif cls_id == 1:
police_prisoner_dets_roles.append("prisoner")
ppolice_prisoner_dets = np.array(police_prisoner_dets_for_tracker, dtype=np.float32) if len(
police_prisoner_dets_for_tracker) else np.empty((0, 5))
police_prisoner_dets_tracks = self.tracker.update(
ppolice_prisoner_dets,
[self.height, self.width],
[self.height, self.width]
)
# 重置当前帧的犯人track标记
current_frame_prisoner_tids = set()
# ========= 单帧统计变量 =========
current_police_count = 0
current_prisoner_count = 0
# ========= 警察和犯人检测 =========
for t in police_prisoner_dets_tracks:
tid = t.track_id
# IoU 匹配角色
REVALIDATE_FRAME_INTERVAL = 10
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (
tid not in self.police_prisoner_track_role):
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
for i, box in enumerate(police_prisoner_dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = police_prisoner_dets_roles[i]
if best_iou > 0.1:
self.police_prisoner_track_role[tid] = best_role
else:
self.police_prisoner_track_role[tid] = "unknown"
role = self.police_prisoner_track_role.get(tid, "unknown")
cls_id = -1
if role == "police":
cls_id = 0
elif role == "prisoner":
cls_id = 1
x1, y1, x2, y2 = map(int, t.tlbr)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
color = None
label = None
if cls_id == 0: # police
current_police_count += 1
color = (255, 0, 255)
label = "police"
elif cls_id == 1: # prisoner
current_prisoner_count += 1
color = (0, 0, 139)
label = "prisoner"
current_frame_prisoner_tids.add(tid)
# ===== 路线状态机更新 =====
self._update_prisoner_route(
tid=tid,
point=(cx, cy),
timestamp=current_time_sec
)
else:
color = (255, 255, 255)
label = "Unknown"
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ==========================================
# 检查犯人违规进入entry但未到exit就消失
# ==========================================
violation_alerts = self._check_prisoner_violation(current_time_sec)
current_frame_alerts.extend(violation_alerts)
# ==========================================
# 犯人检测
# ==========================================
if current_prisoner_count > 0:
self.prisoner_detection_frames += 1
self.prisoner_missing_frames = 0
if self.prisoner_detection_frames >= self.frame_thresh_prisoner:
self.prisoner_alert_active = True
else:
self.prisoner_missing_frames += 1
if self.prisoner_detection_frames > 0:
if self.prisoner_missing_frames >= self.frame_buffer_prisoner:
self.prisoner_detection_frames = 0
self.prisoner_alert_active = False
# ==========================================
# 警察检测
# ==========================================
if current_police_count > 0:
self.police_detection_frames += 1
self.police_missing_frames = 0
if self.police_detection_frames >= self.frame_thresh_police:
self.police_alert_active = True
else:
self.police_missing_frames += 1
if self.police_detection_frames > 0:
if self.police_missing_frames >= self.frame_buffer_police:
self.police_detection_frames = 0
self.police_alert_active = False
alert_offset = 0
# A. 有犯人
if self.prisoner_alert_active:
duration_seconds = self.prisoner_detection_frames / self.fps
current_frame_alerts.append(
{
'time': current_time_sec,
'action': 'prisoner',
'confidence': 1.0,
'details': f"Detected for {duration_seconds:.1f}s"
}
)
self.draw_alert(frame, "prisoner", (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
# B. 路线违规告警
for tid, state in self.prisoner_route_state.items():
if state["finished"]:
current_frame_alerts.append({
"time": current_time_sec,
"action": "finished",
"confidence": 1.0,
"details": ""
})
#state["finished"] = False
self.draw_alert(frame, "finished", (0, 255, 0), offset_y=alert_offset)
alert_offset += 100
# C. 路线违规告警
for violation in violation_alerts:
self.draw_alert(frame, "ROUTE VIOLATION!", (0, 0, 255),
sub_text=violation['details'], offset_y=alert_offset)
alert_offset += 100
# =========================
# 绘制路线 ROI始终显示
# =========================
self._draw_route_rois(frame)
return {
"image": frame,
"alerts": current_frame_alerts
}
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(threading.Thread):
def __init__(self,
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
ws_send_queue: "queue.Queue[Dict[str, Any]]",
stop_event: threading.Event):
super().__init__(daemon=True)
self.raw_queue = raw_frame_queue
self.ws_queue = ws_send_queue
self.stop_event = stop_event
self.last_ts: Dict[int, float] = {}
# 每个摄像头一个独立的 Kadian 检测器实例
self.trajectory_detectors: Dict[int, TrajectoryDetector] = {}
# 新增维护每个摄像头每个action的最后推送时间 {camera_id: {action: last_push_time}}
self.last_alert_push_time: Dict[int, Dict[str, float]] = {}
def _encode_image_to_base64(self, image) -> str:
ok, buf = cv2.imencode(".jpg", image)
if not ok:
raise RuntimeError("Failed to encode image to JPEG")
return base64.b64encode(buf.tobytes()).decode("ascii")
def run(self):
target_interval = 1.0 / RTSP_TARGET_FPS
while not self.stop_event.is_set():
try:
item = self.raw_queue.get(timeout=0.5)
except queue.Empty:
continue
cam_id = item["camera_id"]
ts = item["timestamp"]
frame = item["frame"]
# 抽帧控制
if ts - self.last_ts.get(cam_id, 0) < target_interval:
self.raw_queue.task_done()
continue
self.last_ts[cam_id] = ts
# 获取检测器实例
if cam_id not in self.trajectory_detectors:
self.trajectory_detectors[cam_id] = TrajectoryDetector()
detector = self.trajectory_detectors[cam_id]
# 执行检测
result = detector.process_frame(frame.copy(), cam_id, ts)
result_img = result["image"]
result_type = result["alerts"]
# ========= 核心修改过滤5秒内重复的action =========
# 初始化当前摄像头的推送时间记录
if cam_id not in self.last_alert_push_time:
self.last_alert_push_time[cam_id] = {}
# 筛选出符合推送条件的action5秒内未推送过
push_actions = []
current_time = time.time()
for alert in result_type:
action = alert['action']
last_push = self.last_alert_push_time[cam_id].get(action, 0)
# 检查是否超过推送间隔
if current_time - last_push >= ALERT_PUSH_INTERVAL:
push_actions.append(action)
# 更新该action的最后推送时间
self.last_alert_push_time[cam_id][action] = current_time
# 通过 WebSocket 发送帧结果
try:
img_b64 = self._encode_image_to_base64(result_img)
except Exception as e:
print(f"[ERROR] Encode image failed: {e}")
img_b64 = None
if img_b64 is not None:
msg = {
"msg_type": "frame",
"camera_id": 1,
"timestamp": ts,
"result_type": push_actions,
"image_base64": img_b64,
}
try:
self.ws_queue.put(msg, timeout=1.0)
# if push_actions and len(push_actions) > 0:
# self.ws_queue_2.put(msg, timeout=1.0)
except queue.Full:
print("[WARN] ws_send_queue full, drop frame message")
self.raw_queue.task_done()