[go: nahoru, domu]

Skip to content

Commit

Permalink
Update compatibility constraint in BUILD file.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626318760
  • Loading branch information
tensorflower-gardener committed Jun 1, 2024
1 parent 1d36138 commit 7ea7e8a
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 48 deletions.
24 changes: 1 addition & 23 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,6 @@ filegroup(
]),
)

# Collection of XLA tests that support XLA:CPU thunk-based runtime. We keep
# running them on TAP while we keep working on porting XLA:CPU to the new
# runtime.
#
# XLA:CPU thunks enabled with:
# --test_env=XLA_FLAGS=--xla_cpu_use_thunk_runtime=true
#
test_suite(
name = "thunk_runtime_tests",
tests = [
"//xla/tests:array_elementwise_ops_test_cpu",
"//xla/tests:axpy_simple_test_cpu",
"//xla/tests:convert_test_cpu",
"//xla/tests:copy_test_cpu",
"//xla/tests:floor_ceil_test_cpu",
"//xla/tests:numerics_test_cpu",
"//xla/tests:reshape_test_cpu",
"//xla/tests:reverse_test_cpu",
"//xla/tests:unary_op_test_cpu",
"//xla/tests/exhaustive:exhaustive_binary_16_bit_test_cpu",
],
)

cc_library(
name = "test_header_helper",
testonly = True,
Expand Down Expand Up @@ -641,6 +618,7 @@ cc_library(
":elemental_math_emitter",
":ir_emitter",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:elemental_ir_emitter",
"//xla/service/llvm_ir:fused_ir_emitter",
Expand Down
49 changes: 36 additions & 13 deletions third_party/xla/xla/service/cpu/ir_emitter2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/cpu/elemental_math_emitter.h"
#include "xla/service/cpu/ir_emitter.h"
Expand All @@ -47,6 +48,7 @@ limitations under the License.
#include "xla/service/llvm_ir/loop_emitter.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -201,6 +203,35 @@ bool IrEmitter2::fast_min_max() const {
return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max();
}

static absl::Status EmitElementalLoops(
llvm::IRBuilder<>& b, const HloInstruction* instr,
const llvm_ir::ElementGenerator& element_generator,
absl::Span<const llvm_ir::IrArray> results) {
// We can emit loops for instruction with multiple results only if it is a
// fusion, reduce or reduce window.
bool multiple_results = results.size() > 1;
bool support_multiple_results = instr->opcode() == HloOpcode::kFusion ||
instr->opcode() == HloOpcode::kReduce ||
instr->opcode() == HloOpcode::kReduceWindow;

if (multiple_results && !support_multiple_results) {
return Internal(
"Multi-output host kernels are not supported for %s instruction",
HloOpcodeString(instr->opcode()));
}

if (multiple_results) {
TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(element_generator, results, &b)
.EmitLoop(llvm_ir::IrName(instr)));
} else {
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, results.front(), &b)
.EmitLoop(llvm_ir::IrName(instr)));
}

return absl::OkStatus();
}

absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitElementalHostKernel(
const HloInstruction* instr) {
VLOG(2) << "Emit elemental host kernel: " << instr->name();
Expand All @@ -218,19 +249,13 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitElementalHostKernel(
};
}

if (kernel_prototype.results.size() > 1) {
return absl::InternalError("Multi-output host kernels are not supported");
}

ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_,
nested_ir_emitter_, fast_min_max());
llvm_ir::ElementGenerator element_generator =
elemental_emitter.MakeElementGenerator(instr, operand_to_generator);

TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, kernel_prototype.results[0], &b)
.EmitLoop(llvm_ir::IrName(instr)));

TF_RETURN_IF_ERROR(EmitElementalLoops(b, instr, element_generator,
kernel_prototype.results));
return kernels_.emplace_back(kernel_prototype.function->getName().str());
}

Expand Down Expand Up @@ -263,10 +288,8 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitFusionHostKernel(
auto element_generator,
fused_emitter.GetGenerator(*fusion->fused_expression_root()));

TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, kernel_prototype.results[0], &b)
.EmitLoop(llvm_ir::IrName(fusion)));

TF_RETURN_IF_ERROR(EmitElementalLoops(b, fusion, element_generator,
kernel_prototype.results));
return kernels_.emplace_back(kernel_prototype.function->getName().str());
}

Expand Down Expand Up @@ -328,7 +351,7 @@ IrEmitter2::KernelPrototype IrEmitter2::EmitKernelPrototype(
<< ", #arguments=" << arguments.size()
<< ", #results=" << results.size();
for (const Shape& argument : arguments) {
VLOG(3) << " arguments: " << argument.ToString(true);
VLOG(3) << " argument: " << argument.ToString(true);
}
for (const Shape& result : results) {
VLOG(3) << " result: " << result.ToString(true);
Expand Down
8 changes: 2 additions & 6 deletions third_party/xla/xla/service/cpu/runtime/copy_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ CopyThunk::CopyThunk(BufferAllocation::Slice source_buffer,
<< " must be compatble with destination shape "
<< destination_shape_.ToString(true);

// TODO(ezhulenev): This is almost certainly wrong for many types of copies
// that change layout, however it works in a few tests. This implementation
// is copied from `xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc`. It seems to
// work only if destination is a row-major layout.
if (source_shape_ != destination_shape_) {
TransposePlan::Options options;
options.elem_size_in_bytes =
Expand Down Expand Up @@ -83,10 +79,10 @@ absl::Status CopyThunk::Execute(const ExecuteParams& params) {
VLOG(3) << absl::StreamFormat("Copy buffer: use_transpose=%s",
transpose_plan_ ? "true" : "false");
VLOG(3) << absl::StreamFormat(
" - src: %s in slice %s (%p)", source_shape_.ToString(true),
" src: %s in slice %s (%p)", source_shape_.ToString(true),
source_buffer_.ToString(), source_data.opaque());
VLOG(3) << absl::StreamFormat(
" - dst: %s in slice %s (%p)", destination_shape_.ToString(true),
" dst: %s in slice %s (%p)", destination_shape_.ToString(true),
destination_buffer_.ToString(), destination_data.opaque());

// TODO(ezhulenev): Add benchmarks for copy thunk and add support for
Expand Down
5 changes: 4 additions & 1 deletion third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/cpu/runtime/kernel_thunk.h"

#include <cstdint>
#include <string>
#include <utility>

Expand Down Expand Up @@ -49,10 +50,12 @@ absl::Status KernelThunk::Execute(const ExecuteParams& params) {
absl::InlinedVector<se::DeviceMemoryBase, 8> buffers_data;
buffers_data.reserve(buffers_.size());

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
params.buffer_allocations->GetDeviceAddress(buffer));
VLOG(3) << absl::StreamFormat(" - arg: %s (%p)", buffer.ToString(),
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(),
buffers_data.back().opaque());
}

Expand Down
6 changes: 5 additions & 1 deletion third_party/xla/xla/service/cpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitEntryComputation(
if (!module.has_schedule()) {
return absl::InternalError("HLO module must be scheduled to emit thunks");
}
VLOG(0) << module.ToString();
return EmitHloComputation(module.entry_computation());
}

Expand Down Expand Up @@ -114,6 +113,8 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBroadcast:
case HloOpcode::kCbrt:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
Expand All @@ -123,6 +124,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
case HloOpcode::kErf:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
Expand All @@ -139,6 +141,8 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kRsqrt:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ xla_test(
xla_test(
name = "axpy_simple_test",
srcs = ["axpy_simple_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down Expand Up @@ -695,6 +696,7 @@ xla_test(
xla_test(
name = "unary_op_test",
srcs = ["unary_op_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down Expand Up @@ -803,6 +805,7 @@ xla_test(
"TENSORFLOW_USE_ROCM=1",
]),
shard_count = 25,
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down Expand Up @@ -1595,6 +1598,7 @@ xla_test(
shard_count = 31,
tags = [
"optonly",
"test_xla_cpu_thunks",
],
deps = [
":client_library_test_base",
Expand Down Expand Up @@ -1981,6 +1985,7 @@ xla_test(
name = "reshape_test",
srcs = ["reshape_test.cc"],
shard_count = 30,
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand All @@ -2006,6 +2011,7 @@ xla_test(
xla_test(
name = "reverse_test",
srcs = ["reverse_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down Expand Up @@ -2070,6 +2076,7 @@ xla_test(
xla_test(
name = "convert_test",
srcs = ["convert_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down Expand Up @@ -2261,6 +2268,7 @@ xla_test(
xla_test(
name = "floor_ceil_test",
srcs = ["floor_ceil_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":client_library_test_base",
":literal_test_util",
Expand Down Expand Up @@ -2978,6 +2986,7 @@ xla_test(
xla_test(
name = "numerics_test",
srcs = ["numerics_test.cc"],
tags = ["test_xla_cpu_thunks"],
deps = [
":hlo_test_base",
":test_macros_header",
Expand Down
13 changes: 9 additions & 4 deletions third_party/xla/xla/tests/exhaustive/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ xla_test(
shard_count = 50,
tags = [
"optonly",
"test_xla_cpu_thunks",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
],
Expand All @@ -83,9 +84,10 @@ xla_test(
"gpu",
"cpu",
],
shard_count = 48,
shard_count = 50,
tags = [
"optonly",
"test_xla_cpu_thunks",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
# TODO(b/151340488): Timed out on 2020-03-18.
Expand All @@ -107,9 +109,10 @@ xla_test(
"gpu",
"cpu",
],
shard_count = 48,
shard_count = 50,
tags = [
"optonly",
"test_xla_cpu_thunks",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
],
Expand All @@ -132,9 +135,10 @@ xla_test(
"gpu",
"cpu",
],
shard_count = 48,
shard_count = 50,
tags = [
"optonly",
"test_xla_cpu_thunks",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
],
Expand All @@ -151,9 +155,10 @@ xla_test(
"gpu",
"cpu",
],
shard_count = 48,
shard_count = 50,
tags = [
"optonly",
"test_xla_cpu_thunks",
# This is a big test that we skip for capacity reasons in OSS testing.
"no_oss",
],
Expand Down

0 comments on commit 7ea7e8a

Please sign in to comment.