119119
120120MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
121121
122+ MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt" , "NASNet" , "RegNetX" , "RegNetY" ]
123+ # Add each data format for each model
124+ test_parameters_with_image_data_format = [
125+ (
126+ "{}_{}" .format (model [0 ].__name__ , image_data_format ),
127+ * model ,
128+ image_data_format ,
129+ )
130+ for image_data_format in ["channels_first" , "channels_last" ]
131+ for model in MODEL_LIST
132+ ]
133+
122134# Parameters for loading weights for MobileNetV3.
123135# (class, alpha, minimalistic, include_top)
124136MOBILENET_V3_FOR_WEIGHTS = [
138150
139151
140152class ApplicationsTest (tf .test .TestCase , parameterized .TestCase ):
141- def assertShapeEqual (self , shape1 , shape2 ):
153+ def setUp (self ):
154+ self .original_image_data_format = backend .image_data_format ()
155+
156+ def tearDown (self ):
157+ backend .set_image_data_format (self .original_image_data_format )
158+
159+ @classmethod
160+ def assertShapeEqual (cls , shape1 , shape2 ):
142161 if len (shape1 ) != len (shape2 ):
143162 raise AssertionError (
144163 f"Shapes are different rank: { shape1 } vs { shape2 } "
@@ -147,8 +166,27 @@ def assertShapeEqual(self, shape1, shape2):
147166 if v1 != v2 :
148167 raise AssertionError (f"Shapes differ: { shape1 } vs { shape2 } " )
149168
150- @parameterized .parameters (* MODEL_LIST )
151- def test_application_base (self , app , _ ):
169+ def skip_if_invalid_image_data_format_for_model (
170+ self , app , image_data_format
171+ ):
172+ does_not_support_channels_first = any (
173+ [
174+ unsupported_name .lower () in app .__name__ .lower ()
175+ for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST
176+ ]
177+ )
178+ if (
179+ image_data_format == "channels_first"
180+ and does_not_support_channels_first
181+ ):
182+ self .skipTest (
183+ "{} does not support channels first" .format (app .__name__ )
184+ )
185+
186+ @parameterized .named_parameters (test_parameters_with_image_data_format )
187+ def test_application_base (self , app , _ , image_data_format ):
188+ self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
189+ backend .set_image_data_format (image_data_format )
152190 # Can be instantiated with default arguments
153191 model = app (weights = None )
154192 # Can be serialized and deserialized
@@ -162,36 +200,55 @@ def test_application_base(self, app, _):
162200 self .assertEqual (len (model .weights ), len (reconstructed_model .weights ))
163201 backend .clear_session ()
164202
165- @parameterized .parameters (* MODEL_LIST )
166- def test_application_notop (self , app , last_dim ):
203+ @parameterized .named_parameters (test_parameters_with_image_data_format )
204+ def test_application_notop (self , app , last_dim , image_data_format ):
205+ self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
206+ backend .set_image_data_format (image_data_format )
207+ if image_data_format == "channels_first" :
208+ input_shape = (3 , None , None )
209+ correct_output_shape = (None , last_dim , None , None )
210+ channels_axis = 1
211+ else :
212+ input_shape = (None , None , 3 )
213+ correct_output_shape = (None , None , None , last_dim )
214+ channels_axis = - 1
215+
167216 if "NASNet" in app .__name__ :
168217 only_check_last_dim = True
169218 else :
170219 only_check_last_dim = False
171- output_shape = _get_output_shape (
172- lambda : app ( weights = None , include_top = False )
173- )
220+ output_shape = app (
221+ weights = None , include_top = False , input_shape = input_shape
222+ ). output_shape
174223 if only_check_last_dim :
175- self .assertEqual (output_shape [- 1 ], last_dim )
224+ self .assertEqual (output_shape [channels_axis ], last_dim )
176225 else :
177- self .assertShapeEqual (output_shape , ( None , None , None , last_dim ) )
226+ self .assertShapeEqual (output_shape , correct_output_shape )
178227 backend .clear_session ()
179228
180- @parameterized .parameters (* MODEL_LIST )
181- def test_application_notop_custom_input_shape (self , app , last_dim ):
182- output_shape = _get_output_shape (
183- lambda : app (
184- weights = "imagenet" , include_top = False , input_shape = (224 , 224 , 3 )
185- )
186- )
229+ @parameterized .named_parameters (test_parameters_with_image_data_format )
230+ def test_application_notop_custom_input_shape (
231+ self , app , last_dim , image_data_format
232+ ):
233+ self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
234+ backend .set_image_data_format (image_data_format )
235+ if image_data_format == "channels_first" :
236+ input_shape = (3 , 224 , 224 )
237+ channels_axis = 1
238+ else :
239+ input_shape = (224 , 224 , 3 )
240+ channels_axis = - 1
241+ output_shape = app (
242+ weights = "imagenet" , include_top = False , input_shape = input_shape
243+ ).output_shape
187244
188- self .assertEqual (output_shape [- 1 ], last_dim )
245+ self .assertEqual (output_shape [channels_axis ], last_dim )
189246
190247 @parameterized .parameters (MODEL_LIST )
191248 def test_application_pooling (self , app , last_dim ):
192- output_shape = _get_output_shape (
193- lambda : app ( weights = None , include_top = False , pooling = "avg" )
194- )
249+ output_shape = app (
250+ weights = None , include_top = False , pooling = "avg"
251+ ). output_shape
195252 self .assertShapeEqual (output_shape , (None , last_dim ))
196253
197254 @parameterized .parameters (MODEL_LIST )
@@ -204,30 +261,34 @@ def test_application_classifier_activation(self, app, _):
204261 last_layer_act = model .layers [- 1 ].activation .__name__
205262 self .assertEqual (last_layer_act , "softmax" )
206263
207- @parameterized .parameters (* MODEL_LIST_NO_NASNET )
208- def test_application_variable_input_channels (self , app , last_dim ):
264+ @parameterized .named_parameters (test_parameters_with_image_data_format )
265+ def test_application_variable_input_channels (
266+ self , app , last_dim , image_data_format
267+ ):
268+ self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
269+ backend .set_image_data_format (image_data_format )
209270 if backend .image_data_format () == "channels_first" :
210271 input_shape = (1 , None , None )
272+ correct_output_shape = (None , last_dim , None , None )
211273 else :
212274 input_shape = (None , None , 1 )
213- output_shape = _get_output_shape (
214- lambda : app (
215- weights = None , include_top = False , input_shape = input_shape
216- )
217- )
218- self .assertShapeEqual (output_shape , ( None , None , None , last_dim ) )
275+ correct_output_shape = ( None , None , None , last_dim )
276+ output_shape = app (
277+ weights = None , include_top = False , input_shape = input_shape
278+ ). output_shape
279+
280+ self .assertShapeEqual (output_shape , correct_output_shape )
219281 backend .clear_session ()
220282
221283 if backend .image_data_format () == "channels_first" :
222284 input_shape = (4 , None , None )
223285 else :
224286 input_shape = (None , None , 4 )
225- output_shape = _get_output_shape (
226- lambda : app (
227- weights = None , include_top = False , input_shape = input_shape
228- )
229- )
230- self .assertShapeEqual (output_shape , (None , None , None , last_dim ))
287+ output_shape = app (
288+ weights = None , include_top = False , input_shape = input_shape
289+ ).output_shape
290+
291+ self .assertShapeEqual (output_shape , correct_output_shape )
231292 backend .clear_session ()
232293
233294 @parameterized .parameters (* MOBILENET_V3_FOR_WEIGHTS )
@@ -242,9 +303,12 @@ def test_mobilenet_v3_load_weights(
242303 include_top = include_top ,
243304 )
244305
245- @parameterized .parameters ( MODEL_LIST )
306+ @parameterized .named_parameters ( test_parameters_with_image_data_format )
246307 @test_utils .run_v2_only
247- def test_model_checkpoint (self , app , _ ):
308+ def test_model_checkpoint (self , app , _ , image_data_format ):
309+ self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
310+ backend .set_image_data_format (image_data_format )
311+
248312 model = app (weights = None )
249313
250314 checkpoint = tf .train .Checkpoint (model = model )
@@ -256,10 +320,5 @@ def test_model_checkpoint(self, app, _):
256320 checkpoint_manager .save (checkpoint_number = 1 )
257321
258322
259- def _get_output_shape (model_fn ):
260- model = model_fn ()
261- return model .output_shape
262-
263-
264323if __name__ == "__main__" :
265324 tf .test .main ()
0 commit comments