[go: nahoru, domu]

Model reads during measure stage should survive Frame switches

Previously we were calling currentFrame().observeReads(block) which wasn't working for some tricky cases when within this block we switch the Frame: the observer was still connected to the previous not-active frame and wasn't notified about the model reads. This issue was reproducible when inside WithConstraints we call DrawVector which is causing an additional subcomposition which is switching to a new frame.
To fix this I introduced global thread local list of read observers which can survive frame switches. And we fully control when they are added and removed as our api only provides observeReads(block) api, so the observer can't be leaked.

Bug: 144493391
Test: the fix is covered with tests on three different layers.
Change-Id: Ibba72c3f791c1435390ce12abfedac70729a5071
diff --git a/compose/compose-runtime/src/androidMain/kotlin/androidx/compose/ActualJvm.kt b/compose/compose-runtime/src/androidMain/kotlin/androidx/compose/ActualJvm.kt
index dce8cf8..4400cf3 100644
--- a/compose/compose-runtime/src/androidMain/kotlin/androidx/compose/ActualJvm.kt
+++ b/compose/compose-runtime/src/androidMain/kotlin/androidx/compose/ActualJvm.kt
@@ -18,17 +18,20 @@
 
 actual typealias BitSet = java.util.BitSet
 
-actual open class ThreadLocal<T> actual constructor() : java.lang.ThreadLocal<T>() {
-    actual override fun get(): T? {
-        return super.get()
+actual open class ThreadLocal<T> actual constructor(
+    private val initialValue: () -> T
+) : java.lang.ThreadLocal<T>() {
+    @Suppress("UNCHECKED_CAST")
+    actual override fun get(): T {
+        return super.get() as T
     }
 
-    actual override fun set(value: T?) {
+    actual override fun set(value: T) {
         super.set(value)
     }
 
-    actual override fun initialValue(): T? {
-        return super.initialValue()
+    override fun initialValue(): T? {
+        return initialValue.invoke()
     }
 }
 
diff --git a/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Expect.kt b/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Expect.kt
index bc96998..b266a4b 100644
--- a/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Expect.kt
+++ b/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Expect.kt
@@ -23,12 +23,13 @@
     operator fun get(bitIndex: Int): Boolean
 }
 
-expect open class ThreadLocal<T>() {
-    fun get(): T?
-    fun set(value: T?)
-    protected open fun initialValue(): T?
+expect open class ThreadLocal<T>(initialValue: () -> T) {
+    fun get(): T
+    fun set(value: T)
 }
 
+fun <T> ThreadLocal() = ThreadLocal<T?> { null }
+
 expect class WeakHashMap<K, V>() : MutableMap<K, V>
 
 expect fun identityHashCode(instance: Any?): Int
diff --git a/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Recomposer.kt b/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Recomposer.kt
index 7a000b00f..11ade84 100644
--- a/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Recomposer.kt
+++ b/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/Recomposer.kt
@@ -35,16 +35,13 @@
             require(isMainThread()) {
                 "No Recomposer for this Thread"
             }
-            return threadRecomposer.get() ?: error("No Recomposer for this Thread")
+            return threadRecomposer.get()
         }
 
         internal fun recompose(component: Component, composer: Composer<*>) =
             current().recompose(component, composer)
 
-        // TODO delete the explicit type after https://youtrack.jetbrains.com/issue/KT-20996
-        private val threadRecomposer: ThreadLocal<Recomposer> = object : ThreadLocal<Recomposer>() {
-            override fun initialValue(): Recomposer? = createRecomposer()
-        }
+        private val threadRecomposer = ThreadLocal { createRecomposer() }
     }
 
     private val composers = mutableSetOf<Composer<*>>()
diff --git a/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/frames/Frames.kt b/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/frames/Frames.kt
index b0c0a36..c4f619c 100644
--- a/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/frames/Frames.kt
+++ b/compose/compose-runtime/src/commonMain/kotlin/androidx/compose/frames/Frames.kt
@@ -121,23 +121,22 @@
     /**
      * Observe a frame read
      */
-    readObserver: FrameReadObserver?,
+    internal val readObserver: FrameReadObserver?,
 
     /**
      * Observe a frame write
      */
-    internal val writeObserver: FrameWriteObserver?
+    internal val writeObserver: FrameWriteObserver?,
+
+    /**
+     * The reference to the thread local list of observers from [threadReadObservers].
+     * We store it here to save on an additional ThreadLocal.get() call during
+     * the every model read.
+     */
+    internal var threadReadObservers: MutableList<FrameReadObserver>
 ) {
     internal val modified = if (readOnly) null else HashSet<Framed>()
 
-    internal val readObservers = mutableListOf<FrameReadObserver>()
-
-    init {
-        if (readObserver != null) {
-            readObservers += readObserver
-        }
-    }
-
     /**
      * True if any change to a frame object will throw.
      */
@@ -145,23 +144,31 @@
         get() = modified == null
 
     /**
-     * Add a [FrameReadObserver] during execution of the [block].
-     */
-    fun observeReads(readObserver: FrameReadObserver, block: () -> Unit) {
-        try {
-            readObservers += readObserver
-            block()
-        } finally {
-            readObservers -= readObserver
-        }
-    }
-
-    /**
      * Whether there are any pending changes in this frame.
      */
     fun hasPendingChanges(): Boolean = (modified?.size ?: 0) > 0
 }
 
+/**
+ * Holds the thread local list of [FrameReadObserver]s not associated with any specific [Frame].
+ * They survives [Frame]s switch.
+ */
+private val threadReadObservers = ThreadLocal { mutableListOf<FrameReadObserver>() }
+
+/**
+ * [FrameReadObserver] will be called for every frame read happened on the current
+ * thread during execution of the [block].
+ */
+fun observeAllReads(readObserver: FrameReadObserver, block: () -> Unit) {
+    val observers = threadReadObservers.get()
+    try {
+        observers.add(readObserver)
+        block()
+    } finally {
+        observers.remove(readObserver)
+    }
+}
+
 private fun validateNotInFrame() {
     if (threadFrame.get() != null) throw IllegalStateException("In an existing frame")
 }
@@ -193,6 +200,7 @@
     writeObserver: FrameWriteObserver?
 ): Frame {
     validateNotInFrame()
+    val threadReadObservers = threadReadObservers.get()
     synchronized(sync) {
         val id = maxFrameId++
         val invalid = openFrames
@@ -201,7 +209,8 @@
             invalid = invalid,
             readOnly = readOnly,
             readObserver = readObserver,
-            writeObserver = writeObserver
+            writeObserver = writeObserver,
+            threadReadObservers = threadReadObservers
         )
         openFrames = openFrames.set(id)
         threadFrame.set(frame)
@@ -386,6 +395,7 @@
 fun restore(frame: Frame) {
     validateNotInFrame()
     validateOpen(frame)
+    frame.threadReadObservers = threadReadObservers.get()
     threadFrame.set(frame)
 }
 
@@ -435,11 +445,11 @@
 }
 
 fun <T : Record> T.readable(framed: Framed): T {
-    return this.readable(currentFrame(), framed)
-}
-
-fun <T : Record> T.readable(frame: Frame, framed: Framed): T {
-    frame.readObservers.forEach { it(framed) }
+    val frame = currentFrame()
+    // invoke the observer associated with the current frame.
+    frame.readObserver?.invoke(framed)
+    // invoke the thread local observers.
+    frame.threadReadObservers.forEach { it(framed) }
     return readable(this, frame.id, frame.invalid)
 }
 
diff --git a/compose/compose-runtime/src/commonTest/kotlin/androidx/compose/frames/FramesTests.kt b/compose/compose-runtime/src/commonTest/kotlin/androidx/compose/frames/FramesTests.kt
index e7bc30f..89f5d32 100644
--- a/compose/compose-runtime/src/commonTest/kotlin/androidx/compose/frames/FramesTests.kt
+++ b/compose/compose-runtime/src/commonTest/kotlin/androidx/compose/frames/FramesTests.kt
@@ -340,12 +340,11 @@
         }
         var read: Address? = null
         var otherRead: Address? = null
-        val frame = open({ obj -> read = obj as Address })
+        open({ obj -> read = obj as Address })
         try {
-            frame.observeReads({ obj -> otherRead = obj as Address }) {
+            observeAllReads({ obj -> otherRead = obj as Address }) {
                 assertEquals(OLD_STREET, address.street)
             }
-            assertEquals(1, frame.readObservers.size)
         } finally {
             commitHandler()
         }
@@ -1100,6 +1099,48 @@
             iterator.remove()
         }
     }
+
+    @Test
+    fun testGlobalReadObserverSurvivesFrameSwitch() {
+        val address1 = frame {
+            Address(
+                OLD_STREET,
+                OLD_CITY
+            )
+        }
+        val address2 = frame {
+            Address(
+                NEW_STREET,
+                NEW_CITY
+            )
+        }
+        val address3 = frame {
+            Address(
+                OLD_STREET,
+                NEW_CITY
+            )
+        }
+        val readAddresses = HashSet<Address>()
+
+        observeAllReads({ readAddresses.add(it as Address) }) {
+            frame {
+                // read 1
+                address1.city
+            }
+            frame {
+                // read 2
+                address2.city
+            }
+        }
+        frame {
+            // read 3 outside of observeReads
+            address3.city
+        }
+
+        assertTrue(readAddresses.contains(address1))
+        assertTrue(readAddresses.contains(address2))
+        assertFalse(readAddresses.contains(address3))
+    }
 }
 
 fun expectError(block: () -> Unit) {
diff --git a/ui/ui-framework/src/androidTest/java/androidx/ui/core/test/WithConstraintsTest.kt b/ui/ui-framework/src/androidTest/java/androidx/ui/core/test/WithConstraintsTest.kt
new file mode 100644
index 0000000..d1b8b95
--- /dev/null
+++ b/ui/ui-framework/src/androidTest/java/androidx/ui/core/test/WithConstraintsTest.kt
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.ui.core.test
+
+import androidx.test.filters.SmallTest
+import androidx.test.rule.ActivityTestRule
+import androidx.ui.core.Layout
+import androidx.ui.core.WithConstraints
+import androidx.ui.core.ipx
+import androidx.ui.core.px
+import androidx.ui.core.setContent
+import androidx.ui.framework.test.TestActivity
+import androidx.ui.graphics.vector.DrawVector
+import org.junit.Assert.assertTrue
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+
+@SmallTest
+@RunWith(JUnit4::class)
+class WithConstraintsTest {
+
+    @get:Rule
+    val rule = ActivityTestRule<TestActivity>(
+        TestActivity::class.java
+    )
+    private lateinit var activity: TestActivity
+
+    @Before
+    fun setup() {
+        activity = rule.activity
+        activity.hasFocusLatch.await(5, TimeUnit.SECONDS)
+    }
+
+    @Test
+    fun subcomposionInsideWithConstraintsDoesntAffectModelReadsObserving() {
+        val model = ValueModel(0)
+        var latch = CountDownLatch(1)
+
+        rule.runOnUiThreadIR {
+            activity.setContent {
+                WithConstraints {
+                    // this block is called as a subcomposition from LayoutNode.measure()
+                    // DrawVector introduces additional subcomposition which is closing the
+                    // current frame and opens a new one. our model reads during measure()
+                    // wasn't possible to survide Frames swicth previously so the model read
+                    // within the child Layout wasn't recorded
+                    DrawVector(100.px, 100.px) { _, _ -> }
+                    Layout({}) { _, _ ->
+                        // read the model
+                        model.value
+                        latch.countDown()
+                        layout(10.ipx, 10.ipx) {}
+                    }
+                }
+            }
+        }
+        assertTrue(latch.await(1, TimeUnit.SECONDS))
+
+        latch = CountDownLatch(1)
+        rule.runOnUiThread { model.value++ }
+        assertTrue(latch.await(1, TimeUnit.SECONDS))
+    }
+}
diff --git a/ui/ui-platform/src/androidTest/java/androidx/ui/core/NodeStagesModelObserverTest.kt b/ui/ui-platform/src/androidTest/java/androidx/ui/core/NodeStagesModelObserverTest.kt
index c8b6a2c..866311c 100644
--- a/ui/ui-platform/src/androidTest/java/androidx/ui/core/NodeStagesModelObserverTest.kt
+++ b/ui/ui-platform/src/androidTest/java/androidx/ui/core/NodeStagesModelObserverTest.kt
@@ -16,6 +16,7 @@
 
 package androidx.ui.core
 
+import androidx.compose.FrameManager
 import androidx.compose.annotations.Hide
 import androidx.compose.frames.AbstractRecord
 import androidx.compose.frames.Framed
@@ -193,6 +194,38 @@
         assertTrue(layoutLatch2.await(1, TimeUnit.SECONDS))
         assertTrue(measureLatch.await(1, TimeUnit.SECONDS))
     }
+
+    @Test
+    fun modelReadTriggersCallbackAfterSwitchingFrameWithinObserveReads() {
+        val node = DrawNode()
+        val countDownLatch = CountDownLatch(1)
+
+        val model = State(0)
+        val modelObserver = NodeStagesModelObserver { _, _ ->
+            assertEquals(1, countDownLatch.count)
+            countDownLatch.countDown()
+        }
+
+        modelObserver.enableModelUpdatesObserving(true)
+
+        open() // open the frame
+
+        modelObserver.observeReads {
+            modelObserver.stage(Stage.Draw, node) {
+                // switch to the next frame.
+                // this will be done by subcomposition, for example.
+                FrameManager.nextFrame()
+                // read the value
+                model.value
+            }
+        }
+
+        model.value++
+        commit() // close the frame
+
+        modelObserver.enableModelUpdatesObserving(false)
+        assertTrue(countDownLatch.await(1, TimeUnit.SECONDS))
+    }
 }
 
 // @Model generation is not enabled for this module and androidx.compose.State is internal
diff --git a/ui/ui-platform/src/main/java/androidx/ui/core/NodeStagesModelObserver.kt b/ui/ui-platform/src/main/java/androidx/ui/core/NodeStagesModelObserver.kt
index 3ce44df..fc363d9 100644
--- a/ui/ui-platform/src/main/java/androidx/ui/core/NodeStagesModelObserver.kt
+++ b/ui/ui-platform/src/main/java/androidx/ui/core/NodeStagesModelObserver.kt
@@ -23,7 +23,7 @@
 import androidx.compose.WeakReference
 import androidx.compose.frames.FrameCommitObserver
 import androidx.compose.frames.FrameReadObserver
-import androidx.compose.frames.currentFrame
+import androidx.compose.frames.observeAllReads
 import androidx.compose.frames.registerCommitObserver
 
 /**
@@ -128,7 +128,7 @@
         check(!isObserving)
         check(currentNodes.isEmpty())
         isObserving = true
-        currentFrame().observeReads(frameReadObserver, block)
+        observeAllReads(frameReadObserver, block)
         isObserving = false
         check(currentNodes.isEmpty())
     }