@@ -305,5 +305,233 @@ def test_patch_inferer_errors(self, inputs, arguments, expected_error):
305305 inferer (inputs = inputs , network = lambda x : x )
306306
307307
308+
309+ # ----------------------------------------------------------------------------
310+ # Error test cases with conditionign
311+ # ----------------------------------------------------------------------------
312+
313+ # no-overlapping 2x2 patches
314+ TEST_CASE_0_TENSOR_c = [
315+ TENSOR_4x4 ,
316+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
317+ lambda x , condition : x + condition ,
318+ TENSOR_4x4 * 2 ,
319+ ]
320+
321+ # no-overlapping 2x2 patches using all default parameters (except for splitter)
322+ TEST_CASE_1_TENSOR_c = [
323+ TENSOR_4x4 ,
324+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 ))),
325+ lambda x , condition : x + condition ,
326+ TENSOR_4x4 * 2 ,
327+ ]
328+
329+ # divisible batch_size
330+ TEST_CASE_2_TENSOR_c = [
331+ TENSOR_4x4 ,
332+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , batch_size = 2 ),
333+ lambda x , condition : x + condition ,
334+ TENSOR_4x4 * 2 ,
335+ ]
336+
337+ # non-divisible batch_size
338+ TEST_CASE_3_TENSOR_c = [
339+ TENSOR_4x4 ,
340+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , batch_size = 3 ),
341+ lambda x , condition : x + condition ,
342+ TENSOR_4x4 * 2 ,
343+ ]
344+
345+ # patches that are already split (Splitter should be None)
346+ TEST_CASE_4_SPLIT_LIST_c = [
347+ [
348+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
349+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
350+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
351+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
352+ ],
353+ dict (splitter = None , merger_cls = AvgMerger , merged_shape = (2 , 3 , 4 , 4 )),
354+ lambda x , condition : x + condition ,
355+ TENSOR_4x4 * 2 ,
356+ ]
357+
358+ # using all default parameters (patches are already split)
359+ TEST_CASE_5_SPLIT_LIST_c = [
360+ [
361+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
362+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
363+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
364+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
365+ ],
366+ dict (merger_cls = AvgMerger , merged_shape = (2 , 3 , 4 , 4 )),
367+ lambda x , condition : x + condition ,
368+ TENSOR_4x4 * 2 ,
369+ ]
370+
371+ # output smaller than input patches
372+ TEST_CASE_6_SMALLER_c = [
373+ TENSOR_4x4 ,
374+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
375+ lambda x , condition : torch .mean (x , dim = (- 1 , - 2 ), keepdim = True ) + torch .mean (condition , dim = (- 1 , - 2 ), keepdim = True ),
376+ TENSOR_2x2 * 2 ,
377+ ]
378+
379+ # preprocess patches
380+ TEST_CASE_7_PREPROCESS_c = [
381+ TENSOR_4x4 ,
382+ dict (
383+ splitter = SlidingWindowSplitter (patch_size = (2 , 2 )),
384+ merger_cls = AvgMerger ,
385+ preprocessing = lambda x : 2 * x ,
386+ postprocessing = None ,
387+ ),
388+ lambda x , condition : x + condition ,
389+ 2 * TENSOR_4x4 + TENSOR_4x4 ,
390+ ]
391+
392+ # preprocess patches
393+ TEST_CASE_8_POSTPROCESS_c = [
394+ TENSOR_4x4 ,
395+ dict (
396+ splitter = SlidingWindowSplitter (patch_size = (2 , 2 )),
397+ merger_cls = AvgMerger ,
398+ preprocessing = None ,
399+ postprocessing = lambda x : 4 * x ,
400+ ),
401+ lambda x , condition : x + condition ,
402+ 4 * TENSOR_4x4 * 2 ,
403+ ]
404+
405+ # str merger as the class name
406+ TEST_CASE_9_STR_MERGER_c = [
407+ TENSOR_4x4 ,
408+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = "AvgMerger" ),
409+ lambda x , condition : x + condition ,
410+ TENSOR_4x4 * 2 ,
411+ ]
412+
413+ # str merger as dotted patch
414+ TEST_CASE_10_STR_MERGER_c = [
415+ TENSOR_4x4 ,
416+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = "monai.inferers.merger.AvgMerger" ),
417+ lambda x , condition : x + condition ,
418+ TENSOR_4x4 * 2 ,
419+ ]
420+
421+ # non-divisible patch_size leading to larger image (without matching spatial shape)
422+ TEST_CASE_11_PADDING_c = [
423+ TENSOR_4x4 ,
424+ dict (
425+ splitter = SlidingWindowSplitter (patch_size = (2 , 3 ), pad_mode = "constant" , pad_value = 0.0 ),
426+ merger_cls = AvgMerger ,
427+ match_spatial_shape = False ,
428+ ),
429+ lambda x , condition : x + condition ,
430+ pad (TENSOR_4x4 , (0 , 2 ), value = 0.0 ) * 2 ,
431+ ]
432+
433+ # non-divisible patch_size with matching spatial shapes
434+ TEST_CASE_12_MATCHING_c = [
435+ TENSOR_4x4 ,
436+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 3 ), pad_mode = None ), merger_cls = AvgMerger ),
437+ lambda x , condition : x + condition ,
438+ pad (TENSOR_4x4 [..., :3 ], (0 , 1 ), value = float ("nan" )) * 2 ,
439+ ]
440+
441+ # non-divisible patch_size with matching spatial shapes
442+ TEST_CASE_13_PADDING_MATCHING_c = [
443+ TENSOR_4x4 ,
444+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 3 )), merger_cls = AvgMerger ),
445+ lambda x , condition : x + condition ,
446+ TENSOR_4x4 * 2 ,
447+ ]
448+
449+ # multi-threading
450+ TEST_CASE_14_MULTITHREAD_BUFFER_c = [
451+ TENSOR_4x4 ,
452+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , buffer_size = 2 ),
453+ lambda x , condition : x + condition ,
454+ TENSOR_4x4 * 2 ,
455+ ]
456+
457+ # multi-threading with batch
458+ TEST_CASE_15_MULTITHREADD_BUFFER_c = [
459+ TENSOR_4x4 ,
460+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , buffer_size = 4 , batch_size = 4 ),
461+ lambda x , condition : x + condition ,
462+ TENSOR_4x4 * 2 ,
463+ ]
464+
465+ # list of tensor output
466+ TEST_CASE_0_LIST_TENSOR_c = [
467+ TENSOR_4x4 ,
468+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
469+ lambda x , condition : (x + condition , x + condition ),
470+ (TENSOR_4x4 * 2 , TENSOR_4x4 * 2 ),
471+ ]
472+
473+ # list of tensor output
474+ TEST_CASE_0_DICT_c = [
475+ TENSOR_4x4 ,
476+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
477+ lambda x , condition : {"model_output" : x + condition },
478+ {"model_output" : TENSOR_4x4 * 2 },
479+ ]
480+
481+
482+
483+ class PatchInfererTests_cond (unittest .TestCase ):
484+ @parameterized .expand (
485+ [
486+ TEST_CASE_0_TENSOR_c ,
487+ TEST_CASE_1_TENSOR_c ,
488+ TEST_CASE_2_TENSOR_c ,
489+ TEST_CASE_3_TENSOR_c ,
490+ TEST_CASE_4_SPLIT_LIST_c ,
491+ TEST_CASE_5_SPLIT_LIST_c ,
492+ TEST_CASE_6_SMALLER_c ,
493+ TEST_CASE_7_PREPROCESS_c ,
494+ TEST_CASE_8_POSTPROCESS_c ,
495+ TEST_CASE_9_STR_MERGER_c ,
496+ TEST_CASE_10_STR_MERGER_c ,
497+ TEST_CASE_11_PADDING_c ,
498+ TEST_CASE_12_MATCHING_c ,
499+ TEST_CASE_13_PADDING_MATCHING_c ,
500+ TEST_CASE_14_MULTITHREAD_BUFFER_c ,
501+ TEST_CASE_15_MULTITHREADD_BUFFER_c ,
502+ ]
503+ )
504+ def test_patch_inferer_tensor (self , inputs , arguments , network , expected ):
505+ if isinstance (inputs , list ): # case 4 and 5
506+ condition = [(x [0 ].clone (), x [1 ]) for x in inputs ]
507+ else :
508+ condition = inputs .clone ()
509+ inferer = PatchInferer (** arguments )
510+ output = inferer (inputs = inputs , network = network , condition = condition )
511+ assert_allclose (output , expected )
512+
513+ @parameterized .expand ([TEST_CASE_0_LIST_TENSOR_c ])
514+ def test_patch_inferer_list_tensor (self , inputs , arguments , network , expected ):
515+ if isinstance (inputs , list ): # case 4 and 5
516+ condition = [(x [0 ].clone (), x [1 ]) for x in inputs ]
517+ else :
518+ condition = inputs .clone ()
519+ inferer = PatchInferer (** arguments )
520+ output = inferer (inputs = inputs , network = network , condition = condition )
521+ for out , exp in zip (output , expected ):
522+ assert_allclose (out , exp )
523+
524+ @parameterized .expand ([TEST_CASE_0_DICT_c ])
525+ def test_patch_inferer_dict (self , inputs , arguments , network , expected ):
526+ if isinstance (inputs , list ): # case 4 and 5
527+ condition = [(x [0 ].clone (), x [1 ]) for x in inputs ]
528+ else :
529+ condition = inputs .clone ()
530+ inferer = PatchInferer (** arguments )
531+ output = inferer (inputs = inputs , network = network , condition = condition )
532+ for k in expected :
533+ assert_allclose (output [k ], expected [k ])
534+
535+
308536if __name__ == "__main__" :
309537 unittest .main ()
0 commit comments