439 lines
15 KiB
Python
439 lines
15 KiB
Python
# video_face_recognition.py
|
||
import cv2
|
||
import numpy as np
|
||
import time
|
||
from insightface.app import FaceAnalysis
|
||
from typing import List, Dict, Tuple, Optional
|
||
import os
|
||
import glob
|
||
|
||
#黑白名单
|
||
|
||
class VideoFaceRecognition:
|
||
"""
|
||
视频人脸识别系统
|
||
支持实时视频流和视频文件处理
|
||
支持黑名单和白名单模式
|
||
"""
|
||
|
||
def __init__(self, model_name: str = 'buffalo_l', use_gpu: bool = True):
|
||
# 初始化人脸识别模型
|
||
self.app = FaceAnalysis(name=model_name)
|
||
self.app.prepare(
|
||
ctx_id=0 if use_gpu else -1,
|
||
det_thresh=0.39,
|
||
det_size=(640, 640)
|
||
)
|
||
|
||
# 名单相关变量
|
||
self.list_mode = "blacklist" # "blacklist" 或 "whitelist"
|
||
self.registered_faces = {} # {name: embedding}
|
||
self.similarity_threshold = 0.3
|
||
|
||
# 性能统计
|
||
self.frame_count = 0
|
||
self.processing_times = []
|
||
|
||
print(f"✅ 视频人脸识别系统初始化完成 - GPU: {use_gpu}")
|
||
|
||
def set_list_mode(self, mode: str):
|
||
"""设置名单模式"""
|
||
if mode.lower() in ["blacklist", "whitelist"]:
|
||
self.list_mode = mode.lower()
|
||
print(f"✅ 名单模式设置为: {self.list_mode}")
|
||
else:
|
||
print("❌ 无效的名单模式,请使用 'blacklist' 或 'whitelist'")
|
||
|
||
def load_registered_faces(self, register_dir: str):
|
||
"""
|
||
从目录加载注册的人脸图片
|
||
文件名(去掉后缀)即为人的名字
|
||
"""
|
||
if not os.path.exists(register_dir):
|
||
print(f"❌ 注册目录不存在: {register_dir}")
|
||
return False
|
||
|
||
# 支持的图片格式
|
||
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
|
||
image_files = []
|
||
|
||
for ext in image_extensions:
|
||
image_files.extend(glob.glob(os.path.join(register_dir, ext)))
|
||
image_files.extend(glob.glob(os.path.join(register_dir, ext.upper())))
|
||
|
||
if not image_files:
|
||
print(f"❌ 在目录 {register_dir} 中未找到图片文件")
|
||
return False
|
||
|
||
loaded_count = 0
|
||
for image_path in image_files:
|
||
# 获取文件名(不含扩展名)作为人名
|
||
person_name = os.path.splitext(os.path.basename(image_path))[0]
|
||
|
||
# 读取图片并提取人脸特征
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图片: {image_path}")
|
||
continue
|
||
|
||
faces = self.app.get(img)
|
||
if not faces:
|
||
print(f"❌ 图片中未检测到人脸: {image_path}")
|
||
continue
|
||
|
||
# 使用第一张检测到的人脸
|
||
self.registered_faces[person_name] = faces[0].embedding
|
||
loaded_count += 1
|
||
print(f"✅ 加载注册人脸: {person_name}")
|
||
|
||
print(f"🎉 成功加载 {loaded_count} 张注册人脸")
|
||
return loaded_count > 0
|
||
|
||
def find_best_match(self, embedding: np.ndarray) -> Tuple[Optional[str], float]:
|
||
"""
|
||
在注册人脸中查找最佳匹配
|
||
返回: (匹配的人名, 相似度)
|
||
"""
|
||
if not self.registered_faces:
|
||
return None, 0.0
|
||
|
||
best_similarity = 0.0
|
||
best_name = None
|
||
|
||
# 归一化查询嵌入
|
||
query_emb = embedding / np.linalg.norm(embedding)
|
||
|
||
for name, registered_embedding in self.registered_faces.items():
|
||
# 归一化注册嵌入
|
||
reg_emb = registered_embedding / np.linalg.norm(registered_embedding)
|
||
|
||
# 计算余弦相似度
|
||
similarity = float(np.dot(query_emb, reg_emb))
|
||
|
||
if similarity > best_similarity:
|
||
best_similarity = similarity
|
||
best_name = name
|
||
|
||
return best_name, best_similarity
|
||
|
||
def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
|
||
"""
|
||
处理单帧图像
|
||
返回: (处理后的帧, 识别结果列表)
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 人脸检测和识别
|
||
faces = self.app.get(frame)
|
||
|
||
results = []
|
||
for face in faces:
|
||
# 查找最佳匹配
|
||
best_name, similarity = self.find_best_match(face.embedding)
|
||
|
||
# 根据名单模式判断是否匹配
|
||
if self.list_mode == "blacklist":
|
||
# 黑名单模式:在黑名单中即为匹配(需要关注)
|
||
is_match = best_name is not None and similarity >= self.similarity_threshold
|
||
else: # whitelist
|
||
# 白名单模式:在白名单中即为匹配(允许通过)
|
||
is_match = best_name is not None and similarity >= self.similarity_threshold
|
||
|
||
result = {
|
||
'bbox': face.bbox.astype(int).tolist(),
|
||
'similarity': similarity,
|
||
'best_match': best_name,
|
||
'is_match': is_match,
|
||
'gender': 'Male' if face.gender == 1 else 'Female',
|
||
'age': int(face.age),
|
||
'det_score': float(face.det_score)
|
||
}
|
||
results.append(result)
|
||
|
||
# 在帧上绘制结果
|
||
frame = self._draw_detection(frame, result)
|
||
|
||
# 性能统计
|
||
processing_time = (time.time() - start_time) * 1000
|
||
self.processing_times.append(processing_time)
|
||
self.frame_count += 1
|
||
|
||
return frame, results
|
||
|
||
def _draw_detection(self, frame: np.ndarray, result: Dict) -> np.ndarray:
|
||
"""在帧上绘制检测结果"""
|
||
bbox = result['bbox']
|
||
similarity = result['similarity']
|
||
is_match = result['is_match']
|
||
best_match = result['best_match']
|
||
|
||
# 选择颜色 - 根据名单模式
|
||
if self.list_mode == "blacklist":
|
||
# 黑名单模式:匹配(在黑名单中)显示红色,不匹配显示绿色
|
||
color = (0, 0, 255) if is_match else (0, 255, 0) # 红色-黑名单, 绿色-正常
|
||
else: # whitelist
|
||
# 白名单模式:匹配(在白名单中)显示绿色,不匹配显示红色
|
||
color = (0, 255, 0) if is_match else (0, 0, 255) # 绿色-白名单, 红色-陌生人
|
||
|
||
# 绘制人脸框
|
||
x1, y1, x2, y2 = bbox
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 构建显示文本
|
||
if best_match and similarity >= self.similarity_threshold:
|
||
name_text = f"{best_match}: {similarity:.3f}"
|
||
else:
|
||
name_text = f"Unknown: {similarity:.3f}"
|
||
|
||
# 添加名单状态
|
||
status = "MATCH" if is_match else "NO MATCH"
|
||
text = f"{status} | {name_text}"
|
||
|
||
# 文本背景
|
||
text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||
cv2.rectangle(frame, (x1, y1 - text_size[1] - 10),
|
||
(x1 + text_size[0], y1), color, -1)
|
||
|
||
# 文本
|
||
cv2.putText(frame, text, (x1, y1 - 5),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||
|
||
# 详细信息
|
||
info_text = f"{result['gender']}/{result['age']}"
|
||
cv2.putText(frame, info_text, (x1, y2 + 20),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||
|
||
return frame
|
||
|
||
def process_video_file(self, video_path: str, output_path: str = None,
|
||
skip_frames: int = 0, show_preview: bool = True):
|
||
"""
|
||
处理视频文件
|
||
|
||
Args:
|
||
video_path: 输入视频路径
|
||
output_path: 输出视频路径
|
||
skip_frames: 跳帧数,用于提高处理速度
|
||
show_preview: 是否显示实时预览
|
||
"""
|
||
if not os.path.exists(video_path):
|
||
print(f"❌ 视频文件不存在: {video_path}")
|
||
return
|
||
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
print(f"❌ 无法打开视频文件: {video_path}")
|
||
return
|
||
|
||
# 获取视频信息
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
print(f"📹 视频信息: {width}x{height}, {fps:.1f}FPS, 总帧数: {total_frames}")
|
||
print(f"🎯 当前模式: {self.list_mode}, 注册人脸数: {len(self.registered_faces)}")
|
||
|
||
# 设置输出视频
|
||
if output_path:
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
out = cv2.VideoWriter(output_path, fourcc, fps / (skip_frames + 1), (width, height))
|
||
else:
|
||
out = None
|
||
|
||
# 处理视频帧
|
||
frame_index = 0
|
||
processed_frames = 0
|
||
start_time = time.time()
|
||
|
||
print("🚀 开始处理视频...")
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# 跳帧处理
|
||
if skip_frames > 0 and frame_index % (skip_frames + 1) != 0:
|
||
frame_index += 1
|
||
continue
|
||
|
||
# 处理当前帧
|
||
processed_frame, results = self.process_frame(frame)
|
||
|
||
# 写入输出视频
|
||
if out:
|
||
out.write(processed_frame)
|
||
|
||
# 显示预览
|
||
if show_preview:
|
||
# 添加性能信息
|
||
fps_text = f"Frame: {frame_index}/{total_frames} | Faces: {len(results)} | Mode: {self.list_mode}"
|
||
cv2.putText(processed_frame, fps_text, (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||
|
||
# 添加名单统计
|
||
match_count = sum(1 for r in results if r['is_match'])
|
||
list_text = f"Match: {match_count}/{len(results)}"
|
||
cv2.putText(processed_frame, list_text, (10, 60),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||
|
||
cv2.imshow('Video Face Recognition', processed_frame)
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
break
|
||
|
||
frame_index += 1
|
||
processed_frames += 1
|
||
|
||
# 进度显示
|
||
if frame_index % 30 == 0:
|
||
progress = (frame_index / total_frames) * 100
|
||
print(f"📊 处理进度: {progress:.1f}% ({frame_index}/{total_frames})")
|
||
|
||
# 清理资源
|
||
cap.release()
|
||
if out:
|
||
out.release()
|
||
if show_preview:
|
||
cv2.destroyAllWindows()
|
||
|
||
# 性能统计
|
||
total_time = time.time() - start_time
|
||
avg_processing_time = np.mean(self.processing_times) if self.processing_times else 0
|
||
|
||
print(f"\n🎉 视频处理完成!")
|
||
print(f"📊 性能统计:")
|
||
print(f" 总处理帧数: {processed_frames}")
|
||
print(f" 总耗时: {total_time:.1f}秒")
|
||
print(f" 平均每帧: {avg_processing_time:.1f}ms")
|
||
print(f" 实际FPS: {processed_frames / total_time:.1f}")
|
||
if output_path:
|
||
print(f" 输出视频: {output_path}")
|
||
|
||
def process_webcam(self, camera_id: int = 0, output_path: str = None):
|
||
"""
|
||
处理摄像头实时视频流
|
||
"""
|
||
cap = cv2.VideoCapture(camera_id)
|
||
if not cap.isOpened():
|
||
print(f"❌ 无法打开摄像头 {camera_id}")
|
||
return
|
||
|
||
# 设置摄像头分辨率(可选)
|
||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
|
||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
|
||
|
||
# 设置输出视频
|
||
if output_path:
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
||
else:
|
||
out = None
|
||
|
||
print(f"🎥 开始摄像头实时识别 - 模式: {self.list_mode} (按 'q' 退出)...")
|
||
print(f"📋 注册人脸数: {len(self.registered_faces)}")
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
print("❌ 无法读取摄像头帧")
|
||
break
|
||
|
||
# 处理当前帧
|
||
processed_frame, results = self.process_frame(frame)
|
||
|
||
# 添加实时信息
|
||
current_fps = 1000 / self.processing_times[-1] if self.processing_times else 0
|
||
info_text = f"FPS: {current_fps:.1f} | Faces: {len(results)} | Mode: {self.list_mode}"
|
||
cv2.putText(processed_frame, info_text, (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||
|
||
# 添加名单统计
|
||
match_count = sum(1 for r in results if r['is_match'])
|
||
list_text = f"Match: {match_count}/{len(results)}"
|
||
cv2.putText(processed_frame, list_text, (10, 60),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||
|
||
# 写入输出
|
||
if out:
|
||
out.write(processed_frame)
|
||
|
||
# 显示预览
|
||
cv2.imshow('Real-time Face Recognition', processed_frame)
|
||
|
||
# 按'q'退出
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
break
|
||
|
||
# 清理资源
|
||
cap.release()
|
||
if out:
|
||
out.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
print("✅ 摄像头处理结束")
|
||
|
||
|
||
# 使用示例
|
||
def main():
|
||
# 创建视频识别系统
|
||
video_system = VideoFaceRecognition(use_gpu=True)
|
||
|
||
# 设置名单模式
|
||
# video_system.set_list_mode("blacklist") # 黑名单模式
|
||
video_system.set_list_mode("whitelist") # 白名单模式
|
||
|
||
# 加载注册人脸
|
||
register_dir = "test_data/register" # 注册图片目录
|
||
if os.path.exists(register_dir):
|
||
video_system.load_registered_faces(register_dir)
|
||
else:
|
||
print(f"⚠️ 注册目录不存在: {register_dir}")
|
||
|
||
# # 选择处理模式
|
||
# print("请选择处理模式:")
|
||
# print("1. 处理视频文件")
|
||
# print("2. 实时摄像头")
|
||
#
|
||
# choice = input("请输入选择 (1 或 2): ").strip()
|
||
|
||
choice = "1"
|
||
|
||
if choice == "1":
|
||
# 处理视频文件
|
||
video_path = "test_data/video/video_1.mp4"
|
||
output_path = "test_data/output_video/video_1_white.mp4"
|
||
|
||
# 性能优化:跳帧处理
|
||
skip_frames = 1 # 每2帧处理1帧,提高速度
|
||
|
||
video_system.process_video_file(
|
||
video_path=video_path,
|
||
output_path=output_path,
|
||
skip_frames=skip_frames,
|
||
show_preview=True
|
||
)
|
||
|
||
elif choice == "2":
|
||
# 实时摄像头
|
||
output_path = "webcam_recording.mp4" # 可选:保存录制
|
||
|
||
video_system.process_webcam(
|
||
camera_id=0,
|
||
output_path=output_path
|
||
)
|
||
|
||
else:
|
||
print("❌ 无效选择")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |