[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Use IndexingMap's simplifier in tiling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634360780
  • Loading branch information
tdanyluk authored and tensorflower-gardener committed May 16, 2024
1 parent 228e8a4 commit c9e8a7f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ cc_library(
hdrs = ["symbolic_tile.h"],
deps = [
":affine_map_printer",
":indexing_analysis",
":indexing_map",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
Expand Down
42 changes: 30 additions & 12 deletions third_party/xla/xla/service/gpu/model/symbolic_tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "xla/service/gpu/model/affine_map_printer.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"

namespace xla {
Expand Down Expand Up @@ -278,6 +279,30 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStride(
LOG(FATAL) << "unreachable";
}

// Simplifies the given affine expression using the constraints / bounds of
// the reference indexing map.
//
// The dimensions and symbols of the expression should correspond to the
// dimensions and symbols of the reference indexing map.
AffineExpr SimplifyAffineExpr(const AffineExpr& expr,
const IndexingMap& reference) {
AffineMap tmp_affine_map =
AffineMap::get(/*dimCount=*/reference.GetDimVars().size(),
/*symbolCount=*/reference.GetSymbolCount(),
/*results=*/{expr},
/*context=*/reference.GetMLIRContext());
IndexingMap tmp_indexing_map(
/*affine_map=*/std::move(tmp_affine_map),
/*dimensions=*/reference.GetDimVars(),
/*range_vars=*/reference.GetRangeVars(),
/*rt_vars=*/reference.GetRTVars(),
/*constraints=*/reference.GetConstraints());
tmp_indexing_map.Simplify(GetIndexingMapForInstruction);

CHECK_EQ(tmp_indexing_map.GetAffineMap().getResults().size(), 1);
return tmp_indexing_map.GetAffineMap().getResults().back();
}

} // anonymous namespace

/*static*/ std::optional<SymbolicTile> SymbolicTile::FromIndexingMap(
Expand Down Expand Up @@ -322,6 +347,9 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStride(
input_affine_map, getAffineConstantExpr(0, mlir_context),
indexing_map.GetRangeVarsCount())
.getResults();
for (AffineExpr& expr : offset_expressions) {
expr = SimplifyAffineExpr(expr, indexing_map);
}

std::vector<AffineExpr> size_expressions;
std::vector<AffineExpr> stride_expressions;
Expand All @@ -333,19 +361,9 @@ std::optional<SizeAndStrideExpression> ExtractSizeAndStride(
for (auto [composite_indexing, offset] :
llvm::zip(input_affine_map.getResults(), offset_expressions)) {
std::optional<SizeAndStrideExpression> maybe_size_and_stride =
ExtractSizeAndStride(composite_indexing - offset,
ExtractSizeAndStride(SimplifyAffineExpr(composite_indexing - offset,
/*reference=*/indexing_map),
indexing_map.GetSymbolBounds());
if (!maybe_size_and_stride.has_value()) {
// Retry with a simplified expression.
// For example `(d0 + s0 - s0)` will be simplified to `d0`.
// But the simplification doesn't help when it rewrites `mod` to
// `floordiv` & `add`, so at first we try without simplification.
maybe_size_and_stride = ExtractSizeAndStride(
simplifyAffineExpr(composite_indexing - offset,
input_affine_map.getNumDims(),
input_affine_map.getNumSymbols()),
indexing_map.GetSymbolBounds());
}
if (!maybe_size_and_stride.has_value()) {
VLOG(1) << "No size and stride extracted";
return std::nullopt;
Expand Down
16 changes: 7 additions & 9 deletions third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,18 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicSlice) {
EXPECT_THAT(
SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()),
// s0, s1, s2: tile sizes
// s3, s4, s5: runtime parameters
// TODO(tdanyluk): If the RTVar can only have 1 value, maybe we should
// optimize it out?
// s3, s4: runtime parameters
// Note: We don't have s0 in the size map's rhs, because the first dim
// of the tile size can only be 1. The second offset is optimized to 0,
// because that is the only possible value.
Optional(MatchSymbolicTileWithRtVars(
"()[s0, s1, s2, s3, s4, s5] -> (s3, s4, s5)",
"()[s0, s1, s2] -> (s0, s1, s2)", "()[s0, s1, s2] -> (1, 1, 1)",
"()[s0, s1, s2, s3, s4] -> (s3, 0, s4)",
"()[s0, s1, s2] -> (1, s1, s2)", "()[s0, s1, s2] -> (0, 1, 1)",
R"(
s3 in [0, 1]
hlo: %of1 = s32[] parameter(1)
(d0, d1, d2) -> ()
s4 in [0, 0]
hlo: %of2 = s32[] parameter(2)
(d0, d1, d2) -> ()
s5 in [0, 226]
s4 in [0, 226]
hlo: %of3 = s32[] parameter(3)
(d0, d1, d2) -> ()
)")));
Expand Down

0 comments on commit c9e8a7f

Please sign in to comment.