@@ -583,11 +583,155 @@ class ArrowOpKernelBase : public DatasetOpKernel {
583
583
std::vector<PartialTensorShape> output_shapes_;
584
584
};
585
585
586
- // Op to create an ArrowDataset that consumes Arrow record batches from
587
- // memory in a Python process, or a Pandas DataFrame .
588
- class ArrowDatasetOp : public ArrowOpKernelBase {
586
+ // Op to create an ArrowZeroCopyDataset that consumes Arrow record batches
587
+ // from a memory buffer address owned in Python .
588
+ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
589
589
public:
590
- explicit ArrowDatasetOp (OpKernelConstruction* ctx) : ArrowOpKernelBase(ctx) {}
590
+ explicit ArrowZeroCopyDatasetOp (OpKernelConstruction* ctx)
591
+ : ArrowOpKernelBase(ctx) {}
592
+
593
+ virtual void MakeArrowDataset (
594
+ OpKernelContext* ctx,
595
+ const std::vector<int32>& columns,
596
+ const int64 batch_size,
597
+ const ArrowBatchMode batch_mode,
598
+ const DataTypeVector& output_types,
599
+ const std::vector<PartialTensorShape>& output_shapes,
600
+ ArrowDatasetBase** output) override {
601
+ uintptr_t buffer_address;
602
+ OP_REQUIRES_OK (
603
+ ctx,
604
+ ParseScalarArgument<uintptr_t >(ctx, " buffer_address" , &buffer_address));
605
+ const uint8_t * buffer = reinterpret_cast <const uint8_t *>(buffer_address);
606
+
607
+ int64_t buffer_size;
608
+ OP_REQUIRES_OK (
609
+ ctx,
610
+ ParseScalarArgument<int64_t >(ctx, " buffer_size" , &buffer_size));
611
+ *output = new Dataset (
612
+ ctx,
613
+ buffer,
614
+ buffer_size,
615
+ columns,
616
+ batch_size,
617
+ batch_mode,
618
+ output_types_,
619
+ output_shapes_);
620
+ }
621
+
622
+ private:
623
+ class Dataset : public ArrowDatasetBase {
624
+ public:
625
+ Dataset (OpKernelContext* ctx,
626
+ const uint8_t * buffer_ptr,
627
+ const int64 buffer_size,
628
+ const std::vector<int32>& columns,
629
+ const int64 batch_size,
630
+ const ArrowBatchMode batch_mode,
631
+ const DataTypeVector& output_types,
632
+ const std::vector<PartialTensorShape>& output_shapes)
633
+ : ArrowDatasetBase(ctx, columns, batch_size, batch_mode,
634
+ output_types, output_shapes),
635
+ buffer_ptr_ (buffer_ptr), buffer_size_(buffer_size) {}
636
+
637
+ string DebugString () const override {
638
+ return " ArrowZeroCopyDatasetOp::Dataset" ;
639
+ }
640
+
641
+ protected:
642
+ Status AsGraphDefInternal (SerializationContext* ctx,
643
+ DatasetGraphDefBuilder* b,
644
+ Node** output) const override {
645
+ Node* buffer = nullptr ;
646
+ uintptr_t buffer_temp = reinterpret_cast <uintptr_t >(buffer_ptr_);
647
+ uint64 buffer_address = buffer_temp;
648
+ TF_RETURN_IF_ERROR (b->AddScalar (buffer_address, &buffer));
649
+ Node* size = nullptr ;
650
+ TF_RETURN_IF_ERROR (
651
+ b->AddScalar (static_cast <int64>(buffer_size_), &size));
652
+ Node* columns = nullptr ;
653
+ TF_RETURN_IF_ERROR (b->AddVector (columns_, &columns));
654
+ Node* batch_size = nullptr ;
655
+ TF_RETURN_IF_ERROR (b->AddScalar (batch_size_, &batch_size));
656
+ Node* batch_mode = nullptr ;
657
+ string batch_mode_str;
658
+ TF_RETURN_IF_ERROR (GetBatchModeStr (batch_mode_, &batch_mode_str));
659
+ TF_RETURN_IF_ERROR (b->AddScalar (batch_mode_str, &batch_mode));
660
+ TF_RETURN_IF_ERROR (
661
+ b->AddDataset (
662
+ this ,
663
+ {buffer, size, columns, batch_size, batch_mode},
664
+ output));
665
+ return Status::OK ();
666
+ }
667
+
668
+ std::unique_ptr<IteratorBase> MakeIteratorInternal (
669
+ const string& prefix) const override {
670
+ return std::unique_ptr<IteratorBase>(
671
+ new Iterator ({this , strings::StrCat (prefix, " ::Arrow" )}));
672
+ }
673
+
674
+ private:
675
+ class Iterator : public ArrowBaseIterator <Dataset> {
676
+ public:
677
+ explicit Iterator (const Params& params)
678
+ : ArrowBaseIterator<Dataset>(params) {}
679
+
680
+ private:
681
+ Status SetupStreamsLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
682
+ buffer_ = std::make_shared<arrow::Buffer>(
683
+ dataset ()->buffer_ptr_ ,
684
+ dataset ()->buffer_size_ );
685
+ buffer_reader_ = std::make_shared<arrow::io::BufferReader>(buffer_);
686
+ CHECK_ARROW (
687
+ arrow::ipc::RecordBatchFileReader::Open (
688
+ buffer_reader_.get (),
689
+ buffer_->size (),
690
+ &reader_));
691
+ num_batches_ = reader_->num_record_batches ();
692
+ if (num_batches_ > 0 ) {
693
+ CHECK_ARROW (
694
+ reader_->ReadRecordBatch (current_batch_idx_, ¤t_batch_));
695
+ TF_RETURN_IF_ERROR (CheckBatchColumnTypes (current_batch_));
696
+ }
697
+ return Status::OK ();
698
+ }
699
+
700
+ Status NextStreamLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
701
+ ArrowBaseIterator<Dataset>::NextStreamLocked ();
702
+ if (++current_batch_idx_ < num_batches_) {
703
+ CHECK_ARROW (
704
+ reader_->ReadRecordBatch (current_batch_idx_, ¤t_batch_));
705
+ }
706
+ return Status::OK ();
707
+ }
708
+
709
+ void ResetStreamsLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
710
+ ArrowBaseIterator<Dataset>::ResetStreamsLocked ();
711
+ reader_.reset ();
712
+ current_batch_idx_ = 0 ;
713
+ num_batches_ = 0 ;
714
+ }
715
+
716
+ std::shared_ptr<arrow::Buffer> buffer_ GUARDED_BY (mu_);
717
+ std::shared_ptr<arrow::io::BufferReader> buffer_reader_ GUARDED_BY (mu_);
718
+ std::shared_ptr<arrow::ipc::RecordBatchFileReader> reader_
719
+ GUARDED_BY (mu_);
720
+ int current_batch_idx_ GUARDED_BY (mu_) = 0;
721
+ int num_batches_ GUARDED_BY (mu_) = 0;
722
+ };
723
+
724
+ const uint8_t * buffer_ptr_;
725
+ const int64 buffer_size_;
726
+ };
727
+ };
728
+
729
+ // Op to create an ArrowSerializedDataset that consumes Arrow record batches
730
+ // serialized in a Tensor buffer.
731
+ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
732
+ public:
733
+ explicit ArrowSerializedDatasetOp (OpKernelConstruction* ctx)
734
+ : ArrowOpKernelBase(ctx) {}
591
735
592
736
virtual void MakeArrowDataset (
593
737
OpKernelContext* ctx,
@@ -629,7 +773,9 @@ class ArrowDatasetOp : public ArrowOpKernelBase {
629
773
batches_ (std::move(batches_tensor)) {
630
774
}
631
775
632
- string DebugString () const override { return " ArrowDatasetOp::Dataset" ; }
776
+ string DebugString () const override {
777
+ return " ArrowSerializedDatasetOp::Dataset" ;
778
+ }
633
779
634
780
protected:
635
781
Status AsGraphDefInternal (SerializationContext* ctx,
@@ -1009,8 +1155,11 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
1009
1155
};
1010
1156
};
1011
1157
1012
- REGISTER_KERNEL_BUILDER (Name(" ArrowDataset" ).Device(DEVICE_CPU),
1013
- ArrowDatasetOp);
1158
+ REGISTER_KERNEL_BUILDER (Name(" ArrowZeroCopyDataset" ).Device(DEVICE_CPU),
1159
+ ArrowZeroCopyDatasetOp);
1160
+
1161
+ REGISTER_KERNEL_BUILDER (Name(" ArrowSerializedDataset" ).Device(DEVICE_CPU),
1162
+ ArrowSerializedDatasetOp);
1014
1163
1015
1164
REGISTER_KERNEL_BUILDER (Name(" ArrowFeatherDataset" ).Device(DEVICE_CPU),
1016
1165
ArrowFeatherDatasetOp);
0 commit comments