From 2758a19a232454d3317db2181b0672753806e3f1 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 16 Oct 2023 09:53:14 -0700 Subject: [PATCH] [XlaCallModule] Allow i64 platform index arguments. Previously, for multi-platform serialization the platform index argument was required to be an i32. Now we allow also i64, just like we do for dimension variables. This flexibility is useful for JAX when running in 64-bit mode. PiperOrigin-RevId: 573843239 --- tensorflow/compiler/tests/BUILD | 3 + .../compiler/tests/xla_call_module_test.py | 71 +++++++++++-------- .../tf2xla/kernels/xla_call_module_loader.cc | 41 ++++++----- 3 files changed, 67 insertions(+), 48 deletions(-) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 368d6dd0af1a9d..612dee33d8530d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -2788,6 +2788,7 @@ tf_xla_py_strict_test( "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:test", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -2820,6 +2821,7 @@ tf_xla_py_strict_test( "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:test", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -2851,6 +2853,7 @@ tf_xla_py_strict_test( "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:test", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index a24930a7b8c846..c3017de28a70b0 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -19,6 +19,7 @@ from typing import Optional, Sequence import unittest +from absl.testing import parameterized import numpy as np from tensorflow.compiler.mlir.stablehlo import stablehlo @@ -41,7 +42,7 @@ def serialize(module_str: str) -> tuple[str, int]: return byte_str, xla.call_module_maximum_supported_version() -class XlaCallModuleOpTest(xla_test.XLATestCase): +class XlaCallModuleOpTest(xla_test.XLATestCase, parameterized.TestCase): def _assertOpOutputMatchesExpected(self, op, @@ -212,23 +213,29 @@ def f(x): # x: f32[2, b] self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) - def test_poly_basic(self): + @parameterized.named_parameters( + dict(testcase_name='_' + dim_var_type, + dim_var_type=dim_var_type) + for dim_var_type in ('i32', 'i64') + ) + def test_poly_basic(self, *, dim_var_type: str): x = np.arange(6, dtype=np.float32).reshape((2, 3)) def f(x): # x: f32[2, b] # (sin(x), x.shape[1]) - module, version = serialize(""" -module @jit_f.0 attributes {jax.uses_shape_polymorphism = true} { - func.func public @main(%arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { - %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 1 : i64} : (tensor<2x?xf32>) -> tensor - %0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) - return %0, %1 : tensor<2x?xf32>, tensor - } - func.func private @dyn_main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { + module, version = serialize(f""" +module @jit_f.0 attributes {{jax.uses_shape_polymorphism = true}} {{ + func.func public @main(%arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) {{ + %arg0_new_i32 = "stablehlo.get_dimension_size"(%arg1) {{dimension = 1 : i64}} : (tensor<2x?xf32>) -> tensor + %arg0_new = stablehlo.convert %arg0_new_i32 : (tensor) -> tensor<{dim_var_type}> + %0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<{dim_var_type}>, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) + return %0, %1 : tensor<2x?xf32>, tensor<{dim_var_type}> + }} + func.func private @dyn_main(%arg0: tensor<{dim_var_type}>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) {{ %0 = stablehlo.sine %arg1 : tensor<2x?xf32> - return %0, %arg0 : tensor<2x?xf32>, tensor - } -} + return %0, %arg0 : tensor<2x?xf32>, tensor<{dim_var_type}> + }} +}} """) return xla.call_module([x], module=module, version=version, @@ -308,27 +315,33 @@ def f(x, y): ): self._assertOpOutputMatchesExpected(f, (x_bad_shape, y), (x_bad_shape,)) - def test_platforms_basic(self): + @parameterized.named_parameters( + dict(testcase_name='_' + platform_idx_type, + platform_idx_type=platform_idx_type) + for platform_idx_type in ('i32', 'i64') + ) + def test_platforms_basic(self, *, platform_idx_type: str): x = np.float32(0.) # returns x + 2. on CPU, x + 3. on GPU (CUDA or ROCM) and x + 4. on TPU - module, version = serialize(""" -module @jit_f.0 { - func.func public @main(%arg_platform_idx: tensor, %arg0: tensor) -> tensor { - %to_add = "stablehlo.case"(%arg_platform_idx) ({ + module, version = serialize(f""" +module @jit_f.0 {{ + func.func public @main(%arg_platform_idx: tensor<{platform_idx_type}>, %arg0: tensor) -> tensor {{ + %0 = stablehlo.convert %arg_platform_idx : (tensor<{platform_idx_type}>) -> tensor + %to_add = "stablehlo.case"(%0) ({{ %cpu_val = stablehlo.constant dense<2.> : tensor stablehlo.return %cpu_val : tensor - }, { + }}, {{ %gpu_val = stablehlo.constant dense<3.> : tensor stablehlo.return %gpu_val : tensor - }, { + }}, {{ %tpu_val = stablehlo.constant dense<4.> : tensor stablehlo.return %tpu_val : tensor - }) : (tensor) -> tensor - %0 = stablehlo.add %arg0, %to_add : tensor - return %0 : tensor - } -} + }}) : (tensor) -> tensor + %1 = stablehlo.add %arg0, %to_add : tensor + return %1 : tensor + }} +}} """) platforms = ['CPU', 'CUDA', 'ROCM', 'TPU'] @@ -494,15 +507,15 @@ def platforms_errors_no_platform_index_arg(self): ), ) - def platforms_errors_platform_index_i64(self): - module_str = self.platforms_errors_module_str.replace('i32', 'i64') + def platforms_errors_platform_index_i16(self): + module_str = self.platforms_errors_module_str.replace('i32', 'i16') self.platforms_errors_helper( module_str=module_str, expected_error=errors.InvalidArgumentError, expected_error_message=( 'Module argument at index 0 should be a 0-dimensional ' - '32-bit integer-tensor platform index argument .* has type ' - 'tensor' + '32-bit or 64-bit integer-tensor platform index argument ' + '.* has type tensor' ), ) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index aa739954f0f1c0..9246ec4af8eae4 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -165,16 +165,19 @@ tsl::Status SetPlatformIndex(mlir::func::FuncOp main, int platform_index) { mlir::RankedTensorType arg_ranked_type = platform_index_arg.getType().dyn_cast(); if (!arg_ranked_type || arg_ranked_type.getRank() != 0 || - !arg_ranked_type.getElementType().isSignlessInteger(32)) { + !(arg_ranked_type.getElementType().isSignlessInteger(32) || + arg_ranked_type.getElementType().isSignlessInteger(64))) { return absl::InvalidArgumentError( absl::StrCat("Module argument at index 0 should be a 0-dimensional " - "32-bit integer-tensor platform index argument but " - "has type ", + "32-bit or 64-bit integer-tensor platform index argument " + "but has type ", mlir::debugString(platform_index_arg.getType()))); } + bool is_32_bit = arg_ranked_type.getElementType().isSignlessInteger(32); + auto const_attr = is_32_bit ? op_builder.getI32IntegerAttr(platform_index) + : op_builder.getI64IntegerAttr(platform_index); auto platform_index_op = op_builder.create( - platform_index_arg.getLoc(), - op_builder.getI32IntegerAttr(platform_index)); + platform_index_arg.getLoc(), const_attr); platform_index_arg.replaceAllUsesWith(platform_index_op); main.eraseArgument(0); @@ -218,8 +221,9 @@ tsl::StatusOr> XlaCallModuleLoader::Create( // for %arg_dim0 and one for %arg_dim1. E.g., ['0.0', '0.1'] specifies that // %arg_dim0 should be set to the size of axis 0 or array argument 0 (%arg0), // while %arg_dim1 should be set to the size of axis 1. -// The platform index argument must be a 0-dimensional 32-bit integer, and the -// dimension arguments must be 0-dimensional tensors of integer type. +// The platform index argument must be a 0-dimensional 32-bit or 64-bit integer, +// and the dimension arguments must be 0-dimensional 32-bit or 64-bit integer +// tensors. // // We create a new "main" function as follows: // func public main(%arg0: f32[?, ?, 8]) { @@ -290,25 +294,24 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() { arg_type.dyn_cast(); if (!arg_ranked_type || !arg_ranked_type.getElementType().dyn_cast() || - !arg_ranked_type.getShape().empty()) { + !arg_ranked_type.getShape().empty() || + !(arg_ranked_type.getElementTypeBitWidth() == 32 || + arg_ranked_type.getElementTypeBitWidth() == 64)) { std::string argument_type = (i < nr_platform_args) ? "platform index" : "dimension"; return absl::InvalidArgumentError(absl::StrCat( "Module argument at index ", i, - " should be a 0-dimensional integer-tensor ", argument_type, - " argument but has type ", mlir::debugString(arg_type))); + " should be a 0-dimensional 32-bit or 64-bit integer-tensor ", + argument_type, " argument but has type ", + mlir::debugString(arg_type))); } if (i < nr_platform_args) { - if (arg_ranked_type.getElementTypeBitWidth() != 32) { - return absl::InvalidArgumentError( - absl::StrCat("Module argument at index ", i, - " should be a 0-dimensional 32-bit integer-tensor" - " platform index argument but has type ", - mlir::debugString(arg_type))); - } + bool is_32_bit = arg_ranked_type.getElementType().isSignlessInteger(32); + auto const_attr = is_32_bit + ? op_builder.getI32IntegerAttr(platform_index_) + : op_builder.getI64IntegerAttr(platform_index_); call_args[i] = op_builder.create( - block_args[0].getLoc(), - op_builder.getI32IntegerAttr(platform_index_)); + block_args[0].getLoc(), const_attr); } else { TF_ASSIGN_OR_RETURN( call_args[i],