[go: nahoru, domu]

blob: 985088dbe63149036dc48bffdc70b3423c7d0114 [file] [log] [blame]
/*
* Copyright 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.compose.compiler.plugins.kotlin.lower
import androidx.compose.compiler.plugins.kotlin.ComposeFqNames
import androidx.compose.compiler.plugins.kotlin.analysis.ComposeWritableSlices
import androidx.compose.compiler.plugins.kotlin.composableTrackedContract
import androidx.compose.compiler.plugins.kotlin.irTrace
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContextImpl
import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder
import org.jetbrains.kotlin.backend.common.peek
import org.jetbrains.kotlin.backend.common.pop
import org.jetbrains.kotlin.backend.common.push
import org.jetbrains.kotlin.incremental.components.NoLookupLocation
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.irBlock
import org.jetbrains.kotlin.ir.builders.irBoolean
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irInt
import org.jetbrains.kotlin.ir.builders.irNull
import org.jetbrains.kotlin.ir.builders.irReturn
import org.jetbrains.kotlin.ir.builders.irTemporary
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.declarations.IrSymbolOwner
import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration
import org.jetbrains.kotlin.ir.declarations.IrVariable
import org.jetbrains.kotlin.ir.declarations.copyAttributes
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.IrValueAccessExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionReferenceImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrVarargImpl
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.types.toKotlinType
import org.jetbrains.kotlin.ir.util.DeepCopySymbolRemapper
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.KtFunctionLiteral
import org.jetbrains.kotlin.resolve.BindingTrace
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.resolve.inline.InlineUtil
import org.jetbrains.kotlin.types.typeUtil.isUnit
import org.jetbrains.kotlin.types.typeUtil.replaceArgumentsWithStarProjections
private class CaptureCollector {
val captures = mutableSetOf<IrValueDeclaration>()
fun recordCapture(local: IrValueDeclaration) {
captures.add(local)
}
}
private abstract class DeclarationContext {
abstract val composable: Boolean
abstract val symbol: IrSymbol
abstract fun declareLocal(local: IrValueDeclaration?)
abstract fun recordCapture(local: IrValueDeclaration?)
abstract fun pushCollector(collector: CaptureCollector)
abstract fun popCollector(collector: CaptureCollector)
}
private class SymbolOwnerContext(val declaration: IrSymbolOwner) : DeclarationContext() {
override val composable get() = false
override val symbol get() = declaration.symbol
override fun declareLocal(local: IrValueDeclaration?) { }
override fun recordCapture(local: IrValueDeclaration?) { }
override fun pushCollector(collector: CaptureCollector) { }
override fun popCollector(collector: CaptureCollector) { }
}
private class FunctionContext(
val declaration: IrFunction,
override val composable: Boolean,
val canRemember: Boolean
) : DeclarationContext() {
override val symbol get() = declaration.symbol
val locals = mutableSetOf<IrValueDeclaration>()
var collectors = mutableListOf<CaptureCollector>()
init {
declaration.valueParameters.forEach {
declareLocal(it)
}
}
override fun declareLocal(local: IrValueDeclaration?) {
if (local != null) {
locals.add(local)
}
}
override fun recordCapture(local: IrValueDeclaration?) {
if (local != null && collectors.isNotEmpty() && locals.contains(local)) {
for (collector in collectors) {
collector.recordCapture(local)
}
}
}
override fun pushCollector(collector: CaptureCollector) {
collectors.add(collector)
}
override fun popCollector(collector: CaptureCollector) {
require(collectors.lastOrNull() == collector)
collectors.removeAt(collectors.size - 1)
}
}
const val COMPOSABLE_LAMBDA = "composableLambda"
const val COMPOSABLE_LAMBDA_N = "composableLambdaN"
const val COMPOSABLE_LAMBDA_INSTANCE = "composableLambdaInstance"
const val COMPOSABLE_LAMBDA_N_INSTANCE = "composableLambdaNInstance"
@Suppress("PRE_RELEASE_CLASS")
class ComposerLambdaMemoization(
context: IrPluginContext,
symbolRemapper: DeepCopySymbolRemapper,
bindingTrace: BindingTrace
) :
AbstractComposeLowering(context, symbolRemapper, bindingTrace),
ModuleLoweringPass {
private val declarationContextStack = mutableListOf<DeclarationContext>()
override fun lower(module: IrModuleFragment) = module.transformChildrenVoid(this)
override fun visitDeclaration(declaration: IrDeclaration): IrStatement {
if (declaration is IrFunction) return super.visitDeclaration(declaration)
val symbolOwner = declaration as? IrSymbolOwner
if (symbolOwner != null) declarationContextStack.push(SymbolOwnerContext(declaration))
val result = super.visitDeclaration(declaration)
if (symbolOwner != null) declarationContextStack.pop()
return result
}
private fun irCurrentComposer(): IrExpression {
val currentComposerSymbol = getTopLevelPropertyGetter(
ComposeFqNames.fqNameFor("currentComposer")
)
currentComposerSymbol.bindIfNecessary()
return IrCallImpl(
UNDEFINED_OFFSET,
UNDEFINED_OFFSET,
composerTypeDescriptor.defaultType.replaceArgumentsWithStarProjections().toIrType(),
currentComposerSymbol,
IrStatementOrigin.FOR_LOOP_ITERATOR
)
}
override fun visitFunction(declaration: IrFunction): IrStatement {
val descriptor = declaration.descriptor
val composable = descriptor.allowsComposableCalls()
val canRemember = composable &&
// Don't use remember in an inline function
!descriptor.isInline &&
// Don't use remember if in a composable that returns a value
// TODO(b/150390108): Consider allowing remember in effects
descriptor.returnType.let { it != null && it.isUnit() }
declarationContextStack.push(FunctionContext(declaration, composable, canRemember))
val result = super.visitFunction(declaration)
declarationContextStack.pop()
return result
}
override fun visitVariable(declaration: IrVariable): IrStatement {
declarationContextStack.peek()?.declareLocal(declaration)
return super.visitVariable(declaration)
}
override fun visitValueAccess(expression: IrValueAccessExpression): IrExpression {
declarationContextStack.forEach { it.recordCapture(expression.symbol.owner) }
return super.visitValueAccess(expression)
}
override fun visitFunctionReference(expression: IrFunctionReference): IrExpression {
// Memoize the instance created by using the :: operator
val result = super.visitFunctionReference(expression)
val functionContext = declarationContextStack.peek() as? FunctionContext
?: return result
if (expression.valueArgumentsCount != 0) {
// If this syntax is as a curry syntax in the future, don't memoize.
// The syntax <expr>::<method>(<params>) and ::<function>(<params>) is reserved for
// future use. This ensures we don't try to memoize this syntax without knowing
// its meaning.
// The most likely correct implementation is to treat the parameters exactly as the
// receivers are treated below.
return result
}
if (functionContext.canRemember) {
// Memoize the reference for <expr>::<method>
val dispatchReceiver = expression.dispatchReceiver
val extensionReceiver = expression.extensionReceiver
if ((dispatchReceiver != null || extensionReceiver != null) &&
dispatchReceiver.isNullOrStable() &&
extensionReceiver.isNullOrStable()) {
// Save the receivers into a temporaries and memoize the function reference using
// the resulting temporaries
val builder = DeclarationIrBuilder(
generatorContext = context,
symbol = functionContext.symbol,
startOffset = expression.startOffset,
endOffset = expression.endOffset
)
return builder.irBlock(
resultType = expression.type
) {
val captures = mutableListOf<IrValueDeclaration>()
val tempDispatchReceiver = dispatchReceiver?.let {
val tmp = irTemporary(it)
captures.add(tmp)
tmp
}
val tempExtensionReceiver = extensionReceiver?.let {
val tmp = irTemporary(it)
captures.add(tmp)
tmp
}
+rememberExpression(
functionContext,
IrFunctionReferenceImpl(
startOffset,
endOffset,
expression.type,
expression.symbol,
expression.typeArgumentsCount,
expression.reflectionTarget).copyAttributes(expression).apply {
this.dispatchReceiver = tempDispatchReceiver?.let { irGet(it) }
this.extensionReceiver = tempExtensionReceiver?.let { irGet(it) }
},
captures
)
}
} else if (dispatchReceiver == null) {
return rememberExpression(functionContext, result, emptyList())
}
}
return result
}
private fun visitNonComposableFunctionExpression(
expression: IrFunctionExpression,
declarationContext: DeclarationContext
): IrExpression {
val functionContext = declarationContext as? FunctionContext
?: return super.visitFunctionExpression(expression)
if (
// Only memoize non-composable lambdas in a context we can use remember
!functionContext.canRemember ||
// Don't memoize inlined lambdas
expression.isInlineArgument() ||
// Don't memoize untracked function
!expression.isTracked()) {
return super.visitFunctionExpression(expression)
}
// Record capture variables for this scope
val collector = CaptureCollector()
startCollector(collector)
// Wrap composable functions expressions or memoize non-composable function expressions
val result = super.visitFunctionExpression(expression)
stopCollector(collector)
// If the ancestor converted this then return
val functionExpression = result as? IrFunctionExpression ?: return result
return rememberExpression(
functionContext,
functionExpression,
collector.captures.toList()
)
}
private fun visitComposableFunctionExpression(
expression: IrFunctionExpression,
declarationContext: DeclarationContext
): IrExpression {
val result = super.visitFunctionExpression(expression)
// If the ancestor converted this then return
val functionExpression = result as? IrFunctionExpression ?: return result
// Do not wrap target of an inline function
if (expression.isInlineArgument()) return functionExpression
// Do not wrap composable lambdas with return results
if (!functionExpression.function.descriptor.returnType.let { it == null || it.isUnit() }) {
return functionExpression
}
return wrapFunctionExpression(declarationContext, functionExpression)
}
override fun visitFunctionExpression(expression: IrFunctionExpression): IrExpression {
val declarationContext = declarationContextStack.peek()
?: return super.visitFunctionExpression(expression)
return if (expression.allowsComposableCalls())
visitComposableFunctionExpression(expression, declarationContext)
else
visitNonComposableFunctionExpression(expression, declarationContext)
}
private fun startCollector(collector: CaptureCollector) {
for (declarationContext in declarationContextStack) {
declarationContext.pushCollector(collector)
}
}
private fun stopCollector(collector: CaptureCollector) {
for (declarationContext in declarationContextStack) {
declarationContext.popCollector(collector)
}
}
private fun wrapFunctionExpression(
declarationContext: DeclarationContext,
expression: IrFunctionExpression
): IrExpression {
val function = expression.function
val argumentCount = function.descriptor.valueParameters.size
val useComposableLambdaN = argumentCount > MAX_RESTART_ARGUMENT_COUNT
val restartFunctionFactory =
if (declarationContext.composable)
if (useComposableLambdaN)
COMPOSABLE_LAMBDA_N
else COMPOSABLE_LAMBDA
else if (useComposableLambdaN)
COMPOSABLE_LAMBDA_N_INSTANCE
else COMPOSABLE_LAMBDA_INSTANCE
val restartFactorySymbol =
getTopLevelFunction(ComposeFqNames.internalFqNameFor(restartFunctionFactory))
val irBuilder = DeclarationIrBuilder(context,
symbol = declarationContext.symbol,
startOffset = expression.startOffset,
endOffset = expression.endOffset
)
(context as IrPluginContextImpl).linker.getDeclaration(restartFactorySymbol)
return irBuilder.irCall(restartFactorySymbol).apply {
var index = 0
// first parameter is the composer parameter if we are ina composable context
if (declarationContext.composable) {
putValueArgument(
index++,
irCurrentComposer()
)
}
// key parameter
putValueArgument(
index++, irBuilder.irInt(
@Suppress("DEPRECATION")
symbol.descriptor.fqNameSafe.hashCode() xor expression.startOffset
)
)
// tracked parameter
putValueArgument(index++, irBuilder.irBoolean(expression.isTracked()))
if (declarationContext.composable) {
// sourceInformation parameter
putValueArgument(index++, irBuilder.irNull())
}
// ComposableLambdaN requires the arity
if (useComposableLambdaN) {
// arity parameter
putValueArgument(index++, irBuilder.irInt(argumentCount))
}
if (index >= valueArgumentsCount) {
error("function = ${
function.name.asString()
}, count = $valueArgumentsCount, index = $index")
}
// block parameter
putValueArgument(index, expression)
}
}
private fun rememberExpression(
functionContext: FunctionContext,
expression: IrExpression,
captures: List<IrValueDeclaration>
): IrExpression {
if (captures.any {
!((it as? IrVariable)?.isVar != true && it.type.toKotlinType().isStable())
}) return expression
val rememberParameterCount = captures.size + 1 // One additional parameter for the lambda
val declaration = functionContext.declaration
val descriptor = declaration.descriptor
val module = descriptor.module
val rememberFunctions = module
.getPackage(ComposeFqNames.Package)
.memberScope
.getContributedFunctions(
Name.identifier("remember"),
NoLookupLocation.FROM_BACKEND
)
val directRememberFunction = // Exclude the varargs version
rememberFunctions.singleOrNull {
it.valueParameters.size == rememberParameterCount &&
// Exclude the varargs version
it.valueParameters.firstOrNull()?.varargElementType == null
}
val rememberFunctionDescriptor = directRememberFunction
?: rememberFunctions.single {
// Use the varargs version
it.valueParameters.firstOrNull()?.varargElementType != null
}
val rememberFunctionSymbol = referenceSimpleFunction(rememberFunctionDescriptor)
val irBuilder = DeclarationIrBuilder(
generatorContext = context,
symbol = functionContext.symbol,
startOffset = expression.startOffset,
endOffset = expression.endOffset
)
return irBuilder.irCall(
callee = rememberFunctionSymbol,
descriptor = rememberFunctionDescriptor,
type = expression.type
).apply {
// The result type type parameter is first, followed by the argument types
putTypeArgument(0, expression.type)
val lambdaArgumentIndex = if (directRememberFunction != null) {
// Call to the non-vararg version
for (i in 1..captures.size) {
putTypeArgument(i, captures[i - 1].type)
}
// condition arguments are the first `arg.size` arguments
for (i in captures.indices) {
putValueArgument(i, irBuilder.irGet(captures[i]))
}
// The lambda is the last parameter
captures.size
} else {
val parameterType = rememberFunctionDescriptor.valueParameters[0].type.toIrType()
// Call to the vararg version
putValueArgument(0,
IrVarargImpl(
startOffset = UNDEFINED_OFFSET,
endOffset = UNDEFINED_OFFSET,
type = parameterType,
varargElementType = context.irBuiltIns.anyType,
elements = captures.map {
irBuilder.irGet(it)
}
)
)
1
}
putValueArgument(lambdaArgumentIndex, irBuilder.irLambdaExpression(
descriptor = irBuilder.createFunctionDescriptor(
rememberFunctionDescriptor.valueParameters.last().type
),
type = rememberFunctionDescriptor.valueParameters.last().type.toIrType(),
body = {
+irReturn(expression)
}
))
}.patchDeclarationParents(declaration).markAsSynthetic(mark = true)
}
private fun <T : IrFunctionAccessExpression> T.markAsSynthetic(mark: Boolean): T {
if (mark) {
// Mark it so the ComposableCallTransformer will insert the correct code around this
// call
context.irTrace.record(
ComposeWritableSlices.IS_SYNTHETIC_COMPOSABLE_CALL,
this,
true
)
}
return this
}
private fun IrFunctionExpression.isInlineArgument(): Boolean {
function.descriptor.findPsi()?.let { psi ->
(psi as? KtFunctionLiteral)?.let {
if (InlineUtil.isInlinedArgument(
it,
@Suppress("DEPRECATION") context.bindingContext,
false
)
)
return true
}
}
return false
}
private fun IrExpression?.isNullOrStable() = this == null || type.toKotlinType().isStable()
private fun IrFunctionExpression.isTracked() =
function.descriptor.composableTrackedContract() != false
}
// This must match the highest value of FunctionXX which is current Function22
private const val MAX_RESTART_ARGUMENT_COUNT = 22