Skip to content

Commit 3c9a51b

Browse files
Merge pull request #672 from Inquisitive-ME:fix-efficientnet_v2-channels-first
PiperOrigin-RevId: 592265098
2 parents 5359b3c + 6840eb4 commit 3c9a51b

File tree

4 files changed

+121
-51
lines changed

4 files changed

+121
-51
lines changed

tf_keras/applications/applications_test.py

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@
119119

120120
MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
121121

122+
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "NASNet", "RegNetX", "RegNetY"]
123+
# Add each data format for each model
124+
test_parameters_with_image_data_format = [
125+
(
126+
"{}_{}".format(model[0].__name__, image_data_format),
127+
*model,
128+
image_data_format,
129+
)
130+
for image_data_format in ["channels_first", "channels_last"]
131+
for model in MODEL_LIST
132+
]
133+
122134
# Parameters for loading weights for MobileNetV3.
123135
# (class, alpha, minimalistic, include_top)
124136
MOBILENET_V3_FOR_WEIGHTS = [
@@ -138,7 +150,14 @@
138150

139151

140152
class ApplicationsTest(tf.test.TestCase, parameterized.TestCase):
141-
def assertShapeEqual(self, shape1, shape2):
153+
def setUp(self):
154+
self.original_image_data_format = backend.image_data_format()
155+
156+
def tearDown(self):
157+
backend.set_image_data_format(self.original_image_data_format)
158+
159+
@classmethod
160+
def assertShapeEqual(cls, shape1, shape2):
142161
if len(shape1) != len(shape2):
143162
raise AssertionError(
144163
f"Shapes are different rank: {shape1} vs {shape2}"
@@ -147,8 +166,27 @@ def assertShapeEqual(self, shape1, shape2):
147166
if v1 != v2:
148167
raise AssertionError(f"Shapes differ: {shape1} vs {shape2}")
149168

150-
@parameterized.parameters(*MODEL_LIST)
151-
def test_application_base(self, app, _):
169+
def skip_if_invalid_image_data_format_for_model(
170+
self, app, image_data_format
171+
):
172+
does_not_support_channels_first = any(
173+
[
174+
unsupported_name.lower() in app.__name__.lower()
175+
for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST
176+
]
177+
)
178+
if (
179+
image_data_format == "channels_first"
180+
and does_not_support_channels_first
181+
):
182+
self.skipTest(
183+
"{} does not support channels first".format(app.__name__)
184+
)
185+
186+
@parameterized.named_parameters(test_parameters_with_image_data_format)
187+
def test_application_base(self, app, _, image_data_format):
188+
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
189+
backend.set_image_data_format(image_data_format)
152190
# Can be instantiated with default arguments
153191
model = app(weights=None)
154192
# Can be serialized and deserialized
@@ -162,36 +200,55 @@ def test_application_base(self, app, _):
162200
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
163201
backend.clear_session()
164202

165-
@parameterized.parameters(*MODEL_LIST)
166-
def test_application_notop(self, app, last_dim):
203+
@parameterized.named_parameters(test_parameters_with_image_data_format)
204+
def test_application_notop(self, app, last_dim, image_data_format):
205+
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
206+
backend.set_image_data_format(image_data_format)
207+
if image_data_format == "channels_first":
208+
input_shape = (3, None, None)
209+
correct_output_shape = (None, last_dim, None, None)
210+
channels_axis = 1
211+
else:
212+
input_shape = (None, None, 3)
213+
correct_output_shape = (None, None, None, last_dim)
214+
channels_axis = -1
215+
167216
if "NASNet" in app.__name__:
168217
only_check_last_dim = True
169218
else:
170219
only_check_last_dim = False
171-
output_shape = _get_output_shape(
172-
lambda: app(weights=None, include_top=False)
173-
)
220+
output_shape = app(
221+
weights=None, include_top=False, input_shape=input_shape
222+
).output_shape
174223
if only_check_last_dim:
175-
self.assertEqual(output_shape[-1], last_dim)
224+
self.assertEqual(output_shape[channels_axis], last_dim)
176225
else:
177-
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
226+
self.assertShapeEqual(output_shape, correct_output_shape)
178227
backend.clear_session()
179228

180-
@parameterized.parameters(*MODEL_LIST)
181-
def test_application_notop_custom_input_shape(self, app, last_dim):
182-
output_shape = _get_output_shape(
183-
lambda: app(
184-
weights="imagenet", include_top=False, input_shape=(224, 224, 3)
185-
)
186-
)
229+
@parameterized.named_parameters(test_parameters_with_image_data_format)
230+
def test_application_notop_custom_input_shape(
231+
self, app, last_dim, image_data_format
232+
):
233+
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
234+
backend.set_image_data_format(image_data_format)
235+
if image_data_format == "channels_first":
236+
input_shape = (3, 224, 224)
237+
channels_axis = 1
238+
else:
239+
input_shape = (224, 224, 3)
240+
channels_axis = -1
241+
output_shape = app(
242+
weights="imagenet", include_top=False, input_shape=input_shape
243+
).output_shape
187244

188-
self.assertEqual(output_shape[-1], last_dim)
245+
self.assertEqual(output_shape[channels_axis], last_dim)
189246

190247
@parameterized.parameters(MODEL_LIST)
191248
def test_application_pooling(self, app, last_dim):
192-
output_shape = _get_output_shape(
193-
lambda: app(weights=None, include_top=False, pooling="avg")
194-
)
249+
output_shape = app(
250+
weights=None, include_top=False, pooling="avg"
251+
).output_shape
195252
self.assertShapeEqual(output_shape, (None, last_dim))
196253

197254
@parameterized.parameters(MODEL_LIST)
@@ -204,30 +261,34 @@ def test_application_classifier_activation(self, app, _):
204261
last_layer_act = model.layers[-1].activation.__name__
205262
self.assertEqual(last_layer_act, "softmax")
206263

207-
@parameterized.parameters(*MODEL_LIST_NO_NASNET)
208-
def test_application_variable_input_channels(self, app, last_dim):
264+
@parameterized.named_parameters(test_parameters_with_image_data_format)
265+
def test_application_variable_input_channels(
266+
self, app, last_dim, image_data_format
267+
):
268+
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
269+
backend.set_image_data_format(image_data_format)
209270
if backend.image_data_format() == "channels_first":
210271
input_shape = (1, None, None)
272+
correct_output_shape = (None, last_dim, None, None)
211273
else:
212274
input_shape = (None, None, 1)
213-
output_shape = _get_output_shape(
214-
lambda: app(
215-
weights=None, include_top=False, input_shape=input_shape
216-
)
217-
)
218-
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
275+
correct_output_shape = (None, None, None, last_dim)
276+
output_shape = app(
277+
weights=None, include_top=False, input_shape=input_shape
278+
).output_shape
279+
280+
self.assertShapeEqual(output_shape, correct_output_shape)
219281
backend.clear_session()
220282

221283
if backend.image_data_format() == "channels_first":
222284
input_shape = (4, None, None)
223285
else:
224286
input_shape = (None, None, 4)
225-
output_shape = _get_output_shape(
226-
lambda: app(
227-
weights=None, include_top=False, input_shape=input_shape
228-
)
229-
)
230-
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
287+
output_shape = app(
288+
weights=None, include_top=False, input_shape=input_shape
289+
).output_shape
290+
291+
self.assertShapeEqual(output_shape, correct_output_shape)
231292
backend.clear_session()
232293

233294
@parameterized.parameters(*MOBILENET_V3_FOR_WEIGHTS)
@@ -242,9 +303,12 @@ def test_mobilenet_v3_load_weights(
242303
include_top=include_top,
243304
)
244305

245-
@parameterized.parameters(MODEL_LIST)
306+
@parameterized.named_parameters(test_parameters_with_image_data_format)
246307
@test_utils.run_v2_only
247-
def test_model_checkpoint(self, app, _):
308+
def test_model_checkpoint(self, app, _, image_data_format):
309+
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
310+
backend.set_image_data_format(image_data_format)
311+
248312
model = app(weights=None)
249313

250314
checkpoint = tf.train.Checkpoint(model=model)
@@ -256,10 +320,5 @@ def test_model_checkpoint(self, app, _):
256320
checkpoint_manager.save(checkpoint_number=1)
257321

258322

259-
def _get_output_shape(model_fn):
260-
model = model_fn()
261-
return model.output_shape
262-
263-
264323
if __name__ == "__main__":
265324
tf.test.main()

tf_keras/applications/efficientnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,17 @@ def round_repeats(repeats):
364364
# original implementation.
365365
# See https://github.com/tensorflow/tensorflow/issues/49930 for more
366366
# details
367-
x = layers.Rescaling(
368-
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB]
369-
)(x)
367+
if backend.image_data_format() == "channels_first":
368+
shape_for_multiply = [1, 3, 1, 1]
369+
else:
370+
shape_for_multiply = [1, 1, 1, 3]
371+
x = tf.math.multiply(
372+
x,
373+
tf.reshape(
374+
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB],
375+
shape_for_multiply,
376+
),
377+
)
370378

371379
x = layers.ZeroPadding2D(
372380
padding=imagenet_utils.correct_pad(x, 3), name="stem_conv_pad"

tf_keras/applications/efficientnet_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def apply(inputs):
656656
strides=1,
657657
kernel_initializer=CONV_KERNEL_INITIALIZER,
658658
padding="same",
659-
data_format="channels_last",
659+
data_format=backend.image_data_format(),
660660
use_bias=False,
661661
name=name + "expand_conv",
662662
)(inputs)
@@ -677,7 +677,7 @@ def apply(inputs):
677677
strides=strides,
678678
depthwise_initializer=CONV_KERNEL_INITIALIZER,
679679
padding="same",
680-
data_format="channels_last",
680+
data_format=backend.image_data_format(),
681681
use_bias=False,
682682
name=name + "dwconv2",
683683
)(x)
@@ -722,7 +722,7 @@ def apply(inputs):
722722
strides=1,
723723
kernel_initializer=CONV_KERNEL_INITIALIZER,
724724
padding="same",
725-
data_format="channels_last",
725+
data_format=backend.image_data_format(),
726726
use_bias=False,
727727
name=name + "project_conv",
728728
)(x)
@@ -771,7 +771,7 @@ def apply(inputs):
771771
kernel_size=kernel_size,
772772
strides=strides,
773773
kernel_initializer=CONV_KERNEL_INITIALIZER,
774-
data_format="channels_last",
774+
data_format=backend.image_data_format(),
775775
padding="same",
776776
use_bias=False,
777777
name=name + "expand_conv",
@@ -1052,7 +1052,7 @@ def EfficientNetV2(
10521052
strides=1,
10531053
kernel_initializer=CONV_KERNEL_INITIALIZER,
10541054
padding="same",
1055-
data_format="channels_last",
1055+
data_format=backend.image_data_format(),
10561056
use_bias=False,
10571057
name="top_conv",
10581058
)(x)

tf_keras/applications/mobilenet_v3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,10 @@ def MobileNetV3(
269269
input_shape = (cols, rows, 3)
270270
# If input_shape is None and input_tensor is None using standard shape
271271
if input_shape is None and input_tensor is None:
272-
input_shape = (None, None, 3)
272+
if backend.image_data_format() == "channels_last":
273+
input_shape = (None, None, 3)
274+
else:
275+
input_shape = (3, None, None)
273276

274277
if backend.image_data_format() == "channels_last":
275278
row_axis, col_axis = (0, 1)

0 commit comments

Comments
 (0)