From c57602849ff665bcc35c52223653ffc350b5d77f Mon Sep 17 00:00:00 2001 From: michalursa Date: Thu, 3 Feb 2022 00:27:13 -0800 Subject: [PATCH 1/2] Faster version of hash join --- cpp/src/arrow/CMakeLists.txt | 2 + cpp/src/arrow/compute/exec/hash_join.cc | 102 +- cpp/src/arrow/compute/exec/hash_join.h | 10 +- .../arrow/compute/exec/hash_join_benchmark.cc | 9 +- cpp/src/arrow/compute/exec/hash_join_node.cc | 54 +- cpp/src/arrow/compute/exec/key_compare.cc | 130 +- cpp/src/arrow/compute/exec/key_compare.h | 40 +- .../arrow/compute/exec/key_compare_avx2.cc | 95 +- cpp/src/arrow/compute/exec/key_encode.cc | 103 +- cpp/src/arrow/compute/exec/key_encode.h | 29 +- cpp/src/arrow/compute/exec/key_hash.cc | 22 + cpp/src/arrow/compute/exec/key_hash.h | 10 + cpp/src/arrow/compute/exec/key_map.cc | 206 +- cpp/src/arrow/compute/exec/key_map.h | 97 +- cpp/src/arrow/compute/exec/key_map_avx2.cc | 54 +- cpp/src/arrow/compute/exec/partition_util.h | 36 + cpp/src/arrow/compute/exec/schema_util.h | 74 +- cpp/src/arrow/compute/exec/swiss_join.cc | 3279 +++++++++++++++++ cpp/src/arrow/compute/exec/swiss_join.h | 874 +++++ cpp/src/arrow/compute/exec/swiss_join_avx2.cc | 198 + cpp/src/arrow/compute/exec/util.h | 49 +- .../arrow/compute/kernels/hash_aggregate.cc | 41 +- 22 files changed, 5127 insertions(+), 387 deletions(-) create mode 100644 cpp/src/arrow/compute/exec/swiss_join.cc create mode 100644 cpp/src/arrow/compute/exec/swiss_join.h create mode 100644 cpp/src/arrow/compute/exec/swiss_join_avx2.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index e9e826097b305..247964afa2a37 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -399,6 +399,7 @@ if(ARROW_COMPUTE) compute/exec/project_node.cc compute/exec/sink_node.cc compute/exec/source_node.cc + compute/exec/swiss_join.cc compute/exec/task_util.cc compute/exec/tpch_node.cc compute/exec/union_node.cc @@ -450,6 +451,7 @@ if(ARROW_COMPUTE) append_avx2_src(compute/exec/key_encode_avx2.cc) append_avx2_src(compute/exec/key_hash_avx2.cc) append_avx2_src(compute/exec/key_map_avx2.cc) + append_avx2_src(compute/exec/swiss_join_avx2.cc) append_avx2_src(compute/exec/util_avx2.cc) list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 5a9afaa5bdf5f..b6fd0a851882e 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -83,7 +83,8 @@ class HashJoinBasicImpl : public HashJoinImpl { } Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, - size_t num_threads, HashJoinSchema* schema_mgr, + size_t num_threads, const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, std::vector key_cmp, Expression filter, OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, @@ -98,7 +99,8 @@ class HashJoinBasicImpl : public HashJoinImpl { ctx_ = ctx; join_type_ = join_type; num_threads_ = num_threads; - schema_mgr_ = schema_mgr; + schema_[0] = proj_map_left; + schema_[1] = proj_map_right; key_cmp_ = std::move(key_cmp); filter_ = std::move(filter); output_batch_callback_ = std::move(output_batch_callback); @@ -139,12 +141,11 @@ class HashJoinBasicImpl : public HashJoinImpl { private: void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) { std::vector data_types; - int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); + int num_cols = schema_[side]->num_cols(projection_handle); data_types.resize(num_cols); for (int icol = 0; icol < num_cols; ++icol) { - data_types[icol] = - ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol), - ValueDescr::ARRAY); + data_types[icol] = ValueDescr(schema_[side]->data_type(projection_handle, icol), + ValueDescr::ARRAY); } encoder->Init(data_types, ctx_); encoder->Clear(); @@ -155,8 +156,7 @@ class HashJoinBasicImpl : public HashJoinImpl { ThreadLocalState& local_state = local_states_[thread_index]; if (!local_state.is_initialized) { InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys); - bool has_payload = - (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_payload) { InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads); } @@ -168,11 +168,10 @@ class HashJoinBasicImpl : public HashJoinImpl { Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder, const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) { ExecBatch projected({}, batch.length); - int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); + int num_cols = schema_[side]->num_cols(projection_handle); projected.values.resize(num_cols); - auto to_input = - schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT); + auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT); for (int icol = 0; icol < num_cols; ++icol) { projected.values[icol] = batch.values[to_input.get(icol)]; } @@ -235,16 +234,13 @@ class HashJoinBasicImpl : public HashJoinImpl { ExecBatch* opt_left_payload, ExecBatch* opt_right_key, ExecBatch* opt_right_payload) { ExecBatch result({}, batch_size_next); - int num_out_cols_left = - schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT); - int num_out_cols_right = - schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT); + int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT); + int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT); result.values.resize(num_out_cols_left + num_out_cols_right); - auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, - HashJoinProjection::KEY); - auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, - HashJoinProjection::PAYLOAD); + auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + auto from_payload = + schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); for (int icol = 0; icol < num_out_cols_left; ++icol) { bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField()); bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField()); @@ -262,10 +258,9 @@ class HashJoinBasicImpl : public HashJoinImpl { ? opt_left_key->values[from_key.get(icol)] : opt_left_payload->values[from_payload.get(icol)]; } - from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT, - HashJoinProjection::KEY); - from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT, - HashJoinProjection::PAYLOAD); + from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + from_payload = + schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); for (int icol = 0; icol < num_out_cols_right; ++icol) { bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField()); bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField()); @@ -284,7 +279,7 @@ class HashJoinBasicImpl : public HashJoinImpl { : opt_right_payload->values[from_payload.get(icol)]; } - output_batch_callback_(std::move(result)); + output_batch_callback_(0, std::move(result)); // Update the counter of produced batches // @@ -310,13 +305,13 @@ class HashJoinBasicImpl : public HashJoinImpl { hash_table_keys_.Decode(match_right.size(), match_right.data())); ExecBatch left_payload; - if (!schema_mgr_->LeftPayloadIsEmpty()) { + if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) { ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode( match_left.size(), match_left.data())); } ExecBatch right_payload; - if (!schema_mgr_->RightPayloadIsEmpty()) { + if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) { ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode( match_right.size(), match_right.data())); } @@ -336,14 +331,14 @@ class HashJoinBasicImpl : public HashJoinImpl { } }; - SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map( - HashJoinProjection::FILTER, HashJoinProjection::KEY); - SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map( - HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); - SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map( - HashJoinProjection::FILTER, HashJoinProjection::KEY); - SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map( - HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + SchemaProjectionMap left_to_key = + schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + SchemaProjectionMap left_to_pay = + schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + SchemaProjectionMap right_to_key = + schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + SchemaProjectionMap right_to_pay = + schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); AppendFields(left_to_key, left_to_pay, left_key, left_payload); AppendFields(right_to_key, right_to_pay, right_key, right_payload); @@ -419,15 +414,14 @@ class HashJoinBasicImpl : public HashJoinImpl { bool has_left = (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI && - schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0); + schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0); bool has_right = (join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI && - schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0); + schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0); bool has_left_payload = - has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0); bool has_right_payload = - has_right && - (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); + has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0); ThreadLocalState& local_state = local_states_[thread_index]; InitLocalStateIfNeeded(thread_index); @@ -450,7 +444,7 @@ class HashJoinBasicImpl : public HashJoinImpl { ARROW_ASSIGN_OR_RAISE(right_key, hash_table_keys_.Decode(batch_size_next, opt_right_ids)); // Post process build side keys that use dictionary - RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_)); + RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_)); } if (has_right_payload) { ARROW_ASSIGN_OR_RAISE(right_payload, @@ -550,8 +544,7 @@ class HashJoinBasicImpl : public HashJoinImpl { RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys, batch, &batch_key_for_lookups)); - bool has_left_payload = - (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_left_payload) { local_state.exec_batch_payloads.Clear(); RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD, @@ -563,13 +556,13 @@ class HashJoinBasicImpl : public HashJoinImpl { local_state.match_left.clear(); local_state.match_right.clear(); - bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded( - thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_); + bool use_key_batch_for_dicts = + dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_); RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys; if (use_key_batch_for_dicts) { - RETURN_NOT_OK(dict_probe_.EncodeBatch( - thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_, - batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_)); + RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1], + dict_build_, batch, &row_encoder_for_lookups, + &batch_key_for_lookups, ctx_)); } // Collect information about all nulls in key columns. @@ -609,9 +602,8 @@ class HashJoinBasicImpl : public HashJoinImpl { if (batches.empty()) { hash_table_empty_ = true; } else { - dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_); - bool has_payload = - (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); + dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_); + bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_payload) { InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_); } @@ -626,11 +618,11 @@ class HashJoinBasicImpl : public HashJoinImpl { } else if (hash_table_empty_) { hash_table_empty_ = false; - RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_)); + RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_)); } int32_t num_rows_before = hash_table_keys_.num_rows(); - RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1], - batch, &hash_table_keys_, ctx_)); + RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch, + &hash_table_keys_, ctx_)); if (has_payload) { RETURN_NOT_OK( EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch)); @@ -643,7 +635,7 @@ class HashJoinBasicImpl : public HashJoinImpl { } if (hash_table_empty_) { - RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_)); + RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_)); } return Status::OK(); @@ -869,7 +861,7 @@ class HashJoinBasicImpl : public HashJoinImpl { ExecContext* ctx_; JoinType join_type_; size_t num_threads_; - HashJoinSchema* schema_mgr_; + const HashJoinProjectionMaps* schema_[2]; std::vector key_cmp_; Expression filter_; std::unique_ptr scheduler_; diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 12455f0c6d021..9aaadfcd5e796 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -57,6 +57,10 @@ class ARROW_EXPORT HashJoinSchema { const std::string& left_field_name_prefix, const std::string& right_field_name_prefix); + bool HasDictionaries() const; + + bool HasLargeBinary() const; + Result BindFilter(Expression filter, const Schema& left_schema, const Schema& right_schema); std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, @@ -98,12 +102,13 @@ class ARROW_EXPORT HashJoinSchema { class HashJoinImpl { public: - using OutputBatchCallback = std::function; + using OutputBatchCallback = std::function; using FinishedCallback = std::function; virtual ~HashJoinImpl() = default; virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, - size_t num_threads, HashJoinSchema* schema_mgr, + size_t num_threads, const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, std::vector key_cmp, Expression filter, OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, @@ -113,6 +118,7 @@ class HashJoinImpl { virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0; static Result> MakeBasic(); + static Result> MakeSwiss(); protected: util::tracing::Span span_; diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc index 3d4271b6cb9d1..a69de21f92fb4 100644 --- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc @@ -77,6 +77,8 @@ class JoinBenchmark { build_metadata["null_probability"] = std::to_string(settings.null_percentage); build_metadata["min"] = std::to_string(min_build_value); build_metadata["max"] = std::to_string(max_build_value); + build_metadata["min_length"] = "2"; + build_metadata["max_length"] = "20"; std::unordered_map probe_metadata; probe_metadata["null_probability"] = std::to_string(settings.null_percentage); @@ -124,7 +126,7 @@ class JoinBenchmark { DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_.schema, left_keys, *r_batches_.schema, right_keys, filter, "l_", "r_")); - join_ = *HashJoinImpl::MakeBasic(); + join_ = *HashJoinImpl::MakeSwiss(); omp_set_num_threads(settings.num_threads); auto schedule_callback = [](std::function func) -> Status { @@ -135,8 +137,9 @@ class JoinBenchmark { DCHECK_OK(join_->Init( ctx_.get(), settings.join_type, !is_parallel, settings.num_threads, - schema_mgr_.get(), {JoinKeyCmp::EQ}, std::move(filter), [](ExecBatch) {}, - [](int64_t x) {}, schedule_callback)); + &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), {JoinKeyCmp::EQ}, + std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {}, + schedule_callback)); } void RunJoin() { diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 93e54c6400e57..74259ada37426 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -453,6 +453,34 @@ Status HashJoinSchema::CollectFilterColumns(std::vector& left_filter, return Status::OK(); } +bool HashJoinSchema::HasDictionaries() const { + for (int side = 0; side <= 1; ++side) { + for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT); + ++icol) { + const std::shared_ptr& column_type = + proj_maps[side].data_type(HashJoinProjection::INPUT, icol); + if (column_type->id() == Type::DICTIONARY) { + return true; + } + } + } + return false; +} + +bool HashJoinSchema::HasLargeBinary() const { + for (int side = 0; side <= 1; ++side) { + for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT); + ++icol) { + const std::shared_ptr& column_type = + proj_maps[side].data_type(HashJoinProjection::INPUT, icol); + if (is_large_binary_like(column_type->id())) { + return true; + } + } + } + return false; +} + class HashJoinNode : public ExecNode { public: HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options, @@ -504,8 +532,26 @@ class HashJoinNode : public ExecNode { // Generate output schema std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( join_options.output_suffix_for_left, join_options.output_suffix_for_right); + // Create hash join implementation object - ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, HashJoinImpl::MakeBasic()); + // SwissJoin does not support: + // a) 64-bit string offsets + // b) residual predicates + // c) dictionaries + // + bool use_swiss_join; +#if ARROW_LITTLE_ENDIAN + use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() && + !schema_mgr->HasLargeBinary(); +#else + use_swiss_join = false; +#endif + std::unique_ptr impl; + if (use_swiss_join) { + ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeSwiss()); + } else { + ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeBasic()); + } return plan->EmplaceNode( plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr), @@ -584,8 +630,10 @@ class HashJoinNode : public ExecNode { RETURN_NOT_OK(impl_->Init( plan_->exec_context(), join_type_, use_sync_execution, num_threads, - schema_mgr_.get(), key_cmp_, filter_, - [this](ExecBatch batch) { this->OutputBatchCallback(batch); }, + &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), key_cmp_, filter_, + [this](int64_t /*ignored*/, ExecBatch batch) { + this->OutputBatchCallback(batch); + }, [this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); }, [this](std::function func) -> Status { return this->ScheduleTaskCallback(std::move(func)); diff --git a/cpp/src/arrow/compute/exec/key_compare.cc b/cpp/src/arrow/compute/exec/key_compare.cc index ed94bf72301d9..dfe83b5e06f04 100644 --- a/cpp/src/arrow/compute/exec/key_compare.cc +++ b/cpp/src/arrow/compute/exec/key_compare.cc @@ -30,13 +30,11 @@ namespace arrow { namespace compute { template -void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, - KeyEncoder::KeyEncoderContext* ctx, - const KeyEncoder::KeyColumnArray& col, - const KeyEncoder::KeyRowArray& rows, - uint8_t* match_bytevector) { +void KeyCompare::NullUpdateColumnToRow( + uint32_t id_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, + const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, + uint8_t* match_bytevector, bool are_cols_in_encoding_order) { if (!rows.has_any_nulls(ctx) && !col.data(0)) { return; } @@ -49,6 +47,9 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com } #endif + uint32_t null_bit_id = + are_cols_in_encoding_order ? id_col : rows.metadata().pos_after_encoding(id_col); + if (!col.data(0)) { // Remove rows from the result for which the column value is a null const uint8_t* null_masks = rows.null_masks(); @@ -56,11 +57,12 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - int64_t bitid = irow_right * null_mask_num_bytes * 8 + id_col; + int64_t bitid = irow_right * null_mask_num_bytes * 8 + null_bit_id; match_bytevector[i] &= (bit_util::GetBit(null_masks, bitid) ? 0 : 0xff); } } else if (!rows.has_any_nulls(ctx)) { - // Remove rows from the result for which the column value on left side is null + // Remove rows from the result for which the column value on left side is + // null const uint8_t* non_nulls = col.data(0); ARROW_DCHECK(non_nulls); for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { @@ -76,7 +78,7 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + id_col; + int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + null_bit_id; int right_null = bit_util::GetBit(null_masks, bitid_right) ? 0xff : 0; int left_null = bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff; @@ -228,25 +230,17 @@ void KeyCompare::CompareBinaryColumnToRow( // Overwrites the match_bytevector instead of updating it template -void KeyCompare::CompareVarBinaryColumnToRow( - uint32_t id_varbinary_col, uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, - KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, - const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) { -#if defined(ARROW_HAVE_AVX2) - if (ctx->has_avx2()) { - CompareVarBinaryColumnToRow_avx2( - use_selection, is_first_varbinary_col, id_varbinary_col, num_rows_to_compare, - sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector); - return; - } -#endif - +void KeyCompare::CompareVarBinaryColumnToRowHelper( + uint32_t id_varbinary_col, uint32_t first_row_to_compare, + uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, + const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, + uint8_t* match_bytevector) { const uint32_t* offsets_left = col.offsets(); const uint32_t* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); const uint8_t* rows_right = rows.data(2); - for (uint32_t i = 0; i < num_rows_to_compare; ++i) { + for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; uint32_t begin_left = offsets_left[irow_left]; @@ -290,6 +284,27 @@ void KeyCompare::CompareVarBinaryColumnToRow( } } +// Overwrites the match_bytevector instead of updating it +template +void KeyCompare::CompareVarBinaryColumnToRow( + uint32_t id_varbinary_col, uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, + KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, + const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) { + uint32_t num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + num_processed = CompareVarBinaryColumnToRow_avx2( + use_selection, is_first_varbinary_col, id_varbinary_col, num_rows_to_compare, + sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector); + } +#endif + + CompareVarBinaryColumnToRowHelper( + id_varbinary_col, num_processed, num_rows_to_compare, sel_left_maybe_null, + left_to_right_map, ctx, col, rows, match_bytevector); +} + void KeyCompare::AndByteVectors(KeyEncoder::KeyEncoderContext* ctx, uint32_t num_elements, uint8_t* bytevector_A, const uint8_t* bytevector_B) { uint32_t num_processed = 0; @@ -306,14 +321,13 @@ void KeyCompare::AndByteVectors(KeyEncoder::KeyEncoderContext* ctx, uint32_t num } } -void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, - KeyEncoder::KeyEncoderContext* ctx, - uint32_t* out_num_rows, - uint16_t* out_sel_left_maybe_same, - const std::vector& cols, - const KeyEncoder::KeyRowArray& rows) { +void KeyCompare::CompareColumnsToRows( + uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, + uint32_t* out_num_rows, uint16_t* out_sel_left_maybe_same, + const std::vector& cols, + const KeyEncoder::KeyRowArray& rows, bool are_cols_in_encoding_order, + uint8_t* out_match_bitvector_maybe_null) { if (num_rows_to_compare == 0) { *out_num_rows = 0; return; @@ -334,6 +348,7 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, bool is_first_column = true; for (size_t icol = 0; icol < cols.size(); ++icol) { const KeyEncoder::KeyColumnArray& col = cols[icol]; + if (col.metadata().is_null_type) { // If this null type col is the first column, the match_bytevector_A needs to be // initialized with 0xFF. Otherwise, the calculation can be skipped @@ -342,8 +357,11 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, } continue; } - uint32_t offset_within_row = - rows.metadata().encoded_field_offset(static_cast(icol)); + + uint32_t offset_within_row = rows.metadata().encoded_field_offset( + are_cols_in_encoding_order + ? static_cast(icol) + : rows.metadata().pos_after_encoding(static_cast(icol))); if (col.metadata().is_fixed_length) { if (sel_left_maybe_null) { CompareBinaryColumnToRow( @@ -353,7 +371,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } else { // Version without using selection vector CompareBinaryColumnToRow( @@ -363,7 +382,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } if (!is_first_column) { AndByteVectors(ctx, num_rows_to_compare, match_bytevector_A, match_bytevector_B); @@ -390,7 +410,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } else { if (ivarbinary == 0) { CompareVarBinaryColumnToRow( @@ -404,7 +425,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } if (!is_first_column) { AndByteVectors(ctx, num_rows_to_compare, match_bytevector_A, match_bytevector_B); @@ -416,18 +438,26 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, util::bit_util::bytes_to_bits(ctx->hardware_flags, num_rows_to_compare, match_bytevector_A, match_bitvector); - if (sel_left_maybe_null) { - int out_num_rows_int; - util::bit_util::bits_filter_indexes(0, ctx->hardware_flags, num_rows_to_compare, - match_bitvector, sel_left_maybe_null, - &out_num_rows_int, out_sel_left_maybe_same); - *out_num_rows = out_num_rows_int; + + if (out_match_bitvector_maybe_null) { + ARROW_DCHECK(out_num_rows == nullptr); + ARROW_DCHECK(out_sel_left_maybe_same == nullptr); + memcpy(out_match_bitvector_maybe_null, match_bitvector, + bit_util::BytesForBits(num_rows_to_compare)); } else { - int out_num_rows_int; - util::bit_util::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare, - match_bitvector, &out_num_rows_int, - out_sel_left_maybe_same); - *out_num_rows = out_num_rows_int; + if (sel_left_maybe_null) { + int out_num_rows_int; + util::bit_util::bits_filter_indexes(0, ctx->hardware_flags, num_rows_to_compare, + match_bitvector, sel_left_maybe_null, + &out_num_rows_int, out_sel_left_maybe_same); + *out_num_rows = out_num_rows_int; + } else { + int out_num_rows_int; + util::bit_util::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare, + match_bitvector, &out_num_rows_int, + out_sel_left_maybe_same); + *out_num_rows = out_num_rows_int; + } } } diff --git a/cpp/src/arrow/compute/exec/key_compare.h b/cpp/src/arrow/compute/exec/key_compare.h index aeb5abbdd144d..0aeccdb4a67d3 100644 --- a/cpp/src/arrow/compute/exec/key_compare.h +++ b/cpp/src/arrow/compute/exec/key_compare.h @@ -31,26 +31,23 @@ namespace compute { class KeyCompare { public: // Returns a single 16-bit selection vector of rows that failed comparison. - // If there is input selection on the left, the resulting selection is a filtered image - // of input selection. - static void CompareColumnsToRows(uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, - KeyEncoder::KeyEncoderContext* ctx, - uint32_t* out_num_rows, - uint16_t* out_sel_left_maybe_same, - const std::vector& cols, - const KeyEncoder::KeyRowArray& rows); + // If there is input selection on the left, the resulting selection is a + // filtered image of input selection. + static void CompareColumnsToRows( + uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, + uint32_t* out_num_rows, uint16_t* out_sel_left_maybe_same, + const std::vector& cols, + const KeyEncoder::KeyRowArray& rows, bool are_cols_in_encoding_order, + uint8_t* out_match_bitvector_maybe_null = NULLPTR); private: template - static void NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, - KeyEncoder::KeyEncoderContext* ctx, - const KeyEncoder::KeyColumnArray& col, - const KeyEncoder::KeyRowArray& rows, - uint8_t* match_bytevector); + static void NullUpdateColumnToRow( + uint32_t id_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, + const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, + uint8_t* match_bytevector, bool are_cols_in_encoding_order); template static void CompareBinaryColumnToRowHelper( @@ -67,6 +64,13 @@ class KeyCompare { KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector); + template + static void CompareVarBinaryColumnToRowHelper( + uint32_t id_varlen_col, uint32_t first_row_to_compare, uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, + KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, + const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector); + template static void CompareVarBinaryColumnToRow( uint32_t id_varlen_col, uint32_t num_rows_to_compare, @@ -123,7 +127,7 @@ class KeyCompare { KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector); - static void CompareVarBinaryColumnToRow_avx2( + static uint32_t CompareVarBinaryColumnToRow_avx2( bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, diff --git a/cpp/src/arrow/compute/exec/key_compare_avx2.cc b/cpp/src/arrow/compute/exec/key_compare_avx2.cc index df13e8cae3c26..db0f0b3fadd96 100644 --- a/cpp/src/arrow/compute/exec/key_compare_avx2.cc +++ b/cpp/src/arrow/compute/exec/key_compare_avx2.cc @@ -45,6 +45,9 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( if (!rows.has_any_nulls(ctx) && !col.data(0)) { return num_rows_to_compare; } + + uint32_t null_bit_id = rows.metadata().pos_after_encoding(id_col); + if (!col.data(0)) { // Remove rows from the result for which the column value is a null const uint8_t* null_masks = rows.null_masks(); @@ -64,7 +67,7 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( } __m256i bitid = _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); - bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col)); + bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id)); __m256i right = _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); right = _mm256_and_si256( @@ -81,7 +84,8 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( num_processed = num_rows_to_compare / unroll * unroll; return num_processed; } else if (!rows.has_any_nulls(ctx)) { - // Remove rows from the result for which the column value on left side is null + // Remove rows from the result for which the column value on left side is + // null const uint8_t* non_nulls = col.data(0); ARROW_DCHECK(non_nulls); uint32_t num_processed = 0; @@ -146,7 +150,7 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( } __m256i bitid = _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); - bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col)); + bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id)); __m256i right = _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); right = _mm256_and_si256( @@ -254,22 +258,22 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r int bit_offset = 0) { __m256i left; switch (column_width) { - case 0: + case 0: { irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(bit_offset)); left = _mm256_i32gather_epi32((const int*)left_base, - _mm256_srli_epi32(irow_left, 3), 1); - left = _mm256_and_si256( - _mm256_set1_epi32(1), - _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7)))); - left = _mm256_mullo_epi32(left, _mm256_set1_epi32(0xff)); - break; + _mm256_srli_epi32(irow_left, 5), 4); + __m256i bit_selection = _mm256_sllv_epi32( + _mm256_set1_epi32(1), _mm256_and_si256(irow_left, _mm256_set1_epi32(31))); + left = _mm256_cmpeq_epi32(bit_selection, _mm256_and_si256(left, bit_selection)); + left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); + } break; case 1: left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 1); left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); break; case 2: left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 2); - left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); + left = _mm256_and_si256(left, _mm256_set1_epi32(0xffff)); break; case 4: left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 4); @@ -313,15 +317,15 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas } break; case 1: left = _mm256_cvtepu8_epi32(_mm_set1_epi64x( - reinterpret_cast(left_base)[irow_left_first / 8])); + *reinterpret_cast(left_base + irow_left_first))); break; case 2: left = _mm256_cvtepu16_epi32(_mm_loadu_si128( - reinterpret_cast(left_base) + irow_left_first / 8)); + reinterpret_cast(left_base + 2 * irow_left_first))); break; case 4: - left = _mm256_loadu_si256(reinterpret_cast(left_base) + - irow_left_first / 8); + left = _mm256_loadu_si256( + reinterpret_cast(left_base + 4 * irow_left_first)); break; default: ARROW_DCHECK(false); @@ -349,19 +353,17 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig __m256i offset_right) { auto left_base_i64 = reinterpret_cast(left_base); - __m256i left_lo = - _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8); - __m256i left_hi = - _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8); + __m256i left_lo, left_hi; if (use_selection) { left_lo = _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8); left_hi = _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8); } else { - left_lo = _mm256_loadu_si256(reinterpret_cast(left_base) + - irow_left_first / 4); - left_hi = _mm256_loadu_si256(reinterpret_cast(left_base) + - irow_left_first / 4 + 1); + left_lo = _mm256_loadu_si256( + reinterpret_cast(left_base + irow_left_first * sizeof(uint64_t))); + left_hi = _mm256_loadu_si256( + reinterpret_cast(left_base + irow_left_first * sizeof(uint64_t)) + + 1); } auto right_base_i64 = reinterpret_cast(right_base); @@ -534,7 +536,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( const __m256i* key_right_ptr = reinterpret_cast(rows_right + begin_right); int32_t j; - // length can be zero + // length is greater than zero for (j = 0; j < (static_cast(length) + 31) / 32 - 1; ++j) { __m256i key_left = _mm256_loadu_si256(key_left_ptr + j); __m256i key_right = _mm256_loadu_si256(key_right_ptr + j); @@ -571,6 +573,15 @@ uint32_t KeyCompare::NullUpdateColumnToRow_avx2( const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) { + int64_t num_rows_safe = + TailSkipForSIMD::FixBitAccess(sizeof(uint32_t), col.length(), col.bit_offset(0)); + if (sel_left_maybe_null) { + num_rows_to_compare = static_cast(TailSkipForSIMD::FixSelection( + num_rows_safe, static_cast(num_rows_to_compare), sel_left_maybe_null)); + } else { + num_rows_to_compare = static_cast(num_rows_safe); + } + if (use_selection) { return NullUpdateColumnToRowImp_avx2(id_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, @@ -587,6 +598,29 @@ uint32_t KeyCompare::CompareBinaryColumnToRow_avx2( const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) { + uint32_t col_width = col.metadata().fixed_length; + int64_t num_rows_safe = col.length(); + if (col_width == 0) { + // In this case we will access left column memory 4B at a time + num_rows_safe = + TailSkipForSIMD::FixBitAccess(sizeof(uint32_t), col.length(), col.bit_offset(1)); + } else if (col_width == 1 && col_width == 2) { + // In this case we will access left column memory 4B at a time + num_rows_safe = + TailSkipForSIMD::FixBinaryAccess(sizeof(uint32_t), col.length(), col_width); + } else if (col_width != 4 && col_width != 8) { + // In this case we will access left column memory 32B at a time + num_rows_safe = + TailSkipForSIMD::FixBinaryAccess(sizeof(__m256i), col.length(), col_width); + } + if (sel_left_maybe_null) { + num_rows_to_compare = static_cast(TailSkipForSIMD::FixSelection( + num_rows_safe, static_cast(num_rows_to_compare), sel_left_maybe_null)); + } else { + num_rows_to_compare = static_cast( + std::min(num_rows_safe, static_cast(num_rows_to_compare))); + } + if (use_selection) { return CompareBinaryColumnToRowImp_avx2(offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, @@ -598,12 +632,21 @@ uint32_t KeyCompare::CompareBinaryColumnToRow_avx2( } } -void KeyCompare::CompareVarBinaryColumnToRow_avx2( +uint32_t KeyCompare::CompareVarBinaryColumnToRow_avx2( bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) { + int64_t num_rows_safe = + TailSkipForSIMD::FixVarBinaryAccess(sizeof(__m256i), col.length(), col.offsets()); + if (use_selection) { + num_rows_to_compare = static_cast(TailSkipForSIMD::FixSelection( + num_rows_safe, static_cast(num_rows_to_compare), sel_left_maybe_null)); + } else { + num_rows_to_compare = static_cast(num_rows_safe); + } + if (use_selection) { if (is_first_varbinary_col) { CompareVarBinaryColumnToRowImp_avx2( @@ -625,6 +668,8 @@ void KeyCompare::CompareVarBinaryColumnToRow_avx2( col, rows, match_bytevector); } } + + return num_rows_to_compare; } #endif diff --git a/cpp/src/arrow/compute/exec/key_encode.cc b/cpp/src/arrow/compute/exec/key_encode.cc index f8bd7c2503ea3..e567b7945161d 100644 --- a/cpp/src/arrow/compute/exec/key_encode.cc +++ b/cpp/src/arrow/compute/exec/key_encode.cc @@ -21,8 +21,10 @@ #include +#include "arrow/compute/exec.h" #include "arrow/compute/exec/util.h" #include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/ubsan.h" namespace arrow { @@ -885,6 +887,10 @@ void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( } return left < right; }); + inverse_column_order.resize(num_cols); + for (uint32_t i = 0; i < num_cols; ++i) { + inverse_column_order[column_order[i]] = i; + } row_alignment = in_row_alignment; string_alignment = in_string_alignment; @@ -936,9 +942,8 @@ void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( } } -void KeyEncoder::Init(const std::vector& cols, KeyEncoderContext* ctx, - int row_alignment, int string_alignment) { - ctx_ = ctx; +void KeyEncoder::Init(const std::vector& cols, int row_alignment, + int string_alignment) { row_metadata_.FromColumnMetadataVector(cols, row_alignment, string_alignment); uint32_t num_cols = row_metadata_.num_cols(); uint32_t num_varbinary_cols = row_metadata_.num_varbinary_cols(); @@ -974,18 +979,24 @@ void KeyEncoder::PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows, void KeyEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const KeyRowArray& rows, - std::vector* cols) { + std::vector* cols, + int64_t hardware_flags, + util::TempVectorStack* temp_stack) { // Prepare column array vectors PrepareKeyColumnArrays(start_row_output, num_rows, *cols); + KeyEncoderContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + // Create two temp vectors with 16-bit elements auto temp_buffer_holder_A = - util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + util::TempVectorHolder(ctx.stack, static_cast(num_rows)); auto temp_buffer_A = KeyColumnArray( KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, reinterpret_cast(temp_buffer_holder_A.mutable_data()), nullptr); auto temp_buffer_holder_B = - util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + util::TempVectorHolder(ctx.stack, static_cast(num_rows)); auto temp_buffer_B = KeyColumnArray( KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, reinterpret_cast(temp_buffer_holder_B.mutable_data()), nullptr); @@ -994,7 +1005,7 @@ void KeyEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, if (!is_row_fixed_length) { EncoderOffsets::Decode(static_cast(start_row_input), static_cast(num_rows), rows, &batch_varbinary_cols_, - batch_varbinary_cols_base_offsets_, ctx_); + batch_varbinary_cols_base_offsets_, &ctx); } // Process fixed length columns @@ -1013,13 +1024,13 @@ void KeyEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, EncoderBinary::Decode(static_cast(start_row_input), static_cast(num_rows), row_metadata_.column_offsets[i], rows, &batch_all_cols_[i], - ctx_, &temp_buffer_A); + &ctx, &temp_buffer_A); i += 1; } else { EncoderBinaryPair::Decode( static_cast(start_row_input), static_cast(num_rows), row_metadata_.column_offsets[i], rows, &batch_all_cols_[i], - &batch_all_cols_[i + 1], ctx_, &temp_buffer_A, &temp_buffer_B); + &batch_all_cols_[i + 1], &ctx, &temp_buffer_A, &temp_buffer_B); i += 2; } } @@ -1032,10 +1043,16 @@ void KeyEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, void KeyEncoder::DecodeVaryingLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const KeyRowArray& rows, - std::vector* cols) { + std::vector* cols, + int64_t hardware_flags, + util::TempVectorStack* temp_stack) { // Prepare column array vectors PrepareKeyColumnArrays(start_row_output, num_rows, *cols); + KeyEncoderContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + bool is_row_fixed_length = row_metadata_.is_fixed_length; if (!is_row_fixed_length) { for (size_t i = 0; i < batch_varbinary_cols_.size(); ++i) { @@ -1043,7 +1060,7 @@ void KeyEncoder::DecodeVaryingLengthBuffers(int64_t start_row_input, // positions in the output row buffer. EncoderVarBinary::Decode(static_cast(start_row_input), static_cast(num_rows), static_cast(i), - rows, &batch_varbinary_cols_[i], ctx_); + rows, &batch_varbinary_cols_[i], &ctx); } } } @@ -1352,5 +1369,69 @@ Status KeyEncoder::EncodeSelected(KeyRowArray* rows, uint32_t num_selected, return Status::OK(); } +KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType( + const std::shared_ptr& type) { + if (type->id() == Type::DICTIONARY) { + auto bit_width = + arrow::internal::checked_cast(*type).bit_width(); + ARROW_DCHECK(bit_width % 8 == 0); + return KeyEncoder::KeyColumnMetadata(true, bit_width / 8); + } else if (type->id() == Type::BOOL) { + return KeyEncoder::KeyColumnMetadata(true, 0); + } else if (is_fixed_width(type->id())) { + return KeyEncoder::KeyColumnMetadata( + true, + arrow::internal::checked_cast(*type).bit_width() / 8); + } else if (is_binary_like(type->id())) { + return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t)); + } + ARROW_DCHECK(false); + return KeyEncoder::KeyColumnMetadata(true, sizeof(int)); +} + +KeyEncoder::KeyColumnArray ColumnArrayFromArrayData( + const std::shared_ptr& array_data, int start_row, int num_rows) { + KeyEncoder::KeyColumnArray column_array = KeyEncoder::KeyColumnArray( + ColumnMetadataFromDataType(array_data->type), + array_data->offset + start_row + num_rows, + array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() : nullptr, + array_data->buffers[1]->data(), + (array_data->buffers.size() > 2 && array_data->buffers[2] != NULLPTR) + ? array_data->buffers[2]->data() + : nullptr); + return KeyEncoder::KeyColumnArray(column_array, array_data->offset + start_row, + num_rows); +} + +void ColumnMetadatasFromExecBatch( + const ExecBatch& batch, + std::vector& column_metadatas) { + int num_columns = static_cast(batch.values.size()); + column_metadatas.resize(num_columns); + for (int i = 0; i < num_columns; ++i) { + const Datum& data = batch.values[i]; + ARROW_DCHECK(data.is_array()); + const std::shared_ptr& array_data = data.array(); + column_metadatas[i] = ColumnMetadataFromDataType(array_data->type); + } +} + +void ColumnArraysFromExecBatch(const ExecBatch& batch, int start_row, int num_rows, + std::vector& column_arrays) { + int num_columns = static_cast(batch.values.size()); + column_arrays.resize(num_columns); + for (int i = 0; i < num_columns; ++i) { + const Datum& data = batch.values[i]; + ARROW_DCHECK(data.is_array()); + const std::shared_ptr& array_data = data.array(); + column_arrays[i] = ColumnArrayFromArrayData(array_data, start_row, num_rows); + } +} + +void ColumnArraysFromExecBatch(const ExecBatch& batch, + std::vector& column_arrays) { + ColumnArraysFromExecBatch(batch, 0, static_cast(batch.length), column_arrays); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_encode.h b/cpp/src/arrow/compute/exec/key_encode.h index da533434d3957..b5d97c9b5f841 100644 --- a/cpp/src/arrow/compute/exec/key_encode.h +++ b/cpp/src/arrow/compute/exec/key_encode.h @@ -116,6 +116,7 @@ class KeyEncoder { /// Order in which fields are encoded. std::vector column_order; + std::vector inverse_column_order; /// Offsets within a row to fields in their encoding order. std::vector column_offsets; @@ -175,6 +176,10 @@ class KeyEncoder { uint32_t encoded_field_order(uint32_t icol) const { return column_order[icol]; } + uint32_t pos_after_encoding(uint32_t icol) const { + return inverse_column_order[icol]; + } + uint32_t encoded_field_offset(uint32_t icol) const { return column_offsets[icol]; } uint32_t num_cols() const { return static_cast(column_metadatas.size()); } @@ -292,8 +297,8 @@ class KeyEncoder { int bit_offset_[max_buffers_ - 1]; }; - void Init(const std::vector& cols, KeyEncoderContext* ctx, - int row_alignment, int string_alignment); + void Init(const std::vector& cols, int row_alignment, + int string_alignment); const KeyRowMetadata& row_metadata() { return row_metadata_; } @@ -312,11 +317,14 @@ class KeyEncoder { /// length buffers sizes. void DecodeFixedLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const KeyRowArray& rows, - std::vector* cols); + std::vector* cols, int64_t hardware_flags, + util::TempVectorStack* temp_stack); void DecodeVaryingLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const KeyRowArray& rows, - std::vector* cols); + std::vector* cols, + int64_t hardware_flags, + util::TempVectorStack* temp_stack); const std::vector& GetBatchColumns() const { return batch_all_cols_; } @@ -486,8 +494,6 @@ class KeyEncoder { std::vector* cols); }; - KeyEncoderContext* ctx_; - // Data initialized once, based on data types of key columns KeyRowMetadata row_metadata_; @@ -569,5 +575,16 @@ inline void KeyEncoder::EncoderVarBinary::DecodeHelper( } } +KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType( + const std::shared_ptr& type); +KeyEncoder::KeyColumnArray ColumnArrayFromArrayData( + const std::shared_ptr& array_data, int start_row, int num_rows); +void ColumnMetadatasFromExecBatch( + const ExecBatch& batch, std::vector& column_metadatas); +void ColumnArraysFromExecBatch(const ExecBatch& batch, int start_row, int num_rows, + std::vector& column_arrays); +void ColumnArraysFromExecBatch(const ExecBatch& batch, + std::vector& column_arrays); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_hash.cc b/cpp/src/arrow/compute/exec/key_hash.cc index bc4cae74ddc67..b9410ea3781d3 100644 --- a/cpp/src/arrow/compute/exec/key_hash.cc +++ b/cpp/src/arrow/compute/exec/key_hash.cc @@ -456,6 +456,17 @@ void Hashing32::HashMultiColumn(const std::vector& c } } +void Hashing32::HashBatch(const ExecBatch& key_batch, int start_row, int num_rows, + uint32_t* hashes, + std::vector& column_arrays, + int64_t hardware_flags, util::TempVectorStack* temp_stack) { + ColumnArraysFromExecBatch(key_batch, start_row, num_rows, column_arrays); + KeyEncoder::KeyEncoderContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + HashMultiColumn(column_arrays, &ctx, hashes); +} + inline uint64_t Hashing64::Avalanche(uint64_t acc) { acc ^= (acc >> 33); acc *= PRIME64_2; @@ -875,5 +886,16 @@ void Hashing64::HashMultiColumn(const std::vector& c } } +void Hashing64::HashBatch(const ExecBatch& key_batch, int start_row, int num_rows, + uint64_t* hashes, + std::vector& column_arrays, + int64_t hardware_flags, util::TempVectorStack* temp_stack) { + ColumnArraysFromExecBatch(key_batch, start_row, num_rows, column_arrays); + KeyEncoder::KeyEncoderContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + HashMultiColumn(column_arrays, &ctx, hashes); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_hash.h b/cpp/src/arrow/compute/exec/key_hash.h index 88f77be1a4fe2..ade61ba0b8e0a 100644 --- a/cpp/src/arrow/compute/exec/key_hash.h +++ b/cpp/src/arrow/compute/exec/key_hash.h @@ -48,6 +48,11 @@ class ARROW_EXPORT Hashing32 { static void HashMultiColumn(const std::vector& cols, KeyEncoder::KeyEncoderContext* ctx, uint32_t* out_hash); + static void HashBatch(const ExecBatch& key_batch, int start_row, int num_rows, + uint32_t* hashes, + std::vector& column_arrays, + int64_t hardware_flags, util::TempVectorStack* temp_stack); + private: static const uint32_t PRIME32_1 = 0x9E3779B1; static const uint32_t PRIME32_2 = 0x85EBCA77; @@ -156,6 +161,11 @@ class ARROW_EXPORT Hashing64 { static void HashMultiColumn(const std::vector& cols, KeyEncoder::KeyEncoderContext* ctx, uint64_t* hashes); + static void HashBatch(const ExecBatch& key_batch, int start_row, int num_rows, + uint64_t* hashes, + std::vector& column_arrays, + int64_t hardware_flags, util::TempVectorStack* temp_stack); + private: static const uint64_t PRIME64_1 = 0x9E3779B185EBCA87ULL; static const uint64_t PRIME64_2 = 0xC2B2AE3D27D4EB4FULL; diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index fe5ed98bb3e24..a61184e4ca9a8 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -42,8 +42,8 @@ constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; // b) first empty slot is encountered, // c) we reach the end of the block. // -// Optionally an index of the first slot to start the search from can be specified. -// In this case slots before it will be ignored. +// Optionally an index of the first slot to start the search from can be specified. In +// this case slots before it will be ignored. // template inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, @@ -88,29 +88,12 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, // We get 0 if there are no matches *out_match_found = (matches == 0 ? 0 : 1); - // Now if we or with the highest bits of the block and scan zero bits in reverse, - // we get 8x slot index that we were looking for. - // This formula works in all three cases a), b) and c). + // Now if we or with the highest bits of the block and scan zero bits in reverse, we get + // 8x slot index that we were looking for. This formula works in all three cases a), b) + // and c). *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, - uint64_t group_id_mask) const { - // Group id values for all 8 slots in the block are bit-packed and follow the status - // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In - // that case we can extract group id using aligned 64-bit word access. - int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); - ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || - num_group_id_bits == 32 || num_group_id_bits == 64); - - int bit_offset = slot * num_group_id_bits; - const uint64_t* group_id_bytes = - reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); - uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; - - return group_id; -} - template void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, @@ -147,14 +130,16 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32); + int num_processed = 0; + // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. #if defined(ARROW_HAVE_AVX2) int num_group_id_bytes = num_group_id_bits / 8; if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) { - extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t), - 8 + 8 * num_group_id_bytes, num_group_id_bytes); - return; + num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, + sizeof(uint64_t), 8 + 8 * num_group_id_bytes, + num_group_id_bytes); } #endif switch (num_group_id_bits) { @@ -163,8 +148,9 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, out_group_ids, 8, 16); } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 8, 16); + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed, 8, 16); } break; case 16: @@ -172,8 +158,9 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, out_group_ids, 4, 12); } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 4, 12); + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed, 4, 12); } break; case 32: @@ -181,8 +168,9 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, out_group_ids, 2, 10); } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 2, 10); + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed, 2, 10); } break; default: @@ -312,24 +300,21 @@ void SwissTable::early_filter(const int num_keys, const uint32_t* hashes, uint8_t* out_local_slots) const { // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. + int num_processed = 0; #if defined(ARROW_HAVE_AVX2) if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { if (log_blocks_ <= 4) { - int tail = num_keys % 32; - int delta = num_keys - tail; - early_filter_imp_avx2_x32(num_keys - tail, hashes, out_match_bitvector, - out_local_slots); - early_filter_imp_avx2_x8(tail, hashes + delta, out_match_bitvector + delta / 8, - out_local_slots + delta); - } else { - early_filter_imp_avx2_x8(num_keys, hashes, out_match_bitvector, out_local_slots); + num_processed = early_filter_imp_avx2_x32(num_keys, hashes, out_match_bitvector, + out_local_slots); } - } else { -#endif - early_filter_imp(num_keys, hashes, out_match_bitvector, out_local_slots); -#if defined(ARROW_HAVE_AVX2) + num_processed += early_filter_imp_avx2_x8( + num_keys - num_processed, hashes + num_processed, + out_match_bitvector + num_processed / 8, out_local_slots + num_processed); } #endif + early_filter_imp(num_keys - num_processed, hashes + num_processed, + out_match_bitvector + num_processed / 8, + out_local_slots + num_processed); } // Input selection may be: @@ -348,10 +333,16 @@ void SwissTable::run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, const uint8_t* optional_selection_bitvector, const uint32_t* groupids, int* out_num_not_equal, - uint16_t* out_not_equal_selection) const { + uint16_t* out_not_equal_selection, + const EqualImpl& equal_impl, void* callback_ctx) const { ARROW_DCHECK(optional_selection_ids || optional_selection_bitvector); ARROW_DCHECK(!optional_selection_ids || !optional_selection_bitvector); + if (num_keys == 0) { + *out_num_not_equal = 0; + return; + } + if (!optional_selection_ids && optional_selection_bitvector) { // Count rows with matches (based on stamp comparison) // and decide based on their percentage whether to call dense or sparse comparison @@ -368,21 +359,22 @@ void SwissTable::run_comparisons(const int num_keys, if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { uint32_t out_num; - equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection); + equal_impl(num_keys, nullptr, groupids, &out_num, out_not_equal_selection, + callback_ctx); *out_num_not_equal = static_cast(out_num); } else { util::bit_util::bits_to_indexes(1, hardware_flags_, num_keys, optional_selection_bitvector, out_num_not_equal, out_not_equal_selection); uint32_t out_num; - equal_impl_(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, - out_not_equal_selection); + equal_impl(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, + out_not_equal_selection, callback_ctx); *out_num_not_equal = static_cast(out_num); } } else { uint32_t out_num; - equal_impl_(num_keys, optional_selection_ids, groupids, &out_num, - out_not_equal_selection); + equal_impl(num_keys, optional_selection_ids, groupids, &out_num, + out_not_equal_selection, callback_ctx); *out_num_not_equal = static_cast(out_num); } } @@ -432,35 +424,6 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl return match_found; } -void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, - uint32_t group_id) { - const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - - // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. - // In that case we can insert group id value using aligned 64-bit word access. - ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || - num_groupid_bits == 32 || num_groupid_bits == 64); - - const uint64_t num_block_bytes = (8 + num_groupid_bits); - constexpr uint64_t stamp_mask = 0x7f; - - int start_slot = (slot_id & 7); - int stamp = - static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); - uint64_t block_id = slot_id >> 3; - uint8_t* blockbase = blocks_ + num_block_bytes * block_id; - - blockbase[7 - start_slot] = static_cast(stamp); - int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); - - // Block status bytes should start at an address aligned to 8 bytes - ARROW_DCHECK((reinterpret_cast(blockbase) & 7) == 0); - uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); - *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); - - hashes_[slot_id] = hash; -} - // Find method is the continuation of processing from early_filter. // Its input consists of hash values and the output of early_filter. // It updates match bit-vector, clearing it from any false positives @@ -471,7 +434,8 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, // void SwissTable::find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, const uint8_t* local_slots, - uint32_t* out_group_ids) const { + uint32_t* out_group_ids, util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, void* callback_ctx) const { // Temporary selection vector. // It will hold ids of keys for which we do not know yet // if they have a match in hash table or not. @@ -481,12 +445,12 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, // to array of ids. // ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); - auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + auto ids_buf = util::TempVectorHolder(temp_stack, num_keys); uint16_t* ids = ids_buf.mutable_data(); int num_ids; - int64_t num_matches = - arrow::internal::CountSetBits(inout_match_bitvector, /*offset=*/0, num_keys); + int64_t num_matches = arrow::internal::CountSetBits(inout_match_bitvector, + /*offset=*/0, num_keys); // If there is a high density of selected input rows // (majority of them are present in the selection), @@ -498,19 +462,20 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, if (visit_all) { extract_group_ids(num_keys, nullptr, hashes, local_slots, out_group_ids); run_comparisons(num_keys, nullptr, inout_match_bitvector, out_group_ids, &num_ids, - ids); + ids, equal_impl, callback_ctx); } else { util::bit_util::bits_to_indexes(1, hardware_flags_, num_keys, inout_match_bitvector, &num_ids, ids); extract_group_ids(num_ids, ids, hashes, local_slots, out_group_ids); - run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids, equal_impl, + callback_ctx); } if (num_ids == 0) { return; } - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + auto slot_ids_buf = util::TempVectorHolder(temp_stack, num_keys); uint32_t* slot_ids = slot_ids_buf.mutable_data(); init_slot_ids(num_ids, ids, hashes, local_slots, inout_match_bitvector, slot_ids); @@ -531,9 +496,10 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, } } - run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids, equal_impl, + callback_ctx); } -} // namespace compute +} // Slow processing of input keys in the most generic case. // Handles inserting new keys. @@ -545,11 +511,11 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, // Update selection vector to reflect which items have been processed. // Ids in selection vector do not have to be sorted. // -Status SwissTable::map_new_keys_helper(const uint32_t* hashes, - uint32_t* inout_num_selected, - uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, - uint32_t* inout_next_slot_ids) { +Status SwissTable::map_new_keys_helper( + const uint32_t* hashes, uint32_t* inout_num_selected, uint16_t* inout_selection, + bool* out_need_resize, uint32_t* out_group_ids, uint32_t* inout_next_slot_ids, + util::TempVectorStack* temp_stack, const EqualImpl& equal_impl, + const AppendImpl& append_impl, void* callback_ctx) { auto num_groups_limit = num_groups_for_resize(); ARROW_DCHECK(num_inserted_ < num_groups_limit); @@ -560,7 +526,7 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t); auto match_bitvector_buf = util::TempVectorHolder( - temp_stack_, static_cast(num_bytes_for_bits)); + temp_stack, static_cast(num_bytes_for_bits)); uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); memset(match_bitvector, 0xff, num_bytes_for_bits); @@ -580,11 +546,12 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, // out_group_ids[id] = num_inserted_ + num_inserted_new; insert_into_empty_slot(inout_next_slot_ids[id], hashes[id], out_group_ids[id]); + hashes_[inout_next_slot_ids[id]] = hashes[id]; ::arrow::bit_util::ClearBit(match_bitvector, num_processed); ++num_inserted_new; - // We need to break processing and have the caller of this function - // resize hash table if we reach the limit of the number of groups present. + // We need to break processing and have the caller of this function resize hash + // table if we reach the limit of the number of groups present. // if (num_inserted_ + num_inserted_new == num_groups_limit) { ++num_processed; @@ -594,7 +561,7 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, } auto temp_ids_buffer = - util::TempVectorHolder(temp_stack_, *inout_num_selected); + util::TempVectorHolder(temp_stack, *inout_num_selected); uint16_t* temp_ids = temp_ids_buffer.mutable_data(); int num_temp_ids = 0; @@ -603,16 +570,18 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, util::bit_util::bits_filter_indexes(0, hardware_flags_, num_processed, match_bitvector, inout_selection, &num_temp_ids, temp_ids); ARROW_DCHECK(static_cast(num_inserted_new) == num_temp_ids); - RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids)); + RETURN_NOT_OK(append_impl(num_inserted_new, temp_ids, callback_ctx)); num_inserted_ += num_inserted_new; // Evaluate comparisons and append ids of rows that failed it to the non-match set. util::bit_util::bits_filter_indexes(1, hardware_flags_, num_processed, match_bitvector, inout_selection, &num_temp_ids, temp_ids); - run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, - temp_ids); + run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, temp_ids, + equal_impl, callback_ctx); - memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); + if (num_temp_ids > 0) { + memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); + } // Append ids of any unprocessed entries if we aborted processing due to the need // to resize. if (num_processed < *inout_num_selected) { @@ -629,7 +598,9 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, // this set). // Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, - uint32_t* group_ids) { + uint32_t* group_ids, util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, + const AppendImpl& append_impl, void* callback_ctx) { if (num_ids == 0) { return Status::OK(); } @@ -645,7 +616,7 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* ARROW_DCHECK(static_cast(max_id + 1) <= (1 << log_minibatch_)); // Allocate temporary buffers for slot ids and intialize them - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, max_id + 1); + auto slot_ids_buf = util::TempVectorHolder(temp_stack, max_id + 1); uint32_t* slot_ids = slot_ids_buf.mutable_data(); init_slot_ids_for_new_keys(num_ids, ids, hashes, slot_ids); @@ -658,7 +629,8 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* // bigger hash table. bool out_of_capacity; RETURN_NOT_OK(map_new_keys_helper(hashes, &num_ids, ids, &out_of_capacity, group_ids, - slot_ids)); + slot_ids, temp_stack, equal_impl, append_impl, + callback_ctx)); if (out_of_capacity) { RETURN_NOT_OK(grow_double()); // Reset start slot ids for still unprocessed input keys. @@ -803,17 +775,13 @@ Status SwissTable::grow_double() { return Status::OK(); } -Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, - util::TempVectorStack* temp_stack, int log_minibatch, - EqualImpl equal_impl, AppendImpl append_impl) { +Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks, + bool no_hash_array) { hardware_flags_ = hardware_flags; pool_ = pool; - temp_stack_ = temp_stack; - log_minibatch_ = log_minibatch; - equal_impl_ = equal_impl; - append_impl_ = append_impl; + log_minibatch_ = util::MiniBatch::kLogMiniBatchLength; - log_blocks_ = 0; + log_blocks_ = log_blocks; int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); num_inserted_ = 0; @@ -829,12 +797,16 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte); } - uint64_t num_slots = 1ULL << (log_blocks_ + 3); - const uint64_t hash_size = sizeof(uint32_t); - const uint64_t hash_bytes = hash_size * num_slots + padding_; - uint8_t* hashes8; - RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8)); - hashes_ = reinterpret_cast(hashes8); + if (no_hash_array) { + hashes_ = nullptr; + } else { + uint64_t num_slots = 1ULL << (log_blocks_ + 3); + const uint64_t hash_size = sizeof(uint32_t); + const uint64_t hash_bytes = hash_size * num_slots + padding_; + uint8_t* hashes8; + RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8)); + hashes_ = reinterpret_cast(hashes8); + } return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/key_map.h b/cpp/src/arrow/compute/exec/key_map.h index 12c1e393c4a3b..edd3b16f0a48a 100644 --- a/cpp/src/arrow/compute/exec/key_map.h +++ b/cpp/src/arrow/compute/exec/key_map.h @@ -28,6 +28,8 @@ namespace arrow { namespace compute { class SwissTable { + friend class SwissTableMerge; + public: SwissTable() = default; ~SwissTable() { cleanup(); } @@ -35,11 +37,12 @@ class SwissTable { using EqualImpl = std::function; - using AppendImpl = std::function; + uint16_t* out_selection_mismatch, void* callback_ctx)>; + using AppendImpl = + std::function; - Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack, - int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl); + Status init(int64_t hardware_flags, MemoryPool* pool, int log_blocks = 0, + bool no_hash_array = false); void cleanup(); @@ -47,10 +50,22 @@ class SwissTable { uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; void find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, - const uint8_t* local_slots, uint32_t* out_group_ids) const; + const uint8_t* local_slots, uint32_t* out_group_ids, + util::TempVectorStack* temp_stack, const EqualImpl& equal_impl, + void* callback_ctx) const; Status map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, - uint32_t* group_ids); + uint32_t* group_ids, util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, const AppendImpl& append_impl, + void* callback_ctx); + + int minibatch_size() const { return 1 << log_minibatch_; } + + int64_t num_inserted() const { return num_inserted_; } + + int64_t hardware_flags() const { return hardware_flags_; } + + MemoryPool* pool() const { return pool_; } private: // Lookup helpers @@ -116,21 +131,22 @@ class SwissTable { void early_filter_imp(const int num_keys, const uint32_t* hashes, uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; #if defined(ARROW_HAVE_AVX2) - void early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + int early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const; + int early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; - void early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, - uint8_t* out_local_slots) const; - void extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, - const uint8_t* local_slots, uint32_t* out_group_ids, - int byte_offset, int byte_multiplier, int byte_size) const; + int extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, uint32_t* out_group_ids, + int byte_offset, int byte_multiplier, int byte_size) const; #endif void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, const uint8_t* optional_selection_bitvector, const uint32_t* groupids, int* out_num_not_equal, - uint16_t* out_not_equal_selection) const; + uint16_t* out_not_equal_selection, const EqualImpl& equal_impl, + void* callback_ctx) const; inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, uint32_t* out_slot_id, uint32_t* out_group_id) const; @@ -145,7 +161,10 @@ class SwissTable { // Status map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected, uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* out_next_slot_ids); + uint32_t* out_group_ids, uint32_t* out_next_slot_ids, + util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, const AppendImpl& append_impl, + void* callback_ctx); // Resize small hash tables when 50% full (up to 8KB). // Resize large hash tables when 75% full. @@ -198,11 +217,51 @@ class SwissTable { int64_t hardware_flags_; MemoryPool* pool_; - util::TempVectorStack* temp_stack_; - - EqualImpl equal_impl_; - AppendImpl append_impl_; }; +uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, + uint64_t group_id_mask) const { + // Group id values for all 8 slots in the block are bit-packed and follow the status + // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In + // that case we can extract group id using aligned 64-bit word access. + int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); + ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || + num_group_id_bits == 32 || num_group_id_bits == 64); + + int bit_offset = slot * num_group_id_bits; + const uint64_t* group_id_bytes = + reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); + uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; + + return group_id; +} + +void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, + uint32_t group_id) { + const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + + // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. + // In that case we can insert group id value using aligned 64-bit word access. + ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || + num_groupid_bits == 32 || num_groupid_bits == 64); + + const uint64_t num_block_bytes = (8 + num_groupid_bits); + constexpr uint64_t stamp_mask = 0x7f; + + int start_slot = (slot_id & 7); + int stamp = + static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); + uint64_t block_id = slot_id >> 3; + uint8_t* blockbase = blocks_ + num_block_bytes * block_id; + + blockbase[7 - start_slot] = static_cast(stamp); + int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); + + // Block status bytes should start at an address aligned to 8 bytes + ARROW_DCHECK((reinterpret_cast(blockbase) & 7) == 0); + uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); + *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_map_avx2.cc b/cpp/src/arrow/compute/exec/key_map_avx2.cc index 2fca6bf6c10c9..4c77f3af237b5 100644 --- a/cpp/src/arrow/compute/exec/key_map_avx2.cc +++ b/cpp/src/arrow/compute/exec/key_map_avx2.cc @@ -24,21 +24,15 @@ namespace compute { #if defined(ARROW_HAVE_AVX2) -// Why it is OK to round up number of rows internally: -// All of the buffers: hashes, out_match_bitvector, out_group_ids, out_next_slot_ids -// are temporary buffers of group id mapping. -// Temporary buffers are buffers that live only within the boundaries of a single -// minibatch. Temporary buffers add 64B at the end, so that SIMD code does not have to -// worry about reading and writing outside of the end of the buffer up to 64B. If the -// hashes array contains garbage after the last element, it cannot cause computation to -// fail, since any random data is a valid hash for the purpose of lookup. +// This is more or less translation of equivalent scalar code, adjusted for a +// different instruction set (e.g. missing leading zero count instruction). // -// This is more or less translation of equivalent scalar code, adjusted for a different -// instruction set (e.g. missing leading zero count instruction). +// Returns the number of hashes actually processed, which may be less than +// requested due to alignment required by SIMD. // -void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, - uint8_t* out_local_slots) const { +int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { // Number of inputs processed together in a loop constexpr int unroll = 8; @@ -46,8 +40,7 @@ void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* const __m256i* vhash_ptr = reinterpret_cast(hashes); const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1); - // TODO: explain why it is ok to process hashes outside of buffer boundaries - for (int i = 0; i < ((num_hashes + unroll - 1) / unroll); ++i) { + for (int i = 0; i < num_hashes / unroll; ++i) { constexpr uint64_t kEachByteIs8 = 0x0808080808080808ULL; constexpr uint64_t kByteSequenceOfPowersOf2 = 0x8040201008040201ULL; @@ -139,6 +132,8 @@ void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* out_match_bitvector[i] = _pext_u32(_mm256_movemask_epi8(vmatch_found), 0x11111111); // 0b00010001 repeated 4x } + + return num_hashes - (num_hashes % unroll); } // Take a set of 16 64-bit elements, @@ -173,8 +168,8 @@ inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256 // k4, o4, l4, p4, ... k7, o7, l7, p7} __m256i byte01 = _mm256_unpacklo_epi32( - a, b); // {a0, e0, b0, f0, i0, m0, j0, n0, a1, e1, b1, f1, i1, m1, j1, n1, c0, g0, - // d0, h0, k0, o0, l0, p0, ...} + a, b); // {a0, e0, b0, f0, i0, m0, j0, n0, a1, e1, b1, f1, i1, m1, j1, n1, + // c0, g0, d0, h0, k0, o0, l0, p0, ...} __m256i shuffle_const = _mm256_setr_epi8(0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15, 0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15); @@ -206,9 +201,13 @@ inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256 // using a different method. // TODO: Explain the idea behind storing arrays in SIMD registers. // Explain why it is faster with SIMD than using memory loads. -void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, - uint8_t* out_local_slots) const { +// +// Returns the number of hashes actually processed, which may be less than +// requested due to alignment required by SIMD. +// +int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { constexpr int unroll = 32; // There is a limit on the number of input blocks, @@ -366,12 +365,14 @@ void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* reinterpret_cast(out_match_bitvector)[i] = _mm256_movemask_epi8(vmatch_found); } + + return num_hashes - (num_hashes % unroll); } -void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, - const uint8_t* local_slots, - uint32_t* out_group_ids, int byte_offset, - int byte_multiplier, int byte_size) const { +int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, + uint32_t* out_group_ids, int byte_offset, + int byte_multiplier, int byte_size) const { ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4); uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF; auto elements = reinterpret_cast(blocks_ + byte_offset); @@ -380,7 +381,7 @@ void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hash ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16); __m256i block_group_ids = _mm256_set1_epi64x(reinterpret_cast(blocks_)[1]); - for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + for (int i = 0; i < num_keys / unroll; ++i) { __m256i local_slot = _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); __m256i group_id = _mm256_shuffle_epi8(block_group_ids, local_slot); @@ -390,7 +391,7 @@ void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hash _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); } } else { - for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + for (int i = 0; i < num_keys / unroll; ++i) { __m256i hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); __m256i local_slot = _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); @@ -406,6 +407,7 @@ void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hash _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); } } + return num_keys - (num_keys % unroll); } #endif diff --git a/cpp/src/arrow/compute/exec/partition_util.h b/cpp/src/arrow/compute/exec/partition_util.h index 6efda4aeeb0f3..ba0d932dc600a 100644 --- a/cpp/src/arrow/compute/exec/partition_util.h +++ b/cpp/src/arrow/compute/exec/partition_util.h @@ -115,6 +115,42 @@ class PartitionLocks { /// \brief Release a partition so that other threads can work on it void ReleasePartitionLock(int prtn_id); + template + Status ForEachPartition(int* temp_unprocessed_prtns, IS_PRTN_EMPTY_FN is_prtn_empty_fn, + PROCESS_PRTN_FN process_prtn_fn) { + int num_unprocessed_partitions = 0; + for (int i = 0; i < num_prtns_; ++i) { + bool is_prtn_empty = is_prtn_empty_fn(i); + if (!is_prtn_empty) { + temp_unprocessed_prtns[num_unprocessed_partitions++] = i; + } + } + while (num_unprocessed_partitions > 0) { + int locked_prtn_id; + int locked_prtn_id_pos; + AcquirePartitionLock(num_unprocessed_partitions, temp_unprocessed_prtns, + /*limit_retries=*/false, /*max_retries=*/-1, &locked_prtn_id, + &locked_prtn_id_pos); + { + class AutoReleaseLock { + public: + AutoReleaseLock(PartitionLocks* locks, int prtn_id) + : locks(locks), prtn_id(prtn_id) {} + ~AutoReleaseLock() { locks->ReleasePartitionLock(prtn_id); } + PartitionLocks* locks; + int prtn_id; + } auto_release_lock(this, locked_prtn_id); + ARROW_RETURN_NOT_OK(process_prtn_fn(locked_prtn_id)); + } + if (locked_prtn_id_pos < num_unprocessed_partitions - 1) { + temp_unprocessed_prtns[locked_prtn_id_pos] = + temp_unprocessed_prtns[num_unprocessed_partitions - 1]; + } + --num_unprocessed_partitions; + } + return Status::OK(); + } + private: std::atomic* lock_ptr(int prtn_id); int random_int(int num_values); diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h index 4e307e238072e..adb949ef7dc1c 100644 --- a/cpp/src/arrow/compute/exec/schema_util.h +++ b/cpp/src/arrow/compute/exec/schema_util.h @@ -24,12 +24,8 @@ #include "arrow/compute/exec/key_encode.h" // for KeyColumnMetadata #include "arrow/type.h" // for DataType, FieldRef, Field and Schema -#include "arrow/util/mutex.h" namespace arrow { - -using internal::checked_cast; - namespace compute { // Identifiers for all different row schemas that are used in a join @@ -79,16 +75,28 @@ class SchemaProjectionMaps { int num_cols(ProjectionIdEnum schema_handle) const { int id = schema_id(schema_handle); - return static_cast(schemas_[id].second.size()); + return static_cast(schemas_[id].second.data_types.size()); + } + + bool is_empty(ProjectionIdEnum schema_handle) const { + return num_cols(schema_handle) == 0; } const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const { - return field(schema_handle, field_id).field_name; + int id = schema_id(schema_handle); + return schemas_[id].second.field_names[field_id]; } const std::shared_ptr& data_type(ProjectionIdEnum schema_handle, int field_id) const { - return field(schema_handle, field_id).data_type; + int id = schema_id(schema_handle); + return schemas_[id].second.data_types[field_id]; + } + + const std::vector>& data_types( + ProjectionIdEnum schema_handle) const { + int id = schema_id(schema_handle); + return schemas_[id].second.data_types; } SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) const { @@ -102,22 +110,24 @@ class SchemaProjectionMaps { } protected: - struct FieldInfo { - int field_path; - std::string field_name; - std::shared_ptr data_type; + struct FieldInfos { + std::vector field_paths; + std::vector field_names; + std::vector> data_types; }; Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) { - std::vector out_fields; + FieldInfos out_fields; const FieldVector& in_fields = schema.fields(); - out_fields.resize(in_fields.size()); + out_fields.field_paths.resize(in_fields.size()); + out_fields.field_names.resize(in_fields.size()); + out_fields.data_types.resize(in_fields.size()); for (size_t i = 0; i < in_fields.size(); ++i) { const std::string& name = in_fields[i]->name(); const std::shared_ptr& type = in_fields[i]->type(); - out_fields[i].field_path = static_cast(i); - out_fields[i].field_name = name; - out_fields[i].data_type = type; + out_fields.field_paths[i] = static_cast(i); + out_fields.field_names[i] = name; + out_fields.data_types[i] = type; } schemas_.push_back(std::make_pair(handle, out_fields)); return Status::OK(); @@ -126,17 +136,19 @@ class SchemaProjectionMaps { Status RegisterProjectedSchema(ProjectionIdEnum handle, const std::vector& selected_fields, const Schema& full_schema) { - std::vector out_fields; + FieldInfos out_fields; const FieldVector& in_fields = full_schema.fields(); - out_fields.resize(selected_fields.size()); + out_fields.field_paths.resize(selected_fields.size()); + out_fields.field_names.resize(selected_fields.size()); + out_fields.data_types.resize(selected_fields.size()); for (size_t i = 0; i < selected_fields.size(); ++i) { // All fields must be found in schema without ambiguity ARROW_ASSIGN_OR_RAISE(auto match, selected_fields[i].FindOne(full_schema)); const std::string& name = in_fields[match[0]]->name(); const std::shared_ptr& type = in_fields[match[0]]->type(); - out_fields[i].field_path = match[0]; - out_fields[i].field_name = name; - out_fields[i].data_type = type; + out_fields.field_paths[i] = match[0]; + out_fields.field_names[i] = name; + out_fields.data_types[i] = type; } schemas_.push_back(std::make_pair(handle, out_fields)); return Status::OK(); @@ -163,15 +175,9 @@ class SchemaProjectionMaps { return -1; } - const FieldInfo& field(ProjectionIdEnum schema_handle, int field_id) const { - int id = schema_id(schema_handle); - const std::vector& field_infos = schemas_[id].second; - return field_infos[field_id]; - } - void GenerateMapForProjection(int id_proj, int id_base) { - int num_cols_proj = static_cast(schemas_[id_proj].second.size()); - int num_cols_base = static_cast(schemas_[id_base].second.size()); + int num_cols_proj = static_cast(schemas_[id_proj].second.data_types.size()); + int num_cols_base = static_cast(schemas_[id_base].second.data_types.size()); std::vector& mapping = mappings_[id_proj]; std::vector& inverse_mapping = inverse_mappings_[id_proj]; @@ -183,15 +189,15 @@ class SchemaProjectionMaps { mapping[i] = inverse_mapping[i] = i; } } else { - const std::vector& fields_proj = schemas_[id_proj].second; - const std::vector& fields_base = schemas_[id_base].second; + const FieldInfos& fields_proj = schemas_[id_proj].second; + const FieldInfos& fields_base = schemas_[id_base].second; for (int i = 0; i < num_cols_base; ++i) { inverse_mapping[i] = SchemaProjectionMap::kMissingField; } for (int i = 0; i < num_cols_proj; ++i) { int field_id = SchemaProjectionMap::kMissingField; for (int j = 0; j < num_cols_base; ++j) { - if (fields_proj[i].field_path == fields_base[j].field_path) { + if (fields_proj.field_paths[i] == fields_base.field_paths[j]) { field_id = j; // If there are multiple matches for the same input field, // it will be mapped to the first match. @@ -206,10 +212,12 @@ class SchemaProjectionMaps { } // vector used as a mapping from ProjectionIdEnum to fields - std::vector>> schemas_; + std::vector> schemas_; std::vector> mappings_; std::vector> inverse_mappings_; }; +using HashJoinProjectionMaps = SchemaProjectionMaps; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc b/cpp/src/arrow/compute/exec/swiss_join.cc new file mode 100644 index 0000000000000..62beba35db14d --- /dev/null +++ b/cpp/src/arrow/compute/exec/swiss_join.cc @@ -0,0 +1,3279 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/swiss_join.h" +#include +#include // std::upper_bound +#include +#include +#include +#include "arrow/array/util.h" // MakeArrayFromScalar +#include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/key_compare.h" +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/key_hash.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { +namespace compute { + +void ResizableArrayData::Init(const std::shared_ptr& data_type, + MemoryPool* pool, int log_num_rows_min) { +#ifndef NDEBUG + if (num_rows_allocated_ > 0) { + ARROW_DCHECK(data_type_ != NULLPTR); + KeyEncoder::KeyColumnMetadata metadata_before = + ColumnMetadataFromDataType(data_type_); + KeyEncoder::KeyColumnMetadata metadata_after = ColumnMetadataFromDataType(data_type); + ARROW_DCHECK(metadata_before.is_fixed_length == metadata_after.is_fixed_length && + metadata_before.fixed_length == metadata_after.fixed_length); + } +#endif + Clear(/*release_buffers=*/false); + log_num_rows_min_ = log_num_rows_min; + data_type_ = data_type; + pool_ = pool; +} + +void ResizableArrayData::Clear(bool release_buffers) { + num_rows_ = 0; + if (release_buffers) { + non_null_buf_.reset(); + fixed_len_buf_.reset(); + var_len_buf_.reset(); + num_rows_allocated_ = 0; + var_len_buf_size_ = 0; + } +} + +Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) { + ARROW_DCHECK(num_rows_new >= 0); + if (num_rows_new <= num_rows_allocated_) { + num_rows_ = num_rows_new; + return Status::OK(); + } + + int num_rows_allocated_new = 1 << log_num_rows_min_; + while (num_rows_allocated_new < num_rows_new) { + num_rows_allocated_new *= 2; + } + + KeyEncoder::KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(data_type_); + + if (fixed_len_buf_ == NULLPTR) { + ARROW_DCHECK(non_null_buf_ == NULLPTR && var_len_buf_ == NULLPTR); + + ARROW_ASSIGN_OR_RAISE( + non_null_buf_, + AllocateResizableBuffer( + bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes, pool_)); + if (column_metadata.is_fixed_length) { + if (column_metadata.fixed_length == 0) { + ARROW_ASSIGN_OR_RAISE( + fixed_len_buf_, + AllocateResizableBuffer( + bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes, + pool_)); + } else { + ARROW_ASSIGN_OR_RAISE( + fixed_len_buf_, + AllocateResizableBuffer( + num_rows_allocated_new * column_metadata.fixed_length + kNumPaddingBytes, + pool_)); + } + } else { + ARROW_ASSIGN_OR_RAISE( + fixed_len_buf_, + AllocateResizableBuffer( + (num_rows_allocated_new + 1) * sizeof(uint32_t) + kNumPaddingBytes, pool_)); + } + + ARROW_ASSIGN_OR_RAISE(var_len_buf_, AllocateResizableBuffer( + sizeof(uint64_t) + kNumPaddingBytes, pool_)); + + var_len_buf_size_ = sizeof(uint64_t); + } else { + ARROW_DCHECK(non_null_buf_ != NULLPTR && var_len_buf_ != NULLPTR); + + RETURN_NOT_OK(non_null_buf_->Resize(bit_util::BytesForBits(num_rows_allocated_new) + + kNumPaddingBytes)); + + if (column_metadata.is_fixed_length) { + if (column_metadata.fixed_length == 0) { + RETURN_NOT_OK(fixed_len_buf_->Resize( + bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes)); + } else { + RETURN_NOT_OK(fixed_len_buf_->Resize( + num_rows_allocated_new * column_metadata.fixed_length + kNumPaddingBytes)); + } + } else { + RETURN_NOT_OK(fixed_len_buf_->Resize( + (num_rows_allocated_new + 1) * sizeof(uint32_t) + kNumPaddingBytes)); + } + } + + num_rows_allocated_ = num_rows_allocated_new; + num_rows_ = num_rows_new; + + return Status::OK(); +} + +Status ResizableArrayData::ResizeVaryingLengthBuffer() { + KeyEncoder::KeyColumnMetadata column_metadata; + column_metadata = ColumnMetadataFromDataType(data_type_); + + if (!column_metadata.is_fixed_length) { + int min_new_size = static_cast( + reinterpret_cast(fixed_len_buf_->data())[num_rows_]); + ARROW_DCHECK(var_len_buf_size_ > 0); + if (var_len_buf_size_ < min_new_size) { + int new_size = var_len_buf_size_; + while (new_size < min_new_size) { + new_size *= 2; + } + RETURN_NOT_OK(var_len_buf_->Resize(new_size + kNumPaddingBytes)); + var_len_buf_size_ = new_size; + } + } + + return Status::OK(); +} + +KeyEncoder::KeyColumnArray ResizableArrayData::column_array() const { + KeyEncoder::KeyColumnMetadata column_metadata; + column_metadata = ColumnMetadataFromDataType(data_type_); + return KeyEncoder::KeyColumnArray( + column_metadata, num_rows_, non_null_buf_->mutable_data(), + fixed_len_buf_->mutable_data(), var_len_buf_->mutable_data()); +} + +std::shared_ptr ResizableArrayData::array_data() const { + KeyEncoder::KeyColumnMetadata column_metadata; + column_metadata = ColumnMetadataFromDataType(data_type_); + + auto valid_count = arrow::internal::CountSetBits(non_null_buf_->data(), /*offset=*/0, + static_cast(num_rows_)); + int null_count = static_cast(num_rows_) - static_cast(valid_count); + + if (column_metadata.is_fixed_length) { + return ArrayData::Make(data_type_, num_rows_, {non_null_buf_, fixed_len_buf_}, + null_count); + } else { + return ArrayData::Make(data_type_, num_rows_, + {non_null_buf_, fixed_len_buf_, var_len_buf_}, null_count); + } +} + +int ExecBatchBuilder::NumRowsToSkip(const std::shared_ptr& column, + int num_rows, const uint16_t* row_ids, + int num_tail_bytes_to_skip) { +#ifndef NDEBUG + // Ids must be in non-decreasing order + // + for (int i = 1; i < num_rows; ++i) { + ARROW_DCHECK(row_ids[i] >= row_ids[i - 1]); + } +#endif + + KeyEncoder::KeyColumnMetadata column_metadata = + ColumnMetadataFromDataType(column->type); + + int num_rows_left = num_rows; + int num_bytes_skipped = 0; + while (num_rows_left > 0 && num_bytes_skipped < num_tail_bytes_to_skip) { + if (column_metadata.is_fixed_length) { + if (column_metadata.fixed_length == 0) { + num_rows_left = std::max(num_rows_left, 8) - 8; + ++num_bytes_skipped; + } else { + --num_rows_left; + num_bytes_skipped += column_metadata.fixed_length; + } + } else { + --num_rows_left; + int row_id_removed = row_ids[num_rows_left]; + const uint32_t* offsets = + reinterpret_cast(column->buffers[1]->data()); + num_bytes_skipped += offsets[row_id_removed + 1] - offsets[row_id_removed]; + } + } + + return num_rows - num_rows_left; +} + +template +void ExecBatchBuilder::CollectBitsImp(const uint8_t* input_bits, + int64_t input_bits_offset, uint8_t* output_bits, + int64_t output_bits_offset, int num_rows, + const uint16_t* row_ids) { + if (!OUTPUT_BYTE_ALIGNED) { + ARROW_DCHECK(output_bits_offset % 8 > 0); + output_bits[output_bits_offset / 8] &= + static_cast((1 << (output_bits_offset % 8)) - 1); + } else { + ARROW_DCHECK(output_bits_offset % 8 == 0); + } + constexpr int unroll = 8; + for (int i = 0; i < num_rows / unroll; ++i) { + const uint16_t* row_ids_base = row_ids + unroll * i; + uint8_t result; + result = bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[0]) ? 1 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[1]) ? 2 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[2]) ? 4 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[3]) ? 8 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[4]) ? 16 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[5]) ? 32 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[6]) ? 64 : 0; + result |= bit_util::GetBit(input_bits, input_bits_offset + row_ids_base[7]) ? 128 : 0; + if (OUTPUT_BYTE_ALIGNED) { + output_bits[output_bits_offset / 8 + i] = result; + } else { + output_bits[output_bits_offset / 8 + i] |= + static_cast(result << (output_bits_offset % 8)); + output_bits[output_bits_offset / 8 + i + 1] = + static_cast(result >> (8 - (output_bits_offset % 8))); + } + } + if (num_rows % unroll > 0) { + for (int i = num_rows - (num_rows % unroll); i < num_rows; ++i) { + bit_util::SetBitTo(output_bits, output_bits_offset + i, + bit_util::GetBit(input_bits, input_bits_offset + row_ids[i])); + } + } +} + +void ExecBatchBuilder::CollectBits(const uint8_t* input_bits, int64_t input_bits_offset, + uint8_t* output_bits, int64_t output_bits_offset, + int num_rows, const uint16_t* row_ids) { + if (output_bits_offset % 8 > 0) { + CollectBitsImp(input_bits, input_bits_offset, output_bits, output_bits_offset, + num_rows, row_ids); + } else { + CollectBitsImp(input_bits, input_bits_offset, output_bits, output_bits_offset, + num_rows, row_ids); + } +} + +template +void ExecBatchBuilder::Visit(const std::shared_ptr& column, int num_rows, + const uint16_t* row_ids, PROCESS_VALUE_FN process_value_fn) { + KeyEncoder::KeyColumnMetadata metadata = ColumnMetadataFromDataType(column->type); + + if (!metadata.is_fixed_length) { + const uint8_t* ptr_base = column->buffers[2]->data(); + const uint32_t* offsets = + reinterpret_cast(column->buffers[1]->data()) + column->offset; + for (int i = 0; i < num_rows; ++i) { + uint16_t row_id = row_ids[i]; + const uint8_t* field_ptr = ptr_base + offsets[row_id]; + uint32_t field_length = offsets[row_id + 1] - offsets[row_id]; + process_value_fn(i, field_ptr, field_length); + } + } else { + ARROW_DCHECK(metadata.fixed_length > 0); + for (int i = 0; i < num_rows; ++i) { + uint16_t row_id = row_ids[i]; + const uint8_t* field_ptr = + column->buffers[1]->data() + + (column->offset + row_id) * static_cast(metadata.fixed_length); + process_value_fn(i, field_ptr, metadata.fixed_length); + } + } +} + +Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source, + ResizableArrayData& target, + int num_rows_to_append, const uint16_t* row_ids, + MemoryPool* pool) { + int num_rows_before = target.num_rows(); + ARROW_DCHECK(num_rows_before >= 0); + int num_rows_after = num_rows_before + num_rows_to_append; + if (target.num_rows() == 0) { + target.Init(source->type, pool, kLogNumRows); + } + RETURN_NOT_OK(target.ResizeFixedLengthBuffers(num_rows_after)); + + KeyEncoder::KeyColumnMetadata column_metadata = + ColumnMetadataFromDataType(source->type); + + if (column_metadata.is_fixed_length) { + // Fixed length column + // + uint32_t fixed_length = column_metadata.fixed_length; + switch (fixed_length) { + case 0: + CollectBits(source->buffers[1]->data(), source->offset, target.mutable_data(1), + num_rows_before, num_rows_to_append, row_ids); + break; + case 1: + Visit(source, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + target.mutable_data(1)[num_rows_before + i] = *ptr; + }); + break; + case 2: + Visit(source, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(target.mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + case 4: + Visit(source, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(target.mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + case 8: + Visit(source, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(target.mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + default: { + int num_rows_to_process = + num_rows_to_append - + NumRowsToSkip(source, num_rows_to_append, row_ids, sizeof(uint64_t)); + Visit(source, num_rows_to_process, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + target.mutable_data(1) + + static_cast(num_bytes) * (num_rows_before + i)); + const uint64_t* src = reinterpret_cast(ptr); + for (uint32_t word_id = 0; + word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); + ++word_id) { + util::SafeStore(dst + word_id, util::SafeLoad(src + word_id)); + } + }); + if (num_rows_to_append > num_rows_to_process) { + Visit(source, num_rows_to_append - num_rows_to_process, + row_ids + num_rows_to_process, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + target.mutable_data(1) + + static_cast(num_bytes) * + (num_rows_before + num_rows_to_process + i)); + const uint64_t* src = reinterpret_cast(ptr); + memcpy(dst, src, num_bytes); + }); + } + } + } + } else { + // Varying length column + // + + // Step 1: calculate target offsets + // + uint32_t* offsets = reinterpret_cast(target.mutable_data(1)); + uint32_t sum = num_rows_before == 0 ? 0 : offsets[num_rows_before]; + Visit(source, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + offsets[num_rows_before + i] = num_bytes; + }); + for (int i = 0; i < num_rows_to_append; ++i) { + uint32_t length = offsets[num_rows_before + i]; + offsets[num_rows_before + i] = sum; + sum += length; + } + offsets[num_rows_before + num_rows_to_append] = sum; + + // Step 2: resize output buffers + // + RETURN_NOT_OK(target.ResizeVaryingLengthBuffer()); + + // Step 3: copy varying-length data + // + int num_rows_to_process = + num_rows_to_append - + NumRowsToSkip(source, num_rows_to_append, row_ids, sizeof(uint64_t)); + Visit(source, num_rows_to_process, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast(target.mutable_data(2) + + offsets[num_rows_before + i]); + const uint64_t* src = reinterpret_cast(ptr); + for (uint32_t word_id = 0; + word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) { + util::SafeStore(dst + word_id, util::SafeLoad(src + word_id)); + } + }); + Visit(source, num_rows_to_append - num_rows_to_process, row_ids + num_rows_to_process, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + target.mutable_data(2) + + offsets[num_rows_before + num_rows_to_process + i]); + const uint64_t* src = reinterpret_cast(ptr); + memcpy(dst, src, num_bytes); + }); + } + + // Process nulls + // + if (source->buffers[0] == NULLPTR) { + uint8_t* dst = target.mutable_data(0); + dst[num_rows_before / 8] |= static_cast(~0ULL << (num_rows_before & 7)); + for (int i = num_rows_before / 8 + 1; + i < bit_util::BytesForBits(num_rows_before + num_rows_to_append); ++i) { + dst[i] = 0xff; + } + } else { + CollectBits(source->buffers[0]->data(), source->offset, target.mutable_data(0), + num_rows_before, num_rows_to_append, row_ids); + } + + return Status::OK(); +} + +Status ExecBatchBuilder::AppendNulls(const std::shared_ptr& type, + ResizableArrayData& target, int num_rows_to_append, + MemoryPool* pool) { + int num_rows_before = target.num_rows(); + int num_rows_after = num_rows_before + num_rows_to_append; + if (target.num_rows() == 0) { + target.Init(type, pool, kLogNumRows); + } + RETURN_NOT_OK(target.ResizeFixedLengthBuffers(num_rows_after)); + + KeyEncoder::KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(type); + + // Process fixed length buffer + // + if (column_metadata.is_fixed_length) { + uint8_t* dst = target.mutable_data(1); + if (column_metadata.fixed_length == 0) { + dst[num_rows_before / 8] &= static_cast((1 << (num_rows_before % 8)) - 1); + int64_t offset_begin = num_rows_before / 8 + 1; + int64_t offset_end = bit_util::BytesForBits(num_rows_after); + if (offset_end > offset_begin) { + memset(dst + offset_begin, 0, offset_end - offset_begin); + } + } else { + memset(dst + num_rows_before * static_cast(column_metadata.fixed_length), + 0, static_cast(column_metadata.fixed_length) * num_rows_to_append); + } + } else { + uint32_t* dst = reinterpret_cast(target.mutable_data(1)); + uint32_t sum = num_rows_before == 0 ? 0 : dst[num_rows_before]; + for (int64_t i = num_rows_before; i <= num_rows_after; ++i) { + dst[i] = sum; + } + } + + // Process nulls + // + uint8_t* dst = target.mutable_data(0); + dst[num_rows_before / 8] &= static_cast((1 << (num_rows_before % 8)) - 1); + int64_t offset_begin = num_rows_before / 8 + 1; + int64_t offset_end = bit_util::BytesForBits(num_rows_after); + if (offset_end > offset_begin) { + memset(dst + offset_begin, 0, offset_end - offset_begin); + } + + return Status::OK(); +} + +Status ExecBatchBuilder::AppendSelected(MemoryPool* pool, const ExecBatch& batch, + int num_rows_to_append, const uint16_t* row_ids, + int num_cols, const int* col_ids) { + if (num_rows_to_append == 0) { + return Status::OK(); + } + // If this is the first time we append rows, then initialize output buffers. + // + if (values_.empty()) { + values_.resize(num_cols); + for (int i = 0; i < num_cols; ++i) { + const Datum& data = batch.values[col_ids ? col_ids[i] : i]; + ARROW_DCHECK(data.is_array()); + const std::shared_ptr& array_data = data.array(); + values_[i].Init(array_data->type, pool, kLogNumRows); + } + } + + for (size_t i = 0; i < values_.size(); ++i) { + const Datum& data = batch.values[col_ids ? col_ids[i] : i]; + ARROW_DCHECK(data.is_array()); + const std::shared_ptr& array_data = data.array(); + RETURN_NOT_OK( + AppendSelected(array_data, values_[i], num_rows_to_append, row_ids, pool)); + } + + return Status::OK(); +} + +Status ExecBatchBuilder::AppendSelected(MemoryPool* pool, const ExecBatch& batch, + int num_rows_to_append, const uint16_t* row_ids, + int* num_appended, int num_cols, + const int* col_ids) { + *num_appended = 0; + if (num_rows_to_append == 0) { + return Status::OK(); + } + int num_rows_max = 1 << kLogNumRows; + int num_rows_present = num_rows(); + if (num_rows_present >= num_rows_max) { + return Status::OK(); + } + int num_rows_available = num_rows_max - num_rows_present; + int num_rows_next = std::min(num_rows_available, num_rows_to_append); + RETURN_NOT_OK(AppendSelected(pool, batch, num_rows_next, row_ids, num_cols, col_ids)); + *num_appended = num_rows_next; + return Status::OK(); +} + +Status ExecBatchBuilder::AppendNulls(MemoryPool* pool, + const std::vector>& types, + int num_rows_to_append) { + if (num_rows_to_append == 0) { + return Status::OK(); + } + + // If this is the first time we append rows, then initialize output buffers. + // + if (values_.empty()) { + values_.resize(types.size()); + for (size_t i = 0; i < types.size(); ++i) { + values_[i].Init(types[i], pool, kLogNumRows); + } + } + + for (size_t i = 0; i < values_.size(); ++i) { + RETURN_NOT_OK(AppendNulls(types[i], values_[i], num_rows_to_append, pool)); + } + + return Status::OK(); +} + +Status ExecBatchBuilder::AppendNulls(MemoryPool* pool, + const std::vector>& types, + int num_rows_to_append, int* num_appended) { + *num_appended = 0; + if (num_rows_to_append == 0) { + return Status::OK(); + } + int num_rows_max = 1 << kLogNumRows; + int num_rows_present = num_rows(); + if (num_rows_present >= num_rows_max) { + return Status::OK(); + } + int num_rows_available = num_rows_max - num_rows_present; + int num_rows_next = std::min(num_rows_available, num_rows_to_append); + RETURN_NOT_OK(AppendNulls(pool, types, num_rows_next)); + *num_appended = num_rows_next; + return Status::OK(); +} + +ExecBatch ExecBatchBuilder::Flush() { + ARROW_DCHECK(num_rows() > 0); + ExecBatch out({}, num_rows()); + out.values.resize(values_.size()); + for (size_t i = 0; i < values_.size(); ++i) { + out.values[i] = values_[i].array_data(); + values_[i].Clear(true); + } + return out; +} + +int RowArrayAccessor::VarbinaryColumnId(const KeyEncoder::KeyRowMetadata& row_metadata, + int column_id) { + ARROW_DCHECK(row_metadata.num_cols() > static_cast(column_id)); + ARROW_DCHECK(!row_metadata.is_fixed_length); + ARROW_DCHECK(!row_metadata.column_metadatas[column_id].is_fixed_length); + + int varbinary_column_id = 0; + for (int i = 0; i < column_id; ++i) { + if (!row_metadata.column_metadatas[i].is_fixed_length) { + ++varbinary_column_id; + } + } + return varbinary_column_id; +} + +int RowArrayAccessor::NumRowsToSkip(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + int num_tail_bytes_to_skip) { + uint32_t num_bytes_skipped = 0; + int num_rows_left = num_rows; + + bool is_fixed_length_column = + rows.metadata().column_metadatas[column_id].is_fixed_length; + + if (!is_fixed_length_column) { + // Varying length column + // + int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); + + while (num_rows_left > 0 && + num_bytes_skipped < static_cast(num_tail_bytes_to_skip)) { + // Find the pointer to the last requested row + // + uint32_t last_row_id = row_ids[num_rows_left - 1]; + const uint8_t* row_ptr = rows.data(2) + rows.offsets()[last_row_id]; + + // Find the length of the requested varying length field in that row + // + uint32_t field_offset_within_row, field_length; + if (varbinary_column_id == 0) { + rows.metadata().first_varbinary_offset_and_length( + row_ptr, &field_offset_within_row, &field_length); + } else { + rows.metadata().nth_varbinary_offset_and_length( + row_ptr, varbinary_column_id, &field_offset_within_row, &field_length); + } + + num_bytes_skipped += field_length; + --num_rows_left; + } + } else { + // Fixed length column + // + uint32_t field_length = rows.metadata().column_metadatas[column_id].fixed_length; + uint32_t num_bytes_skipped = 0; + while (num_rows_left > 0 && + num_bytes_skipped < static_cast(num_tail_bytes_to_skip)) { + num_bytes_skipped += field_length; + --num_rows_left; + } + } + + return num_rows - num_rows_left; +} + +template +void RowArrayAccessor::Visit(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + PROCESS_VALUE_FN process_value_fn) { + bool is_fixed_length_column = + rows.metadata().column_metadatas[column_id].is_fixed_length; + + // There are 4 cases, each requiring different steps: + // 1. Varying length column that is the first varying length column in a row + // 2. Varying length column that is not the first varying length column in a + // row + // 3. Fixed length column in a fixed length row + // 4. Fixed length column in a varying length row + + if (!is_fixed_length_column) { + int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); + const uint8_t* row_ptr_base = rows.data(2); + const uint32_t* row_offsets = rows.offsets(); + uint32_t field_offset_within_row, field_length; + + if (varbinary_column_id == 0) { + // Case 1: This is the first varbinary column + // + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; + rows.metadata().first_varbinary_offset_and_length( + row_ptr, &field_offset_within_row, &field_length); + process_value_fn(i, row_ptr + field_offset_within_row, field_length); + } + } else { + // Case 2: This is second or later varbinary column + // + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; + rows.metadata().nth_varbinary_offset_and_length( + row_ptr, varbinary_column_id, &field_offset_within_row, &field_length); + process_value_fn(i, row_ptr + field_offset_within_row, field_length); + } + } + } + + if (is_fixed_length_column) { + uint32_t field_offset_within_row = rows.metadata().encoded_field_offset( + rows.metadata().pos_after_encoding(column_id)); + uint32_t field_length = rows.metadata().column_metadatas[column_id].fixed_length; + // Bit column is encoded as a single byte + // + if (field_length == 0) { + field_length = 1; + } + uint32_t row_length = rows.metadata().fixed_length; + + bool is_fixed_length_row = rows.metadata().is_fixed_length; + if (is_fixed_length_row) { + // Case 3: This is a fixed length column in a fixed length row + // + const uint8_t* row_ptr_base = rows.data(1) + field_offset_within_row; + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_length * row_id; + process_value_fn(i, row_ptr, field_length); + } + } else { + // Case 4: This is a fixed length column in a varying length row + // + const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row; + const uint32_t* row_offsets = rows.offsets(); + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; + process_value_fn(i, row_ptr, field_length); + } + } + } +} + +template +void RowArrayAccessor::VisitNulls(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + PROCESS_VALUE_FN process_value_fn) { + const uint8_t* null_masks = rows.null_masks(); + uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; + uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id); + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + int64_t bit_id = row_id * null_mask_num_bytes * 8 + pos_after_encoding; + process_value_fn(i, bit_util::GetBit(null_masks, bit_id) ? 0xff : 0); + } +} + +Status RowArray::InitIfNeeded(MemoryPool* pool, + const KeyEncoder::KeyRowMetadata& row_metadata) { + if (is_initialized_) { + return Status::OK(); + } + encoder_.Init(row_metadata.column_metadatas, sizeof(uint64_t), sizeof(uint64_t)); + RETURN_NOT_OK(rows_temp_.Init(pool, row_metadata)); + RETURN_NOT_OK(rows_.Init(pool, row_metadata)); + is_initialized_ = true; + return Status::OK(); +} + +Status RowArray::InitIfNeeded(MemoryPool* pool, const ExecBatch& batch) { + if (is_initialized_) { + return Status::OK(); + } + std::vector column_metadatas; + ColumnMetadatasFromExecBatch(batch, column_metadatas); + KeyEncoder::KeyRowMetadata row_metadata; + row_metadata.FromColumnMetadataVector(column_metadatas, sizeof(uint64_t), + sizeof(uint64_t)); + + return InitIfNeeded(pool, row_metadata); +} + +Status RowArray::AppendBatchSelection( + MemoryPool* pool, const ExecBatch& batch, int begin_row_id, int end_row_id, + int num_row_ids, const uint16_t* row_ids, + std::vector& temp_column_arrays) { + RETURN_NOT_OK(InitIfNeeded(pool, batch)); + ColumnArraysFromExecBatch(batch, begin_row_id, end_row_id - begin_row_id, + temp_column_arrays); + encoder_.PrepareEncodeSelected( + /*start_row=*/0, end_row_id - begin_row_id, temp_column_arrays); + RETURN_NOT_OK(encoder_.EncodeSelected(&rows_temp_, num_row_ids, row_ids)); + RETURN_NOT_OK(rows_.AppendSelectionFrom(rows_temp_, num_row_ids, nullptr)); + return Status::OK(); +} + +void RowArray::Compare(const ExecBatch& batch, int begin_row_id, int end_row_id, + int num_selected, const uint16_t* batch_selection_maybe_null, + const uint32_t* array_row_ids, uint32_t* out_num_not_equal, + uint16_t* out_not_equal_selection, int64_t hardware_flags, + util::TempVectorStack* temp_stack, + std::vector& temp_column_arrays, + uint8_t* out_match_bitvector_maybe_null) { + ColumnArraysFromExecBatch(batch, begin_row_id, end_row_id - begin_row_id, + temp_column_arrays); + + KeyEncoder::KeyEncoderContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + KeyCompare::CompareColumnsToRows( + num_selected, batch_selection_maybe_null, array_row_ids, &ctx, out_num_not_equal, + out_not_equal_selection, temp_column_arrays, rows_, + /*are_cols_in_encoding_order=*/false, out_match_bitvector_maybe_null); +} + +Status RowArray::DecodeSelected(ResizableArrayData* output, int column_id, + int num_rows_to_append, const uint32_t* row_ids, + MemoryPool* pool) const { + int num_rows_before = output->num_rows(); + RETURN_NOT_OK(output->ResizeFixedLengthBuffers(num_rows_before + num_rows_to_append)); + + // Both input (KeyRowArray) and output (ResizableArrayData) have buffers with + // extra bytes added at the end to avoid buffer overruns when using wide load + // instructions. + // + + KeyEncoder::KeyColumnMetadata column_metadata = output->column_metadata(); + + if (column_metadata.is_fixed_length) { + uint32_t fixed_length = column_metadata.fixed_length; + switch (fixed_length) { + case 0: + RowArrayAccessor::Visit(rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + bit_util::SetBitTo(output->mutable_data(1), + num_rows_before + i, *ptr != 0); + }); + break; + case 1: + RowArrayAccessor::Visit(rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + output->mutable_data(1)[num_rows_before + i] = *ptr; + }); + break; + case 2: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(output->mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + case 4: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(output->mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + case 8: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(output->mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + default: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + output->mutable_data(1) + num_bytes * (num_rows_before + i)); + const uint64_t* src = reinterpret_cast(ptr); + for (uint32_t word_id = 0; + word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) { + util::SafeStore(dst + word_id, util::SafeLoad(src + word_id)); + } + }); + break; + } + } else { + uint32_t* offsets = + reinterpret_cast(output->mutable_data(1)) + num_rows_before; + uint32_t sum = num_rows_before == 0 ? 0 : offsets[0]; + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { offsets[i] = num_bytes; }); + for (int i = 0; i < num_rows_to_append; ++i) { + uint32_t length = offsets[i]; + offsets[i] = sum; + sum += length; + } + offsets[num_rows_to_append] = sum; + RETURN_NOT_OK(output->ResizeVaryingLengthBuffer()); + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + output->mutable_data(2) + + reinterpret_cast( + output->mutable_data(1))[num_rows_before + i]); + const uint64_t* src = reinterpret_cast(ptr); + for (uint32_t word_id = 0; + word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) { + util::SafeStore(dst + word_id, util::SafeLoad(src + word_id)); + } + }); + } + + // Process nulls + // + RowArrayAccessor::VisitNulls( + rows_, column_id, num_rows_to_append, row_ids, [&](int i, uint8_t value) { + bit_util::SetBitTo(output->mutable_data(0), num_rows_before + i, value == 0); + }); + + return Status::OK(); +} + +void RowArray::DebugPrintToFile(const char* filename, bool print_sorted) const { + FILE* fout; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fout, filename, "wt"); +#else + fout = fopen(filename, "wt"); +#endif + if (!fout) { + return; + } + + for (int64_t row_id = 0; row_id < rows_.length(); ++row_id) { + for (uint32_t column_id = 0; column_id < rows_.metadata().num_cols(); ++column_id) { + bool is_null; + uint32_t row_id_cast = static_cast(row_id); + RowArrayAccessor::VisitNulls(rows_, column_id, 1, &row_id_cast, + [&](int i, uint8_t value) { is_null = (value != 0); }); + if (is_null) { + fprintf(fout, "null"); + } else { + RowArrayAccessor::Visit(rows_, column_id, 1, &row_id_cast, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + fprintf(fout, "\""); + for (uint32_t ibyte = 0; ibyte < num_bytes; ++ibyte) { + fprintf(fout, "%02x", ptr[ibyte]); + } + fprintf(fout, "\""); + }); + } + fprintf(fout, "\t"); + } + fprintf(fout, "\n"); + } + fclose(fout); + + if (print_sorted) { + struct stat sb; + if (stat(filename, &sb) == -1) { + ARROW_DCHECK(false); + return; + } + std::vector buffer; + buffer.resize(sb.st_size); + std::vector lines; + FILE* fin; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fin, filename, "rt"); +#else + fin = fopen(filename, "rt"); +#endif + if (!fin) { + return; + } + while (fgets(buffer.data(), static_cast(buffer.size()), fin)) { + lines.push_back(std::string(buffer.data())); + } + fclose(fin); + std::sort(lines.begin(), lines.end()); + FILE* fout2; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fout2, filename, "wt"); +#else + fout2 = fopen(filename, "wt"); +#endif + if (!fout2) { + return; + } + for (size_t i = 0; i < lines.size(); ++i) { + fprintf(fout2, "%s\n", lines[i].c_str()); + } + fclose(fout2); + } +} + +Status RowArrayMerge::PrepareForMerge(RowArray* target, + const std::vector& sources, + std::vector* first_target_row_id, + MemoryPool* pool) { + ARROW_DCHECK(!sources.empty()); + + ARROW_DCHECK(sources[0]->is_initialized_); + const KeyEncoder::KeyRowMetadata& metadata = sources[0]->rows_.metadata(); + ARROW_DCHECK(!target->is_initialized_); + RETURN_NOT_OK(target->InitIfNeeded(pool, metadata)); + + // Sum the number of rows from all input sources and calculate their total + // size. + // + int64_t num_rows = 0; + int64_t num_bytes = 0; + first_target_row_id->resize(sources.size() + 1); + for (size_t i = 0; i < sources.size(); ++i) { + // All input sources must be initialized and have the same row format. + // + ARROW_DCHECK(sources[i]->is_initialized_); + ARROW_DCHECK(metadata.is_compatible(sources[i]->rows_.metadata())); + (*first_target_row_id)[i] = num_rows; + num_rows += sources[i]->rows_.length(); + if (!metadata.is_fixed_length) { + num_bytes += sources[i]->rows_.offsets()[sources[i]->rows_.length()]; + } + } + (*first_target_row_id)[sources.size()] = num_rows; + + // Allocate target memory + // + target->rows_.Clean(); + RETURN_NOT_OK(target->rows_.AppendEmpty(static_cast(num_rows), + static_cast(num_bytes))); + + // In case of varying length rows, + // initialize the first row offset for each range of rows corresponding to a + // single source. + // + if (!metadata.is_fixed_length) { + num_rows = 0; + num_bytes = 0; + for (size_t i = 0; i < sources.size(); ++i) { + target->rows_.mutable_offsets()[num_rows] = static_cast(num_bytes); + num_rows += sources[i]->rows_.length(); + num_bytes += sources[i]->rows_.offsets()[sources[i]->rows_.length()]; + } + target->rows_.mutable_offsets()[num_rows] = static_cast(num_bytes); + } + + return Status::OK(); +} + +void RowArrayMerge::MergeSingle(RowArray* target, const RowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation) { + // Source and target must: + // - be initialized + // - use the same row format + // - use 64-bit alignment + // + ARROW_DCHECK(source.is_initialized_ && target->is_initialized_); + ARROW_DCHECK(target->rows_.metadata().is_compatible(source.rows_.metadata())); + ARROW_DCHECK(target->rows_.metadata().row_alignment == sizeof(uint64_t)); + + if (target->rows_.metadata().is_fixed_length) { + CopyFixedLength(&target->rows_, source.rows_, first_target_row_id, + source_rows_permutation); + } else { + CopyVaryingLength(&target->rows_, source.rows_, first_target_row_id, + target->rows_.offsets()[first_target_row_id], + source_rows_permutation); + } + CopyNulls(&target->rows_, source.rows_, first_target_row_id, source_rows_permutation); +} + +void RowArrayMerge::CopyFixedLength(KeyEncoder::KeyRowArray* target, + const KeyEncoder::KeyRowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation) { + int64_t num_source_rows = source.length(); + + int64_t fixed_length = target->metadata().fixed_length; + + // Permutation of source rows is optional. Without permutation all that is + // needed is memcpy. + // + if (!source_rows_permutation) { + memcpy(target->mutable_data(1) + fixed_length * first_target_row_id, source.data(1), + fixed_length * num_source_rows); + } else { + // Row length must be a multiple of 64-bits due to enforced alignment. + // Loop for each output row copying a fixed number of 64-bit words. + // + ARROW_DCHECK(fixed_length % sizeof(uint64_t) == 0); + + int64_t num_words_per_row = fixed_length / sizeof(uint64_t); + for (int64_t i = 0; i < num_source_rows; ++i) { + int64_t source_row_id = source_rows_permutation[i]; + const uint64_t* source_row_ptr = reinterpret_cast( + source.data(1) + fixed_length * source_row_id); + uint64_t* target_row_ptr = reinterpret_cast( + target->mutable_data(1) + fixed_length * (first_target_row_id + i)); + + for (int64_t word = 0; word < num_words_per_row; ++word) { + target_row_ptr[word] = source_row_ptr[word]; + } + } + } +} + +void RowArrayMerge::CopyVaryingLength(KeyEncoder::KeyRowArray* target, + const KeyEncoder::KeyRowArray& source, + int64_t first_target_row_id, + int64_t first_target_row_offset, + const int64_t* source_rows_permutation) { + int64_t num_source_rows = source.length(); + uint32_t* target_offsets = target->mutable_offsets(); + const uint32_t* source_offsets = source.offsets(); + + // Permutation of source rows is optional. + // + if (!source_rows_permutation) { + int64_t target_row_offset = first_target_row_offset; + for (int64_t i = 0; i < num_source_rows; ++i) { + target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_row_offset += source_offsets[i + 1] - source_offsets[i]; + } + // We purposefully skip outputting of N+1 offset, to allow concurrent + // copies of rows done to adjacent ranges in target array. + // It should have already been initialized during preparation for merge. + // + + // We can simply memcpy bytes of rows if their order has not changed. + // + memcpy(target->mutable_data(2) + target_offsets[first_target_row_id], source.data(2), + source_offsets[num_source_rows] - source_offsets[0]); + } else { + int64_t target_row_offset = first_target_row_offset; + uint64_t* target_row_ptr = + reinterpret_cast(target->mutable_data(2) + target_row_offset); + for (int64_t i = 0; i < num_source_rows; ++i) { + int64_t source_row_id = source_rows_permutation[i]; + const uint64_t* source_row_ptr = reinterpret_cast( + source.data(2) + source_offsets[source_row_id]); + uint32_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; + + // Rows should be 64-bit aligned. + // In that case we can copy them using a sequence of 64-bit read/writes. + // + ARROW_DCHECK(length % sizeof(uint64_t) == 0); + + for (uint32_t word = 0; word < length / sizeof(uint64_t); ++word) { + *target_row_ptr++ = *source_row_ptr++; + } + + target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_row_offset += length; + } + } +} + +void RowArrayMerge::CopyNulls(KeyEncoder::KeyRowArray* target, + const KeyEncoder::KeyRowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation) { + int64_t num_source_rows = source.length(); + int num_bytes_per_row = target->metadata().null_masks_bytes_per_row; + uint8_t* target_nulls = target->null_masks() + num_bytes_per_row * first_target_row_id; + if (!source_rows_permutation) { + memcpy(target_nulls, source.null_masks(), num_bytes_per_row * num_source_rows); + } else { + for (int64_t i = 0; i < num_source_rows; ++i) { + int64_t source_row_id = source_rows_permutation[i]; + const uint8_t* source_nulls = + source.null_masks() + num_bytes_per_row * source_row_id; + for (int64_t byte = 0; byte < num_bytes_per_row; ++byte) { + *target_nulls++ = *source_nulls++; + } + } + } +} + +Status SwissTableMerge::PrepareForMerge(SwissTable* target, + const std::vector& sources, + std::vector* first_target_group_id, + MemoryPool* pool) { + ARROW_DCHECK(!sources.empty()); + + // Each source should correspond to a range of hashes. + // A row belongs to a source with index determined by K highest bits of hash. + // That means that the number of sources must be a power of 2. + // + int log_num_sources = bit_util::Log2(sources.size()); + ARROW_DCHECK((1 << log_num_sources) == static_cast(sources.size())); + + // Determine the number of blocks in the target table. + // We will use max of numbers of blocks in any of the sources multiplied by + // the number of sources. + // + int log_blocks_max = 1; + for (size_t i = 0; i < sources.size(); ++i) { + log_blocks_max = std::max(log_blocks_max, sources[i]->log_blocks_); + } + int log_blocks = log_num_sources + log_blocks_max; + + // Allocate target blocks and mark all slots as empty + // + // We will skip allocating the array of hash values in target table. + // Target will be used in read-only mode and that array is only needed when + // resizing table which may occur only after new inserts. + // + RETURN_NOT_OK(target->init(sources[0]->hardware_flags_, pool, log_blocks, + /*no_hash_array=*/true)); + + // Calculate and output the first group id index for each source. + // + uint32_t num_groups = 0; + first_target_group_id->resize(sources.size()); + for (size_t i = 0; i < sources.size(); ++i) { + (*first_target_group_id)[i] = num_groups; + num_groups += sources[i]->num_inserted_; + } + target->num_inserted_ = num_groups; + + return Status::OK(); +} + +void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* source, + uint32_t partition_id, int num_partition_bits, + uint32_t base_group_id, + std::vector* overflow_group_ids, + std::vector* overflow_hashes) { + // Prepare parameters needed for scanning full slots in source. + // + int source_group_id_bits = + SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks_); + uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits); + int64_t source_block_bytes = source_group_id_bits + 8; + ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0); + + // Compute index of the last block in target that corresponds to the given + // partition. + // + ARROW_DCHECK(num_partition_bits <= target->log_blocks_); + int64_t target_max_block_id = + ((partition_id + 1) << (target->log_blocks_ - num_partition_bits)) - 1; + + overflow_group_ids->clear(); + overflow_hashes->clear(); + + // For each source block... + int64_t source_blocks = 1LL << source->log_blocks_; + for (int64_t block_id = 0; block_id < source_blocks; ++block_id) { + uint8_t* block_bytes = source->blocks_ + block_id * source_block_bytes; + uint64_t block = *reinterpret_cast(block_bytes); + + // For each non-empty source slot... + constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; + constexpr int kSlotsPerBlock = 8; + int num_full_slots = + kSlotsPerBlock - static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) { + // Read group id and hash for this slot. + // + uint64_t group_id = + source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask); + int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; + uint32_t hash = source->hashes_[global_slot_id]; + // Insert partition id into the highest bits of hash, shifting the + // remaining hash bits right. + // + hash >>= num_partition_bits; + hash |= (partition_id << (SwissTable::bits_hash_ - 1 - num_partition_bits) << 1); + // Add base group id + // + group_id += base_group_id; + + // Insert new entry into target. Store in overflow vectors if not + // successful. + // + bool was_inserted = InsertNewGroup(target, group_id, hash, target_max_block_id); + if (!was_inserted) { + overflow_group_ids->push_back(static_cast(group_id)); + overflow_hashes->push_back(hash); + } + } + } +} + +inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id, + uint32_t hash, int64_t max_block_id) { + // Load the first block to visit for this hash + // + int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks_); + int64_t block_id_mask = ((1LL << target->log_blocks_) - 1); + int num_group_id_bits = + SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks_); + int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t); + ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); + uint8_t* block_bytes = target->blocks_ + block_id * num_block_bytes; + uint64_t block = *reinterpret_cast(block_bytes); + + // Search for the first block with empty slots. + // Stop after reaching max block id. + // + constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; + while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) { + block_id = (block_id + 1) & block_id_mask; + block_bytes = target->blocks_ + block_id * num_block_bytes; + block = *reinterpret_cast(block_bytes); + } + if ((block & kHighBitOfEachByte) == 0) { + return false; + } + constexpr int kSlotsPerBlock = 8; + int local_slot_id = + kSlotsPerBlock - static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; + target->insert_into_empty_slot(static_cast(global_slot_id), hash, + static_cast(group_id)); + return true; +} + +void SwissTableMerge::InsertNewGroups(SwissTable* target, + const std::vector& group_ids, + const std::vector& hashes) { + int64_t num_blocks = 1LL << target->log_blocks_; + for (size_t i = 0; i < group_ids.size(); ++i) { + std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks); + } +} + +SwissTableWithKeys::Input::Input( + const ExecBatch* in_batch, int in_batch_start_row, int in_batch_end_row, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays) + : batch(in_batch), + batch_start_row(in_batch_start_row), + batch_end_row(in_batch_end_row), + num_selected(0), + selection_maybe_null(nullptr), + temp_stack(in_temp_stack), + temp_column_arrays(in_temp_column_arrays), + temp_group_ids(nullptr) {} + +SwissTableWithKeys::Input::Input( + const ExecBatch* in_batch, util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays) + : batch(in_batch), + batch_start_row(0), + batch_end_row(static_cast(in_batch->length)), + num_selected(0), + selection_maybe_null(nullptr), + temp_stack(in_temp_stack), + temp_column_arrays(in_temp_column_arrays), + temp_group_ids(nullptr) {} + +SwissTableWithKeys::Input::Input( + const ExecBatch* in_batch, int in_num_selected, const uint16_t* in_selection, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays, + std::vector* in_temp_group_ids) + : batch(in_batch), + batch_start_row(0), + batch_end_row(static_cast(in_batch->length)), + num_selected(in_num_selected), + selection_maybe_null(in_selection), + temp_stack(in_temp_stack), + temp_column_arrays(in_temp_column_arrays), + temp_group_ids(in_temp_group_ids) {} + +SwissTableWithKeys::Input::Input(const Input& base, int num_rows_to_skip, + int num_rows_to_include) + : batch(base.batch), + temp_stack(base.temp_stack), + temp_column_arrays(base.temp_column_arrays), + temp_group_ids(base.temp_group_ids) { + if (base.selection_maybe_null) { + batch_start_row = 0; + batch_end_row = static_cast(batch->length); + ARROW_DCHECK(num_rows_to_skip + num_rows_to_include <= base.num_selected); + num_selected = num_rows_to_include; + selection_maybe_null = base.selection_maybe_null + num_rows_to_skip; + } else { + ARROW_DCHECK(base.batch_start_row + num_rows_to_skip + num_rows_to_include <= + base.batch_end_row); + batch_start_row = base.batch_start_row + num_rows_to_skip; + batch_end_row = base.batch_start_row + num_rows_to_skip + num_rows_to_include; + num_selected = 0; + selection_maybe_null = nullptr; + } +} + +Status SwissTableWithKeys::Init(int64_t hardware_flags, MemoryPool* pool) { + InitCallbacks(); + return swiss_table_.init(hardware_flags, pool); +} + +void SwissTableWithKeys::EqualCallback(int num_keys, const uint16_t* selection_maybe_null, + const uint32_t* group_ids, + uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, + void* callback_ctx) { + if (num_keys == 0) { + *out_num_keys_mismatch = 0; + return; + } + + ARROW_DCHECK(num_keys <= swiss_table_.minibatch_size()); + + Input* in = reinterpret_cast(callback_ctx); + + int64_t hardware_flags = swiss_table_.hardware_flags(); + + int batch_start_to_use; + int batch_end_to_use; + const uint16_t* selection_to_use; + const uint32_t* group_ids_to_use; + + if (in->selection_maybe_null) { + auto selection_to_use_buf = + util::TempVectorHolder(in->temp_stack, num_keys); + ARROW_DCHECK(in->temp_group_ids); + in->temp_group_ids->resize(in->batch->length); + + if (selection_maybe_null) { + for (int i = 0; i < num_keys; ++i) { + uint16_t local_row_id = selection_maybe_null[i]; + uint16_t global_row_id = in->selection_maybe_null[local_row_id]; + selection_to_use_buf.mutable_data()[i] = global_row_id; + (*in->temp_group_ids)[global_row_id] = group_ids[local_row_id]; + } + selection_to_use = selection_to_use_buf.mutable_data(); + } else { + for (int i = 0; i < num_keys; ++i) { + uint16_t global_row_id = in->selection_maybe_null[i]; + (*in->temp_group_ids)[global_row_id] = group_ids[i]; + } + selection_to_use = in->selection_maybe_null; + } + batch_start_to_use = 0; + batch_end_to_use = static_cast(in->batch->length); + group_ids_to_use = in->temp_group_ids->data(); + + auto match_bitvector_buf = util::TempVectorHolder(in->temp_stack, num_keys); + uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); + + keys_.Compare(*in->batch, batch_start_to_use, batch_end_to_use, num_keys, + selection_to_use, group_ids_to_use, nullptr, nullptr, hardware_flags, + in->temp_stack, *in->temp_column_arrays, match_bitvector); + + if (selection_maybe_null) { + int num_keys_mismatch = 0; + util::bit_util::bits_filter_indexes(0, hardware_flags, num_keys, match_bitvector, + selection_maybe_null, &num_keys_mismatch, + out_selection_mismatch); + *out_num_keys_mismatch = num_keys_mismatch; + } else { + int num_keys_mismatch = 0; + util::bit_util::bits_to_indexes(0, hardware_flags, num_keys, match_bitvector, + &num_keys_mismatch, out_selection_mismatch); + *out_num_keys_mismatch = num_keys_mismatch; + } + + } else { + batch_start_to_use = in->batch_start_row; + batch_end_to_use = in->batch_end_row; + selection_to_use = selection_maybe_null; + group_ids_to_use = group_ids; + keys_.Compare(*in->batch, batch_start_to_use, batch_end_to_use, num_keys, + selection_to_use, group_ids_to_use, out_num_keys_mismatch, + out_selection_mismatch, hardware_flags, in->temp_stack, + *in->temp_column_arrays); + } +} + +Status SwissTableWithKeys::AppendCallback(int num_keys, const uint16_t* selection, + void* callback_ctx) { + ARROW_DCHECK(num_keys <= swiss_table_.minibatch_size()); + ARROW_DCHECK(selection); + + Input* in = reinterpret_cast(callback_ctx); + + int batch_start_to_use; + int batch_end_to_use; + const uint16_t* selection_to_use; + + if (in->selection_maybe_null) { + auto selection_to_use_buf = + util::TempVectorHolder(in->temp_stack, num_keys); + for (int i = 0; i < num_keys; ++i) { + selection_to_use_buf.mutable_data()[i] = in->selection_maybe_null[selection[i]]; + } + batch_start_to_use = 0; + batch_end_to_use = static_cast(in->batch->length); + selection_to_use = selection_to_use_buf.mutable_data(); + + return keys_.AppendBatchSelection(swiss_table_.pool(), *in->batch, batch_start_to_use, + batch_end_to_use, num_keys, selection_to_use, + *in->temp_column_arrays); + } else { + batch_start_to_use = in->batch_start_row; + batch_end_to_use = in->batch_end_row; + selection_to_use = selection; + + return keys_.AppendBatchSelection(swiss_table_.pool(), *in->batch, batch_start_to_use, + batch_end_to_use, num_keys, selection_to_use, + *in->temp_column_arrays); + } +} + +void SwissTableWithKeys::InitCallbacks() { + equal_impl_ = [&](int num_keys, const uint16_t* selection_maybe_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx) { + EqualCallback(num_keys, selection_maybe_null, group_ids, out_num_keys_mismatch, + out_selection_mismatch, callback_ctx); + }; + append_impl_ = [&](int num_keys, const uint16_t* selection, void* callback_ctx) { + return AppendCallback(num_keys, selection, callback_ctx); + }; +} + +void SwissTableWithKeys::Hash(Input* input, uint32_t* hashes, int64_t hardware_flags) { + // Hashing does not support selection of rows + // + ARROW_DCHECK(input->selection_maybe_null == nullptr); + + Hashing32::HashBatch(*input->batch, input->batch_start_row, + input->batch_end_row - input->batch_start_row, hashes, + *input->temp_column_arrays, hardware_flags, input->temp_stack); +} + +void SwissTableWithKeys::MapReadOnly(Input* input, const uint32_t* hashes, + uint8_t* match_bitvector, uint32_t* key_ids) { + std::ignore = Map(input, /*insert_missing=*/false, hashes, match_bitvector, key_ids); +} + +Status SwissTableWithKeys::MapWithInserts(Input* input, const uint32_t* hashes, + uint32_t* key_ids) { + return Map(input, /*insert_missing=*/true, hashes, nullptr, key_ids); +} + +Status SwissTableWithKeys::Map(Input* input, bool insert_missing, const uint32_t* hashes, + uint8_t* match_bitvector_maybe_null, uint32_t* key_ids) { + util::TempVectorStack* temp_stack = input->temp_stack; + + // Split into smaller mini-batches + // + int minibatch_size = swiss_table_.minibatch_size(); + int num_rows_to_process = input->selection_maybe_null + ? input->num_selected + : input->batch_end_row - input->batch_start_row; + auto hashes_buf = util::TempVectorHolder(temp_stack, minibatch_size); + auto match_bitvector_buf = util::TempVectorHolder( + temp_stack, + static_cast(bit_util::BytesForBits(minibatch_size)) + sizeof(uint64_t)); + for (int minibatch_start = 0; minibatch_start < num_rows_to_process;) { + int minibatch_size_next = + std::min(minibatch_size, num_rows_to_process - minibatch_start); + + // Prepare updated input buffers that represent the current minibatch. + // + Input minibatch_input(*input, minibatch_start, minibatch_size_next); + uint8_t* minibatch_match_bitvector = + insert_missing ? match_bitvector_buf.mutable_data() + : match_bitvector_maybe_null + minibatch_start / 8; + const uint32_t* minibatch_hashes; + if (input->selection_maybe_null) { + minibatch_hashes = hashes_buf.mutable_data(); + for (int i = 0; i < minibatch_size_next; ++i) { + hashes_buf.mutable_data()[i] = hashes[minibatch_input.selection_maybe_null[i]]; + } + } else { + minibatch_hashes = hashes + minibatch_start; + } + uint32_t* minibatch_key_ids = key_ids + minibatch_start; + + // Lookup existing keys. + { + auto slots = util::TempVectorHolder(temp_stack, minibatch_size_next); + swiss_table_.early_filter(minibatch_size_next, minibatch_hashes, + minibatch_match_bitvector, slots.mutable_data()); + swiss_table_.find(minibatch_size_next, minibatch_hashes, minibatch_match_bitvector, + slots.mutable_data(), minibatch_key_ids, temp_stack, equal_impl_, + &minibatch_input); + } + + // Perform inserts of missing keys if required. + // + if (insert_missing) { + auto ids_buf = util::TempVectorHolder(temp_stack, minibatch_size_next); + int num_ids; + util::bit_util::bits_to_indexes(0, swiss_table_.hardware_flags(), + minibatch_size_next, minibatch_match_bitvector, + &num_ids, ids_buf.mutable_data()); + + RETURN_NOT_OK(swiss_table_.map_new_keys( + num_ids, ids_buf.mutable_data(), minibatch_hashes, minibatch_key_ids, + temp_stack, equal_impl_, append_impl_, &minibatch_input)); + } + + minibatch_start += minibatch_size_next; + } + + return Status::OK(); +} + +void SwissTableForJoin::Lookup( + const ExecBatch& batch, int start_row, int num_rows, uint8_t* out_has_match_bitvector, + uint32_t* out_key_ids, util::TempVectorStack* temp_stack, + std::vector* temp_column_arrays) { + SwissTableWithKeys::Input input(&batch, start_row, start_row + num_rows, temp_stack, + temp_column_arrays); + + // Split into smaller mini-batches + // + int minibatch_size = map_.swiss_table()->minibatch_size(); + auto hashes_buf = util::TempVectorHolder(temp_stack, minibatch_size); + for (int minibatch_start = 0; minibatch_start < num_rows;) { + uint32_t minibatch_size_next = std::min(minibatch_size, num_rows - minibatch_start); + + SwissTableWithKeys::Input minibatch_input(input, minibatch_start, + minibatch_size_next); + + SwissTableWithKeys::Hash(&minibatch_input, hashes_buf.mutable_data(), + map_.swiss_table()->hardware_flags()); + map_.MapReadOnly(&minibatch_input, hashes_buf.mutable_data(), + out_has_match_bitvector + minibatch_start / 8, + out_key_ids + minibatch_start); + + minibatch_start += minibatch_size_next; + } +} + +uint8_t* SwissTableForJoin::local_has_match(int64_t thread_id) { + int64_t num_rows_hash_table = num_rows(); + if (num_rows_hash_table == 0) { + return nullptr; + } + + ThreadLocalState& local_state = local_states_[thread_id]; + if (local_state.has_match.empty() && num_rows_hash_table > 0) { + local_state.has_match.resize(bit_util::BytesForBits(num_rows_hash_table) + + sizeof(uint64_t)); + memset(local_state.has_match.data(), 0, bit_util::BytesForBits(num_rows_hash_table)); + } + + return local_states_[thread_id].has_match.data(); +} + +void SwissTableForJoin::UpdateHasMatchForKeys(int64_t thread_id, int num_ids, + const uint32_t* key_ids) { + uint8_t* bit_vector = local_has_match(thread_id); + if (num_ids == 0 || !bit_vector) { + return; + } + for (int i = 0; i < num_ids; ++i) { + // Mark row in hash table as having a match + // + bit_util::SetBit(bit_vector, key_ids[i]); + } +} + +void SwissTableForJoin::MergeHasMatch() { + int64_t num_rows_hash_table = num_rows(); + if (num_rows_hash_table == 0) { + return; + } + + has_match_.resize(bit_util::BytesForBits(num_rows_hash_table) + sizeof(uint64_t)); + memset(has_match_.data(), 0, bit_util::BytesForBits(num_rows_hash_table)); + + for (size_t tid = 0; tid < local_states_.size(); ++tid) { + if (!local_states_[tid].has_match.empty()) { + arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(), + 0, num_rows_hash_table, 0, has_match_.data()); + } + } +} + +uint32_t SwissTableForJoin::payload_id_to_key_id(uint32_t payload_id) const { + if (no_duplicate_keys_) { + return payload_id; + } + int64_t num_entries = num_keys(); + const uint32_t* entries = key_to_payload(); + ARROW_DCHECK(entries); + ARROW_DCHECK(entries[num_entries] > payload_id); + const uint32_t* first_greater = + std::upper_bound(entries, entries + num_entries + 1, payload_id); + ARROW_DCHECK(first_greater > entries); + return static_cast(first_greater - entries) - 1; +} + +void SwissTableForJoin::payload_ids_to_key_ids(int num_rows, const uint32_t* payload_ids, + uint32_t* key_ids) const { + if (num_rows == 0) { + return; + } + if (no_duplicate_keys_) { + memcpy(key_ids, payload_ids, num_rows * sizeof(uint32_t)); + return; + } + + const uint32_t* entries = key_to_payload(); + uint32_t key_id = payload_id_to_key_id(payload_ids[0]); + key_ids[0] = key_id; + for (int i = 1; i < num_rows; ++i) { + ARROW_DCHECK(payload_ids[i] > payload_ids[i - 1]); + while (entries[key_id + 1] <= payload_ids[i]) { + ++key_id; + ARROW_DCHECK(key_id < num_keys()); + } + key_ids[i] = key_id; + } +} + +Status SwissTableForJoinBuild::Init( + SwissTableForJoin* target, int dop, int64_t num_rows, bool reject_duplicate_keys, + bool no_payload, const std::vector& key_types, + const std::vector& payload_types, MemoryPool* pool, + int64_t hardware_flags) { + target_ = target; + dop_ = dop; + num_rows_ = num_rows; + + // Make sure that we do not use many partitions if there are not enough rows. + // + constexpr int64_t min_num_rows_per_prtn = 1 << 18; + log_num_prtns_ = + std::min(bit_util::Log2(dop_), + bit_util::Log2(bit_util::CeilDiv(num_rows, min_num_rows_per_prtn))); + num_prtns_ = 1 << log_num_prtns_; + + reject_duplicate_keys_ = reject_duplicate_keys; + no_payload_ = no_payload; + pool_ = pool; + hardware_flags_ = hardware_flags; + + prtn_states_.resize(num_prtns_); + thread_states_.resize(dop_); + prtn_locks_.Init(num_prtns_); + + KeyEncoder::KeyRowMetadata key_row_metadata; + key_row_metadata.FromColumnMetadataVector(key_types, + /*row_alignment=*/sizeof(uint64_t), + /*string_alignment=*/sizeof(uint64_t)); + KeyEncoder::KeyRowMetadata payload_row_metadata; + payload_row_metadata.FromColumnMetadataVector(payload_types, + /*row_alignment=*/sizeof(uint64_t), + /*string_alignment=*/sizeof(uint64_t)); + + for (int i = 0; i < num_prtns_; ++i) { + PartitionState& prtn_state = prtn_states_[i]; + RETURN_NOT_OK(prtn_state.keys.Init(hardware_flags_, pool_)); + RETURN_NOT_OK(prtn_state.keys.keys()->InitIfNeeded(pool, key_row_metadata)); + RETURN_NOT_OK(prtn_state.payloads.InitIfNeeded(pool, payload_row_metadata)); + } + + target_->dop_ = dop_; + target_->local_states_.resize(dop_); + target_->no_payload_columns_ = no_payload; + target_->no_duplicate_keys_ = reject_duplicate_keys; + target_->map_.InitCallbacks(); + + return Status::OK(); +} + +Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id, + const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack) { + ARROW_DCHECK(thread_id < dop_); + ThreadState& locals = thread_states_[thread_id]; + + // Compute hash + // + locals.batch_hashes.resize(key_batch.length); + Hashing32::HashBatch(key_batch, /*start_row=*/0, static_cast(key_batch.length), + locals.batch_hashes.data(), locals.temp_column_arrays, + hardware_flags_, temp_stack); + + // Partition on hash + // + locals.batch_prtn_row_ids.resize(locals.batch_hashes.size()); + locals.batch_prtn_ranges.resize(num_prtns_ + 1); + int num_rows = static_cast(locals.batch_hashes.size()); + if (num_prtns_ == 1) { + // We treat single partition case separately to avoid extra checks in row + // partitioning implementation for general case. + // + locals.batch_prtn_ranges[0] = 0; + locals.batch_prtn_ranges[1] = num_rows; + for (int i = 0; i < num_rows; ++i) { + locals.batch_prtn_row_ids[i] = i; + } + } else { + PartitionSort::Eval( + static_cast(locals.batch_hashes.size()), num_prtns_, + locals.batch_prtn_ranges.data(), + [this, &locals](int i) { + // SwissTable uses the highest bits of the hash for block index. + // We want each partition to correspond to a range of block indices, + // so we also partition on the highest bits of the hash. + // + return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1; + }, + [&locals](int i, int pos) { locals.batch_prtn_row_ids[pos] = i; }); + } + + // Update hashes, shifting left to get rid of the bits that were already used + // for partitioning. + // + for (size_t i = 0; i < locals.batch_hashes.size(); ++i) { + locals.batch_hashes[i] <<= log_num_prtns_; + } + + // For each partition: + // - map keys to unique integers using (this partition's) hash table + // - append payloads (if present) to (this partition's) row array + // + locals.temp_prtn_ids.resize(num_prtns_); + + RETURN_NOT_OK(prtn_locks_.ForEachPartition( + locals.temp_prtn_ids.data(), + /*is_prtn_empty_fn=*/ + [&](int prtn_id) { + return locals.batch_prtn_ranges[prtn_id + 1] == locals.batch_prtn_ranges[prtn_id]; + }, + /*process_prtn_fn=*/ + [&](int prtn_id) { + return ProcessPartition(thread_id, key_batch, payload_batch_maybe_null, + temp_stack, prtn_id); + })); + + return Status::OK(); +} + +Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id, + const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack, + int prtn_id) { + ARROW_DCHECK(thread_id < dop_); + ThreadState& locals = thread_states_[thread_id]; + + int num_rows_new = + locals.batch_prtn_ranges[prtn_id + 1] - locals.batch_prtn_ranges[prtn_id]; + const uint16_t* row_ids = + locals.batch_prtn_row_ids.data() + locals.batch_prtn_ranges[prtn_id]; + PartitionState& prtn_state = prtn_states_[prtn_id]; + size_t num_rows_before = prtn_state.key_ids.size(); + // Insert new keys into hash table associated with the current partition + // and map existing keys to integer ids. + // + prtn_state.key_ids.resize(num_rows_before + num_rows_new); + SwissTableWithKeys::Input input(&key_batch, num_rows_new, row_ids, temp_stack, + &locals.temp_column_arrays, &locals.temp_group_ids); + RETURN_NOT_OK(prtn_state.keys.MapWithInserts( + &input, locals.batch_hashes.data(), prtn_state.key_ids.data() + num_rows_before)); + // Append input batch rows from current partition to an array of payload + // rows for this partition. + // + // The order of payloads is the same as the order of key ids accumulated + // in a vector (we will use the vector of key ids later on to sort + // payload on key ids before merging into the final row array). + // + if (!no_payload_) { + ARROW_DCHECK(payload_batch_maybe_null); + RETURN_NOT_OK(prtn_state.payloads.AppendBatchSelection( + pool_, *payload_batch_maybe_null, 0, + static_cast(payload_batch_maybe_null->length), num_rows_new, row_ids, + locals.temp_column_arrays)); + } + // We do not need to keep track of key ids if we reject rows with + // duplicate keys. + // + if (reject_duplicate_keys_) { + prtn_state.key_ids.clear(); + } + return Status::OK(); +} + +Status SwissTableForJoinBuild::PreparePrtnMerge() { + // There are 4 data structures that require partition merging: + // 1. array of key rows + // 2. SwissTable + // 3. array of payload rows (only when no_payload_ is false) + // 4. mapping from key id to first payload id (only when + // reject_duplicate_keys_ is false and there are duplicate keys) + // + + // 1. Array of key rows: + // + std::vector partition_keys; + partition_keys.resize(num_prtns_); + for (int i = 0; i < num_prtns_; ++i) { + partition_keys[i] = prtn_states_[i].keys.keys(); + } + RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(target_->map_.keys(), partition_keys, + &partition_keys_first_row_id_, pool_)); + + // 2. SwissTable: + // + std::vector partition_tables; + partition_tables.resize(num_prtns_); + for (int i = 0; i < num_prtns_; ++i) { + partition_tables[i] = prtn_states_[i].keys.swiss_table(); + } + std::vector partition_first_group_id; + RETURN_NOT_OK(SwissTableMerge::PrepareForMerge( + target_->map_.swiss_table(), partition_tables, &partition_first_group_id, pool_)); + + // 3. Array of payload rows: + // + if (!no_payload_) { + std::vector partition_payloads; + partition_payloads.resize(num_prtns_); + for (int i = 0; i < num_prtns_; ++i) { + partition_payloads[i] = &prtn_states_[i].payloads; + } + RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(&target_->payloads_, partition_payloads, + &partition_payloads_first_row_id_, + pool_)); + } + + // Check if we have duplicate keys + // + int64_t num_keys = partition_keys_first_row_id_[num_prtns_]; + int64_t num_rows = 0; + for (int i = 0; i < num_prtns_; ++i) { + num_rows += static_cast(prtn_states_[i].key_ids.size()); + } + bool no_duplicate_keys = reject_duplicate_keys_ || num_keys == num_rows; + + // 4. Mapping from key id to first payload id: + // + target_->no_duplicate_keys_ = no_duplicate_keys; + if (!no_duplicate_keys) { + target_->row_offset_for_key_.resize(num_keys + 1); + int64_t num_rows = 0; + for (int i = 0; i < num_prtns_; ++i) { + int64_t first_key = partition_keys_first_row_id_[i]; + target_->row_offset_for_key_[first_key] = static_cast(num_rows); + num_rows += static_cast(prtn_states_[i].key_ids.size()); + } + target_->row_offset_for_key_[num_keys] = static_cast(num_rows); + } + + return Status::OK(); +} + +void SwissTableForJoinBuild::PrtnMerge(int prtn_id) { + PartitionState& prtn_state = prtn_states_[prtn_id]; + + // There are 4 data structures that require partition merging: + // 1. array of key rows + // 2. SwissTable + // 3. mapping from key id to first payload id (only when + // reject_duplicate_keys_ is false and there are duplicate keys) + // 4. array of payload rows (only when no_payload_ is false) + // + + // 1. Array of key rows: + // + RowArrayMerge::MergeSingle(target_->map_.keys(), *prtn_state.keys.keys(), + partition_keys_first_row_id_[prtn_id], + /*source_rows_permutation=*/nullptr); + + // 2. SwissTable: + // + SwissTableMerge::MergePartition( + target_->map_.swiss_table(), prtn_state.keys.swiss_table(), prtn_id, log_num_prtns_, + static_cast(partition_keys_first_row_id_[prtn_id]), + &prtn_state.overflow_key_ids, &prtn_state.overflow_hashes); + + std::vector source_payload_ids; + + // 3. mapping from key id to first payload id + // + if (!target_->no_duplicate_keys_) { + // Count for each local (within partition) key id how many times it appears + // in input rows. + // + // For convenience, we use an array in merged hash table mapping key ids to + // first payload ids to collect the counters. + // + int64_t first_key = partition_keys_first_row_id_[prtn_id]; + int64_t num_keys = partition_keys_first_row_id_[prtn_id + 1] - first_key; + uint32_t* counters = target_->row_offset_for_key_.data() + first_key; + uint32_t first_payload = counters[0]; + for (int64_t i = 0; i < num_keys; ++i) { + counters[i] = 0; + } + for (size_t i = 0; i < prtn_state.key_ids.size(); ++i) { + uint32_t key_id = prtn_state.key_ids[i]; + ++counters[key_id]; + } + + if (!no_payload_) { + // Count sort payloads on key id + // + // Start by computing inclusive cummulative sum of counters. + // + uint32_t sum = 0; + for (int64_t i = 0; i < num_keys; ++i) { + sum += counters[i]; + counters[i] = sum; + } + // Now use cummulative sum of counters to obtain the target position in + // the sorted order for each row. At the end of this process the counters + // will contain exclusive cummulative sum (instead of inclusive that is + // there at the beginning). + // + source_payload_ids.resize(prtn_state.key_ids.size()); + for (size_t i = 0; i < prtn_state.key_ids.size(); ++i) { + uint32_t key_id = prtn_state.key_ids[i]; + int64_t position = --counters[key_id]; + source_payload_ids[position] = static_cast(i); + } + // Add base payload id to all of the counters. + // + for (int64_t i = 0; i < num_keys; ++i) { + counters[i] += first_payload; + } + } else { + // When there is no payload to process, we just need to compute exclusive + // cummulative sum of counters and add the base payload id to all of them. + // + uint32_t sum = 0; + for (int64_t i = 0; i < num_keys; ++i) { + uint32_t sum_next = sum + counters[i]; + counters[i] = sum + first_payload; + sum = sum_next; + } + } + } + + // 4. Array of payload rows: + // + if (!no_payload_) { + // If there are duplicate keys, then we have already initialized permutation + // of payloads for this partition. + // + if (target_->no_duplicate_keys_) { + source_payload_ids.resize(prtn_state.key_ids.size()); + for (size_t i = 0; i < prtn_state.key_ids.size(); ++i) { + uint32_t key_id = prtn_state.key_ids[i]; + source_payload_ids[key_id] = static_cast(i); + } + } + // Merge partition payloads into target array using the permutation. + // + RowArrayMerge::MergeSingle(&target_->payloads_, prtn_state.payloads, + partition_payloads_first_row_id_[prtn_id], + source_payload_ids.data()); + + // TODO: Uncomment for debugging + // prtn_state.payloads.DebugPrintToFile("payload_local.txt", false); + } +} + +void SwissTableForJoinBuild::FinishPrtnMerge(util::TempVectorStack* temp_stack) { + // Process overflow key ids + // + for (int prtn_id = 0; prtn_id < num_prtns_; ++prtn_id) { + SwissTableMerge::InsertNewGroups(target_->map_.swiss_table(), + prtn_states_[prtn_id].overflow_key_ids, + prtn_states_[prtn_id].overflow_hashes); + } + + // Calculate whether we have nulls in hash table keys + // (it is lazily evaluated but since we will be accessing it from multiple + // threads we need to make sure that the value gets calculated here). + // + KeyEncoder::KeyEncoderContext ctx; + ctx.hardware_flags = hardware_flags_; + ctx.stack = temp_stack; + std::ignore = target_->map_.keys()->rows_.has_any_nulls(&ctx); +} + +void JoinResultMaterialize::Init(MemoryPool* pool, + const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas) { + pool_ = pool; + probe_schemas_ = probe_schemas; + build_schemas_ = build_schemas; + num_rows_ = 0; + null_ranges_.clear(); + num_produced_batches_ = 0; + + // Initialize mapping of columns from output batch column index to key and + // payload batch column index. + // + probe_output_to_key_and_payload_.resize( + probe_schemas_->num_cols(HashJoinProjection::OUTPUT)); + int num_key_cols = probe_schemas_->num_cols(HashJoinProjection::KEY); + auto to_key = probe_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + auto to_payload = + probe_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); + for (int i = 0; static_cast(i) < probe_output_to_key_and_payload_.size(); ++i) { + probe_output_to_key_and_payload_[i] = + to_key.get(i) == SchemaProjectionMap::kMissingField + ? to_payload.get(i) + num_key_cols + : to_key.get(i); + } +} + +void JoinResultMaterialize::SetBuildSide(const RowArray* build_keys, + const RowArray* build_payloads, + bool payload_id_same_as_key_id) { + build_keys_ = build_keys; + build_payloads_ = build_payloads; + payload_id_same_as_key_id_ = payload_id_same_as_key_id; +} + +bool JoinResultMaterialize::HasProbeOutput() const { + return probe_schemas_->num_cols(HashJoinProjection::OUTPUT) > 0; +} + +bool JoinResultMaterialize::HasBuildKeyOutput() const { + auto to_key = build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + for (int i = 0; i < build_schemas_->num_cols(HashJoinProjection::OUTPUT); ++i) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + return true; + } + } + return false; +} + +bool JoinResultMaterialize::HasBuildPayloadOutput() const { + auto to_payload = + build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); + for (int i = 0; i < build_schemas_->num_cols(HashJoinProjection::OUTPUT); ++i) { + if (to_payload.get(i) != SchemaProjectionMap::kMissingField) { + return true; + } + } + return false; +} + +bool JoinResultMaterialize::NeedsKeyId() const { + return HasBuildKeyOutput() || (HasBuildPayloadOutput() && payload_id_same_as_key_id_); +} + +bool JoinResultMaterialize::NeedsPayloadId() const { + return HasBuildPayloadOutput() && !payload_id_same_as_key_id_; +} + +Status JoinResultMaterialize::AppendProbeOnly(const ExecBatch& key_and_payload, + int num_rows_to_append, + const uint16_t* row_ids, + int* num_rows_appended) { + num_rows_to_append = + std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); + if (HasProbeOutput()) { + RETURN_NOT_OK(batch_builder_.AppendSelected( + pool_, key_and_payload, num_rows_to_append, row_ids, + static_cast(probe_output_to_key_and_payload_.size()), + probe_output_to_key_and_payload_.data())); + } + if (!null_ranges_.empty() && + null_ranges_.back().first + null_ranges_.back().second == num_rows_) { + // We can extend the last range of null rows on build side. + // + null_ranges_.back().second += num_rows_to_append; + } else { + null_ranges_.push_back( + std::make_pair(static_cast(num_rows_), num_rows_to_append)); + } + num_rows_ += num_rows_to_append; + *num_rows_appended = num_rows_to_append; + return Status::OK(); +} + +Status JoinResultMaterialize::AppendBuildOnly(int num_rows_to_append, + const uint32_t* key_ids, + const uint32_t* payload_ids, + int* num_rows_appended) { + num_rows_to_append = + std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); + if (HasProbeOutput()) { + RETURN_NOT_OK(batch_builder_.AppendNulls( + pool_, probe_schemas_->data_types(HashJoinProjection::OUTPUT), + num_rows_to_append)); + } + if (NeedsKeyId()) { + ARROW_DCHECK(key_ids != nullptr); + key_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(key_ids_.data() + num_rows_, key_ids, num_rows_to_append * sizeof(uint32_t)); + } + if (NeedsPayloadId()) { + ARROW_DCHECK(payload_ids != nullptr); + payload_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(payload_ids_.data() + num_rows_, payload_ids, + num_rows_to_append * sizeof(uint32_t)); + } + num_rows_ += num_rows_to_append; + *num_rows_appended = num_rows_to_append; + return Status::OK(); +} + +Status JoinResultMaterialize::Append(const ExecBatch& key_and_payload, + int num_rows_to_append, const uint16_t* row_ids, + const uint32_t* key_ids, const uint32_t* payload_ids, + int* num_rows_appended) { + num_rows_to_append = + std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); + if (HasProbeOutput()) { + RETURN_NOT_OK(batch_builder_.AppendSelected( + pool_, key_and_payload, num_rows_to_append, row_ids, + static_cast(probe_output_to_key_and_payload_.size()), + probe_output_to_key_and_payload_.data())); + } + if (NeedsKeyId()) { + ARROW_DCHECK(key_ids != nullptr); + key_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(key_ids_.data() + num_rows_, key_ids, num_rows_to_append * sizeof(uint32_t)); + } + if (NeedsPayloadId()) { + ARROW_DCHECK(payload_ids != nullptr); + payload_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(payload_ids_.data() + num_rows_, payload_ids, + num_rows_to_append * sizeof(uint32_t)); + } + num_rows_ += num_rows_to_append; + *num_rows_appended = num_rows_to_append; + return Status::OK(); +} + +Result> JoinResultMaterialize::FlushBuildColumn( + const std::shared_ptr& data_type, const RowArray* row_array, int column_id, + uint32_t* row_ids) { + ResizableArrayData output; + output.Init(data_type, pool_, bit_util::Log2(num_rows_)); + + for (size_t i = 0; i <= null_ranges_.size(); ++i) { + int row_id_begin = + i == 0 ? 0 : null_ranges_[i - 1].first + null_ranges_[i - 1].second; + int row_id_end = i == null_ranges_.size() ? num_rows_ : null_ranges_[i].first; + if (row_id_end > row_id_begin) { + RETURN_NOT_OK(row_array->DecodeSelected( + &output, column_id, row_id_end - row_id_begin, row_ids + row_id_begin, pool_)); + } + int num_nulls = i == null_ranges_.size() ? 0 : null_ranges_[i].second; + if (num_nulls > 0) { + RETURN_NOT_OK(ExecBatchBuilder::AppendNulls(data_type, output, num_nulls, pool_)); + } + } + + return output.array_data(); +} + +Status JoinResultMaterialize::Flush(ExecBatch* out) { + ARROW_DCHECK(num_rows_ > 0); + out->length = num_rows_; + out->values.clear(); + + int num_probe_cols = probe_schemas_->num_cols(HashJoinProjection::OUTPUT); + int num_build_cols = build_schemas_->num_cols(HashJoinProjection::OUTPUT); + out->values.resize(num_probe_cols + num_build_cols); + + if (HasProbeOutput()) { + ExecBatch probe_batch = batch_builder_.Flush(); + ARROW_DCHECK(static_cast(probe_batch.values.size()) == num_probe_cols); + for (size_t i = 0; i < probe_batch.values.size(); ++i) { + out->values[i] = std::move(probe_batch.values[i]); + } + } + auto to_key = build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + auto to_payload = + build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); + for (int i = 0; i < num_build_cols; ++i) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + std::shared_ptr column; + ARROW_ASSIGN_OR_RAISE( + column, + FlushBuildColumn(build_schemas_->data_type(HashJoinProjection::OUTPUT, i), + build_keys_, to_key.get(i), key_ids_.data())); + out->values[num_probe_cols + i] = std::move(column); + } else if (to_payload.get(i) != SchemaProjectionMap::kMissingField) { + std::shared_ptr column; + ARROW_ASSIGN_OR_RAISE( + column, + FlushBuildColumn( + build_schemas_->data_type(HashJoinProjection::OUTPUT, i), build_payloads_, + to_payload.get(i), + payload_id_same_as_key_id_ ? key_ids_.data() : payload_ids_.data())); + out->values[num_probe_cols + i] = std::move(column); + } else { + ARROW_DCHECK(false); + } + } + + num_rows_ = 0; + key_ids_.clear(); + payload_ids_.clear(); + null_ranges_.clear(); + + ++num_produced_batches_; + + return Status::OK(); +} + +void JoinNullFilter::Filter(const ExecBatch& key_batch, int batch_start_row, + int num_batch_rows, const std::vector& cmp, + bool* all_valid, bool and_with_input, + uint8_t* inout_bit_vector) { + // AND together validity vectors for columns that use equality comparison. + // + bool is_output_initialized = and_with_input; + for (size_t i = 0; i < cmp.size(); ++i) { + // No null filtering if null == null is true + // + if (cmp[i] != JoinKeyCmp::EQ) { + continue; + } + + // No null filtering when there are no nulls + // + const Datum& data = key_batch.values[i]; + ARROW_DCHECK(data.is_array()); + const std::shared_ptr& array_data = data.array(); + if (!array_data->buffers[0]) { + continue; + } + + const uint8_t* non_null_buffer = array_data->buffers[0]->data(); + int64_t offset = array_data->offset + batch_start_row; + + // Filter out nulls for this column + // + if (!is_output_initialized) { + memset(inout_bit_vector, 0xff, bit_util::BytesForBits(num_batch_rows)); + is_output_initialized = true; + } + arrow::internal::BitmapAnd(inout_bit_vector, 0, non_null_buffer, offset, + num_batch_rows, 0, inout_bit_vector); + } + *all_valid = !is_output_initialized; +} + +void JoinMatchIterator::SetLookupResult(int num_batch_rows, int start_batch_row, + const uint8_t* batch_has_match, + const uint32_t* key_ids, bool no_duplicate_keys, + const uint32_t* key_to_payload) { + num_batch_rows_ = num_batch_rows; + start_batch_row_ = start_batch_row; + batch_has_match_ = batch_has_match; + key_ids_ = key_ids; + + no_duplicate_keys_ = no_duplicate_keys; + key_to_payload_ = key_to_payload; + + current_row_ = 0; + current_match_for_row_ = 0; +} + +bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, + uint16_t* batch_row_ids, uint32_t* key_ids, + uint32_t* payload_ids) { + *out_num_rows = 0; + + if (no_duplicate_keys_) { + // When every input key can have at most one match, + // then we only need to filter according to has match bit vector. + // + // We stop when either we produce a full batch or when we reach the end of + // matches to output. + // + while (current_row_ < num_batch_rows_ && *out_num_rows < num_rows_max) { + batch_row_ids[*out_num_rows] = start_batch_row_ + current_row_; + key_ids[*out_num_rows] = payload_ids[*out_num_rows] = key_ids_[current_row_]; + (*out_num_rows) += bit_util::GetBit(batch_has_match_, current_row_) ? 1 : 0; + ++current_row_; + } + } else { + // When every input key can have zero, one or many matches, + // then we need to filter out ones with no match and + // iterate over all matches for the remaining ones. + // + // We stop when either we produce a full batch or when we reach the end of + // matches to output. + // + while (current_row_ < num_batch_rows_ && *out_num_rows < num_rows_max) { + if (!bit_util::GetBit(batch_has_match_, current_row_)) { + ++current_row_; + current_match_for_row_ = 0; + continue; + } + uint32_t base_payload_id = key_to_payload_[key_ids_[current_row_]]; + + // Total number of matches for the currently selected input row + // + int num_matches_total = + key_to_payload_[key_ids_[current_row_] + 1] - base_payload_id; + + // Number of remaining matches for the currently selected input row + // + int num_matches_left = num_matches_total - current_match_for_row_; + + // Number of matches for the currently selected input row that will fit + // into the next batch + // + int num_matches_next = std::min(num_matches_left, num_rows_max - *out_num_rows); + + for (int imatch = 0; imatch < num_matches_next; ++imatch) { + batch_row_ids[*out_num_rows] = start_batch_row_ + current_row_; + key_ids[*out_num_rows] = key_ids_[current_row_]; + payload_ids[*out_num_rows] = base_payload_id + current_match_for_row_ + imatch; + ++(*out_num_rows); + } + current_match_for_row_ += num_matches_next; + + if (current_match_for_row_ == num_matches_total) { + ++current_row_; + current_match_for_row_ = 0; + } + } + } + + return (*out_num_rows) > 0; +} + +void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type, + SwissTableForJoin* hash_table, + std::vector materialize, + const std::vector* cmp, + OutputBatchFn output_batch_fn) { + num_key_columns_ = num_key_columns; + join_type_ = join_type; + hash_table_ = hash_table; + materialize_.resize(materialize.size()); + for (size_t i = 0; i < materialize.size(); ++i) { + materialize_[i] = materialize[i]; + } + cmp_ = cmp; + output_batch_fn_ = output_batch_fn; +} + +Status JoinProbeProcessor::OnNextBatch( + int64_t thread_id, const ExecBatch& keypayload_batch, + util::TempVectorStack* temp_stack, + std::vector* temp_column_arrays) { + const SwissTable* swiss_table = hash_table_->keys()->swiss_table(); + int64_t hardware_flags = swiss_table->hardware_flags(); + int minibatch_size = swiss_table->minibatch_size(); + int num_rows = static_cast(keypayload_batch.length); + + ExecBatch key_batch({}, keypayload_batch.length); + key_batch.values.resize(num_key_columns_); + for (int i = 0; i < num_key_columns_; ++i) { + key_batch.values[i] = keypayload_batch.values[i]; + } + + // Break into mini-batches + // + // Start by allocating mini-batch buffers + // + auto hashes_buf = util::TempVectorHolder(temp_stack, minibatch_size); + auto match_bitvector_buf = util::TempVectorHolder( + temp_stack, static_cast(bit_util::BytesForBits(minibatch_size))); + auto key_ids_buf = util::TempVectorHolder(temp_stack, minibatch_size); + auto materialize_batch_ids_buf = + util::TempVectorHolder(temp_stack, minibatch_size); + auto materialize_key_ids_buf = + util::TempVectorHolder(temp_stack, minibatch_size); + auto materialize_payload_ids_buf = + util::TempVectorHolder(temp_stack, minibatch_size); + + for (int minibatch_start = 0; minibatch_start < num_rows;) { + uint32_t minibatch_size_next = std::min(minibatch_size, num_rows - minibatch_start); + + SwissTableWithKeys::Input input(&key_batch, minibatch_start, + minibatch_start + minibatch_size_next, temp_stack, + temp_column_arrays); + hash_table_->keys()->Hash(&input, hashes_buf.mutable_data(), hardware_flags); + hash_table_->keys()->MapReadOnly(&input, hashes_buf.mutable_data(), + match_bitvector_buf.mutable_data(), + key_ids_buf.mutable_data()); + + // AND bit vector with null key filter for join + // + bool ignored; + JoinNullFilter::Filter(key_batch, minibatch_start, minibatch_size_next, *cmp_, + &ignored, + /*and_with_input=*/true, match_bitvector_buf.mutable_data()); + // Semi-joins + // + if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI || + join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { + int num_passing_ids = 0; + util::bit_util::bits_to_indexes( + (join_type_ == JoinType::LEFT_ANTI) ? 0 : 1, hardware_flags, + minibatch_size_next, match_bitvector_buf.mutable_data(), &num_passing_ids, + materialize_batch_ids_buf.mutable_data()); + + // For right-semi, right-anti joins: update has-match flags for the rows + // in hash table. + // + if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { + for (int i = 0; i < num_passing_ids; ++i) { + uint16_t id = materialize_batch_ids_buf.mutable_data()[i]; + key_ids_buf.mutable_data()[i] = key_ids_buf.mutable_data()[id]; + } + hash_table_->UpdateHasMatchForKeys(thread_id, num_passing_ids, + key_ids_buf.mutable_data()); + } else { + // For left-semi, left-anti joins: call materialize using match + // bit-vector. + // + + // Add base batch row index. + // + for (int i = 0; i < num_passing_ids; ++i) { + materialize_batch_ids_buf.mutable_data()[i] += + static_cast(minibatch_start); + } + + RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( + keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), + [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + } + } else { + // We need to output matching pairs of rows from both sides of the join. + // Since every hash table lookup for an input row might have multiple + // matches we use a helper class that implements enumerating all of them. + // + bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); + bool no_payload_columns = (hash_table_->payloads() == nullptr); + JoinMatchIterator match_iterator; + match_iterator.SetLookupResult( + minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(), + key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload()); + int num_matches_next; + while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next, + materialize_batch_ids_buf.mutable_data(), + materialize_key_ids_buf.mutable_data(), + materialize_payload_ids_buf.mutable_data())) { + const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data(); + const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data(); + const uint32_t* materialize_payload_ids = + no_duplicate_keys || no_payload_columns + ? materialize_key_ids_buf.mutable_data() + : materialize_payload_ids_buf.mutable_data(); + + // For right-outer, full-outer joins we need to update has-match flags + // for the rows in hash table. + // + if (join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER) { + hash_table_->UpdateHasMatchForKeys(thread_id, num_matches_next, + materialize_key_ids); + } + + // Call materialize for resulting id tuples pointing to matching pairs + // of rows. + // + RETURN_NOT_OK(materialize_[thread_id]->Append( + keypayload_batch, num_matches_next, materialize_batch_ids, + materialize_key_ids, materialize_payload_ids, + [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + } + + // For left-outer and full-outer joins output non-matches. + // + // Call materialize. Nulls will be output in all columns that come from + // the other side of the join. + // + if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) { + int num_passing_ids = 0; + util::bit_util::bits_to_indexes( + /*bit_to_search=*/0, hardware_flags, minibatch_size_next, + match_bitvector_buf.mutable_data(), &num_passing_ids, + materialize_batch_ids_buf.mutable_data()); + + // Add base batch row index. + // + for (int i = 0; i < num_passing_ids; ++i) { + materialize_batch_ids_buf.mutable_data()[i] += + static_cast(minibatch_start); + } + + RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( + keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), + [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + } + } + + minibatch_start += minibatch_size_next; + } + + return Status::OK(); +} + +Status JoinProbeProcessor::OnFinished() { + // Flush all instances of materialize that have non-zero accumulated output + // rows. + // + for (size_t i = 0; i < materialize_.size(); ++i) { + JoinResultMaterialize& materialize = *materialize_[i]; + if (materialize.num_rows() > 0) { + RETURN_NOT_OK(materialize.Flush( + [&](ExecBatch batch) { output_batch_fn_(i, std::move(batch)); })); + } + } + + return Status::OK(); +} + +class ExecBatchQueue { + public: + void Init(int dop); + bool Append(int64_t thread_id, ExecBatch* batch); + void CloseAll(); + int64_t num_shared_rows() const { return num_shared_rows_; } + int64_t num_shared_batches() const { return num_shared_batches_; } + ExecBatch* shared_batch(int64_t batch_id); + + private: + struct ThreadLocalState { + ThreadLocalState() {} + ThreadLocalState(const ThreadLocalState&) {} + std::mutex mutex_; + // Protected by mutex: + bool is_closed_; + std::vector queue_; + int64_t num_rows_; + // Not protected by mutex: + std::vector queue_shared_; + }; + std::vector thread_states_; + int64_t num_shared_rows_; + int64_t num_shared_batches_; +}; + +void ExecBatchQueue::Init(int dop) { + num_shared_rows_ = 0; + num_shared_batches_ = 0; + thread_states_.resize(dop); + for (int i = 0; i < dop; ++i) { + std::lock_guard lock(thread_states_[i].mutex_); + thread_states_[i].is_closed_ = false; + thread_states_[i].num_rows_ = 0; + } +} + +bool ExecBatchQueue::Append(int64_t thread_id, ExecBatch* batch) { + int64_t num_rows = batch->length; + std::lock_guard lock(thread_states_[thread_id].mutex_); + if (thread_states_[thread_id].is_closed_) { + return false; + } + thread_states_[thread_id].queue_.push_back(*batch); + thread_states_[thread_id].num_rows_ += num_rows; + return true; +} + +void ExecBatchQueue::CloseAll() { + for (size_t i = 0; i < thread_states_.size(); ++i) { + std::vector queue_copy; + int64_t num_rows; + { + std::lock_guard lock(thread_states_[i].mutex_); + if (thread_states_[i].is_closed_) { + continue; + } + thread_states_[i].is_closed_ = true; + queue_copy = std::move(thread_states_[i].queue_); + num_rows = thread_states_[i].num_rows_; + } + num_shared_batches_ += queue_copy.size(); + num_shared_rows_ += num_rows; + thread_states_[i].queue_shared_ = std::move(queue_copy); + } +} + +ExecBatch* ExecBatchQueue::shared_batch(int64_t batch_id) { + for (size_t i = 0; i < thread_states_.size(); ++i) { + int64_t queue_size = static_cast(thread_states_[i].queue_shared_.size()); + if (batch_id < queue_size) { + return &(thread_states_[i].queue_shared_[batch_id]); + } + batch_id -= queue_size; + } + // We should never get here + // + ARROW_DCHECK(false); + return nullptr; +} + +class SwissJoin : public HashJoinImpl { + public: + Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, + size_t num_threads, const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector key_cmp, Expression filter, + OutputBatchCallback output_batch_callback, + FinishedCallback finished_callback, + TaskScheduler::ScheduleImpl schedule_task_callback) override { + num_threads_ = static_cast(std::max(num_threads, static_cast(1))); + + START_SPAN(span_, "HashJoinBasicImpl", + {{"detail", filter.ToString()}, + {"join.kind", ToString(join_type)}, + {"join.threads", static_cast(num_threads)}}); + + ctx_ = ctx; + hardware_flags_ = ctx->cpu_info()->hardware_flags(); + pool_ = ctx->memory_pool(); + + join_type_ = join_type; + key_cmp_.resize(key_cmp.size()); + for (size_t i = 0; i < key_cmp.size(); ++i) { + key_cmp_[i] = key_cmp[i]; + } + schema_[0] = proj_map_left; + schema_[1] = proj_map_right; + output_batch_callback_ = output_batch_callback; + finished_callback_ = finished_callback; + hash_table_ready_.store(false); + cancelled_.store(false); + { + std::lock_guard lock(state_mutex_); + left_side_finished_ = false; + left_queue_finished_ = false; + right_side_finished_ = false; + error_status_ = Status::OK(); + } + + local_states_.resize(num_threads_); + for (int i = 0; i < num_threads_; ++i) { + local_states_[i].hash_table_ready = false; + local_states_[i].num_output_batches = 0; + RETURN_NOT_OK(CancelIfNotOK(local_states_[i].temp_stack.Init( + pool_, 1024 + 64 * util::MiniBatch::kMiniBatchLength))); + local_states_[i].materialize.Init(pool_, proj_map_left, proj_map_right); + } + + std::vector materialize; + materialize.resize(num_threads_); + for (int i = 0; i < num_threads_; ++i) { + materialize[i] = &local_states_[i].materialize; + } + + probe_processor_.Init(proj_map_left->num_cols(HashJoinProjection::KEY), join_type_, + &hash_table_, materialize, &key_cmp_, output_batch_callback_); + batch_queue_[0].Init(num_threads_); + batch_queue_[1].Init(num_threads_); + + RETURN_NOT_OK(InitScheduler(use_sync_execution, num_threads, schedule_task_callback)); + + return Status::OK(); + } + + Status InitScheduler(bool use_sync_execution, size_t num_threads, + TaskScheduler::ScheduleImpl schedule_task_callback) { + scheduler_ = TaskScheduler::Make(); + task_group_build_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return BuildTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return BuildFinished(thread_index); }); + task_group_merge_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return MergeTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return MergeFinished(thread_index); }); + task_group_queued_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return QueueTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return QueueFinished(thread_index); }); + task_group_scan_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return ScanTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return ScanFinished(thread_index); }); + scheduler_->RegisterEnd(); + RETURN_NOT_OK(scheduler_->StartScheduling( + 0 /*thread index*/, std::move(schedule_task_callback), + static_cast(2 * num_threads) /*concurrent tasks*/, use_sync_execution)); + return Status::OK(); + } + + Status InputReceived(size_t thread_index, int side, ExecBatch batch) override { + if (IsCancelled()) { + return status(); + } + EVENT(span_, "InputReceived"); + + ExecBatch keypayload_batch; + ARROW_ASSIGN_OR_RAISE(keypayload_batch, KeyPayloadFromInput(side, &batch)); + + if (side == 1) { + // Build side + // + bool result = batch_queue_[side].Append(static_cast(thread_index), + &keypayload_batch); + ARROW_DCHECK(result); + return Status::OK(); + } + + // Probe side + // + ARROW_DCHECK(side == 0); + if (!local_states_[thread_index].hash_table_ready) { + local_states_[thread_index].hash_table_ready = hash_table_ready_.load(); + } + if (!local_states_[thread_index].hash_table_ready) { + if (!batch_queue_[side].Append(static_cast(thread_index), + &keypayload_batch)) { + local_states_[thread_index].hash_table_ready = true; + } else { + return Status::OK(); + } + } + return CancelIfNotOK(probe_processor_.OnNextBatch( + thread_index, keypayload_batch, &local_states_[thread_index].temp_stack, + &local_states_[thread_index].temp_column_arrays)); + } + + Status InputFinished(size_t thread_id, int side) override { + if (IsCancelled()) { + return status(); + } + EVENT(span_, "InputFinished", {{"side", side}}); + + if (side == 0) { + bool proceed; + { + std::lock_guard lock(state_mutex_); + proceed = !left_side_finished_ && left_queue_finished_; + left_side_finished_ = true; + } + if (proceed) { + RETURN_NOT_OK( + CancelIfNotOK(OnLeftSideAndQueueFinished(static_cast(thread_id)))); + } + } else { + bool proceed; + { + std::lock_guard lock(state_mutex_); + proceed = !right_side_finished_; + right_side_finished_ = true; + } + if (proceed) { + RETURN_NOT_OK( + CancelIfNotOK(OnRightSideFinished(static_cast(thread_id)))); + } + } + return Status::OK(); + } + + void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override { + EVENT(span_, "Abort"); + END_SPAN(span_); + std::ignore = CancelIfNotOK(Status::Cancelled("Hash Join Cancelled")); + scheduler_->Abort(std::move(pos_abort_callback)); + } + + private: + Status OnRightSideFinished(int64_t thread_id) { + return CancelIfNotOK(BuildHashTableAsync(thread_id)); + } + + Status BuildHashTableAsync(int64_t thread_id) { + // Initialize build class instance + // + ExecBatchQueue& batches = batch_queue_[1]; + const HashJoinProjectionMaps* schema = schema_[1]; + bool reject_duplicate_keys = + join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI; + bool no_payload = + reject_duplicate_keys || schema->num_cols(HashJoinProjection::PAYLOAD) == 0; + + batches.CloseAll(); + + std::vector key_types; + for (int i = 0; i < schema->num_cols(HashJoinProjection::KEY); ++i) { + key_types.push_back( + ColumnMetadataFromDataType(schema->data_type(HashJoinProjection::KEY, i))); + } + std::vector payload_types; + for (int i = 0; i < schema->num_cols(HashJoinProjection::PAYLOAD); ++i) { + payload_types.push_back( + ColumnMetadataFromDataType(schema->data_type(HashJoinProjection::PAYLOAD, i))); + } + RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.Init( + &hash_table_, num_threads_, batches.num_shared_rows(), reject_duplicate_keys, + no_payload, key_types, payload_types, pool_, hardware_flags_))); + + // Process all input batches + // + return CancelIfNotOK(scheduler_->StartTaskGroup( + static_cast(thread_id), task_group_build_, batches.num_shared_batches())); + } + + Status BuildTask(size_t thread_id, int64_t batch_id) { + if (IsCancelled()) { + return Status::OK(); + } + + ExecBatchQueue& batches = batch_queue_[1]; + const HashJoinProjectionMaps* schema = schema_[1]; + bool no_payload = hash_table_build_.no_payload(); + + ExecBatch* input_batch = batches.shared_batch(batch_id); + if (!input_batch || input_batch->length == 0) { + return Status::OK(); + } + + // Split batch into key batch and optional payload batch + // + // Input batch is key-payload batch (key columns followed by payload + // columns). We split it into two separate batches. + // + // TODO: Change SwissTableForJoinBuild interface to use key-payload + // batch instead to avoid this operation, which involves increasing + // shared pointer ref counts. + // + ExecBatch key_batch({}, input_batch->length); + key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY)); + for (size_t icol = 0; icol < key_batch.values.size(); ++icol) { + key_batch.values[icol] = input_batch->values[icol]; + } + ExecBatch payload_batch({}, input_batch->length); + + if (!no_payload) { + payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD)); + for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) { + payload_batch.values[icol] = + input_batch->values[schema->num_cols(HashJoinProjection::KEY) + icol]; + } + } + RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PushNextBatch( + static_cast(thread_id), key_batch, no_payload ? nullptr : &payload_batch, + &local_states_[thread_id].temp_stack))); + + // Release input batch + // + input_batch->values.clear(); + + return Status::OK(); + } + + Status BuildFinished(size_t thread_id) { + RETURN_NOT_OK(status()); + // On a single thread prepare for merging partitions of the resulting hash + // table. + // + RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PreparePrtnMerge())); + return CancelIfNotOK(scheduler_->StartTaskGroup(thread_id, task_group_merge_, + hash_table_build_.num_prtns())); + } + + Status MergeTask(size_t /*thread_id*/, int64_t prtn_id) { + if (IsCancelled()) { + return Status::OK(); + } + hash_table_build_.PrtnMerge(static_cast(prtn_id)); + return Status::OK(); + } + + Status MergeFinished(size_t thread_id) { + RETURN_NOT_OK(status()); + hash_table_build_.FinishPrtnMerge(&local_states_[thread_id].temp_stack); + return CancelIfNotOK(OnBuildHashTableFinished(static_cast(thread_id))); + } + + Status OnBuildHashTableFinished(int64_t thread_id) { + if (IsCancelled()) { + return status(); + } + + for (int i = 0; i < num_threads_; ++i) { + local_states_[i].materialize.SetBuildSide(hash_table_.keys()->keys(), + hash_table_.payloads(), + hash_table_.key_to_payload() == nullptr); + } + hash_table_ready_.store(true); + batch_queue_[0].CloseAll(); + return ProcessLeftSideQueueAsync(thread_id); + } + + Status ProcessLeftSideQueueAsync(int64_t thread_id) { + if (IsCancelled()) { + return status(); + } + + ExecBatchQueue& batches = batch_queue_[0]; + int64_t num_tasks = batches.num_shared_batches(); + + return CancelIfNotOK(scheduler_->StartTaskGroup(static_cast(thread_id), + task_group_queued_, num_tasks)); + } + + Status QueueTask(size_t thread_id, int64_t batch_id) { + if (IsCancelled()) { + return Status::OK(); + } + + ExecBatchQueue& batches = batch_queue_[0]; + ExecBatch* input_batch = batches.shared_batch(batch_id); + RETURN_NOT_OK(CancelIfNotOK( + probe_processor_.OnNextBatch(static_cast(thread_id), *input_batch, + &local_states_[thread_id].temp_stack, + &local_states_[thread_id].temp_column_arrays))); + // Release input batch + // + input_batch->values.clear(); + return Status::OK(); + } + + Status QueueFinished(size_t thread_id) { + if (IsCancelled()) { + return status(); + } + + bool proceed; + { + std::lock_guard lock(state_mutex_); + proceed = left_side_finished_ && !left_queue_finished_; + left_queue_finished_ = true; + } + if (proceed) { + RETURN_NOT_OK( + CancelIfNotOK(OnLeftSideAndQueueFinished(static_cast(thread_id)))); + } + return Status::OK(); + } + + Status OnLeftSideAndQueueFinished(int64_t thread_id) { + return CancelIfNotOK(ScanHashTableAsync(thread_id)); + } + + Status ScanHashTableAsync(int64_t thread_id) { + if (IsCancelled()) { + return status(); + } + + bool need_to_scan = + (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI || + join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER); + + if (need_to_scan) { + hash_table_.MergeHasMatch(); + int64_t num_tasks = bit_util::CeilDiv(hash_table_.num_rows(), kNumRowsPerScanTask); + + return CancelIfNotOK(scheduler_->StartTaskGroup(static_cast(thread_id), + task_group_scan_, num_tasks)); + } else { + return CancelIfNotOK(OnScanHashTableFinished()); + } + } + + Status ScanTask(size_t thread_id, int64_t task_id) { + if (IsCancelled()) { + return Status::OK(); + } + + // Should we output matches or non-matches? + // + bool bit_to_output = (join_type_ == JoinType::RIGHT_SEMI); + + int64_t start_row = task_id * kNumRowsPerScanTask; + int64_t end_row = + std::min((task_id + 1) * kNumRowsPerScanTask, hash_table_.num_rows()); + // Get thread index and related temp vector stack + // + util::TempVectorStack* temp_stack = &local_states_[thread_id].temp_stack; + + // Split into mini-batches + // + auto payload_ids_buf = + util::TempVectorHolder(temp_stack, util::MiniBatch::kMiniBatchLength); + auto key_ids_buf = + util::TempVectorHolder(temp_stack, util::MiniBatch::kMiniBatchLength); + auto selection_buf = + util::TempVectorHolder(temp_stack, util::MiniBatch::kMiniBatchLength); + for (int64_t mini_batch_start = start_row; mini_batch_start < end_row;) { + // Compute the size of the next mini-batch + // + int64_t mini_batch_size_next = + std::min(end_row - mini_batch_start, + static_cast(util::MiniBatch::kMiniBatchLength)); + + // Get the list of key and payload ids from this mini-batch to output. + // + uint32_t first_key_id = hash_table_.payload_id_to_key_id(mini_batch_start); + uint32_t last_key_id = + hash_table_.payload_id_to_key_id(mini_batch_start + mini_batch_size_next - 1); + int num_output_rows = 0; + for (uint32_t key_id = first_key_id; key_id <= last_key_id; ++key_id) { + if (bit_util::GetBit(hash_table_.has_match(), key_id) == bit_to_output) { + uint32_t first_payload_for_key = + std::max(static_cast(mini_batch_start), + hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id] + : key_id); + uint32_t last_payload_for_key = std::min( + static_cast(mini_batch_start + mini_batch_size_next - 1), + hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id + 1] - 1 + : key_id); + uint32_t num_payloads_for_key = + last_payload_for_key - first_payload_for_key + 1; + for (uint32_t i = 0; i < num_payloads_for_key; ++i) { + key_ids_buf.mutable_data()[num_output_rows + i] = key_id; + payload_ids_buf.mutable_data()[num_output_rows + i] = + first_payload_for_key + i; + } + num_output_rows += num_payloads_for_key; + } + } + + if (num_output_rows > 0) { + // Materialize (and output whenever buffers get full) hash table + // values according to the generated list of ids. + // + Status status = local_states_[thread_id].materialize.AppendBuildOnly( + num_output_rows, key_ids_buf.mutable_data(), payload_ids_buf.mutable_data(), + [&](ExecBatch batch) { + output_batch_callback_(static_cast(thread_id), std::move(batch)); + }); + RETURN_NOT_OK(CancelIfNotOK(status)); + if (!status.ok()) { + break; + } + } + mini_batch_start += mini_batch_size_next; + } + + return Status::OK(); + } + + Status ScanFinished(size_t thread_id) { + if (IsCancelled()) { + return status(); + } + + return CancelIfNotOK(OnScanHashTableFinished()); + } + + Status OnScanHashTableFinished() { + if (IsCancelled()) { + return status(); + } + END_SPAN(span_); + + // Flush all instances of materialize that have non-zero accumulated output + // rows. + // + RETURN_NOT_OK(CancelIfNotOK(probe_processor_.OnFinished())); + + int64_t num_produced_batches = 0; + for (size_t i = 0; i < local_states_.size(); ++i) { + JoinResultMaterialize& materialize = local_states_[i].materialize; + num_produced_batches += materialize.num_produced_batches(); + } + + finished_callback_(num_produced_batches); + + return Status::OK(); + } + + Result KeyPayloadFromInput(int side, ExecBatch* input) { + ExecBatch projected({}, input->length); + int num_key_cols = schema_[side]->num_cols(HashJoinProjection::KEY); + int num_payload_cols = schema_[side]->num_cols(HashJoinProjection::PAYLOAD); + projected.values.resize(num_key_cols + num_payload_cols); + + auto key_to_input = + schema_[side]->map(HashJoinProjection::KEY, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_key_cols; ++icol) { + const Datum& value_in = input->values[key_to_input.get(icol)]; + if (value_in.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + projected.values[icol], + MakeArrayFromScalar(*value_in.scalar(), projected.length, pool_)); + } else { + projected.values[icol] = value_in; + } + } + auto payload_to_input = + schema_[side]->map(HashJoinProjection::PAYLOAD, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_payload_cols; ++icol) { + const Datum& value_in = input->values[payload_to_input.get(icol)]; + if (value_in.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + projected.values[num_key_cols + icol], + MakeArrayFromScalar(*value_in.scalar(), projected.length, pool_)); + } else { + projected.values[num_key_cols + icol] = value_in; + } + } + + return projected; + } + + bool IsCancelled() { return cancelled_.load(); } + + Status status() { + if (IsCancelled()) { + std::lock_guard lock(state_mutex_); + return error_status_; + } + return Status::OK(); + } + + Status CancelIfNotOK(Status status) { + if (!status.ok()) { + { + std::lock_guard lock(state_mutex_); + // Only update the status for the first error encountered. + // + if (error_status_.ok()) { + error_status_ = status; + } + } + cancelled_.store(true); + } + return status; + } + + static constexpr int kNumRowsPerScanTask = 512 * 1024; + + ExecContext* ctx_; + int64_t hardware_flags_; + MemoryPool* pool_; + int num_threads_; + JoinType join_type_; + std::vector key_cmp_; + const HashJoinProjectionMaps* schema_[2]; + + // Task scheduling + std::unique_ptr scheduler_; + int task_group_build_; + int task_group_merge_; + int task_group_queued_; + int task_group_scan_; + + // Callbacks + OutputBatchCallback output_batch_callback_; + FinishedCallback finished_callback_; + + struct ThreadLocalState { + JoinResultMaterialize materialize; + util::TempVectorStack temp_stack; + std::vector temp_column_arrays; + int64_t num_output_batches; + bool hash_table_ready; + }; + std::vector local_states_; + + SwissTableForJoin hash_table_; + JoinProbeProcessor probe_processor_; + SwissTableForJoinBuild hash_table_build_; + ExecBatchQueue batch_queue_[2]; + + // Atomic state flags. + // These flags are kept outside of mutex, since they can be queried for every + // batch. + // + // The other flags that follow them, protected by mutex, will be queried or + // updated only a fixed number of times during entire join processing. + // + std::atomic hash_table_ready_; + std::atomic cancelled_; + + // Mutex protecting state flags. + // + std::mutex state_mutex_; + + // Mutex protected state flags. + // + bool left_side_finished_; + bool left_queue_finished_; + bool right_side_finished_; + Status error_status_; +}; + +Result> HashJoinImpl::MakeSwiss() { + std::unique_ptr impl{new SwissJoin()}; + return std::move(impl); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/swiss_join.h b/cpp/src/arrow/compute/exec/swiss_join.h new file mode 100644 index 0000000000000..d419cc0c62e0d --- /dev/null +++ b/cpp/src/arrow/compute/exec/swiss_join.h @@ -0,0 +1,874 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/key_map.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/partition_util.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/task_util.h" + +namespace arrow { +namespace compute { + +class ResizableArrayData { + public: + ResizableArrayData() + : log_num_rows_min_(0), + pool_(NULLPTR), + num_rows_(0), + num_rows_allocated_(0), + var_len_buf_size_(0) {} + ~ResizableArrayData() { Clear(true); } + void Init(const std::shared_ptr& data_type, MemoryPool* pool, + int log_num_rows_min); + void Clear(bool release_buffers); + Status ResizeFixedLengthBuffers(int num_rows_new); + Status ResizeVaryingLengthBuffer(); + int num_rows() const { return num_rows_; } + KeyEncoder::KeyColumnArray column_array() const; + KeyEncoder::KeyColumnMetadata column_metadata() const { + return ColumnMetadataFromDataType(data_type_); + } + std::shared_ptr array_data() const; + uint8_t* mutable_data(int i) { + return i == 0 ? non_null_buf_->mutable_data() + : i == 1 ? fixed_len_buf_->mutable_data() + : var_len_buf_->mutable_data(); + } + + private: + static constexpr int64_t kNumPaddingBytes = 64; + int log_num_rows_min_; + std::shared_ptr data_type_; + MemoryPool* pool_; + int num_rows_; + int num_rows_allocated_; + int var_len_buf_size_; + std::shared_ptr non_null_buf_; + std::shared_ptr fixed_len_buf_; + std::shared_ptr var_len_buf_; +}; + +class ExecBatchBuilder { + public: + static Status AppendSelected(const std::shared_ptr& source, + ResizableArrayData& target, int num_rows_to_append, + const uint16_t* row_ids, MemoryPool* pool); + + static Status AppendNulls(const std::shared_ptr& type, + ResizableArrayData& target, int num_rows_to_append, + MemoryPool* pool); + + Status AppendSelected(MemoryPool* pool, const ExecBatch& batch, int num_rows_to_append, + const uint16_t* row_ids, int num_cols, + const int* col_ids = NULLPTR); + + Status AppendSelected(MemoryPool* pool, const ExecBatch& batch, int num_rows_to_append, + const uint16_t* row_ids, int* num_appended, int num_cols, + const int* col_ids = NULLPTR); + + Status AppendNulls(MemoryPool* pool, + const std::vector>& types, + int num_rows_to_append); + + Status AppendNulls(MemoryPool* pool, + const std::vector>& types, + int num_rows_to_append, int* num_appended); + + // Should only be called if num_rows() returns non-zero. + // + ExecBatch Flush(); + + int num_rows() const { return values_.empty() ? 0 : values_[0].num_rows(); } + + static int num_rows_max() { return 1 << kLogNumRows; } + + private: + static constexpr int kLogNumRows = 15; + + // Calculate how many rows to skip from the tail of the + // sequence of selected rows, such that the total size of skipped rows is at + // least equal to the size specified by the caller. Skipping of the tail rows + // is used to allow for faster processing by the caller of remaining rows + // without checking buffer bounds (useful with SIMD or fixed size memory loads + // and stores). + // + // The sequence of row_ids provided must be non-decreasing. + // + static int NumRowsToSkip(const std::shared_ptr& column, int num_rows, + const uint16_t* row_ids, int num_tail_bytes_to_skip); + + // The supplied lambda will be called for each row in the given list of rows. + // The arguments given to it will be: + // - index of a row (within the set of selected rows), + // - pointer to the value, + // - byte length of the value. + // + // The information about nulls (validity bitmap) is not used in this call and + // has to be processed separately. + // + template + static void Visit(const std::shared_ptr& column, int num_rows, + const uint16_t* row_ids, PROCESS_VALUE_FN process_value_fn); + + template + static void CollectBitsImp(const uint8_t* input_bits, int64_t input_bits_offset, + uint8_t* output_bits, int64_t output_bits_offset, + int num_rows, const uint16_t* row_ids); + static void CollectBits(const uint8_t* input_bits, int64_t input_bits_offset, + uint8_t* output_bits, int64_t output_bits_offset, int num_rows, + const uint16_t* row_ids); + + std::vector values_; +}; + +class RowArrayAccessor { + public: + // Find the index of this varbinary column within the sequence of all + // varbinary columns encoded in rows. + // + static int VarbinaryColumnId(const KeyEncoder::KeyRowMetadata& row_metadata, + int column_id); + + // Calculate how many rows to skip from the tail of the + // sequence of selected rows, such that the total size of skipped rows is at + // least equal to the size specified by the caller. Skipping of the tail rows + // is used to allow for faster processing by the caller of remaining rows + // without checking buffer bounds (useful with SIMD or fixed size memory loads + // and stores). + // + static int NumRowsToSkip(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + int num_tail_bytes_to_skip); + + // The supplied lambda will be called for each row in the given list of rows. + // The arguments given to it will be: + // - index of a row (within the set of selected rows), + // - pointer to the value, + // - byte length of the value. + // + // The information about nulls (validity bitmap) is not used in this call and + // has to be processed separately. + // + template + static void Visit(const KeyEncoder::KeyRowArray& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn); + + // The supplied lambda will be called for each row in the given list of rows. + // The arguments given to it will be: + // - index of a row (within the set of selected rows), + // - byte 0xFF if the null is set for the row or 0x00 otherwise. + // + template + static void VisitNulls(const KeyEncoder::KeyRowArray& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn); + + private: +#if defined(ARROW_HAVE_AVX2) + // This is equivalent to Visit method, but processing 8 rows at a time in a + // loop. + // Returns the number of processed rows, which may be less than requested (up + // to 7 rows at the end may be skipped). + // + template + static int Visit_avx2(const KeyEncoder::KeyRowArray& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_8_VALUES_FN process_8_values_fn); + + // This is equivalent to VisitNulls method, but processing 8 rows at a time in + // a loop. Returns the number of processed rows, which may be less than + // requested (up to 7 rows at the end may be skipped). + // + template + static int VisitNulls_avx2(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + PROCESS_8_VALUES_FN process_8_values_fn); +#endif +}; + +// Write operations (appending batch rows) must not be called by more than one +// thread at the same time. +// +// Read operations (row comparison, column decoding) +// can be called by multiple threads concurrently. +// +struct RowArray { + RowArray() : is_initialized_(false) {} + + Status InitIfNeeded(MemoryPool* pool, const ExecBatch& batch); + Status InitIfNeeded(MemoryPool* pool, const KeyEncoder::KeyRowMetadata& row_metadata); + + Status AppendBatchSelection( + MemoryPool* pool, const ExecBatch& batch, int begin_row_id, int end_row_id, + int num_row_ids, const uint16_t* row_ids, + std::vector& temp_column_arrays); + + // This can only be called for a minibatch. + // + void Compare(const ExecBatch& batch, int begin_row_id, int end_row_id, int num_selected, + const uint16_t* batch_selection_maybe_null, const uint32_t* array_row_ids, + uint32_t* out_num_not_equal, uint16_t* out_not_equal_selection, + int64_t hardware_flags, util::TempVectorStack* temp_stack, + std::vector& temp_column_arrays, + uint8_t* out_match_bitvector_maybe_null = NULLPTR); + + // TODO: add AVX2 version + // + Status DecodeSelected(ResizableArrayData* target, int column_id, int num_rows_to_append, + const uint32_t* row_ids, MemoryPool* pool) const; + + void DebugPrintToFile(const char* filename, bool print_sorted) const; + + int64_t num_rows() const { return is_initialized_ ? rows_.length() : 0; } + + bool is_initialized_; + KeyEncoder encoder_; + KeyEncoder::KeyRowArray rows_; + KeyEncoder::KeyRowArray rows_temp_; +}; + +// Implements concatenating multiple row arrays into a single one, using +// potentially multiple threads, each processing a single input row array. +// +class RowArrayMerge { + public: + // Calculate total number of rows and size in bytes for merged sequence of + // rows and allocate memory for it. + // + // If the rows are of varying length, initialize in the offset array the first + // entry for the write area for each input row array. Leave all other + // offsets and buffers uninitialized. + // + // All input sources must be initialized, but they can contain zero rows. + // + // Output in vector the first target row id for each source (exclusive + // cummulative sum of number of rows in sources). + // + static Status PrepareForMerge(RowArray* target, const std::vector& sources, + std::vector* first_target_row_id, + MemoryPool* pool); + + // Copy rows from source array to target array. + // Both arrays must have the same row metadata. + // Target array must already have the memory reserved in all internal buffers + // for the copy of the rows. + // + // Copy of the rows will occupy the same amount of space in the target array + // buffers as in the source array, but in the target array we pick at what row + // position and offset we start writing. + // + // Optionally, the rows may be reordered during copy according to the + // provided permutation, which represents some sorting order of source rows. + // Nth element of the permutation array is the source row index for the Nth + // row written into target array. If permutation is missing (null), then the + // order of source rows will remain unchanged. + // + // In case of varying length rows, we purposefully skip outputting of N+1 (one + // after last) offset, to allow concurrent copies of rows done to adjacent + // ranges in the target array. This offset should already contain the right + // value after calling the method preparing target array for merge (which + // initializes boundary offsets for target row ranges for each source). + // + static void MergeSingle(RowArray* target, const RowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation); + + private: + // Copy rows from source array to a region of the target array. + // This implementation is for fixed length rows. + // Null information needs to be handled separately. + // + static void CopyFixedLength(KeyEncoder::KeyRowArray* target, + const KeyEncoder::KeyRowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation); + + // Copy rows from source array to a region of the target array. + // This implementation is for varying length rows. + // Null information needs to be handled separately. + // + static void CopyVaryingLength(KeyEncoder::KeyRowArray* target, + const KeyEncoder::KeyRowArray& source, + int64_t first_target_row_id, + int64_t first_target_row_offset, + const int64_t* source_rows_permutation); + + // Copy null information from rows from source array to a region of the target + // array. + // + static void CopyNulls(KeyEncoder::KeyRowArray* target, + const KeyEncoder::KeyRowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation); +}; + +// Implements merging of multiple SwissTables into a single one, using +// potentially multiple threads, each processing a single input source. +// +// Each source should correspond to a range of original hashes. +// A row belongs to a source with index determined by K highest bits of +// original hash. That means that the number of sources must be a power of 2. +// +// We assume that the hash values used and stored inside source tables +// have K highest bits removed from the original hash in order to avoid huge +// number of hash collisions that would occur otherwise. +// These bits will be reinserted back (original hashes will be used) when +// merging into target. +// +class SwissTableMerge { + public: + // Calculate total number of blocks for merged table. + // Allocate buffers sized accordingly and initialize empty target table. + // + // All input sources must be initialized, but they can be empty. + // + // Output in a vector the first target group id for each source (exclusive + // cummulative sum of number of groups in sources). + // + static Status PrepareForMerge(SwissTable* target, + const std::vector& sources, + std::vector* first_target_group_id, + MemoryPool* pool); + + // Copy all entries from source to a range of blocks (partition) of target. + // + // During copy, adjust group ids from source by adding provided base id. + // + // Skip entries from source that would cross partition boundaries (range of + // blocks) when inserted into target. Save their data in output vector for + // processing later. We postpone inserting these overflow entries in order to + // allow concurrent processing of all partitions. Overflow entries will be + // handled by a single-thread afterwards. + // + static void MergePartition(SwissTable* target, const SwissTable* source, + uint32_t partition_id, int num_partition_bits, + uint32_t base_group_id, + std::vector* overflow_group_ids, + std::vector* overflow_hashes); + + // Single-threaded processing of remaining groups, that could not be + // inserted in partition merge phase + // (due to entries from one partition spilling over due to full blocks into + // the next partition). + // + static void InsertNewGroups(SwissTable* target, const std::vector& group_ids, + const std::vector& hashes); + + private: + // Insert a new group id. + // + // Assumes that there are enough slots in the target + // and there is no need to resize it. + // + // Max block id can be provided, in which case the search for an empty slot to + // insert new entry to will stop after visiting that block. + // + // Max block id value greater or equal to the number of blocks guarantees that + // the search will not be stopped. + // + static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash, + int64_t max_block_id); +}; + +struct SwissTableWithKeys { + struct Input { + Input(const ExecBatch* in_batch, int in_batch_start_row, int in_batch_end_row, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays); + + Input(const ExecBatch* in_batch, util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays); + + Input(const ExecBatch* in_batch, int in_num_selected, const uint16_t* in_selection, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays, + std::vector* in_temp_group_ids); + + Input(const Input& base, int num_rows_to_skip, int num_rows_to_include); + + const ExecBatch* batch; + // Window of the batch to operate on. + // The window information is only used if row selection is null. + // + int batch_start_row; + int batch_end_row; + // Optional selection. + // Used instead of window of the batch if not null. + // + int num_selected; + const uint16_t* selection_maybe_null; + // Thread specific scratch buffers for storing temporary data. + // + util::TempVectorStack* temp_stack; + std::vector* temp_column_arrays; + std::vector* temp_group_ids; + }; + + Status Init(int64_t hardware_flags, MemoryPool* pool); + + void InitCallbacks(); + + static void Hash(Input* input, uint32_t* hashes, int64_t hardware_flags); + + // If input uses selection, then hashes array must have one element for every + // row in the whole (unfiltered and not spliced) input exec batch. Otherwise, + // there must be one element in hashes array for every value in the window of + // the exec batch specified by input. + // + // Output arrays will contain one element for every selected batch row in + // input (selected either by selection vector if provided or input window + // otherwise). + // + void MapReadOnly(Input* input, const uint32_t* hashes, uint8_t* match_bitvector, + uint32_t* key_ids); + Status MapWithInserts(Input* input, const uint32_t* hashes, uint32_t* key_ids); + + SwissTable* swiss_table() { return &swiss_table_; } + const SwissTable* swiss_table() const { return &swiss_table_; } + RowArray* keys() { return &keys_; } + const RowArray* keys() const { return &keys_; } + + private: + void EqualCallback(int num_keys, const uint16_t* selection_maybe_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx); + Status AppendCallback(int num_keys, const uint16_t* selection, void* callback_ctx); + Status Map(Input* input, bool insert_missing, const uint32_t* hashes, + uint8_t* match_bitvector_maybe_null, uint32_t* key_ids); + + SwissTable::EqualImpl equal_impl_; + SwissTable::AppendImpl append_impl_; + + SwissTable swiss_table_; + RowArray keys_; +}; + +// Enhances SwissTableWithKeys with the following structures used by hash join: +// - storage of payloads (that unlike keys do not have to be unique) +// - mapping from a key to all inserted payloads corresponding to it (we can +// store multiple rows corresponding to a single key) +// - bit-vectors for keeping track of whether each payload had a match during +// evaluation of join. +// +class SwissTableForJoin { + friend class SwissTableForJoinBuild; + + public: + void Lookup(const ExecBatch& batch, int start_row, int num_rows, + uint8_t* out_has_match_bitvector, uint32_t* out_key_ids, + util::TempVectorStack* temp_stack, + std::vector* temp_column_arrays); + void UpdateHasMatchForKeys(int64_t thread_id, int num_rows, const uint32_t* key_ids); + void MergeHasMatch(); + + const SwissTableWithKeys* keys() const { return &map_; } + SwissTableWithKeys* keys() { return &map_; } + const RowArray* payloads() const { return no_payload_columns_ ? NULLPTR : &payloads_; } + const uint32_t* key_to_payload() const { + return no_duplicate_keys_ ? NULLPTR : row_offset_for_key_.data(); + } + const uint8_t* has_match() const { + return has_match_.empty() ? NULLPTR : has_match_.data(); + } + int64_t num_keys() const { return map_.keys()->num_rows(); } + int64_t num_rows() const { + return no_duplicate_keys_ ? num_keys() : row_offset_for_key_[num_keys()]; + } + + uint32_t payload_id_to_key_id(uint32_t payload_id) const; + // Input payload ids must form an increasing sequence. + // + void payload_ids_to_key_ids(int num_rows, const uint32_t* payload_ids, + uint32_t* key_ids) const; + + private: + uint8_t* local_has_match(int64_t thread_id); + + // Degree of parallelism (number of threads) + int dop_; + + struct ThreadLocalState { + std::vector has_match; + }; + std::vector local_states_; + std::vector has_match_; + + SwissTableWithKeys map_; + + bool no_duplicate_keys_; + // Not used if no_duplicate_keys_ is true. + std::vector row_offset_for_key_; + + bool no_payload_columns_; + // Not used if no_payload_columns_ is true. + RowArray payloads_; +}; + +// Implements parallel build process for hash table for join from a sequence of +// exec batches with input rows. +// +class SwissTableForJoinBuild { + public: + Status Init(SwissTableForJoin* target, int dop, int64_t num_rows, + bool reject_duplicate_keys, bool no_payload, + const std::vector& key_types, + const std::vector& payload_types, + MemoryPool* pool, int64_t hardware_flags); + + // In the first phase of parallel hash table build, threads pick unprocessed + // exec batches, partition the rows based on hash, and update all of the + // partitions with information related to that batch of rows. + // + Status PushNextBatch(int64_t thread_id, const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack); + + // Allocate memory and initialize counters required for parallel merging of + // hash table partitions. + // Single-threaded. + // + Status PreparePrtnMerge(); + + // Second phase of parallel hash table build. + // Each partition can be processed by a different thread. + // Parallel step. + // + void PrtnMerge(int prtn_id); + + // Single-threaded processing of the rows that have been skipped during + // parallel merging phase, due to hash table search resulting in crossing + // partition boundaries. + // + void FinishPrtnMerge(util::TempVectorStack* temp_stack); + + // The number of partitions is the number of parallel tasks to execute during + // the final phase of hash table build process. + // + int num_prtns() const { return num_prtns_; } + + bool no_payload() const { return no_payload_; } + + private: + void InitRowArray(); + Status ProcessPartition(int64_t thread_id, const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack, int prtn_id); + + SwissTableForJoin* target_; + // DOP stands for Degree Of Parallelism - the maximum number of participating + // threads. + // + int dop_; + // Partition is a unit of parallel work. + // + // There must be power of 2 partitions (bits of hash will be used to + // identify them). + // + // Pick number of partitions at least equal to the number of threads (degree + // of parallelism). + // + int log_num_prtns_; + int num_prtns_; + int64_t num_rows_; + // Left-semi and left-anti-semi joins do not need more than one copy of the + // same key in the hash table. + // This flag, if set, will result in filtering rows with duplicate keys before + // inserting them into hash table. + // + // Since left-semi and left-anti-semi joins also do not need payload, when + // this flag is set there also will not be any processing of payload. + // + bool reject_duplicate_keys_; + // This flag, when set, will result in skipping any processing of the payload. + // + // The flag for rejecting duplicate keys (which should be set for left-semi + // and left-anti joins), when set, will force this flag to also be set, but + // other join flavors may set it to true as well if no payload columns are + // needed for join output. + // + bool no_payload_; + MemoryPool* pool_; + int64_t hardware_flags_; + + // One per partition. + // + struct PartitionState { + SwissTableWithKeys keys; + RowArray payloads; + std::vector key_ids; + std::vector overflow_key_ids; + std::vector overflow_hashes; + }; + + // One per thread. + // + // Buffers for storing temporary intermediate results when processing input + // batches. + // + struct ThreadState { + std::vector batch_hashes; + std::vector batch_prtn_ranges; + std::vector batch_prtn_row_ids; + std::vector temp_prtn_ids; + std::vector temp_group_ids; + std::vector temp_column_arrays; + }; + + std::vector prtn_states_; + std::vector thread_states_; + PartitionLocks prtn_locks_; + + std::vector partition_keys_first_row_id_; + std::vector partition_payloads_first_row_id_; +}; + +class JoinResultMaterialize { + public: + void Init(MemoryPool* pool, const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas); + + void SetBuildSide(const RowArray* build_keys, const RowArray* build_payloads, + bool payload_id_same_as_key_id); + + // Input probe side batches should contain all key columns followed by all + // payload columns. + // + Status AppendProbeOnly(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, int* num_rows_appended); + + Status AppendBuildOnly(int num_rows_to_append, const uint32_t* key_ids, + const uint32_t* payload_ids, int* num_rows_appended); + + Status Append(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, const uint32_t* key_ids, + const uint32_t* payload_ids, int* num_rows_appended); + + // Should only be called if num_rows() returns non-zero. + // + Status Flush(ExecBatch* out); + + int num_rows() const { return num_rows_; } + + template + Status AppendAndOutput(int num_rows_to_append, const APPEND_ROWS_FN& append_rows_fn, + const OUTPUT_BATCH_FN& output_batch_fn) { + int offset = 0; + for (;;) { + int num_rows_appended = 0; + ARROW_RETURN_NOT_OK(append_rows_fn(num_rows_to_append, offset, &num_rows_appended)); + if (num_rows_appended < num_rows_to_append) { + ExecBatch batch; + ARROW_RETURN_NOT_OK(Flush(&batch)); + output_batch_fn(batch); + num_rows_to_append -= num_rows_appended; + offset += num_rows_appended; + } else { + break; + } + } + return Status::OK(); + } + + template + Status AppendProbeOnly(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, OUTPUT_BATCH_FN output_batch_fn) { + return AppendAndOutput( + num_rows_to_append, + [&](int num_rows_to_append_left, int offset, int* num_rows_appended) { + return AppendProbeOnly(key_and_payload, num_rows_to_append_left, + row_ids + offset, num_rows_appended); + }, + output_batch_fn); + } + + template + Status AppendBuildOnly(int num_rows_to_append, const uint32_t* key_ids, + const uint32_t* payload_ids, OUTPUT_BATCH_FN output_batch_fn) { + return AppendAndOutput( + num_rows_to_append, + [&](int num_rows_to_append_left, int offset, int* num_rows_appended) { + return AppendBuildOnly( + num_rows_to_append_left, key_ids ? key_ids + offset : NULLPTR, + payload_ids ? payload_ids + offset : NULLPTR, num_rows_appended); + }, + output_batch_fn); + } + + template + Status Append(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, const uint32_t* key_ids, + const uint32_t* payload_ids, OUTPUT_BATCH_FN output_batch_fn) { + return AppendAndOutput( + num_rows_to_append, + [&](int num_rows_to_append_left, int offset, int* num_rows_appended) { + return Append(key_and_payload, num_rows_to_append_left, + row_ids ? row_ids + offset : NULLPTR, + key_ids ? key_ids + offset : NULLPTR, + payload_ids ? payload_ids + offset : NULLPTR, num_rows_appended); + }, + output_batch_fn); + } + + template + Status Flush(OUTPUT_BATCH_FN output_batch_fn) { + if (num_rows_ > 0) { + ExecBatch batch({}, num_rows_); + ARROW_RETURN_NOT_OK(Flush(&batch)); + output_batch_fn(std::move(batch)); + } + return Status::OK(); + } + + int64_t num_produced_batches() const { return num_produced_batches_; } + + private: + bool HasProbeOutput() const; + bool HasBuildKeyOutput() const; + bool HasBuildPayloadOutput() const; + bool NeedsKeyId() const; + bool NeedsPayloadId() const; + Result> FlushBuildColumn( + const std::shared_ptr& data_type, const RowArray* row_array, + int column_id, uint32_t* row_ids); + + MemoryPool* pool_; + const HashJoinProjectionMaps* probe_schemas_; + const HashJoinProjectionMaps* build_schemas_; + const RowArray* build_keys_; + // Payload array pointer may be left as null, if no payload columns are + // in the output column set. + // + const RowArray* build_payloads_; + // If true, then ignore updating payload ids and use key ids instead when + // reading. + // + bool payload_id_same_as_key_id_; + std::vector probe_output_to_key_and_payload_; + + // Number of accumulated rows (since last flush) + // + int num_rows_; + // Accumulated output columns from probe side batches. + // + ExecBatchBuilder batch_builder_; + // Accumulated build side row references. + // + std::vector key_ids_; + std::vector payload_ids_; + // Information about ranges of rows from build side, + // that in the accumulated materialized results have all fields set to null. + // + // Each pair contains index of the first output row in the range and the + // length of the range. Only rows outside of these ranges have data present in + // the key_ids_ and payload_ids_ arrays. + // + std::vector> null_ranges_; + + int64_t num_produced_batches_; +}; + +// Implements evaluating filter bit vector eliminating rows that do not have +// join matches due to nulls in key columns. +// +// +class JoinNullFilter { + public: + // The batch for which the filter bit vector will be computed + // needs to start with all key columns but it may contain more columns + // (payload) following them. + // + static void Filter(const ExecBatch& key_batch, int batch_start_row, int num_batch_rows, + const std::vector& cmp, bool* all_valid, + bool and_with_input, uint8_t* out_bit_vector); +}; + +// A helper class that takes hash table lookup results for a range of rows in +// input batch, that is: +// - bit vector marking whether there was a key match in the hash table +// - key id if there was a match +// - mapping from key id to a range of payload ids associated with that key +// (representing multiple matching rows in a hash table for a single row in an +// input batch), and iterates output batches of limited size containing tuples +// describing all matching pairs of rows: +// - input batch row id (only rows that have matches in the hash table are +// included) +// - key id for a match +// - payload id (different one for each matching row in the hash table) +// +class JoinMatchIterator { + public: + void SetLookupResult(int num_batch_rows, int start_batch_row, + const uint8_t* batch_has_match, const uint32_t* key_ids, + bool no_duplicate_keys, const uint32_t* key_to_payload); + bool GetNextBatch(int num_rows_max, int* out_num_rows, uint16_t* batch_row_ids, + uint32_t* key_ids, uint32_t* payload_ids); + + private: + int num_batch_rows_; + int start_batch_row_; + const uint8_t* batch_has_match_; + const uint32_t* key_ids_; + + bool no_duplicate_keys_; + const uint32_t* key_to_payload_; + + // Index of the first not fully processed input row, or number of rows if all + // have been processed. May be pointing to a row with no matches. + // + int current_row_; + // Index of the first unprocessed match for the input row. May be zero if the + // row has no matches. + // + int current_match_for_row_; +}; + +// Implements entire processing of a probe side exec batch, +// provided the join hash table is already built and available. +// +class JoinProbeProcessor { + public: + using OutputBatchFn = std::function; + + void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, + std::vector materialize, + const std::vector* cmp, OutputBatchFn output_batch_fn); + Status OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, + util::TempVectorStack* temp_stack, + std::vector* temp_column_arrays); + + // Must be called by a single-thread having exclusive access to the instance + // of this class. The caller is responsible for ensuring that. + // + Status OnFinished(); + + private: + int num_key_columns_; + JoinType join_type_; + + SwissTableForJoin* hash_table_; + // One element per thread + // + std::vector materialize_; + const std::vector* cmp_; + OutputBatchFn output_batch_fn_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/swiss_join_avx2.cc b/cpp/src/arrow/compute/exec/swiss_join_avx2.cc new file mode 100644 index 0000000000000..2ddb5983448f8 --- /dev/null +++ b/cpp/src/arrow/compute/exec/swiss_join_avx2.cc @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/swiss_join.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace compute { + +#if defined(ARROW_HAVE_AVX2) + +template +int RowArrayAccessor::Visit_avx2(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + PROCESS_8_VALUES_FN process_8_values_fn) { + // Number of rows processed together in a single iteration of the loop (single + // call to the provided processing lambda). + // + constexpr int unroll = 8; + + bool is_fixed_length_column = + rows.metadata().column_metadatas[column_id].is_fixed_length; + + // There are 4 cases, each requiring different steps: + // 1. Varying length column that is the first varying length column in a row + // 2. Varying length column that is not the first varying length column in a + // row + // 3. Fixed length column in a fixed length row + // 4. Fixed length column in a varying length row + + if (!is_fixed_length_column) { + int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); + const uint8_t* row_ptr_base = rows.data(2); + const uint32_t* row_offsets = rows.offsets(); + uint32_t field_offset_within_row, field_length; + + if (varbinary_column_id == 0) { + // Case 1: This is the first varbinary column + // + __m256i field_offset_within_row = _mm256_set1_epi32(rows.metadata().fixed_length); + __m256i varbinary_end_array_offset = + _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_i32gather_epi32( + reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + __m256i field_length = _mm256_sub_epi32( + _mm256_i32gather_epi32( + reinterpret_cast(row_ptr_base), + _mm256_add_epi32(row_offset, varbinary_end_array_offset), 1), + field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, + _mm256_add_epi32(row_offset, field_offset_within_row), + field_length); + } + } else { + // Case 2: This is second or later varbinary column + // + __m256i varbinary_end_array_offset = + _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset + + sizeof(uint32_t) * (varbinary_column_id - 1)); + auto row_ptr_base_i64 = + reinterpret_cast(row_ptr_base); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_i32gather_epi32( + reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + __m256i end_array_offset = + _mm256_add_epi32(row_offset, varbinary_end_array_offset); + + __m256i field_offset_within_row_A = _mm256_i32gather_epi64( + row_ptr_base_i64, _mm256_castsi256_si128(end_array_offset), 1); + __m256i field_offset_within_row_B = _mm256_i32gather_epi64( + row_ptr_base_i64, _mm256_extracti128_si256(end_array_offset, 1), 1); + field_offset_within_row_A = _mm256_permutevar8x32_epi32( + field_offset_within_row_A, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + field_offset_within_row_B = _mm256_permutevar8x32_epi32( + field_offset_within_row_B, _mm256_setr_epi32(1, 3, 5, 7, 0, 2, 4, 6)); + + __m256i field_offset_within_row = _mm256_blend_epi32( + field_offset_within_row_A, field_offset_within_row_B, 0xf0); + + __m256i alignment_padding = + _mm256_andnot_si256(field_offset_within_row, _mm256_set1_epi8(0xff)); + alignment_padding = _mm256_add_epi32(alignment_padding, _mm256_set1_epi32(1)); + alignment_padding = _mm256_and_si256( + alignment_padding, _mm256_set1_epi32(rows.metadata().string_alignment - 1)); + + field_offset_within_row = + _mm256_add_epi32(field_offset_within_row, alignment_padding); + + __m256i field_length = _mm256_blend_epi32(field_offset_within_row_A, + field_offset_within_row_B, 0x0f); + field_length = _mm256_permute4x64_epi64(field_length, + 0x4e); // Swapping low and high 128-bits + field_length = _mm256_sub_epi32(field_length, field_offset_within_row); + + process_8_values_fn(i * unroll, row_ptr_base, + _mm256_add_epi32(row_offset, field_offset_within_row), + field_length); + } + } + } + + if (is_fixed_length_column) { + __m256i field_offset_within_row = + _mm256_set1_epi32(rows.metadata().encoded_field_offset( + rows.metadata().pos_after_encoding(column_id))); + __m256i field_length = + _mm256_set1_epi32(rows.metadata().column_metadatas[column_id].fixed_length); + + bool is_fixed_length_row = rows.metadata().is_fixed_length; + if (is_fixed_length_row) { + // Case 3: This is a fixed length column in fixed length row + // + const uint8_t* row_ptr_base = rows.data(1); + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_mullo_epi32(row_id, field_length); + __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + } + } else { + // Case 4: This is a fixed length column in varying length row + // + const uint8_t* row_ptr_base = rows.data(2); + const uint32_t* row_offsets = rows.offsets(); + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_i32gather_epi32( + reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + } + } + } + + return num_rows - (num_rows % unroll); +} + +template +int RowArrayAccessor::VisitNulls_avx2(const KeyEncoder::KeyRowArray& rows, int column_id, + int num_rows, const uint32_t* row_ids, + PROCESS_8_VALUES_FN process_8_values_fn) { + // Number of rows processed together in a single iteration of the loop (single + // call to the provided processing lambda). + // + constexpr int unroll = 8; + + const uint8_t* null_masks = rows.null_masks(); + __m256i null_bits_per_row = + _mm256_set1_epi32(8 * rows.metadata().null_masks_bytes_per_row); + for (uint32_t i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i bit_id = _mm256_mullo_epi32(row_id, null_bits_per_row); + bit_id = _mm256_add_epi32(bit_id, _mm256_set1_epi32(column_id)); + __m256i bytes = _mm256_i32gather_epi32(reinterpret_cast(null_masks), + _mm256_srli_epi32(bit_id, 3), 1); + __m256i bit_in_word = _mm256_sllv_epi32( + _mm256_set1_epi32(1), _mm256_and_si256(bit_id, _mm256_set1_epi32(7))); + __m256i result = + _mm256_cmpeq_epi32(_mm256_and_si256(bytes, bit_in_word), bit_in_word); + uint64_t null_bytes = static_cast( + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(result)))); + null_bytes |= static_cast(_mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(result, 1)))) + << 32; + + process_8_values_fn(i * unroll, null_bytes); + } + + return num_rows - (num_rows % unroll); +} + +#endif + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index b0e423c858053..b1a417e1c370d 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -77,7 +77,8 @@ using int64_for_gather_t = const long long int; // NOLINT runtime-int // class MiniBatch { public: - static constexpr int kMiniBatchLength = 1024; + static constexpr int kLogMiniBatchLength = 10; + static constexpr int kMiniBatchLength = 1 << kLogMiniBatchLength; }; /// Storage used to allocate temporary vectors of a batch size. @@ -293,5 +294,51 @@ class ThreadIndexer { std::unordered_map id_to_index_; }; +// Helper class to calculate the modified number of rows to process using SIMD. +// +// Some array elements at the end will be skipped in order to avoid buffer +// overrun, when doing memory loads and stores using larger word size than a +// single array element. +// +class TailSkipForSIMD { + public: + static int64_t FixBitAccess(int num_bytes_accessed_together, int64_t num_rows, + int bit_offset) { + int64_t num_bytes = bit_util::BytesForBits(num_rows + bit_offset); + int64_t num_bytes_safe = + std::max(static_cast(0LL), num_bytes - num_bytes_accessed_together + 1); + int64_t num_rows_safe = + std::max(static_cast(0LL), 8 * num_bytes_safe - bit_offset); + return std::min(num_rows_safe, num_rows); + } + static int64_t FixBinaryAccess(int num_bytes_accessed_together, int64_t num_rows, + int64_t length) { + int64_t num_rows_to_skip = bit_util::CeilDiv(length, num_bytes_accessed_together); + int64_t num_rows_safe = + std::max(static_cast(0LL), num_rows - num_rows_to_skip); + return num_rows_safe; + } + static int64_t FixVarBinaryAccess(int num_bytes_accessed_together, int64_t num_rows, + const uint32_t* offsets) { + // Do not process rows that could read past the end of the buffer using N + // byte loads/stores. + // + int64_t num_rows_safe = num_rows; + while (num_rows_safe > 0 && + offsets[num_rows_safe] + num_bytes_accessed_together > offsets[num_rows]) { + --num_rows_safe; + } + return num_rows_safe; + } + static int FixSelection(int64_t num_rows_safe, int num_selected, + const uint16_t* selection) { + int num_selected_safe = num_selected; + while (num_selected_safe > 0 && selection[num_selected_safe] >= num_rows_safe) { + --num_selected_safe; + } + return num_selected_safe; + } +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index ab8e6cd77d14f..1db9b27731589 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -256,7 +256,7 @@ struct GrouperFastImpl : Grouper { impl->key_types_[icol] = key; } - impl->encoder_.Init(impl->col_metadata_, &impl->encode_ctx_, + impl->encoder_.Init(impl->col_metadata_, /* row_alignment = */ sizeof(uint64_t), /* string_alignment = */ sizeof(uint64_t)); RETURN_NOT_OK(impl->rows_.Init(ctx->memory_pool(), impl->encoder_.row_metadata())); @@ -264,24 +264,24 @@ struct GrouperFastImpl : Grouper { impl->rows_minibatch_.Init(ctx->memory_pool(), impl->encoder_.row_metadata())); impl->minibatch_size_ = impl->minibatch_size_min_; GrouperFastImpl* impl_ptr = impl.get(); - auto equal_func = [impl_ptr]( - int num_keys_to_compare, const uint16_t* selection_may_be_null, - const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, - uint16_t* out_selection_mismatch) { - arrow::compute::KeyCompare::CompareColumnsToRows( - num_keys_to_compare, selection_may_be_null, group_ids, &impl_ptr->encode_ctx_, - out_num_keys_mismatch, out_selection_mismatch, - impl_ptr->encoder_.GetBatchColumns(), impl_ptr->rows_); - }; - auto append_func = [impl_ptr](int num_keys, const uint16_t* selection) { + impl_ptr->map_equal_impl_ = + [impl_ptr](int num_keys_to_compare, const uint16_t* selection_may_be_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void*) { + arrow::compute::KeyCompare::CompareColumnsToRows( + num_keys_to_compare, selection_may_be_null, group_ids, + &impl_ptr->encode_ctx_, out_num_keys_mismatch, out_selection_mismatch, + impl_ptr->encoder_.GetBatchColumns(), impl_ptr->rows_, + /*are_cols_in_encoding_order=*/true); + }; + impl_ptr->map_append_impl_ = [impl_ptr](int num_keys, const uint16_t* selection, + void*) { RETURN_NOT_OK(impl_ptr->encoder_.EncodeSelected(&impl_ptr->rows_minibatch_, num_keys, selection)); return impl_ptr->rows_.AppendSelectionFrom(impl_ptr->rows_minibatch_, num_keys, nullptr); }; - RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool(), - impl->encode_ctx_.stack, impl->log_minibatch_max_, - equal_func, append_func)); + RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool())); impl->cols_.resize(num_columns); impl->minibatch_hashes_.resize(impl->minibatch_size_max_ + kPaddingForSIMD / sizeof(uint32_t)); @@ -382,7 +382,8 @@ struct GrouperFastImpl : Grouper { match_bitvector.mutable_data(), local_slots.mutable_data()); map_.find(batch_size_next, minibatch_hashes_.data(), match_bitvector.mutable_data(), local_slots.mutable_data(), - reinterpret_cast(group_ids->mutable_data()) + start_row); + reinterpret_cast(group_ids->mutable_data()) + start_row, + &temp_stack_, map_equal_impl_, nullptr); } auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); int num_ids; @@ -392,7 +393,8 @@ struct GrouperFastImpl : Grouper { RETURN_NOT_OK(map_.map_new_keys( num_ids, ids.mutable_data(), minibatch_hashes_.data(), - reinterpret_cast(group_ids->mutable_data()) + start_row)); + reinterpret_cast(group_ids->mutable_data()) + start_row, + &temp_stack_, map_equal_impl_, map_append_impl_, nullptr)); start_row += batch_size_next; @@ -460,7 +462,7 @@ struct GrouperFastImpl : Grouper { int64_t batch_size_next = std::min(num_groups - start_row, static_cast(minibatch_size_max_)); encoder_.DecodeFixedLengthBuffers(start_row, start_row, batch_size_next, rows_, - &cols_); + &cols_, encode_ctx_.hardware_flags, &temp_stack_); start_row += batch_size_next; } @@ -480,7 +482,8 @@ struct GrouperFastImpl : Grouper { int64_t batch_size_next = std::min(num_groups - start_row, static_cast(minibatch_size_max_)); encoder_.DecodeVaryingLengthBuffers(start_row, start_row, batch_size_next, rows_, - &cols_); + &cols_, encode_ctx_.hardware_flags, + &temp_stack_); start_row += batch_size_next; } } @@ -544,6 +547,8 @@ struct GrouperFastImpl : Grouper { arrow::compute::KeyEncoder::KeyRowArray rows_minibatch_; arrow::compute::KeyEncoder encoder_; arrow::compute::SwissTable map_; + arrow::compute::SwissTable::EqualImpl map_equal_impl_; + arrow::compute::SwissTable::AppendImpl map_append_impl_; }; /// C++ abstract base class for the HashAggregateKernel interface. From e451d774bec92fecb6a224d6cd372341d35ed705 Mon Sep 17 00:00:00 2001 From: michalursa Date: Fri, 8 Apr 2022 00:18:15 -0700 Subject: [PATCH 2/2] Adding consolidated output key to SwissJoin --- cpp/src/arrow/compute/exec/hash_join.cc | 4 +- cpp/src/arrow/compute/exec/hash_join.h | 13 +- .../arrow/compute/exec/hash_join_benchmark.cc | 9 +- cpp/src/arrow/compute/exec/hash_join_node.cc | 139 +++++++++++++++++- cpp/src/arrow/compute/exec/options.h | 14 +- cpp/src/arrow/compute/exec/swiss_join.cc | 36 ++++- cpp/src/arrow/compute/exec/swiss_join.h | 4 +- 7 files changed, 198 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index b6fd0a851882e..4e7dc7f16678f 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -88,7 +88,9 @@ class HashJoinBasicImpl : public HashJoinImpl { std::vector key_cmp, Expression filter, OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, - TaskScheduler::ScheduleImpl schedule_task_callback) override { + TaskScheduler::ScheduleImpl schedule_task_callback, + OutputKeyProbeCallback output_key_probe_callback, + OutputKeyBuildCallback output_key_build_callback) override { num_threads = std::max(num_threads, static_cast(1)); START_SPAN(span_, "HashJoinBasicImpl", diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 9aaadfcd5e796..790e0fdd8eed6 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -64,7 +64,8 @@ class ARROW_EXPORT HashJoinSchema { Result BindFilter(Expression filter, const Schema& left_schema, const Schema& right_schema); std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, - const std::string& right_field_name_suffix); + const std::string& right_field_name_suffix, + bool append_consolidated_key = false); bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); } @@ -100,10 +101,16 @@ class ARROW_EXPORT HashJoinSchema { const std::vector& key); }; +class RowArray; + class HashJoinImpl { public: using OutputBatchCallback = std::function; using FinishedCallback = std::function; + using OutputKeyProbeCallback = + std::function; + using OutputKeyBuildCallback = + std::function; virtual ~HashJoinImpl() = default; virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, @@ -112,7 +119,9 @@ class HashJoinImpl { std::vector key_cmp, Expression filter, OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, - TaskScheduler::ScheduleImpl schedule_task_callback) = 0; + TaskScheduler::ScheduleImpl schedule_task_callback, + OutputKeyProbeCallback output_key_probe_callback, + OutputKeyBuildCallback output_key_build_callback) = 0; virtual Status InputReceived(size_t thread_index, int side, ExecBatch batch) = 0; virtual Status InputFinished(size_t thread_index, int side) = 0; virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0; diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc index a69de21f92fb4..cdbccc5d6295c 100644 --- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc @@ -138,8 +138,13 @@ class JoinBenchmark { DCHECK_OK(join_->Init( ctx_.get(), settings.join_type, !is_parallel, settings.num_threads, &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), {JoinKeyCmp::EQ}, - std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {}, - schedule_callback)); + std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {}, schedule_callback, + [](int64_t, const ExecBatch&, int, const uint16_t*) -> Status { + return Status::OK(); + }, + [](int64_t, const RowArray&, int, const uint32_t*) -> Status { + return Status::OK(); + })); } void RunJoin() { diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 74259ada37426..f80af1acbb457 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -22,6 +22,7 @@ #include "arrow/compute/exec/hash_join_dict.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/swiss_join.h" #include "arrow/compute/exec/util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" @@ -34,6 +35,112 @@ using internal::checked_cast; namespace compute { +class ConsolidatedJoinKey { + public: + static void UpdateOutputSchema(const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector>* fields_to_update); + void Init(bool enabled, MemoryPool* pool, size_t num_threads, + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right); + Status OutputKeyProbe(int64_t thread_index, const ExecBatch& batch_key, int num_rows, + const uint16_t* row_ids); + Status OutputKeyBuild(int64_t thread_index, const RowArray& hash_table_key, + int num_rows, const uint32_t* key_ids); + void UpdateOutputBatch(int64_t thread_index, ExecBatch* batch); + + private: + static constexpr int kLogNumRows = 15; + + bool enabled_; + MemoryPool* pool_; + size_t num_threads_; + const HashJoinProjectionMaps* proj_map_left_; + const HashJoinProjectionMaps* proj_map_right_; + struct ThreadLocalState { + std::vector values_; + }; + std::vector thread_local_states_; +}; + +void ConsolidatedJoinKey::UpdateOutputSchema( + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector>* fields_to_update) { + size_t num_fields_before = fields_to_update->size(); + int num_key_cols = proj_map_left->num_cols(HashJoinProjection::KEY); + fields_to_update->resize(num_fields_before + num_key_cols); + for (int icol = 0; icol < num_key_cols; ++icol) { + (*fields_to_update)[num_fields_before + icol] = std::make_shared( + proj_map_left->field_name(HashJoinProjection::KEY, icol) + "_" + + proj_map_right->field_name(HashJoinProjection::KEY, icol), + proj_map_left->data_type(HashJoinProjection::KEY, icol), true /*nullable*/); + } +} + +void ConsolidatedJoinKey::Init(bool enabled, MemoryPool* pool, size_t num_threads, + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right) { + enabled_ = enabled; + pool_ = pool; + num_threads_ = num_threads; + proj_map_left_ = proj_map_left; + proj_map_right_ = proj_map_right; + thread_local_states_.resize(num_threads_); + int num_key_cols = proj_map_left->num_cols(HashJoinProjection::KEY); + for (size_t i = 0; i < num_threads; ++i) { + thread_local_states_[i].values_.resize(num_key_cols); + for (int icol = 0; icol < num_key_cols; ++icol) { + thread_local_states_[i].values_[icol].Init( + proj_map_left->data_type(HashJoinProjection::KEY, icol), pool, kLogNumRows); + } + } +} + +Status ConsolidatedJoinKey::OutputKeyProbe(int64_t thread_index, + const ExecBatch& batch_key, int num_rows, + const uint16_t* row_ids) { + if (!enabled_) { + return Status::OK(); + } + int num_key_cols = proj_map_left_->num_cols(HashJoinProjection::KEY); + for (int icol = 0; icol < num_key_cols; ++icol) { + RETURN_NOT_OK(ExecBatchBuilder::AppendSelected( + batch_key[icol].array(), thread_local_states_[thread_index].values_[icol], + num_rows, row_ids, pool_)); + } + return Status::OK(); +} + +Status ConsolidatedJoinKey::OutputKeyBuild(int64_t thread_index, + const RowArray& hash_table_key, int num_rows, + const uint32_t* key_ids) { + if (!enabled_) { + return Status::OK(); + } + int num_key_cols = proj_map_right_->num_cols(HashJoinProjection::KEY); + for (int icol = 0; icol < num_key_cols; ++icol) { + RETURN_NOT_OK( + hash_table_key.DecodeSelected(&thread_local_states_[thread_index].values_[icol], + icol, num_rows, key_ids, pool_)); + } + return Status::OK(); +} + +void ConsolidatedJoinKey::UpdateOutputBatch(int64_t thread_index, ExecBatch* batch) { + if (!enabled_) { + return; + } + int num_key_cols = proj_map_left_->num_cols(HashJoinProjection::KEY); + size_t num_cols_before = batch->values.size(); + batch->values.resize(num_cols_before + num_key_cols); + for (int icol = 0; icol < num_key_cols; ++icol) { + batch->values[num_cols_before + icol] = + thread_local_states_[thread_index].values_[icol].array_data(); + thread_local_states_[thread_index].values_[icol].Clear(true); + } +} + // Check if a type is supported in a join (as either a key or non-key column) bool HashJoinSchema::IsTypeSupported(const DataType& type) { const Type::type id = type.id(); @@ -275,8 +382,8 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc } std::shared_ptr HashJoinSchema::MakeOutputSchema( - const std::string& left_field_name_suffix, - const std::string& right_field_name_suffix) { + const std::string& left_field_name_suffix, const std::string& right_field_name_suffix, + bool append_consolidated_key) { std::vector> fields; int left_size = proj_maps[0].num_cols(HashJoinProjection::OUTPUT); int right_size = proj_maps[1].num_cols(HashJoinProjection::OUTPUT); @@ -330,6 +437,11 @@ std::shared_ptr HashJoinSchema::MakeOutputSchema( std::make_shared(input_field_name, input_data_type, true /*nullable*/); } } + + if (append_consolidated_key) { + ConsolidatedJoinKey::UpdateOutputSchema(&proj_maps[0], &proj_maps[1], &fields); + } + return std::make_shared(std::move(fields)); } @@ -496,6 +608,7 @@ class HashJoinNode : public ExecNode { schema_mgr_(std::move(schema_mgr)), impl_(std::move(impl)) { complete_.store(false); + append_consolidated_key_ = join_options.append_consolidated_key; } static Result Make(ExecPlan* plan, std::vector inputs, @@ -531,7 +644,8 @@ class HashJoinNode : public ExecNode { // Generate output schema std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( - join_options.output_suffix_for_left, join_options.output_suffix_for_right); + join_options.output_suffix_for_left, join_options.output_suffix_for_right, + join_options.append_consolidated_key); // Create hash join implementation object // SwissJoin does not support: @@ -628,15 +742,30 @@ class HashJoinNode : public ExecNode { bool use_sync_execution = !(plan_->exec_context()->executor()); size_t num_threads = use_sync_execution ? 1 : thread_indexer_.Capacity(); + consolidated_join_key_.Init( + append_consolidated_key_, plan_->exec_context()->memory_pool(), num_threads, + &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1])); + RETURN_NOT_OK(impl_->Init( plan_->exec_context(), join_type_, use_sync_execution, num_threads, &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), key_cmp_, filter_, - [this](int64_t /*ignored*/, ExecBatch batch) { + [this](int64_t thread_index, ExecBatch batch) { + consolidated_join_key_.UpdateOutputBatch(thread_index, &batch); this->OutputBatchCallback(batch); }, [this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); }, [this](std::function func) -> Status { return this->ScheduleTaskCallback(std::move(func)); + }, + [this](int64_t thread_index, const ExecBatch& batch_key, int num_rows, + const uint16_t* row_ids) -> Status { + return consolidated_join_key_.OutputKeyProbe(thread_index, batch_key, num_rows, + row_ids); + }, + [this](int64_t thread_index, const RowArray& hash_table_key, int num_rows, + const uint32_t* key_ids) -> Status { + return consolidated_join_key_.OutputKeyBuild(thread_index, hash_table_key, + num_rows, key_ids); })); return Status::OK(); } @@ -704,6 +833,8 @@ class HashJoinNode : public ExecNode { ThreadIndexer thread_indexer_; std::unique_ptr schema_mgr_; std::unique_ptr impl_; + bool append_consolidated_key_; + ConsolidatedJoinKey consolidated_join_key_; }; namespace internal { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index d5780753254b7..a485e409a2b80 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -205,11 +205,13 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { JoinType in_join_type, std::vector in_left_keys, std::vector in_right_keys, Expression filter = literal(true), std::string output_suffix_for_left = default_output_suffix_for_left, - std::string output_suffix_for_right = default_output_suffix_for_right) + std::string output_suffix_for_right = default_output_suffix_for_right, + bool append_consolidated_key = false) : join_type(in_join_type), left_keys(std::move(in_left_keys)), right_keys(std::move(in_right_keys)), output_all(true), + append_consolidated_key(append_consolidated_key), output_suffix_for_left(std::move(output_suffix_for_left)), output_suffix_for_right(std::move(output_suffix_for_right)), filter(std::move(filter)) { @@ -223,13 +225,15 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { std::vector right_keys, std::vector left_output, std::vector right_output, Expression filter = literal(true), std::string output_suffix_for_left = default_output_suffix_for_left, - std::string output_suffix_for_right = default_output_suffix_for_right) + std::string output_suffix_for_right = default_output_suffix_for_right, + bool append_consolidated_key = false) : join_type(join_type), left_keys(std::move(left_keys)), right_keys(std::move(right_keys)), output_all(false), left_output(std::move(left_output)), right_output(std::move(right_output)), + append_consolidated_key(append_consolidated_key), output_suffix_for_left(std::move(output_suffix_for_left)), output_suffix_for_right(std::move(output_suffix_for_right)), filter(std::move(filter)) { @@ -244,13 +248,15 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { std::vector right_output, std::vector key_cmp, Expression filter = literal(true), std::string output_suffix_for_left = default_output_suffix_for_left, - std::string output_suffix_for_right = default_output_suffix_for_right) + std::string output_suffix_for_right = default_output_suffix_for_right, + bool append_consolidated_key = false) : join_type(join_type), left_keys(std::move(left_keys)), right_keys(std::move(right_keys)), output_all(false), left_output(std::move(left_output)), right_output(std::move(right_output)), + append_consolidated_key(append_consolidated_key), key_cmp(std::move(key_cmp)), output_suffix_for_left(std::move(output_suffix_for_left)), output_suffix_for_right(std::move(output_suffix_for_right)), @@ -269,6 +275,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { std::vector left_output; // output fields passed from right input std::vector right_output; + // output consolidated key + bool append_consolidated_key; // key comparison function (determines whether a null key is equal another null // key or not) std::vector key_cmp; diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc b/cpp/src/arrow/compute/exec/swiss_join.cc index 62beba35db14d..5f3cf19ecb060 100644 --- a/cpp/src/arrow/compute/exec/swiss_join.cc +++ b/cpp/src/arrow/compute/exec/swiss_join.cc @@ -2434,7 +2434,8 @@ void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type, Status JoinProbeProcessor::OnNextBatch( int64_t thread_id, const ExecBatch& keypayload_batch, util::TempVectorStack* temp_stack, - std::vector* temp_column_arrays) { + std::vector* temp_column_arrays, + HashJoinImpl::OutputKeyProbeCallback& output_key_probe_callback) { const SwissTable* swiss_table = hash_table_->keys()->swiss_table(); int64_t hardware_flags = swiss_table->hardware_flags(); int minibatch_size = swiss_table->minibatch_size(); @@ -2510,6 +2511,9 @@ Status JoinProbeProcessor::OnNextBatch( static_cast(minibatch_start); } + RETURN_NOT_OK( + output_key_probe_callback(thread_id, keypayload_batch, num_passing_ids, + materialize_batch_ids_buf.mutable_data())); RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); @@ -2548,6 +2552,8 @@ Status JoinProbeProcessor::OnNextBatch( // Call materialize for resulting id tuples pointing to matching pairs // of rows. // + RETURN_NOT_OK(output_key_probe_callback(thread_id, keypayload_batch, + num_matches_next, materialize_batch_ids)); RETURN_NOT_OK(materialize_[thread_id]->Append( keypayload_batch, num_matches_next, materialize_batch_ids, materialize_key_ids, materialize_payload_ids, @@ -2573,6 +2579,9 @@ Status JoinProbeProcessor::OnNextBatch( static_cast(minibatch_start); } + RETURN_NOT_OK( + output_key_probe_callback(thread_id, keypayload_batch, num_passing_ids, + materialize_batch_ids_buf.mutable_data())); RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); @@ -2689,7 +2698,9 @@ class SwissJoin : public HashJoinImpl { std::vector key_cmp, Expression filter, OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, - TaskScheduler::ScheduleImpl schedule_task_callback) override { + TaskScheduler::ScheduleImpl schedule_task_callback, + OutputKeyProbeCallback output_key_probe_callback, + OutputKeyBuildCallback output_key_build_callback) override { num_threads_ = static_cast(std::max(num_threads, static_cast(1))); START_SPAN(span_, "HashJoinBasicImpl", @@ -2742,6 +2753,9 @@ class SwissJoin : public HashJoinImpl { RETURN_NOT_OK(InitScheduler(use_sync_execution, num_threads, schedule_task_callback)); + output_key_probe_callback_ = output_key_probe_callback; + output_key_build_callback_ = output_key_build_callback; + return Status::OK(); } @@ -2809,7 +2823,7 @@ class SwissJoin : public HashJoinImpl { } return CancelIfNotOK(probe_processor_.OnNextBatch( thread_index, keypayload_batch, &local_states_[thread_index].temp_stack, - &local_states_[thread_index].temp_column_arrays)); + &local_states_[thread_index].temp_column_arrays, output_key_probe_callback_)); } Status InputFinished(size_t thread_id, int side) override { @@ -2994,10 +3008,10 @@ class SwissJoin : public HashJoinImpl { ExecBatchQueue& batches = batch_queue_[0]; ExecBatch* input_batch = batches.shared_batch(batch_id); - RETURN_NOT_OK(CancelIfNotOK( - probe_processor_.OnNextBatch(static_cast(thread_id), *input_batch, - &local_states_[thread_id].temp_stack, - &local_states_[thread_id].temp_column_arrays))); + RETURN_NOT_OK(CancelIfNotOK(probe_processor_.OnNextBatch( + static_cast(thread_id), *input_batch, + &local_states_[thread_id].temp_stack, + &local_states_[thread_id].temp_column_arrays, output_key_probe_callback_))); // Release input batch // input_batch->values.clear(); @@ -3108,7 +3122,11 @@ class SwissJoin : public HashJoinImpl { // Materialize (and output whenever buffers get full) hash table // values according to the generated list of ids. // - Status status = local_states_[thread_id].materialize.AppendBuildOnly( + Status status; + status = output_key_build_callback_(thread_id, *hash_table_.keys()->keys(), + num_output_rows, key_ids_buf.mutable_data()); + RETURN_NOT_OK(CancelIfNotOK(status)); + status = local_states_[thread_id].materialize.AppendBuildOnly( num_output_rows, key_ids_buf.mutable_data(), payload_ids_buf.mutable_data(), [&](ExecBatch batch) { output_batch_callback_(static_cast(thread_id), std::move(batch)); @@ -3233,6 +3251,8 @@ class SwissJoin : public HashJoinImpl { // Callbacks OutputBatchCallback output_batch_callback_; FinishedCallback finished_callback_; + OutputKeyProbeCallback output_key_probe_callback_; + OutputKeyBuildCallback output_key_build_callback_; struct ThreadLocalState { JoinResultMaterialize materialize; diff --git a/cpp/src/arrow/compute/exec/swiss_join.h b/cpp/src/arrow/compute/exec/swiss_join.h index d419cc0c62e0d..7982e48717d54 100644 --- a/cpp/src/arrow/compute/exec/swiss_join.h +++ b/cpp/src/arrow/compute/exec/swiss_join.h @@ -18,6 +18,7 @@ #pragma once #include +#include "arrow/compute/exec/hash_join.h" #include "arrow/compute/exec/key_encode.h" #include "arrow/compute/exec/key_map.h" #include "arrow/compute/exec/options.h" @@ -851,7 +852,8 @@ class JoinProbeProcessor { const std::vector* cmp, OutputBatchFn output_batch_fn); Status OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, util::TempVectorStack* temp_stack, - std::vector* temp_column_arrays); + std::vector* temp_column_arrays, + HashJoinImpl::OutputKeyProbeCallback& output_key_probe_callback); // Must be called by a single-thread having exclusive access to the instance // of this class. The caller is responsible for ensuring that.