From 6d06d8a4ac13c47def72109bc2b63d6d682d6fc2 Mon Sep 17 00:00:00 2001 From: "Chunxiang (Jake) Zheng" Date: Thu, 27 Jun 2024 09:12:09 -0700 Subject: [PATCH 1/3] Explicitly #include onednn_config.proto.h in a few places. onednn_config.proto is moved out of backend_config.proto in PR#10301, and hence we need to explicitly #include onednn_config.proto.h to avoid build break when strict layering check is used. PiperOrigin-RevId: 647339933 --- tensorflow/core/kernels/mkl/BUILD | 1 + third_party/xla/xla/service/cpu/BUILD | 9 +++++++++ third_party/xla/xla/service/cpu/ir_emitter.cc | 1 + third_party/xla/xla/service/cpu/onednn_convolution.cc | 1 + .../xla/xla/service/cpu/onednn_convolution_rewriter.cc | 1 + third_party/xla/xla/service/cpu/onednn_layer_norm.cc | 1 + third_party/xla/xla/service/cpu/onednn_matmul.cc | 1 + .../xla/xla/service/cpu/onednn_matmul_rewriter.cc | 1 + third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc | 1 + third_party/xla/xla/service/cpu/onednn_softmax.cc | 1 + 10 files changed, 18 insertions(+) diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index fd1ff2abcdea16..3dc886b53b1032 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -99,6 +99,7 @@ tf_mkl_kernel_library( ], deps = [ "//tensorflow/core:graph", + "@com_google_absl//absl/container:inlined_vector", ] + MKL_DEPS, ) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index fef197f94143e8..40f01c0779674b 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -692,6 +692,7 @@ cc_library( ":elemental_math_emitter", ":ir_emission_utils", ":ir_function", + ":onednn_config_proto_cc", ":onednn_memory_util", ":parallel_loop_emitter", ":target_machine_features", @@ -1719,6 +1720,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_memory_util", ":onednn_util", ":runtime_lightweight_check", @@ -1743,6 +1745,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_memory_util", ":onednn_util", ":runtime_lightweight_check", @@ -1769,6 +1772,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_memory_util", ":runtime_lightweight_check", "//xla:executable_run_options", @@ -1792,6 +1796,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_memory_util", ":runtime_lightweight_check", "//xla:executable_run_options", @@ -1824,6 +1829,7 @@ cc_library( copts = tsl_copts(), deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_matmul", ":onednn_memory_util", ":onednn_pattern_utils", @@ -1854,6 +1860,7 @@ cc_library( copts = tsl_copts(), deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_memory_util", ":onednn_pattern_utils", ":onednn_util", @@ -1877,6 +1884,7 @@ cc_library( copts = tsl_copts(), deps = [ ":backend_config_proto_cc", + ":onednn_config_proto_cc", ":onednn_convolution", ":onednn_memory_util", ":onednn_util", @@ -1890,6 +1898,7 @@ cc_library( "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 0af8d8f5d23cd7..b8344183a0711c 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -68,6 +68,7 @@ limitations under the License. #include "xla/service/cpu/elemental_math_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/cpu/ir_function.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/parallel_loop_emitter.h" #include "xla/service/elemental_ir_emitter.h" #include "xla/service/hlo_module_config.h" diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.cc b/third_party/xla/xla/service/cpu/onednn_convolution.cc index f7a4e17d339fb7..7225169fcc512a 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution.cc @@ -31,6 +31,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "xla/executable_run_options.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" #include "xla/tsl/util/onednn_threadpool.h" diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc index b8533cb2eb7481..0c65c5d3dd2ff2 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc @@ -19,6 +19,7 @@ limitations under the License. #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_util.h" #include "xla/service/pattern_matcher.h" diff --git a/third_party/xla/xla/service/cpu/onednn_layer_norm.cc b/third_party/xla/xla/service/cpu/onednn_layer_norm.cc index 6abd698f5898a6..8807fb3a73d60d 100644 --- a/third_party/xla/xla/service/cpu/onednn_layer_norm.cc +++ b/third_party/xla/xla/service/cpu/onednn_layer_norm.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/base/dynamic_annotations.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" #include "xla/tsl/util/onednn_threadpool.h" diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 5e066795344766..8e8468e599541c 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index 95098dd31cfc6b..3ddfec5e41568c 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_matmul.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_pattern_utils.h" diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc index f95bf13880000a..ec94eb695d2397 100644 --- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/literal_comparison.h" #include "xla/literal_util.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_pattern_utils.h" #include "xla/service/cpu/onednn_util.h" diff --git a/third_party/xla/xla/service/cpu/onednn_softmax.cc b/third_party/xla/xla/service/cpu/onednn_softmax.cc index 8c187a042b2ab2..275a58fdb2145e 100644 --- a/third_party/xla/xla/service/cpu/onednn_softmax.cc +++ b/third_party/xla/xla/service/cpu/onednn_softmax.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/dynamic_annotations.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" #include "xla/tsl/util/onednn_threadpool.h" From 8213f2421b7d56f70b5066918f0a3477c3cc29e8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 27 Jun 2024 09:13:27 -0700 Subject: [PATCH 2/3] [xla:ffi] Add a test for returning errors from FFI handlers PiperOrigin-RevId: 647340322 --- third_party/xla/xla/ffi/api/ffi_test.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index df234501637b8a..77c81143774a74 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -136,6 +136,17 @@ TEST(FfiTest, ErrorEnumValue) { encoded(ErrorCode::kUnauthenticated)); } +TEST(FfiTest, ReturnError) { + CallFrameBuilder builder; + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().To( + []() { return Error(ErrorCode::kInternal, "Test error"); }); + + auto status = Call(*handler, call_frame); + EXPECT_EQ(status, absl::InternalError("Test error")); +} + TEST(FfiTest, AnyBufferArgument) { std::vector storage(4, 0.0f); se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); From c729596c18adb86cbf271b6e377a13acff538412 Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Tue, 25 Jun 2024 18:25:17 -0700 Subject: [PATCH 3/3] [XLA:CollectivePipeliner] Refactor IsSupportedDynamicUpdateSlice into a separate function. Reverts 49f4d9052259ae562ad0b0b84b5ba759494e6f83 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/13106 from buptzyb:asyncpool f25bd029f1308b23ab8f91876ae967fcfad29890 PiperOrigin-RevId: 646676579 --- tensorflow/core/common_runtime/gpu/BUILD | 1 + .../gpu/gpu_serving_device_selector.cc | 2 +- .../gpu/gpu_serving_device_selector.h | 4 +- .../gpu/gpu_serving_device_selector_test.cc | 24 +- tensorflow/lite/CMakeLists.txt | 50 ++-- .../lite/kernels/cpu_backend_context.cc | 4 + tensorflow/lite/kernels/cpu_backend_context.h | 6 + .../lite/python/kernel_tests/signal/BUILD | 40 +++ .../python/kernel_tests/signal/test_util.py | 69 +++++ .../kernel_tests/signal/window_ops_test.py | 60 +++++ tensorflow/python/kernel_tests/signal/BUILD | 3 - .../python/kernel_tests/signal/test_util.py | 52 ---- .../kernel_tests/signal/window_ops_test.py | 25 -- .../xla/xla/service/collective_pipeliner.cc | 240 +++++++++++------- .../xla/xla/service/sharding_propagation.cc | 20 +- .../xla/service/sharding_propagation_test.cc | 52 ++++ .../gpu/gpu_cudamallocasync_allocator.cc | 25 +- .../gpu/gpu_cudamallocasync_allocator_test.cc | 33 +++ .../tsl/framework/serving_device_selector.h | 8 + .../test_util/mock_serving_device_selector.h | 2 + 20 files changed, 504 insertions(+), 216 deletions(-) create mode 100644 tensorflow/lite/python/kernel_tests/signal/BUILD create mode 100644 tensorflow/lite/python/kernel_tests/signal/test_util.py create mode 100644 tensorflow/lite/python/kernel_tests/signal/window_ops_test.py diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index a6d52852eba92a..16115ea2c94c4b 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -479,5 +479,6 @@ cc_library( name = "gpu_scheduling_metrics_storage", srcs = ["gpu_scheduling_metrics_storage.cc"], hdrs = ["gpu_scheduling_metrics_storage.h"], + visibility = ["//visibility:public"], deps = ["@local_xla//xla/tsl/framework:real_time_in_memory_metric"], ) diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc index 3e440c4c6d8d3e..2de45258933f7a 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.cc @@ -69,7 +69,7 @@ tsl::DeviceReservation GpuServingDeviceSelector::ReserveDevice( void GpuServingDeviceSelector::FreeDeviceReservation( const tsl::DeviceReservation& reservation) { - Completed(reservation.device_index()); + Completed(reservation.device_index(), /*had_error=*/false); } void GpuServingDeviceSelector::Enqueue(int32_t index_on_host, diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h index c6f352acb961f6..51a342fb56f14f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h @@ -61,12 +61,12 @@ class GpuServingDeviceSelector : public tsl::ServingDeviceSelector { absl::string_view program_fingerprint) override; // Enqueues the program on the stream of index `index_on_host`. - void Enqueue(int32_t index_on_host, absl::string_view fingerprint); + void Enqueue(int32_t index_on_host, absl::string_view fingerprint) override; // Marks the completion of a program on the given stream. // If `had_error` is true, this function doesn't update program's execution // time stats to avoid incorrect estimates. - void Completed(int32_t index_on_host, bool had_error = false); + void Completed(int32_t index_on_host, bool had_error) override; private: friend class ServingDeviceSelectorTestHelper; diff --git a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc index fef168a30ea4c8..8d6e7f21f2faf7 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector_test.cc @@ -84,18 +84,18 @@ TEST(GpuServingDeviceSelector, DefaultPolicyOnlyEnqueueCall) { serving_device_selector->Enqueue(1, "4ms"); serving_device_selector->Enqueue(0, "2ms"); helper.ElapseNs(2e6); - serving_device_selector->Completed(0); + serving_device_selector->Completed(0, false); helper.ElapseNs(2e6); - serving_device_selector->Completed(0); - serving_device_selector->Completed(1); + serving_device_selector->Completed(0, false); + serving_device_selector->Completed(1, false); helper.ElapseNs(4e6); - serving_device_selector->Completed(1); - serving_device_selector->Completed(2); + serving_device_selector->Completed(1, false); + serving_device_selector->Completed(2, false); helper.ElapseNs(8e6); - serving_device_selector->Completed(2); - serving_device_selector->Completed(3); + serving_device_selector->Completed(2, false); + serving_device_selector->Completed(3, false); helper.ElapseNs(16e6); - serving_device_selector->Completed(3); + serving_device_selector->Completed(3, false); serving_device_selector->Enqueue(3, "16ms"); EXPECT_EQ( @@ -114,22 +114,22 @@ TEST(GpuServingDeviceSelector, DefaultPolicyOnlyEnqueueCall) { GpuSchedulingMetricsStorage::GetGlobalStorage().TotalGpuLoadNs().Get(), 30e6); helper.ElapseNs(2e6); - serving_device_selector->Completed(0); + serving_device_selector->Completed(0, false); EXPECT_EQ( GpuSchedulingMetricsStorage::GetGlobalStorage().TotalGpuLoadNs().Get(), 22e6); helper.ElapseNs(2e6); - serving_device_selector->Completed(1); + serving_device_selector->Completed(1, false); EXPECT_EQ( GpuSchedulingMetricsStorage::GetGlobalStorage().TotalGpuLoadNs().Get(), 16e6); helper.ElapseNs(4e6); - serving_device_selector->Completed(2); + serving_device_selector->Completed(2, false); EXPECT_EQ( GpuSchedulingMetricsStorage::GetGlobalStorage().TotalGpuLoadNs().Get(), 8e6); helper.ElapseNs(8e6); - serving_device_selector->Completed(3); + serving_device_selector->Completed(3, false); EXPECT_EQ( GpuSchedulingMetricsStorage::GetGlobalStorage().TotalGpuLoadNs().Get(), 0e6); diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 6e51f1115943cd..5b84aac7f82c20 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -182,26 +182,29 @@ include_directories( ${XLA_SOURCE_DIR} ) # Download necessary dependencies. -# Download pthreadpool source package if it doesn't exist. -if(SYSTEM_PTHREADPOOL) - find_library(PTHREADPOOL_LIB pthreadpool REQUIRED) -elseif(NOT DEFINED PTHREADPOOL_SOURCE_DIR) - message(STATUS "Downloading pthreadpool to ${CMAKE_BINARY_DIR}/pthreadpool-source (define SYSTEM_PTHREADPOOL or PTHREADPOOL_SOURCE_DIR to avoid it)") - configure_file(cmake/DownloadPThreadPool.cmake "${CMAKE_BINARY_DIR}/pthreadpool-download/CMakeLists.txt") - execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/pthreadpool-download") - execute_process(COMMAND "${CMAKE_COMMAND}" --build . - WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/pthreadpool-download") - set(PTHREADPOOL_SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" CACHE STRING "pthreadpool source directory") -endif() -# Configure pthreadpool -if(NOT SYSTEM_PTHREADPOOL AND NOT TARGET pthreadpool) - set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") - set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") - set(PTHREADPOOL_ALLOW_DEPRECATED_API OFF CACHE BOOL "") - add_subdirectory( - "${PTHREADPOOL_SOURCE_DIR}" - "${CMAKE_BINARY_DIR}/pthreadpool") +if(TFLITE_ENABLE_XNNPACK) + # pthreadpool is used by XNNPACK. + if(SYSTEM_PTHREADPOOL) + find_library(PTHREADPOOL_LIB pthreadpool REQUIRED) + elseif(NOT DEFINED PTHREADPOOL_SOURCE_DIR) + message(STATUS "Downloading pthreadpool to ${CMAKE_BINARY_DIR}/pthreadpool-source (define SYSTEM_PTHREADPOOL or PTHREADPOOL_SOURCE_DIR to avoid it)") + configure_file(cmake/DownloadPThreadPool.cmake "${CMAKE_BINARY_DIR}/pthreadpool-download/CMakeLists.txt") + execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/pthreadpool-download") + execute_process(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/pthreadpool-download") + set(PTHREADPOOL_SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" CACHE STRING "pthreadpool source directory") + endif() + # Configure pthreadpool + if(NOT SYSTEM_PTHREADPOOL AND NOT TARGET pthreadpool) + set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") + set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") + set(PTHREADPOOL_ALLOW_DEPRECATED_API OFF CACHE BOOL "") + add_subdirectory( + "${PTHREADPOOL_SOURCE_DIR}" + "${CMAKE_BINARY_DIR}/pthreadpool") + endif() + list(APPEND TFLITE_TARGET_DEPENDENCIES pthreadpool) endif() set(TF_TARGET_PRIVATE_OPTIONS "") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang$") @@ -536,6 +539,12 @@ if(TFLITE_ENABLE_XNNPACK) ruy::ruy ) + list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DTFLITE_KERNEL_USE_XNNPACK") + target_compile_options(xnnpack-delegate + PUBLIC ${TFLITE_TARGET_PUBLIC_OPTIONS} + PRIVATE ${TFLITE_TARGET_PRIVATE_OPTIONS} + ) + list(APPEND TFLITE_TARGET_DEPENDENCIES xnnpack-delegate XNNPACK @@ -696,7 +705,6 @@ target_link_libraries(tensorflow-lite gemmlowp::gemmlowp ml_dtypes ruy::ruy - pthreadpool ${CMAKE_DL_LIBS} ${TFLITE_TARGET_DEPENDENCIES} ) diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index 70ae333bb78e17..010d38a8564cdd 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -17,7 +17,9 @@ limitations under the License. #include +#ifdef TFLITE_KERNEL_USE_XNNPACK #include "pthreadpool.h" // from @pthreadpool +#endif #ifdef TFLITE_HAVE_CPUINFO #include "include/cpuinfo.h" @@ -149,6 +151,7 @@ void CpuBackendContext::SetMaxNumThreads(int max_num_threads) { void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; } +#ifdef TFLITE_KERNEL_USE_XNNPACK pthreadpool_t CpuBackendContext::get_xnnpack_threadpool() { if (!xnnpack_threadpool_ && max_num_threads_ > 1) { xnnpack_threadpool_.reset( @@ -156,6 +159,7 @@ pthreadpool_t CpuBackendContext::get_xnnpack_threadpool() { } return xnnpack_threadpool_.get(); } +#endif bool CpuBackendContext::PreferGemmlowpOnX86() { bool use_gemmlowp_on_x86 = false; diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index 24e70bb6391619..d5bf9114c4b622 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -24,7 +24,9 @@ limitations under the License. #include #include "public/gemmlowp.h" +#ifdef TFLITE_KERNEL_USE_XNNPACK #include "pthreadpool.h" // from @pthreadpool +#endif #include "ruy/context.h" // from @ruy #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/external_cpu_backend_context.h" @@ -54,7 +56,9 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { bool use_caching() const { return use_caching_; } +#ifdef TFLITE_KERNEL_USE_XNNPACK pthreadpool_t get_xnnpack_threadpool(); +#endif void ClearCaches() override { ruy_context_->ClearPrepackedCache(); } @@ -118,10 +122,12 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { // (currently the Ruy library only). bool use_caching_; +#ifdef TFLITE_KERNEL_USE_XNNPACK // A smart pointer for the xnnpack threadpool. Is created by a call from the // interpreter, and then consumed by xnnpack, possibly via a TFLite kernel. std::unique_ptr xnnpack_threadpool_{nullptr, &pthreadpool_destroy}; +#endif CpuBackendContext(const CpuBackendContext&) = delete; }; diff --git a/tensorflow/lite/python/kernel_tests/signal/BUILD b/tensorflow/lite/python/kernel_tests/signal/BUILD new file mode 100644 index 00000000000000..a6128e6be32f54 --- /dev/null +++ b/tensorflow/lite/python/kernel_tests/signal/BUILD @@ -0,0 +1,40 @@ +load("//tensorflow:strict.default.bzl", "py_strict_library") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], +) + +py_strict_library( + name = "test_util", + srcs = ["test_util.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow/lite/python:interpreter", + "//tensorflow/lite/python:lite", + "//tensorflow/python/eager:def_function", + ], +) + +cuda_py_strict_test( + name = "window_ops_test", + srcs = ["window_ops_test.py"], + python_version = "PY3", + shard_count = 4, + tags = [ + "no_rocm", + "no_windows_gpu", + ], + deps = [ + ":test_util", + "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops/signal:window_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/lite/python/kernel_tests/signal/test_util.py b/tensorflow/lite/python/kernel_tests/signal/test_util.py new file mode 100644 index 00000000000000..2d32e8d0d48882 --- /dev/null +++ b/tensorflow/lite/python/kernel_tests/signal/test_util.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.signal.""" + +from tensorflow.lite.python import interpreter +from tensorflow.lite.python import lite +from tensorflow.python.eager import def_function + + +def tflite_convert(fn, input_templates): + """Converts the provided fn to tf.lite model. + + Args: + fn: A callable that expects a list of inputs like input_templates that + returns a tensor or structure of tensors. + input_templates: A list of Tensors, ndarrays or TensorSpecs describing the + inputs that fn expects. The actual values of the Tensors or ndarrays are + unused. + + Returns: + The serialized tf.lite model. + """ + fn = def_function.function(fn) + concrete_func = fn.get_concrete_function(*input_templates) + converter = lite.TFLiteConverterV2([concrete_func]) + return converter.convert() + + +def evaluate_tflite_model(tflite_model, input_ndarrays): + """Evaluates the provided tf.lite model with the given input ndarrays. + + Args: + tflite_model: bytes. The serialized tf.lite model. + input_ndarrays: A list of NumPy arrays to feed as input to the model. + + Returns: + A list of ndarrays produced by the model. + + Raises: + ValueError: If the number of input arrays does not match the number of + inputs the model expects. + """ + the_interpreter = interpreter.Interpreter(model_content=tflite_model) + the_interpreter.allocate_tensors() + + input_details = the_interpreter.get_input_details() + output_details = the_interpreter.get_output_details() + + if len(input_details) != len(input_ndarrays): + raise ValueError('Wrong number of inputs: provided=%s, ' + 'input_details=%s output_details=%s' % ( + input_ndarrays, input_details, output_details)) + for input_tensor, data in zip(input_details, input_ndarrays): + the_interpreter.set_tensor(input_tensor['index'], data) + the_interpreter.invoke() + return [the_interpreter.get_tensor(details['index']) + for details in output_details] diff --git a/tensorflow/lite/python/kernel_tests/signal/window_ops_test.py b/tensorflow/lite/python/kernel_tests/signal/window_ops_test.py new file mode 100644 index 00000000000000..240d04c7138ae0 --- /dev/null +++ b/tensorflow/lite/python/kernel_tests/signal/window_ops_test.py @@ -0,0 +1,60 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for window_ops.""" + +from absl.testing import parameterized +import numpy as np + +from tensorflow.lite.python.kernel_tests.signal import test_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python.ops.signal import window_ops +from tensorflow.python.platform import test + + +@tf_test_util.run_all_in_graph_and_eager_modes +class WindowOpsTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + # Only float32 is supported. + (window_ops.hann_window, 10, False, dtypes.float32), + (window_ops.hann_window, 10, True, dtypes.float32), + (window_ops.hamming_window, 10, False, dtypes.float32), + (window_ops.hamming_window, 10, True, dtypes.float32), + (window_ops.vorbis_window, 12, None, dtypes.float32), + ) + def test_tflite_convert(self, window_fn, window_length, periodic, dtype): + + def fn(window_length): + try: + return window_fn(window_length, periodic=periodic, dtype=dtype) + except TypeError: + return window_fn(window_length, dtype=dtype) + + tflite_model = test_util.tflite_convert( + fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)] + ) + window_length = np.array(window_length).astype(np.int32) + (actual_output,) = test_util.evaluate_tflite_model( + tflite_model, [window_length] + ) + + expected_output = self.evaluate(fn(window_length)) + self.assertAllClose(actual_output, expected_output, rtol=1e-6, atol=1e-6) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD index 687f27d7d2b078..641303f87bdfb2 100644 --- a/tensorflow/python/kernel_tests/signal/BUILD +++ b/tensorflow/python/kernel_tests/signal/BUILD @@ -14,9 +14,6 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/lite/python:interpreter", - "//tensorflow/lite/python:lite", - "//tensorflow/python/eager:def_function", "//tensorflow/python/grappler:tf_optimizer", "//tensorflow/python/training:saver", ], diff --git a/tensorflow/python/kernel_tests/signal/test_util.py b/tensorflow/python/kernel_tests/signal/test_util.py index 0d82ca8e634122..3a2f7b7005ed5b 100644 --- a/tensorflow/python/kernel_tests/signal/test_util.py +++ b/tensorflow/python/kernel_tests/signal/test_util.py @@ -15,9 +15,6 @@ """Test utilities for tf.signal.""" from tensorflow.core.protobuf import config_pb2 -from tensorflow.lite.python import interpreter -from tensorflow.lite.python import lite -from tensorflow.python.eager import def_function from tensorflow.python.grappler import tf_optimizer from tensorflow.python.training import saver @@ -45,52 +42,3 @@ def grappler_optimize(graph, fetches=None, config_proto=None): metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) return tf_optimizer.OptimizeGraph(config_proto, metagraph) - -def tflite_convert(fn, input_templates): - """Converts the provided fn to tf.lite model. - - Args: - fn: A callable that expects a list of inputs like input_templates that - returns a tensor or structure of tensors. - input_templates: A list of Tensors, ndarrays or TensorSpecs describing the - inputs that fn expects. The actual values of the Tensors or ndarrays are - unused. - - Returns: - The serialized tf.lite model. - """ - fn = def_function.function(fn) - concrete_func = fn.get_concrete_function(*input_templates) - converter = lite.TFLiteConverterV2([concrete_func]) - return converter.convert() - - -def evaluate_tflite_model(tflite_model, input_ndarrays): - """Evaluates the provided tf.lite model with the given input ndarrays. - - Args: - tflite_model: bytes. The serialized tf.lite model. - input_ndarrays: A list of NumPy arrays to feed as input to the model. - - Returns: - A list of ndarrays produced by the model. - - Raises: - ValueError: If the number of input arrays does not match the number of - inputs the model expects. - """ - the_interpreter = interpreter.Interpreter(model_content=tflite_model) - the_interpreter.allocate_tensors() - - input_details = the_interpreter.get_input_details() - output_details = the_interpreter.get_output_details() - - if len(input_details) != len(input_ndarrays): - raise ValueError('Wrong number of inputs: provided=%s, ' - 'input_details=%s output_details=%s' % ( - input_ndarrays, input_details, output_details)) - for input_tensor, data in zip(input_details, input_ndarrays): - the_interpreter.set_tensor(input_tensor['index'], data) - the_interpreter.invoke() - return [the_interpreter.get_tensor(details['index']) - for details in output_details] diff --git a/tensorflow/python/kernel_tests/signal/window_ops_test.py b/tensorflow/python/kernel_tests/signal/window_ops_test.py index ba783b34ef0f9c..7bedb3ee25c7c8 100644 --- a/tensorflow/python/kernel_tests/signal/window_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/window_ops_test.py @@ -23,7 +23,6 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.ops.signal import window_ops @@ -151,30 +150,6 @@ def test_constant_folding(self, window_fn, periodic, tf_dtype_tol): rewritten_graph = test_util.grappler_optimize(g, [window]) self.assertLen(rewritten_graph.node, 1) - @parameterized.parameters( - # Only float32 is supported. - (window_ops.hann_window, 10, False, dtypes.float32), - (window_ops.hann_window, 10, True, dtypes.float32), - (window_ops.hamming_window, 10, False, dtypes.float32), - (window_ops.hamming_window, 10, True, dtypes.float32), - (window_ops.vorbis_window, 12, None, dtypes.float32)) - def test_tflite_convert(self, window_fn, window_length, periodic, dtype): - - def fn(window_length): - try: - return window_fn(window_length, periodic=periodic, dtype=dtype) - except TypeError: - return window_fn(window_length, dtype=dtype) - - tflite_model = test_util.tflite_convert( - fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)]) - window_length = np.array(window_length).astype(np.int32) - actual_output, = test_util.evaluate_tflite_model( - tflite_model, [window_length]) - - expected_output = self.evaluate(fn(window_length)) - self.assertAllClose(actual_output, expected_output, rtol=1e-6, atol=1e-6) - @parameterized.parameters( itertools.product( _MDCT_WINDOW_LENGTHS, diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 6e08ae23708dad..b15a8429c13c26 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -442,8 +442,9 @@ bool IsLoopIterator(const HloInstruction* instr, // Scavenge operands that are dependencies not included in the ops set and that // aren't the source_op passed as input parameter and return them in a vector. std::vector CollectDependenciesToPipeline( - HloInstruction* source_op, absl::Span ops) { - absl::flat_hash_set formatting_set(ops.begin(), ops.end()); + const HloInstruction* source_op, absl::Span ops) { + absl::flat_hash_set formatting_set(ops.begin(), + ops.end()); formatting_set.insert(source_op); std::vector to_return; absl::flat_hash_set already_inserted; @@ -709,6 +710,27 @@ class WhileLoopAnalysis { int64_t GetMaxPipeliningPerLoop() const { return max_pipelining_per_loop_; } bool ComputeLoopStatistics(); + // Checks if the given dynamic-update-slice is supported for pipelining and + // returns its slice dimension and index in the while tuple if supported. + // Returns std::nullopt if it is not supported, which can happen for several + // reasons: + // - The slice dimension can not be found or is not 0 for forward-sinking. + // - The number of slices size does not match the loop iteration count. + // - There is an unexpected shape/size in the overall dependency chain. + // - The buffer to insert into is not a GTE from the loop parameter. + // - The parameter usage is not compatible with the expected pattern. + // - The update slicing is not compatible with the expected pattern. + // - The update index is not monotonic. + // - The output index for the insertion can not be found. + std::optional> IsSupportedDynamicUpdateSlice( + const HloDynamicUpdateSliceInstruction* dyn_update, + const HloInstruction* instr, + const std::vector& formatting_ops, + CollectivePipeliner::PipeliningDirection direction, + int64_t level_to_operate_on, + const absl::flat_hash_map& parameter_gtes_count, + const absl::flat_hash_map& index_ranges) + const; void CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, @@ -849,6 +871,119 @@ bool WhileLoopAnalysis::ComputeLoopStatistics() { return true; } +std::optional> +WhileLoopAnalysis::IsSupportedDynamicUpdateSlice( + const HloDynamicUpdateSliceInstruction* dyn_update, + const HloInstruction* instr, + const std::vector& formatting_ops, + CollectivePipeliner::PipeliningDirection direction, + int64_t level_to_operate_on, + const absl::flat_hash_map& parameter_gtes_count, + const absl::flat_hash_map& index_ranges) + const { + HloComputation* while_body = while_->while_body(); + const HloInstruction* loop_parameter = + while_body->parameter_instructions()[0]; + std::optional sliced_dim = GetSlicedDimension(dyn_update); + if (!sliced_dim.has_value()) { + VLOG(5) << "Skipping " << instr->name() + << " because couldn't find sliced dimension"; + return std::nullopt; + } + if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink && + (*sliced_dim != 0 || dyn_update->shape().dimensions(0) != + loop_iteration_count_->GetUnsignedValue())) { + VLOG(5) << "Skipping " << instr->name() + << " because number of iteration of the loop doesn't match " + "slices being inserted or slice dim is not 0. slice_dim = " + << *sliced_dim + << " loop count = " << loop_iteration_count_->GetUnsignedValue(); + } + if (!process_different_sized_options_) { + if (!formatting_ops.empty()) { + if (instr->operand(0)->shape() != formatting_ops.back()->shape()) { + VLOG(5) << "Skipping " << instr->name() + << " because operand and last formatting op don't have the " + "same shape"; + return std::nullopt; + } + auto dependencies_to_pipeline = CollectDependenciesToPipeline( + instr, absl::MakeConstSpan(formatting_ops)); + bool skip_because_not_same_size = false; + // If any instruction in the dependency chain is not of the same size + // then we abort for this instruction. + for (auto* dependency : dependencies_to_pipeline) { + if (ShapeUtil::IsEffectiveScalar(dependency->shape())) { + skip_because_not_same_size = true; + break; + } + } + if (skip_because_not_same_size) { + VLOG(5) + << "Skipping " << instr->name() + << " because formatting ops do not have the expected shapes/sizes"; + return std::nullopt; + } + } else if (instr->operand(0)->shape() != instr->shape()) { + VLOG(5) << "Skipping " << instr->name() + << " because instr does not have the same shape as its operand"; + return std::nullopt; + } + } + const HloInstruction* to_insert_into = dyn_update->operand(0); + if (level_to_operate_on == 0 && + (to_insert_into->opcode() != HloOpcode::kGetTupleElement || + to_insert_into->operand(0) != loop_parameter)) { + VLOG(5) << "Skipping " << instr->name() + << " because slice to insert into is not a GTE from input " + "parameter " + << to_insert_into->ToString(); + return std::nullopt; + } + // If Level is > 0 then we already did our analysis in the previous + // iteration for safeness of this index to transform. + if (level_to_operate_on == 0) { + if (to_insert_into->opcode() == HloOpcode::kGetTupleElement) { + // GTE for this parameter is not CSEd. Abort because we don't analyze + // every single use from other GTEs. + if (parameter_gtes_count.at(to_insert_into->tuple_index()) != 1) { + VLOG(5) << "Skipping " << instr->name() + << " because there are multiple parameter GTEs for this slice"; + return std::nullopt; + } + } + const HloInstruction* dyn_update_idx = dyn_update->operand( + dyn_update->first_index_operand_number() + *sliced_dim); + if (level_to_operate_on == 0 && + !CheckParameterUsageIsCompatible(to_insert_into, dyn_update, + dyn_update_idx, *sliced_dim)) { + VLOG(5) << "Skipping " << instr->name() + << " because parameter usage doesn't follow the expected pattern"; + return std::nullopt; + } + if (!AllIndicesConstantsExceptOne( + dyn_update, + dyn_update->first_index_operand_number() + *sliced_dim)) { + VLOG(5) << "Skipping " << instr->name() + << " because update slicing doesn't match expectation"; + return std::nullopt; + } + if (!CheckIndexIsMonotonic(dyn_update_idx, index_ranges)) { + VLOG(5) << "Skipping " << instr->name() + << " because update index is not monotonic"; + return std::nullopt; + } + } + std::optional output_idx = FindOutputIndexForDynamicUpdateSlice( + dyn_update, while_body->root_instruction()); + if (!output_idx.has_value()) { + VLOG(5) << "Skipping " << instr->name() + << " because couldn't find unique output index for insertion"; + return std::nullopt; + } + return std::make_pair(*sliced_dim, *output_idx); +} + void WhileLoopAnalysis::CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, @@ -927,100 +1062,15 @@ void WhileLoopAnalysis::CollectCollectivesToMove( "computation"; continue; } - std::optional sliced_dim = GetSlicedDimension(dyn_update); - if (!sliced_dim.has_value()) { - VLOG(5) << "Skipping " << instr->name() - << " because couldn't find sliced dimension"; - continue; - } - if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink && - (*sliced_dim != 0 || dyn_update->shape().dimensions(0) != - loop_iteration_count_->GetUnsignedValue())) { - VLOG(5) << "Skipping " << instr->name() - << " because number of iteration of the loop doesn't match " - "slices being inserted or slice dim is not 0. slice_dim = " - << *sliced_dim << " loop count = " - << loop_iteration_count_->GetUnsignedValue(); - } - if (!process_different_sized_options_) { - if (!formatting_ops.empty()) { - if (instr->operand(0)->shape() != formatting_ops.back()->shape()) { - continue; - } - auto dependencies_to_pipeline = CollectDependenciesToPipeline( - instr, absl::MakeConstSpan(formatting_ops)); - bool skip_because_not_same_size = false; - // If any instruction in the dependency chain is not of the same size - // then we abort for this instruction. - for (auto* dependency : dependencies_to_pipeline) { - if (ShapeUtil::IsEffectiveScalar(dependency->shape())) { - skip_because_not_same_size = true; - break; - } - } - if (skip_because_not_same_size) { - continue; - } - } else if (instr->operand(0)->shape() != instr->shape()) { - continue; - } - } - const HloInstruction* to_insert_into = dyn_update->operand(0); - if (level_to_operate_on == 0 && - (to_insert_into->opcode() != HloOpcode::kGetTupleElement || - to_insert_into->operand(0) != loop_parameter)) { - VLOG(5) << "Skipping " << instr->name() - << " because slice to insert into is not a GTE from input " - "parameter " - << to_insert_into->ToString(); - continue; - } - if (dyn_update->user_count() != 1) { - continue; - } - // If Level is > 0 then we already did our analysis in the previous - // iteration for safeness of this index to transform. - if (level_to_operate_on == 0) { - if (to_insert_into->opcode() == HloOpcode::kGetTupleElement) { - // GTE for this parameter is not CSEd. Abort because we don't analyze - // every single use from other GTEs. - if (parameter_gtes_count.at(to_insert_into->tuple_index()) != 1) { - VLOG(5) - << "Skipping " << instr->name() - << " because there are multiple parameter GTEs for this slice"; - continue; - } - } - HloInstruction* dyn_update_idx = dyn_update->mutable_operand( - dyn_update->first_index_operand_number() + *sliced_dim); - if (level_to_operate_on == 0 && - !CheckParameterUsageIsCompatible(to_insert_into, dyn_update, - dyn_update_idx, *sliced_dim)) { - VLOG(5) - << "Skipping " << instr->name() - << " because parameter usage doesn't follow the expected pattern"; - continue; - } - if (!AllIndicesConstantsExceptOne( - dyn_update, - dyn_update->first_index_operand_number() + *sliced_dim)) { - VLOG(5) << "Skipping " << instr->name() - << " because update slicing doesn't match expectation"; - continue; - } - if (!CheckIndexIsMonotonic(dyn_update_idx, index_ranges)) { - VLOG(5) << "Skipping " << instr->name() - << " because update index is not monotonic"; - continue; - } - } - std::optional output_idx = FindOutputIndexForDynamicUpdateSlice( - dyn_update, while_body->root_instruction()); - if (!output_idx.has_value()) { - VLOG(5) << "Skipping " << instr->name() - << " because couldn't find unique output index for insertion"; + std::optional> maybe_dus_info = + IsSupportedDynamicUpdateSlice(dyn_update, instr, formatting_ops, + direction, level_to_operate_on, + parameter_gtes_count, index_ranges); + if (!maybe_dus_info.has_value()) { continue; } + int64_t sliced_dim = maybe_dus_info->first; + int64_t output_idx = maybe_dus_info->second; auto merge_as_formatting = [this, &instruction_order]( absl::flat_hash_map::iterator it, @@ -1057,7 +1107,7 @@ void WhileLoopAnalysis::CollectCollectivesToMove( } index_per_dyn_update_slice[dyn_update] = move_infos_.size(); move_infos_.push_back({instr, dyn_update, std::move(formatting_ops), - *sliced_dim, *output_idx}); + sliced_dim, output_idx}); } else { CHECK_EQ(direction, CollectivePipeliner::PipeliningDirection::kBackward); auto chain_collected = CollectChainsToPushBackwards( diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index ad088e6f15e8bb..de54c93c4fc8b1 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -2863,7 +2863,8 @@ absl::StatusOr ShardingPropagation::Run( // Instructions that are related through a computation and need to share the // same sharding. - auto get_related_instructions = [this](HloInstruction* inst) { + auto get_related_instructions = [this, + &computation_map](HloInstruction* inst) { if (inst->opcode() == HloOpcode::kWhile) { return std::vector{ inst, inst->while_body()->root_instruction(), @@ -2887,6 +2888,18 @@ absl::StatusOr ShardingPropagation::Run( } else if (inst->opcode() == HloOpcode::kCall) { HloComputation* callee = inst->called_computations().front(); return std::vector{inst, callee->root_instruction()}; + } else if (inst->opcode() == HloOpcode::kParameter) { + auto it = computation_map.find(inst->parent()); + if (it != computation_map.end() && + it->second->opcode() == HloOpcode::kConditional) { + HloInstruction* cond = it->second; + for (int64_t i = 1; i < cond->operand_count(); ++i) { + if (cond->called_computations()[i - 1] == inst->parent()) { + return std::vector{inst, cond->mutable_operand(i)}; + } + } + } + return std::vector{}; } else { CHECK(false); } @@ -2927,6 +2940,11 @@ absl::StatusOr ShardingPropagation::Run( auto it = computation_map.find(instruction->parent()); if (it != computation_map.end()) { propagate_to_instruction(it->second); + // Propagate parameter shardings back to conditional's operands. + if (instruction->opcode() == HloOpcode::kParameter && + it->second->opcode() == HloOpcode::kConditional) { + propagate_to_instruction(instruction); + } } } }; diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 554b7a035a4f74..0123e14ec535f1 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -10732,6 +10732,58 @@ ENTRY entry { "{{devices=[1,8]0,1,2,3,4,5,6,7}, {devices=[1,8]0,1,2,3,4,5,6,7}}")); } +TEST_F(ShardingPropagationTest, ConditionalManual) { + const char* const hlo_string = R"( +HloModule module + +%true_comp { + %tp = (f32[3,5], f32[]) parameter(0) + %tgte.0 = f32[3,5] get-tuple-element(%tp), index=0 + %tgte.1 = f32[] get-tuple-element(%tp), index=1 + %ttr = f32[5,3] transpose(%tgte.0), dimensions={1,0} + + %broadcast.1 = f32[5,3] broadcast(%tgte.1), dimensions={} + %add.1 = f32[5,3] add(%broadcast.1, %ttr) + + ROOT %tr = (f32[5,3], f32[]) tuple(%add.1, %tgte.1) +} + +%false_comp { + %fp = (f32[5,3], f32[5,3], f32[]) parameter(0) + %fgte.0 = f32[5,3] get-tuple-element(%fp), index=0 + %fgte.1 = f32[] get-tuple-element(%fp), index=2 + ROOT %fr = (f32[5,3], f32[]) tuple(%fgte.0, %fgte.1) +} + +ENTRY entry { + %cond = pred[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + %tp.0 = f32[3,5] parameter(1), sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}} + %fp.0 = f32[5,3] parameter(2), sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}} + %const0 = f32[] constant(0) + %const1 = f32[] constant(1) + %true_param = (f32[3,5], f32[]) tuple(%tp.0, %const0) + %false_param = (f32[5,3], f32[5,3], f32[]) tuple(%fp.0, fp.0, %const1) + ROOT %conditional = (f32[5,3], f32[]) conditional( + %cond, %true_param, %false_param), + true_computation=%true_comp, + false_computation=%false_comp +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + auto* tp = FindInstruction(module.get(), "tp"); + auto* true_param = FindInstruction(module.get(), "true_param"); + EXPECT_EQ(tp->sharding(), true_param->sharding()); + auto* fp = FindInstruction(module.get(), "fp"); + auto* false_param = FindInstruction(module.get(), "false_param"); + EXPECT_EQ(fp->sharding(), false_param->sharding()); +} + TEST_F(ShardingPropagationTest, PropagateToOutput) { const char* const hlo_string = R"( HloModule module diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc index 4edaa0a8ac3a53..9e4ab68fe9d9f2 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc @@ -241,6 +241,22 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( static auto* all_ids_ = new std::vector(); if (!create_new_pool_) { DCHECK(all_pools_->size() == all_ids_->size()); + + // If the pool_ is found in all_pools_, it means it has been initialized + // before. This can happen in some cases, such as when multiple virtual + // devices are created from one physical GPU, the virtual devices will + // actually share the same CUDA memory pool. So the following pool + // initialization steps should be skipped to avoid duplicated initialization + // of the same pool. + for (auto& pool_item_ : *all_pools_) { + if (*pool_item_ == pool_) { + VLOG(2) << Name() + << " GpuCudaMallocAsyncAllocator pool already initialized. " + "PoolSize " + << pool_size; + return; + } + } for (int i = 0; i < all_pools_->size(); ++i) { // Set the current pool access to the previous GPUs. CUmemAccessDesc map; @@ -273,9 +289,10 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( // Set the previous pools access to the current GPU. map.location.id = platform_device_id.value(); - VLOG(2) << "Set access to the pool id: " << i + int previous_pool_id = (*all_ids_)[i].value(); + VLOG(2) << "Set access to the pool id: " << previous_pool_id << " location id: " << map.location.id; - if (auto status = cuDeviceCanAccessPeer(&canAccessPeer, i, + if (auto status = cuDeviceCanAccessPeer(&canAccessPeer, previous_pool_id, platform_device_id.value())) { pool_ = nullptr; LOG(FATAL) // Crash OK. @@ -285,8 +302,8 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( if (auto status = cuMemPoolSetAccess(*(*all_pools_)[i], &map, 1)) { pool_ = nullptr; LOG(FATAL) // Crash OK. - << "Error when setting access to the pool id: " << i - << " location id: " << map.location.id + << "Error when setting access to the pool id: " + << previous_pool_id << " location id: " << map.location.id << " error: " << GetCudaErrorMessage(status); } } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc index 0b155be7f91ced..6211d35bc64241 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc @@ -47,6 +47,39 @@ static se::StreamExecutor* GpuExecutor() { namespace stream_executor { +TEST(GpuCudaMallocAsyncAllocator, TwoAllocatorsShareDefaultPool) { +#if CUDA_VERSION < 11030 + GTEST_SKIP() << "Cuda async memory allocator is not supported for CUDA " + "version less than 11030"; +#endif + + se::StreamExecutor* executor = GpuExecutor(); + TF_ASSERT_OK_AND_ASSIGN(auto stream1, executor->CreateStream()); + auto allocator1 = GpuCudaMallocAsyncAllocator( + /*platform_device_id*/ tsl::PlatformDeviceId(executor->device_ordinal()), + /*pool_size*/ 2048, + /*new_pool_size*/ true, + /*release_threshold*/ true); + allocator1.SetStreamAndPreallocateMemory( + se::gpu::AsGpuStreamValue(stream1.get())); + TF_ASSERT_OK_AND_ASSIGN(auto stream2, executor->CreateStream()); + auto allocator2 = GpuCudaMallocAsyncAllocator( + /*platform_device_id*/ tsl::PlatformDeviceId(executor->device_ordinal()), + /*pool_size*/ 2048, + /*new_pool_size*/ true, + /*release_threshold*/ true); + allocator2.SetStreamAndPreallocateMemory( + se::gpu::AsGpuStreamValue(stream2.get())); + void* addr1 = allocator1.AllocateRaw(128, 127); + void* addr2 = allocator2.AllocateRaw(128, 129); + CHECK_EQ((reinterpret_cast(addr1) & 127), 0); + CHECK_EQ((reinterpret_cast(addr2) & 127), 0); + allocator1.DeallocateRaw(addr1); + allocator2.DeallocateRaw(addr2); + EXPECT_TRUE(stream1->ok()); + EXPECT_TRUE(stream2->ok()); +} + TEST(GpuCudaMallocAsyncAllocator, AddressAlignedDefaultPool) { #if CUDA_VERSION < 11030 GTEST_SKIP() << "Cuda async memory allocator is not supported for CUDA " diff --git a/third_party/xla/xla/tsl/framework/serving_device_selector.h b/third_party/xla/xla/tsl/framework/serving_device_selector.h index 827f4a269932b6..7baa9d338dccf6 100644 --- a/third_party/xla/xla/tsl/framework/serving_device_selector.h +++ b/third_party/xla/xla/tsl/framework/serving_device_selector.h @@ -152,6 +152,14 @@ class ServingDeviceSelector { virtual DeviceReservation ReserveDevice( absl::string_view program_fingerprint) = 0; + // Enqueues a program on the given device. Used only for load tracking + // purposes when the device selection feature is unused. + virtual void Enqueue(int32_t device_index, absl::string_view fingerprint) = 0; + + // Marks the completion of a program on the given device. Used only for load + // tracking purposes when the device selection feature is unused. + virtual void Completed(int32_t device_index, bool had_error) = 0; + protected: // A helper function for Enqueue. The EnqueueHelper does the following things. // 1. If there are programs in the scheduled_programs queue of the given diff --git a/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h b/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h index 71627bab4c39a0..80add74bbd413e 100644 --- a/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h +++ b/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h @@ -29,6 +29,8 @@ class MockServingDeviceSelector : public tsl::ServingDeviceSelector { public: MOCK_METHOD(tsl::DeviceReservation, ReserveDevice, (absl::string_view), (override)); + MOCK_METHOD(void, Enqueue, (int32_t, absl::string_view), (override)); + MOCK_METHOD(void, Completed, (int32_t, bool), (override)); MOCK_METHOD(void, FreeDeviceReservation, (const tsl::DeviceReservation&), (override)); };