diff --git a/dataloaders/__init__.py b/dataloaders/__init__.py index 641f5486..ac07999f 100644 --- a/dataloaders/__init__.py +++ b/dataloaders/__init__.py @@ -1,4 +1,4 @@ -from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd +from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd, factory from torch.utils.data import DataLoader def make_data_loader(args, **kwargs): @@ -37,6 +37,15 @@ def make_data_loader(args, **kwargs): test_loader = None return train_loader, val_loader, test_loader, num_class + elif args.dataset == 'factory': + train_set = factory.FactorySegmentation(args, split='train') + val_set = factory.FactorySegmentation(args, split='val') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) + test_loader = None + return train_loader, val_loader, test_loader, num_class + else: raise NotImplementedError diff --git a/dataloaders/datasets/factory.py b/dataloaders/datasets/factory.py new file mode 100644 index 00000000..2e5b0a53 --- /dev/null +++ b/dataloaders/datasets/factory.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +from torch.utils.data import Dataset +from mypath import Path +from tqdm import trange +import os +from pycocotools import mask +from torchvision import transforms +from dataloaders import custom_transforms as tr +from PIL import Image, ImageFile +from glob import glob +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class FactorySegmentation(): + NUM_CLASSES = 5 + + def __init__(self, args, base_dir=Path.db_root_dir('factory'),split='train'): + self.label_path = glob(base_dir+"label_img/*") + img_path = [p.split("/")[-1].split(".")[0] for p in self.label_path] + self.img_path = [base_dir+"data_img/"+num+".png" for num in img_path] + self.split = split + + def __getitem__(self, index): + _img = Image.open(self.img_path[index]).convert('RGB') + _target = np.array(Image.open(self.label_path[index]).convert('RGB'), dtype=np.float32) + _target = self._gen_seg_mask(_target) + sample = {'image': _img, 'label': _target} + + if self.split == "train": + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample) + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor() + ]) + return composed_transforms(sample) + + def transform_val(self, sample): + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor() + ]) + return composed_transforms(sample) + + def _gen_seg_mask(self, target): + target /= 255 + mask = target[:,:,0] + target[:,:,1]*2 + target[:,:,2]*3 + + mask = np.where(mask==1, 1, mask) + mask = np.where(mask==2, 2, mask) + mask = np.where(mask==3, 3, mask) + mask = np.where(mask==6, 4, mask) + return Image.fromarray(mask) + + def __len__(self): + return len(self.img_path) + + +if __name__ == "__main__": + from dataloaders import custom_transforms as tr + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + from torchvision import transforms + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + factory_val = FactorySegmentation(args, split='val') + + dataloader = DataLoader(factory_val, batch_size=4, shuffle=True, num_workers=0) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='factory') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) \ No newline at end of file diff --git a/dataloaders/utils.py b/dataloaders/utils.py index b4a80ba4..516b81ff 100644 --- a/dataloaders/utils.py +++ b/dataloaders/utils.py @@ -27,6 +27,9 @@ def decode_segmap(label_mask, dataset, plot=False): elif dataset == 'cityscapes': n_classes = 19 label_colours = get_cityscapes_labels() + elif dataset == 'factory': + n_classes = 5 + label_colours = get_factory_labels() else: raise NotImplementedError @@ -98,4 +101,12 @@ def get_pascal_labels(): [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], - [0, 64, 128]]) \ No newline at end of file + [0, 64, 128]]) + +def get_factory_labels(): + return np.array([ + [0, 0, 0], + [255, 0, 0], + [0, 255, 0], + [0, 0, 255], + [255, 255, 255]]) \ No newline at end of file diff --git a/mypath.py b/mypath.py index 8354d7da..e4d7e140 100644 --- a/mypath.py +++ b/mypath.py @@ -9,6 +9,8 @@ def db_root_dir(dataset): return '/path/to/datasets/cityscapes/' # foler that contains leftImg8bit/ elif dataset == 'coco': return '/path/to/datasets/coco/' + elif dataset == 'factory': + return '/home/kataoka/dataset/factory_seg/' else: print('Dataset {} not available.'.format(dataset)) raise NotImplementedError diff --git a/train.py b/train.py index 1c5c4fdd..523052b2 100644 --- a/train.py +++ b/train.py @@ -181,10 +181,10 @@ def main(): choices=['resnet', 'xception', 'drn', 'mobilenet'], help='backbone name (default: resnet)') parser.add_argument('--out-stride', type=int, default=16, - help='network output stride (default: 8)') - parser.add_argument('--dataset', type=str, default='pascal', - choices=['pascal', 'coco', 'cityscapes'], - help='dataset name (default: pascal)') + help='network output stride (default: 16)') + parser.add_argument('--dataset', type=str, default='factory', + choices=['pascal', 'coco', 'cityscapes', 'factory'], + help='dataset name (default: factory)') parser.add_argument('--use-sbd', action='store_true', default=True, help='whether to use SBD dataset (default: True)') parser.add_argument('--workers', type=int, default=4, @@ -267,6 +267,7 @@ def main(): 'coco': 30, 'cityscapes': 200, 'pascal': 50, + 'factory': 50, } args.epochs = epoches[args.dataset.lower()] @@ -281,6 +282,7 @@ def main(): 'coco': 0.1, 'cityscapes': 0.01, 'pascal': 0.007, + 'factory': 0.01, } args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size