[go: nahoru, domu]

Skip to content

Commit

Permalink
[XlaCallModule] Allow i64 platform index arguments.
Browse files Browse the repository at this point in the history
Previously, for multi-platform serialization the platform index
argument was required to be an i32. Now we allow also
i64, just like we do for dimension variables. This flexibility
is useful for JAX when running in 64-bit mode.

PiperOrigin-RevId: 573843239
  • Loading branch information
gnecula authored and tensorflower-gardener committed Oct 16, 2023
1 parent bab9c15 commit 2758a19
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 48 deletions.
3 changes: 3 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,7 @@ tf_xla_py_strict_test(
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

Expand Down Expand Up @@ -2820,6 +2821,7 @@ tf_xla_py_strict_test(
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

Expand Down Expand Up @@ -2851,6 +2853,7 @@ tf_xla_py_strict_test(
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

Expand Down
71 changes: 42 additions & 29 deletions tensorflow/compiler/tests/xla_call_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Sequence
import unittest

from absl.testing import parameterized
import numpy as np

from tensorflow.compiler.mlir.stablehlo import stablehlo
Expand All @@ -41,7 +42,7 @@ def serialize(module_str: str) -> tuple[str, int]:
return byte_str, xla.call_module_maximum_supported_version()


class XlaCallModuleOpTest(xla_test.XLATestCase):
class XlaCallModuleOpTest(xla_test.XLATestCase, parameterized.TestCase):

def _assertOpOutputMatchesExpected(self,
op,
Expand Down Expand Up @@ -212,23 +213,29 @@ def f(x): # x: f32[2, b]

self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1]))

def test_poly_basic(self):
@parameterized.named_parameters(
dict(testcase_name='_' + dim_var_type,
dim_var_type=dim_var_type)
for dim_var_type in ('i32', 'i64')
)
def test_poly_basic(self, *, dim_var_type: str):
x = np.arange(6, dtype=np.float32).reshape((2, 3))

def f(x): # x: f32[2, b]
# (sin(x), x.shape[1])
module, version = serialize("""
module @jit_f.0 attributes {jax.uses_shape_polymorphism = true} {
func.func public @main(%arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>) {
%arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 1 : i64} : (tensor<2x?xf32>) -> tensor<i32>
%0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>)
return %0, %1 : tensor<2x?xf32>, tensor<i32>
}
func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<i32>) {
module, version = serialize(f"""
module @jit_f.0 attributes {{jax.uses_shape_polymorphism = true}} {{
func.func public @main(%arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) {{
%arg0_new_i32 = "stablehlo.get_dimension_size"(%arg1) {{dimension = 1 : i64}} : (tensor<2x?xf32>) -> tensor<i32>
%arg0_new = stablehlo.convert %arg0_new_i32 : (tensor<i32>) -> tensor<{dim_var_type}>
%0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<{dim_var_type}>, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>)
return %0, %1 : tensor<2x?xf32>, tensor<{dim_var_type}>
}}
func.func private @dyn_main(%arg0: tensor<{dim_var_type}>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) {{
%0 = stablehlo.sine %arg1 : tensor<2x?xf32>
return %0, %arg0 : tensor<2x?xf32>, tensor<i32>
}
}
return %0, %arg0 : tensor<2x?xf32>, tensor<{dim_var_type}>
}}
}}
""")
return xla.call_module([x],
module=module, version=version,
Expand Down Expand Up @@ -308,27 +315,33 @@ def f(x, y):
):
self._assertOpOutputMatchesExpected(f, (x_bad_shape, y), (x_bad_shape,))

def test_platforms_basic(self):
@parameterized.named_parameters(
dict(testcase_name='_' + platform_idx_type,
platform_idx_type=platform_idx_type)
for platform_idx_type in ('i32', 'i64')
)
def test_platforms_basic(self, *, platform_idx_type: str):
x = np.float32(0.)

# returns x + 2. on CPU, x + 3. on GPU (CUDA or ROCM) and x + 4. on TPU
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<f32>) -> tensor<f32> {
%to_add = "stablehlo.case"(%arg_platform_idx) ({
module, version = serialize(f"""
module @jit_f.0 {{
func.func public @main(%arg_platform_idx: tensor<{platform_idx_type}>, %arg0: tensor<f32>) -> tensor<f32> {{
%0 = stablehlo.convert %arg_platform_idx : (tensor<{platform_idx_type}>) -> tensor<i32>
%to_add = "stablehlo.case"(%0) ({{
%cpu_val = stablehlo.constant dense<2.> : tensor<f32>
stablehlo.return %cpu_val : tensor<f32>
}, {
}}, {{
%gpu_val = stablehlo.constant dense<3.> : tensor<f32>
stablehlo.return %gpu_val : tensor<f32>
}, {
}}, {{
%tpu_val = stablehlo.constant dense<4.> : tensor<f32>
stablehlo.return %tpu_val : tensor<f32>
}) : (tensor<i32>) -> tensor<f32>
%0 = stablehlo.add %arg0, %to_add : tensor<f32>
return %0 : tensor<f32>
}
}
}}) : (tensor<i32>) -> tensor<f32>
%1 = stablehlo.add %arg0, %to_add : tensor<f32>
return %1 : tensor<f32>
}}
}}
""")

platforms = ['CPU', 'CUDA', 'ROCM', 'TPU']
Expand Down Expand Up @@ -494,15 +507,15 @@ def platforms_errors_no_platform_index_arg(self):
),
)

def platforms_errors_platform_index_i64(self):
module_str = self.platforms_errors_module_str.replace('i32', 'i64')
def platforms_errors_platform_index_i16(self):
module_str = self.platforms_errors_module_str.replace('i32', 'i16')
self.platforms_errors_helper(
module_str=module_str,
expected_error=errors.InvalidArgumentError,
expected_error_message=(
'Module argument at index 0 should be a 0-dimensional '
'32-bit integer-tensor platform index argument .* has type '
'tensor<i64>'
'32-bit or 64-bit integer-tensor platform index argument '
'.* has type tensor<i16>'
),
)

Expand Down
41 changes: 22 additions & 19 deletions tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,19 @@ tsl::Status SetPlatformIndex(mlir::func::FuncOp main, int platform_index) {
mlir::RankedTensorType arg_ranked_type =
platform_index_arg.getType().dyn_cast<mlir::RankedTensorType>();
if (!arg_ranked_type || arg_ranked_type.getRank() != 0 ||
!arg_ranked_type.getElementType().isSignlessInteger(32)) {
!(arg_ranked_type.getElementType().isSignlessInteger(32) ||
arg_ranked_type.getElementType().isSignlessInteger(64))) {
return absl::InvalidArgumentError(
absl::StrCat("Module argument at index 0 should be a 0-dimensional "
"32-bit integer-tensor platform index argument but "
"has type ",
"32-bit or 64-bit integer-tensor platform index argument "
"but has type ",
mlir::debugString(platform_index_arg.getType())));
}
bool is_32_bit = arg_ranked_type.getElementType().isSignlessInteger(32);
auto const_attr = is_32_bit ? op_builder.getI32IntegerAttr(platform_index)
: op_builder.getI64IntegerAttr(platform_index);
auto platform_index_op = op_builder.create<mlir::stablehlo::ConstantOp>(
platform_index_arg.getLoc(),
op_builder.getI32IntegerAttr(platform_index));
platform_index_arg.getLoc(), const_attr);
platform_index_arg.replaceAllUsesWith(platform_index_op);

main.eraseArgument(0);
Expand Down Expand Up @@ -218,8 +221,9 @@ tsl::StatusOr<std::unique_ptr<XlaCallModuleLoader>> XlaCallModuleLoader::Create(
// for %arg_dim0 and one for %arg_dim1. E.g., ['0.0', '0.1'] specifies that
// %arg_dim0 should be set to the size of axis 0 or array argument 0 (%arg0),
// while %arg_dim1 should be set to the size of axis 1.
// The platform index argument must be a 0-dimensional 32-bit integer, and the
// dimension arguments must be 0-dimensional tensors of integer type.
// The platform index argument must be a 0-dimensional 32-bit or 64-bit integer,
// and the dimension arguments must be 0-dimensional 32-bit or 64-bit integer
// tensors.
//
// We create a new "main" function as follows:
// func public main(%arg0: f32[?, ?, 8]) {
Expand Down Expand Up @@ -290,25 +294,24 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() {
arg_type.dyn_cast<mlir::RankedTensorType>();
if (!arg_ranked_type ||
!arg_ranked_type.getElementType().dyn_cast<mlir::IntegerType>() ||
!arg_ranked_type.getShape().empty()) {
!arg_ranked_type.getShape().empty() ||
!(arg_ranked_type.getElementTypeBitWidth() == 32 ||
arg_ranked_type.getElementTypeBitWidth() == 64)) {
std::string argument_type =
(i < nr_platform_args) ? "platform index" : "dimension";
return absl::InvalidArgumentError(absl::StrCat(
"Module argument at index ", i,
" should be a 0-dimensional integer-tensor ", argument_type,
" argument but has type ", mlir::debugString(arg_type)));
" should be a 0-dimensional 32-bit or 64-bit integer-tensor ",
argument_type, " argument but has type ",
mlir::debugString(arg_type)));
}
if (i < nr_platform_args) {
if (arg_ranked_type.getElementTypeBitWidth() != 32) {
return absl::InvalidArgumentError(
absl::StrCat("Module argument at index ", i,
" should be a 0-dimensional 32-bit integer-tensor"
" platform index argument but has type ",
mlir::debugString(arg_type)));
}
bool is_32_bit = arg_ranked_type.getElementType().isSignlessInteger(32);
auto const_attr = is_32_bit
? op_builder.getI32IntegerAttr(platform_index_)
: op_builder.getI64IntegerAttr(platform_index_);
call_args[i] = op_builder.create<mlir::stablehlo::ConstantOp>(
block_args[0].getLoc(),
op_builder.getI32IntegerAttr(platform_index_));
block_args[0].getLoc(), const_attr);
} else {
TF_ASSIGN_OR_RETURN(
call_args[i],
Expand Down

0 comments on commit 2758a19

Please sign in to comment.