[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Construct constraints for destructured summations in SymbolicTile derivation. #70421

Merged
merged 1 commit into from
Jun 28, 2024

Commits on Jun 28, 2024

  1. [XLA:GPU] Construct constraints for destructured summations in `Symbo…

    …licTile` derivation.
    
    Take the summation expression
    
      `expr = sum(map(lambda [size, stride]: stride * size, sizes_and_strides))`.
    
    In order to assign a single stride for the summation expression, we need to
    ensure that the parameters (sizes) involved in the expression are such that
    the gap between them is always the same. Concretely, given a list of sizes
    `[s0, s1, ..., s{n}]` ordered in descending order of associated strides, we
    expect that each size `s{k}` is either:
    
    1. `1` (and the corresponding stride is irrelevant);
    2. fully captured---i.e. `s{k} = upper_bound(s{k})`. Assume `s{k}` is the
       leftmost fully captured dimension. In that case,
       `for i in {0, ..., n-k-1}`, `s{k+i+1}` is allowed to be fully captured if
       `s{k+i}` is also fully captured.  Otherwise, `s{k+i+1} = 1`. The resulting
       stride is the smallest stride associated with a fully captured
       dimension;
    3. partially captured---i.e. `1 < s{k} < upper_bound(s{k})`. In that case,
       `for i in {0, ..., k-1}`, `s{i} = 1`. `s{k+1}` is allowed to be fully
       captured (and thus the leftmost fully captured dimension), in which case
       we do as in `2.`. If `s{k+1}` is not fully captured, then
       `for i in {k+1, ..., n}`, `s{i} = 1`, and the stride of the expression is
        the stride associated with `s{k}`.
    
    As a regex-like summary, we expect the sizes to be as follows in row-major
    order (i.e. strictly decreasing order of strides):
    
    &nbsp;&nbsp;`(1*, partial_dim?, full_dims*, 1*)`
    
    and construct constraints as such.
    
    PiperOrigin-RevId: 647638017
    bchetioui authored and tensorflower-gardener committed Jun 28, 2024
    Configuration menu
    Copy the full SHA
    30e6caf View commit details
    Browse the repository at this point in the history