Skip to content

Commit 4486ba1

Browse files
authored
[DOC] Add user docs for data deduplication and so forth. (#150)
1. data deduplication. 2 the usage of `hb.embedding_scope` 3. DataSyncReplicas 4. `hb.keras` API. Signed-off-by: langshi.cls <langshi.cls@alibaba-inc.com>
1 parent 02a714b commit 4486ba1

File tree

3 files changed

+314
-21
lines changed

3 files changed

+314
-21
lines changed

docs/data.md

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Loading
22

3-
Large batch training on cloud requires great IO performance. HybridBackend
3+
Large-batch training on cloud requires great IO performance. HybridBackend
44
supports memory-efficient loading of categorical data.
55

66
## 1. Data Frame
@@ -13,18 +13,18 @@ Supported logical data types:
1313

1414
| Name | Data Structure |
1515
| --------------------------- | ------------------------------------------------- |
16-
| Scalar | `tf.Tensor` / `hb.data.DataFrame.Value` |
17-
| Fixed-Length List | `tf.Tensor` / `hb.data.DataFrame.Value` |
18-
| Variable-Length List | `tf.SparseTensor` / `hb.data.DataFrame.Value` |
19-
| Variable-Length Nested List | `tf.SparseTensor` / `hb.data.DataFrame.Value` |
16+
| Scalar | `tf.Tensor` / `hb.data.DataFrame.Value` |
17+
| Fixed-Length List | `tf.Tensor` / `hb.data.DataFrame.Value` |
18+
| Variable-Length List | `tf.SparseTensor` / `hb.data.DataFrame.Value` |
19+
| Variable-Length Nested List | `tf.SparseTensor` / `hb.data.DataFrame.Value` |
2020

2121
Supported physical data types:
2222

2323
| Category | Types |
2424
| -------- | ------------------------------------------------------------ |
25-
| Integers | `int64` `uint64` `int32` `uint32` `int8` `uint8` |
26-
| Numerics | `float64` `float32` `float16` |
27-
| Text | `string` |
25+
| Integers | `int64` `uint64` `int32` `uint32` `int8` `uint8` |
26+
| Numerics | `float64` `float32` `float16` |
27+
| Text | `string` |
2828

2929
```{eval-rst}
3030
.. autoclass:: hybridbackend.tensorflow.data.DataFrame
@@ -168,9 +168,78 @@ batch = it.get_next()
168168
...
169169
```
170170

171-
## 3. Tips
171+
## 3. Deduplication
172172

173-
### 3.1 Remove dataset ops in exported saved model
173+
Some of the feature columns associated to users, such as an user's bio information or
174+
the recent behaviour (user-viewed items), would normally contain redundant
175+
information. For instance, two records associated to the same user id shall have
176+
the same data from the feature column of recent-viewed items. HybridBackend
177+
provides us of a deduplication mechanism to improve the data loading speedup
178+
as well as the data storage capacity.
179+
180+
### 3.1 Preparation of deduplicated training data
181+
182+
Currently, it is user's responsibility to deduplicate the training data (e.g., in parquet format).
183+
An example of python script is described in `hybridbackend/docs/tutorial/ranking/taobao/data/deduplicate.py`.
184+
In general, users shall provide three arguments:
185+
186+
1. `--deduplicated-block-size`: indicates that how many rows (records) are
187+
involved per deduplicate operation. For instance, if 1000 rows applies a
188+
deduplication, the compressed one record shall be restored to 1000 records
189+
in the actual training. Theoretically, a large dedupicate block size shall
190+
bring a better deduplicate ratio, however, it also depends on the
191+
distribution of duplicated data.
192+
193+
2. `--user-cols`: A list of feature column names (fields).
194+
The first feature column of the list serves as the `key`
195+
to deduplicate while the rest of feature columns are values (targets) to compress.
196+
There could be multiple such `--user-cols` to be deduplicate independently.
197+
198+
3. `--non-user-cols`: The feature columns that are excluded from the deduplication.
199+
200+
The prepared data shall contain an additional feature column for each `--user-cols`
201+
, which stores the inverse index to restore the deduplicated values in training.
202+
203+
### 3.2 Read deduplicated data and restore.
204+
205+
HybridBackend provides a API to read deduplicated training data prepared in 3.1.
206+
207+
Example:
208+
209+
```python
210+
import tensorflow as tf
211+
import hybridbackend.tensorflow as hb
212+
213+
# Define data frame fields.
214+
fields = [
215+
hb.data.DataFrame.Field('user', tf.int64), # scalar
216+
hb.data.DataFrame.Field('user-index', tf.int64), # scalar
217+
hb.data.DataFrame.Field('user-feat-0', tf.int64, shape=[32]), # fixed-length list
218+
hb.data.DataFrame.Field('user-feat-1', tf.int64, ragged_rank=1), # variable-length list
219+
hb.data.DataFrame.Field('item-feat-0', tf.int64, ragged_rank=1)] # variable-length list
220+
221+
# Read from deduplicated parquet files (deduplicate every 1024 rows)
222+
# by specifying the `key` and `value` feature columns.
223+
ds = hb.data.Dataset.from_parquet(
224+
'/path/to/f1.parquet',
225+
fields=fields,
226+
key_idx_field_names=['user-index'],
227+
value_field_names=[['user', 'user-feat-0', 'user-feat-1']])
228+
ds = ds.batch(1)
229+
ds = ds.prefetch(4)
230+
it = tf.data.make_one_shot_iterator(ds)
231+
batch = it.get_next()
232+
```
233+
Where the argument of `key_idx_field_names` is a list of feature columns that
234+
contains the inversed index of key feature columns, and
235+
`value_field_names` is a list of feature columns (list) associated to each
236+
key feature column. It supports multiple `key-value` deduplication. When
237+
calling `get_next()` method to obtain the batched data, the deduplicated values
238+
shall be internally restored to their original values.
239+
240+
## 4. Tips
241+
242+
### 4.1 Remove dataset ops in exported saved model
174243

175244
```python
176245
import tensorflow as tf
@@ -195,7 +264,7 @@ with tf.Graph().as_default() as predict_graph:
195264
outputs=model_outputs)
196265
```
197266

198-
## 4. Benchmark
267+
## 5. Benchmark
199268

200269
In benchmark for reading 20k samples from 200 columns of a Parquet file,
201270
`hb.data.Dataset` is about **21.51x faster** than

docs/distributed.md

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,20 @@ or
3737
HB_GRAD_NBUCKETS=2 python xxx.py
3838
```
3939

40-
### 1.4 Example: Launch workers on multiple GPUs
40+
### 1.4 Example: Launch workers on single machine of multiple GPUs
4141

4242
```bash
4343
# Launch workers for each GPU by reading environment variable
44-
# `NVIDIA_VISIBLE_DEVICES`.
44+
# `NVIDIA_VISIBLE_DEVICES` or `CUDA_VISIBLE_DEVICES`.
45+
python -m hybridbackend.run python /path/to/main.py
46+
```
47+
48+
### 1.5 Example: Launch workers on multiple machines of multiple GPUs
49+
50+
```bash
51+
# set the environment of `TF_CONFIG` with respect to machines. E.g.,
52+
# TF_CONFIG='{"cluster":{"chief":["x.x.x.x:8860"],"worker":["x.x.x.x:8861"]}, "task":{"type":"chief","index":0}}'
53+
# then set `NVIDIA_VISIBLE_DEVICES` or `CUDA_VISIBLE_DEVICES` for gpus per machine
4554
python -m hybridbackend.run python /path/to/main.py
4655
```
4756

@@ -69,12 +78,12 @@ with hb.scope():
6978
opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
7079
```
7180

72-
## 2. Embedding-Sharded Data Parallelism
81+
## 3. Embedding-Sharded Data Parallelism
7382

74-
HybridBackend provides option `sharding` to shard variables and support
83+
HybridBackend provides a `hb.embedding_scope` to shard variables and support
7584
embedding-sharded data paralleism.
7685

77-
### 2.1 APIs
86+
### 3.1 APIs
7887

7988
```{eval-rst}
8089
.. autofunction:: hybridbackend.tensorflow.metrics.accuracy
@@ -83,7 +92,7 @@ embedding-sharded data paralleism.
8392
.. autofunction:: hybridbackend.tensorflow.train.export
8493
```
8594

86-
### 2.2 Example: Sharding embedding weights within a scope
95+
### 3.2 Example: Sharding embedding weights within a scope
8796

8897
```python
8998
import tensorflow as tf
@@ -92,7 +101,7 @@ import hybridbackend.tensorflow as hb
92101
def foo():
93102
# ...
94103
with hb.scope():
95-
with hb.scope(sharding=True):
104+
with hb.embedding_scope():
96105
embedding_weights = tf.get_variable(
97106
'emb_weights', shape=[bucket_size, dim_size])
98107
embedding = tf.nn.embedding_lookup(embedding_weights, ids)
@@ -102,7 +111,7 @@ def foo():
102111
opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
103112
```
104113

105-
### 2.3 Example: Evaluation
114+
### 3.3 Example: Evaluation
106115

107116
```python
108117
import tensorflow as tf
@@ -119,7 +128,7 @@ with tf.Graph().as_default():
119128
with hb.scope():
120129
batch = tf.data.make_one_shot_iterator(train_ds).get_next()
121130
# ...
122-
with hb.scope(sharding=True):
131+
with hb.embedding_scope():
123132
embedding_weights = tf.get_variable(
124133
'emb_weights', shape=[bucket_size, dim_size])
125134
embedding = tf.nn.embedding_lookup(embedding_weights, ids)
@@ -133,7 +142,7 @@ with tf.Graph().as_default():
133142
sess.run(train_op)
134143
```
135144

136-
### 2.4 Example: Exporting to SavedModel
145+
### 3.4 Example: Exporting to SavedModel
137146

138147
```python
139148
import tensorflow as tf
@@ -159,3 +168,31 @@ def _on_export():
159168
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
160169
hb.train.export(export_dir_base, checkpoint_path, _on_export)
161170
```
171+
172+
## 4. Sync training with unbalanced data across workers.
173+
174+
In training data across distributed workers, it is likely that some of the
175+
workers have been assigned less batches of data than the others. Hence, these
176+
workers shall run out of data ahead of other workers. HybridBackend provides users
177+
of two strategy to process remained training data on some of the workers.
178+
179+
1. set `data_sync_drop_remainder=True` (by default) in `hb.scope()`
180+
```python
181+
import tensorflow as tf
182+
import hybridbackend.tensorflow as hb
183+
184+
if __name__ == '__main__':
185+
...
186+
with hb.scope(data_sync_drop_remainder=True):
187+
main()
188+
```
189+
By doing so, whenever one of the workers has finished assigned training data,
190+
HybridBackend would drop remained training data on other workers to end the
191+
training task.
192+
193+
2. set `data_sync_drop_remainder=False` in `hb.scope()`. As a result, whenever
194+
a worker has finished its training data, it will keep producing empty data (tensor)
195+
to join the synchronous training along with other workers until all of the workers
196+
have finished their training data. It is worth noting that the users shall ensure
197+
a compatibility of their customized TF operators or other implementation to allow
198+
such emtpy data (tensor) in their executions.

0 commit comments

Comments
 (0)