[go: nahoru, domu]

Skip to content

Commit

Permalink
PR #11563: [NVIDIA GPU] Improve GPU collective matmul to support all-…
Browse files Browse the repository at this point in the history
…gather having multiple users

Imported from GitHub PR openxla/xla#11563

We have identified another optimization opportunity for gpt-3 using collective matmul, in the backward pass, the all-gather has multiple dot users but current spmd will duplicate multiple collective matmul loops. We'd like this transformation:
before:
```
  //                       input
  //                       /    |
  //                      /     |
  //                     AG    windowed loop
  //                     /
  //                    /
  //                   dot

```
after:
```
  //                       input
  //                       |
  //                       |
  //                     windowed loop
  //                       |
  //                       |
  //                      dot
```
This is advantageous since the chained dot can fully utilize all the resource on the GPU while comm is hidden by the first collective matmul loop.

We introduced an option to turn off CM loop duplication in SPMD and rewrite the graph to desired pattern in the gpu_windowed_einsum_handler pass.
Copybara import of the project:

--
986ac94ab44d31f6d11ec6f135f6cfb2e5636d80 by TJ <tjx@nvidia.com>:

Moved most of changes to gpu pass

--
44e81df91c235cac635f334c89d1d8a117ac6511 by TJ <tjx@nvidia.com>:

Added e2e test for windowed einsum
Minimized unit test hlo

--
8fc24a479de7515f532f36de8ffbcce49516c154 by TJ <tjx@nvidia.com>:

Added explanations for spmd tests and dot_handler to skip multiple
consumers

--
142d84d54db2b6291484443e43913d86c44a485c by TJ <tjx@nvidia.com>:

move windowed einsum test to stateful_rng_spmd_partitioner_test

--
8b9fc43746136b40a814d93bf8086a687490fd7f by TJ <tjx@nvidia.com>:

Changed e2e test back to include reducescatter

Merging this change closes #11563

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11563 from Tixxx:tixxx/ag_multi_user 8b9fc43746136b40a814d93bf8086a687490fd7f
PiperOrigin-RevId: 633179304
  • Loading branch information
Tixxx authored and tensorflower-gardener committed May 14, 2024
1 parent f32e67e commit 2ff2df0
Show file tree
Hide file tree
Showing 10 changed files with 479 additions and 26 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6070,6 +6070,8 @@ xla_cc_test(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings:string_view",
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,9 @@ absl::Status RunSPMDPasses(
.xla_gpu_threshold_for_windowed_einsum_mib(),
hlo_module->config()
.debug_options()
.xla_gpu_multi_streamed_windowed_einsum());
.xla_gpu_multi_streamed_windowed_einsum(),
/*skip_checking_windowed_einsum_users=*/true,
/*disable_ag_rewrite_for_multiple_consumers=*/true);
spmd_pipeline.AddPass<CollectivePermuteMotion>();
return spmd_pipeline.Run(hlo_module).status();
} else {
Expand Down
162 changes: 162 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -142,6 +143,136 @@ absl::StatusOr<bool> HandleAgWindowedEinsumLoop(HloComputation* comp,
return changed;
}

absl::Status ProcessWindowedEinsumLoopForActivationCaching(
GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) {
HloInstruction* loop = ag_loop.loop;
// Transform the while body to cache the allgathered result in the
// output buffer to be consumed by the dot
HloComputation* while_body = loop->while_body();
HloInstruction* input_gte;
for (HloInstruction* gte : while_body->parameter_instruction(0)->users()) {
if (gte->tuple_index() == 0) {
input_gte = gte;
}
}
// Get the output operand of the full buffer.
HloInstruction* root = while_body->root_instruction();
// The full buffer that we will use to cache the accumulated activation
// is the 4th operand in the output tuple.
int64_t full_cache_buffer_index = 3;
HloInstruction* full_buffer_output_gte =
root->mutable_operand(full_cache_buffer_index);
HloInstruction* new_full_buffer_output;
// Find the DUS in the loop body and re-use the slice indices
// This should just be a constant(0)
HloInstruction* dus_boundary_constant;
for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
HloInstruction* slice_indices;
// If we have a DUS(PARAM,DS) pattern, we need to update the output
// buffer with the first slice.
if (Match(inst,
m::DynamicUpdateSlice(
m::GetTupleElement(m::Parameter()), m::Op(),
m::Constant(&dus_boundary_constant),
m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
m::Op()))) {
slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
dus_boundary_constant->shape(), slice_indices));
VLOG(5) << "Created slice op for first slice: "
<< slice_indices->ToString();
full_buffer_output_gte =
while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
full_buffer_output_gte->shape(), full_buffer_output_gte,
input_gte,
{dus_boundary_constant, slice_indices, dus_boundary_constant}));
}
// If we have a DUS(DUS,DS) pattern, then the einsum loop is
// unrolled, we need to update the output buffer again with the
// second slice. Since the second slice will have different indices,
// we need to re-capture slice_indices.
if (Match(inst,
m::DynamicUpdateSlice(
m::DynamicUpdateSlice(), m::Op(), m::Constant(),
m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
m::Op()))) {
slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
dus_boundary_constant->shape(), slice_indices));
VLOG(5) << "Created slice op for second slice: "
<< slice_indices->ToString();
// The slice we need this time is the output of the first
// collective-permute
HloInstruction* cp_output;
for (HloInstruction* gte_user : input_gte->users()) {
if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
cp_output = gte_user;
break;
}
}
new_full_buffer_output =
while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
full_buffer_output_gte->shape(), full_buffer_output_gte,
cp_output,
{dus_boundary_constant, slice_indices, dus_boundary_constant}));
}
}
TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index,
new_full_buffer_output));
return OkStatus();
}

class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
public:
explicit WindowedEinsumVisitor(
std::vector<GpuWindowedEinsumHandler::WindowedEinsumAgLoops>&
all_ag_loops)
: all_ag_loops_(all_ag_loops) {}
// Rewrites a allgather-dot pattern that shares the same operand
// with a windowed einsum loop to consume the output of the loop
// and remove the all-gather.
absl::Status HandleDot(HloInstruction* dot) override {
CHECK_EQ(dot->opcode(), HloOpcode::kDot);
for (GpuWindowedEinsumHandler::WindowedEinsumAgLoops ag_loop :
all_ag_loops_) {
HloInstruction* loop = ag_loop.loop;
HloInstruction* ag_operand = nullptr;

if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) ||
Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) {
HloInstruction* windowed_lhs =
loop->mutable_operand(0)->mutable_operand(0);
HloInstruction* ag_with_shared_operand = nullptr;
if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) {
ag_with_shared_operand = ag_operand;
}

if (!ag_with_shared_operand) {
continue;
}

VLOG(5) << "Found all-gather that shares the same operand with a "
"windowed einsum loop : "
<< loop->ToString();
int64_t cache_output_index = dot->operand_index(ag_with_shared_operand);
HloComputation* comp = dot->parent();
HloInstruction* new_gte = comp->AddInstruction(
HloInstruction::CreateGetTupleElement(loop, 3));
TF_RETURN_IF_ERROR(
dot->ReplaceOperandWith(cache_output_index, new_gte));
TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand));
if (!ag_loop.consumed) {
TF_RETURN_IF_ERROR(
ProcessWindowedEinsumLoopForActivationCaching(ag_loop));
ag_loop.consumed = true;
}
}
}
return absl::OkStatus();
}

private:
std::vector<GpuWindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops_;
};

} // namespace

absl::StatusOr<bool> GpuWindowedEinsumHandler::Run(
Expand All @@ -163,9 +294,40 @@ absl::StatusOr<bool> GpuWindowedEinsumHandler::Run(
VLOG(5) << "Processing computation: " << comp->name();
TF_ASSIGN_OR_RETURN(bool comp_result,
HandleAgWindowedEinsumLoop(comp, stream_id));
all_ag_loops_.push_back(
WindowedEinsumAgLoops(comp->WhileCallInstruction()));
changed = comp_result;
}
}
// Now that we have processed all loops, we can check if there are any
// allgather-dot pattern that we can optimize. We'd want to transform:
// input
// / |
// / |
// AG windowed loop
// /
// /
// dot
// to:
// input
// |
// |
// windowed loop
// |
// |
// dot
// The windowed einsum loop will also be rewritten to output the full input to
// be consumed by the dot.
// This is advantageous since the chained dot can fully utilize all the
// resources on the GPU while comm is hidden by the first collective matmul
// loop.
for (HloComputation* comp :
module->MakeNonfusionComputations(execution_threads)) {
WindowedEinsumVisitor visitor(all_ag_loops_);
TF_RETURN_IF_ERROR(comp->Accept(&visitor));
changed |= visitor.changed();
}

XLA_VLOG_LINES(
5, "GpuWindowedEinsumHandler::Run(), after:\n" + module->ToString());
return changed;
Expand Down
10 changes: 9 additions & 1 deletion third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,24 @@ class GpuWindowedEinsumHandler : public HloModulePass {
return "gpu-windowed-einsum-handler";
}

struct WindowedEinsumAgLoops {
WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {}
HloInstruction* loop;
bool consumed = false;
};

using HloPassInterface::Run;
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
constexpr static const char* kWindowedEinsumRsLoopName =
"windowed_dot_general_body_rs";
constexpr static const char* kWindowedEinsumAgLoopName =
"windowed_dot_general_body_ag";

private:
std::vector<WindowedEinsumAgLoops> all_ag_loops_;
};

} // namespace xla::gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {
namespace {

namespace m = ::xla::match;

using GpuWindowedEinsumHanlderTest = HloTestBase;

HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) {
Expand Down Expand Up @@ -193,5 +197,96 @@ ENTRY main.9_spmd {
cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
}

TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) {
constexpr absl::string_view kHloString = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4
windowed_dot_general_body_ag {
param.1 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
get-tuple-element.1 = bf16[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
collective-permute = bf16[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
collective-permute.1 = bf16[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
get-tuple-element.2 = bf16[24576,24576]{1,0} get-tuple-element(param.1), index=1
get-tuple-element.3 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
dot = bf16[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
constant.2 = s32[] constant(0)
constant.3 = s32[4]{0} constant({0, 512, 1024, 1536})
get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
partition-id = u32[] partition-id()
add = u32[] add(get-tuple-element.5, partition-id)
constant.1 = u32[] constant(4)
remainder = u32[] remainder(add, constant.1)
dynamic-slice = s32[1]{0} dynamic-slice(constant.3, remainder), dynamic_slice_sizes={1}
reshape = s32[] reshape(dynamic-slice)
dynamic-update-slice = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.2, reshape, constant.2)
dot.1 = bf16[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
constant.5 = u32[] constant(1)
add.1 = u32[] add(get-tuple-element.5, constant.5)
add.2 = u32[] add(add.1, partition-id)
remainder.1 = u32[] remainder(add.2, constant.1)
dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.3, remainder.1), dynamic_slice_sizes={1}
reshape.1 = s32[] reshape(dynamic-slice.1)
dynamic-update-slice.1 = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.2, reshape.1, constant.2)
get-tuple-element.4 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
add.3 = u32[] add(add.1, constant.5)
ROOT tuple = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
} // windowed_dot_general_body_ag
windowed_dot_general_cond_ag {
param = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
get-tuple-element = u32[] get-tuple-element(param), index=4
constant = u32[] constant(4)
ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
}
ENTRY main.12_spmd {
param.4 = bf16[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
constant.22 = bf16[] constant(0)
broadcast = bf16[2,2048,24576]{2,1,0} broadcast(constant.22), dimensions={}
constant.24 = u32[] constant(0)
tuple.2 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
while = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
get-tuple-element.13 = bf16[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
copy.1 = bf16[2,2048,24576]{2,1,0} copy(get-tuple-element.13)
all-gather = bf16[2,2048,24576]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
param.6 = bf16[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]}
ROOT dot.7 = bf16[2,2048,24576]{2,1,0} dot(all-gather, param.6), lhs_contracting_dims={2}, rhs_contracting_dims={0}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));

GpuWindowedEinsumHandler gpu_handler;
bool changed;
TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
EXPECT_TRUE(changed);

HloInstruction* ag_loop =
FindInstructionByName(module->entry_computation(), "while");
HloInstruction* inst =
FindInstructionByName(module->entry_computation(), "dot.7");
// dot.7 should now consume output of the windowed einsum while loop.
EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(inst->operand(0)->tuple_index(), 3);
EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);

// while loop's root should now have a chain of DUS.
HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction();
EXPECT_THAT(ag_while_root,
GmockMatch(m::Tuple(
m::Op(), m::Op(), m::Op(),
m::DynamicUpdateSlice(
m::DynamicUpdateSlice(
m::GetTupleElement(m::Parameter())
.WithPredicate([](const HloInstruction* instr) {
return instr->tuple_index() == 3;
}),
m::Op(), m::Op(), m::Op(), m::Op()),
m::Op(), m::Op(), m::Op(), m::Op()),
m::Op())));
}

} // namespace
} // namespace xla::gpu
Loading

0 comments on commit 2ff2df0

Please sign in to comment.