Initial commit for tst
This commit is contained in:
837
AIMonitor/rtsp_service_ws_new.py
Normal file
837
AIMonitor/rtsp_service_ws_new.py
Normal file
@@ -0,0 +1,837 @@
|
||||
# rtsp_service_ws.py (merged with YOLO + MMAction2 logic)
|
||||
import cv2
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
import yaml
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import asyncio
|
||||
import websockets
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
|
||||
# --- 新增依赖(YOLO/ONNX/跟踪/NumPy) ---
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
# 如果你有 yolox 的 BYTETracker,请确保 pythonpath 包含 yolox 包
|
||||
# 如果没有,请在运行环境中安装或提供替代跟踪器
|
||||
try:
|
||||
from yolox.tracker.byte_tracker import BYTETracker
|
||||
except Exception as e:
|
||||
BYTETracker = None
|
||||
print(f"[WARN] 无法导入 BYTETracker: {e}. 请确保 yolox 已安装,或提供替代跟踪实现.")
|
||||
|
||||
# 从你上传的模块导入 YOLO ONNX 类(npu_yolo_onnx.py)
|
||||
# 确保该文件可在同一目录或 PYTHONPATH 中导入
|
||||
try:
|
||||
from npu_yolo_onnx import YOLOv8_ONNX
|
||||
except Exception as e:
|
||||
YOLOv8_ONNX = None
|
||||
print(f"[WARN] 无法导入 YOLOv8_ONNX(npu_yolo_onnx.py):{e}")
|
||||
|
||||
# =========================
|
||||
# 配置与数据结构
|
||||
# =========================
|
||||
|
||||
@dataclass
|
||||
class CameraConfig:
|
||||
id: int
|
||||
name: str
|
||||
rtsp_url: str
|
||||
|
||||
|
||||
RTSP_TARGET_FPS = 10.0 # 固定 10 帧/秒
|
||||
FRAMES_PER_SEGMENT = 600 # 每 600 帧一个 mp4
|
||||
VIDEO_OUTPUT_DIR = "./videos" # 视频输出目录
|
||||
WS_HOST = "0.0.0.0" # WebSocket 服务端监听地址
|
||||
WS_PORT = 8765 # WebSocket 服务端端口
|
||||
|
||||
# 已连接的 WebSocket 客户端集合
|
||||
ws_clients = set()
|
||||
|
||||
# =========================
|
||||
# YOLO / 动作识别 / 跟踪 配置
|
||||
# =========================
|
||||
|
||||
# --- 请根据你实际路径修改下面的 ONNX 模型路径 ---
|
||||
YOLO_ONNX_PATH = "YOLO_Weight/best.onnx" # <-- 改为实际路径
|
||||
SUPERVISOR_ONNX = "ONNX_Weight/Supervisor.onnx" # <-- 改为实际路径
|
||||
SUSPECT_ONNX = "ONNX_Weight/Suspect.onnx" # <-- 改为实际路径
|
||||
|
||||
# 动作标签(来自 Ascend_NPU_YOLO_TSM_RealTime.py)
|
||||
LABELS_SUPERVISOR = {0: 'Normal', 1: 'Push', 2: 'Slap'}
|
||||
LABELS_SUSPECT = {0: 'Collision', 1: 'Hanging', 2: 'Lyingdown', 3: 'Normal'}
|
||||
|
||||
# 超参数(和 Ascend 文件保持一致)
|
||||
CLIP_LEN = 32
|
||||
SLIDE_STEP = 16
|
||||
CONF_THRESH = 0.1
|
||||
EXPAND_RATIO = 0.4
|
||||
TARGET_SIZE = 224
|
||||
YOLO_CONF_THRESH = 0.5
|
||||
YOLO_IOU_THRESH = 0.45
|
||||
ACTION_COOLDOWN = 0.0
|
||||
|
||||
# 跟踪器 / 缓存等(按 camera 分离)
|
||||
trackers: Dict[int, Any] = {} # camera_id -> BYTETracker instance
|
||||
track_buffers: Dict[int, Dict[int, list]] = {} # camera_id -> {track_id -> list of cv2 crops}
|
||||
last_alert: Dict[int, Dict[int, float]] = {} # camera_id -> {track_id -> last_alert_time}
|
||||
track_role: Dict[int, Dict[int, str]] = {} # camera_id -> {track_id -> role}
|
||||
track_action_result: Dict[int, Dict[int, str]] = {} # camera_id -> {track_id -> action string}
|
||||
|
||||
# 最近动作显示(全局或 per-camera 可扩展)
|
||||
recent_actions: Dict[int, list] = {} # camera_id -> list of recent actions
|
||||
MAX_RECENT_ACTIONS = 3
|
||||
ACTION_DISPLAY_DURATION = 2.0
|
||||
|
||||
# YOLO 和动作识别 session(单例式)
|
||||
yolo_model = None
|
||||
sess_supervisor = None
|
||||
sess_suspect = None
|
||||
input_name_sup = None
|
||||
input_name_sus = None
|
||||
|
||||
# =========================
|
||||
# 初始化模型(尝试导入/创建 session)
|
||||
# =========================
|
||||
|
||||
def init_models_once():
|
||||
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
|
||||
# YOLO
|
||||
if YOLOv8_ONNX is None:
|
||||
print("[ERROR] YOLOv8_ONNX 未导入,无法初始化 YOLO 模型")
|
||||
else:
|
||||
try:
|
||||
yolo_model = YOLOv8_ONNX(YOLO_ONNX_PATH, conf_threshold=YOLO_CONF_THRESH, iou_threshold=YOLO_IOU_THRESH)
|
||||
print("[INFO] YOLO 模型初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] YOLO 模型初始化失败: {e}")
|
||||
yolo_model = None
|
||||
|
||||
# -----------------------------
|
||||
# 动作识别模型初始化(正确的 Provider 判断方式)
|
||||
# -----------------------------
|
||||
try:
|
||||
# 请求使用 CANN,但是否真正启用必须用 get_providers 判断
|
||||
providers = [
|
||||
("CANNExecutionProvider", {
|
||||
"device_id": 0,
|
||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||||
"npu_mem_limit": 16 * 1024 * 1024 * 1024,
|
||||
"precision_mode": "allow_fp32_to_fp16",
|
||||
"op_select_impl_mode": "high_precision",
|
||||
"enable_cann_graph": True,
|
||||
}),
|
||||
"CPUExecutionProvider", # 自动 fallback
|
||||
]
|
||||
|
||||
sess_supervisor = ort.InferenceSession(SUPERVISOR_ONNX, providers=providers)
|
||||
sess_suspect = ort.InferenceSession(SUSPECT_ONNX, providers=providers)
|
||||
|
||||
sup_prov = sess_supervisor.get_providers()
|
||||
sus_prov = sess_suspect.get_providers()
|
||||
|
||||
print("Supervisor Providers:", sup_prov)
|
||||
print("Suspect Providers:", sus_prov)
|
||||
|
||||
if "CANNExecutionProvider" in sup_prov:
|
||||
print("[INFO] 动作识别模型:使用 CANNExecutionProvider(昇腾)")
|
||||
else:
|
||||
print("[INFO] 动作识别模型:使用 CPUExecutionProvider(非昇腾环境)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 初始化动作识别模型失败: {e}")
|
||||
sess_supervisor = None
|
||||
sess_suspect = None
|
||||
|
||||
if sess_supervisor is not None:
|
||||
input_name_sup = sess_supervisor.get_inputs()[0].name
|
||||
print(f"[INFO] 监护人模型输入: {input_name_sup}")
|
||||
if sess_suspect is not None:
|
||||
input_name_sus = sess_suspect.get_inputs()[0].name
|
||||
print(f"[INFO] 被监护人模型输入: {input_name_sus}")
|
||||
|
||||
# 只初始化一次
|
||||
init_models_once()
|
||||
|
||||
# =========================
|
||||
# 工具函数(IoU, preprocess_clip)
|
||||
# =========================
|
||||
|
||||
def compute_iou(box1, box2):
|
||||
"""计算两个框的 IoU"""
|
||||
x1, y1, x2, y2 = box1
|
||||
x1_, y1_, x2_, y2_ = box2
|
||||
xi1 = max(x1, x1_)
|
||||
yi1 = max(y1, y1_)
|
||||
xi2 = min(x2, x2_)
|
||||
yi2 = min(y2, y2_)
|
||||
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
||||
box1_area = max(0, (x2 - x1)) * max(0, (y2 - y1))
|
||||
box2_area = max(0, (x2_ - x1_)) * max(0, (y2_ - y1_))
|
||||
union_area = box1_area + box2_area - inter_area
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
def preprocess_clip(frames):
|
||||
"""按 Ascend_NPU_YOLO_TSM_RealTime.py 的预处理:32帧 -> 每2帧取1帧 -> crop/resize/normalize -> (1, T, C, H, W)"""
|
||||
# 确保有足够帧
|
||||
if len(frames) < CLIP_LEN:
|
||||
last_frame = frames[-1] if frames else np.zeros((TARGET_SIZE, TARGET_SIZE, 3), dtype=np.uint8)
|
||||
frames = frames + [last_frame] * (CLIP_LEN - len(frames))
|
||||
|
||||
indices = list(range(0, CLIP_LEN, 2)) # 0,2,4,...,30 -> 16 frames
|
||||
selected = [frames[i] for i in indices]
|
||||
|
||||
imgs = []
|
||||
for f in selected:
|
||||
h, w = f.shape[:2]
|
||||
scale = 256.0 / min(h, w)
|
||||
nw, nh = int(w * scale), int(h * scale)
|
||||
f_resized = cv2.resize(f, (nw, nh))
|
||||
top = (nh - 224) // 2
|
||||
left = (nw - 224) // 2
|
||||
f_cropped = f_resized[top:top + 224, left:left + 224]
|
||||
f_rgb = cv2.cvtColor(f_cropped, cv2.COLOR_BGR2RGB).transpose(2, 0, 1).astype(np.float32)
|
||||
imgs.append(f_rgb)
|
||||
|
||||
x = np.stack(imgs)[np.newaxis] # shape (1, 16, 3, 224, 224) or (1, T, C, H, W)
|
||||
|
||||
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32).reshape(1, 1, 3, 1, 1)
|
||||
std = np.array([58.395, 57.12, 57.375], dtype=np.float32).reshape(1, 1, 3, 1, 1)
|
||||
|
||||
result = (x - mean) / std
|
||||
return result.astype(np.float32)
|
||||
|
||||
# =========================
|
||||
# WebSocket 服务线程(不变)
|
||||
# =========================
|
||||
|
||||
class WebSocketSender(threading.Thread):
|
||||
def __init__(self, send_queue: "queue.Queue[Dict[str, Any]]", stop_event: threading.Event):
|
||||
super().__init__(daemon=True)
|
||||
self.send_queue = send_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
async def _ws_handler(self, websocket):
|
||||
ws_clients.add(websocket)
|
||||
try:
|
||||
async for _ in websocket:
|
||||
pass
|
||||
finally:
|
||||
ws_clients.discard(websocket)
|
||||
|
||||
async def _broadcaster(self):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
data = json.dumps(msg)
|
||||
dead = []
|
||||
for ws in list(ws_clients):
|
||||
try:
|
||||
await ws.send(data)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
ws_clients.discard(ws)
|
||||
self.send_queue.task_done()
|
||||
|
||||
async def _run_async(self):
|
||||
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
|
||||
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
|
||||
await self._broadcaster()
|
||||
|
||||
def run(self):
|
||||
asyncio.run(self._run_async())
|
||||
|
||||
# =========================
|
||||
# RTSP 抓流线程(不变)
|
||||
# =========================
|
||||
|
||||
class RTSPCaptureWorker(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
camera_cfg: CameraConfig,
|
||||
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
|
||||
stop_event: threading.Event,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.camera_cfg = camera_cfg
|
||||
self.raw_frame_queue = raw_frame_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
def run(self):
|
||||
cap = cv2.VideoCapture(self.camera_cfg.rtsp_url, cv2.CAP_FFMPEG)
|
||||
|
||||
if not cap.isOpened():
|
||||
print(f"[ERROR] Cannot open RTSP stream: {self.camera_cfg.rtsp_url}")
|
||||
return
|
||||
|
||||
print(f"[INFO] Start capturing: id={self.camera_cfg.id}, name={self.camera_cfg.name}")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
print(f"[WARN] Failed to read frame from camera {self.camera_cfg.id}, retrying...")
|
||||
time.sleep(0.2)
|
||||
continue
|
||||
|
||||
ts = time.time()
|
||||
item = {
|
||||
"camera_id": self.camera_cfg.id,
|
||||
"camera_name": self.camera_cfg.name,
|
||||
"timestamp": ts,
|
||||
"frame": frame,
|
||||
}
|
||||
|
||||
try:
|
||||
self.raw_frame_queue.put(item, timeout=1.0)
|
||||
except queue.Full:
|
||||
print(f"[WARN] Raw frame queue full, drop frame from camera {self.camera_cfg.id}")
|
||||
|
||||
cap.release()
|
||||
print(f"[INFO] Stop capturing: id={self.camera_cfg.id}")
|
||||
|
||||
# =========================
|
||||
# 帧处理线程(抽帧 + 写mp4 + 调用用户函数 + 发WebSocket消息)
|
||||
# =========================
|
||||
|
||||
class FrameProcessorWorker(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
|
||||
ws_send_queue: "queue.Queue[Dict[str, Any]]",
|
||||
stop_event: threading.Event,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.raw_frame_queue = raw_frame_queue
|
||||
self.ws_send_queue = ws_send_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
# 每个摄像头独立维护视频写入状态
|
||||
self.video_writers: Dict[int, cv2.VideoWriter] = {}
|
||||
self.video_frame_counts: Dict[int, int] = {}
|
||||
self.video_segment_start_ts: Dict[int, float] = {}
|
||||
self.video_segment_filenames: Dict[int, str] = {}
|
||||
|
||||
os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 控制 10fps 抽帧:记录每个摄像头上次处理时间
|
||||
self.last_process_ts: Dict[int, float] = {}
|
||||
|
||||
def _get_video_writer(self, camera_id: int, frame) -> Tuple[cv2.VideoWriter, str]:
|
||||
writer = self.video_writers.get(camera_id)
|
||||
if writer is not None:
|
||||
return writer, self.video_segment_filenames[camera_id]
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
|
||||
start_ts = time.time()
|
||||
self.video_segment_start_ts[camera_id] = start_ts
|
||||
|
||||
ts_str = time.strftime("%Y%m%d_%H%M%S", time.localtime(start_ts))
|
||||
filename = f"{ts_str}_cam{camera_id}.mp4"
|
||||
filepath = os.path.join(VIDEO_OUTPUT_DIR, filename)
|
||||
|
||||
writer = cv2.VideoWriter(filepath, fourcc, RTSP_TARGET_FPS, (w, h))
|
||||
self.video_writers[camera_id] = writer
|
||||
self.video_frame_counts[camera_id] = 0
|
||||
self.video_segment_filenames[camera_id] = filepath
|
||||
|
||||
print(f"[INFO] Start new segment: camera={camera_id}, file={filepath}")
|
||||
return writer, filepath
|
||||
|
||||
def _close_segment_if_needed(self, camera_id: int):
|
||||
count = self.video_frame_counts.get(camera_id, 0)
|
||||
if count >= FRAMES_PER_SEGMENT:
|
||||
writer = self.video_writers.get(camera_id)
|
||||
if writer is not None:
|
||||
writer.release()
|
||||
print(f"[INFO] Close segment: camera={camera_id}, file={self.video_segment_filenames[camera_id]}")
|
||||
|
||||
self.video_writers.pop(camera_id, None)
|
||||
self.video_frame_counts.pop(camera_id, None)
|
||||
self.video_segment_start_ts.pop(camera_id, None)
|
||||
self.video_segment_filenames.pop(camera_id, None)
|
||||
|
||||
def _encode_image_to_base64(self, image) -> str:
|
||||
ok, buf = cv2.imencode(".jpg", image)
|
||||
if not ok:
|
||||
raise RuntimeError("Failed to encode image to JPEG")
|
||||
return base64.b64encode(buf.tobytes()).decode("ascii")
|
||||
|
||||
def run(self):
|
||||
print("[INFO] FrameProcessorWorker started")
|
||||
target_interval = 1.0 / RTSP_TARGET_FPS
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
item = self.raw_frame_queue.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
camera_id = item["camera_id"]
|
||||
ts = item["timestamp"]
|
||||
frame = item["frame"]
|
||||
|
||||
last_ts = self.last_process_ts.get(camera_id, 0.0)
|
||||
if ts - last_ts < target_interval:
|
||||
self.raw_frame_queue.task_done()
|
||||
continue
|
||||
self.last_process_ts[camera_id] = ts
|
||||
|
||||
# 1) 写入 mp4 (当前segment)
|
||||
writer, video_filepath = self._get_video_writer(camera_id, frame)
|
||||
writer.write(frame)
|
||||
self.video_frame_counts[camera_id] = self.video_frame_counts.get(camera_id, 0) + 1
|
||||
|
||||
# 2) 调用用户自定义处理逻辑
|
||||
result = user_process_frame(frame, camera_id, ts)
|
||||
|
||||
if result is not None and "image" in result and "type" in result:
|
||||
result_img = result["image"]
|
||||
result_type = int(result["type"])
|
||||
|
||||
# 3) 通过 WebSocket 发送帧结果
|
||||
try:
|
||||
img_b64 = self._encode_image_to_base64(result_img)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Encode image failed: {e}")
|
||||
img_b64 = None
|
||||
|
||||
if img_b64 is not None:
|
||||
msg = {
|
||||
"msg_type": "frame",
|
||||
"camera_id": camera_id,
|
||||
"timestamp": ts,
|
||||
"result_type": result_type,
|
||||
"image_base64": img_b64,
|
||||
}
|
||||
try:
|
||||
self.ws_send_queue.put(msg, timeout=1.0)
|
||||
except queue.Full:
|
||||
print("[WARN] ws_send_queue full, drop frame message")
|
||||
|
||||
# 4) 如果 result_type != 0,通过 WebSocket 发送告警
|
||||
if result_type != 0:
|
||||
alert_msg = {
|
||||
"msg_type": "alert",
|
||||
"camera_id": camera_id,
|
||||
"event_type": result_type,
|
||||
"video_file": video_filepath,
|
||||
"timestamp": ts,
|
||||
}
|
||||
try:
|
||||
self.ws_send_queue.put(alert_msg, timeout=1.0)
|
||||
except queue.Full:
|
||||
print("[WARN] ws_send_queue full, drop alert message")
|
||||
|
||||
# 5) 检查是否需要切换到下一个 mp4 segment
|
||||
self._close_segment_if_needed(camera_id)
|
||||
|
||||
self.raw_frame_queue.task_done()
|
||||
|
||||
# 退出时,关闭所有 VideoWriter
|
||||
for cam_id, writer in list(self.video_writers.items()):
|
||||
writer.release()
|
||||
print(f"[INFO] Release writer on exit: camera={cam_id}")
|
||||
print("[INFO] FrameProcessorWorker stopped")
|
||||
|
||||
# =========================
|
||||
# 用户自定义函数(重要:已集成 YOLO + 动作识别 + 跟踪 + 告警)
|
||||
# =========================
|
||||
|
||||
def user_process_frame(image, camera_id: int, timestamp: float) -> Dict[str, Any]:
|
||||
"""
|
||||
集成了:
|
||||
1. 视频帧输入
|
||||
2. YOLO 目标检测(Supervisor / Suspect)
|
||||
3. 对每个检测到的人物:
|
||||
- 裁剪 ROI
|
||||
- 预处理(resize 等)
|
||||
- 根据类别选择动作识别模型(supervisor / suspect)
|
||||
- 执行动作识别 ONNX 推理
|
||||
- 解析动作类别并判断是否触发告警
|
||||
4. 绘制结果(检测框、标签、告警)
|
||||
5. 返回处理后的图像与告警类型
|
||||
注意:尽量保持原实现逻辑(Ascend_NPU_YOLO_TSM_RealTime.py)
|
||||
"""
|
||||
global trackers, track_buffers, last_alert, track_role, track_action_result, recent_actions
|
||||
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
|
||||
|
||||
# 初始化 per-camera 结构
|
||||
if camera_id not in trackers:
|
||||
if BYTETracker is not None:
|
||||
# 使用 Ascend 源里的 Tracker 参数
|
||||
class TrackerArgs:
|
||||
track_thresh = 0.5
|
||||
track_buffer = 30
|
||||
match_thresh = 0.8
|
||||
mot20 = False
|
||||
try:
|
||||
trackers[camera_id] = BYTETracker(TrackerArgs(), frame_rate=RTSP_TARGET_FPS)
|
||||
print(f"[INFO] 初始化 BYTETracker for camera {camera_id}")
|
||||
except Exception as e:
|
||||
trackers[camera_id] = None
|
||||
print(f"[WARN] 无法初始化 BYTETracker: {e}")
|
||||
else:
|
||||
trackers[camera_id] = None
|
||||
print("[WARN] BYTETracker 未安装,跟踪功能不可用")
|
||||
|
||||
if camera_id not in track_buffers:
|
||||
track_buffers[camera_id] = {}
|
||||
if camera_id not in last_alert:
|
||||
last_alert[camera_id] = {}
|
||||
if camera_id not in track_role:
|
||||
track_role[camera_id] = {}
|
||||
if camera_id not in track_action_result:
|
||||
track_action_result[camera_id] = {}
|
||||
if camera_id not in recent_actions:
|
||||
recent_actions[camera_id] = []
|
||||
|
||||
frame = image # BGR
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# === 1. YOLO 检测 ===
|
||||
detections = []
|
||||
if yolo_model is not None:
|
||||
try:
|
||||
detections = yolo_model(frame)
|
||||
# detections 格式: list of [x1,y1,x2,y2, conf, cls_id]
|
||||
except Exception as e:
|
||||
print(f"[WARN] YOLO 推理失败: {e}")
|
||||
detections = []
|
||||
else:
|
||||
# 如果没有模型,返回原图(不报警)
|
||||
return {"image": frame, "type": 0}
|
||||
|
||||
dets_xyxy = []
|
||||
dets_roles = []
|
||||
dets_for_tracker = []
|
||||
supervisor_count = 0
|
||||
suspect_count = 0
|
||||
|
||||
if detections:
|
||||
for det in detections:
|
||||
x1, y1, x2, y2, conf, cls_id = det
|
||||
dets_xyxy.append([x1, y1, x2, y2])
|
||||
if int(cls_id) == 0:
|
||||
dets_roles.append("supervisor"); supervisor_count += 1
|
||||
else:
|
||||
dets_roles.append("suspect"); suspect_count += 1
|
||||
dets_for_tracker.append([x1, y1, x2, y2, conf])
|
||||
else:
|
||||
dets_for_tracker = np.empty((0,5))
|
||||
|
||||
dets_for_tracker = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) > 0 else np.empty((0,5))
|
||||
|
||||
# === 2. 跟踪(BYTETracker) ===
|
||||
tracker = trackers.get(camera_id)
|
||||
if tracker is None:
|
||||
# 没有 tracker 的情况下,仅在检测框上画出结果并返回
|
||||
for i, det in enumerate(detections):
|
||||
x1, y1, x2, y2, conf, cls_id = det
|
||||
role = "supervisor" if int(cls_id) == 0 else "suspect"
|
||||
color = (255,0,0) if role=="supervisor" else (0,0,255)
|
||||
cv2.rectangle(frame, (int(x1),int(y1)), (int(x2),int(y2)), color, 2)
|
||||
cv2.putText(frame, f"{role} {conf:.2f}", (int(x1)+5, int(y1)-6), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
return {"image": frame, "type": 0}
|
||||
|
||||
if dets_for_tracker.size == 0:
|
||||
dets_tensor = torch.zeros((0,5))
|
||||
else:
|
||||
dets_tensor = torch.from_numpy(dets_for_tracker).float()
|
||||
|
||||
# tracker.update expects args (dets, ori_img_shape, img_shape) in Ascend file they called tracker.update(dets_tensor, [h,w], [h,w])
|
||||
try:
|
||||
tracks = tracker.update(dets_tensor, [h, w], [h, w])
|
||||
except Exception as e:
|
||||
print(f"[WARN] tracker.update 出错: {e}")
|
||||
tracks = []
|
||||
|
||||
current_time_sec = timestamp # 使用外部传入的时间戳(秒)
|
||||
|
||||
current_frame_abnormal_actions = []
|
||||
|
||||
# === 3. 每个 track 做 IoU 匹配角色并做动作识别 ===
|
||||
for t in tracks:
|
||||
try:
|
||||
tid = t.track_id
|
||||
x1, y1, x2, y2 = map(int, t.tlbr)
|
||||
except Exception as e:
|
||||
# 跳过无法解析的 track 对象
|
||||
continue
|
||||
|
||||
# 有效性检查
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
continue
|
||||
|
||||
# IoU 匹配找回类别(如果之前没有)
|
||||
if tid not in track_role[camera_id] and dets_xyxy:
|
||||
best_iou = 0.0
|
||||
best_role = "unknown"
|
||||
track_box = [x1, y1, x2, y2]
|
||||
for i, det_box in enumerate(dets_xyxy):
|
||||
iou = compute_iou(track_box, det_box)
|
||||
if iou > best_iou:
|
||||
best_iou = iou
|
||||
best_role = dets_roles[i]
|
||||
if best_iou > 0.3:
|
||||
track_role[camera_id][tid] = best_role
|
||||
|
||||
role = track_role[camera_id].get(tid, "unknown")
|
||||
|
||||
# 扩展裁剪并 crop ROI
|
||||
dw = int((x2 - x1) * EXPAND_RATIO)
|
||||
dh = int((y2 - y1) * EXPAND_RATIO)
|
||||
ex1, ey1 = max(0, x1 - dw), max(0, y1 - dh)
|
||||
ex2, ey2 = min(w, x2 + dw), min(h, y2 + dh)
|
||||
crop = frame[ey1:ey2, ex1:ex2]
|
||||
if crop.size == 0:
|
||||
continue
|
||||
crop = cv2.resize(crop, (TARGET_SIZE, TARGET_SIZE))
|
||||
|
||||
# 填充 track_buffers
|
||||
if tid not in track_buffers[camera_id]:
|
||||
track_buffers[camera_id][tid] = []
|
||||
track_buffers[camera_id][tid].append(crop)
|
||||
if len(track_buffers[camera_id][tid]) > CLIP_LEN:
|
||||
track_buffers[camera_id][tid] = track_buffers[camera_id][tid][-CLIP_LEN:]
|
||||
|
||||
# 默认值
|
||||
action_text = "Detecting..."
|
||||
conf_val = 0.0
|
||||
action_name = "Normal"
|
||||
|
||||
# 若缓存达到 CLIP_LEN,则做动作识别
|
||||
if len(track_buffers[camera_id][tid]) >= CLIP_LEN and sess_supervisor is not None and sess_suspect is not None:
|
||||
tensor = preprocess_clip(track_buffers[camera_id][tid])
|
||||
if tensor.dtype != np.float32:
|
||||
tensor = tensor.astype(np.float32)
|
||||
|
||||
pred = None
|
||||
labels = None
|
||||
# 根据角色与上下文选择模型(保留 Ascend 中的逻辑)
|
||||
if role == "supervisor" and suspect_count >= 1:
|
||||
try:
|
||||
pred = sess_supervisor.run(None, {input_name_sup: tensor})[0]
|
||||
labels = LABELS_SUPERVISOR
|
||||
except Exception as e:
|
||||
print(f"[WARN] supervisor 模型推理失败: {e}")
|
||||
pred = None
|
||||
elif role == "suspect" and supervisor_count == 0:
|
||||
try:
|
||||
pred = sess_suspect.run(None, {input_name_sus: tensor})[0]
|
||||
labels = LABELS_SUSPECT
|
||||
except Exception as e:
|
||||
print(f"[WARN] suspect 模型推理失败: {e}")
|
||||
pred = None
|
||||
else:
|
||||
# 条件不满足,滑动窗口并继续
|
||||
track_buffers[camera_id][tid] = track_buffers[camera_id][tid][SLIDE_STEP:]
|
||||
continue
|
||||
|
||||
if pred is not None:
|
||||
idx = int(np.argmax(pred[0]))
|
||||
conf_val = float(pred[0][idx])
|
||||
action_name = labels.get(idx, "Unknown")
|
||||
action_text = f"{action_name}({conf_val:.2f})"
|
||||
|
||||
should_alert = False
|
||||
|
||||
# 角色-动作匹配逻辑(原封不动)
|
||||
if (action_name == 'Slap' or action_name == 'Push') and role == 'supervisor':
|
||||
should_alert = True
|
||||
track_action_result[camera_id][tid] = f"{action_name}({conf_val:.2f})"
|
||||
print(f"⏰ 时间:{current_time_sec:.2f} | Camera:{camera_id} | ID: {tid} | 动作:{action_name} | 置信度:{conf_val:.2f}")
|
||||
elif (action_name == 'Hanging' or action_name == 'Collision' or action_name == 'Lyingdown') and role == 'suspect':
|
||||
should_alert = True
|
||||
track_action_result[camera_id][tid] = f"{action_name}({conf_val:.2f})"
|
||||
print(f"⏰ 时间:{current_time_sec:.2f} | Camera:{camera_id} | ID: {tid} | 动作:{action_name} | 置信度:{conf_val:.2f}")
|
||||
else:
|
||||
if tid in track_action_result[camera_id]:
|
||||
del track_action_result[camera_id][tid]
|
||||
|
||||
# 报警逻辑
|
||||
if (should_alert and conf_val >= CONF_THRESH and
|
||||
(tid not in last_alert[camera_id] or current_time_sec - last_alert[camera_id][tid] > ACTION_COOLDOWN)):
|
||||
print(f"[ALERT] Camera:{camera_id} | ID:{tid} ({role}) -> {action_name} ({conf_val:.3f})")
|
||||
last_alert[camera_id][tid] = current_time_sec
|
||||
|
||||
action_info = {
|
||||
'time': current_time_sec,
|
||||
'camera_id': camera_id,
|
||||
'role': role,
|
||||
'id': tid,
|
||||
'action': action_name,
|
||||
'confidence': conf_val
|
||||
}
|
||||
|
||||
recent_actions[camera_id].append(action_info)
|
||||
if len(recent_actions[camera_id]) > MAX_RECENT_ACTIONS:
|
||||
recent_actions[camera_id].pop(0)
|
||||
|
||||
# 添加到当前帧异常动作列表(用于可视化)
|
||||
current_frame_abnormal_actions.append(action_info)
|
||||
|
||||
# 滑动窗口
|
||||
track_buffers[camera_id][tid] = track_buffers[camera_id][tid][SLIDE_STEP:]
|
||||
|
||||
# 可视化:若检测到异常动作则画框
|
||||
action_to_show = track_action_result[camera_id].get(tid, None)
|
||||
if action_to_show is not None and action_name != "Normal" and conf_val >= CONF_THRESH:
|
||||
color = (255,0,0) if role == "supervisor" else (0,0,255)
|
||||
cv2.rectangle(frame, (x1,y1), (x2,y2), color, 3)
|
||||
overlay = frame.copy()
|
||||
cv2.rectangle(overlay, (x1, y1 - 48), (x1 + 420, y1), color, -1)
|
||||
cv2.addWeighted(overlay, 0.75, frame, 0.25, 0, frame)
|
||||
cv2.putText(frame, f"{role.upper()} ID:{tid}", (x1 + 8, y1 - 25),
|
||||
cv2.FONT_HERSHEY_DUPLEX, 0.8, (255,255,255), 2)
|
||||
action_color = (0,0,255)
|
||||
cv2.putText(frame, track_action_result[camera_id][tid], (x1 + 8, y1 - 3),
|
||||
cv2.FONT_HERSHEY_DUPLEX, 0.9, action_color, 2)
|
||||
cv2.putText(frame, "ALERT!", (x2 - 130, y1 - 8),
|
||||
cv2.FONT_HERSHEY_COMPLEX, 1.1, (0,0,255), 3)
|
||||
|
||||
# === 全局信息显示(在图像左上) ===
|
||||
# 统计当前跟踪角色
|
||||
cur_supervisors = 0
|
||||
cur_suspects = 0
|
||||
for tid_map in track_role.get(camera_id, {}).items():
|
||||
# tid_map is (tid, role) pairs
|
||||
pass
|
||||
# 上面的循环写法是为了不改变原逻辑结构;统计用下面的更直接方式
|
||||
for tid, r in track_role.get(camera_id, {}).items():
|
||||
if r == "supervisor":
|
||||
cur_supervisors += 1
|
||||
elif r == "suspect":
|
||||
cur_suspects += 1
|
||||
|
||||
info = [
|
||||
f"Camera: {camera_id}",
|
||||
f"Targets: {len(tracks)}",
|
||||
f"Supervisor: {cur_supervisors}",
|
||||
f"Suspect: {cur_suspects}"
|
||||
]
|
||||
for i, text in enumerate(info):
|
||||
cv2.putText(frame, text, (10, 35 + i * 28), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||||
|
||||
# 动作显示区域
|
||||
status_y = 35 + len(info) * 28 + 10
|
||||
if len(current_frame_abnormal_actions) > 0:
|
||||
cv2.putText(frame, "ACTION DETECTED!", (10, status_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 3)
|
||||
for i, action_info in enumerate(current_frame_abnormal_actions):
|
||||
role_text = action_info['role'].upper()
|
||||
action_display = f"{role_text} ID:{action_info['id']} -> {action_info['action']} ({action_info['confidence']:.2f})"
|
||||
color = (255,0,0) if action_info['role'] == "supervisor" else (0,0,255)
|
||||
cv2.putText(frame, action_display, (10, status_y + 40 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||||
else:
|
||||
# recent actions 显示
|
||||
if recent_actions[camera_id]:
|
||||
for i, action_info in enumerate(recent_actions[camera_id][-MAX_RECENT_ACTIONS:]):
|
||||
action_display = f"{action_info['role'].upper()} ID:{action_info['id']} -> {action_info['action']} ({action_info['confidence']:.2f})"
|
||||
color = (255,0,0) if action_info['role'] == "supervisor" else (0,0,255)
|
||||
cv2.putText(frame, action_display, (10, status_y + 10 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||||
else:
|
||||
cv2.putText(frame, "Detecting...", (10, status_y), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
|
||||
|
||||
# 返回:image 与 type(type 0 = 无报警;1 = 有报警)
|
||||
result_type = 1 if any(len(recent_actions[camera_id])>0 and item.get('confidence',0)>=CONF_THRESH for item in recent_actions[camera_id]) else 0
|
||||
|
||||
return {
|
||||
"image": frame,
|
||||
"type": result_type
|
||||
}
|
||||
|
||||
# =========================
|
||||
# 服务封装(不变)
|
||||
# =========================
|
||||
|
||||
class RTSPService:
|
||||
def __init__(self, config_path: str):
|
||||
self.config_path = config_path
|
||||
self.cameras = self._load_config()
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# 队列
|
||||
self.raw_frame_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=500)
|
||||
self.ws_send_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=1000)
|
||||
|
||||
# 线程
|
||||
self.capture_workers = []
|
||||
self.frame_processor = FrameProcessorWorker(self.raw_frame_queue, self.ws_send_queue, self.stop_event)
|
||||
self.ws_sender = WebSocketSender(self.ws_send_queue, self.stop_event)
|
||||
|
||||
def _load_config(self):
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
cameras_cfg = cfg.get("cameras", [])
|
||||
cameras = []
|
||||
for c in cameras_cfg:
|
||||
cameras.append(
|
||||
CameraConfig(
|
||||
id=int(c["id"]),
|
||||
name=str(c.get("name", f"cam_{c['id']}")),
|
||||
rtsp_url=str(c["rtsp_url"]),
|
||||
)
|
||||
)
|
||||
return cameras
|
||||
|
||||
def start(self):
|
||||
print("[INFO] RTSPService starting...")
|
||||
|
||||
# 启动 WebSocket 发送线程
|
||||
self.ws_sender.start()
|
||||
|
||||
# 启动帧处理线程
|
||||
self.frame_processor.start()
|
||||
|
||||
# 启动每个摄像头的抓流线程
|
||||
for cam in self.cameras:
|
||||
w = RTSPCaptureWorker(cam, self.raw_frame_queue, self.stop_event)
|
||||
w.start()
|
||||
self.capture_workers.append(w)
|
||||
|
||||
print("[INFO] RTSPService started")
|
||||
|
||||
def stop(self):
|
||||
print("[INFO] RTSPService stopping...")
|
||||
self.stop_event.set()
|
||||
|
||||
# 等待队列处理完(可选)
|
||||
try:
|
||||
self.raw_frame_queue.join()
|
||||
self.ws_send_queue.join()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for w in self.capture_workers:
|
||||
w.join(timeout=1.0)
|
||||
self.frame_processor.join(timeout=1.0)
|
||||
self.ws_sender.join(timeout=1.0)
|
||||
|
||||
print("[INFO] RTSPService stopped")
|
||||
|
||||
def main():
|
||||
service = RTSPService(config_path="config.yaml")
|
||||
service.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("[INFO] KeyboardInterrupt, shutting down...")
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user