@@ -174,13 +174,29 @@ def test_inference(self):
174174 inputs = self .get_dummy_inputs (device )
175175 image = pipe (** inputs ).images
176176 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
177- print (torch .from_numpy (image_slice .flatten ()))
178177
179178 self .assertEqual (image .shape , (1 , 8 , 8 , 3 ))
180179 expected_slice = np .array ([0.5303 , 0.2658 , 0.7979 , 0.1182 , 0.3304 , 0.4608 , 0.5195 , 0.4261 , 0.4675 ])
181180 max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
182181 self .assertLessEqual (max_diff , 1e-3 )
183182
183+ def test_inference_non_square_images (self ):
184+ device = "cpu"
185+
186+ components = self .get_dummy_components ()
187+ pipe = self .pipeline_class (** components )
188+ pipe .to (device )
189+ pipe .set_progress_bar_config (disable = None )
190+
191+ inputs = self .get_dummy_inputs (device )
192+ image = pipe (** inputs , height = 32 , width = 48 ).images
193+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
194+
195+ self .assertEqual (image .shape , (1 , 32 , 48 , 3 ))
196+ expected_slice = np .array ([0.3859 , 0.2987 , 0.2333 , 0.5243 , 0.6721 , 0.4436 , 0.5292 , 0.5373 , 0.4416 ])
197+ max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
198+ self .assertLessEqual (max_diff , 1e-3 )
199+
184200 def test_inference_with_embeddings_and_multiple_images (self ):
185201 components = self .get_dummy_components ()
186202 pipe = self .pipeline_class (** components )
0 commit comments