[go: nahoru, domu]

Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623886396
  • Loading branch information
tensorflower-gardener committed May 28, 2024
1 parent 092d33a commit 4ebd066
Show file tree
Hide file tree
Showing 14 changed files with 123 additions and 70 deletions.
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ bzl_library(
srcs = ["tensorflow.bzl"],
visibility = ["//visibility:public"],
deps = [
"//devtools/build_cleaner/skylark:build_defs_lib",
"//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//third_party/compute_library:build_defs_bzl",
Expand Down
8 changes: 4 additions & 4 deletions third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc")
load("@com_github_grpc_grpc//bazel:protobuf.bzl", "well_known_proto_libs")
load("//third_party/protobuf/bazel:cc_proto_library.bzl", "cc_proto_library")
load("//third_party/protobuf/bazel:proto_library.bzl", "proto_library")

def cc_grpc_library(
name,
Expand Down Expand Up @@ -63,15 +65,13 @@ def cc_grpc_library(
proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1]
if well_known_protos:
proto_deps += well_known_proto_libs()

native.proto_library(
proto_library(
name = proto_target,
srcs = srcs,
deps = proto_deps,
**kwargs
)

native.cc_proto_library(
cc_proto_library(
name = cc_proto_target,
deps = [":" + proto_target],
**kwargs
Expand Down
1 change: 1 addition & 0 deletions third_party/systemlibs/grpc.bazel.generate_cc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ load(
"get_proto_root",
"proto_path_to_generated_filename",
)
load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo")

_GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h"
_GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc"
Expand Down
2 changes: 2 additions & 0 deletions third_party/systemlibs/grpc.bazel.protobuf.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility functions for generating protobuf code."""

load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo")

_PROTO_EXTENSION = ".proto"
_VIRTUAL_IMPORTS = "/_virtual_imports/"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc")
load("@com_github_grpc_grpc//bazel:protobuf.bzl", "well_known_proto_libs")
load("//third_party/protobuf/bazel:cc_proto_library.bzl", "cc_proto_library")
load("//third_party/protobuf/bazel:proto_library.bzl", "proto_library")

def cc_grpc_library(
name,
Expand Down Expand Up @@ -63,15 +65,13 @@ def cc_grpc_library(
proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1]
if well_known_protos:
proto_deps += well_known_proto_libs()

native.proto_library(
proto_library(
name = proto_target,
srcs = srcs,
deps = proto_deps,
**kwargs
)

native.cc_proto_library(
cc_proto_library(
name = cc_proto_target,
deps = [":" + proto_target],
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ load(
"get_proto_root",
"proto_path_to_generated_filename",
)
load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo")

_GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h"
_GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility functions for generating protobuf code."""

load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo")

_PROTO_EXTENSION = ".proto"
_VIRTUAL_IMPORTS = "/_virtual_imports/"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ load(
"if_not_windows",
"if_tsl_link_protobuf",
)
load("//third_party/protobuf/bazel:proto_library.bzl", "proto_library")
load("//tsl/platform:build_config_root.bzl", "if_static")

def well_known_proto_libs():
Expand Down Expand Up @@ -607,8 +608,7 @@ def tf_proto_library(
name_sans_proto = name[:-6]
else:
name_sans_proto = name

native.proto_library(
proto_library(
name = name,
srcs = srcs,
deps = protodeps + well_known_proto_libs(),
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cc_library(
"//xla:status",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
Expand Down
44 changes: 31 additions & 13 deletions third_party/xla/xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -64,6 +63,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -172,7 +172,8 @@ absl::Status CollectSliceInfo(
const BufferAssignment& buffer_assignment,
const HloInstruction& fusion_instr,
absl::Span<HloInstruction*> slice_instrs,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>&
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>&
offset_buffer_indices,
std::vector<std::optional<Shape>>& orig_shapes,
std::vector<std::optional<Shape>>& sliced_shapes,
Expand All @@ -183,15 +184,30 @@ absl::Status CollectSliceInfo(
return absl::OkStatus();
}

std::vector<BufferAllocation::Slice> offset_slices;
std::vector<std::variant<int64_t, BufferAllocation::Slice>> offset_slices;
for (auto idx_op : slice_instr->index_operands()) {
const auto* param = Cast<HloParameterInstruction>(idx_op);
TF_ASSIGN_OR_RETURN(
auto offset_slice,
GetAllocationSlice(buffer_assignment,
fusion_instr.operand(param->parameter_number()),
/*index=*/{}));
offset_slices.push_back(offset_slice);
const auto* offset_param = fusion_instr.operand(param->parameter_number());

if (auto* cst_offset = DynCast<HloConstantInstruction>(offset_param)) {
auto s32_scalar = ShapeUtil::MakeShape(PrimitiveType::S32, {});
auto s64_scalar = ShapeUtil::MakeShape(PrimitiveType::S64, {});

if (cst_offset->shape() == s32_scalar) {
offset_slices.emplace_back() = cst_offset->literal().data<int32_t>()[0];
} else if (cst_offset->shape() == s64_scalar) {
offset_slices.emplace_back() = cst_offset->literal().data<int64_t>()[0];
} else {
return absl::InternalError(
absl::StrCat("Unsupported constant offset shape: ",
cst_offset->shape().ToString()));
}

} else {
TF_ASSIGN_OR_RETURN(offset_slices.emplace_back(),
GetAllocationSlice(buffer_assignment, offset_param,
/*index=*/{}));
}
}
offset_buffer_indices[arg_idx] = std::move(offset_slices);
orig_shapes[arg_idx] = slice_instr->operand(0)->shape();
Expand Down Expand Up @@ -256,7 +272,8 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
const BufferAssignment& buffer_assignment =
ir_emitter_context.buffer_assignment();

std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices(4, std::nullopt);
std::vector<std::optional<Shape>> orig_shapes(4, std::nullopt);
std::vector<std::optional<Shape>> sliced_shapes(4, std::nullopt);
Expand Down Expand Up @@ -379,7 +396,7 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake,
slice_out_fake, slice_workspace_fake, deterministic_ops));

std::vector<std::optional<const BufferAllocation::Slice>> arguments{
std::vector<std::optional<BufferAllocation::Slice>> arguments{
lhs_slice, rhs_slice, output, workspace};

thunk = std::make_unique<AddressComputationThunk>(
Expand Down Expand Up @@ -435,15 +452,16 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
num_args += ShapeUtil::GetLeafCount(operand->shape());
});

std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices(num_args, std::nullopt);
std::vector<std::optional<Shape>> orig_shapes(num_args, std::nullopt);
std::vector<std::optional<Shape>> sliced_shapes(num_args, std::nullopt);
std::vector<std::optional<uint64_t>> offset_byte_sizes(num_args,
std::nullopt);

std::vector<HloInstruction*> slice_instrs(num_args, nullptr);
std::vector<std::optional<const BufferAllocation::Slice>> arguments;
std::vector<std::optional<BufferAllocation::Slice>> arguments;

unsigned arg_idx = 0;
// TODO(vuson): add test for custom call with token-typed operands
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ limitations under the License.
#include <memory>
#include <optional>
#include <utility>
#include <variant>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "llvm/ADT/STLExtras.h"
#include "xla/service/buffer_assignment.h"
Expand All @@ -38,16 +38,18 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/memory_allocation.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {

AddressComputationThunk::AddressComputationThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<const BufferAllocation::Slice>> arguments,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
Expand Down Expand Up @@ -151,28 +153,47 @@ absl::Status AddressComputationThunk::ExecuteOnStream(
std::vector<int64_t> slice_starts;
slice_starts.reserve(dst_shape.rank());

// Number of issues d2h transfers to copy offset values from device to host.
int64_t num_transfers = 0;

// Get offset for `argument_idx`-th argument, which has `dst_shape.rank()`
// components.
for (auto [offset_idx, values] : llvm::enumerate(llvm::zip(
*offset_slice, src_shape.dimensions(), dst_shape.dimensions()))) {
auto [slice, src_dim, dst_dim] = values;
se::DeviceMemoryBase offset_src =
orig_allocations.GetDeviceAddress(slice);
int64_t* offset_dst = &offsets_base[argument_idx + offset_idx];
// Copy the `offset_idx`-th component of the offset for the
// `argument_idx`-th argument from device to host.
TF_RETURN_IF_ERROR(
stream.Memcpy(offset_dst, offset_src, offset_byte_size.value()));

if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) {
return absl::InternalError(absl::StrFormat(
"Failed to retrieve all slice offset values on stream %p: %s",
&stream, blocked.message()));

if (int64_t* const_offset = std::get_if<int64_t>(&slice)) {
// Forward slice offsets that are known constant values
offsets_base[argument_idx + offset_idx] = *const_offset;
} else {
// Transfer slice offset value from device to host.
se::DeviceMemoryBase offset_src = orig_allocations.GetDeviceAddress(
std::get<BufferAllocation::Slice>(slice));
int64_t* offset_dst = &offsets_base[argument_idx + offset_idx];

// Copy the `offset_idx`-th component of the offset for the
// `argument_idx`-th argument from device to host.
TF_RETURN_IF_ERROR(
stream.Memcpy(offset_dst, offset_src, offset_byte_size.value()));
++num_transfers;
}
// Clamp start indices:
// start_indices[i] = min(max(start_indices[i], 0),
// operand.dimension_size[i] - size_indices[i])
auto start_index = std::min(std::max(*offset_dst, 0L), src_dim - dst_dim);
}

// Wait for the completion of all transfers.
if (num_transfers > 0) {
VLOG(2) << "Wait for completion of " << num_transfers << " transfer";
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
}

// Clamp start indices:
// start_indices[i] = min(max(start_indices[i], 0),
// operand.dimension_size[i] - size_indices[i])
for (auto [offset_idx, values] : llvm::enumerate(
llvm::zip(src_shape.dimensions(), dst_shape.dimensions()))) {
auto [src_dim, dst_dim] = values;
int64_t start_index =
std::min(std::max(offsets_base[argument_idx + offset_idx], 0L),
src_dim - dst_dim);
slice_starts.push_back(start_index);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <variant>
#include <vector>

#include "absl/base/thread_annotations.h"
Expand All @@ -29,6 +30,7 @@ limitations under the License.
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/shape.h"
#include "xla/status.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream_executor.h"
Expand All @@ -45,9 +47,10 @@ class AddressComputationThunk : public Thunk {
public:
AddressComputationThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<const BufferAllocation::Slice>> arguments,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
Expand All @@ -65,10 +68,10 @@ class AddressComputationThunk : public Thunk {

private:
std::unique_ptr<SequentialThunk> embedded_thunk_;
std::vector<std::optional<const BufferAllocation::Slice>>
embedded_thunk_arguments_;
std::vector<std::optional<BufferAllocation::Slice>> embedded_thunk_arguments_;
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_;
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices_;
std::vector<std::optional<Shape>> orig_shapes_;
std::vector<std::optional<Shape>> sliced_shapes_;
Expand Down
Loading

0 comments on commit 4ebd066

Please sign in to comment.