Skip to content

Commit cf9f492

Browse files
Fix code formatting
1 parent faa997d commit cf9f492

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

tf_keras/applications/applications_test.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@
122122
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "NASNet", "RegNetX", "RegNetY"]
123123
# Add each data format for each model
124124
test_parameters_with_image_data_format = [
125-
('{}_{}'.format(model[0].__name__, image_data_format), *model, image_data_format)
125+
(
126+
"{}_{}".format(model[0].__name__, image_data_format),
127+
*model,
128+
image_data_format,
129+
)
126130
for image_data_format in ["channels_first", "channels_last"]
127131
for model in MODEL_LIST
128132
]
@@ -164,11 +168,19 @@ def assertShapeEqual(cls, shape1, shape2):
164168
if v1 != v2:
165169
raise AssertionError(f"Shapes differ: {shape1} vs {shape2}")
166170

167-
def skip_if_invalid_image_data_format_for_model(self, app, image_data_format):
171+
def skip_if_invalid_image_data_format_for_model(
172+
self, app, image_data_format
173+
):
168174
does_not_support_channels_first = any(
169-
[unsupported_name.lower() in app.__name__.lower() for unsupported_name in
170-
MODELS_UNSUPPORTED_CHANNELS_FIRST])
171-
if image_data_format == "channels_first" and does_not_support_channels_first:
175+
[
176+
unsupported_name.lower() in app.__name__.lower()
177+
for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST
178+
]
179+
)
180+
if (
181+
image_data_format == "channels_first"
182+
and does_not_support_channels_first
183+
):
172184
self.skipTest(
173185
"{} does not support channels first".format(app.__name__)
174186
)
@@ -207,15 +219,19 @@ def test_application_notop(self, app, last_dim, image_data_format):
207219
only_check_last_dim = True
208220
else:
209221
only_check_last_dim = False
210-
output_shape = app(weights=None, include_top=False, input_shape=input_shape).output_shape
222+
output_shape = app(
223+
weights=None, include_top=False, input_shape=input_shape
224+
).output_shape
211225
if only_check_last_dim:
212226
self.assertEqual(output_shape[channels_axis], last_dim)
213227
else:
214228
self.assertShapeEqual(output_shape, correct_output_shape)
215229
backend.clear_session()
216230

217231
@parameterized.named_parameters(test_parameters_with_image_data_format)
218-
def test_application_notop_custom_input_shape(self, app, last_dim, image_data_format):
232+
def test_application_notop_custom_input_shape(
233+
self, app, last_dim, image_data_format
234+
):
219235
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
220236
backend.set_image_data_format(image_data_format)
221237
if image_data_format == "channels_first":
@@ -224,13 +240,17 @@ def test_application_notop_custom_input_shape(self, app, last_dim, image_data_fo
224240
else:
225241
input_shape = (224, 224, 3)
226242
channels_axis = -1
227-
output_shape = app(weights="imagenet", include_top=False, input_shape=input_shape).output_shape
243+
output_shape = app(
244+
weights="imagenet", include_top=False, input_shape=input_shape
245+
).output_shape
228246

229247
self.assertEqual(output_shape[channels_axis], last_dim)
230248

231249
@parameterized.parameters(MODEL_LIST)
232250
def test_application_pooling(self, app, last_dim):
233-
output_shape = app(weights=None, include_top=False, pooling="avg").output_shape
251+
output_shape = app(
252+
weights=None, include_top=False, pooling="avg"
253+
).output_shape
234254
self.assertShapeEqual(output_shape, (None, last_dim))
235255

236256
@parameterized.parameters(MODEL_LIST)
@@ -244,7 +264,9 @@ def test_application_classifier_activation(self, app, _):
244264
self.assertEqual(last_layer_act, "softmax")
245265

246266
@parameterized.named_parameters(test_parameters_with_image_data_format)
247-
def test_application_variable_input_channels(self, app, last_dim, image_data_format):
267+
def test_application_variable_input_channels(
268+
self, app, last_dim, image_data_format
269+
):
248270
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
249271
backend.set_image_data_format(image_data_format)
250272
if backend.image_data_format() == "channels_first":
@@ -253,7 +275,9 @@ def test_application_variable_input_channels(self, app, last_dim, image_data_for
253275
else:
254276
input_shape = (None, None, 1)
255277
correct_output_shape = (None, None, None, last_dim)
256-
output_shape = app(weights=None, include_top=False, input_shape=input_shape).output_shape
278+
output_shape = app(
279+
weights=None, include_top=False, input_shape=input_shape
280+
).output_shape
257281

258282
self.assertShapeEqual(output_shape, correct_output_shape)
259283
backend.clear_session()
@@ -262,7 +286,9 @@ def test_application_variable_input_channels(self, app, last_dim, image_data_for
262286
input_shape = (4, None, None)
263287
else:
264288
input_shape = (None, None, 4)
265-
output_shape = app(weights=None, include_top=False, input_shape=input_shape).output_shape
289+
output_shape = app(
290+
weights=None, include_top=False, input_shape=input_shape
291+
).output_shape
266292

267293
self.assertShapeEqual(output_shape, correct_output_shape)
268294
backend.clear_session()

tf_keras/applications/efficientnet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,17 @@ def round_repeats(repeats):
364364
# original implementation.
365365
# See https://github.com/tensorflow/tensorflow/issues/49930 for more
366366
# details
367-
if backend.image_data_format() == 'channels_first':
367+
if backend.image_data_format() == "channels_first":
368368
shape_for_multiply = [1, 3, 1, 1]
369369
else:
370370
shape_for_multiply = [1, 1, 1, 3]
371-
x = tf.math.multiply(x,
372-
tf.reshape(
373-
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB],
374-
shape_for_multiply
375-
))
371+
x = tf.math.multiply(
372+
x,
373+
tf.reshape(
374+
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB],
375+
shape_for_multiply,
376+
),
377+
)
376378

377379
x = layers.ZeroPadding2D(
378380
padding=imagenet_utils.correct_pad(x, 3), name="stem_conv_pad"

0 commit comments

Comments
 (0)