@@ -231,7 +231,9 @@ def _is_tensor(v):
231231 return True
232232 return False
233233
234- return all (_is_tensor (v ) for v in flat_inputs )
234+ return all (_is_tensor (v ) for v in flat_inputs if v is not None ) and any (
235+ _is_tensor (v ) for v in flat_inputs
236+ )
235237
236238 def __init__ (
237239 self ,
@@ -259,7 +261,7 @@ def __init__(
259261 inputs = pack_x_y_sample_weight (x , y , sample_weights )
260262
261263 num_samples = set (
262- int (i .shape [0 ]) for i in tf .nest .flatten (inputs )
264+ int (i .shape [0 ]) for i in tf .nest .flatten (inputs ) if i is not None
263265 ).pop ()
264266 _check_data_cardinality (inputs )
265267
@@ -386,7 +388,7 @@ def slice_inputs(self, indices_dataset, inputs):
386388
387389 def grab_batch (i , data ):
388390 return tf .nest .map_structure (
389- lambda d : tf .gather (d , i , axis = 0 ), data
391+ lambda d : tf .gather (d , i , axis = 0 ) if d is not None else d , data
390392 )
391393
392394 dataset = dataset .map (grab_batch , num_parallel_calls = tf .data .AUTOTUNE )
@@ -459,7 +461,9 @@ def _is_array_like(v):
459461 if not TensorLikeDataAdapter .can_handle (
460462 x , y
461463 ) and not CompositeTensorDataAdapter .can_handle (x , y ):
462- return all (_is_array_like (v ) for v in flat_inputs )
464+ return all (
465+ _is_array_like (v ) for v in flat_inputs if v is not None
466+ ) and any (v is not None for v in flat_inputs )
463467 else :
464468 return False
465469
@@ -496,7 +500,7 @@ def dynamic_shape_like(t):
496500 shape [0 ] = None
497501 return tuple (shape )
498502
499- flat_dtypes = [inp .dtype for inp in flat_inputs ]
503+ flat_dtypes = [inp .dtype for inp in flat_inputs if inp is not None ]
500504 contiguous = True
501505 if self ._shuffle and self ._shuffle != "batch" :
502506 contiguous = False
@@ -509,15 +513,26 @@ def grab_batch(indices):
509513 # to a Tensor may force it into memory..
510514 def py_method (ind ):
511515 def slice_array (data ):
516+ if data is None :
517+ return None
512518 return training_utils .slice_arrays (
513519 data , ind .numpy (), contiguous = contiguous
514520 )
515521
516- return [slice_array (inp ) for inp in flat_inputs ]
522+ return [
523+ slice_array (inp ) for inp in flat_inputs if inp is not None
524+ ]
517525
518- flat_out = tf .py_function (py_method , [indices ], flat_dtypes )
519- for v , original_inp in zip (flat_out , flat_inputs ):
520- v .set_shape (dynamic_shape_like (original_inp ))
526+ results = tf .py_function (py_method , [indices ], flat_dtypes )
527+ results_it = iter (results )
528+ flat_out = []
529+ for original_inp in flat_inputs :
530+ if original_inp is None :
531+ flat_out .append (None )
532+ else :
533+ v = next (results_it )
534+ v .set_shape (dynamic_shape_like (original_inp ))
535+ flat_out .append (v )
521536 return tf .nest .pack_sequence_as (inputs , flat_out )
522537
523538 dataset = indices_dataset .map (
@@ -608,8 +623,10 @@ def _is_tensor_or_composite(v):
608623 return True
609624 return _is_composite (v )
610625
611- return any (_is_composite (v ) for v in flat_inputs ) and all (
612- _is_tensor_or_composite (v ) for v in flat_inputs
626+ return any (
627+ _is_composite (v ) for v in flat_inputs if v is not None
628+ ) and all (
629+ _is_tensor_or_composite (v ) for v in flat_inputs if v is not None
613630 )
614631
615632 def __init__ (
@@ -1944,14 +1961,18 @@ def single_batch_iterator(
19441961
19451962
19461963def _check_data_cardinality (data ):
1947- num_samples = set (int (i .shape [0 ]) for i in tf .nest .flatten (data ))
1964+ num_samples = set (
1965+ int (i .shape [0 ]) for i in tf .nest .flatten (data ) if i is not None
1966+ )
19481967 if len (num_samples ) > 1 :
19491968 msg = "Data cardinality is ambiguous:\n "
19501969 for label , single_data in zip (["x" , "y" , "sample_weight" ], data ):
19511970 msg += " {} sizes: {}\n " .format (
19521971 label ,
19531972 ", " .join (
1954- str (i .shape [0 ]) for i in tf .nest .flatten (single_data )
1973+ str (i .shape [0 ])
1974+ for i in tf .nest .flatten (single_data )
1975+ if i is not None
19551976 ),
19561977 )
19571978 msg += "Make sure all arrays contain the same number of samples."
0 commit comments