# rtsp_service_kadian.py # 融合 Kadian_Detect_1221.py + rtsp_service_ws.py # 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警 import cv2 import numpy as np import time import threading import queue import yaml import json import base64 import asyncio import websockets from dataclasses import dataclass from typing import Dict, Any # -------------------------- Kadian 检测相关导入 -------------------------- from algorithm.common.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机) from yolox.tracker.byte_tracker import BYTETracker # ========================= 配置区 ========================= # Kadian 模型路径与ROI(可根据实际情况修改) detector_model_path = 'YOLO_Weight/bag_model.onnx' # 输入尺寸 input_size = 640 # RTSP 服务配置 RTSP_TARGET_FPS = 5.0 WS_HOST = "0.0.0.0" WS_PORT = 8767 WS_PORT_2 = 8766 # 新增:第二个WebSocket端口 ALERT_PUSH_INTERVAL=5.0 # WebSocket 客户端集合 ws_clients = set() ws_clients_2 = set() # 新增:第二个WebSocket客户端集合 # ========================= 数据结构 ========================= @dataclass class CameraConfig: id: int name: str rtsp_url: str # ========================= Kadian TrafficMonitor(精简版,专为服务设计) ========================= class KadianDetector: def __init__(self): # 模型加载 self.detector = YOLOv8_ONNX(detector_model_path, conf_threshold=0.5, iou_threshold=0.45, input_size=input_size) # ByteTracker class TrackerArgs: track_thresh = 0.25 track_buffer = 30 match_thresh = 0.8 mot20 = False self.track_role = {} self.fps = RTSP_TARGET_FPS self.tracker = BYTETracker(TrackerArgs(), frame_rate=self.fps) # ========================================== # 超参数设置 # ========================================== self.TIME_THRESHOLD_BLACKBAG = 1.0 # 黑包判定时长(秒) self.TIME_TOLERANCE_BLACKBAG = 0.5 # 黑包丢失缓冲时间 # 转换为帧数阈值 self.frame_thresh_blackbag = int(self.TIME_THRESHOLD_BLACKBAG * self.fps) self.frame_buffer_blackbag = int(self.TIME_TOLERANCE_BLACKBAG * self.fps) print(f"\n超参数设置:") print(f" FPS: {self.fps:.2f}") print(f" 判定 'BlackBag Detected' 需累计检测: {self.frame_thresh_blackbag} 帧") print(f" 黑包丢失缓冲帧数: {self.frame_buffer_blackbag} 帧") # ========================================== # 状态变量初始化 # ========================================== self.current_frame_idx = 0 # 黑包检测状态 self.blackbag_detection_frames = 0 self.blackbag_missing_frames = 0 self.blackbag_alert_active = False # 人员统计变量 self.current_person_count = 0 def compute_iou(self,boxA, boxB): # box = [x1, y1, x2, y2] xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) interW = max(0, xB - xA) interH = max(0, yB - yA) interArea = interW * interH boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) unionArea = boxAArea + boxBArea - interArea if unionArea == 0: return 0.0 return interArea / unionArea def draw_alert(self, frame, text, color=(0, 0, 255), sub_text=None, offset_y=0): """在右上角绘制警告文字 (支持垂直偏移,防止文字重叠)""" font_scale = 1.5 thickness = 3 font = cv2.FONT_HERSHEY_SIMPLEX (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness) x = self.width - text_w - 20 y = 50 + text_h + offset_y # 增加 Y 轴偏移 cv2.rectangle(frame, (x - 10, y - text_h - 10), (x + text_w + 10, y + 10), (0, 0, 0), -1) cv2.putText(frame, text, (x, y), font, font_scale, color, thickness) if sub_text: cv2.putText(frame, sub_text, (x, y + 40), font, 0.7, (200, 200, 200), 2) def process_frame(self, frame, camera_id: int, timestamp: float) -> Dict[str, Any]: h, w = frame.shape[:2] self.width, self.height = w, h self.current_frame_idx += 1 current_time_sec = timestamp # ========= 检测推理(黑包+人)========= detect_results = self.detector(frame) # 初始化检测结果存储 dets_xyxy = [] dets_roles = [] dets_for_tracker = [] current_frame_alerts = [] # 解析检测结果(黑包cls_id=0,人员cls_id=1) if detect_results: for det in detect_results: x1, y1, x2, y2, conf, cls_id = det dets_xyxy.append([x1, y1, x2, y2]) dets_for_tracker.append([x1, y1, x2, y2, conf]) if cls_id == 0: dets_roles.append("black_bag") elif cls_id == 1: dets_roles.append("person") # 跟踪器更新 dets = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) else np.empty((0, 5)) tracks = self.tracker.update( dets, [self.height, self.width], [self.height, self.width] ) # ========= 单帧统计初始化 ========= self.current_person_count = 0 current_blackbag_count = 0 # ========= 跟踪结果绘制与统计 ========= for t in tracks: tid = t.track_id # IoU匹配跟踪ID和类别 REVALIDATE_FRAME_INTERVAL = 10 #if tid not in self.track_role: if (self.current_frame_idx % REVALIDATE_FRAME_INTERVAL == 0) or (tid not in self.track_role): best_iou = 0 best_role = "unknown" t_box = list(map(float, t.tlbr)) for i, box in enumerate(dets_xyxy): iou_val = self.compute_iou(t_box, box) if iou_val > best_iou: best_iou = iou_val best_role = dets_roles[i] self.track_role[tid] = best_role if best_iou > 0.1 else "unknown" role = self.track_role.get(tid, "unknown") x1, y1, x2, y2 = map(int, t.tlbr) color = (255, 255, 255) label = "Unknown" # 人员检测(cls_id=1) if role == "person": self.current_person_count += 1 color = (255, 0, 255) # 紫色框 label = "Person" # 黑包检测(cls_id=0) elif role == "black_bag": current_blackbag_count += 1 color = (0, 128, 0) # 绿色框 label = "Black Bag" # 绘制检测框和标签 cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) # ========================================== # 黑包状态更新 # ========================================== if current_blackbag_count > 0: self.blackbag_detection_frames += 1 self.blackbag_missing_frames = 0 if self.blackbag_detection_frames >= self.frame_thresh_blackbag: self.blackbag_alert_active = True else: self.blackbag_missing_frames += 1 if self.blackbag_missing_frames >= self.frame_buffer_blackbag: self.blackbag_detection_frames = 0 self.blackbag_alert_active = False # ========================================== # 警告信息收集 # ========================================== if self.blackbag_alert_active: duration_seconds = self.blackbag_detection_frames / self.fps current_frame_alerts.append( { 'time': current_time_sec, 'action': 'Black Bag', 'details': f"Detected for {duration_seconds:.1f}s" } ) self.draw_alert(frame, "Black Bag Alert", (0, 0, 255), sub_text=f"Detected for {duration_seconds:.1f}s") # ========================================== # 绘制信息 # ========================================== # 实时统计 debug_info = f"Person: {self.current_person_count} | BlackBag: {current_blackbag_count}" cv2.putText(frame, debug_info, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # 警告信息 alert_y_start = 150 for i, alert in enumerate(current_frame_alerts): action = alert['action'] details = alert.get('details', '') color = (0, 0, 255) # 红色警告 main_text = f"{action} ({details})" y_pos = alert_y_start + i * 50 cv2.rectangle(frame, (20, y_pos - 40), (900, y_pos + 10), (0, 0, 0), -1) cv2.putText(frame, main_text, (30, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) return { "image": frame, "alerts":current_frame_alerts } # ========================= WebSocket 服务线程 ========================= class WebSocketSender(threading.Thread): def __init__(self, send_queue: queue.Queue, stop_event: threading.Event): super().__init__(daemon=True) self.send_queue = send_queue self.stop_event = stop_event async def _ws_handler(self, websocket): ws_clients.add(websocket) try: async for _ in websocket: pass finally: ws_clients.discard(websocket) async def _broadcaster(self): while not self.stop_event.is_set(): try: msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5) except queue.Empty: continue data = json.dumps(msg) dead = [] for ws in list(ws_clients): try: await ws.send(data) except: dead.append(ws) for ws in dead: ws_clients.discard(ws) self.send_queue.task_done() async def _run_async(self): async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT): print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}") await self._broadcaster() def run(self): asyncio.run(self._run_async()) # ========================= WebSocket 服务线程2 ========================= class WebSocketSender2(threading.Thread): def __init__(self, send_queue: queue.Queue, stop_event: threading.Event): super().__init__(daemon=True) self.send_queue = send_queue self.stop_event = stop_event async def _ws_handler(self, websocket): ws_clients_2.add(websocket) try: async for _ in websocket: pass finally: ws_clients_2.discard(websocket) async def _broadcaster(self): while not self.stop_event.is_set(): try: msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5) except queue.Empty: continue data = json.dumps(msg) dead = [] for ws in list(ws_clients_2): try: await ws.send(data) except: dead.append(ws) for ws in dead: ws_clients_2.discard(ws) self.send_queue.task_done() async def _run_async(self): async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT_2): print(f"[INFO] WebSocket server 2 started at ws://{WS_HOST}:{WS_PORT_2}") await self._broadcaster() def run(self): asyncio.run(self._run_async()) # ========================= RTSP 抓流线程 ========================= class RTSPCaptureWorker(threading.Thread): def __init__(self, camera_cfg: CameraConfig, raw_queue: queue.Queue, stop_event: threading.Event): super().__init__(daemon=True) self.camera_cfg = camera_cfg self.raw_queue = raw_queue self.stop_event = stop_event def run(self): cap = cv2.VideoCapture(self.camera_cfg.rtsp_url, cv2.CAP_FFMPEG) if not cap.isOpened(): print(f"[ERROR] Cannot open RTSP: {self.camera_cfg.rtsp_url}") return print(f"[INFO] Capturing {self.camera_cfg.name} (ID:{self.camera_cfg.id})") while not self.stop_event.is_set(): ret, frame = cap.read() if not ret: time.sleep(0.2) continue item = { "camera_id": self.camera_cfg.id, "camera_name": self.camera_cfg.name, "timestamp": time.time(), "frame": frame, } try: self.raw_queue.put(item, timeout=1.0) except queue.Full: pass cap.release() # ========================= 帧处理线程 ========================= class FrameProcessorWorker(threading.Thread): def __init__(self, raw_frame_queue: "queue.Queue[Dict[str, Any]]", ws_send_queue: "queue.Queue[Dict[str, Any]]", ws_send_queue_2: "queue.Queue[Dict[str, Any]]", stop_event: threading.Event): super().__init__(daemon=True) self.raw_queue = raw_frame_queue self.ws_queue = ws_send_queue self.ws_queue_2 = ws_send_queue_2 # 新增:第二个WebSocket队列 self.stop_event = stop_event self.last_ts: Dict[int, float] = {} # 每个摄像头一个独立的 Kadian 检测器实例 self.kadian_detectors: Dict[int, KadianDetector] = {} self.last_alert_push_time: Dict[int,Dict[str,float]]={} def _encode_image_to_base64(self, image) -> str: ok, buf = cv2.imencode(".jpg", image) if not ok: raise RuntimeError("Failed to encode image to JPEG") return base64.b64encode(buf.tobytes()).decode("ascii") def run(self): target_interval = 1.0 / RTSP_TARGET_FPS while not self.stop_event.is_set(): try: item = self.raw_queue.get(timeout=0.5) except queue.Empty: continue cam_id = item["camera_id"] ts = item["timestamp"] frame = item["frame"] # 抽帧控制 if ts - self.last_ts.get(cam_id, 0) < target_interval: self.raw_queue.task_done() continue self.last_ts[cam_id] = ts # 获取检测器实例 if cam_id not in self.kadian_detectors: self.kadian_detectors[cam_id] = KadianDetector() detector = self.kadian_detectors[cam_id] # 执行检测 result = detector.process_frame(frame.copy(), cam_id, ts) result_img = result["image"] result_type = result["alerts"] if cam_id not in self.last_alert_push_time: self.last_alert_push_time[cam_id]={} push_actions=[] current_time=time.time() for alert in result_type: action=alert['action'] last_push=self.last_alert_push_time[cam_id].get(action,0) if current_time-last_push>=ALERT_PUSH_INTERVAL: push_actions.append(action) self.last_alert_push_time[cam_id][action] = current_time # 通过 WebSocket 发送帧结果 try: img_b64 = self._encode_image_to_base64(result_img) except Exception as e: print(f"[ERROR] Encode image failed: {e}") img_b64 = None if img_b64 is not None: # 将abnormal_actions对象数组转换为字符串数组 #action_names = [action_info['action'] for action_info in push_actions] msg = { "msg_type": "frame", "camera_id": 1, "timestamp": ts, #"result_type": action_names, "result_type": push_actions, "image_base64": img_b64, } try: self.ws_queue.put(msg, timeout=1.0) #if action_names and len(action_names) > 0: if push_actions and len(push_actions) > 0: self.ws_queue_2.put(msg, timeout=1.0) except queue.Full: print("[WARN] ws_send_queue full, drop frame message") self.raw_queue.task_done() # ========================= 服务主类 ========================= class RTSPService: def __init__(self, config_path: str = "config.yaml"): with open(config_path, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) self.cameras = [CameraConfig(id=c["id"], name=c.get("name", f"cam_{c['id']}"), rtsp_url=c["rtsp_url"]) for c in cfg.get("cameras", [])] self.stop_event = threading.Event() self.raw_queue = queue.Queue(maxsize=500) self.ws_queue = queue.Queue(maxsize=1000) self.ws_queue_2 = queue.Queue(maxsize=1000) # 新增:第二个WebSocket队列 self.capture_workers = [] self.processor = FrameProcessorWorker(self.raw_queue, self.ws_queue, self.ws_queue_2, self.stop_event) self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event) self.ws_sender_2 = WebSocketSender2(self.ws_queue_2, self.stop_event) # 新增:第二个WebSocket发送器 def start(self): self.ws_sender.start() self.ws_sender_2.start() # 新增:启动第二个WebSocket服务 self.processor.start() for cam in self.cameras: w = RTSPCaptureWorker(cam, self.raw_queue, self.stop_event) w.start() self.capture_workers.append(w) print("[INFO] Kadian RTSP Service started") def stop(self): self.stop_event.set() self.raw_queue.join() self.ws_queue.join() self.ws_queue_2.join() # 新增:等待第二个WebSocket队列 for w in self.capture_workers: w.join(timeout=2.0) self.processor.join(timeout=2.0) self.ws_sender.join(timeout=2.0) self.ws_sender_2.join(timeout=2.0) # 新增:等待第二个WebSocket发送器 print("[INFO] Service stopped") if __name__ == "__main__": service = RTSPService("config.yaml") service.start() try: while True: time.sleep(1) except KeyboardInterrupt: service.stop()