diff --git a/benchmarks/src/jmh/kotlin/benchmarks/flow/TakeBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/flow/TakeBenchmark.kt new file mode 100644 index 0000000000..84afca2439 --- /dev/null +++ b/benchmarks/src/jmh/kotlin/benchmarks/flow/TakeBenchmark.kt @@ -0,0 +1,138 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package benchmarks.flow + +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.* +import java.util.concurrent.CancellationException +import kotlin.coroutines.* +import kotlin.coroutines.intrinsics.* +import benchmarks.flow.scrabble.flow as unsafeFlow + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Fork(value = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Benchmark) +open class TakeBenchmark { + @Param("1", "10", "100", "1000") + private var size: Int = 0 + + private suspend inline fun Flow.consume() = + filter { it % 2L != 0L } + .map { it * it }.count() + + @Benchmark + fun baseline() = runBlocking { + (0L until size).asFlow().consume() + } + + @Benchmark + fun originalTake() = runBlocking { + (0L..Long.MAX_VALUE).asFlow().originalTake(size).consume() + } + + @Benchmark + fun fastPathTake() = runBlocking { + (0L..Long.MAX_VALUE).asFlow().fastPathTake(size).consume() + } + + @Benchmark + fun mergedStateMachine() = runBlocking { + (0L..Long.MAX_VALUE).asFlow().mergedStateMachineTake(size).consume() + } + + internal class StacklessCancellationException() : CancellationException() { + override fun fillInStackTrace(): Throwable = this + } + + public fun Flow.originalTake(count: Int): Flow { + return unsafeFlow { + var consumed = 0 + try { + collect { value -> + emit(value) + if (++consumed == count) { + throw StacklessCancellationException() + } + } + } catch (e: StacklessCancellationException) { + // Nothing, bail out + } + } + } + + private suspend fun FlowCollector.emitAbort(value: T) { + emit(value) + throw StacklessCancellationException() + } + + public fun Flow.fastPathTake(count: Int): Flow { + return unsafeFlow { + var consumed = 0 + try { + collect { value -> + if (++consumed < count) { + return@collect emit(value) + } else { + return@collect emitAbort(value) + } + } + } catch (e: StacklessCancellationException) { + // Nothing, bail out + } + } + } + + + public fun Flow.mergedStateMachineTake(count: Int): Flow { + return unsafeFlow() { + try { + val takeCollector = FlowTakeCollector(count, this) + collect(takeCollector) + } catch (e: StacklessCancellationException) { + // Nothing, bail out + } + } + } + + + private class FlowTakeCollector( + private val count: Int, + downstream: FlowCollector + ) : FlowCollector, Continuation { + private var consumed = 0 + // Workaround for KT-30991 + private val emitFun = run { + val suspendFun: suspend (T) -> Unit = { downstream.emit(it) } + suspendFun as Function2, Any?> + } + + private var caller: Continuation? = null // lateinit + + override val context: CoroutineContext + get() = caller?.context ?: EmptyCoroutineContext + + override fun resumeWith(result: Result) { + val completion = caller!! + if (++consumed == count) completion.resumeWith(Result.failure(StacklessCancellationException())) + else completion.resumeWith(Result.success(Unit)) + } + + override suspend fun emit(value: T) = suspendCoroutineUninterceptedOrReturn sc@{ + // Invoke it in non-suspending way + caller = it + val result = emitFun.invoke(value, this) + if (result !== COROUTINE_SUSPENDED) { + if (++consumed == count) throw StacklessCancellationException() + else return@sc Unit + } + COROUTINE_SUSPENDED + } + } +} diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt b/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt index 7f638f9814..1343dad868 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt @@ -55,9 +55,10 @@ public fun Flow.take(count: Int): Flow { var consumed = 0 try { collect { value -> - emit(value) - if (++consumed == count) { - throw AbortFlowException() + if (++consumed < count) { + return@collect emit(value) + } else { + return@collect emitAbort(value) } } } catch (e: AbortFlowException) { @@ -66,6 +67,11 @@ public fun Flow.take(count: Int): Flow { } } +private suspend fun FlowCollector.emitAbort(value: T) { + emit(value) + throw AbortFlowException() +} + /** * Returns a flow that contains first elements satisfying the given [predicate]. */ diff --git a/kotlinx-coroutines-core/common/test/flow/operators/TakeTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/TakeTest.kt index 711034969f..8ea137df08 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/TakeTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/TakeTest.kt @@ -21,6 +21,27 @@ class TakeTest : TestBase() { assertEquals(2, flow.drop(1).take(1).single()) } + @Test + fun testIllegalArgument() { + assertFailsWith { flowOf(1).take(0) } + assertFailsWith { flowOf(1).take(-1) } + } + + @Test + fun testTakeSuspending() = runTest { + val flow = flow { + emit(1) + yield() + emit(2) + yield() + } + + assertEquals(3, flow.take(2).sum()) + assertEquals(3, flow.take(Int.MAX_VALUE).sum()) + assertEquals(1, flow.take(1).single()) + assertEquals(2, flow.drop(1).take(1).single()) + } + @Test fun testEmptyFlow() = runTest { val sum = emptyFlow().take(10).sum()