From ecddea522a180c710c29215a49b06169c1f5965c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jun 2024 22:37:55 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 643542037 --- .../xla/third_party/tsl/tsl/lib/io/block.cc | 4 +- .../tsl/tsl/lib/io/buffered_file.h | 20 +-- .../tsl/tsl/lib/io/buffered_inputstream.cc | 55 ++++---- .../tsl/tsl/lib/io/buffered_inputstream.h | 22 ++-- .../tsl/lib/io/buffered_inputstream_test.cc | 8 +- .../xla/third_party/tsl/tsl/lib/io/format.cc | 17 +-- .../xla/third_party/tsl/tsl/lib/io/format.h | 8 +- .../third_party/tsl/tsl/lib/io/inputbuffer.cc | 57 ++++----- .../third_party/tsl/tsl/lib/io/inputbuffer.h | 33 ++--- .../tsl/tsl/lib/io/inputstream_interface.cc | 4 +- .../tsl/tsl/lib/io/inputstream_interface.h | 8 +- .../tsl/lib/io/inputstream_interface_test.cc | 8 +- .../third_party/tsl/tsl/lib/io/iterator.cc | 10 +- .../xla/third_party/tsl/tsl/lib/io/iterator.h | 4 +- .../tsl/tsl/lib/io/random_inputstream.cc | 23 ++-- .../tsl/tsl/lib/io/random_inputstream.h | 12 +- .../tsl/tsl/lib/io/record_reader.cc | 29 ++--- .../tsl/tsl/lib/io/record_reader.h | 18 +-- .../tsl/lib/io/record_reader_writer_test.cc | 4 +- .../tsl/tsl/lib/io/record_writer.cc | 30 ++--- .../tsl/tsl/lib/io/record_writer.h | 8 +- .../tsl/tsl/lib/io/recordio_test.cc | 30 ++--- third_party/xla/xla/service/gpu/BUILD | 4 +- .../xla/xla/service/gpu/ir_emitter_triton.cc | 119 +++++++++++------- third_party/xla/xla/service/gpu/model/BUILD | 12 ++ .../model/gpu_indexing_performance_model.cc | 112 +++++++++++++++++ .../model/gpu_indexing_performance_model.h | 26 ++++ .../gpu_indexing_performance_model_test.cc | 105 +++++++++++++++- .../gpu/model/gpu_performance_model_test.cc | 5 +- .../xla/service/gpu/triton_support_test.cc | 68 +++++++++- 30 files changed, 610 insertions(+), 253 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc b/third_party/xla/third_party/tsl/tsl/lib/io/block.cc index 0bc9fa3664c97b..8eefa4b5a3609f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/block.cc @@ -98,7 +98,7 @@ class Block::Iter : public Iterator { uint32 restart_index_; // Index of restart block in which current_ falls string key_; StringPiece value_; - Status status_; + absl::Status status_; inline int Compare(const StringPiece& a, const StringPiece& b) const { return a.compare(b); @@ -135,7 +135,7 @@ class Block::Iter : public Iterator { } bool Valid() const override { return current_ < restarts_; } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } StringPiece key() const override { assert(Valid()); return key_; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h index 5627a7228fb782..69300956d9fe20 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h @@ -36,7 +36,7 @@ class BufferedWritableFile : public WritableFile { } ~BufferedWritableFile() override { Close().IgnoreError(); } - Status Append(StringPiece str_data) override { + absl::Status Append(StringPiece str_data) override { int64_t bytes_left = str_data.size(); const char* data = str_data.data(); @@ -58,22 +58,22 @@ class BufferedWritableFile : public WritableFile { bytes_left -= append_bytes; } - return OkStatus(); + return absl::OkStatus(); } - Status Append(const absl::Cord& data) override { + absl::Status Append(const absl::Cord& data) override { for (absl::string_view fragment : data.Chunks()) { TF_RETURN_IF_ERROR(Append(fragment)); } - return OkStatus(); + return absl::OkStatus(); } - Status Close() override { + absl::Status Close() override { TF_RETURN_IF_ERROR(Flush()); return file_->Close(); } - Status Flush() override { + absl::Status Flush() override { if (buffer_pos_ > 0) { TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], buffer_pos_))); buffer_pos_ = 0; @@ -81,18 +81,18 @@ class BufferedWritableFile : public WritableFile { return file_->Flush(); } - tsl::Status Tell(int64_t* position) override { + absl::Status Tell(int64_t* position) override { int64_t bytes_written; - tsl::Status status = file_->Tell(&bytes_written); + absl::Status status = file_->Tell(&bytes_written); if (status.ok()) { *position = bytes_written + buffer_pos_; - return OkStatus(); + return absl::OkStatus(); } else { return status; } } - Status Sync() override { return file_->Sync(); } + absl::Status Sync() override { return file_->Sync(); } // For compatibilty with the TensorBundle writer, we expose CRC32 checksums. uint32_t crc32() const { return crc32_; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc index b3cfdbb20818ec..89ed20757cf093 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc @@ -41,13 +41,13 @@ BufferedInputStream::~BufferedInputStream() { } } -Status BufferedInputStream::FillBuffer() { +absl::Status BufferedInputStream::FillBuffer() { if (!file_status_.ok()) { pos_ = 0; limit_ = 0; return file_status_; } - Status s = input_stream_->ReadNBytes(size_, &buf_); + absl::Status s = input_stream_->ReadNBytes(size_, &buf_); pos_ = 0; limit_ = buf_.size(); if (!s.ok()) { @@ -57,10 +57,10 @@ Status BufferedInputStream::FillBuffer() { } template -Status BufferedInputStream::ReadLineHelper(StringType* result, - bool include_eol) { +absl::Status BufferedInputStream::ReadLineHelper(StringType* result, + bool include_eol) { result->clear(); - Status s; + absl::Status s; size_t start_pos = pos_; while (true) { if (pos_ == limit_) { @@ -79,7 +79,7 @@ Status BufferedInputStream::ReadLineHelper(StringType* result, result->append(1, c); } pos_++; - return OkStatus(); + return absl::OkStatus(); } // We don't append '\r' to *result if (c == '\r') { @@ -89,12 +89,13 @@ Status BufferedInputStream::ReadLineHelper(StringType* result, pos_++; } if (absl::IsOutOfRange(s) && !result->empty()) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { +absl::Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, + tstring* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); @@ -105,7 +106,7 @@ Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { } result->reserve(bytes_to_read); - Status s; + absl::Status s; while (result->size() < static_cast(bytes_to_read)) { // Check whether the buffer is fully read or not. if (pos_ == limit_) { @@ -127,12 +128,12 @@ Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) { // obtained enough data to satisfy the function call. Returning OK then. if (absl::IsOutOfRange(s) && (result->size() == static_cast(bytes_to_read))) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { +absl::Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can only skip forward, not ", bytes_to_skip); @@ -144,7 +145,7 @@ Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { // Otherwise, we already have read limit_ - pos_, so skip the rest. At this // point we need to get fresh data into the buffer, so reset pos_ and // limit_. - Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); + absl::Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); pos_ = 0; limit_ = 0; if (absl::IsOutOfRange(s)) { @@ -152,14 +153,14 @@ Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) { } return s; } - return OkStatus(); + return absl::OkStatus(); } int64_t BufferedInputStream::Tell() const { return input_stream_->Tell() - (limit_ - pos_); } -Status BufferedInputStream::Seek(int64_t position) { +absl::Status BufferedInputStream::Seek(int64_t position) { if (position < 0) { return errors::InvalidArgument("Seeking to a negative position: ", position); @@ -176,7 +177,7 @@ Status BufferedInputStream::Seek(int64_t position) { if (position < Tell()) { // Seek within buffer before 'pos_' pos_ -= Tell() - position; - return OkStatus(); + return absl::OkStatus(); } // Seek after 'pos_' @@ -184,9 +185,9 @@ Status BufferedInputStream::Seek(int64_t position) { } template -Status BufferedInputStream::ReadAll(T* result) { +absl::Status BufferedInputStream::ReadAll(T* result) { result->clear(); - Status status; + absl::Status status; while (status.ok()) { status = FillBuffer(); if (limit_ == 0) { @@ -198,7 +199,7 @@ Status BufferedInputStream::ReadAll(T* result) { if (absl::IsOutOfRange(status)) { file_status_ = status; - return OkStatus(); + return absl::OkStatus(); } return status; } @@ -206,19 +207,19 @@ Status BufferedInputStream::ReadAll(T* result) { template Status BufferedInputStream::ReadAll(std::string* result); template Status BufferedInputStream::ReadAll(tstring* result); -Status BufferedInputStream::Reset() { +absl::Status BufferedInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); pos_ = 0; limit_ = 0; - file_status_ = OkStatus(); - return OkStatus(); + file_status_ = absl::OkStatus(); + return absl::OkStatus(); } -Status BufferedInputStream::ReadLine(std::string* result) { +absl::Status BufferedInputStream::ReadLine(std::string* result) { return ReadLineHelper(result, false); } -Status BufferedInputStream::ReadLine(tstring* result) { +absl::Status BufferedInputStream::ReadLine(tstring* result) { return ReadLineHelper(result, false); } @@ -228,8 +229,8 @@ std::string BufferedInputStream::ReadLineAsString() { return result; } -Status BufferedInputStream::SkipLine() { - Status s; +absl::Status BufferedInputStream::SkipLine() { + absl::Status s; bool skipped = false; while (true) { if (pos_ == limit_) { @@ -242,11 +243,11 @@ Status BufferedInputStream::SkipLine() { char c = buf_[pos_++]; skipped = true; if (c == '\n') { - return OkStatus(); + return absl::OkStatus(); } } if (absl::IsOutOfRange(s) && skipped) { - return OkStatus(); + return absl::OkStatus(); } return s; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h index 5318434c63919c..6681f1bbfbed32 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h @@ -43,9 +43,9 @@ class BufferedInputStream : public InputStreamInterface { ~BufferedInputStream() override; - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; - Status SkipNBytes(int64_t bytes_to_skip) override; + absl::Status SkipNBytes(int64_t bytes_to_skip) override; int64_t Tell() const override; @@ -58,7 +58,7 @@ class BufferedInputStream : public InputStreamInterface { // Note: When seeking backwards in a stream, this implementation uses // Reset() + SkipNBytes(), so its performance will be dependent // largely on the performance of SkipNBytes(). - Status Seek(int64_t position); + absl::Status Seek(int64_t position); // Read one text line of data into "*result" until end-of-file or a // \n is read. (The \n is not included in the result.) Overwrites @@ -67,8 +67,8 @@ class BufferedInputStream : public InputStreamInterface { // If successful, returns OK. If we are already at the end of the // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. - Status ReadLine(std::string* result); - Status ReadLine(tstring* result); + absl::Status ReadLine(std::string* result); + absl::Status ReadLine(tstring* result); // Returns one text line of data until end-of-file or a '\n' is read. The '\n' // is included in the result. @@ -83,21 +83,21 @@ class BufferedInputStream : public InputStreamInterface { // If successful, returns OK. If we are already at the end of the // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. - Status SkipLine(); + absl::Status SkipLine(); // Reads the entire contents of the file into *result. // // Note: the amount of memory used by this function call is unbounded, so only // use in ops that expect that behavior. template - Status ReadAll(T* result); + absl::Status ReadAll(T* result); - Status Reset() override; + absl::Status Reset() override; private: - Status FillBuffer(); + absl::Status FillBuffer(); template - Status ReadLineHelper(StringType* result, bool include_eol); + absl::Status ReadLineHelper(StringType* result, bool include_eol); InputStreamInterface* input_stream_; // not owned. size_t size_; // buffer size. @@ -108,7 +108,7 @@ class BufferedInputStream : public InputStreamInterface { bool owns_input_stream_ = false; // When EoF is reached, file_status_ contains the status to skip unnecessary // buffer allocations. - Status file_status_ = OkStatus(); + absl::Status file_status_ = absl::OkStatus(); BufferedInputStream(const BufferedInputStream&) = delete; void operator=(const BufferedInputStream&) = delete; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc index 56dd88510377bd..ab1f58e0b14a83 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc @@ -36,7 +36,7 @@ class ReadOnceInputStream : public InputStreamInterface { public: ReadOnceInputStream() : start_(true) {} - virtual Status ReadNBytes(int64_t bytes_to_read, tstring* result) { + virtual absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) { if (bytes_to_read < 11) { return errors::InvalidArgument("Not reading all bytes: ", bytes_to_read); } @@ -52,9 +52,9 @@ class ReadOnceInputStream : public InputStreamInterface { int64_t Tell() const override { return start_ ? 0 : 10; } // Resets the stream to the beginning. - Status Reset() override { + absl::Status Reset() override { start_ = true; - return OkStatus(); + return absl::OkStatus(); } private: @@ -311,7 +311,7 @@ TEST(BufferedInputStream, OutOfRangeCache) { TF_ASSERT_OK((in.ReadNBytes(7, &read))); EXPECT_EQ(read, "3456789"); EXPECT_EQ(10, in.Tell()); - Status s = in.ReadNBytes(5, &read); + absl::Status s = in.ReadNBytes(5, &read); // Make sure the read is failing with OUT_OF_RANGE error. If it is failing // with other errors, it is not caching the OUT_OF_RANGE properly. EXPECT_EQ(error::OUT_OF_RANGE, s.code()) << s; diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc b/third_party/xla/third_party/tsl/tsl/lib/io/format.cc index bc12656f7fbec7..d0b20da64a385e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/format.cc @@ -36,9 +36,9 @@ void BlockHandle::EncodeTo(string* dst) const { core::PutVarint64(dst, size_); } -Status BlockHandle::DecodeFrom(StringPiece* input) { +absl::Status BlockHandle::DecodeFrom(StringPiece* input) { if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::DataLoss("bad block handle"); } @@ -56,7 +56,7 @@ void Footer::EncodeTo(string* dst) const { assert(dst->size() == original_size + kEncodedLength); } -Status Footer::DecodeFrom(StringPiece* input) { +absl::Status Footer::DecodeFrom(StringPiece* input) { const char* magic_ptr = input->data() + kEncodedLength - 8; const uint32 magic_lo = core::DecodeFixed32(magic_ptr); const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4); @@ -66,7 +66,7 @@ Status Footer::DecodeFrom(StringPiece* input) { return errors::DataLoss("not an sstable (bad magic number)"); } - Status result = metaindex_handle_.DecodeFrom(input); + absl::Status result = metaindex_handle_.DecodeFrom(input); if (result.ok()) { result = index_handle_.DecodeFrom(input); } @@ -78,8 +78,8 @@ Status Footer::DecodeFrom(StringPiece* input) { return result; } -Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, - BlockContents* result) { +absl::Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result) { result->data = StringPiece(); result->cacheable = false; result->heap_allocated = false; @@ -94,7 +94,8 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, char* buf = new char[n + kBlockTrailerSize]; StringPiece contents; - Status s = file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); + absl::Status s = + file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); if (!s.ok()) { delete[] buf; return s; @@ -159,7 +160,7 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, return errors::DataLoss("bad block type"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace table diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.h b/third_party/xla/third_party/tsl/tsl/lib/io/format.h index cd8e863435f440..ae5bb26b8b8c86 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/format.h @@ -46,7 +46,7 @@ class BlockHandle { void set_size(uint64 size) { size_ = size; } void EncodeTo(string* dst) const; - Status DecodeFrom(StringPiece* input); + absl::Status DecodeFrom(StringPiece* input); // Maximum encoding length of a BlockHandle enum { kMaxEncodedLength = 10 + 10 }; @@ -71,7 +71,7 @@ class Footer { void set_index_handle(const BlockHandle& h) { index_handle_ = h; } void EncodeTo(string* dst) const; - Status DecodeFrom(StringPiece* input); + absl::Status DecodeFrom(StringPiece* input); // Encoded length of a Footer. Note that the serialization of a // Footer will always occupy exactly this many bytes. It consists @@ -99,8 +99,8 @@ struct BlockContents { // Read the block identified by "handle" from "file". On failure // return non-OK. On success fill *result and return OK. -extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, - BlockContents* result); +extern absl::Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, + BlockContents* result); // Implementation details follow. Clients should ignore, diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc index 1b9e2dc6fe2b21..3c183ee1ae1b3c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc @@ -33,9 +33,9 @@ InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes) InputBuffer::~InputBuffer() { delete[] buf_; } -Status InputBuffer::FillBuffer() { +absl::Status InputBuffer::FillBuffer() { StringPiece data; - Status s = file_->Read(file_pos_, size_, &data, buf_); + absl::Status s = file_->Read(file_pos_, size_, &data, buf_); if (data.data() != buf_) { memmove(buf_, data.data(), data.size()); } @@ -46,9 +46,9 @@ Status InputBuffer::FillBuffer() { } template -Status InputBuffer::ReadLine(T* result) { +absl::Status InputBuffer::ReadLine(T* result) { result->clear(); - Status s; + absl::Status s; do { size_t buf_remain = limit_ - pos_; char* newline = static_cast(memchr(pos_, '\n', buf_remain)); @@ -59,7 +59,7 @@ Status InputBuffer::ReadLine(T* result) { if (!result->empty() && result->back() == '\r') { result->resize(result->size() - 1); } - return OkStatus(); + return absl::OkStatus(); } if (buf_remain > 0) result->append(pos_, buf_remain); // Get more data into buffer @@ -70,7 +70,7 @@ Status InputBuffer::ReadLine(T* result) { result->resize(result->size() - 1); } if (errors::IsOutOfRange(s) && !result->empty()) { - return OkStatus(); + return absl::OkStatus(); } return s; } @@ -78,7 +78,8 @@ Status InputBuffer::ReadLine(T* result) { template Status InputBuffer::ReadLine(std::string* result); template Status InputBuffer::ReadLine(tstring* result); -Status InputBuffer::ReadNBytes(int64_t bytes_to_read, std::string* result) { +absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, + std::string* result) { result->clear(); if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", @@ -86,18 +87,18 @@ Status InputBuffer::ReadNBytes(int64_t bytes_to_read, std::string* result) { } result->resize(bytes_to_read); size_t bytes_read = 0; - Status status = ReadNBytes(bytes_to_read, &(*result)[0], &bytes_read); + absl::Status status = ReadNBytes(bytes_to_read, &(*result)[0], &bytes_read); if (bytes_read < bytes_to_read) result->resize(bytes_read); return status; } -Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, - size_t* bytes_read) { +absl::Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, + size_t* bytes_read) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); } - Status status; + absl::Status status; *bytes_read = 0; while (*bytes_read < static_cast(bytes_to_read)) { if (pos_ == limit_) { @@ -117,21 +118,21 @@ Status InputBuffer::ReadNBytes(int64_t bytes_to_read, char* result, } if (errors::IsOutOfRange(status) && (*bytes_read == static_cast(bytes_to_read))) { - return OkStatus(); + return absl::OkStatus(); } return status; } -Status InputBuffer::ReadVarint32Fallback(uint32* result) { - Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); +absl::Status InputBuffer::ReadVarint32Fallback(uint32* result) { + absl::Status s = ReadVarintFallback(result, core::kMaxVarint32Bytes); if (errors::IsDataLoss(s)) { return errors::DataLoss("Stored data is too large to be a varint32."); } return s; } -Status InputBuffer::ReadVarint64Fallback(uint64* result) { - Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); +absl::Status InputBuffer::ReadVarint64Fallback(uint64* result) { + absl::Status s = ReadVarintFallback(result, core::kMaxVarint64Bytes); if (errors::IsDataLoss(s)) { return errors::DataLoss("Stored data is too large to be a varint64."); } @@ -139,7 +140,7 @@ Status InputBuffer::ReadVarint64Fallback(uint64* result) { } template -Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { +absl::Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { uint8 scratch = 0; auto* p = reinterpret_cast(&scratch); size_t unused_bytes_read = 0; @@ -149,18 +150,18 @@ Status InputBuffer::ReadVarintFallback(T* result, int max_bytes) { int shift = 7 * index; TF_RETURN_IF_ERROR(ReadNBytes(1, p, &unused_bytes_read)); *result |= (static_cast(scratch) & 127) << shift; - if (!(scratch & 128)) return OkStatus(); + if (!(scratch & 128)) return absl::OkStatus(); } return errors::DataLoss("Stored data longer than ", max_bytes, " bytes."); } -Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { +absl::Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can only skip forward, not ", bytes_to_skip); } int64_t bytes_skipped = 0; - Status s; + absl::Status s; while (bytes_skipped < bytes_to_skip) { if (pos_ == limit_) { // Get more data into buffer @@ -175,12 +176,12 @@ Status InputBuffer::SkipNBytes(int64_t bytes_to_skip) { pos_ += bytes_to_advance; } if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) { - return OkStatus(); + return absl::OkStatus(); } return s; } -Status InputBuffer::Seek(int64_t position) { +absl::Status InputBuffer::Seek(int64_t position) { if (position < 0) { return errors::InvalidArgument("Seeking to a negative position: ", position); @@ -196,10 +197,10 @@ Status InputBuffer::Seek(int64_t position) { pos_ = limit_ = buf_; file_pos_ = position; } - return OkStatus(); + return absl::OkStatus(); } -Status InputBuffer::Hint(int64_t bytes_to_read) { +absl::Status InputBuffer::Hint(int64_t bytes_to_read) { if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); @@ -207,14 +208,14 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { // The internal buffer is too small. Do nothing. if (bytes_to_read > size_) { - return OkStatus(); + return absl::OkStatus(); } const int64_t bytes_remain_in_buf = static_cast(limit_ - pos_); // There are enough data in the buffer. Do nothing. if (bytes_to_read <= bytes_remain_in_buf) { - return OkStatus(); + return absl::OkStatus(); } // Additional read from file is necessary. Make some room. @@ -225,7 +226,7 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { // Read the remaining bytes from file. StringPiece data; - Status s = file_->Read(file_pos_, bytes_to_read, &data, limit_); + absl::Status s = file_->Read(file_pos_, bytes_to_read, &data, limit_); if (data.data() != limit_) { memmove(limit_, data.data(), data.size()); } @@ -233,7 +234,7 @@ Status InputBuffer::Hint(int64_t bytes_to_read) { file_pos_ += data.size(); if (errors::IsOutOfRange(s) && data.size() == bytes_to_read) { - return OkStatus(); + return absl::OkStatus(); } else { return s; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h index e357efb5f75b53..57a4a983c11e75 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h @@ -45,38 +45,39 @@ class InputBuffer { // file, we return an OUT_OF_RANGE error. Otherwise, we return // some other non-OK status. template - Status ReadLine(T* result); + absl::Status ReadLine(T* result); // Reads bytes_to_read bytes into *result, overwriting *result. // // If successful, returns OK. If we there are not enough bytes to // read before the end of the file, we return an OUT_OF_RANGE error. // Otherwise, we return some other non-OK status. - Status ReadNBytes(int64_t bytes_to_read, std::string* result); + absl::Status ReadNBytes(int64_t bytes_to_read, std::string* result); // An overload that writes to char*. Caller must ensure result[0, // bytes_to_read) is valid to be overwritten. Returns OK iff "*bytes_read == // bytes_to_read". - Status ReadNBytes(int64_t bytes_to_read, char* result, size_t* bytes_read); + absl::Status ReadNBytes(int64_t bytes_to_read, char* result, + size_t* bytes_read); // Reads a single varint32. - Status ReadVarint32(uint32* result); + absl::Status ReadVarint32(uint32* result); // Reads a single varint64. - Status ReadVarint64(uint64* result); + absl::Status ReadVarint64(uint64* result); // Like ReadNBytes() without returning the bytes read. - Status SkipNBytes(int64_t bytes_to_skip); + absl::Status SkipNBytes(int64_t bytes_to_skip); // Seek to this offset within the file. // // If we seek to somewhere within our pre-buffered data, we will re-use what // data we can. Otherwise, Seek() throws out the current buffer and the next // read will trigger a File::Read(). - Status Seek(int64_t position); + absl::Status Seek(int64_t position); // Provides a hint about future reads, which may improve their performance. - Status Hint(int64_t bytes_to_read); + absl::Status Hint(int64_t bytes_to_read); // Returns the position in the file. int64_t Tell() const { return file_pos_ - (limit_ - pos_); } @@ -85,19 +86,19 @@ class InputBuffer { RandomAccessFile* file() const { return file_; } private: - Status FillBuffer(); + absl::Status FillBuffer(); // Internal slow-path routine used by ReadVarint32(). - Status ReadVarint32Fallback(uint32* result); + absl::Status ReadVarint32Fallback(uint32* result); // Internal slow-path routine used by ReadVarint64(). - Status ReadVarint64Fallback(uint64* result); + absl::Status ReadVarint64Fallback(uint64* result); // Helper method for reading a varint which can span at max `max_bytes`. // If the varint is longer, a DataLoss error status is returned. // If end of file is reached while reading, OutOfRange error is returned. template - Status ReadVarintFallback(T* result, int max_bytes); + absl::Status ReadVarintFallback(T* result, int max_bytes); RandomAccessFile* file_; // Not owned int64_t file_pos_; // Next position to read from in "file_" @@ -118,28 +119,28 @@ extern template Status InputBuffer::ReadLine(std::string* result); extern template Status InputBuffer::ReadLine(tstring* result); // Inlined for performance. -inline Status InputBuffer::ReadVarint32(uint32* result) { +inline absl::Status InputBuffer::ReadVarint32(uint32* result) { if (pos_ + core::kMaxVarint32Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). const char* offset = core::GetVarint32Ptr(pos_, limit_, result); if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); pos_ = const_cast(offset); - return OkStatus(); + return absl::OkStatus(); } else { return ReadVarint32Fallback(result); } } // Inlined for performance. -inline Status InputBuffer::ReadVarint64(uint64* result) { +inline absl::Status InputBuffer::ReadVarint64(uint64* result) { if (pos_ + core::kMaxVarint64Bytes <= limit_) { // Fast path: directly parse from buffered data. // Reads strictly from the range [pos_, limit_). const char* offset = core::GetVarint64Ptr(pos_, limit_, result); if (offset == nullptr) return errors::OutOfRange("Parsed past limit."); pos_ = const_cast(offset); - return OkStatus(); + return absl::OkStatus(); } else { return ReadVarint64Fallback(result); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc index 1a2f11d4d2b2b2..6425ff0656b658 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc @@ -24,7 +24,7 @@ namespace io { // 8MB at a time. static constexpr int64_t kMaxSkipSize = 8 * 1024 * 1024; -Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { +absl::Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can't skip a negative number of bytes"); } @@ -35,7 +35,7 @@ Status InputStreamInterface::SkipNBytes(int64_t bytes_to_skip) { TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &unused)); bytes_to_skip -= bytes_to_read; } - return OkStatus(); + return absl::OkStatus(); } } // namespace io diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h index afe87a4b9cc37e..8eb7f2ad868965 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h @@ -35,13 +35,13 @@ class InputStreamInterface { // Reads the next bytes_to_read from the file. Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status ReadNBytes(int64_t bytes_to_read, tstring* result) = 0; + virtual absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) = 0; #if defined(TF_CORD_SUPPORT) // Reads the next bytes_to_read from the file. Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status ReadNBytes(int64_t bytes_to_read, absl::Cord* cord) { + virtual absl::Status ReadNBytes(int64_t bytes_to_read, absl::Cord* cord) { return errors::Unimplemented( "ReadNBytes(int64, absl::Cord*) is not implemented."); } @@ -51,7 +51,7 @@ class InputStreamInterface { // Typical return codes: // * OK - in case of success. // * OUT_OF_RANGE - not enough bytes remaining before end of file. - virtual Status SkipNBytes(int64_t bytes_to_skip); + virtual absl::Status SkipNBytes(int64_t bytes_to_skip); // Return the offset of the current byte relative to the beginning of the // file. @@ -61,7 +61,7 @@ class InputStreamInterface { virtual int64_t Tell() const = 0; // Resets the stream to the beginning. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; }; } // namespace io diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc index 2f7cda954fd13d..23d4fb0ddf50bc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc @@ -27,21 +27,21 @@ class TestStringStream : public InputStreamInterface { public: explicit TestStringStream(const string& content) : content_(content) {} - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { result->clear(); if (pos_ + bytes_to_read > content_.size()) { return errors::OutOfRange("limit reached"); } *result = content_.substr(pos_, bytes_to_read); pos_ += bytes_to_read; - return OkStatus(); + return absl::OkStatus(); } int64_t Tell() const override { return pos_; } - Status Reset() override { + absl::Status Reset() override { pos_ = 0; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc index 4dff9eb4f61761..a02a4254985087 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc @@ -53,7 +53,7 @@ void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) { namespace { class EmptyIterator : public Iterator { public: - explicit EmptyIterator(const Status& s) : status_(s) {} + explicit EmptyIterator(const absl::Status& s) : status_(s) {} bool Valid() const override { return false; } void Seek(const StringPiece& target) override {} void SeekToFirst() override {} @@ -66,16 +66,16 @@ class EmptyIterator : public Iterator { assert(false); return StringPiece(); } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } private: - Status status_; + absl::Status status_; }; } // namespace -Iterator* NewEmptyIterator() { return new EmptyIterator(OkStatus()); } +Iterator* NewEmptyIterator() { return new EmptyIterator(absl::OkStatus()); } -Iterator* NewErrorIterator(const Status& status) { +Iterator* NewErrorIterator(const absl::Status& status) { return new EmptyIterator(status); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h index bb83f41ea47dd9..f0b16943c44b9c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h @@ -68,7 +68,7 @@ class Iterator { virtual StringPiece value() const = 0; // If an error has occurred, return it. Else return an ok status. - virtual Status status() const = 0; + virtual absl::Status status() const = 0; // Clients are allowed to register function/arg1/arg2 triples that // will be invoked when this iterator is destroyed. @@ -96,7 +96,7 @@ class Iterator { extern Iterator* NewEmptyIterator(); // Return an empty iterator with the specified status. -extern Iterator* NewErrorIterator(const Status& status); +extern Iterator* NewErrorIterator(const absl::Status& status); } // namespace table } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc index 1b5262057771b7..841e3d1bf26f6c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc @@ -30,8 +30,8 @@ RandomAccessInputStream::~RandomAccessInputStream() { } } -Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, - tstring* result) { +absl::Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, + tstring* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } @@ -39,7 +39,7 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, result->resize_uninitialized(bytes_to_read); char* result_buffer = &(*result)[0]; StringPiece data; - Status s = file_->Read(pos_, bytes_to_read, &data, result_buffer); + absl::Status s = file_->Read(pos_, bytes_to_read, &data, result_buffer); if (data.data() != result_buffer) { memmove(result_buffer, data.data(), data.size()); } @@ -51,13 +51,13 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, } #if defined(TF_CORD_SUPPORT) -Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, - absl::Cord* result) { +absl::Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, + absl::Cord* result) { if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } int64_t current_size = result->size(); - Status s = file_->Read(pos_, bytes_to_read, result); + absl::Status s = file_->Read(pos_, bytes_to_read, result); if (s.ok() || errors::IsOutOfRange(s)) { pos_ += result->size() - current_size; } @@ -69,7 +69,7 @@ Status RandomAccessInputStream::ReadNBytes(int64_t bytes_to_read, // 8MB at a time. static constexpr int64_t kMaxSkipSize = 8 * 1024 * 1024; -Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { +absl::Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { if (bytes_to_skip < 0) { return errors::InvalidArgument("Can't skip a negative number of bytes"); } @@ -78,17 +78,18 @@ Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { // not reached yet and we could return. if (bytes_to_skip > 0) { StringPiece data; - Status s = file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); + absl::Status s = + file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get()); if ((s.ok() || errors::IsOutOfRange(s)) && data.size() == 1) { pos_ += bytes_to_skip; - return OkStatus(); + return absl::OkStatus(); } } // Read kDefaultSkipSize at a time till bytes_to_skip. while (bytes_to_skip > 0) { int64_t bytes_to_read = std::min(kMaxSkipSize, bytes_to_skip); StringPiece data; - Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); + absl::Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get()); if (s.ok() || errors::IsOutOfRange(s)) { pos_ += data.size(); } else { @@ -99,7 +100,7 @@ Status RandomAccessInputStream::SkipNBytes(int64_t bytes_to_skip) { } bytes_to_skip -= bytes_to_read; } - return OkStatus(); + return absl::OkStatus(); } int64_t RandomAccessInputStream::Tell() const { return pos_; } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h index e1608ce3ec2b9b..4d48db62c2b03f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h @@ -33,22 +33,22 @@ class RandomAccessInputStream : public InputStreamInterface { ~RandomAccessInputStream() override; - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; #if defined(TF_CORD_SUPPORT) - Status ReadNBytes(int64_t bytes_to_read, absl::Cord* result) override; + absl::Status ReadNBytes(int64_t bytes_to_read, absl::Cord* result) override; #endif - Status SkipNBytes(int64_t bytes_to_skip) override; + absl::Status SkipNBytes(int64_t bytes_to_skip) override; int64_t Tell() const override; - Status Seek(int64_t position) { + absl::Status Seek(int64_t position) { pos_ = position; - return OkStatus(); + return absl::OkStatus(); } - Status Reset() override { return Seek(0); } + absl::Status Reset() override { return Seek(0); } private: RandomAccessFile* file_; // Not owned. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc index e267b5cee84dab..8d17c610b09f71 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc @@ -101,7 +101,8 @@ inline const char* GetChecksumErrorSuffix(uint64 offset) { // and is used only in error messages. For failures at offset 0, // a reminder about the file format is added, because TFRecord files // contain no explicit format marker. -Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) { +absl::Status RecordReader::ReadChecksummed(uint64 offset, size_t n, + tstring* result) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large", GetChecksumErrorSuffix(offset)); @@ -125,10 +126,10 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) { GetChecksumErrorSuffix(offset)); } result->resize(n); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::GetMetadata(Metadata* md) { +absl::Status RecordReader::GetMetadata(Metadata* md) { if (!md) { return errors::InvalidArgument( "Metadata object call to GetMetadata() was null"); @@ -148,7 +149,7 @@ Status RecordReader::GetMetadata(Metadata* md) { tstring record; while (true) { // Read header, containing size of data. - Status s = ReadChecksummed(offset, sizeof(uint64), &record); + absl::Status s = ReadChecksummed(offset, sizeof(uint64), &record); if (!s.ok()) { if (errors::IsOutOfRange(s)) { // We should reach out of range when the record file is complete. @@ -178,10 +179,10 @@ Status RecordReader::GetMetadata(Metadata* md) { } md->stats = cached_metadata_->stats; - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::PositionInputStream(uint64 offset) { +absl::Status RecordReader::PositionInputStream(uint64 offset) { int64_t curr_pos = input_stream_->Tell(); int64_t desired_pos = static_cast(offset); if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ || @@ -193,14 +194,14 @@ Status RecordReader::PositionInputStream(uint64 offset) { TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos)); } DCHECK_EQ(desired_pos, input_stream_->Tell()); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::ReadRecord(uint64* offset, tstring* record) { +absl::Status RecordReader::ReadRecord(uint64* offset, tstring* record) { TF_RETURN_IF_ERROR(PositionInputStream(*offset)); // Read header data. - Status s = ReadChecksummed(*offset, sizeof(uint64), record); + absl::Status s = ReadChecksummed(*offset, sizeof(uint64), record); if (!s.ok()) { last_read_failed_ = true; return s; @@ -220,14 +221,14 @@ Status RecordReader::ReadRecord(uint64* offset, tstring* record) { *offset += kHeaderSize + length + kFooterSize; DCHECK_EQ(*offset, input_stream_->Tell()); - return OkStatus(); + return absl::OkStatus(); } -Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, - int* num_skipped) { +absl::Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, + int* num_skipped) { TF_RETURN_IF_ERROR(PositionInputStream(*offset)); - Status s; + absl::Status s; tstring record; *num_skipped = 0; for (int i = 0; i < num_to_skip; ++i) { @@ -252,7 +253,7 @@ Status RecordReader::SkipRecords(uint64* offset, int num_to_skip, DCHECK_EQ(*offset, input_stream_->Tell()); (*num_skipped)++; } - return OkStatus(); + return absl::OkStatus(); } SequentialRecordReader::SequentialRecordReader( diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h index 282c0daff2a5a8..61540a657324c8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h @@ -94,14 +94,14 @@ class RecordReader { // Read the record at "*offset" into *record and update *offset to // point to the offset of the next record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - Status ReadRecord(uint64* offset, tstring* record); + absl::Status ReadRecord(uint64* offset, tstring* record); // Skip num_to_skip record starting at "*offset" and update *offset // to point to the offset of the next num_to_skip + 1 record. // Return OK on success, OUT_OF_RANGE for end of file, or something // else for an error. "*num_skipped" records the number of records that // are actually skipped. It should be equal to num_to_skip on success. - Status SkipRecords(uint64* offset, int num_to_skip, int* num_skipped); + absl::Status SkipRecords(uint64* offset, int num_to_skip, int* num_skipped); // Return the metadata of the Record file. // @@ -112,11 +112,11 @@ class RecordReader { // so that GetMetadata() could be a const method. // // 'metadata' must not be nullptr. - Status GetMetadata(Metadata* md); + absl::Status GetMetadata(Metadata* md); private: - Status ReadChecksummed(uint64 offset, size_t n, tstring* result); - Status PositionInputStream(uint64 offset); + absl::Status ReadChecksummed(uint64 offset, size_t n, tstring* result); + absl::Status PositionInputStream(uint64 offset); RecordReaderOptions options_; std::unique_ptr input_stream_; @@ -143,7 +143,7 @@ class SequentialRecordReader { // Read the next record in the file into *record. Returns OK on success, // OUT_OF_RANGE for end of file, or something else for an error. - Status ReadRecord(tstring* record) { + absl::Status ReadRecord(tstring* record) { return underlying_.ReadRecord(&offset_, record); } @@ -151,7 +151,7 @@ class SequentialRecordReader { // OUT_OF_RANGE for end of file, or something else for an error. // "*num_skipped" records the number of records that are actually skipped. // It should be equal to num_to_skip on success. - Status SkipRecords(int num_to_skip, int* num_skipped) { + absl::Status SkipRecords(int num_to_skip, int* num_skipped) { return underlying_.SkipRecords(&offset_, num_to_skip, num_skipped); } @@ -160,13 +160,13 @@ class SequentialRecordReader { // Seek to this offset within the file and set this offset as the current // offset. Trying to seek backward will throw error. - Status SeekOffset(uint64 offset) { + absl::Status SeekOffset(uint64 offset) { if (offset < offset_) return errors::InvalidArgument( "Trying to seek offset: ", offset, " which is less than the current offset: ", offset_); offset_ = offset; - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc index 2497db348a5729..67df783112f9ee 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc @@ -226,7 +226,7 @@ TEST(RecordReaderWriterTest, TestSkipOutOfRange) { uint64 offset = 0; int num_skipped; tstring record; - Status s = reader.SkipRecords(&offset, 3, &num_skipped); + absl::Status s = reader.SkipRecords(&offset, 3, &num_skipped); EXPECT_EQ(2, num_skipped); EXPECT_EQ(error::OUT_OF_RANGE, s.code()); } @@ -254,7 +254,7 @@ TEST(RecordReaderWriterTest, TestMalformedInput) { tstring record; // At offset 0, the error message reminds of the file type. uint64 offset = 0; - Status s = reader.ReadRecord(&offset, &record); + absl::Status s = reader.ReadRecord(&offset, &record); EXPECT_EQ(error::DATA_LOSS, s.code()); EXPECT_EQ("corrupted record at 0 (Is this even a TFRecord file?)", s.message()); diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc index aace9f10e14c6d..9a6a932dd77a26 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc @@ -69,7 +69,7 @@ RecordWriter::RecordWriter(WritableFile* dest, ZlibOutputBuffer* zlib_output_buffer = new ZlibOutputBuffer( dest, options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options); - Status s = zlib_output_buffer->Init(); + absl::Status s = zlib_output_buffer->Init(); if (!s.ok()) { LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: " << s.ToString(); @@ -89,17 +89,17 @@ RecordWriter::RecordWriter(WritableFile* dest, RecordWriter::~RecordWriter() { if (dest_ != nullptr) { - Status s = Close(); + absl::Status s = Close(); if (!s.ok()) { LOG(ERROR) << "Could not finish writing file: " << s; } } } -Status RecordWriter::WriteRecord(StringPiece data) { +absl::Status RecordWriter::WriteRecord(StringPiece data) { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } // Format of a single record: // uint64 length @@ -116,10 +116,10 @@ Status RecordWriter::WriteRecord(StringPiece data) { } #if defined(TF_CORD_SUPPORT) -Status RecordWriter::WriteRecord(const absl::Cord& data) { +absl::Status RecordWriter::WriteRecord(const absl::Cord& data) { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } // Format of a single record: // uint64 length @@ -136,21 +136,21 @@ Status RecordWriter::WriteRecord(const absl::Cord& data) { } #endif -Status RecordWriter::Close() { - if (dest_ == nullptr) return OkStatus(); +absl::Status RecordWriter::Close() { + if (dest_ == nullptr) return absl::OkStatus(); if (IsZlibCompressed(options_) || IsSnappyCompressed(options_)) { - Status s = dest_->Close(); + absl::Status s = dest_->Close(); delete dest_; dest_ = nullptr; return s; } - return OkStatus(); + return absl::OkStatus(); } -Status RecordWriter::Flush() { +absl::Status RecordWriter::Flush() { if (dest_ == nullptr) { - return Status(absl::StatusCode::kFailedPrecondition, - "Writer not initialized or previously closed"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Writer not initialized or previously closed"); } return dest_->Flush(); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h index b585cb9b52f70c..06e9a5c847910c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h @@ -77,22 +77,22 @@ class RecordWriter { // implicit Close() call in the destructor. ~RecordWriter(); - Status WriteRecord(StringPiece data); + absl::Status WriteRecord(StringPiece data); #if defined(TF_CORD_SUPPORT) - Status WriteRecord(const absl::Cord& data); + absl::Status WriteRecord(const absl::Cord& data); #endif // Flushes any buffered data held by underlying containers of the // RecordWriter to the WritableFile. Does *not* flush the // WritableFile. - Status Flush(); + absl::Status Flush(); // Writes all output to the file. Does *not* close the WritableFile. // // After calling Close(), any further calls to `WriteRecord()` or `Flush()` // are invalid. - Status Close(); + absl::Status Close(); // Utility method to populate TFRecord headers. Populates record-header in // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1]. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc index f9702f2ed13997..42adf76f7ef0d3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc @@ -55,22 +55,22 @@ class StringDest : public WritableFile { public: explicit StringDest(string* contents) : contents_(contents) {} - Status Close() override { return OkStatus(); } - Status Flush() override { return OkStatus(); } - Status Sync() override { return OkStatus(); } - Status Append(StringPiece slice) override { + absl::Status Close() override { return absl::OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } + absl::Status Append(StringPiece slice) override { contents_->append(slice.data(), slice.size()); - return OkStatus(); + return absl::OkStatus(); } #if defined(TF_CORD_SUPPORT) - Status Append(const absl::Cord& data) override { + absl::Status Append(const absl::Cord& data) override { contents_->append(std::string(data)); - return OkStatus(); + return absl::OkStatus(); } #endif - Status Tell(int64_t* pos) override { + absl::Status Tell(int64_t* pos) override { *pos = contents_->size(); - return OkStatus(); + return absl::OkStatus(); } private: @@ -82,8 +82,8 @@ class StringSource : public RandomAccessFile { explicit StringSource(string* contents) : contents_(contents), force_error_(false) {} - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { + absl::Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { if (force_error_) { force_error_ = false; return errors::DataLoss("read error"); @@ -97,7 +97,7 @@ class StringSource : public RandomAccessFile { n = contents_->size() - offset; } *result = StringPiece(contents_->data() + offset, n); - return OkStatus(); + return absl::OkStatus(); } void force_error() { force_error_ = true; } @@ -150,7 +150,7 @@ class RecordioTest : public ::testing::Test { reading_ = true; } tstring record; - Status s = reader_->ReadRecord(&readpos_, &record); + absl::Status s = reader_->ReadRecord(&readpos_, &record); if (s.ok()) { return record; } else if (errors::IsOutOfRange(s)) { @@ -184,7 +184,7 @@ class RecordioTest : public ::testing::Test { reading_ = true; uint64 offset = WrittenBytes() + offset_past_end; tstring record; - Status s = reader_->ReadRecord(&offset, &record); + absl::Status s = reader_->ReadRecord(&offset, &record); ASSERT_TRUE(errors::IsOutOfRange(s)) << s; } }; @@ -317,7 +317,7 @@ void TestReadError(const RecordWriterOptions& writer_options, uint64 offset = 0; tstring read; file.force_error(); - Status status = reader.ReadRecord(&offset, &read); + absl::Status status = reader.ReadRecord(&offset, &read); ASSERT_TRUE(errors::IsDataLoss(status)); ASSERT_EQ(0, offset); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index c6daa3e8d6745a..ea8feebb6aa687 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1259,16 +1259,16 @@ xla_test( deps = [ ":gpu_device_info_for_tests", ":ir_emitter_triton", - ":matmul_utils", ":triton_fusion_analysis", ":triton_support", ":triton_test_utils", + "//third_party/protobuf", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index fd03050d701fe3..d3e02d7e09237d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -167,7 +167,7 @@ using mlir::ValueRange; namespace { // XLA -> Triton type conversions. -Type TritonType(mlir::OpBuilder b, PrimitiveType t) { +absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { switch (t) { case F64: return b.getF64Type(); @@ -195,8 +195,9 @@ Type TritonType(mlir::OpBuilder b, PrimitiveType t) { // Triton. return b.getFloat8E4M3FNUZType(); default: - LOG(FATAL) << "This type is not supported yet: " - << primitive_util::LowercasePrimitiveTypeName(t); + return absl::UnimplementedError( + absl::StrCat("This type is not supported yet: ", + primitive_util::LowercasePrimitiveTypeName(t))); } } @@ -485,8 +486,11 @@ absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, case HloOpcode::kNegate: // NegFOp is not supported by Triton. return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); - case HloOpcode::kConvert: - return Cast(b, inputs[0], TritonType(b, hlo.shape().element_type())); + case HloOpcode::kConvert: { + TF_ASSIGN_OR_RETURN(Type dst_ty, + TritonType(b, hlo.shape().element_type())); + return Cast(b, inputs[0], dst_ty); + } case HloOpcode::kAdd: if (is_integer) { return b.create(inputs[0], inputs[1]); @@ -577,8 +581,9 @@ Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, {}); } -Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) { - Type ty = TritonType(b, constant.shape().element_type()); +absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, + const HloInstruction& constant) { + TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type())); if (constant.shape().IsInteger()) { if (constant.shape().element_type() == U64) { return CreateConst(b, ty, ScalarConstantValue(constant, U64)); @@ -681,13 +686,14 @@ absl::StatusOr EmitReduce(ImplicitLocOpBuilder& b, if (operand->opcode() == HloOpcode::kConvert) { TF_RET_CHECK(operand->operand(0)->opcode() == HloOpcode::kConstant); TF_RET_CHECK(operand->operand(0)->shape().element_type() == BF16); - PrimitiveType dest_ty = operand->shape().element_type(); - TF_RET_CHECK(dest_ty == F32); - neutral = EmitConstant(b, *operand->operand(0)); - neutral = Cast(b, neutral, TritonType(b, dest_ty)); + TF_RET_CHECK(operand->shape().element_type() == F32); + TF_ASSIGN_OR_RETURN(Type dst_ty, + TritonType(b, operand->shape().element_type())); + TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand->operand(0))); + neutral = Cast(b, neutral, dst_ty); } else { TF_RET_CHECK(operand->opcode() == HloOpcode::kConstant); - neutral = EmitConstant(b, *operand); + TF_ASSIGN_OR_RETURN(neutral, EmitConstant(b, *operand)); } // Since every shape is padded to a power of 2 in Triton, the input tile may @@ -756,7 +762,9 @@ absl::StatusOr EmitReduce(ImplicitLocOpBuilder& b, result = Splat(b, result, {}); } - return Cast(b, result, TritonType(b, hlo_reduce.shape().element_type())); + TF_ASSIGN_OR_RETURN(Type result_ty, + TritonType(b, hlo_reduce.shape().element_type())); + return Cast(b, result, result_ty); } // Emit code corresponding to a fusion instruction somehow nested within the @@ -873,8 +881,9 @@ absl::StatusOr EmitTiledHloInstruction( if (hlo->opcode() == HloOpcode::kConstant && ShapeUtil::IsEffectiveScalar(hlo->shape())) { + TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); // Splat makes it a tensor to avoid type mismatches. - return Splat(b, EmitConstant(b, *hlo), {}); + return Splat(b, constant, {}); } if (hlo->opcode() == HloOpcode::kBroadcast) { @@ -896,16 +905,18 @@ absl::StatusOr EmitTiledHloInstruction( return EmitElementwise(b, libdevice_path, device_info, *hlo, operands); } - if (hlo->opcode() == HloOpcode::kTranspose || - hlo->opcode() == HloOpcode::kSlice || hlo->opcode() == HloOpcode::kPad) { - // All these are currently supported only as operations on indices - // which are pushed to loads and stores. No operations on tiles are - // performed here. + // All these operations are currently supported only as operations on indices + // which are pushed to loads and stores. We don't generate any further code + // for these operations here. + std::vector passthrough_opcodes( + {HloOpcode::kBitcast, HloOpcode::kPad, HloOpcode::kReshape, + HloOpcode::kSlice, HloOpcode::kTranspose}); + if (absl::c_linear_search(passthrough_opcodes, hlo->opcode())) { return values[tiled_hlo.operand(0)]; } return absl::UnimplementedError( - absl::StrCat("Unsupported opcode: ", hlo->opcode())); + absl::StrCat("Unsupported operation ", hlo->ToString())); } // Emit sequence of instructions using compatible tiling ordered producers @@ -954,8 +965,9 @@ absl::StatusOr EmitScope( TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; } else if (hlo->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); // Splat makes it a tensor to avoid type mismatches. - result = Splat(b, EmitConstant(b, *hlo), {}); + result = Splat(b, constant, {}); } else if (hlo->opcode() == HloOpcode::kBroadcast) { TF_ASSIGN_OR_RETURN( result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo, @@ -1362,12 +1374,13 @@ class MatMulEmitterHelper { // TODO(b/266862493): Accumulator can be integer too. // Otherwise only f64 x f64 -> f64 uses f64 accumulator. - mlir::FloatType GetDotAccumulatorType() { + absl::StatusOr GetDotAccumulatorType() { const PrecisionConfig::Algorithm algorithm = dot_instr_->precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_UNSET) { - Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type()); + TF_ASSIGN_OR_RETURN(Type dot_output_ty, + TritonType(b_, dot_instr_->shape().element_type())); // The code below assumes that lhs and rhs have the same type. However // it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported // at the hardware level. NVidia GPU currently only supports f32 @@ -1377,14 +1390,14 @@ class MatMulEmitterHelper { } // Data type of dot() immediate inputs. - Type dot_input_ty = [&] { - const Type lhs_ty = - TritonType(b_, dot_instr_->operand(0)->shape().element_type()); - const Type rhs_ty = - TritonType(b_, dot_instr_->operand(1)->shape().element_type()); - CHECK(lhs_ty == rhs_ty); - return lhs_ty; - }(); + TF_ASSIGN_OR_RETURN( + const Type lhs_ty, + TritonType(b_, dot_instr_->operand(0)->shape().element_type())); + TF_ASSIGN_OR_RETURN( + const Type rhs_ty, + TritonType(b_, dot_instr_->operand(1)->shape().element_type())); + TF_RET_CHECK(lhs_ty == rhs_ty); + Type dot_input_ty = lhs_ty; // TODO(b/266862493): Accumulator can be integer too. // Otherwise only f64 x f64 -> f64 uses f64 accumulator. return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() @@ -1395,7 +1408,8 @@ class MatMulEmitterHelper { algorithm_util::GetDotAccumulatorType(algorithm); CHECK(accum_type.ok()) << "Unexpected algorithm: " << PrecisionConfig::Algorithm_Name(algorithm); - Type mlir_accum_type = TritonType(b_, accum_type.value()); + TF_ASSIGN_OR_RETURN(Type mlir_accum_type, + TritonType(b_, accum_type.value())); if (auto float_accum_type = mlir::dyn_cast(mlir_accum_type)) { return float_accum_type; @@ -2131,10 +2145,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, if (node->opcode() != HloOpcode::kConvert) { return false; } - Type in_type = - TritonType(builder, node->operand(0)->shape().element_type()); - Type out_type = TritonType(builder, node->shape().element_type()); - return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32(); + int in_width = + primitive_util::BitWidth(node->operand(0)->shape().element_type()); + return in_width <= 8 && node->shape().element_type() == F32; }); // We'll be creating a lot of instructions from a single dot, use an @@ -2181,7 +2194,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, auto pid_n = b.create(b.create(pid_nc, c32(width)), group_size); - mlir::FloatType acc_ty = emitter.GetDotAccumulatorType(); + TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType()); ma::ConstantOp accumulator_init = CreateConst(b, acc_ty, 0, {block_m, block_n}); @@ -2229,8 +2242,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size(); + absl::flat_hash_map triton_type_for_input; + for (const Side& side : {lhs, rhs}) { + for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { + TF_ASSIGN_OR_RETURN(Type input_ty, + TritonType(b, input->shape().element_type())); + triton_type_for_input.insert({input, input_ty}); + } + } + auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki, - ValueRange iter_args) { + ValueRange iter_args) -> void { SmallVector iter_args_next; iter_args_next.reserve(iter_args.size()); std::array, 3> values; @@ -2243,7 +2265,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, const HloInstruction* param_hlo = iter_args_to_inputs[i]; Type param_ty = index == kLhsMetaOperandIdx ? b.getI16Type() - : TritonType(b, param_hlo->shape().element_type()); + : triton_type_for_input.at(param_hlo); Type param_storage_ty = StorageType(b, param_ty); Value param_value = EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]); @@ -2364,6 +2386,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, iter_args_next.push_back(accumulator_next); b.create(iter_args_next); + return; }; // Pointers to inputs of LHS scope, then RHS, then the accumulator @@ -2393,8 +2416,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, /*iterArgs=*/iter_args, body_builder) .getResult(iter_args.size() - 1); absl::flat_hash_map values_out; - values_out[dot_instr] = - Cast(b, acc_final, TritonType(b, dot_instr->shape().element_type())); + TF_ASSIGN_OR_RETURN(Type acc_final_ty, + TritonType(b, dot_instr->shape().element_type())); + values_out[dot_instr] = Cast(b, acc_final, acc_final_ty); // Emit the output scope. if (std::vector to_emit = @@ -2774,16 +2798,21 @@ absl::StatusOr> CreateTritonModule( SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); - Type ir_type = type != U16 ? TritonType(b, type) : b.getI16Type(); + Type ir_type; + if (type == U16) { + ir_type = b.getI16Type(); + } else { + TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type)); + } fn_arg_types.push_back( mt::PointerType::get(StorageType(b, ir_type), mn::kGlobalMemorySpace)); } for (const ShapeUtil::IndexedShape& s : ShapeUtil::GetLeafShapes(fusion->shape())) { - fn_arg_types.push_back(mt::PointerType::get( - StorageType(b, TritonType(b, s.shape.element_type())), - mn::kGlobalMemorySpace)); + TF_ASSIGN_OR_RETURN(Type triton_ty, TritonType(b, s.shape.element_type())); + fn_arg_types.push_back(mt::PointerType::get(StorageType(b, triton_ty), + mn::kGlobalMemorySpace)); } auto fn = b.create(loc, fn_name, diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 4ece9de1054dee..9da56e03fffedf 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -261,6 +261,7 @@ xla_cc_test( name = "gpu_performance_model_test", srcs = ["gpu_performance_model_test.cc"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_indexing_performance_model", ":gpu_performance_model", @@ -341,25 +342,35 @@ cc_library( hdrs = ["gpu_indexing_performance_model.h"], deps = [ ":coalescing_analysis", + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_performance_model_base", ":hlo_op_profiles", ":indexing_analysis", + ":symbolic_tile_analysis", + ":tiled_hlo_computation", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", + "//xla/service:instruction_fusion", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions:triton", "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -367,6 +378,7 @@ xla_cc_test( name = "gpu_indexing_performance_model_test", srcs = ["gpu_indexing_performance_model_test.cc"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", ":gpu_indexing_performance_model", ":gpu_performance_model_base", diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 3ab064f4391f41..6b8cc815fc52f0 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -17,15 +17,21 @@ limitations under the License. #include #include +#include +#include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" @@ -34,10 +40,14 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -240,5 +250,107 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes( return {time_unfused, time_fused}; } +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + const std::vector& tile_sizes) { + // TODO(b/332714755): Add caching for SymbolicTileAnalysis. + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + if (!std::holds_alternative(analysis_or_error)) { + return absl::FailedPreconditionError( + absl::StrCat("SymbolicTileAnalysis failed. ", + std::get(analysis_or_error).Explain())); + } + SymbolicTileAnalysis analysis = + std::get(std::move(analysis_or_error)); + + TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + analysis.ComputeTiledHloInstructions(tile_sizes)); + + absl::flat_hash_map n_bytes_total_map; + + int64_t flops = 0; + int64_t bytes_read = 0; + + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + // Number of blocks that read or compute this tile. + int64_t num_blocks = tiled_hlo->block_id_to_tile_offsets_indexing() + .GetDimensionBound(0) + .GetLoopTripCount(); + + // Total number of elements that are read from memory or computed for this + // tile across all blocks. + int64_t num_elements = num_blocks * Product(tiled_hlo->tile_sizes()); + + const HloInstruction* hlo = tiled_hlo->hlo(); + + if (fusion_adaptor.ContainsInstruction(hlo)) { + // Tiles inside the computation contribute to the total FLOPs count. + flops += FlopsPerElement(hlo) * num_elements; + } else { + // Tiles of the operands of the fusion contribute to the total memory + // read time. + int64_t element_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(hlo->shape().element_type()); + int64_t tile_bytes_read = element_type_size * num_elements; + + bytes_read += tile_bytes_read; + n_bytes_total_map[hlo] += tile_bytes_read; + } + } + + int64_t num_blocks = launch_dimensions.num_blocks(); + absl::Duration read_time = absl::ZeroDuration(); + for (const auto& [hlo, n_bytes_total] : n_bytes_total_map) { + int64_t operand_size = shape_size_(hlo->shape()); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + + read_time += ReadTimeWithDRAMHeuristic( + *device_info_, num_blocks, n_bytes_net, n_bytes_total, + /*element_type=*/hlo->shape().element_type(), + /*coalesced=*/true); + } + + int64_t bytes_written = + GetShapeSizeRecursive(tiled_hlo_computation.GetRoot()->hlo()->shape()); + + absl::Duration compute_time = + ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), + launch_dimensions.num_threads_per_block()); + absl::Duration write_time = WriteTime(*device_info_, bytes_written); + absl::Duration memory_access_time = read_time + write_time; + absl::Duration exec_time = CombineComputeAndMemoryAccessTime( + compute_time, memory_access_time, + GpuPerformanceModelOptions::PriorityFusion()); + + return EstimateRunTimeData{/*flops=*/flops, + /*bytes_read=*/bytes_read, + /*bytes_written=*/bytes_written, + /*read_time=*/read_time, + /*write_time=*/write_time, + /*compute_time=*/compute_time, + /*exec_time=*/exec_time}; +} + +absl::StatusOr +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( + const HloInstruction* producer, const HloInstruction* consumer) { + const auto& fusion_analysis = + (consumer == nullptr) ? fusion_analysis_cache_->Get(*producer) + : fusion_analysis_cache_->Get(*producer, *consumer); + auto launch_config = TritonFusion(fusion_analysis).launch_config(); + + if (!launch_config.has_value()) { + return absl::InvalidArgumentError( + "Could not get launch config for Triton fusion."); + } + + return EstimateRunTimeForTiledFusion(fusion_analysis.fusion(), + launch_config->launch_dimensions, + launch_config->output_tile_sizes); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 14d7e520a820d3..a1f98a8660d663 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -16,12 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ #define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ +#include #include +#include +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_cost_analysis.h" @@ -37,10 +43,12 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { public: explicit GpuPerformanceModelWithIndexingAnalysis( const se::DeviceDescription* device_info, + HloFusionAnalysisCache* fusion_analysis_cache, HloCostAnalysis::ShapeSizeFunction shape_size, mlir::MLIRContext* mlir_context) : hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)), device_info_(device_info), + fusion_analysis_cache_(fusion_analysis_cache), shape_size_(shape_size), mlir_context_(mlir_context) {} @@ -57,6 +65,23 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloInstruction* producer, absl::Span fused_consumers = {}); + // Estimate the run time of the fusion with the given launch dimensions and + // output tile sizes. + // + // The model uses SymbolicTileAnalysis to build a TiledHloComputation with the + // given tile sizes. This way it can better estimate the amount of memory + // access and computation. + absl::StatusOr EstimateRunTimeForTiledFusion( + const HloFusionAdaptor& fusion_adaptor, + const LaunchDimensions& launch_dimensions, + const std::vector& output_tile_sizes); + + // Estimate the run time of producer and consumer fused together, assuming + // that they will be emitted with Triton. + // If consumer is nullptr, estimate run time of the producer alone. + absl::StatusOr EstimateRunTimeForTriton( + const HloInstruction* producer, const HloInstruction* consumer = nullptr); + private: // Returns an estimate how many FLOPs will be used to produce one element of // the output. @@ -66,6 +91,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const HloOpProfiles::HloOpProfile* hlo_op_profile_; const se::DeviceDescription* device_info_; + HloFusionAnalysisCache* fusion_analysis_cache_; HloCostAnalysis::ShapeSizeFunction shape_size_; mlir::MLIRContext* mlir_context_; }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 5e52685762e524..57468cf258d521 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/shape.h" @@ -51,8 +52,10 @@ class GpuIndexingPerformanceModelTest : public HloTestBase { // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &mlir_context_}; GpuIndexingPerformanceModelTest() : HloTestBase() {} }; @@ -167,6 +170,106 @@ ENTRY entry_computation { EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 58, 1); } +TEST_F(GpuIndexingPerformanceModelTest, + TritonSoftmaxFusionInstructionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + broadcast_0 = f32[512,911]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[512,911]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911]{1,0} parameter(0) + param_1 = f32[911]{0} parameter(1) + ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + TF_ASSERT_OK_AND_ASSIGN(auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton( + module->entry_computation()->root_instruction())); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Each block reads 1 tile of shape [1, 911] from param_0 and full param_1. + // In total param_0 is read once and param_1 is read 512 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 512 * kParam1SizeBytes; + + EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); +} + +TEST_F(GpuIndexingPerformanceModelTest, + TritonSoftmaxProducerConsumerFusionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +fusion { + param_0 = f32[512,911] parameter(0) + param_1 = f32[911] parameter(1) + broadcast = f32[512,911] broadcast(param_1), dimensions={1} + ROOT multiply = f32[512,911] multiply(param_0, broadcast) +} + +triton_softmax_computation { + param_0 = f32[512,911] parameter(0) + constant_0 = f32[] constant(0) + reduce_0 = f32[512] reduce(param_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[512,911] broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[512,911] multiply(param_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[512,911] parameter(0) + param_1 = f32[911] parameter(1) + fusion.1 = f32[512,911] fusion(param_0, param_1), kind=kLoop, calls=fusion + ROOT triton_softmax = f32[512,911] fusion(fusion.1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto consumer = module->entry_computation()->root_instruction(); + auto producer = consumer->operand(0); + + TF_ASSERT_OK_AND_ASSIGN( + auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton(producer, consumer)); + + constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; + constexpr int64_t kParam1SizeBytes = 911 * 4; + constexpr int64_t kOutputSizeBytes = 512 * 911 * 4; + + // Each block reads 1 tile of shape [1, 911] from param_0 and full param_1. + // In total param_0 is read once and param_1 is read 512 times. + constexpr int64_t kExpectedBytesRead = + kParam0SizeBytes + 512 * kParam1SizeBytes; + + EXPECT_EQ(runtime_data.bytes_read, kExpectedBytesRead); + EXPECT_EQ(runtime_data.bytes_written, kOutputSizeBytes); + EXPECT_NEAR(absl::ToDoubleMicroseconds(runtime_data.exec_time), 5, 1); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index c8c42019d2efdb..e76f783f477e95 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" @@ -79,10 +80,12 @@ class GpuPerformanceModelTest : public HloTestBase { // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuHloCostAnalysis analysis_{options_, &device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &mlir_context_}; GpuPerformanceModelTest() : HloTestBase() {} }; diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc index 99fb5b7538515f..0d6f2998953362 100644 --- a/third_party/xla/xla/service/gpu/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -20,11 +20,14 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "third_party/protobuf/descriptor.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emitter_triton.h" @@ -41,9 +44,72 @@ namespace xla { namespace gpu { namespace { -using UnaryElementwiseTest = TritonSupportTestWithParam; +using ::testing::Not; +using ::testing::status::IsOk; + +auto AllXlaDataTypes() { + std::vector xla_data_types; + std::vector to_filter_out = {PRIMITIVE_TYPE_INVALID, + TUPLE, OPAQUE_TYPE, TOKEN}; + const proto2::EnumDescriptor* xla_type_descriptor = + proto2::GetEnumDescriptor(); + for (int enum_ix = 0; enum_ix < xla_type_descriptor->value_count(); + ++enum_ix) { + xla::PrimitiveType xla_type = static_cast( + xla_type_descriptor->value(enum_ix)->number()); + if (!absl::c_linear_search(to_filter_out, xla_type)) { + xla_data_types.push_back(xla_type); + } + } + return ::testing::ValuesIn(xla_data_types); +} // TODO(b/343158720): remove references to TritonFusionAnalysis in this file. +// TODO(b/343158720): factor out implication tests into a util in order to +// simplify the test structure. +using BitcastOrReshapeTest = TritonSupportTestWithParam; + +TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) { + auto [data_type, opcode] = GetParam(); + const std::string kHloTestTemplate = R"( +triton_computation { + parameter_0 = $0[1,16,4]{2,1,0} parameter(0) + ROOT bitcast_or_reshape = $0[64]{0} $1(parameter_0) +} + +ENTRY e { + parameter_0 = $0[1,16,4]{2,1,0} parameter(0) + ROOT root_op = $0[64]{0} fusion(parameter_0), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config":{"kind":"__triton"}} +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + if (IsTritonSupportedInstruction(ti.Instruction(), + GetCudaComputeCapability())) { + TF_EXPECT_OK(CreateTritonIrAndFileCheck(ti.TritonComputation(), + FromOutputTileSizes({16}), + "CHECK: tt.func @triton_fn")); + } else { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper("test_fn", &ti.TritonFusion(), GetCudaComputeCapability(), + dev_info, FromOutputTileSizes({1}), &llvm_module_, + mlir_context_), + Not(IsOk())); + } +} + +INSTANTIATE_TEST_SUITE_P( + BitcastOrReshapeTestSuite, BitcastOrReshapeTest, + ::testing::Combine(AllXlaDataTypes(), + ::testing::Values(HloOpcode::kBitcast, + HloOpcode::kReshape)), + TritonSupportTestParamsToString); + +using UnaryElementwiseTest = TritonSupportTestWithParam; // TODO(b/331636835): updates elementwise op tests to directly emit single op // instead of relying on triton gemm kernel.