diff --git a/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/__init__.py b/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/__init__.py new file mode 100644 index 000000000..4f7f84119 --- /dev/null +++ b/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/__init__.py @@ -0,0 +1,2 @@ +from .pytorch_model_descriptors import PytorchModelDescriptor +DESCRIPTOR_GENERATOR_CLASS = PytorchModelDescriptor diff --git a/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/pytorch_model_descriptors.py b/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/pytorch_model_descriptors.py new file mode 100644 index 000000000..b82b92f6e --- /dev/null +++ b/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/pytorch_model_descriptors.py @@ -0,0 +1,387 @@ +from smqtk.algorithms.descriptor_generator import DescriptorGenerator, \ + DFLT_DESCRIPTOR_FACTORY +from smqtk.utils.cli import ProgressReporter + +from collections import deque +import multiprocessing +import multiprocessing.pool +import six +import logging + +try: + import torch + import torchvision + from torch.utils.data import DataLoader + from torch.autograd import Variable + from .utils import PytorchImagedataset +except ImportError as ex: + logging.warning("Failed to import torch/torchvision " + "module: %s", str(ex)) + torch = None + torchvision = None + +__all__ = [ + "PytorchModelDescriptor", +] + + +class PytorchModelDescriptor (DescriptorGenerator): + """ + Compute images against a PyTorch model, extracting a layer as the content + descriptor. + """ + + @classmethod + def is_usable(cls): + valid = (torch is None) or (torchvision is None) + if valid: + cls.get_logger().debug("PyTorch or torchvision cannot be imported") + return (not valid) + + + def truncate_pytorch_model(self, model, t1_model): + """ + Given a pytorch model and label of layer, the function returns a + model truncated at return layer. + :param model: The pytorch model that needs to be truncated at + a certain return layer in network. + :type model: torch.nn.Sequential + :param t1_model: The pytorch sequential block of layers containing + the final return layer key. + :type t1_model: torch.nn.Sequential + + :return model: Model truncated till the given sub module return + layer + :rtype: torch.nn.Sequential + """ + # Extract children of submodule return_key1 + sub_module_list = [_ for _ in t1_model.named_children()] + for inx, lay in enumerate(sub_module_list): + if self.return_layer[1] == lay[0]: + sub_pos = inx + break + trunc_pos = len(sub_module_list) - (sub_pos+1) + model_sub_ = torch.nn.Sequential(*(list(t1_model.children())) + [:-trunc_pos]) + setattr(locals().get("model"), self.return_layer[0], model_sub_) + return model + + def check_model_truncate(self, model): + """ + Checks model dictionary to see if the top return layer is present. + :param model: Base model to be checked for presense of layer + :type model: torch.nn + :param model: The final model truncated to layer return_key2 if + present, otherwise the model to return_key1. + :type model: torch.nn.Sequential + """ + try: + # We currently support iterating through only two levels of the network + assert len(self.return_layer) < 3 + if self.return_layer[0] is not '': + assert model._modules[self.return_layer[0]] + module_list = list(model.__dict__['_modules']) + layer_position = (module_list.index(self.return_layer[0])) + if len(self.return_layer) == 1: + # If return_key1 is the last submodule + if len(module_list) == layer_position: + return model + else: + # If no submodule i.e return_key2 return + # truncated model + model = torch.nn.Sequential(*(list(model.children()) + [:layer_position+1])) + # Return the last submodule that needs to be truncated further. + if len(self.return_layer) == 2: + last_stage = torch.nn.Sequential((list(model.children()) + [layer_position]))[0] + # If we want to truncate submodule return_key1 + model = self.truncate_pytorch_model(model, last_stage) + return model + except KeyError: + self._log.error("Given return layer is " + "invalid:{}".format(self.return_layer)) + raise + + def __init__(self, + model_name = 'resnet18', return_layer = 'avgpool', + custom_model_arch = None, weights_filepath = None, + input_dim = (224, 224), norm_mean = [0.485, 0.456, 0.406], + norm_std = [0.229, 0.224, 0.225], use_gpu = True, + batch_size = 32, pretrained = True): + """ + Create a PyTorch CNN descriptor generator + :param model_name: Name of model on PyTorch library, + for example: 'resnet50', 'vgg16'. + :type model_name: str + :param return_layer: The label of the layer we take data from + to compose output descriptor vector. + :type return_layer: str + :param custom_model_arch: Method that implements a custom Pytorch + model. + :type custom_model_arch: torch.nn + :param weights_filepath: Absolute file path to weights of a custom + model custom_model_arch. + :type weights_filepath: str + :param input_dim: Image height and width of an input image. + :type input_dim: (int, int) + :param norm_mean: Mean for normalizing images across three channels. + :type norm_mean: List [float, float, float]. + :param norm_std: Standard deviation for normalizing images across + three channels. + :type norm_std: List [float, float, float]. + :param use_gpu: If Caffe should try to use the GPU + :type use_gpu: bool + :param batch_size: The maximum number of images to process in one feed + forward of the network. This is especially important for GPUs since + they can only process a batch that will fit in the GPU memory + space. + :type batch_size: int + :param pretrained: The network is loaded with pretrained weights + available on torchvision instead of custom weights. + :type pretrained: bool + """ + self.model_name = model_name + self.transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize((input_dim[0], input_dim[1])), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(norm_mean, norm_std)]) + self.batch_size = batch_size + self.input_dim = input_dim + self.norm_mean = norm_mean + self.norm_std = norm_std + self.use_gpu = use_gpu + self.pretrained = pretrained + self.weights_filepath = weights_filepath + self.custom_model_arch = custom_model_arch + # Check if user wants to load custom model or a model from torchvision + if not custom_model_arch: + try: + assert model_name in torchvision.models.__dict__.keys() + except AssertionError: + self._log.error("Invalid model name, model not present " + "in torchvision. Please load network architecture") + self._log.info("Available models include:{}" + .format([s for s in torchvision.models.__dict__.keys() + if not "__" in s])) + raise + # Loading model from torchvision library + model = getattr(torchvision.models, self.model_name)(self.pretrained) + else: + # If custom architecture + model = custom_model_arch + if (not self.pretrained) and (self.weights_filepath): + checkpoint = torch.load(self.weights_filepath) + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + model.load_state_dict(checkpoint) + if (not self.pretrained) and (not self.weights_filepath): + self._log.error("Network might be loaded with junk weights") + raise ValueError + self.return_layer = [k for k in return_layer.split('.')] + # We currently support iterating through only two levels of the network + # i.e return_layer1 and return_layer2 + # Check if return_layer1 is present in model and truncate the sub + # module containing return_key2. + model = self.check_model_truncate(model) + model.eval() + + if self.use_gpu: + try: + model = model.cuda() + self.model = torch.nn.DataParallel(model) + except ValueError: + self.model = model + self._log.info("Cannot load PyTorch model to GPU, running on CPU") + + try: + assert model + except AssertionError: + self._log.info("Selected model{}".format(sub_model)) + raise ("Model could not be loaded") + + def __getstate__(self): + return self.get_config() + + def _setup_network(self): + pass + #raise NotImplementedError("Nada") + + def __setstate__(self, state): + # This works because configuration parameters exactly match up with + # instance attributes + self.__dict__.update(state) + self._setup_network() + + def get_config(self): + """ + Return a JSON-compliant dictionary that could be passed to this class's + ``from_config`` method to produce an instance with identical + configuration. + In the common case, this involves naming the keys of the dictionary + based on the initialization argument names as if it were to be passed + to the constructor via dictionary expansion. + :return: JSON type compliant configuration dictionary. + :rtype: dict + """ + return { + 'model_name': self.model_name, + 'return_layer': '.'.join(self.return_layer), + 'custom_model_arch': self.custom_model_arch, + 'weights_filepath': self.weights_filepath, + 'input_dim': self.input_dim, + 'norm_mean': self.norm_mean, + 'norm_std': self.norm_std, + 'use_gpu': self.use_gpu, + 'batch_size': self.batch_size, + 'pretrained': self.pretrained, + } + + def valid_content_types(self): + """ + :return: A set valid MIME type content types that this descriptor can + handle. + :rtype: set[str] + """ + return { + 'image/bmp', + 'image/tiff', + 'image/png', + 'image/jpeg', + } + + def compute_descriptor(self, data, descr_factory=DFLT_DESCRIPTOR_FACTORY, + overwrite=False): + """ + Given some data, return a descriptor element containing a descriptor + vector. + :raises RuntimeError: Descriptor extraction failure of some kind. + :raises ValueError: Given data element content was not of a valid type + with respect to this descriptor. + :param data: Some kind of input data for the feature descriptor. + :type data: smqtk.representation.DataElement + :param descr_factory: Factory instance to produce the wrapping + descriptor element instance. The default factory produces + ``DescriptorMemoryElement`` instances by default. + :type descr_factory: smqtk.representation.DescriptorElementFactory + :param overwrite: Whether or not to force re-computation of a descriptor + vector for the given data even when there exists a precomputed + vector in the generated DescriptorElement as generated from the + provided factory. This will overwrite the persistently stored vector + if the provided factory produces a DescriptorElement implementation + with such storage. + :type overwrite: bool + :return: Result descriptor element. UUID of this output descriptor is + the same as the UUID of the input data element. + :rtype: smqtk.representation.DescriptorElement + """ + m = self.compute_descriptor_async([data], descr_factory) + return m[data.uuid()] + + def _compute_descriptor(self, data): + raise NotImplementedError("Shouldn't get here as " + "compute_descriptor[_async] is being " + "overridden") + + def check_get_uuid(self, descriptor_elem): + if self.overwrite or not descriptor_elem.has_vector(): + self.uuid4proc.append(descriptor_elem.uuid()) + + def compute_descriptor_async(self, data_set, descriptor_elem_factory= + DFLT_DESCRIPTOR_FACTORY, overwrite=False): + """ + Asynchronously compute feature data for multiple data items. + :param data_iter: Iterable of data elements to compute features for. + These must have UIDs assigned for feature association in return + value. + :type data_iter: collections.Iterable[smqtk.representation.DataElement] + :param descr_factory: Factory instance to produce the wrapping + descriptor element instance. The default factory produces + ``DescriptorMemoryElement`` instances by default. + :type descr_factory: smqtk.representation.DescriptorElementFactory + :param overwrite: Whether or not to force re-computation of a descriptor + vectors for the given data even when there exists precomputed + vectors in the generated DescriptorElements as generated from the + provided factory. This will overwrite the persistently stored + vectors if the provided factory produces a DescriptorElement + implementation such storage. + :type overwrite: bool + :raises ValueError: An input DataElement was of a content type that we + cannot handle. + :return: Mapping of input DataElement UUIDs to the computed descriptor + element for that data. DescriptorElement UUID's are congruent with + the UUID of the data element it is the descriptor of. + :rtype: dict[collections.Hashable, + smqtk.representation.DescriptorElement] + """ + data_elements = {} + descr_elements = {} + pr = ProgressReporter(self._log.debug, 1.0).start() + for d in data_set: + ct = d.content_type() + if ct not in self.valid_content_types(): + self._log.error("Cannot compute descriptor from content type " + "'%s' data: %s)" % (ct, d)) + raise ValueError("Cannot compute descriptor from content type " + "'%s' data: %s)" % (ct, d)) + data_elements[d.uuid()] = d + descr_elements[d.uuid()] = descriptor_elem_factory \ + .new_descriptor(self.name, d.uuid()) + pr.increment_report() + pr.report() + self.overwrite = overwrite + self.uuid4proc = deque() + + procs = multiprocessing.cpu_count() + if len(data_elements) < procs: + procs = len(data_elements) + if procs == 0: + raise ValueError("No data elements provided") + # Using thread-pool due to in-line function + updating local deque + p = multiprocessing.pool.ThreadPool(procs) + try: + p.map(self.check_get_uuid, six.itervalues(descr_elements)) + except AttributeError: + p.close() + p.join() + del p + self._log.debug("%d descriptors already computed", + len(data_elements) - len(self.uuid4proc)) + self._log.debug("Given %d unique data elements", + len(data_elements)) + if len(data_elements) == 0: + raise ValueError("No data elements provided") + + if self.uuid4proc: + kwargs = {'num_workers': procs, 'pin_memory': True} + data_loader_cls = PytorchImagedataset(data_elements, + self.uuid4proc, self.transforms) + data_loader = DataLoader(data_loader_cls, + batch_size=self.batch_size, shuffle=False, **kwargs) + self._log.debug("Extracting PyTorch features") + for (d, uuids) in data_loader: + if self.use_gpu: + d = d.cuda() + pytorch_f = self.model(Variable(d)).squeeze() + if len(pytorch_f.shape) < 2: + pytorch_f = pytorch_f.unsqueeze(0) + if len(pytorch_f.shape) > 2: + import numpy + pytorch_f = pytorch_f.view(pytorch_f.shape[0], + (numpy.prod(pytorch_f.shape[1:]))) + [descr_elements[uuid].set_vector( + pytorch_f.data.cpu().numpy()[idx]) + for idx, uuid in enumerate(uuids)] + self._log.debug("forming output dict") + return dict((data_elements[k].uuid(), descr_elements[k]) + for k in data_elements) + +def _process_load_img_array(image_pil, transforms = None): + """ + Helper function for multiprocessing image data loading + + """ + if transforms: + image_pil = transforms(image_pil) + return torchvision.transforms.ToPILImage(image_pil) + diff --git a/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/utils.py b/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/utils.py new file mode 100644 index 000000000..defe5fdb1 --- /dev/null +++ b/python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/utils.py @@ -0,0 +1,55 @@ +from torch.utils.data import Dataset +from PIL import Image +import io + + +class PytorchImagedataset(Dataset): + """ + A Pytorch dataset class that loads images for feature extraction, + while maintaining a corresponde between their feature vectors + and uuids. + """ + + def __init__(self, data_elements, uuid4proc, transforms): + """ + Create a Pytorch dataset for feature extraction using CNN. + :param data_elements: A dictionary of uuids to corresponding + smqtk.representation.DataElement + :type data_elements: dict[uuid, smqtk.representation.DataElement] + :param uuid4proc: A queue of descriptor element uuids. + :type uuid4proc: list[uuid] + :param transforms: Augmentations and transforms applied to each + image. + :type tranforms: torchvision.transforms + + :return: A tuple containing the transformed image and corresponding + uuid. + :rtype: tuple(torch.tensor, str) + """ + self.transform = transforms + self._uuid4proc = uuid4proc + self.data_ele = data_elements + + def __len__(self): + """ + Returns the length of dataset + """ + return len(self.data_ele) + + def __getitem__(self, idx): + """ + Returns both the transformed image tensor and its corresponding uuids + at a random position inside the dataset. + :param idx: id of a dataset elements to be fetched in current batch + of feature extraction. + :type idx: int or [int] + + :return res: A tuple of the image tensor and its uuid. + :rtype res: tuple(torch.tensor, str) + """ + img = Image.open(io.BytesIO(self.data_ele[self._uuid4proc[idx]].get_bytes())) + img = img.convert('RGB') + if self.transform: + img = self.transform(img) + res = (img, self._uuid4proc[idx]) + return res diff --git a/requirements.txt b/requirements.txt index 448ead444..91fe810f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,5 @@ scikit-learn==0.20.0 scipy==1.1.0 six==1.11.0 stevedore==1.29.0 +torch==1.4.0 +torchvision==0.2.2 diff --git a/tests/algorithms/descriptor_generator/test_pytorch.py b/tests/algorithms/descriptor_generator/test_pytorch.py new file mode 100644 index 000000000..2429444cd --- /dev/null +++ b/tests/algorithms/descriptor_generator/test_pytorch.py @@ -0,0 +1,172 @@ +from __future__ import division, print_function +import inspect +import os +import unittest + +import six +import PIL.Image +import numpy + +from smqtk.algorithms.descriptor_generator import DescriptorGenerator +from smqtk.algorithms.descriptor_generator.pytorchdescriptor.pytorch_model_descriptors import \ + torch, PytorchModelDescriptor +from smqtk.representation.data_element.memory_element import DataMemoryElement + +from tests import TEST_DATA_DIR +import pytest + +if PytorchModelDescriptor.is_usable(): + + class TestPytorchModelDescriptor (unittest.TestCase): + + lenna_image_fp = os.path.join(TEST_DATA_DIR,'Lenna.png') + lenna_torch_res18_avgpool_descr_fp = os.path.join( + TEST_DATA_DIR, 'Lenna.resnet18_avgpool_output.npy' + ) + + model_name_elem = 'resnet18' + return_layer_elem = 'avgpool' + norm_mean_elem = [0.485, 0.456, 0.406] + norm_std_elem = [0.229, 0.224, 0.225] + pretrained = True + resnet18_avgpool_weights = os.path.join( + TEST_DATA_DIR,'resnet18_avgpool_weights_torch.pth') + + # Dummy pytorch configuration files + weights + dummy_model_name = 'dummy_model' + dummy_return_layer = 'junk_layer' + + @classmethod + def setup_class(cls): + cls.model_name = 'resnet18' + cls.return_layer = 'avgpool' + cls.input_dim = (224, 224) + cls.norm_mean = [0.485, 0.456, 0.406] + cls.norm_std = [0.229, 0.224, 0.225] + if not torch.cuda.is_available(): + cls.use_gpu = False + + def test_impl_findable(self): + self.assertIn(PytorchModelDescriptor, + DescriptorGenerator.get_impls()) + + def test_get_config(self): + # Mocking set_network + expected_params = { + 'model_name': 'resnet18', + 'return_layer': 'avgpool', + 'custom_model_arch': False, + 'weights_filepath': None, + 'input_dim': (24, 996), + 'norm_mean': [0, 0, -0.5], + 'norm_std': [0.2, 0.3, 1], + 'use_gpu': True, + 'batch_size': 777, + 'pretrained': True, + } + g = PytorchModelDescriptor(**expected_params) + self.assertEqual(g.get_config(), expected_params) + + def test_no_internal_compute_descriptor(self): + # This implementation's descriptor computation logic sits in async + # method override due to Pytorch's natural multi-element computation + # interface. Thus, ``_compute_descriptor`` should not be + # implemented. + + # Passing purposefully bag constructor parameters and ignoring + # noinspection PyTypeChecker + g = PytorchModelDescriptor() + self.assertRaises( + NotImplementedError, + g._compute_descriptor, None + ) + + def test_compute_descriptor_dummy_model(self): + # Pytorch dummy network interaction test Lenna image) + + # Construct network with an dummy model. + # We expect an AsserterionError + self.assertRaises( + AssertionError, + PytorchModelDescriptor, model_name = self.dummy_model_name) + + @unittest.skipUnless(DataMemoryElement.is_usable(), + "Memory element not functional") + def test_compute_descriptor_lenna_description(self): + # Pytorch ResNet interaction test (Lenna image) + # This is a long test since it has to compute descriptors. + expected_descr = numpy.load(self.lenna_torch_res18_avgpool_descr_fp) + d = PytorchModelDescriptor( + self.model_name_elem, + self.return_layer_elem, + None, None, self.input_dim, + self.norm_mean_elem, + self.norm_std_elem, + True, 1, self.pretrained) + im = PIL.Image.open(self.lenna_image_fp) + buff = six.BytesIO() + (im).save(buff, format="bmp") + de = DataMemoryElement(buff.getvalue(), + content_type='image/bmp') + descr = (d.compute_descriptor(de)).vector() + numpy.testing.assert_allclose(expected_descr, descr, atol=1e-4) + + @unittest.skipUnless(DataMemoryElement.is_usable(), + "Memory element not functional") + def test_load_image_data(self): + # Testing if image can be loaded and throw an error if uuid is + # not automatically generated. + buff = six.BytesIO() + im = PIL.Image.open(self.lenna_image_fp) + (im).save(buff, format="bmp") + de = DataMemoryElement(buff.getvalue(), + content_type='image/bmp') + with pytest.raises(AssertionError): + assert not (de.uuid()) + + def test_compute_descriptor_async_no_data(self): + # Should get a ValueError when given no descriptors to async method + g = PytorchModelDescriptor( + self.model_name_elem, + self.return_layer_elem, + None, None, self.input_dim, + self.norm_mean_elem, + self.norm_std_elem, + True, 32, self.pretrained) + self.assertRaises( + ValueError, + g.compute_descriptor_async, [] + ) + + def test_loading_custom_weights_model(self): + # Should get a ValueError when the network weights are not + # loaded to the network or junk weights loaded. + with pytest.raises(ValueError): + g = PytorchModelDescriptor(custom_model_arch=None, \ + weights_filepath=None, pretrained=False) + + def test_weights_loaded_to_model(self): + # Should fail when the network weights with pretrained flag + # loaded are not the imagenet pretrained weights. + d = PytorchModelDescriptor( + self.model_name_elem, + self.return_layer_elem, + None, None, self.input_dim, + self.norm_mean_elem, + self.norm_std_elem, + True, 1, self.pretrained) + imagenet_weights = torch.load(self.resnet18_avgpool_weights) + d.model.state_dict() == pytest.approx(imagenet_weights, rel=1e-6, abs=1e-12) + + def test_return_layer_from_network(self): + # Should get a KeyError when the network does not contain + # the given return layer + with pytest.raises(KeyError): + g = PytorchModelDescriptor( + self.model_name_elem, + self.dummy_return_layer, + None, None, self.input_dim, + self.norm_mean_elem, + self.norm_std_elem, + True, 32, True) + diff --git a/tests/data/Lenna.resnet18_avgpool_output.npy b/tests/data/Lenna.resnet18_avgpool_output.npy new file mode 100644 index 000000000..0abb62499 Binary files /dev/null and b/tests/data/Lenna.resnet18_avgpool_output.npy differ diff --git a/tests/data/resnet18_avgpool_weights_torch.pth b/tests/data/resnet18_avgpool_weights_torch.pth new file mode 100644 index 000000000..558c7c250 Binary files /dev/null and b/tests/data/resnet18_avgpool_weights_torch.pth differ