249 lines
8.2 KiB
Python
249 lines
8.2 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
from insightface.app import FaceAnalysis
|
||
from insightface.data import get_image as ins_get_image
|
||
import pickle
|
||
from typing import List, Tuple, Dict
|
||
import json
|
||
|
||
|
||
class FaceRecognitionSystem:
|
||
def __init__(self, model_name: str = 'buffalo_l'):
|
||
"""
|
||
初始化人脸识别系统
|
||
Args:
|
||
model_name: 模型名称,可选 'buffalo_l', 'buffalo_s' 等
|
||
"""
|
||
self.app = FaceAnalysis(name=model_name, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
||
self.face_database = {} # 存储人脸特征向量
|
||
self.database_file = "face_database.pkl"
|
||
|
||
# 加载已有数据库
|
||
self.load_database()
|
||
|
||
def extract_face_features(self, image_path: str) -> List[Dict]:
|
||
"""
|
||
从图像中提取人脸特征
|
||
Args:
|
||
image_path: 图像路径
|
||
Returns:
|
||
List[Dict]: 包含人脸信息的列表
|
||
"""
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
raise ValueError(f"无法读取图像: {image_path}")
|
||
|
||
faces = self.app.get(img)
|
||
results = []
|
||
|
||
for i, face in enumerate(faces):
|
||
face_info = {
|
||
'bbox': face.bbox.astype(int).tolist(), # 人脸框
|
||
'kps': face.kps.astype(int).tolist(), # 关键点
|
||
'embedding': face.embedding.tolist(), # 特征向量
|
||
'gender': face.gender, # 性别
|
||
'age': face.age # 年龄
|
||
}
|
||
results.append(face_info)
|
||
|
||
return results
|
||
|
||
def register_face(self, image_path: str, person_id: str):
|
||
"""
|
||
注册人脸到数据库
|
||
Args:
|
||
image_path: 图像路径
|
||
person_id: 人员ID
|
||
"""
|
||
faces = self.extract_face_features(image_path)
|
||
|
||
if not faces:
|
||
print(f"在图像 {image_path} 中未检测到人脸")
|
||
return False
|
||
|
||
if len(faces) > 1:
|
||
print(f"在图像 {image_path} 中检测到多个人脸,将使用第一个人脸")
|
||
|
||
# 存储第一个人脸的特征
|
||
self.face_database[person_id] = {
|
||
'embedding': np.array(faces[0]['embedding']),
|
||
'image_path': image_path
|
||
}
|
||
|
||
self.save_database()
|
||
print(f"成功注册人脸: {person_id}")
|
||
return True
|
||
|
||
def compare_faces(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
||
"""
|
||
计算两个人脸特征的相似度
|
||
Args:
|
||
embedding1: 特征向量1
|
||
embedding2: 特征向量2
|
||
Returns:
|
||
float: 相似度得分 (0-1之间,越大越相似)
|
||
"""
|
||
# 计算余弦相似度
|
||
similarity = np.dot(embedding1, embedding2) / (
|
||
np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
|
||
)
|
||
return float(similarity)
|
||
|
||
def one_vs_one(self, image_path1: str, image_path2: str) -> Tuple[float, bool]:
|
||
"""
|
||
1v1人脸比对
|
||
Args:
|
||
image_path1: 图像1路径
|
||
image_path2: 图像2路径
|
||
Returns:
|
||
Tuple[float, bool]: (相似度得分, 是否同一人)
|
||
"""
|
||
faces1 = self.extract_face_features(image_path1)
|
||
faces2 = self.extract_face_features(image_path2)
|
||
|
||
if not faces1 or not faces2:
|
||
return 0.0, False
|
||
|
||
embedding1 = np.array(faces1[0]['embedding'])
|
||
embedding2 = np.array(faces2[0]['embedding'])
|
||
|
||
similarity = self.compare_faces(embedding1, embedding2)
|
||
is_same = similarity > 0.6 # 阈值可根据实际情况调整
|
||
|
||
return similarity, is_same
|
||
|
||
def one_vs_many(self, image_path: str, threshold: float = 0.6) -> List[Tuple[str, float]]:
|
||
"""
|
||
1vn人脸检索
|
||
Args:
|
||
image_path: 查询图像路径
|
||
threshold: 相似度阈值
|
||
Returns:
|
||
List[Tuple[str, float]]: 匹配结果 (人员ID, 相似度)
|
||
"""
|
||
faces = self.extract_face_features(image_path)
|
||
if not faces:
|
||
return []
|
||
|
||
query_embedding = np.array(faces[0]['embedding'])
|
||
results = []
|
||
|
||
for person_id, data in self.face_database.items():
|
||
similarity = self.compare_faces(query_embedding, data['embedding'])
|
||
if similarity > threshold:
|
||
results.append((person_id, similarity))
|
||
|
||
# 按相似度降序排序
|
||
results.sort(key=lambda x: x[1], reverse=True)
|
||
return results
|
||
|
||
def save_database(self):
|
||
"""保存人脸数据库到文件"""
|
||
# 将numpy数组转换为列表以便序列化
|
||
save_data = {}
|
||
for person_id, data in self.face_database.items():
|
||
save_data[person_id] = {
|
||
'embedding': data['embedding'].tolist(),
|
||
'image_path': data['image_path']
|
||
}
|
||
|
||
with open(self.database_file, 'wb') as f:
|
||
pickle.dump(save_data, f)
|
||
|
||
def load_database(self):
|
||
"""从文件加载人脸数据库"""
|
||
if os.path.exists(self.database_file):
|
||
with open(self.database_file, 'rb') as f:
|
||
save_data = pickle.load(f)
|
||
|
||
# 将列表转换回numpy数组
|
||
for person_id, data in save_data.items():
|
||
self.face_database[person_id] = {
|
||
'embedding': np.array(data['embedding']),
|
||
'image_path': data['image_path']
|
||
}
|
||
print(f"已加载数据库,包含 {len(self.face_database)} 个人脸")
|
||
|
||
def visualize_detection(self, image_path: str, save_path: str = None):
|
||
"""
|
||
可视化人脸检测结果
|
||
Args:
|
||
image_path: 图像路径
|
||
save_path: 保存路径
|
||
"""
|
||
img = cv2.imread(image_path)
|
||
faces = self.app.get(img)
|
||
|
||
for face in faces:
|
||
# 绘制人脸框
|
||
bbox = face.bbox.astype(int)
|
||
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||
|
||
# 绘制关键点
|
||
for kp in face.kps.astype(int):
|
||
cv2.circle(img, (kp[0], kp[1]), 2, (0, 0, 255), -1)
|
||
|
||
# 显示性别和年龄
|
||
info = f"{'M' if face.gender == 1 else 'F'}/{face.age}"
|
||
cv2.putText(img, info, (bbox[0], bbox[1] - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
||
|
||
if save_path:
|
||
cv2.imwrite(save_path, img)
|
||
|
||
return img
|
||
|
||
|
||
# 使用示例
|
||
def main():
|
||
# 初始化系统
|
||
face_system = FaceRecognitionSystem()
|
||
|
||
# # 创建测试图像目录
|
||
# os.makedirs("test_images", exist_ok=True)
|
||
#
|
||
# # 示例1: 注册人脸
|
||
# print("=== 注册人脸 ===")
|
||
# # 假设你有一些人脸图像放在 test_images 目录下
|
||
# test_images = {
|
||
# "person_001": "test_images/person1.png",
|
||
# "person_002": "test_images/person2.jpg",
|
||
# # "person_003": "test_images/person3.jpg"
|
||
# }
|
||
#
|
||
# for person_id, img_path in test_images.items():
|
||
# if os.path.exists(img_path):
|
||
# face_system.register_face(img_path, person_id)
|
||
|
||
# 示例2: 1v1比对
|
||
print("\n=== 1v1人脸比对 ===")
|
||
img1 = "test_data/register/person1.png"
|
||
# img1 = "test_data/query/file___media_Photo_103_IMG_1737872809_085_IMG_20250126_142509.jpg"
|
||
img2 = "test_data/query/file___media_Photo_152_IMG_1737876072_134_IMG_20250126_151932.jpg"
|
||
|
||
|
||
if os.path.exists(img1) and os.path.exists(img2):
|
||
similarity, is_same = face_system.one_vs_one(img1, img2)
|
||
print(f"相似度: {similarity:.4f}, 是否同一人: {is_same}")
|
||
|
||
# # 示例3: 1vn检索
|
||
# print("\n=== 1vn人脸检索 ===")
|
||
# query_img = "test_images/query.jpg"
|
||
# if os.path.exists(query_img):
|
||
# results = face_system.one_vs_many(query_img)
|
||
# print("检索结果:")
|
||
# for person_id, score in results:
|
||
# print(f" {person_id}: {score:.4f}")
|
||
#
|
||
# # 示例4: 可视化检测结果
|
||
# print("\n=== 人脸检测可视化 ===")
|
||
# test_img = "test_images/test.jpg"
|
||
# if os.path.exists(test_img):
|
||
# output_img = face_system.visualize_detection(test_img, "output/detection_result.jpg")
|
||
# print("检测结果已保存到 output/detection_result.jpg")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |