Make sure fromSavedStateHandle restores custom arrays
We need to properly cast custom paracelable arrays in
fromSaveStateHandle so that when it is restored it will not crash.
RelNote: "SafeArgs no longer crashes when attempting to restore custom
parcelable arrays after process death."
Test: ./gradlew navigation:navigation-safe-args-generator:test --rerun-tasks
Bug: 207315994
Change-Id: I618e8b5027ef6c8b95c1696eddf3bdf7dd15ac4d
(cherry picked from commit bf2c5c847f289797ea98f826d272593340cd4e63)
Merged-In:I618e8b5027ef6c8b95c1696eddf3bdf7dd15ac4d
diff --git a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt
index 57cac57..2ab4454 100644
--- a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt
+++ b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinNavWriter.kt
@@ -303,7 +303,7 @@
arg.type.typeName().copy(nullable = true)
)
beginControlFlow("if (%L.contains(%S))", savedStateParamName, arg.name)
- addStatement("%L = %L[%S]", tempVal, savedStateParamName, arg.name)
+ arg.type.addSavedStateGetStatement(this, arg, tempVal, savedStateParamName)
if (!arg.isNullable) {
beginControlFlow("if (%L == null)", tempVal)
val errorMessage = if (arg.type.allowsNullable()) {
diff --git a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt
index 67551f9..8f35ae4 100644
--- a/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt
+++ b/navigation/navigation-safe-args-generator/src/main/kotlin/androidx/navigation/safe/args/generator/kotlin/KotlinTypes.kt
@@ -155,6 +155,47 @@
)
}
+internal fun NavType.addSavedStateGetStatement(
+ builder: FunSpec.Builder,
+ arg: Argument,
+ lValue: String,
+ savedStateHandle: String
+): FunSpec.Builder = when (this) {
+ is ObjectType -> builder.apply {
+ beginControlFlow(
+ "if (%T::class.java.isAssignableFrom(%T::class.java) " +
+ "|| %T::class.java.isAssignableFrom(%T::class.java))",
+ PARCELABLE_CLASSNAME, arg.type.typeName(),
+ SERIALIZABLE_CLASSNAME, arg.type.typeName()
+ )
+ addStatement(
+ "%L = %L.get<%T>(%S)",
+ lValue, savedStateHandle, arg.type.typeName().copy(nullable = true), arg.name
+ )
+ nextControlFlow("else")
+ addStatement(
+ "throw·%T(%T::class.java.name + %S)",
+ UnsupportedOperationException::class.asTypeName(),
+ arg.type.typeName(),
+ " must implement Parcelable or Serializable or must be an Enum."
+ )
+ endControlFlow()
+ }
+ is ObjectArrayType -> builder.apply {
+ val baseType = (arg.type.typeName() as ParameterizedTypeName).typeArguments.first()
+ addStatement(
+ "%L = %L.get<Array<%T>>(%S)?.map { it as %T }?.toTypedArray()",
+ lValue, savedStateHandle, PARCELABLE_CLASSNAME, arg.name, baseType
+ )
+ }
+ else -> builder.addStatement(
+ "%L = %L[%S]",
+ lValue,
+ savedStateHandle,
+ arg.name
+ )
+}
+
internal fun NavType.addSavedStateSetStatement(
builder: FunSpec.Builder,
arg: Argument,
diff --git a/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt b/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt
index 90bfdeb..b3ac929 100644
--- a/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt
+++ b/navigation/navigation-safe-args-generator/src/test/test-data/expected/kotlin_nav_writer_test/MainFragmentArgs.kt
@@ -230,7 +230,8 @@
}
val __objectArrayArg : Array<ActivityInfo>?
if (savedStateHandle.contains("objectArrayArg")) {
- __objectArrayArg = savedStateHandle["objectArrayArg"]
+ __objectArrayArg = savedStateHandle.get<Array<Parcelable>>("objectArrayArg")?.map { it as
+ ActivityInfo }?.toTypedArray()
if (__objectArrayArg == null) {
throw IllegalArgumentException("Argument \"objectArrayArg\" is marked as non-null but was passed a null value")
}
@@ -248,13 +249,25 @@
}
val __optionalParcelable : ActivityInfo?
if (savedStateHandle.contains("optionalParcelable")) {
- __optionalParcelable = savedStateHandle["optionalParcelable"]
+ if (Parcelable::class.java.isAssignableFrom(ActivityInfo::class.java) ||
+ Serializable::class.java.isAssignableFrom(ActivityInfo::class.java)) {
+ __optionalParcelable = savedStateHandle.get<ActivityInfo?>("optionalParcelable")
+ } else {
+ throw UnsupportedOperationException(ActivityInfo::class.java.name +
+ " must implement Parcelable or Serializable or must be an Enum.")
+ }
} else {
__optionalParcelable = null
}
val __enumArg : AccessMode?
if (savedStateHandle.contains("enumArg")) {
- __enumArg = savedStateHandle["enumArg"]
+ if (Parcelable::class.java.isAssignableFrom(AccessMode::class.java) ||
+ Serializable::class.java.isAssignableFrom(AccessMode::class.java)) {
+ __enumArg = savedStateHandle.get<AccessMode?>("enumArg")
+ } else {
+ throw UnsupportedOperationException(AccessMode::class.java.name +
+ " must implement Parcelable or Serializable or must be an Enum.")
+ }
if (__enumArg == null) {
throw IllegalArgumentException("Argument \"enumArg\" is marked as non-null but was passed a null value")
}