Files
SupervisorAI/biz/base_frame_processor.py

482 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# base_frame_processor.py
# 通用帧处理工作线程基类
# 统一异步POST、告警去重、抽帧控制等公共逻辑
import cv2
import base64
import json
import os
import subprocess
import time
import threading
import queue
import requests
from typing import Dict, Any, Callable
from concurrent.futures import ThreadPoolExecutor
from common import constants
from common.type_mapping import get_alert_label
from utils.logger import get_logger
from utils.hls_utils import get_segments_before_current, parse_segment_info
logger = get_logger(__name__)
# 告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0
class BaseFrameProcessorWorker(threading.Thread):
"""
通用帧处理工作线程基类
功能:
- 统一异步POST线程池
- 统一告警去重逻辑
- 统一抽帧控制
- 统一异常处理
子类仅需提供:
- DETECTOR_FACTORY: 检测器工厂函数
- POST_TYPE: POST请求的type值
- TARGET_FPS: 目标帧率
"""
# ========== 子类必须重写的类常量 ==========
DETECTOR_FACTORY: Callable = None # 检测器工厂函数
POST_TYPE: int = 2 # POST type
TARGET_FPS: float = 10.0 # 目标帧率
def __init__(self,
raw_queue: queue.Queue,
ws_queue: queue.Queue,
stop_event: threading.Event,
cameras=None,
post_workers: int = 4):
super().__init__(daemon=True)
self.raw_queue = raw_queue
self.ws_queue = ws_queue
self.stop_event = stop_event
# 将摄像头列表转换为字典key为id方便通过camera_id快速查找
self.cameras = {cam.id: cam for cam in cameras} if cameras is not None else {}
self.last_ts: Dict[int, float] = {}
# 检测器实例缓存
self.detectors: Dict[int, Any] = {}
# 告警去重记录 {camera_id: {action: last_push_time}}
self.last_alert_push_time: Dict[int, Dict[str, float]] = {}
# 异步POST线程池
self.post_executor = ThreadPoolExecutor(
max_workers=post_workers,
thread_name_prefix="alert_post"
)
# MP4缓存 {segment_path: mp4_path}
self._mp4_cache: Dict[str, str] = {}
# 启动视频文件清理线程
self._start_cleanup_thread()
def _encode_image_to_base64(self, img) -> str:
"""图像编码为 Base64"""
ok, buf = cv2.imencode(".jpg", img)
if not ok:
raise RuntimeError("Failed to encode image to JPEG")
return base64.b64encode(buf.tobytes()).decode("ascii")
def _expand_msg_by_ori_alert(self, msg: dict) -> list:
"""
将 msg 中的 ori_alert 数组展开为多个独立的 msg
Args:
msg: 原始消息,包含 ori_alert 数组
original_image_b64: 原始图像的 base64 编码(作为后备)
Returns:
msg 列表,每个 msg 的 result_type 为包含 action_code 和 action_name 的对象
"""
ori_alerts = msg.get("ori_alert", [])
# 如果没有 ori_alert 或为空,直接返回原消息
if not ori_alerts:
new_msg = msg.copy()
new_msg.pop("ori_alert", None)
return [new_msg]
result = []
for alert_item in ori_alerts:
action_code = alert_item.get("action")
if not action_code:
continue
new_msg = msg.copy()
# 处理 image优先使用 ori_alert 中的 image否则使用原来的
alert_image = alert_item.get("image")
if alert_image is not None:
try:
new_msg["image_base64"] = self._encode_image_to_base64(alert_image)
except Exception as e:
logger.warning(f"[WARN] Failed to encode alert image: {e}, using original")
# 设置 result_type
new_msg["result_type"] = {
"action_code": action_code,
"action_name": get_alert_label(action_code)
}
# 移除 ori_alert
new_msg.pop("ori_alert", None)
result.append(new_msg)
return result
def _post_alert(self, msg: dict):
"""异步发送告警 POST 请求(在线程池中执行)- 旧接口,保留备用"""
try:
response = requests.post(constants.ALERT_PUSH_URL, json=msg, timeout=5.0)
if response.status_code == 200:
print(f"[INFO] POST alert sent successfully for actions: {msg.get('result_type')}")
else:
print(f"[WARN] POST alert failed with status: {response.status_code}")
except Exception as e:
print(f"[ERROR] POST alert request failed: {e}")
def _post_alert_with_video(self, msg: dict, video_path: str = None):
"""
异步发送告警 POST 请求带视频multipart/form-data
Args:
msg: 消息内容
video_path: 视频文件路径(可选)
"""
try:
if video_path and os.path.exists(video_path):
# 有视频,使用 multipart/form-data 上传
with open(video_path, 'rb') as f:
files = {
'video': f,
'metadata': (None, json.dumps(msg))
}
response = requests.post(constants.ALERT_PUSH_URL, files=files, timeout=10.0)
else:
# 无视频,也使用 multipart/form-data
files = {
'metadata': (None, json.dumps(msg))
}
response = requests.post(constants.ALERT_PUSH_URL, files=files, timeout=5.0)
if response.status_code == 200:
logger.info(f"[INFO] POST alert sent successfully for actions: {msg.get('result_type')}")
else:
logger.warning(f"[WARN] POST alert failed with status: {response.status_code}")
except Exception as e:
logger.error(f"[ERROR] POST alert request failed: {e}")
def _start_cleanup_thread(self):
"""启动视频文件清理线程"""
def cleanup_loop():
while not self.stop_event.is_set():
try:
self._cleanup_expired_files()
except Exception as e:
logger.error(f"[ERROR] Cleanup thread error: {e}")
# 每10分钟检查一次
self.stop_event.wait(600)
thread = threading.Thread(target=cleanup_loop, daemon=True, name="video_cleanup")
thread.start()
logger.info("[INFO] Video cleanup thread started")
def _cleanup_expired_files(self):
"""清理过期的视频文件"""
output_dir = constants.VIDEO_CLIP_OUTPUT_DIR
if not output_dir or not os.path.exists(output_dir):
return
retention_seconds = constants.VIDEO_CLIP_RETENTION_SECONDS
current_time = time.time()
try:
for filename in os.listdir(output_dir):
if not filename.endswith('.mp4'):
continue
filepath = os.path.join(output_dir, filename)
if os.path.isfile(filepath):
file_mtime = os.path.getmtime(filepath)
if current_time - file_mtime > retention_seconds:
try:
os.remove(filepath)
logger.info(f"[INFO] Cleaned up expired video: {filename}")
except Exception as e:
logger.error(f"[ERROR] Failed to delete {filename}: {e}")
except Exception as e:
logger.error(f"[ERROR] Cleanup expired files error: {e}")
def _create_or_get_video_clip(self, segment_path: str, segment_duration: float = None) -> str | None:
"""
创建或获取视频剪辑
Args:
segment_path: 当前TS分片路径
segment_duration: 当前分片时长(秒)
Returns:
MP4文件路径失败返回 None
"""
if not segment_path:
return None
# 检查缓存
if segment_path in self._mp4_cache:
cached_path = self._mp4_cache[segment_path]
if os.path.exists(cached_path):
return cached_path
else:
# 缓存失效,移除
del self._mp4_cache[segment_path]
# 解析分片信息构建MP4路径
camera_id, timestamp, seq = parse_segment_info(segment_path)
if not camera_id:
logger.warning(f"[WARN] Failed to parse segment info: {segment_path}")
return None
output_dir = constants.VIDEO_CLIP_OUTPUT_DIR
if not output_dir:
logger.warning("[WARN] VIDEO_CLIP_OUTPUT_DIR not configured")
return None
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# MP4文件名
mp4_filename = f"{camera_id}_{timestamp}_{seq}.mp4"
mp4_path = os.path.join(output_dir, mp4_filename)
# 检查是否已存在
if os.path.exists(mp4_path):
self._mp4_cache[segment_path] = mp4_path
return mp4_path
# 计算需要回溯的分片数量
clip_duration = constants.VIDEO_CLIP_DURATION_SECONDS
default_segment_duration = constants.VIDEO_CLIP_DEFAULT_SEGMENT_DURATION
effective_duration = segment_duration if segment_duration else default_segment_duration
if effective_duration <= 0:
effective_duration = default_segment_duration
n_segments = int(clip_duration / effective_duration) + 1
# 获取需要合并的分片
ts_files = get_segments_before_current(segment_path, n_segments)
if not ts_files:
logger.warning(f"[WARN] No segments found for clip: {segment_path}")
return None
# 合并TS为MP4
if self._merge_ts_to_mp4(ts_files, mp4_path):
self._mp4_cache[segment_path] = mp4_path
return mp4_path
return None
def _merge_ts_to_mp4(self, ts_files: list, output_path: str) -> bool:
"""
使用 ffmpeg 合并 TS 分片为 MP4
Args:
ts_files: TS文件路径列表按时间顺序
output_path: 输出MP4路径
Returns:
是否成功
"""
if not ts_files:
return False
try:
# 构建 concat 字符串
concat_str = "|".join(ts_files)
# ffmpeg 命令
cmd = [
'ffmpeg',
'-i', f'concat:{concat_str}',
'-c', 'copy',
'-y', # 覆盖输出文件
output_path
]
# 执行命令
result = subprocess.run(
cmd,
capture_output=True,
timeout=60 # 60秒超时
)
if result.returncode == 0:
logger.info(f"[INFO] Created video clip: {output_path}")
return True
else:
logger.error(f"[ERROR] ffmpeg failed: {result.stderr.decode()}")
return False
except subprocess.TimeoutExpired:
logger.error(f"[ERROR] ffmpeg timeout for {output_path}")
return False
except FileNotFoundError:
logger.error("[ERROR] ffmpeg not found, please install ffmpeg")
return False
except Exception as e:
logger.error(f"[ERROR] Failed to merge TS files: {e}")
return False
def _process_alert_with_video(self, msg: dict, segment_path: str, segment_duration: float):
"""
处理告警(含视频剪辑)- 在线程池中执行
Args:
msg: 基础消息
segment_path: 当前TS分片路径
segment_duration: 当前分片时长
"""
# 尝试创建/获取视频剪辑
mp4_path = None
if segment_path:
mp4_path = self._create_or_get_video_clip(segment_path, segment_duration)
# 展开 ori_alert
expanded_msgs = self._expand_msg_by_ori_alert(msg)
# 发送每个展开后的消息
for expanded_msg in expanded_msgs:
self._post_alert_with_video(expanded_msg, mp4_path)
def _create_detector(self, params):
"""创建检测器实例"""
# 使用 type(self) 访问类属性,避免 lambda 被绑定 self 参数
factory = type(self).DETECTOR_FACTORY
if factory is None:
raise NotImplementedError("子类必须提供 DETECTOR_FACTORY")
return factory(params)
def _filter_duplicate_alerts(self, cam_id: int, alerts: list, current_time: float) -> list:
"""
过滤5秒内重复的告警
Args:
cam_id: 摄像头ID
alerts: 当前帧的告警列表
current_time: 当前时间戳
Returns:
符合推送条件的action列表
"""
if cam_id not in self.last_alert_push_time:
self.last_alert_push_time[cam_id] = {}
push_actions = []
for alert in alerts:
action = alert['action']
last_push = self.last_alert_push_time[cam_id].get(action, 0)
# 检查是否超过推送间隔
if current_time - last_push >= ALERT_PUSH_INTERVAL:
push_actions.append(action)
# 更新该action的最后推送时间
self.last_alert_push_time[cam_id][action] = current_time
return push_actions
def run(self):
"""主循环 - 模板方法"""
target_interval = 1.0 / self.TARGET_FPS
while not self.stop_event.is_set():
try:
item = self.raw_queue.get(timeout=0.5)
except queue.Empty:
continue
try:
cam_id = item["camera_id"]
ts = item["timestamp"]
frame = item["frame"]
# 抽帧控制
if ts - self.last_ts.get(cam_id, 0) < target_interval:
continue
self.last_ts[cam_id] = ts
# 获取或创建检测器实例
if cam_id not in self.detectors:
camera_config = self.cameras.get(cam_id)
params = camera_config.params if camera_config else None
self.detectors[cam_id] = self._create_detector(params)
detector = self.detectors[cam_id]
# 执行检测
result = detector.process_frame(frame.copy(), cam_id, ts)
result_img = result["image"]
result_alerts = result["alerts"]
# 过滤重复告警
push_actions = self._filter_duplicate_alerts(
cam_id, result_alerts, time.time()
)
# 编码图像
try:
img_b64 = self._encode_image_to_base64(result_img)
except Exception as e:
logger.error(f"[ERROR] Encode image failed: {e}")
img_b64 = None
# 推送结果
if img_b64 is not None:
msg = {
"msg_type": "frame",
"camera_id": item["camera_index"],
"timestamp": ts,
"result_type": push_actions,
"image_base64": img_b64
}
try:
self.ws_queue.put(msg, timeout=1.0)
if push_actions and len(push_actions) > 0:
# 构建消息
post_msg = msg.copy()
post_msg['type'] = self.POST_TYPE
post_msg['ori_alert']: result_alerts
#备用backup
#self.post_executor.submit(self._post_alert, post_msg)
# 获取视频相关信息仅HLS模式有
segment_path = item.get("segment_path")
segment_duration = item.get("segment_duration")
# 提交到线程池执行包含视频剪辑和POST
self.post_executor.submit(
self._process_alert_with_video,
post_msg,
segment_path,
segment_duration
)
except queue.Full:
logger.warning("[WARN] ws_send_queue full, drop frame message")
except Exception as e:
logger.error(
f"[ERROR] Frame processing failed for camera {cam_id if 'cam_id' in locals() else 'unknown'}: {e}")
logger.exception("Exception details:")
finally:
self.raw_queue.task_done()
# 线程退出时关闭线程池
self.post_executor.shutdown(wait=False)