實現功能
讀取圖像數據集,經預訓練模型提取特徵後,將特徵向量保存為二進位格式的embedding.bin和embedding.h文件,方便部署測試。
功能的整體框架
"""Script to generate Face Id embeddings"""import argparseimport os.path as pathfrom mtcnn.mtcnn import MTCNNfrom ai85.ai85_adapter import AI85SimulatorAdapterfrom utils import append_db_file_from_path, save_embedding_db, create_embeddings_include_file
CURRENT_DIR = path.abspath(path.dirname(path.abspath(__file__)))MODEL_PATH = path.join(CURRENT_DIR, 'model', 'ai85-streaming_seqfaceid_nobias_x6.pth.tar')
def create_db_from_folder(args): """ Main function of the script to generate face detector, AI85 simulator and calls the utility functions to generate embeddings and store them in required format. """ face_detector = MTCNN(image_size=80, margin=0, min_face_size=60, thresholds=[0.6, 0.8, 0.92], factor=0.85, post_process=True, device='cpu') ai85_adapter = AI85SimulatorAdapter(MODEL_PATH) embedding_db, _ = append_db_file_from_path(args.db, face_detector, ai85_adapter, db_dict=None, verbose=True) if not embedding_db: print(f'Cannot create a DB file. No face could be detected from the images in folder ', f'`{args.db}`') return save_embedding_db(embedding_db, path.join(CURRENT_DIR, args.db_filename + '.bin'), add_prev_imgs=True) create_embeddings_include_file(CURRENT_DIR, args.db_filename, args.include_path)
def parse_arguments(): """ Function to parse comman line arguments. """ parser = argparse.ArgumentParser(description='Create embedding database file.') parser.add_argument('--db', '-db-path', type=str, default='db', help='path for face images') parser.add_argument('--db-filename', type=str, default='embeddings', help='filename to store embeddings') parser.add_argument('--include-path', type=str, default='embeddings', help='path to include folder')
args = parser.parse_args() return args
def main(): """ Entry point of the script to parse command line arguments and run the function to generate embeddings. """ args = parse_arguments() create_db_from_folder(args)
if __name__ == "__main__": main()其中,圖像讀入模型處理函數append_db_file_from_path如下:
def append_db_file_from_path(folder_path, mtcnn, ai85_adapter, db_dict=None, verbose=True, preview_images=False): """Creates embeddings for each image in the given folder and appends to the existing embedding dictionary """ existing_db_dict = db_dict db_dict = defaultdict(dict) img_list = [] subj_id = 0 subject_list = sorted(os.listdir(folder_path)) for subject in subject_list: print(f'Processing subject: {subject}') img_id = 0 subject_path = os.path.join(folder_path, subject) if not os.path.isdir(subject_path): continue if not os.listdir(subject_path): subj_id += 1 for file in os.listdir(subject_path): print(f'\tFile: {file}') img_path = os.path.join(subject_path, file) img_rot = get_image_rotation(img_path) img = imread(img_path) img = rotate_image(img, img_rot) if img.dtype == np.float32: img = (255 * img).astype(np.uint8) img = get_face_image(img, mtcnn) if img is not None: if img.shape == (160, 120, 3): img_id += 1 img = equalize_hist(img, 99) current_embedding = ai85_adapter.get_network_out(img)[:, :, 0, 0] current_embedding = current_embedding.astype(np.int8).flatten() img_list.append(img) db_dict[subject]['Embedding_%d' % img_id] = {'emb': current_embedding, 'img': img} max_photo = 0 if verbose: if existing_db_dict: print('New entries for the DB') else: print('A new DB with') for idx, subj in enumerate(db_dict.keys()): if verbose: print(f'\t{subj}: {len(db_dict[subj].keys())} images') if len(list(db_dict[subj].keys())) > max_photo: max_photo = len(list(db_dict[subj].keys())) if verbose: if existing_db_dict: print('have been appended!') else: print('has been created!')
preview = None if preview_images: preview = 125*np.ones((len(db_dict.keys()) * 160, max_photo * 120, 3)) for idx, subj in enumerate(db_dict.keys()): start_y = 0 + idx * 160 start_x = 0 for img_ind in db_dict[subj].keys(): preview[start_y:start_y+160, start_x:start_x+120, :] = db_dict[subj][img_ind]['img'] start_x += 120 preview = preview.astype(np.uint8) if verbose: plt.figure(figsize=(1.5*max_photo, 2*len(db_dict.keys()))) plt.imshow(preview) plt.show()
if existing_db_dict: integrated_db = copy.deepcopy(existing_db_dict) for subj in db_dict.keys(): if subj in existing_db_dict.keys(): img_id = max(list(existing_db_dict[subj].keys())) + 1 for ind in db_dict[subj].keys(): integrated_db[subj]['Embedding_%d' % img_id] = integrated_db[subj][ind] img_id += 1 else: integrated_db[subj] = integrated_db[subj] db_dict = integrated_db
return db_dict, preview提取到預處理圖像及其特徵嵌入,即可將其鍵值對存儲到二進位文件中,如下:
def save_embedding_db(emb_db, db_path, add_prev_imgs=False): """ Saves embedding database in binary format 前面幾個是圖像和特徵嵌入的相關信息 The data in order: 1 byte : number of subjects (S) 2 bytes: length of embeddings (L) 2 bytes: number of embeddings (N) 2 bytes: length of image width (W) 2 bytes: length of image height (H) 2 bytes: length of subject names (K) K bytes: subject names (L+1)*N bytes: embeddings 1 byte : subject id L bytes: embedding (W*H*3)*N bytes: image """
subject_names, subject_arr, embedding_arr, img_arr = create_data_arrs(emb_db, add_prev_imgs)
subject_arr = subject_arr.astype(np.uint8) embedding_arr = embedding_arr.astype(np.int8)
names_str = ' '.join(subject_names) names_bytes = bytearray() names_bytes.extend(map(ord, names_str))
S = np.unique(subject_arr).size K = len(names_bytes) N, L = embedding_arr.shape W = 120 H = 160
db_data = bytearray(np.uint8([S])) db_data.extend(L.to_bytes(2, 'big', signed=False)) db_data.extend(N.to_bytes(2, 'big', signed=False)) db_data.extend(W.to_bytes(2, 'big', signed=False)) db_data.extend(H.to_bytes(2, 'big', signed=False))
db_data.extend(K.to_bytes(2, 'big', signed=False)) db_data.extend(names_bytes)
for i, emb in enumerate(embedding_arr): db_data.extend(bytearray([subject_arr[i]])) db_data.extend(bytearray(emb))
if add_prev_imgs: for img in img_arr: db_data.extend(bytearray(img))
with open(db_path, 'wb') as file: file.write(db_data)
print(f'Binary embedding file is saved to "{db_path}".')順便轉成.h文件,方便工程部署:
def create_embeddings_include_file(db_folder, db_filename, include_folder): """Converts binary embedding to a .h file to compile as a C code. """ db_path = os.path.join(db_folder, db_filename + '.bin') data_bin = bytearray()
with open(db_path, "rb") as file: S = int.from_bytes(file.read(1), byteorder='big', signed=False) L = int.from_bytes(file.read(2), byteorder='big', signed=False) N = int.from_bytes(file.read(2), byteorder='big', signed=False) W = int.from_bytes(file.read(2), byteorder='big', signed=False) H = int.from_bytes(file.read(2), byteorder='big', signed=False) K = int.from_bytes(file.read(2), byteorder='big', signed=False)
subject_names_str = file.read(K).decode('ascii') subject_names_list = subject_names_str.split(' ') subject_names_str = '\0'.join(subject_names_list) + '\0'
names_bytes = bytearray() names_bytes.extend(map(ord, subject_names_str))
K = len(subject_names_str)
data_bin.extend(S.to_bytes(1, 'little', signed=False)) data_bin.extend(L.to_bytes(2, 'little', signed=False)) data_bin.extend(N.to_bytes(2, 'little', signed=False)) data_bin.extend(W.to_bytes(2, 'little', signed=False)) data_bin.extend(H.to_bytes(2, 'little', signed=False)) data_bin.extend(K.to_bytes(2, 'little', signed=False))
data_bin.extend(names_bytes)
for _ in range((L+1)*N): next_d = file.read(1) data_bin.extend(next_d)
data_arr = [] data_line = [] for next_d in data_bin: data_line.append(f'0x{next_d:02x}') if (len(data_line) % 18) == 0: data_arr.append(','.join(data_line)) data_line.clear()
data_arr.append(','.join(data_line)) data = ', \\\n '.join(data_arr)
db_h_path = os.path.join(include_folder, db_filename + '.h') with open(db_h_path, 'w') as h_file: h_file.write('#define EMBEDDINGS { \\\n ') h_file.write(data) h_file.write(' \\\n}')
print(f'Embedding file is saved to {db_h_path}')