@@ -337,18 +337,33 @@ def step(data):
337337 embedding_layer = DistributedEmbedding(feature_configs)
338338
339339 # Add preprocessing to a data input pipeline.
340- def train_dataset_generator( ):
341- for (inputs, weights), labels in iter(train_dataset ):
340+ def preprocessed_dataset_generator(dataset ):
341+ for (inputs, weights), labels in iter(dataset ):
342342 yield embedding_layer.preprocess(
343343 inputs, weights, training=True
344344 ), labels
345345
346- preprocessed_train_dataset = train_dataset_generator( )
346+ preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset )
347347 ```
348348 This explicit preprocessing stage combines the input and optional weights,
349349 so the new data can be passed directly into the `inputs` argument of the
350350 layer or model.
351351
352+ **NOTE**: When working in a multi-host setting with data parallelism, the
353+ data needs to be sharded properly across hosts. If the original dataset is
354+ of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
355+ applying the preprocess generator:
356+ ```python
357+ # Manually shard the dataset across hosts.
358+ train_dataset = distribution.distribute_dataset(train_dataset)
359+ distribution.auto_shard_dataset = False # Dataset is already sharded.
360+
361+ # Add a preprocessing stage to the distributed data input pipeline.
362+ train_dataset = preprocessed_dataset_generator(train_dataset)
363+ ```
364+ If the original dataset is _not_ a `tf.data.Dataset`, it must already be
365+ pre-sharded across hosts.
366+
352367 #### Usage in a Keras model
353368
354369 Once the global distribution is set and the input preprocessing pipeline
0 commit comments