diff --git a/.gitignore b/.gitignore index a81c8ee..6b94b7c 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,10 @@ dmypy.json # Cython debug symbols cython_debug/ + +/test_data +/videos +/packages +/ONNX_Weight +/YOLO_Weight +/yolox \ No newline at end of file diff --git a/biz/checkpoint/checkpoint_biz.py b/biz/checkpoint/checkpoint_biz.py index 18998e0..bd0d816 100644 --- a/biz/checkpoint/checkpoint_biz.py +++ b/biz/checkpoint/checkpoint_biz.py @@ -1,7 +1,11 @@ import cv2 import numpy as np +import base64 from typing import Dict, Any +import threading +import time +import queue # -------------------------- Kadian 检测相关导入 -------------------------- from algorithm.checkpoint.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机) @@ -664,4 +668,99 @@ class KadianDetector: "image": frame, "alerts":current_frame_alerts - } \ No newline at end of file + } + + +# ========================= 帧处理线程 ========================= +class FrameProcessorWorker(threading.Thread): + def __init__(self, raw_queue: queue.Queue, ws_queue: queue.Queue, stop_event: threading.Event): + super().__init__(daemon=True) + self.raw_queue = raw_queue + self.ws_queue = ws_queue + self.stop_event = stop_event + + self.last_ts: Dict[int, float] = {} + + # 每个摄像头一个独立的 Kadian 检测器实例 + self.kadian_detectors: Dict[int, KadianDetector] = {} + + self.last_alert_push_time: Dict[int, Dict[str, float]] = {} + + def _encode_base64(self, img): + _, buf = cv2.imencode(".jpg", img) + return base64.b64encode(buf).decode("ascii") + + def run(self): + target_interval = 1.0 / RTSP_TARGET_FPS + while not self.stop_event.is_set(): + try: + item = self.raw_queue.get(timeout=0.5) + except queue.Empty: + continue + + cam_id = item["camera_id"] + ts = item["timestamp"] + frame = item["frame"] + + # 抽帧控制 + if ts - self.last_ts.get(cam_id, 0) < target_interval: + self.raw_queue.task_done() + continue + self.last_ts[cam_id] = ts + + # 获取检测器实例 + if cam_id not in self.kadian_detectors: + self.kadian_detectors[cam_id] = KadianDetector() + detector = self.kadian_detectors[cam_id] + + # 执行检测 + result = detector.process_frame(frame.copy(), cam_id, ts) + + result_img = result["image"] + result_type = result["alerts"] + # print(f"alerts: {result_type}") + + # ========= 核心修改:过滤5秒内重复的action ========= + # 初始化当前摄像头的推送时间记录 + if cam_id not in self.last_alert_push_time: + self.last_alert_push_time[cam_id] = {} + + # 筛选出符合推送条件的action(5秒内未推送过) + push_actions = [] + current_time = time.time() + for alert in result_type: + action = alert['action'] + last_push = self.last_alert_push_time[cam_id].get(action, 0) + # 检查是否超过推送间隔 + if current_time - last_push >= ALERT_PUSH_INTERVAL: + push_actions.append(action) + # 更新该action的最后推送时间 + self.last_alert_push_time[cam_id][action] = current_time + + # 通过 WebSocket 发送帧结果 + try: + img_b64 = self._encode_base64(result_img) + except Exception as e: + print(f"[ERROR] Encode image failed: {e}") + img_b64 = None + + if img_b64 is not None: + # 将abnormal_actions对象数组转换为字符串数组 + # action_names = [action_info['action'] for action_info in push_actions] + + msg = { + "msg_type": "frame", + "camera_id": 0, + "timestamp": ts, + # "result_type": action_names, + "result_type": push_actions, + "image_base64": img_b64, + } + try: + self.ws_queue.put(msg, timeout=1.0) + # if push_actions and len(push_actions) > 0: + # self.ws_queue_2.put(msg, timeout=1.0) + except queue.Full: + print("[WARN] ws_send_queue full, drop frame message") + + self.raw_queue.task_done() diff --git a/rtsp_service_ws_kadian.py b/rtsp_service_ws_kadian.py index cf83e90..9be0261 100644 --- a/rtsp_service_ws_kadian.py +++ b/rtsp_service_ws_kadian.py @@ -9,21 +9,17 @@ import time import threading import queue import yaml -import json -import base64 -import asyncio -import websockets -from dataclasses import dataclass -from typing import Dict, Any -from biz.checkpoint.checkpoint_biz import KadianDetector, RTSP_TARGET_FPS, ALERT_PUSH_INTERVAL +from dataclasses import dataclass + +from biz.checkpoint.checkpoint_biz import KadianDetector, RTSP_TARGET_FPS, ALERT_PUSH_INTERVAL, FrameProcessorWorker from test_cam import get_camera_preview_url +from utils.web_socket_sender import WebSocketSender WS_HOST = "0.0.0.0" WS_PORT = 8765 -# WebSocket 客户端集合 -ws_clients = set() + # ========================= 数据结构 ========================= @@ -35,47 +31,6 @@ class CameraConfig: rtsp_url: str -# ========================= WebSocket 服务线程 ========================= -class WebSocketSender(threading.Thread): - def __init__(self, send_queue: queue.Queue, stop_event: threading.Event): - super().__init__(daemon=True) - self.send_queue = send_queue - self.stop_event = stop_event - - async def _ws_handler(self, websocket): - ws_clients.add(websocket) - try: - async for _ in websocket: - pass - finally: - ws_clients.discard(websocket) - - async def _broadcaster(self): - while not self.stop_event.is_set(): - try: - msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5) - except queue.Empty: - continue - data = json.dumps(msg) - dead = [] - for ws in list(ws_clients): - try: - await ws.send(data) - except: - dead.append(ws) - for ws in dead: - ws_clients.discard(ws) - self.send_queue.task_done() - - async def _run_async(self): - async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT): - print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}") - await self._broadcaster() - - def run(self): - asyncio.run(self._run_async()) - - # ========================= RTSP 抓流线程 ========================= class RTSPCaptureWorker(threading.Thread): @@ -220,101 +175,6 @@ class RTSPCaptureWorker(threading.Thread): return None -# ========================= 帧处理线程 ========================= -class FrameProcessorWorker(threading.Thread): - def __init__(self, raw_queue: queue.Queue, ws_queue: queue.Queue, stop_event: threading.Event): - super().__init__(daemon=True) - self.raw_queue = raw_queue - self.ws_queue = ws_queue - self.stop_event = stop_event - - self.last_ts: Dict[int, float] = {} - - # 每个摄像头一个独立的 Kadian 检测器实例 - self.kadian_detectors: Dict[int, KadianDetector] = {} - - self.last_alert_push_time: Dict[int,Dict[str,float]]={} - - - def _encode_base64(self, img): - _, buf = cv2.imencode(".jpg", img) - return base64.b64encode(buf).decode("ascii") - - def run(self): - target_interval = 1.0 / RTSP_TARGET_FPS - while not self.stop_event.is_set(): - try: - item = self.raw_queue.get(timeout=0.5) - except queue.Empty: - continue - - cam_id = item["camera_id"] - ts = item["timestamp"] - frame = item["frame"] - - # 抽帧控制 - if ts - self.last_ts.get(cam_id, 0) < target_interval: - self.raw_queue.task_done() - continue - self.last_ts[cam_id] = ts - - # 获取检测器实例 - if cam_id not in self.kadian_detectors: - self.kadian_detectors[cam_id] = KadianDetector() - detector = self.kadian_detectors[cam_id] - - # 执行检测 - result = detector.process_frame(frame.copy(), cam_id, ts) - - result_img = result["image"] - result_type = result["alerts"] - #print(f"alerts: {result_type}") - - # ========= 核心修改:过滤5秒内重复的action ========= - # 初始化当前摄像头的推送时间记录 - if cam_id not in self.last_alert_push_time: - self.last_alert_push_time[cam_id] = {} - - # 筛选出符合推送条件的action(5秒内未推送过) - push_actions = [] - current_time = time.time() - for alert in result_type: - action = alert['action'] - last_push = self.last_alert_push_time[cam_id].get(action, 0) - # 检查是否超过推送间隔 - if current_time - last_push >= ALERT_PUSH_INTERVAL: - push_actions.append(action) - # 更新该action的最后推送时间 - self.last_alert_push_time[cam_id][action] = current_time - - # 通过 WebSocket 发送帧结果 - try: - img_b64 = self._encode_base64(result_img) - except Exception as e: - print(f"[ERROR] Encode image failed: {e}") - img_b64 = None - - if img_b64 is not None: - # 将abnormal_actions对象数组转换为字符串数组 - #action_names = [action_info['action'] for action_info in push_actions] - - msg = { - "msg_type": "frame", - "camera_id": 0, - "timestamp": ts, - #"result_type": action_names, - "result_type": push_actions, - "image_base64": img_b64, - } - try: - self.ws_queue.put(msg, timeout=1.0) - # if push_actions and len(push_actions) > 0: - # self.ws_queue_2.put(msg, timeout=1.0) - except queue.Full: - print("[WARN] ws_send_queue full, drop frame message") - - self.raw_queue.task_done() - # ========================= 服务主类 ========================= class RTSPService: @@ -330,7 +190,7 @@ class RTSPService: self.capture_workers = [] self.processor = FrameProcessorWorker(self.raw_queue, self.ws_queue, self.stop_event) - self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event) + self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event, WS_HOST, WS_PORT) def start(self): self.ws_sender.start() diff --git a/utils/web_socket_sender.py b/utils/web_socket_sender.py new file mode 100644 index 0000000..d90a90b --- /dev/null +++ b/utils/web_socket_sender.py @@ -0,0 +1,49 @@ + +import json +import asyncio +import websockets +import threading +import queue + +# ========================= WebSocket 服务线程 ========================= +class WebSocketSender(threading.Thread): + def __init__(self, send_queue: queue.Queue, stop_event: threading.Event, ws_host: str, ws_port: int): + super().__init__(daemon=True) + self.send_queue = send_queue + self.stop_event = stop_event + self.ws_clients = set() + self.ws_host = ws_host + self.ws_port = ws_port + + async def _ws_handler(self, websocket): + self.ws_clients.add(websocket) + try: + async for _ in websocket: + pass + finally: + self.ws_clients.discard(websocket) + + async def _broadcaster(self): + while not self.stop_event.is_set(): + try: + msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5) + except queue.Empty: + continue + data = json.dumps(msg) + dead = [] + for ws in list(self.ws_clients): + try: + await ws.send(data) + except: + dead.append(ws) + for ws in dead: + self.ws_clients.discard(ws) + self.send_queue.task_done() + + async def _run_async(self): + async with websockets.serve(self._ws_handler, self.ws_host, self.ws_port): + print(f"[INFO] WebSocket server started at ws://{self.ws_host}:{self.ws_port}") + await self._broadcaster() + + def run(self): + asyncio.run(self._run_async())