838 lines
32 KiB
Python
838 lines
32 KiB
Python
# rtsp_service_ws.py (merged with YOLO + MMAction2 logic)
|
||
import cv2
|
||
import time
|
||
import threading
|
||
import queue
|
||
import yaml
|
||
import os
|
||
import json
|
||
import base64
|
||
import asyncio
|
||
import websockets
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Dict, Any, Tuple
|
||
|
||
# --- 新增依赖(YOLO/ONNX/跟踪/NumPy) ---
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import torch
|
||
|
||
# 如果你有 yolox 的 BYTETracker,请确保 pythonpath 包含 yolox 包
|
||
# 如果没有,请在运行环境中安装或提供替代跟踪器
|
||
try:
|
||
from yolox.tracker.byte_tracker import BYTETracker
|
||
except Exception as e:
|
||
BYTETracker = None
|
||
print(f"[WARN] 无法导入 BYTETracker: {e}. 请确保 yolox 已安装,或提供替代跟踪实现.")
|
||
|
||
# 从你上传的模块导入 YOLO ONNX 类(npu_yolo_onnx.py)
|
||
# 确保该文件可在同一目录或 PYTHONPATH 中导入
|
||
try:
|
||
from npu_yolo_onnx import YOLOv8_ONNX
|
||
except Exception as e:
|
||
YOLOv8_ONNX = None
|
||
print(f"[WARN] 无法导入 YOLOv8_ONNX(npu_yolo_onnx.py):{e}")
|
||
|
||
# =========================
|
||
# 配置与数据结构
|
||
# =========================
|
||
|
||
@dataclass
|
||
class CameraConfig:
|
||
id: int
|
||
name: str
|
||
rtsp_url: str
|
||
|
||
|
||
RTSP_TARGET_FPS = 10.0 # 固定 10 帧/秒
|
||
FRAMES_PER_SEGMENT = 600 # 每 600 帧一个 mp4
|
||
VIDEO_OUTPUT_DIR = "./videos" # 视频输出目录
|
||
WS_HOST = "0.0.0.0" # WebSocket 服务端监听地址
|
||
WS_PORT = 8765 # WebSocket 服务端端口
|
||
|
||
# 已连接的 WebSocket 客户端集合
|
||
ws_clients = set()
|
||
|
||
# =========================
|
||
# YOLO / 动作识别 / 跟踪 配置
|
||
# =========================
|
||
|
||
# --- 请根据你实际路径修改下面的 ONNX 模型路径 ---
|
||
YOLO_ONNX_PATH = "YOLO_Weight/best.onnx" # <-- 改为实际路径
|
||
SUPERVISOR_ONNX = "ONNX_Weight/Supervisor.onnx" # <-- 改为实际路径
|
||
SUSPECT_ONNX = "ONNX_Weight/Suspect.onnx" # <-- 改为实际路径
|
||
|
||
# 动作标签(来自 Ascend_NPU_YOLO_TSM_RealTime.py)
|
||
LABELS_SUPERVISOR = {0: 'Normal', 1: 'Push', 2: 'Slap'}
|
||
LABELS_SUSPECT = {0: 'Collision', 1: 'Hanging', 2: 'Lyingdown', 3: 'Normal'}
|
||
|
||
# 超参数(和 Ascend 文件保持一致)
|
||
CLIP_LEN = 32
|
||
SLIDE_STEP = 16
|
||
CONF_THRESH = 0.1
|
||
EXPAND_RATIO = 0.4
|
||
TARGET_SIZE = 224
|
||
YOLO_CONF_THRESH = 0.5
|
||
YOLO_IOU_THRESH = 0.45
|
||
ACTION_COOLDOWN = 0.0
|
||
|
||
# 跟踪器 / 缓存等(按 camera 分离)
|
||
trackers: Dict[int, Any] = {} # camera_id -> BYTETracker instance
|
||
track_buffers: Dict[int, Dict[int, list]] = {} # camera_id -> {track_id -> list of cv2 crops}
|
||
last_alert: Dict[int, Dict[int, float]] = {} # camera_id -> {track_id -> last_alert_time}
|
||
track_role: Dict[int, Dict[int, str]] = {} # camera_id -> {track_id -> role}
|
||
track_action_result: Dict[int, Dict[int, str]] = {} # camera_id -> {track_id -> action string}
|
||
|
||
# 最近动作显示(全局或 per-camera 可扩展)
|
||
recent_actions: Dict[int, list] = {} # camera_id -> list of recent actions
|
||
MAX_RECENT_ACTIONS = 3
|
||
ACTION_DISPLAY_DURATION = 2.0
|
||
|
||
# YOLO 和动作识别 session(单例式)
|
||
yolo_model = None
|
||
sess_supervisor = None
|
||
sess_suspect = None
|
||
input_name_sup = None
|
||
input_name_sus = None
|
||
|
||
# =========================
|
||
# 初始化模型(尝试导入/创建 session)
|
||
# =========================
|
||
|
||
def init_models_once():
|
||
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
|
||
# YOLO
|
||
if YOLOv8_ONNX is None:
|
||
print("[ERROR] YOLOv8_ONNX 未导入,无法初始化 YOLO 模型")
|
||
else:
|
||
try:
|
||
yolo_model = YOLOv8_ONNX(YOLO_ONNX_PATH, conf_threshold=YOLO_CONF_THRESH, iou_threshold=YOLO_IOU_THRESH)
|
||
print("[INFO] YOLO 模型初始化完成")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] YOLO 模型初始化失败: {e}")
|
||
yolo_model = None
|
||
|
||
# -----------------------------
|
||
# 动作识别模型初始化(正确的 Provider 判断方式)
|
||
# -----------------------------
|
||
try:
|
||
# 请求使用 CANN,但是否真正启用必须用 get_providers 判断
|
||
providers = [
|
||
("CANNExecutionProvider", {
|
||
"device_id": 0,
|
||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||
"npu_mem_limit": 16 * 1024 * 1024 * 1024,
|
||
"precision_mode": "allow_fp32_to_fp16",
|
||
"op_select_impl_mode": "high_precision",
|
||
"enable_cann_graph": True,
|
||
}),
|
||
"CPUExecutionProvider", # 自动 fallback
|
||
]
|
||
|
||
sess_supervisor = ort.InferenceSession(SUPERVISOR_ONNX, providers=providers)
|
||
sess_suspect = ort.InferenceSession(SUSPECT_ONNX, providers=providers)
|
||
|
||
sup_prov = sess_supervisor.get_providers()
|
||
sus_prov = sess_suspect.get_providers()
|
||
|
||
print("Supervisor Providers:", sup_prov)
|
||
print("Suspect Providers:", sus_prov)
|
||
|
||
if "CANNExecutionProvider" in sup_prov:
|
||
print("[INFO] 动作识别模型:使用 CANNExecutionProvider(昇腾)")
|
||
else:
|
||
print("[INFO] 动作识别模型:使用 CPUExecutionProvider(非昇腾环境)")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 初始化动作识别模型失败: {e}")
|
||
sess_supervisor = None
|
||
sess_suspect = None
|
||
|
||
if sess_supervisor is not None:
|
||
input_name_sup = sess_supervisor.get_inputs()[0].name
|
||
print(f"[INFO] 监护人模型输入: {input_name_sup}")
|
||
if sess_suspect is not None:
|
||
input_name_sus = sess_suspect.get_inputs()[0].name
|
||
print(f"[INFO] 被监护人模型输入: {input_name_sus}")
|
||
|
||
# 只初始化一次
|
||
init_models_once()
|
||
|
||
# =========================
|
||
# 工具函数(IoU, preprocess_clip)
|
||
# =========================
|
||
|
||
def compute_iou(box1, box2):
|
||
"""计算两个框的 IoU"""
|
||
x1, y1, x2, y2 = box1
|
||
x1_, y1_, x2_, y2_ = box2
|
||
xi1 = max(x1, x1_)
|
||
yi1 = max(y1, y1_)
|
||
xi2 = min(x2, x2_)
|
||
yi2 = min(y2, y2_)
|
||
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
||
box1_area = max(0, (x2 - x1)) * max(0, (y2 - y1))
|
||
box2_area = max(0, (x2_ - x1_)) * max(0, (y2_ - y1_))
|
||
union_area = box1_area + box2_area - inter_area
|
||
return inter_area / union_area if union_area > 0 else 0
|
||
|
||
def preprocess_clip(frames):
|
||
"""按 Ascend_NPU_YOLO_TSM_RealTime.py 的预处理:32帧 -> 每2帧取1帧 -> crop/resize/normalize -> (1, T, C, H, W)"""
|
||
# 确保有足够帧
|
||
if len(frames) < CLIP_LEN:
|
||
last_frame = frames[-1] if frames else np.zeros((TARGET_SIZE, TARGET_SIZE, 3), dtype=np.uint8)
|
||
frames = frames + [last_frame] * (CLIP_LEN - len(frames))
|
||
|
||
indices = list(range(0, CLIP_LEN, 2)) # 0,2,4,...,30 -> 16 frames
|
||
selected = [frames[i] for i in indices]
|
||
|
||
imgs = []
|
||
for f in selected:
|
||
h, w = f.shape[:2]
|
||
scale = 256.0 / min(h, w)
|
||
nw, nh = int(w * scale), int(h * scale)
|
||
f_resized = cv2.resize(f, (nw, nh))
|
||
top = (nh - 224) // 2
|
||
left = (nw - 224) // 2
|
||
f_cropped = f_resized[top:top + 224, left:left + 224]
|
||
f_rgb = cv2.cvtColor(f_cropped, cv2.COLOR_BGR2RGB).transpose(2, 0, 1).astype(np.float32)
|
||
imgs.append(f_rgb)
|
||
|
||
x = np.stack(imgs)[np.newaxis] # shape (1, 16, 3, 224, 224) or (1, T, C, H, W)
|
||
|
||
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32).reshape(1, 1, 3, 1, 1)
|
||
std = np.array([58.395, 57.12, 57.375], dtype=np.float32).reshape(1, 1, 3, 1, 1)
|
||
|
||
result = (x - mean) / std
|
||
return result.astype(np.float32)
|
||
|
||
# =========================
|
||
# WebSocket 服务线程(不变)
|
||
# =========================
|
||
|
||
class WebSocketSender(threading.Thread):
|
||
def __init__(self, send_queue: "queue.Queue[Dict[str, Any]]", stop_event: threading.Event):
|
||
super().__init__(daemon=True)
|
||
self.send_queue = send_queue
|
||
self.stop_event = stop_event
|
||
|
||
async def _ws_handler(self, websocket):
|
||
ws_clients.add(websocket)
|
||
try:
|
||
async for _ in websocket:
|
||
pass
|
||
finally:
|
||
ws_clients.discard(websocket)
|
||
|
||
async def _broadcaster(self):
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
|
||
except queue.Empty:
|
||
continue
|
||
|
||
data = json.dumps(msg)
|
||
dead = []
|
||
for ws in list(ws_clients):
|
||
try:
|
||
await ws.send(data)
|
||
except Exception:
|
||
dead.append(ws)
|
||
for ws in dead:
|
||
ws_clients.discard(ws)
|
||
self.send_queue.task_done()
|
||
|
||
async def _run_async(self):
|
||
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
|
||
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
|
||
await self._broadcaster()
|
||
|
||
def run(self):
|
||
asyncio.run(self._run_async())
|
||
|
||
# =========================
|
||
# RTSP 抓流线程(不变)
|
||
# =========================
|
||
|
||
class RTSPCaptureWorker(threading.Thread):
|
||
def __init__(
|
||
self,
|
||
camera_cfg: CameraConfig,
|
||
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
|
||
stop_event: threading.Event,
|
||
):
|
||
super().__init__(daemon=True)
|
||
self.camera_cfg = camera_cfg
|
||
self.raw_frame_queue = raw_frame_queue
|
||
self.stop_event = stop_event
|
||
|
||
def run(self):
|
||
cap = cv2.VideoCapture(self.camera_cfg.rtsp_url, cv2.CAP_FFMPEG)
|
||
|
||
if not cap.isOpened():
|
||
print(f"[ERROR] Cannot open RTSP stream: {self.camera_cfg.rtsp_url}")
|
||
return
|
||
|
||
print(f"[INFO] Start capturing: id={self.camera_cfg.id}, name={self.camera_cfg.name}")
|
||
|
||
while not self.stop_event.is_set():
|
||
ok, frame = cap.read()
|
||
if not ok:
|
||
print(f"[WARN] Failed to read frame from camera {self.camera_cfg.id}, retrying...")
|
||
time.sleep(0.2)
|
||
continue
|
||
|
||
ts = time.time()
|
||
item = {
|
||
"camera_id": self.camera_cfg.id,
|
||
"camera_name": self.camera_cfg.name,
|
||
"timestamp": ts,
|
||
"frame": frame,
|
||
}
|
||
|
||
try:
|
||
self.raw_frame_queue.put(item, timeout=1.0)
|
||
except queue.Full:
|
||
print(f"[WARN] Raw frame queue full, drop frame from camera {self.camera_cfg.id}")
|
||
|
||
cap.release()
|
||
print(f"[INFO] Stop capturing: id={self.camera_cfg.id}")
|
||
|
||
# =========================
|
||
# 帧处理线程(抽帧 + 写mp4 + 调用用户函数 + 发WebSocket消息)
|
||
# =========================
|
||
|
||
class FrameProcessorWorker(threading.Thread):
|
||
def __init__(
|
||
self,
|
||
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
|
||
ws_send_queue: "queue.Queue[Dict[str, Any]]",
|
||
stop_event: threading.Event,
|
||
):
|
||
super().__init__(daemon=True)
|
||
self.raw_frame_queue = raw_frame_queue
|
||
self.ws_send_queue = ws_send_queue
|
||
self.stop_event = stop_event
|
||
|
||
# 每个摄像头独立维护视频写入状态
|
||
self.video_writers: Dict[int, cv2.VideoWriter] = {}
|
||
self.video_frame_counts: Dict[int, int] = {}
|
||
self.video_segment_start_ts: Dict[int, float] = {}
|
||
self.video_segment_filenames: Dict[int, str] = {}
|
||
|
||
os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 控制 10fps 抽帧:记录每个摄像头上次处理时间
|
||
self.last_process_ts: Dict[int, float] = {}
|
||
|
||
def _get_video_writer(self, camera_id: int, frame) -> Tuple[cv2.VideoWriter, str]:
|
||
writer = self.video_writers.get(camera_id)
|
||
if writer is not None:
|
||
return writer, self.video_segment_filenames[camera_id]
|
||
|
||
h, w = frame.shape[:2]
|
||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||
|
||
start_ts = time.time()
|
||
self.video_segment_start_ts[camera_id] = start_ts
|
||
|
||
ts_str = time.strftime("%Y%m%d_%H%M%S", time.localtime(start_ts))
|
||
filename = f"{ts_str}_cam{camera_id}.mp4"
|
||
filepath = os.path.join(VIDEO_OUTPUT_DIR, filename)
|
||
|
||
writer = cv2.VideoWriter(filepath, fourcc, RTSP_TARGET_FPS, (w, h))
|
||
self.video_writers[camera_id] = writer
|
||
self.video_frame_counts[camera_id] = 0
|
||
self.video_segment_filenames[camera_id] = filepath
|
||
|
||
print(f"[INFO] Start new segment: camera={camera_id}, file={filepath}")
|
||
return writer, filepath
|
||
|
||
def _close_segment_if_needed(self, camera_id: int):
|
||
count = self.video_frame_counts.get(camera_id, 0)
|
||
if count >= FRAMES_PER_SEGMENT:
|
||
writer = self.video_writers.get(camera_id)
|
||
if writer is not None:
|
||
writer.release()
|
||
print(f"[INFO] Close segment: camera={camera_id}, file={self.video_segment_filenames[camera_id]}")
|
||
|
||
self.video_writers.pop(camera_id, None)
|
||
self.video_frame_counts.pop(camera_id, None)
|
||
self.video_segment_start_ts.pop(camera_id, None)
|
||
self.video_segment_filenames.pop(camera_id, None)
|
||
|
||
def _encode_image_to_base64(self, image) -> str:
|
||
ok, buf = cv2.imencode(".jpg", image)
|
||
if not ok:
|
||
raise RuntimeError("Failed to encode image to JPEG")
|
||
return base64.b64encode(buf.tobytes()).decode("ascii")
|
||
|
||
def run(self):
|
||
print("[INFO] FrameProcessorWorker started")
|
||
target_interval = 1.0 / RTSP_TARGET_FPS
|
||
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
item = self.raw_frame_queue.get(timeout=0.5)
|
||
except queue.Empty:
|
||
continue
|
||
|
||
camera_id = item["camera_id"]
|
||
ts = item["timestamp"]
|
||
frame = item["frame"]
|
||
|
||
last_ts = self.last_process_ts.get(camera_id, 0.0)
|
||
if ts - last_ts < target_interval:
|
||
self.raw_frame_queue.task_done()
|
||
continue
|
||
self.last_process_ts[camera_id] = ts
|
||
|
||
# 1) 写入 mp4 (当前segment)
|
||
writer, video_filepath = self._get_video_writer(camera_id, frame)
|
||
writer.write(frame)
|
||
self.video_frame_counts[camera_id] = self.video_frame_counts.get(camera_id, 0) + 1
|
||
|
||
# 2) 调用用户自定义处理逻辑
|
||
result = user_process_frame(frame, camera_id, ts)
|
||
|
||
if result is not None and "image" in result and "type" in result:
|
||
result_img = result["image"]
|
||
result_type = int(result["type"])
|
||
|
||
# 3) 通过 WebSocket 发送帧结果
|
||
try:
|
||
img_b64 = self._encode_image_to_base64(result_img)
|
||
except Exception as e:
|
||
print(f"[ERROR] Encode image failed: {e}")
|
||
img_b64 = None
|
||
|
||
if img_b64 is not None:
|
||
msg = {
|
||
"msg_type": "frame",
|
||
"camera_id": camera_id,
|
||
"timestamp": ts,
|
||
"result_type": result_type,
|
||
"image_base64": img_b64,
|
||
}
|
||
try:
|
||
self.ws_send_queue.put(msg, timeout=1.0)
|
||
except queue.Full:
|
||
print("[WARN] ws_send_queue full, drop frame message")
|
||
|
||
# 4) 如果 result_type != 0,通过 WebSocket 发送告警
|
||
if result_type != 0:
|
||
alert_msg = {
|
||
"msg_type": "alert",
|
||
"camera_id": camera_id,
|
||
"event_type": result_type,
|
||
"video_file": video_filepath,
|
||
"timestamp": ts,
|
||
}
|
||
try:
|
||
self.ws_send_queue.put(alert_msg, timeout=1.0)
|
||
except queue.Full:
|
||
print("[WARN] ws_send_queue full, drop alert message")
|
||
|
||
# 5) 检查是否需要切换到下一个 mp4 segment
|
||
self._close_segment_if_needed(camera_id)
|
||
|
||
self.raw_frame_queue.task_done()
|
||
|
||
# 退出时,关闭所有 VideoWriter
|
||
for cam_id, writer in list(self.video_writers.items()):
|
||
writer.release()
|
||
print(f"[INFO] Release writer on exit: camera={cam_id}")
|
||
print("[INFO] FrameProcessorWorker stopped")
|
||
|
||
# =========================
|
||
# 用户自定义函数(重要:已集成 YOLO + 动作识别 + 跟踪 + 告警)
|
||
# =========================
|
||
|
||
def user_process_frame(image, camera_id: int, timestamp: float) -> Dict[str, Any]:
|
||
"""
|
||
集成了:
|
||
1. 视频帧输入
|
||
2. YOLO 目标检测(Supervisor / Suspect)
|
||
3. 对每个检测到的人物:
|
||
- 裁剪 ROI
|
||
- 预处理(resize 等)
|
||
- 根据类别选择动作识别模型(supervisor / suspect)
|
||
- 执行动作识别 ONNX 推理
|
||
- 解析动作类别并判断是否触发告警
|
||
4. 绘制结果(检测框、标签、告警)
|
||
5. 返回处理后的图像与告警类型
|
||
注意:尽量保持原实现逻辑(Ascend_NPU_YOLO_TSM_RealTime.py)
|
||
"""
|
||
global trackers, track_buffers, last_alert, track_role, track_action_result, recent_actions
|
||
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
|
||
|
||
# 初始化 per-camera 结构
|
||
if camera_id not in trackers:
|
||
if BYTETracker is not None:
|
||
# 使用 Ascend 源里的 Tracker 参数
|
||
class TrackerArgs:
|
||
track_thresh = 0.5
|
||
track_buffer = 30
|
||
match_thresh = 0.8
|
||
mot20 = False
|
||
try:
|
||
trackers[camera_id] = BYTETracker(TrackerArgs(), frame_rate=RTSP_TARGET_FPS)
|
||
print(f"[INFO] 初始化 BYTETracker for camera {camera_id}")
|
||
except Exception as e:
|
||
trackers[camera_id] = None
|
||
print(f"[WARN] 无法初始化 BYTETracker: {e}")
|
||
else:
|
||
trackers[camera_id] = None
|
||
print("[WARN] BYTETracker 未安装,跟踪功能不可用")
|
||
|
||
if camera_id not in track_buffers:
|
||
track_buffers[camera_id] = {}
|
||
if camera_id not in last_alert:
|
||
last_alert[camera_id] = {}
|
||
if camera_id not in track_role:
|
||
track_role[camera_id] = {}
|
||
if camera_id not in track_action_result:
|
||
track_action_result[camera_id] = {}
|
||
if camera_id not in recent_actions:
|
||
recent_actions[camera_id] = []
|
||
|
||
frame = image # BGR
|
||
h, w = frame.shape[:2]
|
||
|
||
# === 1. YOLO 检测 ===
|
||
detections = []
|
||
if yolo_model is not None:
|
||
try:
|
||
detections = yolo_model(frame)
|
||
# detections 格式: list of [x1,y1,x2,y2, conf, cls_id]
|
||
except Exception as e:
|
||
print(f"[WARN] YOLO 推理失败: {e}")
|
||
detections = []
|
||
else:
|
||
# 如果没有模型,返回原图(不报警)
|
||
return {"image": frame, "type": 0}
|
||
|
||
dets_xyxy = []
|
||
dets_roles = []
|
||
dets_for_tracker = []
|
||
supervisor_count = 0
|
||
suspect_count = 0
|
||
|
||
if detections:
|
||
for det in detections:
|
||
x1, y1, x2, y2, conf, cls_id = det
|
||
dets_xyxy.append([x1, y1, x2, y2])
|
||
if int(cls_id) == 0:
|
||
dets_roles.append("supervisor"); supervisor_count += 1
|
||
else:
|
||
dets_roles.append("suspect"); suspect_count += 1
|
||
dets_for_tracker.append([x1, y1, x2, y2, conf])
|
||
else:
|
||
dets_for_tracker = np.empty((0,5))
|
||
|
||
dets_for_tracker = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) > 0 else np.empty((0,5))
|
||
|
||
# === 2. 跟踪(BYTETracker) ===
|
||
tracker = trackers.get(camera_id)
|
||
if tracker is None:
|
||
# 没有 tracker 的情况下,仅在检测框上画出结果并返回
|
||
for i, det in enumerate(detections):
|
||
x1, y1, x2, y2, conf, cls_id = det
|
||
role = "supervisor" if int(cls_id) == 0 else "suspect"
|
||
color = (255,0,0) if role=="supervisor" else (0,0,255)
|
||
cv2.rectangle(frame, (int(x1),int(y1)), (int(x2),int(y2)), color, 2)
|
||
cv2.putText(frame, f"{role} {conf:.2f}", (int(x1)+5, int(y1)-6), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||
return {"image": frame, "type": 0}
|
||
|
||
if dets_for_tracker.size == 0:
|
||
dets_tensor = torch.zeros((0,5))
|
||
else:
|
||
dets_tensor = torch.from_numpy(dets_for_tracker).float()
|
||
|
||
# tracker.update expects args (dets, ori_img_shape, img_shape) in Ascend file they called tracker.update(dets_tensor, [h,w], [h,w])
|
||
try:
|
||
tracks = tracker.update(dets_tensor, [h, w], [h, w])
|
||
except Exception as e:
|
||
print(f"[WARN] tracker.update 出错: {e}")
|
||
tracks = []
|
||
|
||
current_time_sec = timestamp # 使用外部传入的时间戳(秒)
|
||
|
||
current_frame_abnormal_actions = []
|
||
|
||
# === 3. 每个 track 做 IoU 匹配角色并做动作识别 ===
|
||
for t in tracks:
|
||
try:
|
||
tid = t.track_id
|
||
x1, y1, x2, y2 = map(int, t.tlbr)
|
||
except Exception as e:
|
||
# 跳过无法解析的 track 对象
|
||
continue
|
||
|
||
# 有效性检查
|
||
if x2 <= x1 or y2 <= y1:
|
||
continue
|
||
|
||
# IoU 匹配找回类别(如果之前没有)
|
||
if tid not in track_role[camera_id] and dets_xyxy:
|
||
best_iou = 0.0
|
||
best_role = "unknown"
|
||
track_box = [x1, y1, x2, y2]
|
||
for i, det_box in enumerate(dets_xyxy):
|
||
iou = compute_iou(track_box, det_box)
|
||
if iou > best_iou:
|
||
best_iou = iou
|
||
best_role = dets_roles[i]
|
||
if best_iou > 0.3:
|
||
track_role[camera_id][tid] = best_role
|
||
|
||
role = track_role[camera_id].get(tid, "unknown")
|
||
|
||
# 扩展裁剪并 crop ROI
|
||
dw = int((x2 - x1) * EXPAND_RATIO)
|
||
dh = int((y2 - y1) * EXPAND_RATIO)
|
||
ex1, ey1 = max(0, x1 - dw), max(0, y1 - dh)
|
||
ex2, ey2 = min(w, x2 + dw), min(h, y2 + dh)
|
||
crop = frame[ey1:ey2, ex1:ex2]
|
||
if crop.size == 0:
|
||
continue
|
||
crop = cv2.resize(crop, (TARGET_SIZE, TARGET_SIZE))
|
||
|
||
# 填充 track_buffers
|
||
if tid not in track_buffers[camera_id]:
|
||
track_buffers[camera_id][tid] = []
|
||
track_buffers[camera_id][tid].append(crop)
|
||
if len(track_buffers[camera_id][tid]) > CLIP_LEN:
|
||
track_buffers[camera_id][tid] = track_buffers[camera_id][tid][-CLIP_LEN:]
|
||
|
||
# 默认值
|
||
action_text = "Detecting..."
|
||
conf_val = 0.0
|
||
action_name = "Normal"
|
||
|
||
# 若缓存达到 CLIP_LEN,则做动作识别
|
||
if len(track_buffers[camera_id][tid]) >= CLIP_LEN and sess_supervisor is not None and sess_suspect is not None:
|
||
tensor = preprocess_clip(track_buffers[camera_id][tid])
|
||
if tensor.dtype != np.float32:
|
||
tensor = tensor.astype(np.float32)
|
||
|
||
pred = None
|
||
labels = None
|
||
# 根据角色与上下文选择模型(保留 Ascend 中的逻辑)
|
||
if role == "supervisor" and suspect_count >= 1:
|
||
try:
|
||
pred = sess_supervisor.run(None, {input_name_sup: tensor})[0]
|
||
labels = LABELS_SUPERVISOR
|
||
except Exception as e:
|
||
print(f"[WARN] supervisor 模型推理失败: {e}")
|
||
pred = None
|
||
elif role == "suspect" and supervisor_count == 0:
|
||
try:
|
||
pred = sess_suspect.run(None, {input_name_sus: tensor})[0]
|
||
labels = LABELS_SUSPECT
|
||
except Exception as e:
|
||
print(f"[WARN] suspect 模型推理失败: {e}")
|
||
pred = None
|
||
else:
|
||
# 条件不满足,滑动窗口并继续
|
||
track_buffers[camera_id][tid] = track_buffers[camera_id][tid][SLIDE_STEP:]
|
||
continue
|
||
|
||
if pred is not None:
|
||
idx = int(np.argmax(pred[0]))
|
||
conf_val = float(pred[0][idx])
|
||
action_name = labels.get(idx, "Unknown")
|
||
action_text = f"{action_name}({conf_val:.2f})"
|
||
|
||
should_alert = False
|
||
|
||
# 角色-动作匹配逻辑(原封不动)
|
||
if (action_name == 'Slap' or action_name == 'Push') and role == 'supervisor':
|
||
should_alert = True
|
||
track_action_result[camera_id][tid] = f"{action_name}({conf_val:.2f})"
|
||
print(f"⏰ 时间:{current_time_sec:.2f} | Camera:{camera_id} | ID: {tid} | 动作:{action_name} | 置信度:{conf_val:.2f}")
|
||
elif (action_name == 'Hanging' or action_name == 'Collision' or action_name == 'Lyingdown') and role == 'suspect':
|
||
should_alert = True
|
||
track_action_result[camera_id][tid] = f"{action_name}({conf_val:.2f})"
|
||
print(f"⏰ 时间:{current_time_sec:.2f} | Camera:{camera_id} | ID: {tid} | 动作:{action_name} | 置信度:{conf_val:.2f}")
|
||
else:
|
||
if tid in track_action_result[camera_id]:
|
||
del track_action_result[camera_id][tid]
|
||
|
||
# 报警逻辑
|
||
if (should_alert and conf_val >= CONF_THRESH and
|
||
(tid not in last_alert[camera_id] or current_time_sec - last_alert[camera_id][tid] > ACTION_COOLDOWN)):
|
||
print(f"[ALERT] Camera:{camera_id} | ID:{tid} ({role}) -> {action_name} ({conf_val:.3f})")
|
||
last_alert[camera_id][tid] = current_time_sec
|
||
|
||
action_info = {
|
||
'time': current_time_sec,
|
||
'camera_id': camera_id,
|
||
'role': role,
|
||
'id': tid,
|
||
'action': action_name,
|
||
'confidence': conf_val
|
||
}
|
||
|
||
recent_actions[camera_id].append(action_info)
|
||
if len(recent_actions[camera_id]) > MAX_RECENT_ACTIONS:
|
||
recent_actions[camera_id].pop(0)
|
||
|
||
# 添加到当前帧异常动作列表(用于可视化)
|
||
current_frame_abnormal_actions.append(action_info)
|
||
|
||
# 滑动窗口
|
||
track_buffers[camera_id][tid] = track_buffers[camera_id][tid][SLIDE_STEP:]
|
||
|
||
# 可视化:若检测到异常动作则画框
|
||
action_to_show = track_action_result[camera_id].get(tid, None)
|
||
if action_to_show is not None and action_name != "Normal" and conf_val >= CONF_THRESH:
|
||
color = (255,0,0) if role == "supervisor" else (0,0,255)
|
||
cv2.rectangle(frame, (x1,y1), (x2,y2), color, 3)
|
||
overlay = frame.copy()
|
||
cv2.rectangle(overlay, (x1, y1 - 48), (x1 + 420, y1), color, -1)
|
||
cv2.addWeighted(overlay, 0.75, frame, 0.25, 0, frame)
|
||
cv2.putText(frame, f"{role.upper()} ID:{tid}", (x1 + 8, y1 - 25),
|
||
cv2.FONT_HERSHEY_DUPLEX, 0.8, (255,255,255), 2)
|
||
action_color = (0,0,255)
|
||
cv2.putText(frame, track_action_result[camera_id][tid], (x1 + 8, y1 - 3),
|
||
cv2.FONT_HERSHEY_DUPLEX, 0.9, action_color, 2)
|
||
cv2.putText(frame, "ALERT!", (x2 - 130, y1 - 8),
|
||
cv2.FONT_HERSHEY_COMPLEX, 1.1, (0,0,255), 3)
|
||
|
||
# === 全局信息显示(在图像左上) ===
|
||
# 统计当前跟踪角色
|
||
cur_supervisors = 0
|
||
cur_suspects = 0
|
||
for tid_map in track_role.get(camera_id, {}).items():
|
||
# tid_map is (tid, role) pairs
|
||
pass
|
||
# 上面的循环写法是为了不改变原逻辑结构;统计用下面的更直接方式
|
||
for tid, r in track_role.get(camera_id, {}).items():
|
||
if r == "supervisor":
|
||
cur_supervisors += 1
|
||
elif r == "suspect":
|
||
cur_suspects += 1
|
||
|
||
info = [
|
||
f"Camera: {camera_id}",
|
||
f"Targets: {len(tracks)}",
|
||
f"Supervisor: {cur_supervisors}",
|
||
f"Suspect: {cur_suspects}"
|
||
]
|
||
for i, text in enumerate(info):
|
||
cv2.putText(frame, text, (10, 35 + i * 28), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||
|
||
# 动作显示区域
|
||
status_y = 35 + len(info) * 28 + 10
|
||
if len(current_frame_abnormal_actions) > 0:
|
||
cv2.putText(frame, "ACTION DETECTED!", (10, status_y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 3)
|
||
for i, action_info in enumerate(current_frame_abnormal_actions):
|
||
role_text = action_info['role'].upper()
|
||
action_display = f"{role_text} ID:{action_info['id']} -> {action_info['action']} ({action_info['confidence']:.2f})"
|
||
color = (255,0,0) if action_info['role'] == "supervisor" else (0,0,255)
|
||
cv2.putText(frame, action_display, (10, status_y + 40 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||
else:
|
||
# recent actions 显示
|
||
if recent_actions[camera_id]:
|
||
for i, action_info in enumerate(recent_actions[camera_id][-MAX_RECENT_ACTIONS:]):
|
||
action_display = f"{action_info['role'].upper()} ID:{action_info['id']} -> {action_info['action']} ({action_info['confidence']:.2f})"
|
||
color = (255,0,0) if action_info['role'] == "supervisor" else (0,0,255)
|
||
cv2.putText(frame, action_display, (10, status_y + 10 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
||
else:
|
||
cv2.putText(frame, "Detecting...", (10, status_y), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
|
||
|
||
# 返回:image 与 type(type 0 = 无报警;1 = 有报警)
|
||
result_type = 1 if any(len(recent_actions[camera_id])>0 and item.get('confidence',0)>=CONF_THRESH for item in recent_actions[camera_id]) else 0
|
||
|
||
return {
|
||
"image": frame,
|
||
"type": result_type
|
||
}
|
||
|
||
# =========================
|
||
# 服务封装(不变)
|
||
# =========================
|
||
|
||
class RTSPService:
|
||
def __init__(self, config_path: str):
|
||
self.config_path = config_path
|
||
self.cameras = self._load_config()
|
||
|
||
self.stop_event = threading.Event()
|
||
|
||
# 队列
|
||
self.raw_frame_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=500)
|
||
self.ws_send_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=1000)
|
||
|
||
# 线程
|
||
self.capture_workers = []
|
||
self.frame_processor = FrameProcessorWorker(self.raw_frame_queue, self.ws_send_queue, self.stop_event)
|
||
self.ws_sender = WebSocketSender(self.ws_send_queue, self.stop_event)
|
||
|
||
def _load_config(self):
|
||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
cameras_cfg = cfg.get("cameras", [])
|
||
cameras = []
|
||
for c in cameras_cfg:
|
||
cameras.append(
|
||
CameraConfig(
|
||
id=int(c["id"]),
|
||
name=str(c.get("name", f"cam_{c['id']}")),
|
||
rtsp_url=str(c["rtsp_url"]),
|
||
)
|
||
)
|
||
return cameras
|
||
|
||
def start(self):
|
||
print("[INFO] RTSPService starting...")
|
||
|
||
# 启动 WebSocket 发送线程
|
||
self.ws_sender.start()
|
||
|
||
# 启动帧处理线程
|
||
self.frame_processor.start()
|
||
|
||
# 启动每个摄像头的抓流线程
|
||
for cam in self.cameras:
|
||
w = RTSPCaptureWorker(cam, self.raw_frame_queue, self.stop_event)
|
||
w.start()
|
||
self.capture_workers.append(w)
|
||
|
||
print("[INFO] RTSPService started")
|
||
|
||
def stop(self):
|
||
print("[INFO] RTSPService stopping...")
|
||
self.stop_event.set()
|
||
|
||
# 等待队列处理完(可选)
|
||
try:
|
||
self.raw_frame_queue.join()
|
||
self.ws_send_queue.join()
|
||
except Exception:
|
||
pass
|
||
|
||
for w in self.capture_workers:
|
||
w.join(timeout=1.0)
|
||
self.frame_processor.join(timeout=1.0)
|
||
self.ws_sender.join(timeout=1.0)
|
||
|
||
print("[INFO] RTSPService stopped")
|
||
|
||
def main():
|
||
service = RTSPService(config_path="config.yaml")
|
||
service.start()
|
||
try:
|
||
while True:
|
||
time.sleep(1.0)
|
||
except KeyboardInterrupt:
|
||
print("[INFO] KeyboardInterrupt, shutting down...")
|
||
finally:
|
||
service.stop()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|