继续整理卡点代码
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -136,3 +136,10 @@ dmypy.json
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
/test_data
|
||||
/videos
|
||||
/packages
|
||||
/ONNX_Weight
|
||||
/YOLO_Weight
|
||||
/yolox
|
||||
@@ -1,7 +1,11 @@
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import base64
|
||||
from typing import Dict, Any
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
|
||||
# -------------------------- Kadian 检测相关导入 --------------------------
|
||||
from algorithm.checkpoint.npu_yolo_onnx_person_car_phone import YOLOv8_ONNX # 主检测模型(人/车/后备箱/手机)
|
||||
@@ -664,4 +668,99 @@ class KadianDetector:
|
||||
"image": frame,
|
||||
|
||||
"alerts":current_frame_alerts
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ========================= 帧处理线程 =========================
|
||||
class FrameProcessorWorker(threading.Thread):
|
||||
def __init__(self, raw_queue: queue.Queue, ws_queue: queue.Queue, stop_event: threading.Event):
|
||||
super().__init__(daemon=True)
|
||||
self.raw_queue = raw_queue
|
||||
self.ws_queue = ws_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
self.last_ts: Dict[int, float] = {}
|
||||
|
||||
# 每个摄像头一个独立的 Kadian 检测器实例
|
||||
self.kadian_detectors: Dict[int, KadianDetector] = {}
|
||||
|
||||
self.last_alert_push_time: Dict[int, Dict[str, float]] = {}
|
||||
|
||||
def _encode_base64(self, img):
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
return base64.b64encode(buf).decode("ascii")
|
||||
|
||||
def run(self):
|
||||
target_interval = 1.0 / RTSP_TARGET_FPS
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
item = self.raw_queue.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
cam_id = item["camera_id"]
|
||||
ts = item["timestamp"]
|
||||
frame = item["frame"]
|
||||
|
||||
# 抽帧控制
|
||||
if ts - self.last_ts.get(cam_id, 0) < target_interval:
|
||||
self.raw_queue.task_done()
|
||||
continue
|
||||
self.last_ts[cam_id] = ts
|
||||
|
||||
# 获取检测器实例
|
||||
if cam_id not in self.kadian_detectors:
|
||||
self.kadian_detectors[cam_id] = KadianDetector()
|
||||
detector = self.kadian_detectors[cam_id]
|
||||
|
||||
# 执行检测
|
||||
result = detector.process_frame(frame.copy(), cam_id, ts)
|
||||
|
||||
result_img = result["image"]
|
||||
result_type = result["alerts"]
|
||||
# print(f"alerts: {result_type}")
|
||||
|
||||
# ========= 核心修改:过滤5秒内重复的action =========
|
||||
# 初始化当前摄像头的推送时间记录
|
||||
if cam_id not in self.last_alert_push_time:
|
||||
self.last_alert_push_time[cam_id] = {}
|
||||
|
||||
# 筛选出符合推送条件的action(5秒内未推送过)
|
||||
push_actions = []
|
||||
current_time = time.time()
|
||||
for alert in result_type:
|
||||
action = alert['action']
|
||||
last_push = self.last_alert_push_time[cam_id].get(action, 0)
|
||||
# 检查是否超过推送间隔
|
||||
if current_time - last_push >= ALERT_PUSH_INTERVAL:
|
||||
push_actions.append(action)
|
||||
# 更新该action的最后推送时间
|
||||
self.last_alert_push_time[cam_id][action] = current_time
|
||||
|
||||
# 通过 WebSocket 发送帧结果
|
||||
try:
|
||||
img_b64 = self._encode_base64(result_img)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Encode image failed: {e}")
|
||||
img_b64 = None
|
||||
|
||||
if img_b64 is not None:
|
||||
# 将abnormal_actions对象数组转换为字符串数组
|
||||
# action_names = [action_info['action'] for action_info in push_actions]
|
||||
|
||||
msg = {
|
||||
"msg_type": "frame",
|
||||
"camera_id": 0,
|
||||
"timestamp": ts,
|
||||
# "result_type": action_names,
|
||||
"result_type": push_actions,
|
||||
"image_base64": img_b64,
|
||||
}
|
||||
try:
|
||||
self.ws_queue.put(msg, timeout=1.0)
|
||||
# if push_actions and len(push_actions) > 0:
|
||||
# self.ws_queue_2.put(msg, timeout=1.0)
|
||||
except queue.Full:
|
||||
print("[WARN] ws_send_queue full, drop frame message")
|
||||
|
||||
self.raw_queue.task_done()
|
||||
|
||||
@@ -9,21 +9,17 @@ import time
|
||||
import threading
|
||||
import queue
|
||||
import yaml
|
||||
import json
|
||||
import base64
|
||||
import asyncio
|
||||
import websockets
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
|
||||
from biz.checkpoint.checkpoint_biz import KadianDetector, RTSP_TARGET_FPS, ALERT_PUSH_INTERVAL
|
||||
from dataclasses import dataclass
|
||||
|
||||
from biz.checkpoint.checkpoint_biz import KadianDetector, RTSP_TARGET_FPS, ALERT_PUSH_INTERVAL, FrameProcessorWorker
|
||||
from test_cam import get_camera_preview_url
|
||||
from utils.web_socket_sender import WebSocketSender
|
||||
|
||||
WS_HOST = "0.0.0.0"
|
||||
WS_PORT = 8765
|
||||
|
||||
# WebSocket 客户端集合
|
||||
ws_clients = set()
|
||||
|
||||
|
||||
|
||||
# ========================= 数据结构 =========================
|
||||
@@ -35,47 +31,6 @@ class CameraConfig:
|
||||
rtsp_url: str
|
||||
|
||||
|
||||
# ========================= WebSocket 服务线程 =========================
|
||||
class WebSocketSender(threading.Thread):
|
||||
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event):
|
||||
super().__init__(daemon=True)
|
||||
self.send_queue = send_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
async def _ws_handler(self, websocket):
|
||||
ws_clients.add(websocket)
|
||||
try:
|
||||
async for _ in websocket:
|
||||
pass
|
||||
finally:
|
||||
ws_clients.discard(websocket)
|
||||
|
||||
async def _broadcaster(self):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
data = json.dumps(msg)
|
||||
dead = []
|
||||
for ws in list(ws_clients):
|
||||
try:
|
||||
await ws.send(data)
|
||||
except:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
ws_clients.discard(ws)
|
||||
self.send_queue.task_done()
|
||||
|
||||
async def _run_async(self):
|
||||
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
|
||||
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
|
||||
await self._broadcaster()
|
||||
|
||||
def run(self):
|
||||
asyncio.run(self._run_async())
|
||||
|
||||
|
||||
|
||||
# ========================= RTSP 抓流线程 =========================
|
||||
class RTSPCaptureWorker(threading.Thread):
|
||||
@@ -220,101 +175,6 @@ class RTSPCaptureWorker(threading.Thread):
|
||||
return None
|
||||
|
||||
|
||||
# ========================= 帧处理线程 =========================
|
||||
class FrameProcessorWorker(threading.Thread):
|
||||
def __init__(self, raw_queue: queue.Queue, ws_queue: queue.Queue, stop_event: threading.Event):
|
||||
super().__init__(daemon=True)
|
||||
self.raw_queue = raw_queue
|
||||
self.ws_queue = ws_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
self.last_ts: Dict[int, float] = {}
|
||||
|
||||
# 每个摄像头一个独立的 Kadian 检测器实例
|
||||
self.kadian_detectors: Dict[int, KadianDetector] = {}
|
||||
|
||||
self.last_alert_push_time: Dict[int,Dict[str,float]]={}
|
||||
|
||||
|
||||
def _encode_base64(self, img):
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
return base64.b64encode(buf).decode("ascii")
|
||||
|
||||
def run(self):
|
||||
target_interval = 1.0 / RTSP_TARGET_FPS
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
item = self.raw_queue.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
cam_id = item["camera_id"]
|
||||
ts = item["timestamp"]
|
||||
frame = item["frame"]
|
||||
|
||||
# 抽帧控制
|
||||
if ts - self.last_ts.get(cam_id, 0) < target_interval:
|
||||
self.raw_queue.task_done()
|
||||
continue
|
||||
self.last_ts[cam_id] = ts
|
||||
|
||||
# 获取检测器实例
|
||||
if cam_id not in self.kadian_detectors:
|
||||
self.kadian_detectors[cam_id] = KadianDetector()
|
||||
detector = self.kadian_detectors[cam_id]
|
||||
|
||||
# 执行检测
|
||||
result = detector.process_frame(frame.copy(), cam_id, ts)
|
||||
|
||||
result_img = result["image"]
|
||||
result_type = result["alerts"]
|
||||
#print(f"alerts: {result_type}")
|
||||
|
||||
# ========= 核心修改:过滤5秒内重复的action =========
|
||||
# 初始化当前摄像头的推送时间记录
|
||||
if cam_id not in self.last_alert_push_time:
|
||||
self.last_alert_push_time[cam_id] = {}
|
||||
|
||||
# 筛选出符合推送条件的action(5秒内未推送过)
|
||||
push_actions = []
|
||||
current_time = time.time()
|
||||
for alert in result_type:
|
||||
action = alert['action']
|
||||
last_push = self.last_alert_push_time[cam_id].get(action, 0)
|
||||
# 检查是否超过推送间隔
|
||||
if current_time - last_push >= ALERT_PUSH_INTERVAL:
|
||||
push_actions.append(action)
|
||||
# 更新该action的最后推送时间
|
||||
self.last_alert_push_time[cam_id][action] = current_time
|
||||
|
||||
# 通过 WebSocket 发送帧结果
|
||||
try:
|
||||
img_b64 = self._encode_base64(result_img)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Encode image failed: {e}")
|
||||
img_b64 = None
|
||||
|
||||
if img_b64 is not None:
|
||||
# 将abnormal_actions对象数组转换为字符串数组
|
||||
#action_names = [action_info['action'] for action_info in push_actions]
|
||||
|
||||
msg = {
|
||||
"msg_type": "frame",
|
||||
"camera_id": 0,
|
||||
"timestamp": ts,
|
||||
#"result_type": action_names,
|
||||
"result_type": push_actions,
|
||||
"image_base64": img_b64,
|
||||
}
|
||||
try:
|
||||
self.ws_queue.put(msg, timeout=1.0)
|
||||
# if push_actions and len(push_actions) > 0:
|
||||
# self.ws_queue_2.put(msg, timeout=1.0)
|
||||
except queue.Full:
|
||||
print("[WARN] ws_send_queue full, drop frame message")
|
||||
|
||||
self.raw_queue.task_done()
|
||||
|
||||
|
||||
# ========================= 服务主类 =========================
|
||||
class RTSPService:
|
||||
@@ -330,7 +190,7 @@ class RTSPService:
|
||||
|
||||
self.capture_workers = []
|
||||
self.processor = FrameProcessorWorker(self.raw_queue, self.ws_queue, self.stop_event)
|
||||
self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event)
|
||||
self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event, WS_HOST, WS_PORT)
|
||||
|
||||
def start(self):
|
||||
self.ws_sender.start()
|
||||
|
||||
49
utils/web_socket_sender.py
Normal file
49
utils/web_socket_sender.py
Normal file
@@ -0,0 +1,49 @@
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import websockets
|
||||
import threading
|
||||
import queue
|
||||
|
||||
# ========================= WebSocket 服务线程 =========================
|
||||
class WebSocketSender(threading.Thread):
|
||||
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event, ws_host: str, ws_port: int):
|
||||
super().__init__(daemon=True)
|
||||
self.send_queue = send_queue
|
||||
self.stop_event = stop_event
|
||||
self.ws_clients = set()
|
||||
self.ws_host = ws_host
|
||||
self.ws_port = ws_port
|
||||
|
||||
async def _ws_handler(self, websocket):
|
||||
self.ws_clients.add(websocket)
|
||||
try:
|
||||
async for _ in websocket:
|
||||
pass
|
||||
finally:
|
||||
self.ws_clients.discard(websocket)
|
||||
|
||||
async def _broadcaster(self):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
data = json.dumps(msg)
|
||||
dead = []
|
||||
for ws in list(self.ws_clients):
|
||||
try:
|
||||
await ws.send(data)
|
||||
except:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
self.ws_clients.discard(ws)
|
||||
self.send_queue.task_done()
|
||||
|
||||
async def _run_async(self):
|
||||
async with websockets.serve(self._ws_handler, self.ws_host, self.ws_port):
|
||||
print(f"[INFO] WebSocket server started at ws://{self.ws_host}:{self.ws_port}")
|
||||
await self._broadcaster()
|
||||
|
||||
def run(self):
|
||||
asyncio.run(self._run_async())
|
||||
Reference in New Issue
Block a user