diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/data/RowData.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/data/RowData.java index 687b804a732bb..65fa790921f08 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/data/RowData.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/data/RowData.java @@ -286,9 +286,6 @@ static FieldGetter createFieldGetter(LogicalType fieldType, int fieldPos) { default: throw new IllegalArgumentException(); } - if (!fieldType.isNullable()) { - return fieldGetter; - } return row -> { if (row.isNullAt(fieldPos)) { return null; diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/data/RowDataTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/data/RowDataTest.java index f568e431a1527..eaaaae8368db4 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/data/RowDataTest.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/data/RowDataTest.java @@ -30,8 +30,21 @@ import org.apache.flink.table.runtime.typeutils.MapDataSerializer; import org.apache.flink.table.runtime.typeutils.RawValueDataSerializer; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.BinaryType; +import org.apache.flink.table.types.logical.BooleanType; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.FloatType; import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.MapType; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.SmallIntType; +import org.apache.flink.table.types.logical.TimestampType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.flink.table.types.logical.VarCharType; import org.apache.flink.types.RowKind; import org.junit.jupiter.api.BeforeEach; @@ -209,6 +222,121 @@ void testJoinedRow() { testGetters(new JoinedRowData(row1, row2)); } + @Test + void testFieldGetters() { + RowData row = getBinaryRow(); + + assertThat(RowData.createFieldGetter(new BooleanType(), 0).getFieldOrNull(row)) + .isEqualTo(true); + assertThat(RowData.createFieldGetter(new SmallIntType(), 1).getFieldOrNull(row)) + .isEqualTo((short) 1); + assertThat(RowData.createFieldGetter(new TinyIntType(), 2).getFieldOrNull(row)) + .isEqualTo((byte) 2); + assertThat(RowData.createFieldGetter(new IntType(), 3).getFieldOrNull(row)).isEqualTo(3); + assertThat(RowData.createFieldGetter(new BigIntType(), 4).getFieldOrNull(row)) + .isEqualTo(4L); + assertThat(RowData.createFieldGetter(new FloatType(), 5).getFieldOrNull(row)).isEqualTo(5f); + assertThat(RowData.createFieldGetter(new DoubleType(), 6).getFieldOrNull(row)) + .isEqualTo(6d); + assertThat(RowData.createFieldGetter(new CharType(1), 8).getFieldOrNull(row)) + .isEqualTo(str); + assertThat(RowData.createFieldGetter(new VarCharType(4), 8).getFieldOrNull(row)) + .isEqualTo(str); + assertThat(RowData.createFieldGetter(new DecimalType(5, 0), 10).getFieldOrNull(row)) + .isEqualTo(decimal1); + assertThat(RowData.createFieldGetter(new DecimalType(20, 0), 11).getFieldOrNull(row)) + .isEqualTo(decimal2); + assertThat(RowData.createFieldGetter(new ArrayType(new IntType()), 12).getFieldOrNull(row)) + .isEqualTo(array); + assertThat( + RowData.createFieldGetter(new MapType(new IntType(), new IntType()), 13) + .getFieldOrNull(row)) + .isEqualTo(map); + assertThat( + RowData.createFieldGetter(RowType.of(new IntType(), new IntType()), 14) + .getFieldOrNull(row)) + .isEqualTo(underRow); + assertThat(RowData.createFieldGetter(new BinaryType(3), 15).getFieldOrNull(row)) + .isEqualTo(bytes); + assertThat(RowData.createFieldGetter(new TimestampType(3), 16).getFieldOrNull(row)) + .isEqualTo(timestamp1); + assertThat(RowData.createFieldGetter(new TimestampType(9), 17).getFieldOrNull(row)) + .isEqualTo(timestamp2); + } + + @Test + void testFieldGettersWithNullableTypes() { + testFieldGettersWithNull(true); + } + + @Test + void testFieldGettersWithNonNullableTypes() { + testFieldGettersWithNull(false); + } + + private void testFieldGettersWithNull(boolean nullable) { + RowData row = getNullBinaryRow(); + assertThat(RowData.createFieldGetter(new BooleanType(nullable), 0).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new SmallIntType(nullable), 1).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new TinyIntType(nullable), 2).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new IntType(nullable), 3).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new BigIntType(nullable), 4).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new FloatType(nullable), 5).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new DoubleType(nullable), 6).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new CharType(nullable, 1), 8).getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new VarCharType(nullable, 4), 8).getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter(new DecimalType(nullable, 5, 0), 10) + .getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter(new DecimalType(nullable, 20, 0), 11) + .getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter( + new ArrayType(nullable, new IntType(nullable)), 12) + .getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter( + new MapType( + nullable, + new IntType(nullable), + new IntType(nullable)), + 13) + .getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter( + RowType.of( + nullable, + new IntType(nullable), + new IntType(nullable)), + 14) + .getFieldOrNull(row)) + .isNull(); + assertThat(RowData.createFieldGetter(new BinaryType(nullable, 3), 15).getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter(new TimestampType(nullable, 3), 16) + .getFieldOrNull(row)) + .isNull(); + assertThat( + RowData.createFieldGetter(new TimestampType(nullable, 9), 17) + .getFieldOrNull(row)) + .isNull(); + } + private void testGetters(RowData row) { assertThat(row.getArity()).isEqualTo(18); @@ -284,4 +412,13 @@ private void testSetters(RowData row) { setter.setNullAt(0); assertThat(row.isNullAt(0)).isTrue(); } + + private static BinaryRowData getNullBinaryRow() { + BinaryRowData row = new BinaryRowData(18); + BinaryRowWriter binaryRowWriter = new BinaryRowWriter(row); + for (int i = 0; i < row.getArity(); i++) { + binaryRowWriter.setNullAt(i); + } + return row; + } }