Files
SupervisorAI/rtsp_service_ws_1217.py

1047 lines
37 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rtsp_service_ws_new.py (merged with YOLO + MMAction2 + Face Recognition)
from datetime import datetime
import cv2
import time
import threading
import queue
import yaml
import os
import json
import base64
import asyncio
import websockets
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple, List
# --- 新增依赖YOLO/ONNX/跟踪/NumPy ---
import numpy as np
import onnxruntime as ort
import torch
# 导入人脸识别算法
try:
from api.routes.algorithm_router import video_face_biz
print("[INFO] 成功导入人脸识别算法")
except Exception as e:
print(f"[WARN] 无法导入人脸识别算法: {e}")
from yolox.tracker.byte_tracker import BYTETracker
from npu_yolo_onnx_yolo11n import YOLOv8_ONNX
# =========================
# 配置与数据结构
# =========================
@dataclass
class CameraConfig:
id: int
name: str
rtsp_url: str
RTSP_TARGET_FPS = 30 # 固定 10 帧/秒
FRAMES_PER_SEGMENT = 1800 # 每 600 帧一个 mp4
VIDEO_OUTPUT_DIR = "./videos" # 视频输出目录
WS_HOST = "0.0.0.0" # WebSocket 服务端监听地址
WS_PORT = 8765 # WebSocket 服务端端口
# 人脸识别相关配置
FACE_RECOGNITION_ENABLED = True # 是否启用人脸识别
FACE_ALERT_COOLDOWN = 5.0 # 同一个人脸黑名单告警冷却时间(秒)
FACE_REGISTER_DIR = "test_data/register" # 人脸注册目录
# 已连接的 WebSocket 客户端集合
ws_clients = set()
# =========================
# YOLO / 动作识别 / 跟踪 配置
# =========================
# --- 请根据你实际路径修改下面的 ONNX 模型路径 ---
YOLO_ONNX_PATH = "YOLO_Weight/yolov8l.onnx" # <-- 改为实际路径
SUPERVISOR_ONNX = "ONNX_Weight/Supervisor.onnx" # <-- 改为实际路径
SUSPECT_ONNX = "ONNX_Weight/Suspect.onnx" # <-- 改为实际路径
# 动作标签(来自 Ascend_NPU_YOLO_TSM_RealTime.py
LABELS_SUPERVISOR = {0: 'Normal', 1: 'Push', 2: 'Slap'}
LABELS_SUSPECT = {0: 'Collision', 1: 'Hanging', 2: 'Lyingdown', 3: 'Normal'}
# 超参数
Show_Framework = False
Show_MMAction2_Video = False
CLIP_LEN = 16 # 修改缓存帧数改为32
FRAME_INTERVAL = 4 #每隔4帧采样一次
#SLIDE_STEP = 4 # 修改滑动步长改为16
SLIDE_STEP = 4
CONF_THRESH = 0.1
ACTION_COOLDOWN = 0.0 # 同一目标重复报警间隔(秒)
EXPAND_RATIO = 0.3
LeavingPost_Threshold = 10 #离岗阈值
LeavingPost_Show_Interval = 3 #显示离岗的时间间隔
MMAction_Preprocess_Size = 256
MMAction_TARGET_SIZE = 224
CROP_OFFSET= (MMAction_Preprocess_Size-MMAction_TARGET_SIZE)//2
YOLO_CONF_THRESH = 0.25 # YOLO检测置信度阈值 0.25
YOLO_IOU_THRESH = 0.6 # YOLO NMS IoU阈值 0.7
YOLO_INPUT_SIZE_1 = 640 #YOLO的预处理的尺寸
YOLO_INPUT_SIZE_2 = 640 #YOLO的预处理的尺寸
# 缓存区(关键改动)
track_raw_frames = {} # tid → list of original full frames
track_raw_coords = {} # tid → list of expanded bbox [x1,y1,x2,y2](原始坐标)
track_role = {} # tid → "supervisor" or "suspect"
last_alert = {}
frame_idx = 0
prev_time = time.time()
track_action_result = {} # tid -> 动作结果字符串
recent_actions = [] # 存储最近的动作检测结果
MAX_RECENT_ACTIONS = 3 # 最多显示3个最近动作
ACTION_DISPLAY_DURATION = 1.0 # 动作显示持续时间(秒)
action_display_start_time = 0 # 动作显示开始时间
leavingpost_time = 0
leavingpost_show = 0
# 人脸识别相关数据结构
face_last_alert: Dict[int, Dict[str, float]] = {} # camera_id -> {person_name -> last_alert_time}
# YOLO 和动作识别 session单例式
yolo_model = None
sess_supervisor = None
sess_suspect = None
input_name_sup = None
input_name_sus = None
# =========================
# 初始化模型(尝试导入/创建 session
# =========================
def init_models_once():
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
# YOLO
if YOLOv8_ONNX is None:
print("[ERROR] YOLOv8_ONNX 未导入,无法初始化 YOLO 模型")
else:
try:
yolo_model = YOLOv8_ONNX(YOLO_ONNX_PATH, conf_threshold=YOLO_CONF_THRESH, iou_threshold=YOLO_IOU_THRESH)
print("[INFO] YOLO 模型初始化完成")
except Exception as e:
print(f"[ERROR] YOLO 模型初始化失败: {e}")
yolo_model = None
# -----------------------------
# 动作识别模型初始化(正确的 Provider 判断方式)
# -----------------------------
try:
# 请求使用 CANN但是否真正启用必须用 get_providers 判断
providers = [
("CANNExecutionProvider", {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"npu_mem_limit": 16 * 1024 * 1024 * 1024,
"precision_mode": "allow_fp32_to_fp16",
"op_select_impl_mode": "high_precision",
"enable_cann_graph": True,
}),
"CUDAExecutionProvider",
"CPUExecutionProvider", # 自动 fallback
]
sess_supervisor = ort.InferenceSession(SUPERVISOR_ONNX, providers=providers)
sess_suspect = ort.InferenceSession(SUSPECT_ONNX, providers=providers)
sup_prov = sess_supervisor.get_providers()
sus_prov = sess_suspect.get_providers()
print("Supervisor Providers:", sup_prov)
print("Suspect Providers:", sus_prov)
if "CANNExecutionProvider" in sup_prov:
print("[INFO] 动作识别模型:使用 CANNExecutionProvider昇腾")
else:
print("[INFO] 动作识别模型:使用 CPUExecutionProvider非昇腾环境")
except Exception as e:
print(f"[ERROR] 初始化动作识别模型失败: {e}")
sess_supervisor = None
sess_suspect = None
if sess_supervisor is not None:
input_name_sup = sess_supervisor.get_inputs()[0].name
print(f"[INFO] 监护人模型输入: {input_name_sup}")
if sess_suspect is not None:
input_name_sus = sess_suspect.get_inputs()[0].name
print(f"[INFO] 被监护人模型输入: {input_name_sus}")
# 只初始化一次
init_models_once()
# =========================
# 工具函数IoU, preprocess_clip
# =========================
def compute_iou(boxA, boxB):
# box = [x1, y1, x2, y2]
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interW = max(0, xB - xA)
interH = max(0, yB - yA)
interArea = interW * interH
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
unionArea = boxAArea + boxBArea - interArea
if unionArea == 0:
return 0.0
return interArea / unionArea
def compute_stable_crop_region(coord_list, img_w, img_h, extra_ratio=0.05, pad=15, force_square=False):
"""计算最大包络框 + 安全垫"""
if not coord_list:
return 0, 0, img_w, img_h
x1s = [c[0] for c in coord_list]
y1s = [c[1] for c in coord_list]
x2s = [c[2] for c in coord_list]
y2s = [c[3] for c in coord_list]
ux1 = max(0, min(x1s))
uy1 = max(0, min(y1s))
ux2 = min(img_w, max(x2s))
uy2 = min(img_h, max(y2s))
w = ux2 - ux1
h = uy2 - uy1
extra_w = int(w * extra_ratio)
extra_h = int(h * extra_ratio)
final_x1 = max(0, ux1 - extra_w - pad)
final_y1 = max(0, uy1 - extra_h - pad)
final_x2 = min(img_w, ux2 + extra_w + pad)
final_y2 = min(img_h, uy2 + extra_h + pad)
# 强制转换为正方形
if force_square:
square_size = max(final_x2 - final_x1, final_y2 - final_y1)
center_x = (final_x1 + final_x2) // 2
center_y = (final_y1 + final_y2) // 2
final_x1 = max(0, center_x - square_size // 2)
final_y1 = max(0, center_y - square_size // 2)
final_x2 = min(img_w, final_x1 + square_size)
final_y2 = min(img_h, final_y1 + square_size)
# 确保尺寸一致
size = min(final_x2 - final_x1, final_y2 - final_y1)
final_x2 = final_x1 + size
final_y2 = final_y1 + size
return int(final_x1), int(final_y1), int(final_x2), int(final_y2)
def preprocess_clip(frames):
"""优化版本的预处理函数"""
# 采样帧索引
#indices = list(range(0, CLIP_LEN, int(FRAME_INTERVAL)))
indices = list(range(0, int(CLIP_LEN)))
h, w = frames[0].shape[:2]
# 预分配结果数组
imgs = np.empty((len(indices), 3, h, w), dtype=np.float32)
for i, index in enumerate(indices):
f = frames[index]
# 使用更快的插值方法
# f = cv2.resize(f, (MMAction_TARGET_SIZE, MMAction_TARGET_SIZE),
# interpolation=cv2.INTER_LINEAR)
# 直接赋值避免额外拷贝
f = f.transpose(2, 0, 1).astype(np.float32)
imgs[i] = f
# 添加批次维度
x = imgs[np.newaxis]
# 预计算的归一化参数
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32).reshape(1, 3, 1, 1)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32).reshape(1, 3, 1, 1)
# 向量化归一化
result = (x - mean) / std
return result.astype(np.float32)
def save_clip_as_video(clip_frames, tid, role, frame_idx, save_dir="../debug_clips"):
"""
把 16 帧 list[ndarray] 保存为 mp4文件名包含 ID、角色、起始帧号
"""
os.makedirs(save_dir, exist_ok=True)
# 16 帧 → 假设原视频 30fps这里也用 30fps或用原 fps
clip_fps = 30 # 可调小点,视频更短更清晰
h, w = clip_frames[0].shape[:2]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{save_dir}/Clip_ID{tid:03d}_{role}_F{frame_idx:06d}_{timestamp}.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(filename, fourcc, clip_fps, (w, h))
for frame in clip_frames:
# 如果你的是已经 resize 到 224×224 的 BGR 图像,直接写
out.write(frame)
out.release()
print(f"[DEBUG] 已保存调试视频 → {filename}")
# ByteTrack
class TrackerArgs:
track_thresh = 0.5
track_buffer = 30
match_thresh = 0.8
mot20 = False
tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_TARGET_FPS)
# =========================
# WebSocket 服务线程(不变)
# =========================
class WebSocketSender(threading.Thread):
def __init__(self, send_queue: "queue.Queue[Dict[str, Any]]", stop_event: threading.Event):
super().__init__(daemon=True)
self.send_queue = send_queue
self.stop_event = stop_event
async def _ws_handler(self, websocket):
ws_clients.add(websocket)
try:
async for _ in websocket:
pass
finally:
ws_clients.discard(websocket)
async def _broadcaster(self):
while not self.stop_event.is_set():
try:
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
except queue.Empty:
continue
data = json.dumps(msg)
dead = []
for ws in list(ws_clients):
try:
await ws.send(data)
except Exception:
dead.append(ws)
for ws in dead:
ws_clients.discard(ws)
self.send_queue.task_done()
async def _run_async(self):
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
await self._broadcaster()
def run(self):
asyncio.run(self._run_async())
# =========================
# RTSP 抓流线程(不变)
# =========================
class RTSPCaptureWorker(threading.Thread):
def __init__(
self,
camera_cfg: CameraConfig,
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
stop_event: threading.Event,
):
super().__init__(daemon=True)
self.camera_cfg = camera_cfg
self.raw_frame_queue = raw_frame_queue
self.stop_event = stop_event
def run(self):
cap = cv2.VideoCapture(self.camera_cfg.rtsp_url, cv2.CAP_FFMPEG)
if not cap.isOpened():
print(f"[ERROR] Cannot open RTSP stream: {self.camera_cfg.rtsp_url}")
return
print(f"[INFO] Start capturing: id={self.camera_cfg.id}, name={self.camera_cfg.name}")
while not self.stop_event.is_set():
ok, frame = cap.read()
if not ok:
print(f"[WARN] Failed to read frame from camera {self.camera_cfg.id}, retrying...")
time.sleep(0.2)
continue
ts = time.time()
item = {
"camera_id": self.camera_cfg.id,
"camera_name": self.camera_cfg.name,
"timestamp": ts,
"frame": frame,
}
try:
self.raw_frame_queue.put(item, timeout=1.0)
except queue.Full:
print(f"[WARN] Raw frame queue full, drop frame from camera {self.camera_cfg.id}")
cap.release()
print(f"[INFO] Stop capturing: id={self.camera_cfg.id}")
# =========================
# 帧处理线程(抽帧 + 写mp4 + 调用用户函数 + 发WebSocket消息
# =========================
class FrameProcessorWorker(threading.Thread):
def __init__(
self,
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
ws_send_queue: "queue.Queue[Dict[str, Any]]",
stop_event: threading.Event,
):
super().__init__(daemon=True)
self.raw_frame_queue = raw_frame_queue
self.ws_send_queue = ws_send_queue
self.stop_event = stop_event
# 每个摄像头独立维护视频写入状态
self.video_writers: Dict[int, cv2.VideoWriter] = {}
self.video_frame_counts: Dict[int, int] = {}
self.video_segment_start_ts: Dict[int, float] = {}
self.video_segment_filenames: Dict[int, str] = {}
os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)
# 控制 10fps 抽帧:记录每个摄像头上次处理时间
self.last_process_ts: Dict[int, float] = {}
# 人脸识别相关初始化
self.face_last_alert = face_last_alert
self.current_face_alert = None # 存储当前帧的人脸告警信息
def _get_video_writer(self, camera_id: int, frame) -> Tuple[cv2.VideoWriter, str]:
writer = self.video_writers.get(camera_id)
if writer is not None:
return writer, self.video_segment_filenames[camera_id]
h, w = frame.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
start_ts = time.time()
self.video_segment_start_ts[camera_id] = start_ts
ts_str = time.strftime("%Y%m%d_%H%M%S", time.localtime(start_ts))
filename = f"{ts_str}_cam{camera_id}.mp4"
filepath = os.path.join(VIDEO_OUTPUT_DIR, filename)
writer = cv2.VideoWriter(filepath, fourcc, RTSP_TARGET_FPS, (w, h))
self.video_writers[camera_id] = writer
self.video_frame_counts[camera_id] = 0
self.video_segment_filenames[camera_id] = filepath
print(f"[INFO] Start new segment: camera={camera_id}, file={filepath}")
return writer, filepath
def _close_segment_if_needed(self, camera_id: int):
count = self.video_frame_counts.get(camera_id, 0)
if count >= FRAMES_PER_SEGMENT:
writer = self.video_writers.get(camera_id)
if writer is not None:
writer.release()
print(f"[INFO] Close segment: camera={camera_id}, file={self.video_segment_filenames[camera_id]}")
self.video_writers.pop(camera_id, None)
self.video_frame_counts.pop(camera_id, None)
self.video_segment_start_ts.pop(camera_id, None)
self.video_segment_filenames.pop(camera_id, None)
def _encode_image_to_base64(self, image) -> str:
ok, buf = cv2.imencode(".jpg", image)
if not ok:
raise RuntimeError("Failed to encode image to JPEG")
return base64.b64encode(buf.tobytes()).decode("ascii")
def run(self):
print("[INFO] FrameProcessorWorker started")
target_interval = 1.0 / RTSP_TARGET_FPS
while not self.stop_event.is_set():
try:
item = self.raw_frame_queue.get(timeout=0.5)
except queue.Empty:
continue
camera_id = item["camera_id"]
ts = item["timestamp"]
frame = item["frame"]
last_ts = self.last_process_ts.get(camera_id, 0.0)
if ts - last_ts < target_interval:
self.raw_frame_queue.task_done()
continue
self.last_process_ts[camera_id] = ts
# 1) 写入 mp4 (当前segment)
writer, video_filepath = self._get_video_writer(camera_id, frame)
# 2) 进行人脸识别(如果启用)
face_results = []
face_processing_time = 0
if video_face_biz is not None and FACE_RECOGNITION_ENABLED:
try:
# 处理当前帧 - 获取人脸识别结果
processed_frame_for_face, face_results, face_processing_time = video_face_biz.process_frame(
frame)
# 检查是否有黑名单匹配
if camera_id not in self.face_last_alert:
self.face_last_alert[camera_id] = {}
for result in face_results:
if result['is_match'] and result['best_match']:
person_name = result['best_match']
similarity = result['similarity']
# 检查是否需要发送告警(冷却时间控制)
last_alert_time = self.face_last_alert[camera_id].get(person_name, 0)
if ts - last_alert_time > FACE_ALERT_COOLDOWN:
# 记录当前帧人脸告警信息
self.current_face_alert = {
"person_name": person_name,
"similarity": similarity,
"timestamp": ts
}
print(f"[FACE ALERT] Camera:{camera_id} | Person:{person_name} | Similarity:{similarity:.3f}")
# 更新最后告警时间
self.face_last_alert[camera_id][person_name] = ts
except Exception as e:
print(f"[WARN] 人脸识别处理失败: {e}")
# 3) 调用用户自定义处理逻辑YOLO+动作识别)
result = user_process_frame(frame, camera_id, ts)
result_img = None
#self.current_face_alert = None # 重置当前帧人脸告警信息
# 4) 在人脸识别结果上绘制人脸检测框
if result is not None and "image" in result:
result_img = result["image"]
#result_type = int(result["type"]) if "type" in result else 0
abnormal_actions = result["actions"]
current_fps = result["FPS"]
current_person = result["person_count"]
tracks = result["Targets"]
#============================================================================================================
#绘制动作识别结果
info = [
f"FPS: {current_fps:.1f}",
f"Targets: {tracks}",
f"Person: {current_person}",
]
for i, text in enumerate(info):
cv2.putText(result_img, text, (10, 35 + i * 38), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 255), 2)
# 动作检测状态显示(在基础信息下方)
status_y = 35 + len(info) * 38 + 20
# 检查是否有异常动作
has_abnormal_action = len(abnormal_actions) > 0
if has_abnormal_action:
# 显示"ACTION DETECTED!"标题
cv2.putText(result_img, "ACTION DETECTED!", (10, status_y),
cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0, 0, 255), 3)
# 显示每个检测到的异常动作
for i, action_info in enumerate(abnormal_actions):
action_name = action_info['action']
conf = action_info['confidence']
# 根据角色选择颜色
if action_info['role'] == "person":
action_color = (255, 0, 0) # 蓝色
else:
action_color = (0, 0, 255) # 红色
# 显示动作信息
action_display = f"{action_name} ({conf:.2f})"
cv2.putText(result_img, action_display, (10, status_y + 40 + i * 35),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, action_color, 2)
else:
# 检查是否在显示时间内
if ts - action_display_start_time < ACTION_DISPLAY_DURATION and recent_actions:
# 显示最近检测到的动作
cv2.putText(result_img, "RECENT ACTIONS:", (10, status_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 0), 2)
# 显示最近的动作
for i, action_info in enumerate(recent_actions[-MAX_RECENT_ACTIONS:]):
action_name = action_info['action']
conf = action_info['confidence']
# 根据角色选择颜色
if action_info['role'] == "person":
action_color = (255, 0, 0) # 蓝色
else:
action_color = (0, 0, 255) # 红色
# 显示动作信息
action_display = f"{action_name} ({conf:.2f})"
cv2.putText(result_img, action_display, (10, status_y + 40 + i * 35),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, action_color, 2)
else:
# 显示"Detecting..."
cv2.putText(result_img, "Detecting...", (10, status_y),
cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0, 255, 0), 2)
#============================================================================================================
# 绘制人脸识别结果
if video_face_biz is not None and face_results:
result_img = video_face_biz.draw_detections(result_img, face_results)
# 添加人脸识别统计信息
match_count = sum(1 for r in face_results if r['is_match'])
face_info_text = f"Faces: {len(face_results)} | Matches: {match_count}"
cv2.putText(result_img, face_info_text, (10, result_img.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
# 5) 通过 WebSocket 发送帧结果
try:
img_b64 = self._encode_image_to_base64(result_img)
except Exception as e:
print(f"[ERROR] Encode image failed: {e}")
img_b64 = None
if img_b64 is not None:
# 将abnormal_actions对象数组转换为字符串数组
action_names = [action_info['action'] for action_info in abnormal_actions]
# 如果有当前帧人脸告警添加到result_type中
if self.current_face_alert is not None:
action_names.append("face")
msg = {
"msg_type": "frame",
"camera_id": camera_id,
"timestamp": ts,
"result_type": action_names,
"image_base64": img_b64,
}
try:
self.ws_send_queue.put(msg, timeout=1.0)
except queue.Full:
print("[WARN] ws_send_queue full, drop frame message")
# 7) 检查是否需要切换到下一个 mp4 segment
self.current_face_alert = None
self._close_segment_if_needed(camera_id)
self.raw_frame_queue.task_done()
writer.write(result_img)
self.video_frame_counts[camera_id] = self.video_frame_counts.get(camera_id, 0) + 1
# 退出时,关闭所有 VideoWriter
for cam_id, writer in list(self.video_writers.items()):
writer.release()
print(f"[INFO] Release writer on exit: camera={cam_id}")
print("[INFO] FrameProcessorWorker stopped")
# =========================
# 用户自定义函数(重要:已集成 YOLO + 动作识别 + 跟踪 + 告警)
# =========================
def user_process_frame(image, camera_id: int, timestamp: float) -> Dict[str, Any]:
"""
集成了:
1. 视频帧输入
2. YOLO 目标检测Supervisor / Suspect
3. 对每个检测到的人物:
- 裁剪 ROI
- 预处理resize 等)
- 根据类别选择动作识别模型supervisor / suspect
- 执行动作识别 ONNX 推理
- 解析动作类别并判断是否触发告警
4. 绘制结果(检测框、标签、告警)
5. 返回处理后的图像与告警类型
注意尽量保持原实现逻辑Ascend_NPU_YOLO_TSM_RealTime.py
"""
global track_raw_frames, track_raw_coords, last_alert, track_role, \
track_action_result, recent_actions, frame_idx, prev_time, leavingpost_time, leavingpost_show, action_display_start_time
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
frame = image # BGR
h, w = frame.shape[:2]
# === 1. YOLO 检测 ===
detections = []
if yolo_model is not None:
try:
detections = yolo_model(frame)
# detections 格式: list of [x1,y1,x2,y2, conf, cls_id]
except Exception as e:
print(f"[WARN] YOLO 推理失败: {e}")
detections = []
else:
# 如果没有模型,返回原图(不报警)
return {"image": frame, "type": 0}
current_time_sec = timestamp # 使用外部传入的时间戳(秒)
frame_idx += 1
dets_xyxy = []
dets_roles = []
dets_for_tracker = []
person_count = 0
if detections:
for det in detections:
x1, y1, x2, y2, conf, cls_id = det # x1, y1, x2, y2为角点坐标x1 y1为左上角x2 y2为右下角
dets_xyxy.append([x1, y1, x2, y2])
dets_for_tracker.append([x1, y1, x2, y2, conf])
if cls_id == 0:
dets_roles.append("person")
person_count += 1
dets = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) else np.empty((0, 5))
tracks = tracker.update(
# torch.from_numpy(dets_for_tracker).float() if len(dets_for_tracker) else torch.zeros((0, 5)),
dets,
[h, w],
[h, w]
)
current_frame_abnormal_actions = []
# === 3. 每个 track 做 IoU 匹配角色并做动作识别 ===
for t in tracks:
tid = t.track_id
role = track_role.get(tid, "unknown")
if tid not in track_role:
best_iou = 0
best_role = "unknown"
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
for i, box in enumerate(dets_xyxy):
iou_val = compute_iou(t_box, box)
if iou_val > best_iou:
best_iou = iou_val
best_role = dets_roles[i]
if best_iou > 0.1:
track_role[tid] = best_role
else:
track_role[tid] = "unknown"
x1, y1, x2, y2 = map(int, t.tlbr)
if tid not in track_raw_frames:
track_raw_frames[tid] = []
track_raw_coords[tid] = []
if frame_idx % FRAME_INTERVAL == 0:
track_raw_coords[tid].append([x1, y1, x2, y2])
track_raw_frames[tid].append(frame.copy())
# 保持缓存长度
if len(track_raw_frames[tid]) > CLIP_LEN:
track_raw_frames[tid] = track_raw_frames[tid][-CLIP_LEN:]
track_raw_coords[tid] = track_raw_coords[tid][-CLIP_LEN:]
# 默认值
conf = 0.0
action = ""
# 若缓存达到 CLIP_LEN则做动作识别
if len(track_raw_frames[tid]) >= CLIP_LEN:
fx1, fy1, fx2, fy2 = compute_stable_crop_region(track_raw_coords[tid], w, h, EXPAND_RATIO)
stable_clip = []
# for raw_f in sameples_frames:
for raw_f in track_raw_frames[tid]:
crop = raw_f[fy1:fy2, fx1:fx2]
if crop.size == 0:
crop = np.zeros((fy2 - fy1, fx2 - fx1, 3), dtype=np.uint8)
crop = cv2.resize(crop, (MMAction_TARGET_SIZE, MMAction_TARGET_SIZE))
stable_clip.append(crop)
if Show_MMAction2_Video == True:
save_clip_as_video(stable_clip, tid, role, frame_idx)
tensor = preprocess_clip(stable_clip)
# 再次确保类型正确
if tensor.dtype != np.float32:
tensor = tensor.astype(np.float32)
pred = None
labels = None
try:
if 2 <= person_count <= 3:
pred = sess_supervisor.run(None, {input_name_sup: tensor})[0]
labels = LABELS_SUPERVISOR
elif person_count == 1:
pred = sess_suspect.run(None, {input_name_sus: tensor})[0]
labels = LABELS_SUSPECT
# 增加维度和空值校验
if pred is None or len(pred) == 0:
action = "Unknown"
conf = 0.0
else:
# 现有逻辑pred已是置信度
idx = int(np.argmax(pred[0]))
conf = float(pred[0][idx])
action = labels.get(idx, "Unknown")
action_text = f"{action}({conf:.2f})"
except Exception as e:
print(f"❌ 动作识别失败:{str(e)[:50]}")
action = "识别失败"
conf = 0.0
action_text = f"{action}({conf:.2f})"
# 检查角色-动作匹配
should_alert = False
if (action == 'Slap' or action == 'Push') and role == 'person' and 2 <= person_count <= 3:
should_alert = True
# 存储动作结果
track_action_result[tid] = f"{action}({conf:.2f})"
# print(f"⏰ 时间:{current_time_sec:.2f} | ID: {tid} | 动作:{action} | 置信度:{conf:.2f}")
elif (
action == 'Hanging' or action == 'Collision' or action == 'Lyingdown') and role == 'person' and person_count == 1:
should_alert = True
track_action_result[tid] = f"{action}({conf:.2f})"
# 报警逻辑 - 只对匹配的动作报警
if (should_alert and action != 'Normal' and conf >= CONF_THRESH and
(tid not in last_alert or current_time_sec - last_alert[tid] > ACTION_COOLDOWN)):
print(f"【报警】⏰ 时间:{current_time_sec:.2f} | {action} ({conf:.3f})")
last_alert[tid] = current_time_sec
# 记录异常动作
action_info = {
'time': current_time_sec,
'role': role,
'action': action,
'confidence': conf
}
# 添加到最近动作列表
recent_actions.append(action_info)
# 限制列表长度
if len(recent_actions) > MAX_RECENT_ACTIONS:
recent_actions.pop(0)
# 设置动作显示开始时间
action_display_start_time = current_time_sec
# 添加到当前帧异常动作列表
current_frame_abnormal_actions.append(action_info)
# track_buffers[tid] = track_buffers[tid][SLIDE_STEP:]
# 滑动窗口保留后24帧
track_raw_frames[tid] = track_raw_frames[tid][SLIDE_STEP:]
track_raw_coords[tid] = track_raw_coords[tid][SLIDE_STEP:]
if Show_Framework == True:
color = (255, 0, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
overlay = frame.copy()
cv2.rectangle(overlay, (x1, y1 - 48), (x1 + 420, y1), color, -1)
cv2.addWeighted(overlay, 0.75, frame, 0.25, 0, frame)
cv2.putText(frame, f"{role.upper()} ID:{tid}", (x1 + 8, y1 - 25),
cv2.FONT_HERSHEY_DUPLEX, 0.8, (255, 255, 255), 2)
action_color = (0, 0, 255) # 红色显示异常动作
if tid in track_action_result:
cv2.putText(frame, track_action_result[tid], (x1 + 8, y1 - 3),
cv2.FONT_HERSHEY_DUPLEX, 0.9, action_color, 2)
cv2.putText(frame, "ALERT!", (x2 - 130, y1 - 8),
cv2.FONT_HERSHEY_COMPLEX, 1.1, (0, 0, 255), 3)
if 1 <= person_count < 3:
leavingpost_time += 1
leavingpost_show += 1
if int(leavingpost_time / RTSP_TARGET_FPS) > LeavingPost_Threshold and (int(leavingpost_show / RTSP_TARGET_FPS) % LeavingPost_Show_Interval) == 0:
current_frame_abnormal_actions.append(
{
'time': current_time_sec,
'role': 'person',
'action': "LeavingPost",
'confidence': 1.0
}
)
print(f'【报警】⏰ 时间:{current_time_sec:.2f} | LeavingPost')
elif person_count == 3:
leavingpost_show = 0
leavingpost_time = 0
# 全局信息
current_fps = 1.0 / (time.time() - prev_time) if frame_idx > 1 else 0
prev_time = time.time()
# 正确统计当前帧的角色
current_person = 0
for t in tracks:
tid = t.track_id
role = track_role.get(tid, "unknown")
if role == "person":
current_person += 1
# 返回image 与 typetype 0 = 无报警1 = 有报警)
#result_type = 1 if any(len(recent_actions[camera_id]) > 0 and item.get('confidence', 0) >= CONF_THRESH for item in recent_actions[camera_id]) else 0
return {
"image": frame,
"actions": current_frame_abnormal_actions,
"FPS": current_fps,
"person_count": person_count,
"Targets": len(tracks),
}
# =========================
# 服务封装(不变)
# =========================
class RTSPService:
def __init__(self, config_path: str):
self.config_path = config_path
self.cameras = self._load_config()
self.stop_event = threading.Event()
# 队列
self.raw_frame_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=500)
self.ws_send_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=1000)
# 线程
self.capture_workers = []
self.frame_processor = FrameProcessorWorker(self.raw_frame_queue, self.ws_send_queue, self.stop_event)
self.ws_sender = WebSocketSender(self.ws_send_queue, self.stop_event)
def _load_config(self):
with open(self.config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
cameras_cfg = cfg.get("cameras", [])
cameras = []
for c in cameras_cfg:
cameras.append(
CameraConfig(
id=int(c["id"]),
name=str(c.get("name", f"cam_{c['id']}")),
rtsp_url=str(c["rtsp_url"]),
)
)
return cameras
def start(self):
print("[INFO] RTSPService starting...")
# 启动 WebSocket 发送线程
self.ws_sender.start()
# 启动帧处理线程
self.frame_processor.start()
# 启动每个摄像头的抓流线程
for cam in self.cameras:
w = RTSPCaptureWorker(cam, self.raw_frame_queue, self.stop_event)
w.start()
self.capture_workers.append(w)
print("[INFO] RTSPService started")
def stop(self):
print("[INFO] RTSPService stopping...")
self.stop_event.set()
# 等待队列处理完(可选)
try:
self.raw_frame_queue.join()
self.ws_send_queue.join()
except Exception:
pass
for w in self.capture_workers:
w.join(timeout=1.0)
self.frame_processor.join(timeout=1.0)
self.ws_sender.join(timeout=1.0)
print("[INFO] RTSPService stopped")
def main():
service = RTSPService(config_path="config.yaml")
service.start()
try:
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("[INFO] KeyboardInterrupt, shutting down...")
finally:
service.stop()
if __name__ == "__main__":
main()