[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Unroll small column reductions less, and use less shmem.
Browse files Browse the repository at this point in the history
A column reduction has an unroll factor called `num_partial_results`.  I think
this is a misnomer -- I observe that with num_partial_results == 4, my kernel
produces four *complete* (i.e. not partial) results per warp.

This patch makes two changes.

1. We now unroll small column reductions less.  Previously, we could get into a
situation where, due to unrolling, we don't produce enough blocks to saturate
the GPU.

2. We use a new codegen strategy for column reductions that uses less shared
memory when the column reduction is unrolled.  Previously we used a chunk of
Nx33x32 elements where N is the unroll factor.  But actually only one 33x32
block is live at a time, so this is N times larger than necessary.

If we don't do (1) before doing (2), XLA takes advantage of the additional
available shmem and unrolls small reductions even more, causing performance
regressions!

PiperOrigin-RevId: 538403704
  • Loading branch information
Justin Lebar authored and tensorflower-gardener committed Jun 7, 2023
1 parent d27dc25 commit 42ea7ad
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 52 deletions.
136 changes: 100 additions & 36 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3902,14 +3902,16 @@ ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState(
"shared_cache");
} else {
// Allocate __shared__
// cache[num_partial_results][num_threads][num_threads + 1], where
// cache[num_threads][num_threads + 1], where
// num_threads == num_threads_x == num_threads_y. The "+1" is used to
// avoid bank conflicts.
//
// (Although each thread produces num_partial_results results, we
// don't need that much cache: Only one result is live at a time.)
CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY));
return AllocateShared(
tiling_scheme, element_type,
{num_partial_results, num_threads_x, num_threads_x + 1},
"shared_cache");
return AllocateShared(tiling_scheme, element_type,
{num_threads_x, num_threads_x + 1},
"shared_cache");
}
}();

Expand Down Expand Up @@ -4258,6 +4260,13 @@ void IrEmitterUnnested::EmitReductionOutputForColumnReduction(
const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme();
int num_outputs = reducer->num_parameters() / 2;

// Wait for reads from shmem in the last iteration to complete. (If this is
// slow, we could "double-buffer" by having two shmem buffers and switching
// between them.)
if (partial_result_idx > 0) {
EmitSyncThreads();
}

// Store the transpose in shared memory.
for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
const ReductionCodegenState::ReductionCalculationState& state =
Expand All @@ -4266,8 +4275,7 @@ void IrEmitterUnnested::EmitReductionOutputForColumnReduction(
llvm::AddrSpaceCastInst* shmem_output_addr =
llvm::cast<llvm::AddrSpaceCastInst>(thread_id_info.GEPIntoSharedMemory(
&b_, shared_cache,
{constant(partial_result_idx), thread_id_info.thread_id_x,
thread_id_info.thread_id_y},
{thread_id_info.thread_id_x, thread_id_info.thread_id_y},
"shmem_output_address"));
llvm::Value* current_output =
InBoundsGEP(state.partial_result_address->getAllocatedType(),
Expand All @@ -4289,8 +4297,7 @@ void IrEmitterUnnested::EmitReductionOutputForColumnReduction(
llvm::AddrSpaceCastInst* shmem_transposed_addr =
llvm::cast<llvm::AddrSpaceCastInst>(thread_id_info.GEPIntoSharedMemory(
&b_, state.shared_cache,
{constant(partial_result_idx), thread_id_info.thread_id_y,
thread_id_info.thread_id_x},
{thread_id_info.thread_id_y, thread_id_info.thread_id_x},
"shmem_transposed_addr"));
shmem_transposed_addrs.push_back(
{shmem_transposed_addr, llvm::cast<llvm::GetElementPtrInst>(
Expand Down Expand Up @@ -4705,17 +4712,45 @@ int64_t NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
});
}

// If the reduce is relatively small, it's possible to unroll so much that we
// don't have enough blocks to saturate the GPU. This function computes the max
// number of times we can unroll the reduction while still saturating the GPU.
static int64_t MaxBeneficialColumnReductionUnrollBasedOnBlockSize(
const GpuDeviceInfo& gpu_info, HloComputation* fused_computation) {
int64_t num_reduce_output_elems = 0;
for (const HloInstruction* root : GetFusionRoots(fused_computation)) {
if (!IsReductionFromOrToContiguousDimensions(*root)) {
continue;
}
const Shape* output_shape = &root->shape();
// Unwrap multi-output reduction. All outputs should be the same shape.
if (output_shape->IsTuple()) {
output_shape = &output_shape->tuple_shapes()[0];
}
num_reduce_output_elems =
std::max(num_reduce_output_elems, ShapeUtil::ElementsIn(*output_shape));
}

// A column reduction that's unrolled N times uses one warp to generate N
// output elements. The block size is always 32 warps = 1024 threads.
int64_t num_blocks = CeilOfRatio(num_reduce_output_elems, int64_t{32});
int64_t num_threads = num_blocks * 1024;
// Number of SMs we can saturate with this work.
int num_cores =
CeilOfRatio<int64_t>(num_threads, gpu_info.threads_per_core_limit);
return static_cast<int>(CeilOfRatio(num_cores, gpu_info.core_count));
}

// The benefit of unrolling a kInput fusion that is a column reduction comes
// from the vectorization of non-reduction fusion outputs and fusion inputs.
// On the other hand, unrolling can also introduce factors that can cause
// the kernel to run slower. This routine uses a simple heuristic to estimate
// the benefit as well as the overhead of unrolling in order to decide whether
// unrolling is beneficial for the given kInput fusion.
bool IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion,
HloComputation* fused_computation,
const Shape& input_shape,
int64_t num_kept_minor,
bool reduction_is_race_free) {
static bool IsUnrollingColumnReductionBeneficial(
const GpuDeviceInfo& gpu_info, mlir::lmhlo::FusionOp fusion,
HloComputation* fused_computation, const Shape& input_shape,
int64_t num_kept_minor, bool reduction_is_race_free) {
if (num_kept_minor % (WarpSize() * 2) != 0) {
return false;
}
Expand Down Expand Up @@ -4750,7 +4785,12 @@ bool IsUnrollingColumnReductionBeneficial(mlir::lmhlo::FusionOp fusion,
// unrolled even with such an assumption, and the accesses to those inputs
// turn out to be vectorizable, the compiler will still vectorize them.
cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
return can_be_vectorized >= cannot_be_vectorized;
if (can_be_vectorized < cannot_be_vectorized) {
return false;
}

return MaxBeneficialColumnReductionUnrollBasedOnBlockSize(
gpu_info, fused_computation) > 1;
}

int64_t NearestPowerOfTwo(int64_t v) {
Expand Down Expand Up @@ -4844,16 +4884,30 @@ llvm::GlobalVariable* IrEmitterUnnested::AllocateShared(
array_type, buffer_name);
}

// Returns the size of the dtype of the smallest "nontrivial" input to this
// fusion.
//
// We use this as part of the heuristic of choosing our fusion unroll factor.
// Fusions which read smaller elements (e.g. i8) should unroll more, so that
// they execute larger coalesced loads.
static int SmallestInputDtypeBits(mlir::lmhlo::FusionOp fusion) {
int bits = std::numeric_limits<int>::max();
for (mlir::Value operand : fusion.getInputBuffers()) {
bits = std::min(GetPrimitiveBitwidth(operand), bits);
}
return bits;
}

// Whether the reduction can be vectorized.
static bool CanVectorizeReduction(
se::CudaComputeCapability cc, mlir::lmhlo::FusionOp fusion,
HloComputation* fused_computation,
se::CudaComputeCapability cc, const GpuDeviceInfo& gpu_info,
mlir::lmhlo::FusionOp fusion, HloComputation* fused_computation,
const ReductionDimensions& reduction_dimensions, int num_threads_x,
Vector3 reduction_tiling, const Shape& input_shape,
bool reduction_is_race_free) {
if (!reduction_dimensions.is_row_reduction) {
return IsUnrollingColumnReductionBeneficial(
fusion, fused_computation, input_shape,
gpu_info, fusion, fused_computation, input_shape,
reduction_dimensions.dimensions[kDimX], reduction_is_race_free);
}

Expand All @@ -4872,13 +4926,8 @@ static bool CanVectorizeReduction(
return true;
}

int smallest_input_dtype_bits = std::numeric_limits<int>::max();
for (mlir::Value operand : fusion.getInputBuffers()) {
smallest_input_dtype_bits =
std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits);
}
if (cc.IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
return smallest_input_dtype_bits <= 32 &&
return SmallestInputDtypeBits(fusion) <= 32 &&
reduction_dimensions.dimensions[kDimX] %
(reduction_tiling[2] * num_threads_x) ==
0;
Expand Down Expand Up @@ -4943,12 +4992,6 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(

se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();

int smallest_input_dtype_bits = std::numeric_limits<int>::max();
for (mlir::Value operand : fusion.getInputBuffers()) {
smallest_input_dtype_bits =
std::min(GetPrimitiveBitwidth(operand), smallest_input_dtype_bits);
}

TilingScheme::IndexingOrder indexing_order =
reduction_dimensions.is_row_reduction ? kStridedIndexingX
: kLinearIndexingX;
Expand All @@ -4960,13 +5003,15 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
bool vectorize =
// Vectorization might cause us to run out of budget.
(shmem_usage * 2 <= shmem_budget) &&
CanVectorizeReduction(cc, fusion, fused_computation, reduction_dimensions,
CanVectorizeReduction(cc, ir_emitter_context_->gpu_device_info(), fusion,
fused_computation, reduction_dimensions,
num_threads_x, reduction_tiling, input_shape,
reduction_is_race_free);
int vector_size = vectorize ? 2 : 1;

int num_partial_results = 1;
if (!reduction_dimensions.is_row_reduction && vectorize) {
int smallest_input_dtype_bits = SmallestInputDtypeBits(fusion);
if (smallest_input_dtype_bits <= 32) {
// Make sure to use all the data read at once.
// Instead of hardcoding the granularity, we can query the granularity we
Expand All @@ -4986,12 +5031,31 @@ StatusOr<ReductionCodegenInfo> IrEmitterUnnested::ComputeReductionCodegenInfo(
} else {
num_partial_results = 2;
}
}

while (shmem_usage * num_partial_results > shmem_budget) {
num_partial_results /= 2;
if (num_partial_results == 1) {
break;
// Take into account MaxBeneficialColumnReductionUnrollBasedOnBlockSize.
// (We can't go below 2 for the unroll factor -- if we wanted to use 1 as
// the unroll factor, we should have set this reduction as unvectorized.)
num_partial_results = std::clamp<int>(
2, //
num_partial_results,
MaxBeneficialColumnReductionUnrollBasedOnBlockSize(
ir_emitter_context_->gpu_device_info(), fused_computation));
}

// TODO(b/283542954): Autotune num_partial_results? This can make a big
// difference, e.g. by affecting register spilling.

// Row reductions use one shmem block per partial result, so we have to make
// sure we fit in budget. Column reductions only ever use one shmem block.
// (Indeed I *think* "num_partial_results" is a misnomer for column
// reductions; I think it's the number of *complete*, i.e. not partial,
// results per warp.)
if (reduction_dimensions.is_row_reduction) {
while (shmem_usage * num_partial_results > shmem_budget) {
num_partial_results /= 2;
if (num_partial_results == 1) {
break;
}
}
}

Expand Down
32 changes: 16 additions & 16 deletions tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) {
ENTRY kernel_entry {
constant0 = f32[] constant(0)
arg1 = f16[1024,512]{1,0} parameter(0)
arg1_conv = f32[1024,512]{1,0} convert(arg1)
ROOT reduce = f32[512]{0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction
arg1 = f16[1024,512,128]{2,1,0} parameter(0)
arg1_conv = f32[1024,512,128]{2,1,0} convert(arg1)
ROOT reduce = f32[512,128]{1,0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction
})";

// Check that two calls to llvm.nvvm.atomic are generated.
Expand Down Expand Up @@ -322,23 +322,23 @@ TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) {
fused_computation {
constant0 = f32[] constant(0)
arg.1 = f16[1024,512]{1,0} parameter(0)
arg.2 = f16[1024,512]{1,0} parameter(1)
arg1.conv = f32[1024,512]{1,0} convert(arg.1)
arg2.conv = f32[1024,512]{1,0} convert(arg.2)
reduce1 = f32[512]{0} reduce(arg1.conv, constant0), dimensions={0},
arg.1 = f16[1024,512,128]{2,1,0} parameter(0)
arg.2 = f16[1024,512,128]{2,1,0} parameter(1)
arg1.conv = f32[1024,512,128]{2,1,0} convert(arg.1)
arg2.conv = f32[1024,512,128]{2,1,0} convert(arg.2)
reduce1 = f32[512,128]{1,0} reduce(arg1.conv, constant0), dimensions={0},
to_apply=reduction22
reduce2 = f32[512]{0} reduce(arg2.conv, constant0), dimensions={0},
reduce2 = f32[512,128]{1,0} reduce(arg2.conv, constant0), dimensions={0},
to_apply=reduction22
add = f32[1024,512]{1,0} add(arg1.conv, arg2.conv)
ROOT tuple = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0})
add = f32[1024,512,128]{2,1,0} add(arg1.conv, arg2.conv)
ROOT tuple = (f32[512,128]{1,0}, f32[512,128]{1,0}, f32[1024,512,128]{2,1,0})
tuple(reduce1, reduce2, add)
}
ENTRY kernel_entry {
arg1 = f16[1024,512]{1,0} parameter(0)
arg2 = f16[1024,512]{1,0} parameter(1)
ROOT fusion = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0})
arg1 = f16[1024,512,128]{2,1,0} parameter(0)
arg2 = f16[1024,512,128]{2,1,0} parameter(1)
ROOT fusion = (f32[512,128]{1,0}, f32[512,128]{1,0}, f32[1024,512,128]{2,1,0})
fusion(arg1, arg2), kind=kInput, calls=fused_computation
})";

Expand Down Expand Up @@ -731,8 +731,8 @@ reduction {
ENTRY kernel_entry {
constant0 = f32[] constant(0)
arg1 = f32[1024,512]{1,0} parameter(0)
ROOT reduce = f32[512]{0} reduce(arg1, constant0), dimensions={0}, to_apply=reduction
arg1 = f32[1024,512,128]{2,1,0} parameter(0)
ROOT reduce = f32[512,128]{1,0} reduce(arg1, constant0), dimensions={0}, to_apply=reduction
}
)";
auto expected_ir = R"(
Expand Down

0 comments on commit 42ea7ad

Please sign in to comment.