[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:SPMD] Support partitioning kCall.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634944833
  • Loading branch information
Tongfei-Guo authored and tensorflower-gardener committed May 18, 2024
1 parent 17c7b86 commit b409682
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
27 changes: 26 additions & 1 deletion third_party/xla/xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2266,6 +2266,30 @@ std::vector<ReplicaGroup> SpmdPartitioningVisitor::CreateReplicaGroups(
return device_groups;
}

Status SpmdPartitioningVisitor::HandleCall(HloInstruction* hlo) {
std::vector<HloInstruction*> call_args;
HloComputation* computation = hlo->called_computations()[0];
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
// Shardings of the computation parameter and its argument must be
// the same.
computation->parameter_instruction(i)->set_sharding(
hlo->operand(i)->sharding());
call_args.push_back(GetPartitionedHlo(hlo->operand(i)).hlo());
}

TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(computation, hlo->sharding(),
next_channel_id_, logger_,
call_graph_)
.status());
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateCall(
MakePartitionedShape(hlo->shape(), hlo->sharding()), call_args,
hlo->called_computations()[0]));
});
return OkStatus();
}

Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
if (hlo->HasSideEffect() && !hlo->sharding().HasUniqueDevice()) {
return Unimplemented("Side-effect ops cannot be replicated: %s",
Expand Down Expand Up @@ -2343,7 +2367,8 @@ Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
hlo->opcode() != HloOpcode::kParameter &&
hlo->opcode() != HloOpcode::kWhile && hlo->opcode() != HloOpcode::kRng &&
hlo->opcode() != HloOpcode::kOutfeed &&
hlo->opcode() != HloOpcode::kAllReduce) {
hlo->opcode() != HloOpcode::kAllReduce &&
hlo->opcode() != HloOpcode::kCall) {
const bool has_manual_sharding =
hlo->sharding().IsManual() ||
(hlo->sharding().IsTuple() &&
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/spmd/spmd_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo) override;
Status HandleAllReduce(HloInstruction* hlo) override;
Status HandleBroadcast(HloInstruction* hlo) override;
Status HandleCall(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* hlo) override;
Status HandleCustomCall(HloInstruction* hlo) override;
Status HandleDot(HloInstruction* hlo) override;
Expand Down
30 changes: 30 additions & 0 deletions third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,36 @@ ENTRY entry {
op::Shape("s32[1,3]")));
}

TEST_P(SpmdPartitioningTest, PartitionCall) {
absl::string_view hlo_string = R"(
HloModule jit_f

g {
Arg_0.6 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,2]<=[4]}
constant.0 = s32[] constant(2), sharding={replicated}
broadcast.0 = s32[8,2]{1,0} broadcast(constant.0), dimensions={}, sharding={devices=[2,2]<=[4]}
ROOT multiply.9 = s32[8,2]{1,0} multiply(Arg_0.6, broadcast.0), sharding={devices=[2,2]<=[4]}
}

ENTRY main {
Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,2]<=[4]}
constant.1 = s32[] constant(3), sharding={replicated}
broadcast.1 = s32[8,2]{1,0} broadcast(constant.1), dimensions={}, sharding={devices=[2,2]<=[4]}
multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.1), sharding={devices=[2,2]<=[4]}
ROOT call = s32[8,2]{1,0} call(multiply.4), to_apply=g, sharding={devices=[2,2]<=[4]}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_HOST","used_scoped_memory_configs":[]}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Call(), op::Shape("s32[4,1]")));
HloInstruction* call_comp_root =
root->called_computations()[0]->root_instruction();
EXPECT_THAT(call_comp_root, AllOf(op::Multiply(op::Parameter(0),
op::Broadcast(op::Constant())),
op::Shape("s32[4,1]")));
}

TEST_P(SpmdPartitioningTest, TiledToReplicated) {
absl::string_view hlo_string = R"(
HloModule module
Expand Down

0 comments on commit b409682

Please sign in to comment.