Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Flow.take #1538

Merged
merged 3 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/flow/TakeBenchmark.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*
import java.util.concurrent.CancellationException
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
import benchmarks.flow.scrabble.flow as unsafeFlow

@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class TakeBenchmark {
@Param("1", "10", "100", "1000")
private var size: Int = 0

private suspend inline fun Flow<Long>.consume() =
filter { it % 2L != 0L }
.map { it * it }.count()

@Benchmark
fun baseline() = runBlocking<Int> {
(0L until size).asFlow().consume()
}

@Benchmark
fun originalTake() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().originalTake(size).consume()
}

@Benchmark
fun fastPathTake() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().fastPathTake(size).consume()
}

@Benchmark
fun mergedStateMachine() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().mergedStateMachineTake(size).consume()
}


internal class StacklessCancellationException() : CancellationException() {
override fun fillInStackTrace(): Throwable = this
}

public fun <T> Flow<T>.originalTake(count: Int): Flow<T> {
return unsafeFlow {
var consumed = 0
try {
collect { value ->
emit(value)
if (++consumed == count) {
throw StacklessCancellationException()
}
}
} catch (e: StacklessCancellationException) {
// Nothing, bail out
}
}
}


public fun <T> Flow<T>.fastPathTake(count: Int): Flow<T> {

suspend fun FlowCollector<T>.emitAbort(value: T) {
emit(value)
throw StacklessCancellationException()
}

return unsafeFlow {
var consumed = 0
try {
collect { value ->
if (++consumed < count) {
return@collect emit(value)
} else {
return@collect emitAbort(value)
}
}
} catch (e: StacklessCancellationException) {
// Nothing, bail out
}
}
}


public fun <T> Flow<T>.mergedStateMachineTake(count: Int): Flow<T> {
return unsafeFlow() {
try {
val takeCollector = FlowTakeCollector(count, this)
collect(takeCollector)
} catch (e: StacklessCancellationException) {
// Nothing, bail out
}
}
}


private class FlowTakeCollector<T>(
private val count: Int,
downstream: FlowCollector<T>
) : FlowCollector<T>, Continuation<Unit> {
private var consumed = 0
// Workaround for KT-30991
private val emitFun = run {
val suspendFun: suspend (T) -> Unit = { downstream.emit(it) }
suspendFun as Function2<T, Continuation<Unit>, Any?>
}

private var caller: Continuation<Unit>? = null // lateinit

override val context: CoroutineContext
get() = caller?.context ?: EmptyCoroutineContext

override fun resumeWith(result: Result<Unit>) {
val completion = caller!!
if (++consumed == count) completion.resumeWith(Result.failure(StacklessCancellationException()))
else completion.resumeWith(Result.success(Unit))
}

override suspend fun emit(value: T) = suspendCoroutineUninterceptedOrReturn<Unit> sc@{
// Invoke it in non-suspending way
caller = it
val result = emitFun.invoke(value, this)
if (result !== COROUTINE_SUSPENDED) {
if (++consumed == count) throw StacklessCancellationException()
else return@sc Unit
}
COROUTINE_SUSPENDED
}
}
}
13 changes: 10 additions & 3 deletions kotlinx-coroutines-core/common/src/flow/operators/Limit.kt
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ public fun <T> Flow<T>.dropWhile(predicate: suspend (T) -> Boolean): Flow<T> = f
@ExperimentalCoroutinesApi
public fun <T> Flow<T>.take(count: Int): Flow<T> {
require(count > 0) { "Requested element count $count should be positive" }

suspend fun <T> FlowCollector<T>.emitAbort(value: T) {
emit(value)
throw AbortFlowException()
}

return flow {
var consumed = 0
try {
collect { value ->
emit(value)
if (++consumed == count) {
throw AbortFlowException()
if (++consumed < count) {
return@collect emit(value)
} else {
return@collect emitAbort(value)
}
}
} catch (e: AbortFlowException) {
Expand Down
21 changes: 21 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/operators/TakeTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ class TakeTest : TestBase() {
assertEquals(2, flow.drop(1).take(1).single())
}

@Test
fun testIllegalArgument() {
assertFailsWith<IllegalArgumentException> { flowOf(1).take(0) }
assertFailsWith<IllegalArgumentException> { flowOf(1).take(-1) }
}

@Test
fun testTakeSuspending() = runTest {
val flow = flow {
emit(1)
yield()
emit(2)
yield()
}

assertEquals(3, flow.take(2).sum())
assertEquals(3, flow.take(Int.MAX_VALUE).sum())
assertEquals(1, flow.take(1).single())
assertEquals(2, flow.drop(1).take(1).single())
}

@Test
fun testEmptyFlow() = runTest {
val sum = emptyFlow<Int>().take(10).sum()
Expand Down