[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Use llvm::SmallVector instead of std::vector. #71575

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2511,7 +2511,7 @@ MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp(
llvm::SmallVector<int32_t> order;
llvm::SmallVector<int32_t> boundary_checks;

const std::vector<int64_t>& tile_strides = tiled_hlo.tile_strides();
const llvm::SmallVector<int64_t>& tile_strides = tiled_hlo.tile_strides();
const Shape& shape = tiled_hlo.hlo()->shape();

// Compute physical strides of the tile. `tile_strides` contains strides for
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ cc_library(
hdrs = ["affine_map_evaluator.h"],
deps = [
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:logging",
Expand Down Expand Up @@ -604,6 +605,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
)

Expand Down Expand Up @@ -637,6 +639,8 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
)

Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/service/gpu/model/affine_map_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace gpu {

namespace {

using llvm::SmallVector;
using mlir::AffineBinaryOpExpr;
using mlir::AffineConstantExpr;
using mlir::AffineDimExpr;
Expand Down Expand Up @@ -82,13 +83,13 @@ int64_t EvaluateAffineExpr(AffineExpr expr,
}
}

std::vector<int64_t> EvaluateAffineMap(
SmallVector<int64_t> EvaluateAffineMap(
AffineMap affine_map, absl::Span<int64_t const> dim_values,
absl::Span<int64_t const> symbol_values) {
CHECK_EQ(affine_map.getNumDims(), dim_values.size());
CHECK_EQ(affine_map.getNumSymbols(), symbol_values.size());

std::vector<int64_t> results;
SmallVector<int64_t> results;
results.reserve(affine_map.getNumResults());
for (auto expr : affine_map.getResults()) {
results.push_back(EvaluateAffineExpr(expr, dim_values, symbol_values));
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/gpu/model/affine_map_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <vector>

#include "absl/types/span.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project

Expand All @@ -37,7 +38,7 @@ int64_t EvaluateAffineExpr(mlir::AffineExpr expr,

// Given an AffineMap and the values for its dimensions and symbols, evaluates
// the results.
std::vector<int64_t> EvaluateAffineMap(
llvm::SmallVector<int64_t> EvaluateAffineMap(
mlir::AffineMap affine_map, absl::Span<int64_t const> dim_values,
absl::Span<int64_t const> symbol_values = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions(

for (const std::unique_ptr<SymbolicTiledHloInstruction>& symbolic_tiled_hlo :
symbolic_tiled_hlo_instructions_) {
std::vector<int64_t> tile_sizes =
llvm::SmallVector<int64_t> tile_sizes =
symbolic_tiled_hlo->TileSizes(tile_parameters);
std::vector<int64_t> tile_strides =
llvm::SmallVector<int64_t> tile_strides =
symbolic_tiled_hlo->TileStrides(tile_parameters);

TF_ASSIGN_OR_RETURN(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,28 @@ limitations under the License.
#include <cstdint>
#include <sstream>
#include <string>
#include <vector>

#include "absl/types/span.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/service/gpu/model/affine_map_evaluator.h"
#include "xla/service/gpu/model/symbolic_tile.h"

namespace xla {
namespace gpu {

std::vector<int64_t> SymbolicTiledHloInstruction::TileOffsets(
llvm::SmallVector<int64_t> SymbolicTiledHloInstruction::TileOffsets(
absl::Span<int64_t const> tile_parameters) const {
return EvaluateAffineMap(symbolic_tile().offset_map(),
/*dim_values=*/tile_parameters);
}

std::vector<int64_t> SymbolicTiledHloInstruction::TileSizes(
llvm::SmallVector<int64_t> SymbolicTiledHloInstruction::TileSizes(
absl::Span<int64_t const> tile_parameters) const {
return EvaluateAffineMap(symbolic_tile().size_map(),
/*dim_values=*/tile_parameters);
}

std::vector<int64_t> SymbolicTiledHloInstruction::TileStrides(
llvm::SmallVector<int64_t> SymbolicTiledHloInstruction::TileStrides(
absl::Span<int64_t const> tile_parameters) const {
return EvaluateAffineMap(symbolic_tile().stride_map(),
/*dim_values=*/tile_parameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.

#include "absl/log/check.h"
#include "absl/types/span.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile.h"
Expand All @@ -41,13 +42,13 @@ class SymbolicTiledHloInstruction {
: hlo_(hlo), indexing_map_(std::move(indexing_map)) {}

// Evaluates the tile offsets of an instruction with given tile parameters.
std::vector<int64_t> TileOffsets(
llvm::SmallVector<int64_t> TileOffsets(
absl::Span<int64_t const> tile_parameters) const;
// Evaluates the tile sizes of an instruction with given tile parameters.
std::vector<int64_t> TileSizes(
llvm::SmallVector<int64_t> TileSizes(
absl::Span<int64_t const> tile_parameters) const;
// Evaluates the tile strides of an instruction with given tile parameters.
std::vector<int64_t> TileStrides(
llvm::SmallVector<int64_t> TileStrides(
absl::Span<int64_t const> tile_parameters) const;

const HloInstruction* hlo() const { return hlo_; }
Expand Down
10 changes: 3 additions & 7 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/util.h"
Expand All @@ -38,8 +34,8 @@ namespace gpu {
/*static*/
absl::StatusOr<std::unique_ptr<TiledHloInstruction>>
TiledHloInstruction::Create(const HloInstruction* hlo,
std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
llvm::SmallVector<int64_t> tile_sizes,
llvm::SmallVector<int64_t> tile_strides,
IndexingMap tile_offsets_indexing) {
int rank = hlo->shape().rank();

Expand Down
33 changes: 21 additions & 12 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"

Expand All @@ -46,19 +48,22 @@ class TiledHloInstruction {
// * `tile_offsets_indexing` should have the number of dimensions equal to the
// rank of the output tile and 0 symbols.
static absl::StatusOr<std::unique_ptr<TiledHloInstruction>> Create(
const HloInstruction* hlo, std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides, IndexingMap tile_offsets_indexing);
const HloInstruction* hlo, llvm::SmallVector<int64_t> tile_sizes,
llvm::SmallVector<int64_t> tile_strides,
IndexingMap tile_offsets_indexing);

// Returns the original HLO instruction.
const HloInstruction* hlo() const { return hlo_; }

// Returns the tile sizes. The number of tile sizes is equal to the rank of
// the output shape.
const std::vector<int64_t>& tile_sizes() const { return tile_sizes_; }
const llvm::SmallVector<int64_t>& tile_sizes() const { return tile_sizes_; }

// Returns the tile strides. The number of tile strides is equal to the rank
// of the output shape.
const std::vector<int64_t>& tile_strides() const { return tile_strides_; }
const llvm::SmallVector<int64_t>& tile_strides() const {
return tile_strides_;
}

// Returns the indexing map from tile multi-index to tile offsets. The map has
// a form of `(d0, d1, ...) -> (tile_offset0, tile_offset1, ...)`. The number
Expand Down Expand Up @@ -90,8 +95,8 @@ class TiledHloInstruction {

private:
TiledHloInstruction(const HloInstruction* hlo,
std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
llvm::SmallVector<int64_t> tile_sizes,
llvm::SmallVector<int64_t> tile_strides,
IndexingMap tile_offsets_indexing)
: hlo_(hlo),
tile_sizes_(std::move(tile_sizes)),
Expand All @@ -102,8 +107,8 @@ class TiledHloInstruction {
const HloInstruction* hlo_;

// Tile sizes and strides.
std::vector<int64_t> tile_sizes_;
std::vector<int64_t> tile_strides_;
llvm::SmallVector<int64_t> tile_sizes_;
llvm::SmallVector<int64_t> tile_strides_;

// Indexing map for tile offsets.
IndexingMap tile_offsets_indexing_;
Expand All @@ -126,10 +131,14 @@ inline bool operator!=(const TiledHloInstruction& lhs,

template <typename H>
H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) {
return H::combine(std::move(h), tiled_hlo_instruction.hlo(),
tiled_hlo_instruction.tile_sizes(),
tiled_hlo_instruction.tile_strides(),
tiled_hlo_instruction.tile_offsets_indexing());
// There is no default hash implementation for llvm::SmallVector neither in
// AbslHashValue nor in llvm::hash_value. We can use the available hash
// implementation for absl::Span instread.
return H::combine(
std::move(h), tiled_hlo_instruction.hlo(),
absl::Span<int64_t const>(tiled_hlo_instruction.tile_sizes()),
absl::Span<int64_t const>(tiled_hlo_instruction.tile_strides()),
tiled_hlo_instruction.tile_offsets_indexing());
}

} // namespace gpu
Expand Down