Files
SupervisorAI/backup/face_recognition_system.py
2025-12-20 18:07:49 +08:00

249 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from insightface.app import FaceAnalysis
from insightface.data import get_image as ins_get_image
import pickle
from typing import List, Tuple, Dict
import json
class FaceRecognitionSystem:
def __init__(self, model_name: str = 'buffalo_l'):
"""
初始化人脸识别系统
Args:
model_name: 模型名称,可选 'buffalo_l', 'buffalo_s'
"""
self.app = FaceAnalysis(name=model_name, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app.prepare(ctx_id=0, det_size=(640, 640))
self.face_database = {} # 存储人脸特征向量
self.database_file = "face_database.pkl"
# 加载已有数据库
self.load_database()
def extract_face_features(self, image_path: str) -> List[Dict]:
"""
从图像中提取人脸特征
Args:
image_path: 图像路径
Returns:
List[Dict]: 包含人脸信息的列表
"""
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图像: {image_path}")
faces = self.app.get(img)
results = []
for i, face in enumerate(faces):
face_info = {
'bbox': face.bbox.astype(int).tolist(), # 人脸框
'kps': face.kps.astype(int).tolist(), # 关键点
'embedding': face.embedding.tolist(), # 特征向量
'gender': face.gender, # 性别
'age': face.age # 年龄
}
results.append(face_info)
return results
def register_face(self, image_path: str, person_id: str):
"""
注册人脸到数据库
Args:
image_path: 图像路径
person_id: 人员ID
"""
faces = self.extract_face_features(image_path)
if not faces:
print(f"在图像 {image_path} 中未检测到人脸")
return False
if len(faces) > 1:
print(f"在图像 {image_path} 中检测到多个人脸,将使用第一个人脸")
# 存储第一个人脸的特征
self.face_database[person_id] = {
'embedding': np.array(faces[0]['embedding']),
'image_path': image_path
}
self.save_database()
print(f"成功注册人脸: {person_id}")
return True
def compare_faces(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""
计算两个人脸特征的相似度
Args:
embedding1: 特征向量1
embedding2: 特征向量2
Returns:
float: 相似度得分 (0-1之间越大越相似)
"""
# 计算余弦相似度
similarity = np.dot(embedding1, embedding2) / (
np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
)
return float(similarity)
def one_vs_one(self, image_path1: str, image_path2: str) -> Tuple[float, bool]:
"""
1v1人脸比对
Args:
image_path1: 图像1路径
image_path2: 图像2路径
Returns:
Tuple[float, bool]: (相似度得分, 是否同一人)
"""
faces1 = self.extract_face_features(image_path1)
faces2 = self.extract_face_features(image_path2)
if not faces1 or not faces2:
return 0.0, False
embedding1 = np.array(faces1[0]['embedding'])
embedding2 = np.array(faces2[0]['embedding'])
similarity = self.compare_faces(embedding1, embedding2)
is_same = similarity > 0.6 # 阈值可根据实际情况调整
return similarity, is_same
def one_vs_many(self, image_path: str, threshold: float = 0.6) -> List[Tuple[str, float]]:
"""
1vn人脸检索
Args:
image_path: 查询图像路径
threshold: 相似度阈值
Returns:
List[Tuple[str, float]]: 匹配结果 (人员ID, 相似度)
"""
faces = self.extract_face_features(image_path)
if not faces:
return []
query_embedding = np.array(faces[0]['embedding'])
results = []
for person_id, data in self.face_database.items():
similarity = self.compare_faces(query_embedding, data['embedding'])
if similarity > threshold:
results.append((person_id, similarity))
# 按相似度降序排序
results.sort(key=lambda x: x[1], reverse=True)
return results
def save_database(self):
"""保存人脸数据库到文件"""
# 将numpy数组转换为列表以便序列化
save_data = {}
for person_id, data in self.face_database.items():
save_data[person_id] = {
'embedding': data['embedding'].tolist(),
'image_path': data['image_path']
}
with open(self.database_file, 'wb') as f:
pickle.dump(save_data, f)
def load_database(self):
"""从文件加载人脸数据库"""
if os.path.exists(self.database_file):
with open(self.database_file, 'rb') as f:
save_data = pickle.load(f)
# 将列表转换回numpy数组
for person_id, data in save_data.items():
self.face_database[person_id] = {
'embedding': np.array(data['embedding']),
'image_path': data['image_path']
}
print(f"已加载数据库,包含 {len(self.face_database)} 个人脸")
def visualize_detection(self, image_path: str, save_path: str = None):
"""
可视化人脸检测结果
Args:
image_path: 图像路径
save_path: 保存路径
"""
img = cv2.imread(image_path)
faces = self.app.get(img)
for face in faces:
# 绘制人脸框
bbox = face.bbox.astype(int)
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
# 绘制关键点
for kp in face.kps.astype(int):
cv2.circle(img, (kp[0], kp[1]), 2, (0, 0, 255), -1)
# 显示性别和年龄
info = f"{'M' if face.gender == 1 else 'F'}/{face.age}"
cv2.putText(img, info, (bbox[0], bbox[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
if save_path:
cv2.imwrite(save_path, img)
return img
# 使用示例
def main():
# 初始化系统
face_system = FaceRecognitionSystem()
# # 创建测试图像目录
# os.makedirs("test_images", exist_ok=True)
#
# # 示例1: 注册人脸
# print("=== 注册人脸 ===")
# # 假设你有一些人脸图像放在 test_images 目录下
# test_images = {
# "person_001": "test_images/person1.png",
# "person_002": "test_images/person2.jpg",
# # "person_003": "test_images/person3.jpg"
# }
#
# for person_id, img_path in test_images.items():
# if os.path.exists(img_path):
# face_system.register_face(img_path, person_id)
# 示例2: 1v1比对
print("\n=== 1v1人脸比对 ===")
img1 = "test_data/register/person1.png"
# img1 = "test_data/query/file___media_Photo_103_IMG_1737872809_085_IMG_20250126_142509.jpg"
img2 = "test_data/query/file___media_Photo_152_IMG_1737876072_134_IMG_20250126_151932.jpg"
if os.path.exists(img1) and os.path.exists(img2):
similarity, is_same = face_system.one_vs_one(img1, img2)
print(f"相似度: {similarity:.4f}, 是否同一人: {is_same}")
# # 示例3: 1vn检索
# print("\n=== 1vn人脸检索 ===")
# query_img = "test_images/query.jpg"
# if os.path.exists(query_img):
# results = face_system.one_vs_many(query_img)
# print("检索结果:")
# for person_id, score in results:
# print(f" {person_id}: {score:.4f}")
#
# # 示例4: 可视化检测结果
# print("\n=== 人脸检测可视化 ===")
# test_img = "test_images/test.jpg"
# if os.path.exists(test_img):
# output_img = face_system.visualize_detection(test_img, "output/detection_result.jpg")
# print("检测结果已保存到 output/detection_result.jpg")
if __name__ == "__main__":
main()