[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Fix in-place dynamic update slice emitter for producer-cons…
Browse files Browse the repository at this point in the history
…umer fusion.

For a case when we have a potential dus-producer and bitcast-consumer fusion, the current code wouldn't see that we can use DUS emitter and will use cost model as for loop emitter.

PiperOrigin-RevId: 631391738
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed May 7, 2024
1 parent d5fd2e3 commit 2c83f81
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 38 deletions.
13 changes: 2 additions & 11 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -991,13 +991,10 @@ cc_library(
"//xla:literal",
"//xla:shape_util",
"//xla:status",
"//xla:status_macros",
"//xla:statusor",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/service:buffer_assignment",
"//xla/service:hlo_parser",
"//xla/service/llvm_ir:buffer_assignment_util",
Expand All @@ -1017,15 +1014,11 @@ cc_library(
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:ml_dtypes",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down Expand Up @@ -4434,17 +4427,15 @@ cc_library(
":backend_configs_cc",
":hlo_traversal",
":ir_emission_utils",
":launch_dimensions",
":reduction_utils",
"//xla:shape_util",
"//xla:statusor",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
],
Expand Down Expand Up @@ -5619,9 +5610,9 @@ cc_library(
hdrs = ["copy_fusion.h"],
deps = [
":gpu_fusible",
":hlo_traversal",
":ir_emission_utils",
":reduction_utils",
"//xla:statusor",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_pass",
"@com_google_absl//absl/algorithm:container",
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/service/gpu/copy_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/reduction_utils.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -111,8 +112,9 @@ absl::StatusOr<bool> CopyFusion::DoCopyFusion(HloComputation* computation) {
if (copies.empty()) {
continue;
}
auto dynamic_update_slices = GetOutputDefiningDynamicUpdateSlices(
GetFusionRoots(*fused_computation));
auto fusion_adaptor = HloFusionAdaptor::ForComputation(fused_computation);
auto dynamic_update_slices =
GetOutputDefiningDynamicUpdateSlices(fusion_adaptor->GetRoots());
// Skip dynamic update slice fusions which might be emitted in-place.
if (!dynamic_update_slices.empty() &&
(root->opcode() != HloOpcode::kTuple ||
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ cc_library(
"//xla/service:buffer_assignment",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
"@com_google_absl//absl/algorithm:container",
Expand Down
11 changes: 6 additions & 5 deletions third_party/xla/xla/service/gpu/fusions/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/transpose_mlir.h"
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand All @@ -76,10 +77,10 @@ bool IsParameterOrGteOfParameter(const HloInstruction* instr) {

bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) {
return absl::c_all_of(
analysis.fusion_roots(), [](const HloInstruction* root) {
return root->opcode() == HloOpcode::kDynamicUpdateSlice ||
(root->opcode() == HloOpcode::kBitcast &&
root->operand(0)->opcode() == HloOpcode::kDynamicUpdateSlice);
analysis.fusion_root_adaptors(), [](const HloInstructionAdaptor& root) {
return root.opcode() == HloOpcode::kDynamicUpdateSlice ||
(root.opcode() == HloOpcode::kBitcast &&
root.GetOperand(0).opcode() == HloOpcode::kDynamicUpdateSlice);
});
}

Expand Down Expand Up @@ -126,7 +127,7 @@ HloFusionInfo::GetCopyFusion() const {

bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const {
auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
instr_, buffer_assignment_, analysis().fusion_roots());
instr_, buffer_assignment_, analysis().fusion_root_adaptors());
return ret.ok() && *ret;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase {
public:
explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis)
: analysis_(analysis),
dus_ops_(
GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
dus_ops_(GetOutputDefiningDynamicUpdateSlices(
analysis.fusion_root_adaptors())) {}
LaunchDimensions launch_dimensions() const override;

std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase {
explicit MlirInPlaceDynamicUpdateSliceFusion(
const HloFusionAnalysis& analysis)
: analysis_(analysis),
dus_ops_(
GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
dus_ops_(GetOutputDefiningDynamicUpdateSlices(
analysis.fusion_root_adaptors())) {}

LaunchDimensions launch_dimensions() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase {
protected:
AffineMapPrinter printer_;
mlir::MLIRContext mlir_context_;
stream_executor::DeviceDescription device_info_ =
TestGpuDeviceInfo::RTXA6000DeviceInfo();
};

TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
auto module = ParseAndReturnVerifiedModule(R"(
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule module
fused_computation {
Expand All @@ -64,14 +66,10 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
i1 = s32[] constant(3)
ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation
}
)")
.value();

stream_executor::DeviceDescription device_info =
TestGpuDeviceInfo::RTXA6000DeviceInfo();
)"));

auto* root = module->entry_computation()->root_instruction();
auto analysis_fused = AnalyzeFusion(*root, device_info);
auto analysis_fused = AnalyzeFusion(*root, device_info_);

TF_ASSERT_OK_AND_ASSIGN(
auto emitter,
Expand Down Expand Up @@ -100,6 +98,44 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt));
}

TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
fused_computation.1 {
param_0 = bf16[1,2,5,1,2] parameter(0)
bitcast = bf16[1,5,1,2,2] bitcast(param_0)
param_1 = bf16[1,1,1,2,2] parameter(1)
param_2 = s32[] parameter(2)
param_3 = s32[] parameter(3)
ROOT dynamic-update-slice = bf16[1,5,1,2,2] dynamic-update-slice(bitcast, param_1, param_2, param_3, param_2, param_2, param_2)
}
ENTRY entry_computation {
param_0.2 = bf16[1,2,5,1,2] parameter(3)
param_1.2 = bf16[1,1,1,2,2] parameter(0)
param_2.2 = s32[] parameter(1)
param_3.2 = s32[] parameter(2)
fusion = bf16[1,5,1,2,2] fusion(param_0.2, param_1.2, param_2.2, param_3.2), kind=kLoop, calls=fused_computation.1
ROOT bitcast.1 = bf16[1,2,5,1,2] bitcast(fusion)
}
)"));

auto* root = module->entry_computation()->root_instruction();

auto analysis_fused =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info_);

TF_ASSERT_OK_AND_ASSIGN(
auto emitter,
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}));

auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());

ASSERT_NE(fusion, nullptr);
EXPECT_EQ(fusion->launch_dimensions().launch_bound(), 4 /* update size */);
}

} // namespace
} // namespace gpu
} // namespace xla
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <optional>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -65,6 +66,14 @@ class HloFusionAnalysis {
const std::vector<const HloInstruction*>& fusion_roots() const {
return fusion_roots_;
}

// TODO(b/336597139): Merge with fusion_roots() and only return
// HloInstructionAdaptor. This function is added temporarily to make
// transition easier and avoid breaking the existing code.
absl::InlinedVector<HloInstructionAdaptor, 2> fusion_root_adaptors() const {
return fusion_->GetRoots();
}

HloInstructionAdaptor fusion_root(int64_t i) const {
return HloInstructionAdaptor(*fusion_roots_[i], fusion_.get());
}
Expand Down
14 changes: 7 additions & 7 deletions third_party/xla/xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,15 @@ absl::StatusOr<BufferAllocation::Slice> GetAllocationSlice(
}

std::vector<const HloInstruction*> GetOutputDefiningDynamicUpdateSlices(
const std::vector<const HloInstruction*>& roots) {
absl::Span<HloInstructionAdaptor const> roots) {
std::vector<const HloInstruction*> dus_ops;
for (const HloInstruction* root : roots) {
while (root->opcode() == HloOpcode::kBitcast) {
root = root->operand(0);
for (HloInstructionAdaptor root : roots) {
while (root.opcode() == HloOpcode::kBitcast) {
root = root.GetOperand(0);
}

if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
dus_ops.push_back(root);
if (root.opcode() == HloOpcode::kDynamicUpdateSlice) {
dus_ops.push_back(&root.instruction());
}
}
return dus_ops;
Expand All @@ -390,7 +390,7 @@ absl::InlinedVector<const HloInstruction*, 4> GetStartIndices(T instr) {
absl::StatusOr<bool> CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
const HloFusionInstruction* fusion,
const BufferAssignment* buffer_assignment,
const std::vector<const HloInstruction*>& roots) {
absl::Span<HloInstructionAdaptor const> roots) {
std::vector<const HloInstruction*> dus_instrs =
GetOutputDefiningDynamicUpdateSlices(roots);

Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ absl::StatusOr<BufferAllocation::Slice> GetAllocationSlice(
absl::StatusOr<bool> CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
const HloFusionInstruction* fusion,
const BufferAssignment* buffer_assignment,
const std::vector<const HloInstruction*>& roots);
absl::Span<HloInstructionAdaptor const> roots);

// Returns the dynamic-update-slice instructions defining the results of a
// fusion node. A dynamic slice update is said to be "defining" of a result if
// that result is the output of a dynamic slice update, or if that result is the
// output of a bitcast of a dynamic slice update---since such bitcast may be
// handled as a no-op.
std::vector<const HloInstruction*> GetOutputDefiningDynamicUpdateSlices(
const std::vector<const HloInstruction*>& roots);
absl::Span<HloInstructionAdaptor const> roots);

Shape GetShape(mlir::Value value);

Expand Down

0 comments on commit 2c83f81

Please sign in to comment.