122122MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt" , "NASNet" , "RegNetX" , "RegNetY" ]
123123# Add each data format for each model
124124test_parameters_with_image_data_format = [
125- ('{}_{}' .format (model [0 ].__name__ , image_data_format ), * model , image_data_format )
125+ (
126+ "{}_{}" .format (model [0 ].__name__ , image_data_format ),
127+ * model ,
128+ image_data_format ,
129+ )
126130 for image_data_format in ["channels_first" , "channels_last" ]
127131 for model in MODEL_LIST
128132]
@@ -164,11 +168,19 @@ def assertShapeEqual(cls, shape1, shape2):
164168 if v1 != v2 :
165169 raise AssertionError (f"Shapes differ: { shape1 } vs { shape2 } " )
166170
167- def skip_if_invalid_image_data_format_for_model (self , app , image_data_format ):
171+ def skip_if_invalid_image_data_format_for_model (
172+ self , app , image_data_format
173+ ):
168174 does_not_support_channels_first = any (
169- [unsupported_name .lower () in app .__name__ .lower () for unsupported_name in
170- MODELS_UNSUPPORTED_CHANNELS_FIRST ])
171- if image_data_format == "channels_first" and does_not_support_channels_first :
175+ [
176+ unsupported_name .lower () in app .__name__ .lower ()
177+ for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST
178+ ]
179+ )
180+ if (
181+ image_data_format == "channels_first"
182+ and does_not_support_channels_first
183+ ):
172184 self .skipTest (
173185 "{} does not support channels first" .format (app .__name__ )
174186 )
@@ -207,15 +219,19 @@ def test_application_notop(self, app, last_dim, image_data_format):
207219 only_check_last_dim = True
208220 else :
209221 only_check_last_dim = False
210- output_shape = app (weights = None , include_top = False , input_shape = input_shape ).output_shape
222+ output_shape = app (
223+ weights = None , include_top = False , input_shape = input_shape
224+ ).output_shape
211225 if only_check_last_dim :
212226 self .assertEqual (output_shape [channels_axis ], last_dim )
213227 else :
214228 self .assertShapeEqual (output_shape , correct_output_shape )
215229 backend .clear_session ()
216230
217231 @parameterized .named_parameters (test_parameters_with_image_data_format )
218- def test_application_notop_custom_input_shape (self , app , last_dim , image_data_format ):
232+ def test_application_notop_custom_input_shape (
233+ self , app , last_dim , image_data_format
234+ ):
219235 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
220236 backend .set_image_data_format (image_data_format )
221237 if image_data_format == "channels_first" :
@@ -224,13 +240,17 @@ def test_application_notop_custom_input_shape(self, app, last_dim, image_data_fo
224240 else :
225241 input_shape = (224 , 224 , 3 )
226242 channels_axis = - 1
227- output_shape = app (weights = "imagenet" , include_top = False , input_shape = input_shape ).output_shape
243+ output_shape = app (
244+ weights = "imagenet" , include_top = False , input_shape = input_shape
245+ ).output_shape
228246
229247 self .assertEqual (output_shape [channels_axis ], last_dim )
230248
231249 @parameterized .parameters (MODEL_LIST )
232250 def test_application_pooling (self , app , last_dim ):
233- output_shape = app (weights = None , include_top = False , pooling = "avg" ).output_shape
251+ output_shape = app (
252+ weights = None , include_top = False , pooling = "avg"
253+ ).output_shape
234254 self .assertShapeEqual (output_shape , (None , last_dim ))
235255
236256 @parameterized .parameters (MODEL_LIST )
@@ -244,7 +264,9 @@ def test_application_classifier_activation(self, app, _):
244264 self .assertEqual (last_layer_act , "softmax" )
245265
246266 @parameterized .named_parameters (test_parameters_with_image_data_format )
247- def test_application_variable_input_channels (self , app , last_dim , image_data_format ):
267+ def test_application_variable_input_channels (
268+ self , app , last_dim , image_data_format
269+ ):
248270 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
249271 backend .set_image_data_format (image_data_format )
250272 if backend .image_data_format () == "channels_first" :
@@ -253,7 +275,9 @@ def test_application_variable_input_channels(self, app, last_dim, image_data_for
253275 else :
254276 input_shape = (None , None , 1 )
255277 correct_output_shape = (None , None , None , last_dim )
256- output_shape = app (weights = None , include_top = False , input_shape = input_shape ).output_shape
278+ output_shape = app (
279+ weights = None , include_top = False , input_shape = input_shape
280+ ).output_shape
257281
258282 self .assertShapeEqual (output_shape , correct_output_shape )
259283 backend .clear_session ()
@@ -262,7 +286,9 @@ def test_application_variable_input_channels(self, app, last_dim, image_data_for
262286 input_shape = (4 , None , None )
263287 else :
264288 input_shape = (None , None , 4 )
265- output_shape = app (weights = None , include_top = False , input_shape = input_shape ).output_shape
289+ output_shape = app (
290+ weights = None , include_top = False , input_shape = input_shape
291+ ).output_shape
266292
267293 self .assertShapeEqual (output_shape , correct_output_shape )
268294 backend .clear_session ()
0 commit comments