1+ import warnings
12from abc import ABC , abstractmethod
23from collections import OrderedDict
34from itertools import chain
89import torch
910import torch .nn as nn
1011import yaml
11- from pathos .multiprocessing import ThreadPool as Pool
1212from torch .utils .data import DataLoader
1313from tqdm import tqdm
1414
15- from ..utils import tensor_to_ndarray
16- from ..utils .save_utils import mask2mat
15+ from ..utils import FileHandler , tensor_to_ndarray
1716from .folder_dataset import FolderDataset
1817from .post_processor import PostProcessor
1918from .predictor import Predictor
@@ -33,14 +32,14 @@ def __init__(
3332 normalization : str = None ,
3433 device : str = "cuda" ,
3534 n_devices : int = 1 ,
36- save_masks : bool = True ,
3735 save_intermediate : bool = False ,
3836 save_dir : Union [Path , str ] = None ,
37+ save_format : str = ".mat" ,
3938 checkpoint_path : Union [Path , str ] = None ,
4039 n_images : int = None ,
4140 type_post_proc : Callable = None ,
4241 sem_post_proc : Callable = None ,
43- ** postproc_kwargs ,
42+ ** kwargs ,
4443 ) -> None :
4544 """Inference for an image folder.
4645
@@ -77,16 +76,14 @@ def __init__(
7776 n_devices : int, default=1
7877 Number of devices (cpus/gpus) used for inference.
7978 The model will be copied into these devices.
80- save_masks : bool, default=False
81- If True, the resulting segmentation masks will be saved into `out_masks`
82- variable.
83- save_intermediate : bool, default=False
84- If True, intermediate soft masks will be saved into `soft_masks` var.
8579 save_dir : bool, optional
8680 Path to save directory. If None, no masks will be saved to disk as .mat
87- files. If not None, overrides `save_masks`, thus for every batch the
88- segmentation results are saved into disk and the intermediate results
89- are flushed.
81+ or .json files. Instead the masks will be saved in `self.out_masks`.
82+ save_intermediate : bool, default=False
83+ If True, intermediate soft masks will be saved into `soft_masks` var.
84+ save_format : str, default=".mat"
85+ The file format for the saved output masks. One of (".mat", ".json").
86+ The ".json" option will save masks into geojson format.
9087 checkpoint_path : Path | str, optional
9188 Path to the model weight checkpoints.
9289 n_images : int, optional
@@ -97,8 +94,8 @@ def __init__(
9794 sem_post_proc : Callable, optional
9895 A post-processing function for the semantc seg maps. If not None,
9996 overrides the default.
100- **postproc_kwargs :
101- Arbitrary keyword arguments for the post-processing.
97+ **kwargs :
98+ Arbitrary keyword arguments expecially for post-processing and saving .
10299 """
103100 # basic inits
104101 self .model = model
@@ -109,14 +106,25 @@ def __init__(
109106 self .out_activations = out_activations
110107 self .out_boundary_weights = out_boundary_weights
111108 self .head_kwargs = self ._check_and_set_head_args ()
109+ self .kwargs = kwargs
112110
113111 self .save_dir = Path (save_dir ) if save_dir is not None else None
114- self .save_masks = save_masks
115112 self .save_intermediate = save_intermediate
113+ self .save_format = save_format
116114
117115 # dataloader
118116 self .path = Path (input_folder )
117+
119118 folder_ds = FolderDataset (self .path , n_images = n_images )
119+ if self .save_dir is None and len (folder_ds .fnames ) > 40 :
120+ warnings .warn (
121+ "`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
122+ "class variable. If the input folder contains many images, running "
123+ "inference will likely flood the memory depending on the size and "
124+ "number of the images. Consider saving outputs to disk by providing "
125+ "`save_dir` argument."
126+ )
127+
120128 self .dataloader = DataLoader (
121129 folder_ds , batch_size = batch_size , shuffle = False , pin_memory = True
122130 )
@@ -128,7 +136,7 @@ def __init__(
128136 aux_key = self .model .aux_key ,
129137 type_post_proc = type_post_proc ,
130138 sem_post_proc = sem_post_proc ,
131- ** postproc_kwargs ,
139+ ** kwargs ,
132140 )
133141
134142 # load weights and set devices
@@ -188,10 +196,16 @@ def _infer_batch(self):
188196 def infer (self ) -> None :
189197 """Run inference and post-processing for the images.
190198
191- NOTE: Saves outputs in `self.out_masks` or to disk (.mat) files.
192-
193- `self.out_masks` is a nested dict: E.g.
194- {"image1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
199+ NOTE:
200+ - Saves outputs in `self.out_masks` or to disk (.mat/.json) files.
201+ - If `save_intermediate` is set to True, also intermiediate model outputs are
202+ saved to `self.soft_masks`
203+ - `self.out_masks` and `self.soft_masks` are nested dicts: E.g.
204+ {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
205+ - If masks are saved to geojson .json files, more key word arguments
206+ need to be given at class initialization. Namely: `geo_format`,
207+ `classes_type`, `classes_sem`, `offsets`. See more in the
208+ `FileHandler.save_masks` docs.
195209 """
196210 self .soft_masks = {}
197211 self .out_masks = {}
@@ -223,89 +237,25 @@ def infer(self) -> None:
223237 self .soft_masks [n ] = m
224238
225239 if self .save_dir is None :
226- if self .save_masks :
227- for n , m in zip (names , seg_results ):
228- self .out_masks [n ] = m
240+ for n , m in zip (names , seg_results ):
241+ self .out_masks [n ] = m
229242 else :
230243 loader .set_postfix_str ("Saving results to disk" )
231244 if self .batch_size > 1 :
232- self .save_parallel (seg_results , names , self .save_dir )
245+ fnames = [Path (self .save_dir ) / n for n in names ]
246+ FileHandler .save_masks_parallel (
247+ maps = seg_results ,
248+ fnames = fnames ,
249+ ** {** self .kwargs , "format" : self .save_format },
250+ )
233251 else :
234252 for n , m in zip (names , seg_results ):
235- self .save_mask (m , n , self .save_dir )
236-
237- @staticmethod
238- def save_mask (
239- maps : Dict [str , np .ndarray ],
240- fname : str ,
241- save_dir : Union [str , Path ],
242- format : str = ".mat" ,
243- ) -> None :
244- """Save model outputs to .mat or geojson.
245-
246- Parameters
247- ----------
248- maps : Dict[str, np.ndarray]
249- model output names mapped to model outputs.
250- E.g. {"sem": np.ndarray, "type": np.ndarray, "inst": np.ndarray}
251- fname : str
252- Name for the output-file.
253- save_dir : Path or str
254- Path to the save directory.
255- format : str
256- One of ".mat" or "geojson"
257- """
258- allowed = (".mat" , ".json" )
259- if format not in allowed :
260- raise ValueError (
261- f"Illegal file-format. Got: { format } . Allowed formats: { allowed } "
262- )
263-
264- if format == ".mat" :
265- mask2mat (fname , save_dir , ** maps )
266- else :
267- pass
268-
269- return True
270-
271- @staticmethod
272- def save_parallel (
273- maps : List [Dict [str , np .ndarray ]],
274- fnames : List [str ],
275- save_dir : Union [Path , str ],
276- format : str = ".mat" ,
277- progress_bar : bool = False ,
278- ) -> None :
279- """Save the model output masks to a folder. (multi-threaded).
280-
281- Parameters
282- ----------
283- maps : List[Dict[str, np.ndarray]]
284- The model output map dictionaries in a list.
285- fnames : List[str]
286- Name for the output-files. (In the same order with `maps`)
287- save_dir : Path or str
288- Path to the save directory.
289- format : str
290- One of ".mat" or "geojson"
291- progress_bar : bool, default=False
292- If True, a tqdm progress bar is shown.
293- """
294- args = tuple (zip (maps , fnames , [save_dir ] * len (maps ), [format ] * len (maps )))
295-
296- with Pool () as pool :
297- if progress_bar :
298- it = tqdm (pool .imap (BaseInferer ._save_mask , args ), total = len (maps ))
299- else :
300- it = pool .imap (BaseInferer ._save_mask , args )
301-
302- for _ in it :
303- pass
304-
305- @staticmethod
306- def _save_mask (args : Tuple [Dict [str , np .ndarray ], str , str ]) -> None :
307- """Unpacks the args for `save_mask` to enable multi-threading."""
308- return BaseInferer .save_mask (* args )
253+ fname = Path (self .save_dir ) / n
254+ FileHandler .save_masks (
255+ fname = fname ,
256+ maps = m ,
257+ ** {** self .kwargs , "format" : self .save_format },
258+ )
309259
310260 def _strip_state_dict (self , ckpt : Dict ) -> OrderedDict :
311261 """Strip te first 'model.' (generated by lightning) from the state dict keys."""
0 commit comments