[XLA:GPU] Construct constraints for destructured summations in SymbolicTile
derivation.
#70421
+199
−35
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[XLA:GPU] Construct constraints for destructured summations in
SymbolicTile
derivation.Take the summation expression
expr = sum(map(lambda [size, stride]: stride * size, sizes_and_strides))
.In order to assign a single stride for the summation expression, we need to
ensure that the parameters (sizes) involved in the expression are such that
the gap between them is always the same. Concretely, given a list of sizes
[s0, s1, ..., s{n}]
ordered in descending order of associated strides, weexpect that each size
s{k}
is either:1
(and the corresponding stride is irrelevant);s{k} = upper_bound(s{k})
. Assumes{k}
is theleftmost fully captured dimension. In that case,
for i in {0, ..., n-k-1}
,s{k+i+1}
is allowed to be fully captured ifs{k+i}
is also fully captured. Otherwise,s{k+i+1} = 1
. The resultingstride is the smallest stride associated with a fully captured
dimension;
1 < s{k} < upper_bound(s{k})
. In that case,for i in {0, ..., k-1}
,s{i} = 1
.s{k+1}
is allowed to be fullycaptured (and thus the leftmost fully captured dimension), in which case
we do as in
2.
. Ifs{k+1}
is not fully captured, thenfor i in {k+1, ..., n}
,s{i} = 1
, and the stride of the expression isthe stride associated with
s{k}
.As a regex-like summary, we expect the sizes to be as follows in row-major
order (i.e. strictly decreasing order of strides):
(1*, partial_dim?, full_dims*, 1*)
and construct constraints as such.