[go: nahoru, domu]

Correctly record capture scope of local functions

Fixes: 201252574
Change-Id: I19b2548fd7a03ebeb12b769a5970231600ca4563
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt
index 2918915..e140173 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt
@@ -168,6 +168,25 @@
     )
 
     @Test
+    fun testCaptureIssue23(): Unit = codegen(
+        """
+            import androidx.compose.animation.AnimatedContent
+            import androidx.compose.animation.ExperimentalAnimationApi
+            import androidx.compose.runtime.Composable
+
+            @OptIn(ExperimentalAnimationApi::class)
+            @Composable
+            fun SimpleAnimatedContentSample() {
+                @Composable fun Foo() {}
+
+                AnimatedContent(1f) {
+                    Foo()
+                }
+            }
+        """
+    )
+
+    @Test
     fun test32Params(): Unit = codegen(
         """
         @Composable
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt
index 732e8bf..d37af66 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt
@@ -109,6 +109,171 @@
         """
     )
 
+    // Fixes b/201252574
+    @Test
+    fun testLocalFunCaptures(): Unit = verifyComposeIrTransform(
+        """
+            import androidx.compose.runtime.NonRestartableComposable
+            import androidx.compose.runtime.Composable
+
+            @NonRestartableComposable
+            @Composable
+            fun Err() {
+                // `x` is not a capture of handler, but is treated as such.
+                fun handler() {
+                    { x: Int -> x }
+                }
+                // Lambda calling handler. To find captures, we need captures of `handler`.
+                {
+                  handler()
+                }
+            }
+        """,
+        """
+            @NonRestartableComposable
+            @Composable
+            fun Err(%composer: Composer?, %changed: Int) {
+              %composer.startReplaceableGroup(<>)
+              sourceInformation(%composer, "C(Err):Test.kt")
+              fun handler() {
+                { x: Int ->
+                  x
+                }
+              }
+              {
+                handler()
+              }
+              %composer.endReplaceableGroup()
+            }
+        """,
+        """
+        """
+    )
+
+    @Test
+    fun testLocalClassCaptures1(): Unit = verifyComposeIrTransform(
+        """
+            import androidx.compose.runtime.NonRestartableComposable
+            import androidx.compose.runtime.Composable
+
+            @NonRestartableComposable
+            @Composable
+            fun Err(y: Int, z: Int) {
+                class Local {
+                    val w = z
+                    fun something(x: Int): Int { return x + y + w }
+                }
+                {
+                  Local().something(2)
+                }
+            }
+        """,
+        """
+            @NonRestartableComposable
+            @Composable
+            fun Err(y: Int, z: Int, %composer: Composer?, %changed: Int) {
+              %composer.startReplaceableGroup(<>)
+              sourceInformation(%composer, "C(Err)<{>:Test.kt")
+              class Local {
+                val w: Int = z
+                fun something(x: Int): Int {
+                  return x + y + w
+                }
+              }
+              remember(y, z, {
+                {
+                  Local().something(2)
+                }
+              }, %composer, 0)
+              %composer.endReplaceableGroup()
+            }
+        """,
+        """
+        """
+    )
+
+    @Test
+    fun testLocalClassCaptures2(): Unit = verifyComposeIrTransform(
+        """
+            import androidx.compose.runtime.Composable
+            import androidx.compose.runtime.NonRestartableComposable
+
+            @NonRestartableComposable
+            @Composable
+            fun Example(z: Int) {
+                class Foo(val x: Int) { val y = z }
+                val lambda: () -> Any = {
+                    Foo(1)
+                }
+            }
+        """,
+        """
+            @NonRestartableComposable
+            @Composable
+            fun Example(z: Int, %composer: Composer?, %changed: Int) {
+              %composer.startReplaceableGroup(<>)
+              sourceInformation(%composer, "C(Example)<{>:Test.kt")
+              class Foo(val x: Int) {
+                val y: Int = z
+              }
+              val lambda = remember(z, {
+                {
+                  Foo(1)
+                }
+              }, %composer, 0)
+              %composer.endReplaceableGroup()
+            }
+        """,
+        """
+        """
+    )
+
+    @Test
+    fun testLocalFunCaptures3(): Unit = verifyComposeIrTransform(
+        """
+            import androidx.compose.animation.AnimatedContent
+            import androidx.compose.animation.ExperimentalAnimationApi
+            import androidx.compose.runtime.Composable
+
+            @OptIn(ExperimentalAnimationApi::class)
+            @Composable
+            fun SimpleAnimatedContentSample() {
+                @Composable fun Foo() {}
+
+                AnimatedContent(1f) {
+                    Foo()
+                }
+            }
+        """,
+        """
+            @OptIn(markerClass = ExperimentalAnimationApi::class)
+            @Composable
+            fun SimpleAnimatedContentSample(%composer: Composer?, %changed: Int) {
+              %composer = %composer.startRestartGroup(<>)
+              sourceInformation(%composer, "C(SimpleAnimatedContentSample)<Animat...>:Test.kt")
+              if (%changed !== 0 || !%composer.skipping) {
+                @Composable
+                fun Foo(%composer: Composer?, %changed: Int) {
+                  %composer.startReplaceableGroup(<>)
+                  sourceInformation(%composer, "C(Foo):Test.kt")
+                  %composer.endReplaceableGroup()
+                }
+                AnimatedContent(1.0f, null, null, null, composableLambda(%composer, <>, false) { it: Float, %composer: Composer?, %changed: Int ->
+                  sourceInformation(%composer, "C<Foo()>:Test.kt")
+                  Foo(%composer, 0)
+                }, %composer, 0b0110000000000110, 0b1110)
+              } else {
+                %composer.skipToGroupEnd()
+              }
+              %composer.endRestartGroup()?.updateScope { %composer: Composer?, %force: Int ->
+                SimpleAnimatedContentSample(%composer, %changed or 0b0001)
+              }
+            }
+        """,
+        """
+        """
+    )
+
     @Test
     fun testStateDelegateCapture(): Unit = verifyComposeIrTransform(
         """
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt
index 04d8902..bb88a82 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt
@@ -53,6 +53,7 @@
 import org.jetbrains.kotlin.ir.builders.irTemporary
 import org.jetbrains.kotlin.ir.declarations.IrAttributeContainer
 import org.jetbrains.kotlin.ir.declarations.IrClass
+import org.jetbrains.kotlin.ir.declarations.IrConstructor
 import org.jetbrains.kotlin.ir.declarations.IrDeclarationBase
 import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
 import org.jetbrains.kotlin.ir.declarations.IrFile
@@ -63,6 +64,7 @@
 import org.jetbrains.kotlin.ir.declarations.IrVariable
 import org.jetbrains.kotlin.ir.declarations.copyAttributes
 import org.jetbrains.kotlin.ir.expressions.IrCall
+import org.jetbrains.kotlin.ir.expressions.IrConstructorCall
 import org.jetbrains.kotlin.ir.expressions.IrExpression
 import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression
 import org.jetbrains.kotlin.ir.expressions.IrFunctionExpression
@@ -97,53 +99,85 @@
 
 private class CaptureCollector {
     val captures = mutableSetOf<IrValueDeclaration>()
-    val capturedFunctions = mutableSetOf<IrFunction>()
-    val hasCaptures: Boolean get() = captures.isNotEmpty() || capturedFunctions.isNotEmpty()
+    val capturedDeclarations = mutableSetOf<IrSymbolOwner>()
+    val hasCaptures: Boolean get() = captures.isNotEmpty() || capturedDeclarations.isNotEmpty()
 
     fun recordCapture(local: IrValueDeclaration) {
         captures.add(local)
     }
 
-    fun recordCapture(local: IrFunction) {
-        capturedFunctions.add(local)
+    fun recordCapture(local: IrSymbolOwner) {
+        capturedDeclarations.add(local)
     }
 }
 
 private abstract class DeclarationContext {
+    val localDeclarationCaptures = mutableMapOf<IrSymbolOwner, Set<IrValueDeclaration>>()
+    fun recordLocalDeclaration(local: DeclarationContext) {
+        localDeclarationCaptures[local.declaration] = local.captures
+    }
     abstract val composable: Boolean
     abstract val symbol: IrSymbol
+    abstract val declaration: IrSymbolOwner
+    abstract val captures: Set<IrValueDeclaration>
     abstract val functionContext: FunctionContext?
     abstract fun declareLocal(local: IrValueDeclaration?)
-    abstract fun recordLocalFunction(local: FunctionContext)
-    abstract fun recordCapture(local: IrValueDeclaration?)
-    abstract fun recordCapture(local: IrFunction?)
+    abstract fun recordCapture(local: IrValueDeclaration?): Boolean
+    abstract fun recordCapture(local: IrSymbolOwner?)
     abstract fun pushCollector(collector: CaptureCollector)
     abstract fun popCollector(collector: CaptureCollector)
 }
 
-private class SymbolOwnerContext(val declaration: IrSymbolOwner) : DeclarationContext() {
+private fun List<DeclarationContext>.recordCapture(value: IrValueDeclaration) {
+    for (dec in reversed()) {
+        val shouldBreak = dec.recordCapture(value)
+        if (shouldBreak) break
+    }
+}
+
+private fun List<DeclarationContext>.recordLocalDeclaration(local: DeclarationContext) {
+    for (dec in reversed()) {
+        dec.recordLocalDeclaration(local)
+    }
+}
+
+private fun List<DeclarationContext>.recordLocalCapture(local: IrSymbolOwner) {
+    val capturesForLocal = reversed().firstNotNullOfOrNull { it.localDeclarationCaptures[local] }
+    if (capturesForLocal != null) {
+        capturesForLocal.forEach { recordCapture(it) }
+        for (dec in reversed()) {
+            dec.recordCapture(local)
+            if (dec.localDeclarationCaptures.containsKey(local)) {
+                // this is the scope that the class was defined in, so above this we don't need
+                // to do anything
+                break
+            }
+        }
+    }
+}
+
+private class SymbolOwnerContext(override val declaration: IrSymbolOwner) : DeclarationContext() {
     override val composable get() = false
     override val functionContext: FunctionContext? get() = null
     override val symbol get() = declaration.symbol
+    override val captures: Set<IrValueDeclaration> get() = emptySet()
     override fun declareLocal(local: IrValueDeclaration?) { }
-    override fun recordLocalFunction(local: FunctionContext) { }
-    override fun recordCapture(local: IrValueDeclaration?) { }
-    override fun recordCapture(local: IrFunction?) { }
+    override fun recordCapture(local: IrValueDeclaration?): Boolean { return false }
+    override fun recordCapture(local: IrSymbolOwner?) { }
     override fun pushCollector(collector: CaptureCollector) { }
     override fun popCollector(collector: CaptureCollector) { }
 }
 
 private class FunctionLocalSymbol(
-    val declaration: IrSymbolOwner,
+    override val declaration: IrSymbolOwner,
     override val functionContext: FunctionContext
 ) : DeclarationContext() {
     override val composable: Boolean get() = functionContext.composable
     override val symbol: IrSymbol get() = declaration.symbol
+    override val captures: Set<IrValueDeclaration> get() = functionContext.captures
     override fun declareLocal(local: IrValueDeclaration?) = functionContext.declareLocal(local)
-    override fun recordLocalFunction(local: FunctionContext) =
-        functionContext.recordLocalFunction(local)
     override fun recordCapture(local: IrValueDeclaration?) = functionContext.recordCapture(local)
-    override fun recordCapture(local: IrFunction?) = functionContext.recordCapture(local)
+    override fun recordCapture(local: IrSymbolOwner?) = functionContext.recordCapture(local)
     override fun pushCollector(collector: CaptureCollector) =
         functionContext.pushCollector(collector)
     override fun popCollector(collector: CaptureCollector) =
@@ -151,16 +185,15 @@
 }
 
 private class FunctionContext(
-    val declaration: IrFunction,
+    override val declaration: IrFunction,
     override val composable: Boolean,
     val canRemember: Boolean
 ) : DeclarationContext() {
     override val symbol get() = declaration.symbol
     override val functionContext: FunctionContext? get() = this
     val locals = mutableSetOf<IrValueDeclaration>()
-    val captures = mutableSetOf<IrValueDeclaration>()
+    override val captures: MutableSet<IrValueDeclaration> = mutableSetOf()
     var collectors = mutableListOf<CaptureCollector>()
-    val localFunctionCaptures = mutableMapOf<IrFunction, Set<IrValueDeclaration>>()
 
     init {
         declaration.valueParameters.forEach {
@@ -176,26 +209,22 @@
         }
     }
 
-    override fun recordLocalFunction(local: FunctionContext) {
-        if (local.captures.isNotEmpty() && local.declaration.isLocal) {
-            localFunctionCaptures[local.declaration] = local.captures
-        }
-    }
-
-    override fun recordCapture(local: IrValueDeclaration?) {
-        if (local != null && collectors.isNotEmpty() && locals.contains(local)) {
+    override fun recordCapture(local: IrValueDeclaration?): Boolean {
+        val containsLocal = locals.contains(local)
+        if (local != null && collectors.isNotEmpty() && containsLocal) {
             for (collector in collectors) {
                 collector.recordCapture(local)
             }
         }
-        if (local != null && declaration.isLocal && !locals.contains(local)) {
+        if (local != null && declaration.isLocal && !containsLocal) {
             captures.add(local)
         }
+        return containsLocal
     }
 
-    override fun recordCapture(local: IrFunction?) {
+    override fun recordCapture(local: IrSymbolOwner?) {
         if (local != null) {
-            val captures = localFunctionCaptures[local]
+            val captures = localDeclarationCaptures[local]
             for (collector in collectors) {
                 collector.recordCapture(local)
                 if (captures != null) {
@@ -217,22 +246,28 @@
     }
 }
 
-private class ClassContext(val declaration: IrClass) : DeclarationContext() {
+private class ClassContext(override val declaration: IrClass) : DeclarationContext() {
     override val composable: Boolean = false
     override val symbol get() = declaration.symbol
     override val functionContext: FunctionContext? = null
+    override val captures: MutableSet<IrValueDeclaration> = mutableSetOf()
     val thisParam: IrValueDeclaration? = declaration.thisReceiver!!
     var collectors = mutableListOf<CaptureCollector>()
     override fun declareLocal(local: IrValueDeclaration?) { }
-    override fun recordLocalFunction(local: FunctionContext) { }
-    override fun recordCapture(local: IrValueDeclaration?) {
-        if (local != null && collectors.isNotEmpty() && local == thisParam) {
+    override fun recordCapture(local: IrValueDeclaration?): Boolean {
+        val isThis = local == thisParam
+        val isCtorParam = (local?.parent as? IrConstructor)?.parent === declaration
+        if (local != null && collectors.isNotEmpty() && isThis) {
             for (collector in collectors) {
                 collector.recordCapture(local)
             }
         }
+        if (local != null && declaration.isLocal && !isThis && !isCtorParam) {
+            captures.add(local)
+        }
+        return isThis || isCtorParam
     }
-    override fun recordCapture(local: IrFunction?) { }
+    override fun recordCapture(local: IrSymbolOwner?) { }
     override fun pushCollector(collector: CaptureCollector) {
         collectors.add(collector)
     }
@@ -383,7 +418,7 @@
         val result = super.visitFunction(declaration)
         declarationContextStack.pop()
         if (declaration.isLocal) {
-            declarationContextStack.peek()?.recordLocalFunction(context)
+            declarationContextStack.recordLocalDeclaration(context)
         }
         return result
     }
@@ -393,6 +428,9 @@
         declarationContextStack.push(context)
         val result = super.visitClass(declaration)
         declarationContextStack.pop()
+        if (declaration.isLocal) {
+            declarationContextStack.recordLocalDeclaration(context)
+        }
         return result
     }
 
@@ -402,9 +440,7 @@
     }
 
     override fun visitValueAccess(expression: IrValueAccessExpression): IrExpression {
-        declarationContextStack.forEach {
-            it.recordCapture(expression.symbol.owner)
-        }
+        declarationContextStack.recordCapture(expression.symbol.owner)
         return super.visitValueAccess(expression)
     }
 
@@ -514,13 +550,20 @@
     override fun visitCall(expression: IrCall): IrExpression {
         val fn = expression.symbol.owner
         if (fn.isLocal) {
-            declarationContextStack.forEach {
-                it.recordCapture(fn)
-            }
+            declarationContextStack.recordLocalCapture(fn)
         }
         return super.visitCall(expression)
     }
 
+    override fun visitConstructorCall(expression: IrConstructorCall): IrExpression {
+        val fn = expression.symbol.owner
+        val cls = fn.parent as? IrClass
+        if (cls != null && fn.isLocal) {
+            declarationContextStack.recordLocalCapture(cls)
+        }
+        return super.visitConstructorCall(expression)
+    }
+
     @ObsoleteDescriptorBasedAPI
     private fun visitComposableFunctionExpression(
         expression: IrFunctionExpression,