Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Sum_IF() aggregate function #24174

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions presto-docs/src/main/sphinx/functions/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,17 @@ General Aggregate Functions

Returns the sum of all input values.

.. function:: sum_if(x, y) -> [same as y]

Returns the sum of all ``y`` values for rows where ``x`` is ``TRUE``.
This function is equivalent to ``sum(CASE WHEN x THEN y END)``.

.. function:: sum_if(x, y, d) -> [same as y]

Returns the sum of all ``y`` values for rows where ``x`` is ``TRUE``, using ``d``
instead of null when ``x`` is ``FALSE``. This function is equivalent to
``sum(CASE WHEN x THEN y ELSE d END)``.

Bitwise Aggregate Functions
---------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@
import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation;
import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation;
import com.facebook.presto.operator.aggregation.DoubleSumAggregation;
import com.facebook.presto.operator.aggregation.DoubleSumIfAggregation;
import com.facebook.presto.operator.aggregation.EntropyAggregation;
import com.facebook.presto.operator.aggregation.GeometricMeanAggregations;
import com.facebook.presto.operator.aggregation.IntervalDayToSecondAverageAggregation;
import com.facebook.presto.operator.aggregation.IntervalDayToSecondSumAggregation;
import com.facebook.presto.operator.aggregation.IntervalYearToMonthAverageAggregation;
import com.facebook.presto.operator.aggregation.IntervalYearToMonthSumAggregation;
import com.facebook.presto.operator.aggregation.LongSumAggregation;
import com.facebook.presto.operator.aggregation.LongSumIfAggregation;
import com.facebook.presto.operator.aggregation.MaxDataSizeForStats;
import com.facebook.presto.operator.aggregation.MergeHyperLogLogAggregation;
import com.facebook.presto.operator.aggregation.MergeQuantileDigestFunction;
Expand Down Expand Up @@ -715,8 +717,10 @@ private List<? extends SqlFunction> getBuiltInFunctions(FunctionsConfig function
.aggregates(BooleanAndAggregation.class)
.aggregates(BooleanOrAggregation.class)
.aggregates(DoubleSumAggregation.class)
.aggregates(DoubleSumIfAggregation.class)
.aggregates(RealSumAggregation.class)
.aggregates(LongSumAggregation.class)
.aggregates(LongSumIfAggregation.class)
.aggregates(IntervalDayToSecondSumAggregation.class)
.aggregates(IntervalYearToMonthSumAggregation.class)
.aggregates(AverageAggregations.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.operator.aggregation.state.DoubleState;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;

import static com.facebook.presto.common.type.DoubleType.DOUBLE;

@AggregationFunction(value = "sum_if", isCalledOnNullInput = true)
public final class DoubleSumIfAggregation
{
private DoubleSumIfAggregation() {}

@InputFunction
public static void input(@AggregationState DoubleState state, @SqlType(StandardTypes.BOOLEAN) boolean value,
@SqlType(StandardTypes.DOUBLE) double sum)
{
if (value) {
state.setDouble(state.getDouble() + sum);
}
}

@InputFunction
public static void input(@AggregationState DoubleState state, @SqlType(StandardTypes.BOOLEAN) boolean value,
@SqlType(StandardTypes.DOUBLE) double sum, @SqlType(StandardTypes.DOUBLE) double defaultValue)
{
if (value) {
state.setDouble(state.getDouble() + sum);
}
else {
state.setDouble(state.getDouble() + defaultValue);
}
}

@CombineFunction
public static void combine(@AggregationState DoubleState state, @AggregationState DoubleState otherState)
{
state.setDouble(state.getDouble() + otherState.getDouble());
}

@OutputFunction(StandardTypes.DOUBLE)
public static void output(@AggregationState DoubleState state, BlockBuilder out)
{
DOUBLE.writeDouble(out, state.getDouble());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.operator.aggregation.state.LongState;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;

import static com.facebook.presto.common.type.BigintType.BIGINT;

@AggregationFunction(value = "sum_if", isCalledOnNullInput = true)
public final class LongSumIfAggregation
{
private LongSumIfAggregation() {}

@InputFunction
public static void input(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value,
@SqlType(StandardTypes.BIGINT) long sum)
{
if (value) {
state.setLong(state.getLong() + sum);
}
}

@InputFunction
public static void input(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value,
@SqlType(StandardTypes.BIGINT) long sum, @SqlType(StandardTypes.BIGINT) long defaultValue)
{
if (value) {
state.setLong(state.getLong() + sum);
}
else {
state.setLong(state.getLong() + defaultValue);
}
}

@CombineFunction
public static void combine(@AggregationState LongState state, @AggregationState LongState otherState)
{
state.setLong(state.getLong() + otherState.getLong());
}

@OutputFunction(StandardTypes.BIGINT)
public static void output(@AggregationState LongState state, BlockBuilder out)
{
BIGINT.writeLong(out, state.getLong());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.google.common.collect.ImmutableList;

import java.util.List;

import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;

public class TestDoubleSumIfAggregation
extends AbstractTestAggregationFunction
{
@Override
public Block[] getSequenceBlocks(int start, int length)
{
BlockBuilder conditions = BOOLEAN.createBlockBuilder(null, length);
BlockBuilder values = DOUBLE.createBlockBuilder(null, length);
for (int i = start; i < start + length; i++) {
BOOLEAN.writeBoolean(conditions, i % 2 == 0);
DOUBLE.writeDouble(values, i);
}
return new Block[] {conditions.build(), values.build()};
}

@Override
public Number getExpectedValue(int start, int length)
{
double sum = 0.0;
for (int i = start; i < start + length; i++) {
if (i % 2 == 0) {
sum += i;
}
}
return sum;
}

@Override
protected String getFunctionName()
{
return "sum_if";
}

@Override
protected List<String> getFunctionParameterTypes()
{
return ImmutableList.of(StandardTypes.BOOLEAN, StandardTypes.DOUBLE);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.google.common.collect.ImmutableList;

import java.util.List;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;

public class TestLongSumIfAggregation
extends AbstractTestAggregationFunction
{
@Override
public Block[] getSequenceBlocks(int start, int length)
{
BlockBuilder conditions = BOOLEAN.createBlockBuilder(null, length);
BlockBuilder values = BIGINT.createBlockBuilder(null, length);
for (int i = start; i < start + length; i++) {
BOOLEAN.writeBoolean(conditions, i % 2 == 0);
BIGINT.writeLong(values, i);
}
return new Block[] {conditions.build(), values.build()};
}

@Override
public Number getExpectedValue(int start, int length)
{
long sum = 0L;
for (int i = start; i < start + length; i++) {
if (i % 2 == 0) {
sum += i;
}
}
return sum;
}

@Override
protected String getFunctionName()
{
return "sum_if";
}

@Override
protected List<String> getFunctionParameterTypes()
{
return ImmutableList.of(StandardTypes.BOOLEAN, StandardTypes.BIGINT);
}
}
Loading