516 lines
18 KiB
Python
516 lines
18 KiB
Python
"""
|
||
人脸识别业务基类
|
||
处理具体的业务逻辑,包括参数配置、人脸匹配、质量评估等
|
||
"""
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional
|
||
import os
|
||
from insightface.app import FaceAnalysis
|
||
|
||
|
||
class BaseFaceBiz:
|
||
"""
|
||
人脸识别业务基类
|
||
处理具体的业务逻辑,与底层算法分离
|
||
"""
|
||
|
||
def __init__(self, face_analysis: FaceAnalysis):
|
||
"""
|
||
初始化业务类
|
||
|
||
参数:
|
||
face_analysis: 已初始化好的FaceAnalysis实例
|
||
"""
|
||
self.app = face_analysis
|
||
|
||
# 业务参数配置
|
||
self.list_mode = "0" # 0 = blacklist, 1= whitelist
|
||
self.clarity_threshold = 100.0 # 清晰度阈值,低于此值认为人脸模糊
|
||
self.min_face_size = 20 # 最小人脸像素尺寸
|
||
self.pitch_threshold = 90 # 俯仰角阈值
|
||
self.yaw_threshold = 90 # 偏航角阈值
|
||
self.similarity_threshold = 0.3 # 相似度阈值
|
||
|
||
# 名单相关变量
|
||
self.registered_faces = {} # {name: embedding}
|
||
|
||
def set_list_mode(self, mode: str):
|
||
"""设置名单模式"""
|
||
if mode.lower() in ["0", "1"]:
|
||
self.list_mode = mode.lower()
|
||
print(f"✅ 名单模式设置为: {self.list_mode}")
|
||
else:
|
||
print("❌ 无效的名单模式,请使用 '0' 或 '1'")
|
||
|
||
def get_list_mode(self) -> str:
|
||
"""获取当前名单模式"""
|
||
return self.list_mode
|
||
|
||
def set_clarity_threshold(self, threshold: float):
|
||
"""设置清晰度阈值"""
|
||
self.clarity_threshold = threshold
|
||
print(f"✅ 清晰度阈值设置为: {threshold}")
|
||
|
||
def get_clarity_threshold(self) -> float:
|
||
"""获取清晰度阈值"""
|
||
return self.clarity_threshold
|
||
|
||
def set_min_face_size(self, size: int):
|
||
"""设置最小人脸尺寸"""
|
||
self.min_face_size = size
|
||
print(f"✅ 最小人脸尺寸设置为: {size}")
|
||
|
||
def get_min_face_size(self) -> int:
|
||
"""获取最小人脸尺寸"""
|
||
return self.min_face_size
|
||
|
||
def set_pitch_threshold(self, threshold: float):
|
||
"""设置俯仰角阈值"""
|
||
self.pitch_threshold = threshold
|
||
print(f"✅ 俯仰角阈值设置为: {threshold}")
|
||
|
||
def get_pitch_threshold(self) -> float:
|
||
"""获取俯仰角阈值"""
|
||
return self.pitch_threshold
|
||
|
||
def set_yaw_threshold(self, threshold: float):
|
||
"""设置偏航角阈值"""
|
||
self.yaw_threshold = threshold
|
||
print(f"✅ 偏航角阈值设置为: {threshold}")
|
||
|
||
def get_yaw_threshold(self) -> float:
|
||
"""获取偏航角阈值"""
|
||
return self.yaw_threshold
|
||
|
||
def set_similarity_threshold(self, threshold: float):
|
||
"""设置相似度阈值"""
|
||
self.similarity_threshold = threshold
|
||
print(f"✅ 相似度阈值设置为: {threshold}")
|
||
|
||
def get_similarity_threshold(self) -> float:
|
||
"""获取相似度阈值"""
|
||
return self.similarity_threshold
|
||
|
||
def set_registered_faces(self, registered_faces: Dict[str, np.ndarray]):
|
||
"""
|
||
直接设置已注册的人脸数据
|
||
|
||
参数:
|
||
registered_faces: 字典格式 {name: embedding}
|
||
"""
|
||
if not isinstance(registered_faces, dict):
|
||
print("❌ 参数必须是字典格式 {name: embedding}")
|
||
return False
|
||
|
||
# 验证数据格式
|
||
valid_count = 0
|
||
for name, embedding in registered_faces.items():
|
||
if isinstance(embedding, np.ndarray) and embedding.size > 0:
|
||
valid_count += 1
|
||
|
||
if valid_count == 0:
|
||
print("❌ 未找到有效的人脸嵌入数据")
|
||
return False
|
||
|
||
self.registered_faces = registered_faces
|
||
print(f"✅ 成功设置 {valid_count} 个注册人脸")
|
||
return True
|
||
|
||
def get_registered_faces(self) -> Dict[str, np.ndarray]:
|
||
"""获取已注册的人脸数据"""
|
||
return self.registered_faces
|
||
|
||
def get_registered_face_count(self) -> int:
|
||
"""获取已注册人脸数量"""
|
||
return len(self.registered_faces)
|
||
|
||
def load_registered_faces(self, register_dir: str):
|
||
"""
|
||
从目录加载注册的人脸图片
|
||
文件名(去掉后缀)即为人的名字
|
||
"""
|
||
import glob
|
||
|
||
if not os.path.exists(register_dir):
|
||
print(f"❌ 注册目录不存在: {register_dir}")
|
||
return False
|
||
|
||
# 支持的图片格式
|
||
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
|
||
image_files = []
|
||
|
||
for ext in image_extensions:
|
||
image_files.extend(glob.glob(os.path.join(register_dir, ext)))
|
||
image_files.extend(glob.glob(os.path.join(register_dir, ext.upper())))
|
||
|
||
if not image_files:
|
||
print(f"❌ 在目录 {register_dir} 中未找到图片文件")
|
||
return False
|
||
|
||
loaded_count = 0
|
||
for image_path in image_files:
|
||
# 获取文件名(不含扩展名)作为人名
|
||
person_name = os.path.splitext(os.path.basename(image_path))[0]
|
||
|
||
# 读取图片并提取人脸特征
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图片: {image_path}")
|
||
continue
|
||
|
||
faces = self.app.get(img)
|
||
if not faces:
|
||
print(f"❌ 图片中未检测到人脸: {image_path}")
|
||
continue
|
||
|
||
# 使用第一张检测到的人脸
|
||
self.registered_faces[person_name] = faces[0].embedding
|
||
loaded_count += 1
|
||
print(f"✅ 加载注册人脸: {person_name}")
|
||
|
||
print(f"🎉 成功加载 {loaded_count} 张注册人脸")
|
||
return loaded_count > 0
|
||
|
||
def find_best_match(self, embedding: np.ndarray) -> Tuple[Optional[str], float]:
|
||
"""
|
||
在注册人脸中查找最佳匹配
|
||
|
||
返回:
|
||
(匹配的人名, 相似度)
|
||
"""
|
||
if not self.registered_faces:
|
||
return None, 0.0
|
||
|
||
best_similarity = 0.0
|
||
best_name = None
|
||
|
||
# 归一化查询嵌入
|
||
query_emb = embedding / np.linalg.norm(embedding)
|
||
|
||
for name, registered_embedding in self.registered_faces.items():
|
||
# 归一化注册嵌入
|
||
reg_emb = registered_embedding / np.linalg.norm(registered_embedding)
|
||
|
||
# 计算余弦相似度
|
||
similarity = float(np.dot(query_emb, reg_emb))
|
||
|
||
if similarity > best_similarity:
|
||
best_similarity = similarity
|
||
best_name = name
|
||
|
||
return best_name, best_similarity
|
||
|
||
def calculate_clarity(self, face_region: np.ndarray) -> float:
|
||
"""
|
||
计算人脸区域的清晰度/模糊度
|
||
使用拉普拉斯方差方法:值越高表示图像越清晰
|
||
"""
|
||
if len(face_region.shape) == 3:
|
||
gray = cv2.cvtColor(face_region, cv2.COLOR_BGR2GRAY)
|
||
else:
|
||
gray = face_region
|
||
|
||
# 计算拉普拉斯算子的方差
|
||
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
|
||
return laplacian_var
|
||
|
||
def is_face_quality_acceptable(self, face, frame: np.ndarray) -> Tuple[bool, Dict]:
|
||
"""
|
||
综合判断人脸质量是否可接受
|
||
|
||
返回:
|
||
(是否可接受, 质量指标字典)
|
||
"""
|
||
quality_metrics = {}
|
||
is_acceptable = True
|
||
|
||
# 1. 检测置信度
|
||
quality_metrics['det_score'] = float(face.det_score)
|
||
|
||
# 2. 人脸姿态角度
|
||
if hasattr(face, 'pose') and face.pose is not None:
|
||
pitch, yaw, roll = face.pose
|
||
quality_metrics['pitch'] = float(pitch)
|
||
quality_metrics['yaw'] = float(yaw)
|
||
quality_metrics['roll'] = float(roll)
|
||
else:
|
||
quality_metrics['pitch'] = 100.0
|
||
quality_metrics['yaw'] = 100.0
|
||
quality_metrics['roll'] = 100.0
|
||
|
||
# 3. 人脸边界框信息
|
||
bbox = face.bbox
|
||
x1, y1, x2, y2 = bbox.astype(int)
|
||
width = x2 - x1
|
||
height = y2 - y1
|
||
quality_metrics['bbox_width'] = width
|
||
quality_metrics['bbox_height'] = height
|
||
quality_metrics['bbox_area'] = width * height
|
||
quality_metrics['aspect_ratio'] = width / height if height > 0 else 0
|
||
|
||
# 4. 图像清晰度检测
|
||
h, w = frame.shape[:2]
|
||
x1_clip = max(0, x1)
|
||
y1_clip = max(0, y1)
|
||
x2_clip = min(w, x2)
|
||
y2_clip = min(h, y2)
|
||
|
||
if x2_clip > x1_clip and y2_clip > y1_clip:
|
||
face_region = frame[y1_clip:y2_clip, x1_clip:x2_clip]
|
||
clarity_score = self.calculate_clarity(face_region)
|
||
quality_metrics['clarity_score'] = clarity_score
|
||
else:
|
||
quality_metrics['clarity_score'] = 0.0
|
||
|
||
# 5. 综合质量评分
|
||
base_score = quality_metrics['det_score']
|
||
|
||
# 清晰度惩罚
|
||
if quality_metrics['clarity_score'] < self.clarity_threshold:
|
||
is_acceptable = False
|
||
|
||
# 姿态惩罚
|
||
if abs(quality_metrics['yaw']) > self.yaw_threshold:
|
||
is_acceptable = False
|
||
if abs(quality_metrics['pitch']) > self.pitch_threshold:
|
||
is_acceptable = False
|
||
|
||
# 尺寸惩罚
|
||
if min(width, height) < self.min_face_size:
|
||
is_acceptable = False
|
||
|
||
quality_metrics['quality_score'] = base_score
|
||
|
||
return is_acceptable, quality_metrics
|
||
|
||
def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict], float]:
|
||
"""
|
||
处理单帧图像
|
||
|
||
返回:
|
||
(原始帧, 识别结果列表, 处理时间ms)
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 人脸检测和识别
|
||
faces = self.app.get(frame)
|
||
|
||
results = []
|
||
for face in faces:
|
||
# 检查人脸质量是否可接受
|
||
is_acceptable, quality_metrics = self.is_face_quality_acceptable(face, frame)
|
||
|
||
# 查找最佳匹配
|
||
best_name, similarity = self.find_best_match(face.embedding)
|
||
|
||
# 根据名单模式判断是否匹配
|
||
if self.list_mode == "0":
|
||
# 黑名单模式:在黑名单中即为匹配(需要关注)
|
||
is_match = best_name is not None and similarity >= self.similarity_threshold
|
||
else: # whitelist
|
||
# 白名单模式:在白名单中即为匹配(允许通过)
|
||
is_match = best_name is not None and similarity >= self.similarity_threshold
|
||
|
||
result = {
|
||
'bbox': face.bbox.astype(int).tolist(),
|
||
'similarity': similarity,
|
||
'best_match': best_name,
|
||
'is_match': is_match,
|
||
'det_score': float(face.det_score),
|
||
'quality_metrics': quality_metrics,
|
||
'is_acceptable': is_acceptable
|
||
}
|
||
results.append(result)
|
||
|
||
processing_time = (time.time() - start_time) * 1000
|
||
return frame, results, processing_time
|
||
|
||
def _draw_detection(self, frame: np.ndarray, result: Dict) -> np.ndarray:
|
||
"""在帧上绘制检测结果和质量信息"""
|
||
bbox = result['bbox']
|
||
similarity = result['similarity']
|
||
is_match = result['is_match']
|
||
is_acceptable = result['is_acceptable']
|
||
quality_metrics = result['quality_metrics']
|
||
best_match = result['best_match']
|
||
|
||
# 选择颜色
|
||
if not is_acceptable:
|
||
color = (128, 128, 128) # 灰色 - 质量不可接受
|
||
else:
|
||
# 选择颜色 - 根据名单模式
|
||
if self.list_mode == "0":
|
||
# 黑名单模式:匹配(在黑名单中)显示红色,不匹配显示绿色
|
||
color = (0, 0, 255) if is_match else (0, 255, 0) # 红色-黑名单, 绿色-正常
|
||
else: # whitelist
|
||
# 白名单模式:匹配(在白名单中)显示绿色,不匹配显示红色
|
||
color = (0, 255, 0) if is_match else (0, 0, 255) # 绿色-白名单, 红色-陌生人
|
||
|
||
# 绘制人脸框
|
||
x1, y1, x2, y2 = bbox
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 准备显示文本
|
||
text_lines = []
|
||
|
||
# 第一行:匹配状态
|
||
if not is_acceptable:
|
||
text_lines.append("LOW QUALITY")
|
||
else:
|
||
status = f"MATCH: {best_match}: {similarity:.3f}" if is_match else f"NO MATCH: {similarity:.3f}"
|
||
text_lines.append(status)
|
||
|
||
# 第二行:质量得分
|
||
text_lines.append(f"Quality: {quality_metrics['quality_score']:.3f}")
|
||
|
||
# 第三行:检测得分
|
||
text_lines.append(f"DetScore: {quality_metrics['det_score']:.3f}")
|
||
|
||
# 第四行:清晰度
|
||
text_lines.append(f"Clarity: {quality_metrics['clarity_score']:.1f}")
|
||
|
||
# 第五行:姿态角度
|
||
text_lines.append(f"Pitch: {quality_metrics['pitch']:.1f}°")
|
||
text_lines.append(f"Yaw: {quality_metrics['yaw']:.1f}°")
|
||
|
||
# 计算文本区域大小
|
||
max_text_width = 0
|
||
total_text_height = 0
|
||
line_heights = []
|
||
|
||
for line in text_lines:
|
||
(text_width, text_height), baseline = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
||
max_text_width = max(max_text_width, text_width)
|
||
line_heights.append(text_height + baseline)
|
||
total_text_height += text_height + baseline + 2
|
||
|
||
# 绘制文本背景
|
||
bg_x1 = x1
|
||
bg_y1 = y1 - total_text_height - 10
|
||
bg_x2 = x1 + max_text_width + 10
|
||
bg_y2 = y1
|
||
|
||
# 如果背景超出图像顶部,调整到框下方
|
||
if bg_y1 < 0:
|
||
bg_y1 = y2
|
||
bg_y2 = y2 + total_text_height + 10
|
||
|
||
# 绘制半透明背景
|
||
overlay = frame.copy()
|
||
cv2.rectangle(overlay, (bg_x1, bg_y1), (bg_x2, bg_y2), (0, 0, 0), -1)
|
||
alpha = 0.6
|
||
cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
|
||
|
||
# 绘制文本
|
||
current_y = bg_y1 + 15
|
||
for i, line in enumerate(text_lines):
|
||
# 根据内容选择颜色
|
||
if i == 0: # 状态行
|
||
if not is_acceptable:
|
||
text_color = (128, 128, 128) # 灰色 - 质量差
|
||
elif is_match:
|
||
text_color = (0, 255, 0) # 绿色 - 匹配
|
||
else:
|
||
text_color = (0, 0, 255) # 红色 - 不匹配
|
||
elif i == 3: # 清晰度行
|
||
if quality_metrics['clarity_score'] >= self.clarity_threshold:
|
||
text_color = (255, 255, 255)
|
||
else:
|
||
text_color = (0, 0, 255)
|
||
elif i == 4: # pitch
|
||
if abs(quality_metrics['pitch']) > self.pitch_threshold:
|
||
text_color = (0, 0, 255)
|
||
else:
|
||
text_color = (255, 255, 255)
|
||
elif i == 5: # yaw
|
||
if abs(quality_metrics['yaw']) > self.yaw_threshold:
|
||
text_color = (0, 0, 255)
|
||
else:
|
||
text_color = (255, 255, 255)
|
||
else:
|
||
text_color = (255, 255, 255) # 白色 - 其他信息
|
||
|
||
cv2.putText(frame, line, (x1 + 5, current_y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, text_color, 1)
|
||
current_y += line_heights[i]
|
||
|
||
return frame
|
||
|
||
def draw_detections(self, frame: np.ndarray, results: List[Dict]) -> np.ndarray:
|
||
"""绘制所有检测结果"""
|
||
for result in results:
|
||
frame = self._draw_detection(frame, result)
|
||
return frame
|
||
|
||
def extract_face_feature(self, image_path: str) -> Optional[np.ndarray]:
|
||
"""
|
||
从单张图片中提取人脸特征值
|
||
|
||
参数:
|
||
image_path: 人脸图片路径
|
||
|
||
返回:
|
||
numpy数组格式的人脸特征值,如果检测失败返回None
|
||
"""
|
||
if not os.path.exists(image_path):
|
||
print(f"❌ 图片文件不存在: {image_path}")
|
||
return None
|
||
|
||
# 读取图片
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图片: {image_path}")
|
||
return None
|
||
|
||
# 人脸检测
|
||
faces = self.app.get(img)
|
||
if not faces:
|
||
print(f"❌ 图片中未检测到人脸: {image_path}")
|
||
return None
|
||
|
||
# 使用第一张检测到的人脸
|
||
face = faces[0]
|
||
|
||
# 检查人脸质量
|
||
is_acceptable, quality_metrics = self.is_face_quality_acceptable(face, img)
|
||
|
||
if not is_acceptable:
|
||
print(f"⚠️ 人脸质量不可接受: {image_path}")
|
||
print(f" 质量得分: {quality_metrics['quality_score']:.3f}")
|
||
print(f" 清晰度: {quality_metrics['clarity_score']:.1f}")
|
||
print(f" 姿态角度: pitch={quality_metrics['pitch']:.1f}°, yaw={quality_metrics['yaw']:.1f}°")
|
||
return None
|
||
|
||
print(f"✅ 成功提取人脸特征: {image_path}")
|
||
print(f" 检测得分: {quality_metrics['det_score']:.3f}")
|
||
print(f" 质量得分: {quality_metrics['quality_score']:.3f}")
|
||
|
||
# 返回特征向量
|
||
return face.embedding
|
||
|
||
def extract_face_features_batch(self, image_paths: List[str]) -> Dict[str, Optional[np.ndarray]]:
|
||
"""
|
||
批量从多张图片中提取人脸特征值
|
||
|
||
参数:
|
||
image_paths: 人脸图片路径列表
|
||
|
||
返回:
|
||
字典格式 {图片路径: 特征值},失败的特征值为None
|
||
"""
|
||
results = {}
|
||
|
||
for image_path in image_paths:
|
||
feature = self.extract_face_feature(image_path)
|
||
results[image_path] = feature
|
||
|
||
# 统计结果
|
||
success_count = sum(1 for feature in results.values() if feature is not None)
|
||
total_count = len(image_paths)
|
||
|
||
print(f"🎉 批量提取完成: 成功 {success_count}/{total_count} 张图片")
|
||
|
||
return results |