[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Make the TritonSupportTest have exhaustive coverage of el…
Browse files Browse the repository at this point in the history
…ementwise ops.

The one that remains is covering `Convert` in cases where the input and output are different.

PiperOrigin-RevId: 646800379
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 26, 2024
1 parent 9a84833 commit 80c6906
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 142 deletions.
2 changes: 1 addition & 1 deletion tensorflow/tools/pip_package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def standard_or_nightly(standard, nightly):
'nvidia-curand-cu12 == 10.3.4.107',
'nvidia-cusolver-cu12 == 11.5.4.101',
'nvidia-cusparse-cu12 == 12.2.0.103',
'nvidia-nccl-cu12 == 2.19.3',
'nvidia-nccl-cu12 == 2.21.5',
'nvidia-nvjitlink-cu12 == 12.3.101',
]

Expand Down
6 changes: 3 additions & 3 deletions tensorflow/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ def _tf_repositories():
name = "nccl_archive",
build_file = "//third_party:nccl/archive.BUILD",
patch_file = ["//third_party/nccl:archive.patch"],
sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2",
strip_prefix = "nccl-2.19.3-1",
urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"),
sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303",
strip_prefix = "nccl-2.21.5-1",
urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"),
)

java_import_external(
Expand Down
4 changes: 2 additions & 2 deletions third_party/nccl/archive.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ exports_files(["LICENSE.txt"])

NCCL_MAJOR = 2

NCCL_MINOR = 19
NCCL_MINOR = 21

NCCL_PATCH = 3
NCCL_PATCH = 5

NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605

Expand Down
50 changes: 47 additions & 3 deletions third_party/nccl/archive.patch
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
diff --git a/src/device/all_gather.h b/src/device/all_gather.h
index 809e8ae..57eab81 100644
--- a/src/device/all_gather.h
+++ b/src/device/all_gather.h
@@ -296,7 +296,7 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/1, /*Send=*/1>(scat);
+ prims.template process</*Recv=*/1, /*Send=*/1>(scat);
}
}
return;
@@ -314,7 +314,7 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/1, /*Send=*/0>(scat);
+ prims.template process</*Recv=*/1, /*Send=*/0>(scat);
}
return;
}
diff --git a/src/device/common.cu b/src/device/common.cu.cc
similarity index 100%
rename from src/device/common.cu
rename to src/device/common.cu.cc
diff --git a/src/device/common.h b/src/device/common.h
index 97581f7..134fdb8 100644
index d8581d3..09ac3b6 100644
--- a/src/device/common.h
+++ b/src/device/common.h
@@ -15,7 +15,7 @@
Expand All @@ -14,9 +36,9 @@ index 97581f7..134fdb8 100644
+extern __device__ ncclDevFuncPtr_t ncclDevFuncTable[];

struct ncclShmemGroup {
ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY];
ncclConnInfo *recvConns[NCCL_MAX_ARITY];
diff --git a/src/device/generate.py b/src/device/generate.py
index 0b053de..87bf6cb 100755
index 43de85d..87cd677 100755
--- a/src/device/generate.py
+++ b/src/device/generate.py
@@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs))
Expand Down Expand Up @@ -111,3 +133,25 @@ diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc
similarity index 100%
rename from src/device/onerank.cu
rename to src/device/onerank.cu.cc
diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h
index d0b5249..2dacd60 100644
--- a/src/device/reduce_scatter.h
+++ b/src/device/reduce_scatter.h
@@ -254,7 +254,7 @@ struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_COLLNET_DIRECT,
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/0, /*Send=*/1>(scat);
+ prims.template process</*Recv=*/0, /*Send=*/1>(scat);
}
return;
}
@@ -278,7 +278,7 @@ struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_COLLNET_DIRECT,
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/1, /*Send=*/1>(scat);
+ prims.template process</*Recv=*/1, /*Send=*/1>(scat);
}
}
return;
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ exports_files(["LICENSE.txt"])

NCCL_MAJOR = 2

NCCL_MINOR = 19
NCCL_MINOR = 21

NCCL_PATCH = 3
NCCL_PATCH = 5

NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605

Expand Down
50 changes: 47 additions & 3 deletions third_party/xla/third_party/tsl/third_party/nccl/archive.patch
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
diff --git a/src/device/all_gather.h b/src/device/all_gather.h
index 809e8ae..57eab81 100644
--- a/src/device/all_gather.h
+++ b/src/device/all_gather.h
@@ -296,7 +296,7 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/1, /*Send=*/1>(scat);
+ prims.template process</*Recv=*/1, /*Send=*/1>(scat);
}
}
return;
@@ -314,7 +314,7 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCC
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/1, /*Send=*/0>(scat);
+ prims.template process</*Recv=*/1, /*Send=*/0>(scat);
}
return;
}
diff --git a/src/device/common.cu b/src/device/common.cu.cc
similarity index 100%
rename from src/device/common.cu
rename to src/device/common.cu.cc
diff --git a/src/device/common.h b/src/device/common.h
index 97581f7..134fdb8 100644
index d8581d3..09ac3b6 100644
--- a/src/device/common.h
+++ b/src/device/common.h
@@ -15,7 +15,7 @@
Expand All @@ -14,9 +36,9 @@ index 97581f7..134fdb8 100644
+extern __device__ ncclDevFuncPtr_t ncclDevFuncTable[];

struct ncclShmemGroup {
ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY];
ncclConnInfo *recvConns[NCCL_MAX_ARITY];
diff --git a/src/device/generate.py b/src/device/generate.py
index 0b053de..87bf6cb 100755
index 43de85d..87cd677 100755
--- a/src/device/generate.py
+++ b/src/device/generate.py
@@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs))
Expand Down Expand Up @@ -111,3 +133,25 @@ diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc
similarity index 100%
rename from src/device/onerank.cu
rename to src/device/onerank.cu.cc
diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h
index d0b5249..2dacd60 100644
--- a/src/device/reduce_scatter.h
+++ b/src/device/reduce_scatter.h
@@ -254,7 +254,7 @@ struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_COLLNET_DIRECT,
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/0, /*Send=*/1>(scat);
+ prims.template process</*Recv=*/0, /*Send=*/1>(scat);
}
return;
}
@@ -278,7 +278,7 @@ struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_COLLNET_DIRECT,
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
- prims.process</*Recv=*/1, /*Send=*/1>(scat);
+ prims.template process</*Recv=*/1, /*Send=*/1>(scat);
}
}
return;
6 changes: 3 additions & 3 deletions third_party/xla/third_party/tsl/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def _tf_repositories():
name = "nccl_archive",
build_file = "//third_party:nccl/archive.BUILD",
patch_file = ["//third_party/nccl:archive.patch"],
sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2",
strip_prefix = "nccl-2.19.3-1",
urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"),
sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303",
strip_prefix = "nccl-2.21.5-1",
urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"),
)

# Note that we are currently taking NVTX headers from a NCCL release to get nvToolsExtPayload.h
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,8 @@ xla_cc_test(
"//xla/service/gpu/model:tiled_hlo_computation",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:protobuf",
Expand Down
96 changes: 77 additions & 19 deletions third_party/xla/xla/service/gpu/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/variant_visitor.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -107,7 +108,12 @@ bool IsTritonSupportedDataType(PrimitiveType type,
case S8:
case S16:
case S32:
case S64:
case U16:
case F64:
return true;
case F8E4M3FN:
return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
default:
return false;
}
Expand Down Expand Up @@ -432,10 +438,24 @@ namespace {
absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
PrimitiveType element_type) {
if (element_type == PrimitiveType::PRED) {
return {HloOpcode::kConvert, HloOpcode::kNot};
return {HloOpcode::kNot};
}

if (element_type == PrimitiveType::U16) {
return {HloOpcode::kAbs};
}

absl::flat_hash_set<HloOpcode> ret{HloOpcode::kAbs, HloOpcode::kConvert};

if (element_type != PrimitiveType::F8E5M2 &&
element_type != PrimitiveType::F8E4M3FN) {
ret.insert(HloOpcode::kNegate);
}

if (primitive_util::IsIntegralType(element_type)) {
ret.insert(HloOpcode::kNot);
}
absl::flat_hash_set<HloOpcode> ret = {HloOpcode::kConvert, HloOpcode::kAbs,
HloOpcode::kNegate};

if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::F64) {
absl::flat_hash_set<HloOpcode> additional_opcodes{
Expand All @@ -460,19 +480,41 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
// TODO(b/345763510): make sure that this is accurate. At the moment, this is
// mostly a fork of the same code in legacy_triton::.
absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
PrimitiveType element_type) {
PrimitiveType element_type, const se::GpuComputeCapability& gpu_version) {
if (element_type == PrimitiveType::F8E5M2 ||
element_type == PrimitiveType::F8E4M3FN) {
return {};
}

if (element_type == PrimitiveType::PRED) {
return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor,
HloOpcode::kCompare};
return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor,
HloOpcode::kCompare, HloOpcode::kAdd, HloOpcode::kMultiply,
HloOpcode::kMaximum, HloOpcode::kMinimum};
}
absl::flat_hash_set<HloOpcode> ret = {
HloOpcode::kAdd, HloOpcode::kCompare, HloOpcode::kMaximum,
HloOpcode::kMinimum, HloOpcode::kMultiply, HloOpcode::kSubtract};

absl::flat_hash_set<HloOpcode> ret{HloOpcode::kCompare};

if (element_type != PrimitiveType::U16) {
ret.insert(HloOpcode::kAdd);
ret.insert(HloOpcode::kSubtract);
ret.insert(HloOpcode::kMaximum);
ret.insert(HloOpcode::kMinimum);
ret.insert(HloOpcode::kMultiply);

if (primitive_util::IsIntegralType(element_type)) {
ret.insert(HloOpcode::kDivide);
ret.insert(HloOpcode::kAnd);
ret.insert(HloOpcode::kOr);
ret.insert(HloOpcode::kXor);
}
}

if (element_type == PrimitiveType::F32 ||
element_type == PrimitiveType::F64) {
absl::flat_hash_set<HloOpcode> additional_opcodes{
HloOpcode::kAtan2, HloOpcode::kDivide, HloOpcode::kPower};
ret.insert(additional_opcodes.begin(), additional_opcodes.end());
ret.insert(HloOpcode::kAtan2);
ret.insert(HloOpcode::kDivide);
ret.insert(HloOpcode::kRemainder);
ret.insert(HloOpcode::kPower);
}
return ret;
}
Expand All @@ -481,19 +523,30 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
// TODO(b/345763510): make sure that this is accurate. At the moment, this is
// mostly a fork of the same code in legacy_triton::.
absl::flat_hash_set<HloOpcode> TritonSupportedTernaryElementwiseOps(
PrimitiveType element_type) {
PrimitiveType element_type, const se::GpuComputeCapability& gpu_version) {
if (element_type == PrimitiveType::U16) {
return {};
}

if (element_type == PrimitiveType::F8E5M2 ||
element_type == PrimitiveType::F8E4M3FN) {
return {HloOpcode::kSelect};
}

return {HloOpcode::kSelect, HloOpcode::kClamp};
}

// Returns `true` if the given opcode and element type correspond to a n-ary
// elementwise op that is genuinely supported by Triton. The caller is
// responsible for ensuring that the relevant data type is supported on the
// device of interest.
bool IsTritonSupportedElementwise(HloOpcode opcode,
PrimitiveType element_type) {
bool IsTritonSupportedElementwise(HloOpcode opcode, PrimitiveType element_type,
const se::GpuComputeCapability& gpu_version) {
return TritonSupportedUnaryElementwiseOps(element_type).contains(opcode) ||
TritonSupportedBinaryElementwiseOps(element_type).contains(opcode) ||
TritonSupportedTernaryElementwiseOps(element_type).contains(opcode);
TritonSupportedBinaryElementwiseOps(element_type, gpu_version)
.contains(opcode) ||
TritonSupportedTernaryElementwiseOps(element_type, gpu_version)
.contains(opcode);
}

} // namespace
Expand All @@ -518,8 +571,13 @@ CodegenDecision IsTritonSupportedInstruction(
}

if (instr.IsElementwise()) {
if (!IsTritonSupportedElementwise(instr.opcode(),
instr.shape().element_type())) {
if (!IsTritonSupportedElementwise(
instr.opcode(),
// Use the last operand below in order to support both `compare`
// and `select` which have a fixed PRED type in the output and first
// operand.
instr.operand(instr.operand_count() - 1)->shape().element_type(),
gpu_version)) {
return "Unsupported elementwise operation.";
}
return CodegenDecision{};
Expand Down
Loading

0 comments on commit 80c6906

Please sign in to comment.