1212
1313from ..functional import iou
1414
15- __all__ = ["WandbImageCallback" , "WandbClassMetricCallback " ]
15+ __all__ = ["WandbImageCallback" , "WandbClassBarCallback" , "WandbClassLineCallback " ]
1616
1717
1818class WandbImageCallback (pl .Callback ):
@@ -104,26 +104,22 @@ def on_validation_batch_end(
104104 trainer .logger .experiment .log (log_dict )
105105
106106
107- class WandbClassMetricCallback (pl .Callback ):
107+ class WandbIoUCallback (pl .Callback ):
108108 def __init__ (
109109 self ,
110110 type_classes : Dict [str , int ],
111111 sem_classes : Optional [Dict [str , int ]],
112112 freq : int = 100 ,
113- return_series : bool = True ,
114- return_bar : bool = True ,
115- return_table : bool = False ,
116113 ) -> None :
117- """Call back to compute per- class ious and log them to wandb."""
114+ """Create a base class for IoU wandb callbacks ."""
118115 super ().__init__ ()
119116 self .type_classes = type_classes
120117 self .sem_classes = sem_classes
121118 self .freq = freq
122- self .return_series = return_series
123- self .return_bar = return_bar
124- self .return_table = return_table
125- self .cell_ious = np .empty (0 )
126- self .sem_ious = np .empty (0 )
119+
120+ def batch_end (self ) -> None :
121+ """Abstract batch end method."""
122+ raise NotImplementedError
127123
128124 def compute (
129125 self ,
@@ -139,36 +135,47 @@ def compute(
139135 met = iou (pred , target ).mean (dim = 0 )
140136 return met .to ("cpu" ).numpy ()
141137
142- def get_table (
143- self , ious : np .ndarray , x : np .ndarray , classes : Dict [int , str ]
144- ) -> wandb .Table :
145- """Return a wandb Table with step, iou and label values for every step."""
146- batch_data = [
147- [xi * self .freq , c , np .round (ious [xi , i ], 4 )]
148- for i , c , in classes .items ()
149- for xi in x
150- ]
138+ def on_train_batch_end (
139+ self ,
140+ trainer : pl .Trainer ,
141+ pl_module : pl .LightningModule ,
142+ outputs : Dict [str , torch .Tensor ],
143+ batch : Dict [str , torch .Tensor ],
144+ batch_idx : int ,
145+ dataloader_idx : int ,
146+ ) -> None :
147+ """Log the inputs and outputs of the model to wandb."""
148+ self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "train" )
151149
152- return wandb .Table (data = batch_data , columns = ["step" , "label" , "value" ])
150+ def on_validation_batch_end (
151+ self ,
152+ trainer : pl .Trainer ,
153+ pl_module : pl .LightningModule ,
154+ outputs : Dict [str , torch .Tensor ],
155+ batch : Dict [str , torch .Tensor ],
156+ batch_idx : int ,
157+ dataloader_idx : int ,
158+ ) -> None :
159+ """Log the inputs and outputs of the model to wandb."""
160+ self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "val" )
161+
162+
163+ class WandbClassBarCallback (WandbIoUCallback ):
164+ def __init__ (
165+ self ,
166+ type_classes : Dict [str , int ],
167+ sem_classes : Optional [Dict [str , int ]],
168+ freq : int = 100 ,
169+ ) -> None :
170+ """Create a wandb callback that logs per-class mIoU at batch ends."""
171+ super ().__init__ (type_classes , sem_classes , freq )
153172
154173 def get_bar (self , iou : np .ndarray , classes : Dict [int , str ], title : str ) -> Any :
155174 """Return a wandb bar plot object of the current per class iou values."""
156175 batch_data = [[lab , val ] for lab , val in zip (list (classes .values ()), iou )]
157176 table = wandb .Table (data = batch_data , columns = ["label" , "value" ])
158177 return wandb .plot .bar (table , "label" , "value" , title = title )
159178
160- def get_series (
161- self , ious : np .ndarray , x : np .ndarray , classes : Dict [int , str ], title : str
162- ) -> Any :
163- """Return a wandb series plot obj of the per class iou values over timesteps."""
164- return wandb .plot .line_series (
165- xs = x .tolist (),
166- ys = [ious [:, c ].tolist () for c in classes .keys ()],
167- keys = list (classes .values ()),
168- title = title ,
169- xname = "step" ,
170- )
171-
172179 def batch_end (
173180 self ,
174181 trainer : pl .Trainer ,
@@ -182,69 +189,54 @@ def batch_end(
182189 log_dict = {}
183190 if "type" in list (batch .keys ()):
184191 iou = self .compute ("type" , outputs , batch )
185- self .cell_ious = np .append (self .cell_ious , iou )
186- cell_ious = self .cell_ious .reshape (- 1 , len (self .type_classes ))
187- x = np .arange (cell_ious .shape [0 ])
188-
189- if self .return_table :
190- log_dict [f"{ phase } /type_ious_table" ] = self .get_table (
191- cell_ious , x , self .type_classes
192- )
193-
194- if self .return_series :
195- log_dict [f"{ phase } /type_ious_per_class" ] = self .get_series (
196- cell_ious , x , self .type_classes , title = "Per type class mIoU"
197- )
198-
199- if self .return_bar :
200- log_dict [f"{ phase } /type_ious_bar" ] = self .get_bar (
201- list (iou ), self .type_classes , title = "Cell class mIoUs"
202- )
192+ log_dict [f"{ phase } /type_ious_bar" ] = self .get_bar (
193+ list (iou ), self .type_classes , title = "Cell class mIoUs"
194+ )
203195
204196 if "sem" in list (batch .keys ()):
205197 iou = self .compute ("sem" , outputs , batch )
206-
207- self .sem_ious = np .append (self .sem_ious , iou )
208- sem_ious = self .sem_ious .reshape (- 1 , len (self .sem_classes ))
209- x = np .arange (sem_ious .shape [0 ])
210-
211- if self .return_table :
212- log_dict [f"{ phase } /sem_ious_table" ] = self .get_table (
213- cell_ious , x , self .type_classes
214- )
215-
216- if self .return_series :
217- log_dict [f"{ phase } /sem_ious_per_class" ] = self .get_series (
218- cell_ious , x , self .type_classes , title = "Per sem class mIoU"
219- )
220-
221- if self .return_bar :
222- log_dict [f"{ phase } /sem_ious_bar" ] = self .get_bar (
223- list (iou ), self .type_classes , title = "Sem class mIoUs"
224- )
198+ log_dict [f"{ phase } /sem_ious_bar" ] = self .get_bar (
199+ list (iou ), self .sem_classes , title = "Sem class mIoUs"
200+ )
225201
226202 trainer .logger .experiment .log (log_dict )
227203
228- def on_train_batch_end (
204+
205+ class WandbClassLineCallback (WandbIoUCallback ):
206+ def __init__ (
229207 self ,
230- trainer : pl .Trainer ,
231- pl_module : pl .LightningModule ,
232- outputs : Dict [str , torch .Tensor ],
233- batch : Dict [str , torch .Tensor ],
234- batch_idx : int ,
235- dataloader_idx : int ,
208+ type_classes : Dict [str , int ],
209+ sem_classes : Optional [Dict [str , int ]],
210+ freq : int = 100 ,
236211 ) -> None :
237- """Log the inputs and outputs of the model to wandb ."""
238- self . batch_end ( trainer , outputs [ "soft_masks" ], batch , batch_idx , phase = "train" )
212+ """Create a wandb callback that logs per-class mIoU at batch ends ."""
213+ super (). __init__ ( type_classes , sem_classes , freq )
239214
240- def on_validation_batch_end (
215+ def get_points (self , iou : np .ndarray , classes : Dict [int , str ]) -> Any :
216+ """Return a wandb bar plot object of the current per class iou values."""
217+ return {lab : val for lab , val in zip (list (classes .values ()), iou )}
218+
219+ def batch_end (
241220 self ,
242221 trainer : pl .Trainer ,
243- pl_module : pl .LightningModule ,
244222 outputs : Dict [str , torch .Tensor ],
245223 batch : Dict [str , torch .Tensor ],
246224 batch_idx : int ,
247- dataloader_idx : int ,
225+ phase : str ,
248226 ) -> None :
249- """Log the inputs and outputs of the model to wandb."""
250- self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "val" )
227+ """Log metrics at every 100th step to wandb."""
228+ if batch_idx % self .freq == 0 :
229+ log_dict = {}
230+ if "type" in list (batch .keys ()):
231+ iou = self .compute ("type" , outputs , batch )
232+ log_dict [f"{ phase } /type_ious_points" ] = self .get_points (
233+ list (iou ), self .type_classes
234+ )
235+
236+ if "sem" in list (batch .keys ()):
237+ iou = self .compute ("sem" , outputs , batch )
238+ log_dict [f"{ phase } /sem_ious_points" ] = self .get_points (
239+ list (iou ), self .sem_classes
240+ )
241+
242+ trainer .logger .experiment .log (log_dict )
0 commit comments