Skip to content

Commit 1e418f0

Browse files
authored
Enable zero-copy transfer of pyarrow buffer to ArrowDataset kernel when in eager mode (tensorflow#413)
1 parent f7f1b9b commit 1e418f0

File tree

3 files changed

+232
-22
lines changed

3 files changed

+232
-22
lines changed

tensorflow_io/arrow/kernels/arrow_dataset_ops.cc

Lines changed: 156 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,155 @@ class ArrowOpKernelBase : public DatasetOpKernel {
583583
std::vector<PartialTensorShape> output_shapes_;
584584
};
585585

586-
// Op to create an ArrowDataset that consumes Arrow record batches from
587-
// memory in a Python process, or a Pandas DataFrame.
588-
class ArrowDatasetOp : public ArrowOpKernelBase {
586+
// Op to create an ArrowZeroCopyDataset that consumes Arrow record batches
587+
// from a memory buffer address owned in Python.
588+
class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
589589
public:
590-
explicit ArrowDatasetOp(OpKernelConstruction* ctx) : ArrowOpKernelBase(ctx) {}
590+
explicit ArrowZeroCopyDatasetOp(OpKernelConstruction* ctx)
591+
: ArrowOpKernelBase(ctx) {}
592+
593+
virtual void MakeArrowDataset(
594+
OpKernelContext* ctx,
595+
const std::vector<int32>& columns,
596+
const int64 batch_size,
597+
const ArrowBatchMode batch_mode,
598+
const DataTypeVector& output_types,
599+
const std::vector<PartialTensorShape>& output_shapes,
600+
ArrowDatasetBase** output) override {
601+
uintptr_t buffer_address;
602+
OP_REQUIRES_OK(
603+
ctx,
604+
ParseScalarArgument<uintptr_t>(ctx, "buffer_address", &buffer_address));
605+
const uint8_t* buffer = reinterpret_cast<const uint8_t*>(buffer_address);
606+
607+
int64_t buffer_size;
608+
OP_REQUIRES_OK(
609+
ctx,
610+
ParseScalarArgument<int64_t>(ctx, "buffer_size", &buffer_size));
611+
*output = new Dataset(
612+
ctx,
613+
buffer,
614+
buffer_size,
615+
columns,
616+
batch_size,
617+
batch_mode,
618+
output_types_,
619+
output_shapes_);
620+
}
621+
622+
private:
623+
class Dataset : public ArrowDatasetBase {
624+
public:
625+
Dataset(OpKernelContext* ctx,
626+
const uint8_t* buffer_ptr,
627+
const int64 buffer_size,
628+
const std::vector<int32>& columns,
629+
const int64 batch_size,
630+
const ArrowBatchMode batch_mode,
631+
const DataTypeVector& output_types,
632+
const std::vector<PartialTensorShape>& output_shapes)
633+
: ArrowDatasetBase(ctx, columns, batch_size, batch_mode,
634+
output_types, output_shapes),
635+
buffer_ptr_(buffer_ptr), buffer_size_(buffer_size) {}
636+
637+
string DebugString() const override {
638+
return "ArrowZeroCopyDatasetOp::Dataset";
639+
}
640+
641+
protected:
642+
Status AsGraphDefInternal(SerializationContext* ctx,
643+
DatasetGraphDefBuilder* b,
644+
Node** output) const override {
645+
Node* buffer = nullptr;
646+
uintptr_t buffer_temp = reinterpret_cast<uintptr_t>(buffer_ptr_);
647+
uint64 buffer_address = buffer_temp;
648+
TF_RETURN_IF_ERROR(b->AddScalar(buffer_address, &buffer));
649+
Node* size = nullptr;
650+
TF_RETURN_IF_ERROR(
651+
b->AddScalar(static_cast<int64>(buffer_size_), &size));
652+
Node* columns = nullptr;
653+
TF_RETURN_IF_ERROR(b->AddVector(columns_, &columns));
654+
Node* batch_size = nullptr;
655+
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
656+
Node* batch_mode = nullptr;
657+
string batch_mode_str;
658+
TF_RETURN_IF_ERROR(GetBatchModeStr(batch_mode_, &batch_mode_str));
659+
TF_RETURN_IF_ERROR(b->AddScalar(batch_mode_str, &batch_mode));
660+
TF_RETURN_IF_ERROR(
661+
b->AddDataset(
662+
this,
663+
{buffer, size, columns, batch_size, batch_mode},
664+
output));
665+
return Status::OK();
666+
}
667+
668+
std::unique_ptr<IteratorBase> MakeIteratorInternal(
669+
const string& prefix) const override {
670+
return std::unique_ptr<IteratorBase>(
671+
new Iterator({this, strings::StrCat(prefix, "::Arrow")}));
672+
}
673+
674+
private:
675+
class Iterator : public ArrowBaseIterator<Dataset> {
676+
public:
677+
explicit Iterator(const Params& params)
678+
: ArrowBaseIterator<Dataset>(params) {}
679+
680+
private:
681+
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
682+
buffer_ = std::make_shared<arrow::Buffer>(
683+
dataset()->buffer_ptr_,
684+
dataset()->buffer_size_);
685+
buffer_reader_ = std::make_shared<arrow::io::BufferReader>(buffer_);
686+
CHECK_ARROW(
687+
arrow::ipc::RecordBatchFileReader::Open(
688+
buffer_reader_.get(),
689+
buffer_->size(),
690+
&reader_));
691+
num_batches_ = reader_->num_record_batches();
692+
if (num_batches_ > 0) {
693+
CHECK_ARROW(
694+
reader_->ReadRecordBatch(current_batch_idx_, &current_batch_));
695+
TF_RETURN_IF_ERROR(CheckBatchColumnTypes(current_batch_));
696+
}
697+
return Status::OK();
698+
}
699+
700+
Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
701+
ArrowBaseIterator<Dataset>::NextStreamLocked();
702+
if (++current_batch_idx_ < num_batches_) {
703+
CHECK_ARROW(
704+
reader_->ReadRecordBatch(current_batch_idx_, &current_batch_));
705+
}
706+
return Status::OK();
707+
}
708+
709+
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
710+
ArrowBaseIterator<Dataset>::ResetStreamsLocked();
711+
reader_.reset();
712+
current_batch_idx_ = 0;
713+
num_batches_ = 0;
714+
}
715+
716+
std::shared_ptr<arrow::Buffer> buffer_ GUARDED_BY(mu_);
717+
std::shared_ptr<arrow::io::BufferReader> buffer_reader_ GUARDED_BY(mu_);
718+
std::shared_ptr<arrow::ipc::RecordBatchFileReader> reader_
719+
GUARDED_BY(mu_);
720+
int current_batch_idx_ GUARDED_BY(mu_) = 0;
721+
int num_batches_ GUARDED_BY(mu_) = 0;
722+
};
723+
724+
const uint8_t* buffer_ptr_;
725+
const int64 buffer_size_;
726+
};
727+
};
728+
729+
// Op to create an ArrowSerializedDataset that consumes Arrow record batches
730+
// serialized in a Tensor buffer.
731+
class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
732+
public:
733+
explicit ArrowSerializedDatasetOp(OpKernelConstruction* ctx)
734+
: ArrowOpKernelBase(ctx) {}
591735

592736
virtual void MakeArrowDataset(
593737
OpKernelContext* ctx,
@@ -629,7 +773,9 @@ class ArrowDatasetOp : public ArrowOpKernelBase {
629773
batches_(std::move(batches_tensor)) {
630774
}
631775

632-
string DebugString() const override { return "ArrowDatasetOp::Dataset"; }
776+
string DebugString() const override {
777+
return "ArrowSerializedDatasetOp::Dataset";
778+
}
633779

634780
protected:
635781
Status AsGraphDefInternal(SerializationContext* ctx,
@@ -1009,8 +1155,11 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
10091155
};
10101156
};
10111157

1012-
REGISTER_KERNEL_BUILDER(Name("ArrowDataset").Device(DEVICE_CPU),
1013-
ArrowDatasetOp);
1158+
REGISTER_KERNEL_BUILDER(Name("ArrowZeroCopyDataset").Device(DEVICE_CPU),
1159+
ArrowZeroCopyDatasetOp);
1160+
1161+
REGISTER_KERNEL_BUILDER(Name("ArrowSerializedDataset").Device(DEVICE_CPU),
1162+
ArrowSerializedDatasetOp);
10141163

10151164
REGISTER_KERNEL_BUILDER(Name("ArrowFeatherDataset").Device(DEVICE_CPU),
10161165
ArrowFeatherDatasetOp);

tensorflow_io/arrow/ops/dataset_ops.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,26 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22-
REGISTER_OP("ArrowDataset")
22+
REGISTER_OP("ArrowZeroCopyDataset")
23+
.Input("buffer_address: uint64")
24+
.Input("buffer_size: int64")
25+
.Input("columns: int32")
26+
.Input("batch_size: int64")
27+
.Input("batch_mode: string")
28+
.Output("handle: variant")
29+
.Attr("output_types: list(type) >= 1")
30+
.Attr("output_shapes: list(shape) >= 1")
31+
.SetIsStateful()
32+
.SetShapeFn(shape_inference::ScalarShape)
33+
.Doc(R"doc(
34+
Creates a dataset that zero-copy reads data from an Arrow Buffer.
35+
36+
buffer_address: Buffer address as long int with contents as Arrow RecordBatches
37+
in file format.
38+
buffer_size: Buffer size in bytes
39+
)doc");
40+
41+
REGISTER_OP("ArrowSerializedDataset")
2342
.Input("serialized_batches: string")
2443
.Input("columns: int32")
2544
.Input("batch_size: int64")

tensorflow_io/arrow/python/ops/arrow_dataset_ops.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,14 @@ def __init__(self,
161161
output_types,
162162
output_shapes=None,
163163
batch_size=None,
164-
batch_mode='keep_remainder'):
164+
batch_mode='keep_remainder',
165+
arrow_buffer=None):
165166
"""Create an ArrowDataset from a Tensor of serialized batches.
166167
This constructor requires pyarrow to be installed.
167168
168169
Args:
169170
serialized_batches: A string Tensor as a serialized buffer containing
170-
Arrow record batches as Arrow file format
171+
Arrow record batches in Arrow File format
171172
columns: A list of column indices to be used in the Dataset
172173
output_types: Tensor dtypes of the output tensors
173174
output_shapes: TensorShapes of the output tensors or None to
@@ -180,9 +181,39 @@ def __init__(self,
180181
"keep_remainder" (default, keeps partial batch data),
181182
"drop_remainder" (discard partial batch data),
182183
"auto" (size to number of records in Arrow record batch)
184+
arrow_buffer: Optional Arrow Buffer containing Arrow record batches in
185+
Arrow File format. This will share the Arrow buffer with
186+
the C++ kernel by address for zero-copy. Only supported if
187+
the kernel process is local, with TensorFlow in eager mode.
188+
If this is used, set `serialized_batches` to `None`.
183189
"""
190+
if serialized_batches is not None:
191+
make_variant_fn = partial(
192+
core_ops.arrow_serialized_dataset,
193+
serialized_batches)
194+
elif arrow_buffer is None:
195+
raise ValueError("Must set either serialzied_batches or arrow_buffer")
196+
elif not tf.executing_eagerly():
197+
raise ValueError("Using arrow_buffer for zero-copy only supported in "
198+
"TensorFlow Eager mode.")
199+
else:
200+
address_int = arrow_buffer.address
201+
buffer_address = tf.convert_to_tensor(
202+
address_int,
203+
dtype=dtypes.uint64,
204+
name="buffer_address")
205+
buffer_size = tf.convert_to_tensor(
206+
arrow_buffer.size,
207+
dtype=dtypes.int64,
208+
name="buffer_size")
209+
make_variant_fn = partial(
210+
core_ops.arrow_zero_copy_dataset,
211+
buffer_address,
212+
buffer_size)
213+
# Keep a reference to the arrow buffers used
214+
self._arrow_buffer_refs = [arrow_buffer]
184215
super(ArrowDataset, self).__init__(
185-
partial(core_ops.arrow_dataset, serialized_batches),
216+
make_variant_fn,
186217
columns,
187218
output_types,
188219
output_shapes,
@@ -221,22 +252,33 @@ def from_record_batches(cls,
221252
if columns is None:
222253
columns = tuple(range(record_batches[0].num_columns))
223254
assert record_batches
224-
buf = io.BytesIO()
225-
writer = pa.RecordBatchFileWriter(buf, record_batches[0].schema)
226-
for batch in record_batches:
227-
writer.write_batch(batch)
228-
writer.close()
229-
serialized_batches = tf.convert_to_tensor(
230-
buf.getvalue(),
231-
dtype=dtypes.string,
232-
name="serialized_batches")
255+
if tf.executing_eagerly():
256+
sink = pa.BufferOutputStream()
257+
writer = pa.RecordBatchFileWriter(sink, record_batches[0].schema)
258+
for batch in record_batches:
259+
writer.write_batch(batch)
260+
writer.close()
261+
serialized_batches = None
262+
arrow_buffer = sink.getvalue()
263+
else:
264+
buf = io.BytesIO()
265+
writer = pa.RecordBatchFileWriter(buf, record_batches[0].schema)
266+
for batch in record_batches:
267+
writer.write_batch(batch)
268+
writer.close()
269+
serialized_batches = tf.convert_to_tensor(
270+
buf.getvalue(),
271+
dtype=dtypes.string,
272+
name="serialized_batches")
273+
arrow_buffer = None
233274
return cls(
234275
serialized_batches,
235276
columns,
236277
output_types,
237278
output_shapes,
238-
batch_size,
239-
batch_mode)
279+
batch_size=batch_size,
280+
batch_mode=batch_mode,
281+
arrow_buffer=arrow_buffer)
240282

241283
@classmethod
242284
def from_pandas(cls,

0 commit comments

Comments
 (0)