Skip to content

Commit 2c39cb3

Browse files
BryanCutleryongtang
authored andcommitted
Use TFIO ArrowRandomAccessFile to read Arrow Feather files (tensorflow#418)
* Add feather test with prefix * Use ArrowRandomAccessFile for reading Feather files
1 parent aa5aa14 commit 2c39cb3

File tree

4 files changed

+63
-24
lines changed

4 files changed

+63
-24
lines changed

tensorflow_io/arrow/kernels/arrow_dataset_ops.cc

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ limitations under the License.
1717
#include "arrow/adapters/tensorflow/convert.h"
1818
#include "arrow/ipc/api.h"
1919
#include "arrow/util/io-util.h"
20-
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
21-
#include "tensorflow_io/arrow/kernels/arrow_util.h"
2220
#include "tensorflow/core/framework/dataset.h"
2321
#include "tensorflow/core/graph/graph.h"
22+
#include "tensorflow_io/core/kernels/stream.h"
23+
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
24+
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
25+
#include "tensorflow_io/arrow/kernels/arrow_util.h"
2426

2527
#define CHECK_ARROW(arrow_status) \
2628
do { \
@@ -31,6 +33,7 @@ limitations under the License.
3133
} while (false)
3234

3335
namespace tensorflow {
36+
namespace data {
3437

3538
enum ArrowBatchMode {
3639
BATCH_KEEP_REMAINDER,
@@ -294,7 +297,7 @@ class ArrowDatasetBase : public DatasetBase {
294297

295298
// If in initial state, setup and read first batch
296299
if (current_batch_ == nullptr && current_row_idx_ == 0) {
297-
TF_RETURN_IF_ERROR(SetupStreamsLocked());
300+
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
298301
}
299302

300303
std::vector<Tensor>* result_tensors = out_tensors;
@@ -309,7 +312,7 @@ class ArrowDatasetBase : public DatasetBase {
309312
// Try to go to next batch if consumed all rows in current batch
310313
if (current_batch_ != nullptr &&
311314
current_row_idx_ >= current_batch_->num_rows()) {
312-
TF_RETURN_IF_ERROR(NextStreamLocked());
315+
TF_RETURN_IF_ERROR(NextStreamLocked(ctx->env()));
313316
}
314317

315318
// Check if reached end of stream
@@ -465,11 +468,12 @@ class ArrowDatasetBase : public DatasetBase {
465468
}
466469

467470
// Setup Arrow record batch consumer and initialze current_batch_
468-
virtual Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
471+
virtual Status SetupStreamsLocked(Env* env)
472+
EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
469473

470474
// Get the next Arrow record batch, if available. If not then
471475
// current_batch_ will be set to nullptr to indicate no further batches.
472-
virtual Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
476+
virtual Status NextStreamLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
473477
current_batch_ = nullptr;
474478
current_row_idx_ = 0;
475479
return Status::OK();
@@ -678,7 +682,8 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
678682
: ArrowBaseIterator<Dataset>(params) {}
679683

680684
private:
681-
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
685+
Status SetupStreamsLocked(Env* env)
686+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
682687
buffer_ = std::make_shared<arrow::Buffer>(
683688
dataset()->buffer_ptr_,
684689
dataset()->buffer_size_);
@@ -697,8 +702,9 @@ class ArrowZeroCopyDatasetOp : public ArrowOpKernelBase {
697702
return Status::OK();
698703
}
699704

700-
Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
701-
ArrowBaseIterator<Dataset>::NextStreamLocked();
705+
Status NextStreamLocked(Env* env)
706+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
707+
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
702708
if (++current_batch_idx_ < num_batches_) {
703709
CHECK_ARROW(
704710
reader_->ReadRecordBatch(current_batch_idx_, &current_batch_));
@@ -818,7 +824,8 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
818824
: ArrowBaseIterator<Dataset>(params) {}
819825

820826
private:
821-
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
827+
Status SetupStreamsLocked(Env* env)
828+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
822829
const string& batches = dataset()->batches_.scalar<string>()();
823830
auto buffer = std::make_shared<arrow::Buffer>(batches);
824831
auto buffer_reader = std::make_shared<arrow::io::BufferReader>(buffer);
@@ -833,8 +840,9 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
833840
return Status::OK();
834841
}
835842

836-
Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
837-
ArrowBaseIterator<Dataset>::NextStreamLocked();
843+
Status NextStreamLocked(Env* env)
844+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
845+
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
838846
if (++current_batch_idx_ < num_batches_) {
839847
CHECK_ARROW(
840848
reader_->ReadRecordBatch(current_batch_idx_, &current_batch_));
@@ -864,8 +872,6 @@ class ArrowSerializedDatasetOp : public ArrowOpKernelBase {
864872
// ideal for simple writing of Pandas DataFrames.
865873
class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
866874
public:
867-
//using DatasetOpKernel::DatasetOpKernel;
868-
869875
explicit ArrowFeatherDatasetOp(OpKernelConstruction* ctx)
870876
: ArrowOpKernelBase(ctx) {}
871877

@@ -951,10 +957,22 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
951957
: ArrowBaseIterator<Dataset>(params) {}
952958

953959
private:
954-
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
960+
Status SetupStreamsLocked(Env* env)
961+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
955962
const string& filename = dataset()->filenames_[current_file_idx_];
956-
std::shared_ptr<arrow::io::ReadableFile> in_file;
957-
CHECK_ARROW(arrow::io::ReadableFile::Open(filename, &in_file));
963+
964+
// Init a TF file from the filename and determine size
965+
// TODO: set optional memory to nullptr until input arg is added
966+
std::shared_ptr<SizedRandomAccessFile> tf_file(
967+
new SizedRandomAccessFile(env, filename, nullptr, 0));
968+
uint64 size;
969+
TF_RETURN_IF_ERROR(tf_file->GetFileSize(&size));
970+
971+
// Wrap the TF file in Arrow interface to be used in Feather reader
972+
std::shared_ptr<ArrowRandomAccessFile> in_file(
973+
new ArrowRandomAccessFile(tf_file.get(), size));
974+
975+
// Create the Feather reader
958976
std::unique_ptr<arrow::ipc::feather::TableReader> reader;
959977
CHECK_ARROW(arrow::ipc::feather::TableReader::Open(in_file, &reader));
960978

@@ -982,14 +1000,15 @@ class ArrowFeatherDatasetOp : public ArrowOpKernelBase {
9821000
return Status::OK();
9831001
}
9841002

985-
Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
986-
ArrowBaseIterator<Dataset>::NextStreamLocked();
1003+
Status NextStreamLocked(Env* env)
1004+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1005+
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
9871006
if (++current_batch_idx_ < record_batches_.size()) {
9881007
current_batch_ = record_batches_[current_batch_idx_];
9891008
} else if (++current_file_idx_ < dataset()->filenames_.size()) {
9901009
current_batch_idx_ = 0;
9911010
record_batches_.clear();
992-
SetupStreamsLocked();
1011+
return SetupStreamsLocked(env);
9931012
}
9941013
return Status::OK();
9951014
}
@@ -1102,7 +1121,7 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
11021121
: ArrowBaseIterator<Dataset>(params) {}
11031122

11041123
private:
1105-
Status SetupStreamsLocked()
1124+
Status SetupStreamsLocked(Env* env)
11061125
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
11071126
const string& endpoint = dataset()->endpoints_[current_endpoint_idx_];
11081127
string endpoint_type;
@@ -1128,13 +1147,14 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
11281147
return Status::OK();
11291148
}
11301149

1131-
Status NextStreamLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1132-
ArrowBaseIterator<Dataset>::NextStreamLocked();
1150+
Status NextStreamLocked(Env* env)
1151+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
1152+
ArrowBaseIterator<Dataset>::NextStreamLocked(env);
11331153
CHECK_ARROW(reader_->ReadNext(&current_batch_));
11341154
if (current_batch_ == nullptr &&
11351155
++current_endpoint_idx_ < dataset()->endpoints_.size()) {
11361156
reader_.reset();
1137-
SetupStreamsLocked();
1157+
SetupStreamsLocked(env);
11381158
}
11391159
return Status::OK();
11401160
}
@@ -1167,4 +1187,5 @@ REGISTER_KERNEL_BUILDER(Name("ArrowFeatherDataset").Device(DEVICE_CPU),
11671187
REGISTER_KERNEL_BUILDER(Name("ArrowStreamDataset").Device(DEVICE_CPU),
11681188
ArrowStreamDatasetOp);
11691189

1190+
} // namespace data
11701191
} // namespace tensorflow

tensorflow_io/arrow/kernels/arrow_kernels.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#ifndef TENSORFLOW_IO_ARROW_KERNELS_H_
17+
#define TENSORFLOW_IO_ARROW_KERNELS_H_
18+
1619
#include "kernels/stream.h"
1720
#include "arrow/io/api.h"
1821
#include "arrow/buffer.h"
@@ -78,3 +81,5 @@ class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
7881
};
7982
} // namespace data
8083
} // namespace tensorflow
84+
85+
#endif // TENSORFLOW_IO_ARROW_KERNELS_H_

tensorflow_io/core/kernels/stream.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#ifndef TENSORFLOW_IO_CORE_KERNELS_STREAM_H_
17+
#define TENSORFLOW_IO_CORE_KERNELS_STREAM_H_
18+
1619
#include "tensorflow/core/lib/io/inputstream_interface.h"
1720
#include "tensorflow/core/lib/io/random_inputstream.h"
1821

@@ -69,3 +72,5 @@ class SizedRandomAccessFile : public tensorflow::RandomAccessFile {
6972

7073
} // namespace data
7174
} // namespace tensorflow
75+
76+
#endif // TENSORFLOW_IO_CORE_KERNELS_STREAM_H_

tests/test_arrow_eager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ def test_arrow_feather_dataset(self):
282282
truth_data.output_shapes)
283283
self.run_test_case(dataset, truth_data)
284284

285+
# test single file with 'file://' prefix
286+
dataset = arrow_io.ArrowFeatherDataset(
287+
"file://{}".format(f.name),
288+
list(range(len(truth_data.output_types))),
289+
truth_data.output_types,
290+
truth_data.output_shapes)
291+
self.run_test_case(dataset, truth_data)
292+
285293
# test multiple files
286294
dataset = arrow_io.ArrowFeatherDataset(
287295
[f.name, f.name],

0 commit comments

Comments
 (0)