[go: nahoru, domu]

Skip to content

Commit

Permalink
Update compatibility constraint in BUILD file.
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11832 from Intel-tensorflow:nhatle/mm-bias-sigmoid-fusion a96013ce31dcb32c82982c61ba2a7baafab97ed4
PiperOrigin-RevId: 626318760
  • Loading branch information
tensorflower-gardener committed May 30, 2024
1 parent 6efeb06 commit 1ad51c4
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 83 deletions.
1 change: 0 additions & 1 deletion third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,6 @@ cc_library(

cc_library(
name = "cpu_plugin",
compatible_with = [],
deps = [
":service",
"//xla/service/cpu:cpu_compiler",
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/cpu/backend_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ message OneDnnMatMulConfig {
LINEAR = 7;
ELU = 8;
RELU6 = 9;
SIGMOID = 10;
}
repeated FusionKind fused_ops = 3;
bool bias_broadcast = 4;
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/cpu/onednn_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ std::unique_ptr<matmul::primitive_desc> CreateMatMulPrimDesc(
case OneDnnMatMulConfig::RELU6:
post_ops.append_eltwise(dnnl::algorithm::eltwise_clip_v2, 0.f, 6.0f);
break;
case OneDnnMatMulConfig::SIGMOID:
post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
break;
case OneDnnMatMulConfig::BIAS: {
bias_md = fused_mds.at(fused_operand_idx);
// Extend bias rank to match result rank.
Expand Down
24 changes: 24 additions & 0 deletions third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,30 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
return absl::OkStatus();
}

auto SigmoidActivation(HloInstruction* instr, HloInstruction** src) {
return Match(instr,
m::Divide(BcastConstScalar(1.0),
m::AddAnyOrder(BcastConstScalar(1.0),
m::Exp(m::Negate(m::Op(src))))));
}

Status HandleDivide(HloInstruction* instr) override {
HloInstruction* matmul_call;
HloInstruction* intermediate_instr = nullptr;
HloInstruction* optional_bitcast = nullptr;
HloInstruction* src;
if (SigmoidActivation(instr, &src)) {
if (Match(src, ElementwiseSafeIntermediates(
&intermediate_instr, &optional_bitcast,
OneDnnMatmulInstr(&matmul_call))
.WithOneUser())) {
return FuseActivation(OneDnnMatMulConfig::SIGMOID, instr, matmul_call,
intermediate_instr, optional_bitcast);
}
}
return OkStatus();
}

absl::Status FuseActivation(OneDnnMatMulConfig_FusionKind kind,
HloInstruction* activation,
HloInstruction* matmul,
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ cc_library(
hdrs = ["triton_support.h"],
deps = [
":variant_visitor",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ MlirFusionEmitterBase::CreateLLVMModule(
const BufferAssignment* buffer_assignment) const {
bool is_amd = std::holds_alternative<se::RocmComputeCapability>(
device.gpu_compute_capability());
auto* hlo_module = fusion.GetModule();
HloModule* hlo_module = fusion.GetModule();
std::unique_ptr<mlir::interpreter::MlirCompilationTrace> trace = nullptr;
if (DumpingEnabledForHloModule(*hlo_module) &&
DumpingEnabledForHloPass("mlir-fusion-emitter",
Expand Down Expand Up @@ -574,12 +574,9 @@ absl::Status MlirFusionEmitterBase::RunPassPipeline(
std::make_unique<mlir::interpreter::MlirCompilerTraceInstrumentation>(
*trace));
}

if (pm.run(module).failed()) {
std::string module_dump;
llvm::raw_string_ostream os(module_dump);
module->print(os);
return absl::InternalError(absl::StrFormat(
"Failed to run pass pipeline.\nMLIR module:\n%s", module_dump));
return absl::InternalError("Failed to run pass pipeline");
}
return absl::OkStatus();
}
Expand Down
109 changes: 63 additions & 46 deletions third_party/xla/xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ limitations under the License.
#include "tsl/platform/statusor.h"
#include "tsl/protobuf/dnn.pb.h"


namespace xla {
namespace gpu {
namespace {
Expand Down Expand Up @@ -184,67 +183,87 @@ absl::StatusOr<HloInstruction *> InvertAndConvertScalar(HloInstruction *scalar,
return scalar;
}

// Recursively collects unary, divide, dynamic-slice, pad, multiply or select
// operands of instr and the index of the operand identifying the next op in the
// sequence until an instruction with FP8 element type is reached. Returns an
// empty vector when no FP8 instruction is reached.
std::vector<std::pair<HloInstruction *, int>> FindF8SubgraphRecursive(
HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs,
std::vector<std::pair<HloInstruction *, int>> subgraph) {
// A path of instructions by traversing downwards through users, as (op,
// operand_index) pairs. operand_index is the index to get to the previous
// element in the path. I.e.,
// path[i].first->operand(path[i].second) == path[i-1].first
using InstrPath = std::vector<std::pair<HloInstruction *, int>>;

// From 'instr', recursively traverses operands until an FP8 instruction is
// encountered. Only unary ops and a few types of non-unary ops are traversed.
// If an FP8 instruction is found, returns the path from the FP8 instruction to
// 'instr'. Returns nullopt when no FP8 instruction is reached.
//
// The intent is, given 'instr' is the operand of a dot, to find a sequence of
// instruction that can potentially be fused into a cuBLAS LT FP8 gemm.
std::optional<InstrPath> FindF8SubgraphRecursive(
HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs) {
// Avoid visiting the same instruction more than once.
if (!visited_instrs.emplace(instr->unique_id()).second) {
return {};
return std::nullopt;
}
subgraph.emplace_back(std::make_pair(instr, 0));
if (IsF8Type(instr)) {
return subgraph;
// The initial operand index is meaningless. Arbitrarily use -1.
return InstrPath{{instr, -1}};
}
if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide ||
instr->opcode() == HloOpcode::kDynamicSlice ||
instr->opcode() == HloOpcode::kPad) {
return FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs,
std::move(subgraph));
std::optional<InstrPath> subgraph =
FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs);
if (subgraph) {
subgraph->emplace_back(std::make_pair(instr, 0));
}
return subgraph;
} else if (instr->opcode() == HloOpcode::kMultiply ||
instr->opcode() == HloOpcode::kSelect) {
for (int k = 0; k < 2; ++k) {
// Iterate over operands 0 and 1 for multiply and operands 1 and 2 for
// select.
int operand_idx = k + (instr->opcode() == HloOpcode::kSelect);
subgraph.back().second = operand_idx;
auto binary_subgraph = FindF8SubgraphRecursive(
instr->mutable_operand(operand_idx), visited_instrs, subgraph);
if (!binary_subgraph.empty()) {
return binary_subgraph;
std::optional<InstrPath> subgraph = FindF8SubgraphRecursive(
instr->mutable_operand(operand_idx), visited_instrs);
if (subgraph) {
subgraph->emplace_back(std::make_pair(instr, operand_idx));
return subgraph;
}
}
}
return {};
return std::nullopt;
}

// Returns whether instr and its operands describe a pattern which is compatible
// with rewriting the dot operating on instr into an FP8 Custom Call. If
// applicable, captures the operand of the Custom Call, its scaling factor,
// whether the scaling factor is applied by multiplication and intermediate
// unary ops.
bool IsSupportedF8Pattern(
HloInstruction *instr, HloInstruction *&x, HloInstruction *&x_scale,
bool &x_mult_scale, std::vector<std::pair<HloInstruction *, int>> &x_ops) {
// Given an operand of a dot, 'instr', returns true if this operand allows
// rewriting the dot in an FP8 cublasLT custom call, optionally with scaling.
// In particular, returns true if either 'instr' is FP8 or there is a there is a
// path from an FP8 instruction 'x' to 'instr' consisting of the following.
// 1. A convert to a wider type.
// 2. Optionally, a multiplication/division by a scalar, representing the scale.
// If present, the scalar scale is returned as 'x_scale' and 'x_mult_scale'
// is set to true or false depending on whether there is a multiplication or
// a division.
// 3. A possibly-empty set of ops communative with steps (1) and (2), meaning
// they can be safely moved before step (1). Such ops are returned in
// 'x_ops'.
// Steps (1) and (2) together are a dequantization, and can be fused into a
// cublas LT matmul. Step (3) can be moved before the cublas LT matmul.
bool IsSupportedF8Pattern(HloInstruction *instr, HloInstruction *&x,
HloInstruction *&x_scale, bool &x_mult_scale,
InstrPath &x_ops) {
absl::flat_hash_set<int> visited_instrs;
std::vector<std::pair<HloInstruction *, int>> subgraph =
FindF8SubgraphRecursive(instr, visited_instrs,
std::vector<std::pair<HloInstruction *, int>>{});

if (subgraph.empty()) {
std::optional<InstrPath> maybe_subgraph =
FindF8SubgraphRecursive(instr, visited_instrs);
if (!maybe_subgraph) {
return false;
}
InstrPath &subgraph = maybe_subgraph.value();

// Directly operating on an FP8 operand.
if (subgraph.size() == 1) {
x = subgraph[0].first;
CHECK(IsF8Type(x));
return true;
}

std::reverse(subgraph.begin(), subgraph.end());
// When not operating directly on an FP8 operand, the second and
// third instructions in the subgraph must describe a dequantization, i.e. a
// convert instruction followed by a multiply/divide instruction.
Expand Down Expand Up @@ -272,6 +291,7 @@ bool IsSupportedF8Pattern(
return instr->GetModule()->config().use_spmd_partitioning();
};

// Skip the initial FP8 instruction and the two dequantization instructions.
for (int i = 3; i < subgraph.size(); ++i) {
// The remaining instructions must be commutative with dequantization.
// Bitcast, broadcast, copy, dynamic-slice, pad, reshape, select, slice,
Expand Down Expand Up @@ -580,10 +600,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
GemmIsSupportedByCublasLt(*instr, gemm_backend_config));
HloInstruction *a, *b, *a_scale = nullptr, *b_scale = nullptr;
// Sequence of ops between dequantization and GEMM which are
// mathematically commutative with dequantization. The second element of
// the pair gives the index of the operand identifying the next op in the
// sequence.
std::vector<std::pair<HloInstruction *, int>> a_ops, b_ops;
// mathematically commutative with dequantization.
InstrPath a_ops, b_ops;
bool a_mult_scale{}, b_mult_scale{};
if (supported_by_cublaslt &&
Match(instr,
Expand Down Expand Up @@ -948,12 +966,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return *rocm_cc;
}

absl::StatusOr<bool> CreateF8CustomCall(
HloInstruction *instr, GpuBackendConfig &gpu_backend_config,
HloInstruction *a, HloInstruction *b, HloInstruction *a_scale,
HloInstruction *b_scale, bool a_mult_scale, bool b_mult_scale,
std::vector<std::pair<HloInstruction *, int>> a_ops,
std::vector<std::pair<HloInstruction *, int>> b_ops) {
absl::StatusOr<bool> CreateF8CustomCall(HloInstruction *instr,
GpuBackendConfig &gpu_backend_config,
HloInstruction *a, HloInstruction *b,
HloInstruction *a_scale,
HloInstruction *b_scale,
bool a_mult_scale, bool b_mult_scale,
InstrPath a_ops, InstrPath b_ops) {
GemmBackendConfig &gemm_backend_config =
*gpu_backend_config.mutable_gemm_backend_config();
if (IsCuda(gpu_version_)) {
Expand Down Expand Up @@ -1110,9 +1129,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {

// Sequentially apply the collected unary, dynamic-slice, pad and select ops
// to the unconverted and unscaled operands.
auto shift_ops =
[&instr](HloInstruction *&x,
std::vector<std::pair<HloInstruction *, int>> &x_ops) -> void {
auto shift_ops = [&instr](HloInstruction *&x, InstrPath &x_ops) -> void {
for (std::pair<HloInstruction *, int> op : x_ops) {
std::vector<HloInstruction *> operands = {x};
// Insert the additional operands of dynamic-slice ops.
Expand Down
23 changes: 13 additions & 10 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1556,21 +1556,24 @@ class MatMulEmitterHelper {
majormost_dim_start_index_ptr_val, mt::CacheModifier::NONE,
mt::EvictionPolicy::NORMAL,
/*isVolatile=*/false);
Value majormost_dim_start_index_lower_limit_val =
CreateConst(b_, majormost_dim_start_index_val.getType(), 0);
int64_t majormost_dim_start_index_upper_limit =
hlo->operand(0)->shape().dimensions(majormost_dim) -
hlo->dynamic_slice_sizes().at(majormost_dim);
Value majormost_dim_start_index_upper_limit_val =
CreateConst(b_, majormost_dim_start_index_val.getType(),
majormost_dim_start_index_upper_limit);
// Our Triton codegen only supports signed integers so far.
// We don't want to cast S64 indices to S32, because that could result
// in an incorrect value.
if (majormost_dim_start_index_val.getType().isInteger() &&
majormost_dim_start_index_val.getType().getIntOrFloatBitWidth() ==
64) {
return UncompilableMatmul(
"64 bit dynamic-slice indices are not supported yet.");
}
majormost_dim_start_index_val =
b_.create<ma::MaxSIOp>(majormost_dim_start_index_val,
majormost_dim_start_index_lower_limit_val);
Cast(b_, majormost_dim_start_index_val, b_.getI32Type());
majormost_dim_start_index_val =
b_.create<ma::MinSIOp>(majormost_dim_start_index_val,
majormost_dim_start_index_upper_limit_val);
b_.create<ma::MaxSIOp>(majormost_dim_start_index_val, Cst32(0));
majormost_dim_start_index_val = b_.create<ma::MinSIOp>(
majormost_dim_start_index_val,
Cst32(majormost_dim_start_index_upper_limit));

// How many "rows" (non-contracting dim values) are there in a slice of
// size 1?
Expand Down
43 changes: 43 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/gpu/triton_support.h"

#include <cstdint>
#include <iterator>
#include <variant>
#include <vector>
Expand All @@ -25,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/service/gpu/variant_visitor.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -314,6 +316,43 @@ CodegenDecision CanTritonHandleReduce(
return "Reduction is not a row-reduction of a single operand.";
}

CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr) {
for (const HloInstruction* index_operand : instr.index_operands()) {
switch (index_operand->shape().element_type()) {
case S8:
case S16:
case S32:
break; // supported
default:
return CodegenDecision(
"Dynamic slice is only supported with S8, S16, or S32 indices.");
}
}

// Similar to normal slice, we cannot slice a non-major-most dimension as
// that would introduce non-contiguous strides under tiling. The existing
// check against this in GetRequirementsIfSupportedOrder is not suitable for
// dynamic slices, so we instead check for this here.
const HloInstruction* input = instr.operand(0);
Layout in_layout = input->shape().layout();
int64_t majormost_dim_id =
in_layout.minor_to_major(in_layout.minor_to_major_size() - 1);

for (int i = 0; i < input->shape().dimensions_size(); ++i) {
if (i == majormost_dim_id) {
continue;
} else if (input->shape().dimensions(i) != instr.slice_sizes(i)) {
return CodegenDecision(
"Unsupported dynamic slice on non-major-most dimension.");
}
}

// TODO(b/343143854): Check the subtleties of which dynamic slices are
// supported, for example that a fragmented dimension cannot be sliced.
return CodegenDecision{};
}

CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
if (instr.IsElementwise()) {
Expand All @@ -334,6 +373,10 @@ CodegenDecision IsTritonSupportedInstruction(
}
return "Only supports root tuples.";
}
case HloOpcode::kDynamicSlice: {
return IsTritonSupportedDynamicSlice(
*Cast<HloDynamicSliceInstruction>(&instr));
}
case HloOpcode::kBitcast:
case HloOpcode::kTranspose:
case HloOpcode::kSlice:
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/gpu/triton_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <vector>

#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"
Expand Down Expand Up @@ -52,6 +53,14 @@ bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType);
CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version);

// Checks dynamic slice against requirements of triton emitter.
//
// This is exposed separately from IsTritonSupportedInstruction because we can
// use it in the dimension order propagation without adding a dependency on the
// GPU version.
CodegenDecision IsTritonSupportedDynamicSlice(
const HloDynamicSliceInstruction& instr);

} // namespace gpu
} // namespace xla

Expand Down
Loading

0 comments on commit 1ad51c4

Please sign in to comment.