Skip to content

Commit 02a714b

Browse files
authored
[DATA] Implement zero-copied string dtype and accelerate shuffle. (#149)
1. Implement a zero-copied approach to read string data from Arrow to TF. 2. Accelerate the shuffle operation of string type in ParquetDataset. preliminary benchmarking results - col=300, `batch_size`=1000 - `Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz` with 128 logical cores. | Dataset | list type | shuffling | throughput (samples/s) | speedup over TFRecord | | --- | --- | --- | --- | --- | | TFRecord | N | N | 1404.23 | 1.0 | | HbParquet | N | N | 41137.53 | 29.3 | | HbParquet-ZeroCopy | N | N | 51335.40 | 36.56 | | TFRecord | N | Y | 1343.10 | 1.0 | | HbParquet | N | Y | 6629.60 | 4.9 | | HbParquet-ZeroCopy | N | Y | 10941.25 | 8.1 | | TFRecord | Y | N | 1352.05 | 1.0 | | HbParquet | Y | N | 2307.33 | 1.71 | | HbParquet-ZeroCopy | Y | N | 2869.98 | 2.12 | | TFRecord | Y | Y | 1367.96 | 1.0 | | HbParquet | Y | Y | 1080.03 | 0.79 | | HbParquet-ZeroCopy | Y | Y | 1454.02 | 1.06 | Signed-off-by: langshi.cls <langshi.cls@alibaba-inc.com>
1 parent 0545159 commit 02a714b

File tree

10 files changed

+639
-167
lines changed

10 files changed

+639
-167
lines changed

build/dockerfiles/Dockerfile.developer-tf1.15-py3.6-cu100-ubuntu18.04

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ ENV HYBRIDBACKEND_WITH_CUDA=ON \
169169
HYBRIDBACKEND_WITH_NCCL=OFF \
170170
HYBRIDBACKEND_WITH_ARROW_ZEROCOPY=ON \
171171
HYBRIDBACKEND_WITH_TENSORFLOW_HALF=OFF \
172-
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=1015 \
172+
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=77661015 \
173173
HYBRIDBACKEND_USE_CXX11_ABI=0 \
174174
HYBRIDBACKEND_WHEEL_ALIAS=-tf115-cu100 \
175175
HYBRIDBACKEND_WHEEL_REQUIRES="tensorflow_gpu>=1.15,<2.0"

build/dockerfiles/Dockerfile.developer-tf1.15-py3.8-cu121-ubuntu20.04

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ COPY --from=devel_tools /opt/tools /usr/local
121121
ENV HYBRIDBACKEND_WITH_CUDA=ON \
122122
HYBRIDBACKEND_WITH_NCCL=ON \
123123
HYBRIDBACKEND_WITH_ARROW_ZEROCOPY=ON \
124-
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=1015 \
124+
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=77661015 \
125125
HYBRIDBACKEND_USE_CXX11_ABI=0 \
126126
HYBRIDBACKEND_USE_RUFF=1 \
127127
HYBRIDBACKEND_WHEEL_ALIAS=-tf115-cu121 \
128128
TENSORFLOW_INCLUDE=/opt/tensorflow/tensorflow-source \
129-
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64
129+
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64

hybridbackend/tensorflow/benchmarks/data_benchmark_parquet.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,31 @@ def benchmark(params):
4040
tf.logging.info('Started generating mock file ...')
4141
workspace = tempfile.mkdtemp()
4242
params.filenames = [os.path.join(workspace, 'benchmark.parquet')]
43-
df = pd.DataFrame(
44-
np.random.randint(
45-
0, 100,
46-
size=(params.batch_size * 100, len(params.fields)),
47-
dtype=np.int64),
48-
columns=params.fields)
43+
if params.use_string_data:
44+
df = pd.DataFrame(
45+
np.array([
46+
[
47+
*[
48+
np.array(list(map(str, np.random.randint(
49+
0, 9,
50+
size=(np.random.randint(10, 30),),
51+
dtype=np.int64))))
52+
for _ in xrange(len(params.fields))]]
53+
for _ in xrange(params.batch_size * 100)], dtype=object),
54+
columns=params.fields)
55+
elif params.use_fixed_len_string_data:
56+
df = pd.DataFrame(
57+
np.array([
58+
['abcdefghijklmnoprstu' for _ in xrange(len(params.fields))]
59+
for _ in xrange(params.batch_size * 100)], dtype=np.str),
60+
columns=params.fields)
61+
else:
62+
df = pd.DataFrame(
63+
np.random.randint(
64+
0, 100,
65+
size=(params.batch_size * 100, len(params.fields)),
66+
dtype=np.int64),
67+
columns=params.fields)
4968
df.to_parquet(params.filenames[0])
5069
tf.logging.info(f'Mock file {params.filenames[0]} generated.')
5170
with tf.Graph().as_default():
@@ -66,7 +85,14 @@ def benchmark(params):
6685
ds = ds.batch(params.batch_size, drop_remainder=True)
6786
batch = tf.data.make_one_shot_iterator(ds).get_next()
6887
train_op = tf.group(list(batch.values()) + [step.assign_add(1)])
69-
with tf.train.MonitoredTrainingSession('') as sess:
88+
chief_only_hooks = []
89+
if params.profile_every_n_iter is not None:
90+
chief_only_hooks.append(
91+
tf.train.ProfilerHook(
92+
save_steps=params.profile_every_n_iter,
93+
output_dir=params.output_dir))
94+
with tf.train.MonitoredTrainingSession(
95+
'', chief_only_hooks=chief_only_hooks) as sess:
7096
count = 0
7197
prev_ts = time.time()
7298
try:
@@ -100,8 +126,13 @@ def benchmark(params):
100126
parser = argparse.ArgumentParser()
101127
parser.add_argument('--baseline', default=False, action='store_true')
102128
parser.add_argument('--shuffle', default=False, action='store_true')
129+
parser.add_argument('--use-string-data', default=False, action='store_true')
130+
parser.add_argument(
131+
'--use-fixed-len-string-data', default=False, action='store_true')
103132
parser.add_argument('--batch-size', type=int, default=64000)
104133
parser.add_argument('--num-steps', type=int, default=None)
134+
parser.add_argument('--output-dir', default='./outputs')
135+
parser.add_argument('--profile-every-n-iter', type=int, default=None)
105136
parser.add_argument(
106137
'--fields', nargs='+', default=[f'f{c}' for c in xrange(200)])
107138
parser.add_argument('filenames', nargs='*')

hybridbackend/tensorflow/benchmarks/data_benchmark_tfrecord.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,39 +38,89 @@ def benchmark(params):
3838
tf.logging.info('Started generating mock file ...')
3939
workspace = tempfile.mkdtemp()
4040
params.filenames = [os.path.join(workspace, 'benchmark.tfrecord')]
41-
df = pd.DataFrame(
42-
np.random.randint(
43-
0, 100,
44-
size=(params.batch_size * 100, len(params.fields)),
45-
dtype=np.int64),
46-
columns=params.fields)
41+
if params.use_string_data:
42+
df = pd.DataFrame(
43+
np.array([
44+
[
45+
*[
46+
np.array(list(map(str, np.random.randint(
47+
0, 9,
48+
size=(np.random.randint(10, 30),),
49+
dtype=np.int64))))
50+
for _ in xrange(len(params.fields))]]
51+
for _ in xrange(params.batch_size * 100)], dtype=object),
52+
columns=params.fields)
53+
elif params.use_fixed_len_string_data:
54+
df = pd.DataFrame(
55+
np.array([
56+
['abcdefghijklmnoprstu' for _ in xrange(len(params.fields))]
57+
for _ in xrange(params.batch_size * 100)], dtype=np.str),
58+
columns=params.fields)
59+
else:
60+
df = pd.DataFrame(
61+
np.random.randint(
62+
0, 100,
63+
size=(params.batch_size * 100, len(params.fields)),
64+
dtype=np.int64),
65+
columns=params.fields)
4766
writer = tf.python_io.TFRecordWriter(params.filenames[0])
48-
for row in tq(range(params.samples)):
49-
feats = tf.train.Features(
50-
feature={
51-
f: tf.train.Feature(
52-
int64_list=tf.train.Int64List(value=[df[f][row]]))
53-
for f in params.fields})
67+
for row in tq(range(params.batch_size * 100)):
68+
if params.use_string_data or params.use_fixed_len_string_data:
69+
feats = tf.train.Features(
70+
feature={
71+
f: tf.train.Feature(
72+
bytes_list=tf.train.BytesList(
73+
value=[bytes(val, 'utf-8') for val in df[f][row]]))
74+
for f in params.fields})
75+
else:
76+
feats = tf.train.Features(
77+
feature={
78+
f: tf.train.Feature(
79+
int64_list=tf.train.Int64List(value=[df[f][row]]))
80+
for f in params.fields})
5481
example = tf.train.Example(features=feats)
5582
writer.write(example.SerializeToString())
5683
writer.close()
5784
tf.logging.info(f'Mock file {params.filenames[0]} generated.')
5885
with tf.Graph().as_default():
5986
step = tf.train.get_or_create_global_step()
6087
ds = tf.data.TFRecordDataset(params.filenames)
88+
if params.shuffle:
89+
ds = ds.shuffle(params.batch_size * 10)
6190
ds = ds.batch(params.batch_size, drop_remainder=True)
62-
ds = ds.map(
63-
lambda line: tf.parse_example(
64-
line, {f: tf.FixedLenFeature([1], tf.int64) for f in params.fields}))
91+
if params.use_string_data or params.use_fixed_len_string_data:
92+
ds = ds.map(
93+
lambda line: tf.parse_example(
94+
line, {f: tf.VarLenFeature(tf.string) for f in params.fields}))
95+
else:
96+
ds = ds.map(
97+
lambda line: tf.parse_example(
98+
line, {f: tf.FixedLenFeature([1], tf.int64) for f in params.fields}))
6599
batch = tf.data.make_one_shot_iterator(ds).get_next()
66-
train_op = tf.group(batch + [step.assign_add(1)])
67-
with tf.train.MonitoredTrainingSession('') as sess:
100+
train_op = tf.group(list(batch.values()) + [step.assign_add(1)])
101+
chief_only_hooks = []
102+
if params.profile_every_n_iter is not None:
103+
chief_only_hooks.append(
104+
tf.train.ProfilerHook(
105+
save_steps=params.profile_every_n_iter,
106+
output_dir=params.output_dir))
107+
with tf.train.MonitoredTrainingSession(
108+
'', chief_only_hooks=chief_only_hooks) as sess:
68109
count = 0
69110
prev_ts = time.time()
70111
try:
71-
while not sess.should_stop():
72-
sess.run(train_op)
73-
count += 1
112+
with tq() as pbar:
113+
should_stop = False
114+
while not sess.should_stop() and not should_stop:
115+
prev_sess_run = time.time()
116+
sess.run(train_op)
117+
sess_run_duration = time.time() - prev_sess_run
118+
pbar.set_description(
119+
f'{params.batch_size / sess_run_duration:6.2f} samples/sec')
120+
pbar.update(1)
121+
count += 1
122+
if params.num_steps is not None:
123+
should_stop = count >= params.num_steps
74124
except tf.errors.OutOfRangeError:
75125
pass
76126
duration = time.time() - prev_ts
@@ -87,7 +137,14 @@ def benchmark(params):
87137
os.environ['CUDA_VISIBLE_DEVICES'] = ''
88138
tf.logging.set_verbosity(tf.logging.INFO)
89139
parser = argparse.ArgumentParser()
140+
parser.add_argument('--shuffle', default=False, action='store_true')
141+
parser.add_argument('--use-string-data', default=False, action='store_true')
142+
parser.add_argument(
143+
'--use-fixed-len-string-data', default=False, action='store_true')
90144
parser.add_argument('--batch-size', type=int, default=64000)
145+
parser.add_argument('--num-steps', type=int, default=None)
146+
parser.add_argument('--output-dir', default='./outputs')
147+
parser.add_argument('--profile-every-n-iter', type=int, default=None)
91148
parser.add_argument(
92149
'--fields', nargs='+', default=[f'f{c}' for c in xrange(200)])
93150
parser.add_argument('filenames', nargs='*')

hybridbackend/tensorflow/common/arrow.cc

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,80 @@ limitations under the License.
2424
#include <arrow/util/thread_pool.h>
2525
#include <tensorflow/core/framework/allocation_description.pb.h>
2626

27+
#include "hybridbackend/common/env.h"
2728
#include "hybridbackend/tensorflow/common/arrow.h"
2829
#include "hybridbackend/tensorflow/common/eigen.h"
2930
#endif
3031

3132
namespace tensorflow {
3233
namespace hybridbackend {
3334

35+
namespace {
36+
inline bool ZeroCopyStringForRebatchDisabled() {
37+
static const bool kZeroCopyStringForRebatchDisabled =
38+
::hybridbackend::EnvVarGetBool("HB_ZERO_COPY_STRING_FOR_REBATCH_DISABLED",
39+
false);
40+
return kZeroCopyStringForRebatchDisabled;
41+
}
42+
} // namespace
43+
3444
#if HYBRIDBACKEND_ARROW
3545

46+
#if HYBRIDBACKEND_ARROW_ZEROCOPY
47+
#if (TF_MAJOR_VERSION * 1000L + TF_MINOR_VERSION) < 1014L
48+
ArrowStringTensorBuffer::ArrowStringTensorBuffer(
49+
const std::shared_ptr<arrow::Buffer>& value_data_buf,
50+
const std::shared_ptr<arrow::Buffer>& value_offsets_buf,
51+
const uint8_t* raw_data, const int32_t* raw_value_offsets)
52+
: value_data_buf_(value_data_buf),
53+
value_offsets_buf_(value_offsets_buf),
54+
raw_data_(raw_data),
55+
raw_value_offsets_(raw_value_offsets) {}
56+
57+
void ArrowStringTensorBuffer::data() const { return this; }
58+
59+
#else
60+
ArrowStringTensorBuffer::ArrowStringTensorBuffer(
61+
const std::shared_ptr<arrow::Buffer>& value_data_buf,
62+
const std::shared_ptr<arrow::Buffer>& value_offsets_buf,
63+
const uint8_t* raw_data, const int32_t* raw_value_offsets)
64+
: TensorBuffer(this),
65+
value_data_buf_(value_data_buf),
66+
value_offsets_buf_(value_offsets_buf),
67+
raw_data_(raw_data),
68+
raw_value_offsets_(raw_value_offsets) {}
69+
#endif
70+
71+
size_t ArrowStringTensorBuffer::size() const {
72+
LOG(ERROR) << "When using zero copy string for rebatch, please and a "
73+
"hb.data.rebatch(batch_size) following hb.data.ParquetDataset ";
74+
return 0;
75+
}
76+
77+
TensorBuffer* ArrowStringTensorBuffer::root_buffer() { return this; }
78+
79+
void ArrowStringTensorBuffer::FillAllocationDescription(
80+
AllocationDescription* proto) const {
81+
proto->set_requested_bytes(sizeof(tstring));
82+
proto->set_allocator_name("ZerocopyArrowStringTensorBuffer");
83+
#if HYBRIDBACKEND_TENSORFLOW_DISTRO == 1015
84+
// NOTE: vanilla tensorflow from community has no `data()` method of
85+
// class `Tensor`, thus we have to leverage the FillAllocationDescription
86+
// API to obtain the underlying ArrowStringTensorBuffer.
87+
proto->set_ptr(reinterpret_cast<uint64>(this));
88+
#endif
89+
}
90+
91+
bool ArrowStringTensorBuffer::OwnsMemory() const { return false; }
92+
93+
const uint8_t* ArrowStringTensorBuffer::GetValue(int64_t i,
94+
int32_t* out_length) {
95+
const int32_t pos = raw_value_offsets_[i];
96+
*out_length = raw_value_offsets_[i + 1] - pos;
97+
return raw_data_ + pos;
98+
}
99+
#endif
100+
36101
namespace {
37102
#if HYBRIDBACKEND_ARROW_ZEROCOPY
38103
class ArrowPrimitiveTensorBuffer : public TensorBuffer {
@@ -143,15 +208,34 @@ ::arrow::Status MakeStringTensorFromArrowArray(
143208
&actual_shape))) {
144209
return ::arrow::Status::Invalid("Field shape is not fully defined");
145210
}
146-
147-
*tensor = Tensor(DT_STRING, actual_shape);
148-
auto tensor_vec = tensor->vec<std::string>();
149-
150-
for (auto i = 0; i < total_num_elems; ++i) {
151-
int string_size;
152-
auto string_data = array.GetValue(i, &string_size);
153-
tensor_vec(i).assign(reinterpret_cast<const char*>(string_data),
154-
string_size);
211+
if (ZeroCopyStringForRebatchDisabled()) {
212+
*tensor = Tensor(DT_STRING, actual_shape);
213+
auto tensor_vec = tensor->vec<std::string>();
214+
215+
for (auto i = 0; i < total_num_elems; ++i) {
216+
int string_size;
217+
auto string_data = array.GetValue(i, &string_size);
218+
tensor_vec(i).assign(reinterpret_cast<const char*>(string_data),
219+
string_size);
220+
}
221+
} else {
222+
#if HYBRIDBACKEND_ARROW_ZEROCOPY
223+
ArrowStringTensorBuffer* tensor_buffer = new ArrowStringTensorBuffer(
224+
array.value_data(), array.value_offsets(), array.raw_data(),
225+
array.raw_value_offsets());
226+
core::ScopedUnref unref(tensor_buffer);
227+
*tensor = Tensor(DT_STRING, actual_shape, tensor_buffer);
228+
#else
229+
*tensor = Tensor(DT_STRING, actual_shape);
230+
auto tensor_vec = tensor->vec<std::string>();
231+
232+
for (auto i = 0; i < total_num_elems; ++i) {
233+
int string_size;
234+
auto string_data = array.GetValue(i, &string_size);
235+
tensor_vec(i).assign(reinterpret_cast<const char*>(string_data),
236+
string_size);
237+
}
238+
#endif
155239
}
156240
return ::arrow::Status::OK();
157241
}

hybridbackend/tensorflow/common/arrow.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <deque>
2020

2121
#if HYBRIDBACKEND_ARROW
22+
#include <arrow/array.h>
2223
#include <arrow/dataset/api.h>
2324
#include <arrow/filesystem/path_util.h>
2425
#include <arrow/record_batch.h>
@@ -34,6 +35,7 @@ limitations under the License.
3435

3536
#include <tensorflow/core/framework/tensor.h>
3637
#include <tensorflow/core/lib/core/errors.h>
38+
#include <tensorflow/core/public/version.h>
3739

3840
#define TF_RETURN_IF_ARROW_ERROR(...) \
3941
do { \
@@ -89,6 +91,31 @@ MATCH_TYPE_AND_ARROW_ENUM(float, ::arrow::Type::FLOAT);
8991
MATCH_TYPE_AND_ARROW_ENUM(double, ::arrow::Type::DOUBLE);
9092
MATCH_TYPE_AND_ARROW_ENUM(string, ::arrow::Type::STRING);
9193

94+
#if HYBRIDBACKEND_ARROW_ZEROCOPY
95+
class ArrowStringTensorBuffer : public TensorBuffer {
96+
public:
97+
ArrowStringTensorBuffer() = delete;
98+
explicit ArrowStringTensorBuffer(
99+
const std::shared_ptr<arrow::Buffer>& value_data_buf,
100+
const std::shared_ptr<arrow::Buffer>& value_offsets_buf,
101+
const uint8_t* raw_data, const int32_t* raw_value_offsets);
102+
#if (TF_MAJOR_VERSION * 1000L + TF_MINOR_VERSION) < 1014L
103+
void* data() const override;
104+
#endif
105+
const uint8_t* GetValue(int64_t i, int32_t* out_length);
106+
size_t size() const override;
107+
TensorBuffer* root_buffer() override;
108+
void FillAllocationDescription(AllocationDescription* proto) const override;
109+
bool OwnsMemory() const override;
110+
111+
private:
112+
std::shared_ptr<::arrow::Buffer> value_data_buf_;
113+
std::shared_ptr<::arrow::Buffer> value_offsets_buf_;
114+
const uint8_t* raw_data_;
115+
const int32_t* raw_value_offsets_;
116+
};
117+
#endif
118+
92119
Status MakeDataTypeAndRaggedRankFromArrowDataType(
93120
const std::shared_ptr<::arrow::DataType>& arrow_dtype, DataType* dtype,
94121
int32* ragged_rank);

0 commit comments

Comments
 (0)