Files
SupervisorAI/rtsp_service_ws_kadian.py
2026-02-02 11:32:59 +08:00

1071 lines
41 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rtsp_service_kadian.py
# 融合 Kadian_Detect_1221.py + rtsp_service_ws.py
# 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警
import cv2
import numpy as np
import os
import time
import threading
import queue
import yaml
import json
import base64
import asyncio
import websockets
from dataclasses import dataclass
from typing import Dict, Any
from test_cam import get_camera_preview_url
# -------------------------- Kadian 检测相关导入 --------------------------
from algorithm.checkpoint.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from algorithm.checkpoint.npu_yolo_pose_onnx import YOLOv8_Pose_ONNX # Pose 专用模型
from yolox.tracker.byte_tracker import BYTETracker
# ========================= 配置区 =========================
# Kadian 模型路径与ROI可根据实际情况修改
DETECT_MODEL_PATH = 'YOLO_Weight/car_opentrunk_person_phone.onnx'
POSE_MODEL_PATH = 'YOLO_Weight/yolov8l-pose.onnx'
# 默认相对ROI与原文件一致
#ROI_RELATIVE = np.array([
# [0.10989583333333333, 0.006481481481481481],
# [0.421875, 0.005555555555555556],
# [0.9921875, 0.9888888888888889],
# [0.3411458333333333, 0.9861111111111112]
#])
ROI_RELATIVE=np.array([
[0.15,0.001],
[0.5,0.001],
[1.0,0.8],
[0.35,1.0]
])
ALERT_PUSH_INTERVAL = 5.0
# 输入尺寸
PERSON_CAR_INPUT_SIZE = 640
POSE_INPUT_SIZE = 640
# RTSP 服务配置
RTSP_TARGET_FPS = 10.0
WS_HOST = "0.0.0.0"
WS_PORT = 8765
WS_PORT_2 = 8764 # 新增第二个WebSocket端口
# WebSocket 客户端集合
ws_clients = set()
ws_clients_2 = set() # 新增第二个WebSocket客户端集合
# ========================= 数据结构 =========================
@dataclass
class CameraConfig:
id: int
name: str
index: str
rtsp_url: str
# ========================= Kadian TrafficMonitor精简版专为服务设计 =========================
class KadianDetector:
def __init__(self, roi_points=ROI_RELATIVE):
# 模型加载
self.detector = YOLOv8_ONNX(DETECT_MODEL_PATH, conf_threshold=0.25, iou_threshold=0.45,
input_size=PERSON_CAR_INPUT_SIZE)
self.pose_detector = YOLOv8_Pose_ONNX(POSE_MODEL_PATH, conf_threshold=0.7, iou_threshold=0.6,
input_size=POSE_INPUT_SIZE)
# Tracker
class TrackerArgs:
track_thresh = 0.25
track_buffer = 30
match_thresh = 0.8
mot20 = False
self.tracker = BYTETracker(TrackerArgs(), frame_rate=10.0)
self.track_role = {}
self.fps = RTSP_TARGET_FPS
# ROI 处理(支持相对/绝对)
#self.roi_points = roi_points.astype(np.int32)
self.roi_points = np.array(roi_points, dtype=np.float64) if roi_points is not None else None
# ==========================================
# 超参数设置 (Hyperparameters)
# ==========================================
# 1. 业务判定时间阈值
self.TIME_THRESHOLD_ONLY_ONE = 3.0 # 单人单检判定时长
self.TIME_THRESHOLD_NOBODY = 2.0 # 无人检查判定时长
# 后备箱检查判定阈值
self.TIME_THRESHOLD_TRUNK_OPEN = 0.5
# 新增:手机检测判定阈值
self.TIME_THRESHOLD_PHONE = 1.0 # 手机检测持续1秒30帧 @30fps
self.TIME_TOLERANCE_PHONE = 0.5 # 手机丢失缓冲时间(防抖动)
# 新增:制服检测判定阈值
self.TIME_THRESHOLD_UNIFORM = 1.0 # 制服不合规判定时长
self.TIME_TOLERANCE_UNIFORM = 0.5 # 制服合规恢复缓冲时间
# 车辆最小停留时间阈值 (小于此时间视为无人检查/直接通过)
self.TIME_THRESHOLD_CAR_MIN_DURATION = 3.0
# 2. Person 丢帧缓冲
self.TIME_TOLERANCE_PERSON = 1.0
# 3. Car 丢帧/ID维持缓冲
self.TIME_TOLERANCE_CAR = 0.5
# --- 计算对应的帧数阈值 ---
self.frame_thresh_one = int(self.TIME_THRESHOLD_ONLY_ONE * self.fps)
self.frame_thresh_nobody = int(self.TIME_THRESHOLD_NOBODY * self.fps)
self.frame_thresh_trunk_valid = int(self.TIME_THRESHOLD_TRUNK_OPEN * self.fps)
# 新增:手机检测帧数阈值
self.frame_thresh_phone = int(self.TIME_THRESHOLD_PHONE * self.fps)
self.frame_buffer_phone = int(self.TIME_TOLERANCE_PHONE * self.fps)
# 新增:制服检测帧数阈值
self.frame_thresh_uniform = int(self.TIME_THRESHOLD_UNIFORM * self.fps)
self.frame_buffer_uniform = int(self.TIME_TOLERANCE_UNIFORM * self.fps)
self.frame_thresh_car_min_duration = int(self.TIME_THRESHOLD_CAR_MIN_DURATION * self.fps)
self.frame_buffer_limit_person = int(self.TIME_TOLERANCE_PERSON * self.fps)
self.frame_buffer_limit_car = int(self.TIME_TOLERANCE_CAR * self.fps)
print(f"\n超参数设置:")
print(f" FPS: {self.fps:.2f}")
print(f" 判定 'Only One' / 'Nobody' 需连续: {self.frame_thresh_one}")
print(f" 判定 'Trunk Checked' 需累计检测: {self.frame_thresh_trunk_valid}")
print(f" 判定 'Phone Detected' 需累计检测: {self.frame_thresh_phone}")
print(f" 手机丢失缓冲帧数: {self.frame_buffer_phone}")
print(f" 判定 'Uniform Invalid' 需连续检测: {self.frame_thresh_uniform}")
print(f" 制服合规恢复缓冲帧数: {self.frame_buffer_uniform}")
print(f" 判定 'Too Fast' (视为Nobody) 最小停留: {self.frame_thresh_car_min_duration}")
self.current_frame_idx = 0
self.cnt_frame_one_person = 0
self.cnt_frame_nobody = 0
self.cnt_missing_buffer_person = 0
# 手机检测状态变量(独立于车辆)
self.phone_detection_frames = 0 # 连续检测到手机的帧数
self.phone_missing_frames = 0 # 连续未检测到手机的帧数
self.phone_alert_active = False # 手机报警是否激活
# 新增:制服检测状态变量
self.pose_person_count = 0 # 骨骼点模型检测的ROI内人员数量
self.uniform_alert_active = False # 制服报警是否激活
self.uniform_detection_frames = 0 # 连续检测到制服不合规的帧数
self.uniform_recovery_frames = 0 # 连续恢复合规的帧数
# 车辆注册表 (字典)
self.roi_car_registry = {}
# 违规车辆记录 (后备箱未检)
self.unchecked_trunk_alerts = {}
# 违规车辆记录 (通过过快 -> 归类为 Nobody)
self.fast_pass_alerts = {}
def _get_roi_points(self, frame_width: int, frame_height: int):
"""
每帧动态计算正确的 ROI 绝对坐标,并确保类型为 np.int32
用于 pointPolygonTest 和 polylines
"""
if self.roi_points is None:
# 使用默认相对坐标
default_rel = np.array([
[0.15, 0.01],
[0.45, 0.01],
[0.95, 0.95],
[0.35, 0.95]
], dtype=np.float64)
roi_abs = default_rel * np.array([frame_width, frame_height])
else:
if self.roi_points.max() <= 1.0:
# 相对坐标 → 转换为绝对
roi_abs = self.roi_points * np.array([frame_width, frame_height])
else:
# 绝对坐标,直接使用
roi_abs = self.roi_points.copy()
# 强制转为 int32关键解决 OpenCV 断言错误)
return roi_abs.astype(np.int32)
def check_point_in_roi(self,roi_points, point):
return cv2.pointPolygonTest(roi_points, point, False) >= 0
def compute_iou(self,boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
x = self.width - text_w - 20
y = 50 + text_h + offset_y # 增加 Y 轴偏移
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
if sub_text:
cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2)
def is_point_in_box(self, point, box):
px, py = point
x1, y1, x2, y2 = box
return x1 < px < x2 and y1 < py < y2
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
h, w = frame.shape[:2]
self.width, self.height = w, h
self.current_frame_idx += 1
# ========= 每帧动态获取正确的 ROIint32=========
roi_points_int32 = self._get_roi_points(w, h) # shape: (4, 2), dtype: int32
roi_points_draw = roi_points_int32.reshape((-1, 1, 2)) # shape: (4, 1, 2) 用于绘制
current_time_sec = timestamp
pose_results = self.pose_detector(frame)
# ========= 主检测 =========
detections = self.detector(frame)
dets_xyxy = []
dets_roles = []
dets_for_tracker = []
# ========= 当前帧所有警告列表(关键改动)==========
current_frame_alerts = [] # 每帧清空,重新收集
if detections:
for det in detections:
x1, y1, x2, y2, conf, cls_id = det # x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
dets_xyxy.append([x1, y1, x2, y2])
dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
dets_roles.append("car")
elif cls_id == 1:
dets_roles.append("opentrunk")
elif cls_id == 2:
dets_roles.append("person")
elif cls_id == 3:
dets_roles.append("phone")
# print(f'dets_roles: {dets_roles}')
dets = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) else np.empty((0, 5))
tracks = self.tracker.update(
dets,
[self.height, self.width],
[self.height, self.width]
)
# print("tracks: {}".format(tracks))
# 绘制骨骼
frame = YOLOv8_Pose_ONNX.draw_keypoints(frame, pose_results)
# ========= 绘制 ROI =========
cv2.polylines(frame, [roi_points_draw], isClosed=True, color=(255, 0, 0), thickness=3)
# ========= 单帧统计变量 =========
current_roi_person_count = 0
current_roi_trunk_count = 0
current_roi_phone_count = 0
# 临时存储本帧的目标,用于后续关联分析
current_cars = [] # {'id':, 'box':}
current_trunks = [] # (cx, cy)
for t in tracks:
# print("t: {}".format(t))
tid = t.track_id
# cls_id = -1
# IoU 匹配角色
# if tid not in track_role and dets_xyxy:
REVALIDATE_FRAME_INTERVAL = 10
#if tid not in self.track_role:
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (tid not in self.track_role):
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
for i, box in enumerate(dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = dets_roles[i]
if best_iou > 0.1:
self.track_role[tid] = best_role
else:
self.track_role[tid] = "unknown"
role = self.track_role.get(tid, "unknown")
cls_id = -1
if role == "car":
cls_id = 0
elif role == "opentrunk":
cls_id = 1
elif role == "person":
cls_id = 2
elif role == "phone":
cls_id = 3
# print("tid: {}, role: {}, cls: {}".format(tid, role,cls_id))
x1, y1, x2, y2 = map(int, t.tlbr)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
color = None
label = None
if self.check_point_in_roi(roi_points_int32,(cx, cy)):
if cls_id == 0: # Car
color = (0, 255, 0)
current_cars.append({'id': tid, 'box': [x1, y1, x2, y2]})
if tid not in self.roi_car_registry:
self.roi_car_registry[tid] = {
'first_seen': self.current_frame_idx,
'last_seen': self.current_frame_idx,
'trunk_frames': 0,
'is_checked': False,
}
else:
self.roi_car_registry[tid]['last_seen'] = self.current_frame_idx
label = f"Car:{tid} IN"
elif cls_id == 1: # Opentrunk
current_roi_trunk_count += 1
color = (255, 165, 0)
current_trunks.append((cx, cy))
label = "OpenTrunk IN"
elif cls_id == 2: # Person
current_roi_person_count += 1
color = (255, 0, 255)
label = "Person IN"
elif cls_id == 3: # Phone主模型已支持
current_roi_phone_count += 1
color = (0, 0, 139)
else:
color = (255, 255, 255)
label = "Unknown"
# label = f"ID:{tid} IN"
# 特殊显示: 如果这辆车已经合格,框变蓝色
if cls_id == 0 and tid in self.roi_car_registry and self.roi_car_registry[tid][
'is_checked']:
color = (255, 255, 0) # Cyan for checked cars
label += " (Checked)"
else:
color = (0, 0, 255)
label = "OUT"
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ==========================================
# 4. 从骨骼点模型中统计ROI内人员数量
# ==========================================
self.pose_person_count = 0
# if pose_results[0].boxes is not None:
# pose_boxes = pose_results[0].boxes
# for box in pose_boxes:
# # 获取人体检测框的中心点
# x1, y1, x2, y2 = map(int, box.xyxy[0])
# cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
#
# # 判断中心点是否在ROI内
# if self.check_point_in_roi((cx, cy)):
# self.pose_person_count += 1
if pose_results:
for pose in pose_results:
x1, y1, x2, y2 = pose['bbox'][0], pose['bbox'][1], pose['bbox'][2], pose['bbox'][3]
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# 判断中心点是否在ROI内
if self.check_point_in_roi(roi_points_int32,(cx, cy)):
self.pose_person_count += 1
# ==========================================
# 5. 关联分析: 哪个后备箱属于哪辆车?
# ==========================================
for car_info in current_cars:
c_id = car_info['id']
c_box = car_info['box']
trunk_found_for_this_car = False
for t_pt in current_trunks:
if self.is_point_in_box(t_pt, c_box):
trunk_found_for_this_car = True
break
if trunk_found_for_this_car:
self.roi_car_registry[c_id]['trunk_frames'] += 1
if self.roi_car_registry[c_id]['trunk_frames'] >= self.frame_thresh_trunk_valid:
self.roi_car_registry[c_id]['is_checked'] = True
# ==========================================
# 6. 独立的手机检测逻辑(不与车辆绑定)
# ==========================================
if current_roi_phone_count > 0:
# 检测到手机框
self.phone_detection_frames += 1
self.phone_missing_frames = 0 # 重置丢失计数器
# 当检测累计达到阈值时,激活报警
if self.phone_detection_frames >= self.frame_thresh_phone:
self.phone_alert_active = True
else:
# 未检测到手机框
self.phone_missing_frames += 1
# 如果之前检测到手机,重置检测计数器
if self.phone_detection_frames > 0:
# 只有在连续丢失超过缓冲帧数时才重置
if self.phone_missing_frames >= self.frame_buffer_phone:
self.phone_detection_frames = 0
self.phone_alert_active = False
else:
# 从未检测到手机,保持状态
pass
# ==========================================
# 7. 制服检测逻辑(比较两个模型的人员数量)
# ==========================================
# 比较骨骼点模型和业务检测模型的人员数量
uniform_invalid = False
if self.pose_person_count > current_roi_person_count:
# 骨骼点模型检测到的人员多于业务检测模型
# 说明有人没穿执勤制服
uniform_invalid = True
self.uniform_detection_frames += 1
self.uniform_recovery_frames = 0 # 重置恢复计数器
# 当连续检测不合规达到阈值时,激活报警
if self.uniform_detection_frames >= self.frame_thresh_uniform:
self.uniform_alert_active = True
else:
# 人员数量匹配或业务模型检测更多(理论上不会)
self.uniform_recovery_frames += 1
# 如果之前有不合规检测,检查是否需要关闭报警
if self.uniform_detection_frames > 0:
# 只有在连续合规超过缓冲帧数时才重置
if self.uniform_recovery_frames >= self.frame_buffer_uniform:
self.uniform_detection_frames = 0
self.uniform_alert_active = False
else:
# 从未检测到不合规,保持状态
pass
# ==========================================
# 8. 维护车辆注册表 & 生成离场报警
# ==========================================
active_car_ids = []
cars_to_remove = []
for car_id, info in self.roi_car_registry.items():
last_seen = info['last_seen']
if (self.current_frame_idx - last_seen) <= self.frame_buffer_limit_car:
active_car_ids.append(car_id)
else:
cars_to_remove.append(car_id)
# 执行删除 并 检查违规
for car_id in cars_to_remove:
car_info = self.roi_car_registry[car_id]
duration_frames = car_info['last_seen'] - car_info['first_seen']
# 情况1通过时间太短 -> 归类为 Nobody (Too Fast)
if duration_frames < self.frame_thresh_car_min_duration:
print(f"ALARM: Car {car_id} passed too fast -> Regarded as Nobody Checked!")
self.fast_pass_alerts[car_id] = self.current_frame_idx + int(3.0 * self.fps)
# 情况2时间够长但没检查后备箱 -> Unchecked Trunk
elif not car_info['is_checked']:
print(f"ALARM: Car {car_id} left without checking trunk!")
self.unchecked_trunk_alerts[car_id] = self.current_frame_idx + int(3.0 * self.fps)
del self.roi_car_registry[car_id]
effective_car_count = len(active_car_ids)
# ==========================================
# 9. 业务逻辑判定 (Only One / Nobody)
# ==========================================
status_text = ""
if effective_car_count > 0:
# --- Only One ---
if current_roi_person_count == 1:
self.cnt_frame_one_person += 1
self.cnt_missing_buffer_person = 0
self.cnt_frame_nobody = 0
# --- Nobody ---
elif current_roi_person_count == 0:
if self.cnt_frame_one_person > 0 and self.cnt_missing_buffer_person < self.frame_buffer_limit_person:
self.cnt_frame_one_person += 1
self.cnt_missing_buffer_person += 1
self.cnt_frame_nobody = 0
status_text = f"Person Buffer ({self.cnt_missing_buffer_person}/{self.frame_buffer_limit_person})"
else:
self.cnt_frame_one_person = 0
self.cnt_missing_buffer_person = 0
self.cnt_frame_nobody += 1
else:
self.cnt_frame_one_person = 0
self.cnt_missing_buffer_person = 0
self.cnt_frame_nobody = 0
else:
self.cnt_frame_one_person = 0
self.cnt_missing_buffer_person = 0
self.cnt_frame_nobody = 0
# ==========================================
# 10. 显示报警 (UI分层优化)
# ==========================================
# 更新调试信息,包含所有检测状态
phone_status = f"Phone: {current_roi_phone_count}"
if self.phone_alert_active:
phone_status += " (ALERT)"
uniform_status = f"Uniform: Pose={self.pose_person_count}, Model={current_roi_person_count}"
if self.uniform_alert_active:
uniform_status += " (INVALID!)"
debug_info = f"Cars: {len(active_car_ids)} | Person: {current_roi_person_count} | Trunk: {current_roi_trunk_count} | {phone_status}"
cv2.putText(frame, debug_info, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, uniform_status, (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 使用 offset 实现报警堆叠,防止遮挡
alert_offset = 0
# 第一层:实时状态 (Real-time Status)
# ------------------------------------------------
# A. 显示 Only One
if self.cnt_frame_one_person >= self.frame_thresh_one:
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Only One",
}
)
self.draw_alert(frame, "Only One", (0, 255, 255), status_text, offset_y=alert_offset)
alert_offset += 100
# B. 显示 Nobody (实时状态)
elif self.cnt_frame_nobody >= self.frame_thresh_nobody:
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Nobody",
}
)
self.draw_alert(frame, "Nobody", (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
# C. 显示 Trunk Checked (在车辆存活期间)
for car_id in active_car_ids:
if car_id in self.roi_car_registry and self.roi_car_registry[car_id]['is_checked']:
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Trunk Checked",
}
)
self.draw_alert(frame, "Trunk Checked!!", (0, 255, 0), offset_y=alert_offset)
alert_offset += 100
break # 只显示一次
# D. 显示 Playing Phone独立检测不与车辆绑定
if self.phone_alert_active:
# 可以显示检测的持续时间
duration_seconds = self.phone_detection_frames / self.fps
sub_text = f"Detected for {duration_seconds:.1f}s"
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Playing Phone",
}
)
self.draw_alert(frame, "Playing Phone", (255, 0, 0), sub_text, offset_y=alert_offset)
alert_offset += 100
# E. 新增:显示 Unvaild Uniform!!
if self.uniform_alert_active:
# 显示具体数量差异
diff = self.pose_person_count - current_roi_person_count
sub_text = f"Missing {diff} uniform(s)"
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Unvaild Uniform!!",
}
)
self.draw_alert(frame, "Unvaild Uniform!!", (255, 165, 0), sub_text, offset_y=alert_offset)
alert_offset += 100
# 第二层:离场违规 (Post-Event Alerts)
# ------------------------------------------------
# F. 显示 Unchecked Trunk
expired_alerts = [cid for cid, end_frame in self.unchecked_trunk_alerts.items() if
self.current_frame_idx > end_frame]
for cid in expired_alerts:
del self.unchecked_trunk_alerts[cid]
if len(self.unchecked_trunk_alerts) > 0:
alert_text = f"Unchecked Trunk! (ID:{list(self.unchecked_trunk_alerts.keys())})"
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Unchecked Trunk",
}
)
self.draw_alert(frame, alert_text, (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
# G. 显示 Nobody (离场结果)
expired_fast_alerts = [cid for cid, end_frame in self.fast_pass_alerts.items() if
self.current_frame_idx > end_frame]
for cid in expired_fast_alerts:
del self.fast_pass_alerts[cid]
if len(self.fast_pass_alerts) > 0:
alert_text = f"Nobody (ID:{list(self.fast_pass_alerts.keys())})"
current_frame_alerts.append(
{
'time': current_time_sec,
'action': "Nobody",
}
)
self.draw_alert(frame, alert_text, (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
return {
"image": frame,
"alerts":current_frame_alerts
}
# ========================= WebSocket 服务线程 =========================
class WebSocketSender(threading.Thread):
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.send_queue = send_queue
self.stop_event = stop_event
async def _ws_handler(self, websocket):
ws_clients.add(websocket)
try:
async for _ in websocket:
pass
finally:
ws_clients.discard(websocket)
async def _broadcaster(self):
while not self.stop_event.is_set():
try:
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
except queue.Empty:
continue
data = json.dumps(msg)
dead = []
for ws in list(ws_clients):
try:
await ws.send(data)
except:
dead.append(ws)
for ws in dead:
ws_clients.discard(ws)
self.send_queue.task_done()
async def _run_async(self):
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
await self._broadcaster()
def run(self):
asyncio.run(self._run_async())
# ========================= WebSocket 服务线程2 =========================
class WebSocketSender2(threading.Thread):
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.send_queue = send_queue
self.stop_event = stop_event
async def _ws_handler(self, websocket):
ws_clients_2.add(websocket)
try:
async for _ in websocket:
pass
finally:
ws_clients_2.discard(websocket)
async def _broadcaster(self):
while not self.stop_event.is_set():
try:
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
except queue.Empty:
continue
data = json.dumps(msg)
dead = []
for ws in list(ws_clients_2):
try:
await ws.send(data)
except:
dead.append(ws)
for ws in dead:
ws_clients_2.discard(ws)
self.send_queue.task_done()
async def _run_async(self):
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT_2):
print(f"[INFO] WebSocket server 2 started at ws://{WS_HOST}:{WS_PORT_2}")
await self._broadcaster()
def run(self):
asyncio.run(self._run_async())
# ========================= RTSP 抓流线程 =========================
class RTSPCaptureWorker(threading.Thread):
def __init__(self, camera_cfg: CameraConfig, raw_queue: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.camera_cfg = camera_cfg
self.raw_queue = raw_queue
self.stop_event = stop_event
# 添加重连计数器
self.reconnect_count = 0
self.max_reconnects = 5
self.rtsp_url = ""
def run(self):
while not self.stop_event.is_set():
try:
if self.reconnect_count >= self.max_reconnects:
print(f"[WARN] RTSP: {self.camera_cfg.name} reach max reconnects, refresh url")
self.reconnect_count = 0
new_url = self.refresh_video_url()
if new_url:
self.rtsp_url = new_url
else:
print(f"[ERROR] refresh RTSP URL is empty, do nothing")
# 检查rtsp_url是否为空或None如果是则重新获取
if not self.rtsp_url:
print(f"[WARN] RTSP URL is empty, refreshing...")
new_url = self.refresh_video_url()
if new_url:
self.rtsp_url = new_url
else:
print(f"[ERROR] RTSP URL is still empty, retrying in 5 seconds")
time.sleep(5)
continue
# 方法1使用TCP传输更稳定
rtsp_url = self.rtsp_url
if "?" not in rtsp_url:
rtsp_url += "?transport=tcp" # 强制TCP传输
else:
rtsp_url += "&transport=tcp"
# 方法2添加更多FFmpeg参数
cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
# 方法3设置缓冲区大小
cap.set(cv2.CAP_PROP_BUFFERSIZE, 10) # 增加缓冲区
# 方法4设置超时和重连参数
os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = \
"rtsp_transport;tcp|buffer_size;1024000|max_delay;500000|stimeout;2000000"
# 方法5设置解码器flags忽略解码错误
# cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
if not cap.isOpened():
print(f"[ERROR] Cannot open RTSP: {self.rtsp_url}")
time.sleep(2)
self.reconnect_count += 1
continue
print(f"[INFO] Successfully opened RTSP: {self.name}")
self.reconnect_count = 0 # 重置重连计数
# # 设置帧率(可选)
# cap.set(cv2.CAP_PROP_FPS, 25)
while not self.stop_event.is_set():
ret, frame = cap.read()
if not ret:
# 检查流是否结束
print(f"[WARN] Failed to read frame from {self.camera_cfg.name}")
# 检查是否还有数据
time.sleep(0.1)
# 尝试几次后重连
break
item = {
"camera_id": self.camera_cfg.id,
"camera_name": self.camera_cfg.name,
"timestamp": time.time(),
"frame": frame,
}
try:
# 添加队列满时的处理
if self.raw_queue.full():
# 丢弃最旧的一帧
try:
self.raw_queue.get_nowait()
self.raw_queue.task_done()
except queue.Empty:
pass
self.raw_queue.put(item, timeout=0.5)
except queue.Full:
print(f"[WARN] Queue full, dropping frame from {self.camera_cfg.name}")
continue
# 控制读取速度,避免过快
time.sleep(0.02) # 约50ms间隔
cap.release()
except Exception as e:
print(f"[ERROR] Error in RTSP capture for {self.camera_cfg.name}: {e}")
time.sleep(2)
self.reconnect_count += 1
if self.reconnect_count >= self.max_reconnects:
print(f"[ERROR] Max reconnects reached for {self.camera_cfg.name}, stopping.")
def refresh_video_url(self):
"""
重新通过视频ID获取视频URL调用test_cam.py中的get_camera_preview_url方法
返回:
str: 新的视频URL如果获取失败则返回None
"""
try:
# 获取视频IDcamera_cfg.index
video_id = self.camera_cfg.index
# 调用test_cam.py中的函数
result = get_camera_preview_url(video_id)
# 解析结果与test_cam.py相同
if 'data' in result and 'url' in result['data']:
new_url = result['data']['url']
print(f"[INFO] get rtsp url success, URL: {new_url}")
return new_url
else:
print(f"[ERROR] get rtsp url failed: {result}")
return None
except Exception as e:
print(f"[ERROR] get rtsp url error: {str(e)}")
return None
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(threading.Thread):
def __init__(self, raw_queue: queue.Queue, ws_queue: queue.Queue, ws_queue_2: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.raw_queue = raw_queue
self.ws_queue = ws_queue
self.ws_queue_2 = ws_queue_2 # 新增第二个WebSocket队列
self.stop_event = stop_event
self.last_ts: Dict[int, float] = {}
# 每个摄像头一个独立的 Kadian 检测器实例
self.kadian_detectors: Dict[int, KadianDetector] = {}
self.last_alert_push_time: Dict[int,Dict[str,float]]={}
def _encode_base64(self, img):
_, buf = cv2.imencode(".jpg", img)
return base64.b64encode(buf).decode("ascii")
def run(self):
target_interval = 1.0 / RTSP_TARGET_FPS
while not self.stop_event.is_set():
try:
item = self.raw_queue.get(timeout=0.5)
except queue.Empty:
continue
cam_id = item["camera_id"]
ts = item["timestamp"]
frame = item["frame"]
# 抽帧控制
if ts - self.last_ts.get(cam_id, 0) < target_interval:
self.raw_queue.task_done()
continue
self.last_ts[cam_id] = ts
# 获取检测器实例
if cam_id not in self.kadian_detectors:
self.kadian_detectors[cam_id] = KadianDetector()
detector = self.kadian_detectors[cam_id]
# 执行检测
result = detector.process_frame(frame.copy(), cam_id, ts)
result_img = result["image"]
result_type = result["alerts"]
#print(f"alerts: {result_type}")
# ========= 核心修改过滤5秒内重复的action =========
# 初始化当前摄像头的推送时间记录
if cam_id not in self.last_alert_push_time:
self.last_alert_push_time[cam_id] = {}
# 筛选出符合推送条件的action5秒内未推送过
push_actions = []
current_time = time.time()
for alert in result_type:
action = alert['action']
last_push = self.last_alert_push_time[cam_id].get(action, 0)
# 检查是否超过推送间隔
if current_time - last_push >= ALERT_PUSH_INTERVAL:
push_actions.append(action)
# 更新该action的最后推送时间
self.last_alert_push_time[cam_id][action] = current_time
# 通过 WebSocket 发送帧结果
try:
img_b64 = self._encode_base64(result_img)
except Exception as e:
print(f"[ERROR] Encode image failed: {e}")
img_b64 = None
if img_b64 is not None:
# 将abnormal_actions对象数组转换为字符串数组
#action_names = [action_info['action'] for action_info in push_actions]
msg = {
"msg_type": "frame",
"camera_id": 0,
"timestamp": ts,
#"result_type": action_names,
"result_type": push_actions,
"image_base64": img_b64,
}
try:
self.ws_queue.put(msg, timeout=1.0)
if push_actions and len(push_actions) > 0:
self.ws_queue_2.put(msg, timeout=1.0)
except queue.Full:
print("[WARN] ws_send_queue full, drop frame message")
self.raw_queue.task_done()
# ========================= 服务主类 =========================
class RTSPService:
def __init__(self, config_path: str = "config.yaml"):
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
self.cameras = [CameraConfig(id=c["id"], name=c.get("name", f"cam_{c['id']}"), index = c["index"], rtsp_url=c["rtsp_url"])
for c in cfg.get("cameras", [])]
self.stop_event = threading.Event()
self.raw_queue = queue.Queue(maxsize=500)
self.ws_queue = queue.Queue(maxsize=1000)
self.ws_queue_2 = queue.Queue(maxsize=1000) # 新增第二个WebSocket队列
self.capture_workers = []
self.processor = FrameProcessorWorker(self.raw_queue, self.ws_queue, self.ws_queue_2, self.stop_event)
self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event)
self.ws_sender_2 = WebSocketSender2(self.ws_queue_2, self.stop_event) # 新增第二个WebSocket发送器
def start(self):
self.ws_sender.start()
self.ws_sender_2.start() # 新增启动第二个WebSocket服务
self.processor.start()
for cam in self.cameras:
w = RTSPCaptureWorker(cam, self.raw_queue, self.stop_event)
w.start()
self.capture_workers.append(w)
print("[INFO] Kadian RTSP Service started")
def stop(self):
self.stop_event.set()
self.raw_queue.join()
self.ws_queue.join()
self.ws_queue_2.join() # 新增等待第二个WebSocket队列
for w in self.capture_workers:
w.join(timeout=2.0)
self.processor.join(timeout=2.0)
self.ws_sender.join(timeout=2.0)
self.ws_sender_2.join(timeout=2.0) # 新增等待第二个WebSocket发送器
print("[INFO] Service stopped")
if __name__ == "__main__":
service = RTSPService("config.yaml")
service.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
service.stop()