1919from __future__ import print_function
2020
2121from tensorflow .python .distribute import distribution_strategy_context
22- from tensorflow .python .framework import constant_op
2322from tensorflow .python .framework import dtypes
2423from tensorflow .python .framework import ops
2524from tensorflow .python .framework import tensor_shape
@@ -187,8 +186,10 @@ def __init__(self,
187186 if self ._USE_V2_BEHAVIOR :
188187 if fused :
189188 self ._raise_if_fused_cannot_be_used ()
190- elif fused is None :
191- fused = self ._fused_can_be_used ()
189+ # We leave fused as None if self._fused_can_be_used()==True, since we
190+ # still may set it to False in self.build() if the input rank is not 4.
191+ elif fused is None and not self ._fused_can_be_used ():
192+ fused = False
192193 elif fused is None :
193194 fused = True
194195 self .supports_masking = True
@@ -209,16 +210,22 @@ def __init__(self,
209210
210211 def _raise_if_fused_cannot_be_used (self ):
211212 """Raises a ValueError if fused implementation cannot be used.
213+
214+ In addition to the checks done in this function, the input tensors rank must
215+ be 4. The input rank check can only be done once the input shape is known.
212216 """
213217 # Currently fused batch norm doesn't support renorm. It also only supports a
214- # single axis, when no virtual batch size or adjustment is used.
218+ # channel dimension on axis 1 or 3, when no virtual batch size or adjustment
219+ # is used.
215220 if self .renorm :
216221 raise ValueError ('Passing both fused=True and renorm=True is '
217222 'unsupported' )
218223 axis = [self .axis ] if isinstance (self .axis , int ) else self .axis
219- if len (axis ) > 1 :
220- raise ValueError ('Passing fused=True is only supported when operating '
221- 'over a single axis.' )
224+ # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
225+ # input rank is required to be 4 (which is checked later).
226+ if len (axis ) > 1 or axis [0 ] not in (- 3 , - 1 , 1 , 3 ):
227+ raise ValueError ('Passing fused=True is only supported when axis is 1 '
228+ 'or 3' )
222229 if self .virtual_batch_size is not None :
223230 raise ValueError ('Passing fused=True is unsupported when '
224231 'virtual_batch_size is specified.' )
@@ -262,62 +269,6 @@ def _support_zero_size_input(self):
262269 distribution_strategy_context .get_strategy ().extended ,
263270 'experimental_enable_get_next_as_optional' , False )
264271
265- def _get_shape_and_axis_for_fused (self , nd_shape , nd_axis ):
266- """Compute an equivalent shape and axis that are compatible with the fused
267- implementation.
268-
269- The input/output of the layer can be reshaped to/from the shape returned by
270- this function without affecting the correctness of the computation.
271-
272- Arguments:
273- nd_shape: Tensor. The original shape of the operation.
274- nd_axis: Integer. The original axis of the operation.
275-
276- Returns:
277- shape: Tensor. A 4D shape.
278- axis: Integer. An axis (always 1 or 3).
279- """
280- assert (isinstance (nd_axis , int ))
281- ndims = nd_shape .shape [0 ]
282- shape = nd_shape [:]
283- axis = nd_shape + nd_axis if nd_axis < 0 else nd_axis
284- # First check if the axis needs to be moved.
285- if axis not in (1 , ndims - 1 ):
286- # Move axis to dim 1.
287- if axis == 0 :
288- # Transform [C, ...] to [1, C, ...].
289- shape = array_ops .concat ([constant_op .constant ([1 ]), shape ], axis = 0 )
290- ndims += 1
291- else :
292- # Merge excess pre-axis dims into first dim.
293- # Transform [N, ..., C, ...] to [product(N, ...), C, ...].
294- product = math_ops .reduce_prod (shape [:axis ], keepdims = True )
295- shape = array_ops .concat ([product , shape [axis :]], axis = 0 )
296- ndims -= (axis - 1 )
297- axis = 1
298- # Now change shape to 4D.
299- is_channels_last = axis == ndims - 1
300- if ndims < 4 :
301- # Insert new dims after existing spatial dim or before channel dim.
302- new_dims = constant_op .constant ([1 ] * (4 - ndims ))
303- if is_channels_last :
304- # Transform [..., C] to [..., 1..., C] (ndims=4).
305- shape = array_ops .concat ([shape [:- 1 ], new_dims , shape [- 1 :]], axis = 0 )
306- else :
307- # Transform [N, C, ...] to [N, C, ..., 1...] (ndims=4).
308- shape = array_ops .concat ([shape , new_dims ], axis = 0 )
309- elif ndims > 4 :
310- # Merge excess spatial dims into the second spatial dim.
311- # Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)].
312- # Or [N, H, W, ..., C] to [N, H, product(W, ...), C].
313- merge_dim = 2 if is_channels_last else 3
314- product = math_ops .reduce_prod (
315- shape [merge_dim :merge_dim + 1 + (ndims - 4 )], keepdims = True )
316- shape = array_ops .concat ([shape [:merge_dim ], product ,
317- shape [merge_dim + 1 + (ndims - 4 ):]], axis = 0 )
318- axis = 3 if is_channels_last else 1
319- return shape , axis
320-
321272 def build (self , input_shape ):
322273 input_shape = tensor_shape .TensorShape (input_shape )
323274 if not input_shape .ndims :
@@ -352,8 +303,33 @@ def build(self, input_shape):
352303 raise ValueError ('When using virtual_batch_size, adjustment cannot '
353304 'be specified' )
354305
355- if self .fused and not self ._USE_V2_BEHAVIOR :
356- self .fused = self ._fused_can_be_used ()
306+ if self .fused in (None , True ):
307+ # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
308+ # output back to its original shape accordingly.
309+ if self ._USE_V2_BEHAVIOR :
310+ if self .fused is None :
311+ self .fused = (ndims == 4 )
312+ elif self .fused and ndims != 4 :
313+ raise ValueError ('Batch normalization layers with fused=True only '
314+ 'support 4D input tensors.' )
315+ else :
316+ assert self .fused is not None
317+ self .fused = (ndims == 4 and self ._fused_can_be_used ())
318+ # TODO(chrisying): fused batch norm is currently not supported for
319+ # multi-axis batch norm and by extension virtual batches. In some cases,
320+ # it might be possible to use fused batch norm but would require reshaping
321+ # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
322+ # particularly tricky. A compromise might be to just support the most
323+ # common use case (turning 5D w/ virtual batch to NCHW)
324+
325+ if self .fused :
326+ if self .axis == [1 ]:
327+ self ._data_format = 'NCHW'
328+ elif self .axis == [3 ]:
329+ self ._data_format = 'NHWC'
330+ else :
331+ raise ValueError ('Unsupported axis, fused batch norm only supports '
332+ 'axis == [1] or axis == [3]' )
357333
358334 axis_to_dim = {x : input_shape .dims [x ].value for x in self .axis }
359335 for x in axis_to_dim :
@@ -548,7 +524,7 @@ def _fused_batch_norm_training():
548524 gamma ,
549525 beta ,
550526 epsilon = self .epsilon ,
551- data_format = data_format )
527+ data_format = self . _data_format )
552528
553529 def _fused_batch_norm_inference ():
554530 return nn .fused_batch_norm (
@@ -559,7 +535,7 @@ def _fused_batch_norm_inference():
559535 variance = self .moving_variance ,
560536 epsilon = self .epsilon ,
561537 is_training = False ,
562- data_format = data_format )
538+ data_format = self . _data_format )
563539
564540 output , mean , variance = tf_utils .smart_cond (
565541 training , _fused_batch_norm_training , _fused_batch_norm_inference )
@@ -572,9 +548,6 @@ def _fused_batch_norm_inference():
572548 factor = (sample_size - math_ops .cast (1.0 , variance .dtype )) / sample_size
573549 variance *= factor
574550
575- if original_shape is not None :
576- output = array_ops .reshape (output , original_shape )
577-
578551 training_value = tf_utils .constant_value (training )
579552 if training_value is None :
580553 momentum = tf_utils .smart_cond (training ,
0 commit comments