Skip to content

Commit aec3c96

Browse files
authored
Revert "Support fused batchnorm with any ndims and axis" (#695)
This reverts commit 65dabe1.
1 parent 08c81ad commit aec3c96

File tree

3 files changed

+62
-234
lines changed

3 files changed

+62
-234
lines changed

tensorflow/python/keras/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -943,15 +943,15 @@ tf_py_test(
943943

944944
tf_py_test(
945945
name = "normalization_test",
946-
size = "large",
946+
size = "medium",
947947
srcs = ["layers/normalization_test.py"],
948948
additional_deps = [
949949
":keras",
950950
"@absl_py//absl/testing:parameterized",
951951
"//third_party/py/numpy",
952952
"//tensorflow/python:client_testlib",
953953
],
954-
shard_count = 8,
954+
shard_count = 4,
955955
tags = [
956956
"no_rocm",
957957
"notsan",

tensorflow/python/keras/layers/normalization.py

Lines changed: 43 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import print_function
2020

2121
from tensorflow.python.distribute import distribution_strategy_context
22-
from tensorflow.python.framework import constant_op
2322
from tensorflow.python.framework import dtypes
2423
from tensorflow.python.framework import ops
2524
from tensorflow.python.framework import tensor_shape
@@ -187,8 +186,10 @@ def __init__(self,
187186
if self._USE_V2_BEHAVIOR:
188187
if fused:
189188
self._raise_if_fused_cannot_be_used()
190-
elif fused is None:
191-
fused = self._fused_can_be_used()
189+
# We leave fused as None if self._fused_can_be_used()==True, since we
190+
# still may set it to False in self.build() if the input rank is not 4.
191+
elif fused is None and not self._fused_can_be_used():
192+
fused = False
192193
elif fused is None:
193194
fused = True
194195
self.supports_masking = True
@@ -209,16 +210,22 @@ def __init__(self,
209210

210211
def _raise_if_fused_cannot_be_used(self):
211212
"""Raises a ValueError if fused implementation cannot be used.
213+
214+
In addition to the checks done in this function, the input tensors rank must
215+
be 4. The input rank check can only be done once the input shape is known.
212216
"""
213217
# Currently fused batch norm doesn't support renorm. It also only supports a
214-
# single axis, when no virtual batch size or adjustment is used.
218+
# channel dimension on axis 1 or 3, when no virtual batch size or adjustment
219+
# is used.
215220
if self.renorm:
216221
raise ValueError('Passing both fused=True and renorm=True is '
217222
'unsupported')
218223
axis = [self.axis] if isinstance(self.axis, int) else self.axis
219-
if len(axis) > 1:
220-
raise ValueError('Passing fused=True is only supported when operating '
221-
'over a single axis.')
224+
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
225+
# input rank is required to be 4 (which is checked later).
226+
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
227+
raise ValueError('Passing fused=True is only supported when axis is 1 '
228+
'or 3')
222229
if self.virtual_batch_size is not None:
223230
raise ValueError('Passing fused=True is unsupported when '
224231
'virtual_batch_size is specified.')
@@ -262,62 +269,6 @@ def _support_zero_size_input(self):
262269
distribution_strategy_context.get_strategy().extended,
263270
'experimental_enable_get_next_as_optional', False)
264271

265-
def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis):
266-
"""Compute an equivalent shape and axis that are compatible with the fused
267-
implementation.
268-
269-
The input/output of the layer can be reshaped to/from the shape returned by
270-
this function without affecting the correctness of the computation.
271-
272-
Arguments:
273-
nd_shape: Tensor. The original shape of the operation.
274-
nd_axis: Integer. The original axis of the operation.
275-
276-
Returns:
277-
shape: Tensor. A 4D shape.
278-
axis: Integer. An axis (always 1 or 3).
279-
"""
280-
assert(isinstance(nd_axis, int))
281-
ndims = nd_shape.shape[0]
282-
shape = nd_shape[:]
283-
axis = nd_shape + nd_axis if nd_axis < 0 else nd_axis
284-
# First check if the axis needs to be moved.
285-
if axis not in (1, ndims - 1):
286-
# Move axis to dim 1.
287-
if axis == 0:
288-
# Transform [C, ...] to [1, C, ...].
289-
shape = array_ops.concat([constant_op.constant([1]), shape], axis=0)
290-
ndims += 1
291-
else:
292-
# Merge excess pre-axis dims into first dim.
293-
# Transform [N, ..., C, ...] to [product(N, ...), C, ...].
294-
product = math_ops.reduce_prod(shape[:axis], keepdims=True)
295-
shape = array_ops.concat([product, shape[axis:]], axis=0)
296-
ndims -= (axis - 1)
297-
axis = 1
298-
# Now change shape to 4D.
299-
is_channels_last = axis == ndims - 1
300-
if ndims < 4:
301-
# Insert new dims after existing spatial dim or before channel dim.
302-
new_dims = constant_op.constant([1] * (4 - ndims))
303-
if is_channels_last:
304-
# Transform [..., C] to [..., 1..., C] (ndims=4).
305-
shape = array_ops.concat([shape[:-1], new_dims, shape[-1:]], axis=0)
306-
else:
307-
# Transform [N, C, ...] to [N, C, ..., 1...] (ndims=4).
308-
shape = array_ops.concat([shape, new_dims], axis=0)
309-
elif ndims > 4:
310-
# Merge excess spatial dims into the second spatial dim.
311-
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)].
312-
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C].
313-
merge_dim = 2 if is_channels_last else 3
314-
product = math_ops.reduce_prod(
315-
shape[merge_dim:merge_dim + 1 + (ndims - 4)], keepdims=True)
316-
shape = array_ops.concat([shape[:merge_dim], product,
317-
shape[merge_dim + 1 + (ndims - 4):]], axis=0)
318-
axis = 3 if is_channels_last else 1
319-
return shape, axis
320-
321272
def build(self, input_shape):
322273
input_shape = tensor_shape.TensorShape(input_shape)
323274
if not input_shape.ndims:
@@ -352,8 +303,33 @@ def build(self, input_shape):
352303
raise ValueError('When using virtual_batch_size, adjustment cannot '
353304
'be specified')
354305

355-
if self.fused and not self._USE_V2_BEHAVIOR:
356-
self.fused = self._fused_can_be_used()
306+
if self.fused in (None, True):
307+
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
308+
# output back to its original shape accordingly.
309+
if self._USE_V2_BEHAVIOR:
310+
if self.fused is None:
311+
self.fused = (ndims == 4)
312+
elif self.fused and ndims != 4:
313+
raise ValueError('Batch normalization layers with fused=True only '
314+
'support 4D input tensors.')
315+
else:
316+
assert self.fused is not None
317+
self.fused = (ndims == 4 and self._fused_can_be_used())
318+
# TODO(chrisying): fused batch norm is currently not supported for
319+
# multi-axis batch norm and by extension virtual batches. In some cases,
320+
# it might be possible to use fused batch norm but would require reshaping
321+
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
322+
# particularly tricky. A compromise might be to just support the most
323+
# common use case (turning 5D w/ virtual batch to NCHW)
324+
325+
if self.fused:
326+
if self.axis == [1]:
327+
self._data_format = 'NCHW'
328+
elif self.axis == [3]:
329+
self._data_format = 'NHWC'
330+
else:
331+
raise ValueError('Unsupported axis, fused batch norm only supports '
332+
'axis == [1] or axis == [3]')
357333

358334
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
359335
for x in axis_to_dim:
@@ -548,7 +524,7 @@ def _fused_batch_norm_training():
548524
gamma,
549525
beta,
550526
epsilon=self.epsilon,
551-
data_format=data_format)
527+
data_format=self._data_format)
552528

553529
def _fused_batch_norm_inference():
554530
return nn.fused_batch_norm(
@@ -559,7 +535,7 @@ def _fused_batch_norm_inference():
559535
variance=self.moving_variance,
560536
epsilon=self.epsilon,
561537
is_training=False,
562-
data_format=data_format)
538+
data_format=self._data_format)
563539

564540
output, mean, variance = tf_utils.smart_cond(
565541
training, _fused_batch_norm_training, _fused_batch_norm_inference)
@@ -572,9 +548,6 @@ def _fused_batch_norm_inference():
572548
factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
573549
variance *= factor
574550

575-
if original_shape is not None:
576-
output = array_ops.reshape(output, original_shape)
577-
578551
training_value = tf_utils.constant_value(training)
579552
if training_value is None:
580553
momentum = tf_utils.smart_cond(training,

0 commit comments

Comments
 (0)