[go: nahoru, domu]

Make sure fromSavedStateHandle restores custom arrays

We need to properly cast custom paracelable arrays in
fromSaveStateHandle so that when it is restored it will not crash.

RelNote: "SafeArgs no longer crashes when attempting to restore custom
parcelable arrays after process death."
Test: ./gradlew navigation:navigation-safe-args-generator:test --rerun-tasks
Bug: 207315994

Change-Id: I618e8b5027ef6c8b95c1696eddf3bdf7dd15ac4d
(cherry picked from commit bf2c5c847f289797ea98f826d272593340cd4e63)
Merged-In:I618e8b5027ef6c8b95c1696eddf3bdf7dd15ac4d
diff --git a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt
index 57cac57..2ab4454 100644
--- a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt
+++ b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt
@@ -303,7 +303,7 @@
                     arg.type.typeName().copy(nullable = true)
                 )
                 beginControlFlow("if (%L.contains(%S))", savedStateParamName, arg.name)
-                addStatement("%L = %L[%S]", tempVal, savedStateParamName, arg.name)
+                arg.type.addSavedStateGetStatement(this, arg, tempVal, savedStateParamName)
                 if (!arg.isNullable) {
                     beginControlFlow("if (%L == null)", tempVal)
                     val errorMessage = if (arg.type.allowsNullable()) {
diff --git a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt
index 67551f9..8f35ae4 100644
--- a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt
+++ b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt
@@ -155,6 +155,47 @@
     )
 }
 
+internal fun NavType.addSavedStateGetStatement(
+    builder: FunSpec.Builder,
+    arg: Argument,
+    lValue: String,
+    savedStateHandle: String
+): FunSpec.Builder = when (this) {
+    is ObjectType -> builder.apply {
+        beginControlFlow(
+            "if (%T::class.java.isAssignableFrom(%T::class.java) " +
+                "|| %T::class.java.isAssignableFrom(%T::class.java))",
+            PARCELABLE_CLASSNAME, arg.type.typeName(),
+            SERIALIZABLE_CLASSNAME, arg.type.typeName()
+        )
+        addStatement(
+            "%L = %L.get<%T>(%S)",
+            lValue, savedStateHandle, arg.type.typeName().copy(nullable = true), arg.name
+        )
+        nextControlFlow("else")
+        addStatement(
+            "throw·%T(%T::class.java.name + %S)",
+            UnsupportedOperationException::class.asTypeName(),
+            arg.type.typeName(),
+            " must implement Parcelable or Serializable or must be an Enum."
+        )
+        endControlFlow()
+    }
+    is ObjectArrayType -> builder.apply {
+        val baseType = (arg.type.typeName() as ParameterizedTypeName).typeArguments.first()
+        addStatement(
+            "%L = %L.get<Array<%T>>(%S)?.map { it as %T }?.toTypedArray()",
+            lValue, savedStateHandle, PARCELABLE_CLASSNAME, arg.name, baseType
+        )
+    }
+    else -> builder.addStatement(
+        "%L = %L[%S]",
+        lValue,
+        savedStateHandle,
+        arg.name
+    )
+}
+
 internal fun NavType.addSavedStateSetStatement(
     builder: FunSpec.Builder,
     arg: Argument,
diff --git a/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt b/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt
index 90bfdeb..b3ac929 100644
--- a/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt
+++ b/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt
@@ -230,7 +230,8 @@
       }
       val __objectArrayArg : Array<ActivityInfo>?
       if (savedStateHandle.contains("objectArrayArg")) {
-        __objectArrayArg = savedStateHandle["objectArrayArg"]
+        __objectArrayArg = savedStateHandle.get<Array<Parcelable>>("objectArrayArg")?.map { it as
+            ActivityInfo }?.toTypedArray()
         if (__objectArrayArg == null) {
           throw IllegalArgumentException("Argument \"objectArrayArg\" is marked as non-null but was passed a null value")
         }
@@ -248,13 +249,25 @@
       }
       val __optionalParcelable : ActivityInfo?
       if (savedStateHandle.contains("optionalParcelable")) {
-        __optionalParcelable = savedStateHandle["optionalParcelable"]
+        if (Parcelable::class.java.isAssignableFrom(ActivityInfo::class.java) ||
+            Serializable::class.java.isAssignableFrom(ActivityInfo::class.java)) {
+          __optionalParcelable = savedStateHandle.get<ActivityInfo?>("optionalParcelable")
+        } else {
+          throw UnsupportedOperationException(ActivityInfo::class.java.name +
+              " must implement Parcelable or Serializable or must be an Enum.")
+        }
       } else {
         __optionalParcelable = null
       }
       val __enumArg : AccessMode?
       if (savedStateHandle.contains("enumArg")) {
-        __enumArg = savedStateHandle["enumArg"]
+        if (Parcelable::class.java.isAssignableFrom(AccessMode::class.java) ||
+            Serializable::class.java.isAssignableFrom(AccessMode::class.java)) {
+          __enumArg = savedStateHandle.get<AccessMode?>("enumArg")
+        } else {
+          throw UnsupportedOperationException(AccessMode::class.java.name +
+              " must implement Parcelable or Serializable or must be an Enum.")
+        }
         if (__enumArg == null) {
           throw IllegalArgumentException("Argument \"enumArg\" is marked as non-null but was passed a null value")
         }