77import numpy as np
88import torch
99import torch .nn as nn
10+ import yaml
1011from pathos .multiprocessing import ThreadPool as Pool
1112from torch .utils .data import DataLoader
1213from tqdm import tqdm
@@ -31,6 +32,7 @@ def __init__(
3132 batch_size : int = 8 ,
3233 normalization : str = None ,
3334 device : str = "cuda" ,
35+ n_devices : int = 1 ,
3436 save_masks : bool = True ,
3537 save_intermediate : bool = False ,
3638 save_dir : Union [Path , str ] = None ,
@@ -72,6 +74,9 @@ def __init__(
7274 One of: "dataset", "minmax", "norm", "percentile", None.
7375 device : str, default="cuda"
7476 The device of the input and model. One of: "cuda", "cpu"
77+ n_devices : int, default=1
78+ Number of devices (cpus/gpus) used for inference.
79+ The model will be copied into these devices.
7580 save_masks : bool, default=False
7681 If True, the resulting segmentation masks will be saved into `out_masks`
7782 variable.
@@ -95,6 +100,16 @@ def __init__(
95100 **postproc_kwargs:
96101 Arbitrary keyword arguments for the post-processing.
97102 """
103+ # basic inits
104+ self .model = model
105+ self .out_heads = self ._get_out_info () # the names and num channels of out heads
106+ self .batch_size = batch_size
107+ self .patch_size = patch_size
108+ self .padding = padding
109+ self .out_activations = out_activations
110+ self .out_boundary_weights = out_boundary_weights
111+ self .head_kwargs = self ._check_and_set_head_args ()
112+
98113 self .save_dir = Path (save_dir ) if save_dir is not None else None
99114 self .save_masks = save_masks
100115 self .save_intermediate = save_intermediate
@@ -106,17 +121,17 @@ def __init__(
106121 folder_ds , batch_size = batch_size , shuffle = False , pin_memory = True
107122 )
108123
109- # model and device
110- self .model = model
111- if device == "cpu" :
112- self .model .cpu ()
113- self .device = torch .device ("cpu" )
114- if torch .cuda .is_available () and device == "cuda" :
115- self .model .cuda ()
116- self .device = torch .device ("cuda" )
117-
118- self .model .eval ()
124+ # Set post processor
125+ self .postprocessor = PostProcessor (
126+ instance_postproc ,
127+ inst_key = self .model .inst_key ,
128+ aux_key = self .model .aux_key ,
129+ type_post_proc = type_post_proc ,
130+ sem_post_proc = sem_post_proc ,
131+ ** postproc_kwargs ,
132+ )
119133
134+ # load weights and set devices
120135 if checkpoint_path is not None :
121136 ckpt = torch .load (
122137 checkpoint_path , map_location = lambda storage , loc : storage
@@ -130,30 +145,41 @@ def __init__(
130145 except BaseException as e :
131146 print (e )
132147
133- #
148+ assert device in ("cuda" , "cpu" )
149+ if device == "cpu" :
150+ self .device = torch .device ("cpu" )
151+ if torch .cuda .is_available () and device == "cuda" :
152+ self .device = torch .device ("cuda" )
153+
154+ if torch .cuda .device_count () > 1 and n_devices > 1 :
155+ self .model = nn .DataParallel (self .model , device_ids = range (n_devices ))
156+
157+ self .model .to (self .device )
158+ self .model .eval ()
159+
160+ # Helper class to perform forward + extra processing
134161 self .predictor = Predictor (
135162 model = self .model ,
136163 patch_size = patch_size ,
137164 normalization = normalization ,
138165 device = self .device ,
139166 )
140- self .out_heads = self ._get_out_info () # the names and num channels of out heads
141- self .batch_size = batch_size
142- self .patch_size = patch_size
143- self .padding = padding
144- self .out_activations = out_activations
145- self .out_boundary_weights = out_boundary_weights
146- self .head_kwargs = self ._check_and_set_head_args ()
147167
148- #
149- self .postprocessor = PostProcessor (
150- instance_postproc ,
151- inst_key = self .model .inst_key ,
152- aux_key = self .model .aux_key ,
153- type_post_proc = type_post_proc ,
154- sem_post_proc = sem_post_proc ,
155- ** postproc_kwargs ,
156- )
168+ @classmethod
169+ def from_yaml (cls , model : nn .Module , yaml_path : str ):
170+ """Initialize the inferer from a yaml-file.
171+
172+ Parameters
173+ ----------
174+ model : nn.Module
175+ Initialized segmentation model.
176+ yaml_path : str
177+ Path to the yaml file containing rest of the params
178+ """
179+ with open (yaml_path , "r" ) as stream :
180+ kwargs = yaml .full_load (stream )
181+
182+ return cls (model , ** kwargs )
157183
158184 @abstractmethod
159185 def _infer_batch (self ):
0 commit comments