[go: nahoru, domu]

Skip to content

Commit

Permalink
[SPMD] Support shard like interface in XLA and JAX (0/N).
Browse files Browse the repository at this point in the history
Define sharding annotation with key words

1. shard_as <num>
2. shard_like <num>

to group shardings.

PiperOrigin-RevId: 561733657
  • Loading branch information
Tongfei-Guo authored and tensorflower-gardener committed Aug 31, 2023
1 parent 6480d94 commit ae366f7
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 14 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
Expand Down
48 changes: 37 additions & 11 deletions tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_op_metadata.h"
Expand Down Expand Up @@ -403,23 +404,32 @@ void HloSharding::Print(Printer* printer, bool include_metadata) const {
printer->Append("}");
}
};
auto print_shard_group = [&] {
auto shard_group_str = shard_group_.ToString();
if (!shard_group_str.empty()) {
printer->Append(" " + shard_group_str);
}
};

if (replicated_) {
printer->Append("{replicated");
print_shard_group();
print_metadata();
printer->Append("}");
return;
}

if (manual_) {
printer->Append("{manual");
print_shard_group();
print_metadata();
printer->Append("}");
return;
}
if (maximal_) {
AppendCat(printer, "{maximal device=",
static_cast<int64_t>(*tile_assignment_.array().begin()));
print_shard_group();
print_metadata();
printer->Append("}");
return;
Expand Down Expand Up @@ -447,12 +457,14 @@ void HloSharding::Print(Printer* printer, bool include_metadata) const {
printer->Append("}");
}
};

printer->Append("{");
tile_assignment_.Print(printer);
if (replicate_on_last_tile_dim_) {
printer->Append(" last_tile_dim_replicate");
}
print_last_tile_dims();
print_shard_group();
print_metadata();
printer->Append("}");
}
Expand Down Expand Up @@ -497,7 +509,7 @@ std::map<int64_t, int64_t> HloSharding::UsedDevices(int64_t* count) const {

std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {
CHECK(!maximal_);
CHECK(!manual_);
CHECK(!IsManual());
CHECK(!IsTuple());
std::vector<int64_t> ret_index;
tile_assignment_.Each([&](absl::Span<const int64_t> index, int64_t d) {
Expand All @@ -512,7 +524,7 @@ std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {

int64_t HloSharding::DeviceForTileIndex(absl::Span<const int64_t> index) const {
CHECK(!replicated_);
CHECK(!manual_);
CHECK(!IsManual());
CHECK(!IsTuple());
if (maximal_) {
return *tile_assignment_.array().begin();
Expand All @@ -532,7 +544,7 @@ int64_t HloSharding::DeviceForTileIndex(absl::Span<const int64_t> index) const {
std::vector<int64_t> HloSharding::TileOffsetForDevice(const Shape& shape,
int64_t device) const {
CHECK(!IsTuple());
CHECK(!manual_);
CHECK(!IsManual());

if (maximal_) {
return std::vector<int64_t>(shape.dimensions_size(), 0);
Expand All @@ -550,7 +562,7 @@ std::vector<int64_t> HloSharding::TileOffsetForDevice(const Shape& shape,
std::vector<int64_t> HloSharding::TileLimitForDevice(const Shape& shape,
int64_t device) const {
CHECK(!IsTuple());
CHECK(!manual_);
CHECK(!IsManual());

if (maximal_) {
return std::vector<int64_t>(shape.dimensions().begin(),
Expand Down Expand Up @@ -781,17 +793,18 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
HloSharding::FromProto(tuple_sharding_proto));
tuple_shardings.push_back(sharding);
}
return HloSharding(tuple_shardings);
return HloSharding(tuple_shardings).SetShardGroupFromProto(proto);
} else if (proto.type() == OpSharding::REPLICATED) {
return Replicate(metadata);
return Replicate(metadata).SetShardGroupFromProto(proto);
} else if (proto.type() == OpSharding::MANUAL) {
return Manual(metadata);
return Manual(metadata).SetShardGroupFromProto(proto);
} else if (proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0), metadata);
return HloSharding(proto.tile_assignment_devices(0), metadata)
.SetShardGroupFromProto(proto);
} else if (!proto.iota_reshape_dims().empty() &&
absl::c_all_of(proto.iota_reshape_dims(),
[](int64_t d) { return d == 1; })) {
return HloSharding(0, metadata);
return HloSharding(0, metadata).SetShardGroupFromProto(proto);
}

TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
Expand Down Expand Up @@ -848,12 +861,15 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
};
if (!subgroup_types.empty()) {
TF_RET_CHECK(!proto.replicate_on_last_tile_dim());
return Subgroup(create_tile_assignment(), subgroup_types, metadata);
return Subgroup(create_tile_assignment(), subgroup_types, metadata)
.SetShardGroupFromProto(proto);
}
return proto.replicate_on_last_tile_dim()
? PartialTile(create_tile_assignment(), metadata)
.SetShardGroupFromProto(proto)
: HloSharding(create_tile_assignment(),
/*replicate_on_last_tile_dim=*/false, metadata);
/*replicate_on_last_tile_dim=*/false, metadata)
.SetShardGroupFromProto(proto);
}

OpSharding HloSharding::ToProto() const {
Expand Down Expand Up @@ -912,6 +928,16 @@ OpSharding HloSharding::ToProto() const {
result.add_last_tile_dims(type);
}
}

if (IsShardGroup()) {
result.set_is_shard_group(true);
result.set_shard_group_id(shard_group_.shard_group_id);
if (shard_group_.shard_as) {
result.set_shard_group_type(OpSharding::AS);
} else {
result.set_shard_group_type(OpSharding::LIKE);
}
}
return result;
}

Expand Down
105 changes: 103 additions & 2 deletions tensorflow/compiler/xla/hlo/ir/hlo_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -188,6 +189,31 @@ class HloSharding {
[](const HloSharding& s) { return s.IsManual(); });
}

bool IsShardGroup() const {
if (!IsTuple()) {
return shard_group_.shard_group_id != -1 &&
(shard_group_.shard_like || shard_group_.shard_as);
}
return absl::c_all_of(
tuple_elements_, [](const HloSharding& s) { return s.IsShardGroup(); });
}

bool IsShardAs() const {
if (!IsTuple()) {
return shard_group_.shard_group_id != -1 && shard_group_.shard_as;
}
return absl::c_all_of(tuple_elements_,
[](const HloSharding& s) { return s.IsShardAs(); });
}

bool IsShardLike() const {
if (!IsTuple()) {
return shard_group_.shard_group_id != -1 && shard_group_.shard_like;
}
return absl::c_all_of(tuple_elements_,
[](const HloSharding& s) { return s.IsShardLike(); });
}

// Returns whether the sharding represents manual subgroup sharding.
bool IsManualSubgroup() const {
if (!IsTuple()) {
Expand Down Expand Up @@ -312,7 +338,8 @@ class HloSharding {
tile_assignment_ == other.tile_assignment_ &&
tuple_elements_ == other.tuple_elements_ &&
replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ &&
subgroup_types_ == other.subgroup_types_;
subgroup_types_ == other.subgroup_types_ &&
shard_group_ == other.shard_group_;
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }

Expand All @@ -323,7 +350,8 @@ class HloSharding {
}
return H::combine(std::move(h), sharding.replicated_, sharding.manual_,
sharding.tile_assignment_.array(),
sharding.replicate_on_last_tile_dim_);
sharding.replicate_on_last_tile_dim_,
sharding.shard_group_.ToString());
}

// Gets the tile assignment tensor.
Expand Down Expand Up @@ -400,6 +428,75 @@ class HloSharding {
// Returns the number of tuple_elements_ entries to fit the shape.
static int64_t RequiredLeaves(const Shape& shape);

struct ShardGroup {
ShardGroup(int64_t shard_group_id, bool shard_as, bool shard_like)
: shard_group_id(shard_group_id),
shard_as(shard_as),
shard_like(shard_like) {}

bool operator==(const ShardGroup& rhs) const {
return shard_group_id == rhs.shard_group_id && shard_as == rhs.shard_as &&
shard_like == rhs.shard_like;
}

std::string ToString() const {
std::ostringstream result;
if (shard_as) {
result << "shard_as " << shard_group_id;
} else if (shard_like) {
result << "shard_like " << shard_group_id;
}
return result.str();
}

int64_t shard_group_id = 0;
bool shard_as;
bool shard_like;
};
static ShardGroup NotShardGroup() {
return ShardGroup(
/*shard_group_id=*/-1,
/*shard_as=*/false,
/*shard_like=*/false);
}

static ShardGroup ShardAs(int64_t shard_group_id) {
return ShardGroup(shard_group_id,
/*shard_as=*/true,
/*shard_like=*/false);
}

static ShardGroup ShardLike(int64_t shard_group_id) {
return ShardGroup(shard_group_id,
/*shard_as=*/false,
/*shard_like=*/true);
}

HloSharding& SetShardGroup(const ShardGroup& shard_group) {
shard_group_ = shard_group;
return *this;
}

HloSharding& SetShardGroupFromProto(const OpSharding& proto) {
ShardGroup shard_group = NotShardGroup();
if (proto.is_shard_group()) {
if (proto.shard_group_type() == OpSharding::AS) {
shard_group = ShardAs(proto.shard_group_id());
} else {
shard_group = ShardLike(proto.shard_group_id());
}
}
SetShardGroup(shard_group);
return *this;
}

HloSharding& ClearShardGroup() {
shard_group_ = NotShardGroup();
return *this;
}

const ShardGroup& GetShardGroup() const { return shard_group_; }

private:
explicit HloSharding(bool manual, bool replicated,
absl::Span<const OpMetadata> metadata)
Expand Down Expand Up @@ -522,6 +619,10 @@ class HloSharding {
// shape rank, and the added last dimension represents the subgroups of
// replications, i.e., elements in slice [..., :] will be replicated.
bool replicate_on_last_tile_dim_;
// This field is used to store the shard group information. Instructions
// within the same shard group(i.e. under the same shard_group_id) will be
// sharded alike or exactly the same as each other.
ShardGroup shard_group_ = NotShardGroup();
};

std::ostream& operator<<(std::ostream& out, const HloSharding& sharding);
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/compiler/xla/python/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,11 +938,21 @@ void BuildXlaCompilerSubmodule(py::module& m) {
.value("TUPLE", OpSharding::TUPLE)
.value("OTHER", OpSharding::OTHER);

py::enum_<OpSharding::ShardGroupType> op_sharding_shard_group_type(
m, "OpSharding_ShardGroupType");
op_sharding_shard_group_type.value("AS", OpSharding::AS)
.value("LIKE", OpSharding::LIKE);

py::class_<OpSharding> op_sharding(m, "OpSharding");
op_sharding
.def_property_readonly_static(
"Type",
[op_sharding_type](const py::object&) { return op_sharding_type; })
.def_property_readonly_static(
"ShardGroupType",
[op_sharding_shard_group_type](const py::object&) {
return op_sharding_shard_group_type;
})
.def(py::init<>())
.def(py::pickle(
[](const OpSharding& self) {
Expand All @@ -957,6 +967,12 @@ void BuildXlaCompilerSubmodule(py::module& m) {
.def_property("replicate_on_last_tile_dim",
&xla::OpSharding::replicate_on_last_tile_dim,
&xla::OpSharding::set_replicate_on_last_tile_dim)
.def_property("is_shard_group", &xla::OpSharding::is_shard_group,
&xla::OpSharding::set_is_shard_group)
.def_property("shard_group_id", &xla::OpSharding::shard_group_id,
&xla::OpSharding::set_shard_group_id)
.def_property("shard_group_type", &xla::OpSharding::shard_group_type,
&xla::OpSharding::set_shard_group_type)
.def("__repr__", &xla::OpSharding::DebugString)
.def("ParseFromString",
[](OpSharding& sharding, const std::string& s) {
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ class OpSharding_Type(enum.IntEnum):
OTHER: int
MANUAL: int

class OpSharding_ShardGroupType(enum.IntEnum):
AS: int
LIKE: int

class OpSharding:
Type: typing.Type[OpSharding_Type]
type: OpSharding_Type
Expand All @@ -302,6 +306,10 @@ class OpSharding:
iota_reshape_dims: Sequence[int]
iota_transpose_perm: Sequence[int]
tuple_shardings: Sequence[OpSharding]
is_shard_group: bool
shard_group_id: int
ShardGroupType: typing.Type[OpSharding_ShardGroupType]
shard_group_type: OpSharding_ShardGroupType
def ParseFromString(self, s: bytes) -> None: ...
def SerializeToString(self) -> bytes: ...
def clone(self) -> OpSharding: ...
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5879,6 +5879,7 @@ xla_cc_test(
"//tensorflow/tsl/platform:statusor",
"//tensorflow/tsl/platform:test",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
)

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/xla/service/hlo_lexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ TokKind HloLexer::LexIdentifier() {
KEYWORD(replicated);
KEYWORD(manual);
KEYWORD(last_tile_dim_replicate);
KEYWORD(shard_as);
KEYWORD(shard_like);

#undef KEYWORD

Expand Down Expand Up @@ -606,6 +608,10 @@ std::string TokKindToString(TokKind kind) {
return "kw_manual";
case TokKind::kw_last_tile_dim_replicate:
return "kw_last_tile_dim_replicate";
case TokKind::kw_shard_as:
return "kw_shard_as";
case TokKind::kw_shard_like:
return "kw_shard_like";
case TokKind::kw_inf:
return "kw_inf";
case TokKind::kNegInf:
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/hlo_lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ enum class TokKind {
kw_replicated,
kw_manual,
kw_last_tile_dim_replicate,
kw_shard_as,
kw_shard_like,
kw_inf,

kNegInf, // -inf
Expand Down
Loading

0 comments on commit ae366f7

Please sign in to comment.