[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Construct constraints for destructured summations in SymbolicTile derivation. #70421

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 187 additions & 28 deletions third_party/xla/xla/service/gpu/model/symbolic_tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <optional>
#include <ostream>
#include <sstream>
Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
Expand All @@ -56,29 +58,32 @@ using ::mlir::getAffineDimExpr;
using ::mlir::MLIRContext;
using ConjointConstraints = ConstraintExpression::ConjointConstraints;

// Converts all dimensions to symbols at the same position.
AffineExpr DimsToSymbols(AffineExpr expr, int64_t num_dimensions) {
MLIRContext* ctx = expr.getContext();
llvm::SmallVector<int64_t> iota(num_dimensions);
std::iota(iota.begin(), iota.end(), 0);

return expr.replaceDims(llvm::to_vector(llvm::map_range(
iota, [&](int64_t i) { return getAffineSymbolExpr(i, ctx); })));
}

// Gets a modified version of `expressions` where both the original dimensions
// and symbols are replaced with symbols.
//
// (dimensions)[symbols] -> ()[dimensions, symbols]
std::vector<AffineExpr> DimsToSymbols(std::vector<AffineExpr> expressions,
const IndexingMap& indexing_map) {
MLIRContext* mlir_context = indexing_map.GetMLIRContext();

// Move symbols right
// Shift existing symbols to the right such that they end up following the
// newly introduced symbols.
for (AffineExpr& expression : expressions) {
expression =
expression.shiftSymbols(/*numSymbols=*/indexing_map.GetSymbolCount(),
/*shift=*/indexing_map.GetDimensionCount());
}

// Convert dimensions to symbols
llvm::DenseMap<AffineExpr, AffineExpr> dim_to_symbol_map;
for (int i = 0; i < indexing_map.GetDimensionCount(); i++) {
dim_to_symbol_map[getAffineDimExpr(i, mlir_context)] =
getAffineSymbolExpr(i, mlir_context);
}
for (AffineExpr& expression : expressions) {
expression = expression.replace(dim_to_symbol_map);
expression = DimsToSymbols(expression, indexing_map.GetDimensionCount());
}

return expressions;
Expand Down Expand Up @@ -122,11 +127,11 @@ struct SizeAndStrideExpression {
// Extracts size and stride expressions from the operands to a modulo
// expression.
//
// TODO(b/326998704): Currently, this fails when the stride is not exactly unit.
// TODO(b/349487906): Currently, this fails when the stride is not exactly unit.
std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromMod(
AffineExpr lhs, AffineExpr modulus) {
// TODO(b/326998704): finish deriving constraints here, as well as the non-one
// stride case, both in the code and in the proof.
// TODO(b/349487906): handle the non-one stride case, both in the code and in
// the proof.
// Let f(d0) = d0 mod c. Then, given an input tile size n,
// {f(x) | x in Fin(n)} contains:
// * n elements if n < c (and we add a constraint that c % n == 0)
Expand All @@ -153,7 +158,7 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromMod(
AffineExpr tile_size_expr =
getAffineSymbolExpr(dim_expr.getPosition(), lhs.getContext());
Interval zero_interval{/*lower=*/0, /*upper=*/0};
// TODO(b/326998704): the below also becomes more complicated if stride is
// TODO(b/349487906): the below also becomes more complicated if stride is
// not unit.
//
// tile_size % modulus == 0 || modulus % tile_size == 0
Expand All @@ -175,7 +180,7 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromMod(
// Extracts size and stride expressions from the operands to a floordiv
// expression.
//
// TODO(b/326998704): Currently, this fails when the numerator of the stride
// TODO(b/349487906): Currently, this fails when the numerator of the stride
// is not exactly unit.
std::optional<SizeAndStrideExpression> ExtractSizeAndStrideFromFloorDiv(
AffineExpr num, AffineExpr den) {
Expand Down Expand Up @@ -316,30 +321,49 @@ AffineExpr IfNeqOne(AffineExpr eq_param, AffineExpr true_expr,

// Sorts a list of `SizeAndStrideExpression`s by stride. There is a precondition
// that all strides are constant.
void SortByStride(std::vector<SizeAndStrideExpression>& sizes_and_strides) {
absl::c_sort(sizes_and_strides, [](const SizeAndStrideExpression& sas1,
const SizeAndStrideExpression& sas2) {
void SortByStride(std::vector<SizeAndStrideExpression>& sizes_and_strides,
bool reverse = false) {
absl::c_sort(sizes_and_strides, [&](const SizeAndStrideExpression& sas1,
const SizeAndStrideExpression& sas2) {
int64_t stride1 = llvm::cast<AffineConstantExpr>(sas1.stride).getValue();
int64_t stride2 = llvm::cast<AffineConstantExpr>(sas2.stride).getValue();
if (reverse) {
return stride1 > stride2;
}
return stride1 < stride2;
});
}

// Returns the range size of the given size expression.
//
// `size` must be a constant or dimension expression.
// `size` must be a constant or dimension expression by default. If
// `dimensions_were_converted_to_symbols` is true, then `size` can not be a
// dimension expression but can be a symbol expression instead.
std::optional<int64_t> TryGetSizeExpressionRangeSize(
AffineExpr size, absl::Span<Interval const> dimension_intervals) {
CHECK(size.getKind() == AffineExprKind::Constant ||
size.getKind() == AffineExprKind::DimId);
if (auto dimension = llvm::dyn_cast<AffineDimExpr>(size)) {
const Interval& interval = dimension_intervals.at(dimension.getPosition());
AffineExpr size, absl::Span<Interval const> dimension_intervals,
bool dimensions_were_converted_to_symbols = false) {
std::optional<int64_t> dim_or_symbol_position;
if (dimensions_were_converted_to_symbols) {
CHECK(size.getKind() == AffineExprKind::Constant ||
size.getKind() == AffineExprKind::SymbolId);
if (auto symbol = llvm::dyn_cast<AffineSymbolExpr>(size)) {
dim_or_symbol_position = symbol.getPosition();
}
} else {
CHECK(size.getKind() == AffineExprKind::Constant ||
size.getKind() == AffineExprKind::DimId);
if (auto dimension = llvm::dyn_cast<AffineDimExpr>(size)) {
dim_or_symbol_position = dimension.getPosition();
}
}
if (dim_or_symbol_position.has_value()) {
const Interval& interval = dimension_intervals.at(*dim_or_symbol_position);
if (interval.lower != 0) {
// TODO(bchetioui): I think we may need to handle this to have reshapes
// working well with concatenations. Nevertheless, we can take a look
// later.
VLOG(1) << "Attempted to combine strides but got dimension "
<< AffineMapPrinter().ToString(dimension) << " with lower bound "
<< AffineMapPrinter().ToString(size) << " with lower bound "
<< interval.lower << " != 0";
return std::nullopt;
}
Expand Down Expand Up @@ -460,6 +484,133 @@ std::optional<AffineExpr> CombineStrides(
return nested_if;
}

// Given a set of size expressions assumed to be sorted in descending order of
// associated stride, returns a conjunction such that:
// - the first `partial_dim_index` size expressions are constrained to be
// equal to 1;
// - the `partial_dim_index`-th size expression is unconstrained;
// - the next `num_full_dims` size expressions are constrained to be equal to
// their upper bound;
// - the remaining size expressions are constrained to be equal to 1.
//
// See also the documentation of
// `ConstructConstraintExpressionForDestructuredSummation` for broader context.
std::optional<ConjointConstraints>
TryConstructSingleConjointConstraintForDestructuredSummation(
absl::Span<AffineExpr const> sizes,
absl::Span<Interval const> dimension_intervals, int64_t partial_dim_index,
int64_t num_full_dims) {
CHECK(partial_dim_index + num_full_dims <= sizes.size());

ConjointConstraints constraints;
Interval /*upper=*/1};
int64_t running_size_index = 0;

// Add leading ones.
while (running_size_index < partial_dim_index) {
constraints.insert({sizes[running_size_index], one});
++running_size_index;
}

// Skip partial dimension, since "partial" basically means unconstrained.
++running_size_index;

// Add full dimensions.
while (running_size_index <= partial_dim_index + num_full_dims) {
AffineExpr size_expr = sizes[running_size_index];
std::optional<int64_t> max_size = TryGetSizeExpressionRangeSize(
size_expr, dimension_intervals,
/*dimensions_were_converted_to_symbols=*/true);
if (!max_size.has_value()) {
return std::nullopt;
}
constraints.insert(
{size_expr, Interval{/*lower=*/*max_size, /*upper=*/*max_size}});
++running_size_index;
}

// Add trailing ones.
while (running_size_index < sizes.size()) {
constraints.insert({sizes[running_size_index], one});
++running_size_index;
}

return constraints;
}

// Constructs constraints for the summation expression
// expr = sum(map(lambda [size, stride]: stride * size, sizes_and_strides)).
//
// In order to assign a single stride for the summation expression, we need to
// ensure that the parameters (sizes) involved in the expression are such that
// the gap between them is always the same. Concretely, given a list of sizes
// [s0, s1, ..., s{n}] ordered in descending order of associated strides, we
// expect that each size s{k} is either:
// a) 1 (and the corresponding stride is irrelevant);
// b) fully captured---i.e. s{k} = upper_bound(s{k}). Assume s{k} is the
// leftmost fully captured dimension. In that case,
// for i in {0, ..., n-k-1}, s{k+i+1} is allowed to be fully captured if
// s{k+i} is also fully captured. Otherwise, s{k+i+1} = 1. The resulting
// stride is the smallest stride associated with a fully captured
// dimension, or the stride of s{k};
// c) partially captured---i.e. 1 < s{k} < upper_bound(s{k}). In that case,
// for i in {0, ..., k-1}, s{i} = 1. s{k+1} is allowed to be fully
// captured (and thus the leftmost fully captured dimension), in which case
// we do as in b). If s{k+1} is not fully captured, then
// for i in {k+1, ..., n}, s{i} = 1, and the stride of the expression is
// the stride associated with s{k}.
//
// As a regex-like summary, we expect the sizes to be as follows in row-major
// order (i.e. strictly decreasing order of strides):
// (1*, partial_dim?, full_dims*, 1*).
//
// See also the documentation of `CombineStrides`.
ConstraintExpression ConstructConstraintExpressionForDestructuredSummation(
std::vector<SizeAndStrideExpression> sizes_and_strides,
absl::Span<Interval const> dimension_intervals) {
SortByStride(sizes_and_strides, /*reverse=*/true);
ConstraintExpression result;

int64_t num_dimension_parameters = dimension_intervals.size();
std::vector<AffineExpr> sizes_with_symbols;
sizes_with_symbols.reserve(num_dimension_parameters);
// Use symbols here because constraints operate on the symbols of the
// `SymbolicTile`, as explained in the documentation of the class.
for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) {
sizes_with_symbols.push_back(
DimsToSymbols(size_and_stride.size, num_dimension_parameters));
}

int64_t num_components = sizes_and_strides.size();
for (int64_t partial_dim_index = 0; partial_dim_index < num_components;
++partial_dim_index) {
for (int64_t num_full_dims = 0;
num_full_dims < num_components - partial_dim_index; ++num_full_dims) {
std::optional<ConjointConstraints> single_conjoint_constraint =
TryConstructSingleConjointConstraintForDestructuredSummation(
sizes_with_symbols, dimension_intervals, partial_dim_index,
num_full_dims);
if (!single_conjoint_constraint.has_value()) {
// Even if we fail to derive a single conjunction, we can still recover
// if we are able to derive another one. The constraint system will
// just end up being more restricted (since one of the branches of the
// overall disjunction will disappear).
continue;
}
result.Or(std::move(*single_conjoint_constraint));
}
}

// If we didn't succeed at constructing any constraint, we don't really know
// what valid tile sizes could even make this work---hence, we return an
// unsatisfiable map.
if (result.IsAlwaysSatisfied()) {
return ConstraintExpression::GetUnsatisfiableConstraintExpression();
}

return result;
}

// See documentation of `CombineSizes` and `CombineStrides` for an explanation
// of how sizes and strides are combined.
std::optional<SizeAndStrideExpression> CombineSizesAndStrides(
Expand All @@ -485,12 +636,20 @@ std::optional<SizeAndStrideExpression> CombineSizesAndStrides(

AffineExpr size = CombineSizes(sizes_and_strides);
std::optional<AffineExpr> stride =
CombineStrides(std::move(sizes_and_strides), dimension_intervals);
CombineStrides(sizes_and_strides, dimension_intervals);
if (!stride.has_value()) {
return std::nullopt;
}

// TODO(b/326998704): handle reshape constraints here.
// Derive necessary constraints for the summation expression. These
// constraints are explained in the documentation of
// `ConstructConstraintExpressionForDestructuredSummation` and
// `CombineStrides`.
constraints = ConstraintExpression::And(
std::move(constraints),
ConstructConstraintExpressionForDestructuredSummation(
std::move(sizes_and_strides), dimension_intervals));

return SizeAndStrideExpression(size, *stride, std::move(constraints));
}

Expand Down Expand Up @@ -900,7 +1059,7 @@ void ConstraintExpression::Print(std::ostream& out,
/*results=*/results,
/*context=*/indexing_map.GetMLIRContext());

// TODO(b/326998704): Can we derive any constraint from the constraints of
// TODO(b/349507828): Can we derive any constraint from the constraints of
// the original indexing map?
IndexingMap tile_map(
/*affine_map=*/std::move(tile_affine_map),
Expand Down
19 changes: 12 additions & 7 deletions third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ TEST_F(SymbolicTileTest,

TEST_F(SymbolicTileTest,
CanPropagateTileThroughNonTrivialSplitReshapeFromOutputToInput) {
// TODO(b/334043867): we need disjunctions here to derive the proper
// constraints for the tile sizes.
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
HloModule m
ENTRY e {
Expand All @@ -180,6 +178,12 @@ TEST_F(SymbolicTileTest,
(((-s2 + 7) floordiv 6) * (((-s1 + 9) floordiv 8) *
((-((-s0 + 5) floordiv 4) + 1) * 48) +
(-((-s1 + 9) floordiv 8) + 1) * 6) + -((-s2 + 7) floordiv 6) + 1, 1)
constraints: s0 in [1, 2) && s1 in [1, 2) ||
s0 in [1, 2) && s2 in [1, 2) ||
s0 in [1, 2) && s2 in [6, 7) ||
s1 in [1, 2) && s2 in [1, 2) ||
s1 in [8, 9) && s2 in [1, 2) ||
s1 in [8, 9) && s2 in [6, 7)
)")));

// Capturing elements along dimensions 0, 1, and 2 makes the stride equal to
Expand Down Expand Up @@ -575,7 +579,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReshapeOfReverse) {

TEST_F(SymbolicTileTest,
FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshape) {
// TODO(b/326998704): constraints should allow us to unblock this use case.
// TODO(b/349487906): constraints should allow us to unblock this use case.
// A slice of a split reshape creates a non-unit stride atop a floordiv.
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
HloModule m
Expand All @@ -598,8 +602,6 @@ TEST_F(SymbolicTileTest,

TEST_F(SymbolicTileTest,
FailsGracefullyAtPropagatingTileThroughMisalignedSliceOfSplitReshape) {
// TODO(b/326998704): constraints should allow us to unblock part of this use
// case.
// TODO(b/331257678): handling correctly cases where offsets don't get
// simplified away perfectly will allow us to unblock part of this use case.
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
Expand All @@ -623,7 +625,7 @@ TEST_F(SymbolicTileTest,

TEST_F(SymbolicTileTest,
FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshapeOnTranspose) {
// TODO(b/326998704): constraints should allow us to unblock this use case.
// TODO(b/349487906): constraints should allow us to unblock this use case.
// A slice of a split reshape creates a non-unit stride atop a floordiv.
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
HloModule m
Expand All @@ -647,7 +649,7 @@ TEST_F(SymbolicTileTest,

TEST_F(SymbolicTileTest,
FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshapeOfReverse) {
// TODO(b/326998704): constraints should allow us to unblock this use case.
// TODO(b/349487906): constraints should allow us to unblock this use case.
// A slice of a split reshape of a reverse creates a negative non-unit stride
// atop a floordiv.
auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
Expand Down Expand Up @@ -746,6 +748,7 @@ TEST_F(SymbolicTileTest,
offset_map: ()[s0, s1, s2] -> (0, 0)
size_map: ()[s0, s1, s2] -> (s0 * s1, 50304)
stride_map: ()[s0, s1, s2] -> (((-s1 + 2049) floordiv 2048) * ((-((-s0 + 5) floordiv 4) + 1) * 2048) + -((-s1 + 2049) floordiv 2048) + 1, 1)
constraints: s0 in [1, 2) || s1 in [1, 2) || s1 in [2048, 2049)
)")));
}

Expand All @@ -757,12 +760,14 @@ TEST_F(SymbolicTileTest, CanDeriveTileWhenTheIndexingMapHasSymbolsInASum) {
&mlir_context_),
{4, 2048, 393}, {128});

// TODO(b/349377672): the constraints here can be simplified away.
EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map),
Optional(MatchSymbolicTileString(R"(
Symbolic tile with
offset_map: ()[s0, s1, s2] -> (0, 0, 0)
size_map: ()[s0, s1, s2] -> (s0, s1, s2 * 128)
stride_map: ()[s0, s1, s2] -> (1, 1, 1)
constraints: 128 in [1, 2) || 128 in [128, 129) || s2 in [1, 2)
)")));
}

Expand Down