@@ -5032,25 +5032,30 @@ def test_transform(self, make_input):
50325032 check_transform (transforms .CenterCrop (self .OUTPUT_SIZES [0 ]), make_input (self .INPUT_SIZE ))
50335033
50345034 @pytest .mark .parametrize ("output_size" , OUTPUT_SIZES )
5035+ @pytest .mark .parametrize (
5036+ "make_input" ,
5037+ [
5038+ make_image ,
5039+ pytest .param (
5040+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5041+ ),
5042+ ],
5043+ )
50355044 @pytest .mark .parametrize ("fn" , [F .center_crop , transform_cls_to_functional (transforms .CenterCrop )])
5036- def test_image_correctness (self , output_size , fn ):
5037- image = make_image (self .INPUT_SIZE , dtype = torch .uint8 , device = "cpu" )
5045+ def test_image_correctness (self , output_size , make_input , fn ):
5046+ image = make_input (self .INPUT_SIZE , dtype = torch .uint8 , device = "cpu" )
50385047
50395048 actual = fn (image , output_size )
5040- expected = F .to_image (F .center_crop (F .to_pil_image (image ), output_size = output_size ))
50415049
5042- assert_equal (actual , expected )
5043-
5044- @pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5045- @pytest .mark .parametrize ("output_size" , OUTPUT_SIZES )
5046- @pytest .mark .parametrize ("fn" , [F .center_crop , transform_cls_to_functional (transforms .CenterCrop )])
5047- def test_cvcuda_correctness (self , output_size , fn ):
5048- image = make_image_cvcuda (self .INPUT_SIZE , dtype = torch .uint8 , device = "cuda" )
5050+ if make_input == make_image_cvcuda :
5051+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5052+ actual = actual .squeeze (0 )
5053+ image = F .cvcuda_to_tensor (image ).to (device = "cpu" )
5054+ image = image .squeeze (0 )
50495055
5050- actual = fn (image , output_size )
5051- expected = F .center_crop (F .cvcuda_to_tensor (image ), output_size )
5056+ expected = F .to_image (F .center_crop (F .to_pil_image (image ), output_size = output_size ))
50525057
5053- assert_equal (F . cvcuda_to_tensor ( actual ) , expected )
5058+ assert_equal (actual , expected )
50545059
50555060 def _reference_center_crop_bounding_boxes (self , bounding_boxes , output_size ):
50565061 image_height , image_width = bounding_boxes .canvas_size
0 commit comments