Files
SupervisorAI/rtsp_service_ws_Zoulang.py
2026-01-23 09:39:56 +08:00

632 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rtsp_service_kadian.py
# 融合 Kadian_Detect_1221.py + rtsp_service_ws.py
# 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警
import cv2
import numpy as np
import os
import time
import threading
import queue
import yaml
import json
import base64
import asyncio
import websockets
from dataclasses import dataclass
from typing import Dict, Any, Tuple, List
from datetime import datetime
# -------------------------- Kadian 检测相关导入 --------------------------
from npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
from yolox.tracker.byte_tracker import BYTETracker
# ========================= 配置区 =========================
# Kadian 模型路径与ROI可根据实际情况修改
detector_model_path = 'YOLO_Weight/prisoner_model.onnx'
# 输入尺寸
input_size = 640
# RTSP 服务配置
RTSP_TARGET_FPS = 10.0
WS_HOST = "0.0.0.0"
WS_PORT = 8769
WS_PORT_2 = 8768 # 新增第二个WebSocket端口
# 新增:告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0 # 相同action 5秒内仅推送一次
# WebSocket 客户端集合
ws_clients = set()
ws_clients_2 = set() # 新增第二个WebSocket客户端集合
# ========================= 数据结构 =========================
@dataclass
class CameraConfig:
id: int
name: str
rtsp_url: str
class ZoulangDetector:
def __init__(self):
# 模型加载
self.police_prisoner_detector = YOLOv8_ONNX(detector_model_path, conf_threshold=0.5, iou_threshold=0.45,
input_size=input_size)
# ByteTracker
class TrackerArgs:
track_thresh = 0.25
track_buffer = 30
match_thresh = 0.8
mot20 = False
self.police_prisoner_track_role = {}
self.fps = RTSP_TARGET_FPS
self.tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps)
# ==========================================
# 超参数设置 (Hyperparameters)
# ==========================================
# 1. 业务判定时间阈值
# self.TIME_THRESHOLD_NOBODY = 2.0 # 无人在场判定时长
self.TIME_THRESHOLD_POLICE = 1.0 # 警察判定时长
self.TIME_TOLERANCE_POLICE = 0.5 # 警察失缓冲时间(防抖动)
self.TIME_THRESHOLD_PRISONER = 1.0 # 犯人判定时长
self.TIME_TOLERANCE_PRISONER = 0.5 # 犯人丢失缓冲时间(防抖动)
# 无人在场帧数阈值
# self.frame_thresh_nobody = int(self.TIME_THRESHOLD_NOBODY * self.fps)
# 警察检测帧数阈值
self.frame_thresh_police = int(self.TIME_THRESHOLD_POLICE * self.fps)
self.frame_buffer_police = int(self.TIME_TOLERANCE_POLICE * self.fps)
# 犯人检测帧数阈值
self.frame_thresh_prisoner = int(self.TIME_THRESHOLD_PRISONER * self.fps)
self.frame_buffer_prisoner = int(self.TIME_TOLERANCE_PRISONER * self.fps)
print(f"\n超参数设置:")
print(f" FPS: {self.fps:.2f}")
# print(f" 判定 'Nobody' 需连续: {self.frame_thresh_nobody} 帧")
print(f" 判定 'police Detected' 需累计检测: {self.frame_thresh_police}")
print(f" 警察丢失缓冲帧数: {self.frame_buffer_police}")
print(f" 判定 'prisoner Detected' 需累计检测: {self.frame_thresh_prisoner}")
print(f" 犯人丢失缓冲帧数: {self.frame_buffer_prisoner}")
# ==========================================
# 状态变量初始化
# ==========================================
self.current_frame_idx = 0
# 无人在场检测状态变量
self.cnt_frame_nobody = 0
# 警察检测状态变量
self.police_detection_frames = 0 # 连续检测到警察的帧数
self.police_missing_frames = 0 # 连续未检测到警察的帧数
self.police_alert_active = False # 警察报警是否激活
# 犯人检测状态变量
self.prisoner_detection_frames = 0 # 连续检测到犯人的帧数
self.prisoner_missing_frames = 0 # 连续未检测到犯人的帧数
self.prisoner_alert_active = False # 犯人报警是否激活
def compute_iou(self,boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0):
"""在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)"""
font_scale = 1.5
thickness = 3
font = cv2.FONT_HERSHEY_SIMPLEX
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
x = self.width - text_w - 20
y = 50 + text_h + offset_y # 增加 Y 轴偏移
cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1)
cv2.putText(frame, text, (x, y), font, font_scale, color, thickness)
if sub_text:
cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2)
def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]:
h, w = frame.shape[:2]
self.width, self.height = w, h
self.current_frame_idx += 1
current_time_sec = timestamp
# ========= 警察和犯人检测 =========
police_prisoner_results = self.police_prisoner_detector(frame)
police_prisoner_dets_xyxy = []
police_prisoner_dets_roles = []
police_prisoner_dets_for_tracker = []
# ========= 当前帧所有警告列表(关键改动)==========
current_frame_alerts = [] # 每帧清空,重新收集
if police_prisoner_results:
for det in police_prisoner_results:
x1, y1, x2, y2, conf, cls_id = det # x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
police_prisoner_dets_xyxy.append([x1, y1, x2, y2])
police_prisoner_dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
police_prisoner_dets_roles.append("police")
elif cls_id == 1:
police_prisoner_dets_roles.append("prisoner")
ppolice_prisoner_dets = np.array(police_prisoner_dets_for_tracker, dtype=np.float32) if len(
police_prisoner_dets_for_tracker) else np.empty((0, 5))
police_prisoner_dets_tracks = self.tracker.update(
ppolice_prisoner_dets,
[self.height, self.width],
[self.height, self.width]
)
# ========= 单帧统计变量 =========
current_police_count = 0
current_prisoner_count = 0
# ========= 警察和犯人检测 =========
for t in police_prisoner_dets_tracks:
# print("t: {}".format(t))
tid = t.track_id
# cls_id = -1
# IoU 匹配角色
REVALIDATE_FRAME_INTERVAL = 10
if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (
tid not in self.police_prisoner_track_role):
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
for i, box in enumerate(police_prisoner_dets_xyxy):
iou_val = self.compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = police_prisoner_dets_roles[i]
if best_iou > 0.1:
self.police_prisoner_track_role[tid] = best_role
else:
self.police_prisoner_track_role[tid] = "unknown"
role = self.police_prisoner_track_role.get(tid, "unknown")
cls_id = -1
if role == "police":
cls_id = 0
elif role == "prisoner":
cls_id = 1
# print("tid: {}, role: {}, cls: {}".format(tid, role,cls_id))
x1, y1, x2, y2 = map(int, t.tlbr)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
color = None
label = None
if cls_id == 0: # Person
current_police_count += 1
color = (255, 0, 255)
label = "police"
elif cls_id == 1: # Phone主模型已支持
current_prisoner_count += 1
color = (0, 0, 139)
label = "prisoner"
else:
color = (255, 255, 255)
label = "Unknown"
# label = f"ID:{tid} IN"
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# ==========================================
# 犯人检测
# ==========================================
if current_prisoner_count > 0:
# 检测到犯人框
self.prisoner_detection_frames += 1
self.prisoner_missing_frames = 0 # 重置丢失计数器
# 当检测累计达到阈值时,激活报警
if self.prisoner_detection_frames >= self.frame_thresh_prisoner:
self.prisoner_alert_active = True
else:
# 未检测到犯人框
self.prisoner_missing_frames += 1
# 如果之前检测到手机,重置检测计数器
if self.prisoner_detection_frames > 0:
# 只有在连续丢失超过缓冲帧数时才重置
if self.prisoner_missing_frames >= self.frame_buffer_prisoner:
self.prisoner_detection_frames = 0
self.prisoner_alert_active = False
else:
# 从未检测到犯人,保持状态
pass
# ==========================================
# 警察检测
# ==========================================
if current_police_count > 0:
# 检测到犯人框
self.police_detection_frames += 1
self.police_missing_frames = 0 # 重置丢失计数器
# 当检测累计达到阈值时,激活报警
if self.police_detection_frames >= self.frame_thresh_police:
self.police_alert_active = True
else:
# 未检测到犯人框
self.police_missing_frames += 1
# 如果之前检测到手机,重置检测计数器
if self.police_detection_frames > 0:
# 只有在连续丢失超过缓冲帧数时才重置
if self.police_missing_frames >= self.frame_buffer_police:
self.police_detection_frames = 0
self.police_alert_active = False
else:
# 从未检测到犯人,保持状态
pass
alert_offset = 0
# A. 有犯人
if self.prisoner_alert_active:
duration_seconds = self.prisoner_detection_frames / self.fps
current_frame_alerts.append(
{
'time': current_time_sec,
'action': 'prisoner',
'confidence': 1.0, # 固定为1.0(规则判定)
'details': f"Detected for {duration_seconds:.1f}s"
}
)
self.draw_alert(frame, "prisoner", (0, 0, 255), offset_y=alert_offset)
alert_offset += 100
# ==========================================
# 11. 统一显示当前帧所有警告(可替换原分层显示)
# ==========================================
debug_info = f" prisoner: {current_prisoner_count}"
cv2.putText(frame, debug_info, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 统一警告显示区
alert_y_start = 150
for i, alert in enumerate(current_frame_alerts):
action = alert['action']
details = alert.get('details', '')
color = (0, 0, 255) # 默认红色警告
if action == 'prisoner':
color = (255, 255, 255)
main_text = action
if details:
main_text += f" ({details})"
y_pos = alert_y_start + i * 50
cv2.rectangle(frame, (20, y_pos - 40), (900, y_pos + 10), (0, 0, 0), -1)
cv2.putText(frame, main_text, (30, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
return {
"image": frame,
"alerts":current_frame_alerts
}
# ========================= WebSocket 服务线程 =========================
class WebSocketSender(threading.Thread):
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.send_queue = send_queue
self.stop_event = stop_event
async def _ws_handler(self, websocket):
ws_clients.add(websocket)
try:
async for _ in websocket:
pass
finally:
ws_clients.discard(websocket)
async def _broadcaster(self):
while not self.stop_event.is_set():
try:
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
except queue.Empty:
continue
data = json.dumps(msg)
dead = []
for ws in list(ws_clients):
try:
await ws.send(data)
except:
dead.append(ws)
for ws in dead:
ws_clients.discard(ws)
self.send_queue.task_done()
async def _run_async(self):
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
await self._broadcaster()
def run(self):
asyncio.run(self._run_async())
# ========================= WebSocket 服务线程2 =========================
class WebSocketSender2(threading.Thread):
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.send_queue = send_queue
self.stop_event = stop_event
async def _ws_handler(self, websocket):
ws_clients_2.add(websocket)
try:
async for _ in websocket:
pass
finally:
ws_clients_2.discard(websocket)
async def _broadcaster(self):
while not self.stop_event.is_set():
try:
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
except queue.Empty:
continue
data = json.dumps(msg)
dead = []
for ws in list(ws_clients_2):
try:
await ws.send(data)
except:
dead.append(ws)
for ws in dead:
ws_clients_2.discard(ws)
self.send_queue.task_done()
async def _run_async(self):
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT_2):
print(f"[INFO] WebSocket server 2 started at ws://{WS_HOST}:{WS_PORT_2}")
await self._broadcaster()
def run(self):
asyncio.run(self._run_async())
# ========================= RTSP 抓流线程 =========================
class RTSPCaptureWorker(threading.Thread):
def __init__(self, camera_cfg: CameraConfig, raw_queue: queue.Queue, stop_event: threading.Event):
super().__init__(daemon=True)
self.camera_cfg = camera_cfg
self.raw_queue = raw_queue
self.stop_event = stop_event
def run(self):
cap = cv2.VideoCapture(self.camera_cfg.rtsp_url, cv2.CAP_FFMPEG)
if not cap.isOpened():
print(f"[ERROR] Cannot open RTSP: {self.camera_cfg.rtsp_url}")
return
print(f"[INFO] Capturing {self.camera_cfg.name} (ID:{self.camera_cfg.id})")
while not self.stop_event.is_set():
ret, frame = cap.read()
if not ret:
time.sleep(0.2)
continue
item = {
"camera_id": self.camera_cfg.id,
"camera_name": self.camera_cfg.name,
"timestamp": time.time(),
"frame": frame,
}
try:
self.raw_queue.put(item, timeout=1.0)
except queue.Full:
pass
cap.release()
# ========================= 帧处理线程 =========================
class FrameProcessorWorker(threading.Thread):
def __init__(self,
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
ws_send_queue: "queue.Queue[Dict[str, Any]]",
ws_send_queue_2: "queue.Queue[Dict[str, Any]]",
stop_event: threading.Event):
super().__init__(daemon=True)
self.raw_queue = raw_frame_queue
self.ws_queue = ws_send_queue
self.ws_queue_2 = ws_send_queue_2 # 新增第二个WebSocket队列
self.stop_event = stop_event
self.last_ts: Dict[int, float] = {}
# 每个摄像头一个独立的 Kadian 检测器实例
self.kadian_detectors: Dict[int, ZoulangDetector] = {}
# 新增维护每个摄像头每个action的最后推送时间 {camera_id: {action: last_push_time}}
self.last_alert_push_time: Dict[int, Dict[str, float]] = {}
def _encode_image_to_base64(self, image) -> str:
ok, buf = cv2.imencode(".jpg", image)
if not ok:
raise RuntimeError("Failed to encode image to JPEG")
return base64.b64encode(buf.tobytes()).decode("ascii")
def run(self):
target_interval = 1.0 / RTSP_TARGET_FPS
while not self.stop_event.is_set():
try:
item = self.raw_queue.get(timeout=0.5)
except queue.Empty:
continue
cam_id = item["camera_id"]
ts = item["timestamp"]
frame = item["frame"]
# 抽帧控制
if ts - self.last_ts.get(cam_id, 0) < target_interval:
self.raw_queue.task_done()
continue
self.last_ts[cam_id] = ts
# 获取检测器实例
if cam_id not in self.kadian_detectors:
self.kadian_detectors[cam_id] = ZoulangDetector()
detector = self.kadian_detectors[cam_id]
# 执行检测
result = detector.process_frame(frame.copy(), cam_id, ts)
result_img = result["image"]
result_type = result["alerts"]
# ========= 核心修改过滤5秒内重复的action =========
# 初始化当前摄像头的推送时间记录
if cam_id not in self.last_alert_push_time:
self.last_alert_push_time[cam_id] = {}
# 筛选出符合推送条件的action5秒内未推送过
push_actions = []
current_time = time.time()
for alert in result_type:
action = alert['action']
last_push = self.last_alert_push_time[cam_id].get(action, 0)
# 检查是否超过推送间隔
if current_time - last_push >= ALERT_PUSH_INTERVAL:
push_actions.append(action)
# 更新该action的最后推送时间
self.last_alert_push_time[cam_id][action] = current_time
# 通过 WebSocket 发送帧结果
try:
img_b64 = self._encode_image_to_base64(result_img)
except Exception as e:
print(f"[ERROR] Encode image failed: {e}")
img_b64 = None
if img_b64 is not None:
# 将abnormal_actions对象数组转换为字符串数组
#action_names = [action_info['action'] for action_info in push_actions]
msg = {
"msg_type": "frame",
"camera_id": 1,
"timestamp": ts,
#"result_type": action_names,
"result_type": push_actions,
"image_base64": img_b64,
}
try:
self.ws_queue.put(msg, timeout=1.0)
#if action_names and len(action_names) > 0:
if push_actions and len(push_actions) > 0:
self.ws_queue_2.put(msg, timeout=1.0)
except queue.Full:
print("[WARN] ws_send_queue full, drop frame message")
self.raw_queue.task_done()
# ========================= 服务主类 =========================
class RTSPService:
def __init__(self, config_path: str = "config.yaml"):
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
self.cameras = [CameraConfig(id=c["id"], name=c.get("name", f"cam_{c['id']}"), rtsp_url=c["rtsp_url"])
for c in cfg.get("cameras", [])]
self.stop_event = threading.Event()
self.raw_queue = queue.Queue(maxsize=500)
self.ws_queue = queue.Queue(maxsize=1000)
self.ws_queue_2 = queue.Queue(maxsize=1000) # 新增第二个WebSocket队列
self.capture_workers = []
self.processor = FrameProcessorWorker(self.raw_queue, self.ws_queue, self.ws_queue_2, self.stop_event)
self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event)
self.ws_sender_2 = WebSocketSender2(self.ws_queue_2, self.stop_event) # 新增第二个WebSocket发送器
def start(self):
self.ws_sender.start()
self.ws_sender_2.start() # 新增启动第二个WebSocket服务
self.processor.start()
for cam in self.cameras:
w = RTSPCaptureWorker(cam, self.raw_queue, self.stop_event)
w.start()
self.capture_workers.append(w)
print("[INFO] Zoulang RTSP Service started")
def stop(self):
self.stop_event.set()
self.raw_queue.join()
self.ws_queue.join()
self.ws_queue_2.join() # 新增等待第二个WebSocket队列
for w in self.capture_workers:
w.join(timeout=2.0)
self.processor.join(timeout=2.0)
self.ws_sender.join(timeout=2.0)
self.ws_sender_2.join(timeout=2.0) # 新增等待第二个WebSocket发送器
print("[INFO] Service stopped")
if __name__ == "__main__":
service = RTSPService("config.yaml")
service.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
service.stop()