[go: nahoru, domu]

Skip to content

Commit

Permalink
Implements the FullyReplicatedShard method for the BasicStringArray…
Browse files Browse the repository at this point in the history
… class.

PiperOrigin-RevId: 643698167
  • Loading branch information
tensorflower-gardener committed Jun 16, 2024
1 parent 60fdc48 commit 49f3f3d
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 11 deletions.
112 changes: 110 additions & 2 deletions third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,20 @@ limitations under the License.
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/future.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"

// TODO(jmudigonda): Several BasicStringArray operations such as
// DisassembleIntoSingleDeviceArrays, Reshard, FullyReplicatedShard,
// CopyToHostBuffer and AssembleFromSingleDeviceArrays share a common pattern
// that waits for the source array(s) buffers to become ready and then copies
// the data into a new array's buffer backing store. Factor out the common
// pattern into a helper function.

namespace xla {
namespace ifrt {

Expand Down Expand Up @@ -247,13 +255,113 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Reshard(
std::shared_ptr<const Sharding> new_sharding,
ArrayCopySemantics semantics) {
DCHECK(this);
return absl::UnimplementedError("Not implemented");
absl::MutexLock lock(&mu_);
if (is_deleted_) {
return absl::FailedPreconditionError("Array has already been deleted");
}

if (new_sharding->devices().size() != sharding_->devices().size()) {
return absl::InvalidArgumentError(absl::StrCat(
"Number of devices in new sharding: ", new_sharding->devices().size(),
" does not match the number of devices in the current sharding: ",
sharding_->devices().size()));
}

struct BufferBackingStore {
void AddShardData(absl::Span<const absl::string_view> input_buffer) {
auto& shard_strings = strings.emplace_back();
shard_strings.reserve(input_buffer.size());

auto& shard_string_views = string_views.emplace_back();
shard_string_views.reserve(input_buffer.size());

for (absl::string_view buf : input_buffer) {
shard_strings.push_back(std::string(buf.data(), buf.size()));
shard_string_views.push_back(shard_strings.back());
}
}
std::vector<std::vector<std::string>> strings;
std::vector<std::vector<absl::string_view>> string_views;
};

auto backing_store = std::make_shared<BufferBackingStore>();
auto on_done_with_buffer = [backing_store]() {};
auto buffers_promise = Future<Buffers>::CreatePromise();
auto buffers_future = Future<Buffers>(buffers_promise);

auto copier = [backing_store = std::move(backing_store),
buffers_promise = std::move(buffers_promise)](
absl::StatusOr<Buffers> input_buffers) mutable {
if (!input_buffers.ok()) {
buffers_promise.Set(input_buffers.status());
return;
}
Buffers buffers;
buffers.reserve(input_buffers->size());
for (auto& input_buffer : *input_buffers) {
backing_store->AddShardData(input_buffer);
buffers.push_back(backing_store->string_views.back());
}
buffers_promise.Set(std::move(buffers));
};
buffers_.OnReady(std::move(copier));
return BasicStringArray::Create(client_, shape_, std::move(new_sharding),
std::move(buffers_future),
std::move(on_done_with_buffer));
}

absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
ArrayCopySemantics semantics) {
absl::MutexLock lock(&mu_);
if (is_deleted_) {
return absl::FailedPreconditionError("Array has already been deleted");
}
// Make a single sharded BasicStringArray from the first shard.
return absl::UnimplementedError("Not implemented");
if (!sharding_->IsFullyReplicated()) {
return absl::FailedPreconditionError("This array is not fully replicated");
}
struct BufferBackingStore { // Data (strings) for a single shard.
void CopyFrom(absl::Span<const absl::string_view> input_buffer) {
strings.reserve(input_buffer.size());
string_views.reserve(input_buffer.size());
for (absl::string_view buf : input_buffer) {
strings.push_back(std::string(buf.data(), buf.size()));
string_views.push_back(strings.back());
}
}
std::vector<std::string> strings;
std::vector<absl::string_view> string_views;
};

auto backing_store = std::make_shared<BufferBackingStore>();
auto on_done_with_buffer = [backing_store]() {};
auto buffers_promise = Future<Buffers>::CreatePromise();
auto buffers_future = Future<Buffers>(buffers_promise);

auto copier = [backing_store = std::move(backing_store),
buffers_promise = std::move(buffers_promise)](
absl::StatusOr<Buffers> input_buffers) mutable {
if (!input_buffers.ok()) {
buffers_promise.Set(input_buffers.status());
return;
}

// No need to check the size of input_buffers. The consistency checks that
// were run when the source array's buffers became ready would have ensured
// that the input_buffers have at least one shard's worth of data.
auto& input_buffer = (*input_buffers)[0];
backing_store->CopyFrom(input_buffer);

Buffers buffers;
buffers.push_back(backing_store->string_views);
buffers_promise.Set(std::move(buffers));
};
buffers_.OnReady(std::move(copier));

return BasicStringArray::Create(
client_, shape_,
SingleDeviceSharding::Create(sharding_->devices().at(0), MemoryKind()),
std::move(buffers_future), std::move(on_done_with_buffer));
}

absl::StatusOr<std::unique_ptr<PjRtLayout>> BasicStringArray::layout() const {
Expand Down
Loading

0 comments on commit 49f3f3d

Please sign in to comment.