From 35a6eb819c7d73a339f1effa9a2f0e55d0972080 Mon Sep 17 00:00:00 2001 From: Shivam Singhal Date: Thu, 5 Sep 2024 12:19:20 +0530 Subject: [PATCH 1/5] feat: added batching support in rescaling --- models/utils/detect_face.py | 54 ++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/models/utils/detect_face.py b/models/utils/detect_face.py index 5d144864..4a9313cc 100644 --- a/models/utils/detect_face.py +++ b/models/utils/detect_face.py @@ -6,6 +6,7 @@ import numpy as np import os import math +from collections import defaultdict # OpenCV is optional, but required if using numpy arrays instead of PIL try: @@ -106,11 +107,14 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): # Second stage if len(boxes) > 0: im_data = [] + sizes = [] for k in range(len(y)): if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): - img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) - im_data.append(imresample(img_k, (24, 24))) - im_data = torch.cat(im_data, dim=0) + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]] + im_data.append(img_k) + sizes.append((ey[k] - y[k] + 1, ex[k] - x[k] + 1)) + + im_data = batch_resample_by_size(im_data, sizes, (24, 24), device) im_data = (im_data - 127.5) * 0.0078125 # This is equivalent to out = rnet(im_data) to avoid GPU out of memory. @@ -135,11 +139,14 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): if len(boxes) > 0: y, ey, x, ex = pad(boxes, w, h) im_data = [] + sizes = [] for k in range(len(y)): if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): - img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) - im_data.append(imresample(img_k, (48, 48))) - im_data = torch.cat(im_data, dim=0) + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]] + im_data.append(img_k) + sizes.append((ey[k] - y[k] + 1, ex[k] - x[k] + 1)) + + im_data = batch_resample_by_size(im_data, sizes, (48, 48), device) im_data = (im_data - 127.5) * 0.0078125 # This is equivalent to out = onet(im_data) to avoid GPU out of memory. @@ -306,6 +313,41 @@ def imresample(img, sz): return im_data +def batch_resample_by_size(imgs, sizes, target_size, device): + """ + Batch resampling function grouping by size while preserving order. + + Args: + imgs (list of torch.Tensor): List of image tensors + sizes (list of tuple): List of original sizes (height, width) + target_size (tuple): Target size for resampling (height, width) + device (torch.device): Device to perform computation on + + Returns: + torch.Tensor: Batch of resampled images in original order + """ + # Group images by size + size_groups = defaultdict(list) + size_to_indices = defaultdict(list) + for i, (img, size) in enumerate(zip(imgs, sizes)): + size_groups[size].append(img) + size_to_indices[size].append(i) + + resampled_imgs = torch.zeros(len(imgs), 3, target_size[0], target_size[1], device=device) + for size, group in size_groups.items(): + # Stack images of the same size + batch = torch.stack(group).to(device) + + # Perform batch resample + resampled = interpolate(batch, size=target_size, mode='area') + + # Put resampled images back in their original positions + for resampled_img, original_idx in zip(resampled, size_to_indices[size]): + resampled_imgs[original_idx] = resampled_img + + return resampled_imgs + + def crop_resize(img, box, image_size): if isinstance(img, np.ndarray): img = img[box[1]:box[3], box[0]:box[2]] From e8ee2e7fb5ce79f46a57f0d8517f39008f6f5ed2 Mon Sep 17 00:00:00 2001 From: Shivam Singhal Date: Thu, 5 Sep 2024 12:24:42 +0530 Subject: [PATCH 2/5] version update --- setup.py | 6 +++--- tests/actions_requirements.txt | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 1864989b..51f5fbc5 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import setuptools, os PACKAGE_NAME = 'facenet-pytorch' -VERSION = '2.5.2' +VERSION = '2.5.2.dev0' AUTHOR = 'Tim Esler' EMAIL = 'tim.esler@gmail.com' DESCRIPTION = 'Pretrained Pytorch face detection and recognition models' @@ -39,8 +39,8 @@ 'numpy>=1.24.0,<2.0.0', 'Pillow>=10.2.0,<10.3.0', 'requests>=2.0.0,<3.0.0', - 'torch>=2.2.0,<=2.3.0', - 'torchvision>=0.17.0,<=0.18.0', + 'torch>=2.2.0,<=2.4.0', + 'torchvision>=0.17.0,<=0.19.0', 'tqdm>=4.0.0,<5.0.0', ], ) diff --git a/tests/actions_requirements.txt b/tests/actions_requirements.txt index b74924a5..4e66bdb3 100644 --- a/tests/actions_requirements.txt +++ b/tests/actions_requirements.txt @@ -1,7 +1,7 @@ numpy>=1.24.0,<2.0.0 requests>=2.0.0,<3.0.0 -torch>=2.2.0,<2.3.0 -torchvision>=0.17.0,<0.18.0 +torch>=2.2.0,<=2.4.0 +torchvision>=0.17.0,<=0.19.0 Pillow>=10.2.0,<10.3.0 opencv-python>=4.9.0 scipy>=1.10.0,<2.0.0 From 6dddfa5ed107d5366b4cfbb3e2dfcf8a3020fd2c Mon Sep 17 00:00:00 2001 From: Shivam Singhal Date: Thu, 5 Sep 2024 14:32:18 +0530 Subject: [PATCH 3/5] version update --- models/utils/detect_face.py | 19 +++++++++---------- setup.py | 12 ++++++------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/models/utils/detect_face.py b/models/utils/detect_face.py index 4a9313cc..6f52ced3 100644 --- a/models/utils/detect_face.py +++ b/models/utils/detect_face.py @@ -22,7 +22,6 @@ def fixed_batch_process(im_data, model): out.append(model(batch)) return tuple(torch.cat(v, dim=0) for v in zip(*out)) - def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): if isinstance(imgs, (np.ndarray, torch.Tensor)): if isinstance(imgs,np.ndarray): @@ -107,14 +106,12 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): # Second stage if len(boxes) > 0: im_data = [] - sizes = [] for k in range(len(y)): if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]] im_data.append(img_k) - sizes.append((ey[k] - y[k] + 1, ex[k] - x[k] + 1)) - im_data = batch_resample_by_size(im_data, sizes, (24, 24), device) + im_data = batch_resample_by_size(im_data, (24, 24), device) im_data = (im_data - 127.5) * 0.0078125 # This is equivalent to out = rnet(im_data) to avoid GPU out of memory. @@ -139,14 +136,12 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): if len(boxes) > 0: y, ey, x, ex = pad(boxes, w, h) im_data = [] - sizes = [] for k in range(len(y)): if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]] im_data.append(img_k) - sizes.append((ey[k] - y[k] + 1, ex[k] - x[k] + 1)) - im_data = batch_resample_by_size(im_data, sizes, (48, 48), device) + im_data = batch_resample_by_size(im_data, (48, 48), device) im_data = (im_data - 127.5) * 0.0078125 # This is equivalent to out = onet(im_data) to avoid GPU out of memory. @@ -313,27 +308,31 @@ def imresample(img, sz): return im_data -def batch_resample_by_size(imgs, sizes, target_size, device): +def batch_resample_by_size(imgs, target_size, device): """ Batch resampling function grouping by size while preserving order. Args: imgs (list of torch.Tensor): List of image tensors - sizes (list of tuple): List of original sizes (height, width) target_size (tuple): Target size for resampling (height, width) device (torch.device): Device to perform computation on Returns: torch.Tensor: Batch of resampled images in original order """ + if not imgs: + return torch.zeros((0, 3, target_size[0], target_size[1]), device=device) + # Group images by size size_groups = defaultdict(list) size_to_indices = defaultdict(list) - for i, (img, size) in enumerate(zip(imgs, sizes)): + for i, img in enumerate(imgs): + size = tuple(img.shape[1:]) size_groups[size].append(img) size_to_indices[size].append(i) resampled_imgs = torch.zeros(len(imgs), 3, target_size[0], target_size[1], device=device) + for size, group in size_groups.items(): # Stack images of the same size batch = torch.stack(group).to(device) diff --git a/setup.py b/setup.py index 51f5fbc5..b76999fe 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ import setuptools, os -PACKAGE_NAME = 'facenet-pytorch' -VERSION = '2.5.2.dev0' -AUTHOR = 'Tim Esler' -EMAIL = 'tim.esler@gmail.com' -DESCRIPTION = 'Pretrained Pytorch face detection and recognition models' -GITHUB_URL = 'https://github.com/timesler/facenet-pytorch' +PACKAGE_NAME = 'facenet-pytorch-custom' +VERSION = '2.5.2.dev1' +AUTHOR = 'Shivam Singhal' +EMAIL = 'shivamsinghal1012@gmail.com' +DESCRIPTION = 'Pretrained Pytorch face detection and recognition models original - https://github.com/timesler/facenet-pytorch' +GITHUB_URL = 'https://github.com/ShivamSinghal1/facenet-pytorch' parent_dir = os.path.dirname(os.path.realpath(__file__)) import_name = os.path.basename(parent_dir) From df9c0e52f27d389c0dc51ac8102cb446eed353fb Mon Sep 17 00:00:00 2001 From: Shivam Singhal Date: Thu, 5 Sep 2024 16:05:36 +0530 Subject: [PATCH 4/5] version update --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b76999fe..6064a5a5 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import setuptools, os PACKAGE_NAME = 'facenet-pytorch-custom' -VERSION = '2.5.2.dev1' +VERSION = '2.5.2.dev2' AUTHOR = 'Shivam Singhal' EMAIL = 'shivamsinghal1012@gmail.com' DESCRIPTION = 'Pretrained Pytorch face detection and recognition models original - https://github.com/timesler/facenet-pytorch' From a099accf1c08f02125b0f3f439107b18cda0d8e0 Mon Sep 17 00:00:00 2001 From: Shivam Singhal Date: Thu, 5 Sep 2024 20:07:31 +0530 Subject: [PATCH 5/5] minor refactoring --- models/utils/detect_face.py | 15 ++++++++------- setup.py | 12 ++++++------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/models/utils/detect_face.py b/models/utils/detect_face.py index 6f52ced3..6c0c794f 100644 --- a/models/utils/detect_face.py +++ b/models/utils/detect_face.py @@ -22,6 +22,7 @@ def fixed_batch_process(im_data, model): out.append(model(batch)) return tuple(torch.cat(v, dim=0) for v in zip(*out)) + def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): if isinstance(imgs, (np.ndarray, torch.Tensor)): if isinstance(imgs,np.ndarray): @@ -310,15 +311,15 @@ def imresample(img, sz): def batch_resample_by_size(imgs, target_size, device): """ - Batch resampling function grouping by size while preserving order. + Batch resampling function grouping by size while preserving order. - Args: - imgs (list of torch.Tensor): List of image tensors - target_size (tuple): Target size for resampling (height, width) - device (torch.device): Device to perform computation on + Args: + imgs (list of torch.Tensor): List of image tensors + target_size (tuple): Target size for resampling (height, width) + device (torch.device): Device to perform computation on - Returns: - torch.Tensor: Batch of resampled images in original order + Returns: + torch.Tensor: Batch of resampled images in original order """ if not imgs: return torch.zeros((0, 3, target_size[0], target_size[1]), device=device) diff --git a/setup.py b/setup.py index 6064a5a5..72d3692d 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ import setuptools, os -PACKAGE_NAME = 'facenet-pytorch-custom' -VERSION = '2.5.2.dev2' -AUTHOR = 'Shivam Singhal' -EMAIL = 'shivamsinghal1012@gmail.com' -DESCRIPTION = 'Pretrained Pytorch face detection and recognition models original - https://github.com/timesler/facenet-pytorch' -GITHUB_URL = 'https://github.com/ShivamSinghal1/facenet-pytorch' +PACKAGE_NAME = 'facenet-pytorch' +VERSION = '2.5.4' +AUTHOR = 'Tim Esler' +EMAIL = 'tim.esler@gmail.com' +DESCRIPTION = 'Pretrained Pytorch face detection and recognition models' +GITHUB_URL = 'https://github.com/timesler/facenet-pytorch' parent_dir = os.path.dirname(os.path.realpath(__file__)) import_name = os.path.basename(parent_dir)