[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Minor improvements to the CUB sort test.
Browse files Browse the repository at this point in the history
- Add a comment
- Verify that the rewrite is actually happening and CUB sort is indeed used. I got bitten by a passing test with a temporary change that didn't actually use CUB.
- Make the test size medium. Otherwise it timed out quite a lot for me.

PiperOrigin-RevId: 640875027
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 6, 2024
1 parent 12ccaf3 commit afa9a45
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,7 @@ xla_test(

xla_test(
name = "gpu_cub_sort_test",
size = "medium",
srcs = ["gpu_cub_sort_test.cc"],
backends = ["gpu"],
shard_count = 15,
Expand Down
23 changes: 23 additions & 0 deletions third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ namespace xla {
namespace gpu {
namespace {

bool HloWasRewrittenToUseCubSort(const HloModule& module) {
for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
if (pass_metadata.pass_name() == "gpu-sort-rewriter") {
return pass_metadata.module_changed();
}
}
return false;
}

// ----- Sort keys

class CubSortKeysTest : public HloTestBase,
Expand Down Expand Up @@ -60,6 +69,11 @@ ENTRY main {
kHloTpl,
primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
std::get<1>(GetParam()) ? "LT" : "GT", batch_size, segment_size);

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_hlo_module,
GetOptimizedModule(hlo_str));
EXPECT_TRUE(HloWasRewrittenToUseCubSort(*optimized_hlo_module));

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_str));
EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{0, 0}));
Expand Down Expand Up @@ -94,6 +108,10 @@ HloModule TestSortPairs
compare {
%lhs = $0[] parameter(0)
%rhs = $0[] parameter(1)
// Note that only the keys (first operand of `sort`) are sorted and the values
// (second operand of `sort`) are ignored. For the case where this sort is
// part of a TopK decomposition, this works fine, because CUB sort is stable
// and `values` are actually the unique indices, produced by an iota.
%v0 = $1[] parameter(2)
%v1 = $1[] parameter(3)
ROOT %comp = pred[] compare(%lhs, %rhs), direction=$2
Expand All @@ -110,6 +128,11 @@ ENTRY main {
primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
primitive_util::LowercasePrimitiveTypeName(std::get<1>(GetParam())),
std::get<2>(GetParam()) ? "LT" : "GT", batch_size, segment_size);

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_hlo_module,
GetOptimizedModule(hlo_str));
EXPECT_TRUE(HloWasRewrittenToUseCubSort(*optimized_hlo_module));

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_str));
EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{0, 0}));
Expand Down

0 comments on commit afa9a45

Please sign in to comment.