[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #59831 from ROCmSoftwarePlatform:fix_xla_call_modu…
Browse files Browse the repository at this point in the history
…le_test

PiperOrigin-RevId: 538409638
  • Loading branch information
tensorflower-gardener committed Jun 7, 2023
2 parents 42ea7ad + a5277b8 commit 4993c55
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2723,6 +2723,7 @@ tf_xla_py_strict_test(
"//tensorflow/python/framework:function",
"//tensorflow/python/framework:ops",
"//tensorflow/python/ops:array_ops",
"//tensorflow/python/platform:client_testlib",
"//third_party/py/numpy",
],
)
Expand Down
18 changes: 12 additions & 6 deletions tensorflow/compiler/tests/xla_call_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test


def serialize(module_str: str) -> Tuple[str, int]:
Expand Down Expand Up @@ -63,7 +64,10 @@ def testing_platform(self):
if self.device in ['CPU', 'XLA_CPU']:
return 'CPU'
elif self.device in ['GPU', 'XLA_GPU']:
return 'CUDA'
if test.is_built_with_rocm():
return 'ROCM'
else:
return 'CUDA'
elif self.device in ['TPU', 'XLA_TPU']:
return 'TPU'
else:
Expand Down Expand Up @@ -283,7 +287,7 @@ def f(x, y):
def test_platforms_basic(self):
x = np.float32(0.)

# returns x + 2. on CPU, x + 3. on GPU and x + 4. on TPU
# returns x + 2. on CPU, x + 3. on GPU (CUDA or ROCM) and x + 4. on TPU
module, version = serialize("""
module @jit_f.0 {
func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<f32>) -> tensor<f32> {
Expand All @@ -303,15 +307,17 @@ def test_platforms_basic(self):
}
""")

platforms = ['CPU', 'CUDA', 'TPU']
platforms = ['CPU', 'CUDA', 'ROCM', 'TPU']
def f(x):
return xla.call_module([x], version=version,
module=module,
Tout=[np.float32],
Sout=[()],
platforms=platforms)

expected_value = x + dict(CPU=2., CUDA=3., TPU=4.)[self.testing_platform()]
expected_value = (
x + dict(CPU=2.0, CUDA=3.0, ROCM=3.0, TPU=4.0)[self.testing_platform()]
)
self._assertOpOutputMatchesExpected(f, (x,), (expected_value,))

def test_platforms_errors(self):
Expand Down Expand Up @@ -358,7 +364,7 @@ def f(x):
'The current platform .* is not among the platforms'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))

platforms = ['CPU', 'CUDA']
platforms = ['CPU', 'CUDA', 'ROCM']
if self.testing_platform() not in platforms:
with self.assertRaisesRegex(
errors.NotFoundError,
Expand All @@ -369,7 +375,7 @@ def f(x):

# The module cannot have i64 %arg_platform_idx
module, version = serialize(module_str.replace('i32', 'i64'))
platforms = ['CPU', 'CUDA', 'TPU']
platforms = ['CPU', 'CUDA', 'ROCM', 'TPU']
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Module argument at index 0 should be a 0-dimensional '
Expand Down

0 comments on commit 4993c55

Please sign in to comment.