1050 lines
37 KiB
Python
1050 lines
37 KiB
Python
# rtsp_service_ws_new.py (merged with YOLO + MMAction2 + Face Recognition)
|
||
from datetime import datetime
|
||
|
||
import cv2
|
||
import time
|
||
import threading
|
||
import queue
|
||
import yaml
|
||
import os
|
||
import json
|
||
import base64
|
||
import asyncio
|
||
import websockets
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Dict, Any, Tuple, List
|
||
|
||
# --- 新增依赖(YOLO/ONNX/跟踪/NumPy) ---
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import torch
|
||
|
||
# 导入人脸识别算法
|
||
try:
|
||
from api.routes.algorithm_router import video_face_prison_biz
|
||
|
||
print("[INFO] 成功导入人脸识别算法")
|
||
except Exception as e:
|
||
print(f"[WARN] 无法导入人脸识别算法: {e}")
|
||
|
||
from yolox.tracker.byte_tracker import BYTETracker
|
||
|
||
|
||
|
||
from npu_yolo_onnx_yolo11n import YOLOv8_ONNX
|
||
|
||
|
||
|
||
# =========================
|
||
# 配置与数据结构
|
||
# =========================
|
||
|
||
@dataclass
|
||
class CameraConfig:
|
||
id: int
|
||
name: str
|
||
rtsp_url: str
|
||
|
||
|
||
RTSP_TARGET_FPS = 30 # 固定 10 帧/秒
|
||
FRAMES_PER_SEGMENT = 1800 # 每 600 帧一个 mp4
|
||
VIDEO_OUTPUT_DIR = "./videos" # 视频输出目录
|
||
WS_HOST = "0.0.0.0" # WebSocket 服务端监听地址
|
||
WS_PORT = 8765 # WebSocket 服务端端口
|
||
|
||
# 人脸识别相关配置
|
||
FACE_RECOGNITION_ENABLED = True # 是否启用人脸识别
|
||
FACE_ALERT_COOLDOWN = 5.0 # 同一个人脸黑名单告警冷却时间(秒)
|
||
FACE_REGISTER_DIR = "test_data/register" # 人脸注册目录
|
||
|
||
# 已连接的 WebSocket 客户端集合
|
||
ws_clients = set()
|
||
|
||
# =========================
|
||
# YOLO / 动作识别 / 跟踪 配置
|
||
# =========================
|
||
|
||
# --- 请根据你实际路径修改下面的 ONNX 模型路径 ---
|
||
YOLO_ONNX_PATH = "YOLO_Weight/yolov8l.onnx" # <-- 改为实际路径
|
||
SUPERVISOR_ONNX = "ONNX_Weight/Supervisor.onnx" # <-- 改为实际路径
|
||
SUSPECT_ONNX = "ONNX_Weight/Suspect.onnx" # <-- 改为实际路径
|
||
|
||
# 动作标签(来自 Ascend_NPU_YOLO_TSM_RealTime.py)
|
||
LABELS_SUPERVISOR = {0: 'Normal', 1: 'Push', 2: 'Slap'}
|
||
LABELS_SUSPECT = {0: 'Collision', 1: 'Hanging', 2: 'Lyingdown', 3: 'Normal'}
|
||
|
||
# 超参数
|
||
Show_Framework = False
|
||
Show_MMAction2_Video = False
|
||
CLIP_LEN = 16 # 修改:缓存帧数改为32
|
||
FRAME_INTERVAL = 4 #每隔4帧采样一次
|
||
#SLIDE_STEP = 4 # 修改:滑动步长改为16
|
||
SLIDE_STEP = 4
|
||
CONF_THRESH = 0.1
|
||
ACTION_COOLDOWN = 0.0 # 同一目标重复报警间隔(秒)
|
||
EXPAND_RATIO = 0.3
|
||
LeavingPost_Threshold = 10 #离岗阈值
|
||
LeavingPost_Show_Interval = 3 #显示离岗的时间间隔
|
||
|
||
MMAction_Preprocess_Size = 256
|
||
MMAction_TARGET_SIZE = 224
|
||
CROP_OFFSET= (MMAction_Preprocess_Size-MMAction_TARGET_SIZE)//2
|
||
YOLO_CONF_THRESH = 0.25 # YOLO检测置信度阈值 0.25
|
||
YOLO_IOU_THRESH = 0.6 # YOLO NMS IoU阈值 0.7
|
||
YOLO_INPUT_SIZE_1 = 640 #YOLO的预处理的尺寸
|
||
YOLO_INPUT_SIZE_2 = 640 #YOLO的预处理的尺寸
|
||
|
||
|
||
|
||
# 缓存区(关键改动)
|
||
track_raw_frames = {} # tid → list of original full frames
|
||
track_raw_coords = {} # tid → list of expanded bbox [x1,y1,x2,y2](原始坐标)
|
||
track_role = {} # tid → "supervisor" or "suspect"
|
||
last_alert = {}
|
||
frame_idx = 0
|
||
prev_time = time.time()
|
||
|
||
track_action_result = {} # tid -> 动作结果字符串
|
||
|
||
recent_actions = [] # 存储最近的动作检测结果
|
||
MAX_RECENT_ACTIONS = 3 # 最多显示3个最近动作
|
||
ACTION_DISPLAY_DURATION = 1.0 # 动作显示持续时间(秒)
|
||
action_display_start_time = 0 # 动作显示开始时间
|
||
leavingpost_time = 0
|
||
leavingpost_show = 0
|
||
|
||
# 人脸识别相关数据结构
|
||
face_last_alert: Dict[int, Dict[str, float]] = {} # camera_id -> {person_name -> last_alert_time}
|
||
|
||
# YOLO 和动作识别 session(单例式)
|
||
yolo_model = None
|
||
sess_supervisor = None
|
||
sess_suspect = None
|
||
input_name_sup = None
|
||
input_name_sus = None
|
||
|
||
|
||
|
||
# =========================
|
||
# 初始化模型(尝试导入/创建 session)
|
||
# =========================
|
||
|
||
def init_models_once():
|
||
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
|
||
|
||
# YOLO
|
||
if YOLOv8_ONNX is None:
|
||
print("[ERROR] YOLOv8_ONNX 未导入,无法初始化 YOLO 模型")
|
||
else:
|
||
try:
|
||
yolo_model = YOLOv8_ONNX(YOLO_ONNX_PATH, conf_threshold=YOLO_CONF_THRESH, iou_threshold=YOLO_IOU_THRESH)
|
||
print("[INFO] YOLO 模型初始化完成")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] YOLO 模型初始化失败: {e}")
|
||
yolo_model = None
|
||
|
||
# -----------------------------
|
||
# 动作识别模型初始化(正确的 Provider 判断方式)
|
||
# -----------------------------
|
||
try:
|
||
# 请求使用 CANN,但是否真正启用必须用 get_providers 判断
|
||
providers = [
|
||
("CANNExecutionProvider", {
|
||
"device_id": 0,
|
||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||
"npu_mem_limit": 16 * 1024 * 1024 * 1024,
|
||
"precision_mode": "allow_fp32_to_fp16",
|
||
"op_select_impl_mode": "high_precision",
|
||
"enable_cann_graph": True,
|
||
}),
|
||
"CUDAExecutionProvider",
|
||
"CPUExecutionProvider", # 自动 fallback
|
||
]
|
||
|
||
sess_supervisor = ort.InferenceSession(SUPERVISOR_ONNX, providers=providers)
|
||
sess_suspect = ort.InferenceSession(SUSPECT_ONNX, providers=providers)
|
||
|
||
sup_prov = sess_supervisor.get_providers()
|
||
sus_prov = sess_suspect.get_providers()
|
||
|
||
print("Supervisor Providers:", sup_prov)
|
||
print("Suspect Providers:", sus_prov)
|
||
|
||
if "CANNExecutionProvider" in sup_prov:
|
||
print("[INFO] 动作识别模型:使用 CANNExecutionProvider(昇腾)")
|
||
else:
|
||
print("[INFO] 动作识别模型:使用 CPUExecutionProvider(非昇腾环境)")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 初始化动作识别模型失败: {e}")
|
||
sess_supervisor = None
|
||
sess_suspect = None
|
||
|
||
if sess_supervisor is not None:
|
||
input_name_sup = sess_supervisor.get_inputs()[0].name
|
||
print(f"[INFO] 监护人模型输入: {input_name_sup}")
|
||
if sess_suspect is not None:
|
||
input_name_sus = sess_suspect.get_inputs()[0].name
|
||
print(f"[INFO] 被监护人模型输入: {input_name_sus}")
|
||
|
||
|
||
# 只初始化一次
|
||
init_models_once()
|
||
|
||
|
||
# =========================
|
||
# 工具函数(IoU, preprocess_clip)
|
||
# =========================
|
||
|
||
def compute_iou(boxA, boxB):
|
||
# box = [x1, y1, x2, y2]
|
||
xA = max(boxA[0], boxB[0])
|
||
yA = max(boxA[1], boxB[1])
|
||
xB = min(boxA[2], boxB[2])
|
||
yB = min(boxA[3], boxB[3])
|
||
|
||
interW = max(0, xB - xA)
|
||
interH = max(0, yB - yA)
|
||
interArea = interW * interH
|
||
|
||
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
||
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
||
|
||
unionArea = boxAArea + boxBArea - interArea
|
||
if unionArea == 0:
|
||
return 0.0
|
||
|
||
return interArea / unionArea
|
||
|
||
|
||
def compute_stable_crop_region(coord_list, img_w, img_h, extra_ratio=0.05, pad=15, force_square=False):
|
||
"""计算最大包络框 + 安全垫"""
|
||
if not coord_list:
|
||
return 0, 0, img_w, img_h
|
||
|
||
x1s = [c[0] for c in coord_list]
|
||
y1s = [c[1] for c in coord_list]
|
||
x2s = [c[2] for c in coord_list]
|
||
y2s = [c[3] for c in coord_list]
|
||
|
||
ux1 = max(0, min(x1s))
|
||
uy1 = max(0, min(y1s))
|
||
ux2 = min(img_w, max(x2s))
|
||
uy2 = min(img_h, max(y2s))
|
||
|
||
w = ux2 - ux1
|
||
h = uy2 - uy1
|
||
extra_w = int(w * extra_ratio)
|
||
extra_h = int(h * extra_ratio)
|
||
|
||
final_x1 = max(0, ux1 - extra_w - pad)
|
||
final_y1 = max(0, uy1 - extra_h - pad)
|
||
final_x2 = min(img_w, ux2 + extra_w + pad)
|
||
final_y2 = min(img_h, uy2 + extra_h + pad)
|
||
|
||
# 强制转换为正方形
|
||
if force_square:
|
||
square_size = max(final_x2 - final_x1, final_y2 - final_y1)
|
||
center_x = (final_x1 + final_x2) // 2
|
||
center_y = (final_y1 + final_y2) // 2
|
||
|
||
final_x1 = max(0, center_x - square_size // 2)
|
||
final_y1 = max(0, center_y - square_size // 2)
|
||
final_x2 = min(img_w, final_x1 + square_size)
|
||
final_y2 = min(img_h, final_y1 + square_size)
|
||
|
||
# 确保尺寸一致
|
||
size = min(final_x2 - final_x1, final_y2 - final_y1)
|
||
final_x2 = final_x1 + size
|
||
final_y2 = final_y1 + size
|
||
|
||
return int(final_x1), int(final_y1), int(final_x2), int(final_y2)
|
||
|
||
|
||
def preprocess_clip(frames):
|
||
"""优化版本的预处理函数"""
|
||
# 采样帧索引
|
||
#indices = list(range(0, CLIP_LEN, int(FRAME_INTERVAL)))
|
||
indices = list(range(0, int(CLIP_LEN)))
|
||
|
||
h, w = frames[0].shape[:2]
|
||
# 预分配结果数组
|
||
imgs = np.empty((len(indices), 3, h, w), dtype=np.float32)
|
||
|
||
for i, index in enumerate(indices):
|
||
f = frames[index]
|
||
# 使用更快的插值方法
|
||
# f = cv2.resize(f, (MMAction_TARGET_SIZE, MMAction_TARGET_SIZE),
|
||
# interpolation=cv2.INTER_LINEAR)
|
||
# 直接赋值避免额外拷贝
|
||
f = f.transpose(2, 0, 1).astype(np.float32)
|
||
imgs[i] = f
|
||
|
||
# 添加批次维度
|
||
x = imgs[np.newaxis]
|
||
|
||
# 预计算的归一化参数
|
||
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32).reshape(1, 3, 1, 1)
|
||
std = np.array([58.395, 57.12, 57.375], dtype=np.float32).reshape(1, 3, 1, 1)
|
||
|
||
# 向量化归一化
|
||
result = (x - mean) / std
|
||
return result.astype(np.float32)
|
||
|
||
def save_clip_as_video(clip_frames, tid, role, frame_idx, save_dir="../debug_clips"):
|
||
"""
|
||
把 16 帧 list[ndarray] 保存为 mp4,文件名包含 ID、角色、起始帧号
|
||
"""
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 16 帧 → 假设原视频 30fps,这里也用 30fps(或用原 fps)
|
||
clip_fps = 30 # 可调小点,视频更短更清晰
|
||
|
||
h, w = clip_frames[0].shape[:2]
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
filename = f"{save_dir}/Clip_ID{tid:03d}_{role}_F{frame_idx:06d}_{timestamp}.mp4"
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
out = cv2.VideoWriter(filename, fourcc, clip_fps, (w, h))
|
||
|
||
for frame in clip_frames:
|
||
# 如果你的是已经 resize 到 224×224 的 BGR 图像,直接写
|
||
out.write(frame)
|
||
|
||
out.release()
|
||
print(f"[DEBUG] 已保存调试视频 → {filename}")
|
||
|
||
# ByteTrack
|
||
class TrackerArgs:
|
||
track_thresh = 0.5
|
||
track_buffer = 30
|
||
match_thresh = 0.8
|
||
mot20 = False
|
||
|
||
tracker = BYTETracker(TrackerArgs(), frame_rate=RTSP_TARGET_FPS)
|
||
|
||
|
||
|
||
# =========================
|
||
# WebSocket 服务线程(不变)
|
||
# =========================
|
||
|
||
class WebSocketSender(threading.Thread):
|
||
def __init__(self, send_queue: "queue.Queue[Dict[str, Any]]", stop_event: threading.Event):
|
||
super().__init__(daemon=True)
|
||
self.send_queue = send_queue
|
||
self.stop_event = stop_event
|
||
|
||
async def _ws_handler(self, websocket):
|
||
ws_clients.add(websocket)
|
||
try:
|
||
async for _ in websocket:
|
||
pass
|
||
finally:
|
||
ws_clients.discard(websocket)
|
||
|
||
async def _broadcaster(self):
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
msg = await asyncio.to_thread(self.send_queue.get, timeout=0.5)
|
||
except queue.Empty:
|
||
continue
|
||
|
||
data = json.dumps(msg)
|
||
dead = []
|
||
for ws in list(ws_clients):
|
||
try:
|
||
await ws.send(data)
|
||
except Exception:
|
||
dead.append(ws)
|
||
for ws in dead:
|
||
ws_clients.discard(ws)
|
||
self.send_queue.task_done()
|
||
|
||
async def _run_async(self):
|
||
async with websockets.serve(self._ws_handler, WS_HOST, WS_PORT):
|
||
print(f"[INFO] WebSocket server started at ws://{WS_HOST}:{WS_PORT}")
|
||
await self._broadcaster()
|
||
|
||
def run(self):
|
||
asyncio.run(self._run_async())
|
||
|
||
|
||
# =========================
|
||
# RTSP 抓流线程(不变)
|
||
# =========================
|
||
|
||
class RTSPCaptureWorker(threading.Thread):
|
||
def __init__(
|
||
self,
|
||
camera_cfg: CameraConfig,
|
||
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
|
||
stop_event: threading.Event,
|
||
):
|
||
super().__init__(daemon=True)
|
||
self.camera_cfg = camera_cfg
|
||
self.raw_frame_queue = raw_frame_queue
|
||
self.stop_event = stop_event
|
||
|
||
def run(self):
|
||
cap = cv2.VideoCapture(self.camera_cfg.rtsp_url, cv2.CAP_FFMPEG)
|
||
|
||
if not cap.isOpened():
|
||
print(f"[ERROR] Cannot open RTSP stream: {self.camera_cfg.rtsp_url}")
|
||
return
|
||
|
||
print(f"[INFO] Start capturing: id={self.camera_cfg.id}, name={self.camera_cfg.name}")
|
||
|
||
while not self.stop_event.is_set():
|
||
ok, frame = cap.read()
|
||
if not ok:
|
||
print(f"[WARN] Failed to read frame from camera {self.camera_cfg.id}, retrying...")
|
||
time.sleep(0.2)
|
||
continue
|
||
|
||
ts = time.time()
|
||
item = {
|
||
"camera_id": self.camera_cfg.id,
|
||
"camera_name": self.camera_cfg.name,
|
||
"timestamp": ts,
|
||
"frame": frame,
|
||
}
|
||
|
||
try:
|
||
self.raw_frame_queue.put(item, timeout=1.0)
|
||
except queue.Full:
|
||
print(f"[WARN] Raw frame queue full, drop frame from camera {self.camera_cfg.id}")
|
||
|
||
cap.release()
|
||
print(f"[INFO] Stop capturing: id={self.camera_cfg.id}")
|
||
|
||
|
||
# =========================
|
||
# 帧处理线程(抽帧 + 写mp4 + 调用用户函数 + 发WebSocket消息)
|
||
# =========================
|
||
|
||
class FrameProcessorWorker(threading.Thread):
|
||
def __init__(
|
||
self,
|
||
raw_frame_queue: "queue.Queue[Dict[str, Any]]",
|
||
ws_send_queue: "queue.Queue[Dict[str, Any]]",
|
||
stop_event: threading.Event,
|
||
):
|
||
super().__init__(daemon=True)
|
||
self.raw_frame_queue = raw_frame_queue
|
||
self.ws_send_queue = ws_send_queue
|
||
self.stop_event = stop_event
|
||
|
||
# 每个摄像头独立维护视频写入状态
|
||
self.video_writers: Dict[int, cv2.VideoWriter] = {}
|
||
self.video_frame_counts: Dict[int, int] = {}
|
||
self.video_segment_start_ts: Dict[int, float] = {}
|
||
self.video_segment_filenames: Dict[int, str] = {}
|
||
|
||
os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 控制 10fps 抽帧:记录每个摄像头上次处理时间
|
||
self.last_process_ts: Dict[int, float] = {}
|
||
|
||
# 人脸识别相关初始化
|
||
self.face_last_alert = face_last_alert
|
||
self.current_face_alert = None # 存储当前帧的人脸告警信息
|
||
|
||
def _get_video_writer(self, camera_id: int, frame) -> Tuple[cv2.VideoWriter, str]:
|
||
writer = self.video_writers.get(camera_id)
|
||
if writer is not None:
|
||
return writer, self.video_segment_filenames[camera_id]
|
||
|
||
h, w = frame.shape[:2]
|
||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||
|
||
start_ts = time.time()
|
||
self.video_segment_start_ts[camera_id] = start_ts
|
||
|
||
ts_str = time.strftime("%Y%m%d_%H%M%S", time.localtime(start_ts))
|
||
filename = f"{ts_str}_cam{camera_id}.mp4"
|
||
filepath = os.path.join(VIDEO_OUTPUT_DIR, filename)
|
||
|
||
writer = cv2.VideoWriter(filepath, fourcc, RTSP_TARGET_FPS, (w, h))
|
||
self.video_writers[camera_id] = writer
|
||
self.video_frame_counts[camera_id] = 0
|
||
self.video_segment_filenames[camera_id] = filepath
|
||
|
||
print(f"[INFO] Start new segment: camera={camera_id}, file={filepath}")
|
||
return writer, filepath
|
||
|
||
def _close_segment_if_needed(self, camera_id: int):
|
||
count = self.video_frame_counts.get(camera_id, 0)
|
||
if count >= FRAMES_PER_SEGMENT:
|
||
writer = self.video_writers.get(camera_id)
|
||
if writer is not None:
|
||
writer.release()
|
||
print(f"[INFO] Close segment: camera={camera_id}, file={self.video_segment_filenames[camera_id]}")
|
||
|
||
self.video_writers.pop(camera_id, None)
|
||
self.video_frame_counts.pop(camera_id, None)
|
||
self.video_segment_start_ts.pop(camera_id, None)
|
||
self.video_segment_filenames.pop(camera_id, None)
|
||
|
||
def _encode_image_to_base64(self, image) -> str:
|
||
ok, buf = cv2.imencode(".jpg", image)
|
||
if not ok:
|
||
raise RuntimeError("Failed to encode image to JPEG")
|
||
return base64.b64encode(buf.tobytes()).decode("ascii")
|
||
|
||
def run(self):
|
||
print("[INFO] FrameProcessorWorker started")
|
||
target_interval = 1.0 / RTSP_TARGET_FPS
|
||
|
||
while not self.stop_event.is_set():
|
||
try:
|
||
item = self.raw_frame_queue.get(timeout=0.5)
|
||
except queue.Empty:
|
||
continue
|
||
|
||
camera_id = item["camera_id"]
|
||
ts = item["timestamp"]
|
||
frame = item["frame"]
|
||
|
||
last_ts = self.last_process_ts.get(camera_id, 0.0)
|
||
if ts - last_ts < target_interval:
|
||
self.raw_frame_queue.task_done()
|
||
continue
|
||
self.last_process_ts[camera_id] = ts
|
||
|
||
# 1) 写入 mp4 (当前segment)
|
||
writer, video_filepath = self._get_video_writer(camera_id, frame)
|
||
|
||
|
||
# 2) 进行人脸识别(如果启用)
|
||
face_results = []
|
||
face_processing_time = 0
|
||
if video_face_prison_biz is not None and FACE_RECOGNITION_ENABLED:
|
||
try:
|
||
# 处理当前帧 - 获取人脸识别结果
|
||
processed_frame_for_face, face_results, face_processing_time = video_face_prison_biz.process_frame(
|
||
frame)
|
||
|
||
# 检查是否有黑名单匹配
|
||
if camera_id not in self.face_last_alert:
|
||
self.face_last_alert[camera_id] = {}
|
||
|
||
for result in face_results:
|
||
if result['has_passed']:
|
||
print(f"[INFO] 犯人带出: {result['passed_person_id']}")
|
||
|
||
if result['is_match'] and result['best_match']:
|
||
person_name = result['best_match']
|
||
similarity = result['similarity']
|
||
|
||
# 检查是否需要发送告警(冷却时间控制)
|
||
last_alert_time = self.face_last_alert[camera_id].get(person_name, 0)
|
||
if ts - last_alert_time > FACE_ALERT_COOLDOWN:
|
||
# 记录当前帧人脸告警信息
|
||
self.current_face_alert = {
|
||
"person_name": person_name,
|
||
"similarity": similarity,
|
||
"timestamp": ts
|
||
}
|
||
print(f"[FACE ALERT] Camera:{camera_id} | Person:{person_name} | Similarity:{similarity:.3f}")
|
||
|
||
# 更新最后告警时间
|
||
self.face_last_alert[camera_id][person_name] = ts
|
||
except Exception as e:
|
||
print(f"[WARN] 人脸识别处理失败: {e}")
|
||
|
||
# 3) 调用用户自定义处理逻辑(YOLO+动作识别)
|
||
result = user_process_frame(frame, camera_id, ts)
|
||
|
||
result_img = None
|
||
#self.current_face_alert = None # 重置当前帧人脸告警信息
|
||
# 4) 在人脸识别结果上绘制人脸检测框
|
||
if result is not None and "image" in result:
|
||
result_img = result["image"]
|
||
#result_type = int(result["type"]) if "type" in result else 0
|
||
abnormal_actions = result["actions"]
|
||
current_fps = result["FPS"]
|
||
current_person = result["person_count"]
|
||
tracks = result["Targets"]
|
||
|
||
|
||
#============================================================================================================
|
||
#绘制动作识别结果
|
||
info = [
|
||
f"FPS: {current_fps:.1f}",
|
||
f"Targets: {tracks}",
|
||
f"Person: {current_person}",
|
||
]
|
||
for i, text in enumerate(info):
|
||
cv2.putText(result_img, text, (10, 35 + i * 38), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 255), 2)
|
||
|
||
# 动作检测状态显示(在基础信息下方)
|
||
status_y = 35 + len(info) * 38 + 20
|
||
|
||
# 检查是否有异常动作
|
||
has_abnormal_action = len(abnormal_actions) > 0
|
||
|
||
if has_abnormal_action:
|
||
# 显示"ACTION DETECTED!"标题
|
||
cv2.putText(result_img, "ACTION DETECTED!", (10, status_y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0, 0, 255), 3)
|
||
|
||
# 显示每个检测到的异常动作
|
||
for i, action_info in enumerate(abnormal_actions):
|
||
|
||
action_name = action_info['action']
|
||
conf = action_info['confidence']
|
||
|
||
# 根据角色选择颜色
|
||
if action_info['role'] == "person":
|
||
action_color = (255, 0, 0) # 蓝色
|
||
else:
|
||
action_color = (0, 0, 255) # 红色
|
||
|
||
# 显示动作信息
|
||
action_display = f"{action_name} ({conf:.2f})"
|
||
cv2.putText(result_img, action_display, (10, status_y + 40 + i * 35),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, action_color, 2)
|
||
else:
|
||
# 检查是否在显示时间内
|
||
if ts - action_display_start_time < ACTION_DISPLAY_DURATION and recent_actions:
|
||
# 显示最近检测到的动作
|
||
cv2.putText(result_img, "RECENT ACTIONS:", (10, status_y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 0), 2)
|
||
|
||
# 显示最近的动作
|
||
for i, action_info in enumerate(recent_actions[-MAX_RECENT_ACTIONS:]):
|
||
|
||
action_name = action_info['action']
|
||
conf = action_info['confidence']
|
||
|
||
# 根据角色选择颜色
|
||
if action_info['role'] == "person":
|
||
action_color = (255, 0, 0) # 蓝色
|
||
else:
|
||
action_color = (0, 0, 255) # 红色
|
||
|
||
# 显示动作信息
|
||
action_display = f"{action_name} ({conf:.2f})"
|
||
cv2.putText(result_img, action_display, (10, status_y + 40 + i * 35),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, action_color, 2)
|
||
else:
|
||
# 显示"Detecting..."
|
||
cv2.putText(result_img, "Detecting...", (10, status_y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0, 255, 0), 2)
|
||
#============================================================================================================
|
||
|
||
|
||
|
||
|
||
# 绘制人脸识别结果
|
||
if video_face_prison_biz is not None and face_results:
|
||
result_img = video_face_prison_biz.draw_detections(result_img, face_results)
|
||
|
||
# 添加人脸识别统计信息
|
||
match_count = sum(1 for r in face_results if r['is_match'])
|
||
face_info_text = f"Faces: {len(face_results)} | Matches: {match_count}"
|
||
cv2.putText(result_img, face_info_text, (10, result_img.shape[0] - 20),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
||
|
||
# 5) 通过 WebSocket 发送帧结果
|
||
try:
|
||
img_b64 = self._encode_image_to_base64(result_img)
|
||
except Exception as e:
|
||
print(f"[ERROR] Encode image failed: {e}")
|
||
img_b64 = None
|
||
|
||
if img_b64 is not None:
|
||
# 将abnormal_actions对象数组转换为字符串数组
|
||
action_names = [action_info['action'] for action_info in abnormal_actions]
|
||
|
||
# 如果有当前帧人脸告警,添加到result_type中
|
||
if self.current_face_alert is not None:
|
||
action_names.append("face")
|
||
|
||
msg = {
|
||
"msg_type": "frame",
|
||
"camera_id": camera_id,
|
||
"timestamp": ts,
|
||
"result_type": action_names,
|
||
"image_base64": img_b64,
|
||
}
|
||
try:
|
||
self.ws_send_queue.put(msg, timeout=1.0)
|
||
except queue.Full:
|
||
print("[WARN] ws_send_queue full, drop frame message")
|
||
|
||
|
||
|
||
# 7) 检查是否需要切换到下一个 mp4 segment
|
||
self.current_face_alert = None
|
||
self._close_segment_if_needed(camera_id)
|
||
|
||
self.raw_frame_queue.task_done()
|
||
|
||
writer.write(result_img)
|
||
self.video_frame_counts[camera_id] = self.video_frame_counts.get(camera_id, 0) + 1
|
||
|
||
# 退出时,关闭所有 VideoWriter
|
||
for cam_id, writer in list(self.video_writers.items()):
|
||
writer.release()
|
||
print(f"[INFO] Release writer on exit: camera={cam_id}")
|
||
print("[INFO] FrameProcessorWorker stopped")
|
||
|
||
|
||
# =========================
|
||
# 用户自定义函数(重要:已集成 YOLO + 动作识别 + 跟踪 + 告警)
|
||
# =========================
|
||
|
||
def user_process_frame(image, camera_id: int, timestamp: float) -> Dict[str, Any]:
|
||
"""
|
||
集成了:
|
||
1. 视频帧输入
|
||
2. YOLO 目标检测(Supervisor / Suspect)
|
||
3. 对每个检测到的人物:
|
||
- 裁剪 ROI
|
||
- 预处理(resize 等)
|
||
- 根据类别选择动作识别模型(supervisor / suspect)
|
||
- 执行动作识别 ONNX 推理
|
||
- 解析动作类别并判断是否触发告警
|
||
4. 绘制结果(检测框、标签、告警)
|
||
5. 返回处理后的图像与告警类型
|
||
注意:尽量保持原实现逻辑(Ascend_NPU_YOLO_TSM_RealTime.py)
|
||
"""
|
||
global track_raw_frames, track_raw_coords, last_alert, track_role, \
|
||
track_action_result, recent_actions, frame_idx, prev_time, leavingpost_time, leavingpost_show, action_display_start_time
|
||
global yolo_model, sess_supervisor, sess_suspect, input_name_sup, input_name_sus
|
||
|
||
|
||
|
||
frame = image # BGR
|
||
h, w = frame.shape[:2]
|
||
|
||
# === 1. YOLO 检测 ===
|
||
detections = []
|
||
if yolo_model is not None:
|
||
try:
|
||
detections = yolo_model(frame)
|
||
# detections 格式: list of [x1,y1,x2,y2, conf, cls_id]
|
||
except Exception as e:
|
||
print(f"[WARN] YOLO 推理失败: {e}")
|
||
detections = []
|
||
else:
|
||
# 如果没有模型,返回原图(不报警)
|
||
return {"image": frame, "type": 0}
|
||
|
||
current_time_sec = timestamp # 使用外部传入的时间戳(秒)
|
||
frame_idx += 1
|
||
|
||
|
||
dets_xyxy = []
|
||
dets_roles = []
|
||
dets_for_tracker = []
|
||
|
||
person_count = 0
|
||
|
||
if detections:
|
||
for det in detections:
|
||
x1, y1, x2, y2, conf, cls_id = det # x1, y1, x2, y2为角点坐标,x1 y1为左上角,x2 y2为右下角
|
||
dets_xyxy.append([x1, y1, x2, y2])
|
||
dets_for_tracker.append([x1, y1, x2, y2, conf])
|
||
if cls_id == 0:
|
||
dets_roles.append("person")
|
||
person_count += 1
|
||
|
||
dets = np.array(dets_for_tracker, dtype=np.float32) if len(dets_for_tracker) else np.empty((0, 5))
|
||
tracks = tracker.update(
|
||
# torch.from_numpy(dets_for_tracker).float() if len(dets_for_tracker) else torch.zeros((0, 5)),
|
||
dets,
|
||
[h, w],
|
||
[h, w]
|
||
)
|
||
|
||
current_frame_abnormal_actions = []
|
||
|
||
# === 3. 每个 track 做 IoU 匹配角色并做动作识别 ===
|
||
for t in tracks:
|
||
tid = t.track_id
|
||
|
||
role = track_role.get(tid, "unknown")
|
||
|
||
if tid not in track_role:
|
||
best_iou = 0
|
||
best_role = "unknown"
|
||
|
||
t_box = list(map(float, t.tlbr)) # [x1,y1,x2,y2]
|
||
|
||
for i, box in enumerate(dets_xyxy):
|
||
iou_val = compute_iou(t_box, box)
|
||
if iou_val > best_iou:
|
||
best_iou = iou_val
|
||
best_role = dets_roles[i]
|
||
if best_iou > 0.1:
|
||
track_role[tid] = best_role
|
||
else:
|
||
track_role[tid] = "unknown"
|
||
|
||
x1, y1, x2, y2 = map(int, t.tlbr)
|
||
|
||
|
||
if tid not in track_raw_frames:
|
||
track_raw_frames[tid] = []
|
||
track_raw_coords[tid] = []
|
||
|
||
if frame_idx % FRAME_INTERVAL == 0:
|
||
track_raw_coords[tid].append([x1, y1, x2, y2])
|
||
track_raw_frames[tid].append(frame.copy())
|
||
|
||
# 保持缓存长度
|
||
if len(track_raw_frames[tid]) > CLIP_LEN:
|
||
track_raw_frames[tid] = track_raw_frames[tid][-CLIP_LEN:]
|
||
track_raw_coords[tid] = track_raw_coords[tid][-CLIP_LEN:]
|
||
|
||
# 默认值
|
||
conf = 0.0
|
||
action = ""
|
||
|
||
# 若缓存达到 CLIP_LEN,则做动作识别
|
||
if len(track_raw_frames[tid]) >= CLIP_LEN:
|
||
|
||
fx1, fy1, fx2, fy2 = compute_stable_crop_region(track_raw_coords[tid], w, h, EXPAND_RATIO)
|
||
|
||
stable_clip = []
|
||
# for raw_f in sameples_frames:
|
||
for raw_f in track_raw_frames[tid]:
|
||
crop = raw_f[fy1:fy2, fx1:fx2]
|
||
if crop.size == 0:
|
||
crop = np.zeros((fy2 - fy1, fx2 - fx1, 3), dtype=np.uint8)
|
||
crop = cv2.resize(crop, (MMAction_TARGET_SIZE, MMAction_TARGET_SIZE))
|
||
stable_clip.append(crop)
|
||
|
||
if Show_MMAction2_Video == True:
|
||
save_clip_as_video(stable_clip, tid, role, frame_idx)
|
||
|
||
tensor = preprocess_clip(stable_clip)
|
||
|
||
# 再次确保类型正确
|
||
if tensor.dtype != np.float32:
|
||
tensor = tensor.astype(np.float32)
|
||
|
||
pred = None
|
||
labels = None
|
||
try:
|
||
if 2 <= person_count <= 3:
|
||
pred = sess_supervisor.run(None, {input_name_sup: tensor})[0]
|
||
labels = LABELS_SUPERVISOR
|
||
elif person_count == 1:
|
||
pred = sess_suspect.run(None, {input_name_sus: tensor})[0]
|
||
labels = LABELS_SUSPECT
|
||
|
||
# 增加维度和空值校验
|
||
if pred is None or len(pred) == 0:
|
||
action = "Unknown"
|
||
conf = 0.0
|
||
else:
|
||
# 现有逻辑(pred已是置信度)
|
||
idx = int(np.argmax(pred[0]))
|
||
conf = float(pred[0][idx])
|
||
action = labels.get(idx, "Unknown")
|
||
action_text = f"{action}({conf:.2f})"
|
||
except Exception as e:
|
||
print(f"❌ 动作识别失败:{str(e)[:50]}")
|
||
action = "识别失败"
|
||
conf = 0.0
|
||
action_text = f"{action}({conf:.2f})"
|
||
|
||
|
||
# 检查角色-动作匹配
|
||
should_alert = False
|
||
|
||
if (action == 'Slap' or action == 'Push') and role == 'person' and 2 <= person_count <= 3:
|
||
should_alert = True
|
||
# 存储动作结果
|
||
track_action_result[tid] = f"{action}({conf:.2f})"
|
||
# print(f"⏰ 时间:{current_time_sec:.2f} | ID: {tid} | 动作:{action} | 置信度:{conf:.2f}")
|
||
elif (
|
||
action == 'Hanging' or action == 'Collision' or action == 'Lyingdown') and role == 'person' and person_count == 1:
|
||
should_alert = True
|
||
track_action_result[tid] = f"{action}({conf:.2f})"
|
||
|
||
# 报警逻辑 - 只对匹配的动作报警
|
||
if (should_alert and action != 'Normal' and conf >= CONF_THRESH and
|
||
(tid not in last_alert or current_time_sec - last_alert[tid] > ACTION_COOLDOWN)):
|
||
print(f"【报警】⏰ 时间:{current_time_sec:.2f} | {action} ({conf:.3f})")
|
||
last_alert[tid] = current_time_sec
|
||
|
||
# 记录异常动作
|
||
action_info = {
|
||
'time': current_time_sec,
|
||
'role': role,
|
||
'action': action,
|
||
'confidence': conf
|
||
}
|
||
|
||
# 添加到最近动作列表
|
||
recent_actions.append(action_info)
|
||
|
||
# 限制列表长度
|
||
if len(recent_actions) > MAX_RECENT_ACTIONS:
|
||
recent_actions.pop(0)
|
||
|
||
# 设置动作显示开始时间
|
||
action_display_start_time = current_time_sec
|
||
|
||
# 添加到当前帧异常动作列表
|
||
current_frame_abnormal_actions.append(action_info)
|
||
|
||
# track_buffers[tid] = track_buffers[tid][SLIDE_STEP:]
|
||
# 滑动窗口(保留后24帧)
|
||
track_raw_frames[tid] = track_raw_frames[tid][SLIDE_STEP:]
|
||
track_raw_coords[tid] = track_raw_coords[tid][SLIDE_STEP:]
|
||
|
||
if Show_Framework == True:
|
||
color = (255, 0, 0)
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
|
||
overlay = frame.copy()
|
||
cv2.rectangle(overlay, (x1, y1 - 48), (x1 + 420, y1), color, -1)
|
||
cv2.addWeighted(overlay, 0.75, frame, 0.25, 0, frame)
|
||
cv2.putText(frame, f"{role.upper()} ID:{tid}", (x1 + 8, y1 - 25),
|
||
cv2.FONT_HERSHEY_DUPLEX, 0.8, (255, 255, 255), 2)
|
||
action_color = (0, 0, 255) # 红色显示异常动作
|
||
if tid in track_action_result:
|
||
cv2.putText(frame, track_action_result[tid], (x1 + 8, y1 - 3),
|
||
cv2.FONT_HERSHEY_DUPLEX, 0.9, action_color, 2)
|
||
cv2.putText(frame, "ALERT!", (x2 - 130, y1 - 8),
|
||
cv2.FONT_HERSHEY_COMPLEX, 1.1, (0, 0, 255), 3)
|
||
|
||
|
||
|
||
if 1 <= person_count < 3:
|
||
leavingpost_time += 1
|
||
leavingpost_show += 1
|
||
if int(leavingpost_time / RTSP_TARGET_FPS) > LeavingPost_Threshold and (int(leavingpost_show / RTSP_TARGET_FPS) % LeavingPost_Show_Interval) == 0:
|
||
current_frame_abnormal_actions.append(
|
||
{
|
||
'time': current_time_sec,
|
||
'role': 'person',
|
||
'action': "LeavingPost",
|
||
'confidence': 1.0
|
||
}
|
||
)
|
||
print(f'【报警】⏰ 时间:{current_time_sec:.2f} | LeavingPost')
|
||
|
||
elif person_count == 3:
|
||
leavingpost_show = 0
|
||
leavingpost_time = 0
|
||
|
||
|
||
# 全局信息
|
||
current_fps = 1.0 / (time.time() - prev_time) if frame_idx > 1 else 0
|
||
prev_time = time.time()
|
||
|
||
# 正确统计当前帧的角色
|
||
current_person = 0
|
||
|
||
for t in tracks:
|
||
tid = t.track_id
|
||
role = track_role.get(tid, "unknown")
|
||
if role == "person":
|
||
current_person += 1
|
||
|
||
|
||
|
||
# 返回:image 与 type(type 0 = 无报警;1 = 有报警)
|
||
#result_type = 1 if any(len(recent_actions[camera_id]) > 0 and item.get('confidence', 0) >= CONF_THRESH for item in recent_actions[camera_id]) else 0
|
||
|
||
return {
|
||
"image": frame,
|
||
"actions": current_frame_abnormal_actions,
|
||
"FPS": current_fps,
|
||
"person_count": person_count,
|
||
"Targets": len(tracks),
|
||
}
|
||
|
||
|
||
# =========================
|
||
# 服务封装(不变)
|
||
# =========================
|
||
|
||
class RTSPService:
|
||
def __init__(self, config_path: str):
|
||
self.config_path = config_path
|
||
self.cameras = self._load_config()
|
||
|
||
self.stop_event = threading.Event()
|
||
|
||
# 队列
|
||
self.raw_frame_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=500)
|
||
self.ws_send_queue: "queue.Queue[Dict[str, Any]]" = queue.Queue(maxsize=1000)
|
||
|
||
# 线程
|
||
self.capture_workers = []
|
||
self.frame_processor = FrameProcessorWorker(self.raw_frame_queue, self.ws_send_queue, self.stop_event)
|
||
self.ws_sender = WebSocketSender(self.ws_send_queue, self.stop_event)
|
||
|
||
def _load_config(self):
|
||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
cameras_cfg = cfg.get("cameras", [])
|
||
cameras = []
|
||
for c in cameras_cfg:
|
||
cameras.append(
|
||
CameraConfig(
|
||
id=int(c["id"]),
|
||
name=str(c.get("name", f"cam_{c['id']}")),
|
||
rtsp_url=str(c["rtsp_url"]),
|
||
)
|
||
)
|
||
return cameras
|
||
|
||
def start(self):
|
||
print("[INFO] RTSPService starting...")
|
||
|
||
# 启动 WebSocket 发送线程
|
||
self.ws_sender.start()
|
||
|
||
# 启动帧处理线程
|
||
self.frame_processor.start()
|
||
|
||
# 启动每个摄像头的抓流线程
|
||
for cam in self.cameras:
|
||
w = RTSPCaptureWorker(cam, self.raw_frame_queue, self.stop_event)
|
||
w.start()
|
||
self.capture_workers.append(w)
|
||
|
||
print("[INFO] RTSPService started")
|
||
|
||
def stop(self):
|
||
print("[INFO] RTSPService stopping...")
|
||
self.stop_event.set()
|
||
|
||
# 等待队列处理完(可选)
|
||
try:
|
||
self.raw_frame_queue.join()
|
||
self.ws_send_queue.join()
|
||
except Exception:
|
||
pass
|
||
|
||
for w in self.capture_workers:
|
||
w.join(timeout=1.0)
|
||
self.frame_processor.join(timeout=1.0)
|
||
self.ws_sender.join(timeout=1.0)
|
||
|
||
print("[INFO] RTSPService stopped")
|
||
|
||
|
||
def main():
|
||
service = RTSPService(config_path="config.yaml")
|
||
service.start()
|
||
try:
|
||
while True:
|
||
time.sleep(1.0)
|
||
except KeyboardInterrupt:
|
||
print("[INFO] KeyboardInterrupt, shutting down...")
|
||
finally:
|
||
service.stop()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |