From 4ebd0664888d9bd6a194db810fcd96194f860076 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 11 Apr 2024 11:25:40 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 623886396 --- tensorflow/BUILD | 1 + .../systemlibs/grpc.bazel.cc_grpc_library.bzl | 8 +-- .../systemlibs/grpc.bazel.generate_cc.bzl | 1 + .../systemlibs/grpc.bazel.protobuf.bzl | 2 + .../systemlibs/grpc.bazel.cc_grpc_library.bzl | 8 +-- .../systemlibs/grpc.bazel.generate_cc.bzl | 1 + .../systemlibs/grpc.bazel.protobuf.bzl | 2 + .../tsl/tsl/platform/default/build_config.bzl | 4 +- third_party/xla/xla/service/gpu/fusions/BUILD | 1 + .../xla/xla/service/gpu/fusions/custom.cc | 44 ++++++++++---- third_party/xla/xla/service/gpu/runtime/BUILD | 1 + .../gpu/runtime/address_computation_thunk.cc | 59 +++++++++++++------ .../gpu/runtime/address_computation_thunk.h | 13 ++-- .../runtime/address_computation_thunk_test.cc | 48 +++++++-------- 14 files changed, 123 insertions(+), 70 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 5018c6e642e516..f027b31cf388ee 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1107,6 +1107,7 @@ bzl_library( srcs = ["tensorflow.bzl"], visibility = ["//visibility:public"], deps = [ + "//devtools/build_cleaner/skylark:build_defs_lib", "//tensorflow/core/platform:build_config_root_bzl", "//tensorflow/core/platform:rules_cc_bzl", "//third_party/compute_library:build_defs_bzl", diff --git a/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl b/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl index e427328c39be80..9656c07c496f07 100644 --- a/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl +++ b/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl @@ -2,6 +2,8 @@ load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc") load("@com_github_grpc_grpc//bazel:protobuf.bzl", "well_known_proto_libs") +load("//third_party/protobuf/bazel:cc_proto_library.bzl", "cc_proto_library") +load("//third_party/protobuf/bazel:proto_library.bzl", "proto_library") def cc_grpc_library( name, @@ -63,15 +65,13 @@ def cc_grpc_library( proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1] if well_known_protos: proto_deps += well_known_proto_libs() - - native.proto_library( + proto_library( name = proto_target, srcs = srcs, deps = proto_deps, **kwargs ) - - native.cc_proto_library( + cc_proto_library( name = cc_proto_target, deps = [":" + proto_target], **kwargs diff --git a/third_party/systemlibs/grpc.bazel.generate_cc.bzl b/third_party/systemlibs/grpc.bazel.generate_cc.bzl index c659ca16366b7a..083866d803136c 100644 --- a/third_party/systemlibs/grpc.bazel.generate_cc.bzl +++ b/third_party/systemlibs/grpc.bazel.generate_cc.bzl @@ -11,6 +11,7 @@ load( "get_proto_root", "proto_path_to_generated_filename", ) +load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo") _GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h" _GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc" diff --git a/third_party/systemlibs/grpc.bazel.protobuf.bzl b/third_party/systemlibs/grpc.bazel.protobuf.bzl index 3eca97dc2311fb..0f6ac66c0b140d 100644 --- a/third_party/systemlibs/grpc.bazel.protobuf.bzl +++ b/third_party/systemlibs/grpc.bazel.protobuf.bzl @@ -1,5 +1,7 @@ """Utility functions for generating protobuf code.""" +load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo") + _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl index e427328c39be80..9656c07c496f07 100644 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl +++ b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl @@ -2,6 +2,8 @@ load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc") load("@com_github_grpc_grpc//bazel:protobuf.bzl", "well_known_proto_libs") +load("//third_party/protobuf/bazel:cc_proto_library.bzl", "cc_proto_library") +load("//third_party/protobuf/bazel:proto_library.bzl", "proto_library") def cc_grpc_library( name, @@ -63,15 +65,13 @@ def cc_grpc_library( proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1] if well_known_protos: proto_deps += well_known_proto_libs() - - native.proto_library( + proto_library( name = proto_target, srcs = srcs, deps = proto_deps, **kwargs ) - - native.cc_proto_library( + cc_proto_library( name = cc_proto_target, deps = [":" + proto_target], **kwargs diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl index c659ca16366b7a..083866d803136c 100644 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl +++ b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl @@ -11,6 +11,7 @@ load( "get_proto_root", "proto_path_to_generated_filename", ) +load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo") _GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h" _GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc" diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl index 3eca97dc2311fb..0f6ac66c0b140d 100644 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl +++ b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl @@ -1,5 +1,7 @@ """Utility functions for generating protobuf code.""" +load("//third_party/protobuf/bazel/common:proto_info.bzl", "ProtoInfo") + _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl index 35cdcdc503add4..5e1f813424eafd 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl @@ -8,6 +8,7 @@ load( "if_not_windows", "if_tsl_link_protobuf", ) +load("//third_party/protobuf/bazel:proto_library.bzl", "proto_library") load("//tsl/platform:build_config_root.bzl", "if_static") def well_known_proto_libs(): @@ -607,8 +608,7 @@ def tf_proto_library( name_sans_proto = name[:-6] else: name_sans_proto = name - - native.proto_library( + proto_library( name = name, srcs = srcs, deps = protodeps + well_known_proto_libs(), diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 7873af21dc77c1..6cdbf912fc730a 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -129,6 +129,7 @@ cc_library( "//xla:status", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 2d4f26aa2f77da..9974c15d877a14 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -64,6 +63,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -172,7 +172,8 @@ absl::Status CollectSliceInfo( const BufferAssignment& buffer_assignment, const HloInstruction& fusion_instr, absl::Span slice_instrs, - std::vector>>& + std::vector>>>& offset_buffer_indices, std::vector>& orig_shapes, std::vector>& sliced_shapes, @@ -183,15 +184,30 @@ absl::Status CollectSliceInfo( return absl::OkStatus(); } - std::vector offset_slices; + std::vector> offset_slices; for (auto idx_op : slice_instr->index_operands()) { const auto* param = Cast(idx_op); - TF_ASSIGN_OR_RETURN( - auto offset_slice, - GetAllocationSlice(buffer_assignment, - fusion_instr.operand(param->parameter_number()), - /*index=*/{})); - offset_slices.push_back(offset_slice); + const auto* offset_param = fusion_instr.operand(param->parameter_number()); + + if (auto* cst_offset = DynCast(offset_param)) { + auto s32_scalar = ShapeUtil::MakeShape(PrimitiveType::S32, {}); + auto s64_scalar = ShapeUtil::MakeShape(PrimitiveType::S64, {}); + + if (cst_offset->shape() == s32_scalar) { + offset_slices.emplace_back() = cst_offset->literal().data()[0]; + } else if (cst_offset->shape() == s64_scalar) { + offset_slices.emplace_back() = cst_offset->literal().data()[0]; + } else { + return absl::InternalError( + absl::StrCat("Unsupported constant offset shape: ", + cst_offset->shape().ToString())); + } + + } else { + TF_ASSIGN_OR_RETURN(offset_slices.emplace_back(), + GetAllocationSlice(buffer_assignment, offset_param, + /*index=*/{})); + } } offset_buffer_indices[arg_idx] = std::move(offset_slices); orig_shapes[arg_idx] = slice_instr->operand(0)->shape(); @@ -256,7 +272,8 @@ absl::StatusOr EmitGemm( const BufferAssignment& buffer_assignment = ir_emitter_context.buffer_assignment(); - std::vector>> + std::vector>>> offset_buffer_indices(4, std::nullopt); std::vector> orig_shapes(4, std::nullopt); std::vector> sliced_shapes(4, std::nullopt); @@ -379,7 +396,7 @@ absl::StatusOr EmitGemm( thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake, slice_out_fake, slice_workspace_fake, deterministic_ops)); - std::vector> arguments{ + std::vector> arguments{ lhs_slice, rhs_slice, output, workspace}; thunk = std::make_unique( @@ -435,7 +452,8 @@ absl::StatusOr EmitCustomCall( num_args += ShapeUtil::GetLeafCount(operand->shape()); }); - std::vector>> + std::vector>>> offset_buffer_indices(num_args, std::nullopt); std::vector> orig_shapes(num_args, std::nullopt); std::vector> sliced_shapes(num_args, std::nullopt); @@ -443,7 +461,7 @@ absl::StatusOr EmitCustomCall( std::nullopt); std::vector slice_instrs(num_args, nullptr); - std::vector> arguments; + std::vector> arguments; unsigned arg_idx = 0; // TODO(vuson): add test for custom call with token-typed operands diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index df5aec07ace110..feeaf7ef046c9a 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -337,6 +337,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc index 2e49de4220e2df..c31f8176fffca4 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include #include +#include #include #include "absl/status/status.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "llvm/ADT/STLExtras.h" #include "xla/service/buffer_assignment.h" @@ -38,6 +38,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/memory_allocation.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla { @@ -45,9 +46,10 @@ namespace gpu { AddressComputationThunk::AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, - std::vector> arguments, + std::vector> arguments, std::vector> fake_allocations, - std::vector>> + std::vector>>> offset_buffer_indices, std::vector> orig_shapes, std::vector> sliced_shapes, @@ -151,28 +153,47 @@ absl::Status AddressComputationThunk::ExecuteOnStream( std::vector slice_starts; slice_starts.reserve(dst_shape.rank()); + // Number of issues d2h transfers to copy offset values from device to host. + int64_t num_transfers = 0; + // Get offset for `argument_idx`-th argument, which has `dst_shape.rank()` // components. for (auto [offset_idx, values] : llvm::enumerate(llvm::zip( *offset_slice, src_shape.dimensions(), dst_shape.dimensions()))) { auto [slice, src_dim, dst_dim] = values; - se::DeviceMemoryBase offset_src = - orig_allocations.GetDeviceAddress(slice); - int64_t* offset_dst = &offsets_base[argument_idx + offset_idx]; - // Copy the `offset_idx`-th component of the offset for the - // `argument_idx`-th argument from device to host. - TF_RETURN_IF_ERROR( - stream.Memcpy(offset_dst, offset_src, offset_byte_size.value())); - - if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { - return absl::InternalError(absl::StrFormat( - "Failed to retrieve all slice offset values on stream %p: %s", - &stream, blocked.message())); + + if (int64_t* const_offset = std::get_if(&slice)) { + // Forward slice offsets that are known constant values + offsets_base[argument_idx + offset_idx] = *const_offset; + } else { + // Transfer slice offset value from device to host. + se::DeviceMemoryBase offset_src = orig_allocations.GetDeviceAddress( + std::get(slice)); + int64_t* offset_dst = &offsets_base[argument_idx + offset_idx]; + + // Copy the `offset_idx`-th component of the offset for the + // `argument_idx`-th argument from device to host. + TF_RETURN_IF_ERROR( + stream.Memcpy(offset_dst, offset_src, offset_byte_size.value())); + ++num_transfers; } - // Clamp start indices: - // start_indices[i] = min(max(start_indices[i], 0), - // operand.dimension_size[i] - size_indices[i]) - auto start_index = std::min(std::max(*offset_dst, 0L), src_dim - dst_dim); + } + + // Wait for the completion of all transfers. + if (num_transfers > 0) { + VLOG(2) << "Wait for completion of " << num_transfers << " transfer"; + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + } + + // Clamp start indices: + // start_indices[i] = min(max(start_indices[i], 0), + // operand.dimension_size[i] - size_indices[i]) + for (auto [offset_idx, values] : llvm::enumerate( + llvm::zip(src_shape.dimensions(), dst_shape.dimensions()))) { + auto [src_dim, dst_dim] = values; + int64_t start_index = + std::min(std::max(offsets_base[argument_idx + offset_idx], 0L), + src_dim - dst_dim); slice_starts.push_back(start_index); } diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h index bfc70574d975d5..8374362f11da50 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -29,6 +30,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" #include "xla/status.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream_executor.h" @@ -45,9 +47,10 @@ class AddressComputationThunk : public Thunk { public: AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, - std::vector> arguments, + std::vector> arguments, std::vector> fake_allocations_, - std::vector>> + std::vector>>> offset_buffer_indices, std::vector> orig_shapes, std::vector> sliced_shapes, @@ -65,10 +68,10 @@ class AddressComputationThunk : public Thunk { private: std::unique_ptr embedded_thunk_; - std::vector> - embedded_thunk_arguments_; + std::vector> embedded_thunk_arguments_; std::vector> fake_allocations_; - std::vector>> + std::vector>>> offset_buffer_indices_; std::vector> orig_shapes_; std::vector> sliced_shapes_; diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index 39398f6dd10f8b..5673a1efa75b35 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -49,6 +50,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #if GOOGLE_CUDA @@ -135,8 +137,8 @@ TEST(AddressComputationThunkTest, SlicedGemm) { slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, @@ -288,10 +290,10 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { slice_out, slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; - std::vector rhs_offsets{slice_rhs_offset_0, - slice_rhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; + std::vector> rhs_offsets{ + slice_rhs_offset_0, slice_rhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, @@ -452,10 +454,10 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { slice_out, slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; - std::vector rhs_offsets{slice_rhs_offset_0, - slice_rhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; + std::vector> rhs_offsets{ + slice_rhs_offset_0, slice_rhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, @@ -630,7 +632,7 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { /*called_computation=*/nullptr)); // Wrapping address computation thunk around the custom call thunk. - std::vector slice_offsets{ + std::vector> slice_offsets{ slice_offset_0, slice_offset_1, slice_offset_2, slice_offset_3}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), @@ -788,10 +790,10 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { /*called_computation=*/nullptr)); // Wrapping address computation thunk around the custom call thunk. - std::vector slice_src_offsets{ + std::vector> slice_src_offsets{ slice_src_offset_0, slice_src_offset_1, slice_src_offset_2, slice_src_offset_3}; - std::vector slice_dst_offsets{ + std::vector> slice_dst_offsets{ slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2, slice_dst_offset_3}; AddressComputationThunk thunk( @@ -968,8 +970,8 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, @@ -1116,8 +1118,8 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, @@ -1257,8 +1259,8 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) { slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, @@ -1427,10 +1429,10 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { /*called_computation=*/nullptr)); // Wrapping address computation thunk around the custom call thunk. - std::vector slice_src_offsets{ + std::vector> slice_src_offsets{ slice_src_offset_0, slice_src_offset_1, slice_src_offset_2, slice_src_offset_3}; - std::vector slice_dst_offsets{ + std::vector> slice_dst_offsets{ slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2, slice_dst_offset_3}; AddressComputationThunk thunk( @@ -1608,8 +1610,8 @@ TEST(AddressComputationThunkTest, SlicedOperandsSameBufferGemm) { slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. - std::vector lhs_offsets{slice_lhs_offset_0, - slice_lhs_offset_1}; + std::vector> lhs_offsets{ + slice_lhs_offset_0, slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace},