[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Add ToString() methods to Tiled HLO instructions and comput…
Browse files Browse the repository at this point in the history
…ations.

PiperOrigin-RevId: 639051471
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed May 31, 2024
1 parent 9bec073 commit b17844b
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 4 deletions.
12 changes: 12 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ cc_library(
":indexing_analysis",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:name_uniquer",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -614,10 +617,16 @@ xla_cc_test(

cc_library(
name = "tiled_hlo_computation",
srcs = ["tiled_hlo_computation.cc"],
hdrs = ["tiled_hlo_computation.h"],
deps = [
":tiled_hlo_instruction",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:name_uniquer",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/lib/gtl:iterator_range",
],
)
Expand All @@ -627,16 +636,19 @@ cc_library(
srcs = ["symbolic_tile_analysis.cc"],
hdrs = ["symbolic_tile_analysis.h"],
deps = [
":affine_map_printer",
":indexing_analysis",
":symbolic_tile",
":symbolic_tiled_hlo_instruction",
":tiled_hlo_computation",
":tiled_hlo_instruction",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"//xla/service:name_uniquer",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
30 changes: 30 additions & 0 deletions third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/AffineExpr.h" // from @llvm-project
Expand All @@ -37,13 +41,15 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/gpu/model/tiled_hlo_instruction.h"
#include "xla/service/instruction_fusion.h"
#include "xla/service/name_uniquer.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -315,5 +321,29 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions(
std::move(tiled_hlo_instructions));
}

std::string SymbolicTileAnalysis::ToString(
const AffineMapPrinter& printer) const {
std::stringstream ss;
NameUniquer name_uniquer("_");
absl::flat_hash_map<SymbolicTiledHloInstruction*, std::string> tile_names;

for (const auto& tiled_hlo : symbolic_tiled_hlo_instructions_) {
std::string tile_name = name_uniquer.GetUniqueName(
absl::StrCat(tiled_hlo->hlo()->name(), ".tile_0"));
tile_names[tiled_hlo.get()] = tile_name;

absl::InlinedVector<std::string, 4> operand_names;
for (const auto& operand : tiled_hlo->operands()) {
operand_names.push_back(tile_names.at(operand));
}

ss << tile_name << " = " << HloOpcodeString(tiled_hlo->hlo()->opcode())
<< "(" << absl::StrJoin(operand_names, ", ") << ")\n";

ss << tiled_hlo->ToString();
}
return ss.str();
}

} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/status/statusor.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/instruction_fusion.h"
Expand Down Expand Up @@ -67,6 +69,11 @@ class SymbolicTileAnalysis {
// Return the underlying MLIRContext.
mlir::MLIRContext* GetMLIRContext() const { return context_; };

// Returns a string representation of the analysis. Used only for error
// messages and debugging.
std::string ToString(
const AffineMapPrinter& printer = AffineMapPrinter()) const;

private:
SymbolicTileAnalysis(std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>
symbolic_tiled_hlo_instructions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"

#include <cstdint>
#include <sstream>
#include <string>
#include <vector>

#include "absl/log/check.h"
Expand Down Expand Up @@ -75,5 +77,13 @@ std::vector<int64_t> SymbolicTiledHloInstruction::TileStrides(
return EvaluateTileMap(symbolic_tile_.stride_map(), tile_parameters);
}

std::string SymbolicTiledHloInstruction::ToString() const {
std::stringstream ss;
ss << "\thlo: " << hlo_->ToString() << "\n";
ss << "\t" << symbolic_tile_.ToString() << "\n";
ss << "\tindexing map: " << indexing_map_.ToString() << "\n";
return ss.str();
}

} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILED_HLO_INSTRUCTION_H_

#include <cstdint>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -69,6 +70,10 @@ class SymbolicTiledHloInstruction {
operands_.push_back(operand);
}

// Returns a string representation of the instruction. Used only for error
// messages and debugging.
std::string ToString() const;

private:
// Pointer to the original HLO instruction.
const HloInstruction* hlo_;
Expand Down
56 changes: 56 additions & 0 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_computation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/model/tiled_hlo_computation.h"

#include <sstream>
#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/model/tiled_hlo_instruction.h"
#include "xla/service/name_uniquer.h"

namespace xla {
namespace gpu {

std::string TiledHloComputation::ToString() const {
std::stringstream ss;
NameUniquer name_uniquer("_");
absl::flat_hash_map<const TiledHloInstruction*, std::string> tile_names;

for (const auto* tiled_hlo : instructions()) {
std::string tile_name = name_uniquer.GetUniqueName(
absl::StrCat(tiled_hlo->hlo()->name(), ".tile_0"));
tile_names[tiled_hlo] = tile_name;

absl::InlinedVector<std::string, 4> operand_names;
for (const auto& operand : tiled_hlo->operands()) {
operand_names.push_back(tile_names.at(operand));
}

ss << tile_name << " = " << HloOpcodeString(tiled_hlo->hlo()->opcode())
<< "(" << absl::StrJoin(operand_names, ", ") << ")\n";

ss << tiled_hlo->ToString() << "\n";
}
return ss.str();
}

} // namespace gpu
} // namespace xla
5 changes: 5 additions & 0 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_SERVICE_GPU_MODEL_TILED_HLO_COMPUTATION_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -56,6 +57,10 @@ class TiledHloComputation {
return instructions_.back().get();
}

// Returns a string representation of the computation. Used only for error
// messages and debugging.
std::string ToString() const;

private:
explicit TiledHloComputation(
std::vector<std::unique_ptr<TiledHloInstruction>> instructions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ TiledHloInstruction::Create(const HloInstruction* hlo,

std::string TiledHloInstruction::ToString() const {
std::stringstream ss;
ss << "hlo: " << hlo_->ToString() << "\n";
ss << "tile_sizes: {" << absl::StrJoin(tile_sizes_, ", ") << "}\n";
ss << "tile_strides: {" << absl::StrJoin(tile_strides_, ", ") << "}\n";
ss << "block_id_to_tile_offsets_indexing: "
ss << "\thlo: " << hlo_->ToString() << "\n";
ss << "\ttile_sizes: (" << absl::StrJoin(tile_sizes_, ", ") << ")\n";
ss << "\ttile_strides: (" << absl::StrJoin(tile_strides_, ", ") << ")\n";
ss << "\tblock_id_to_tile_offsets_indexing: "
<< block_id_to_tile_offsets_indexing_;
return ss.str();
}
Expand Down

0 comments on commit b17844b

Please sign in to comment.