Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding consolidated output key to SwissJoin #7

Draft
wants to merge 2 commits into
base: ARROW-14182-hash-join2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ if(ARROW_COMPUTE)
compute/exec/project_node.cc
compute/exec/sink_node.cc
compute/exec/source_node.cc
compute/exec/swiss_join.cc
compute/exec/task_util.cc
compute/exec/tpch_node.cc
compute/exec/union_node.cc
Expand Down Expand Up @@ -450,6 +451,7 @@ if(ARROW_COMPUTE)
append_avx2_src(compute/exec/key_encode_avx2.cc)
append_avx2_src(compute/exec/key_hash_avx2.cc)
append_avx2_src(compute/exec/key_map_avx2.cc)
append_avx2_src(compute/exec/swiss_join_avx2.cc)
append_avx2_src(compute/exec/util_avx2.cc)

list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc)
Expand Down
106 changes: 50 additions & 56 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) override {
TaskScheduler::ScheduleImpl schedule_task_callback,
OutputKeyProbeCallback output_key_probe_callback,
OutputKeyBuildCallback output_key_build_callback) override {
num_threads = std::max(num_threads, static_cast<size_t>(1));

START_SPAN(span_, "HashJoinBasicImpl",
Expand All @@ -98,7 +101,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
ctx_ = ctx;
join_type_ = join_type;
num_threads_ = num_threads;
schema_mgr_ = schema_mgr;
schema_[0] = proj_map_left;
schema_[1] = proj_map_right;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
Expand Down Expand Up @@ -139,12 +143,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
private:
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
std::vector<ValueDescr> data_types;
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
data_types.resize(num_cols);
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] =
ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol),
ValueDescr::ARRAY);
data_types[icol] = ValueDescr(schema_[side]->data_type(projection_handle, icol),
ValueDescr::ARRAY);
}
encoder->Init(data_types, ctx_);
encoder->Clear();
Expand All @@ -155,8 +158,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ThreadLocalState& local_state = local_states_[thread_index];
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
bool has_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
}
Expand All @@ -168,11 +170,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
ExecBatch projected({}, batch.length);
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
projected.values.resize(num_cols);

auto to_input =
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT);
for (int icol = 0; icol < num_cols; ++icol) {
projected.values[icol] = batch.values[to_input.get(icol)];
}
Expand Down Expand Up @@ -235,16 +236,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
ExecBatch* opt_right_payload) {
ExecBatch result({}, batch_size_next);
int num_out_cols_left =
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT);

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
auto from_payload =
schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_left; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
Expand All @@ -262,10 +260,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
? opt_left_key->values[from_key.get(icol)]
: opt_left_payload->values[from_payload.get(icol)];
}
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
from_payload =
schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_right; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
Expand All @@ -284,7 +281,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
: opt_right_payload->values[from_payload.get(icol)];
}

output_batch_callback_(std::move(result));
output_batch_callback_(0, std::move(result));

// Update the counter of produced batches
//
Expand All @@ -310,13 +307,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (!schema_mgr_->LeftPayloadIsEmpty()) {
if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (!schema_mgr_->RightPayloadIsEmpty()) {
if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}
Expand All @@ -336,14 +333,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap left_to_key =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);
Expand Down Expand Up @@ -419,15 +416,14 @@ class HashJoinBasicImpl : public HashJoinImpl {

bool has_left =
(join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_right =
(join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_left_payload =
has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_right_payload =
has_right &&
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);

ThreadLocalState& local_state = local_states_[thread_index];
InitLocalStateIfNeeded(thread_index);
Expand All @@ -450,7 +446,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_ASSIGN_OR_RAISE(right_key,
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
// Post process build side keys that use dictionary
RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_));
}
if (has_right_payload) {
ARROW_ASSIGN_OR_RAISE(right_payload,
Expand Down Expand Up @@ -550,8 +546,7 @@ class HashJoinBasicImpl : public HashJoinImpl {

RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
batch, &batch_key_for_lookups));
bool has_left_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_left_payload) {
local_state.exec_batch_payloads.Clear();
RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD,
Expand All @@ -563,13 +558,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
bool use_key_batch_for_dicts =
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_);
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1],
dict_build_, batch, &row_encoder_for_lookups,
&batch_key_for_lookups, ctx_));
}

// Collect information about all nulls in key columns.
Expand Down Expand Up @@ -609,9 +604,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (batches.empty()) {
hash_table_empty_ = true;
} else {
dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
bool has_payload =
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_);
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
}
Expand All @@ -626,11 +620,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
} else if (hash_table_empty_) {
hash_table_empty_ = false;

RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_));
}
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
batch, &hash_table_keys_, ctx_));
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
&hash_table_keys_, ctx_));
if (has_payload) {
RETURN_NOT_OK(
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
Expand All @@ -643,7 +637,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

if (hash_table_empty_) {
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_));
}

return Status::OK();
Expand Down Expand Up @@ -869,7 +863,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecContext* ctx_;
JoinType join_type_;
size_t num_threads_;
HashJoinSchema* schema_mgr_;
const HashJoinProjectionMaps* schema_[2];
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
std::unique_ptr<TaskScheduler> scheduler_;
Expand Down
23 changes: 19 additions & 4 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ class ARROW_EXPORT HashJoinSchema {
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool HasDictionaries() const;

bool HasLargeBinary() const;

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
const std::string& right_field_name_suffix);
const std::string& right_field_name_suffix,
bool append_consolidated_key = false);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }

Expand Down Expand Up @@ -96,23 +101,33 @@ class ARROW_EXPORT HashJoinSchema {
const std::vector<FieldRef>& key);
};

class RowArray;

class HashJoinImpl {
public:
using OutputBatchCallback = std::function<void(ExecBatch)>;
using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
using FinishedCallback = std::function<void(int64_t)>;
using OutputKeyProbeCallback =
std::function<Status(int64_t, const ExecBatch&, int, const uint16_t*)>;
using OutputKeyBuildCallback =
std::function<Status(int64_t, const RowArray&, int, const uint32_t*)>;

virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) = 0;
TaskScheduler::ScheduleImpl schedule_task_callback,
OutputKeyProbeCallback output_key_probe_callback,
OutputKeyBuildCallback output_key_build_callback) = 0;
virtual Status InputReceived(size_t thread_index, int side, ExecBatch batch) = 0;
virtual Status InputFinished(size_t thread_index, int side) = 0;
virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;

static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
static Result<std::unique_ptr<HashJoinImpl>> MakeSwiss();

protected:
util::tracing::Span span_;
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class JoinBenchmark {
build_metadata["null_probability"] = std::to_string(settings.null_percentage);
build_metadata["min"] = std::to_string(min_build_value);
build_metadata["max"] = std::to_string(max_build_value);
build_metadata["min_length"] = "2";
build_metadata["max_length"] = "20";

std::unordered_map<std::string, std::string> probe_metadata;
probe_metadata["null_probability"] = std::to_string(settings.null_percentage);
Expand Down Expand Up @@ -124,7 +126,7 @@ class JoinBenchmark {
DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_.schema, left_keys,
*r_batches_.schema, right_keys, filter, "l_", "r_"));

join_ = *HashJoinImpl::MakeBasic();
join_ = *HashJoinImpl::MakeSwiss();

omp_set_num_threads(settings.num_threads);
auto schedule_callback = [](std::function<Status(size_t)> func) -> Status {
Expand All @@ -135,8 +137,14 @@ class JoinBenchmark {

DCHECK_OK(join_->Init(
ctx_.get(), settings.join_type, !is_parallel, settings.num_threads,
schema_mgr_.get(), {JoinKeyCmp::EQ}, std::move(filter), [](ExecBatch) {},
[](int64_t x) {}, schedule_callback));
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), {JoinKeyCmp::EQ},
std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {}, schedule_callback,
[](int64_t, const ExecBatch&, int, const uint16_t*) -> Status {
return Status::OK();
},
[](int64_t, const RowArray&, int, const uint32_t*) -> Status {
return Status::OK();
}));
}

void RunJoin() {
Expand Down
Loading