[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PadThunk implementation in new API #70874

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
PadThunk implementation in new API
PiperOrigin-RevId: 643030521
  • Loading branch information
tvladyslav authored and tensorflower-gardener committed Jul 4, 2024
commit cf5d48260db087025d0f9c5857a8940c3077228c
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ cc_library(
hdrs = ["ir_emitter2.h"],
deps = [
":backend_config_proto_cc",
":dot_op_emitter",
":elemental_math_emitter",
":ir_emitter",
":parallel_loop_emitter",
Expand All @@ -634,13 +635,11 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:elemental_ir_emitter",
"//xla/service/cpu:dot_op_emitter",
"//xla/service/llvm_ir:dynamic_update_slice_util",
"//xla/service/llvm_ir:fused_ir_emitter",
"//xla/service/llvm_ir:ir_array",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/llvm_ir:loop_emitter",
"//xla/service/llvm_ir:tuple_ops",
"//xla/stream_executor:launch_dim",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
16 changes: 16 additions & 0 deletions third_party/xla/xla/service/cpu/benchmarks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ xla_cc_test(
],
)

xla_cc_test(
name = "pad_benchmark_test",
srcs = ["pad_benchmark_test.cc"],
deps = [
":hlo_benchmark_runner",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:test_benchmark",
"@local_tsl//tsl/platform:test_main",
],
)

xla_cc_test(
name = "fusion_benchmark_test",
srcs = ["fusion_benchmark_test.cc"],
Expand Down
66 changes: 66 additions & 0 deletions third_party/xla/xla/service/cpu/benchmarks/pad_benchmark_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <random>
#include <string_view>
#include <vector>

#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h"
#include "xla/shape_util.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/test_benchmark.h"

namespace xla::cpu {

static void BM_PadF32(benchmark::State& state) {
int64_t d0 = state.range(0);

std::string_view hlo = R"(
HloModule pad_f32_$d0

ENTRY e {
input = f32[1,4,$d0,$d0,4] parameter(0)
value = f32[] parameter(1)
ROOT pad = pad(input, value), padding=0_0_0x0_-1_0x0_-1_0x-2_-2_0x-1_-1_3
}
)";

std::minstd_rand0 engine;

auto input_shape = ShapeUtil::MakeShape(F32, {1, 4, d0, d0, 4});
auto value_shape = ShapeUtil::MakeShape(F32, {});
auto p0 =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto p1 =
*LiteralUtil::CreateRandomLiteral<F32>(value_shape, &engine, 1.0f, 0.1f);

std::vector<const Literal*> args = {&p0, &p1};
CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}));
}

BENCHMARK(BM_PadF32)
->MeasureProcessCPUTime()
->Arg(128)
->Arg(256)
->Arg(512)
->Arg(1024)
->Arg(8192);

} // namespace xla::cpu
73 changes: 73 additions & 0 deletions third_party/xla/xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include <algorithm>
#include <cstddef>
#include <iostream>
#include <iterator>
#include <limits>
#include <map>
Expand All @@ -46,6 +47,7 @@ limitations under the License.
#include "llvm/IR/Constants.h"
#include "llvm/IR/FMF.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsX86.h"
Expand Down Expand Up @@ -2336,6 +2338,77 @@ absl::Status IrEmitter::HandlePad(HloInstruction* pad) {
return absl::OkStatus();
}

absl::Status IrEmitter::HandlePad(llvm::LLVMContext& context,
HloInstruction* pad,
llvm::Function* kernel_function,
llvm::IRBuilder<>* b,
const llvm_ir::IrArray& operand_array,
const llvm_ir::IrArray& padding_value_array,
const llvm_ir::IrArray& output_array) {
CHECK_EQ(pad->operand_count(), 2);

// CPU backend does not properly handle negative padding but this is ok
// because negative padding should be removed by the algebraic simplifier.
for (auto& padding_dimension : pad->padding_config().dimensions()) {
if (padding_dimension.edge_padding_low() < 0 ||
padding_dimension.edge_padding_high() < 0) {
return InternalStrCat(
"Encountered negative padding in IrEmitter on CPU. "
"This should have been eliminated at the HLO level. ",
pad->ToString());
}
}

const HloInstruction* padding_value = pad->operand(1); // TODO: just Shape?
const auto index_type = b->getInt64Ty();
const auto index = llvm_ir::IrArray::Index(index_type);
llvm::Value* padding_value_addr = padding_value_array.EmitArrayElementAddress(
index, b, "padding_value_addr", true, nullptr);
const llvm_ir::ElementGenerator element_generator =
[this, b, padding_value,
padding_value_addr](const llvm_ir::IrArray::Index& target_index) {
return b->CreateLoad(IrShapeType(padding_value->shape()),
padding_value_addr);
};

// First, fill in the padding value to all output elements.
auto le = llvm_ir::LoopEmitter(element_generator, output_array, b);
TF_RETURN_IF_ERROR(le.EmitLoop(IrName(pad /*, "initialize"*/), index_type));

// Create a loop to iterate over the operand elements and update the output
// locations where the operand elements should be stored.
llvm_ir::ForLoopNest loops(IrName(pad, "assign"), b);
const llvm_ir::IrArray::Index operand_index =
loops.AddLoopsForShape(pad->operand(0)->shape(), "operand");

SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b);

// Load an element from the operand.
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, b);

// Compute the output index the operand element should be assigned to.
// output_index := edge_padding_low + operand_index * (interior_padding + 1)
const PaddingConfig& padding_config = pad->padding_config();
std::vector<llvm::Value*> output_multi_index;
for (size_t i = 0; i < operand_index.size(); ++i) {
llvm::Value* offset = b->CreateMul(
operand_index[i],
b->getInt64(padding_config.dimensions(i).interior_padding() + 1));
llvm::Value* index = b->CreateAdd(
offset, b->getInt64(padding_config.dimensions(i).edge_padding_low()));
output_multi_index.push_back(index);
}

// Store the operand element to the computed output location.
llvm_ir::IrArray::Index output_index(
output_multi_index, output_array.GetShape(), operand_index.GetType());
output_array.EmitWriteArrayElement(output_index, operand_data, b);

SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b);
return absl::OkStatus();
}

absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) {
auto* root = fusion->fused_expression_root();
if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -232,6 +233,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const llvm_ir::IrArray& operand_array,
const llvm_ir::IrArray& source_array,
const llvm_ir::IrArray& output_array);
absl::Status HandlePad(llvm::LLVMContext& context, HloInstruction* pad,
llvm::Function* kernel_function, llvm::IRBuilder<>* b,
const llvm_ir::IrArray& operand_array,
const llvm_ir::IrArray& padding_value_array,
const llvm_ir::IrArray& output_array);

// A convenient helper for calling BufferAssignment::GetUniqueSlice.
BufferAllocation::Slice GetAllocationSlice(
Expand Down Expand Up @@ -378,6 +384,10 @@ class IrEmitter : public DfsHloVisitorWithDefault,
absl::Status EmitTargetElementLoop(
HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator);
absl::Status EmitTargetElementLoop(
HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator,
llvm::Value* target_op_address, llvm_ir::IrArray target_array);

// Emits a memcpy from the source instruction's result value to the
// destination's. Both source and destination must have an entry in the
Expand Down
28 changes: 28 additions & 0 deletions third_party/xla/xla/service/cpu/ir_emitter2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,34 @@ bool IrEmitter2::fast_min_max() const {
return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max();
}

absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitPadHostKernel(
const HloInstruction* pad) {
VLOG(2) << "Emit Pad host kernel.";

TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype,
EmitKernelPrototype(pad));

llvm_ir::IrArray operand_array = kernel_prototype.arguments[0];
llvm_ir::IrArray padvalue_array = kernel_prototype.arguments[1];
llvm_ir::IrArray output_array = kernel_prototype.results[0];

llvm::LLVMContext& ctx = module_->getContext();
llvm::IRBuilder<> b(ctx);
llvm::BasicBlock& start_bb = kernel_prototype.function->getEntryBlock();
start_bb.getTerminator()->eraseFromParent();
b.SetInsertPoint(&start_bb);

TF_RETURN_IF_ERROR(nested_ir_emitter_->HandlePad(
ctx, const_cast<HloInstruction*>(pad), kernel_prototype.function, &b,
operand_array, padvalue_array, output_array));
b.CreateRet(
llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(ctx)));

return kernels_.emplace_back(
KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(),
se::ThreadDim()});
}

absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitElementalHostKernel(
const HloInstruction* instr) {
VLOG(2) << "Emit elemental host kernel: " << instr->name();
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/cpu/ir_emitter2.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class IrEmitter2 {
// Returns all the kernels emitted so far via this emitter.
absl::Span<const KernelInfo> kernels() const { return kernels_; }

// Emits a host kernel for the pad instruction.
absl::StatusOr<KernelInfo> EmitPadHostKernel(const HloInstruction* pad);

// Emits an elemental host kernel for the given HLO instruction.
absl::StatusOr<KernelInfo> EmitElementalHostKernel(
const HloInstruction* instr);
Expand Down
14 changes: 12 additions & 2 deletions third_party/xla/xla/service/cpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,8 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
case HloOpcode::kCollectivePermute:
return EmitCollectivePermuteThunk(instruction);

// TODO(ezhulenev): Port pad optimizations from IrEmitter.
case HloOpcode::kPad:
return EmitElementalKernelThunk(instruction);
return EmitPadKernelThunk(instruction);

// TODO(ezhulenev): Implement slice operations as separate Thunks because
// it's much easier to get peak performance from hand written code.
Expand Down Expand Up @@ -549,6 +548,17 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCopyThunk(
instruction->shape());
}

absl::StatusOr<ThunkSequence> ThunkEmitter::EmitPadKernelThunk(
const HloInstruction* instruction) {
const HloPadInstruction* padInstr = Cast<HloPadInstruction>(instruction);
TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitPadHostKernel(padInstr));
TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(padInstr));

return ThunkSequence::Of<KernelThunk>(
ThunkInfo(padInstr), buffers.arguments, buffers.results, kernel.name,
kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign());
}

absl::StatusOr<ThunkSequence> ThunkEmitter::EmitElementalKernelThunk(
const HloInstruction* instruction) {
TF_ASSIGN_OR_RETURN(auto kernel,
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/cpu/thunk_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class ThunkEmitter {
absl::StatusOr<ThunkSequence> EmitCopyThunk(
const HloInstruction* instruction);

absl::StatusOr<ThunkSequence> EmitPadKernelThunk(
const HloInstruction* instruction);

absl::StatusOr<ThunkSequence> EmitElementalKernelThunk(
const HloInstruction* instruction);

Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,7 @@ xla_test(
xla_test(
name = "pad_test",
srcs = ["pad_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down