[go: nahoru, domu]

Skip to content

Commit

Permalink
Reverts changelist 578813627
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643924789
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 17, 2024
1 parent 5cd6ae2 commit 658f706
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true);
opts.set_xla_gpu_llvm_verification_level(0);
opts.set_xla_gpu_target_config_filename("");
opts.set_xla_gpu_enable_cub_radix_sort(true);
opts.set_xla_gpu_enable_cub_radix_sort(false);
opts.set_xla_gpu_enable_cudnn_layer_norm(false);
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000);

Expand Down
16 changes: 14 additions & 2 deletions third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ bool HloWasRewrittenToUseCubSort(const HloModule& module) {

class CubSortKeysTest : public HloTestBase,
public ::testing::WithParamInterface<
std::tuple<PrimitiveType, bool, int>> {};
std::tuple<PrimitiveType, bool, int>> {
DebugOptions GetDebugOptionsForTest() override {
auto options = HloTestBase::GetDebugOptionsForTest();
options.set_xla_gpu_enable_cub_radix_sort(true);
return options;
}
};

TEST_P(CubSortKeysTest, CompareToReference) {
int batch_size = std::get<2>(GetParam());
Expand Down Expand Up @@ -96,7 +102,13 @@ INSTANTIATE_TEST_SUITE_P(
class CubSortPairsTest
: public HloTestBase,
public ::testing::WithParamInterface<
std::tuple<PrimitiveType, PrimitiveType, bool, int>> {};
std::tuple<PrimitiveType, PrimitiveType, bool, int>> {
DebugOptions GetDebugOptionsForTest() override {
auto options = HloTestBase::GetDebugOptionsForTest();
options.set_xla_gpu_enable_cub_radix_sort(true);
return options;
}
};

TEST_P(CubSortPairsTest, CompareToReference) {
int batch_size = std::get<3>(GetParam());
Expand Down

0 comments on commit 658f706

Please sign in to comment.