[go: nahoru, domu]

Skip to content

Commit

Permalink
Canonicalize affine expression trees in simplifier.
Browse files Browse the repository at this point in the history
I don't know where the current non-determinism comes from, but this should
fix it.

The particular order is not important, but it being canonical is. These
expressions are used for codegen in the new emitters, so the simplification
must be deterministic.

PiperOrigin-RevId: 633531399
  • Loading branch information
jreiffers authored and tensorflower-gardener committed May 14, 2024
1 parent d6bfe26 commit b8629e1
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) {

constexpr auto kIndexing = R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(th_x + bl_x * 128) mod 400)
(bl_x * 128 + th_x) mod 400)
domain:
th_x in [0, 127]
th_y in [0, 0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST_F(ConcatenateTest, ThreadIndexing) {

constexpr auto kIndexing = R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(th_x + bl_x * 128) mod 400)
(bl_x * 128 + th_x) mod 400)
domain:
th_x in [0, 127]
th_y in [0, 0]
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/input_slices_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ TEST_F(InputSlicesTest, ThreadIndexing) {
EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0,
((th_x + bl_x * 128) floordiv 3) mod 2,
(th_x + bl_x * 128) mod 3,
((bl_x * 128 + th_x) floordiv 3) mod 2,
(bl_x * 128 + th_x) mod 3,
((bl_x * 64 + th_x floordiv 2) floordiv 3) mod 5)
domain:
th_x in [0, 127]
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) {
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100,
(((th_x + bl_x * 128) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200,
(((bl_x * 128 + th_x) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200,
(th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id
)
domain:
Expand Down Expand Up @@ -150,7 +150,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) {
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10,
((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20,
(th_x + bl_x * 128) mod 30)
(bl_x * 128 + th_x) mod 30)
domain:
th_x in [0, 127]
th_y in [0, 0]
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) {
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100,
(((th_x + bl_x * 128) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200,
(((bl_x * 128 + th_x) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200,
(th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id
)
domain:
Expand Down Expand Up @@ -186,7 +186,7 @@ TEST_F(LoopTest, Broadcast) {
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10,
((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20,
(th_x + bl_x * 128) mod 30)
(bl_x * 128 + th_x) mod 30)
domain:
th_x in [0, 127]
th_y in [0, 0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) {
constexpr char kExpectedIndexing[] = R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
d3 floordiv 8,
d0 floordiv 32 + (d3 mod 8) * 8,
(d3 mod 8) * 8 + d0 floordiv 32,
(d0 mod 32) * 2 + s2 * 64 + s3
)
domain:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) {
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42,
((bl_x * 32 + th_x floordiv 4) floordiv 5) mod 10,
(th_x + bl_x * 128) mod 20)
(bl_x * 128 + th_x) mod 20)
domain:
th_x in [0, 127]
th_y in [0, 0]
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/fusions/scatter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) {
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42,
((bl_x * 32 + th_x floordiv 4) floordiv 5) mod 10,
(th_x + bl_x * 128) mod 20)
(bl_x * 128 + th_x) mod 20)
domain:
th_x in [0, 127]
th_y in [0, 0]
Expand Down
12 changes: 6 additions & 6 deletions third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) {
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 2,
### d0 floordiv 32 + s1 * 4 ###,
### (d3 mod 2) * 32 + d0 mod 32 ###
d0 floordiv 32 + s1 * 4,
(d3 mod 2) * 32 + d0 mod 32
)
domain:
d0 in [0, 127]
Expand All @@ -71,7 +71,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) {
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 2,
### d0 floordiv 32 + (d3 mod 2) * 32 + s1 * 4 ###,
(d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
d0 mod 32
)
domain:
Expand Down Expand Up @@ -110,7 +110,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) {
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 2,
d0 floordiv 32 + (d3 * 32 + s1 * 4) mod 64,
(d3 * 32 + s1 * 4) mod 64 + d0 floordiv 32,
d0 mod 32
)
domain:
Expand All @@ -129,9 +129,9 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) {
fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
### d0 floordiv 32 + s1 * 4 ###,
d0 floordiv 32 + s1 * 4,
d3 floordiv 2,
### (d3 mod 2) * 32 + d0 mod 32 ###
(d3 mod 2) * 32 + d0 mod 32
)
domain:
d0 in [0, 127]
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/transpose_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST_F(TransposeTest, ThreadIndexing021) {
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 2,
d0 floordiv 32 + (d3 mod 2) * 32 + s1 * 4,
(d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
d0 mod 32
)
domain:
Expand Down Expand Up @@ -139,7 +139,7 @@ TEST_F(TransposeTest, ThreadIndexing201) {
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 2,
d0 floordiv 32 + (d3 * 32 + s1 * 4) mod 64,
(d3 * 32 + s1 * 4) mod 64 + d0 floordiv 32,
d0 mod 32
)
domain:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) {
EXPECT_THAT(input_indexing.indexing_maps,
ElementsAre(ElementsAre(MatchIndexingMap(R"(
(d0, d1, d2) -> (d0 * 2 + d1 floordiv 2,
d2 + (d1 mod 2) * 4)
(d1 mod 2) * 4 + d2)
domain:
d0 in [0, 1]
d1 in [0, 3]
Expand All @@ -1662,7 +1662,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) {
EXPECT_THAT(input_indexing.indexing_maps,
ElementsAre(ElementsAre(MatchIndexingMap(R"(
(d0, d1) -> (d0 floordiv 2,
d1 floordiv 4 + (d0 mod 2) * 2,
(d0 mod 2) * 2 + d1 floordiv 4,
d1 mod 4)
domain:
d0 in [0, 3]
Expand Down Expand Up @@ -2615,7 +2615,7 @@ TEST_F(IndexingAnalysisTest, TilingIndexing) {
EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
(d3 floordiv 64) * 8 + s0,
d0 floordiv 4 + (d3 mod 64) * 4,
(d3 mod 64) * 4 + d0 floordiv 4,
d0 mod 4 + s2 * 4
)
domain:
Expand Down Expand Up @@ -2659,7 +2659,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) {
ComputeEpilogueInputToOutputIndexing(transpose, log, &mlir_context_)
.ToString(),
MatchIndexingString(R"(
(d0, d1) -> (d0 + d1 * 1000)
(d0, d1) -> (d1 * 1000 + d0)
domain:
d0 in [0, 999]
d1 in [0, 999]
Expand Down
70 changes: 69 additions & 1 deletion third_party/xla/xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
Expand Down Expand Up @@ -359,6 +360,69 @@ AffineExpr AffineExprSimplifier::RewriteSumIf(
return pred(expr) ? expr : mlir::getAffineConstantExpr(0, expr.getContext());
}

// Compares the two expression by their AST. The ordering is arbitrary but
// similar to what MLIR's simplifier does.
int CompareExprs(AffineExpr a, AffineExpr b) {
if ((b.getKind() == AffineExprKind::Constant) !=
(a.getKind() == AffineExprKind::Constant)) {
return a.getKind() == AffineExprKind::Constant ? 1 : -1;
}
if (a.getKind() < b.getKind()) {
return -1;
}
if (a.getKind() > b.getKind()) {
return 1;
}
assert(a.getKind() == b.getKind());
int64_t a_value = 0;
int64_t b_value = 0;
switch (a.getKind()) {
case AffineExprKind::Add:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mul:
case AffineExprKind::Mod: {
auto a_bin = mlir::cast<AffineBinaryOpExpr>(a);
auto b_bin = mlir::cast<AffineBinaryOpExpr>(b);
auto lhs = CompareExprs(a_bin.getLHS(), b_bin.getLHS());
if (lhs != 0) {
return lhs;
}
return CompareExprs(a_bin.getRHS(), b_bin.getRHS());
}
case AffineExprKind::Constant: {
a_value = mlir::cast<AffineConstantExpr>(a).getValue();
b_value = mlir::cast<AffineConstantExpr>(b).getValue();
break;
}
case AffineExprKind::SymbolId: {
a_value = mlir::cast<AffineSymbolExpr>(a).getPosition();
b_value = mlir::cast<AffineSymbolExpr>(b).getPosition();
break;
}
case AffineExprKind::DimId: {
a_value = mlir::cast<AffineDimExpr>(a).getPosition();
b_value = mlir::cast<AffineDimExpr>(b).getPosition();
break;
}
}
return a_value < b_value ? -1 : (a_value > b_value ? 1 : 0);
}

AffineExpr CanonicalizeOrder(AffineExpr in) {
if (auto binop = mlir::dyn_cast<AffineBinaryOpExpr>(in)) {
auto lhs = CanonicalizeOrder(binop.getLHS());
auto rhs = CanonicalizeOrder(binop.getRHS());
if ((binop.getKind() == AffineExprKind::Add ||
binop.getKind() == AffineExprKind::Mul) &&
CompareExprs(lhs, rhs) > 0) {
std::swap(lhs, rhs);
}
return getAffineBinaryOpExpr(binop.getKind(), lhs, rhs);
}
return in;
}

AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Mul: {
Expand Down Expand Up @@ -478,7 +542,11 @@ AffineMap AffineExprSimplifier::Simplify(AffineMap affine_map) {
results.push_back(simplified);
}
if (nothing_changed) {
return affine_map;
for (auto& result : results) {
result = CanonicalizeOrder(result);
}
return AffineMap::get(affine_map.getNumDims(), affine_map.getNumSymbols(),
results, affine_map.getContext());
}
return Simplify(AffineMap::get(affine_map.getNumDims(),
affine_map.getNumSymbols(), results,
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TEST_F(IndexingMapTest,
ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {});
indexing_map.Simplify(GetIndexingMapForInstruction);
EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"(
(d0, d1, d2) -> (d0 * 2 + (d1 + d2 floordiv 4) floordiv 2,
(d0, d1, d2) -> (d0 * 2 + (d2 floordiv 4 + d1) floordiv 2,
(d1 * 4 + d2) mod 8)
domain:
d0 in [0, 9]
Expand Down Expand Up @@ -619,7 +619,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) {
ParseAffineMap(serialized_map, &mlir_context_), {}, {128});
indexing_map.Simplify(GetIndexingMapForInstruction);
EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"(
()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)
()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715)
domain: s0 in [0, 127]
)"));
}
Expand Down Expand Up @@ -663,7 +663,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) {
indexing_map.Simplify(GetIndexingMapForInstruction);
EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"(
()[s0, s1, s2, s3] -> (
s1 + (s0 * 458752 + s2 * 4 + s3 * 512) mod 20000
(s0 * 458752 + s2 * 4 + s3 * 512) mod 20000 + s1
)
domain:
s0 in [0, 871]
Expand Down
77 changes: 12 additions & 65 deletions third_party/xla/xla/service/gpu/model/indexing_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,78 +140,25 @@ AffineExpr ParseAffineExpr(absl::string_view serialized_affine_expr,
.getResult(0);
}

inline std::vector<std::string> split_string(std::string s,
std::string pattern) {
std::vector<std::string> result;
size_t pos = 0;
while ((pos = s.find(pattern)) != std::string::npos) {
result.push_back(s.substr(0, pos));
s.erase(0, pos + pattern.length());
}
if (!s.empty()) result.push_back(s);
return result;
}

inline bool startswith(const std::string& s, const std::string& pattern) {
return s.substr(0, pattern.size()) == pattern;
}

bool ApproximateMatch(std::string_view lhs, std::string_view rhs) {
std::string lhs_unspaced, rhs_unspaced;
for (auto c : lhs) {
if (!std::isspace(c)) {
lhs_unspaced += c;
size_t lhs_length = lhs.size();
size_t rhs_length = rhs.size();
size_t l = 0, r = 0;
while (l < lhs_length && r < rhs_length) {
while (l < lhs_length && std::isspace(lhs[l])) {
++l;
}
}
for (auto c : rhs) {
if (!std::isspace(c)) {
rhs_unspaced += c;
}
}

if (lhs_unspaced.find("###") == std::string::npos)
return lhs_unspaced == rhs_unspaced;

std::vector<std::string> frags = split_string(lhs_unspaced, "###");

while (frags.size() >= 2) {
if (!startswith(rhs_unspaced, frags[0])) {
return false;
while (r < rhs_length && std::isspace(rhs[r])) {
++r;
}

rhs_unspaced = rhs_unspaced.substr(frags[0].size());

auto terms = split_string(frags[1], "+");
// iterate through permutations of terms
std::vector<int> indexes(terms.size());
for (auto i = 0; i < terms.size(); i++) {
indexes[i] = i;
if (l == lhs_length || r == rhs_length) {
continue;
}
bool match = false;
do {
std::string permuted = "";
for (auto i : indexes) {
permuted += terms[i] + "+";
}
permuted.pop_back();
if (startswith(rhs_unspaced, permuted)) {
match = true;
break;
}
} while (std::next_permutation(indexes.begin(), indexes.end()));

if (!match) {
if (lhs[l++] != rhs[r++]) {
return false;
}

rhs_unspaced = rhs_unspaced.substr(frags[1].size());
frags.erase(frags.begin());
frags.erase(frags.begin());
}
if (frags.empty())
return rhs_unspaced.empty();
else
return rhs_unspaced == frags[0];
return l == lhs_length && r == rhs_length;
}

} // namespace gpu
Expand Down

0 comments on commit b8629e1

Please sign in to comment.