Files
SupervisorAI/biz/base_frame_processor.py

196 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# base_frame_processor.py
# 通用帧处理工作线程基类
# 统一异步POST、告警去重、抽帧控制等公共逻辑
import cv2
import base64
import time
import threading
import queue
import requests
from typing import Dict, Any, Callable
from concurrent.futures import ThreadPoolExecutor
from common import constants
from utils.logger import get_logger
logger = get_logger(__name__)
# 告警推送频率限制(秒)
ALERT_PUSH_INTERVAL = 5.0
class BaseFrameProcessorWorker(threading.Thread):
"""
通用帧处理工作线程基类
功能:
- 统一异步POST线程池
- 统一告警去重逻辑
- 统一抽帧控制
- 统一异常处理
子类仅需提供:
- DETECTOR_FACTORY: 检测器工厂函数
- POST_TYPE: POST请求的type值
- TARGET_FPS: 目标帧率
"""
# ========== 子类必须重写的类常量 ==========
DETECTOR_FACTORY: Callable = None # 检测器工厂函数
POST_TYPE: int = 2 # POST type
TARGET_FPS: float = 10.0 # 目标帧率
def __init__(self,
raw_queue: queue.Queue,
ws_queue: queue.Queue,
stop_event: threading.Event,
cameras=None,
post_workers: int = 4):
super().__init__(daemon=True)
self.raw_queue = raw_queue
self.ws_queue = ws_queue
self.stop_event = stop_event
# 将摄像头列表转换为字典key为id方便通过camera_id快速查找
self.cameras = {cam.id: cam for cam in cameras} if cameras is not None else {}
self.last_ts: Dict[int, float] = {}
# 检测器实例缓存
self.detectors: Dict[int, Any] = {}
# 告警去重记录 {camera_id: {action: last_push_time}}
self.last_alert_push_time: Dict[int, Dict[str, float]] = {}
# 异步POST线程池
self.post_executor = ThreadPoolExecutor(
max_workers=post_workers,
thread_name_prefix="alert_post"
)
def _encode_image_to_base64(self, img) -> str:
"""图像编码为 Base64"""
ok, buf = cv2.imencode(".jpg", img)
if not ok:
raise RuntimeError("Failed to encode image to JPEG")
return base64.b64encode(buf.tobytes()).decode("ascii")
def _post_alert(self, msg: dict):
"""异步发送告警 POST 请求(在线程池中执行)"""
try:
response = requests.post(constants.ALERT_PUSH_URL, json=msg, timeout=5.0)
if response.status_code == 200:
print(f"[INFO] POST alert sent successfully for actions: {msg.get('result_type')}")
else:
print(f"[WARN] POST alert failed with status: {response.status_code}")
except Exception as e:
print(f"[ERROR] POST alert request failed: {e}")
def _create_detector(self, params):
"""创建检测器实例"""
# 使用 type(self) 访问类属性,避免 lambda 被绑定 self 参数
factory = type(self).DETECTOR_FACTORY
if factory is None:
raise NotImplementedError("子类必须提供 DETECTOR_FACTORY")
return factory(params)
def _filter_duplicate_alerts(self, cam_id: int, alerts: list, current_time: float) -> list:
"""
过滤5秒内重复的告警
Args:
cam_id: 摄像头ID
alerts: 当前帧的告警列表
current_time: 当前时间戳
Returns:
符合推送条件的action列表
"""
if cam_id not in self.last_alert_push_time:
self.last_alert_push_time[cam_id] = {}
push_actions = []
for alert in alerts:
action = alert['action']
last_push = self.last_alert_push_time[cam_id].get(action, 0)
# 检查是否超过推送间隔
if current_time - last_push >= ALERT_PUSH_INTERVAL:
push_actions.append(action)
# 更新该action的最后推送时间
self.last_alert_push_time[cam_id][action] = current_time
return push_actions
def run(self):
"""主循环 - 模板方法"""
target_interval = 1.0 / self.TARGET_FPS
while not self.stop_event.is_set():
try:
item = self.raw_queue.get(timeout=0.5)
except queue.Empty:
continue
try:
cam_id = item["camera_id"]
ts = item["timestamp"]
frame = item["frame"]
# 抽帧控制
if ts - self.last_ts.get(cam_id, 0) < target_interval:
continue
self.last_ts[cam_id] = ts
# 获取或创建检测器实例
if cam_id not in self.detectors:
camera_config = self.cameras.get(cam_id)
params = camera_config.params if camera_config else None
self.detectors[cam_id] = self._create_detector(params)
detector = self.detectors[cam_id]
# 执行检测
result = detector.process_frame(frame.copy(), cam_id, ts)
result_img = result["image"]
result_alerts = result["alerts"]
# 过滤重复告警
push_actions = self._filter_duplicate_alerts(
cam_id, result_alerts, time.time()
)
# 编码图像
try:
img_b64 = self._encode_image_to_base64(result_img)
except Exception as e:
logger.error(f"[ERROR] Encode image failed: {e}")
img_b64 = None
# 推送结果
if img_b64 is not None:
msg = {
"msg_type": "frame",
"camera_id": item["camera_index"],
"timestamp": ts,
"result_type": push_actions,
"image_base64": img_b64,
}
try:
self.ws_queue.put(msg, timeout=1.0)
if push_actions and len(push_actions) > 0:
# 异步发送 POST 请求(提交到线程池)
post_msg = msg.copy()
post_msg['type'] = self.POST_TYPE
self.post_executor.submit(self._post_alert, post_msg)
except queue.Full:
logger.warning("[WARN] ws_send_queue full, drop frame message")
except Exception as e:
logger.error(
f"[ERROR] Frame processing failed for camera {cam_id if 'cam_id' in locals() else 'unknown'}: {e}")
logger.exception("Exception details:")
finally:
self.raw_queue.task_done()
# 线程退出时关闭线程池
self.post_executor.shutdown(wait=False)