[go: nahoru, domu]

Refactor LazyLayoutAnimateItemModifierNode to not be a Modifier.Node

We need this to decouple the lifecycle of this object out from the Modifier.Node lifecycle. The removed item will have its modifiers detached, while we need to continue displaying the content during the upcoming disappearance animation.

Test: existing tests in lazy package
Change-Id: Ib09c92af5d742d520ad71dc0c19d04b9770e8eb1
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyItemScopeImpl.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyItemScopeImpl.kt
index 33940b8..cbc947b 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyItemScopeImpl.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyItemScopeImpl.kt
@@ -22,20 +22,17 @@
 import androidx.compose.animation.core.spring
 import androidx.compose.animation.core.tween
 import androidx.compose.foundation.ExperimentalFoundationApi
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimationSpecsNode
 import androidx.compose.runtime.State
 import androidx.compose.runtime.mutableIntStateOf
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.layout.Measurable
 import androidx.compose.ui.layout.MeasureResult
 import androidx.compose.ui.layout.MeasureScope
-import androidx.compose.ui.node.DelegatingNode
 import androidx.compose.ui.node.LayoutModifierNode
 import androidx.compose.ui.node.ModifierNodeElement
-import androidx.compose.ui.node.ParentDataModifierNode
 import androidx.compose.ui.platform.InspectorInfo
 import androidx.compose.ui.unit.Constraints
-import androidx.compose.ui.unit.Density
 import androidx.compose.ui.unit.IntOffset
 import kotlin.math.roundToInt
 
@@ -182,15 +179,15 @@
 
 private data class AnimateItemElement(
     val appearanceSpec: FiniteAnimationSpec<Float>?,
-    val placementSpec: FiniteAnimationSpec<IntOffset>?,
-) : ModifierNodeElement<AnimateItemPlacementNode>() {
+    val placementSpec: FiniteAnimationSpec<IntOffset>?
+) : ModifierNodeElement<LazyLayoutAnimationSpecsNode>() {
 
-    override fun create(): AnimateItemPlacementNode =
-        AnimateItemPlacementNode(appearanceSpec, placementSpec)
+    override fun create(): LazyLayoutAnimationSpecsNode =
+        LazyLayoutAnimationSpecsNode(appearanceSpec, placementSpec)
 
-    override fun update(node: AnimateItemPlacementNode) {
-        node.delegatingNode.appearanceSpec = appearanceSpec
-        node.delegatingNode.placementSpec = placementSpec
+    override fun update(node: LazyLayoutAnimationSpecsNode) {
+        node.appearanceSpec = appearanceSpec
+        node.placementSpec = placementSpec
     }
 
     override fun InspectorInfo.inspectableProperties() {
@@ -199,18 +196,3 @@
         value = placementSpec
     }
 }
-
-private class AnimateItemPlacementNode(
-    appearanceSpec: FiniteAnimationSpec<Float>?,
-    placementSpec: FiniteAnimationSpec<IntOffset>?,
-) : DelegatingNode(), ParentDataModifierNode {
-
-    val delegatingNode = delegate(
-        LazyLayoutAnimateItemModifierNode(
-            appearanceSpec,
-            placementSpec
-        )
-    )
-
-    override fun Density.modifyParentData(parentData: Any?): Any = delegatingNode
-}
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyList.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyList.kt
index d9f7585..ca072ec 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyList.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyList.kt
@@ -292,7 +292,8 @@
                     spacing = spacing,
                     visualOffset = visualItemOffset,
                     key = key,
-                    contentType = contentType
+                    contentType = contentType,
+                    animator = state.itemAnimator
                 )
             }
         }
@@ -341,6 +342,9 @@
             hasLookaheadPassOccurred = hasLookaheadPassOccurred,
             isLookingAhead = isLookingAhead,
             postLookaheadLayoutInfo = state.postLookaheadLayoutInfo,
+            coroutineScope = requireNotNull(state.coroutineScope) {
+                "coroutineScope should be not null"
+            },
             layout = { width, height, placement ->
                 layout(
                     containerConstraints.constrainWidth(width + totalHorizontalPadding),
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListItemAnimator.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListItemAnimator.kt
index 023c327..978c87b 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListItemAnimator.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListItemAnimator.kt
@@ -16,11 +16,13 @@
 
 package androidx.compose.foundation.lazy
 
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimation
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimationSpecsNode
 import androidx.compose.foundation.lazy.layout.LazyLayoutKeyIndexMap
 import androidx.compose.ui.unit.IntOffset
 import androidx.compose.ui.util.fastAny
 import androidx.compose.ui.util.fastForEach
+import kotlinx.coroutines.CoroutineScope
 
 /**
  * Handles the item animations when it is set via [LazyItemScope.animateItemPlacement].
@@ -28,11 +30,11 @@
  * This class is responsible for:
  * - animating item appearance for the new items.
  * - detecting when item position changed, figuring our start/end offsets and starting the
- * animations for placement animations. *
+ * animations for placement animations.
  */
 internal class LazyListItemAnimator {
-    // contains the keys of the active items with animation node.
-    private val activeKeys = mutableSetOf<Any>()
+    // state containing relevant info for active items.
+    private val keyToItemInfoMap = mutableMapOf<Any, ItemInfo>()
 
     // snapshot of the key to index map used for the last measuring.
     private var keyIndexMap: LazyLayoutKeyIndexMap? = null
@@ -60,14 +62,15 @@
         itemProvider: LazyListMeasuredItemProvider,
         isVertical: Boolean,
         isLookingAhead: Boolean,
-        hasLookaheadOccurred: Boolean
+        hasLookaheadOccurred: Boolean,
+        coroutineScope: CoroutineScope
     ) {
         val previousKeyToIndexMap = this.keyIndexMap
         val keyIndexMap = itemProvider.keyIndexMap
         this.keyIndexMap = keyIndexMap
 
         val hasAnimations = positionedItems.fastAny { it.hasAnimations }
-        if (!hasAnimations && activeKeys.isEmpty()) {
+        if (!hasAnimations && keyToItemInfoMap.isEmpty()) {
             // no animations specified - no work needed
             reset()
             return
@@ -89,14 +92,18 @@
         // means lookahead pass, or regular pass when not in a lookahead scope.
         val shouldSetupAnimation = isLookingAhead || !hasLookaheadOccurred
         // first add all items we had in the previous run
-        movingAwayKeys.addAll(activeKeys)
+        movingAwayKeys.addAll(keyToItemInfoMap.keys)
         // iterate through the items which are visible (without animated offsets)
         positionedItems.fastForEach { item ->
             // remove items we have in the current one as they are still visible.
             movingAwayKeys.remove(item.key)
             if (item.hasAnimations) {
-                if (!activeKeys.contains(item.key)) {
-                    activeKeys += item.key
+                val itemInfo = keyToItemInfoMap[item.key]
+                // there is no state associated with this item yet
+                if (itemInfo == null) {
+                    val newItemInfo = ItemInfo()
+                    newItemInfo.updateAnimation(item, coroutineScope)
+                    keyToItemInfoMap[item.key] = newItemInfo
                     val previousIndex = previousKeyToIndexMap?.getIndex(item.key) ?: -1
                     if (item.index != previousIndex && previousIndex != -1) {
                         if (previousIndex < previousFirstVisibleIndex) {
@@ -106,22 +113,25 @@
                             movingInFromEndBound.add(item)
                         }
                     } else {
-                        initializeNode(
+                        initializeAnimation(
                             item,
-                            item.getOffset(0).let { if (item.isVertical) it.y else it.x }
+                            item.getOffset(0).let { if (item.isVertical) it.y else it.x },
+                            newItemInfo
                         )
                         if (previousIndex == -1 && previousKeyToIndexMap != null) {
-                            item.forEachNode { _, node ->
-                                node.animateAppearance()
+                            newItemInfo.animations.forEach {
+                                it?.animateAppearance()
                             }
                         }
                     }
                 } else {
                     if (shouldSetupAnimation) {
-                        item.forEachNode { _, node ->
-                            if (node.rawOffset != LazyLayoutAnimateItemModifierNode.NotInitialized
+                        itemInfo.updateAnimation(item, coroutineScope)
+                        itemInfo.animations.forEach { animation ->
+                            if (animation != null &&
+                                animation.rawOffset != LazyLayoutAnimation.NotInitialized
                             ) {
-                                node.rawOffset += scrollOffset
+                                animation.rawOffset += scrollOffset
                             }
                         }
                         startPlacementAnimationsIfNeeded(item)
@@ -129,7 +139,7 @@
                 }
             } else {
                 // no animation, clean up if needed
-                activeKeys.remove(item.key)
+                keyToItemInfoMap.remove(item.key)
             }
         }
 
@@ -139,7 +149,7 @@
             movingInFromStartBound.fastForEach { item ->
                 accumulatedOffset += item.size
                 val mainAxisOffset = 0 - accumulatedOffset
-                initializeNode(item, mainAxisOffset)
+                initializeAnimation(item, mainAxisOffset)
                 startPlacementAnimationsIfNeeded(item)
             }
             accumulatedOffset = 0
@@ -147,7 +157,7 @@
             movingInFromEndBound.fastForEach { item ->
                 val mainAxisOffset = mainAxisLayoutSize + accumulatedOffset
                 accumulatedOffset += item.size
-                initializeNode(item, mainAxisOffset)
+                initializeAnimation(item, mainAxisOffset)
                 startPlacementAnimationsIfNeeded(item)
             }
         }
@@ -158,19 +168,15 @@
             val newIndex = keyIndexMap.getIndex(key)
 
             if (newIndex == -1) {
-                activeKeys.remove(key)
+                keyToItemInfoMap.remove(key)
             } else {
                 val item = itemProvider.getAndMeasure(newIndex)
+                val itemInfo = keyToItemInfoMap.getValue(key)
                 // check if we have any active placement animation on the item
-                var inProgress = false
-                repeat(item.placeablesCount) {
-                    if (item.getParentData(it).node?.isPlacementAnimationInProgress == true) {
-                        inProgress = true
-                        return@repeat
-                    }
-                }
+                val inProgress =
+                    itemInfo.animations.any { it?.isPlacementAnimationInProgress == true }
                 if ((!inProgress && newIndex == previousKeyToIndexMap?.getIndex(key))) {
-                    activeKeys.remove(key)
+                    keyToItemInfoMap.remove(key)
                 } else {
                     if (newIndex < firstVisibleIndex) {
                         movingAwayToStartBound.add(item)
@@ -222,14 +228,15 @@
      * for example when we snap to a new position.
      */
     fun reset() {
-        activeKeys.clear()
+        keyToItemInfoMap.clear()
         keyIndexMap = LazyLayoutKeyIndexMap.Empty
         firstVisibleIndex = -1
     }
 
-    private fun initializeNode(
+    private fun initializeAnimation(
         item: LazyListMeasuredItem,
-        mainAxisOffset: Int
+        mainAxisOffset: Int,
+        itemInfo: ItemInfo = keyToItemInfoMap.getValue(item.key)
     ) {
         val firstPlaceableOffset = item.getOffset(0)
 
@@ -240,39 +247,77 @@
         }
 
         // initialize offsets
-        item.forEachNode { placeableIndex, node ->
-            val diffToFirstPlaceableOffset =
-                item.getOffset(placeableIndex) - firstPlaceableOffset
-            node.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
+        itemInfo.animations.forEachIndexed { placeableIndex, animation ->
+            if (animation != null) {
+                val diffToFirstPlaceableOffset =
+                    item.getOffset(placeableIndex) - firstPlaceableOffset
+                animation.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
+            }
         }
     }
 
     private fun startPlacementAnimationsIfNeeded(item: LazyListMeasuredItem) {
-        item.forEachNode { placeableIndex, node ->
-            val newTarget = item.getOffset(placeableIndex)
-            val currentTarget = node.rawOffset
-            if (currentTarget != LazyLayoutAnimateItemModifierNode.NotInitialized &&
-                currentTarget != newTarget
-            ) {
-                node.animatePlacementDelta(newTarget - currentTarget)
+        val itemInfo = keyToItemInfoMap.getValue(item.key)
+        itemInfo.animations.forEachIndexed { placeableIndex, animation ->
+            if (animation != null) {
+                val newTarget = item.getOffset(placeableIndex)
+                val currentTarget = animation.rawOffset
+                if (currentTarget != LazyLayoutAnimation.NotInitialized &&
+                    currentTarget != newTarget
+                ) {
+                    animation.animatePlacementDelta(newTarget - currentTarget)
+                }
+                animation.rawOffset = newTarget
             }
-            node.rawOffset = newTarget
         }
     }
 
-    private val Any?.node get() = this as? LazyLayoutAnimateItemModifierNode
+    fun getAnimation(key: Any, placeableIndex: Int): LazyLayoutAnimation? =
+        keyToItemInfoMap[key]?.animations?.get(placeableIndex)
 
     private val LazyListMeasuredItem.hasAnimations: Boolean
         get() {
-            forEachNode { _, _ -> return true }
+            repeat(placeablesCount) { index ->
+                getParentData(index).specs?.let {
+                    // found at least one
+                    return true
+                }
+            }
             return false
         }
 
-    private inline fun LazyListMeasuredItem.forEachNode(
-        block: (placeableIndex: Int, node: LazyLayoutAnimateItemModifierNode) -> Unit
-    ) {
-        repeat(placeablesCount) { index ->
-            getParentData(index).node?.let { block(index, it) }
+    private class ItemInfo {
+        /**
+         * This array will have the same amount of elements as there are placeables on the item.
+         * If the element is not null this means there are specs associated with the given placeable.
+         */
+        var animations = EmptyArray
+            private set
+
+        fun updateAnimation(positionedItem: LazyListMeasuredItem, coroutineScope: CoroutineScope) {
+            for (i in positionedItem.placeablesCount until animations.size) {
+                animations[i]?.stopAnimations()
+            }
+            if (animations.size != positionedItem.placeablesCount) {
+                animations = animations.copyOf(positionedItem.placeablesCount)
+            }
+            repeat(positionedItem.placeablesCount) { index ->
+                val specs = positionedItem.getParentData(index).specs
+                if (specs == null) {
+                    animations[index]?.stopAnimations()
+                    animations[index] = null
+                } else {
+                    val animation = animations[index] ?: LazyLayoutAnimation(coroutineScope).also {
+                        animations[index] = it
+                    }
+                    animation.appearanceSpec = specs.appearanceSpec
+                    animation.placementSpec = specs.placementSpec
+                }
+            }
         }
     }
 }
+
+private val Any?.specs get() = this as? LazyLayoutAnimationSpecsNode
+
+private val EmptyArray = emptyArray<LazyLayoutAnimation?>()
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasure.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasure.kt
index 3876e05..6402fc2 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasure.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasure.kt
@@ -33,6 +33,7 @@
 import kotlin.math.abs
 import kotlin.math.roundToInt
 import kotlin.math.sign
+import kotlinx.coroutines.CoroutineScope
 
 /**
  * Measures and calculates the positions for the requested items. The result is produced
@@ -62,6 +63,7 @@
     hasLookaheadPassOccurred: Boolean,
     isLookingAhead: Boolean,
     postLookaheadLayoutInfo: LazyListLayoutInfo?,
+    coroutineScope: CoroutineScope,
     @Suppress("PrimitiveInLambda")
     layout: (Int, Int, Placeable.PlacementScope.() -> Unit) -> MeasureResult
 ): LazyListMeasureResult {
@@ -79,7 +81,8 @@
             itemProvider = measuredItemProvider,
             isVertical = isVertical,
             isLookingAhead = isLookingAhead,
-            hasLookaheadOccurred = hasLookaheadPassOccurred
+            hasLookaheadOccurred = hasLookaheadPassOccurred,
+            coroutineScope = coroutineScope
         )
         return LazyListMeasureResult(
             firstVisibleItem = null,
@@ -321,7 +324,8 @@
             itemProvider = measuredItemProvider,
             isVertical = isVertical,
             isLookingAhead = isLookingAhead,
-            hasLookaheadOccurred = hasLookaheadPassOccurred
+            hasLookaheadOccurred = hasLookaheadPassOccurred,
+            coroutineScope = coroutineScope
         )
 
         val headerItem = if (headerIndexes.isNotEmpty()) {
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasuredItem.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasuredItem.kt
index 3420577..24a1642 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasuredItem.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/LazyListMeasuredItem.kt
@@ -17,8 +17,8 @@
 package androidx.compose.foundation.lazy
 
 import androidx.compose.foundation.ExperimentalFoundationApi
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode.Companion.NotInitialized
+import androidx.compose.foundation.lazy.layout.DefaultLayerBlock
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimation.Companion.NotInitialized
 import androidx.compose.ui.Alignment
 import androidx.compose.ui.graphics.GraphicsLayerScope
 import androidx.compose.ui.layout.Placeable
@@ -52,7 +52,8 @@
      */
     private val visualOffset: IntOffset,
     override val key: Any,
-    override val contentType: Any?
+    override val contentType: Any?,
+    private val animator: LazyListItemAnimator
 ) : LazyListItemInfo {
     override var offset: Int = 0
         private set
@@ -144,30 +145,30 @@
             val minOffset = minMainAxisOffset - placeable.mainAxisSize
             val maxOffset = maxMainAxisOffset
             var offset = getOffset(index)
-            val animateNode = getParentData(index) as? LazyLayoutAnimateItemModifierNode
+            val animation = animator.getAnimation(key, index)
             val layerBlock: GraphicsLayerScope.() -> Unit
-            if (animateNode != null) {
+            if (animation != null) {
                 if (isLookingAhead) {
                     // Skip animation in lookahead pass
-                    animateNode.lookaheadOffset = offset
+                    animation.lookaheadOffset = offset
                 } else {
-                    val targetOffset = if (animateNode.lookaheadOffset != NotInitialized) {
-                        animateNode.lookaheadOffset
+                    val targetOffset = if (animation.lookaheadOffset != NotInitialized) {
+                        animation.lookaheadOffset
                     } else {
                         offset
                     }
-                    val animatedOffset = targetOffset + animateNode.placementDelta
+                    val animatedOffset = targetOffset + animation.placementDelta
                     // cancel the animation if current and target offsets are both out of the bounds
                     if ((targetOffset.mainAxis <= minOffset &&
                             animatedOffset.mainAxis <= minOffset) ||
                         (targetOffset.mainAxis >= maxOffset &&
                             animatedOffset.mainAxis >= maxOffset)
                     ) {
-                        animateNode.cancelPlacementAnimation()
+                        animation.cancelPlacementAnimation()
                     }
                     offset = animatedOffset
                 }
-                layerBlock = animateNode
+                layerBlock = animation
             } else {
                 layerBlock = DefaultLayerBlock
             }
@@ -192,8 +193,3 @@
 }
 
 private const val Unset = Int.MIN_VALUE
-
-/**
- * Block on [GraphicsLayerScope] which applies the default layer parameters.
- */
-private val DefaultLayerBlock: GraphicsLayerScope.() -> Unit = {}
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGrid.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGrid.kt
index ff95f02..327f8e1 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGrid.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGrid.kt
@@ -34,6 +34,7 @@
 import androidx.compose.foundation.overscroll
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
 import androidx.compose.runtime.snapshots.Snapshot
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.layout.MeasureResult
@@ -47,6 +48,7 @@
 import androidx.compose.ui.unit.dp
 import androidx.compose.ui.unit.offset
 import androidx.compose.ui.util.fastForEach
+import kotlinx.coroutines.CoroutineScope
 
 @OptIn(ExperimentalFoundationApi::class)
 @Composable
@@ -80,6 +82,7 @@
 
     val semanticState = rememberLazyGridSemanticState(state, reverseLayout)
 
+    val coroutineScope = rememberCoroutineScope()
     val measurePolicy = rememberLazyGridMeasurePolicy(
         itemProviderLambda,
         state,
@@ -89,6 +92,7 @@
         isVertical,
         horizontalArrangement,
         verticalArrangement,
+        coroutineScope
     )
 
     state.isVertical = isVertical
@@ -167,10 +171,12 @@
     reverseLayout: Boolean,
     /** The layout orientation of the list */
     isVertical: Boolean,
-    /** The horizontal arrangement for items. Required when isVertical is false */
-    horizontalArrangement: Arrangement.Horizontal? = null,
-    /** The vertical arrangement for items. Required when isVertical is true */
-    verticalArrangement: Arrangement.Vertical? = null,
+    /** The horizontal arrangement for items */
+    horizontalArrangement: Arrangement.Horizontal?,
+    /** The vertical arrangement for items */
+    verticalArrangement: Arrangement.Vertical?,
+    /** Coroutine scope for item animations */
+    coroutineScope: CoroutineScope
 ) = remember<LazyLayoutMeasureScope.(Constraints) -> MeasureResult>(
     state,
     slots,
@@ -281,7 +287,8 @@
                 afterContentPadding = afterContentPadding,
                 visualOffset = visualItemOffset,
                 placeables = placeables,
-                contentType = contentType
+                contentType = contentType,
+                animator = state.placementAnimator
             )
         }
         val measuredLineProvider = object : LazyGridMeasuredLineProvider(
@@ -363,6 +370,7 @@
             placementAnimator = state.placementAnimator,
             spanLayoutProvider = spanLayoutProvider,
             pinnedItems = pinnedItems,
+            coroutineScope = coroutineScope,
             layout = { width, height, placement ->
                 layout(
                     containerConstraints.constrainWidth(width + totalHorizontalPadding),
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemPlacementAnimator.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemPlacementAnimator.kt
index 58be76c..19d1f00 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemPlacementAnimator.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemPlacementAnimator.kt
@@ -16,12 +16,14 @@
 
 package androidx.compose.foundation.lazy.grid
 
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimation
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimationSpecsNode
 import androidx.compose.foundation.lazy.layout.LazyLayoutKeyIndexMap
 import androidx.compose.ui.unit.Constraints
 import androidx.compose.ui.unit.IntOffset
 import androidx.compose.ui.util.fastAny
 import androidx.compose.ui.util.fastForEach
+import kotlinx.coroutines.CoroutineScope
 
 /**
  * Handles the item placement animations when it is set via [LazyGridItemScope.animateItemPlacement].
@@ -58,7 +60,8 @@
         positionedItems: MutableList<LazyGridMeasuredItem>,
         itemProvider: LazyGridMeasuredItemProvider,
         spanLayoutProvider: LazyGridSpanLayoutProvider,
-        isVertical: Boolean
+        isVertical: Boolean,
+        coroutineScope: CoroutineScope
     ) {
         if (!positionedItems.fastAny { it.hasAnimations } && keyToItemInfoMap.isEmpty()) {
             // no animations specified - no work needed
@@ -90,8 +93,9 @@
                 val itemInfo = keyToItemInfoMap[item.key]
                 // there is no state associated with this item yet
                 if (itemInfo == null) {
-                    keyToItemInfoMap[item.key] =
-                        ItemInfo(item.crossAxisSize, item.crossAxisOffset)
+                    val newItemInfo = ItemInfo(item.crossAxisSize, item.crossAxisOffset)
+                    newItemInfo.updateAnimation(item, coroutineScope)
+                    keyToItemInfoMap[item.key] = newItemInfo
                     val previousIndex = previousKeyToIndexMap.getIndex(item.key)
                     if (previousIndex != -1 && item.index != previousIndex) {
                         if (previousIndex < previousFirstVisibleIndex) {
@@ -101,15 +105,18 @@
                             movingInFromEndBound.add(item)
                         }
                     } else {
-                        initializeNode(
+                        initializeAnimation(
                             item,
-                            item.offset.let { if (item.isVertical) it.y else it.x }
+                            item.offset.let { if (item.isVertical) it.y else it.x },
+                            newItemInfo
                         )
                     }
                 } else {
-                    item.forEachNode {
-                        if (it.rawOffset != LazyLayoutAnimateItemModifierNode.NotInitialized) {
-                            it.rawOffset += scrollOffset
+                    itemInfo.animations.forEach { animation ->
+                        if (animation != null &&
+                            animation.rawOffset != LazyLayoutAnimation.NotInitialized
+                        ) {
+                            animation.rawOffset += scrollOffset
                         }
                     }
                     itemInfo.crossAxisSize = item.crossAxisSize
@@ -136,7 +143,7 @@
                 previousLine = line
             }
             val mainAxisOffset = 0 - accumulatedOffset - item.mainAxisSize
-            initializeNode(item, mainAxisOffset)
+            initializeAnimation(item, mainAxisOffset)
             startAnimationsIfNeeded(item)
         }
         accumulatedOffset = 0
@@ -153,7 +160,7 @@
                 previousLine = line
             }
             val mainAxisOffset = mainAxisLayoutSize + accumulatedOffset
-            initializeNode(item, mainAxisOffset)
+            initializeAnimation(item, mainAxisOffset)
             startAnimationsIfNeeded(item)
         }
 
@@ -175,13 +182,8 @@
                     }
                 )
                 // check if we have any active placement animation on the item
-                var inProgress = false
-                repeat(item.placeablesCount) {
-                    if (item.getParentData(it).node?.isPlacementAnimationInProgress == true) {
-                        inProgress = true
-                        return@repeat
-                    }
-                }
+                val inProgress =
+                    itemInfo.animations.any { it?.isPlacementAnimationInProgress == true }
                 if ((!inProgress && newIndex == previousKeyToIndexMap.getIndex(key))) {
                     keyToItemInfoMap.remove(key)
                 } else {
@@ -264,9 +266,10 @@
         firstVisibleIndex = -1
     }
 
-    private fun initializeNode(
+    private fun initializeAnimation(
         item: LazyGridMeasuredItem,
-        mainAxisOffset: Int
+        mainAxisOffset: Int,
+        itemInfo: ItemInfo = keyToItemInfoMap.getValue(item.key)
     ) {
         val firstPlaceableOffset = item.offset
 
@@ -277,44 +280,80 @@
         }
 
         // initialize offsets
-        item.forEachNode { node ->
-            val diffToFirstPlaceableOffset =
-                item.offset - firstPlaceableOffset
-            node.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
+        itemInfo.animations.forEach { animation ->
+            if (animation != null) {
+                val diffToFirstPlaceableOffset =
+                    item.offset - firstPlaceableOffset
+                animation.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
+            }
         }
     }
 
     private fun startAnimationsIfNeeded(item: LazyGridMeasuredItem) {
-        item.forEachNode { node ->
-            val newTarget = item.offset
-            val currentTarget = node.rawOffset
-            if (currentTarget != LazyLayoutAnimateItemModifierNode.NotInitialized &&
-                currentTarget != newTarget
-            ) {
-                node.animatePlacementDelta(newTarget - currentTarget)
+        val itemInfo = keyToItemInfoMap.getValue(item.key)
+        itemInfo.animations.forEach { animation ->
+            if (animation != null) {
+                val newTarget = item.offset
+                val currentTarget = animation.rawOffset
+                if (currentTarget != LazyLayoutAnimation.NotInitialized &&
+                    currentTarget != newTarget
+                ) {
+                    animation.animatePlacementDelta(newTarget - currentTarget)
+                }
+                animation.rawOffset = newTarget
             }
-            node.rawOffset = newTarget
         }
     }
 
-    private val Any?.node get() = this as? LazyLayoutAnimateItemModifierNode
+    fun getAnimation(key: Any, placeableIndex: Int): LazyLayoutAnimation? =
+        keyToItemInfoMap[key]?.animations?.get(placeableIndex)
 
     private val LazyGridMeasuredItem.hasAnimations: Boolean
         get() {
-            forEachNode { return true }
+            repeat(placeablesCount) { index ->
+                getParentData(index).specs?.let {
+                    // found at least one
+                    return true
+                }
+            }
             return false
         }
-
-    private inline fun LazyGridMeasuredItem.forEachNode(
-        block: (LazyLayoutAnimateItemModifierNode) -> Unit
-    ) {
-        repeat(placeablesCount) { index ->
-            getParentData(index).node?.let(block)
-        }
-    }
 }
 
 private class ItemInfo(
     var crossAxisSize: Int,
     var crossAxisOffset: Int
-)
+) {
+    /**
+     * This array will have the same amount of elements as there are placeables on the item.
+     * If the element is not null this means there are specs associated with the given placeable.
+     */
+    var animations = EmptyArray
+        private set
+
+    fun updateAnimation(positionedItem: LazyGridMeasuredItem, coroutineScope: CoroutineScope) {
+        for (i in positionedItem.placeablesCount until animations.size) {
+            animations[i]?.stopAnimations()
+        }
+        if (animations.size != positionedItem.placeablesCount) {
+            animations = animations.copyOf(positionedItem.placeablesCount)
+        }
+        repeat(positionedItem.placeablesCount) { index ->
+            val specs = positionedItem.getParentData(index).specs
+            if (specs == null) {
+                animations[index]?.stopAnimations()
+                animations[index] = null
+            } else {
+                val item = animations[index] ?: LazyLayoutAnimation(coroutineScope).also {
+                    animations[index] = it
+                }
+                item.appearanceSpec = specs.appearanceSpec
+                item.placementSpec = specs.placementSpec
+            }
+        }
+    }
+}
+
+private val Any?.specs get() = this as? LazyLayoutAnimationSpecsNode
+
+private val EmptyArray = emptyArray<LazyLayoutAnimation?>()
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemScopeImpl.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemScopeImpl.kt
index ea43720..bfda57e 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemScopeImpl.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridItemScopeImpl.kt
@@ -18,53 +18,32 @@
 
 import androidx.compose.animation.core.FiniteAnimationSpec
 import androidx.compose.foundation.ExperimentalFoundationApi
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimationSpecsNode
 import androidx.compose.ui.Modifier
-import androidx.compose.ui.node.DelegatingNode
 import androidx.compose.ui.node.ModifierNodeElement
-import androidx.compose.ui.node.ParentDataModifierNode
 import androidx.compose.ui.platform.InspectorInfo
-import androidx.compose.ui.unit.Density
 import androidx.compose.ui.unit.IntOffset
 
 @OptIn(ExperimentalFoundationApi::class)
 internal object LazyGridItemScopeImpl : LazyGridItemScope {
     @ExperimentalFoundationApi
     override fun Modifier.animateItemPlacement(animationSpec: FiniteAnimationSpec<IntOffset>) =
-        this then AnimateItemPlacementElement(animationSpec)
+        this then AnimateItemElement(animationSpec)
 }
 
-private class AnimateItemPlacementElement(
-    val animationSpec: FiniteAnimationSpec<IntOffset>
-) : ModifierNodeElement<AnimateItemPlacementNode>() {
+private data class AnimateItemElement(
+    val placementSpec: FiniteAnimationSpec<IntOffset>
+) : ModifierNodeElement<LazyLayoutAnimationSpecsNode>() {
 
-    override fun create(): AnimateItemPlacementNode = AnimateItemPlacementNode(animationSpec)
+    override fun create(): LazyLayoutAnimationSpecsNode =
+        LazyLayoutAnimationSpecsNode(null, placementSpec)
 
-    override fun update(node: AnimateItemPlacementNode) {
-        node.delegatingNode.placementSpec = animationSpec
-    }
-
-    override fun equals(other: Any?): Boolean {
-        if (this === other) return true
-        if (other !is AnimateItemPlacementElement) return false
-        return animationSpec != other.animationSpec
-    }
-
-    override fun hashCode(): Int {
-        return animationSpec.hashCode()
+    override fun update(node: LazyLayoutAnimationSpecsNode) {
+        node.placementSpec = placementSpec
     }
 
     override fun InspectorInfo.inspectableProperties() {
         name = "animateItemPlacement"
-        value = animationSpec
+        value = placementSpec
     }
 }
-
-private class AnimateItemPlacementNode(
-    animationSpec: FiniteAnimationSpec<IntOffset>
-) : DelegatingNode(), ParentDataModifierNode {
-
-    val delegatingNode = delegate(LazyLayoutAnimateItemModifierNode(null, animationSpec))
-
-    override fun Density.modifyParentData(parentData: Any?): Any = delegatingNode
-}
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasure.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasure.kt
index f768d21..07c9edb 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasure.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasure.kt
@@ -33,6 +33,7 @@
 import kotlin.math.min
 import kotlin.math.roundToInt
 import kotlin.math.sign
+import kotlinx.coroutines.CoroutineScope
 
 /**
  * Measures and calculates the positions for the currently visible items. The result is produced
@@ -59,6 +60,7 @@
     placementAnimator: LazyGridItemPlacementAnimator,
     spanLayoutProvider: LazyGridSpanLayoutProvider,
     pinnedItems: List<Int>,
+    coroutineScope: CoroutineScope,
     layout: (Int, Int, Placeable.PlacementScope.() -> Unit) -> MeasureResult
 ): LazyGridMeasureResult {
     require(beforeContentPadding >= 0) { "negative beforeContentPadding" }
@@ -270,7 +272,8 @@
             positionedItems = positionedItems,
             itemProvider = measuredItemProvider,
             spanLayoutProvider = spanLayoutProvider,
-            isVertical = isVertical
+            isVertical = isVertical,
+            coroutineScope = coroutineScope
         )
 
         return LazyGridMeasureResult(
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasuredItem.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasuredItem.kt
index 3b698b1..961341b 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasuredItem.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/grid/LazyGridMeasuredItem.kt
@@ -16,7 +16,6 @@
 
 package androidx.compose.foundation.lazy.grid
 
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
 import androidx.compose.ui.layout.Placeable
 import androidx.compose.ui.unit.IntOffset
 import androidx.compose.ui.unit.IntSize
@@ -47,7 +46,8 @@
      * value passed into the place() call.
      */
     private val visualOffset: IntOffset,
-    override val contentType: Any?
+    override val contentType: Any?,
+    private val animator: LazyGridItemPlacementAnimator
 ) : LazyGridItemInfo {
     /**
      * Main axis size of the item - the max main axis size of the placeables.
@@ -133,14 +133,14 @@
             val maxOffset = maxMainAxisOffset
 
             var offset = offset
-            val animateNode = getParentData(index) as? LazyLayoutAnimateItemModifierNode
-            if (animateNode != null) {
-                val animatedOffset = offset + animateNode.placementDelta
+            val animation = animator.getAnimation(key, index)
+            if (animation != null) {
+                val animatedOffset = offset + animation.placementDelta
                 // cancel the animation if current and target offsets are both out of the bounds.
                 if ((offset.mainAxis <= minOffset && animatedOffset.mainAxis <= minOffset) ||
                     (offset.mainAxis >= maxOffset && animatedOffset.mainAxis >= maxOffset)
                 ) {
-                    animateNode.cancelPlacementAnimation()
+                    animation.cancelPlacementAnimation()
                 }
                 offset = animatedOffset
             }
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutAnimateItemModifierNode.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutAnimation.kt
similarity index 84%
rename from compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutAnimateItemModifierNode.kt
rename to compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutAnimation.kt
index aeb8a58..0dd7231 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutAnimateItemModifierNode.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/layout/LazyLayoutAnimation.kt
@@ -29,14 +29,19 @@
 import androidx.compose.runtime.setValue
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.graphics.GraphicsLayerScope
+import androidx.compose.ui.node.ParentDataModifierNode
+import androidx.compose.ui.unit.Density
 import androidx.compose.ui.unit.IntOffset
 import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.launch
 
-internal class LazyLayoutAnimateItemModifierNode(
-    var appearanceSpec: FiniteAnimationSpec<Float>?,
-    var placementSpec: FiniteAnimationSpec<IntOffset>?
-) : Modifier.Node(), (GraphicsLayerScope) -> Unit {
+internal class LazyLayoutAnimation(
+    val coroutineScope: CoroutineScope
+) : (GraphicsLayerScope) -> Unit {
+
+    var appearanceSpec: FiniteAnimationSpec<Float>? = null
+    var placementSpec: FiniteAnimationSpec<IntOffset>? = null
 
     /**
      * Returns true when the placement animation is currently in progress so the parent
@@ -153,13 +158,22 @@
         }
     }
 
-    override fun onDetach() {
+    fun stopAnimations() {
+        if (isPlacementAnimationInProgress) {
+            isPlacementAnimationInProgress = false
+            coroutineScope.launch {
+                placementDeltaAnimation.stop()
+            }
+        }
+        if (isAppearanceAnimationInProgress) {
+            isAppearanceAnimationInProgress = false
+            coroutineScope.launch {
+                visibilityAnimation.stop()
+            }
+        }
         placementDelta = IntOffset.Zero
-        isPlacementAnimationInProgress = false
         rawOffset = NotInitialized
         visibility = 1f
-        isAppearanceAnimationInProgress = false
-        // animations will be canceled because coroutineScope will be canceled.
     }
 
     override fun invoke(scope: GraphicsLayerScope) {
@@ -171,6 +185,13 @@
     }
 }
 
+internal class LazyLayoutAnimationSpecsNode(
+    var appearanceSpec: FiniteAnimationSpec<Float>?,
+    var placementSpec: FiniteAnimationSpec<IntOffset>?
+) : Modifier.Node(), ParentDataModifierNode {
+    override fun Density.modifyParentData(parentData: Any?): Any = this@LazyLayoutAnimationSpecsNode
+}
+
 /**
  * We switch to this spec when a duration based animation is being interrupted.
  */
@@ -178,3 +199,8 @@
     stiffness = Spring.StiffnessMediumLow,
     visibilityThreshold = IntOffset.VisibilityThreshold
 )
+
+/**
+ * Block on [GraphicsLayerScope] which applies the default layer parameters.
+ */
+internal val DefaultLayerBlock: GraphicsLayerScope.() -> Unit = {}
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGrid.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGrid.kt
index 8eb1453..f92fcc3 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGrid.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGrid.kt
@@ -28,6 +28,7 @@
 import androidx.compose.foundation.lazy.layout.lazyLayoutSemantics
 import androidx.compose.foundation.overscroll
 import androidx.compose.runtime.Composable
+import androidx.compose.runtime.rememberCoroutineScope
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.platform.LocalLayoutDirection
 import androidx.compose.ui.unit.Constraints
@@ -64,6 +65,7 @@
     val overscrollEffect = ScrollableDefaults.overscrollEffect()
 
     val itemProviderLambda = rememberStaggeredGridItemProviderLambda(state, content)
+    val coroutineScope = rememberCoroutineScope()
     val measurePolicy = rememberStaggeredGridMeasurePolicy(
         state,
         itemProviderLambda,
@@ -72,6 +74,7 @@
         orientation,
         mainAxisSpacing,
         crossAxisSpacing,
+        coroutineScope,
         slots,
     )
     val semanticState = rememberLazyStaggeredGridSemanticState(state, reverseLayout)
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemPlacementAnimator.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemPlacementAnimator.kt
index b46ca41..36d466d 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemPlacementAnimator.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemPlacementAnimator.kt
@@ -16,11 +16,13 @@
 
 package androidx.compose.foundation.lazy.staggeredgrid
 
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimation
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimationSpecsNode
 import androidx.compose.foundation.lazy.layout.LazyLayoutKeyIndexMap
 import androidx.compose.ui.unit.IntOffset
 import androidx.compose.ui.util.fastAny
 import androidx.compose.ui.util.fastForEach
+import kotlinx.coroutines.CoroutineScope
 
 /**
  * Handles the item placement animations when it is set via
@@ -58,7 +60,8 @@
         positionedItems: MutableList<LazyStaggeredGridMeasuredItem>,
         itemProvider: LazyStaggeredGridMeasureProvider,
         isVertical: Boolean,
-        laneCount: Int
+        laneCount: Int,
+        coroutineScope: CoroutineScope
     ) {
         if (!positionedItems.fastAny { it.hasAnimations } && keyToItemInfoMap.isEmpty()) {
             // no animations specified - no work needed
@@ -90,8 +93,9 @@
                 val itemInfo = keyToItemInfoMap[item.key]
                 // there is no state associated with this item yet
                 if (itemInfo == null) {
-                    keyToItemInfoMap[item.key] =
-                        ItemInfo(item.lane, item.span, item.crossAxisOffset)
+                    val newItemInfo = ItemInfo(item.lane, item.span, item.crossAxisOffset)
+                    newItemInfo.updateAnimation(item, coroutineScope)
+                    keyToItemInfoMap[item.key] = newItemInfo
                     val previousIndex = previousKeyToIndexMap.getIndex(item.key)
                     if (previousIndex != -1 && item.index != previousIndex) {
                         if (previousIndex < previousFirstVisibleIndex) {
@@ -101,15 +105,18 @@
                             movingInFromEndBound.add(item)
                         }
                     } else {
-                        initializeNode(
+                        initializeAnimation(
                             item,
-                            item.offset.let { if (item.isVertical) it.y else it.x }
+                            item.offset.let { if (item.isVertical) it.y else it.x },
+                            newItemInfo
                         )
                     }
                 } else {
-                    item.forEachNode {
-                        if (it.rawOffset != LazyLayoutAnimateItemModifierNode.NotInitialized) {
-                            it.rawOffset += scrollOffset
+                    itemInfo.animations.forEach { animation ->
+                        if (animation != null &&
+                            animation.rawOffset != LazyLayoutAnimation.NotInitialized
+                        ) {
+                            animation.rawOffset += scrollOffset
                         }
                     }
                     itemInfo.lane = item.lane
@@ -129,7 +136,7 @@
             movingInFromStartBound.fastForEach { item ->
                 accumulatedOffsetPerLane[item.lane] += item.mainAxisSize
                 val mainAxisOffset = 0 - accumulatedOffsetPerLane[item.lane]
-                initializeNode(item, mainAxisOffset)
+                initializeAnimation(item, mainAxisOffset)
                 startAnimationsIfNeeded(item)
             }
             accumulatedOffsetPerLane.fill(0)
@@ -139,7 +146,7 @@
             movingInFromEndBound.fastForEach { item ->
                 val mainAxisOffset = mainAxisLayoutSize + accumulatedOffsetPerLane[item.lane]
                 accumulatedOffsetPerLane[item.lane] += item.mainAxisSize
-                initializeNode(item, mainAxisOffset)
+                initializeAnimation(item, mainAxisOffset)
                 startAnimationsIfNeeded(item)
             }
             accumulatedOffsetPerLane.fill(0)
@@ -159,13 +166,8 @@
                     SpanRange(itemInfo.lane, itemInfo.span)
                 )
                 // check if we have any active placement animation on the item
-                var inProgress = false
-                repeat(item.placeablesCount) {
-                    if (item.getParentData(it).node?.isPlacementAnimationInProgress == true) {
-                        inProgress = true
-                        return@repeat
-                    }
-                }
+                val inProgress =
+                    itemInfo.animations.any { it?.isPlacementAnimationInProgress == true }
                 if ((!inProgress && newIndex == previousKeyToIndexMap.getIndex(key))) {
                     keyToItemInfoMap.remove(key)
                 } else {
@@ -221,9 +223,10 @@
         firstVisibleIndex = -1
     }
 
-    private fun initializeNode(
+    private fun initializeAnimation(
         item: LazyStaggeredGridMeasuredItem,
-        mainAxisOffset: Int
+        mainAxisOffset: Int,
+        itemInfo: ItemInfo = keyToItemInfoMap.getValue(item.key)
     ) {
         val firstPlaceableOffset = item.offset
 
@@ -234,45 +237,84 @@
         }
 
         // initialize offsets
-        item.forEachNode { node ->
-            val diffToFirstPlaceableOffset =
-                item.offset - firstPlaceableOffset
-            node.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
+        itemInfo.animations.forEach { animation ->
+            if (animation != null) {
+                val diffToFirstPlaceableOffset =
+                    item.offset - firstPlaceableOffset
+                animation.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
+            }
         }
     }
 
     private fun startAnimationsIfNeeded(item: LazyStaggeredGridMeasuredItem) {
-        item.forEachNode { node ->
-            val newTarget = item.offset
-            val currentTarget = node.rawOffset
-            if (currentTarget != LazyLayoutAnimateItemModifierNode.NotInitialized &&
-                currentTarget != newTarget
-            ) {
-                node.animatePlacementDelta(newTarget - currentTarget)
+        val itemInfo = keyToItemInfoMap.getValue(item.key)
+        itemInfo.animations.forEach { animation ->
+            if (animation != null) {
+                val newTarget = item.offset
+                val currentTarget = animation.rawOffset
+                if (currentTarget != LazyLayoutAnimation.NotInitialized &&
+                    currentTarget != newTarget
+                ) {
+                    animation.animatePlacementDelta(newTarget - currentTarget)
+                }
+                animation.rawOffset = newTarget
             }
-            node.rawOffset = newTarget
         }
     }
 
-    private val Any?.node get() = this as? LazyLayoutAnimateItemModifierNode
+    fun getAnimation(key: Any, placeableIndex: Int): LazyLayoutAnimation? =
+        keyToItemInfoMap[key]?.animations?.get(placeableIndex)
 
     private val LazyStaggeredGridMeasuredItem.hasAnimations: Boolean
         get() {
-            forEachNode { return true }
+            repeat(placeablesCount) { index ->
+                getParentData(index).specs?.let {
+                    // found at least one
+                    return true
+                }
+            }
             return false
         }
-
-    private inline fun LazyStaggeredGridMeasuredItem.forEachNode(
-        block: (LazyLayoutAnimateItemModifierNode) -> Unit
-    ) {
-        repeat(placeablesCount) { index ->
-            getParentData(index).node?.let(block)
-        }
-    }
 }
 
 private class ItemInfo(
     var lane: Int,
     var span: Int,
     var crossAxisOffset: Int
-)
+) {
+    /**
+     * This array will have the same amount of elements as there are placeables on the item.
+     * If the element is not null this means there are specs associated with the given placeable.
+     */
+    var animations = EmptyArray
+        private set
+
+    fun updateAnimation(
+        positionedItem: LazyStaggeredGridMeasuredItem,
+        coroutineScope: CoroutineScope
+    ) {
+        for (i in positionedItem.placeablesCount until animations.size) {
+            animations[i]?.stopAnimations()
+        }
+        if (animations.size != positionedItem.placeablesCount) {
+            animations = animations.copyOf(positionedItem.placeablesCount)
+        }
+        repeat(positionedItem.placeablesCount) { index ->
+            val specs = positionedItem.getParentData(index).specs
+            if (specs == null) {
+                animations[index]?.stopAnimations()
+                animations[index] = null
+            } else {
+                val item = animations[index] ?: LazyLayoutAnimation(coroutineScope).also {
+                    animations[index] = it
+                }
+                item.appearanceSpec = specs.appearanceSpec
+                item.placementSpec = specs.placementSpec
+            }
+        }
+    }
+}
+
+private val Any?.specs get() = this as? LazyLayoutAnimationSpecsNode
+
+private val EmptyArray = emptyArray<LazyLayoutAnimation?>()
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemScope.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemScope.kt
index 1ff041d..6f35395 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemScope.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridItemScope.kt
@@ -21,14 +21,11 @@
 import androidx.compose.animation.core.VisibilityThreshold
 import androidx.compose.animation.core.spring
 import androidx.compose.foundation.ExperimentalFoundationApi
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
+import androidx.compose.foundation.lazy.layout.LazyLayoutAnimationSpecsNode
 import androidx.compose.runtime.Stable
 import androidx.compose.ui.Modifier
-import androidx.compose.ui.node.DelegatingNode
 import androidx.compose.ui.node.ModifierNodeElement
-import androidx.compose.ui.node.ParentDataModifierNode
 import androidx.compose.ui.platform.InspectorInfo
-import androidx.compose.ui.unit.Density
 import androidx.compose.ui.unit.IntOffset
 
 /**
@@ -62,40 +59,22 @@
 internal object LazyStaggeredGridItemScopeImpl : LazyStaggeredGridItemScope {
     @ExperimentalFoundationApi
     override fun Modifier.animateItemPlacement(animationSpec: FiniteAnimationSpec<IntOffset>) =
-        this then AnimateItemPlacementElement(animationSpec)
+        this then AnimateItemElement(animationSpec)
 }
 
-private class AnimateItemPlacementElement(
-    val animationSpec: FiniteAnimationSpec<IntOffset>
-) : ModifierNodeElement<AnimateItemPlacementNode>() {
+private data class AnimateItemElement(
+    val placementSpec: FiniteAnimationSpec<IntOffset>
+) : ModifierNodeElement<LazyLayoutAnimationSpecsNode>() {
 
-    override fun create(): AnimateItemPlacementNode = AnimateItemPlacementNode(animationSpec)
+    override fun create(): LazyLayoutAnimationSpecsNode =
+        LazyLayoutAnimationSpecsNode(null, placementSpec)
 
-    override fun update(node: AnimateItemPlacementNode) {
-        node.delegatingNode.placementSpec = animationSpec
-    }
-
-    override fun equals(other: Any?): Boolean {
-        if (this === other) return true
-        if (other !is AnimateItemPlacementElement) return false
-        return animationSpec != other.animationSpec
-    }
-
-    override fun hashCode(): Int {
-        return animationSpec.hashCode()
+    override fun update(node: LazyLayoutAnimationSpecsNode) {
+        node.placementSpec = placementSpec
     }
 
     override fun InspectorInfo.inspectableProperties() {
         name = "animateItemPlacement"
-        value = animationSpec
+        value = placementSpec
     }
 }
-
-private class AnimateItemPlacementNode(
-    animationSpec: FiniteAnimationSpec<IntOffset>
-) : DelegatingNode(), ParentDataModifierNode {
-
-    val delegatingNode = delegate(LazyLayoutAnimateItemModifierNode(null, animationSpec))
-
-    override fun Density.modifyParentData(parentData: Any?): Any = delegatingNode
-}
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasure.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasure.kt
index d1b13f3..5765b23 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasure.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasure.kt
@@ -18,7 +18,6 @@
 
 import androidx.compose.foundation.ExperimentalFoundationApi
 import androidx.compose.foundation.fastMaxOfOrNull
-import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
 import androidx.compose.foundation.lazy.layout.LazyLayoutKeyIndexMap
 import androidx.compose.foundation.lazy.layout.LazyLayoutMeasureScope
 import androidx.compose.foundation.lazy.staggeredgrid.LazyStaggeredGridLaneInfo.Companion.FullSpan
@@ -39,6 +38,7 @@
 import kotlin.math.min
 import kotlin.math.roundToInt
 import kotlin.math.sign
+import kotlinx.coroutines.CoroutineScope
 
 private const val DebugLoggingEnabled = false
 
@@ -88,6 +88,7 @@
     mainAxisSpacing: Int,
     beforeContentPadding: Int,
     afterContentPadding: Int,
+    coroutineScope: CoroutineScope,
 ): LazyStaggeredGridMeasureResult {
     val context = LazyStaggeredGridMeasureContext(
         state = state,
@@ -103,6 +104,7 @@
         reverseLayout = reverseLayout,
         mainAxisSpacing = mainAxisSpacing,
         measureScope = this,
+        coroutineScope = coroutineScope
     )
 
     val initialItemIndices: IntArray
@@ -185,6 +187,7 @@
     val afterContentPadding: Int,
     val reverseLayout: Boolean,
     val mainAxisSpacing: Int,
+    val coroutineScope: CoroutineScope
 ) {
     val measuredItemProvider = object : LazyStaggeredGridMeasureProvider(
         isVertical = isVertical,
@@ -209,7 +212,8 @@
             span = span,
             beforeContentPadding = beforeContentPadding,
             afterContentPadding = afterContentPadding,
-            contentType = contentType
+            contentType = contentType,
+            animator = state.placementAnimator
         )
     }
 
@@ -802,7 +806,8 @@
             positionedItems = positionedItems,
             itemProvider = measuredItemProvider,
             isVertical = isVertical,
-            laneCount = laneCount
+            laneCount = laneCount,
+            coroutineScope = coroutineScope
         )
 
         // end placement
@@ -1062,7 +1067,8 @@
     val span: Int,
     private val beforeContentPadding: Int,
     private val afterContentPadding: Int,
-    override val contentType: Any?
+    override val contentType: Any?,
+    private val animator: LazyStaggeredGridItemPlacementAnimator
 ) : LazyStaggeredGridItemInfo {
     var isVisible = true
 
@@ -1120,14 +1126,14 @@
                 val maxOffset = maxMainAxisOffset
 
                 var offset = offset
-                val animateNode = getParentData(index) as? LazyLayoutAnimateItemModifierNode
-                if (animateNode != null) {
-                    val animatedOffset = offset + animateNode.placementDelta
+                val animation = animator.getAnimation(key, index)
+                if (animation != null) {
+                    val animatedOffset = offset + animation.placementDelta
                     // cancel the animation if current and target offsets are both out of the bounds.
                     if ((offset.mainAxis <= minOffset && animatedOffset.mainAxis <= minOffset) ||
                         (offset.mainAxis >= maxOffset && animatedOffset.mainAxis >= maxOffset)
                     ) {
-                        animateNode.cancelPlacementAnimation()
+                        animation.cancelPlacementAnimation()
                     }
                     offset = animatedOffset
                 }
diff --git a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasurePolicy.kt b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasurePolicy.kt
index 2ba719b93..c0ece02 100644
--- a/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasurePolicy.kt
+++ b/compose/foundation/foundation/src/commonMain/kotlin/androidx/compose/foundation/lazy/staggeredgrid/LazyStaggeredGridMeasurePolicy.kt
@@ -33,6 +33,7 @@
 import androidx.compose.ui.unit.LayoutDirection
 import androidx.compose.ui.unit.constrainHeight
 import androidx.compose.ui.unit.constrainWidth
+import kotlinx.coroutines.CoroutineScope
 
 @OptIn(ExperimentalFoundationApi::class)
 @Composable
@@ -44,6 +45,7 @@
     orientation: Orientation,
     mainAxisSpacing: Dp,
     crossAxisSpacing: Dp,
+    coroutineScope: CoroutineScope,
     slots: Density.(Constraints) -> LazyStaggeredGridSlots
 ): LazyLayoutMeasureScope.(Constraints) -> LazyStaggeredGridMeasureResult = remember(
     state,
@@ -116,6 +118,7 @@
             reverseLayout = reverseLayout,
             beforeContentPadding = beforeContentPadding,
             afterContentPadding = afterContentPadding,
+            coroutineScope = coroutineScope
         ).also {
             state.applyMeasureResult(it)
         }