@@ -17,10 +17,12 @@ limitations under the License.
17
17
#include " arrow/adapters/tensorflow/convert.h"
18
18
#include " arrow/ipc/api.h"
19
19
#include " arrow/util/io-util.h"
20
- #include " tensorflow_io/arrow/kernels/arrow_stream_client.h"
21
- #include " tensorflow_io/arrow/kernels/arrow_util.h"
22
20
#include " tensorflow/core/framework/dataset.h"
23
21
#include " tensorflow/core/graph/graph.h"
22
+ #include " tensorflow_io/core/kernels/stream.h"
23
+ #include " tensorflow_io/arrow/kernels/arrow_kernels.h"
24
+ #include " tensorflow_io/arrow/kernels/arrow_stream_client.h"
25
+ #include " tensorflow_io/arrow/kernels/arrow_util.h"
24
26
25
27
#define CHECK_ARROW (arrow_status ) \
26
28
do { \
@@ -31,6 +33,7 @@ limitations under the License.
31
33
} while (false )
32
34
33
35
namespace tensorflow {
36
+ namespace data {
34
37
35
38
enum ArrowBatchMode {
36
39
BATCH_KEEP_REMAINDER,
@@ -294,7 +297,7 @@ class ArrowDatasetBase : public DatasetBase {
294
297
295
298
// If in initial state, setup and read first batch
296
299
if (current_batch_ == nullptr && current_row_idx_ == 0 ) {
297
- TF_RETURN_IF_ERROR (SetupStreamsLocked ());
300
+ TF_RETURN_IF_ERROR (SetupStreamsLocked (ctx-> env () ));
298
301
}
299
302
300
303
std::vector<Tensor>* result_tensors = out_tensors;
@@ -309,7 +312,7 @@ class ArrowDatasetBase : public DatasetBase {
309
312
// Try to go to next batch if consumed all rows in current batch
310
313
if (current_batch_ != nullptr &&
311
314
current_row_idx_ >= current_batch_->num_rows ()) {
312
- TF_RETURN_IF_ERROR (NextStreamLocked ());
315
+ TF_RETURN_IF_ERROR (NextStreamLocked (ctx-> env () ));
313
316
}
314
317
315
318
// Check if reached end of stream
@@ -465,11 +468,12 @@ class ArrowDatasetBase : public DatasetBase {
465
468
}
466
469
467
470
// Setup Arrow record batch consumer and initialze current_batch_
468
- virtual Status SetupStreamsLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
471
+ virtual Status SetupStreamsLocked (Env* env)
472
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
469
473
470
474
// Get the next Arrow record batch, if available. If not then
471
475
// current_batch_ will be set to nullptr to indicate no further batches.
472
- virtual Status NextStreamLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) {
476
+ virtual Status NextStreamLocked (Env* env ) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
473
477
current_batch_ = nullptr ;
474
478
current_row_idx_ = 0 ;
475
479
return Status::OK ();
@@ -678,7 +682,8 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
678
682
: ArrowBaseIterator<Dataset>(params) {}
679
683
680
684
private:
681
- Status SetupStreamsLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
685
+ Status SetupStreamsLocked (Env* env)
686
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
682
687
buffer_ = std::make_shared<arrow::Buffer>(
683
688
dataset ()->buffer_ptr_ ,
684
689
dataset ()->buffer_size_ );
@@ -697,8 +702,9 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
697
702
return Status::OK ();
698
703
}
699
704
700
- Status NextStreamLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
701
- ArrowBaseIterator<Dataset>::NextStreamLocked ();
705
+ Status NextStreamLocked (Env* env)
706
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
707
+ ArrowBaseIterator<Dataset>::NextStreamLocked (env);
702
708
if (++current_batch_idx_ < num_batches_) {
703
709
CHECK_ARROW (
704
710
reader_->ReadRecordBatch (current_batch_idx_, ¤t_batch_));
@@ -818,7 +824,8 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
818
824
: ArrowBaseIterator<Dataset>(params) {}
819
825
820
826
private:
821
- Status SetupStreamsLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
827
+ Status SetupStreamsLocked (Env* env)
828
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
822
829
const string& batches = dataset ()->batches_ .scalar <string>()();
823
830
auto buffer = std::make_shared<arrow::Buffer>(batches);
824
831
auto buffer_reader = std::make_shared<arrow::io::BufferReader>(buffer);
@@ -833,8 +840,9 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
833
840
return Status::OK ();
834
841
}
835
842
836
- Status NextStreamLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
837
- ArrowBaseIterator<Dataset>::NextStreamLocked ();
843
+ Status NextStreamLocked (Env* env)
844
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
845
+ ArrowBaseIterator<Dataset>::NextStreamLocked (env);
838
846
if (++current_batch_idx_ < num_batches_) {
839
847
CHECK_ARROW (
840
848
reader_->ReadRecordBatch (current_batch_idx_, ¤t_batch_));
@@ -864,8 +872,6 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
864
872
// ideal for simple writing of Pandas DataFrames.
865
873
class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
866
874
public:
867
- // using DatasetOpKernel::DatasetOpKernel;
868
-
869
875
explicit ArrowFeatherDatasetOp (OpKernelConstruction* ctx)
870
876
: ArrowOpKernelBase(ctx) {}
871
877
@@ -951,10 +957,22 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
951
957
: ArrowBaseIterator<Dataset>(params) {}
952
958
953
959
private:
954
- Status SetupStreamsLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
960
+ Status SetupStreamsLocked (Env* env)
961
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
955
962
const string& filename = dataset ()->filenames_ [current_file_idx_];
956
- std::shared_ptr<arrow::io::ReadableFile> in_file;
957
- CHECK_ARROW (arrow::io::ReadableFile::Open (filename, &in_file));
963
+
964
+ // Init a TF file from the filename and determine size
965
+ // TODO: set optional memory to nullptr until input arg is added
966
+ std::shared_ptr<SizedRandomAccessFile> tf_file (
967
+ new SizedRandomAccessFile (env, filename, nullptr , 0 ));
968
+ uint64 size;
969
+ TF_RETURN_IF_ERROR (tf_file->GetFileSize (&size));
970
+
971
+ // Wrap the TF file in Arrow interface to be used in Feather reader
972
+ std::shared_ptr<ArrowRandomAccessFile> in_file (
973
+ new ArrowRandomAccessFile (tf_file.get (), size));
974
+
975
+ // Create the Feather reader
958
976
std::unique_ptr<arrow::ipc::feather::TableReader> reader;
959
977
CHECK_ARROW (arrow::ipc::feather::TableReader::Open (in_file, &reader));
960
978
@@ -982,14 +1000,15 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
982
1000
return Status::OK ();
983
1001
}
984
1002
985
- Status NextStreamLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
986
- ArrowBaseIterator<Dataset>::NextStreamLocked ();
1003
+ Status NextStreamLocked (Env* env)
1004
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1005
+ ArrowBaseIterator<Dataset>::NextStreamLocked (env);
987
1006
if (++current_batch_idx_ < record_batches_.size ()) {
988
1007
current_batch_ = record_batches_[current_batch_idx_];
989
1008
} else if (++current_file_idx_ < dataset ()->filenames_ .size ()) {
990
1009
current_batch_idx_ = 0 ;
991
1010
record_batches_.clear ();
992
- SetupStreamsLocked ();
1011
+ return SetupStreamsLocked (env );
993
1012
}
994
1013
return Status::OK ();
995
1014
}
@@ -1102,7 +1121,7 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
1102
1121
: ArrowBaseIterator<Dataset>(params) {}
1103
1122
1104
1123
private:
1105
- Status SetupStreamsLocked ()
1124
+ Status SetupStreamsLocked (Env* env )
1106
1125
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1107
1126
const string& endpoint = dataset ()->endpoints_ [current_endpoint_idx_];
1108
1127
string endpoint_type;
@@ -1128,13 +1147,14 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
1128
1147
return Status::OK ();
1129
1148
}
1130
1149
1131
- Status NextStreamLocked () EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1132
- ArrowBaseIterator<Dataset>::NextStreamLocked ();
1150
+ Status NextStreamLocked (Env* env)
1151
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1152
+ ArrowBaseIterator<Dataset>::NextStreamLocked (env);
1133
1153
CHECK_ARROW (reader_->ReadNext (¤t_batch_));
1134
1154
if (current_batch_ == nullptr &&
1135
1155
++current_endpoint_idx_ < dataset ()->endpoints_ .size ()) {
1136
1156
reader_.reset ();
1137
- SetupStreamsLocked ();
1157
+ SetupStreamsLocked (env );
1138
1158
}
1139
1159
return Status::OK ();
1140
1160
}
@@ -1167,4 +1187,5 @@ REGISTER_KERNEL_BUILDER(Name("ArrowFeatherDataset").Device(DEVICE_CPU),
1167
1187
REGISTER_KERNEL_BUILDER (Name(" ArrowStreamDataset" ).Device(DEVICE_CPU),
1168
1188
ArrowStreamDatasetOp);
1169
1189
1190
+ } // namespace data
1170
1191
} // namespace tensorflow
0 commit comments