Skip to content

Commit 44c6f48

Browse files
committed
GH-46710: [C++] Fix ownership and lifetime issues in Dataset Writer
1 parent 832bfa1 commit 44c6f48

File tree

2 files changed

+109
-71
lines changed

2 files changed

+109
-71
lines changed

cpp/src/arrow/dataset/dataset_writer.cc

Lines changed: 98 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -131,28 +131,38 @@ Result<std::shared_ptr<FileWriter>> OpenWriter(
131131
{write_options.filesystem, filename});
132132
}
133133

134-
class DatasetWriterFileQueue {
134+
class DatasetWriterFileQueue
135+
: public std::enable_shared_from_this<DatasetWriterFileQueue> {
135136
public:
136137
explicit DatasetWriterFileQueue(const std::shared_ptr<Schema>& schema,
137138
const FileSystemDatasetWriteOptions& options,
138-
DatasetWriterState* writer_state)
139-
: options_(options), schema_(schema), writer_state_(writer_state) {}
139+
std::shared_ptr<DatasetWriterState> writer_state)
140+
: options_(options), schema_(schema), writer_state_(std::move(writer_state)) {}
140141

141-
void Start(util::AsyncTaskScheduler* file_tasks, const std::string& filename) {
142-
file_tasks_ = file_tasks;
142+
void Start(std::unique_ptr<util::ThrottledAsyncTaskScheduler> file_tasks,
143+
std::string filename) {
144+
file_tasks_ = std::move(file_tasks);
143145
// Because the scheduler runs one task at a time we know the writer will
144146
// be opened before any attempt to write
145147
file_tasks_->AddSimpleTask(
146-
[this, filename] {
147-
Executor* io_executor = options_.filesystem->io_context().executor();
148-
return DeferNotOk(io_executor->Submit([this, filename]() {
149-
ARROW_ASSIGN_OR_RAISE(writer_, OpenWriter(options_, schema_, filename));
148+
[self = shared_from_this(), filename = std::move(filename)] {
149+
Executor* io_executor = self->options_.filesystem->io_context().executor();
150+
return DeferNotOk(io_executor->Submit([self, filename = std::move(filename)]() {
151+
ARROW_ASSIGN_OR_RAISE(self->writer_,
152+
OpenWriter(self->options_, self->schema_, filename));
150153
return Status::OK();
151154
}));
152155
},
153156
"DatasetWriter::OpenWriter"sv);
154157
}
155158

159+
void Abort() {
160+
// The scheduler may be keeping this object alive through shared_ptr references
161+
// in async closures. Make sure we break any reference cycles by losing our
162+
// reference to the scheduler.
163+
file_tasks_.reset();
164+
}
165+
156166
Result<std::shared_ptr<RecordBatch>> PopStagedBatch() {
157167
std::vector<std::shared_ptr<RecordBatch>> batches_to_write;
158168
uint64_t num_rows = 0;
@@ -184,7 +194,7 @@ class DatasetWriterFileQueue {
184194

185195
void ScheduleBatch(std::shared_ptr<RecordBatch> batch) {
186196
file_tasks_->AddSimpleTask(
187-
[self = this, batch = std::move(batch)]() {
197+
[self = shared_from_this(), batch = std::move(batch)]() {
188198
return self->WriteNext(std::move(batch));
189199
},
190200
"DatasetWriter::WriteBatch"sv);
@@ -217,21 +227,26 @@ class DatasetWriterFileQueue {
217227
Status Finish() {
218228
writer_state_->staged_rows_count -= rows_currently_staged_;
219229
while (!staged_batches_.empty()) {
220-
RETURN_NOT_OK(PopAndDeliverStagedBatch());
230+
auto st = PopAndDeliverStagedBatch().status();
231+
if (!st.ok()) {
232+
file_tasks_.reset();
233+
return st;
234+
}
221235
}
222236
// At this point all write tasks have been added. Because the scheduler
223237
// is a 1-task FIFO we know this task will run at the very end and can
224238
// add it now.
225-
file_tasks_->AddSimpleTask([this] { return DoFinish(); },
239+
file_tasks_->AddSimpleTask([self = shared_from_this()] { return self->DoFinish(); },
226240
"DatasetWriter::FinishFile"sv);
241+
file_tasks_.reset();
227242
return Status::OK();
228243
}
229244

230245
private:
231246
Future<> WriteNext(std::shared_ptr<RecordBatch> next) {
232247
// May want to prototype / measure someday pushing the async write down further
233248
return DeferNotOk(options_.filesystem->io_context().executor()->Submit(
234-
[self = this, batch = std::move(next)]() {
249+
[self = shared_from_this(), batch = std::move(next)]() {
235250
int64_t rows_to_release = batch->num_rows();
236251
Status status = self->writer_->Write(batch);
237252
self->writer_state_->rows_in_flight_throttle.Release(rows_to_release);
@@ -244,40 +259,48 @@ class DatasetWriterFileQueue {
244259
std::lock_guard<std::mutex> lg(writer_state_->visitors_mutex);
245260
RETURN_NOT_OK(options_.writer_pre_finish(writer_.get()));
246261
}
247-
return writer_->Finish().Then([this]() {
248-
std::lock_guard<std::mutex> lg(writer_state_->visitors_mutex);
249-
return options_.writer_post_finish(writer_.get());
250-
});
262+
return writer_->Finish().Then(
263+
[self = shared_from_this(), writer_post_finish = options_.writer_post_finish]() {
264+
std::lock_guard<std::mutex> lg(self->writer_state_->visitors_mutex);
265+
return writer_post_finish(self->writer_.get());
266+
});
251267
}
252268

253269
const FileSystemDatasetWriteOptions& options_;
254270
const std::shared_ptr<Schema>& schema_;
255-
DatasetWriterState* writer_state_;
271+
std::shared_ptr<DatasetWriterState> writer_state_;
256272
std::shared_ptr<FileWriter> writer_;
257273
// Batches are accumulated here until they are large enough to write out at which
258274
// point they are merged together and added to write_queue_
259275
std::deque<std::shared_ptr<RecordBatch>> staged_batches_;
260276
uint64_t rows_currently_staged_ = 0;
261-
util::AsyncTaskScheduler* file_tasks_ = nullptr;
277+
std::unique_ptr<util::ThrottledAsyncTaskScheduler> file_tasks_;
262278
};
263279

264280
struct WriteTask {
265281
std::string filename;
266282
uint64_t num_rows;
267283
};
268284

269-
class DatasetWriterDirectoryQueue {
285+
class DatasetWriterDirectoryQueue
286+
: public std::enable_shared_from_this<DatasetWriterDirectoryQueue> {
270287
public:
271288
DatasetWriterDirectoryQueue(util::AsyncTaskScheduler* scheduler, std::string directory,
272289
std::string prefix, std::shared_ptr<Schema> schema,
273290
const FileSystemDatasetWriteOptions& write_options,
274-
DatasetWriterState* writer_state)
291+
std::shared_ptr<DatasetWriterState> writer_state)
275292
: scheduler_(std::move(scheduler)),
276293
directory_(std::move(directory)),
277294
prefix_(std::move(prefix)),
278295
schema_(std::move(schema)),
279296
write_options_(write_options),
280-
writer_state_(writer_state) {}
297+
writer_state_(std::move(writer_state)) {}
298+
299+
~DatasetWriterDirectoryQueue() {
300+
if (latest_open_file_) {
301+
latest_open_file_->Abort();
302+
}
303+
}
281304

282305
Result<std::shared_ptr<RecordBatch>> NextWritableChunk(
283306
std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder,
@@ -330,32 +353,27 @@ class DatasetWriterDirectoryQueue {
330353

331354
Status FinishCurrentFile() {
332355
if (latest_open_file_) {
333-
ARROW_RETURN_NOT_OK(latest_open_file_->Finish());
334-
latest_open_file_tasks_.reset();
335-
latest_open_file_ = nullptr;
356+
auto file = std::move(latest_open_file_);
357+
ARROW_RETURN_NOT_OK(file->Finish());
336358
}
337359
rows_written_ = 0;
338360
return GetNextFilename().Value(&current_filename_);
339361
}
340362

341363
Status OpenFileQueue(const std::string& filename) {
342-
auto file_queue =
343-
std::make_unique<DatasetWriterFileQueue>(schema_, write_options_, writer_state_);
344-
latest_open_file_ = file_queue.get();
345-
// Create a dedicated throttle for write jobs to this file and keep it alive until we
346-
// are finished and have closed the file.
347-
auto file_finish_task = [this, file_queue = std::move(file_queue)] {
348-
writer_state_->open_files_throttle.Release(1);
364+
latest_open_file_.reset(
365+
new DatasetWriterFileQueue(schema_, write_options_, writer_state_));
366+
auto file_finish_task = [self = shared_from_this()] {
367+
self->writer_state_->open_files_throttle.Release(1);
349368
return Status::OK();
350369
};
351-
latest_open_file_tasks_ = util::MakeThrottledAsyncTaskGroup(
352-
scheduler_, 1, /*queue=*/nullptr, std::move(file_finish_task));
370+
auto file_tasks = util::MakeThrottledAsyncTaskGroup(scheduler_, 1, /*queue=*/nullptr,
371+
std::move(file_finish_task));
353372
if (init_future_.is_valid()) {
354-
latest_open_file_tasks_->AddSimpleTask(
355-
[init_future = init_future_]() { return init_future; },
356-
"DatasetWriter::WaitForDirectoryInit"sv);
373+
file_tasks->AddSimpleTask([init_future = init_future_]() { return init_future; },
374+
"DatasetWriter::WaitForDirectoryInit"sv);
357375
}
358-
latest_open_file_->Start(latest_open_file_tasks_.get(), filename);
376+
latest_open_file_->Start(std::move(file_tasks), filename);
359377
return Status::OK();
360378
}
361379

@@ -398,41 +416,46 @@ class DatasetWriterDirectoryQueue {
398416
"DatasetWriter::InitializeDirectory"sv);
399417
}
400418

401-
static Result<std::unique_ptr<DatasetWriterDirectoryQueue>> Make(
419+
static Result<std::shared_ptr<DatasetWriterDirectoryQueue>> Make(
402420
util::AsyncTaskScheduler* scheduler,
403421
const FileSystemDatasetWriteOptions& write_options,
404-
DatasetWriterState* writer_state, std::shared_ptr<Schema> schema,
422+
std::shared_ptr<DatasetWriterState> writer_state, std::shared_ptr<Schema> schema,
405423
std::string directory, std::string prefix) {
406-
auto dir_queue = std::make_unique<DatasetWriterDirectoryQueue>(
424+
auto dir_queue = std::make_shared<DatasetWriterDirectoryQueue>(
407425
scheduler, std::move(directory), std::move(prefix), std::move(schema),
408-
write_options, writer_state);
426+
write_options, std::move(writer_state));
409427
dir_queue->PrepareDirectory();
410428
ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename());
411429
return dir_queue;
412430
}
413431

414432
Status Finish() {
415433
if (latest_open_file_) {
416-
ARROW_RETURN_NOT_OK(latest_open_file_->Finish());
417-
latest_open_file_tasks_.reset();
418-
latest_open_file_ = nullptr;
434+
auto file = std::move(latest_open_file_);
435+
ARROW_RETURN_NOT_OK(file->Finish());
419436
}
420437
used_filenames_.clear();
421438
return Status::OK();
422439
}
423440

441+
void Abort() {
442+
if (latest_open_file_) {
443+
latest_open_file_->Abort();
444+
latest_open_file_.reset();
445+
}
446+
}
447+
424448
private:
425449
util::AsyncTaskScheduler* scheduler_ = nullptr;
426450
std::string directory_;
427451
std::string prefix_;
428452
std::shared_ptr<Schema> schema_;
429453
const FileSystemDatasetWriteOptions& write_options_;
430-
DatasetWriterState* writer_state_;
454+
std::shared_ptr<DatasetWriterState> writer_state_;
431455
Future<> init_future_;
432456
std::string current_filename_;
433457
std::unordered_set<std::string> used_filenames_;
434-
DatasetWriterFileQueue* latest_open_file_ = nullptr;
435-
std::unique_ptr<util::ThrottledAsyncTaskScheduler> latest_open_file_tasks_;
458+
std::shared_ptr<DatasetWriterFileQueue> latest_open_file_;
436459
uint64_t rows_written_ = 0;
437460
uint32_t file_counter_ = 0;
438461
};
@@ -520,11 +543,26 @@ class DatasetWriter::DatasetWriterImpl {
520543
return Status::OK();
521544
})),
522545
write_options_(std::move(write_options)),
523-
writer_state_(max_rows_queued, write_options_.max_open_files,
524-
CalculateMaxRowsStaged(max_rows_queued)),
546+
writer_state_(std::make_shared<DatasetWriterState>(
547+
max_rows_queued, write_options_.max_open_files,
548+
CalculateMaxRowsStaged(max_rows_queued))),
525549
pause_callback_(std::move(pause_callback)),
526550
resume_callback_(std::move(resume_callback)) {}
527551

552+
~DatasetWriterImpl() {
553+
// In case something went wrong (e.g. an IO error occurred), some tasks
554+
// may be left dangling in a ThrottledAsyncTaskScheduler and that may
555+
// lead to memory leaks via shared_ptr reference cycles (this can show up
556+
// in some unit tests under Valgrind).
557+
// To prevent this, explicitly break reference cycles at DatasetWriter
558+
// destruction.
559+
// The alternative is to use weak_from_this() thoroughly in async callbacks,
560+
// but that makes for less readable code.
561+
for (const auto& directory_queue : directory_queues_) {
562+
directory_queue.second->Abort();
563+
}
564+
}
565+
528566
Future<> WriteAndCheckBackpressure(std::shared_ptr<RecordBatch> batch,
529567
const std::string& directory,
530568
const std::string& prefix) {
@@ -592,8 +630,10 @@ class DatasetWriter::DatasetWriterImpl {
592630
"DatasetWriter::FinishAll"sv);
593631
// Reset write_tasks_ to signal that we are done adding tasks, this will allow
594632
// us to invoke the finish callback once the tasks wrap up.
595-
std::lock_guard lg(mutex_);
596-
write_tasks_.reset();
633+
{
634+
std::lock_guard lg(mutex_);
635+
write_tasks_.reset();
636+
}
597637
}
598638

599639
protected:
@@ -621,7 +661,7 @@ class DatasetWriter::DatasetWriterImpl {
621661
&directory_queues_, directory + prefix,
622662
[this, &batch, &directory, &prefix](const std::string& key) {
623663
return DatasetWriterDirectoryQueue::Make(scheduler_, write_options_,
624-
&writer_state_, batch->schema(),
664+
writer_state_, batch->schema(),
625665
directory, prefix);
626666
}));
627667
std::shared_ptr<DatasetWriterDirectoryQueue> dir_queue = dir_queue_itr->second;
@@ -643,16 +683,16 @@ class DatasetWriter::DatasetWriterImpl {
643683
continue;
644684
}
645685
backpressure =
646-
writer_state_.rows_in_flight_throttle.Acquire(next_chunk->num_rows());
686+
writer_state_->rows_in_flight_throttle.Acquire(next_chunk->num_rows());
647687
if (!backpressure.is_finished()) {
648688
EVENT_ON_CURRENT_SPAN("DatasetWriter::Backpressure::TooManyRowsQueued");
649689
break;
650690
}
651691
if (will_open_file) {
652-
backpressure = writer_state_.open_files_throttle.Acquire(1);
692+
backpressure = writer_state_->open_files_throttle.Acquire(1);
653693
if (!backpressure.is_finished()) {
654694
EVENT_ON_CURRENT_SPAN("DatasetWriter::Backpressure::TooManyOpenFiles");
655-
writer_state_.rows_in_flight_throttle.Release(next_chunk->num_rows());
695+
writer_state_->rows_in_flight_throttle.Release(next_chunk->num_rows());
656696
RETURN_NOT_OK(TryCloseLargestFile());
657697
break;
658698
}
@@ -664,7 +704,7 @@ class DatasetWriter::DatasetWriterImpl {
664704
//
665705
// `open_files_throttle` will be handed by `DatasetWriterDirectoryQueue`
666706
// so we don't need to release it here.
667-
writer_state_.rows_in_flight_throttle.Release(next_chunk->num_rows());
707+
writer_state_->rows_in_flight_throttle.Release(next_chunk->num_rows());
668708
return s;
669709
}
670710
batch = std::move(remainder);
@@ -685,7 +725,7 @@ class DatasetWriter::DatasetWriterImpl {
685725
std::unique_ptr<util::ThrottledAsyncTaskScheduler> write_tasks_;
686726
Future<> finish_fut_ = Future<>::Make();
687727
FileSystemDatasetWriteOptions write_options_;
688-
DatasetWriterState writer_state_;
728+
std::shared_ptr<DatasetWriterState> writer_state_;
689729
std::function<void()> pause_callback_;
690730
std::function<void()> resume_callback_;
691731
// Map from directory + prefix to the queue for that directory

cpp/src/arrow/util/async_util.cc

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,11 @@ class ThrottledAsyncTaskSchedulerImpl
316316
#endif
317317
queue_->Push(std::move(task));
318318
lk.unlock();
319-
maybe_backoff->AddCallback(
320-
[weak_self = std::weak_ptr<ThrottledAsyncTaskSchedulerImpl>(
321-
shared_from_this())](const Status& st) {
322-
if (st.ok()) {
323-
if (auto self = weak_self.lock()) {
324-
self->ContinueTasks();
325-
}
326-
}
327-
});
319+
maybe_backoff->AddCallback([weak_self = weak_from_this()](const Status& st) {
320+
if (auto self = weak_self.lock(); self && st.ok()) {
321+
self->ContinueTasks();
322+
}
323+
});
328324
return true;
329325
} else {
330326
lk.unlock();
@@ -350,8 +346,9 @@ class ThrottledAsyncTaskSchedulerImpl
350346
self = shared_from_this()]() mutable -> Result<Future<>> {
351347
ARROW_ASSIGN_OR_RAISE(Future<> inner_fut, (*inner_task)());
352348
if (!inner_fut.TryAddCallback([&] {
353-
return [latched_cost, self = std::move(self)](const Status& st) -> void {
354-
if (st.ok()) {
349+
return [latched_cost,
350+
weak_self = self->weak_from_this()](const Status& st) -> void {
351+
if (auto self = weak_self.lock(); self && st.ok()) {
355352
self->throttle_->Release(latched_cost);
356353
self->ContinueTasks();
357354
}
@@ -360,6 +357,7 @@ class ThrottledAsyncTaskSchedulerImpl
360357
// If the task is already finished then don't run ContinueTasks
361358
// if we are already running it so we can avoid stack overflow
362359
self->throttle_->Release(latched_cost);
360+
inner_task.reset();
363361
if (!in_continue) {
364362
self->ContinueTasks();
365363
}
@@ -377,8 +375,8 @@ class ThrottledAsyncTaskSchedulerImpl
377375
if (maybe_backoff) {
378376
lk.unlock();
379377
if (!maybe_backoff->TryAddCallback([&] {
380-
return [self = shared_from_this()](const Status& st) {
381-
if (st.ok()) {
378+
return [weak_self = weak_from_this()](const Status& st) {
379+
if (auto self = weak_self.lock(); self && st.ok()) {
382380
self->ContinueTasks();
383381
}
384382
};

0 commit comments

Comments
 (0)