| /* |
| * Copyright 2020 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package androidx.compose.plugins.kotlin.compiler.lower |
| |
| import androidx.compose.plugins.kotlin.ComposeFqNames |
| import androidx.compose.plugins.kotlin.KtxNameConventions |
| import androidx.compose.plugins.kotlin.analysis.ComposeWritableSlices |
| import androidx.compose.plugins.kotlin.hasDirectAnnotation |
| import androidx.compose.plugins.kotlin.hasUntrackedAnnotation |
| import androidx.compose.plugins.kotlin.irTrace |
| import androidx.compose.plugins.kotlin.isEmitInline |
| import org.jetbrains.kotlin.backend.common.FileLoweringPass |
| import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext |
| import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder |
| import org.jetbrains.kotlin.backend.common.pop |
| import org.jetbrains.kotlin.backend.common.push |
| import org.jetbrains.kotlin.descriptors.CallableMemberDescriptor |
| import org.jetbrains.kotlin.descriptors.FunctionDescriptor |
| import org.jetbrains.kotlin.descriptors.Modality |
| import org.jetbrains.kotlin.descriptors.PropertyDescriptor |
| import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor |
| import org.jetbrains.kotlin.descriptors.SourceElement |
| import org.jetbrains.kotlin.descriptors.Visibilities |
| import org.jetbrains.kotlin.descriptors.annotations.Annotations |
| import org.jetbrains.kotlin.descriptors.impl.AnonymousFunctionDescriptor |
| import org.jetbrains.kotlin.descriptors.impl.ValueParameterDescriptorImpl |
| import org.jetbrains.kotlin.fir.java.topLevelName |
| import org.jetbrains.kotlin.incremental.components.NoLookupLocation |
| import org.jetbrains.kotlin.ir.IrElement |
| import org.jetbrains.kotlin.ir.IrStatement |
| import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET |
| import org.jetbrains.kotlin.ir.backend.js.utils.OperatorNames |
| import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter |
| import org.jetbrains.kotlin.ir.builders.irBlockBody |
| import org.jetbrains.kotlin.ir.builders.irCall |
| import org.jetbrains.kotlin.ir.builders.irGet |
| import org.jetbrains.kotlin.ir.builders.irReturn |
| import org.jetbrains.kotlin.ir.declarations.IrAnonymousInitializer |
| import org.jetbrains.kotlin.ir.declarations.IrClass |
| import org.jetbrains.kotlin.ir.declarations.IrDeclaration |
| import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin |
| import org.jetbrains.kotlin.ir.declarations.IrEnumEntry |
| import org.jetbrains.kotlin.ir.declarations.IrField |
| import org.jetbrains.kotlin.ir.declarations.IrFile |
| import org.jetbrains.kotlin.ir.declarations.IrFunction |
| import org.jetbrains.kotlin.ir.declarations.IrLocalDelegatedProperty |
| import org.jetbrains.kotlin.ir.declarations.IrModuleFragment |
| import org.jetbrains.kotlin.ir.declarations.IrProperty |
| import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction |
| import org.jetbrains.kotlin.ir.declarations.IrTypeAlias |
| import org.jetbrains.kotlin.ir.declarations.IrTypeParameter |
| import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration |
| import org.jetbrains.kotlin.ir.declarations.IrValueParameter |
| import org.jetbrains.kotlin.ir.declarations.IrVariable |
| import org.jetbrains.kotlin.ir.declarations.impl.IrFunctionImpl |
| import org.jetbrains.kotlin.ir.declarations.impl.IrVariableImpl |
| import org.jetbrains.kotlin.ir.expressions.IrBlock |
| import org.jetbrains.kotlin.ir.expressions.IrBody |
| import org.jetbrains.kotlin.ir.expressions.IrBreakContinue |
| import org.jetbrains.kotlin.ir.expressions.IrCall |
| import org.jetbrains.kotlin.ir.expressions.IrConst |
| import org.jetbrains.kotlin.ir.expressions.IrConstKind |
| import org.jetbrains.kotlin.ir.expressions.IrConstructorCall |
| import org.jetbrains.kotlin.ir.expressions.IrDoWhileLoop |
| import org.jetbrains.kotlin.ir.expressions.IrElseBranch |
| import org.jetbrains.kotlin.ir.expressions.IrExpression |
| import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression |
| import org.jetbrains.kotlin.ir.expressions.IrFunctionExpression |
| import org.jetbrains.kotlin.ir.expressions.IrGetEnumValue |
| import org.jetbrains.kotlin.ir.expressions.IrGetObjectValue |
| import org.jetbrains.kotlin.ir.expressions.IrGetValue |
| import org.jetbrains.kotlin.ir.expressions.IrLoop |
| import org.jetbrains.kotlin.ir.expressions.IrReturn |
| import org.jetbrains.kotlin.ir.expressions.IrStatementContainer |
| import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin |
| import org.jetbrains.kotlin.ir.expressions.IrVararg |
| import org.jetbrains.kotlin.ir.expressions.IrWhen |
| import org.jetbrains.kotlin.ir.expressions.IrWhileLoop |
| import org.jetbrains.kotlin.ir.expressions.impl.IrBlockBodyImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrBlockImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrBranchImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrCompositeImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrContainerExpressionBase |
| import org.jetbrains.kotlin.ir.expressions.impl.IrElseBranchImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrLoopBase |
| import org.jetbrains.kotlin.ir.expressions.impl.IrReturnImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrSpreadElementImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrVarargImpl |
| import org.jetbrains.kotlin.ir.expressions.impl.IrWhenImpl |
| import org.jetbrains.kotlin.ir.symbols.IrReturnTargetSymbol |
| import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl |
| import org.jetbrains.kotlin.ir.types.IrType |
| import org.jetbrains.kotlin.ir.types.classOrNull |
| import org.jetbrains.kotlin.ir.types.defaultType |
| import org.jetbrains.kotlin.ir.types.isNothing |
| import org.jetbrains.kotlin.ir.types.isUnit |
| import org.jetbrains.kotlin.ir.types.isUnitOrNullableUnit |
| import org.jetbrains.kotlin.ir.types.makeNullable |
| import org.jetbrains.kotlin.ir.types.toKotlinType |
| import org.jetbrains.kotlin.ir.util.DeepCopySymbolRemapper |
| import org.jetbrains.kotlin.ir.util.getArguments |
| import org.jetbrains.kotlin.ir.util.getPropertyGetter |
| import org.jetbrains.kotlin.ir.util.hasDefaultValue |
| import org.jetbrains.kotlin.ir.util.isInlined |
| import org.jetbrains.kotlin.ir.util.isVararg |
| import org.jetbrains.kotlin.ir.util.patchDeclarationParents |
| import org.jetbrains.kotlin.ir.util.statements |
| import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid |
| import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi |
| import org.jetbrains.kotlin.name.FqName |
| import org.jetbrains.kotlin.name.Name |
| import org.jetbrains.kotlin.psi.KtFunctionLiteral |
| import org.jetbrains.kotlin.psi2ir.findFirstFunction |
| import org.jetbrains.kotlin.resolve.BindingTrace |
| import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe |
| import org.jetbrains.kotlin.resolve.inline.InlineUtil |
| import org.jetbrains.kotlin.types.typeUtil.isUnit |
| import org.jetbrains.kotlin.types.typeUtil.makeNullable |
| import org.jetbrains.kotlin.types.typeUtil.replaceArgumentsWithStarProjections |
| import org.jetbrains.kotlin.utils.addToStdlib.cast |
| import org.jetbrains.kotlin.utils.addToStdlib.lastIsInstanceOrNull |
| import org.jetbrains.kotlin.utils.ifEmpty |
| import kotlin.math.abs |
| import kotlin.math.ceil |
| import kotlin.math.min |
| |
| /** |
| * An enum of the different "states" a parameter of a composable function can have relating to |
| * comparison propagation. Each state is represented by two bits in the `$changed` bitmask. |
| */ |
| enum class ParamState(private val bits: Int) { |
| /** |
| * Indicates that nothing is certain about the current state of the parameter. It could be |
| * different than it was during the last execution, or it could be the same, but it is not |
| * known so the current function looking at it must call equals on it in order to find out. |
| * This is the only state that can cause the function to spend slot table space in order to |
| * look at it. |
| */ |
| Uncertain(0b00), |
| /** |
| * This indicates that the value is known to be the same since the last time the function was |
| * executed. There is no need to store the value in the slot table in this case because the |
| * calling function will *always* know whether the value was the same or different as it was |
| * in the previous execution. |
| */ |
| Same(0b01), |
| /** |
| * This indicates that the value is known to be different since the last time the function |
| * was executed. There is no need to store the value in the slot table in this case because |
| * the calling function will *always* know whether the value was the same or different as it |
| * was in the previous execution. |
| */ |
| Different(0b10), |
| /** |
| * This indicates that the value is known to *never change* for the duration of the running |
| * program. |
| */ |
| Static(0b11); |
| |
| fun bitsForSlot(slot: Int): Int = bitsForSlot(bits, slot) |
| } |
| |
| const val BITS_PER_INT = 31 |
| const val SLOTS_PER_INT = 15 |
| |
| fun bitsForSlot(bits: Int, slot: Int): Int { |
| val realSlot = slot.rem(SLOTS_PER_INT) |
| return bits shl (realSlot * 2 + 1) |
| } |
| |
| fun defaultsParamIndex(index: Int): Int = index / BITS_PER_INT |
| fun defaultsBitIndex(index: Int): Int = index.rem(BITS_PER_INT) |
| |
| val IrFunction.thisParamCount |
| get() = ( |
| if (dispatchReceiverParameter != null) 1 else 0 |
| ) + ( |
| if (extensionReceiverParameter != null) 1 else 0 |
| ) |
| |
| fun changedParamCount(realValueParams: Int, thisParams: Int): Int { |
| if (realValueParams == 0) return 1 |
| val totalParams = realValueParams + thisParams |
| return ceil( |
| totalParams.toDouble() / SLOTS_PER_INT.toDouble() |
| ).toInt() |
| } |
| |
| fun changedParamCountFromTotal(totalParamsIncludingThisParams: Int): Int { |
| var realParams = totalParamsIncludingThisParams |
| realParams-- // composer param |
| realParams-- // key param |
| realParams-- // first changed param (always present) |
| var changedParams = 0 |
| do { |
| realParams -= SLOTS_PER_INT |
| changedParams++ |
| } while (realParams > 0) |
| return changedParams |
| } |
| |
| fun defaultParamCount(realValueParams: Int): Int { |
| return ceil( |
| realValueParams.toDouble() / BITS_PER_INT.toDouble() |
| ).toInt() |
| } |
| |
| fun composeSyntheticParamCount( |
| realValueParams: Int, |
| thisParams: Int = 0, |
| hasDefaults: Boolean = false |
| ): Int { |
| return 1 + // composer param |
| 1 + // key param |
| changedParamCount(realValueParams, thisParams) + |
| if (hasDefaults) defaultParamCount(realValueParams) else 0 |
| } |
| |
| interface IrChangedBitMaskValue { |
| fun irLowBit(): IrExpression |
| fun irIsolateBitsAtSlot(slot: Int): IrExpression |
| fun irHasDifferences(): IrExpression |
| fun irCopyToTemporary( |
| nameHint: String? = null, |
| isVar: Boolean = false, |
| exactName: Boolean = false |
| ): IrChangedBitMaskVariable |
| fun putAsValueArgumentInWithLowBit( |
| fn: IrFunctionAccessExpression, |
| startIndex: Int, |
| lowBit: Boolean |
| ) |
| fun irShiftBits(fromSlot: Int, toSlot: Int): IrExpression |
| } |
| |
| interface IrDefaultBitMaskValue { |
| fun irIsolateBitAtIndex(index: Int): IrExpression |
| fun irHasAnyProvidedAndUnstable(unstable: BooleanArray): IrExpression |
| fun putAsValueArgumentIn(fn: IrFunctionAccessExpression, startIndex: Int) |
| } |
| |
| interface IrChangedBitMaskVariable : IrChangedBitMaskValue { |
| fun asStatements(): List<IrStatement> |
| fun irOrSetBitsAtSlot(slot: Int, value: IrExpression): IrExpression |
| } |
| |
| /** |
| * This IR Transform is responsible for the main transformations of the body of a composable |
| * function. |
| * |
| * 1. Control-Flow Group Generation |
| * 2. Default arguments |
| * 3. Composable Function Skipping |
| * 4. Comparison Propagation |
| * 5. Recomposability |
| * |
| * Control-Flow Group Generation |
| * ============================= |
| * |
| * This transform will insert groups inside of the bodies of Composable functions |
| * depending on the control-flow structures that exist inside of them. |
| * |
| * There are 3 types of groups in Compose: |
| * |
| * 1. Replaceable Groups |
| * 2. Movable Groups |
| * 3. Restart Groups |
| * |
| * Generally speaking, every composable function *must* emit a single group when it executes. |
| * Every group can have any number of children groups. Additionally, we analyze each executable |
| * block and apply the following rules: |
| * |
| * 1. If a block executes exactly 1 time always, no groups are needed |
| * 2. If a set of blocks are such that exactly one of them is executed exactly once (for example, |
| * the result blocks of a when clause), then we insert a replaceable group around each block. |
| * 3. A movable group is only needed if the immediate composable call in the group has a Pivotal |
| * property. |
| * |
| * Default Arguments |
| * ================= |
| * |
| * Composable functions need to have the default expressions executed inside of the group of the |
| * function. In order to accomplish this, composable functions handle default arguments |
| * themselves, instead of using the default handling of kotlin. This is also a win because we can |
| * handle the default arguments without generating an additional function since we do not need to |
| * worry about callers from java. Generally speaking though, compose handles default arguments |
| * similarly to kotlin in that we generate a $default bitmask parameter which maps each parameter |
| * index to a bit on the int. A value of "1" for a given parameter index indicated that that |
| * value was *not* provided at the callsite, and the default expression should be used instead. |
| * |
| * @Composable fun A(x: Int = 0) { |
| * f(x) |
| * } |
| * |
| * gets transformed into |
| * |
| * @Composable fun A(x: Int, $default: Int) { |
| * val x = if ($default and 0b1 != 0) 0 else x |
| * f(x) |
| * } |
| * |
| * Note: This transform requires [ComposerParamTransformer] to also be run in order to work |
| * properly. |
| * |
| * Composable Function Skipping |
| * ============================ |
| * |
| * Composable functions can "skip" their execution if certain conditions are met. This is done by |
| * appealing to the composer and storing previous values of functions and determining if we can |
| * skip based on whether or not they have changed. |
| * |
| * @Composable fun A(x: Int) { |
| * f(x) |
| * } |
| * |
| * gets transformed into |
| * |
| * @Composable fun A(x: Int, $composer: Composer<*>, $changed: Int) { |
| * var $dirty = $changed |
| * if ($changed and 0b0110 === 0) { |
| * $dirty = $dirty or if ($composer.changed(x)) 0b0010 else 0b0100 |
| * } |
| * if (%dirty and 0b1011 xor 0b1010 !== 0 || !$composer.skipping) { |
| * f(x) |
| * } else { |
| * $composer.skipToGroupEnd() |
| * } |
| * } |
| * |
| * Note that this makes use of bitmasks for the $changed and $dirty values. These bitmasks work |
| * in a different bit-space than the $default bitmask because two bits are needed to hold the |
| * four different possible states of each parameter. Additionally, the lowest bit of the bitmask |
| * is a special bit which forces execution of the function. |
| * |
| * This means that for the ith parameter of a composable function, the bit range of i*2 + 1 to |
| * i*2 + 2 are used to store the state of the parameter. |
| * |
| * The states are outlines by the [ParamState] class. |
| * |
| * Comparison Propagation |
| * ====================== |
| * |
| * Because we detect changes in parameters of composable functions and have that data available |
| * in the body of a composable function, if we pass values to another composable function, it |
| * makes sense for us to pass on whatever information about that value we can determine at the |
| * time. This type of propagation of information through composable functions is called |
| * Comparison Propagation. |
| * |
| * Essentially, this comes down to us passing in useful values into the `$changed` parameter of |
| * composable functions. |
| * |
| * When a composable function executes, we have the current known states of all of the function's |
| * parameters in the $dirty variable. We can take bits off of this variable and pass them into a |
| * composable function in order to tell that function what we know. |
| * |
| * @Composable fun A(x: Int) { |
| * B(x, 123) |
| * } |
| * |
| * gets transformed into |
| * |
| * @Composable fun A(x: Int, $composer: Composer<*>, $changed: Int) { |
| * var $dirty = ... |
| * // ... |
| * B( |
| * x, |
| * 123, |
| * $composer, |
| * (0b110 and $dirty) or // 1st param has same state that our 1st param does |
| * 0b11000 // 2nd parameter is "static" |
| * ) |
| * } |
| * |
| * Recomposability |
| * =============== |
| * |
| * Restartable composable functions get wrapped with "restart groups". Restart groups are like |
| * other groups except the end call is more complicated, as it returns a null value if and |
| * only if a subscription to that scope could not have occurred. If the value returned is |
| * non-null, we generate a lambda that teaches the runtime how to "restart" that group. At a high |
| * level, this transform comes down to: |
| * |
| * @Composable fun A(x: Int) { |
| * f(x) |
| * } |
| * |
| * getting transformed into |
| * |
| * @Composable fun A(x: Int, $composer: Composer<*>, $changed: Int) { |
| * $composer.startRestartGroup() |
| * // ... |
| * f(x) |
| * $composer.endRestartGroup()?.updateScope { next -> A(x, next, $changed or 0b1) } |
| * } |
| */ |
| class ComposableFunctionBodyTransformer( |
| context: IrPluginContext, |
| symbolRemapper: DeepCopySymbolRemapper, |
| bindingTrace: BindingTrace |
| ) : |
| AbstractComposeLowering(context, symbolRemapper, bindingTrace), |
| FileLoweringPass, |
| ModuleLoweringPass { |
| |
| override fun lower(module: IrModuleFragment) { |
| module.transformChildrenVoid(this) |
| module.patchDeclarationParents() |
| } |
| |
| override fun lower(irFile: IrFile) { |
| irFile.transformChildrenVoid(this) |
| } |
| |
| private val changedDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("changed") { |
| // this is the changed(value: T) variant. |
| // TODO(lmr): Add handling for different primitive types |
| it.typeParameters.size == 1 |
| } |
| |
| private val skipToGroupEndDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("skipToGroupEnd") { it.valueParameters.size == 0 } |
| |
| private val skipCurrentGroupDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("skipCurrentGroup") { it.valueParameters.size == 0 } |
| |
| private val startReplaceableDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("startReplaceableGroup") { it.valueParameters.size == 1 } |
| |
| private val endReplaceableDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("endReplaceableGroup") { it.valueParameters.size == 0 } |
| |
| private val startDefaultsDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("startDefaults") { it.valueParameters.size == 0 } |
| |
| private val endDefaultsDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("endDefaults") { it.valueParameters.size == 0 } |
| |
| private val startMovableDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("startMovableGroup") { it.valueParameters.size == 2 } |
| |
| private val endMovableDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction("endMovableGroup") { it.valueParameters.size == 0 } |
| |
| private val startRestartGroupDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction(KtxNameConventions.STARTRESTARTGROUP.identifier) { |
| it.valueParameters.size == 1 |
| } |
| |
| private val endRestartGroupDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction(KtxNameConventions.ENDRESTARTGROUP.identifier) { |
| it.valueParameters.size == 0 |
| } |
| |
| private val updateScopeDescriptor = |
| endRestartGroupDescriptor.returnType?.memberScope?.getContributedFunctions( |
| KtxNameConventions.UPDATE_SCOPE, |
| NoLookupLocation.FROM_BACKEND |
| )?.singleOrNull { it.valueParameters.first().type.arguments.size == 4 } |
| ?: error("new updateScope not found in result type of endRestartGroup") |
| |
| private val updateScopeBlockType = updateScopeDescriptor.valueParameters.single().type |
| |
| private val isSkippingDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .getContributedDescriptors { it.asString() == "skipping" } |
| .first { it is PropertyDescriptor && it.name.asString() == "skipping" } |
| .cast<PropertyDescriptor>() |
| |
| private val defaultsInvalidDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .getContributedDescriptors { it.asString() == "defaultsInvalid" } |
| .first { it is PropertyDescriptor && it.name.asString() == "defaultsInvalid" } |
| .cast<PropertyDescriptor>() |
| |
| private val joinKeyDescriptor = composerTypeDescriptor |
| .unsubstitutedMemberScope |
| .findFirstFunction(KtxNameConventions.JOINKEY.identifier) { |
| it.valueParameters.size == 2 |
| } |
| |
| private val scopeStack = mutableListOf<Scope>() |
| |
| private fun printScopeStack(): String { |
| return buildString { |
| for (scope in scopeStack) { |
| appendln(scope.name) |
| } |
| } |
| } |
| |
| private val isInComposableScope: Boolean |
| get() { |
| loop@ for (scope in scopeStack.asReversed()) { |
| return when (scope) { |
| is Scope.FunctionScope -> scope.isComposable |
| is Scope.BlockScope -> continue@loop |
| else -> false |
| } |
| } |
| return false |
| } |
| |
| private val currentFunctionScope |
| get() = scopeStack.lastIsInstanceOrNull<Scope.FunctionScope>() |
| ?: error("Expected a FunctionScope but none exist. \n${printScopeStack()}") |
| |
| override fun visitClass(declaration: IrClass): IrStatement { |
| val scope = Scope.ClassScope(declaration.name) |
| try { |
| scopeStack.push(scope) |
| return super.visitDeclaration(declaration) |
| } finally { |
| require(scopeStack.pop() == scope) { "Unbalanced scope stack" } |
| } |
| } |
| |
| override fun visitFunction(declaration: IrFunction): IrStatement { |
| val scope = Scope.FunctionScope(declaration, this) |
| try { |
| scopeStack.push(scope) |
| return visitFunctionInScope(declaration) |
| } finally { |
| val popped = scopeStack.pop() |
| require(popped == scope) { "Unbalanced scope stack" } |
| } |
| } |
| |
| private fun visitFunctionInScope(declaration: IrFunction): IrStatement { |
| val scope = currentFunctionScope |
| // if the function isn't composable, there's nothing to do |
| if (!scope.isComposable) return super.visitFunction(declaration) |
| val restartable = declaration.shouldBeRestartable() |
| val isLambda = declaration.isLambda() |
| // if the lambda is untracked, we generate the body like a non-restartable function since |
| // the group/update scope is not going to be handled by the RestartableFunction class |
| val isTracked = !declaration.descriptor.hasUntrackedAnnotation() |
| |
| if (declaration.body == null) return declaration |
| |
| val changedParam = scope.changedParameter!! |
| val defaultParam = scope.defaultParameter |
| |
| // restartable functions get extra logic and different types of groups from |
| // non-restartable functions, and lambdas get no groups at all. |
| return when { |
| isLambda && isTracked -> visitComposableLambda( |
| declaration, |
| scope, |
| changedParam |
| ) |
| restartable && isTracked -> visitRestartableComposableFunction( |
| declaration, |
| scope, |
| changedParam, |
| defaultParam |
| ) |
| else -> visitNonRestartableComposableFunction( |
| declaration, |
| scope, |
| changedParam, |
| defaultParam |
| ) |
| } |
| } |
| |
| // Currently, we make all composable functions restartable by default, unless: |
| // 1. They are inline |
| // 2. They have a return value (may get relaxed in the future) |
| // 3. They are a lambda (we use RestartableFunction<...> class for this instead) |
| // 4. They are annotated as @Direct |
| private fun IrFunction.shouldBeRestartable(): Boolean { |
| // Only insert observe scopes in non-empty composable function |
| if (body == null) |
| return false |
| |
| val descriptor = descriptor |
| |
| // Do not insert observe scope in an inline function |
| if (descriptor.isInline) |
| return false |
| |
| if (descriptor.hasDirectAnnotation()) |
| return false |
| |
| // Do not insert an observe scope in an inline composable lambda |
| descriptor.findPsi()?.let { psi -> |
| (psi as? KtFunctionLiteral)?.let { |
| if (InlineUtil.isInlinedArgument( |
| it, |
| context.bindingContext, |
| false |
| ) |
| ) |
| return false |
| if (it.isEmitInline(context.bindingContext)) { |
| return false |
| } |
| } |
| } |
| |
| // Do not insert an observe scope if the function has a return result |
| if (descriptor.returnType.let { it == null || !it.isUnit() }) |
| return false |
| |
| // Do not insert an observe scope if the function hasn't been transformed by the |
| // ComposerParamTransformer and has a synthetic "composer param" as its last parameter |
| if (composerParam() == null) return false |
| |
| // Check if the descriptor has restart scope calls resolved |
| if (descriptor is SimpleFunctionDescriptor && |
| // Lambdas should be ignored. All composable lambdas are wrapped by a restartable |
| // function wrapper by ComposerLambdaMemoization which supplies the startRestartGroup/ |
| // endRestartGroup pair on behalf of the lambda. |
| origin != IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA && |
| origin != IrDeclarationOrigin.LOCAL_FUNCTION_NO_CLOSURE) { |
| |
| return true |
| } |
| |
| return false |
| } |
| |
| private fun IrFunction.isLambda(): Boolean { |
| // There is probably a better way to determine this, but if there is, it isn't obvious |
| return descriptor.name.asString() == "<anonymous>" |
| } |
| |
| // At a high level, a non-restartable composable function |
| // 1. gets a replaceable group placed around the body |
| // 2. never calls `$composer.changed(...)` with its parameters |
| // 3. can have default parameters, so needs to add the defaults preamble if defaults present |
| // 4. proper groups around control flow structures in the body |
| private fun visitNonRestartableComposableFunction( |
| declaration: IrFunction, |
| scope: Scope.FunctionScope, |
| changedParam: IrChangedBitMaskValue, |
| defaultParam: IrDefaultBitMaskValue? |
| ): IrStatement { |
| val body = declaration.body!! |
| |
| val skipPreamble = mutableStatementContainer() |
| val bodyPreamble = mutableStatementContainer() |
| |
| scope.dirty = changedParam |
| |
| val realParams = declaration.valueParameters.take(scope.realValueParamCount) |
| |
| buildStatementsForSkippingAndDefaults( |
| body, |
| skipPreamble, |
| bodyPreamble, |
| false, |
| realParams, |
| scope, |
| changedParam, |
| changedParam, |
| defaultParam, |
| booleanArrayOf() |
| ) |
| |
| realParams.forEach { |
| // we want to remove the default expression from the function. This will prevent |
| // the kotlin compiler from doing its own default handling, which we don't need. |
| it.defaultValue = null |
| } |
| |
| var (transformed, returnVar) = body.asBodyAndResultVar() |
| |
| transformed = transformed.transformChildren() |
| |
| scope.realizeGroup(::irEndReplaceableGroup) |
| |
| declaration.body = IrBlockBodyImpl( |
| body.startOffset, |
| body.endOffset, |
| listOfNotNull( |
| irStartReplaceableGroup(body, irGet(scope.keyParameter!!)), |
| *skipPreamble.statements.toTypedArray(), |
| *bodyPreamble.statements.toTypedArray(), |
| *transformed.statements.toTypedArray(), |
| irEndReplaceableGroup(), |
| returnVar?.let { irReturn(declaration.symbol, irGet(it)) } |
| ) |
| ) |
| |
| return declaration |
| } |
| |
| // Composable lambdas are always wrapped with a RestartableFunction class, which has its own |
| // group in the invoke call. As a result, composable lambdas: |
| // 1. receive no group at the root of their body |
| // 2. cannot have default parameters, so have no default handling |
| // 3. they cannot be skipped since we do not know their capture scope, so no skipping logic |
| // 4. proper groups around control flow structures in the body |
| private fun visitComposableLambda( |
| declaration: IrFunction, |
| scope: Scope.FunctionScope, |
| changedParam: IrChangedBitMaskValue |
| ): IrStatement { |
| // no group, since restartableFunction should already create one |
| // no default logic |
| val body = declaration.body!! |
| val skipPreamble = mutableStatementContainer() |
| val bodyPreamble = mutableStatementContainer() |
| |
| val realParams = declaration.valueParameters.take(scope.realValueParamCount) |
| |
| val realParamsIncludingThis = realParams + listOfNotNull( |
| declaration.extensionReceiverParameter |
| ) |
| |
| // boolean array mapped to parameters. true indicates that the type is unstable |
| val unstableMask = realParams.map { |
| !it.type.toKotlinType().isStable() |
| }.toBooleanArray() |
| |
| // we start off assuming that we *can* skip execution of the function |
| var canSkipExecution = unstableMask.none { it } && declaration.returnType.isUnit() |
| |
| // if the function can never skip, or there are no parameters to test, then we |
| // don't need to have the dirty parameter locally since it will never be different from |
| // the passed in `changed` parameter. |
| val dirty = if (canSkipExecution && realParamsIncludingThis.isNotEmpty()) |
| // NOTE(lmr): Technically, dirty is a mutable variable, but we don't want to mark it |
| // as one since that will cause a `Ref<Int>` to get created if it is captured. Since |
| // we know we will never be mutating this variable _after_ it gets captured, we can |
| // safely mark this as `isVar = false`. |
| changedParam.irCopyToTemporary( |
| isVar = false, |
| nameHint = "\$dirty", |
| exactName = true |
| ).also { |
| skipPreamble.statements.addAll(it.asStatements()) |
| } |
| else |
| changedParam |
| |
| scope.dirty = dirty |
| |
| buildStatementsForSkippingAndDefaults( |
| body, |
| skipPreamble, |
| bodyPreamble, |
| canSkipExecution, |
| realParams, |
| scope, |
| dirty, |
| changedParam, |
| null, |
| unstableMask |
| ) |
| |
| val (nonReturningBody, returnVar) = body.asBodyAndResultVar() |
| |
| // we must transform the body first, since that will allow us to see whether or not we |
| // are using the dispatchReceiverParameter or the extensionReceiverParameter |
| val transformed = nonReturningBody.transformChildren() |
| |
| if (declaration.extensionReceiverParameter != null) { |
| canSkipExecution = buildStatementsForSkippingThisParameter( |
| declaration.extensionReceiverParameter!!, |
| scope.extensionReceiverUsed, |
| canSkipExecution, |
| skipPreamble, |
| changedParam, |
| dirty, |
| scope.realValueParamCount |
| ) |
| } |
| |
| if (canSkipExecution) { |
| // We CANNOT skip if any of the following conditions are met |
| // 1. if any of the stable parameters have *differences* from last execution. |
| // 2. if the composer.skipping call returns false |
| val shouldExecute = irOrOr( |
| scope.dirty!!.irHasDifferences(), |
| irNot(irIsSkipping()) |
| ) |
| |
| val transformedBody = irIfThenElse( |
| condition = shouldExecute, |
| thenPart = irBlock( |
| type = context.irBuiltIns.unitType, |
| statements = transformed.statements |
| ), |
| elsePart = irSkipToGroupEnd() |
| ) |
| declaration.body = IrBlockBodyImpl( |
| body.startOffset, |
| body.endOffset, |
| listOfNotNull( |
| *skipPreamble.statements.toTypedArray(), |
| *bodyPreamble.statements.toTypedArray(), |
| transformedBody, |
| returnVar?.let { irReturn(declaration.symbol, irGet(it)) } |
| ) |
| ) |
| } else { |
| declaration.body = IrBlockBodyImpl( |
| body.startOffset, |
| body.endOffset, |
| listOfNotNull( |
| *skipPreamble.statements.toTypedArray(), |
| *bodyPreamble.statements.toTypedArray(), |
| transformed, |
| returnVar?.let { irReturn(declaration.symbol, irGet(it)) } |
| ) |
| ) |
| } |
| |
| return declaration |
| } |
| |
| // Most composable function declarations will be restartable. At a high level, this means |
| // that for this function we: |
| // 1. generate a startRestartGroup and endRestartGroup call around its body |
| // 2. generate an updateScope lambda and call |
| // 3. generate handling of default parameters if necessary |
| // 4. generate skipping logic based on parameters passed into the function |
| // 5. generate groups around control flow structures in the body |
| private fun visitRestartableComposableFunction( |
| declaration: IrFunction, |
| scope: Scope.FunctionScope, |
| changedParam: IrChangedBitMaskValue, |
| defaultParam: IrDefaultBitMaskValue? |
| ): IrStatement { |
| val body = declaration.body!! |
| val skipPreamble = mutableStatementContainer() |
| val bodyPreamble = mutableStatementContainer() |
| |
| // these are the parameters excluding the synthetic ones that we generate for compose. |
| // These are the only parameters we want to consider in skipping calculations |
| val realParams = declaration.valueParameters.take(scope.realValueParamCount) |
| |
| val thisParams = listOfNotNull( |
| declaration.extensionReceiverParameter, |
| declaration.dispatchReceiverParameter |
| ) |
| |
| val realParamsIncludingThis = realParams + thisParams |
| |
| // we start off assuming that we *can* skip execution of the function |
| var canSkipExecution = true |
| |
| // boolean array mapped to parameters. true indicates that the type is unstable |
| val unstableMask = realParams.map { |
| val isStable = (it.varargElementType ?: it.type).toKotlinType().isStable() |
| if (!isStable && !it.hasDefaultValue()) { |
| // if it has non-optional unstable params, the function can never skip |
| canSkipExecution = false |
| } |
| !isStable |
| }.toBooleanArray() |
| |
| // if the function can never skip, or there are no parameters to test, then we |
| // don't need to have the dirty parameter locally since it will never be different from |
| // the passed in `changed` parameter. |
| val dirty = if (canSkipExecution && realParamsIncludingThis.isNotEmpty()) |
| // NOTE(lmr): Technically, dirty is a mutable variable, but we don't want to mark it |
| // as one since that will cause a `Ref<Int>` to get created if it is captured. Since |
| // we know we will never be mutating this variable _after_ it gets captured, we can |
| // safely mark this as `isVar = false`. |
| changedParam.irCopyToTemporary( |
| isVar = false, |
| nameHint = "\$dirty", |
| exactName = true |
| ).also { |
| skipPreamble.statements.addAll(it.asStatements()) |
| } |
| else |
| changedParam |
| |
| scope.dirty = dirty |
| |
| buildStatementsForSkippingAndDefaults( |
| body, |
| skipPreamble, |
| bodyPreamble, |
| canSkipExecution, |
| realParams, |
| scope, |
| dirty, |
| changedParam, |
| defaultParam, |
| unstableMask |
| ) |
| |
| realParams.forEach { |
| // we want to remove the default expression from the function. This will prevent |
| // the kotlin compiler from doing its own default handling, which we don't need. |
| |
| // NOTE: we are doing this AFTER buildStatementsForSkipping, because the default |
| // value is used in those calculations |
| it.defaultValue = null |
| } |
| |
| val (nonReturningBody, returnVar) = body.asBodyAndResultVar() |
| |
| val end = { |
| irEndRestartGroupAndUpdateScope( |
| scope, |
| changedParam, |
| defaultParam, |
| scope.realValueParamCount |
| ) |
| } |
| |
| // we must transform the body first, since that will allow us to see whether or not we |
| // are using the dispatchReceiverParameter or the extensionReceiverParameter |
| val transformed = nonReturningBody.transformChildren() |
| |
| var slotIndex = scope.realValueParamCount |
| |
| if (declaration.extensionReceiverParameter != null) { |
| canSkipExecution = buildStatementsForSkippingThisParameter( |
| declaration.extensionReceiverParameter!!, |
| scope.extensionReceiverUsed, |
| canSkipExecution, |
| skipPreamble, |
| changedParam, |
| dirty, |
| slotIndex++ |
| ) |
| } |
| |
| if (declaration.dispatchReceiverParameter != null) { |
| canSkipExecution = buildStatementsForSkippingThisParameter( |
| declaration.dispatchReceiverParameter!!, |
| scope.dispatchReceiverUsed, |
| canSkipExecution, |
| skipPreamble, |
| changedParam, |
| dirty, |
| slotIndex |
| ) |
| } |
| |
| // if it has non-optional unstable params, the function can never skip, so we always |
| // execute the body. Otherwise, we wrap the body in an if and only skip when certain |
| // conditions are met. |
| val transformedBody = if (canSkipExecution) { |
| // We CANNOT skip if any of the following conditions are met |
| // 1. if any of the stable parameters have *differences* from last execution. |
| // 2. if the composer.skipping call returns false |
| // 3. if any of the provided parameters to the function were unstable |
| |
| // (3) is only necessary to check if we actually have unstable params, so we only |
| // generate that check if we need to. |
| |
| var shouldExecute = irOrOr( |
| scope.dirty!!.irHasDifferences(), |
| irNot(irIsSkipping()) |
| ) |
| val hasAnyUnstableParams = unstableMask.any { it } |
| |
| // if there are unstable params, then we fence the whole expression with a check to |
| // see if any of the unstable params were the ones that were provided to the |
| // function. If they were, then we short-circuit and always execute |
| if (hasAnyUnstableParams && defaultParam != null) { |
| shouldExecute = irOrOr( |
| defaultParam.irHasAnyProvidedAndUnstable(unstableMask), |
| shouldExecute |
| ) |
| } |
| |
| irIfThenElse( |
| condition = shouldExecute, |
| thenPart = irBlock( |
| statements = bodyPreamble.statements + transformed.statements |
| ), |
| elsePart = irSkipToGroupEnd() |
| ) |
| } else irComposite( |
| statements = bodyPreamble.statements + transformed.statements |
| ) |
| |
| scope.realizeGroup(end) |
| |
| declaration.body = IrBlockBodyImpl( |
| body.startOffset, |
| body.endOffset, |
| listOfNotNull( |
| irStartRestartGroup(body, irGet(scope.keyParameter!!)), |
| *skipPreamble.statements.toTypedArray(), |
| transformedBody, |
| if (returnVar == null) end() else null, |
| returnVar?.let { irReturn(declaration.symbol, irGet(it)) } |
| ) |
| ) |
| |
| return declaration |
| } |
| |
| private fun buildStatementsForSkippingThisParameter( |
| thisParam: IrValueParameter, |
| isUsed: Boolean, |
| canSkipExecution: Boolean, |
| preamble: IrStatementContainer, |
| changedParam: IrChangedBitMaskValue, |
| dirty: IrChangedBitMaskValue, |
| index: Int |
| ): Boolean { |
| val type = thisParam.type |
| val isStable = type.toKotlinType().isStable() |
| |
| return when { |
| !isStable && isUsed -> false |
| isStable && isUsed && canSkipExecution && dirty is IrChangedBitMaskVariable -> { |
| preamble.statements.add(irIf( |
| // we only call `$composer.changed(...)` on a parameter if the value came in |
| // with an "Uncertain" state AND the value was provided. This is safe to do |
| // because this will remain true or false for *every* execution of the |
| // function, so we will never get a slot table misalignment as a result. |
| condition = irIsUncertain(changedParam, index), |
| body = dirty.irOrSetBitsAtSlot( |
| index, |
| irIfThenElse( |
| context.irBuiltIns.intType, |
| irChanged(irGet(thisParam)), |
| // if the value has changed, update the bits in the slot to be |
| // "Different" |
| thenPart = irConst(ParamState.Different.bitsForSlot(index)), |
| // if the value has not changed, update the bits in the slot to |
| // be "Same" |
| elsePart = irConst(ParamState.Same.bitsForSlot(index)) |
| ) |
| ) |
| )) |
| true |
| } |
| !isUsed && canSkipExecution && dirty is IrChangedBitMaskVariable -> { |
| // if the param isn't used we can safely ignore it, but if we can skip the |
| // execution of the function, then we need to make sure that we are only |
| // considering the not-ignored parameters. to do this, we set the changed slot bits |
| // to Static |
| preamble.statements.add(dirty.irOrSetBitsAtSlot( |
| index, |
| irConst(ParamState.Static.bitsForSlot(index)) |
| )) |
| } |
| // nothing changes |
| else -> canSkipExecution |
| } |
| } |
| |
| private fun buildStatementsForSkippingAndDefaults( |
| sourceElement: IrElement, |
| skipPreamble: IrStatementContainer, |
| bodyPreamble: IrStatementContainer, |
| canSkipExecution: Boolean, |
| parameters: List<IrValueParameter>, |
| scope: Scope.FunctionScope, |
| dirty: IrChangedBitMaskValue, |
| changedParam: IrChangedBitMaskValue, |
| defaultParam: IrDefaultBitMaskValue?, |
| unstableMask: BooleanArray |
| ) { |
| // we default to true because the absence of a default expression we want to consider as |
| // "static" |
| val defaultExprIsStatic = BooleanArray(parameters.size) { true } |
| val defaultExpr = Array<IrExpression?>(parameters.size) { null } |
| |
| // first we create the necessary local variables for default handling. |
| val setDefaults = mutableStatementContainer() |
| parameters.forEachIndexed { index, param -> |
| val defaultValue = param.defaultValue |
| if (defaultParam != null && defaultValue != null) { |
| val transformedDefault = defaultValue.expression.transform(this, null) |
| |
| // we want to call this on the transformed version. |
| defaultExprIsStatic[index] = transformedDefault.isStatic() |
| defaultExpr[index] = transformedDefault |
| |
| // create a new temporary variable with the same name as the parameter itself |
| // initialized to the parameter value. |
| val varSymbol = if (!canSkipExecution || defaultExprIsStatic[index]) { |
| // If we can't skip execution, or if the expression is static, there's no need |
| // to separate the assignment of the temporary and the declaration. |
| irTemporary( |
| irIfThenElse( |
| param.type, |
| condition = irGetBit(defaultParam, index), |
| // we need to ensure that this transform runs on the default expression. It |
| // could contain conditional logic as well as composable calls |
| thenPart = transformedDefault, |
| elsePart = irGet(param) |
| ), |
| param.name.identifier, |
| param.type, |
| isVar = false, |
| exactName = true |
| ) |
| } else { |
| // If we can skip execution, we want to only execute the default expression |
| // in certain cases. as a result, we first create the temp variable, and then |
| // add the logic to set it to the "setDefaults" container. |
| irTemporary( |
| irGet(param), |
| param.name.identifier, |
| param.type, |
| // NOTE(lmr): technically, we end up mutating this variable in the body of |
| // the function. It turns out that the isVar doesn't validate this, but |
| // it does cause the variable to be wrapped in a `Ref` object if it is |
| // captured by a closure. We do NOT want that, and we know that the code |
| // will be correct without it, so we set `isVar = false` here. |
| isVar = false, |
| exactName = true |
| ).also { |
| setDefaults.statements.add( |
| irIf( |
| condition = irGetBit(defaultParam, index), |
| body = irSet(it, transformedDefault) |
| ) |
| ) |
| } |
| } |
| |
| // semantically, any reference to the parameter symbol now needs to be remapped |
| // to the temporary variable. |
| scope.remappedParams[param] = varSymbol |
| |
| // in order to propagate the change detection we might perform on this parameter, |
| // we need to know which "slot" it is in |
| scope.paramsToSlots[varSymbol] = index |
| skipPreamble.statements.add(varSymbol) |
| } else { |
| scope.remappedParams[param] = param |
| scope.paramsToSlots[param] = index |
| } |
| } |
| // we start the skipPreamble with all of the changed calls. These need to go at the top |
| // of the function's group. Note that these end up getting called *before* default |
| // expressions, but this is okay because it will only ever get called on parameters that |
| // are provided to the function |
| parameters.forEachIndexed { index, param -> |
| // varargs get handled separately because they will require their own groups |
| if (param.isVararg) return@forEachIndexed |
| val defaultValue = param.defaultValue |
| if (canSkipExecution && dirty is IrChangedBitMaskVariable) { |
| if (unstableMask[index]) { |
| if (defaultParam != null && defaultValue != null) { |
| skipPreamble.statements.add( |
| irIf( |
| condition = irGetBit(defaultParam, index), |
| body = dirty.irOrSetBitsAtSlot( |
| index, |
| irConst(ParamState.Same.bitsForSlot(index)) |
| ) |
| ) |
| ) |
| } |
| |
| // if the value is unstable, there is no reason for us to store it in the slot table |
| return@forEachIndexed |
| } |
| |
| val defaultValueIsStatic = defaultExprIsStatic[index] |
| val callChanged = irChanged(irGet(scope.remappedParams[param]!!)) |
| val isChanged = if (defaultParam != null && !defaultValueIsStatic) |
| irAndAnd(irIsProvided(defaultParam, index), callChanged) |
| else |
| callChanged |
| val modifyDirtyFromChangedResult = dirty.irOrSetBitsAtSlot( |
| index, |
| irIfThenElse( |
| context.irBuiltIns.intType, |
| isChanged, |
| // if the value has changed, update the bits in the slot to be |
| // "Different" |
| thenPart = irConst(ParamState.Different.bitsForSlot(index)), |
| // if the value has not changed, update the bits in the slot to |
| // be "Same" |
| elsePart = irConst(ParamState.Same.bitsForSlot(index)) |
| ) |
| ) |
| |
| val stmt = if (defaultParam != null && defaultValueIsStatic) |
| // if the default expression is "static", then we know that if we are using the |
| // default expression, the parameter can be considered "static". |
| irWhen( |
| origin = IrStatementOrigin.IF, |
| branches = listOf( |
| irBranch( |
| condition = irGetBit(defaultParam, index), |
| result = dirty.irOrSetBitsAtSlot( |
| index, |
| irConst(ParamState.Static.bitsForSlot(index)) |
| ) |
| ), |
| irBranch( |
| condition = irIsUncertain(changedParam, index), |
| result = modifyDirtyFromChangedResult |
| ) |
| ) |
| ) |
| else |
| irIf( |
| // we only call `$composer.changed(...)` on a parameter if the value came in |
| // with an "Uncertain" state AND the value was provided. This is safe to do |
| // because this will remain true or false for *every* execution of the |
| // function, so we will never get a slot table misalignment as a result. |
| condition = irIsUncertain(changedParam, index), |
| body = modifyDirtyFromChangedResult |
| ) |
| skipPreamble.statements.add(stmt) |
| } |
| } |
| // now we handle the vararg parameters specially since it needs to create a group |
| parameters.forEachIndexed { index, param -> |
| val varargElementType = param.varargElementType ?: return@forEachIndexed |
| if (canSkipExecution && dirty is IrChangedBitMaskVariable) { |
| if (unstableMask[index]) { |
| // if the value is unstable, there is no reason for us to store it in the slot table |
| return@forEachIndexed |
| } |
| |
| // for vararg parameters of stable type, we can store each value in the slot |
| // table, but need to generate a group since the size of the array could change |
| // over time. In the future, we may want to make an optimization where whether or |
| // not the call site had a spread or not and only create groups if it did. |
| |
| // composer.startReplaceableGroup(values.size) |
| val irGetParamSize = irMethodCall( |
| irGet(param), |
| param.type.classOrNull!!.getPropertyGetter("size")!!.descriptor |
| ) |
| // TODO(lmr): verify this works with default vararg expressions! |
| skipPreamble.statements.add(irStartReplaceableGroup(param, irGetParamSize)) |
| |
| // for (value in values) { |
| // dirty = dirty or if (composer.changed(value)) 0b0100 else 0b0000 |
| // } |
| skipPreamble.statements.add(irForLoop( |
| scope.function.symbol.descriptor, |
| varargElementType, |
| irGet(param) |
| ) { loopVar -> |
| dirty.irOrSetBitsAtSlot( |
| index, |
| irIfThenElse( |
| context.irBuiltIns.intType, |
| irChanged(irGet(loopVar)), |
| // if the value has changed, update the bits in the slot to be |
| // "Different". |
| thenPart = irConst(ParamState.Different.bitsForSlot(index)), |
| // if the value has not changed, we are still uncertain if the entire |
| // list of values has gone unchanged or not, so we use Uncertain |
| elsePart = irConst(ParamState.Uncertain.bitsForSlot(index)) |
| ) |
| ) |
| }) |
| |
| // composer.endReplaceableGroup() |
| skipPreamble.statements.add(irEndReplaceableGroup()) |
| |
| // if (dirty and 0b0110 === 0) { |
| // dirty = dirty or 0b0010 |
| // } |
| skipPreamble.statements.add(irIf( |
| condition = irIsUncertain(dirty, index), |
| body = dirty.irOrSetBitsAtSlot( |
| index, |
| irConst(ParamState.Same.bitsForSlot(index)) |
| ) |
| )) |
| } |
| } |
| // after all of this, we need to potentially wrap the default setters in a group and if |
| // statement, to make sure that defaults are only executed when they need to be. |
| if (!canSkipExecution || defaultExprIsStatic.all { it }) { |
| // if we don't skip execution ever, then we don't need these groups at all. |
| // Additionally, if all of the defaults are static, we can avoid creating the groups |
| // as well. |
| // NOTE(lmr): should we still wrap this in an if statement to be safe??? |
| bodyPreamble.statements.addAll(setDefaults.statements) |
| } else if (setDefaults.statements.isNotEmpty()) { |
| // otherwise, we wrap the whole thing in an if expression with a skip |
| bodyPreamble.statements.add( |
| irIfThenElse( |
| // this prevents us from re-executing the defaults if this function is getting |
| // executed from a recomposition |
| // if (%changed and 0b0001 === 0 || %composer.defaultsInvalid) { |
| condition = irOrOr( |
| irEqual(changedParam.irLowBit(), irConst(0)), |
| irDefaultsInvalid() |
| ), |
| // set all of the default temp vars |
| thenPart = irBlock( |
| statements = listOf( |
| irStartDefaults(sourceElement), |
| *setDefaults.statements.toTypedArray(), |
| irEndDefaults() |
| ) |
| ), |
| // composer.skipCurrentGroup() |
| elsePart = irSkipCurrentGroup() |
| ) |
| ) |
| } |
| } |
| |
| private fun irEndRestartGroupAndUpdateScope( |
| scope: Scope.FunctionScope, |
| changedParam: IrChangedBitMaskValue, |
| defaultParam: IrDefaultBitMaskValue?, |
| numRealValueParameters: Int |
| ): IrExpression { |
| val function = scope.function |
| |
| // Save the dispatch receiver into a temporary created in |
| // the outer scope because direct references to the |
| // receiver sometimes cause an invalid name, "$<this>", to |
| // be generated. |
| val dispatchReceiverParameter = function.dispatchReceiverParameter |
| val outerReceiver = if (dispatchReceiverParameter != null) irTemporary( |
| value = irGet(dispatchReceiverParameter), |
| nameHint = "rcvr" |
| ) else null |
| |
| // Create self-invoke lambda |
| val lambdaDescriptor = AnonymousFunctionDescriptor( |
| function.descriptor, |
| Annotations.EMPTY, |
| CallableMemberDescriptor.Kind.DECLARATION, |
| SourceElement.NO_SOURCE, |
| false |
| ) |
| |
| val passedInComposerParameter = ValueParameterDescriptorImpl( |
| containingDeclaration = lambdaDescriptor, |
| original = null, |
| index = 0, |
| annotations = Annotations.EMPTY, |
| name = KtxNameConventions.COMPOSER_PARAMETER, |
| outType = composerTypeDescriptor.defaultType.makeNullable(), |
| declaresDefaultValue = false, |
| isCrossinline = false, |
| isNoinline = false, |
| varargElementType = null, |
| source = SourceElement.NO_SOURCE |
| ) |
| |
| val keyParameter = ValueParameterDescriptorImpl( |
| containingDeclaration = lambdaDescriptor, |
| original = null, |
| index = 1, |
| annotations = Annotations.EMPTY, |
| name = KtxNameConventions.KEY_PARAMETER, |
| outType = builtIns.int, |
| declaresDefaultValue = false, |
| isCrossinline = false, |
| isNoinline = false, |
| varargElementType = null, |
| source = SourceElement.NO_SOURCE |
| ) |
| |
| val ignoredChangedParameter = ValueParameterDescriptorImpl( |
| containingDeclaration = lambdaDescriptor, |
| original = null, |
| index = 2, |
| annotations = Annotations.EMPTY, |
| name = KtxNameConventions.CHANGED_PARAMETER, |
| outType = builtIns.int, |
| declaresDefaultValue = false, |
| isCrossinline = false, |
| isNoinline = false, |
| varargElementType = null, |
| source = SourceElement.NO_SOURCE |
| ) |
| |
| lambdaDescriptor.apply { |
| initialize( |
| null, |
| null, |
| emptyList(), |
| listOf(passedInComposerParameter, keyParameter, ignoredChangedParameter), |
| updateScopeBlockType, |
| Modality.FINAL, |
| Visibilities.LOCAL |
| ) |
| } |
| |
| val parameterCount = function.symbol.descriptor.valueParameters.size |
| val keyIndex = numRealValueParameters + 1 |
| val changedIndex = keyIndex + 1 |
| val defaultIndex = changedIndex + changedParamCount( |
| numRealValueParameters, |
| function.thisParamCount |
| ) |
| |
| if (defaultParam == null) { |
| require(parameterCount == defaultIndex) // param count is 1-based, index is 0-based |
| } else { |
| require(parameterCount == defaultIndex + defaultParamCount(numRealValueParameters)) |
| } |
| |
| val lambda = IrFunctionImpl( |
| UNDEFINED_OFFSET, UNDEFINED_OFFSET, |
| IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA, |
| IrSimpleFunctionSymbolImpl(lambdaDescriptor), |
| context.irBuiltIns.unitType |
| ).also { fn -> |
| fn.parent = function |
| val localIrBuilder = DeclarationIrBuilder(context, fn.symbol) |
| fn.addValueParameter( |
| KtxNameConventions.COMPOSER_PARAMETER.identifier, |
| composerTypeDescriptor |
| .defaultType |
| .replaceArgumentsWithStarProjections() |
| .toIrType() |
| .makeNullable() |
| ) |
| fn.addValueParameter( |
| KtxNameConventions.KEY_PARAMETER.identifier, |
| context.irBuiltIns.intType |
| ) |
| fn.addValueParameter( |
| "\$force", |
| context.irBuiltIns.intType |
| ) |
| fn.body = localIrBuilder.irBlockBody { |
| |
| fun remappedParam(index: Int) = function.valueParameters[index].let { |
| scope.remappedParams[it] ?: it |
| } |
| |
| // Call the function again with the same parameters |
| +irReturn(irCall(function.symbol).apply { |
| symbol.owner |
| .valueParameters |
| .forEachIndexed { index, param -> |
| if (param.isVararg) { |
| putValueArgument( |
| index, |
| IrVarargImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| param.type, |
| param.varargElementType!!, |
| elements = listOf( |
| IrSpreadElementImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| irGet(remappedParam(index)) |
| ) |
| ) |
| ) |
| ) |
| } else { |
| // NOTE(lmr): should we be using the parameter here, or the temporary |
| // with the default value? |
| putValueArgument(index, irGet(remappedParam(index))) |
| } |
| } |
| |
| // new composer |
| putValueArgument( |
| numRealValueParameters, |
| irGet(fn.valueParameters[0]) |
| ) |
| |
| putValueArgument( |
| keyIndex, |
| irGet(fn.valueParameters[1]) |
| ) |
| |
| // the call in updateScope needs to *always* have the low bit set to 1. |
| // This ensures that the body of the function is actually executed. |
| changedParam.putAsValueArgumentInWithLowBit( |
| this, |
| changedIndex, |
| lowBit = true |
| ) |
| |
| defaultParam?.putAsValueArgumentIn(this, defaultIndex) |
| |
| extensionReceiver = function.extensionReceiverParameter?.let { irGet(it) } |
| dispatchReceiver = outerReceiver?.let { irGet(it) } |
| function.typeParameters.forEachIndexed { index, parameter -> |
| putTypeArgument(index, parameter.defaultType) |
| } |
| }) |
| } |
| } |
| |
| // $composer.endRestartGroup()?.updateScope { next -> TheFunction(..., next) } |
| return irBlock( |
| statements = listOfNotNull( |
| outerReceiver, |
| irSafeCall( |
| irEndRestartGroup(), |
| updateScopeDescriptor, |
| irLambda(lambda, updateScopeBlockType.toIrType()) |
| ) |
| ) |
| ) |
| } |
| |
| private fun irIsSkipping() = |
| irMethodCall(irCurrentComposer(), isSkippingDescriptor.getter!!) |
| private fun irDefaultsInvalid() = |
| irMethodCall(irCurrentComposer(), defaultsInvalidDescriptor.getter!!) |
| |
| private fun irIsProvided(default: IrDefaultBitMaskValue, slot: Int) = |
| irEqual(default.irIsolateBitAtIndex(slot), irConst(0)) |
| |
| // %changed and 0b11 == 0 |
| private fun irIsUncertain(changed: IrChangedBitMaskValue, slot: Int) = irEqual( |
| changed.irIsolateBitsAtSlot(slot), |
| irConst(0) |
| ) |
| |
| @Suppress("SameParameterValue") |
| private fun irBitsForSlot(bits: Int, slot: Int): IrExpression { |
| return irConst(bitsForSlot(bits, slot)) |
| } |
| |
| private fun IrExpression.endsWithReturnOrJump(): Boolean { |
| var expr: IrStatement? = this |
| while (expr != null) { |
| if (expr is IrReturn) return true |
| if (expr is IrBreakContinue) return true |
| if (expr !is IrBlock) return false |
| expr = expr.statements.lastOrNull() |
| } |
| return false |
| } |
| |
| private fun IrBody.asBodyAndResultVar(): Pair<IrContainerExpressionBase, IrVariable?> { |
| val original = IrCompositeImpl( |
| startOffset, |
| endOffset, |
| context.irBuiltIns.unitType, |
| null, |
| statements |
| ) |
| var block: IrStatementContainer? = original |
| var expr: IrStatement? = block?.statements?.lastOrNull() |
| while (expr != null && block != null) { |
| if (expr is IrReturn) { |
| block.statements.pop() |
| return if (expr.value.type.isUnitOrNullableUnit()) { |
| block.statements.add(expr.value) |
| original to null |
| } else { |
| val temp = irTemporary(expr.value) |
| block.statements.add(temp) |
| original to temp |
| } |
| } |
| if (expr !is IrBlock) |
| return original to null |
| block = expr |
| expr = block.statements.lastOrNull() |
| } |
| return original to null |
| } |
| |
| override fun visitProperty(declaration: IrProperty): IrStatement { |
| val scope = Scope.PropertyScope(declaration.name) |
| try { |
| scopeStack.push(scope) |
| return super.visitProperty(declaration) |
| } finally { |
| val popped = scopeStack.pop() |
| require(popped == scope) { "Unbalanced scope stack" } |
| } |
| } |
| |
| override fun visitField(declaration: IrField): IrStatement { |
| val scope = Scope.FieldScope(declaration.name) |
| try { |
| scopeStack.push(scope) |
| return super.visitField(declaration) |
| } finally { |
| val popped = scopeStack.pop() |
| require(popped == scope) { "Unbalanced scope stack" } |
| } |
| } |
| |
| override fun visitFile(declaration: IrFile): IrFile { |
| val scope = Scope.FileScope(declaration.fqName) |
| try { |
| scopeStack.push(scope) |
| return super.visitFile(declaration) |
| } finally { |
| val popped = scopeStack.pop() |
| require(popped == scope) { "Unbalanced scope stack" } |
| } |
| } |
| |
| override fun visitDeclaration(declaration: IrDeclaration): IrStatement { |
| when (declaration) { |
| is IrField, |
| is IrProperty, |
| is IrFunction, |
| is IrClass -> { |
| // these declarations get scopes, but they are handled individually |
| return super.visitDeclaration(declaration) |
| } |
| is IrTypeAlias, |
| is IrEnumEntry, |
| is IrAnonymousInitializer, |
| is IrTypeParameter, |
| is IrLocalDelegatedProperty, |
| is IrValueDeclaration -> { |
| // these declarations do not create new "scopes", so we do nothing |
| return super.visitDeclaration(declaration) |
| } |
| else -> error("Unhandled declaration! ${declaration::class.java.simpleName}") |
| } |
| } |
| |
| private fun nearestComposer(): IrValueParameter { |
| loop@ for (scope in scopeStack.asReversed()) { |
| when (scope) { |
| is Scope.FunctionScope -> { |
| val result = scope.composerParameter |
| if (result != null) return result |
| } |
| } |
| } |
| error("Not in a composable function \n${printScopeStack()}") |
| } |
| |
| private fun irCurrentComposer(): IrExpression { |
| return IrGetValueImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| nearestComposer().symbol |
| ) |
| } |
| |
| private fun IrElement.sourceKey(): Int { |
| var hash = currentFunctionScope |
| .function |
| .symbol |
| .descriptor |
| .fqNameSafe |
| .toString() |
| .hashCode() |
| hash = 31 * hash + startOffset |
| return hash |
| } |
| |
| private fun IrElement.irSourceKey(): IrConst<Int> { |
| return IrConstImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| context.irBuiltIns.intType, |
| IrConstKind.Int, |
| sourceKey() |
| ) |
| } |
| |
| private fun irStartReplaceableGroup( |
| element: IrElement, |
| key: IrExpression = element.irSourceKey() |
| ): IrExpression { |
| return irMethodCall( |
| irCurrentComposer(), |
| startReplaceableDescriptor, |
| element.startOffset, |
| element.endOffset |
| ).also { |
| it.putValueArgument(0, key) |
| } |
| } |
| |
| private fun irStartDefaults(element: IrElement): IrExpression { |
| return irMethodCall( |
| irCurrentComposer(), |
| startDefaultsDescriptor, |
| element.startOffset, |
| element.endOffset |
| ) |
| } |
| |
| private fun irStartRestartGroup( |
| element: IrElement, |
| key: IrExpression = element.irSourceKey() |
| ): IrExpression { |
| return irMethodCall( |
| irCurrentComposer(), |
| startRestartGroupDescriptor, |
| element.startOffset, |
| element.endOffset |
| ).also { |
| it.putValueArgument(0, key) |
| } |
| } |
| |
| private fun irEndRestartGroup(): IrExpression { |
| return irMethodCall(irCurrentComposer(), endRestartGroupDescriptor) |
| } |
| |
| private fun irChanged(value: IrExpression): IrExpression { |
| return irMethodCall(irCurrentComposer(), changedDescriptor).also { |
| it.putValueArgument(0, value) |
| it.putTypeArgument(0, value.type) |
| } |
| } |
| |
| private fun irSkipToGroupEnd(): IrExpression { |
| return irMethodCall(irCurrentComposer(), skipToGroupEndDescriptor) |
| } |
| |
| private fun irSkipCurrentGroup(): IrExpression { |
| return irMethodCall(irCurrentComposer(), skipCurrentGroupDescriptor) |
| } |
| |
| private fun irEndReplaceableGroup(): IrExpression { |
| return irMethodCall(irCurrentComposer(), endReplaceableDescriptor) |
| } |
| |
| private fun irEndDefaults(): IrExpression { |
| return irMethodCall(irCurrentComposer(), endDefaultsDescriptor) |
| } |
| |
| private fun irStartMovableGroup(element: IrElement, joinedData: IrExpression): IrExpression { |
| return irMethodCall( |
| irCurrentComposer(), |
| startMovableDescriptor, |
| element.startOffset, |
| element.endOffset |
| ).also { |
| it.putValueArgument(0, element.irSourceKey()) |
| it.putValueArgument(1, joinedData) |
| } |
| } |
| |
| private fun irEndMovableGroup(): IrExpression { |
| return irMethodCall(irCurrentComposer(), endMovableDescriptor) |
| } |
| |
| private fun irJoinKeyChain(keyExprs: List<IrExpression>): IrExpression { |
| return keyExprs.reduce { accumulator, value -> |
| irMethodCall(irCurrentComposer(), joinKeyDescriptor).apply { |
| putValueArgument(0, accumulator) |
| putValueArgument(1, value) |
| } |
| } |
| } |
| |
| private fun irSafeCall( |
| target: IrExpression, |
| descriptor: FunctionDescriptor, |
| vararg args: IrExpression |
| ): IrExpression { |
| val tmpVal = irTemporary(target, nameHint = "safe_receiver") |
| return irBlock( |
| origin = IrStatementOrigin.SAFE_CALL, |
| statements = listOf( |
| tmpVal, |
| irIfThenElse( |
| condition = irEqual(irGet(tmpVal), irNull()), |
| thenPart = irNull(), |
| elsePart = irCall(descriptor).apply { |
| dispatchReceiver = irGet(tmpVal) |
| args.forEachIndexed { i, arg -> |
| putValueArgument(i, arg) |
| } |
| } |
| ) |
| ) |
| ) |
| } |
| |
| private fun irCall( |
| descriptor: FunctionDescriptor, |
| startOffset: Int = UNDEFINED_OFFSET, |
| endOffset: Int = UNDEFINED_OFFSET |
| ): IrCall { |
| val type = descriptor.returnType?.toIrType() ?: error("Expected a return type") |
| val symbol = referenceFunction(descriptor) |
| return IrCallImpl( |
| startOffset, |
| endOffset, |
| type, |
| symbol |
| ) |
| } |
| |
| private fun irMethodCall( |
| target: IrExpression, |
| descriptor: FunctionDescriptor, |
| startOffset: Int = UNDEFINED_OFFSET, |
| endOffset: Int = UNDEFINED_OFFSET |
| ): IrCall { |
| return irCall(descriptor, startOffset, endOffset).apply { |
| dispatchReceiver = target |
| } |
| } |
| |
| private fun irTemporary( |
| value: IrExpression, |
| nameHint: String? = null, |
| irType: IrType = value.type, |
| isVar: Boolean = false, |
| exactName: Boolean = false |
| ): IrVariableImpl { |
| val scope = currentFunctionScope |
| val name = if (exactName && nameHint != null) |
| nameHint |
| else |
| scope.getNameForTemporary(nameHint) |
| return irTemporary( |
| scope.function.symbol.descriptor, |
| value, |
| name, |
| irType, |
| isVar |
| ) |
| } |
| |
| private fun IrExpression.asReplaceableGroup(scope: Scope.BlockScope): IrExpression { |
| // if the scope has no composable calls, then the only important thing is that a |
| // start/end call gets executed. as a result, we can just put them both at the top of |
| // the group, and we don't have to deal with any of the complicated jump logic that |
| // could be inside of the block |
| if (!scope.hasComposableCalls && !scope.hasReturn && !scope.hasJump) { |
| return wrap( |
| before = listOf(irStartReplaceableGroup(this), irEndReplaceableGroup()) |
| ) |
| } |
| scope.realizeGroup(::irEndReplaceableGroup) |
| return when { |
| // if the scope ends with a return call, then it will get properly ended if we |
| // just push the end call on the scope because of the way returns get transformed in |
| // this class. As a result, here we can safely just "prepend" the start call |
| endsWithReturnOrJump() -> { |
| wrap(before = listOf(irStartReplaceableGroup(this))) |
| } |
| // otherwise, we want to push an end call for any early returns/jumps, but also add |
| // an end call to the end of the group |
| else -> { |
| wrap( |
| before = listOf(irStartReplaceableGroup(this)), |
| after = listOf(irEndReplaceableGroup()) |
| ) |
| } |
| } |
| } |
| |
| private fun IrExpression.wrap( |
| before: List<IrExpression> = emptyList(), |
| after: List<IrExpression> = emptyList() |
| ): IrExpression { |
| return if (after.isEmpty() || type.isNothing() || type.isUnitOrNullableUnit()) { |
| wrap(type, before, after) |
| } else { |
| val tmpVar = irTemporary(this, nameHint = "group") |
| tmpVar.wrap( |
| type, |
| before, |
| after + irGet(tmpVar) |
| ) |
| } |
| } |
| |
| private fun IrStatement.wrap( |
| type: IrType, |
| before: List<IrExpression> = emptyList(), |
| after: List<IrExpression> = emptyList() |
| ): IrExpression { |
| return IrBlockImpl( |
| startOffset, |
| endOffset, |
| type, |
| null, |
| before + this + after |
| ) |
| } |
| |
| private fun IrExpression.asCoalescableGroup(scope: Scope.BlockScope): IrExpression { |
| val before = mutableStatementContainer() |
| val after = mutableStatementContainer() |
| |
| // Since this expression produces a dynamic number of groups, we may need to wrap it with |
| // a group directly. We don't know that for sure yet, so we provide the parent scope with |
| // handlers to do that if it ends up needing to. |
| encounteredCoalescableGroup( |
| scope, |
| realizeGroup = { |
| before.statements.add(irStartReplaceableGroup(this)) |
| after.statements.add(irEndReplaceableGroup()) |
| }, |
| makeEnd = ::irEndReplaceableGroup |
| ) |
| return wrap( |
| type, |
| listOf(before), |
| listOf(after) |
| ) |
| } |
| |
| private fun mutableStatementContainer(): IrContainerExpressionBase { |
| // NOTE(lmr): It's important to use IrComposite here so that we don't introduce any new |
| // scopes |
| return IrCompositeImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| context.irBuiltIns.unitType |
| ) |
| } |
| |
| private fun encounteredComposableCall() { |
| loop@ for (scope in scopeStack.asReversed()) { |
| when (scope) { |
| is Scope.FunctionScope -> { |
| scope.markComposableCall() |
| break@loop |
| } |
| is Scope.BlockScope -> { |
| scope.markComposableCall() |
| } |
| is Scope.ClassScope -> { |
| break@loop |
| } |
| } |
| } |
| } |
| |
| private fun encounteredCoalescableGroup( |
| coalescableScope: Scope.BlockScope, |
| realizeGroup: () -> Unit, |
| makeEnd: () -> IrExpression |
| ) { |
| loop@ for (scope in scopeStack.asReversed()) { |
| when (scope) { |
| is Scope.BlockScope -> { |
| scope.markCoalescableGroup(coalescableScope, realizeGroup, makeEnd) |
| break@loop |
| } |
| else -> error("Unexpected scope type") |
| } |
| } |
| } |
| |
| private fun encounteredReturn( |
| symbol: IrReturnTargetSymbol, |
| extraEndLocation: (IrExpression) -> Unit |
| ) { |
| loop@ for (scope in scopeStack.asReversed()) { |
| when (scope) { |
| is Scope.FunctionScope -> { |
| scope.markReturn(extraEndLocation) |
| if (scope.function == symbol.owner) { |
| break@loop |
| } else { |
| TODO("Need to handle nested returns!") |
| } |
| } |
| is Scope.BlockScope -> { |
| scope.markReturn(extraEndLocation) |
| } |
| } |
| } |
| } |
| |
| private fun encounteredJump(jump: IrBreakContinue, extraEndLocation: (IrExpression) -> Unit) { |
| loop@ for (scope in scopeStack.asReversed()) { |
| when (scope) { |
| is Scope.FunctionScope -> error("Unexpected Function Scope encountered") |
| is Scope.ClassScope -> error("Unexpected Function Scope encountered") |
| is Scope.LoopScope -> { |
| if (jump.loop == scope.loop) { |
| break@loop |
| } |
| scope.markJump(extraEndLocation) |
| } |
| is Scope.BlockScope -> { |
| scope.markJump(extraEndLocation) |
| } |
| } |
| } |
| } |
| |
| private fun <T : Scope> IrExpression.transformWithScope(scope: T): Pair<T, IrExpression> { |
| try { |
| scopeStack.push(scope) |
| val result = transform(this@ComposableFunctionBodyTransformer, null) |
| return scope to result |
| } finally { |
| require(scopeStack.pop() === scope) |
| } |
| } |
| |
| private inline fun <T : Scope> withScope(scope: T, block: () -> Unit): T { |
| scopeStack.push(scope) |
| try { |
| block() |
| } finally { |
| require(scopeStack.pop() === scope) |
| } |
| return scope |
| } |
| |
| data class ParamMeta( |
| var isVararg: Boolean = false, |
| var isProvided: Boolean = false, |
| var isStatic: Boolean = false, |
| var isCertain: Boolean = false, |
| var maskSlot: Int = -1, |
| var maskParam: IrChangedBitMaskValue? = null |
| ) |
| |
| fun paramMetaOf(arg: IrExpression, isProvided: Boolean): ParamMeta { |
| val meta = ParamMeta(isProvided = isProvided) |
| populateParamMeta(arg, meta) |
| return meta |
| } |
| |
| private fun populateParamMeta(arg: IrExpression, meta: ParamMeta) { |
| when { |
| arg.isStatic() -> meta.isStatic = true |
| arg is IrGetValue -> { |
| val owner = arg.symbol.owner |
| val found = extractParamMetaFromScopes(meta, owner) |
| if (!found) { |
| when (owner) { |
| is IrVariable -> { |
| if (owner.isConst) { |
| meta.isStatic = true |
| } else if (!owner.isVar && owner.initializer != null) { |
| populateParamMeta(owner.initializer!!, meta) |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| override fun visitBlock(expression: IrBlock): IrExpression { |
| return when (expression.origin) { |
| IrStatementOrigin.FOR_LOOP -> { |
| // The psi2ir phase will turn for loops into a block, so: |
| // |
| // for (loopVar in <someIterable>) |
| // |
| // gets transformed into |
| // |
| // // #1: The "header" |
| // val it = <someIterable>.iterator() |
| // |
| // // #2: The inner while loop |
| // while (it.hasNext()) { |
| // val loopVar = it.next() |
| // // Loop body |
| // } |
| // |
| // Additionally, the IR lowering phase will take this block and optimize it |
| // for some shapes of for loops. What we want to do is keep this original |
| // shape in tact so that we don't ruin some of these optimizations. |
| val statements = expression.statements |
| |
| require(statements.size == 2) { |
| "Expected 2 statements in for-loop block" |
| } |
| val oldVar = statements[0] as IrVariable |
| require(oldVar.origin == IrDeclarationOrigin.FOR_LOOP_ITERATOR) { |
| "Expected FOR_LOOP_ITERATOR origin for iterator variable" |
| } |
| val newVar = oldVar.transform(this, null) |
| |
| val oldLoop = statements[1] as IrWhileLoop |
| require(oldLoop.origin == IrStatementOrigin.FOR_LOOP_INNER_WHILE) { |
| "Expected FOR_LOOP_INNER_WHILE origin for while loop" |
| } |
| |
| val newLoop = oldLoop.transform(this, null) |
| |
| if (newVar == oldVar && newLoop == oldLoop) |
| expression |
| else if (newLoop is IrBlock) { |
| require(newLoop.statements.size == 3) |
| val before = newLoop.statements[0] as IrContainerExpressionBase |
| val loop = newLoop.statements[1] as IrWhileLoop |
| val after = newLoop.statements[2] as IrContainerExpressionBase |
| |
| val result = mutableStatementContainer() |
| result.statements.addAll(listOf( |
| before, |
| irBlock( |
| type = expression.type, |
| origin = IrStatementOrigin.FOR_LOOP, |
| statements = listOf( |
| newVar, |
| loop |
| ) |
| ), |
| after |
| )) |
| result |
| } else { |
| error("Expected transformed loop to be an IrBlock") |
| } |
| } |
| else -> super.visitBlock(expression) |
| } |
| } |
| |
| override fun visitCall(expression: IrCall): IrExpression { |
| val emitMetadata = context.irTrace[ |
| ComposeWritableSlices.COMPOSABLE_EMIT_METADATA, |
| expression |
| ] |
| if (emitMetadata != null) { |
| return visitEmitCall(expression) |
| } |
| if (expression.isTransformedComposableCall() || expression.isSyntheticComposableCall()) { |
| return visitComposableCall(expression) |
| } |
| return super.visitCall(expression) |
| } |
| |
| private fun visitComposableCall(expression: IrCall): IrExpression { |
| encounteredComposableCall() |
| if (expression.symbol.descriptor.fqNameSafe == ComposeFqNames.key) { |
| return visitKeyCall(expression) |
| } |
| // it's important that we transform all of the parameters here since this will cause the |
| // IrGetValue's of remapped default parameters to point to the right variable. |
| expression.transformChildrenVoid() |
| |
| val ownerFn = expression.symbol.owner |
| val numValueParams = ownerFn.valueParameters.size |
| val numDefaults: Int |
| val numChanged: Int |
| val numRealValueParams: Int |
| |
| if (expression.origin == IrStatementOrigin.INVOKE) { |
| // in the case of an invoke, all of the parameters are going to be type parameter |
| // args which won't have special names. In this case, we know that the values cannot |
| // be defaulted though, so we can calculate the number of real parameters based on |
| // the total number of parameters |
| numDefaults = 0 |
| numChanged = changedParamCountFromTotal(numValueParams + ownerFn.thisParamCount) |
| numRealValueParams = numValueParams - |
| 1 - // composer param |
| 1 - // key param |
| numChanged |
| } else { |
| val hasDefaults = ownerFn.valueParameters.any { |
| it.name == KtxNameConventions.DEFAULT_PARAMETER |
| } |
| numRealValueParams = ownerFn.valueParameters.indexOfLast { |
| !it.name.asString().startsWith('$') |
| } + 1 |
| numDefaults = if (hasDefaults) defaultParamCount(numRealValueParams) else 0 |
| numChanged = changedParamCount(numRealValueParams, ownerFn.thisParamCount) |
| } |
| |
| require( |
| numRealValueParams + |
| 1 + // composer param |
| 1 + // key param |
| numChanged + |
| numDefaults == numValueParams) |
| |
| val composerIndex = numRealValueParams |
| val keyIndex = composerIndex + 1 |
| val changedArgIndex = keyIndex + 1 |
| val defaultArgIndex = changedArgIndex + numChanged |
| val defaultArgs = (defaultArgIndex until numValueParams).map { |
| expression.getValueArgument(it) |
| } |
| |
| val defaultMasks = defaultArgs.map { |
| when (it) { |
| !is IrConst<*> -> error("Expected default mask to be a const") |
| else -> it.value as? Int ?: error("Expected default mask to be an Int") |
| } |
| }.ifEmpty { listOf(0b0) } |
| |
| val paramMeta = mutableListOf<ParamMeta>() |
| |
| for (index in 0 until numRealValueParams) { |
| val arg = expression.getValueArgument(index) |
| if (arg == null) { |
| val param = expression.symbol.owner.valueParameters[index] |
| if (param.varargElementType == null) { |
| // ComposerParamTransformer should not allow for any null arguments on a composable |
| // invocation unless the parameter is vararg. If this is null here, we have |
| // missed something. |
| error("Unexpected null argument for composable call") |
| } else { |
| paramMeta.add(ParamMeta(isVararg = true)) |
| continue |
| } |
| } |
| val bitIndex = defaultsBitIndex(index) |
| val maskValue = defaultMasks[defaultsParamIndex(index)] |
| val meta = paramMetaOf(arg, isProvided = maskValue and (0b1 shl bitIndex) == 0) |
| |
| paramMeta.add(meta) |
| } |
| |
| val extensionMeta = expression.extensionReceiver?.let { paramMetaOf(it, isProvided = true) } |
| val dispatchMeta = expression.dispatchReceiver?.let { paramMetaOf(it, isProvided = true) } |
| |
| val changedParams = buildChangedParamsForCall( |
| paramMeta, |
| extensionMeta, |
| dispatchMeta |
| ) |
| |
| expression.putValueArgument( |
| keyIndex, |
| expression.irSourceKey() |
| ) |
| |
| changedParams.forEachIndexed { i, param -> |
| expression.putValueArgument(changedArgIndex + i, param) |
| } |
| |
| return expression |
| } |
| |
| private fun visitEmitCall(expression: IrCall): IrExpression { |
| encounteredComposableCall() |
| // TODO(lmr): eventually, we want to handle emits in this transform |
| return super.visitCall(expression) |
| } |
| |
| private fun visitKeyCall(expression: IrCall): IrExpression { |
| val keyArgs = mutableListOf<IrExpression>() |
| var blockArg: IrExpression? = null |
| for (i in 0 until expression.valueArgumentsCount) { |
| val param = expression.symbol.owner.valueParameters[i] |
| val arg = expression.getValueArgument(i) |
| ?: error("Unexpected null argument found on key call") |
| if (param.name.asString().startsWith('$')) |
| // we are done. synthetic args go at |
| // the end |
| break |
| |
| when { |
| param.name.identifier == "block" -> { |
| blockArg = arg |
| } |
| arg is IrVararg -> { |
| keyArgs.addAll(arg.elements.mapNotNull { it as? IrExpression }) |
| } |
| else -> { |
| keyArgs.add(arg) |
| } |
| } |
| } |
| val before = mutableStatementContainer() |
| val after = mutableStatementContainer() |
| |
| if (blockArg !is IrFunctionExpression) error("Expected function expression") |
| |
| val (block, resultVar) = blockArg.function.body!!.asBodyAndResultVar() |
| |
| var transformed: IrExpression = block |
| |
| withScope(Scope.BranchScope()) { |
| transformed = transformed.transform(this, null) |
| } |
| |
| return irBlock( |
| type = expression.type, |
| statements = listOfNotNull( |
| before, |
| irStartMovableGroup( |
| expression, |
| irJoinKeyChain(keyArgs.map { it.transform(this, null) }) |
| ), |
| block, |
| irEndMovableGroup(), |
| after, |
| resultVar?.let { irGet(resultVar) } |
| ) |
| ) |
| } |
| |
| private fun IrExpression.isStatic(): Boolean { |
| return when (this) { |
| // A constant by definition is static |
| is IrConst<*> -> true |
| // We want to consider all enum values as static |
| is IrGetEnumValue -> true |
| // Getting a companion object or top level object can be considered static if the |
| // type of that object is Stable. (`Modifier` for instance is a common example) |
| is IrGetObjectValue -> symbol.owner.superTypes.any { it.toKotlinType().isStable() } |
| is IrConstructorCall -> { |
| // special case constructors of inline classes as static if their underlying |
| // value is static. |
| if ( |
| type.isInlined() && |
| type.unboxInlineClass().toKotlinType().isStable() && |
| getValueArgument(0)?.isStatic() == true |
| ) { |
| return true |
| } |
| false |
| } |
| is IrCall -> when (origin) { |
| is IrStatementOrigin.GET_PROPERTY -> { |
| // If we are in a GET_PROPERTY call, then this should usually resolve to |
| // non-null, but in case it doesn't, just return false |
| val prop = (symbol.owner as? IrSimpleFunction) |
| ?.correspondingPropertySymbol?.owner ?: return false |
| |
| // if the property is a top level constant, then it is static. |
| if (prop.isConst) return true |
| |
| val typeIsStable = type.toKotlinType().isStable() |
| val dispatchReceiverIsStatic = dispatchReceiver?.isStatic() != false |
| val extensionReceiverIsStatic = extensionReceiver?.isStatic() != false |
| |
| // if we see that the property is read-only with a default getter and a |
| // stable return type , then reading the property can also be considered |
| // static if this is a top level property or the subject is also static. |
| if (!prop.isVar && |
| prop.getter?.origin == IrDeclarationOrigin.DEFAULT_PROPERTY_ACCESSOR && |
| typeIsStable && |
| dispatchReceiverIsStatic && extensionReceiverIsStatic |
| ) { |
| return true |
| } |
| |
| val getterIsStable = prop.hasStableAnnotation() || |
| symbol.owner.hasStableAnnotation() |
| |
| if ( |
| getterIsStable && |
| typeIsStable && |
| dispatchReceiverIsStatic && |
| extensionReceiverIsStatic |
| ) { |
| return true |
| } |
| |
| false |
| } |
| is IrStatementOrigin.PLUS, |
| is IrStatementOrigin.MUL, |
| is IrStatementOrigin.MINUS, |
| is IrStatementOrigin.ANDAND, |
| is IrStatementOrigin.OROR, |
| is IrStatementOrigin.DIV, |
| is IrStatementOrigin.EQ, |
| is IrStatementOrigin.EQEQ, |
| is IrStatementOrigin.EQEQEQ, |
| is IrStatementOrigin.GT, |
| is IrStatementOrigin.GTEQ, |
| is IrStatementOrigin.LT, |
| is IrStatementOrigin.LTEQ -> { |
| // special case mathematical operators that are in the stdlib. These are |
| // immutable operations so the overall result is static if the operands are |
| // also static |
| val isStableOperator = symbol |
| .descriptor |
| .fqNameSafe |
| .topLevelName() == "kotlin" || |
| symbol.owner.hasStableAnnotation() |
| |
| val typeIsStable = type.toKotlinType().isStable() |
| if (!typeIsStable) return false |
| |
| if (!isStableOperator) { |
| return false |
| } |
| |
| getArguments().all { it.second.isStatic() } |
| } |
| null -> { |
| // normal function call. If the function is marked as Stable and the result |
| // is Stable, then the static-ness of it is the static-ness of its arguments |
| val isStable = symbol.owner.hasStableAnnotation() |
| if (!isStable) return false |
| |
| val typeIsStable = type.toKotlinType().isStable() |
| if (!typeIsStable) return false |
| |
| // getArguments includes the receivers! |
| getArguments().all { it.second.isStatic() } |
| } |
| else -> false |
| } |
| is IrGetValue -> { |
| val owner = symbol.owner |
| when (owner) { |
| is IrVariable -> { |
| // If we have an immutable variable whose initializer is also static, |
| // then we can determine that the variable reference is also static. |
| !owner.isVar && owner.initializer?.isStatic() == true |
| } |
| else -> false |
| } |
| } |
| else -> false |
| } |
| } |
| |
| private fun extractParamMetaFromScopes(meta: ParamMeta, param: IrValueDeclaration): Boolean { |
| for (scope in scopeStack.asReversed()) { |
| when (scope) { |
| is Scope.FunctionScope -> { |
| if (scope.remappedParams.containsValue(param)) { |
| meta.isCertain = true |
| meta.maskParam = scope.dirty |
| meta.maskSlot = scope.paramsToSlots[param]!! |
| return true |
| } |
| } |
| } |
| } |
| return false |
| } |
| |
| private fun buildChangedParamsForCall( |
| valueParams: List<ParamMeta>, |
| extensionParam: ParamMeta?, |
| dispatchParam: ParamMeta? |
| ): List<IrExpression> { |
| val thisParams = listOfNotNull(extensionParam, dispatchParam) |
| val allParams = valueParams + thisParams |
| // passing in 0 for thisParams since they should be included in the params list |
| val changedCount = changedParamCount(valueParams.size, thisParams.size) |
| val result = mutableListOf<IrExpression>() |
| for (i in 0 until changedCount) { |
| val start = i * SLOTS_PER_INT |
| val end = min(start + SLOTS_PER_INT, allParams.size) |
| val slice = allParams.subList(start, end) |
| result.add(buildChangedParamForCall(slice)) |
| } |
| return result |
| } |
| |
| private fun buildChangedParamForCall(params: List<ParamMeta>): IrExpression { |
| // The general pattern here is: |
| // |
| // $changed = bitMaskConstant or |
| // (0b11 and someMask shl y) or |
| // (0b1100 and someMask shl x) or |
| // ... |
| // (0b11000000 and someMask shr z) |
| // |
| // where `bitMaskConstant` is created in this function based on |
| // all of the static (constant) params and uncertain params (not direct parameter pass |
| // throughs). The other params have had their state made "certain" by the preamble checks |
| // in a composable function in scope. We can extract that state directly by pulling out |
| // the specific slot state from that function's dirty parameter (represented as |
| // `someMask` here, and then shifting the resulting bit mask over to the correct slot |
| // (the shift amount represented here by `x`, `y`, and `z`). |
| |
| // TODO: we could make some small optimization here if we have multiple values passed |
| // from one function into another in the same order. This may not happen commonly enough |
| // to be worth the complication though. |
| |
| // NOTE: we start with 0b0 because it is important that the low bit is always 0 |
| var bitMaskConstant = 0b0 |
| val orExprs = mutableListOf<IrExpression>() |
| |
| params.forEachIndexed { slot, meta -> |
| if (meta.isVararg) { |
| bitMaskConstant = bitMaskConstant or ParamState.Uncertain.bitsForSlot(slot) |
| } else if (!meta.isProvided) { |
| bitMaskConstant = bitMaskConstant or ParamState.Uncertain.bitsForSlot(slot) |
| } else if (meta.isStatic) { |
| bitMaskConstant = bitMaskConstant or ParamState.Static.bitsForSlot(slot) |
| } else if (!meta.isCertain) { |
| bitMaskConstant = bitMaskConstant or ParamState.Uncertain.bitsForSlot(slot) |
| } else { |
| val someMask = meta.maskParam ?: error("Mask param required if param is Certain") |
| val parentSlot = meta.maskSlot |
| require(parentSlot != -1) { "invalid parent slot for Certain param" } |
| |
| // if parentSlot is lower than slot, we shift left a positive amount of bits |
| orExprs.add( |
| irAnd( |
| irConst(bitsForSlot(0b11, slot)), |
| someMask.irShiftBits(parentSlot, slot) |
| ) |
| ) |
| } |
| } |
| return when { |
| // if there are no orExprs, then we can just use the constant |
| orExprs.isEmpty() -> irConst(bitMaskConstant) |
| // if the constant is still 0, then we can just use the or expressions. This is safe |
| // because the low bit will still be 0 regardless of the result of the or expressions. |
| bitMaskConstant == 0 -> orExprs.reduce { lhs, rhs -> |
| irOr(lhs, rhs) |
| } |
| // otherwise, we do (bitMaskConstant or a or b ... or z) |
| else -> orExprs.fold<IrExpression, IrExpression>(irConst(bitMaskConstant)) { lhs, rhs -> |
| irOr(lhs, rhs) |
| } |
| } |
| } |
| |
| override fun visitGetValue(expression: IrGetValue): IrExpression { |
| val declaration = expression.symbol.owner |
| for (scope in scopeStack.asReversed()) { |
| if (scope is Scope.FunctionScope) { |
| if (scope.function.extensionReceiverParameter == declaration) { |
| scope.markGetExtensionReceiver() |
| } |
| if (scope.function.dispatchReceiverParameter == declaration) { |
| scope.markGetDispatchReceiver() |
| } |
| val remapped = scope.remappedParams[declaration] |
| if (remapped != null) { |
| return irGet(remapped) |
| } |
| } |
| } |
| return expression |
| } |
| |
| override fun visitReturn(expression: IrReturn): IrExpression { |
| if (!isInComposableScope) return super.visitReturn(expression) |
| expression.transformChildren() |
| val endBlock = mutableStatementContainer() |
| encounteredReturn(expression.returnTargetSymbol) { endBlock.statements.add(it) } |
| return if (expression.value.type.isUnitOrNullableUnit()) { |
| expression.wrap(listOf(endBlock)) |
| } else { |
| val tempVar = irTemporary(expression.value, nameHint = "return") |
| tempVar.wrap( |
| expression.type, |
| after = listOf( |
| endBlock, |
| IrReturnImpl( |
| expression.startOffset, |
| expression.endOffset, |
| expression.type, |
| expression.returnTargetSymbol, |
| irGet(tempVar) |
| ) |
| ) |
| ) |
| } |
| } |
| |
| override fun visitBreakContinue(jump: IrBreakContinue): IrExpression { |
| if (!isInComposableScope) return super.visitBreakContinue(jump) |
| val endBlock = mutableStatementContainer() |
| encounteredJump(jump) { endBlock.statements.add(it) } |
| return jump.wrap(before = listOf(endBlock)) |
| } |
| |
| override fun visitDoWhileLoop(loop: IrDoWhileLoop): IrExpression { |
| if (!isInComposableScope) return super.visitDoWhileLoop(loop) |
| return handleLoop(loop as IrLoopBase) |
| } |
| |
| override fun visitWhileLoop(loop: IrWhileLoop): IrExpression { |
| if (!isInComposableScope) return super.visitWhileLoop(loop) |
| return handleLoop(loop as IrLoopBase) |
| } |
| |
| private fun handleLoop(loop: IrLoopBase): IrExpression { |
| val loopScope = withScope(Scope.LoopScope(loop)) { |
| loop.transformChildren() |
| } |
| return if (loopScope.hasComposableCalls) { |
| loop.asCoalescableGroup(loopScope) |
| } else { |
| loop |
| } |
| } |
| |
| override fun visitWhen(expression: IrWhen): IrExpression { |
| if (!isInComposableScope) return super.visitWhen(expression) |
| |
| // Composable calls in conditions are more expensive than composable calls in the different |
| // result branches of the when clause. This is because if we have N branches of a when |
| // clause, we will always execute exactly 1 result branch, but we will execute 0-N of the |
| // conditions. This means that if only the results have composable calls, we can use |
| // replaceable groups to represent the entire expression. If a condition has a composable |
| // call in it, we need to place the whole expression in a Container group, since a variable |
| // number of them will be created. The exception here is the first branch's condition, |
| // since it will *always* be executed. As a result, if only the first conditional has a |
| // composable call in it, we can avoid creating a group for it since it is not |
| // conditionally executed. |
| var needsWrappingGroup = false |
| var someResultsHaveCalls = false |
| var hasElseBranch = false |
| |
| val transformed = IrWhenImpl( |
| expression.startOffset, |
| expression.endOffset, |
| expression.type, |
| expression.origin |
| ) |
| val resultScopes = mutableListOf<Scope.BranchScope>() |
| val condScopes = mutableListOf<Scope.BranchScope>() |
| val whenScope = withScope(Scope.WhenScope()) { |
| expression.branches.forEachIndexed { index, it -> |
| if (it is IrElseBranch) { |
| hasElseBranch = true |
| val (resultScope, result) = it.result.transformWithScope(Scope.BranchScope()) |
| |
| condScopes.add(Scope.BranchScope()) |
| resultScopes.add(resultScope) |
| |
| someResultsHaveCalls = someResultsHaveCalls || resultScope.hasComposableCalls |
| transformed.branches.add( |
| IrElseBranchImpl( |
| it.startOffset, |
| it.endOffset, |
| it.condition, |
| result |
| ) |
| ) |
| } else { |
| val (condScope, condition) = it |
| .condition |
| .transformWithScope(Scope.BranchScope()) |
| val (resultScope, result) = it |
| .result |
| .transformWithScope(Scope.BranchScope()) |
| |
| condScopes.add(condScope) |
| resultScopes.add(resultScope) |
| |
| // the first condition is always executed so if it has a composable call in it, |
| // it doesn't necessitate a group |
| needsWrappingGroup = |
| needsWrappingGroup || (index != 0 && condScope.hasComposableCalls) |
| someResultsHaveCalls = someResultsHaveCalls || resultScope.hasComposableCalls |
| transformed.branches.add( |
| IrBranchImpl( |
| it.startOffset, |
| it.endOffset, |
| condition, |
| result |
| ) |
| ) |
| } |
| } |
| } |
| |
| // If we are putting groups around the result branches, we need to guarantee that exactly |
| // one result branch is executed. We do this by adding an else branch if it there is not |
| // one already. Note that we only need to do this if we aren't going to wrap the if |
| // statement in a group entirely, which we will do if the conditions have calls in them. |
| // NOTE: we might also be able to assume that the when is exhaustive if it has a non-unit |
| // resulting type, since the type system should enforce that. |
| if (!hasElseBranch && someResultsHaveCalls && !needsWrappingGroup) { |
| condScopes.add(Scope.BranchScope()) |
| resultScopes.add(Scope.BranchScope()) |
| transformed.branches.add( |
| IrElseBranchImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| condition = IrConstImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| context.irBuiltIns.booleanType, |
| IrConstKind.Boolean, |
| true |
| ), |
| result = IrBlockImpl( |
| UNDEFINED_OFFSET, |
| UNDEFINED_OFFSET, |
| context.irBuiltIns.unitType, |
| null, |
| emptyList() |
| ) |
| ) |
| ) |
| } |
| |
| forEachWith(transformed.branches, condScopes, resultScopes) { it, condScope, resultScope -> |
| // If the conditional block doesn't have a composable call in it, we don't need |
| // to geneerate a group around it because we will be generating one around the entire |
| // if statement |
| if (needsWrappingGroup && condScope.hasComposableCalls) { |
| it.condition = it.condition.asReplaceableGroup(condScope) |
| } |
| if ( |
| // if no wrapping group but some results have calls, we have to have every result |
| // be a group so that we have a consistent number of groups during execution |
| (someResultsHaveCalls && !needsWrappingGroup) || |
| // if we are wrapping the if with a group, then we only need to add a group when |
| // the block has composable calls |
| (needsWrappingGroup && resultScope.hasComposableCalls) |
| ) { |
| it.result = it.result.asReplaceableGroup(resultScope) |
| } |
| } |
| |
| return if (needsWrappingGroup) { |
| transformed.asCoalescableGroup(whenScope) |
| } else { |
| transformed |
| } |
| } |
| |
| sealed class Scope(val name: String) { |
| class FunctionScope( |
| val function: IrFunction, |
| transformer: ComposableFunctionBodyTransformer |
| ) : BlockScope("fun ${function.name.asString()}") { |
| val remappedParams = mutableMapOf<IrValueDeclaration, IrValueDeclaration>() |
| val paramsToSlots = mutableMapOf<IrValueDeclaration, Int>() |
| |
| private var lastTemporaryIndex: Int = 0 |
| |
| private fun nextTemporaryIndex(): Int = lastTemporaryIndex++ |
| |
| var composerParameter: IrValueParameter? = null |
| private set |
| |
| var keyParameter: IrValueParameter? = null |
| private set |
| |
| var defaultParameter: IrDefaultBitMaskValue? = null |
| private set |
| |
| var changedParameter: IrChangedBitMaskValue? = null |
| private set |
| |
| var realValueParamCount: Int = 0 |
| private set |
| |
| // slotCount will include the dispatchReceiver and extensionReceivers |
| var slotCount: Int = 0 |
| private set |
| |
| var dirty: IrChangedBitMaskValue? = null |
| |
| var dispatchReceiverUsed: Boolean = false |
| private set |
| |
| var extensionReceiverUsed: Boolean = false |
| private set |
| |
| fun markGetDispatchReceiver() { |
| dispatchReceiverUsed = true |
| } |
| |
| fun markGetExtensionReceiver() { |
| extensionReceiverUsed = true |
| } |
| |
| init { |
| val defaultParams = mutableListOf<IrValueParameter>() |
| val changedParams = mutableListOf<IrValueParameter>() |
| for (param in function.valueParameters) { |
| val paramName = param.name.asString() |
| when { |
| !paramName.startsWith('$') -> realValueParamCount++ |
| paramName == KtxNameConventions.COMPOSER_PARAMETER.identifier -> |
| composerParameter = param |
| paramName == KtxNameConventions.KEY_PARAMETER.identifier -> |
| keyParameter = param |
| paramName.startsWith(KtxNameConventions.DEFAULT_PARAMETER.identifier) -> |
| defaultParams += param |
| paramName.startsWith(KtxNameConventions.CHANGED_PARAMETER.identifier) -> |
| changedParams += param |
| paramName.startsWith("\$anonymous\$parameter") -> Unit |
| paramName.startsWith("\$name\$for\$destructuring") -> Unit |
| else -> { |
| error("Unexpected parameter name: $paramName") |
| } |
| } |
| } |
| slotCount = realValueParamCount |
| if (function.extensionReceiverParameter != null) slotCount++ |
| if (function.dispatchReceiverParameter != null) { |
| slotCount++ |
| } else if (function.origin == IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA) { |
| slotCount++ |
| } |
| changedParameter = if (composerParameter != null) |
| transformer.IrChangedBitMaskValueImpl( |
| changedParams, |
| slotCount |
| ) |
| else |
| null |
| defaultParameter = if (defaultParams.isNotEmpty()) |
| transformer.IrDefaultBitMaskValueImpl( |
| defaultParams, |
| realValueParamCount |
| ) |
| else |
| null |
| } |
| |
| val isComposable = composerParameter != null |
| |
| fun getNameForTemporary(nameHint: String?): String { |
| val index = nextTemporaryIndex() |
| return if (nameHint != null) "tmp${index}_$nameHint" else "tmp$index" |
| } |
| } |
| |
| abstract class BlockScope(name: String) : Scope(name) { |
| private val extraEndLocations = mutableListOf<(IrExpression) -> Unit>() |
| |
| fun realizeGroup(makeEnd: () -> IrExpression) { |
| realizeCoalescableGroup() |
| realizeEndCalls(makeEnd) |
| } |
| |
| fun markComposableCall() { |
| hasComposableCalls = true |
| if (coalescableChild != null) { |
| // if a call happens after the coalescable child group, then we should |
| // realize the group of the coalescable child |
| shouldRealizeCoalescableChild = true |
| } |
| } |
| |
| fun markReturn(extraEndLocation: (IrExpression) -> Unit) { |
| hasReturn = true |
| extraEndLocations.push(extraEndLocation) |
| } |
| |
| fun markJump(extraEndLocation: (IrExpression) -> Unit) { |
| hasJump = true |
| extraEndLocations.push(extraEndLocation) |
| } |
| |
| fun markCoalescableGroup( |
| scope: BlockScope, |
| realizeGroup: () -> Unit, |
| makeEnd: () -> IrExpression |
| ) { |
| coalescableChild = scope |
| realizeCoalescableChildGroup = { |
| realizeGroup() |
| scope.realizeGroup(makeEnd) |
| realizeCoalescableChildGroup = { error("Attempted to realize group twice") } |
| } |
| } |
| |
| private fun realizeCoalescableGroup() { |
| if (shouldRealizeCoalescableChild) { |
| realizeCoalescableChildGroup() |
| } else { |
| coalescableChild?.realizeCoalescableGroup() |
| } |
| } |
| |
| private fun realizeEndCalls(makeEnd: () -> IrExpression) { |
| extraEndLocations.forEach { |
| it(makeEnd()) |
| } |
| } |
| |
| var hasComposableCalls = false |
| private set |
| var hasReturn = false |
| private set |
| var hasJump = false |
| private set |
| private var realizeCoalescableChildGroup = {} |
| private var shouldRealizeCoalescableChild = false |
| private var coalescableChild: BlockScope? = null |
| } |
| class ClassScope(name: Name) : Scope("class ${name.asString()}") |
| class PropertyScope(name: Name) : Scope("val ${name.asString()}") |
| class FieldScope(name: Name) : Scope("field ${name.asString()}") |
| class FileScope(name: FqName) : Scope("file $name") |
| class LoopScope(val loop: IrLoop) : BlockScope("loop") |
| class WhenScope : BlockScope("when") |
| class BranchScope : BlockScope("branch") |
| } |
| |
| inner class IrDefaultBitMaskValueImpl( |
| private val params: List<IrValueParameter>, |
| private val count: Int |
| ) : IrDefaultBitMaskValue { |
| |
| init { |
| val actual = params.size |
| val expected = defaultParamCount(count) |
| require(actual == expected) { |
| "Function with $count params had $actual default params but expected $expected" |
| } |
| } |
| |
| override fun irIsolateBitAtIndex(index: Int): IrExpression { |
| require(index <= count) |
| // (%default and 0b1) |
| return irAnd( |
| // a value of 1 in default means it was NOT provided |
| irGet(params[defaultsParamIndex(index)]), |
| irConst(0b1 shl defaultsBitIndex(index)) |
| ) |
| } |
| |
| override fun irHasAnyProvidedAndUnstable(unstable: BooleanArray): IrExpression { |
| require(count == unstable.size) |
| val expressions = params.mapIndexed { index, param -> |
| val start = index * BITS_PER_INT |
| val end = min(start + BITS_PER_INT, count) |
| val unstableMask = bitMask(*unstable.sliceArray(start until end)) |
| irNotEqual( |
| // ~$default and unstableMask will be non-zero if any parameters were |
| // *provided* AND *unstable* |
| irAnd( |
| irInv(irGet(param)), |
| irConst(unstableMask) |
| ), |
| irConst(0) |
| ) |
| } |
| return if (expressions.size == 1) |
| expressions.single() |
| else |
| expressions.reduce { lhs, rhs -> irOrOr(lhs, rhs) } |
| } |
| |
| override fun putAsValueArgumentIn(fn: IrFunctionAccessExpression, startIndex: Int) { |
| params.forEachIndexed { i, param -> |
| fn.putValueArgument( |
| startIndex + i, |
| irGet(param) |
| ) |
| } |
| } |
| } |
| |
| open inner class IrChangedBitMaskValueImpl( |
| private val params: List<IrValueDeclaration>, |
| private val count: Int |
| ) : IrChangedBitMaskValue { |
| protected fun paramIndexForSlot(slot: Int): Int = slot / SLOTS_PER_INT |
| |
| init { |
| val actual = params.size |
| val expected = changedParamCount(count, 0) |
| require(actual == expected) { |
| "Function with $count params had $actual changed params but expected $expected" |
| } |
| } |
| |
| override fun irLowBit(): IrExpression { |
| return irAnd( |
| irGet(params[0]), |
| irConst(0b1) |
| ) |
| } |
| |
| override fun irIsolateBitsAtSlot(slot: Int): IrExpression { |
| // %changed and 0b11 |
| return irAnd( |
| irGet(params[paramIndexForSlot(slot)]), |
| irBitsForSlot(0b11, slot) |
| ) |
| } |
| |
| override fun irHasDifferences(): IrExpression { |
| if (count == 0) { |
| // for 0 slots (no params), we can create a shortcut expression of just checking the |
| // low-bit for non-zero. Since all of the higher bits will also be 0, we can just |
| // simplify this to check if dirty is non-zero |
| return irNotEqual( |
| irGet(params[0]), |
| irConst(0) |
| ) |
| } |
| |
| val expressions = params.mapIndexed { index, param -> |
| val start = index * SLOTS_PER_INT |
| val end = min(start + SLOTS_PER_INT, count) |
| |
| // makes an int with each slot having 0b01 mask and the low bit being 0. |
| // so for 3 slots, we would get 0b 01 01 01 0. |
| // This pattern is useful because we can and + xor it with our $changed bitmask and it |
| // will only be non-zero if any of the slots were DIFFERENT or UNCERTAIN. |
| val bitPattern = (start until end).fold(0) { mask, slot -> |
| mask or bitsForSlot(0b01, slot) |
| } |
| |
| // we use this pattern with the low bit set to 1 in the "and", and the low bit set to 0 |
| // for the "xor". This means that if the low bit was set, we will get 1 in the resulting |
| // low bit. Since we use this calculation to determine if we need to run the body of the |
| // function, this is exactly what we want. |
| |
| // $dirty and (0b 01 ... 01 1) xor (0b 01 ... 01 0) |
| irNotEqual( |
| irXor( |
| irAnd( |
| irGet(param), |
| irConst(bitPattern or 0b1) |
| ), |
| irConst(bitPattern or 0b0) |
| ), |
| irConst(0) // anything non-zero means we have differences |
| ) |
| } |
| return if (expressions.size == 1) |
| expressions.single() |
| else |
| expressions.reduce { lhs, rhs -> irOrOr(lhs, rhs) } |
| } |
| |
| override fun irCopyToTemporary( |
| nameHint: String?, |
| isVar: Boolean, |
| exactName: Boolean |
| ): IrChangedBitMaskVariable { |
| val temps = params.mapIndexed { index, param -> |
| irTemporary( |
| irGet(param), |
| if (index == 0) nameHint else "$nameHint$index", |
| context.irBuiltIns.intType, |
| isVar, |
| exactName |
| ) |
| } |
| return IrChangedBitMaskVariableImpl(temps, count) |
| } |
| |
| override fun putAsValueArgumentInWithLowBit( |
| fn: IrFunctionAccessExpression, |
| startIndex: Int, |
| lowBit: Boolean |
| ) { |
| params.forEachIndexed { index, param -> |
| fn.putValueArgument( |
| startIndex + index, |
| if (index == 0) { |
| irOr(irGet(param), irConst(if (lowBit) 0b1 else 0b0)) |
| } else { |
| irGet(param) |
| } |
| ) |
| } |
| } |
| |
| override fun irShiftBits(fromSlot: Int, toSlot: Int): IrExpression { |
| val fromSlotAdjusted = fromSlot.rem(SLOTS_PER_INT) |
| val toSlotAdjusted = toSlot.rem(SLOTS_PER_INT) |
| val bitsToShiftLeft = (toSlotAdjusted - fromSlotAdjusted) * 2 |
| val value = irGet(params[paramIndexForSlot(fromSlot)]) |
| |
| if (bitsToShiftLeft == 0) return value |
| val int = context.builtIns.intType |
| val shiftLeft = context.symbols.getBinaryOperator( |
| OperatorNames.SHL, |
| int, |
| int |
| ) |
| val shiftRight = context.symbols.getBinaryOperator( |
| OperatorNames.SHR, |
| int, |
| int |
| ) |
| |
| return irCall( |
| if (bitsToShiftLeft > 0) shiftLeft else shiftRight, |
| null, |
| value, |
| null, |
| irConst(abs(bitsToShiftLeft)) |
| ) |
| } |
| } |
| |
| inner class IrChangedBitMaskVariableImpl( |
| private val temps: List<IrVariable>, |
| count: Int |
| ) : IrChangedBitMaskVariable, IrChangedBitMaskValueImpl(temps, count) { |
| override fun asStatements(): List<IrStatement> { |
| return temps |
| } |
| |
| override fun irOrSetBitsAtSlot(slot: Int, value: IrExpression): IrExpression { |
| val temp = temps[paramIndexForSlot(slot)] |
| return irSet( |
| temp, |
| irOr( |
| irGet(temp), |
| value |
| ) |
| ) |
| } |
| } |
| } |
| |
| inline fun <A, B, C> forEachWith(a: List<A>, b: List<B>, c: List<C>, fn: (A, B, C) -> Unit) { |
| for (i in a.indices) { |
| fn(a[i], b[i], c[i]) |
| } |
| } |