[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU][NFC] Rename gpu_version_ to compute_capability_ in `gem…
Browse files Browse the repository at this point in the history
…m_fusion`.

PiperOrigin-RevId: 646992650
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Jun 26, 2024
1 parent 22804b0 commit 7ab29f8
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 63 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:tensor_float_32_utils",
],
)
Expand Down
16 changes: 3 additions & 13 deletions third_party/xla/xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -798,24 +798,14 @@ bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
absl::StatusOr<bool> GemmFusion::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version_);
if (!cuda_compute_capability) {
return absl::FailedPreconditionError(
"Triton support is only enabled for CUDA GPUs.");
} else if (!cuda_compute_capability->IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ",
"capability 8.0) and up, but got compute capability ",
cuda_compute_capability->major, ".",
cuda_compute_capability->minor, "."));
}
TF_RETURN_IF_ERROR(
EnsureTritonSupportsComputeCapability(compute_capability_));

bool changed = false;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
TF_ASSIGN_OR_RETURN(bool result,
RunOnComputation(computation, gpu_version_));
RunOnComputation(computation, compute_capability_));
changed |= result;
}
return changed;
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/gpu/gemm_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ bool ShouldTritonHandleGEMM(HloDotInstruction&,
// that target Triton-based matmul emitter.
class GemmFusion : public HloModulePass {
public:
explicit GemmFusion(const se::GpuComputeCapability& gpu_version)
: gpu_version_(gpu_version) {}
explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
: compute_capability_(compute_capability) {}
absl::string_view name() const override { return "triton-gemm-rewriter"; }

using HloPassInterface::Run;
Expand All @@ -48,7 +48,7 @@ class GemmFusion : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
se::GpuComputeCapability gpu_version_;
se::GpuComputeCapability compute_capability_;
};

} // namespace gpu
Expand Down
215 changes: 187 additions & 28 deletions third_party/xla/xla/service/gpu/model/symbolic_tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <optional>
#include <ostream>
#include <sstream>
Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
Expand All @@ -56,29 +58,32 @@ using ::mlir::getAffineDimExpr;
using ::mlir::MLIRContext;
using ConjointConstraints = ConstraintExpression::ConjointConstraints;

// Converts all dimensions to symbols at the same position.
AffineExpr DimsToSymbols(AffineExpr expr, int64_t num_dimensions) {
MLIRContext* ctx = expr.getContext();
llvm::SmallVector<int64_t> iota(num_dimensions);
std::iota(iota.begin(), iota.end(), 0);

return expr.replaceDims(llvm::to_vector(llvm::map_range(
iota, [&](int64_t i) { return getAffineSymbolExpr(i, ctx); })));
}

// Gets a modified version of `expressions` where both the original dimensions
// and symbols are replaced with symbols.
//
// (dimensions)[symbols] -> ()[dimensions, symbols]
std::vector<AffineExpr> DimsToSymbols(std::vector<AffineExpr> expressions,
const IndexingMap& indexing_map) {
MLIRContext* mlir_context = indexing_map.GetMLIRContext();

// Move symbols right
// Shift existing symbols to the right such that they end up following the
// newly introduced symbols.
for (AffineExpr& expression : expressions) {
expression =
expression.shiftSymbols(/*numSymbols=*/indexing_map.GetSymbolCount(),
/*shift=*/indexing_map.GetDimensionCount());
}

// Convert dimensions to symbols
llvm::DenseMap<AffineExpr, AffineExpr> dim_to_symbol_map;
for (int i = 0; i < indexing_map.GetDimensionCount(); i++) {
dim_to_symbol_map[getAffineDimExpr(i, mlir_context)] =
getAffineSymbolExpr(i, mlir_context);
}
for (AffineExpr& expression : expressions) {
expression = expression.replace(dim_to_symbol_map);
expression = DimsToSymbols(expression, indexing_map.GetDimensionCount());
}

return expressions;
Expand Down Expand Up @@ -122,11 +127,11 @@ struct SizeAndStrideExpression {
// Extracts size and stride expressions from the operands to a modulo
// expression.
//
// TODO(b/326998704): Currently, this fails when the stride is not exactly unit.
// TODO(b/349487906): Currently, this fails when the stride is not exactly unit.
std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromMod(
AffineExpr lhs, AffineExpr modulus) {
// TODO(b/326998704): finish deriving constraints here, as well as the non-one
// stride case, both in the code and in the proof.
// TODO(b/349487906): handle the non-one stride case, both in the code and in
// the proof.
// Let f(d0) = d0 mod c. Then, given an input tile size n,
// {f(x) | x in Fin(n)} contains:
// * n elements if n < c (and we add a constraint that c % n == 0)
Expand All @@ -153,7 +158,7 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromMod(
AffineExpr tile_size_expr =
getAffineSymbolExpr(dim_expr.getPosition(), lhs.getContext());
Interval zero_interval{/*lower=*/0, /*upper=*/0};
// TODO(b/326998704): the below also becomes more complicated if stride is
// TODO(b/349487906): the below also becomes more complicated if stride is
// not unit.
//
// tile_size % modulus == 0 || modulus % tile_size == 0
Expand All @@ -175,7 +180,7 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromMod(
// Extracts size and stride expressions from the operands to a floordiv
// expression.
//
// TODO(b/326998704): Currently, this fails when the numerator of the stride
// TODO(b/349487906): Currently, this fails when the numerator of the stride
// is not exactly unit.
std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromFloorDiv(
AffineExpr num, AffineExpr den) {
Expand Down Expand Up @@ -316,30 +321,49 @@ AffineExpr IfNeqOne(AffineExpr eq_param, AffineExpr true_expr,

// Sorts a list of `SizeAndStrideExpression`s by stride. There is a precondition
// that all strides are constant.
void SortByStride(std::vector<SizeAndStrideExpression>& sizes_and_strides) {
absl::c_sort(sizes_and_strides, [](const SizeAndStrideExpression& sas1,
const SizeAndStrideExpression& sas2) {
void SortByStride(std::vector<SizeAndStrideExpression>& sizes_and_strides,
bool reverse = false) {
absl::c_sort(sizes_and_strides, [&](const SizeAndStrideExpression& sas1,
const SizeAndStrideExpression& sas2) {
int64_t stride1 = llvm::cast<AffineConstantExpr>(sas1.stride).getValue();
int64_t stride2 = llvm::cast<AffineConstantExpr>(sas2.stride).getValue();
if (reverse) {
return stride1 > stride2;
}
return stride1 < stride2;
});
}

// Returns the range size of the given size expression.
//
// `size` must be a constant or dimension expression.
// `size` must be a constant or dimension expression by default. If
// `dimensions_were_converted_to_symbols` is true, then `size` can not be a
// dimension expression but can be a symbol expression instead.
std::optional<int64_t> TryGetSizeExpressionRangeSize(
AffineExpr size, absl::Span<Interval const> dimension_intervals) {
CHECK(size.getKind() == AffineExprKind::Constant ||
size.getKind() == AffineExprKind::DimId);
if (auto dimension = llvm::dyn_cast<AffineDimExpr>(size)) {
const Interval& interval = dimension_intervals.at(dimension.getPosition());
AffineExpr size, absl::Span<Interval const> dimension_intervals,
bool dimensions_were_converted_to_symbols = false) {
std::optional<int64_t> dim_or_symbol_position;
if (dimensions_were_converted_to_symbols) {
CHECK(size.getKind() == AffineExprKind::Constant ||
size.getKind() == AffineExprKind::SymbolId);
if (auto symbol = llvm::dyn_cast<AffineSymbolExpr>(size)) {
dim_or_symbol_position = symbol.getPosition();
}
} else {
CHECK(size.getKind() == AffineExprKind::Constant ||
size.getKind() == AffineExprKind::DimId);
if (auto dimension = llvm::dyn_cast<AffineDimExpr>(size)) {
dim_or_symbol_position = dimension.getPosition();
}
}
if (dim_or_symbol_position.has_value()) {
const Interval& interval = dimension_intervals.at(*dim_or_symbol_position);
if (interval.lower != 0) {
// TODO(bchetioui): I think we may need to handle this to have reshapes
// working well with concatenations. Nevertheless, we can take a look
// later.
VLOG(1) << "Attempted to combine strides but got dimension "
<< AffineMapPrinter().ToString(dimension) << " with lower bound "
<< AffineMapPrinter().ToString(size) << " with lower bound "
<< interval.lower << " != 0";
return std::nullopt;
}
Expand Down Expand Up @@ -460,6 +484,133 @@ std::optional<AffineExpr> CombineStrides(
return nested_if;
}

// Given a set of size expressions assumed to be sorted in descending order of
// associated stride, returns a conjunction such that:
// - the first `partial_dim_index` size expressions are constrained to be
// equal to 1;
// - the `partial_dim_index`-th size expression is unconstrained;
// - the next `num_full_dims` size expressions are constrained to be equal to
// their upper bound;
// - the remaining size expressions are constrained to be equal to 1.
//
// See also the documentation of
// `ConstructConstraintExpressionForDestructuredSummation` for broader context.
std::optional<ConjointConstraints>
TryConstructSingleConjointConstraintForDestructuredSummation(
absl::Span<AffineExpr const> sizes,
absl::Span<Interval const> dimension_intervals, int64_t partial_dim_index,
int64_t num_full_dims) {
CHECK(partial_dim_index + num_full_dims <= sizes.size());

ConjointConstraints constraints;
Interval class="pl-c">/*lower=*/1, /*upper=*/1};
int64_t running_size_index = 0;

// Add leading ones.
while (running_size_index < partial_dim_index) {
constraints.insert({sizes[running_size_index], one});
++running_size_index;
}

// Skip partial dimension, since "partial" basically means unconstrained.
++running_size_index;

// Add full dimensions.
while (running_size_index <= partial_dim_index + num_full_dims) {
AffineExpr size_expr = sizes[running_size_index];
std::optional<int64_t> max_size = TryGetSizeExpressionRangeSize(
size_expr, dimension_intervals,
/*dimensions_were_converted_to_symbols=*/true);
if (!max_size.has_value()) {
return std::nullopt;
}
constraints.insert(
{size_expr, Interval{/*lower=*/*max_size, /*upper=*/*max_size}});
++running_size_index;
}

// Add trailing ones.
while (running_size_index < sizes.size()) {
constraints.insert({sizes[running_size_index], one});
++running_size_index;
}

return constraints;
}

// Constructs constraints for the summation expression
// expr = sum(map(lambda [size, stride]: stride * size, sizes_and_strides)).
//
// In order to assign a single stride for the summation expression, we need to
// ensure that the parameters (sizes) involved in the expression are such that
// the gap between them is always the same. Concretely, given a list of sizes
// [s0, s1, ..., s{n}] ordered in descending order of associated strides, we
// expect that each size s{k} is either:
// a) 1 (and the corresponding stride is irrelevant);
// b) fully captured---i.e. s{k} = upper_bound(s{k}). Assume s{k} is the
// leftmost fully captured dimension. In that case,
// for i in {0, ..., n-k-1}, s{k+i+1} is allowed to be fully captured if
// s{k+i} is also fully captured. Otherwise, s{k+i+1} = 1. The resulting
// stride is the smallest stride associated with a fully captured
// dimension, or the stride of s{k};
// c) partially captured---i.e. 1 < s{k} < upper_bound(s{k}). In that case,
// for i in {0, ..., k-1}, s{i} = 1. s{k+1} is allowed to be fully
// captured (and thus the leftmost fully captured dimension), in which case
// we do as in b). If s{k+1} is not fully captured, then
// for i in {k+1, ..., n}, s{i} = 1, and the stride of the expression is
// the stride associated with s{k}.
//
// As a regex-like summary, we expect the sizes to be as follows in row-major
// order (i.e. strictly decreasing order of strides):
// (1*, partial_dim?, full_dims*, 1*).
//
// See also the documentation of `CombineStrides`.
ConstraintExpression ConstructConstraintExpressionForDestructuredSummation(
std::vector<SizeAndStrideExpression> sizes_and_strides,
absl::Span<Interval const> dimension_intervals) {
SortByStride(sizes_and_strides, /*reverse=*/true);
ConstraintExpression result;

int64_t num_dimension_parameters = dimension_intervals.size();
std::vector<AffineExpr> sizes_with_symbols;
sizes_with_symbols.reserve(num_dimension_parameters);
// Use symbols here because constraints operate on the symbols of the
// `SymbolicTile`, as explained in the documentation of the class.
for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) {
sizes_with_symbols.push_back(
DimsToSymbols(size_and_stride.size, num_dimension_parameters));
}

int64_t num_components = sizes_and_strides.size();
for (int64_t partial_dim_index = 0; partial_dim_index < num_components;
++partial_dim_index) {
for (int64_t num_full_dims = 0;
num_full_dims < num_components - partial_dim_index; ++num_full_dims) {
std::optional<ConjointConstraints> single_conjoint_constraint =
TryConstructSingleConjointConstraintForDestructuredSummation(
sizes_with_symbols, dimension_intervals, partial_dim_index,
num_full_dims);
if (!single_conjoint_constraint.has_value()) {
// Even if we fail to derive a single conjunction, we can still recover
// if we are able to derive another one. The constraint system will
// just end up being more restricted (since one of the branches of the
// overall disjunction will disappear).
continue;
}
result.Or(std::move(*single_conjoint_constraint));
}
}

// If we didn't succeed at constructing any constraint, we don't really know
// what valid tile sizes could even make this work---hence, we return an
// unsatisfiable map.
if (result.IsAlwaysSatisfied()) {
return ConstraintExpression::GetUnsatisfiableConstraintExpression();
}

return result;
}

// See documentation of `CombineSizes` and `CombineStrides` for an explanation
// of how sizes and strides are combined.
std::optional<SizeAndStrideExpression> CombineSizesAndStrides(
Expand All @@ -485,12 +636,20 @@ std::optional<SizeAndStrideExpression> CombineSizesAndStrides(

AffineExpr size = CombineSizes(sizes_and_strides);
std::optional<AffineExpr> stride =
CombineStrides(std::move(sizes_and_strides), dimension_intervals);
CombineStrides(sizes_and_strides, dimension_intervals);
if (!stride.has_value()) {
return std::nullopt;
}

// TODO(b/326998704): handle reshape constraints here.
// Derive necessary constraints for the summation expression. These
// constraints are explained in the documentation of
// `ConstructConstraintExpressionForDestructuredSummation` and
// `CombineStrides`.
constraints = ConstraintExpression::And(
std::move(constraints),
ConstructConstraintExpressionForDestructuredSummation(
std::move(sizes_and_strides), dimension_intervals));

return SizeAndStrideExpression(size, *stride, std::move(constraints));
}

Expand Down Expand Up @@ -900,7 +1059,7 @@ void ConstraintExpression::Print(std::ostream& out,
/*results=*/results,
/*context=*/indexing_map.GetMLIRContext());

// TODO(b/326998704): Can we derive any constraint from the constraints of
// TODO(b/349507828): Can we derive any constraint from the constraints of
// the original indexing map?
IndexingMap tile_map(
/*affine_map=*/std::move(tile_affine_map),
Expand Down
Loading

0 comments on commit 7ab29f8

Please sign in to comment.