[go: nahoru, domu]

Skip to content

Commit

Permalink
Simplify XlaCallModule test for disabling platform check.
Browse files Browse the repository at this point in the history
Previously, we had a separate test target `xla_call_module_no_platform_check_test_cpu` and a separate
test file `xla_call_module_no_platform_check_test.py` to test
the effect of disabling the test with `TF_XLA_FLAGS`.

Now we keep the different test target, but we reuse the
tests in `xla_call_module_test.py` and check if
`TF_XLA_FLAGS` is specified.

PiperOrigin-RevId: 545663649
  • Loading branch information
gnecula authored and tensorflower-gardener committed Jul 5, 2023
1 parent 2462d88 commit 3ba0e0e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 92 deletions.
5 changes: 4 additions & 1 deletion tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2773,9 +2773,12 @@ tf_xla_py_strict_test(
tf_xla_py_strict_test(
name = "xla_call_module_no_platform_check_test",
size = "small",
srcs = ["xla_call_module_no_platform_check_test.py"],
srcs = ["xla_call_module_test.py"],
# cpu_ondemand overrides the TF_XLA_FLAGS
disabled_backends = ["cpu_ondemand"],
enable_mlir_bridge = False,
env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=platform"},
main = "xla_call_module_test.py",
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
Expand Down

This file was deleted.

18 changes: 14 additions & 4 deletions tensorflow/compiler/tests/xla_call_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests for XLA call module op wrapper."""
import os
from typing import Tuple
import unittest

Expand Down Expand Up @@ -363,10 +364,17 @@ def f(x):
'and 0 dimension arguments.'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))

platform_check_disabled_by_flags = (
'--tf_xla_call_module_disabled_checks=platform'
in os.getenv('TF_XLA_FLAGS', ''))
platforms = ['RANDOM_PLATFORM_1', 'RANDOM_PLATFORM_2']
with self.assertRaisesRegex(
errors.NotFoundError,
'The current platform .* is not among the platforms'):
if not platform_check_disabled_by_flags:
with self.assertRaisesRegex(
errors.NotFoundError,
'The current platform .* is not among the platforms'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
else:
# No error
self._assertOpOutputMatchesExpected(f, (x,), (x,))

# Disable the check but have two platforms
Expand All @@ -390,12 +398,14 @@ def f(x):
self._assertOpOutputMatchesExpected(f, (x,), (x,))

platforms = ['CPU', 'CUDA', 'ROCM']
if self.testing_platform() not in platforms:
if (self.testing_platform() not in platforms
and not platform_check_disabled_by_flags):
with self.assertRaisesRegex(
errors.NotFoundError,
'The current platform .* is not among the platforms'):
self._assertOpOutputMatchesExpected(f, (x,), (x,))
else:
# No error
self._assertOpOutputMatchesExpected(f, (x,), (x,))

# The module cannot have i64 %arg_platform_idx
Expand Down

0 comments on commit 3ba0e0e

Please sign in to comment.