Skip to content

Commit

Permalink
GH-41190: [C++] support for single threaded joins (#41125)
Browse files Browse the repository at this point in the history
When I initially added single threading support, I didn't do asof joins and sorted merge joins, because the code for these operations uses threads internally. This is a small check-in to add support for them. Tests run okay in single-threaded, I'm pushing it here to run full tests and check I didn't break the threaded case.

I'm pushing this now because making this work saves adding a load of threading checks in python (this currently breaks single-threaded python i.e. emscripten).
* GitHub Issue: #41190

Lead-authored-by: Joe Marshall <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
  • Loading branch information
joemarshall and zanmato1984 authored May 29, 2024
1 parent 774ee0f commit a2453bd
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 33 deletions.
21 changes: 5 additions & 16 deletions cpp/src/arrow/acero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,8 @@ add_arrow_acero_test(hash_join_node_test SOURCES hash_join_node_test.cc
bloom_filter_test.cc)
add_arrow_acero_test(pivot_longer_node_test SOURCES pivot_longer_node_test.cc)

# asof_join_node and sorted_merge_node use std::thread internally
# and doesn't use ThreadPool so it will
# be broken if threading is turned off
if(ARROW_ENABLE_THREADING)
add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc)
add_arrow_acero_test(sorted_merge_node_test SOURCES sorted_merge_node_test.cc)
endif()
add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc)
add_arrow_acero_test(sorted_merge_node_test SOURCES sorted_merge_node_test.cc)

add_arrow_acero_test(tpch_node_test SOURCES tpch_node_test.cc)
add_arrow_acero_test(union_node_test SOURCES union_node_test.cc)
Expand Down Expand Up @@ -228,9 +223,7 @@ if(ARROW_BUILD_BENCHMARKS)
add_arrow_acero_benchmark(project_benchmark SOURCES benchmark_util.cc
project_benchmark.cc)

if(ARROW_ENABLE_THREADING)
add_arrow_acero_benchmark(asof_join_benchmark SOURCES asof_join_benchmark.cc)
endif()
add_arrow_acero_benchmark(asof_join_benchmark SOURCES asof_join_benchmark.cc)

add_arrow_acero_benchmark(tpch_benchmark SOURCES tpch_benchmark.cc)

Expand All @@ -253,9 +246,7 @@ if(ARROW_BUILD_BENCHMARKS)
target_link_libraries(arrow-acero-expression-benchmark PUBLIC arrow_acero_static)
target_link_libraries(arrow-acero-filter-benchmark PUBLIC arrow_acero_static)
target_link_libraries(arrow-acero-project-benchmark PUBLIC arrow_acero_static)
if(ARROW_ENABLE_THREADING)
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_static)
endif()
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_static)
target_link_libraries(arrow-acero-tpch-benchmark PUBLIC arrow_acero_static)
if(ARROW_BUILD_OPENMP_BENCHMARKS)
target_link_libraries(arrow-acero-hash-join-benchmark PUBLIC arrow_acero_static)
Expand All @@ -264,9 +255,7 @@ if(ARROW_BUILD_BENCHMARKS)
target_link_libraries(arrow-acero-expression-benchmark PUBLIC arrow_acero_shared)
target_link_libraries(arrow-acero-filter-benchmark PUBLIC arrow_acero_shared)
target_link_libraries(arrow-acero-project-benchmark PUBLIC arrow_acero_shared)
if(ARROW_ENABLE_THREADING)
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_shared)
endif()
target_link_libraries(arrow-acero-asof-join-benchmark PUBLIC arrow_acero_shared)
target_link_libraries(arrow-acero-tpch-benchmark PUBLIC arrow_acero_shared)
if(ARROW_BUILD_OPENMP_BENCHMARKS)
target_link_libraries(arrow-acero-hash-join-benchmark PUBLIC arrow_acero_shared)
Expand Down
85 changes: 77 additions & 8 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,8 @@ class AsofJoinNode : public ExecNode {
}
}

#ifdef ARROW_ENABLE_THREADING

template <typename Callable>
struct Defer {
Callable callable;
Expand Down Expand Up @@ -1100,6 +1102,7 @@ class AsofJoinNode : public ExecNode {
}

static void ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); }
#endif

public:
AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> input_labels,
Expand Down Expand Up @@ -1131,8 +1134,10 @@ class AsofJoinNode : public ExecNode {
}

virtual ~AsofJoinNode() {
process_.Push(false); // poison pill
#ifdef ARROW_ENABLE_THREADING
PushProcess(false);
process_thread_.join();
#endif
}

const std::vector<col_index_t>& indices_of_on_key() { return indices_of_on_key_; }
Expand Down Expand Up @@ -1410,7 +1415,8 @@ class AsofJoinNode : public ExecNode {
rb->ToString(), DEBUG_MANIP(std::endl));

ARROW_RETURN_NOT_OK(state_.at(k)->Push(rb));
process_.Push(true);
PushProcess(true);

return Status::OK();
}

Expand All @@ -1425,31 +1431,88 @@ class AsofJoinNode : public ExecNode {
// The reason for this is that there are cases at the end of a table where we don't
// know whether the RHS of the join is up-to-date until we know that the table is
// finished.
process_.Push(true);
PushProcess(true);

return Status::OK();
}
void PushProcess(bool value) {
#ifdef ARROW_ENABLE_THREADING
process_.Push(value);
#else
if (value) {
ProcessNonThreaded();
} else if (!process_task_.is_finished()) {
EndFromSingleThread();
}
#endif
}

Status StartProducing() override {
#ifndef ARROW_ENABLE_THREADING
return Status::NotImplemented("ASOF join requires threading enabled");
bool ProcessNonThreaded() {
while (!process_task_.is_finished()) {
Result<std::shared_ptr<RecordBatch>> result = ProcessInner();

if (result.ok()) {
auto out_rb = *result;
if (!out_rb) break;
ExecBatch out_b(*out_rb);
out_b.index = batches_produced_++;
DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl),
out_rb->ToString(), DEBUG_MANIP(std::endl));
Status st = output_->InputReceived(this, std::move(out_b));
if (!st.ok()) {
// this isn't really from a thread,
// but we call through to this for consistency
EndFromSingleThread(std::move(st));
return false;
}
} else {
// this isn't really from a thread,
// but we call through to this for consistency
EndFromSingleThread(result.status());
return false;
}
}
auto& lhs = *state_.at(0);
if (lhs.Finished() && !process_task_.is_finished()) {
EndFromSingleThread(Status::OK());
}
return true;
}

void EndFromSingleThread(Status st = Status::OK()) {
process_task_.MarkFinished(st);
if (st.ok()) {
st = output_->InputFinished(this, batches_produced_);
}
for (const auto& s : state_) {
st &= s->ForceShutdown();
}
}

#endif

Status StartProducing() override {
ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask(
"AsofJoinNode::ProcessThread"));
if (!process_task_.is_valid()) {
// Plan has already aborted. Do not start process thread
return Status::OK();
}
#ifdef ARROW_ENABLE_THREADING
process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this);
#endif
return Status::OK();
}

void PauseProducing(ExecNode* output, int32_t counter) override {}
void ResumeProducing(ExecNode* output, int32_t counter) override {}

Status StopProducingImpl() override {
#ifdef ARROW_ENABLE_THREADING
process_.Clear();
process_.Push(false);
#endif
PushProcess(false);
return Status::OK();
}

Expand Down Expand Up @@ -1479,11 +1542,13 @@ class AsofJoinNode : public ExecNode {

// Backpressure counter common to all inputs
std::atomic<int32_t> backpressure_counter_;
#ifdef ARROW_ENABLE_THREADING
// Queue for triggering processing of a given input
// (a false value is a poison pill)
ConcurrentQueue<bool> process_;
// Worker thread
std::thread process_thread_;
#endif
Future<> process_task_;

// In-progress batches produced
Expand Down Expand Up @@ -1511,9 +1576,13 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs,
debug_os_(join_options.debug_opts ? join_options.debug_opts->os : nullptr),
debug_mutex_(join_options.debug_opts ? join_options.debug_opts->mutex : nullptr),
#endif
backpressure_counter_(1),
backpressure_counter_(1)
#ifdef ARROW_ENABLE_THREADING
,
process_(),
process_thread_() {
process_thread_()
#endif
{
for (auto& key_hasher : key_hashers_) {
key_hasher->node_ = this;
}
Expand Down
52 changes: 43 additions & 9 deletions cpp/src/arrow/acero/sorted_merge_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,22 @@ class SortedMergeNode : public ExecNode {
: ExecNode(plan, inputs, GetInputLabels(inputs), std::move(output_schema)),
ordering_(std::move(new_ordering)),
input_counter(inputs_.size()),
output_counter(inputs_.size()),
process_thread() {
output_counter(inputs_.size())
#ifdef ARROW_ENABLE_THREADING
,
process_thread()
#endif
{
SetLabel("sorted_merge");
}

~SortedMergeNode() override {
process_queue.Push(
kPoisonPill); // poison pill
// We might create a temporary (such as to inspect the output
// schema), in which case there isn't anything to join
PushTask(kPoisonPill);
#ifdef ARROW_ENABLE_THREADING
if (process_thread.joinable()) {
process_thread.join();
}
#endif
}

static arrow::Result<arrow::acero::ExecNode*> Make(
Expand Down Expand Up @@ -355,10 +358,25 @@ class SortedMergeNode : public ExecNode {
// InputState's ConcurrentQueue manages locking
input_counter[index] += rb->num_rows();
ARROW_RETURN_NOT_OK(state[index]->Push(rb));
process_queue.Push(kNewTask);
PushTask(kNewTask);
return Status::OK();
}

void PushTask(bool ok) {
#ifdef ARROW_ENABLE_THREADING
process_queue.Push(ok);
#else
if (process_task.is_finished()) {
return;
}
if (ok == kNewTask) {
PollOnce();
} else {
EndFromProcessThread();
}
#endif
}

arrow::Status InputFinished(arrow::acero::ExecNode* input, int total_batches) override {
ARROW_DCHECK(std_has(inputs_, input));
{
Expand All @@ -368,7 +386,8 @@ class SortedMergeNode : public ExecNode {
state.at(k)->set_total_batches(total_batches);
}
// Trigger a final process call for stragglers
process_queue.Push(kNewTask);
PushTask(kNewTask);

return Status::OK();
}

Expand All @@ -379,13 +398,17 @@ class SortedMergeNode : public ExecNode {
// Plan has already aborted. Do not start process thread
return Status::OK();
}
#ifdef ARROW_ENABLE_THREADING
process_thread = std::thread(&SortedMergeNode::StartPoller, this);
#endif
return Status::OK();
}

arrow::Status StopProducingImpl() override {
#ifdef ARROW_ENABLE_THREADING
process_queue.Clear();
process_queue.Push(kPoisonPill);
#endif
PushTask(kPoisonPill);
return Status::OK();
}

Expand All @@ -408,13 +431,20 @@ class SortedMergeNode : public ExecNode {
<< input_counter[i] << " != " << output_counter[i];
}

#ifdef ARROW_ENABLE_THREADING
ARROW_UNUSED(
plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable {
Defer cleanup([this, &st]() { process_task.MarkFinished(st); });
if (st.ok()) {
st = output_->InputFinished(this, batches_produced);
}
}));
#else
process_task.MarkFinished(st);
if (st.ok()) {
st = output_->InputFinished(this, batches_produced);
}
#endif
}

bool CheckEnded() {
Expand Down Expand Up @@ -552,6 +582,7 @@ class SortedMergeNode : public ExecNode {
return true;
}

#ifdef ARROW_ENABLE_THREADING
void EmitBatches() {
while (true) {
// Implementation note: If the queue is empty, we will block here
Expand All @@ -567,6 +598,7 @@ class SortedMergeNode : public ExecNode {

/// The entry point for processThread
static void StartPoller(SortedMergeNode* node) { node->EmitBatches(); }
#endif

arrow::Ordering ordering_;

Expand All @@ -583,11 +615,13 @@ class SortedMergeNode : public ExecNode {

std::atomic<int32_t> batches_produced{0};

#ifdef ARROW_ENABLE_THREADING
// Queue to trigger processing of a given input. False acts as a poison pill
ConcurrentQueue<bool> process_queue;
// Once StartProducing is called, we initialize this thread to poll the
// input states and emit batches
std::thread process_thread;
#endif
arrow::Future<> process_task;

// Map arg index --> completion counter
Expand Down

0 comments on commit a2453bd

Please sign in to comment.