310 lines
12 KiB
Python
310 lines
12 KiB
Python
import cv2
|
||
import insightface
|
||
import numpy as np
|
||
import os
|
||
import datetime
|
||
import argparse
|
||
from insightface.app import FaceAnalysis
|
||
|
||
# 设备配置映射(NPU采用用户指定的完整参数)
|
||
DEVICE_CONFIG = {
|
||
"cpu": (['CPUExecutionProvider'], -1),
|
||
"gpu": (['CUDAExecutionProvider'], 0),
|
||
"npu": (
|
||
[
|
||
(
|
||
"CANNExecutionProvider",
|
||
{
|
||
"device_id": 1,
|
||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||
"npu_mem_limit": 16*1024*1024*1024,
|
||
"op_select_impl_mode": "high_precision",
|
||
"precision_mode": "allow_fp32_to_fp16",
|
||
"enable_cann_graph": True,
|
||
|
||
},
|
||
),
|
||
"CPUExecutionProvider"
|
||
],
|
||
0
|
||
)
|
||
}
|
||
#allow_fp32_to_fp16
|
||
#核心配置参数
|
||
THRESHOLD = 0.65
|
||
IMAGE_EXTENSIONS = ('.jpg','.jepg','.png','.bmp','.gif')
|
||
NPU_REQUIREMENTS = {
|
||
"依赖包": "onnxruntime-cann(华为官方)+ onnxruntime(基础)",
|
||
"驱动要求": "Ascend CANN Toolkit ≥ 5.0.3",
|
||
"硬件要求": "华为昇腾芯片(如Ascend 310/910)",
|
||
"文档链接": "https://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html"
|
||
}
|
||
|
||
# Initialize face analysis model
|
||
#app = FaceAnalysis(name='buffalo_l', providers=['CPUExecutionProvider']) # Use 'CUDAExecutionProvider' for GPU
|
||
#app.prepare(ctx_id=-1) # ctx_id=-1 for CPU, 0 for GPU
|
||
|
||
def parse_args():
|
||
"""解析命令行参数"""
|
||
parser = argparse.ArgumentParser(description="人脸特征提取与同一人照片分组工具(支持CPU/GPU/NPU)")
|
||
parser.add_argument("-d","--device", type=str, choices=DEVICE_CONFIG.keys(),
|
||
default="cpu", help=f"指定运行设备")
|
||
parser.add_argument("-t","--threshold", type=float, default=THRESHOLD, help=f"相似度阈值,值越大匹配越严格")
|
||
parser.add_argument("--npu-device-id", type=int, default="0", help=f"覆盖NPU设备ID")
|
||
parser.add_argument("--npu-mem-limit", type=int, default="16", help=f"NPU内存限制")
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
def get_face_embedding(image_path, app):
|
||
"""Extract face embedding from an image"""
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
raise ValueError(f"Could not read image: {image_path}")
|
||
|
||
faces = app.get(img)
|
||
|
||
if len(faces) < 1:
|
||
raise ValueError("No faces detected in the image")
|
||
if len(faces) > 1:
|
||
print("Warning: Multiple faces detected. Using first detected face")
|
||
|
||
return faces[0].embedding
|
||
|
||
def compare_faces(emb1, emb2, threshold): # Adjust this threshold according to your usecase.
|
||
"""Compare two embeddings using cosine similarity"""
|
||
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
||
return similarity, similarity > threshold
|
||
|
||
def get_all_images_files(directory="."):
|
||
"""获取目录下所有支持的图片文件"""
|
||
image_files = []
|
||
for filename in os.listdir(directory):
|
||
file_path = os.path.abspath(filename)
|
||
if filename.lower().endswith(IMAGE_EXTENSIONS) and os.path.isfile(file_path):
|
||
image_files.append(file_path)
|
||
|
||
return sorted(image_files)
|
||
|
||
def find_same_person_groups(embedding_dict, threshold):
|
||
"""基于特征向量分组同一人照片"""
|
||
ungrouped_files = list(embedding_dict.keys())
|
||
same_person_groups = []
|
||
single_groups = []
|
||
|
||
while ungrouped_files:
|
||
current_file =ungrouped_files.pop(0)
|
||
current_emb = embedding_dict[current_file]
|
||
current_group = [current_file]
|
||
to_remove = []
|
||
|
||
for candidate_file in ungrouped_files:
|
||
candidate_emb = embedding_dict[candidate_file]
|
||
similarity, is_same = compare_faces(current_emb, candidate_emb, threshold)
|
||
if is_same:
|
||
current_group.append(candidate_file)
|
||
to_remove.append(candidate_file)
|
||
print(f"匹配成功:{os.path.basename(current_file)}和{os.path.basename(candidate_file)}(相似度:{similarity:.4f})")
|
||
for file in to_remove:
|
||
ungrouped_files.remove(file)
|
||
|
||
if len(current_group) >= 2:
|
||
same_person_groups.append(current_group)
|
||
else:
|
||
single_groups.append(current_group)
|
||
|
||
return same_person_groups, single_groups
|
||
|
||
def generate_log_filename(device):
|
||
"""生成带设备类型和时间戳的日志文件名"""
|
||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
return f"face_embeddings_{device}_{timestamp}.log"
|
||
def write_embedding_log(log_entries, args, npu_config, log_filename):
|
||
"""写入详细日志"""
|
||
with open(log_filename, 'w', encoding='utf-8') as f:
|
||
f.write("="*80 + "\n")
|
||
f.write(f"人脸特征向量分析日志\n")
|
||
f.write(f"生成时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||
f.write(f"运行设备:{args.device}\n")
|
||
if args.device == "npu":
|
||
f.write(f"NPU配置:{npu_config}\n")
|
||
f.write(f"相似度阈值:{args.threshold}\n")
|
||
f.write(f"支持图片格式:{','.join(IMAGE_EXTENSIONS)}\n")
|
||
f.write(f"处理图片总数:{len(log_entries)}\n")
|
||
f.write("="*80 + "\n\n")
|
||
|
||
for entry in log_entries:
|
||
f.write(entry + "\n" + "-"*60 + "\n\n")
|
||
|
||
print(f"\n 日志文件已保存至:{os.path.abspath(log_filename)}")
|
||
|
||
def check_npu_environment():
|
||
"""检查NPU环境是否满足要求(仅检查核心依赖,不验证环境变量)"""
|
||
try:
|
||
import onnxruntime as ort
|
||
available_providers = ort.get_available_providers()
|
||
if 'CANNExecutionProvider' not in available_providers:
|
||
return False, "未检测到CANNExecutionProvider(请安装onnxruntime-cann包)"
|
||
|
||
return True,"NPU核心依赖检查通过"
|
||
except ImportError:
|
||
return False, "未安装onnxruntime(基础依赖)"
|
||
except Exception as e:
|
||
return False, f"NPU环境检查失败: {str(e)}"
|
||
|
||
def main():
|
||
args = parse_args()
|
||
device = args.device
|
||
threshold = args.threshold
|
||
npu_device_id = args.npu_device_id
|
||
npu_mem_limit = args.npu_mem_limit*1024*1024*1024
|
||
#生成带设备和时间戳的日志文件名
|
||
log_filename = generate_log_filename(device)
|
||
|
||
#打印启动信息
|
||
print("人脸特征提取与同一人照片分组工具(支持CPU/GPU/NPU)")
|
||
print("="*80)
|
||
print(f"核心配置:")
|
||
print(f"运行设备:{device}")
|
||
print(f"相似度阈值:{threshold}")
|
||
print(f"搜索目录:{os.getcwd()}")
|
||
print(f"支持格式:{','.join(IMAGE_EXTENSIONS)}")
|
||
print(f"日志文件:{log_filename}")
|
||
if device == "npu":
|
||
print(f"NPU设备ID:{npu_device_id}")
|
||
print(f"NPU内存限制:{args.npu_mem_limit}GB")
|
||
print("="*80 + "\n")
|
||
|
||
#设备环境预检查
|
||
if device == "gpu":
|
||
available_providers = insightface.utils.get_available_providers()
|
||
if 'CUDAExecutionProvider' not in available_providers:
|
||
print(" 警告:未检测到CUDA环境,GPU模式可能运行失败!")
|
||
print(" 解决方案:1.安装CUDA≥11.0 + cuDNN 2.安装onnxruntime-gpu 3. 切换至cpu模式")
|
||
elif device == "npu":
|
||
npu_ok,npu_msg = check_npu_environment()
|
||
print(f" NPU环境检查:{npu_msg}")
|
||
if not npu_ok:
|
||
print(" NPU环境不满足要求,建议按以下步骤配置:")
|
||
for key, value in NPU_REQUIREMENTS.items():
|
||
print(f" - {key}:{value}")
|
||
return
|
||
|
||
|
||
#初始化模型
|
||
try:
|
||
providers, ctx_id = DEVICE_CONFIG[device]
|
||
npu_config = None
|
||
|
||
if device == "npu":
|
||
npu_provider = list(providers[0])
|
||
npu_provider[1]["device_id"] = npu_device_id
|
||
npu_provider[1]["npu_mem_limit"] = npu_mem_limit
|
||
providers[0] = tuple(npu_provider)
|
||
npu_config = providers[0][1]
|
||
ctx_id = npu_device_id
|
||
app = FaceAnalysis(name='buffalo_l', providers=providers)
|
||
else:
|
||
app = FaceAnalysis(name='buffalo_l', providers=providers)
|
||
|
||
app.prepare(ctx_id=ctx_id)
|
||
print(f" 模型初始化成功(设备:{device}, ctx_id: {ctx_id})")
|
||
if device == "npu":
|
||
print(f" NPU最终配置:")
|
||
for key, value in npu_config.items():
|
||
if key == "npu_mem_limit":
|
||
print(f" - {key}: {value/(1024*1024*1024)}GB")
|
||
else:
|
||
print(f" - {key}: {value}")
|
||
except Exception as e:
|
||
print(f"模型初始化失败!")
|
||
print(f"Error:{str(e)}")
|
||
print(f" 解决方案:")
|
||
if device == "gpu":
|
||
print(" 1.确认CUDA驱动已安装 2.确认onnxruntime-gpu版本与CUDA匹配 3.尝试切换CPU模式")
|
||
elif device == "npu":
|
||
print(" 1.确认CANN Toolkit已正确安装 2.确认onnxruntime-cann版本兼容 3.检查设备ID和内存限制是否合理")
|
||
return
|
||
|
||
#1. 获取所有图片文件
|
||
image_files = get_all_images_files()
|
||
if not image_files:
|
||
print("未找到任何支持的图片文件(检查目录下是否有jpg/png等格式图片)")
|
||
return
|
||
print(f" 找到{len(image_files)}个图片文件,开始提取特征向量...\n")
|
||
|
||
#2. 提取特征向量并记录日志
|
||
embedding_dict = {}
|
||
log_entries = []
|
||
for img_path in image_files:
|
||
img_name = os.path.basename(img_path)
|
||
try:
|
||
emb = get_face_embedding(img_path, app)
|
||
embedding_dict[img_name] = emb
|
||
log_entry = f"【文件】:{img_name}\n" \
|
||
f"【路径】:{img_path}\n" \
|
||
f"【状态】:成功\n" \
|
||
f"【特征向量维度】:{len(emb)}\n" \
|
||
f"【特征向量】:{emb.tolist()}"
|
||
log_entries.append(log_entry)
|
||
print(f" 处理成功: {img_name}(特征向量维度:{len(emb)})")
|
||
except Exception as e:
|
||
log_entry = f"【文件】:{img_name}\n" \
|
||
f"【路径】:{img_path}\n" \
|
||
f"【状态】:失败\n" \
|
||
f":【错误信息】:{str(e)}"
|
||
log_entries.append(log_entry)
|
||
print(f" 处理失败: {img_name} - 原因:{str(e)}")
|
||
|
||
#3. 写入日志文件(使用动态生成的文件名)
|
||
write_embedding_log(log_entries, args, npu_config, log_filename)
|
||
|
||
#4. 同一人分组分析
|
||
if not embedding_dict:
|
||
print("\n 没有成功提取到人脸特征向量,无法进行分组分析")
|
||
return
|
||
|
||
print(f"\n 开始分组分析(有效人脸数:{len(embedding_dict)})...")
|
||
same_groups, single_groups = find_same_person_groups(embedding_dict, threshold)
|
||
|
||
#5. 输出分组结果
|
||
print("\n" + "="*80)
|
||
print("同一人照片分组结果(每组≥2张):")
|
||
print("="*80)
|
||
if same_groups:
|
||
for i, group in enumerate(same_groups, 1):
|
||
group_names = [os.path.basename(file) for file in group]
|
||
print(f"组{i}:{', '.join(group_names)}")
|
||
else:
|
||
print(f" 未发现同一人的多张照片")
|
||
|
||
print("\n" + "="*80)
|
||
print("无匹配的单独照片:")
|
||
#print("\n"*80)
|
||
if single_groups:
|
||
for group in single_groups:
|
||
print(f" - {os.path.basename(group[0])}")
|
||
else:
|
||
print(f"所有照片均已分组(无单独照片)")
|
||
print("\n" + "="*80)
|
||
print(f" 处理完成!详细日志请查看:{log_filename}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
# Paths to your Indian face images
|
||
#image1_path = "path/to/face1.jpg"
|
||
#image2_path = "path/to/face2.jpg"
|
||
|
||
#try:
|
||
# Get embeddings
|
||
# emb1 = get_face_embedding(image1_path)
|
||
# emb2 = get_face_embedding(image2_path)
|
||
|
||
# Compare faces
|
||
# similarity_score, is_same_person = compare_faces(emb1, emb2)
|
||
|
||
# print(f"Similarity Score: {similarity_score:.4f}")
|
||
# print(f"Same person? {'YES' if is_same_person else 'NO'}")
|
||
|
||
#except Exception as e:
|
||
# print(f"Error: {str(e)}")
|