[go: nahoru, domu]

Close all per-generation channels on Pager invalidation

Fixes: 147676632
Test: ./gradlew paging:paging-common:test
Change-Id: I8857fce3a8a42181a80d4a58b00b3e4f4b341e01
diff --git a/paging/common/api/3.0.0-alpha01.txt b/paging/common/api/3.0.0-alpha01.txt
index 60851f0..0f0964f 100644
--- a/paging/common/api/3.0.0-alpha01.txt
+++ b/paging/common/api/3.0.0-alpha01.txt
@@ -5,6 +5,9 @@
     method public static <T> kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>> cachedIn(kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>>, kotlinx.coroutines.CoroutineScope scope);
   }
 
+  public final class CancelableChannelFlowKt {
+  }
+
   public abstract class DataSource<Key, Value> {
     method @AnyThread public void addInvalidatedCallback(androidx.paging.DataSource.InvalidatedCallback onInvalidatedCallback);
     method @AnyThread public final void addInvalidatedCallback(kotlin.jvm.functions.Function0<kotlin.Unit> onInvalidatedCallback);
diff --git a/paging/common/api/current.txt b/paging/common/api/current.txt
index 60851f0..0f0964f 100644
--- a/paging/common/api/current.txt
+++ b/paging/common/api/current.txt
@@ -5,6 +5,9 @@
     method public static <T> kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>> cachedIn(kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>>, kotlinx.coroutines.CoroutineScope scope);
   }
 
+  public final class CancelableChannelFlowKt {
+  }
+
   public abstract class DataSource<Key, Value> {
     method @AnyThread public void addInvalidatedCallback(androidx.paging.DataSource.InvalidatedCallback onInvalidatedCallback);
     method @AnyThread public final void addInvalidatedCallback(kotlin.jvm.functions.Function0<kotlin.Unit> onInvalidatedCallback);
diff --git a/paging/common/api/public_plus_experimental_3.0.0-alpha01.txt b/paging/common/api/public_plus_experimental_3.0.0-alpha01.txt
index 60851f0..0f0964f 100644
--- a/paging/common/api/public_plus_experimental_3.0.0-alpha01.txt
+++ b/paging/common/api/public_plus_experimental_3.0.0-alpha01.txt
@@ -5,6 +5,9 @@
     method public static <T> kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>> cachedIn(kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>>, kotlinx.coroutines.CoroutineScope scope);
   }
 
+  public final class CancelableChannelFlowKt {
+  }
+
   public abstract class DataSource<Key, Value> {
     method @AnyThread public void addInvalidatedCallback(androidx.paging.DataSource.InvalidatedCallback onInvalidatedCallback);
     method @AnyThread public final void addInvalidatedCallback(kotlin.jvm.functions.Function0<kotlin.Unit> onInvalidatedCallback);
diff --git a/paging/common/api/public_plus_experimental_current.txt b/paging/common/api/public_plus_experimental_current.txt
index 60851f0..0f0964f 100644
--- a/paging/common/api/public_plus_experimental_current.txt
+++ b/paging/common/api/public_plus_experimental_current.txt
@@ -5,6 +5,9 @@
     method public static <T> kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>> cachedIn(kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>>, kotlinx.coroutines.CoroutineScope scope);
   }
 
+  public final class CancelableChannelFlowKt {
+  }
+
   public abstract class DataSource<Key, Value> {
     method @AnyThread public void addInvalidatedCallback(androidx.paging.DataSource.InvalidatedCallback onInvalidatedCallback);
     method @AnyThread public final void addInvalidatedCallback(kotlin.jvm.functions.Function0<kotlin.Unit> onInvalidatedCallback);
diff --git a/paging/common/api/restricted_3.0.0-alpha01.txt b/paging/common/api/restricted_3.0.0-alpha01.txt
index a435181..ff59d10 100644
--- a/paging/common/api/restricted_3.0.0-alpha01.txt
+++ b/paging/common/api/restricted_3.0.0-alpha01.txt
@@ -5,6 +5,9 @@
     method public static <T> kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>> cachedIn(kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>>, kotlinx.coroutines.CoroutineScope scope);
   }
 
+  public final class CancelableChannelFlowKt {
+  }
+
   public abstract class DataSource<Key, Value> {
     method @AnyThread public void addInvalidatedCallback(androidx.paging.DataSource.InvalidatedCallback onInvalidatedCallback);
     method @AnyThread public final void addInvalidatedCallback(kotlin.jvm.functions.Function0<kotlin.Unit> onInvalidatedCallback);
diff --git a/paging/common/api/restricted_current.txt b/paging/common/api/restricted_current.txt
index a435181..ff59d10 100644
--- a/paging/common/api/restricted_current.txt
+++ b/paging/common/api/restricted_current.txt
@@ -5,6 +5,9 @@
     method public static <T> kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>> cachedIn(kotlinx.coroutines.flow.Flow<androidx.paging.PagingData<T>>, kotlinx.coroutines.CoroutineScope scope);
   }
 
+  public final class CancelableChannelFlowKt {
+  }
+
   public abstract class DataSource<Key, Value> {
     method @AnyThread public void addInvalidatedCallback(androidx.paging.DataSource.InvalidatedCallback onInvalidatedCallback);
     method @AnyThread public final void addInvalidatedCallback(kotlin.jvm.functions.Function0<kotlin.Unit> onInvalidatedCallback);
diff --git a/paging/common/src/main/kotlin/androidx/paging/CancelableChannelFlow.kt b/paging/common/src/main/kotlin/androidx/paging/CancelableChannelFlow.kt
new file mode 100644
index 0000000..641f0ef
--- /dev/null
+++ b/paging/common/src/main/kotlin/androidx/paging/CancelableChannelFlow.kt
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.paging
+
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.Job
+import kotlinx.coroutines.channels.ProducerScope
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.channelFlow
+import kotlin.experimental.ExperimentalTypeInference
+
+@UseExperimental(ExperimentalTypeInference::class, ExperimentalCoroutinesApi::class)
+internal fun <T> cancelableChannelFlow(
+    controller: Job,
+    @BuilderInference block: suspend ProducerScope<T>.() -> Unit
+): Flow<T> {
+    return channelFlow {
+        controller.invokeOnCompletion {
+            close()
+        }
+        this.block()
+    }
+}
\ No newline at end of file
diff --git a/paging/common/src/main/kotlin/androidx/paging/PageFetcher.kt b/paging/common/src/main/kotlin/androidx/paging/PageFetcher.kt
index 2da1c0e..aa45f2d 100644
--- a/paging/common/src/main/kotlin/androidx/paging/PageFetcher.kt
+++ b/paging/common/src/main/kotlin/androidx/paging/PageFetcher.kt
@@ -56,6 +56,7 @@
             pagingSource.registerInvalidatedCallback(::refresh)
             previousGeneration?.pagingSource?.unregisterInvalidatedCallback(::refresh)
             previousGeneration?.pagingSource?.invalidate() // Note: Invalidate is idempotent.
+            previousGeneration?.close()
 
             Pager(initialKey, pagingSource, config)
         }
diff --git a/paging/common/src/main/kotlin/androidx/paging/Pager.kt b/paging/common/src/main/kotlin/androidx/paging/Pager.kt
index 173aa6c..88a5f2a 100644
--- a/paging/common/src/main/kotlin/androidx/paging/Pager.kt
+++ b/paging/common/src/main/kotlin/androidx/paging/Pager.kt
@@ -29,13 +29,13 @@
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.FlowPreview
+import kotlinx.coroutines.Job
 import kotlinx.coroutines.channels.BroadcastChannel
 import kotlinx.coroutines.channels.Channel
 import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
 import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
 import kotlinx.coroutines.flow.Flow
 import kotlinx.coroutines.flow.asFlow
-import kotlinx.coroutines.flow.channelFlow
 import kotlinx.coroutines.flow.collect
 import kotlinx.coroutines.flow.conflate
 import kotlinx.coroutines.flow.consumeAsFlow
@@ -70,7 +70,8 @@
     private val stateLock = Mutex()
     private val state = PagerState<Key, Value>(config.pageSize, config.maxSize)
 
-    val pageEventFlow: Flow<PageEvent<Value>> = channelFlow {
+    private val pageEventChannelFlowJob = Job()
+    val pageEventFlow: Flow<PageEvent<Value>> = cancelableChannelFlow(pageEventChannelFlowJob) {
         check(pageEventChCollected.compareAndSet(false, true)) {
             "cannot collect twice from pager"
         }
@@ -86,7 +87,8 @@
             retryChannel.consumeAsFlow()
                 .collect {
                     // Handle refresh failure. Re-attempt doInitialLoad if the last attempt failed,
-                    val refreshFailure = stateLock.withLock { state.failedHintsByLoadType[REFRESH] }
+                    val refreshFailure =
+                        stateLock.withLock { state.failedHintsByLoadType[REFRESH] }
                     refreshFailure?.let {
                         stateLock.withLock { state.failedHintsByLoadType.remove(REFRESH) }
                         doInitialLoad(state)
@@ -125,6 +127,10 @@
         retryChannel.offer(Unit)
     }
 
+    fun close() {
+        pageEventChannelFlowJob.cancel()
+    }
+
     suspend fun refreshKeyInfo(): RefreshInfo<Key, Value>? {
         return lastHint?.let { hint ->
             stateLock.withLock {
diff --git a/paging/common/src/main/kotlin/androidx/paging/PagerState.kt b/paging/common/src/main/kotlin/androidx/paging/PagerState.kt
index 71d5b80..fe34d0a 100644
--- a/paging/common/src/main/kotlin/androidx/paging/PagerState.kt
+++ b/paging/common/src/main/kotlin/androidx/paging/PagerState.kt
@@ -78,7 +78,7 @@
      * Note: This method should be called after state updated by [insert]
      *
      * TODO: Move this into Pager, which owns pageEventCh, since this logic is sensitive to its
-     * implementation.
+     *  implementation.
      */
     internal fun Page<Key, Value>.toPageEvent(
         loadType: LoadType,
diff --git a/paging/common/src/test/kotlin/androidx/paging/PageFetcherTest.kt b/paging/common/src/test/kotlin/androidx/paging/PageFetcherTest.kt
index 67841b0..9982989 100644
--- a/paging/common/src/test/kotlin/androidx/paging/PageFetcherTest.kt
+++ b/paging/common/src/test/kotlin/androidx/paging/PageFetcherTest.kt
@@ -22,7 +22,11 @@
 import kotlinx.coroutines.FlowPreview
 import kotlinx.coroutines.InternalCoroutinesApi
 import kotlinx.coroutines.Job
+import kotlinx.coroutines.channels.ClosedSendChannelException
+import kotlinx.coroutines.flow.collect
 import kotlinx.coroutines.flow.collectIndexed
+import kotlinx.coroutines.flow.onCompletion
+import kotlinx.coroutines.flow.onStart
 import kotlinx.coroutines.flow.toList
 import kotlinx.coroutines.launch
 import kotlinx.coroutines.test.TestCoroutineScope
@@ -31,6 +35,7 @@
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
 import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
 import kotlin.test.assertNotEquals
 import kotlin.test.assertTrue
 
@@ -80,7 +85,7 @@
     }
 
     @Test
-    fun refreshFromPagingSource() = testScope.runBlockingTest {
+    fun refresh_fromPagingSource() = testScope.runBlockingTest {
         var pagingSource: PagingSource<Int, Int>? = null
         val pagingSourceFactory = { TestPagingSource().also { pagingSource = it } }
         val pageFetcher = PageFetcher(pagingSourceFactory, 50, config)
@@ -103,7 +108,7 @@
     }
 
     @Test
-    fun refreshCallsInvalidate() = testScope.runBlockingTest {
+    fun refresh_callsInvalidate() = testScope.runBlockingTest {
         var pagingSource: PagingSource<Int, Int>? = null
         val pagingSourceFactory = { TestPagingSource().also { pagingSource = it } }
         val pageFetcher = PageFetcher(pagingSourceFactory, 50, config)
@@ -127,6 +132,73 @@
     }
 
     @Test
+    fun refresh_closesCollection() = testScope.runBlockingTest {
+        val pageFetcher = PageFetcher(pagingSourceFactory, 50, config)
+
+        var pagingDataCount = 0
+        var didFinish = false
+        val job = launch {
+            pageFetcher.flow.collect { pagedData ->
+                pagingDataCount++
+                pagedData.flow
+                    .onCompletion {
+                        didFinish = true
+                    }
+                    // Return immediately to avoid blocking cancellation. This is analogous to
+                    // logic which would process a single PageEvent and doesn't suspend
+                    // indefinitely, which is what we expect to happen.
+                    .collect { }
+            }
+        }
+
+        advanceUntilIdle()
+
+        pageFetcher.refresh()
+        advanceUntilIdle()
+
+        assertEquals(2, pagingDataCount)
+        assertTrue { didFinish }
+        job.cancel()
+    }
+
+    @Test
+    fun refresh_closesUncollectedPageEventCh() = testScope.runBlockingTest {
+        val pageFetcher = PageFetcher(pagingSourceFactory, 50, config)
+
+        val pagingDatas = mutableListOf<PagingData<Int>>()
+        val didFinish = mutableListOf<Boolean>()
+        val job = launch {
+            pageFetcher.flow.collectIndexed { index, pagingData ->
+                pagingDatas.add(pagingData)
+                if (index != 1) {
+                    pagingData.flow
+                        .onStart {
+                            didFinish.add(false)
+                        }
+                        .onCompletion {
+                            if (index < 2) didFinish[index] = true
+                        }
+                        // Return immediately to avoid blocking cancellation. This is analogous to
+                        // logic which would process a single PageEvent and doesn't suspend
+                        // indefinitely, which is what we expect to happen.
+                        .collect { }
+                }
+            }
+        }
+
+        advanceUntilIdle()
+
+        pageFetcher.refresh()
+        pageFetcher.refresh()
+        advanceUntilIdle()
+
+        assertEquals(3, pagingDatas.size)
+        assertFailsWith<ClosedSendChannelException> { pagingDatas[1].flow.collect { } }
+        assertEquals(listOf(true, false), didFinish)
+        job.cancel()
+    }
+
+    @Test
     fun collectTwice() = testScope.runBlockingTest {
         val pageFetcher = PageFetcher(pagingSourceFactory, 50, config)
         val fetcherState = collectFetcherState(pageFetcher)
diff --git a/paging/common/src/test/kotlin/androidx/paging/PagerTest.kt b/paging/common/src/test/kotlin/androidx/paging/PagerTest.kt
index 6acc0db..ad07948 100644
--- a/paging/common/src/test/kotlin/androidx/paging/PagerTest.kt
+++ b/paging/common/src/test/kotlin/androidx/paging/PagerTest.kt
@@ -33,6 +33,7 @@
 import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.FlowPreview
 import kotlinx.coroutines.InternalCoroutinesApi
+import kotlinx.coroutines.delay
 import kotlinx.coroutines.flow.collect
 import kotlinx.coroutines.launch
 import kotlinx.coroutines.test.TestCoroutineScope
@@ -40,6 +41,8 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
+import kotlin.test.assertTrue
+import kotlin.test.fail
 
 @FlowPreview
 @ExperimentalCoroutinesApi
@@ -527,6 +530,33 @@
     }
 
     @Test
+    fun close_cancelsCollectionBeforeInitialLoad() = testScope.runBlockingTest {
+        // Infinitely suspending PagingSource which never finishes loading anything.
+        val pagingSource = object : PagingSource<Int, Int>() {
+            override suspend fun load(params: LoadParams<Int>): LoadResult<Int, Int> {
+                delay(2000)
+                fail("Should never get here")
+            }
+        }
+
+        val pager = Pager(50, pagingSource, config)
+        val job = launch {
+            pager.pageEventFlow
+                // Return immediately to avoid blocking cancellation. This is analogous to
+                // logic which would process a single PageEvent and doesn't suspend
+                // indefinitely, which is what we expect to happen.
+                .collect { }
+        }
+
+        advanceTimeBy(500)
+
+        pager.close()
+        advanceTimeBy(500)
+
+        assertTrue { !job.isActive }
+    }
+
+    @Test
     fun retry() = testScope.runBlockingTest {
         pauseDispatcher {
             val pageSource = pagingSourceFactory()