[go: nahoru, domu]

Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13310 from Tixxx:tixxx/a2a_collective_matmul d7790ed5e206c5e1ebf33afa8e34d7faedff4d47
PiperOrigin-RevId: 644199531
  • Loading branch information
tensorflower-gardener committed Jun 18, 2024
1 parent 38dd164 commit f082a8f
Show file tree
Hide file tree
Showing 17 changed files with 1,256 additions and 31 deletions.
1 change: 0 additions & 1 deletion tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:export_graphdef",
"//tensorflow/compiler/mlir/tensorflow:import_model",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <memory>
#include <string>
#include <variant>

#include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -96,7 +97,7 @@ constexpr absl::string_view kGroupKeyAttrName =
// that is converted to a TensorShape.
absl::StatusOr<TensorShape> GetTensorShapeFromXlaArgument(
const XlaArgument& arg) {
if (absl::holds_alternative<xla::Shape>(arg.shape)) {
if (std::holds_alternative<xla::Shape>(arg.shape)) {
TensorShape arg_shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(std::get<xla::Shape>(arg.shape), &arg_shape));
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ cc_library(
"//xla/service/cpu/runtime:conditional_thunk",
"//xla/service/cpu/runtime:copy_thunk",
"//xla/service/cpu/runtime:dot_thunk",
"//xla/service/cpu/runtime:fft_thunk",
"//xla/service/cpu/runtime:infeed_thunk",
"//xla/service/cpu/runtime:kernel_thunk",
"//xla/service/cpu/runtime:outfeed_thunk",
Expand Down
23 changes: 23 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,26 @@ cc_library(
"@local_tsl//tsl/profiler/lib:traceme",
],
)

cc_library(
name = "fft_thunk",
srcs = ["fft_thunk.cc"],
hdrs = ["fft_thunk.h"],
deps = [
":thunk",
"//xla:shape_util",
"//xla:status_macros",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service/cpu:runtime_fft",
"//xla/service/cpu:runtime_single_threaded_fft",
"//xla/stream_executor",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme",
],
)
115 changes: 115 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/fft_thunk.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/service/cpu/runtime/fft_thunk.h"

#include <cstdint>
#include <memory>

#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/layout_util.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "xla/service/cpu/runtime_fft.h"
#include "xla/service/cpu/runtime_single_threaded_fft.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {

FftThunk::FftThunk(Info thunk_info, bool is_multi_thread_eigen,
int32_t fft_type, absl::Span<const int64_t> fft_length,
BufferAllocation::Slice input_buffer,
const Shape& input_shape,
BufferAllocation::Slice output_buffer,
const Shape& output_shape)
: Thunk(Kind::kFft, thunk_info),
is_multi_thread_eigen_(is_multi_thread_eigen),
is_double_precision_(input_shape.element_type() == F64 ||
input_shape.element_type() == C128),
fft_type_(fft_type),
fft_length_(fft_length.begin(), fft_length.end()),
input_buffer_(input_buffer),
output_buffer_(output_buffer),
input_shape_(input_shape),
output_shape_(output_shape) {}

absl::StatusOr<std::unique_ptr<FftThunk>> FftThunk::Create(
Info thunk_info, bool is_multi_thread_eigen, int32_t fft_type,
absl::Span<const int64_t> fft_length, BufferAllocation::Slice input_buffer,
const Shape& input_shape, BufferAllocation::Slice output_buffer,
const Shape& output_shape) {
return absl::WrapUnique(
new FftThunk(thunk_info, is_multi_thread_eigen, fft_type, fft_length,
input_buffer, input_shape, output_buffer, output_shape));
}

tsl::AsyncValueRef<Thunk::ExecuteEvent> FftThunk::Execute(
const ExecuteParams& params) {
tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); });
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(input_shape_.layout()));
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape_.layout()));

TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase input_data,
params.buffer_allocations->GetDeviceAddress(input_buffer_));
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase output_data,
params.buffer_allocations->GetDeviceAddress(output_buffer_));

const int fft_rank = fft_length_.size();

// Flatten operand batches.
absl::InlinedVector<int64_t, 4> operand_shape_flat(fft_rank + 1);
int64_t input_batch = 1;
int64_t input_batch_length = output_shape_.dimensions_size() - fft_rank;
for (int i = 0; i < input_batch_length; i++) {
input_batch *= input_shape_.dimensions(i);
}
operand_shape_flat[0] = input_batch;
for (int i = 0; i < fft_rank; ++i) {
operand_shape_flat[i + 1] = input_shape_.dimensions(i + input_batch_length);
}

// Args have been computed, make the call.
if (is_multi_thread_eigen_) {
__xla_cpu_runtime_DuccFft(nullptr,
reinterpret_cast<float*>(output_data.opaque()),
reinterpret_cast<float*>(input_data.opaque()),
fft_type_, is_double_precision_, fft_rank,
operand_shape_flat.data(), fft_length_.data());
} else {
__xla_cpu_runtime_DuccSingleThreadedFft(
nullptr, reinterpret_cast<float*>(output_data.opaque()),
reinterpret_cast<float*>(input_data.opaque()), fft_type_,
is_double_precision_, fft_rank, operand_shape_flat.data(),
fft_length_.data());
}
return OkExecuteEvent();
}

Thunk::BufferUses FftThunk::buffer_uses() const {
return {{input_buffer_, BufferUse::kRead},
{output_buffer_, BufferUse::kWrite}};
}

} // namespace xla::cpu
71 changes: 71 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/fft_thunk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_
#define XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_

#include <cstdint>
#include <memory>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "xla/shape.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/concurrency/async_value_ref.h"

namespace xla::cpu {

// This class stores everything that is needed to launch an FFT.
// It is generated by IrEmitter.
//
// This is thread-compatible.
class FftThunk final : public Thunk {
public:
static absl::StatusOr<std::unique_ptr<FftThunk>> Create(
Info thunk_info, bool is_multi_thread_eigen, int32_t fft_type,
absl::Span<const int64_t> fft_length,
BufferAllocation::Slice input_buffer, const Shape& input_shape,
BufferAllocation::Slice output_buffer, const Shape& output_shape);

tsl::AsyncValueRef<Thunk::ExecuteEvent> Execute(
const ExecuteParams& params) final;

BufferUses buffer_uses() const final;

private:
// Constructs a thunk for launching an FFT on a host.
FftThunk(Info thunk_info, bool is_multi_thread_eigen, int32_t fft_type,
absl::Span<const int64_t> fft_length,
BufferAllocation::Slice input_buffer, const Shape& input_shape,
BufferAllocation::Slice output_buffer, const Shape& output_shape);

const bool is_multi_thread_eigen_;
const bool is_double_precision_;
const int32_t fft_type_;
const std::vector<int64_t> fft_length_;

const BufferAllocation::Slice input_buffer_;
const BufferAllocation::Slice output_buffer_;

const Shape input_shape_;
const Shape output_shape_;
};

} // namespace xla::cpu

#endif // XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/cpu/runtime/thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ std::string_view Thunk::KindToString(Kind kind) {
return "conditional";
case Kind::kDot:
return "dot";
case Kind::kFft:
return "fft";
case Kind::kInfeed:
return "infeed";
case Kind::kRngGetAndUpdateState:
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/cpu/runtime/thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Thunk {
kCopy,
kConditional,
kDot,
kFft,
kInfeed,
kKernel,
kOutfeed,
Expand Down
25 changes: 25 additions & 0 deletions third_party/xla/xla/service/cpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "xla/service/cpu/runtime/conditional_thunk.h"
#include "xla/service/cpu/runtime/copy_thunk.h"
#include "xla/service/cpu/runtime/dot_thunk.h"
#include "xla/service/cpu/runtime/fft_thunk.h"
#include "xla/service/cpu/runtime/infeed_thunk.h"
#include "xla/service/cpu/runtime/kernel_thunk.h"
#include "xla/service/cpu/runtime/outfeed_thunk.h"
Expand Down Expand Up @@ -235,6 +236,9 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
case HloOpcode::kDot:
return EmitDotThunk(instruction);

case HloOpcode::kFft:
return EmitFftThunk(instruction);

default:
return absl::UnimplementedError(
absl::StrCat("HLO opcode `", HloOpcodeString(instruction->opcode()),
Expand Down Expand Up @@ -452,6 +456,27 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
}
}

absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFftThunk(
const HloInstruction* instruction) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*instruction, /*operands=*/{instruction->operands()},
/*supported_types=*/{F32, F64, C64, C128}));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
GetAllocationSlice(instruction->operand(0)));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
GetAllocationSlice(instruction));
return ThunkSequence::Of<FftThunk>(
/*info=*/ThunkInfo(instruction),
/*is_multi_thread_eigen=*/
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(),
/*fft_type=*/instruction->fft_type(),
/*fft_length=*/instruction->fft_length(),
/*input_buffer=*/arg_slice,
/*input_shape=*/instruction->operand(0)->shape(),
/*output_buffer=*/dest_slice,
/*output_shape=*/instruction->shape());
}

absl::StatusOr<ThunkEmitter::HostKernelAllocationSlices>
ThunkEmitter::GetHostKernelAllocationSlices(const HloInstruction* instruction) {
HostKernelAllocationSlices slices;
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/cpu/thunk_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class ThunkEmitter {
absl::StatusOr<ThunkSequence> EmitElementalKernelThunk(
const HloInstruction* instruction);

absl::StatusOr<ThunkSequence> EmitFftThunk(const HloInstruction* instruction);

absl::StatusOr<ThunkSequence> EmitFusionKernelThunk(
const HloInstruction* instruction);

Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6084,13 +6084,18 @@ cc_library(
hdrs = ["gpu_windowed_einsum_handler.h"],
deps = [
":backend_configs_cc",
"//xla:literal_util",
"//xla:status",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:hlo_creation_utils",
"//xla/service:hlo_pass",
"//xla/service:pattern_matcher",
"//xla/service:shape_inference",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -6110,6 +6115,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
Expand Down
Loading

0 comments on commit f082a8f

Please sign in to comment.