88from ..utils import FileHandler , fix_duplicates
99
1010try :
11- from ..datasets import SegmentationFolderDataset
11+ from ..datasets import SegmentationFolderDataset , SegmentationHDF5Dataset
12+ from ..datasets .dataset_writers .hdf5_writer import HDF5Writer
1213 from ._basemodule import BaseDataModule
1314 from .downloader import SimpleDownloader
1415except ModuleNotFoundError :
@@ -26,6 +27,7 @@ def __init__(
2627 fold_split : Dict [str , int ],
2728 img_transforms : List [str ],
2829 inst_transforms : List [str ],
30+ dataset_type : str = "folder" ,
2931 normalization : str = None ,
3032 batch_size : int = 8 ,
3133 num_workers : int = 8 ,
@@ -65,6 +67,8 @@ def __init__(
6567 A list containg all the transformations that are applied to only the
6668 instance labelled masks. Allowed ones: "cellpose", "contour", "dist",
6769 "edgeweight", "hovernet", "omnipose", "smooth_dist", "binarize"
70+ dataset_type : str, default="folder"
71+ The dataset type. One of "folder", "hdf5".
6872 normalization : str, optional
6973 Apply img normalization after all the transformations. One of "minmax",
7074 "norm", "percentile", None.
@@ -107,6 +111,14 @@ def __init__(
107111 self .normalization = normalization
108112 self .kwargs = kwargs if kwargs is not None else {}
109113
114+ if dataset_type not in ("folder" , "hdf5" ):
115+ raise ValueError (
116+ f"Illegal `dataset_type` arg. Got { dataset_type } . "
117+ f"Allowed: { ('folder' , 'hdf5' )} "
118+ )
119+
120+ self .dataset_type = dataset_type
121+
110122 @property
111123 def type_classes (self ) -> Dict [str , int ]:
112124 """Pannuke cell type classes."""
@@ -127,7 +139,7 @@ def download(root: str) -> None:
127139 SimpleDownloader .download (url , root )
128140 PannukeDataModule .extract_zips (root , rm = True )
129141
130- def prepare_data (self , rm_orig : bool = True ) -> None :
142+ def prepare_data (self , rm_orig : bool = False ) -> None :
131143 """Prepare the pannuke datasets.
132144
133145 1. Download pannuke folds from:
@@ -167,6 +179,18 @@ def prepare_data(self, rm_orig: bool = True) -> None:
167179 self ._process_pannuke_fold (
168180 fold_paths , save_im_dir , save_mask_dir , fold_ix , phase
169181 )
182+
183+ if self .dataset_type == "hdf5" :
184+ writer = HDF5Writer (
185+ in_dir_im = save_im_dir ,
186+ in_dir_mask = save_mask_dir ,
187+ save_dir = self .save_dir / phase ,
188+ file_name = f"pannuke_{ phase } .h5" ,
189+ patch_size = None ,
190+ stride = None ,
191+ transforms = None ,
192+ )
193+ writer .write (tiling = False , msg = phase )
170194 else :
171195 print (
172196 "Found processed pannuke data. "
@@ -178,31 +202,45 @@ def prepare_data(self, rm_orig: bool = True) -> None:
178202 if "fold" in d .name .lower ():
179203 shutil .rmtree (d )
180204
205+ def _get_path (self , phase : str , dstype : str , is_mask : bool = False ) -> Path :
206+ if dstype == "hdf5" :
207+ p = self .save_dir / phase / f"pannuke_{ phase } .h5"
208+ else :
209+ dtype = "labels" if is_mask else "images"
210+ p = self .save_dir / phase / dtype
211+
212+ return p
213+
181214 def setup (self , stage : Optional [str ] = None ) -> None :
182215 """Set up the train, valid, and test datasets."""
183- self .trainset = SegmentationFolderDataset (
184- path = self .save_dir / "train" / "images" ,
185- mask_path = self .save_dir / "train" / "labels" ,
216+ if self .dataset_type == "hdf5" :
217+ DS = SegmentationHDF5Dataset
218+ else :
219+ DS = SegmentationFolderDataset
220+
221+ self .trainset = DS (
222+ path = self ._get_path ("train" , self .dataset_type , is_mask = False ),
223+ mask_path = self ._get_path ("train" , self .dataset_type , is_mask = True ),
186224 img_transforms = self .img_transforms ,
187225 inst_transforms = self .inst_transforms ,
188226 return_sem = False ,
189227 normalization = self .normalization ,
190228 ** self .kwargs ,
191229 )
192230
193- self .validset = SegmentationFolderDataset (
194- path = self .save_dir / "valid" / "images" ,
195- mask_path = self .save_dir / "valid" / "labels" ,
231+ self .validset = DS (
232+ path = self ._get_path ( "valid" , self . dataset_type , is_mask = False ) ,
233+ mask_path = self ._get_path ( "valid" , self . dataset_type , is_mask = True ) ,
196234 img_transforms = self .img_transforms ,
197235 inst_transforms = self .inst_transforms ,
198236 return_sem = False ,
199237 normalization = self .normalization ,
200238 ** self .kwargs ,
201239 )
202240
203- self .testset = SegmentationFolderDataset (
204- path = self .save_dir / "test" / "images" ,
205- mask_path = self .save_dir / "test" / "labels" ,
241+ self .testset = DS (
242+ path = self ._get_path ( "test" , self . dataset_type , is_mask = False ) ,
243+ mask_path = self ._get_path ( "test" , self . dataset_type , is_mask = True ) ,
206244 img_transforms = self .img_transforms ,
207245 inst_transforms = self .inst_transforms ,
208246 return_sem = False ,
@@ -256,7 +294,7 @@ def _process_pannuke_fold(
256294 inst_map = self ._get_inst_map (temp_mask [..., 0 :5 ])
257295
258296 fn_mask = Path (save_mask_dir / name ).with_suffix (".mat" )
259- FileHandler .write_mask (fn_mask , inst_map , type_map )
297+ FileHandler .write_mat (fn_mask , inst_map , type_map )
260298 pbar .update (1 )
261299
262300 def _get_type_map (self , pannuke_mask : np .ndarray ) -> np .ndarray :
0 commit comments