Skip to content

Commit d93b28b

Browse files
authored
Add note about multi-host dataset sharding. (#131)
1 parent a94d9fc commit d93b28b

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,18 +337,33 @@ def step(data):
337337
embedding_layer = DistributedEmbedding(feature_configs)
338338
339339
# Add preprocessing to a data input pipeline.
340-
def train_dataset_generator():
341-
for (inputs, weights), labels in iter(train_dataset):
340+
def preprocessed_dataset_generator(dataset):
341+
for (inputs, weights), labels in iter(dataset):
342342
yield embedding_layer.preprocess(
343343
inputs, weights, training=True
344344
), labels
345345
346-
preprocessed_train_dataset = train_dataset_generator()
346+
preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
347347
```
348348
This explicit preprocessing stage combines the input and optional weights,
349349
so the new data can be passed directly into the `inputs` argument of the
350350
layer or model.
351351
352+
**NOTE**: When working in a multi-host setting with data parallelism, the
353+
data needs to be sharded properly across hosts. If the original dataset is
354+
of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
355+
applying the preprocess generator:
356+
```python
357+
# Manually shard the dataset across hosts.
358+
train_dataset = distribution.distribute_dataset(train_dataset)
359+
distribution.auto_shard_dataset = False # Dataset is already sharded.
360+
361+
# Add a preprocessing stage to the distributed data input pipeline.
362+
train_dataset = preprocessed_dataset_generator(train_dataset)
363+
```
364+
If the original dataset is _not_ a `tf.data.Dataset`, it must already be
365+
pre-sharded across hosts.
366+
352367
#### Usage in a Keras model
353368
354369
Once the global distribution is set and the input preprocessing pipeline

0 commit comments

Comments
 (0)