363 lines
13 KiB
Python
363 lines
13 KiB
Python
# rtsp_service_kadian.py
|
||
# 融合 Kadian_Detect_1221.py + rtsp_service_ws.py
|
||
# 支持多路RTSP、抽帧、分段保存MP4、WebSocket推送图像与告警
|
||
|
||
import cv2
|
||
|
||
import os
|
||
import time
|
||
import threading
|
||
import queue
|
||
import yaml
|
||
import json
|
||
import base64
|
||
import asyncio
|
||
import websockets
|
||
from dataclasses import dataclass
|
||
from typing import Dict, Any
|
||
|
||
from biz.checkpoint.checkpoint_biz import KadianDetector, RTSP_TARGET_FPS, ALERT_PUSH_INTERVAL
|
||
from test_cam import get_camera_preview_url
|
||
|
||
WS_HOST = "0.0.0.0"
|
||
WS_PORT = 8765
|
||
|
||
# WebSocket 客户端集合
|
||
ws_clients = set()
|
||
|
||
|
||
# ========================= 数据结构 =========================
|
||
@dataclass
|
||
class CameraConfig:
|
||
id: int
|
||
name: str
|
||
index: str
|
||
rtsp_url: str
|
||
|
||
|
||
# ========================= WebSocket 服务线程 =========================
|
||
class WebSocketSender(threading.Thread):
|
||
def __init__(self, send_queue: queue.Queue, stop_event: threading.Event):
|
||
super().__init__(daemon=True)
|
||
self.send_queue = send_queue
|
||
self.stop_event = stop_event
|
||
|
||
async def _ws_handler(self, websocket):
|
||
ws_clients.add(websocket)
|
||
try:
|
||
async for _ in websocket:
|
||
pass
|
||
finally:
|
||
ws_clients.discard(websocket)
|
||
|
||
async def _broadcaster(self):
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
|
||
except queue.Empty:
|
||
continue
|
||
data = json.dumps(msg)
|
||
dead = []
|
||
for ws in list(ws_clients):
|
||
try:
|
||
await ws.send(data)
|
||
except:
|
||
dead.append(ws)
|
||
for ws in dead:
|
||
ws_clients.discard(ws)
|
||
self.send_queue.task_done()
|
||
|
||
async def _run_async(self):
|
||
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
|
||
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
|
||
await self._broadcaster()
|
||
|
||
def run(self):
|
||
asyncio.run(self._run_async())
|
||
|
||
|
||
|
||
# ========================= RTSP 抓流线程 =========================
|
||
class RTSPCaptureWorker(threading.Thread):
|
||
def __init__(self, camera_cfg: CameraConfig, raw_queue: queue.Queue, stop_event: threading.Event):
|
||
super().__init__(daemon=True)
|
||
self.camera_cfg = camera_cfg
|
||
self.raw_queue = raw_queue
|
||
self.stop_event = stop_event
|
||
# 添加重连计数器
|
||
self.reconnect_count = 0
|
||
self.max_reconnects = 5
|
||
self.rtsp_url = ""
|
||
|
||
def run(self):
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
|
||
if self.reconnect_count >= self.max_reconnects:
|
||
print(f"[WARN] RTSP: {self.camera_cfg.name} reach max reconnects, refresh url")
|
||
self.reconnect_count = 0
|
||
new_url = self.refresh_video_url()
|
||
if new_url:
|
||
self.rtsp_url = new_url
|
||
else:
|
||
print(f"[ERROR] refresh RTSP URL is empty, do nothing")
|
||
|
||
# 检查rtsp_url是否为空或None,如果是则重新获取
|
||
if not self.rtsp_url:
|
||
print(f"[WARN] RTSP URL is empty, refreshing...")
|
||
new_url = self.refresh_video_url()
|
||
if new_url:
|
||
self.rtsp_url = new_url
|
||
else:
|
||
print(f"[ERROR] RTSP URL is still empty, retrying in 5 seconds")
|
||
time.sleep(5)
|
||
continue
|
||
|
||
# 方法1:使用TCP传输(更稳定)
|
||
rtsp_url = self.rtsp_url
|
||
if "?" not in rtsp_url:
|
||
rtsp_url += "?transport=tcp" # 强制TCP传输
|
||
else:
|
||
rtsp_url += "&transport=tcp"
|
||
|
||
# 方法2:添加更多FFmpeg参数
|
||
cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
|
||
|
||
# 方法3:设置缓冲区大小
|
||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 10) # 增加缓冲区
|
||
|
||
# 方法4:设置超时和重连参数
|
||
os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = \
|
||
"rtsp_transport;tcp|buffer_size;1024000|max_delay;500000|stimeout;2000000"
|
||
|
||
# 方法5:设置解码器flags,忽略解码错误
|
||
# cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
|
||
|
||
if not cap.isOpened():
|
||
print(f"[ERROR] Cannot open RTSP: {self.rtsp_url}")
|
||
time.sleep(2)
|
||
self.reconnect_count += 1
|
||
continue
|
||
|
||
print(f"[INFO] Successfully opened RTSP: {self.name}")
|
||
self.reconnect_count = 0 # 重置重连计数
|
||
|
||
# # 设置帧率(可选)
|
||
# cap.set(cv2.CAP_PROP_FPS, 25)
|
||
|
||
while not self.stop_event.is_set():
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
# 检查流是否结束
|
||
print(f"[WARN] Failed to read frame from {self.camera_cfg.name}")
|
||
|
||
# 检查是否还有数据
|
||
time.sleep(0.1)
|
||
# 尝试几次后重连
|
||
break
|
||
|
||
item = {
|
||
"camera_id": self.camera_cfg.id,
|
||
"camera_name": self.camera_cfg.name,
|
||
"timestamp": time.time(),
|
||
"frame": frame,
|
||
}
|
||
|
||
try:
|
||
# 添加队列满时的处理
|
||
if self.raw_queue.full():
|
||
# 丢弃最旧的一帧
|
||
try:
|
||
self.raw_queue.get_nowait()
|
||
self.raw_queue.task_done()
|
||
except queue.Empty:
|
||
pass
|
||
|
||
self.raw_queue.put(item, timeout=0.5)
|
||
except queue.Full:
|
||
print(f"[WARN] Queue full, dropping frame from {self.camera_cfg.name}")
|
||
continue
|
||
|
||
# 控制读取速度,避免过快
|
||
time.sleep(0.02) # 约50ms间隔
|
||
|
||
cap.release()
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error in RTSP capture for {self.camera_cfg.name}: {e}")
|
||
time.sleep(2)
|
||
self.reconnect_count += 1
|
||
|
||
if self.reconnect_count >= self.max_reconnects:
|
||
print(f"[ERROR] Max reconnects reached for {self.camera_cfg.name}, stopping.")
|
||
|
||
def refresh_video_url(self):
|
||
"""
|
||
重新通过视频ID获取视频URL,调用test_cam.py中的get_camera_preview_url方法
|
||
|
||
返回:
|
||
str: 新的视频URL,如果获取失败则返回None
|
||
"""
|
||
try:
|
||
|
||
# 获取视频ID(camera_cfg.index)
|
||
video_id = self.camera_cfg.index
|
||
|
||
# 调用test_cam.py中的函数
|
||
result = get_camera_preview_url(video_id)
|
||
|
||
# 解析结果(与test_cam.py相同)
|
||
if 'data' in result and 'url' in result['data']:
|
||
new_url = result['data']['url']
|
||
print(f"[INFO] get rtsp url success, URL: {new_url}")
|
||
return new_url
|
||
else:
|
||
print(f"[ERROR] get rtsp url failed: {result}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] get rtsp url error: {str(e)}")
|
||
return None
|
||
|
||
|
||
# ========================= 帧处理线程 =========================
|
||
class FrameProcessorWorker(threading.Thread):
|
||
def __init__(self, raw_queue: queue.Queue, ws_queue: queue.Queue, stop_event: threading.Event):
|
||
super().__init__(daemon=True)
|
||
self.raw_queue = raw_queue
|
||
self.ws_queue = ws_queue
|
||
self.stop_event = stop_event
|
||
|
||
self.last_ts: Dict[int, float] = {}
|
||
|
||
# 每个摄像头一个独立的 Kadian 检测器实例
|
||
self.kadian_detectors: Dict[int, KadianDetector] = {}
|
||
|
||
self.last_alert_push_time: Dict[int,Dict[str,float]]={}
|
||
|
||
|
||
def _encode_base64(self, img):
|
||
_, buf = cv2.imencode(".jpg", img)
|
||
return base64.b64encode(buf).decode("ascii")
|
||
|
||
def run(self):
|
||
target_interval = 1.0 / RTSP_TARGET_FPS
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
item = self.raw_queue.get(timeout=0.5)
|
||
except queue.Empty:
|
||
continue
|
||
|
||
cam_id = item["camera_id"]
|
||
ts = item["timestamp"]
|
||
frame = item["frame"]
|
||
|
||
# 抽帧控制
|
||
if ts - self.last_ts.get(cam_id, 0) < target_interval:
|
||
self.raw_queue.task_done()
|
||
continue
|
||
self.last_ts[cam_id] = ts
|
||
|
||
# 获取检测器实例
|
||
if cam_id not in self.kadian_detectors:
|
||
self.kadian_detectors[cam_id] = KadianDetector()
|
||
detector = self.kadian_detectors[cam_id]
|
||
|
||
# 执行检测
|
||
result = detector.process_frame(frame.copy(), cam_id, ts)
|
||
|
||
result_img = result["image"]
|
||
result_type = result["alerts"]
|
||
#print(f"alerts: {result_type}")
|
||
|
||
# ========= 核心修改:过滤5秒内重复的action =========
|
||
# 初始化当前摄像头的推送时间记录
|
||
if cam_id not in self.last_alert_push_time:
|
||
self.last_alert_push_time[cam_id] = {}
|
||
|
||
# 筛选出符合推送条件的action(5秒内未推送过)
|
||
push_actions = []
|
||
current_time = time.time()
|
||
for alert in result_type:
|
||
action = alert['action']
|
||
last_push = self.last_alert_push_time[cam_id].get(action, 0)
|
||
# 检查是否超过推送间隔
|
||
if current_time - last_push >= ALERT_PUSH_INTERVAL:
|
||
push_actions.append(action)
|
||
# 更新该action的最后推送时间
|
||
self.last_alert_push_time[cam_id][action] = current_time
|
||
|
||
# 通过 WebSocket 发送帧结果
|
||
try:
|
||
img_b64 = self._encode_base64(result_img)
|
||
except Exception as e:
|
||
print(f"[ERROR] Encode image failed: {e}")
|
||
img_b64 = None
|
||
|
||
if img_b64 is not None:
|
||
# 将abnormal_actions对象数组转换为字符串数组
|
||
#action_names = [action_info['action'] for action_info in push_actions]
|
||
|
||
msg = {
|
||
"msg_type": "frame",
|
||
"camera_id": 0,
|
||
"timestamp": ts,
|
||
#"result_type": action_names,
|
||
"result_type": push_actions,
|
||
"image_base64": img_b64,
|
||
}
|
||
try:
|
||
self.ws_queue.put(msg, timeout=1.0)
|
||
# if push_actions and len(push_actions) > 0:
|
||
# self.ws_queue_2.put(msg, timeout=1.0)
|
||
except queue.Full:
|
||
print("[WARN] ws_send_queue full, drop frame message")
|
||
|
||
self.raw_queue.task_done()
|
||
|
||
|
||
# ========================= 服务主类 =========================
|
||
class RTSPService:
|
||
def __init__(self, config_path: str = "config.yaml"):
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
self.cameras = [CameraConfig(id=c["id"], name=c.get("name", f"cam_{c['id']}"), index = c["index"], rtsp_url=c["rtsp_url"])
|
||
for c in cfg.get("cameras", [])]
|
||
|
||
self.stop_event = threading.Event()
|
||
self.raw_queue = queue.Queue(maxsize=500)
|
||
self.ws_queue = queue.Queue(maxsize=1000)
|
||
|
||
self.capture_workers = []
|
||
self.processor = FrameProcessorWorker(self.raw_queue, self.ws_queue, self.stop_event)
|
||
self.ws_sender = WebSocketSender(self.ws_queue, self.stop_event)
|
||
|
||
def start(self):
|
||
self.ws_sender.start()
|
||
self.processor.start()
|
||
for cam in self.cameras:
|
||
w = RTSPCaptureWorker(cam, self.raw_queue, self.stop_event)
|
||
w.start()
|
||
self.capture_workers.append(w)
|
||
print("[INFO] Kadian RTSP Service started")
|
||
|
||
def stop(self):
|
||
self.stop_event.set()
|
||
self.raw_queue.join()
|
||
self.ws_queue.join()
|
||
for w in self.capture_workers:
|
||
w.join(timeout=2.0)
|
||
self.processor.join(timeout=2.0)
|
||
self.ws_sender.join(timeout=2.0)
|
||
print("[INFO] Service stopped")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
service = RTSPService("config.yaml")
|
||
service.start()
|
||
try:
|
||
while True:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
service.stop()
|