[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Add num_warps to BlockLevelFusionConfig and a method to con…
Browse files Browse the repository at this point in the history
…vert the struct to proto.

PiperOrigin-RevId: 644782060
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Jun 19, 2024
1 parent 4347a69 commit 1d4b49f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/backend_configs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ message BlockLevelFusionConfig {
// The output tile sizes of the associated instruction. The length of this
// field is expected to be the rank of the output shape.
repeated int64 output_tile_sizes = 1;

// The number of warps to use for the kernel.
int64 num_warps = 2;
}

message FusionBackendConfig {
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,6 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/lib/gtl:iterator_range",
],
)
Expand All @@ -646,6 +645,7 @@ xla_cc_test(
srcs = ["tiled_hlo_computation_test.cc"],
deps = [
":tiled_hlo_computation",
"//xla/service/gpu:backend_configs_cc",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
14 changes: 12 additions & 2 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct BlockLevelParameters {
std::vector<int64_t> output_tile_sizes;

// Triton-specific parameters.
int num_warps = 1;
int64_t num_warps = 1;
int num_ctas = 1;
int num_stages = 1;

Expand All @@ -46,7 +46,17 @@ struct BlockLevelParameters {
return BlockLevelParameters{
/*output_tile_sizes=*/
std::vector<int64_t>(config.output_tile_sizes().begin(),
config.output_tile_sizes().end())};
config.output_tile_sizes().end()),
/*num_warps=*/config.num_warps()};
}

// Returns a BlockLevelFusionConfig proto from a BlockLevelParameters struct.
BlockLevelFusionConfig ToBlockLevelFusionConfig() const {
BlockLevelFusionConfig config;
config.mutable_output_tile_sizes()->Add(output_tile_sizes.begin(),
output_tile_sizes.end());
config.set_num_warps(num_warps);
return config;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "xla/service/gpu/backend_configs.pb.h"

namespace xla {
namespace gpu {
Expand All @@ -29,11 +30,27 @@ TEST(BlockLevelParametersTest,
BlockLevelFusionConfig block_level_fusion_config;
block_level_fusion_config.mutable_output_tile_sizes()->Add(18);
block_level_fusion_config.mutable_output_tile_sizes()->Add(19);
block_level_fusion_config.set_num_warps(12);

EXPECT_THAT(BlockLevelParameters::FromBlockLevelFusionConfig(
block_level_fusion_config)
.output_tile_sizes,
BlockLevelParameters block_level_parameters =
BlockLevelParameters::FromBlockLevelFusionConfig(
block_level_fusion_config);
EXPECT_THAT(block_level_parameters.output_tile_sizes, ElementsAre(18, 19));
EXPECT_THAT(block_level_parameters.num_warps, 12);
}

TEST(BlockLevelParametersTest,
BlockLevelParametersCanBeConvertedToBlockLevelFusionConfig) {
BlockLevelParameters block_level_parameters;
block_level_parameters.output_tile_sizes = {18, 19};
block_level_parameters.num_warps = 12;

BlockLevelFusionConfig block_level_fusion_config =
block_level_parameters.ToBlockLevelFusionConfig();

EXPECT_THAT(block_level_fusion_config.output_tile_sizes(),
ElementsAre(18, 19));
EXPECT_THAT(block_level_fusion_config.num_warps(), 12);
}

} // namespace
Expand Down

0 comments on commit 1d4b49f

Please sign in to comment.