diff --git a/.bazelrc b/.bazelrc index d7ae76f096431a..1a47809a8f9753 100644 --- a/.bazelrc +++ b/.bazelrc @@ -251,7 +251,7 @@ build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang" +build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" @@ -602,15 +602,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium -# Ensure release_base is set on linux -build:release_linux_base --config=release_base - -# Target the AVX instruction set -build:release_linux_base --config=avx_linux - # Enable support for all targets build:release_base --config=cpu_cross +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Disable clang extension that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within @@ -633,8 +630,8 @@ build:release_linux_base --action_env PYTHON_BIN_PATH="/usr/bin/python3" build:release_linux_base --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" build:release_linux_base --python_path="/usr/bin/python3" # Set Clang as compiler. Use the actual path to clang installed in container. -build:release_cpu_linux_base --repo_env=CC="/usr/lib/llvm-17/bin/clang" -build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-17/bin/clang" +build:release_cpu_linux_base --repo_env=CC="/usr/lib/llvm-18/bin/clang" +build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/clang" # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS @@ -645,6 +642,8 @@ test:release_linux_base --test_summary=short # Use the Clang toolchain to compile build:release_cpu_linux --config=release_linux_base build:release_cpu_linux --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" +# Target the AVX instruction set +build:release_cpu_linux --config=avx_linux build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. diff --git a/ci/official/containers/linux_arm64/cuda.packages.txt b/ci/official/containers/linux_arm64/cuda.packages.txt index 657793f1825252..91268178785847 100644 --- a/ci/official/containers/linux_arm64/cuda.packages.txt +++ b/ci/official/containers/linux_arm64/cuda.packages.txt @@ -1,6 +1,8 @@ # CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation libcudnn8=8.9.6.50-1+cuda12.2 libcudnn8-dev=8.9.6.50-1+cuda12.2 +libnvinfer-headers-dev=8.6.1.6-1+cuda12.0 +libnvinfer-headers-plugin-dev=8.6.1.6-1+cuda12.0 # This can be removed once NVIDIA publishes a cuda-12.3.2 Docker image. # For now it ensures that we install at least version 12.3.107 of PTXAS, diff --git a/ci/official/containers/linux_arm64/devel.packages.txt b/ci/official/containers/linux_arm64/devel.packages.txt index a8a9cb442c8b0b..61c7a97f1daf0f 100644 --- a/ci/official/containers/linux_arm64/devel.packages.txt +++ b/ci/official/containers/linux_arm64/devel.packages.txt @@ -3,10 +3,10 @@ autoconf automake build-essential ca-certificates -llvm-17 -clang-17 +llvm-18 +clang-18 clang-format-12 -lld-17 +lld-18 colordiff curl ffmpeg @@ -18,7 +18,7 @@ libcurl3-dev libcurl4-openssl-dev libfreetype6-dev libhdf5-serial-dev -libomp-17-dev +libomp-18-dev libssl-dev libtool libxml2-dev @@ -26,8 +26,8 @@ libxslt1-dev libzmq3-dev mlocate moreutils -openjdk-11-jdk -openjdk-11-jre-headless +openjdk-21-jdk +openjdk-21-jre-headless openssl patchelf pkg-config diff --git a/ci/official/containers/linux_arm64/setup.sources.sh b/ci/official/containers/linux_arm64/setup.sources.sh index ea8dc376f67065..f8c87d4ceade60 100755 --- a/ci/official/containers/linux_arm64/setup.sources.sh +++ b/ci/official/containers/linux_arm64/setup.sources.sh @@ -39,7 +39,7 @@ cat >/etc/apt/sources.list.d/custom.list <(builder.getUnknownLoc()); } - tsl::Status GetDumpDir(std::string* dump_dir) { + absl::Status GetDumpDir(std::string* dump_dir) { std::vector files; if (auto status = tsl::Env::Default()->GetChildren(path_, &files); !status.ok()) { @@ -131,7 +131,7 @@ class InitPassManagerTest : public testing::Test { "Expecting directory to have one child."); } *dump_dir = tsl::io::JoinPath(path_, files[0]); - return tsl::OkStatus(); + return absl::OkStatus(); } std::string path_; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 30a19e04fb368a..8ac81939d0d4de 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -578,6 +578,14 @@ inline bool IsBF16ShapedType(Type t) { return false; } +// Returns true if it is a shaped type of FloatType elements. +inline bool IsFloatShapedType(Type t) { + if (auto shaped_type = t.dyn_cast_or_null()) { + return shaped_type.getElementType().isa(); + } + return false; +} + // Returns new shape with rank 'new_dims' with padded ones on the // left if needed. inline std::vector GetPaddedShape(ArrayRef old_shape, @@ -3069,6 +3077,50 @@ OpFoldResult SquareOp::fold(FoldAdaptor adaptor) { return ConstFoldUnaryOp(result_type, operands[0], compute); } +//===----------------------------------------------------------------------===// +// MaximumOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaximumOp::fold(FoldAdaptor adaptor) { + auto lhs_type = getLhs().getType(); + auto rhs_type = getRhs().getType(); + // Only constant fold for float tensors of the same type is implemented. + if (lhs_type != rhs_type || !IsFloatShapedType(lhs_type)) return nullptr; + + auto lhs = adaptor.getLhs().dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); + if (lhs && lhs.isSplat()) { + APFloat lhs_value = lhs.getSplatValue(); + lhs_value.changeSign(); + if (lhs_value.isLargest()) return getRhs(); + } + if (rhs && rhs.isSplat()) { + APFloat rhs_value = rhs.getSplatValue(); + rhs_value.changeSign(); + if (rhs_value.isLargest()) return getLhs(); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// +// MinimumOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinimumOp::fold(FoldAdaptor adaptor) { + auto lhs_type = getLhs().getType(); + auto rhs_type = getRhs().getType(); + // Only constant fold for float tensors of the same type is implemented. + if (lhs_type != rhs_type || !IsFloatShapedType(lhs_type)) return nullptr; + + auto lhs = adaptor.getLhs().dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); + if (lhs && lhs.isSplat() && lhs.getSplatValue().isLargest()) + return getRhs(); + if (rhs && rhs.isSplat() && rhs.getSplatValue().isLargest()) + return getLhs(); + return nullptr; +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 481f5573058b8c..5f4cce6d8e8a76 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -2269,6 +2269,8 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max ); + let hasFolder = 1; + let builders = [TFL_BroadcastableBinaryBuilder]; let hasOptions = 0; @@ -2528,6 +2530,8 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min ); + let hasFolder = 1; + let builders = [TFL_BroadcastableBinaryBuilder]; let hasOptions = 0; diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 3e50192fa0640d..16e12bbb6da04d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -124,7 +124,7 @@ Status HandleInputOutputArraysWithModule( ") does not exist in the given graph"); } } - return OkStatus(); + return absl::OkStatus(); } Status ConvertSavedModelToTFLiteFlatBuffer( diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index ba9c1e58565286..f244d15294c253 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -181,6 +181,46 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te func.return %7, %8, %9, %10, %11, %12, %13 : tensor, tensor, tensor, tensor, tensor, tensor, tensor } +// CHECK-LABEL: @max_with_neg_f32_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @max_with_neg_f32_max_val(%arg0 : tensor) -> (tensor, tensor) { + %neg_f32_max = arith.constant dense<-3.40282347E+38> : tensor + %0 = "tfl.maximum"(%arg0, %neg_f32_max) : (tensor, tensor) -> tensor + %1 = "tfl.maximum"(%neg_f32_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + +// CHECK-LABEL: @min_with_f32_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @min_with_f32_max_val(%arg0 : tensor) -> (tensor, tensor) { + %f32_max = arith.constant dense<3.40282347E+38> : tensor + %0 = "tfl.minimum"(%arg0, %f32_max) : (tensor, tensor) -> tensor + %1 = "tfl.minimum"(%f32_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + +// CHECK-LABEL: @max_with_neg_f64_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @max_with_neg_f64_max_val(%arg0 : tensor) -> (tensor, tensor) { + %neg_f64_max = arith.constant dense<-1.7976931348623157E+308> : tensor + %0 = "tfl.maximum"(%arg0, %neg_f64_max) : (tensor, tensor) -> tensor + %1 = "tfl.maximum"(%neg_f64_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + +// CHECK-LABEL: @min_with_f64_max_val +// CHECK-SAME: (%[[ARG0:.+]]: tensor) +func.func @min_with_f64_max_val(%arg0 : tensor) -> (tensor, tensor) { + %f64_max = arith.constant dense<1.7976931348623157E+308> : tensor + %0 = "tfl.minimum"(%arg0, %f64_max) : (tensor, tensor) -> tensor + %1 = "tfl.minimum"(%f64_max, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + // CHECK-LABEL: @mul_int func.func @mul_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %0 = arith.constant dense<8> : tensor diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc index 2ed12a34059588..e212ce16ee6ccd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc @@ -219,7 +219,7 @@ class LiftFlexCustomOp : public OpRewritePattern { for (const auto& name_and_value : node_def.attr()) { const std::string& attr_name = name_and_value.first; const tensorflow::AttrValue& attr_value = name_and_value.second; - StatusOr mlir_attr = + absl::StatusOr mlir_attr = tensorflow::ConvertAttributeValue(attr_value, &builder); if (!mlir_attr.ok()) { return emitError(loc, mlir_attr.status().message()); diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index da122b67993af7..5faa358598811c 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -145,6 +145,7 @@ cc_library( deps = [ ":uniform_quantized_types", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", @@ -155,6 +156,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index 540eff26685968..e116341eb79f71 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -34,12 +34,16 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tsl/platform/protobuf.h" namespace mlir::quant { using ::mlir::stablehlo::DotGeneralOp; +using ::stablehlo::quantization::Method; +using ::tsl::protobuf::TextFormat; bool HasStaticShape(Value value) { auto shaped_type = value.getType().dyn_cast(); @@ -174,4 +178,29 @@ std::optional GetDotGeneralQuantizationDim( return filter_rank - 1; } +bool HasWeightOnlyPtqMethod(Operation& op) { + if (auto quantization_method_txtpb = + op.getAttrOfType(kQuantizationMethodAttr)) { + Method method; + if (TextFormat::ParseFromString(quantization_method_txtpb.getValue().str(), + &method)) { + return method.has_weight_only_ptq(); + } + } + return false; +} + +bool IsWeightOnlyQuantizableOp(const Operation& op) { + if (auto call_op = dyn_cast(op)) { + StringRef entry_function_name = GetEntryFunctionName(call_op); + return ContainsConvOrDot(entry_function_name) && + HasWeightOnlyPtqMethod(*call_op); + } + return false; +} + +bool ContainsConvOrDot(StringRef str) { + return str.contains("conv") || str.contains("dot_general"); +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 490a77a3b73ffa..dfbe3c2d45e267 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -42,6 +42,10 @@ namespace mlir::quant { constexpr char kAttrMapAttribute[] = "attr_map"; +// Name of the string attribute attached to `XlaCallModuleOp`, which is the +// textproto representation of `Method`. +inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; + // Permutation from the NHWC tensor format to NCHW. This is an inverse // permutation of `kNchwToNhwcPermutation`. inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; @@ -248,6 +252,16 @@ absl::StatusOr IsDotGeneralFullyConnected( std::optional GetDotGeneralQuantizationDim( ::mlir::stablehlo::DotGeneralOp dot_general_op); +// Checks if the `Method` attatched to the given op has `WeightOnlyPtq`. +bool HasWeightOnlyPtqMethod(Operation& op); + +// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general' +// in its name and has `Method` with `WeightOnlyPtq`. +bool IsWeightOnlyQuantizableOp(const Operation& op); + +// Checks if a `StringRef` contains 'conv' or 'dot_general'. +bool ContainsConvOrDot(StringRef str); + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index ca0df77f81b51c..041ce43eba20ac 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -98,6 +98,21 @@ constexpr absl::string_view kModuleXlaCallModule = R"mlir( } )mlir"; +constexpr absl::string_view kModuleDotWeightOnlyPtq = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } +)mlir"; + constexpr absl::string_view kModuleXlaCallModuleNoEntryNoQuantTrait = R"mlir( module { func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { @@ -526,5 +541,142 @@ TEST_F(AttrsAndConstraintsTest, DotGeneralBatchMatmulReturnsNullQuantDim) { EXPECT_THAT(GetDotGeneralQuantizationDim(dot_general_op), Eq(std::nullopt)); } +TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodExists) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_TRUE(HasWeightOnlyPtqMethod(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodDifferentMethod) { + const absl::string_view kModuleDotNoQuantization = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "no_quantization { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } + )mlir"; + OwningOpRef module_op = + ParseModuleOpString(kModuleDotNoQuantization); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(HasWeightOnlyPtqMethod(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, HasWeightOnlyPtqMethodNoMethod) { + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(HasWeightOnlyPtqMethod(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpDot) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_TRUE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpNotTfXlaCallModuleOp) { + const absl::string_view kModulePartitionedCallDot = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.PartitionedCall"(%arg0, %1, %0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_dot_general_fn_1, _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x2xf32>) -> tensor + return %0 : tensor + } + } + )mlir"; + OwningOpRef module_op = + ParseModuleOpString(kModulePartitionedCallDot); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, IsWeightOnlyQuantizableOpNoConvNoDot) { + constexpr absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable", _quantization_method = "weight_only_ptq { }"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + return %arg0 : tensor + } + } + )mlir"; + OwningOpRef module_op = ParseModuleOpString(kModuleXlaCallModule); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + EXPECT_FALSE(IsWeightOnlyQuantizableOp(*call_op)); +} + +TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotTrue) { + OwningOpRef module_op = + ParseModuleOpString(kModuleDotWeightOnlyPtq); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + const StringRef function_name = GetEntryFunctionName(call_op); + EXPECT_TRUE(ContainsConvOrDot(function_name)); +} + +TEST_F(AttrsAndConstraintsTest, ContainsConvOrDotFalse) { + OwningOpRef module_op = + ParseModuleOpString(kModuleXlaCallModuleNoEntryNoQuantTrait); + ASSERT_TRUE(module_op); + + func::FuncOp main_fn = FindMainFuncOp(*module_op); + ASSERT_THAT(main_fn, NotNull()); + + auto call_op = *main_fn.getOps().begin(); + const StringRef function_name = GetEntryFunctionName(call_op); + EXPECT_FALSE(ContainsConvOrDot(function_name)); +} + } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index bfef9a13df1a01..655eb3103eab54 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -43,10 +43,6 @@ constexpr StringRef kCompositeFuncPrefix = "composite_"; inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = "_original_entry_function"; -// Name of the string attribute attached to `XlaCallModuleOp`, which is the -// textproto representation of `Method`. -inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; - // FunctionCallOpType to be generated as the function call operator when // function lifting will happen. enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 622ff502c01ed9..ea20a8875ded5a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -64,7 +64,6 @@ void AddPostCalibrationPasses(OpPassManager& pm, options.enable_per_channel_quantized_weight_ = true; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; - options.enable_weight_only_ = false; options.merge_fusion_with_dequantize_ = pipeline_config.merge_fusion_with_dequantize(); @@ -101,7 +100,6 @@ void AddWeightOnlyQuantizationPasses( QuantizeCompositeFunctionsPassOptions options; // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; - options.enable_weight_only_ = true; pm.addPass(createQuantizeCompositeFunctionsPass(options)); // Add an inliner pass to inline quantized StableHLO functions. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 6f95f578eacfa3..001ece707cfe90 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -40,7 +40,7 @@ absl::StatusOr PostCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ - [&config, this](PassManager& pm) { + [&config](PassManager& pm) { AddPostCalibrationPasses(pm, config.pipeline_config(), config.specs()); }, *ctx_, module_op)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc index a423bdc5f80142..6143b21eec32cd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc @@ -38,7 +38,7 @@ absl::StatusOr PreCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { TF_RETURN_IF_ERROR(RunPasses( kName, /*add_passes_func=*/ - [&config, this](PassManager& pm) { + [&config](PassManager& pm) { AddPreCalibrationPasses(pm, config.calibration_options(), config.specs(), config.debugger_config()); }, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 9fb1e9e985d15e..d785cd5bf4d970 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -97,23 +97,13 @@ class InsertWeightParamPattern return false; } Operation* user = operand.getOwner(); - if (isa(user)) { - auto call_op = cast(user); - const StringRef function_name = GetEntryFunctionName(call_op); - const bool is_conv_or_dot = function_name.contains("conv") || - function_name.contains("dot_general"); - const bool has_quant_trait = HasQuantizableTrait(call_op); - return is_conv_or_dot && has_quant_trait; - } - return false; + return IsWeightOnlyQuantizableOp(*user); } void rewrite(Operation* op, PatternRewriter& rewriter) const override { Operation* quantizable_op = *op->getUsers().begin(); DenseFPElementsAttr attr; - if (!matchPattern(op->getResult(0), m_Constant(&attr))) { - return; - } + matchPattern(op->getResult(0), m_Constant(&attr)); auto quant_type = quant::GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/false, /*num_bits=*/8, /*is_signed=*/true, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index fdb7fa7941f025..1ca1738566948a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -63,10 +63,6 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function Option<"mlir_dump_file_name_", "mlir-dump-file-name", "std::optional", /*default=*/"std::nullopt", "MLIR dump file name.">, - Option<"enable_weight_only_", - "enable-weight-only", - "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for convolution and dot_general op.">, Option<"merge_fusion_with_dequantize_", "merge-fusion-with-dequantize", "bool", /*default=*/"false", @@ -106,10 +102,6 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, - Option<"enable_weight_only_", - "enable-weight-only", - "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 5578aede7dee3c..dd4ae2ca9ba7b2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -269,7 +269,8 @@ class EntryFuncBodyQuantizationPattern { // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At // this point `entry_func_op`'s signature has not been reset with quantized // types. - virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; + virtual LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const = 0; // Rewrites the `entry_func_op`'s body. virtual void rewrite(func::FuncOp entry_func_op, @@ -408,19 +409,20 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeDotGeneralOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + const bool enable_per_channel_quantized_weight) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - if (enable_weight_only_) return; DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); const bool should_quantize_per_channel = enable_per_channel_quantized_weight_ && @@ -433,28 +435,26 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeConvolutionOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + const bool enable_per_channel_quantized_weight) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { - if (enable_weight_only_) return; RewriteGemmStyleOp( entry_func_op, rewriter, enable_per_channel_quantized_weight_ && @@ -482,19 +482,42 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; +}; + +// Quantizes the entry function's body for weight-only quantized op. +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_weight_only_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override {} + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; }; template class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeSingularOpPattern( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) {} + const bool enable_per_channel_quantized_weight) {} - LogicalResult match(func::FuncOp entry_func_op) const override { + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { const auto op_iterator_range = entry_func_op.getOps(); if (op_iterator_range.empty()) { LLVM_DEBUG(llvm::dbgs() << "Function does not have " @@ -606,12 +629,10 @@ template { public: explicit XlaCallModuleOpToCallOp( - MLIRContext& ctx, const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) + MLIRContext& ctx, const bool enable_per_channel_quantized_weight) : OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight), - enable_weight_only_(enable_weight_only) {} + enable_per_channel_quantized_weight) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); @@ -625,7 +646,7 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { if (!IsQuantizedXlaCallModuleOp(op)) return failure(); // For weight-only quantization, op should be hybrid quantized. - if (enable_weight_only_ && !IsHybridQuantizedOp(op)) { + if (HasWeightOnlyPtqMethod(*op) && !IsHybridQuantizedOp(op)) { return failure(); } @@ -634,10 +655,9 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { op->emitError("Failed to find a valid entry function."); return failure(); } - - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, - enable_weight_only_) - .match(entry_func_op); + Method quantization_method = GetQuantizationMethodOrDefault(op); + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, @@ -650,8 +670,7 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, - enable_weight_only_), + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), quantization_method); } @@ -659,9 +678,6 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; - // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform - // weight-only quantization. - const bool enable_weight_only_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -937,20 +953,6 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { return false; } -template -class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeWeightOnlyOpPattern( - const bool enable_per_channel_quantized_weight) {} - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, - PatternRewriter& rewriter) const override {} -}; - // Compute heavy patterns should be quantized for both server and ODML targets. // Most patterns here are useful when quantized since they are compute heavy // or memory bound. @@ -958,13 +960,18 @@ void PopulateCommonQuantizationPatterns( MLIRContext& ctx, RewritePatternSet& patterns, const bool enable_per_channel_quantized_weight) { patterns.add>( - ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + ctx, enable_per_channel_quantized_weight); patterns.add>( - ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); // TODO: b/307620772 - Per-channel quantization for gather. patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false, - /*enable_weight_only=*/false); + ctx, /*enable_per_channel_quantized_weight=*/false); // Populate pattern for quantization of ops with regions such as // `stablehlo.reduce_window` op. patterns.add(ctx); @@ -973,16 +980,7 @@ void PopulateCommonQuantizationPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns) { patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false, - /*enable_weight_only=*/false); -} - -void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add, - XlaCallModuleOpToCallOp>( - ctx, /*enable_per_channel_quantized_weight*/ false, - /*enable_weight_only=*/true); + ctx, /*enable_per_channel_quantized_weight=*/false); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 67eb267c1d9037..b8ebe592c41f21 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" #include "tensorflow/core/framework/types.pb.h" @@ -59,18 +60,8 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); // quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. // Each matched pattern are rewritten by its quantized alternatives. // -// The concrete pattern, extends from this base pattern, can specify whether it -// allows weight-only quantization. If it is allowed, for operand/result that is -// not adjacent to dequantize/quantize op, it remains as float. For -// operand/result that is adjacent to dequantize/quantize, it is quantized. -// Weight-only quantization can be used to generate both weight-only -// quantization and dynamic range quantization. The condition for allowing -// weight-only quantization or not for an op can be specified in the below -// function: -// -// static bool AllowWeightOnlyQuantization(Operation& op) -// -// This is a templatized `OpRewritePattern`. +// Quantization method is determined by the `_quantization_method` attributes +// attached to each quantizable units. // // Template constraints are imposed as follows: // @@ -159,6 +150,9 @@ class StableHloQuantizationPattern : public OpRewritePattern { return failure(); } + const bool weight_only_quantizable = + IsWeightOnlyQuantizableOp(*candidate_op); + // Collect all the quantized inputs and "clone" the matched op by these // inputs. SmallVector inputs; @@ -178,8 +172,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { // If the operand is an integer tensor, then it doesn't require the // DequantizeOp in the pattern. inputs.push_back(operand); - } else if (static_cast(this) - ->AllowWeightOnlyQuantization(*candidate_op)) { + } else if (weight_only_quantizable) { inputs.push_back(operand); } else { return failure(); @@ -215,8 +208,7 @@ class StableHloQuantizationPattern : public OpRewritePattern { // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); - } else if (static_cast(this) - ->AllowWeightOnlyQuantization(*candidate_op)) { + } else if (weight_only_quantizable) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); } else { @@ -260,10 +252,6 @@ void PopulateCommonQuantizationPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns); -// Populates pattern weight-only quantization. -void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, - RewritePatternSet& patterns); - } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 0000057402886f..86dbae8e4181f9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -77,35 +77,14 @@ struct StableHloQuantizationReverse quantfork::QuantizeCastOp>(ctx) {} }; -bool IsHybridQuantizableOp(Operation& op) { - auto call_op = cast(op); - if (call_op == nullptr) return false; - StringRef entry_function_name = GetEntryFunctionName(call_op); - return entry_function_name.contains("conv") || - entry_function_name.contains("dot_general"); -} - -// Quantization rewrite pattern using DQ as the root op. -struct StableHloQuantizationWeightOnly - : public StableHloQuantizationBase { - explicit StableHloQuantizationWeightOnly(MLIRContext* ctx) - : StableHloQuantizationBase(ctx) {} - - static bool AllowWeightOnlyQuantization(Operation& op) { - return IsHybridQuantizableOp(op); - } -}; - class QuantizePass : public impl::QuantizePassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizePass) using impl::QuantizePassBase::QuantizePassBase; - explicit QuantizePass(const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + explicit QuantizePass(const bool enable_per_channel_quantized_weight) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_weight_only_ = enable_weight_only; } private: @@ -118,10 +97,6 @@ void QuantizePass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); - if (enable_weight_only_) { - patterns.add(&ctx); - PopulateQuantizeWeightOnlyPatterns(ctx, patterns); - } PopulateCommonQuantizationPatterns(ctx, patterns, enable_per_channel_quantized_weight_); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 1efc5d40c7ce20..0328b02c68c609 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -55,10 +55,8 @@ class QuantizeCompositeFunctionsPass QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; explicit QuantizeCompositeFunctionsPass( - const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + const bool enable_per_channel_quantized_weight) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; - enable_weight_only_ = enable_weight_only; } private: @@ -80,9 +78,10 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { // Change this to user-given bit width once we have custom configuration. options.bit_width_ = 8; - if (enable_weight_only_) { - pm.addNestedPass(createInsertWeightParamPass()); - } + // Insert quantization parameters for weights for ops with `weight_only_ptq` + // attribute. + pm.addNestedPass(createInsertWeightParamPass()); + // PrepareQuantizePass uses SymbolTable to fetch relevant GEMM ops for // determining quantization attributes. This requires module-level context. pm.addPass(createPrepareQuantizePass(options)); @@ -90,7 +89,7 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { QuantizePassOptions quantize_options; quantize_options.enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight_; - quantize_options.enable_weight_only_ = enable_weight_only_; + // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. pm.addPass(createQuantizePass(quantize_options)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir index 89ff96efecf471..a8b694b41b8cc5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/insert_weight_param.mlir @@ -1,14 +1,14 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-insert-weight-param | FileCheck %s // Test that q/dq pair is inserted between constant and XlaCallModule op -// with quantizable trait and function name containing conv. +// with `weight_only_ptq` method and function name containing conv. func.func @qdq_for_conv_weight(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -21,20 +21,20 @@ func.func @qdq_for_conv_weight(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32 // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> -// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x2x2x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x2x2x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = ""} : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> // CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> // ----- // Test that q/dq pair is inserted between constant and XlaCallModule op -// with quantizable trait and function name containing dot_general. +// with `weight_only_ptq` method and function name containing dot_general. func.func @qdq_for_dot_general_weight(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -47,7 +47,7 @@ func.func @qdq_for_dot_general_weight(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> // CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> // CHECK: %[[DQ:.+]] = "quantfork.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> -// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CALL]] : tensor<1x3xf32> // ----- @@ -59,7 +59,7 @@ func.func @no_qdq_except_conv_and_dot_general(%arg0: tensor<2x3x2xi64>) -> tenso %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<3x4x2xf32>} : () -> tensor<3x4x2xf32> %0 = "tf.XlaCallModule"(%cst, %arg0) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_gather_fn, - _original_entry_function = "composite_gather_fn", + _original_entry_function = "composite_gather_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -81,7 +81,7 @@ func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor< %0 = "tf.XlaCallModule"(%arg0, %arg1, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, _original_entry_function = "composite_dot_general_with_bias_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 @@ -96,7 +96,7 @@ func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor< // ----- // Test that q/dq pair is not inserted between constant and XlaCallModule op -// without quantizable trait. +// without `weight_only_ptq` method. func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -116,6 +116,27 @@ func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3 // ----- +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// with different method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], _quantization_method = "static_range_ptq { }", version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantfork.qcast +// CHECK-NOT: quantfork.dcast + +// ----- + // Test that q/dq pair is not inserted when constant has multiple users. func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> attributes {tf._original_func_name = "main_0"} { @@ -123,7 +144,7 @@ func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> %0 = "tf.XlaCallModule"(%arg0, %cst) { Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", - _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64 diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir index 9f2e371fb4d1d6..79f44e10b03e46 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir @@ -40,7 +40,7 @@ module attributes {tf_saved_model.semantics} { // CHECK-LABEL: quantize_simple_xla_call_module_no_operand func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf32> { - %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> @@ -63,7 +63,7 @@ module attributes {tf_saved_model.semantics} { %4 = "quantfork.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // expected-error @+2 {{Failed to find a valid entry function}} // expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %7 : tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir index 15330b0b79b800..a6f0111d2c8293 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize=enable-weight-only=true | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize | FileCheck %s // Test that hybrid quantized dot_general is produced when q/dq pair only exists // for weight. @@ -41,7 +41,7 @@ module attributes {tf_saved_model.semantics} { %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> %0 = "quantfork.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> - %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %2 : tensor<1x3x4x2xf32> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir index dbe192bbb55cde..a614ee0af36adc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir @@ -1,5 +1,5 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions=enable-weight-only=true | FileCheck --check-prefix=CHECK %s +// RUN: -stablehlo-quantize-composite-functions | FileCheck --check-prefix=CHECK %s // Test that weight-only quantized dot_general op is produced when // enable-weight-only is set to true. @@ -37,7 +37,7 @@ module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} { func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> - %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> return %1 : tensor<1x3x4x2xf32> } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index e848e7f99fb3bf..9641e092815b58 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -134,7 +134,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", "//tensorflow/core:framework", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", @@ -394,22 +393,6 @@ cc_library( ], ) -tf_cc_test( - name = "tpu_cluster_formation_test", - srcs = ["tpu_cluster_formation_test.cc"], - deps = [ - ":clustering_passes", - "//tensorflow/compiler/mlir/tf2xla/transforms:test_utils", - "//tensorflow/core/lib/monitoring:cell_reader", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "lowering_passes", hdrs = [ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index 637369ed4fb6fc..b600c865661d58 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -59,7 +59,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" -#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -96,8 +95,6 @@ constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster"; constexpr llvm::StringRef kBadReplicateInfoAttrMsg = "requires '_replication_info' string attribute"; -constexpr char kUseMlirBridge[] = "kUseMlirBridge"; - // Mapping for `_replication_info` attribute to TPUReplicateMetadata attributes. using MetadataMap = llvm::SmallDenseMap; @@ -108,15 +105,6 @@ using OpSetVector = llvm::SmallSetVector; // Mapping for `_replication_info` attribute to ops of a cluster. using ClusterMap = llvm::SmallDenseMap; -auto* jit_compile_single_core_tpu_count = - tensorflow::monitoring::Counter<1>::New( - /* metric name */ - "/tensorflow/core/jit_compile_single_core_tpu_count", - /* metric description */ - "Tracks if single core tpu support goes through the first " - "phase of the MLIR bridge", - /* metric field */ "use_mlir_bridge"); - #define GEN_PASS_DEF_TPUCLUSTERFORMATIONPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" @@ -943,7 +931,7 @@ void SetNoReplicationClusterAttrs(mlir::tf_device::ClusterOp cluster, LogicalResult FormClustersInBlock( Block* block, const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis, - bool strict_clusters, bool& has_replication_in_module) { + bool strict_clusters) { MetadataMap metadata_map; LogicalResult result = CollectMetadata(block, &metadata_map); if (failed(result)) return result; @@ -956,8 +944,7 @@ LogicalResult FormClustersInBlock( if (!llvm::hasSingleElement(region)) return op.emitOpError("Expected single block region"); if (failed(FormClustersInBlock(®ion.front(), side_effect_analysis, - strict_clusters, - has_replication_in_module))) + strict_clusters))) return mlir::failure(); } } @@ -998,7 +985,6 @@ LogicalResult FormClustersInBlock( block, cluster_ops, results, cluster_successor_ops.getArrayRef()); if (!has_replication) { - has_replication_in_module = false; SetNoReplicationClusterAttrs(cluster, device_type, device); continue; } @@ -1034,12 +1020,12 @@ LogicalResult FormClustersInBlock( LogicalResult FormClustersInFunction( mlir::func::FuncOp func, const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis, - bool strict_clusters, bool& has_replication_in_module) { + bool strict_clusters) { if (!llvm::hasSingleElement(func)) return func.emitOpError("Expecting a single block function"); if (failed(FormClustersInBlock(&func.front(), side_effect_analysis, - strict_clusters, has_replication_in_module))) + strict_clusters))) return mlir::failure(); // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. @@ -1091,17 +1077,12 @@ void TPUClusterFormationPass::runOnOperation() { }); auto& side_effect_analysis = getAnalysis(); - bool has_replication_in_module = true; for (auto func : getOperation().getOps()) if (!func.isExternal() && failed(FormClustersInFunction( func, side_effect_analysis.GetAnalysisForFunc(func), - strict_clusters_, has_replication_in_module))) + strict_clusters_))) return signalPassFailure(); - - if (!has_replication_in_module) { - jit_compile_single_core_tpu_count->GetCell(kUseMlirBridge)->IncrementBy(1); - } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation_test.cc deleted file mode 100644 index 640385f0156aae..00000000000000 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation_test.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" -#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" -#include "tensorflow/core/lib/monitoring/cell_reader.h" -#include "tsl/platform/statusor.h" - -namespace tensorflow { -namespace tf2xla { -namespace internal { - -namespace { - -constexpr char kJitCompileSingleCoreTpuCount[] = - "/tensorflow/core/jit_compile_single_core_tpu_count"; -constexpr char kUseMlirBridge[] = "kUseMlirBridge"; -using mlir::mhlo::test::GetMlirModuleFromString; - -class TPUClusterFormationPassTest : public testing::Test { - protected: - void CreateModule(const char* module_string) { - TF_ASSERT_OK_AND_ASSIGN(module_, - GetMlirModuleFromString(module_string, &context_)); - bool strict_clusters = true; - pm_ = std::make_unique(&context_); - pm_->addPass(tensorflow::tf2xla::internal::CreateTPUClusterFormationPass( - strict_clusters)); - } - - mlir::LogicalResult Run() { return pm_->run(module_.get()); } - - private: - mlir::MLIRContext context_; - mlir::OwningOpRef module_; - std::unique_ptr pm_; -}; - -TEST_F(TPUClusterFormationPassTest, NonReplicatedTPU) { - monitoring::testing::CellReader feature_metric_reader( - kJitCompileSingleCoreTpuCount); - static constexpr char kMlirModuleStr[] = R"( - module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { - func.func @valid_compilation_cluster_no_replication() { - "tf.opA"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - "tf.opB"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> () - func.return - } - })"; - CreateModule(kMlirModuleStr); - auto result = Run(); - EXPECT_TRUE(result.succeeded()); - EXPECT_EQ(feature_metric_reader.Delta(kUseMlirBridge), 1); -} - -TEST_F(TPUClusterFormationPassTest, ReplicatedTPU) { - monitoring::testing::CellReader feature_metric_reader( - kJitCompileSingleCoreTpuCount); - static constexpr char kMlirModuleStr[] = R"( - module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { - func.func @interleaved_clusters(%arg0 : tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", device = "device_1", num_replicas = 1, topology = "topology_1"} : () -> () - %0 = "tf.opA"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", is_stateless = true} : (tensor) -> tensor - %1 = "tf.opB"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", is_stateless = true} : (tensor) -> tensor - %2 = "tf.opC"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", is_stateless = true} : (tensor) -> tensor - %3 = "tf.opD"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", is_stateless = true} : (tensor) -> tensor - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", device = "device_0", num_replicas = 1, topology = "topology_0"} : () -> () - func.return %2, %3 : tensor, tensor - } - })"; - CreateModule(kMlirModuleStr); - auto result = Run(); - EXPECT_TRUE(result.succeeded()); - EXPECT_EQ(feature_metric_reader.Delta(kUseMlirBridge), 0); -} - -} // namespace -} // namespace internal -} // namespace tf2xla -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index ba36284dc246c9..31b6aa272faf1d 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -278,6 +278,36 @@ cc_library( ], ) +cc_library( + name = "function", + srcs = [ + "function/function.cc", + ], + hdrs = [ + "function/function.h", + ], + deps = [ + ":tf_to_tfrt", + ":tfrt_compile_options", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_model", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@tf_runtime//:bef", + "@tf_runtime//:core_runtime", + "@tf_runtime//:hostcontext", + "@tf_runtime//:mlirtobef", + "@tf_runtime//:tensor", + ], +) + cc_library( name = "saved_model", srcs = [ @@ -326,6 +356,7 @@ cc_library( ], deps = [ ":backend_compiler", + ":function", ":tf_to_tfrt", ":tfrt_compile_options", ":tfrt_pipeline_options", diff --git a/tensorflow/compiler/mlir/tfrt/function/function.cc b/tensorflow/compiler/mlir/tfrt/function/function.cc new file mode 100644 index 00000000000000..42b7ff2b38982a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/function/function.cc @@ -0,0 +1,99 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/function/function.h" + +#include "absl/strings/match.h" +#include "absl/strings/str_split.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime +#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/core_runtime/op_handler.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime + +namespace tensorflow { + +Status CompileTFMLIRToBEF(const TfrtFunctionCompileOptions& options, + mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) { + mlir::OpPrintingFlags print_flags; + print_flags.elideLargeElementsAttrs(); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Input TF Executor dialect:"; + DumpMlirOpToFile("tf_to_tfrt_tf_executor_dialect", module); + } + + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + + // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect. + mlir::PassManager pm(module.getContext()); + tensorflow::applyTensorflowAndCLOptions(pm); + + tensorflow::TfrtPipelineOptions pass_options; + if (!options.default_device.empty()) { + pass_options.default_device = options.default_device; + } + if (!options.force_data_format.empty()) { + pass_options.force_data_format = options.force_data_format; + } + // TODO(tfrt-devs): Current MaxPoolingOp only supports NHWC on device type + // CPU. Enable this layout optimization after we introduce TFRT native ops + // for training. + if (absl::StrContains(pass_options.default_device, "CPU")) { + pass_options.skip_fold_transpose_in_ops = true; + } + pass_options.enable_optimizer = options.enable_optimizer; + // Use TFRT TPU OpKernel for training. + pass_options.target_tpurt = false; + pass_options.tpu_use_core_selector = options.tpu_use_core_selector; + pass_options.tpu_use_bundled_transfer = options.tpu_use_bundled_transfer; + pass_options.tpu_lower_to_fallback = options.tpu_lower_to_fallback; + pass_options.tpu_fuse_ops = options.tpu_fuse_ops; + pass_options.tpu_transfer_result_to_host = + options.tpu_transfer_result_to_host; + Status status = tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options); + if (!status.ok()) { + return diag_handler.Combine(status); + } + + if (mlir::failed(pm.run(module))) + return diag_handler.Combine(tensorflow::errors::Internal( + "failed to lower TF Dialect to CoreRT dialect.")); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "TFRT dialect: "; + DumpMlirOpToFile("tf_to_tfrt_tfrt_dialect", module); + } + + *bef_buffer = + tfrt::ConvertMLIRToBEF(module, /* disable_optional_sections = */ true); + if (bef_buffer->empty()) + return diag_handler.Combine( + tensorflow::errors::Internal("failed to convert MLIR to BEF.")); + + return OkStatus(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/function/function.h b/tensorflow/compiler/mlir/tfrt/function/function.h new file mode 100644 index 00000000000000..1a7d8bd05928da --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/function/function.h @@ -0,0 +1,81 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_FUNCTION_FUNCTION_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_FUNCTION_FUNCTION_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/platform/status.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime + +namespace tfrt { +class CoreRuntime; +} + +namespace mlir { +class ModuleOp; +} + +namespace tensorflow { + +struct TfrtFunctionCompileOptions : public TfrtCompileOptions { + // Currently only SavedModel API inference uses the tpu_fuse_ops option + TfrtFunctionCompileOptions() { + tpu_fuse_ops = false; + // Currently grappler is not correctly applied in the eager execution of TF + // functions, as it may sometimes remove arguments and results. + enable_grappler = false; + } + + // If true, use ServingCoreSelector to pick TPU core. Otherwise, obtain core + // location from assigned device name. + // Currently we don't use core_selector for training use cases. + bool tpu_use_core_selector = false; + + // If true, use BundledTransferToTpuOp to transfer variables and input tensors + // to TPU. + bool tpu_use_bundled_transfer = false; + + // If true, lower an TF op that's placed on TPU device to be executed with + // tfrt_fallback.execute. + // Currently for training use cases we need to lower the op to corert.execute + // to execute with TPU OpHandler, and with TFRT's native implementation. + // TODO(b/188940204): remove this config after we clear up the TPU variable + // implementation. + bool tpu_lower_to_fallback = false; + // If true, transfer the result of TPUExecuteOp from TPU to host. + // Currently for training and Python bulk inference use cases, we don't need + // to proactively transfer the result to host since the consumer op (or + // function) of the result may still be on TPU. + // TODO(b/194081364): remove this option once we unify servo TPU serving + // result transfer behavior. + bool tpu_transfer_result_to_host = false; +}; + +// Compile MLIR generated by tf.function in TF dialect into BEF. +Status CompileTFMLIRToBEF(const TfrtFunctionCompileOptions& options, + mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_FUNCTION_FUNCTION_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 3131742b6b95be..ebaf2570bba3f4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -92,6 +92,7 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, model_name, entry_function_name.str(), *std::move(submodule), ifrt_model_context.GetClient(), &ifrt_model_context.GetThreadPool(), &ifrt_model_context.GetLoadedVariableRegistry(), + &ifrt_model_context.GetRestoreTensorRegistry(), ifrt_model_context.GetDeviceMgr(), ifrt_model_context.GetShapeRepresentationFn()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h index 5a8665e3a78090..6bcbf8dbad317b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -130,6 +130,11 @@ void CreateTfToTfrtPipeline(mlir::OpPassManager& pm, Status CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm, const TfrtPipelineOptions& options); +// Creates a pipeline of passes that lowers MLIR TF Executor dialect to TF +// dialect for CoreRT purposes. +Status CreateTFExecutorToTFPipeline(mlir::PassManager& pm, + const TfrtPipelineOptions& options); + // TODO(deqiangc): refactor below helpers once mlrt is OSSed. void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( mlir::OpPassManager& pm, const TfrtPipelineOptions& options); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index c3945d9bd48b9b..693bd78df0b170 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -1913,6 +1913,14 @@ Status CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm, return absl::OkStatus(); } +Status CreateTFExecutorToTFPipeline(mlir::PassManager &pm, + const TfrtPipelineOptions &options) { + TF_RETURN_IF_ERROR( + CreateTFExecutorToTFPreInvariantOptimizationPipeline(pm, options)); + CreateTFExecutorToTFInvariantOptimizationPipelineHelper(pm, options); + return absl::OkStatus(); +} + static mlir::PassRegistration tf_to_tfrt_pass; static mlir::PassPipelineRegistration tf_pipeline( diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index c44f2bfa23598c..3cf8be9c90cb62 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" +#include "tensorflow/compiler/mlir/tfrt/function/function.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.h b/tensorflow/compiler/mlir/tfrt/translate/import_model.h index 22f770a6e5b82f..c8aece1e8f4706 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.h +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/function/function.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index ada633481a11cd..bf855d8eba0c40 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -66,12 +66,12 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/data:dataset_utils", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:serialization_utils", "//tensorflow/core/framework:dataset_options_proto_cc", "//tensorflow/core/util/tensor_bundle", "//tensorflow/core/util/tensor_bundle:naming", + "@com_google_absl//absl/status", ], ) diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 98d75f2cfe2641..3460bbc7475ca2 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -76,9 +77,10 @@ constexpr char kIncompleteCacheErrorMessage[] = "should use `dataset.take(k).cache().repeat()` instead."; } // namespace -class PartialCache { +class DatasetRandomAccessCache { public: - explicit PartialCache(const DatasetBase* dataset) : input_(dataset) {} + explicit DatasetRandomAccessCache(const DatasetBase* dataset) + : input_(dataset) {} // Extends the temporary cache up to a given index and then updates // out_tensors with the element at that index. @@ -136,6 +138,43 @@ class PartialCache { std::vector> cache_; }; +// Caches dataset elements when global shuffling is enabled. +// TODO(b/325112575): Support save/load. +class IteratorRandomAccessCache { + public: + explicit IteratorRandomAccessCache(const DatasetBase* input) + : input_(input) {} + + absl::Status Get(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) { + TF_ASSIGN_OR_RETURN(size_t element_position, + ctx->index_mapper()(element_count_++)); + if (element_position < cache_.size() && !cache_[element_position].empty()) { + *out_tensors = cache_[element_position]; + *end_of_sequence = false; + return absl::OkStatus(); + } + + absl::Status status = + input_->Get(AnyContext(ctx), element_position, out_tensors); + if (absl::IsOutOfRange(status)) { + *end_of_sequence = true; + return absl::OkStatus(); + } + + if (element_position >= cache_.size()) { + cache_.resize(element_position + 1); + } + cache_[element_position] = *out_tensors; + return absl::OkStatus(); + } + + private: + int64_t element_count_ = 0; + const DatasetBase* input_ = nullptr; + std::vector> cache_; +}; + class CacheDatasetOp::FileDatasetBase : public DatasetBase { public: FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input, @@ -736,6 +775,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { input_(input), cache_(std::move(cache)) { input_->Ref(); + random_indexing_compatible_ = input_->RandomIndexingCompatible(); } ~MemoryDatasetBase() override { input_->Unref(); } @@ -781,10 +821,11 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return errors::OutOfRange("Index out of range [0, ", cardinality, "):", index); } - if (!partial_cache_) { - partial_cache_ = std::make_unique(input_); + if (!dataset_random_access_cache_) { + dataset_random_access_cache_ = + std::make_unique(input_); } - return partial_cache_->Get(ctx, index, out_tensors); + return dataset_random_access_cache_->Get(ctx, index, out_tensors); } Status InputDatasets(std::vector* inputs) const override { @@ -796,6 +837,10 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return input_->CheckExternalState(); } + absl::Status RandomIndexingCompatible() const override { + return random_indexing_compatible_; + } + protected: class MemoryIterator : public DatasetIterator { public: @@ -811,6 +856,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + if (ctx->index_mapper() != nullptr) { + if (!iterator_random_access_cache_) { + iterator_random_access_cache_ = + std::make_unique(dataset()->input_); + } + return iterator_random_access_cache_->Get(ctx, out_tensors, + end_of_sequence); + } return iterator_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -1015,12 +1068,16 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { mutex mu_; MemoryCache* cache_ TF_GUARDED_BY(mu_); // not owned. std::unique_ptr iterator_ TF_GUARDED_BY(mu_); + std::unique_ptr iterator_random_access_cache_ + TF_GUARDED_BY(mu_); }; // MemoryIterator mutable mutex mu_; const DatasetBase* const input_; const std::shared_ptr cache_; - mutable std::unique_ptr partial_cache_ TF_GUARDED_BY(mu_); + mutable std::unique_ptr dataset_random_access_cache_ + TF_GUARDED_BY(mu_); + absl::Status random_indexing_compatible_ = absl::OkStatus(); }; // MemoryDatasetBase // This version of memory dataset has an exclusive ownership of the memory cache diff --git a/tensorflow/core/kernels/data/experimental/index_flat_map_dataset_op.cc b/tensorflow/core/kernels/data/experimental/index_flat_map_dataset_op.cc index c71c20f17b6caa..62006861af2c42 100644 --- a/tensorflow/core/kernels/data/experimental/index_flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/index_flat_map_dataset_op.cc @@ -89,15 +89,21 @@ absl::StatusOr GetValue(const Tensor& tensor) { } // Returns the `offset`-th element from `tensors`. -std::vector GetSlice(const std::vector& tensors, - size_t offset) { +absl::StatusOr> GetSlice(const std::vector& tensors, + size_t offset) { std::vector result; for (size_t i = 0; i < tensors.size(); ++i) { if (tensors[i].dims() == 0) { // Scalar. result.push_back(tensors[i]); - } else { - result.push_back(MaybeCopySubSlice(tensors[i], offset)); + continue; + } + if (offset > tensors[i].dim_size(0)) { + return absl::InvalidArgumentError(absl::StrCat( + "`index_flat_map` got invalid `index_map_fn` which returns offset ", + offset, ", but the input element has ", tensors[i].dim_size(0), + " elements: ", tensors[i].DebugString())); } + result.push_back(MaybeCopySubSlice(tensors[i], offset)); } return result; } @@ -264,7 +270,8 @@ class IndexFlatMapDatasetOp::Dataset::Iterator } input_element_count_ = next_input_index; } - *out_tensors = GetSlice(input_unflattened_tensors_, offset); + TF_ASSIGN_OR_RETURN(*out_tensors, + GetSlice(input_unflattened_tensors_, offset)); ++element_count_; } else { // TODO(b/325112575): Make it easier to return multiple values from @@ -279,7 +286,7 @@ class IndexFlatMapDatasetOp::Dataset::Iterator if (*end_of_sequence) { return absl::OkStatus(); } - *out_tensors = GetSlice(mapped_tensors, offset); + TF_ASSIGN_OR_RETURN(*out_tensors, GetSlice(mapped_tensors, offset)); } return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/linalg/matrix_diag_op.cc b/tensorflow/core/kernels/linalg/matrix_diag_op.cc index 89dc4f06ea111a..b7302ef1514e78 100644 --- a/tensorflow/core/kernels/linalg/matrix_diag_op.cc +++ b/tensorflow/core/kernels/linalg/matrix_diag_op.cc @@ -236,12 +236,13 @@ class MatrixDiagOp : public OpKernel { errors::InvalidArgument( "lower_diag_index must not be larger than upper_diag_index: ", lower_diag_index, " > ", upper_diag_index)); - OP_REQUIRES(context, - lower_diag_index == upper_diag_index || - diagonal_shape.dim_size(diag_rank - 2) == num_diags, - errors::InvalidArgument( - "The number of diagonals provided in the input does not " - "match the lower_diag_index and upper_diag_index range.")); + OP_REQUIRES( + context, + lower_diag_index == upper_diag_index || + diagonal_shape.dim_size(std::max(diag_rank - 2, 0)) == num_diags, + errors::InvalidArgument( + "The number of diagonals provided in the input does not " + "match the lower_diag_index and upper_diag_index range.")); const Eigen::Index max_diag_len = diagonal_shape.dim_size(diag_rank - 1); const Eigen::Index min_num_rows = diff --git a/tensorflow/core/profiler/lib/profiler_disabled_test.cc b/tensorflow/core/profiler/lib/profiler_disabled_test.cc index 4e7fafe84ff71f..089c8692ca7bde 100644 --- a/tensorflow/core/profiler/lib/profiler_disabled_test.cc +++ b/tensorflow/core/profiler/lib/profiler_disabled_test.cc @@ -24,7 +24,8 @@ namespace { TEST(ProfilerDisabledTest, ProfilerDisabledTest) { setenv("TF_DISABLE_PROFILING", "1", /*overwrite=*/1); - StatusOr profiler_lock = tsl::profiler::ProfilerLock::Acquire(); + absl::StatusOr profiler_lock = + tsl::profiler::ProfilerLock::Acquire(); EXPECT_FALSE(profiler_lock.ok()); } diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 0831f4479c16c6..9249f032b1e1c1 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1834 // Updated: 2024/4/16 +#define TF_GRAPH_DEF_VERSION 1835 // Updated: 2024/4/17 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index 37b449af8d6a88..e8d4b629be8400 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -53,6 +53,7 @@ cc_library( hdrs = ["ifrt_serving_executable.h"], deps = [ ":ifrt_loaded_variable_registry", + ":ifrt_restore_tensor_registry", ":ifrt_tensor_utils", ":sharding_utils", ":tf_host_callback", @@ -66,8 +67,6 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/common_runtime/eager:context", - "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -118,7 +117,10 @@ cc_library( srcs = ["ifrt_restore_tensor_registry.cc"], hdrs = ["ifrt_restore_tensor_registry.h"], deps = [ + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", + "//tensorflow/core:framework", "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -240,8 +242,6 @@ cc_library( ":sharding_utils", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", "//tensorflow/core:framework", - "//tensorflow/core/tfrt/mlrt/interpreter:future", - "//tensorflow/core/tfrt/utils:fallback_tensor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -300,8 +300,6 @@ tf_cc_test( "//tensorflow/core/framework:tensor_matcher", "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/framework:types_proto_cc", - "//tensorflow/core/tfrt/mlrt/interpreter:future", - "//tensorflow/core/tfrt/utils:fallback_tensor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -389,6 +387,7 @@ tf_cc_test( tags = ["no_oss"], deps = [ ":ifrt_loaded_variable_registry", + ":ifrt_restore_tensor_registry", ":ifrt_serving_executable", ":sharding_utils", ":tf_host_callback", @@ -440,6 +439,7 @@ tf_cc_test( deps = [ ":ifrt_executable_registry", ":ifrt_loaded_variable_registry", + ":ifrt_restore_tensor_registry", ":ifrt_serving_executable", ":tf_host_callback", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc index cfddce835709f5..b30340287badd2 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" #include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" #include "tsl/platform/env.h" @@ -79,13 +80,14 @@ CreateIfrtServingExecutable(mlir::MLIRContext& context) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSIGN_OR_RETURN(std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); return std::make_unique( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); } TEST(IfrtExecutableRegistry, Basic) { diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h index 5cb885598174bd..469dae297dd0b0 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h @@ -34,7 +34,6 @@ namespace ifrt_serving { class IfrtLoadedVariableRegistry { public: struct LoadedVariable { - DtypeAndShape dtype_and_shape; xla::ifrt::Future>> array; }; using LoadedVariableConstructor = diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc index af7e8e2131b23c..9161a37f55f5ce 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -32,12 +33,9 @@ limitations under the License. #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/sharding_utils.h" -#include "tensorflow/core/tfrt/mlrt/interpreter/future.h" -#include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -63,6 +61,8 @@ absl::StatusOr> LoadIfrtVariable( thread_pool); } +} // namespace + absl::StatusOr GetDtypeAndShape( const ResourceHandle& resource_handle) { const std::vector& dtype_and_partial_shapes = @@ -84,31 +84,20 @@ absl::StatusOr GetDtypeAndShape( return dtype_and_shape; } -} // namespace - std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle) { return absl::StrCat(handle.container(), "__", handle.name()); } absl::Status LoadRestoredTensorAsIfrtLoadedVariable( - const tensorflow::Tensor& variable_handle_tensor, + absl::string_view runtime_name, std::shared_ptr ifrt_client, const tsl::thread::ThreadPool& thread_pool, ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry, ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry, tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, const VariableDeviceShardingConfigProto& sharding_config) { - if (variable_handle_tensor.dtype() != DT_RESOURCE) { - return absl::InvalidArgumentError( - absl::StrCat("variable_handle_tensor is ", - DataTypeString(variable_handle_tensor.dtype()), - " but expected DT_RESOURCE")); - } - const ResourceHandle& handle = - variable_handle_tensor.scalar()(); - std::string runtime_name = GetRuntimeNameFromVarHandle(handle); xla::ifrt::Future> restored_tensor_future = - ifrt_restore_tensor_registry.Get(runtime_name); + ifrt_restore_tensor_registry.GetRestoredTensor(runtime_name); if (!restored_tensor_future.IsValid()) { return absl::InternalError(absl::StrCat( "LoadVariableOp: failed to fetch variable tensor: ", runtime_name)); @@ -120,8 +109,9 @@ absl::Status LoadRestoredTensorAsIfrtLoadedVariable( xla::ifrt::Future>>( loaded_variable_promise); - TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape, - GetDtypeAndShape(handle)); + TF_ASSIGN_OR_RETURN( + absl::StatusOr dtype_and_shape, + ifrt_restore_tensor_registry.GetDtypeAndShape(runtime_name)); // TODO(b/330360798) Load variable on devices from the result of core // selection. TF_RETURN_IF_ERROR(ifrt_loaded_variable_registry.TryRegisterLoadedVariable( @@ -129,8 +119,7 @@ absl::Status LoadRestoredTensorAsIfrtLoadedVariable( [&]() -> absl::StatusOr< ifrt_serving::IfrtLoadedVariableRegistry::LoadedVariable> { return ifrt_serving::IfrtLoadedVariableRegistry::LoadedVariable( - {.dtype_and_shape = dtype_and_shape, - .array = loaded_variable_future}); + {.array = loaded_variable_future}); })); restored_tensor_future.OnReady( [ifrt_client = ifrt_client, &thread_pool = thread_pool, diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h index 4a2c9542be85bb..aafebbda16bd77 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h @@ -22,17 +22,18 @@ limitations under the License. #include "absl/status/status.h" #include "xla/python/ifrt/client.h" #include "tensorflow/core/framework/resource_handle.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" -#include "tensorflow/core/tfrt/mlrt/interpreter/future.h" #include "tsl/platform/threadpool.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime namespace tensorflow { namespace ifrt_serving { +absl::StatusOr GetDtypeAndShape( + const ResourceHandle& resource_handle); + // Returns the runtime name from the resource handle. The name will be concat of // handle's container name and handle's name. std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle); @@ -44,7 +45,7 @@ std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle); // can look for the actual loaded variable value in // `ifrt_loaded_variable_registry`. absl::Status LoadRestoredTensorAsIfrtLoadedVariable( - const tensorflow::Tensor& variable_handle_tensor, + absl::string_view runtime_name, std::shared_ptr ifrt_client, const tsl::thread::ThreadPool& thread_pool, ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry, diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc index 85726c7a3f2cc1..40cf1903d46fe8 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc @@ -38,8 +38,6 @@ limitations under the License. #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" -#include "tensorflow/core/tfrt/mlrt/interpreter/future.h" -#include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tsl/concurrency/ref_count.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" @@ -86,17 +84,17 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableNotFoundWrongName) { xla::ifrt::Future>::CreatePromise(); auto future = xla::ifrt::Future>(promise); - TF_ASSERT_OK( - restored_tensor_registry.TryRegister("var_x_wrong", std::move(future))); + IfrtRestoreTensorRegistry::RestoredTensorInfo restored_tensor_info = { + GetDtypeAndShape(variable_handle.scalar()()).value(), + future}; + TF_ASSERT_OK(restored_tensor_registry.TryRegister("var_x_wrong", + restored_tensor_info)); promise.Set(input_tensor); - TF_ASSERT_OK(LoadRestoredTensorAsIfrtLoadedVariable( - variable_handle, client, thread_pool, restored_tensor_registry, - loaded_variable_registry, restore_work_queue.get(), sharding_config)); - TF_ASSERT_OK_AND_ASSIGN( - auto v, - loaded_variable_registry.GetLoadedVariable(GetRuntimeNameFromVarHandle( - variable_handle.scalar()()))); - EXPECT_THAT(v.array.Await().status(), StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT( + LoadRestoredTensorAsIfrtLoadedVariable( + "var_x", client, thread_pool, restored_tensor_registry, + loaded_variable_registry, restore_work_queue.get(), sharding_config), + StatusIs(absl::StatusCode::kNotFound)); } TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableSucceed) { @@ -129,21 +127,20 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableSucceed) { xla::ifrt::Future>::CreatePromise(); auto future = xla::ifrt::Future>(promise); - TF_ASSERT_OK(restored_tensor_registry.TryRegister( - GetRuntimeNameFromVarHandle(variable_handle.scalar()()), - std::move(future))); + IfrtRestoreTensorRegistry::RestoredTensorInfo restored_tensor_info = { + GetDtypeAndShape(variable_handle.scalar()()).value(), + future}; + + TF_ASSERT_OK( + restored_tensor_registry.TryRegister("var_x", restored_tensor_info)); TF_ASSERT_OK(LoadRestoredTensorAsIfrtLoadedVariable( - variable_handle, client, thread_pool, restored_tensor_registry, + "var_x", client, thread_pool, restored_tensor_registry, loaded_variable_registry, restore_work_queue.get(), sharding_config)); promise.Set(input_tensor); - TF_ASSERT_OK_AND_ASSIGN( - auto v, - loaded_variable_registry.GetLoadedVariable(GetRuntimeNameFromVarHandle( - variable_handle.scalar()()))); + TF_ASSERT_OK_AND_ASSIGN(auto v, + loaded_variable_registry.GetLoadedVariable("var_x")); TF_ASSERT_OK_AND_ASSIGN(auto assembled_array, v.array.Await()); - EXPECT_TRUE(v.dtype_and_shape.shape.IsSameSize(TensorShape({2, 2}))); - EXPECT_EQ(v.dtype_and_shape.dtype, DT_INT32); TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, assembled_array->DisassembleIntoSingleDeviceArrays( xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); diff --git a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.cc b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.cc index 752a22384539bb..e04bafd11ef16c 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/python/ifrt/future.h" #include "tensorflow/core/framework/tensor.h" @@ -31,20 +32,19 @@ namespace tensorflow { namespace ifrt_serving { absl::Status IfrtRestoreTensorRegistry::TryRegister( - absl::string_view name, - xla::ifrt::Future> tensor_future) { + absl::string_view name, RestoredTensorInfo restored_tensor_info) { absl::MutexLock lock(&mutex_); - auto& variable = restored_tensors_[name]; - if (variable.IsValid()) { + auto& info = restored_tensors_[name]; + if (info.tensor_future.IsValid()) { return absl::AlreadyExistsError( absl::StrCat("Variable '", name, "' already registered.")); } - variable = std::move(tensor_future); + info = std::move(restored_tensor_info); return absl::OkStatus(); } xla::ifrt::Future> -IfrtRestoreTensorRegistry::Get(absl::string_view name) const { +IfrtRestoreTensorRegistry::GetRestoredTensor(absl::string_view name) const { absl::MutexLock lock(&mutex_); auto it = restored_tensors_.find(name); if (it == restored_tensors_.end()) { @@ -52,7 +52,19 @@ IfrtRestoreTensorRegistry::Get(absl::string_view name) const { absl::NotFoundError(absl::StrCat("Variable '", name, "' not found."))); } - return it->second; + return it->second.tensor_future; +} + +absl::StatusOr IfrtRestoreTensorRegistry::GetDtypeAndShape( + absl::string_view name) const { + absl::MutexLock lock(&mutex_); + auto it = restored_tensors_.find(name); + if (it == restored_tensors_.end()) { + return absl::NotFoundError( + absl::StrCat("Variable '", name, "' not found.")); + } + + return it->second.dtype_and_shape; } } // namespace ifrt_serving diff --git a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h index d14b3774874543..a178065170b94e 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h @@ -23,8 +23,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/python/ifrt/future.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace ifrt_serving { @@ -32,21 +35,26 @@ namespace ifrt_serving { // This class is thread safe. class IfrtRestoreTensorRegistry { public: + struct RestoredTensorInfo { + DtypeAndShape dtype_and_shape; + xla::ifrt::Future> tensor_future; + }; // Tries to register a loaded variable with the given name. // Returns an error if the named tensor already exists. - absl::Status TryRegister( - absl::string_view name, - xla::ifrt::Future> tensor_future) + absl::Status TryRegister(absl::string_view name, + RestoredTensorInfo restored_tensor_info) ABSL_LOCKS_EXCLUDED(mutex_); - xla::ifrt::Future> Get( + xla::ifrt::Future> GetRestoredTensor( absl::string_view name) const ABSL_LOCKS_EXCLUDED(mutex_); + absl::StatusOr GetDtypeAndShape(absl::string_view name) const + ABSL_LOCKS_EXCLUDED(mutex_); + private: mutable absl::Mutex mutex_; - absl::flat_hash_map>> - restored_tensors_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map restored_tensors_ + ABSL_GUARDED_BY(mutex_); }; } // namespace ifrt_serving diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 41fe0d19ce6d70..c77fe193870990 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -65,6 +65,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h" #include "tensorflow/core/tfrt/ifrt/sharding_utils.h" #include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" @@ -79,7 +80,7 @@ namespace { absl::StatusOr> BuildDtypeAndShape( absl::Span inputs, absl::Span variable_arg_indices, - const IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry) { + const IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry) { std::vector dtypes_and_shapes; dtypes_and_shapes.reserve(inputs.size()); @@ -88,10 +89,10 @@ absl::StatusOr> BuildDtypeAndShape( if (variable_index < variable_arg_indices.size() && i == variable_arg_indices[variable_index]) { // Get already loaded variable tensor. - TF_ASSIGN_OR_RETURN(auto loaded_variable, - ifrt_loaded_variable_registry.GetLoadedVariable( + TF_ASSIGN_OR_RETURN(auto dtype_and_shape, + ifrt_restore_tensor_registry.GetDtypeAndShape( inputs[i].scalar()())); - dtypes_and_shapes.push_back(loaded_variable.dtype_and_shape); + dtypes_and_shapes.push_back(dtype_and_shape); variable_index++; } else { @@ -424,7 +425,7 @@ absl::StatusOr> IfrtServingExecutable::Execute( TF_ASSIGN_OR_RETURN(std::vector dtypes_and_shapes, BuildDtypeAndShape(inputs, variable_arg_indices, - ifrt_loaded_variable_registry_)); + ifrt_restore_tensor_registry_)); TF_ASSIGN_OR_RETURN( CachedExecutableBundle executable_bundle, LookUpOrCreateExecutable(absl::MakeSpan(dtypes_and_shapes)).Await()); diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index 4d040e7e4b6c23..691a22b1e4638b 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/threadpool.h" @@ -61,6 +62,7 @@ class IfrtServingExecutable { std::shared_ptr client, const tsl::thread::ThreadPool* thread_pool, const IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, + const IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, tensorflow::StaticDeviceMgr* device_mgr, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) : model_name_(std::string(model_name)), @@ -69,6 +71,7 @@ class IfrtServingExecutable { ifrt_client_(std::move(client)), thread_pool_(*thread_pool), ifrt_loaded_variable_registry_(*ifrt_loaded_variable_registry), + ifrt_restore_tensor_registry_(*ifrt_restore_tensor_registry), device_mgr_(device_mgr), shape_representation_fn_(std::move(shape_representation_fn)) {} @@ -129,6 +132,7 @@ class IfrtServingExecutable { const tsl::thread::ThreadPool& thread_pool_; const IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry_; + const IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry_; tensorflow::StaticDeviceMgr* device_mgr_; // Not owned. For host callback. tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc index f10367fd58237a..593deedb4cb22b 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/sharding_utils.h" #include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" #include "tsl/concurrency/ref_count.h" @@ -104,14 +105,15 @@ TEST(IfrtServingExecutableTest, Basic) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); IfrtServingExecutable executable( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); auto x = AsTensor({1, 2, 3}, tensorflow::TensorShape({1, 3})); auto y = AsTensor({1, 2, 3}, tensorflow::TensorShape({3, 1})); @@ -149,14 +151,15 @@ TEST(IfrtServingExecutableTest, MultipleShapes) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); IfrtServingExecutable executable( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); auto x1 = AsTensor({1, 2, 3}, tensorflow::TensorShape({1, 3})); auto y1 = AsTensor({1, 2, 3}, tensorflow::TensorShape({3, 1})); @@ -209,14 +212,15 @@ TEST(IfrtServingExecutableTest, Spmd) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); IfrtServingExecutable executable( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); auto x = AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, tensorflow::TensorShape({4, 2})); @@ -259,14 +263,15 @@ TEST(IfrtServingExecutableTest, SpmdTwoReturns) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); IfrtServingExecutable executable( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); auto x = AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, tensorflow::TensorShape({4, 2})); @@ -313,14 +318,15 @@ TEST(IfrtServingExecutableTest, NoReturn) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); IfrtServingExecutable executable( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); auto x = AsTensor({1, 2, 3}, tensorflow::TensorShape({1, 3})); auto y = AsTensor({1, 2, 3}, tensorflow::TensorShape({3, 1})); @@ -355,19 +361,24 @@ TEST_P(VariableInputTest, InterleaveVariable) { xla::ifrt::test_util::GetClient()); IfrtLoadedVariableRegistry ifrt_loaded_variable_registry; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr device_mgr, CreateTfStaticDeviceMgr()); IfrtServingExecutable executable( "test", "main", std::move(mlir_module), client, &GetThreadPool(), - &ifrt_loaded_variable_registry, device_mgr.get(), - tensorflow::IdentityShapeRepresentationFn()); + &ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry, + device_mgr.get(), tensorflow::IdentityShapeRepresentationFn()); std::vector inputs; std::vector loaded_variable_indices; for (int i = 0; i < GetParam().in_tensors.size(); i++) { if (GetParam().is_variable[i]) { + IfrtRestoreTensorRegistry::RestoredTensorInfo restore_tensor_info = { + {GetParam().in_tensors[i].dtype(), GetParam().in_tensors[i].shape()}}; std::string variable_name = absl::StrCat("variable_", i); + ASSERT_OK(ifrt_restore_tensor_registry.TryRegister(variable_name, + restore_tensor_info)); ASSERT_OK(ifrt_loaded_variable_registry.TryRegisterLoadedVariable( variable_name, [&]() -> absl::StatusOr { @@ -387,9 +398,6 @@ TEST_P(VariableInputTest, InterleaveVariable) { IfrtLoadedVariableRegistry::LoadedVariable loaded_variable; loaded_variable.array = future; - loaded_variable.dtype_and_shape.dtype = in_tensor.dtype(); - loaded_variable.dtype_and_shape.shape = in_tensor.shape(); - return loaded_variable; })); loaded_variable_indices.push_back(i); diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.cc b/tensorflow/core/tfrt/ifrt/sharding_utils.cc index dd9ec62c0e0658..184904ee694903 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.cc @@ -616,7 +616,7 @@ absl::StatusOr MakeTensorFromArray( b.index_domain.origin().elements().end()); }); - std::vector> arrays_copy_status; + std::vector> arrays_copy_status; std::vector input_tensors; input_tensors.reserve(index_domain_device_arrays.size()); arrays_copy_status.reserve(index_domain_device_arrays.size()); @@ -626,7 +626,7 @@ absl::StatusOr MakeTensorFromArray( ToTensorDataType(array->dtype())); tensorflow::Tensor tensor(dtype, tensor_shape); input_tensors.push_back(tensor); - xla::ifrt::Future copy_status = + xla::ifrt::Future<> copy_status = array->CopyToHostBuffer(tensor.data(), /*byte_strides=*/{}, xla::ifrt::ArrayCopySemantics::kAlwaysCopy); copy_status.OnReady([tensor](absl::Status status) { diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD index 12ac71f2c9d1ac..a68cff125ea408 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/BUILD +++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD @@ -193,6 +193,7 @@ tf_cc_shared_test( "//tensorflow/core/tfrt/fallback:op_kernel_runner", "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:builtin_kernels", diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc index 0c40ea7887679e..3d3cd0ae7d6c04 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc @@ -61,27 +61,6 @@ namespace tf_mlrt { namespace { -absl::StatusOr GetDtypeAndShape( - const ResourceHandle& variable) { - std::vector dtype_and_partial_shapes = - variable.dtypes_and_shapes(); - - if (dtype_and_partial_shapes.size() != 1) { - return absl::InvalidArgumentError(absl::StrCat( - "Expected 1 dtype and shape, got ", dtype_and_partial_shapes.size())); - } - ifrt_serving::DtypeAndShape dtype_and_shape; - if (!dtype_and_partial_shapes.front().shape.AsTensorShape( - &dtype_and_shape.shape)) { - return absl::InvalidArgumentError( - absl::StrCat("Failed to convert partial shape to full tensor shape: ", - dtype_and_partial_shapes.front().shape.DebugString())); - } - - dtype_and_shape.dtype = dtype_and_partial_shapes.front().dtype; - return dtype_and_shape; -} - struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame { using KernelFrame::KernelFrame; @@ -189,11 +168,22 @@ void MlrtIfrtRestoreVariableKernel::Invoke() { xla::ifrt::Future>::CreatePromise(); auto future = xla::ifrt::Future>(promise); - - std::string runtime_name = ifrt_serving::GetRuntimeNameFromVarHandle( - var_handles()[i].tensor().scalar()()); - if (auto status = - ifrt_restore_tensor_registry.TryRegister(runtime_name, future); + const ResourceHandle& var_handle = + var_handles()[i].tensor().scalar()(); + absl::StatusOr dtype_and_shape = + ifrt_serving::GetDtypeAndShape(var_handle); + if (!dtype_and_shape.ok()) { + // TODO(b/330360798) Refactor Invoke() to have less usage on + // execution_context().Fail. + execution_context().Fail(dtype_and_shape.status()); + return; + } + std::string runtime_name = + ifrt_serving::GetRuntimeNameFromVarHandle(var_handle); + ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo + restored_tensor_info = {*std::move(dtype_and_shape), std::move(future)}; + if (auto status = ifrt_restore_tensor_registry.TryRegister( + runtime_name, restored_tensor_info); !status.ok()) { // Propagate errors so that if already-registered futures are being waited // on, they can be unblocked. @@ -236,7 +226,7 @@ class MlrtIfrtLoadVariableKernel : public mlrt::KernelFrame { static constexpr char kName[] = "tf_mlrt.ifrt_load_variable"; - const tensorflow::Tensor& variable_tensor() const { + const tensorflow::Tensor& variable_handler_tensor() const { DCHECK_GE(arguments().size(), 1); const tensorflow::Tensor& ret = arguments()[0].Get().tensor(); @@ -244,10 +234,6 @@ class MlrtIfrtLoadVariableKernel : public mlrt::KernelFrame { return ret; } - const ResourceHandle& variable_resource_handle() const { - const auto& tensor = variable_tensor(); - return tensor.scalar()(); - } absl::string_view sharding_config_proto_text() const { DCHECK_EQ(attributes().size(), 2); return attributes().GetAs(0).Get(); @@ -294,15 +280,17 @@ absl::Status MlrtIfrtLoadVariableKernel::InvokeHelper() { ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry = (*ifrt_model_context)->GetRestoreTensorRegistry(); + std::string runtime_name = ifrt_serving::GetRuntimeNameFromVarHandle( + variable_handler_tensor().scalar()()); + TF_RETURN_IF_ERROR(ifrt_serving::LoadRestoredTensorAsIfrtLoadedVariable( - variable_tensor(), (*ifrt_model_context)->GetClient(), + runtime_name, (*ifrt_model_context)->GetClient(), (*ifrt_model_context)->GetThreadPool(), ifrt_restore_tensor_registry, (*ifrt_model_context)->GetLoadedVariableRegistry(), (*ifrt_model_context)->checkpoint_loader_queue(), sharding_config)); - std::string runtime_name = - ifrt_serving::GetRuntimeNameFromVarHandle(variable_resource_handle()); xla::ifrt::Future> restored_tensor_future = - ifrt_restore_tensor_registry.Get(runtime_name); + ifrt_restore_tensor_registry.GetRestoredTensor(runtime_name); + restored_tensor_future.OnReady( [tensor_promise = std::move(tensor_promise)]( absl::StatusOr restored_tensor) mutable { diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc index ce3fc4a486e524..56d276fbfc0abc 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.h" @@ -401,10 +402,14 @@ TEST(KernelTest, IfrtLoadVariableOp) { auto input_tensor_future = xla::ifrt::Future>( input_tensor_promise); + ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo + restore_tensor_info{.dtype_and_shape = {.dtype = input_tensor.dtype(), + .shape = input_tensor.shape()}, + .tensor_future = input_tensor_future}; input_tensor_promise.Set(input_tensor); TF_ASSERT_OK((*ifrt_model_context) ->GetRestoreTensorRegistry() - .TryRegister(kVariableRuntimeName, input_tensor_future)); + .TryRegister(kVariableRuntimeName, restore_tensor_info)); std::vector args; std::vector last_uses; @@ -500,10 +505,14 @@ TEST(KernelTest, DuplicateIfrtLoadVariableOpShallSucceed) { auto input_tensor_future = xla::ifrt::Future>( input_tensor_promise); + ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo + restore_tensor_info{.dtype_and_shape = {.dtype = input_tensor.dtype(), + .shape = input_tensor.shape()}, + .tensor_future = input_tensor_future}; input_tensor_promise.Set(input_tensor); TF_ASSERT_OK((*ifrt_model_context) ->GetRestoreTensorRegistry() - .TryRegister(kVariableRuntimeName, input_tensor_future)); + .TryRegister(kVariableRuntimeName, restore_tensor_info)); std::vector args; std::vector last_uses; @@ -589,7 +598,7 @@ TEST(KernelTest, IfrtRestoreVariableOp) { xla::ifrt::Future> uninitialized_entry = (*ifrt_model_context) ->GetRestoreTensorRegistry() - .Get(kVariableRuntimeName); + .GetRestoredTensor(kVariableRuntimeName); ASSERT_TRUE(uninitialized_entry.IsReady()); EXPECT_THAT(uninitialized_entry.Await().status(), ::tsl::testing::StatusIs(absl::StatusCode::kNotFound)); @@ -630,7 +639,7 @@ TEST(KernelTest, IfrtRestoreVariableOp) { xla::ifrt::Future> restored_future = (*ifrt_model_context) ->GetRestoreTensorRegistry() - .Get(kVariableRuntimeName); + .GetRestoredTensor(kVariableRuntimeName); absl::StatusOr restored_tensor = restored_future.Await(); TF_ASSERT_OK(restored_tensor.status()); EXPECT_THAT(*restored_tensor, TensorEq(AsTensor({1, 2, 3}, {3}))); diff --git a/tensorflow/dtensor/cc/parallel_executor.h b/tensorflow/dtensor/cc/parallel_executor.h index 8d3570c837eba4..8234cd36e317b3 100644 --- a/tensorflow/dtensor/cc/parallel_executor.h +++ b/tensorflow/dtensor/cc/parallel_executor.h @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { namespace dtensor { -template +template using Future = ::xla::PjRtFuture; // ParallelExecutor Interface @@ -53,7 +53,7 @@ class ParallelExecutor { // raw pointers. // The client is responsible for the ownership of the outputs. struct ExecutionResult { - Future status; + Future<> status; // The pointed data of `outputs` are filled after `status` future resolves // as ok. std::vector outputs; diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 327cf0da8c45d6..e9fd75d7050234 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -53,7 +53,7 @@ package_group( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], - default_visibility = [":anything_but_tf"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc index fbdf23d11813cc..ade38194bbe4d1 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -100,7 +100,7 @@ ::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run( const auto it = model->operators.begin() + op_index; auto* op = it->get(); if (op->inputs.empty()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (IsTrivialUnfusedActivationFunc(this, *model, op->type, op->inputs[0])) { @@ -109,7 +109,7 @@ ::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run( "minmax imply at least as tight a clamp anyway.", LogName(*op)); *modified = RemoveTrivialPassthroughOp(this, model, op_index); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (IsTrivialFusedActivationFunc(this, *model, op->fused_activation_function, op->outputs[0])) { @@ -120,9 +120,9 @@ ::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run( "a clamp anyway.", LogName(*op)); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index c26816ff490aa4..4e0669e3629cb0 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -78,7 +78,7 @@ ::tensorflow::Status RemoveTrivialQuantizedMinMax::Run(Model* model, if ((op->type != OperatorType::kMinimum && op->type != OperatorType::kMaximum) || op->inputs.size() != 2) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (IsTrivialMinMax(this, *model, op->type, op->inputs[0], op->inputs[1])) { AddMessageF( @@ -86,9 +86,9 @@ ::tensorflow::Status RemoveTrivialQuantizedMinMax::Run(Model* model, "at least as tight a clamp anyway.", LogName(*op)); *modified = RemoveTrivialPassthroughOp(this, model, op_index); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc index 2ccefc6286a863..37e2735874d9ae 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -88,19 +88,19 @@ ::tensorflow::Status RemoveTrivialReshape::Run(Model* model, const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); if (reshape_op->type != OperatorType::kReshape) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!IsReshapeTrivial(*model, *reshape_op, this)) { AddMessageF("%s is not trivial", LogName(*reshape_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } AddMessageF("Removing trivial %s", LogName(*reshape_op)); CHECK_EQ(reshape_op->inputs.size(), 2); *modified = RemoveTrivialPassthroughOp(this, model, op_index); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc index 79fc79abb1d669..7b75596fac59a4 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc @@ -55,18 +55,18 @@ ::tensorflow::Status RemoveTrivialSlice::Run(Model* model, std::size_t op_index, const auto reshape_it = model->operators.begin() + op_index; auto* slice_op = reshape_it->get(); if (slice_op->type != OperatorType::kSlice) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!IsSliceTrivial(*model, *slice_op, this)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } AddMessageF("Removing trivial %s", LogName(*slice_op)); CHECK_EQ(slice_op->inputs.size(), 3); *modified = RemoveTrivialPassthroughOp(this, model, op_index); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/lite/toco/graph_transformations/remove_unused_op.cc index b2b6fc34fb247a..58202fc4c6fe7e 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_unused_op.cc @@ -60,7 +60,7 @@ ::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index, } for (const std::string& output_array : model->flags.output_arrays()) { if (output == output_array) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } for (const auto& rnn_state : model->flags.rnn_states()) { @@ -69,26 +69,26 @@ ::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index, if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) || !IsDiscardableArray(*model, rnn_state.state_array()) || CountOpsWithInput(*model, rnn_state.state_array())) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } } if (CountOpsWithInput(*model, output)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } if (op->unresolved_outputs) { AddMessageF("Not discarding %s because it has unresolved outputs.", LogName(*op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } AddMessageF("Discarding %s because none of its outputs is used.", LogName(*op)); DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 12a986cc0cc904..a17e241600b39b 100644 --- a/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -72,25 +72,25 @@ ::tensorflow::Status ReorderElementwiseUnary::Run(Model* model, const auto element_op_it = model->operators.begin() + op_index; std::unique_ptr& element_op = *element_op_it; if (!IsElementwiseOperator(element_op->type)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const std::string intermediate_name = element_op->inputs[0]; auto it = FindOpWithOutput(*model, intermediate_name); if (it == model->operators.end()) { AddMessageF("No preceding operator"); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } std::unique_ptr& move_op = *it; if (!IsMoveOperator(move_op->type)) { AddMessageF("Preceding operator is not a move operator"); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (CountOpsWithInput(*model, intermediate_name) != 1) { AddMessageF("Input %s used elsewhere", intermediate_name); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Check that the intermediate is discardable. @@ -99,7 +99,7 @@ ::tensorflow::Status ReorderElementwiseUnary::Run(Model* model, "Cannot swap elementwise as it would invalidate %s which is " "an output array.", intermediate_name); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // op->inputs may change so we need to keep a value by copy. @@ -153,7 +153,7 @@ ::tensorflow::Status ReorderElementwiseUnary::Run(Model* model, element_op.swap(move_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc index 56bd1d82aa171d..9ff7e49dfb73af 100644 --- a/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -111,30 +111,30 @@ ::tensorflow::Status ReorderReshapeTranspose::Run(Model* model, transpose_it->get(), OperatorType::kTranspose); if (transpose_op == nullptr) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) { // Wait for values to propagate. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Find the operator that produces the transpose op. auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]); if (reshape_it == model->operators.end()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } TensorFlowReshapeOperator* reshape_op = ConvertOperator(reshape_it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Ignore if the reshape is uninitialized. if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Need to copy to keep static if permutated. @@ -145,7 +145,7 @@ ::tensorflow::Status ReorderReshapeTranspose::Run(Model* model, // Intermediate should not be consumed by any other operators. if (CountOpsWithInput(*model, intermediate_name) != 1) { AddMessageF("Input %s used elsewhere", intermediate_name); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Check that the intermediate is not an output array. @@ -154,7 +154,7 @@ ::tensorflow::Status ReorderReshapeTranspose::Run(Model* model, "Cannot reorder reshape-transpose as it would invalidate %s which is " "an output array.", intermediate_name); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Get the arrays. @@ -176,7 +176,7 @@ ::tensorflow::Status ReorderReshapeTranspose::Run(Model* model, // dimensions then it can be moved between the transpose. if (!ReshapeIsEquivalentToTranspose(*model, reshape_op, true /*allow_extra_unary_dims*/)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!IsDiscardableArray(*model, output_name)) { @@ -247,7 +247,7 @@ ::tensorflow::Status ReorderReshapeTranspose::Run(Model* model, transpose_it->swap(*reshape_it); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc index d031654d0f4b36..ce8ca7fd9c600a 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -31,7 +31,7 @@ ::tensorflow::Status ResolveBatchNormalization::Run(Model* model, *modified = false; auto bn_it = model->operators.begin() + op_index; if (bn_it->get()->type != OperatorType::kBatchNormalization) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* bn_op = static_cast(bn_it->get()); @@ -44,7 +44,7 @@ ::tensorflow::Status ResolveBatchNormalization::Run(Model* model, // we need to exit early if these buffers don't exist yet (i.e. if the params // haven't yet been resolved as constants) and will process it once they have. if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && @@ -141,7 +141,7 @@ ::tensorflow::Status ResolveBatchNormalization::Run(Model* model, DeleteOpAndArrays(model, bn_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index 03aaee0614857e..5d4dec284f9ecb 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -30,29 +30,29 @@ ::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model, *modified = false; const auto op_it = model->operators.begin() + op_index; if (op_it->get()->type != OperatorType::kBatchToSpaceND) - return ::tensorflow::OkStatus(); + return absl::OkStatus(); auto* op = static_cast(op_it->get()); // The attributes are resolved only when the 3 attributes (block_shape, // before_crops, after_crops) are all constant. if (!op->block_shape.empty()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(op->inputs.size(), 3); if (!IsConstantParameterArray(*model, op->inputs[1]) || !IsConstantParameterArray(*model, op->inputs[2])) - return ::tensorflow::OkStatus(); + return absl::OkStatus(); // Handle crops const auto& crops_array = model->GetArray(op->inputs[2]); - if (!crops_array.has_shape()) return ::tensorflow::OkStatus(); + if (!crops_array.has_shape()) return absl::OkStatus(); const std::vector& crops_dims = crops_array.shape().dims(); if (crops_dims.size() != 2) { // Code only handles crops of 2 dimensions. Perhaps another transformation // will delete this op. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const std::vector& crops_buffer = crops_array.GetBuffer().data; @@ -63,7 +63,7 @@ ::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model, // Handle block_shape const auto& block_shape_array = model->GetArray(op->inputs[1]); - if (!block_shape_array.has_shape()) return ::tensorflow::OkStatus(); + if (!block_shape_array.has_shape()) return absl::OkStatus(); const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); const std::vector& block_shape_buffer = @@ -73,7 +73,7 @@ ::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model, } *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc index eabe90c6d24242..89d7cc0664d816 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -207,7 +207,7 @@ ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, binary_op->type != OperatorType::kLessEqual && binary_op->type != OperatorType::kGreater && binary_op->type != OperatorType::kGreaterEqual) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -215,13 +215,13 @@ ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, const auto& input1_array = model->GetArray(binary_op->inputs[1]); // Check if both inputs are constant parameters. if (!input0_array.buffer || !input1_array.buffer) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto& output_array = model->GetArray(binary_op->outputs[0]); // Yield until the output array dims have been resolved. if (!output_array.has_shape()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // At the moment we don't want to care about fused activation functions. @@ -232,7 +232,7 @@ ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, AddMessageF( "Not resolving constant %s because it has a fused activation function", LogName(*binary_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Check that input data types agree. @@ -245,12 +245,12 @@ ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, // Do the actual constants propagation if (!EvaluateBinaryOperatorOnConstantInputs(model, binary_op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } DeleteOpAndArrays(model, binary_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc index a726c9f79fee55..c8ef9e09f97b1e 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -143,7 +143,7 @@ ::tensorflow::Status ResolveConstantConcatenation::Run(Model* model, const auto concat_it = model->operators.begin() + op_index; const auto* concat_base_op = concat_it->get(); if (concat_base_op->type != OperatorType::kConcatenation) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* concat_op = static_cast(concat_base_op); @@ -153,15 +153,12 @@ ::tensorflow::Status ResolveConstantConcatenation::Run(Model* model, // We also make sure the shapes of the input arrays are known and they are // all discardable. const Operator* input_op = GetOpWithOutput(*model, input_name); - if (input_op) return ::tensorflow::OkStatus(); - if (!IsConstantParameterArray(*model, input_name)) - return ::tensorflow::OkStatus(); - if (!model->GetArray(input_name).has_shape()) - return ::tensorflow::OkStatus(); + if (input_op) return absl::OkStatus(); + if (!IsConstantParameterArray(*model, input_name)) return absl::OkStatus(); + if (!model->GetArray(input_name).has_shape()) return absl::OkStatus(); if (model->GetArray(input_name).quantization_params) - return ::tensorflow::OkStatus(); - if (!IsDiscardableArray(*model, input_name)) - return ::tensorflow::OkStatus(); + return absl::OkStatus(); + if (!IsDiscardableArray(*model, input_name)) return absl::OkStatus(); } const int concatenation_axis = concat_op->axis; @@ -210,7 +207,7 @@ ::tensorflow::Status ResolveConstantConcatenation::Run(Model* model, DeleteOpAndArrays(model, concat_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index bcb1b9f4f8402a..f40f1a964c06ce 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -66,7 +66,7 @@ ::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model, const auto fakequant_it = model->operators.begin() + op_index; const auto* fakequant_base_op = fakequant_it->get(); if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* fakequant_op = @@ -74,12 +74,12 @@ ::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model, // Yield until the fakequant MinMax has been resolved. if (!fakequant_op->minmax) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // This transformation only applies when the input array is constant. if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& input_array = model->GetArray(fakequant_op->inputs[0]); @@ -90,7 +90,7 @@ ::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model, if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op, &quantized_data_type)) { AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } AddMessageF("Resolving constant %s", LogName(*fakequant_op)); @@ -134,7 +134,7 @@ ::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model, size); DeleteOpAndArrays(model, fakequant_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc index 0b550b300647aa..13fd18d430c0b5 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc @@ -48,7 +48,7 @@ ::tensorflow::Status ResolveConstantFill::Run(Model* model, const auto fill_it = model->operators.begin() + op_index; auto* base_op = fill_it->get(); if (base_op->type != OperatorType::kFill) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* op = static_cast(base_op); @@ -58,49 +58,49 @@ ::tensorflow::Status ResolveConstantFill::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& val_array = model->GetArray(op->inputs[1]); if (!val_array.has_shape()) { // Yield until the value shape has been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!IsConstantParameterArray(*model, op->inputs[1])) { // Yield until the value is constant. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1); switch (output_array.data_type) { case ArrayDataType::kFloat: if (!ComputeFillArray(model, op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kUint8: if (!ComputeFillArray(model, op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kInt32: if (!ComputeFillArray(model, op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kInt64: if (!ComputeFillArray(model, op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kComplex64: if (!ComputeFillArray(model, op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; default: @@ -111,7 +111,7 @@ ::tensorflow::Status ResolveConstantFill::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc index 8a4354a5421404..961596fd6426e1 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -71,7 +71,7 @@ ::tensorflow::Status ResolveConstantGather::Run(Model* model, auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kGather) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* op = static_cast(base_op); @@ -80,28 +80,28 @@ ::tensorflow::Status ResolveConstantGather::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!op->axis) { // Yield until axis has been set by ResolveGatherAttributes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (op->axis.value() != 0) { // Only handling axis=0 for now. AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), op->axis.value()); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const Array& input_array = model->GetArray(op->inputs[0]); const Array& coords_array = model->GetArray(op->inputs[1]); @@ -144,7 +144,7 @@ ::tensorflow::Status ResolveConstantGather::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc index f471ed74755d16..ead1f91fea42b4 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc @@ -56,7 +56,7 @@ ::tensorflow::Status ResolveConstantPack::Run(Model* model, auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kPack) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* op = static_cast(base_op); @@ -65,18 +65,18 @@ ::tensorflow::Status ResolveConstantPack::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } for (const auto& input : op->inputs) { if (!IsConstantParameterArray(*model, input)) { // Yield if any input is mutable - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } @@ -112,7 +112,7 @@ ::tensorflow::Status ResolveConstantPack::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc index c7f90964647f71..2f011182e31869 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc @@ -66,7 +66,7 @@ ::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model, const auto it = model->operators.begin() + op_index; auto* base_op = it->get(); if (base_op->type != OperatorType::kRandomUniform) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* op = static_cast(base_op); @@ -76,12 +76,12 @@ ::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if ((op->seed == 0) && (op->seed2 == 0)) { @@ -89,13 +89,13 @@ ::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model, << "\" is truly random (using /dev/random system entropy). " "Therefore, cannot resolve as constant. Set \"seed\" or " "\"seed2\" attr non-zero to fix this"; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } switch (output_array.data_type) { case ArrayDataType::kFloat: if (!ComputeRandomUniformArray(model, op)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; // For future support of double or half. @@ -109,7 +109,7 @@ ::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc index 7afd4f6b7535a6..01701757e6c37b 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc @@ -47,7 +47,7 @@ ::tensorflow::Status ResolveConstantRange::Run(Model* model, const auto it = model->operators.begin() + op_index; auto* base_op = it->get(); if (base_op->type != OperatorType::kRange) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* op = static_cast(base_op); @@ -55,23 +55,23 @@ ::tensorflow::Status ResolveConstantRange::Run(Model* model, const auto& start_array = model->GetArray(op->inputs[0]); if (!start_array.has_shape()) { // Yield until all input dims have been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& limit_array = model->GetArray(op->inputs[1]); if (!limit_array.has_shape()) { // Yield until all input dims have been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& delta_array = model->GetArray(op->inputs[2]); if (!delta_array.has_shape()) { // Yield until all input dims have been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } for (const auto& input : op->inputs) { if (!IsConstantParameterArray(*model, input)) { // yield if any input is mutable - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } @@ -79,7 +79,7 @@ ::tensorflow::Status ResolveConstantRange::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1) @@ -107,7 +107,7 @@ ::tensorflow::Status ResolveConstantRange::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc index 6e1f0d7335846e..e68dcbaff6416b 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -29,7 +29,7 @@ ::tensorflow::Status ResolveConstantReshape::Run(Model* model, auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kReshape) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* op = static_cast(base_op); @@ -39,17 +39,17 @@ ::tensorflow::Status ResolveConstantReshape::Run(Model* model, // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const Array& input_array = model->GetArray(op->inputs[0]); @@ -57,7 +57,7 @@ ::tensorflow::Status ResolveConstantReshape::Run(Model* model, AddMessageF("Constant reshape is non-trivial (%s -> %s)", ShapeToString(input_array.shape()), ShapeToString(output_array.shape())); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK(!output_array.buffer); @@ -101,7 +101,7 @@ ::tensorflow::Status ResolveConstantReshape::Run(Model* model, default: LOG(FATAL) << "Unsupported data type: " << ArrayDataTypeName(input_array.data_type); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } AddMessageF("Resolving constant reshape of %s", LogName(*op)); @@ -110,7 +110,7 @@ ::tensorflow::Status ResolveConstantReshape::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc index c59b98334e8170..29744b3ebc64f6 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc @@ -34,7 +34,7 @@ ::tensorflow::Status ResolveConstantSelect::Run(Model* model, auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kSelect) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* op = static_cast(base_op); @@ -43,23 +43,23 @@ ::tensorflow::Status ResolveConstantSelect::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // We require the cond input to be constant. if (!IsConstantParameterArray(*model, op->inputs[0])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const Array& cond_array = model->GetArray(op->inputs[0]); CHECK(cond_array.data_type == ArrayDataType::kBool) << "Only bool conditions are supported"; const auto& cond_data = cond_array.GetBuffer().data; if (cond_data.empty()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Check if the condition is the same for all elements. @@ -70,14 +70,14 @@ ::tensorflow::Status ResolveConstantSelect::Run(Model* model, "Cannot resolve %s as constant; cond_array has differing " "per-element values", LogName(*op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // Pass-through the selected input. *modified = RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index ca5de5e638c187..47768259d29b44 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -26,25 +26,25 @@ ::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model, const auto it = model->operators.begin() + op_index; const auto* op = it->get(); if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been resolved - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the input array's shape has been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Compute the output @@ -63,7 +63,7 @@ ::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc index 742987cff9f60e..29a0a60ccb2d23 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -93,7 +93,7 @@ ::tensorflow::Status ResolveConstantSlice::Run(Model* model, const auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kSlice) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const SliceOperator* op = static_cast(base_op); @@ -102,54 +102,54 @@ ::tensorflow::Status ResolveConstantSlice::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (op->begin.empty() || op->size.empty()) { // Attributes have not resolved yet. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the value shape has been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!IsConstantParameterArray(*model, op->inputs[0])) { // Yield until the value is constant. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK(!output_array.buffer); switch (output_array.data_type) { case ArrayDataType::kFloat: if (!Slice(*op, input_array, &output_array)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kUint8: if (!Slice(*op, input_array, &output_array)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kInt32: if (!Slice(*op, input_array, &output_array)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kInt64: if (!Slice(*op, input_array, &output_array)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; case ArrayDataType::kComplex64: if (!Slice(*op, input_array, &output_array)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } break; default: @@ -160,7 +160,7 @@ ::tensorflow::Status ResolveConstantSlice::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 01c63f1f6aea27..7519b35d0f9cc6 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -105,7 +105,7 @@ ::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model, const auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kStridedSlice) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const StridedSliceOperator* op = @@ -115,28 +115,28 @@ ::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (op->start_indices.empty() || op->stop_indices.empty() || op->strides.empty()) { // Attributes have not resolved yet. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the value shape has been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!IsConstantParameterArray(*model, op->inputs[0])) { // Yield until the value is constant. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK(!output_array.buffer); @@ -165,7 +165,7 @@ ::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc index 9be148571f0910..4c798717544e26 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc @@ -105,7 +105,7 @@ ::tensorflow::Status ResolveConstantTile::Run(Model* model, auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kTile) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* op = static_cast(base_op); @@ -114,17 +114,17 @@ ::tensorflow::Status ResolveConstantTile::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const Array& input_array = model->GetArray(op->inputs[0]); const Array& multiples_array = model->GetArray(op->inputs[1]); @@ -163,7 +163,7 @@ ::tensorflow::Status ResolveConstantTile::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc index 5dc8245b4ded91..265db153e4907a 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc @@ -108,7 +108,7 @@ ::tensorflow::Status ResolveConstantTranspose::Run(Model* model, auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kTranspose) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* op = static_cast(base_op); @@ -117,17 +117,17 @@ ::tensorflow::Status ResolveConstantTranspose::Run(Model* model, auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const Array& input_array = model->GetArray(op->inputs[0]); @@ -135,7 +135,7 @@ ::tensorflow::Status ResolveConstantTranspose::Run(Model* model, if (op->perm.empty()) { // Yield until perm has been populated by ResolveTransposeAttributes. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // We currently only support 1-4 dimensions. @@ -173,7 +173,7 @@ ::tensorflow::Status ResolveConstantTranspose::Run(Model* model, DeleteOpAndArrays(model, op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc index 7016914f3b0978..e415a12cbe6412 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -138,28 +138,28 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, case OperatorType::kRelu: break; default: - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Check if the input is a constant parameter. if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // if the unary op involves a tensor required by a rnn state, ignore it for (const auto& rnn_state : model->flags.rnn_states()) { if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (unary_op->inputs[0] == rnn_state.state_array()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } auto& output_array = model->GetArray(unary_op->outputs[0]); if (!output_array.has_shape()) { // Yield until the output array dims have been resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // At the moment we don't want to care about fused activation functions. @@ -171,7 +171,7 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, "Not resolving constant %s " " because it has a fused activation function", LogName(*unary_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // The min-max is only copied for ops that copy data without arithmetic. @@ -193,7 +193,7 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, "Not resolving constant %s because we currently only support casting " "to float", LogName(*unary_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (cast_op->src_data_type != input_array.buffer->type) { AddMessageF( @@ -203,7 +203,7 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, } } else { if (input_array.buffer->type != ArrayDataType::kFloat) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } input_float_data = &(input_array.GetBuffer().data); } @@ -248,7 +248,7 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto& axis_array = model->GetArray(unary_op->inputs[1]); CHECK(axis_array.data_type == ArrayDataType::kInt32); @@ -345,7 +345,7 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, default: LOG(FATAL) << "Unsupported activation function " << LogName(*unary_op); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } output_float_data[i] = new_value; } @@ -355,7 +355,7 @@ ::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, DeleteOpAndArrays(model, unary_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc index eb16bc8c10a35a..b52632e4bce15d 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc @@ -32,13 +32,13 @@ ::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model, const auto fakequant_it = model->operators.begin() + op_index; auto* fakequant_base_op = fakequant_it->get(); if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* fakequant_op = static_cast(fakequant_base_op); if (fakequant_op->minmax) { // Already resolved. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(fakequant_op->inputs.size(), 3); @@ -46,7 +46,7 @@ ::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model, // resolved to constant arrays. for (int i = 1; i <= 2; i++) { if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } @@ -79,7 +79,7 @@ ::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model, } fakequant_op->inputs.resize(1); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc index 356e7df7b82441..083ce2756d6aa0 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -29,19 +29,18 @@ ::tensorflow::Status ResolveGatherAttributes::Run(Model* model, bool* modified) { *modified = false; auto* gather_op = model->operators[op_index].get(); - if (gather_op->type != OperatorType::kGather) return ::tensorflow::OkStatus(); + if (gather_op->type != OperatorType::kGather) return absl::OkStatus(); auto* op = static_cast(gather_op); if (op->axis) { // Attributes already resolved - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } - if (op->inputs.size() != 3) return ::tensorflow::OkStatus(); - if (!IsConstantParameterArray(*model, op->inputs[2])) - return ::tensorflow::OkStatus(); + if (op->inputs.size() != 3) return absl::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[2])) return absl::OkStatus(); const auto& indices_array = model->GetArray(op->inputs[2]); - if (!indices_array.has_shape()) return ::tensorflow::OkStatus(); + if (!indices_array.has_shape()) return absl::OkStatus(); const auto& axis_data = indices_array.GetBuffer().data; CHECK_EQ(axis_data.size(), 1) << "Multidimensional gather not supported on " << LogName(*op); @@ -52,7 +51,7 @@ ::tensorflow::Status ResolveGatherAttributes::Run(Model* model, op->inputs.resize(2); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc index fc8a669758ba00..5c45ec4a7b360a 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -58,23 +58,23 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, const auto mul_it = model->operators.begin() + op_index; auto* mul_op = mul_it->get(); if (mul_op->type != OperatorType::kMul) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto& output_array_name = mul_op->outputs[0]; auto& output_array = model->GetArray(output_array_name); if (!IsDiscardableArray(*model, output_array_name)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Yield if the output shape is not known yet. if (!output_array.has_shape()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // This transformation only handles the case where one operand is all 0's and @@ -86,12 +86,12 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can resolve here. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants propagation, not // for us to handle here. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -108,7 +108,7 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } FillArrayWithZeros(&output_array); } break; @@ -117,7 +117,7 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } FillArrayWithZeros(&output_array); } break; @@ -126,7 +126,7 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } FillArrayWithZeros(&output_array); } break; @@ -135,7 +135,7 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } FillArrayWithZeros(&output_array); } break; @@ -144,19 +144,19 @@ ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } FillArrayWithZeros(&output_array); } break; default: AddMessageF( "Cannot resolve multiply by 0 because of unsupported data type\n"); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } DeleteOpAndArrays(model, mul_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc index a2b3d2f0a5cd27..2ba9566ac02287 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -30,17 +30,16 @@ ::tensorflow::Status ResolvePadAttributes::Run(Model* model, *modified = false; const auto pad_it = model->operators.begin() + op_index; auto* pad_op = pad_it->get(); - if (pad_op->type != OperatorType::kPad) return ::tensorflow::OkStatus(); + if (pad_op->type != OperatorType::kPad) return absl::OkStatus(); auto* op = static_cast(pad_op); - if (!op->left_padding.empty()) return ::tensorflow::OkStatus(); + if (!op->left_padding.empty()) return absl::OkStatus(); CHECK_EQ(op->inputs.size(), 2); - if (!IsConstantParameterArray(*model, op->inputs[1])) - return ::tensorflow::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[1])) return absl::OkStatus(); const auto& array = model->GetArray(op->inputs[1]); - if (!array.has_shape()) return ::tensorflow::OkStatus(); + if (!array.has_shape()) return absl::OkStatus(); const std::vector& dims = array.shape().dims(); CHECK_EQ(dims.size(), 2); @@ -55,6 +54,6 @@ ::tensorflow::Status ResolvePadAttributes::Run(Model* model, // TODO(dkalenichenko): Delete the extra input? *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc index c0990c3a682353..8f67c01864396f 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc @@ -30,17 +30,16 @@ ::tensorflow::Status ResolvePadV2Attributes::Run(Model* model, *modified = false; const auto pad_it = model->operators.begin() + op_index; auto* pad_op = pad_it->get(); - if (pad_op->type != OperatorType::kPadV2) return ::tensorflow::OkStatus(); + if (pad_op->type != OperatorType::kPadV2) return absl::OkStatus(); auto* op = static_cast(pad_op); - if (!op->left_padding.empty()) return ::tensorflow::OkStatus(); + if (!op->left_padding.empty()) return absl::OkStatus(); CHECK_EQ(op->inputs.size(), 3); - if (!IsConstantParameterArray(*model, op->inputs[1])) - return ::tensorflow::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[1])) return absl::OkStatus(); const auto& array = model->GetArray(op->inputs[1]); - if (!array.has_shape()) return ::tensorflow::OkStatus(); + if (!array.has_shape()) return absl::OkStatus(); const std::vector& dims = array.shape().dims(); CHECK_EQ(dims.size(), 2); @@ -55,6 +54,6 @@ ::tensorflow::Status ResolvePadV2Attributes::Run(Model* model, // TODO(dkalenichenko): Delete the extra input? *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc index 1fde09a8466b52..0d4ffd31220b72 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -52,29 +52,29 @@ ::tensorflow::Status ResolveReduceAttributes::Run(Model* model, switch (op->type) { case OperatorType::kMean: *modified = ResolveAttributes(model, static_cast(op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case OperatorType::kSum: *modified = ResolveAttributes(model, static_cast(op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case OperatorType::kReduceProd: *modified = ResolveAttributes(model, static_cast(op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case OperatorType::kReduceMin: *modified = ResolveAttributes(model, static_cast(op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case OperatorType::kReduceMax: *modified = ResolveAttributes(model, static_cast(op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); case OperatorType::kAny: *modified = ResolveAttributes(model, static_cast(op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); default: - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc index 24a41b6f40228d..bb86b366d37a58 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -85,7 +85,7 @@ ::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index, auto it = model->operators.begin() + op_index; auto* op = it->get(); if (op->type != OperatorType::kReorderAxes) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* reorder_op = static_cast(op); @@ -96,11 +96,11 @@ ::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index, auto& input_array = model->GetArray(input_array_name); auto& output_array = model->GetArray(output_array_name); if (!input_array.buffer) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Yield until output dims have been resolved. if (!output_array.has_shape()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Reorder the input array dims and buffer data if (input_array.buffer->type == ArrayDataType::kFloat) { @@ -124,7 +124,7 @@ ::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index, RenameArray(model, output_array_name, input_array_name); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc index b361dd1b4af7f9..e4d4c0db3d96af 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -32,22 +32,22 @@ ::tensorflow::Status ResolveReshapeAttributes::Run(Model* model, const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); if (reshape_op->type != OperatorType::kReshape) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* op = static_cast(reshape_op); - if (!op->shape.empty()) return ::tensorflow::OkStatus(); + if (!op->shape.empty()) return absl::OkStatus(); if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]); op->shape = constant_input_array.GetBuffer().data; } - if (op->shape.empty()) return ::tensorflow::OkStatus(); + if (op->shape.empty()) return absl::OkStatus(); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc index 390edddcd770fe..80940815d55524 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -30,22 +30,20 @@ ::tensorflow::Status ResolveSliceAttributes::Run(Model* model, *modified = false; const auto slice_it = model->operators.begin() + op_index; auto* slice_op = slice_it->get(); - if (slice_op->type != OperatorType::kSlice) return ::tensorflow::OkStatus(); + if (slice_op->type != OperatorType::kSlice) return absl::OkStatus(); auto* op = static_cast(slice_op); - if (!op->begin.empty()) return ::tensorflow::OkStatus(); + if (!op->begin.empty()) return absl::OkStatus(); CHECK_EQ(op->inputs.size(), 3); - if (!IsConstantParameterArray(*model, op->inputs[1])) - return ::tensorflow::OkStatus(); - if (!IsConstantParameterArray(*model, op->inputs[2])) - return ::tensorflow::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[1])) return absl::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[2])) return absl::OkStatus(); const auto& begin_array = model->GetArray(op->inputs[1]); - if (!begin_array.has_shape()) return ::tensorflow::OkStatus(); + if (!begin_array.has_shape()) return absl::OkStatus(); const auto& size_array = model->GetArray(op->inputs[2]); - if (!size_array.has_shape()) return ::tensorflow::OkStatus(); + if (!size_array.has_shape()) return absl::OkStatus(); op->begin = begin_array.GetBuffer().data; op->size = size_array.GetBuffer().data; @@ -53,6 +51,6 @@ ::tensorflow::Status ResolveSliceAttributes::Run(Model* model, // TODO(dkalenichenko): Delete the extra inputs? *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index 1538f58b151ad0..c97d7e30bd55f6 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -30,14 +30,14 @@ ::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model, *modified = false; const auto op_it = model->operators.begin() + op_index; if (op_it->get()->type != OperatorType::kSpaceToBatchND) - return ::tensorflow::OkStatus(); + return absl::OkStatus(); auto* op = static_cast(op_it->get()); // The attributes are resolved only when the 3 attributes (block_shape, // before_paddings, after_paddings) are all constant. if (!op->block_shape.empty()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const int block_shape_index = 1; @@ -46,16 +46,16 @@ ::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model, CHECK_EQ(op->inputs.size(), 3); if (!IsConstantParameterArray(*model, op->inputs[block_shape_index]) || !IsConstantParameterArray(*model, op->inputs[paddings_index])) - return ::tensorflow::OkStatus(); + return absl::OkStatus(); // Handle paddings. const auto& paddings_array = model->GetArray(op->inputs[paddings_index]); - if (!paddings_array.has_shape()) return ::tensorflow::OkStatus(); + if (!paddings_array.has_shape()) return absl::OkStatus(); const std::vector& paddings_dims = paddings_array.shape().dims(); if (paddings_dims.size() != 2) { // Code only handles padding of 2 dimensions. Perhaps another transformation // will delete this op. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const std::vector& paddings_buffer = paddings_array.GetBuffer().data; @@ -67,7 +67,7 @@ ::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model, // Handle block_shape. const auto& block_shape_array = model->GetArray(op->inputs[block_shape_index]); - if (!block_shape_array.has_shape()) return ::tensorflow::OkStatus(); + if (!block_shape_array.has_shape()) return absl::OkStatus(); const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); const std::vector& block_shape_buffer = @@ -77,7 +77,7 @@ ::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model, } *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index dc6a9ed8eef4f4..9acb34861c6535 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -31,7 +31,7 @@ ::tensorflow::Status ResolveSqueezeAttributes::Run(Model* model, *modified = false; auto* squeeze_op = model->operators[op_index].get(); if (squeeze_op->type != OperatorType::kSqueeze) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } DCHECK_EQ(squeeze_op->inputs.size(), 1); DCHECK_EQ(squeeze_op->outputs.size(), 1); @@ -46,10 +46,10 @@ ::tensorflow::Status ResolveSqueezeAttributes::Run(Model* model, LogName(*squeeze_op)); *modified = RemoveTrivialPassthroughOp(this, model, op_index); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index 7aecc605686d36..0d6334fe8e7af9 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -43,41 +43,37 @@ ::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model, *modified = false; const auto slice_it = model->operators.begin() + op_index; auto* slice_op = slice_it->get(); - if (slice_op->type != OperatorType::kStridedSlice) - return ::tensorflow::OkStatus(); + if (slice_op->type != OperatorType::kStridedSlice) return absl::OkStatus(); auto* op = static_cast(slice_op); if (!op->start_indices.empty()) { // We have already resolved these attributes - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(op->inputs.size(), 4); const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // We require the dimensionality of the input to pad the indices - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto& start_array = model->GetArray(op->inputs[1]); - if (!start_array.has_shape()) return ::tensorflow::OkStatus(); + if (!start_array.has_shape()) return absl::OkStatus(); if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) { // Only 1-4D arrays are supported for now. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto& stop_array = model->GetArray(op->inputs[2]); - if (!stop_array.has_shape()) return ::tensorflow::OkStatus(); + if (!stop_array.has_shape()) return absl::OkStatus(); auto& stride_array = model->GetArray(op->inputs[3]); - if (!stride_array.has_shape()) return ::tensorflow::OkStatus(); + if (!stride_array.has_shape()) return absl::OkStatus(); - if (!IsConstantParameterArray(*model, op->inputs[1])) - return ::tensorflow::OkStatus(); - if (!IsConstantParameterArray(*model, op->inputs[2])) - return ::tensorflow::OkStatus(); - if (!IsConstantParameterArray(*model, op->inputs[3])) - return ::tensorflow::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[1])) return absl::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[2])) return absl::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[3])) return absl::OkStatus(); int num_input_axes = input_array.shape().dimensions_count(); int start_indices_size = start_array.shape().dims(0); @@ -120,6 +116,6 @@ ::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model, op->strides = stride_array.GetBuffer().data; *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index 9939381517ed2f..d20c499355a34c 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -33,7 +33,7 @@ ::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model, const auto* tf_concat_op = concat_it->get(); if (tf_concat_op->type != OperatorType::kConcat && tf_concat_op->type != OperatorType::kConcatV2) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_GE(tf_concat_op->inputs.size(), 2); @@ -57,7 +57,7 @@ ::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model, if (!axis_array.buffer) { AddMessageF("Waiting for the axis of %s to be resolved to a constant", LogName(*tf_concat_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK(axis_array.data_type == ArrayDataType::kInt32); @@ -75,7 +75,7 @@ ::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model, DeleteOpAndArrays(model, tf_concat_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 6c347080790e6f..396e71ae286e74 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -62,7 +62,7 @@ ::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model, *modified = false; auto matmul_it = model->operators.begin() + op_index; if (matmul_it->get()->type != OperatorType::kMatMul) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* matmul_op = static_cast(matmul_it->get()); @@ -87,7 +87,7 @@ ::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model, "Not replacing %s by a FullyConnected operator, because it has " "the transpose_a attribute and LHS has no shape", LogName(*matmul_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } int dimensions_count = lhs_array.shape().dimensions_count(); @@ -228,7 +228,7 @@ ::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model, // erase the MatMul operator model->operators.erase(matmul_it); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 1569427fa5c12f..2fb79f5c0008b0 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -31,7 +31,7 @@ ::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model, const auto merge_it = model->operators.begin() + op_index; const auto* merge_op = merge_it->get(); if (merge_op->type != OperatorType::kMerge) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // We need to yield until this Merge node has only 1 input, which will mean @@ -40,7 +40,7 @@ ::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model, // non-selected inputs, so that at some point there will be only 1 input left. if (merge_op->inputs.size() > 1) { AddMessageF("Waiting for %s to be resolved", LogName(*merge_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Now that the merge node has 1 input exactly, it is the same as an Identity @@ -58,7 +58,7 @@ ::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model, DeleteOpAndArrays(model, merge_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index 390a3a76dddef0..9aefb8799fff1b 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -32,7 +32,7 @@ ::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model, const auto switch_it = model->operators.begin() + op_index; const auto* switch_op = switch_it->get(); if (switch_op->type != OperatorType::kSwitch) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(switch_op->inputs.size(), 2); @@ -44,7 +44,7 @@ ::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model, AddMessageF( "Waiting for the boolean predicate of %s to be resolved to a constant", LogName(*switch_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // The predicate should be boolean, and should consist of a single value. @@ -132,7 +132,7 @@ ::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model, AddMessageF("Removing already-resolved %s", LogName(*switch_op)); DeleteOpAndArrays(model, switch_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc index 258156244b8006..d10cf3b143e96f 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc @@ -29,19 +29,17 @@ ::tensorflow::Status ResolveTransposeAttributes::Run(Model* model, bool* modified) { *modified = false; const auto op_it = model->operators.begin() + op_index; - if (op_it->get()->type != OperatorType::kTranspose) - return ::tensorflow::OkStatus(); + if (op_it->get()->type != OperatorType::kTranspose) return absl::OkStatus(); auto* op = static_cast(op_it->get()); - if (!op->perm.empty()) return ::tensorflow::OkStatus(); + if (!op->perm.empty()) return absl::OkStatus(); CHECK_EQ(op->inputs.size(), 2); - if (!IsConstantParameterArray(*model, op->inputs[1])) - return ::tensorflow::OkStatus(); + if (!IsConstantParameterArray(*model, op->inputs[1])) return absl::OkStatus(); // Handling perm. const auto& perm_array = model->GetArray(op->inputs[1]); - if (!perm_array.has_shape()) return ::tensorflow::OkStatus(); + if (!perm_array.has_shape()) return absl::OkStatus(); const std::vector& perm_dims = perm_array.shape().dims(); CHECK_EQ(perm_dims.size(), 1); @@ -53,7 +51,7 @@ ::tensorflow::Status ResolveTransposeAttributes::Run(Model* model, } *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc b/tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc index 253cb0f77c5759..1f758076772044 100644 --- a/tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc +++ b/tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -30,12 +30,12 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, *modified = false; Operator* op = model->operators[op_index].get(); if (op->type != OperatorType::kFullyConnected) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } FullyConnectedOperator* fc_op = static_cast(op); // Exit if this FC op already has shuffled weights if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const Array& input_array = model->GetArray(fc_op->inputs[0]); const std::string& weights_name = fc_op->inputs[1]; @@ -49,11 +49,11 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, output_array.data_type != ArrayDataType::kInt16 || !input_array.quantization_params || !weights_array.quantization_params || !output_array.quantization_params) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Exit if the shapes aren't known if (!input_array.has_shape() || !weights_array.has_shape()) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Exit if, based on the known shapes, this FC op is not a GEMV. // The shuffling of FC weights is only useful to enable fast GEMV paths. @@ -67,7 +67,7 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, "the input shape is not 1D or 2D (possibly with additional inner " "dimensions of size 1)", LogName(*op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) { @@ -76,7 +76,7 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, "the input shape's leading dimension, i.e. the 'batch size', is not " "equal to 1 or 4", LogName(*op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Exit if the weights shape isn't an integral multiple of the shuffled // block shape, 4x16. We don't want to have to write code dealing with @@ -91,7 +91,7 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, // two. const Shape& weights_shape = weights_array.shape(); if (weights_shape.dimensions_count() != 2) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const int rows = weights_shape.dims(0); const int cols = weights_shape.dims(1); @@ -100,11 +100,11 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, "Not applying experimental shuffling to the weights of %s because its " "shape isn't a multiple of the shuffling block shape, 4x16", LogName(*op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Exit if the weights aren't already a constant array. if (!weights_array.buffer) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Exit if the weights are used by more than one op. if (CountOpsWithInput(*model, weights_name) != 1) { @@ -112,7 +112,7 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, "Not applying experimental shuffling to the weights of %s because that " "array is consumed by other operators", LogName(*op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // Compute the shuffled weights auto& weights_data = @@ -156,7 +156,7 @@ ::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, input_array.GetQuantizationParams(); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc index ef4a92b1d46bf7..52f9ea27171091 100644 --- a/tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -34,7 +34,7 @@ ::tensorflow::Status UnfuseActivationFunctions::Run(Model* model, // If a conv operation has an im2col array, yield: it should be dropped first. if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } Operator* ac_op = nullptr; @@ -49,7 +49,7 @@ ::tensorflow::Status UnfuseActivationFunctions::Run(Model* model, ac_op = new Relu1Operator; break; default: - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } // At this point we know that the op has a fused activation function. At the @@ -78,7 +78,7 @@ ::tensorflow::Status UnfuseActivationFunctions::Run(Model* model, ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index f030d8fdbd77ca..374a98605c82e5 100644 --- a/tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -50,7 +50,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, // First look for the final DynamicStitch. auto op_it = model->operators.begin() + op_index; if (op_it->get()->type != OperatorType::kDynamicStitch) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto* stitch_op = static_cast(op_it->get()); @@ -77,7 +77,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, "Skipping because indices input %s into " "%s is unexpected", LogName(*op), LogName(*stitch_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!indices_partition_op) { indices_partition_op = static_cast(op); @@ -88,7 +88,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, "Skipping because indices input %s into " "%s is from a different source op than others", LogName(*op), LogName(*stitch_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } } @@ -97,12 +97,12 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, // The data for the indices must be a constant range of the array shape. if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) { AddMessageF("Skipping because indices partition data is non-constant"); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]); if (indices_data_array.data_type == ArrayDataType::kNone) { // Yield until data types are propagated. - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } CHECK(indices_data_array.data_type == ArrayDataType::kInt32) << "Indices partition inputs must be int32"; @@ -122,7 +122,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, "Skipping because data input %s into %s " "is unexpected", LogName(*op), LogName(*stitch_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } gather_ops.push_back(static_cast(op)); } @@ -137,7 +137,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, "Skipping because data input %s into " "%s is unexpected", LogName(*op), LogName(*gather_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (!data_partition_op) { data_partition_op = static_cast(op); @@ -148,7 +148,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, "Skipping because data input %s into " "%s is from a different source op than others", LogName(*op), LogName(*gather_op)); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } } @@ -242,7 +242,7 @@ ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, DeleteOpAndArrays(model, data_partition_op); DeleteOpAndArrays(model, stitch_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc index 4bc4c19cbfaf43..d241e668ded88a 100644 --- a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -138,7 +138,7 @@ ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, *modified = false; auto batch_op_it = model->operators.begin() + op_index; if (batch_op_it->get()->type != OperatorType::kBatchMatMul) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } const auto* batch_op = static_cast(batch_op_it->get()); @@ -149,7 +149,7 @@ ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, const auto& input_lhs_array = model->GetArray(input_lhs); const auto& input_rhs_array = model->GetArray(input_rhs); if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape()) - return ::tensorflow::OkStatus(); + return absl::OkStatus(); // Transpose LHS input if necessary. if (batch_op->adj_x) { @@ -194,7 +194,7 @@ ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, model->operators.emplace(tail_it, matmul_op); DeleteOpAndArrays(model, batch_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), bcast.output_batch_size()); @@ -262,7 +262,7 @@ ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, DeleteOpAndArrays(model, batch_op); *modified = true; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace toco diff --git a/tensorflow/lite/tools/build_aar_with_docker.sh b/tensorflow/lite/tools/build_aar_with_docker.sh index 98b091422c0e1b..e815a3dfd06064 100755 --- a/tensorflow/lite/tools/build_aar_with_docker.sh +++ b/tensorflow/lite/tools/build_aar_with_docker.sh @@ -113,7 +113,7 @@ else 'N' 'N' 'Y' - '/usr/lib/llvm-17/bin/clang' + '/usr/lib/llvm-18/bin/clang' '-Wno-sign-compare -Wno-c++20-designator -Wno-gnu-inline-cpp-without-extern' 'y' '/android/sdk' diff --git a/tensorflow/lite/tools/command_line_flags_test.cc b/tensorflow/lite/tools/command_line_flags_test.cc index 8dc1a295ef0e4d..f25279468d3657 100644 --- a/tensorflow/lite/tools/command_line_flags_test.cc +++ b/tensorflow/lite/tools/command_line_flags_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include #include #include "tensorflow/lite/tools/tool_params.h" diff --git a/tensorflow/lite/tools/signature/signature_def_util.cc b/tensorflow/lite/tools/signature/signature_def_util.cc index 9d36d2d08e90fe..f3d5657edc64d6 100644 --- a/tensorflow/lite/tools/signature/signature_def_util.cc +++ b/tensorflow/lite/tools/signature/signature_def_util.cc @@ -62,7 +62,7 @@ Status ReadSignatureDefMap(const Model* model, const Metadata* metadata, const std::string key = signature_defs.Keys()[i].AsString().c_str(); (*map)[key] = signature_defs[key].AsString().c_str(); } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -105,7 +105,7 @@ Status SetSignatureDefMap(const Model* model, *model_data_with_signature_def = std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize()); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } bool HasSignatureDef(const Model* model, const std::string& signature_key) { @@ -118,7 +118,7 @@ bool HasSignatureDef(const Model* model, const std::string& signature_key) { } SerializedSignatureDefMap signature_defs; if (ReadSignatureDefMap(model, metadata, &signature_defs) != - ::tensorflow::OkStatus()) { + absl::OkStatus()) { return false; } return (signature_defs.find(signature_key) != signature_defs.end()); @@ -134,7 +134,7 @@ Status GetSignatureDefMap(const Model* model, if (metadata) { SerializedSignatureDefMap signature_defs; auto status = ReadSignatureDefMap(model, metadata, &signature_defs); - if (status != ::tensorflow::OkStatus()) { + if (status != absl::OkStatus()) { return tensorflow::errors::Internal("Error reading signature def map: ", status.message()); } @@ -148,7 +148,7 @@ Status GetSignatureDefMap(const Model* model, } *signature_def_map = retrieved_signature_def_map; } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } Status ClearSignatureDefMap(const Model* model, std::string* model_data) { @@ -172,7 +172,7 @@ Status ClearSignatureDefMap(const Model* model, std::string* model_data) { *model_data = std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize()); - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace tflite diff --git a/tensorflow/lite/tools/signature/signature_def_util_test.cc b/tensorflow/lite/tools/signature/signature_def_util_test.cc index 51da756059c550..38b064707a48a5 100644 --- a/tensorflow/lite/tools/signature/signature_def_util_test.cc +++ b/tensorflow/lite/tools/signature/signature_def_util_test.cc @@ -69,12 +69,12 @@ TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefTest) { const std::map expected_signature_def_map = { {kDefaultServingSignatureDefKey, expected_signature_def}}; EXPECT_EQ( - ::tensorflow::OkStatus(), + absl::OkStatus(), SetSignatureDefMap(model_, expected_signature_def_map, &model_output)); const Model* add_model = flatbuffers::GetRoot(model_output.data()); EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey)); std::map test_signature_def_map; - EXPECT_EQ(::tensorflow::OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetSignatureDefMap(add_model, &test_signature_def_map)); SignatureDef test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey]; @@ -88,12 +88,12 @@ TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) { std::map expected_signature_def_map = { {kDefaultServingSignatureDefKey, expected_signature_def}}; EXPECT_EQ( - ::tensorflow::OkStatus(), + absl::OkStatus(), SetSignatureDefMap(model_, expected_signature_def_map, &model_output)); const Model* add_model = flatbuffers::GetRoot(model_output.data()); EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey)); std::map test_signature_def_map; - EXPECT_EQ(::tensorflow::OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetSignatureDefMap(add_model, &test_signature_def_map)); SignatureDef test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey]; @@ -105,16 +105,16 @@ TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) { constexpr char kTestSignatureDefKey[] = "ServingTest"; expected_signature_def_map[kTestSignatureDefKey] = expected_signature_def; EXPECT_EQ( - ::tensorflow::OkStatus(), + absl::OkStatus(), SetSignatureDefMap(add_model, expected_signature_def_map, &model_output)); const Model* final_model = flatbuffers::GetRoot(model_output.data()); EXPECT_FALSE(HasSignatureDef(final_model, kDefaultServingSignatureDefKey)); - EXPECT_EQ(::tensorflow::OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetSignatureDefMap(final_model, &test_signature_def_map)); EXPECT_NE(expected_signature_def.SerializeAsString(), test_signature_def.SerializeAsString()); EXPECT_TRUE(HasSignatureDef(final_model, kTestSignatureDefKey)); - EXPECT_EQ(::tensorflow::OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetSignatureDefMap(final_model, &test_signature_def_map)); test_signature_def = test_signature_def_map[kTestSignatureDefKey]; EXPECT_EQ(expected_signature_def.SerializeAsString(), @@ -123,7 +123,7 @@ TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) { TEST_F(SimpleSignatureDefUtilTest, GetSignatureDefTest) { std::map test_signature_def_map; - EXPECT_EQ(::tensorflow::OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetSignatureDefMap(model_, &test_signature_def_map)); EXPECT_FALSE(HasSignatureDef(model_, kDefaultServingSignatureDefKey)); } @@ -135,19 +135,18 @@ TEST_F(SimpleSignatureDefUtilTest, ClearSignatureDefTest) { std::map expected_signature_def_map = { {kDefaultServingSignatureDefKey, expected_signature_def}}; EXPECT_EQ( - ::tensorflow::OkStatus(), + absl::OkStatus(), SetSignatureDefMap(model_, expected_signature_def_map, &model_output)); const Model* add_model = flatbuffers::GetRoot(model_output.data()); EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey)); SignatureDef test_signature_def; std::map test_signature_def_map; - EXPECT_EQ(::tensorflow::OkStatus(), + EXPECT_EQ(absl::OkStatus(), GetSignatureDefMap(add_model, &test_signature_def_map)); test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey]; EXPECT_EQ(expected_signature_def.SerializeAsString(), test_signature_def.SerializeAsString()); - EXPECT_EQ(::tensorflow::OkStatus(), - ClearSignatureDefMap(add_model, &model_output)); + EXPECT_EQ(absl::OkStatus(), ClearSignatureDefMap(add_model, &model_output)); const Model* clear_model = flatbuffers::GetRoot(model_output.data()); EXPECT_FALSE(HasSignatureDef(clear_model, kDefaultServingSignatureDefKey)); EXPECT_EQ(expected_num_buffers, clear_model->buffers()->size()); diff --git a/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc b/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc index a0c80eb2615861..75b71ad6fd0617 100644 --- a/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc +++ b/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc @@ -40,7 +40,7 @@ py::bytes WrappedSetSignatureDefMap( signature_def_map[entry.first] = signature_def; } auto status = tflite::SetSignatureDefMap(model, signature_def_map, &data); - if (status != ::tensorflow::OkStatus()) { + if (status != absl::OkStatus()) { throw std::invalid_argument(std::string(status.message())); } return py::bytes(data); @@ -57,7 +57,7 @@ std::map WrappedGetSignatureDefMap( std::string content; std::map signature_def_map; auto status = tflite::GetSignatureDefMap(model, &signature_def_map); - if (status != ::tensorflow::OkStatus()) { + if (status != absl::OkStatus()) { throw std::invalid_argument("Cannot parse signature def"); } std::map serialized_signature_def_map; @@ -77,7 +77,7 @@ py::bytes WrappedClearSignatureDefs(const std::vector& model_buffer) { } std::string content; auto status = tflite::ClearSignatureDefMap(model, &content); - if (status != ::tensorflow::OkStatus()) { + if (status != absl::OkStatus()) { throw std::invalid_argument("An unknown error occurred"); } return py::bytes(content); diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 0ad8d5efd15975..41d91d9ae07bf0 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 4, 16) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 4, 17) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 0ee2c99f45d690..93f33fdd02d64f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -243,7 +243,7 @@ tf_py_strict_test( tf_py_strict_test( name = "index_flat_map_test", srcs = ["index_flat_map_test.py"], - shard_count = 4, + shard_count = 8, deps = [ "//tensorflow/python/data/experimental/ops:cardinality", "//tensorflow/python/data/experimental/ops:global_shuffle_op", diff --git a/tensorflow/python/data/experimental/kernel_tests/index_flat_map_test.py b/tensorflow/python/data/experimental/kernel_tests/index_flat_map_test.py index 3dae41bc4e3193..9135a9b9a1fd44 100644 --- a/tensorflow/python/data/experimental/kernel_tests/index_flat_map_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/index_flat_map_test.py @@ -14,7 +14,7 @@ # ============================================================================== """Tests for the index flat map dataset.""" -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union from absl.testing import parameterized @@ -78,8 +78,18 @@ def test_cache(self): self.assertEqual(output, [b"0", b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8"]) - @combinations.generate(test_base.default_test_combinations()) - def test_global_shuffle(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + repetitions=[1, 3], + seed=[None, 42], + reshuffle_each_iteration=[True, False]))) + def test_global_shuffle( + self, + repetitions: int, + seed: Optional[int], + reshuffle_each_iteration: bool): input_data = ["0 1", "2 3 4 5", "6 7", "8"] metadata = _get_metadata(input_data) @@ -87,11 +97,15 @@ def test_global_shuffle(self): dataset = index_flat_map_op.index_flat_map( dataset, _split, _get_index_map_func(metadata)) dataset = dataset.apply(cardinality_lib.assert_cardinality(9)) - dataset = global_shuffle_op._global_shuffle(dataset) + if repetitions > 1: + dataset = dataset.repeat(repetitions) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration) dataset_output = self.getDatasetOutput( dataset, requires_initialization=True) - expected = [b"0", b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8"] + expected = [ + b"0", b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8"] * repetitions self.assertCountEqual(dataset_output, expected) self.assertNotEqual(dataset_output, expected) @@ -112,6 +126,20 @@ def _index_map_func(index: int) -> tuple[int, int]: dataset, _map_func, _index_map_func) self.assertDatasetProduces(dataset, list(range(dataset_range))) + @combinations.generate(test_base.default_test_combinations()) + def test_offset_out_of_range(self): + + def _index_map_func(_) -> tuple[int, int]: + return (0, 1000) + + input_data = ["0 1", "2 3 4 5", "6 7", "8"] + dataset = dataset_ops.Dataset.from_tensor_slices(input_data) + dataset = index_flat_map_op.index_flat_map(dataset, _split, _index_map_func) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "invalid `index_map_fn` which returns offset 1000"): + self.getDatasetOutput(dataset) + @combinations.generate(test_base.default_test_combinations()) def test_invalid_map_fn(self): @@ -121,8 +149,7 @@ def _index_map_func(_) -> str: input_data = ["0 1", "2 3 4 5", "6 7", "8"] dataset = dataset_ops.Dataset.from_tensor_slices(input_data) - dataset = index_flat_map_op.index_flat_map( - dataset, _split, _index_map_func) + dataset = index_flat_map_op.index_flat_map(dataset, _split, _index_map_func) with self.assertRaisesRegex( errors.InvalidArgumentError, "expected to return two int values"): @@ -132,51 +159,58 @@ def _index_map_func(_) -> str: class IndexFlatMapCheckpointTest( checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): - # TODO(b/325112575): Support the graph mode. @combinations.generate( combinations.times( - test_base.eager_only_combinations(), + test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), - combinations.combine(symbolic_checkpoint=[True, False]))) + combinations.combine( + repetitions=[1, 3], + symbolic_checkpoint=[True, False]))) def test_index_flat_map( self, verify_fn: Callable[..., None], + repetitions: int, symbolic_checkpoint: bool): input_data = ["0 1", "2 3 4 5", "6 7", "8"] - metadata = _get_metadata(input_data) def _build_dataset() -> dataset_ops.Dataset: dataset = dataset_ops.Dataset.from_tensor_slices(input_data) dataset = index_flat_map_op.index_flat_map( - dataset, _split, _get_index_map_func(metadata)) + dataset, _split, _get_index_map_func(_get_metadata(input_data))) + dataset = dataset.apply(cardinality_lib.assert_cardinality(9)) + if repetitions > 1: + dataset = dataset.repeat(repetitions) options = options_lib.Options() options.experimental_symbolic_checkpoint = symbolic_checkpoint return dataset.with_options(options) - verify_fn(self, _build_dataset, num_outputs=9) + verify_fn(self, _build_dataset, num_outputs=9 * repetitions) @combinations.generate( combinations.times( - test_base.eager_only_combinations(), + test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine( + repetitions=[1, 3], reshuffle_each_iteration=[True, False], symbolic_checkpoint=[True, False]))) def test_global_shuffle( self, verify_fn: Callable[..., None], + repetitions: list[int], reshuffle_each_iteration: bool, symbolic_checkpoint: bool): input_data = ["0 1", "2 3 4 5", "6 7", "8"] - metadata = _get_metadata(input_data) def _build_dataset() -> dataset_ops.Dataset: dataset = dataset_ops.Dataset.from_tensor_slices(input_data) dataset = index_flat_map_op.index_flat_map( - dataset, _split, _get_index_map_func(metadata)) + dataset, _split, _get_index_map_func(_get_metadata(input_data))) dataset = dataset.apply(cardinality_lib.assert_cardinality(9)) + if repetitions > 1: + dataset = dataset.repeat(repetitions) dataset = global_shuffle_op._global_shuffle( dataset, seed=42, reshuffle_each_iteration=reshuffle_each_iteration) @@ -187,7 +221,7 @@ def _build_dataset() -> dataset_ops.Dataset: verify_fn( self, _build_dataset, - num_outputs=9, + num_outputs=9 * repetitions, assert_items_equal=reshuffle_each_iteration) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index bfdb79b802c64b..daf24612198162 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -94,6 +94,7 @@ tf_py_strict_test( ":test_base", "//tensorflow/python/checkpoint", "//tensorflow/python/checkpoint:checkpoint_management", + "//tensorflow/python/data/experimental/ops:global_shuffle_op", "//tensorflow/python/data/experimental/ops:random_access", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:options", diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py index fce4c7bd30c3f9..7e0324cdd11921 100644 --- a/tensorflow/python/data/kernel_tests/cache_test.py +++ b/tensorflow/python/data/kernel_tests/cache_test.py @@ -18,11 +18,13 @@ from os import path import shutil import tempfile +from typing import Optional from absl.testing import parameterized import numpy as np from tensorflow.python.checkpoint import checkpoint as trackable_utils from tensorflow.python.checkpoint import checkpoint_management +from tensorflow.python.data.experimental.ops import global_shuffle_op from tensorflow.python.data.experimental.ops import random_access from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base @@ -692,5 +694,38 @@ def testCacheInputDatasetInfiniteCardinality(self): # with caching will cache through index 11. self.verifyRandomAccessInfiniteCardinality(dataset, expected) + +class CacheGlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + dataset_range=[10], + repetitions=[1, 2], + seed=[None, 42], + reshuffle_each_iteration=[True, False]))) + def test( + self, + dataset_range: int, + repetitions: int, + seed: Optional[int], + reshuffle_each_iteration: bool): + dataset = dataset_ops.Dataset.range(dataset_range) + dataset = dataset.cache() + dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) + if repetitions > 1: + dataset = dataset.repeat(repetitions) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration) + + expected = list(range(0, dataset_range)) * repetitions + dataset_output = self.getDatasetOutput( + dataset, requires_initialization=True) + self.assertCountEqual(dataset_output, expected) + self.assertNotEqual(dataset_output, expected) + self.assertLen(dataset_output, self.evaluate(dataset.cardinality())) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/ci_build/install/install_clang_18.sh b/tensorflow/tools/ci_build/install/install_clang_18.sh new file mode 100755 index 00000000000000..f451b3a5356b6e --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_clang_18.sh @@ -0,0 +1,31 @@ +#!/bin/bash -eu +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# LLVM/Clang: https://apt.llvm.org/ +apt-key adv --fetch-keys https://apt.llvm.org/llvm-snapshot.gpg.key + +# Set up custom sources +cat >/etc/apt/sources.list.d/custom.list <`_) -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h ---- a/clang/include/clang/Sema/Sema.h -+++ b/clang/include/clang/Sema/Sema.h -@@ -5452,8 +5452,7 @@ - - ExprResult BuildDeclarationNameExpr(const CXXScopeSpec &SS, LookupResult &R, - bool NeedsADL, -- bool AcceptInvalidDecl = false, -- bool NeedUnresolved = false); -+ bool AcceptInvalidDecl = false); - ExprResult BuildDeclarationNameExpr( - const CXXScopeSpec &SS, const DeclarationNameInfo &NameInfo, NamedDecl *D, - NamedDecl *FoundD = nullptr, -@@ -6596,10 +6595,7 @@ - SourceLocation RParenLoc); - - //// ActOnCXXThis - Parse 'this' pointer. -- ExprResult ActOnCXXThis(SourceLocation Loc); -- -- /// Check whether the type of 'this' is valid in the current context. -- bool CheckCXXThisType(SourceLocation Loc, QualType Type); -+ ExprResult ActOnCXXThis(SourceLocation loc); - - /// Build a CXXThisExpr and mark it referenced in the current context. - Expr *BuildCXXThisExpr(SourceLocation Loc, QualType Type, bool IsImplicit); -@@ -7022,14 +7018,10 @@ - ///@{ - - public: -- /// Check whether an expression might be an implicit class member access. -- bool isPotentialImplicitMemberAccess(const CXXScopeSpec &SS, LookupResult &R, -- bool IsAddressOfOperand); -- - ExprResult BuildPossibleImplicitMemberExpr( - const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, LookupResult &R, -- const TemplateArgumentListInfo *TemplateArgs, const Scope *S); -- -+ const TemplateArgumentListInfo *TemplateArgs, const Scope *S, -+ UnresolvedLookupExpr *AsULE = nullptr); - ExprResult - BuildImplicitMemberExpr(const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, - LookupResult &R, -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp ---- a/clang/lib/Sema/SemaExpr.cpp -+++ b/clang/lib/Sema/SemaExpr.cpp -@@ -2917,9 +2917,26 @@ - // to get this right here so that we don't end up making a - // spuriously dependent expression if we're inside a dependent - // instance method. -- if (isPotentialImplicitMemberAccess(SS, R, IsAddressOfOperand)) -- return BuildPossibleImplicitMemberExpr(SS, TemplateKWLoc, R, TemplateArgs, -- S); -+ if (getLangOpts().CPlusPlus && !R.empty() && -+ (*R.begin())->isCXXClassMember()) { -+ bool MightBeImplicitMember; -+ if (!IsAddressOfOperand) -+ MightBeImplicitMember = true; -+ else if (!SS.isEmpty()) -+ MightBeImplicitMember = false; -+ else if (R.isOverloadedResult()) -+ MightBeImplicitMember = false; -+ else if (R.isUnresolvableResult()) -+ MightBeImplicitMember = true; -+ else -+ MightBeImplicitMember = isa(R.getFoundDecl()) || -+ isa(R.getFoundDecl()) || -+ isa(R.getFoundDecl()); -+ -+ if (MightBeImplicitMember) -+ return BuildPossibleImplicitMemberExpr(SS, TemplateKWLoc, -+ R, TemplateArgs, S); -+ } - - if (TemplateArgs || TemplateKWLoc.isValid()) { - -@@ -3430,11 +3447,10 @@ - - ExprResult Sema::BuildDeclarationNameExpr(const CXXScopeSpec &SS, - LookupResult &R, bool NeedsADL, -- bool AcceptInvalidDecl, -- bool NeedUnresolved) { -+ bool AcceptInvalidDecl) { - // If this is a single, fully-resolved result and we don't need ADL, - // just build an ordinary singleton decl ref. -- if (!NeedUnresolved && !NeedsADL && R.isSingleResult() && -+ if (!NeedsADL && R.isSingleResult() && - !R.getAsSingle() && - !ShouldLookupResultBeMultiVersionOverload(R)) - return BuildDeclarationNameExpr(SS, R.getLookupNameInfo(), R.getFoundDecl(), -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp ---- a/clang/lib/Sema/SemaExprCXX.cpp -+++ b/clang/lib/Sema/SemaExprCXX.cpp -@@ -1416,42 +1416,26 @@ - } - - ExprResult Sema::ActOnCXXThis(SourceLocation Loc) { -- // C++20 [expr.prim.this]p1: -- // The keyword this names a pointer to the object for which an -- // implicit object member function is invoked or a non-static -- // data member's initializer is evaluated. -+ /// C++ 9.3.2: In the body of a non-static member function, the keyword this -+ /// is a non-lvalue expression whose value is the address of the object for -+ /// which the function is called. - QualType ThisTy = getCurrentThisType(); - -- if (CheckCXXThisType(Loc, ThisTy)) -- return ExprError(); -+ if (ThisTy.isNull()) { -+ DeclContext *DC = getFunctionLevelDeclContext(); - -- return BuildCXXThisExpr(Loc, ThisTy, /*IsImplicit=*/false); --} -+ if (const auto *Method = dyn_cast(DC); -+ Method && Method->isExplicitObjectMemberFunction()) { -+ return Diag(Loc, diag::err_invalid_this_use) << 1; -+ } - --bool Sema::CheckCXXThisType(SourceLocation Loc, QualType Type) { -- if (!Type.isNull()) -- return false; -+ if (isLambdaCallWithExplicitObjectParameter(CurContext)) -+ return Diag(Loc, diag::err_invalid_this_use) << 1; - -- // C++20 [expr.prim.this]p3: -- // If a declaration declares a member function or member function template -- // of a class X, the expression this is a prvalue of type -- // "pointer to cv-qualifier-seq X" wherever X is the current class between -- // the optional cv-qualifier-seq and the end of the function-definition, -- // member-declarator, or declarator. It shall not appear within the -- // declaration of either a static member function or an explicit object -- // member function of the current class (although its type and value -- // category are defined within such member functions as they are within -- // an implicit object member function). -- DeclContext *DC = getFunctionLevelDeclContext(); -- if (const auto *Method = dyn_cast(DC); -- Method && Method->isExplicitObjectMemberFunction()) { -- Diag(Loc, diag::err_invalid_this_use) << 1; -- } else if (isLambdaCallWithExplicitObjectParameter(CurContext)) { -- Diag(Loc, diag::err_invalid_this_use) << 1; -- } else { -- Diag(Loc, diag::err_invalid_this_use) << 0; -+ return Diag(Loc, diag::err_invalid_this_use) << 0; - } -- return true; -+ -+ return BuildCXXThisExpr(Loc, ThisTy, /*IsImplicit=*/false); - } - - Expr *Sema::BuildCXXThisExpr(SourceLocation Loc, QualType Type, -@@ -8658,8 +8642,21 @@ - - // Detect and handle the case where the decl might be an implicit - // member. -- if (SemaRef.isPotentialImplicitMemberAccess( -- NewSS, R, Consumer.isAddressOfOperand())) -+ bool MightBeImplicitMember; -+ if (!Consumer.isAddressOfOperand()) -+ MightBeImplicitMember = true; -+ else if (!NewSS.isEmpty()) -+ MightBeImplicitMember = false; -+ else if (R.isOverloadedResult()) -+ MightBeImplicitMember = false; -+ else if (R.isUnresolvableResult()) -+ MightBeImplicitMember = true; -+ else -+ MightBeImplicitMember = isa(ND) || -+ isa(ND) || -+ isa(ND); -+ -+ if (MightBeImplicitMember) - return SemaRef.BuildPossibleImplicitMemberExpr( - NewSS, /*TemplateKWLoc*/ SourceLocation(), R, - /*TemplateArgs*/ nullptr, /*S*/ nullptr); -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprMember.cpp b/clang/lib/Sema/SemaExprMember.cpp ---- a/clang/lib/Sema/SemaExprMember.cpp -+++ b/clang/lib/Sema/SemaExprMember.cpp -@@ -61,10 +61,6 @@ - /// The reference is a contextually-permitted abstract member reference. - IMA_Abstract, - -- /// Whether the context is static is dependent on the enclosing template (i.e. -- /// in a dependent class scope explicit specialization). -- IMA_Dependent, -- - /// The reference may be to an unresolved using declaration and the - /// context is not an instance method. - IMA_Unresolved_StaticOrExplicitContext, -@@ -95,18 +91,10 @@ - - DeclContext *DC = SemaRef.getFunctionLevelDeclContext(); - -- bool couldInstantiateToStatic = false; -- bool isStaticOrExplicitContext = SemaRef.CXXThisTypeOverride.isNull(); -- -- if (auto *MD = dyn_cast(DC)) { -- if (MD->isImplicitObjectMemberFunction()) { -- isStaticOrExplicitContext = false; -- // A dependent class scope function template explicit specialization -- // that is neither declared 'static' nor with an explicit object -- // parameter could instantiate to a static or non-static member function. -- couldInstantiateToStatic = MD->getDependentSpecializationInfo(); -- } -- } -+ bool isStaticOrExplicitContext = -+ SemaRef.CXXThisTypeOverride.isNull() && -+ (!isa(DC) || cast(DC)->isStatic() || -+ cast(DC)->isExplicitObjectMemberFunction()); - - if (R.isUnresolvableResult()) - return isStaticOrExplicitContext ? IMA_Unresolved_StaticOrExplicitContext -@@ -135,9 +123,6 @@ - if (Classes.empty()) - return IMA_Static; - -- if (couldInstantiateToStatic) -- return IMA_Dependent; -- - // C++11 [expr.prim.general]p12: - // An id-expression that denotes a non-static data member or non-static - // member function of a class can only be used: -@@ -278,52 +263,32 @@ - } - } - --bool Sema::isPotentialImplicitMemberAccess(const CXXScopeSpec &SS, -- LookupResult &R, -- bool IsAddressOfOperand) { -- if (!getLangOpts().CPlusPlus) -- return false; -- else if (R.empty() || !R.begin()->isCXXClassMember()) -- return false; -- else if (!IsAddressOfOperand) -- return true; -- else if (!SS.isEmpty()) -- return false; -- else if (R.isOverloadedResult()) -- return false; -- else if (R.isUnresolvableResult()) -- return true; -- else -- return isa(R.getFoundDecl()); --} -- - /// Builds an expression which might be an implicit member expression. - ExprResult Sema::BuildPossibleImplicitMemberExpr( - const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, LookupResult &R, -- const TemplateArgumentListInfo *TemplateArgs, const Scope *S) { -- switch (IMAKind Classification = ClassifyImplicitMemberAccess(*this, R)) { -+ const TemplateArgumentListInfo *TemplateArgs, const Scope *S, -+ UnresolvedLookupExpr *AsULE) { -+ switch (ClassifyImplicitMemberAccess(*this, R)) { - case IMA_Instance: -+ return BuildImplicitMemberExpr(SS, TemplateKWLoc, R, TemplateArgs, true, S); -+ - case IMA_Mixed: - case IMA_Mixed_Unrelated: - case IMA_Unresolved: -- return BuildImplicitMemberExpr( -- SS, TemplateKWLoc, R, TemplateArgs, -- /*IsKnownInstance=*/Classification == IMA_Instance, S); -+ return BuildImplicitMemberExpr(SS, TemplateKWLoc, R, TemplateArgs, false, -+ S); -+ - case IMA_Field_Uneval_Context: - Diag(R.getNameLoc(), diag::warn_cxx98_compat_non_static_member_use) - << R.getLookupNameInfo().getName(); - [[fallthrough]]; - case IMA_Static: - case IMA_Abstract: -- case IMA_Dependent: - case IMA_Mixed_StaticOrExplicitContext: - case IMA_Unresolved_StaticOrExplicitContext: - if (TemplateArgs || TemplateKWLoc.isValid()) -- return BuildTemplateIdExpr(SS, TemplateKWLoc, R, /*RequiresADL=*/false, -- TemplateArgs); -- return BuildDeclarationNameExpr( -- SS, R, /*NeedsADL=*/false, /*AcceptInvalidDecl=*/false, -- /*NeedUnresolved=*/Classification == IMA_Dependent); -+ return BuildTemplateIdExpr(SS, TemplateKWLoc, R, false, TemplateArgs); -+ return AsULE ? AsULE : BuildDeclarationNameExpr(SS, R, false); - - case IMA_Error_StaticOrExplicitContext: - case IMA_Error_Unrelated: -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp -@@ -5097,14 +5097,6 @@ - EnterExpressionEvaluationContext EvalContext( - *this, Sema::ExpressionEvaluationContext::PotentiallyEvaluated); - -- Qualifiers ThisTypeQuals; -- CXXRecordDecl *ThisContext = nullptr; -- if (CXXMethodDecl *Method = dyn_cast(Function)) { -- ThisContext = Method->getParent(); -- ThisTypeQuals = Method->getMethodQualifiers(); -- } -- CXXThisScopeRAII ThisScope(*this, ThisContext, ThisTypeQuals); -- - // Introduce a new scope where local variable instantiations will be - // recorded, unless we're actually a member function within a local - // class, in which case we need to merge our results with the parent -diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ---- a/clang/lib/Sema/TreeTransform.h -+++ b/clang/lib/Sema/TreeTransform.h -@@ -795,9 +795,6 @@ - ParenExpr *PE, DependentScopeDeclRefExpr *DRE, bool IsAddressOfOperand, - TypeSourceInfo **RecoveryTSI); - -- ExprResult TransformUnresolvedLookupExpr(UnresolvedLookupExpr *E, -- bool IsAddressOfOperand); -- - StmtResult TransformOMPExecutableDirective(OMPExecutableDirective *S); - - // FIXME: We use LLVM_ATTRIBUTE_NOINLINE because inlining causes a ridiculous -@@ -3312,13 +3309,12 @@ - - /// Build a new C++ "this" expression. - /// -- /// By default, performs semantic analysis to build a new "this" expression. -- /// Subclasses may override this routine to provide different behavior. -+ /// By default, builds a new "this" expression without performing any -+ /// semantic analysis. Subclasses may override this routine to provide -+ /// different behavior. - ExprResult RebuildCXXThisExpr(SourceLocation ThisLoc, - QualType ThisType, - bool isImplicit) { -- if (getSema().CheckCXXThisType(ThisLoc, ThisType)) -- return ExprError(); - return getSema().BuildCXXThisExpr(ThisLoc, ThisType, isImplicit); - } - -@@ -11369,11 +11365,7 @@ - ExprResult - TreeTransform::TransformAddressOfOperand(Expr *E) { - if (DependentScopeDeclRefExpr *DRE = dyn_cast(E)) -- return getDerived().TransformDependentScopeDeclRefExpr( -- DRE, /*IsAddressOfOperand=*/true, nullptr); -- else if (UnresolvedLookupExpr *ULE = dyn_cast(E)) -- return getDerived().TransformUnresolvedLookupExpr( -- ULE, /*IsAddressOfOperand=*/true); -+ return getDerived().TransformDependentScopeDeclRefExpr(DRE, true, nullptr); - else - return getDerived().TransformExpr(E); - } -@@ -13079,16 +13071,10 @@ - return false; - } - --template --ExprResult TreeTransform::TransformUnresolvedLookupExpr( -- UnresolvedLookupExpr *Old) { -- return TransformUnresolvedLookupExpr(Old, /*IsAddressOfOperand=*/false); --} -- --template -+template - ExprResult --TreeTransform::TransformUnresolvedLookupExpr(UnresolvedLookupExpr *Old, -- bool IsAddressOfOperand) { -+TreeTransform::TransformUnresolvedLookupExpr( -+ UnresolvedLookupExpr *Old) { - LookupResult R(SemaRef, Old->getName(), Old->getNameLoc(), - Sema::LookupOrdinaryName); - -@@ -13120,8 +13106,26 @@ - R.setNamingClass(NamingClass); - } - -- // Rebuild the template arguments, if any. - SourceLocation TemplateKWLoc = Old->getTemplateKeywordLoc(); -+ -+ // If we have neither explicit template arguments, nor the template keyword, -+ // it's a normal declaration name or member reference. -+ if (!Old->hasExplicitTemplateArgs() && !TemplateKWLoc.isValid()) { -+ NamedDecl *D = R.getAsSingle(); -+ // In a C++11 unevaluated context, an UnresolvedLookupExpr might refer to an -+ // instance member. In other contexts, BuildPossibleImplicitMemberExpr will -+ // give a good diagnostic. -+ if (D && D->isCXXInstanceMember()) { -+ return SemaRef.BuildPossibleImplicitMemberExpr(SS, TemplateKWLoc, R, -+ /*TemplateArgs=*/nullptr, -+ /*Scope=*/nullptr); -+ } -+ -+ return getDerived().RebuildDeclarationNameExpr(SS, R, Old->requiresADL()); -+ } -+ -+ // If we have template arguments, rebuild them, then rebuild the -+ // templateid expression. - TemplateArgumentListInfo TransArgs(Old->getLAngleLoc(), Old->getRAngleLoc()); - if (Old->hasExplicitTemplateArgs() && - getDerived().TransformTemplateArguments(Old->getTemplateArgs(), -@@ -13131,23 +13135,6 @@ - return ExprError(); - } - -- // An UnresolvedLookupExpr can refer to a class member. This occurs e.g. when -- // a non-static data member is named in an unevaluated operand, or when -- // a member is named in a dependent class scope function template explicit -- // specialization that is neither declared static nor with an explicit object -- // parameter. -- if (SemaRef.isPotentialImplicitMemberAccess(SS, R, IsAddressOfOperand)) -- return SemaRef.BuildPossibleImplicitMemberExpr( -- SS, TemplateKWLoc, R, -- Old->hasExplicitTemplateArgs() ? &TransArgs : nullptr, -- /*S=*/nullptr); -- -- // If we have neither explicit template arguments, nor the template keyword, -- // it's a normal declaration name or member reference. -- if (!Old->hasExplicitTemplateArgs() && !TemplateKWLoc.isValid()) -- return getDerived().RebuildDeclarationNameExpr(SS, R, Old->requiresADL()); -- -- // If we have template arguments, then rebuild the template-id expression. - return getDerived().RebuildTemplateIdExpr(SS, TemplateKWLoc, R, - Old->requiresADL(), &TransArgs); - } -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/instantiate-using-decl.cpp b/clang/test/SemaTemplate/instantiate-using-decl.cpp ---- a/clang/test/SemaTemplate/instantiate-using-decl.cpp -+++ b/clang/test/SemaTemplate/instantiate-using-decl.cpp -@@ -121,7 +121,7 @@ - (void)&field; - // expected-error@+1 {{call to non-static member function without an object argument}} - (void)method; -- // expected-error@+1 {{must explicitly qualify name of member function when taking its address}} -+ // expected-error@+1 {{call to non-static member function without an object argument}} - (void)&method; - // expected-error@+1 {{call to non-static member function without an object argument}} - method(); -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/ms-function-specialization-class-scope.cpp b/clang/test/SemaTemplate/ms-function-specialization-class-scope.cpp ---- a/clang/test/SemaTemplate/ms-function-specialization-class-scope.cpp -+++ b/clang/test/SemaTemplate/ms-function-specialization-class-scope.cpp -@@ -1,6 +1,7 @@ --// RUN: %clang_cc1 -fms-extensions -fsyntax-only -Wno-unused-value -verify %s --// RUN: %clang_cc1 -fms-extensions -fdelayed-template-parsing -fsyntax-only -Wno-unused-value -verify %s -+// RUN: %clang_cc1 -fms-extensions -fsyntax-only -verify %s -+// RUN: %clang_cc1 -fms-extensions -fdelayed-template-parsing -fsyntax-only -verify %s - -+// expected-no-diagnostics - class A { - public: - template A(U p) {} -@@ -75,104 +76,3 @@ - int f<0>(int); - }; - } -- --namespace UsesThis { -- template -- struct A { -- int x; -- -- static inline int y; -- -- template -- static void f(); -- -- template -- void g(); -- -- template -- static auto h() -> A*; -- -- void i(); -- -- static void j(); -- -- template<> -- void f() { -- this->x; // expected-error {{invalid use of 'this' outside of a non-static member function}} -- x; // expected-error {{invalid use of member 'x' in static member function}} -- A::x; // expected-error {{invalid use of member 'x' in static member function}} -- +x; // expected-error {{invalid use of member 'x' in static member function}} -- +A::x; // expected-error {{invalid use of member 'x' in static member function}} -- &x; // expected-error {{invalid use of member 'x' in static member function}} -- &A::x; -- this->y; // expected-error {{invalid use of 'this' outside of a non-static member function}} -- y; -- A::y; -- +y; -- +A::y; -- &y; -- &A::y; -- f(); -- f(); -- g(); // expected-error {{call to non-static member function without an object argument}} -- g(); // expected-error {{call to non-static member function without an object argument}} -- i(); // expected-error {{call to non-static member function without an object argument}} -- j(); -- &i; // expected-error 2{{must explicitly qualify name of member function when taking its address}} -- &j; -- &A::i; -- &A::j; -- } -- -- template<> -- void g() { -- this->x; -- x; -- A::x; -- +x; -- +A::x; -- &x; -- &A::x; -- this->y; -- y; -- A::y; -- +y; -- +A::y; -- &y; -- &A::y; -- f(); -- f(); -- g(); -- g(); -- i(); -- j(); -- &i; // expected-error 2{{must explicitly qualify name of member function when taking its address}} -- &j; -- &A::i; -- &A::j; -- } -- -- template<> -- auto h() -> decltype(this); // expected-error {{'this' cannot be used in a static member function declaration}} -- }; -- -- template struct A; // expected-note 3{{in instantiation of}} -- -- template -- struct Foo { -- template -- int bar(X x) { -- return 0; -- } -- -- template <> -- int bar(int x) { -- return bar(5.0); // ok -- } -- }; -- -- void call() { -- Foo f; -- f.bar(1); -- } --} -diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp ---- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp -+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp -@@ -24463,23 +24463,6 @@ - if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT)) - return DAG.getSplatVector(NVT, DL, V.getOperand(0)); - -- // extract_subvector(insert_subvector(x,y,c1),c2) -- // --> extract_subvector(y,c2-c1) -- // iff we're just extracting from the inserted subvector. -- if (V.getOpcode() == ISD::INSERT_SUBVECTOR) { -- SDValue InsSub = V.getOperand(1); -- EVT InsSubVT = InsSub.getValueType(); -- unsigned NumInsElts = InsSubVT.getVectorMinNumElements(); -- unsigned InsIdx = V.getConstantOperandVal(2); -- unsigned NumSubElts = NVT.getVectorMinNumElements(); -- if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) && -- TLI.isExtractSubvectorCheap(NVT, InsSubVT, ExtIdx - InsIdx)) { -- SDLoc DL(N); -- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, InsSub, -- DAG.getVectorIdxConstant(ExtIdx - InsIdx, DL)); -- } -- } -- - // Try to move vector bitcast after extract_subv by scaling extraction index: - // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index') - if (V.getOpcode() == ISD::BITCAST && -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll ---- a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll -+++ b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll -@@ -314,8 +314,8 @@ - ; - ; AVX512F-LABEL: vec64_i16_widen_to_i32_factor2_broadcast_to_v2i32_factor2: - ; AVX512F: # %bb.0: --; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -+; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1,10,11,0,1,14,15,u,u,u,u,u,u,u,u] - ; AVX512F-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa %ymm0, (%rcx) -@@ -324,8 +324,8 @@ - ; - ; AVX512DQ-LABEL: vec64_i16_widen_to_i32_factor2_broadcast_to_v2i32_factor2: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1,10,11,0,1,14,15,u,u,u,u,u,u,u,u] - ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa %ymm0, (%rcx) -@@ -981,7 +981,7 @@ - ; AVX512F-NEXT: vpmovsxbd {{.*#+}} xmm0 = [0,5,0,7] - ; AVX512F-NEXT: vmovdqa (%rdi), %ymm1 - ; AVX512F-NEXT: vpaddb (%rsi), %ymm1, %ymm1 --; AVX512F-NEXT: vpermd %ymm1, %ymm0, %ymm0 -+; AVX512F-NEXT: vpermd %zmm1, %zmm0, %zmm0 - ; AVX512F-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa %ymm0, (%rcx) - ; AVX512F-NEXT: vzeroupper -@@ -992,7 +992,7 @@ - ; AVX512DQ-NEXT: vpmovsxbd {{.*#+}} xmm0 = [0,5,0,7] - ; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm1 - ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm1, %ymm1 --; AVX512DQ-NEXT: vpermd %ymm1, %ymm0, %ymm0 -+; AVX512DQ-NEXT: vpermd %zmm1, %zmm0, %zmm0 - ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa %ymm0, (%rcx) - ; AVX512DQ-NEXT: vzeroupper -@@ -3507,12 +3507,13 @@ - ; - ; AVX512F-LABEL: vec384_i16_widen_to_i32_factor2_broadcast_to_v12i32_factor12: - ; AVX512F: # %bb.0: --; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 - ; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 --; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -+; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm2 -+; AVX512F-NEXT: vpblendw {{.*#+}} ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7],ymm2[8],ymm1[9],ymm2[10],ymm1[11],ymm2[12],ymm1[13],ymm2[14],ymm1[15] - ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 --; AVX512F-NEXT: vpblendw {{.*#+}} ymm1 = ymm0[0],ymm1[1],ymm0[2],ymm1[3],ymm0[4],ymm1[5],ymm0[6],ymm1[7],ymm0[8],ymm1[9],ymm0[10],ymm1[11],ymm0[12],ymm1[13],ymm0[14],ymm1[15] - ; AVX512F-NEXT: vpaddb (%rdx), %ymm1, %ymm1 - ; AVX512F-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa %ymm0, 32(%rcx) -@@ -3522,12 +3523,13 @@ - ; - ; AVX512DQ-LABEL: vec384_i16_widen_to_i32_factor2_broadcast_to_v12i32_factor12: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 - ; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 --; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -+; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm2 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7],ymm2[8],ymm1[9],ymm2[10],ymm1[11],ymm2[12],ymm1[13],ymm2[14],ymm1[15] - ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm0[0],ymm1[1],ymm0[2],ymm1[3],ymm0[4],ymm1[5],ymm0[6],ymm1[7],ymm0[8],ymm1[9],ymm0[10],ymm1[11],ymm0[12],ymm1[13],ymm0[14],ymm1[15] - ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm1, %ymm1 - ; AVX512DQ-NEXT: vpaddb 32(%rdx), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa %ymm0, 32(%rcx) -@@ -3766,10 +3768,10 @@ - ; - ; AVX512F-LABEL: vec384_i16_widen_to_i64_factor4_broadcast_to_v6i64_factor6: - ; AVX512F: # %bb.0: --; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 - ; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 --; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512F-NEXT: vpbroadcastq %xmm0, %ymm2 - ; AVX512F-NEXT: vpblendw {{.*#+}} ymm1 = ymm2[0],ymm1[1,2,3],ymm2[4],ymm1[5,6,7],ymm2[8],ymm1[9,10,11],ymm2[12],ymm1[13,14,15] - ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 -@@ -3782,10 +3784,10 @@ - ; - ; AVX512DQ-LABEL: vec384_i16_widen_to_i64_factor4_broadcast_to_v6i64_factor6: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 - ; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 --; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512DQ-NEXT: vpbroadcastq %xmm0, %ymm2 - ; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm2[0],ymm1[1,2,3],ymm2[4],ymm1[5,6,7],ymm2[8],ymm1[9,10,11],ymm2[12],ymm1[13,14,15] - ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 -@@ -4145,9 +4147,9 @@ - ; - ; AVX512F-LABEL: vec384_i16_widen_to_i192_factor12_broadcast_to_v2i192_factor2: - ; AVX512F: # %bb.0: --; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 --; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 - ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5,6,7] - ; AVX512F-NEXT: vpbroadcastw %xmm0, %xmm0 -@@ -4159,9 +4161,9 @@ - ; - ; AVX512DQ-LABEL: vec384_i16_widen_to_i192_factor12_broadcast_to_v2i192_factor2: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 --; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 - ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5,6,7] - ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %xmm0 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/dpbusd_i4.ll b/llvm/test/CodeGen/X86/dpbusd_i4.ll ---- a/llvm/test/CodeGen/X86/dpbusd_i4.ll -+++ b/llvm/test/CodeGen/X86/dpbusd_i4.ll -@@ -86,7 +86,7 @@ - ; CHECK-NEXT: vpsraw $12, %ymm0, %ymm0 - ; CHECK-NEXT: vpmaddwd %ymm1, %ymm0, %ymm0 - ; CHECK-NEXT: vextracti128 $1, %ymm0, %xmm1 --; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -+; CHECK-NEXT: vpaddd %ymm1, %ymm0, %ymm0 - ; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] - ; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 - ; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/dpbusd.ll b/llvm/test/CodeGen/X86/dpbusd.ll ---- a/llvm/test/CodeGen/X86/dpbusd.ll -+++ b/llvm/test/CodeGen/X86/dpbusd.ll -@@ -26,7 +26,7 @@ - ; AVX512-NEXT: vpmovzxbw {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero - ; AVX512-NEXT: vpmaddwd %ymm0, %ymm1, %ymm0 - ; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm1 --; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -+; AVX512-NEXT: vpaddd %ymm1, %ymm0, %ymm0 - ; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] - ; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 - ; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-3.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-3.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-3.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-3.ll -@@ -1828,22 +1828,22 @@ - ; - ; AVX512-LABEL: load_i16_stride3_vf32: - ; AVX512: # %bb.0: --; AVX512-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm1 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] - ; AVX512-NEXT: vmovdqa 128(%rdi), %ymm5 - ; AVX512-NEXT: vmovdqa 160(%rdi), %ymm6 --; AVX512-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm1 --; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512-NEXT: vpshufb {{.*#+}} ymm3 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512-NEXT: vmovdqa 112(%rdi), %xmm1 -+; AVX512-NEXT: vmovdqa %ymm1, %ymm0 -+; AVX512-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm0 -+; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0],ymm2[1],ymm0[2,3],ymm2[4],ymm0[5,6],ymm2[7],ymm0[8],ymm2[9],ymm0[10,11],ymm2[12],ymm0[13,14],ymm2[15] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512-NEXT: vmovdqa 112(%rdi), %xmm0 - ; AVX512-NEXT: vmovdqa 96(%rdi), %xmm2 --; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm0[1],xmm2[2,3],xmm0[4],xmm2[5,6],xmm0[7] - ; AVX512-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[0,1,6,7,12,13,2,3,8,9,14,15,u,u,u,u] - ; AVX512-NEXT: vpblendd {{.*#+}} ymm7 = ymm4[0,1,2],ymm3[3,4,5,6,7] - ; AVX512-NEXT: vmovdqa (%rdi), %ymm8 - ; AVX512-NEXT: vmovdqa 32(%rdi), %ymm9 --; AVX512-NEXT: vmovdqa %ymm0, %ymm3 -+; AVX512-NEXT: vmovdqa %ymm1, %ymm3 - ; AVX512-NEXT: vpternlogq $202, %ymm9, %ymm8, %ymm3 - ; AVX512-NEXT: vpermq {{.*#+}} ymm4 = ymm3[2,3,0,1] - ; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm4[1],ymm3[2,3],ymm4[4],ymm3[5,6],ymm4[7],ymm3[8],ymm4[9],ymm3[10,11],ymm4[12],ymm3[13,14],ymm4[15] -@@ -1857,14 +1857,14 @@ - ; AVX512-NEXT: vpshufhw {{.*#+}} xmm10 = xmm10[0,1,2,3,6,5,4,7] - ; AVX512-NEXT: vpblendd {{.*#+}} ymm10 = ymm10[0,1,2,3],ymm11[4,5,6,7] - ; AVX512-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 --; AVX512-NEXT: vmovdqa %ymm0, %ymm10 -+; AVX512-NEXT: vmovdqa %ymm1, %ymm10 - ; AVX512-NEXT: vpternlogq $202, %ymm6, %ymm5, %ymm10 - ; AVX512-NEXT: vpermq {{.*#+}} ymm11 = ymm10[2,3,0,1] - ; AVX512-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0,1],ymm11[2],ymm10[3,4],ymm11[5],ymm10[6,7,8,9],ymm11[10],ymm10[11,12],ymm11[13],ymm10[14,15] - ; AVX512-NEXT: vmovdqa {{.*#+}} ymm11 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] - ; AVX512-NEXT: vpshufb %ymm11, %ymm10, %ymm10 --; AVX512-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512-NEXT: vpshufb {{.*#+}} xmm12 = xmm12[2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] -+; AVX512-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm0[2],xmm2[3,4],xmm0[5],xmm2[6,7] -+; AVX512-NEXT: vpshufb %xmm11, %xmm12, %xmm12 - ; AVX512-NEXT: vpblendw {{.*#+}} xmm12 = xmm12[0,1,2,3,4],xmm10[5,6,7] - ; AVX512-NEXT: vpblendd {{.*#+}} ymm10 = ymm12[0,1,2,3],ymm10[4,5,6,7] - ; AVX512-NEXT: vmovdqa {{.*#+}} ymm12 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -@@ -1885,19 +1885,21 @@ - ; AVX512-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1,2],ymm5[3],ymm12[4,5],ymm5[6],ymm12[7],ymm5[8],ymm12[9,10],ymm5[11],ymm12[12,13],ymm5[14],ymm12[15] - ; AVX512-NEXT: vmovdqa {{.*#+}} ymm6 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] - ; AVX512-NEXT: vpshufb %ymm6, %ymm5, %ymm5 --; AVX512-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2],xmm1[3,4],xmm2[5],xmm1[6,7] --; AVX512-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1,2,3,4],xmm5[5,6,7] --; AVX512-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm5[4,5,6,7] --; AVX512-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm0 --; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm0 = ymm2[0],ymm0[1,2],ymm2[3],ymm0[4,5],ymm2[6],ymm0[7],ymm2[8],ymm0[9,10],ymm2[11],ymm0[12,13],ymm2[14],ymm0[15] --; AVX512-NEXT: vpshufb %ymm6, %ymm0, %ymm0 --; AVX512-NEXT: vpblendw {{.*#+}} xmm2 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] --; AVX512-NEXT: vpshufb %xmm6, %xmm2, %xmm2 --; AVX512-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm2[5,6,7] --; AVX512-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -+; AVX512-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm1 -+; AVX512-NEXT: vpermq {{.*#+}} ymm8 = ymm1[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0],ymm1[1,2],ymm8[3],ymm1[4,5],ymm8[6],ymm1[7],ymm8[8],ymm1[9,10],ymm8[11],ymm1[12,13],ymm8[14],ymm1[15] -+; AVX512-NEXT: vpshufb %ymm6, %ymm1, %ymm1 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] -+; AVX512-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3,4],ymm3[5,6,7] -+; AVX512-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2],xmm0[3,4],xmm2[5],xmm0[6,7] -+; AVX512-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -+; AVX512-NEXT: vextracti32x4 $2, %zmm0, %xmm0 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1,2,3,4],xmm5[5,6,7] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm5[4,5,6,7] -+; AVX512-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512-NEXT: vmovdqa64 %zmm7, (%rsi) - ; AVX512-NEXT: vmovdqa64 %zmm10, (%rdx) - ; AVX512-NEXT: vmovdqa64 %zmm0, (%rcx) -@@ -1906,22 +1908,22 @@ - ; - ; AVX512-FCP-LABEL: load_i16_stride3_vf32: - ; AVX512-FCP: # %bb.0: --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm1 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] - ; AVX512-FCP-NEXT: vmovdqa 128(%rdi), %ymm5 - ; AVX512-FCP-NEXT: vmovdqa 160(%rdi), %ymm6 --; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm1 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512-FCP-NEXT: vmovdqa 112(%rdi), %xmm1 -+; AVX512-FCP-NEXT: vmovdqa %ymm1, %ymm0 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm0 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0],ymm2[1],ymm0[2,3],ymm2[4],ymm0[5,6],ymm2[7],ymm0[8],ymm2[9],ymm0[10,11],ymm2[12],ymm0[13,14],ymm2[15] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512-FCP-NEXT: vmovdqa 112(%rdi), %xmm0 - ; AVX512-FCP-NEXT: vmovdqa 96(%rdi), %xmm2 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm0[1],xmm2[2,3],xmm0[4],xmm2[5,6],xmm0[7] - ; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[0,1,6,7,12,13,2,3,8,9,14,15,u,u,u,u] - ; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm7 = ymm4[0,1,2],ymm3[3,4,5,6,7] - ; AVX512-FCP-NEXT: vmovdqa (%rdi), %ymm8 - ; AVX512-FCP-NEXT: vmovdqa 32(%rdi), %ymm9 --; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm3 -+; AVX512-FCP-NEXT: vmovdqa %ymm1, %ymm3 - ; AVX512-FCP-NEXT: vpternlogq $202, %ymm9, %ymm8, %ymm3 - ; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm3[2,3,0,1] - ; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm4[1],ymm3[2,3],ymm4[4],ymm3[5,6],ymm4[7],ymm3[8],ymm4[9],ymm3[10,11],ymm4[12],ymm3[13,14],ymm4[15] -@@ -1935,14 +1937,14 @@ - ; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm10 = xmm10[0,1,2,3,6,5,4,7] - ; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm10 = ymm10[0,1,2,3],ymm11[4,5,6,7] - ; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 --; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm10 -+; AVX512-FCP-NEXT: vmovdqa %ymm1, %ymm10 - ; AVX512-FCP-NEXT: vpternlogq $202, %ymm6, %ymm5, %ymm10 - ; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm11 = ymm10[2,3,0,1] - ; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0,1],ymm11[2],ymm10[3,4],ymm11[5],ymm10[6,7,8,9],ymm11[10],ymm10[11,12],ymm11[13],ymm10[14,15] - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm11 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] - ; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm10, %ymm10 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm12 = xmm12[2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm0[2],xmm2[3,4],xmm0[5],xmm2[6,7] -+; AVX512-FCP-NEXT: vpshufb %xmm11, %xmm12, %xmm12 - ; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm12[0,1,2,3,4],xmm10[5,6,7] - ; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm10 = ymm12[0,1,2,3],ymm10[4,5,6,7] - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm12 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -@@ -1963,19 +1965,21 @@ - ; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1,2],ymm5[3],ymm12[4,5],ymm5[6],ymm12[7],ymm5[8],ymm12[9,10],ymm5[11],ymm12[12,13],ymm5[14],ymm12[15] - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm6 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] - ; AVX512-FCP-NEXT: vpshufb %ymm6, %ymm5, %ymm5 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2],xmm1[3,4],xmm2[5],xmm1[6,7] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1,2,3,4],xmm5[5,6,7] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm5[4,5,6,7] --; AVX512-FCP-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm0 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm2[0],ymm0[1,2],ymm2[3],ymm0[4,5],ymm2[6],ymm0[7],ymm2[8],ymm0[9,10],ymm2[11],ymm0[12,13],ymm2[14],ymm0[15] --; AVX512-FCP-NEXT: vpshufb %ymm6, %ymm0, %ymm0 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm2 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] --; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm2, %xmm2 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm2[5,6,7] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm1 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm8 = ymm1[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0],ymm1[1,2],ymm8[3],ymm1[4,5],ymm8[6],ymm1[7],ymm8[8],ymm1[9,10],ymm8[11],ymm1[12,13],ymm8[14],ymm1[15] -+; AVX512-FCP-NEXT: vpshufb %ymm6, %ymm1, %ymm1 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] -+; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512-FCP-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3,4],ymm3[5,6,7] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2],xmm0[3,4],xmm2[5],xmm0[6,7] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -+; AVX512-FCP-NEXT: vextracti32x4 $2, %zmm0, %xmm0 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1,2,3,4],xmm5[5,6,7] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm5[4,5,6,7] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512-FCP-NEXT: vmovdqa64 %zmm7, (%rsi) - ; AVX512-FCP-NEXT: vmovdqa64 %zmm10, (%rdx) - ; AVX512-FCP-NEXT: vmovdqa64 %zmm0, (%rcx) -@@ -1984,22 +1988,22 @@ - ; - ; AVX512DQ-LABEL: load_i16_stride3_vf32: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm1 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] - ; AVX512DQ-NEXT: vmovdqa 128(%rdi), %ymm5 - ; AVX512DQ-NEXT: vmovdqa 160(%rdi), %ymm6 --; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512DQ-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm1 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm3 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512DQ-NEXT: vmovdqa 112(%rdi), %xmm1 -+; AVX512DQ-NEXT: vmovdqa %ymm1, %ymm0 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm0 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0],ymm2[1],ymm0[2,3],ymm2[4],ymm0[5,6],ymm2[7],ymm0[8],ymm2[9],ymm0[10,11],ymm2[12],ymm0[13,14],ymm2[15] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512DQ-NEXT: vmovdqa 112(%rdi), %xmm0 - ; AVX512DQ-NEXT: vmovdqa 96(%rdi), %xmm2 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm0[1],xmm2[2,3],xmm0[4],xmm2[5,6],xmm0[7] - ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[0,1,6,7,12,13,2,3,8,9,14,15,u,u,u,u] - ; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm7 = ymm4[0,1,2],ymm3[3,4,5,6,7] - ; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm8 - ; AVX512DQ-NEXT: vmovdqa 32(%rdi), %ymm9 --; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm3 -+; AVX512DQ-NEXT: vmovdqa %ymm1, %ymm3 - ; AVX512DQ-NEXT: vpternlogq $202, %ymm9, %ymm8, %ymm3 - ; AVX512DQ-NEXT: vpermq {{.*#+}} ymm4 = ymm3[2,3,0,1] - ; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm4[1],ymm3[2,3],ymm4[4],ymm3[5,6],ymm4[7],ymm3[8],ymm4[9],ymm3[10,11],ymm4[12],ymm3[13,14],ymm4[15] -@@ -2013,14 +2017,14 @@ - ; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm10 = xmm10[0,1,2,3,6,5,4,7] - ; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm10 = ymm10[0,1,2,3],ymm11[4,5,6,7] - ; AVX512DQ-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 --; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm10 -+; AVX512DQ-NEXT: vmovdqa %ymm1, %ymm10 - ; AVX512DQ-NEXT: vpternlogq $202, %ymm6, %ymm5, %ymm10 - ; AVX512DQ-NEXT: vpermq {{.*#+}} ymm11 = ymm10[2,3,0,1] - ; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0,1],ymm11[2],ymm10[3,4],ymm11[5],ymm10[6,7,8,9],ymm11[10],ymm10[11,12],ymm11[13],ymm10[14,15] - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm11 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] - ; AVX512DQ-NEXT: vpshufb %ymm11, %ymm10, %ymm10 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm12 = xmm12[2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm0[2],xmm2[3,4],xmm0[5],xmm2[6,7] -+; AVX512DQ-NEXT: vpshufb %xmm11, %xmm12, %xmm12 - ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm12 = xmm12[0,1,2,3,4],xmm10[5,6,7] - ; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm10 = ymm12[0,1,2,3],ymm10[4,5,6,7] - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm12 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -@@ -2041,19 +2045,21 @@ - ; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1,2],ymm5[3],ymm12[4,5],ymm5[6],ymm12[7],ymm5[8],ymm12[9,10],ymm5[11],ymm12[12,13],ymm5[14],ymm12[15] - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm6 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] - ; AVX512DQ-NEXT: vpshufb %ymm6, %ymm5, %ymm5 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2],xmm1[3,4],xmm2[5],xmm1[6,7] --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1,2,3,4],xmm5[5,6,7] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm5[4,5,6,7] --; AVX512DQ-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm0 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm0 = ymm2[0],ymm0[1,2],ymm2[3],ymm0[4,5],ymm2[6],ymm0[7],ymm2[8],ymm0[9,10],ymm2[11],ymm0[12,13],ymm2[14],ymm0[15] --; AVX512DQ-NEXT: vpshufb %ymm6, %ymm0, %ymm0 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm2 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] --; AVX512DQ-NEXT: vpshufb %xmm6, %xmm2, %xmm2 --; AVX512DQ-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm2[5,6,7] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm1 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm8 = ymm1[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0],ymm1[1,2],ymm8[3],ymm1[4,5],ymm8[6],ymm1[7],ymm8[8],ymm1[9,10],ymm8[11],ymm1[12,13],ymm8[14],ymm1[15] -+; AVX512DQ-NEXT: vpshufb %ymm6, %ymm1, %ymm1 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] -+; AVX512DQ-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512DQ-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3,4],ymm3[5,6,7] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2],xmm0[3,4],xmm2[5],xmm0[6,7] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -+; AVX512DQ-NEXT: vextracti32x4 $2, %zmm0, %xmm0 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1,2,3,4],xmm5[5,6,7] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm5[4,5,6,7] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512DQ-NEXT: vmovdqa64 %zmm7, (%rsi) - ; AVX512DQ-NEXT: vmovdqa64 %zmm10, (%rdx) - ; AVX512DQ-NEXT: vmovdqa64 %zmm0, (%rcx) -@@ -2062,22 +2068,22 @@ - ; - ; AVX512DQ-FCP-LABEL: load_i16_stride3_vf32: - ; AVX512DQ-FCP: # %bb.0: --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm1 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] - ; AVX512DQ-FCP-NEXT: vmovdqa 128(%rdi), %ymm5 - ; AVX512DQ-FCP-NEXT: vmovdqa 160(%rdi), %ymm6 --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm1 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512DQ-FCP-NEXT: vmovdqa 112(%rdi), %xmm1 -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm1, %ymm0 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm5, %ymm6, %ymm0 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0],ymm2[1],ymm0[2,3],ymm2[4],ymm0[5,6],ymm2[7],ymm0[8],ymm2[9],ymm0[10,11],ymm2[12],ymm0[13,14],ymm2[15] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512DQ-FCP-NEXT: vmovdqa 112(%rdi), %xmm0 - ; AVX512DQ-FCP-NEXT: vmovdqa 96(%rdi), %xmm2 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm2[0],xmm0[1],xmm2[2,3],xmm0[4],xmm2[5,6],xmm0[7] - ; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[0,1,6,7,12,13,2,3,8,9,14,15,u,u,u,u] - ; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm7 = ymm4[0,1,2],ymm3[3,4,5,6,7] - ; AVX512DQ-FCP-NEXT: vmovdqa (%rdi), %ymm8 - ; AVX512DQ-FCP-NEXT: vmovdqa 32(%rdi), %ymm9 --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm3 -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm1, %ymm3 - ; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm9, %ymm8, %ymm3 - ; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm3[2,3,0,1] - ; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm4[1],ymm3[2,3],ymm4[4],ymm3[5,6],ymm4[7],ymm3[8],ymm4[9],ymm3[10,11],ymm4[12],ymm3[13,14],ymm4[15] -@@ -2091,14 +2097,14 @@ - ; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm10 = xmm10[0,1,2,3,6,5,4,7] - ; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm10 = ymm10[0,1,2,3],ymm11[4,5,6,7] - ; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm10 -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm1, %ymm10 - ; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm6, %ymm5, %ymm10 - ; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm11 = ymm10[2,3,0,1] - ; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0,1],ymm11[2],ymm10[3,4],ymm11[5],ymm10[6,7,8,9],ymm11[10],ymm10[11,12],ymm11[13],ymm10[14,15] - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm11 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] - ; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm10, %ymm10 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm12 = xmm12[2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0,1],xmm0[2],xmm2[3,4],xmm0[5],xmm2[6,7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm11, %xmm12, %xmm12 - ; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm12[0,1,2,3,4],xmm10[5,6,7] - ; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm10 = ymm12[0,1,2,3],ymm10[4,5,6,7] - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm12 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -@@ -2119,19 +2125,21 @@ - ; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1,2],ymm5[3],ymm12[4,5],ymm5[6],ymm12[7],ymm5[8],ymm12[9,10],ymm5[11],ymm12[12,13],ymm5[14],ymm12[15] - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm6 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] - ; AVX512DQ-FCP-NEXT: vpshufb %ymm6, %ymm5, %ymm5 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2],xmm1[3,4],xmm2[5],xmm1[6,7] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1,2,3,4],xmm5[5,6,7] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm5[4,5,6,7] --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm0 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm0[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm2[0],ymm0[1,2],ymm2[3],ymm0[4,5],ymm2[6],ymm0[7],ymm2[8],ymm0[9,10],ymm2[11],ymm0[12,13],ymm2[14],ymm0[15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm6, %ymm0, %ymm0 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm2 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] --; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm2, %xmm2 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm2[5,6,7] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm0 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm8, %ymm9, %ymm1 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm8 = ymm1[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0],ymm1[1,2],ymm8[3],ymm1[4,5],ymm8[6],ymm1[7],ymm8[8],ymm1[9,10],ymm8[11],ymm1[12,13],ymm8[14],ymm1[15] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm6, %ymm1, %ymm1 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0],xmm3[1],xmm4[2,3],xmm3[4],xmm4[5,6],xmm3[7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3,4],ymm3[5,6,7] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2],xmm0[3,4],xmm2[5],xmm0[6,7] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -+; AVX512DQ-FCP-NEXT: vextracti32x4 $2, %zmm0, %xmm0 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1,2,3,4],xmm5[5,6,7] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm5[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm7, (%rsi) - ; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm10, (%rdx) - ; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm0, (%rcx) -@@ -3492,668 +3500,688 @@ - ; AVX512-LABEL: load_i16_stride3_vf64: - ; AVX512: # %bb.0: - ; AVX512-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] --; AVX512-NEXT: vmovdqa64 224(%rdi), %ymm18 --; AVX512-NEXT: vmovdqa64 192(%rdi), %ymm20 -+; AVX512-NEXT: vmovdqa64 224(%rdi), %ymm20 -+; AVX512-NEXT: vmovdqa64 192(%rdi), %ymm21 - ; AVX512-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512-NEXT: vpternlogq $202, %ymm18, %ymm20, %ymm1 --; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm7 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] --; AVX512-NEXT: vpshufb %ymm7, %ymm2, %ymm5 --; AVX512-NEXT: vmovdqa 272(%rdi), %xmm1 -+; AVX512-NEXT: vpternlogq $202, %ymm20, %ymm21, %ymm1 -+; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm3[1],ymm1[2,3],ymm3[4],ymm1[5,6],ymm3[7],ymm1[8],ymm3[9],ymm1[10,11],ymm3[12],ymm1[13,14],ymm3[15] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm3 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] -+; AVX512-NEXT: vpshufb %ymm3, %ymm1, %ymm5 -+; AVX512-NEXT: vmovdqa 272(%rdi), %xmm8 - ; AVX512-NEXT: vmovdqa 256(%rdi), %xmm2 --; AVX512-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512-NEXT: vmovdqa %xmm2, %xmm3 --; AVX512-NEXT: vmovdqa64 %xmm1, %xmm19 --; AVX512-NEXT: vmovdqa {{.*#+}} xmm13 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] --; AVX512-NEXT: vpshufb %xmm13, %xmm6, %xmm6 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm8[2],xmm2[3,4],xmm8[5],xmm2[6,7] -+; AVX512-NEXT: vmovdqa %xmm2, %xmm14 -+; AVX512-NEXT: vmovdqa {{.*#+}} xmm9 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] -+; AVX512-NEXT: vpshufb %xmm9, %xmm6, %xmm6 - ; AVX512-NEXT: vinserti128 $1, %xmm6, %ymm0, %ymm6 - ; AVX512-NEXT: vpblendw {{.*#+}} ymm6 = ymm5[0,1,2],ymm6[3,4,5,6,7],ymm5[8,9,10],ymm6[11,12,13,14,15] - ; AVX512-NEXT: vpshufhw {{.*#+}} xmm5 = xmm5[0,1,2,3,6,5,4,7] - ; AVX512-NEXT: vpblendd {{.*#+}} ymm5 = ymm5[0,1,2,3],ymm6[4,5,6,7] --; AVX512-NEXT: vmovdqa64 320(%rdi), %ymm21 --; AVX512-NEXT: vmovdqa64 352(%rdi), %ymm22 --; AVX512-NEXT: vmovdqa %ymm0, %ymm8 --; AVX512-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm8 --; AVX512-NEXT: vpermq {{.*#+}} ymm9 = ymm8[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm8 = ymm8[0],ymm9[1],ymm8[2,3],ymm9[4],ymm8[5,6],ymm9[7],ymm8[8],ymm9[9],ymm8[10,11],ymm9[12],ymm8[13,14],ymm9[15] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm10 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512-NEXT: vpshufb %ymm10, %ymm8, %ymm11 -+; AVX512-NEXT: vmovdqa64 320(%rdi), %ymm22 -+; AVX512-NEXT: vmovdqa64 352(%rdi), %ymm23 -+; AVX512-NEXT: vmovdqa %ymm0, %ymm6 -+; AVX512-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm6 -+; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0],ymm7[1],ymm6[2,3],ymm7[4],ymm6[5,6],ymm7[7],ymm6[8],ymm7[9],ymm6[10,11],ymm7[12],ymm6[13,14],ymm7[15] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm11 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512-NEXT: vpshufb %ymm11, %ymm6, %ymm12 - ; AVX512-NEXT: vmovdqa 304(%rdi), %xmm1 - ; AVX512-NEXT: vmovdqa 288(%rdi), %xmm2 --; AVX512-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512-NEXT: vpblendw {{.*#+}} xmm13 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] - ; AVX512-NEXT: vmovdqa %xmm2, %xmm4 --; AVX512-NEXT: vmovdqa %xmm1, %xmm8 --; AVX512-NEXT: vmovdqa {{.*#+}} xmm14 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] --; AVX512-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512-NEXT: vpblendd {{.*#+}} ymm11 = ymm12[0,1,2],ymm11[3,4,5,6,7] --; AVX512-NEXT: vinserti64x4 $1, %ymm11, %zmm5, %zmm16 --; AVX512-NEXT: vmovdqa64 128(%rdi), %ymm23 --; AVX512-NEXT: vmovdqa 160(%rdi), %ymm11 -+; AVX512-NEXT: vmovdqa %xmm1, %xmm6 -+; AVX512-NEXT: vmovdqa {{.*#+}} xmm15 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] -+; AVX512-NEXT: vpshufb %xmm15, %xmm13, %xmm13 -+; AVX512-NEXT: vpblendd {{.*#+}} ymm12 = ymm13[0,1,2],ymm12[3,4,5,6,7] -+; AVX512-NEXT: vinserti64x4 $1, %ymm12, %zmm5, %zmm16 -+; AVX512-NEXT: vmovdqa64 128(%rdi), %ymm24 -+; AVX512-NEXT: vmovdqa 160(%rdi), %ymm13 - ; AVX512-NEXT: vmovdqa %ymm0, %ymm5 --; AVX512-NEXT: vpternlogq $202, %ymm23, %ymm11, %ymm5 -+; AVX512-NEXT: vpternlogq $202, %ymm24, %ymm13, %ymm5 - ; AVX512-NEXT: vpermq {{.*#+}} ymm12 = ymm5[2,3,0,1] - ; AVX512-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6],ymm12[7],ymm5[8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14],ymm12[15] --; AVX512-NEXT: vpshufb %ymm10, %ymm5, %ymm10 --; AVX512-NEXT: vmovdqa 112(%rdi), %xmm15 --; AVX512-NEXT: vmovdqa 96(%rdi), %xmm5 --; AVX512-NEXT: vpblendw {{.*#+}} xmm12 = xmm5[0],xmm15[1],xmm5[2,3],xmm15[4],xmm5[5,6],xmm15[7] --; AVX512-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512-NEXT: vpblendd {{.*#+}} ymm6 = ymm12[0,1,2],ymm10[3,4,5,6,7] --; AVX512-NEXT: vmovdqa64 (%rdi), %ymm24 --; AVX512-NEXT: vmovdqa 32(%rdi), %ymm12 -+; AVX512-NEXT: vpshufb %ymm11, %ymm5, %ymm5 -+; AVX512-NEXT: vmovdqa 112(%rdi), %xmm11 -+; AVX512-NEXT: vmovdqa 96(%rdi), %xmm12 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm10 = xmm12[0],xmm11[1],xmm12[2,3],xmm11[4],xmm12[5,6],xmm11[7] -+; AVX512-NEXT: vpshufb %xmm15, %xmm10, %xmm10 -+; AVX512-NEXT: vpblendd {{.*#+}} ymm1 = ymm10[0,1,2],ymm5[3,4,5,6,7] -+; AVX512-NEXT: vmovdqa64 (%rdi), %ymm17 -+; AVX512-NEXT: vmovdqa 32(%rdi), %ymm5 - ; AVX512-NEXT: vmovdqa %ymm0, %ymm10 --; AVX512-NEXT: vpternlogq $202, %ymm12, %ymm24, %ymm10 --; AVX512-NEXT: vpermq {{.*#+}} ymm1 = ymm10[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm10[0],ymm1[1],ymm10[2,3],ymm1[4],ymm10[5,6],ymm1[7],ymm10[8],ymm1[9],ymm10[10,11],ymm1[12],ymm10[13,14],ymm1[15] --; AVX512-NEXT: vpshufb %ymm7, %ymm1, %ymm7 -+; AVX512-NEXT: vpternlogq $202, %ymm5, %ymm17, %ymm10 -+; AVX512-NEXT: vpermq {{.*#+}} ymm15 = ymm10[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0],ymm15[1],ymm10[2,3],ymm15[4],ymm10[5,6],ymm15[7],ymm10[8],ymm15[9],ymm10[10,11],ymm15[12],ymm10[13,14],ymm15[15] -+; AVX512-NEXT: vpshufb %ymm3, %ymm10, %ymm2 - ; AVX512-NEXT: vmovdqa 80(%rdi), %xmm10 --; AVX512-NEXT: vmovdqa 64(%rdi), %xmm1 --; AVX512-NEXT: vpblendw {{.*#+}} xmm2 = xmm1[0,1],xmm10[2],xmm1[3,4],xmm10[5],xmm1[6,7] --; AVX512-NEXT: vpshufb %xmm13, %xmm2, %xmm2 --; AVX512-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm7[0,1,2],ymm2[3,4,5,6,7],ymm7[8,9,10],ymm2[11,12,13,14,15] --; AVX512-NEXT: vpshufhw {{.*#+}} xmm7 = xmm7[0,1,2,3,6,5,4,7] --; AVX512-NEXT: vpblendd {{.*#+}} ymm2 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512-NEXT: vinserti64x4 $1, %ymm6, %zmm2, %zmm17 --; AVX512-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512-NEXT: vpternlogq $202, %ymm22, %ymm21, %ymm2 --; AVX512-NEXT: vpermq {{.*#+}} ymm6 = ymm2[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm6[2],ymm2[3,4],ymm6[5],ymm2[6,7,8,9],ymm6[10],ymm2[11,12],ymm6[13],ymm2[14,15] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm9 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] --; AVX512-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512-NEXT: vpblendw {{.*#+}} xmm7 = xmm4[0,1],xmm8[2],xmm4[3,4],xmm8[5],xmm4[6,7] --; AVX512-NEXT: vmovdqa64 %xmm8, %xmm25 -+; AVX512-NEXT: vmovdqa 64(%rdi), %xmm15 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm3 = xmm15[0,1],xmm10[2],xmm15[3,4],xmm10[5],xmm15[6,7] -+; AVX512-NEXT: vpshufb %xmm9, %xmm3, %xmm3 -+; AVX512-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm2[0,1,2],ymm3[3,4,5,6,7],ymm2[8,9,10],ymm3[11,12,13,14,15] -+; AVX512-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,6,5,4,7] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm3[4,5,6,7] -+; AVX512-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm18 -+; AVX512-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512-NEXT: vpternlogq $202, %ymm23, %ymm22, %ymm1 -+; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm2[2],ymm1[3,4],ymm2[5],ymm1[6,7,8,9],ymm2[10],ymm1[11,12],ymm2[13],ymm1[14,15] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm2 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] -+; AVX512-NEXT: vpshufb %ymm2, %ymm1, %ymm1 -+; AVX512-NEXT: vmovdqa64 %ymm2, %ymm28 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0,1],xmm6[2],xmm4[3,4],xmm6[5],xmm4[6,7] -+; AVX512-NEXT: vmovdqa64 %xmm6, %xmm25 - ; AVX512-NEXT: vmovdqa64 %xmm4, %xmm26 - ; AVX512-NEXT: vmovdqa {{.*#+}} xmm6 = [2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] --; AVX512-NEXT: vpshufb %xmm6, %xmm7, %xmm7 --; AVX512-NEXT: vpblendw {{.*#+}} xmm7 = xmm7[0,1,2,3,4],xmm2[5,6,7] --; AVX512-NEXT: vpblendd {{.*#+}} ymm7 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] --; AVX512-NEXT: vmovdqa %ymm13, %ymm2 --; AVX512-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm2 --; AVX512-NEXT: vpermq {{.*#+}} ymm4 = ymm2[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] -+; AVX512-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -+; AVX512-NEXT: vmovdqa %ymm9, %ymm1 -+; AVX512-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm1 -+; AVX512-NEXT: vpermq {{.*#+}} ymm4 = ymm1[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm4[2],ymm1[3,4],ymm4[5],ymm1[6,7,8,9],ymm4[10],ymm1[11,12],ymm4[13],ymm1[14,15] - ; AVX512-NEXT: vmovdqa {{.*#+}} ymm4 = [2,3,8,9,14,15,4,5,12,13,10,11,0,1,6,7,18,19,24,25,30,31,20,21,28,29,26,27,16,17,22,23] --; AVX512-NEXT: vpshufb %ymm4, %ymm2, %ymm2 --; AVX512-NEXT: vmovdqa64 %xmm19, %xmm8 --; AVX512-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm3[2],xmm8[3,4],xmm3[5],xmm8[6,7] --; AVX512-NEXT: vmovdqa64 %xmm3, %xmm27 --; AVX512-NEXT: vmovdqa {{.*#+}} xmm3 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] --; AVX512-NEXT: vpshufb %xmm3, %xmm14, %xmm14 -+; AVX512-NEXT: vpshufb %ymm4, %ymm1, %ymm1 -+; AVX512-NEXT: vmovdqa %xmm14, %xmm7 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm14[2],xmm8[3,4],xmm14[5],xmm8[6,7] -+; AVX512-NEXT: vmovdqa64 %xmm8, %xmm27 -+; AVX512-NEXT: vmovdqa {{.*#+}} xmm2 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] -+; AVX512-NEXT: vpshufb %xmm2, %xmm14, %xmm14 - ; AVX512-NEXT: vinserti128 $1, %xmm14, %ymm0, %ymm14 --; AVX512-NEXT: vpblendw {{.*#+}} ymm14 = ymm2[0,1,2],ymm14[3,4,5,6,7],ymm2[8,9,10],ymm14[11,12,13,14,15] --; AVX512-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,5,6,7,4] --; AVX512-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm14[4,5,6,7] --; AVX512-NEXT: vinserti64x4 $1, %ymm7, %zmm2, %zmm19 --; AVX512-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512-NEXT: vpternlogq $202, %ymm11, %ymm23, %ymm2 --; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm2[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm7[2],ymm2[3,4],ymm7[5],ymm2[6,7,8,9],ymm7[10],ymm2[11,12],ymm7[13],ymm2[14,15] --; AVX512-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512-NEXT: vpblendw {{.*#+}} xmm7 = xmm5[0,1],xmm15[2],xmm5[3,4],xmm15[5],xmm5[6,7] --; AVX512-NEXT: vpshufb %xmm6, %xmm7, %xmm6 --; AVX512-NEXT: vpblendw {{.*#+}} xmm6 = xmm6[0,1,2,3,4],xmm2[5,6,7] --; AVX512-NEXT: vpblendd {{.*#+}} ymm2 = ymm6[0,1,2,3],ymm2[4,5,6,7] --; AVX512-NEXT: vmovdqa %ymm13, %ymm6 --; AVX512-NEXT: vpternlogq $202, %ymm24, %ymm12, %ymm6 --; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm7[2],ymm6[3,4],ymm7[5],ymm6[6,7,8,9],ymm7[10],ymm6[11,12],ymm7[13],ymm6[14,15] --; AVX512-NEXT: vpshufb %ymm4, %ymm6, %ymm4 --; AVX512-NEXT: vpblendw {{.*#+}} xmm6 = xmm10[0,1],xmm1[2],xmm10[3,4],xmm1[5],xmm10[6,7] --; AVX512-NEXT: vpshufb %xmm3, %xmm6, %xmm3 --; AVX512-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 --; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm4[0,1,2],ymm3[3,4,5,6,7],ymm4[8,9,10],ymm3[11,12,13,14,15] --; AVX512-NEXT: vpshufhw {{.*#+}} xmm4 = xmm4[0,1,2,3,5,6,7,4] --; AVX512-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 --; AVX512-NEXT: vpternlogq $226, %ymm23, %ymm13, %ymm11 --; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm11[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm11[1,2],ymm3[3],ymm11[4,5],ymm3[6],ymm11[7],ymm3[8],ymm11[9,10],ymm3[11],ymm11[12,13],ymm3[14],ymm11[15] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm11 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] --; AVX512-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm15[0,1],xmm5[2],xmm15[3,4],xmm5[5],xmm15[6,7] --; AVX512-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512-NEXT: vpternlogq $226, %ymm24, %ymm0, %ymm12 --; AVX512-NEXT: vpermq {{.*#+}} ymm4 = ymm12[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0],ymm12[1,2],ymm4[3],ymm12[4,5],ymm4[6],ymm12[7],ymm4[8],ymm12[9,10],ymm4[11],ymm12[12,13],ymm4[14],ymm12[15] --; AVX512-NEXT: vpshufb %ymm11, %ymm4, %ymm4 --; AVX512-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm10[1],xmm1[2,3],xmm10[4],xmm1[5,6],xmm10[7] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm14 = ymm1[0,1,2],ymm14[3,4,5,6,7],ymm1[8,9,10],ymm14[11,12,13,14,15] -+; AVX512-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,5,6,7,4] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm14[4,5,6,7] -+; AVX512-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm19 -+; AVX512-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512-NEXT: vpternlogq $202, %ymm13, %ymm24, %ymm1 -+; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] -+; AVX512-NEXT: vmovdqa64 %ymm28, %ymm3 -+; AVX512-NEXT: vpshufb %ymm3, %ymm1, %ymm1 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm3 = xmm12[0,1],xmm11[2],xmm12[3,4],xmm11[5],xmm12[6,7] -+; AVX512-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm1 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512-NEXT: vmovdqa %ymm9, %ymm3 -+; AVX512-NEXT: vpternlogq $202, %ymm17, %ymm5, %ymm3 -+; AVX512-NEXT: vpermq {{.*#+}} ymm6 = ymm3[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] -+; AVX512-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm10[0,1],xmm15[2],xmm10[3,4],xmm15[5],xmm10[6,7] -+; AVX512-NEXT: vpshufb %xmm2, %xmm4, %xmm2 -+; AVX512-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 -+; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm2[3,4,5,6,7],ymm3[8,9,10],ymm2[11,12,13,14,15] -+; AVX512-NEXT: vpshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,5,6,7,4] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm2 = ymm3[0,1,2,3],ymm2[4,5,6,7] -+; AVX512-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 -+; AVX512-NEXT: vpternlogq $226, %ymm24, %ymm9, %ymm13 -+; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm13[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0],ymm13[1,2],ymm2[3],ymm13[4,5],ymm2[6],ymm13[7],ymm2[8],ymm13[9,10],ymm2[11],ymm13[12,13],ymm2[14],ymm13[15] -+; AVX512-NEXT: vpternlogq $226, %ymm17, %ymm0, %ymm5 -+; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm5[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm5[1,2],ymm3[3],ymm5[4,5],ymm3[6],ymm5[7],ymm3[8],ymm5[9,10],ymm3[11],ymm5[12,13],ymm3[14],ymm5[15] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm4 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] -+; AVX512-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm5 = xmm15[0],xmm10[1],xmm15[2,3],xmm10[4],xmm15[5,6],xmm10[7] - ; AVX512-NEXT: vmovdqa {{.*#+}} xmm6 = [0,1,2,3,0,1,6,7,12,13,2,3,8,9,14,15] --; AVX512-NEXT: vpshufb %xmm6, %xmm1, %xmm1 --; AVX512-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm1 --; AVX512-NEXT: vpblendd {{.*#+}} ymm1 = ymm4[0,1,2,3,4],ymm1[5,6,7] --; AVX512-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm1 --; AVX512-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm13 --; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm13[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm13[1,2],ymm3[3],ymm13[4,5],ymm3[6],ymm13[7],ymm3[8],ymm13[9,10],ymm3[11],ymm13[12,13],ymm3[14],ymm13[15] --; AVX512-NEXT: vmovdqa64 %xmm25, %xmm4 --; AVX512-NEXT: vmovdqa64 %xmm26, %xmm7 --; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm7[2],xmm4[3,4],xmm7[5],xmm4[6,7] --; AVX512-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm0 --; AVX512-NEXT: vpermq {{.*#+}} ymm4 = ymm0[2,3,0,1] --; AVX512-NEXT: vpblendw {{.*#+}} ymm0 = ymm4[0],ymm0[1,2],ymm4[3],ymm0[4,5],ymm4[6],ymm0[7],ymm4[8],ymm0[9,10],ymm4[11],ymm0[12,13],ymm4[14],ymm0[15] --; AVX512-NEXT: vpshufb %ymm11, %ymm0, %ymm0 -+; AVX512-NEXT: vpshufb %xmm6, %xmm5, %xmm5 -+; AVX512-NEXT: vinserti128 $1, %xmm5, %ymm0, %ymm5 -+; AVX512-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3,4],ymm5[5,6,7] -+; AVX512-NEXT: vpshufb %ymm4, %ymm2, %ymm2 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm5 = xmm11[0,1],xmm12[2],xmm11[3,4],xmm12[5],xmm11[6,7] -+; AVX512-NEXT: vmovdqa {{.*#+}} xmm8 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512-NEXT: vpshufb %xmm8, %xmm5, %xmm5 -+; AVX512-NEXT: vinserti64x4 $1, %ymm5, %zmm3, %zmm5 -+; AVX512-NEXT: vextracti32x4 $2, %zmm5, %xmm5 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm5 = xmm5[0,1,2,3,4],xmm2[5,6,7] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm2 = ymm5[0,1,2,3],ymm2[4,5,6,7] -+; AVX512-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 -+; AVX512-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm9 -+; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm9[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm9[1,2],ymm3[3],ymm9[4,5],ymm3[6],ymm9[7],ymm3[8],ymm9[9,10],ymm3[11],ymm9[12,13],ymm3[14],ymm9[15] -+; AVX512-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm0 -+; AVX512-NEXT: vpermq {{.*#+}} ymm5 = ymm0[2,3,0,1] -+; AVX512-NEXT: vpblendw {{.*#+}} ymm0 = ymm5[0],ymm0[1,2],ymm5[3],ymm0[4,5],ymm5[6],ymm0[7],ymm5[8],ymm0[9,10],ymm5[11],ymm0[12,13],ymm5[14],ymm0[15] -+; AVX512-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512-NEXT: vpshufb %ymm4, %ymm0, %ymm0 - ; AVX512-NEXT: vmovdqa64 %xmm27, %xmm4 --; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0],xmm8[1],xmm4[2,3],xmm8[4],xmm4[5,6],xmm8[7] -+; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm7[0],xmm4[1],xmm7[2,3],xmm4[4],xmm7[5,6],xmm4[7] - ; AVX512-NEXT: vpshufb %xmm6, %xmm4, %xmm4 - ; AVX512-NEXT: vinserti128 $1, %xmm4, %ymm0, %ymm4 - ; AVX512-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm4[5,6,7] -+; AVX512-NEXT: vmovdqa64 %xmm25, %xmm4 -+; AVX512-NEXT: vmovdqa64 %xmm26, %xmm5 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm5[2],xmm4[3,4],xmm5[5],xmm4[6,7] -+; AVX512-NEXT: vpshufb %xmm8, %xmm4, %xmm4 -+; AVX512-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm4 -+; AVX512-NEXT: vextracti32x4 $2, %zmm4, %xmm4 -+; AVX512-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] -+; AVX512-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] - ; AVX512-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 --; AVX512-NEXT: vmovdqa64 %zmm17, (%rsi) -+; AVX512-NEXT: vmovdqa64 %zmm18, (%rsi) - ; AVX512-NEXT: vmovdqa64 %zmm16, 64(%rsi) - ; AVX512-NEXT: vmovdqa64 %zmm19, 64(%rdx) --; AVX512-NEXT: vmovdqa64 %zmm2, (%rdx) -+; AVX512-NEXT: vmovdqa64 %zmm1, (%rdx) - ; AVX512-NEXT: vmovdqa64 %zmm0, 64(%rcx) --; AVX512-NEXT: vmovdqa64 %zmm1, (%rcx) -+; AVX512-NEXT: vmovdqa64 %zmm2, (%rcx) - ; AVX512-NEXT: vzeroupper - ; AVX512-NEXT: retq - ; - ; AVX512-FCP-LABEL: load_i16_stride3_vf64: - ; AVX512-FCP: # %bb.0: - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] --; AVX512-FCP-NEXT: vmovdqa64 224(%rdi), %ymm18 --; AVX512-FCP-NEXT: vmovdqa64 192(%rdi), %ymm20 -+; AVX512-FCP-NEXT: vmovdqa64 224(%rdi), %ymm20 -+; AVX512-FCP-NEXT: vmovdqa64 192(%rdi), %ymm21 - ; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm18, %ymm20, %ymm1 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm7 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] --; AVX512-FCP-NEXT: vpshufb %ymm7, %ymm2, %ymm5 --; AVX512-FCP-NEXT: vmovdqa 272(%rdi), %xmm1 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm20, %ymm21, %ymm1 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm3[1],ymm1[2,3],ymm3[4],ymm1[5,6],ymm3[7],ymm1[8],ymm3[9],ymm1[10,11],ymm3[12],ymm1[13,14],ymm3[15] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm3 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] -+; AVX512-FCP-NEXT: vpshufb %ymm3, %ymm1, %ymm5 -+; AVX512-FCP-NEXT: vmovdqa 272(%rdi), %xmm8 - ; AVX512-FCP-NEXT: vmovdqa 256(%rdi), %xmm2 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512-FCP-NEXT: vmovdqa %xmm2, %xmm3 --; AVX512-FCP-NEXT: vmovdqa64 %xmm1, %xmm19 --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm13 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] --; AVX512-FCP-NEXT: vpshufb %xmm13, %xmm6, %xmm6 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm8[2],xmm2[3,4],xmm8[5],xmm2[6,7] -+; AVX512-FCP-NEXT: vmovdqa %xmm2, %xmm14 -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm9 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] -+; AVX512-FCP-NEXT: vpshufb %xmm9, %xmm6, %xmm6 - ; AVX512-FCP-NEXT: vinserti128 $1, %xmm6, %ymm0, %ymm6 - ; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm6 = ymm5[0,1,2],ymm6[3,4,5,6,7],ymm5[8,9,10],ymm6[11,12,13,14,15] - ; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm5 = xmm5[0,1,2,3,6,5,4,7] - ; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm5 = ymm5[0,1,2,3],ymm6[4,5,6,7] --; AVX512-FCP-NEXT: vmovdqa64 320(%rdi), %ymm21 --; AVX512-FCP-NEXT: vmovdqa64 352(%rdi), %ymm22 --; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm8 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm8 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm9 = ymm8[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm8 = ymm8[0],ymm9[1],ymm8[2,3],ymm9[4],ymm8[5,6],ymm9[7],ymm8[8],ymm9[9],ymm8[10,11],ymm9[12],ymm8[13,14],ymm9[15] --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm10 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512-FCP-NEXT: vpshufb %ymm10, %ymm8, %ymm11 -+; AVX512-FCP-NEXT: vmovdqa64 320(%rdi), %ymm22 -+; AVX512-FCP-NEXT: vmovdqa64 352(%rdi), %ymm23 -+; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm6 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm6 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0],ymm7[1],ymm6[2,3],ymm7[4],ymm6[5,6],ymm7[7],ymm6[8],ymm7[9],ymm6[10,11],ymm7[12],ymm6[13,14],ymm7[15] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm11 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm6, %ymm12 - ; AVX512-FCP-NEXT: vmovdqa 304(%rdi), %xmm1 - ; AVX512-FCP-NEXT: vmovdqa 288(%rdi), %xmm2 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm13 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] - ; AVX512-FCP-NEXT: vmovdqa %xmm2, %xmm4 --; AVX512-FCP-NEXT: vmovdqa %xmm1, %xmm8 --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm14 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] --; AVX512-FCP-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm11 = ymm12[0,1,2],ymm11[3,4,5,6,7] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm11, %zmm5, %zmm16 --; AVX512-FCP-NEXT: vmovdqa64 128(%rdi), %ymm23 --; AVX512-FCP-NEXT: vmovdqa 160(%rdi), %ymm11 -+; AVX512-FCP-NEXT: vmovdqa %xmm1, %xmm6 -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm15 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] -+; AVX512-FCP-NEXT: vpshufb %xmm15, %xmm13, %xmm13 -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm12 = ymm13[0,1,2],ymm12[3,4,5,6,7] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm12, %zmm5, %zmm16 -+; AVX512-FCP-NEXT: vmovdqa64 128(%rdi), %ymm24 -+; AVX512-FCP-NEXT: vmovdqa 160(%rdi), %ymm13 - ; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm5 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm23, %ymm11, %ymm5 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm24, %ymm13, %ymm5 - ; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm12 = ymm5[2,3,0,1] - ; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6],ymm12[7],ymm5[8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14],ymm12[15] --; AVX512-FCP-NEXT: vpshufb %ymm10, %ymm5, %ymm10 --; AVX512-FCP-NEXT: vmovdqa 112(%rdi), %xmm15 --; AVX512-FCP-NEXT: vmovdqa 96(%rdi), %xmm5 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm5[0],xmm15[1],xmm5[2,3],xmm15[4],xmm5[5,6],xmm15[7] --; AVX512-FCP-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm6 = ymm12[0,1,2],ymm10[3,4,5,6,7] --; AVX512-FCP-NEXT: vmovdqa64 (%rdi), %ymm24 --; AVX512-FCP-NEXT: vmovdqa 32(%rdi), %ymm12 -+; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm5, %ymm5 -+; AVX512-FCP-NEXT: vmovdqa 112(%rdi), %xmm11 -+; AVX512-FCP-NEXT: vmovdqa 96(%rdi), %xmm12 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm10 = xmm12[0],xmm11[1],xmm12[2,3],xmm11[4],xmm12[5,6],xmm11[7] -+; AVX512-FCP-NEXT: vpshufb %xmm15, %xmm10, %xmm10 -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm10[0,1,2],ymm5[3,4,5,6,7] -+; AVX512-FCP-NEXT: vmovdqa64 (%rdi), %ymm17 -+; AVX512-FCP-NEXT: vmovdqa 32(%rdi), %ymm5 - ; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm10 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm12, %ymm24, %ymm10 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm1 = ymm10[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm10[0],ymm1[1],ymm10[2,3],ymm1[4],ymm10[5,6],ymm1[7],ymm10[8],ymm1[9],ymm10[10,11],ymm1[12],ymm10[13,14],ymm1[15] --; AVX512-FCP-NEXT: vpshufb %ymm7, %ymm1, %ymm7 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm5, %ymm17, %ymm10 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm15 = ymm10[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0],ymm15[1],ymm10[2,3],ymm15[4],ymm10[5,6],ymm15[7],ymm10[8],ymm15[9],ymm10[10,11],ymm15[12],ymm10[13,14],ymm15[15] -+; AVX512-FCP-NEXT: vpshufb %ymm3, %ymm10, %ymm2 - ; AVX512-FCP-NEXT: vmovdqa 80(%rdi), %xmm10 --; AVX512-FCP-NEXT: vmovdqa 64(%rdi), %xmm1 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm2 = xmm1[0,1],xmm10[2],xmm1[3,4],xmm10[5],xmm1[6,7] --; AVX512-FCP-NEXT: vpshufb %xmm13, %xmm2, %xmm2 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm7[0,1,2],ymm2[3,4,5,6,7],ymm7[8,9,10],ymm2[11,12,13,14,15] --; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm7 = xmm7[0,1,2,3,6,5,4,7] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm6, %zmm2, %zmm17 --; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm22, %ymm21, %ymm2 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm2[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm6[2],ymm2[3,4],ymm6[5],ymm2[6,7,8,9],ymm6[10],ymm2[11,12],ymm6[13],ymm2[14,15] --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm9 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] --; AVX512-FCP-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm7 = xmm4[0,1],xmm8[2],xmm4[3,4],xmm8[5],xmm4[6,7] --; AVX512-FCP-NEXT: vmovdqa64 %xmm8, %xmm25 -+; AVX512-FCP-NEXT: vmovdqa 64(%rdi), %xmm15 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm15[0,1],xmm10[2],xmm15[3,4],xmm10[5],xmm15[6,7] -+; AVX512-FCP-NEXT: vpshufb %xmm9, %xmm3, %xmm3 -+; AVX512-FCP-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm2[0,1,2],ymm3[3,4,5,6,7],ymm2[8,9,10],ymm3[11,12,13,14,15] -+; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,6,5,4,7] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm3[4,5,6,7] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm18 -+; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm23, %ymm22, %ymm1 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm2[2],ymm1[3,4],ymm2[5],ymm1[6,7,8,9],ymm2[10],ymm1[11,12],ymm2[13],ymm1[14,15] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm2 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] -+; AVX512-FCP-NEXT: vpshufb %ymm2, %ymm1, %ymm1 -+; AVX512-FCP-NEXT: vmovdqa64 %ymm2, %ymm28 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0,1],xmm6[2],xmm4[3,4],xmm6[5],xmm4[6,7] -+; AVX512-FCP-NEXT: vmovdqa64 %xmm6, %xmm25 - ; AVX512-FCP-NEXT: vmovdqa64 %xmm4, %xmm26 - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm6 = [2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] --; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm7, %xmm7 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm7 = xmm7[0,1,2,3,4],xmm2[5,6,7] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm7 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] --; AVX512-FCP-NEXT: vmovdqa %ymm13, %ymm2 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm2 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm2[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] -+; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -+; AVX512-FCP-NEXT: vmovdqa %ymm9, %ymm1 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm1 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm1[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm4[2],ymm1[3,4],ymm4[5],ymm1[6,7,8,9],ymm4[10],ymm1[11,12],ymm4[13],ymm1[14,15] - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm4 = [2,3,8,9,14,15,4,5,12,13,10,11,0,1,6,7,18,19,24,25,30,31,20,21,28,29,26,27,16,17,22,23] --; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm2, %ymm2 --; AVX512-FCP-NEXT: vmovdqa64 %xmm19, %xmm8 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm3[2],xmm8[3,4],xmm3[5],xmm8[6,7] --; AVX512-FCP-NEXT: vmovdqa64 %xmm3, %xmm27 --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm3 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] --; AVX512-FCP-NEXT: vpshufb %xmm3, %xmm14, %xmm14 -+; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm1, %ymm1 -+; AVX512-FCP-NEXT: vmovdqa %xmm14, %xmm7 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm14[2],xmm8[3,4],xmm14[5],xmm8[6,7] -+; AVX512-FCP-NEXT: vmovdqa64 %xmm8, %xmm27 -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm2 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] -+; AVX512-FCP-NEXT: vpshufb %xmm2, %xmm14, %xmm14 - ; AVX512-FCP-NEXT: vinserti128 $1, %xmm14, %ymm0, %ymm14 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm14 = ymm2[0,1,2],ymm14[3,4,5,6,7],ymm2[8,9,10],ymm14[11,12,13,14,15] --; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,5,6,7,4] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm14[4,5,6,7] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm7, %zmm2, %zmm19 --; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm11, %ymm23, %ymm2 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm2[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm7[2],ymm2[3,4],ymm7[5],ymm2[6,7,8,9],ymm7[10],ymm2[11,12],ymm7[13],ymm2[14,15] --; AVX512-FCP-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm7 = xmm5[0,1],xmm15[2],xmm5[3,4],xmm15[5],xmm5[6,7] --; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm7, %xmm6 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm6[0,1,2,3,4],xmm2[5,6,7] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm6[0,1,2,3],ymm2[4,5,6,7] --; AVX512-FCP-NEXT: vmovdqa %ymm13, %ymm6 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm24, %ymm12, %ymm6 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm7[2],ymm6[3,4],ymm7[5],ymm6[6,7,8,9],ymm7[10],ymm6[11,12],ymm7[13],ymm6[14,15] --; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm6, %ymm4 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm10[0,1],xmm1[2],xmm10[3,4],xmm1[5],xmm10[6,7] --; AVX512-FCP-NEXT: vpshufb %xmm3, %xmm6, %xmm3 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm4[0,1,2],ymm3[3,4,5,6,7],ymm4[8,9,10],ymm3[11,12,13,14,15] --; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm4 = xmm4[0,1,2,3,5,6,7,4] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 --; AVX512-FCP-NEXT: vpternlogq $226, %ymm23, %ymm13, %ymm11 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm11[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm11[1,2],ymm3[3],ymm11[4,5],ymm3[6],ymm11[7],ymm3[8],ymm11[9,10],ymm3[11],ymm11[12,13],ymm3[14],ymm11[15] --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm11 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] --; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm15[0,1],xmm5[2],xmm15[3,4],xmm5[5],xmm15[6,7] --; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512-FCP-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512-FCP-NEXT: vpternlogq $226, %ymm24, %ymm0, %ymm12 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm12[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0],ymm12[1,2],ymm4[3],ymm12[4,5],ymm4[6],ymm12[7],ymm4[8],ymm12[9,10],ymm4[11],ymm12[12,13],ymm4[14],ymm12[15] --; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm4, %ymm4 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm10[1],xmm1[2,3],xmm10[4],xmm1[5,6],xmm10[7] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm14 = ymm1[0,1,2],ymm14[3,4,5,6,7],ymm1[8,9,10],ymm14[11,12,13,14,15] -+; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,5,6,7,4] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm14[4,5,6,7] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm19 -+; AVX512-FCP-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm13, %ymm24, %ymm1 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] -+; AVX512-FCP-NEXT: vmovdqa64 %ymm28, %ymm3 -+; AVX512-FCP-NEXT: vpshufb %ymm3, %ymm1, %ymm1 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm12[0,1],xmm11[2],xmm12[3,4],xmm11[5],xmm12[6,7] -+; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512-FCP-NEXT: vmovdqa %ymm9, %ymm3 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm17, %ymm5, %ymm3 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm3[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] -+; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm10[0,1],xmm15[2],xmm10[3,4],xmm15[5],xmm10[6,7] -+; AVX512-FCP-NEXT: vpshufb %xmm2, %xmm4, %xmm2 -+; AVX512-FCP-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm2[3,4,5,6,7],ymm3[8,9,10],ymm2[11,12,13,14,15] -+; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,5,6,7,4] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm3[0,1,2,3],ymm2[4,5,6,7] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 -+; AVX512-FCP-NEXT: vpternlogq $226, %ymm24, %ymm9, %ymm13 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm13[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0],ymm13[1,2],ymm2[3],ymm13[4,5],ymm2[6],ymm13[7],ymm2[8],ymm13[9,10],ymm2[11],ymm13[12,13],ymm2[14],ymm13[15] -+; AVX512-FCP-NEXT: vpternlogq $226, %ymm17, %ymm0, %ymm5 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm5[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm5[1,2],ymm3[3],ymm5[4,5],ymm3[6],ymm5[7],ymm3[8],ymm5[9,10],ymm3[11],ymm5[12,13],ymm3[14],ymm5[15] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} ymm4 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] -+; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm5 = xmm15[0],xmm10[1],xmm15[2,3],xmm10[4],xmm15[5,6],xmm10[7] - ; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm6 = [0,1,2,3,0,1,6,7,12,13,2,3,8,9,14,15] --; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm1, %xmm1 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm1 --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm4[0,1,2,3,4],ymm1[5,6,7] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm1 --; AVX512-FCP-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm13 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm13[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm13[1,2],ymm3[3],ymm13[4,5],ymm3[6],ymm13[7],ymm3[8],ymm13[9,10],ymm3[11],ymm13[12,13],ymm3[14],ymm13[15] --; AVX512-FCP-NEXT: vmovdqa64 %xmm25, %xmm4 --; AVX512-FCP-NEXT: vmovdqa64 %xmm26, %xmm7 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm7[2],xmm4[3,4],xmm7[5],xmm4[6,7] --; AVX512-FCP-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512-FCP-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm0 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm0[2,3,0,1] --; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm4[0],ymm0[1,2],ymm4[3],ymm0[4,5],ymm4[6],ymm0[7],ymm4[8],ymm0[9,10],ymm4[11],ymm0[12,13],ymm4[14],ymm0[15] --; AVX512-FCP-NEXT: vpshufb %ymm11, %ymm0, %ymm0 -+; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm5, %xmm5 -+; AVX512-FCP-NEXT: vinserti128 $1, %xmm5, %ymm0, %ymm5 -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3,4],ymm5[5,6,7] -+; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm2, %ymm2 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm5 = xmm11[0,1],xmm12[2],xmm11[3,4],xmm12[5],xmm11[6,7] -+; AVX512-FCP-NEXT: vmovdqa {{.*#+}} xmm8 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512-FCP-NEXT: vpshufb %xmm8, %xmm5, %xmm5 -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm5, %zmm3, %zmm5 -+; AVX512-FCP-NEXT: vextracti32x4 $2, %zmm5, %xmm5 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm5 = xmm5[0,1,2,3,4],xmm2[5,6,7] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm5[0,1,2,3],ymm2[4,5,6,7] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm9 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm9[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm9[1,2],ymm3[3],ymm9[4,5],ymm3[6],ymm9[7],ymm3[8],ymm9[9,10],ymm3[11],ymm9[12,13],ymm3[14],ymm9[15] -+; AVX512-FCP-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm0 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm5 = ymm0[2,3,0,1] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm5[0],ymm0[1,2],ymm5[3],ymm0[4,5],ymm5[6],ymm0[7],ymm5[8],ymm0[9,10],ymm5[11],ymm0[12,13],ymm5[14],ymm0[15] -+; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512-FCP-NEXT: vpshufb %ymm4, %ymm0, %ymm0 - ; AVX512-FCP-NEXT: vmovdqa64 %xmm27, %xmm4 --; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0],xmm8[1],xmm4[2,3],xmm8[4],xmm4[5,6],xmm8[7] -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm7[0],xmm4[1],xmm7[2,3],xmm4[4],xmm7[5,6],xmm4[7] - ; AVX512-FCP-NEXT: vpshufb %xmm6, %xmm4, %xmm4 - ; AVX512-FCP-NEXT: vinserti128 $1, %xmm4, %ymm0, %ymm4 - ; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm4[5,6,7] -+; AVX512-FCP-NEXT: vmovdqa64 %xmm25, %xmm4 -+; AVX512-FCP-NEXT: vmovdqa64 %xmm26, %xmm5 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm5[2],xmm4[3,4],xmm5[5],xmm4[6,7] -+; AVX512-FCP-NEXT: vpshufb %xmm8, %xmm4, %xmm4 -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm4 -+; AVX512-FCP-NEXT: vextracti32x4 $2, %zmm4, %xmm4 -+; AVX512-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] -+; AVX512-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] - ; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 --; AVX512-FCP-NEXT: vmovdqa64 %zmm17, (%rsi) -+; AVX512-FCP-NEXT: vmovdqa64 %zmm18, (%rsi) - ; AVX512-FCP-NEXT: vmovdqa64 %zmm16, 64(%rsi) - ; AVX512-FCP-NEXT: vmovdqa64 %zmm19, 64(%rdx) --; AVX512-FCP-NEXT: vmovdqa64 %zmm2, (%rdx) -+; AVX512-FCP-NEXT: vmovdqa64 %zmm1, (%rdx) - ; AVX512-FCP-NEXT: vmovdqa64 %zmm0, 64(%rcx) --; AVX512-FCP-NEXT: vmovdqa64 %zmm1, (%rcx) -+; AVX512-FCP-NEXT: vmovdqa64 %zmm2, (%rcx) - ; AVX512-FCP-NEXT: vzeroupper - ; AVX512-FCP-NEXT: retq - ; - ; AVX512DQ-LABEL: load_i16_stride3_vf64: - ; AVX512DQ: # %bb.0: - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] --; AVX512DQ-NEXT: vmovdqa64 224(%rdi), %ymm18 --; AVX512DQ-NEXT: vmovdqa64 192(%rdi), %ymm20 -+; AVX512DQ-NEXT: vmovdqa64 224(%rdi), %ymm20 -+; AVX512DQ-NEXT: vmovdqa64 192(%rdi), %ymm21 - ; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512DQ-NEXT: vpternlogq $202, %ymm18, %ymm20, %ymm1 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm7 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] --; AVX512DQ-NEXT: vpshufb %ymm7, %ymm2, %ymm5 --; AVX512DQ-NEXT: vmovdqa 272(%rdi), %xmm1 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm20, %ymm21, %ymm1 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm3[1],ymm1[2,3],ymm3[4],ymm1[5,6],ymm3[7],ymm1[8],ymm3[9],ymm1[10,11],ymm3[12],ymm1[13,14],ymm3[15] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm3 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] -+; AVX512DQ-NEXT: vpshufb %ymm3, %ymm1, %ymm5 -+; AVX512DQ-NEXT: vmovdqa 272(%rdi), %xmm8 - ; AVX512DQ-NEXT: vmovdqa 256(%rdi), %xmm2 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512DQ-NEXT: vmovdqa %xmm2, %xmm3 --; AVX512DQ-NEXT: vmovdqa64 %xmm1, %xmm19 --; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm13 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] --; AVX512DQ-NEXT: vpshufb %xmm13, %xmm6, %xmm6 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm8[2],xmm2[3,4],xmm8[5],xmm2[6,7] -+; AVX512DQ-NEXT: vmovdqa %xmm2, %xmm14 -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm9 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] -+; AVX512DQ-NEXT: vpshufb %xmm9, %xmm6, %xmm6 - ; AVX512DQ-NEXT: vinserti128 $1, %xmm6, %ymm0, %ymm6 - ; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm6 = ymm5[0,1,2],ymm6[3,4,5,6,7],ymm5[8,9,10],ymm6[11,12,13,14,15] - ; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm5 = xmm5[0,1,2,3,6,5,4,7] - ; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm5 = ymm5[0,1,2,3],ymm6[4,5,6,7] --; AVX512DQ-NEXT: vmovdqa64 320(%rdi), %ymm21 --; AVX512DQ-NEXT: vmovdqa64 352(%rdi), %ymm22 --; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm8 --; AVX512DQ-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm8 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm9 = ymm8[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm8 = ymm8[0],ymm9[1],ymm8[2,3],ymm9[4],ymm8[5,6],ymm9[7],ymm8[8],ymm9[9],ymm8[10,11],ymm9[12],ymm8[13,14],ymm9[15] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm10 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512DQ-NEXT: vpshufb %ymm10, %ymm8, %ymm11 -+; AVX512DQ-NEXT: vmovdqa64 320(%rdi), %ymm22 -+; AVX512DQ-NEXT: vmovdqa64 352(%rdi), %ymm23 -+; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm6 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm6 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0],ymm7[1],ymm6[2,3],ymm7[4],ymm6[5,6],ymm7[7],ymm6[8],ymm7[9],ymm6[10,11],ymm7[12],ymm6[13,14],ymm7[15] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm11 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512DQ-NEXT: vpshufb %ymm11, %ymm6, %ymm12 - ; AVX512DQ-NEXT: vmovdqa 304(%rdi), %xmm1 - ; AVX512DQ-NEXT: vmovdqa 288(%rdi), %xmm2 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm13 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] - ; AVX512DQ-NEXT: vmovdqa %xmm2, %xmm4 --; AVX512DQ-NEXT: vmovdqa %xmm1, %xmm8 --; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm14 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] --; AVX512DQ-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm11 = ymm12[0,1,2],ymm11[3,4,5,6,7] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm11, %zmm5, %zmm16 --; AVX512DQ-NEXT: vmovdqa64 128(%rdi), %ymm23 --; AVX512DQ-NEXT: vmovdqa 160(%rdi), %ymm11 -+; AVX512DQ-NEXT: vmovdqa %xmm1, %xmm6 -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm15 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] -+; AVX512DQ-NEXT: vpshufb %xmm15, %xmm13, %xmm13 -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm12 = ymm13[0,1,2],ymm12[3,4,5,6,7] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm12, %zmm5, %zmm16 -+; AVX512DQ-NEXT: vmovdqa64 128(%rdi), %ymm24 -+; AVX512DQ-NEXT: vmovdqa 160(%rdi), %ymm13 - ; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm5 --; AVX512DQ-NEXT: vpternlogq $202, %ymm23, %ymm11, %ymm5 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm24, %ymm13, %ymm5 - ; AVX512DQ-NEXT: vpermq {{.*#+}} ymm12 = ymm5[2,3,0,1] - ; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6],ymm12[7],ymm5[8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14],ymm12[15] --; AVX512DQ-NEXT: vpshufb %ymm10, %ymm5, %ymm10 --; AVX512DQ-NEXT: vmovdqa 112(%rdi), %xmm15 --; AVX512DQ-NEXT: vmovdqa 96(%rdi), %xmm5 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm12 = xmm5[0],xmm15[1],xmm5[2,3],xmm15[4],xmm5[5,6],xmm15[7] --; AVX512DQ-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm6 = ymm12[0,1,2],ymm10[3,4,5,6,7] --; AVX512DQ-NEXT: vmovdqa64 (%rdi), %ymm24 --; AVX512DQ-NEXT: vmovdqa 32(%rdi), %ymm12 -+; AVX512DQ-NEXT: vpshufb %ymm11, %ymm5, %ymm5 -+; AVX512DQ-NEXT: vmovdqa 112(%rdi), %xmm11 -+; AVX512DQ-NEXT: vmovdqa 96(%rdi), %xmm12 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm10 = xmm12[0],xmm11[1],xmm12[2,3],xmm11[4],xmm12[5,6],xmm11[7] -+; AVX512DQ-NEXT: vpshufb %xmm15, %xmm10, %xmm10 -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm10[0,1,2],ymm5[3,4,5,6,7] -+; AVX512DQ-NEXT: vmovdqa64 (%rdi), %ymm17 -+; AVX512DQ-NEXT: vmovdqa 32(%rdi), %ymm5 - ; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm10 --; AVX512DQ-NEXT: vpternlogq $202, %ymm12, %ymm24, %ymm10 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm1 = ymm10[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm10[0],ymm1[1],ymm10[2,3],ymm1[4],ymm10[5,6],ymm1[7],ymm10[8],ymm1[9],ymm10[10,11],ymm1[12],ymm10[13,14],ymm1[15] --; AVX512DQ-NEXT: vpshufb %ymm7, %ymm1, %ymm7 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm5, %ymm17, %ymm10 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm15 = ymm10[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0],ymm15[1],ymm10[2,3],ymm15[4],ymm10[5,6],ymm15[7],ymm10[8],ymm15[9],ymm10[10,11],ymm15[12],ymm10[13,14],ymm15[15] -+; AVX512DQ-NEXT: vpshufb %ymm3, %ymm10, %ymm2 - ; AVX512DQ-NEXT: vmovdqa 80(%rdi), %xmm10 --; AVX512DQ-NEXT: vmovdqa 64(%rdi), %xmm1 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm2 = xmm1[0,1],xmm10[2],xmm1[3,4],xmm10[5],xmm1[6,7] --; AVX512DQ-NEXT: vpshufb %xmm13, %xmm2, %xmm2 --; AVX512DQ-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm7[0,1,2],ymm2[3,4,5,6,7],ymm7[8,9,10],ymm2[11,12,13,14,15] --; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm7 = xmm7[0,1,2,3,6,5,4,7] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm2 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm6, %zmm2, %zmm17 --; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512DQ-NEXT: vpternlogq $202, %ymm22, %ymm21, %ymm2 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm6 = ymm2[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm6[2],ymm2[3,4],ymm6[5],ymm2[6,7,8,9],ymm6[10],ymm2[11,12],ymm6[13],ymm2[14,15] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm9 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] --; AVX512DQ-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm7 = xmm4[0,1],xmm8[2],xmm4[3,4],xmm8[5],xmm4[6,7] --; AVX512DQ-NEXT: vmovdqa64 %xmm8, %xmm25 -+; AVX512DQ-NEXT: vmovdqa 64(%rdi), %xmm15 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm3 = xmm15[0,1],xmm10[2],xmm15[3,4],xmm10[5],xmm15[6,7] -+; AVX512DQ-NEXT: vpshufb %xmm9, %xmm3, %xmm3 -+; AVX512DQ-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm2[0,1,2],ymm3[3,4,5,6,7],ymm2[8,9,10],ymm3[11,12,13,14,15] -+; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,6,5,4,7] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm3[4,5,6,7] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm18 -+; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm23, %ymm22, %ymm1 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm2[2],ymm1[3,4],ymm2[5],ymm1[6,7,8,9],ymm2[10],ymm1[11,12],ymm2[13],ymm1[14,15] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm2 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] -+; AVX512DQ-NEXT: vpshufb %ymm2, %ymm1, %ymm1 -+; AVX512DQ-NEXT: vmovdqa64 %ymm2, %ymm28 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0,1],xmm6[2],xmm4[3,4],xmm6[5],xmm4[6,7] -+; AVX512DQ-NEXT: vmovdqa64 %xmm6, %xmm25 - ; AVX512DQ-NEXT: vmovdqa64 %xmm4, %xmm26 - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm6 = [2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] --; AVX512DQ-NEXT: vpshufb %xmm6, %xmm7, %xmm7 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm7 = xmm7[0,1,2,3,4],xmm2[5,6,7] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm7 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] --; AVX512DQ-NEXT: vmovdqa %ymm13, %ymm2 --; AVX512DQ-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm2 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm4 = ymm2[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] -+; AVX512DQ-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -+; AVX512DQ-NEXT: vmovdqa %ymm9, %ymm1 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm1 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm4 = ymm1[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm4[2],ymm1[3,4],ymm4[5],ymm1[6,7,8,9],ymm4[10],ymm1[11,12],ymm4[13],ymm1[14,15] - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm4 = [2,3,8,9,14,15,4,5,12,13,10,11,0,1,6,7,18,19,24,25,30,31,20,21,28,29,26,27,16,17,22,23] --; AVX512DQ-NEXT: vpshufb %ymm4, %ymm2, %ymm2 --; AVX512DQ-NEXT: vmovdqa64 %xmm19, %xmm8 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm3[2],xmm8[3,4],xmm3[5],xmm8[6,7] --; AVX512DQ-NEXT: vmovdqa64 %xmm3, %xmm27 --; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm3 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] --; AVX512DQ-NEXT: vpshufb %xmm3, %xmm14, %xmm14 -+; AVX512DQ-NEXT: vpshufb %ymm4, %ymm1, %ymm1 -+; AVX512DQ-NEXT: vmovdqa %xmm14, %xmm7 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm14[2],xmm8[3,4],xmm14[5],xmm8[6,7] -+; AVX512DQ-NEXT: vmovdqa64 %xmm8, %xmm27 -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm2 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] -+; AVX512DQ-NEXT: vpshufb %xmm2, %xmm14, %xmm14 - ; AVX512DQ-NEXT: vinserti128 $1, %xmm14, %ymm0, %ymm14 --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm14 = ymm2[0,1,2],ymm14[3,4,5,6,7],ymm2[8,9,10],ymm14[11,12,13,14,15] --; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,5,6,7,4] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm14[4,5,6,7] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm7, %zmm2, %zmm19 --; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512DQ-NEXT: vpternlogq $202, %ymm11, %ymm23, %ymm2 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm2[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm7[2],ymm2[3,4],ymm7[5],ymm2[6,7,8,9],ymm7[10],ymm2[11,12],ymm7[13],ymm2[14,15] --; AVX512DQ-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm7 = xmm5[0,1],xmm15[2],xmm5[3,4],xmm15[5],xmm5[6,7] --; AVX512DQ-NEXT: vpshufb %xmm6, %xmm7, %xmm6 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm6 = xmm6[0,1,2,3,4],xmm2[5,6,7] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm2 = ymm6[0,1,2,3],ymm2[4,5,6,7] --; AVX512DQ-NEXT: vmovdqa %ymm13, %ymm6 --; AVX512DQ-NEXT: vpternlogq $202, %ymm24, %ymm12, %ymm6 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm7[2],ymm6[3,4],ymm7[5],ymm6[6,7,8,9],ymm7[10],ymm6[11,12],ymm7[13],ymm6[14,15] --; AVX512DQ-NEXT: vpshufb %ymm4, %ymm6, %ymm4 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm6 = xmm10[0,1],xmm1[2],xmm10[3,4],xmm1[5],xmm10[6,7] --; AVX512DQ-NEXT: vpshufb %xmm3, %xmm6, %xmm3 --; AVX512DQ-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm4[0,1,2],ymm3[3,4,5,6,7],ymm4[8,9,10],ymm3[11,12,13,14,15] --; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm4 = xmm4[0,1,2,3,5,6,7,4] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 --; AVX512DQ-NEXT: vpternlogq $226, %ymm23, %ymm13, %ymm11 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm11[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm11[1,2],ymm3[3],ymm11[4,5],ymm3[6],ymm11[7],ymm3[8],ymm11[9,10],ymm3[11],ymm11[12,13],ymm3[14],ymm11[15] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm11 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] --; AVX512DQ-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm15[0,1],xmm5[2],xmm15[3,4],xmm5[5],xmm15[6,7] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512DQ-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512DQ-NEXT: vpternlogq $226, %ymm24, %ymm0, %ymm12 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm4 = ymm12[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0],ymm12[1,2],ymm4[3],ymm12[4,5],ymm4[6],ymm12[7],ymm4[8],ymm12[9,10],ymm4[11],ymm12[12,13],ymm4[14],ymm12[15] --; AVX512DQ-NEXT: vpshufb %ymm11, %ymm4, %ymm4 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm10[1],xmm1[2,3],xmm10[4],xmm1[5,6],xmm10[7] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm14 = ymm1[0,1,2],ymm14[3,4,5,6,7],ymm1[8,9,10],ymm14[11,12,13,14,15] -+; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,5,6,7,4] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm14[4,5,6,7] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm19 -+; AVX512DQ-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm13, %ymm24, %ymm1 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] -+; AVX512DQ-NEXT: vmovdqa64 %ymm28, %ymm3 -+; AVX512DQ-NEXT: vpshufb %ymm3, %ymm1, %ymm1 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm3 = xmm12[0,1],xmm11[2],xmm12[3,4],xmm11[5],xmm12[6,7] -+; AVX512DQ-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512DQ-NEXT: vmovdqa %ymm9, %ymm3 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm17, %ymm5, %ymm3 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm6 = ymm3[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] -+; AVX512DQ-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm10[0,1],xmm15[2],xmm10[3,4],xmm15[5],xmm10[6,7] -+; AVX512DQ-NEXT: vpshufb %xmm2, %xmm4, %xmm2 -+; AVX512DQ-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm2[3,4,5,6,7],ymm3[8,9,10],ymm2[11,12,13,14,15] -+; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,5,6,7,4] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm2 = ymm3[0,1,2,3],ymm2[4,5,6,7] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 -+; AVX512DQ-NEXT: vpternlogq $226, %ymm24, %ymm9, %ymm13 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm13[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0],ymm13[1,2],ymm2[3],ymm13[4,5],ymm2[6],ymm13[7],ymm2[8],ymm13[9,10],ymm2[11],ymm13[12,13],ymm2[14],ymm13[15] -+; AVX512DQ-NEXT: vpternlogq $226, %ymm17, %ymm0, %ymm5 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm5[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm5[1,2],ymm3[3],ymm5[4,5],ymm3[6],ymm5[7],ymm3[8],ymm5[9,10],ymm3[11],ymm5[12,13],ymm3[14],ymm5[15] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm4 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] -+; AVX512DQ-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm5 = xmm15[0],xmm10[1],xmm15[2,3],xmm10[4],xmm15[5,6],xmm10[7] - ; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm6 = [0,1,2,3,0,1,6,7,12,13,2,3,8,9,14,15] --; AVX512DQ-NEXT: vpshufb %xmm6, %xmm1, %xmm1 --; AVX512DQ-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm1 --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm1 = ymm4[0,1,2,3,4],ymm1[5,6,7] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm1 --; AVX512DQ-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm13 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm13[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm13[1,2],ymm3[3],ymm13[4,5],ymm3[6],ymm13[7],ymm3[8],ymm13[9,10],ymm3[11],ymm13[12,13],ymm3[14],ymm13[15] --; AVX512DQ-NEXT: vmovdqa64 %xmm25, %xmm4 --; AVX512DQ-NEXT: vmovdqa64 %xmm26, %xmm7 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm7[2],xmm4[3,4],xmm7[5],xmm4[6,7] --; AVX512DQ-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512DQ-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512DQ-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm0 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm4 = ymm0[2,3,0,1] --; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm0 = ymm4[0],ymm0[1,2],ymm4[3],ymm0[4,5],ymm4[6],ymm0[7],ymm4[8],ymm0[9,10],ymm4[11],ymm0[12,13],ymm4[14],ymm0[15] --; AVX512DQ-NEXT: vpshufb %ymm11, %ymm0, %ymm0 -+; AVX512DQ-NEXT: vpshufb %xmm6, %xmm5, %xmm5 -+; AVX512DQ-NEXT: vinserti128 $1, %xmm5, %ymm0, %ymm5 -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3,4],ymm5[5,6,7] -+; AVX512DQ-NEXT: vpshufb %ymm4, %ymm2, %ymm2 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm5 = xmm11[0,1],xmm12[2],xmm11[3,4],xmm12[5],xmm11[6,7] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm8 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512DQ-NEXT: vpshufb %xmm8, %xmm5, %xmm5 -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm5, %zmm3, %zmm5 -+; AVX512DQ-NEXT: vextracti32x4 $2, %zmm5, %xmm5 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm5 = xmm5[0,1,2,3,4],xmm2[5,6,7] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm2 = ymm5[0,1,2,3],ymm2[4,5,6,7] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 -+; AVX512DQ-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm9 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm9[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm9[1,2],ymm3[3],ymm9[4,5],ymm3[6],ymm9[7],ymm3[8],ymm9[9,10],ymm3[11],ymm9[12,13],ymm3[14],ymm9[15] -+; AVX512DQ-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm0 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm5 = ymm0[2,3,0,1] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} ymm0 = ymm5[0],ymm0[1,2],ymm5[3],ymm0[4,5],ymm5[6],ymm0[7],ymm5[8],ymm0[9,10],ymm5[11],ymm0[12,13],ymm5[14],ymm0[15] -+; AVX512DQ-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512DQ-NEXT: vpshufb %ymm4, %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa64 %xmm27, %xmm4 --; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0],xmm8[1],xmm4[2,3],xmm8[4],xmm4[5,6],xmm8[7] -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm7[0],xmm4[1],xmm7[2,3],xmm4[4],xmm7[5,6],xmm4[7] - ; AVX512DQ-NEXT: vpshufb %xmm6, %xmm4, %xmm4 - ; AVX512DQ-NEXT: vinserti128 $1, %xmm4, %ymm0, %ymm4 - ; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm4[5,6,7] -+; AVX512DQ-NEXT: vmovdqa64 %xmm25, %xmm4 -+; AVX512DQ-NEXT: vmovdqa64 %xmm26, %xmm5 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm5[2],xmm4[3,4],xmm5[5],xmm4[6,7] -+; AVX512DQ-NEXT: vpshufb %xmm8, %xmm4, %xmm4 -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm4 -+; AVX512DQ-NEXT: vextracti32x4 $2, %zmm4, %xmm4 -+; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] -+; AVX512DQ-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] - ; AVX512DQ-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 --; AVX512DQ-NEXT: vmovdqa64 %zmm17, (%rsi) -+; AVX512DQ-NEXT: vmovdqa64 %zmm18, (%rsi) - ; AVX512DQ-NEXT: vmovdqa64 %zmm16, 64(%rsi) - ; AVX512DQ-NEXT: vmovdqa64 %zmm19, 64(%rdx) --; AVX512DQ-NEXT: vmovdqa64 %zmm2, (%rdx) -+; AVX512DQ-NEXT: vmovdqa64 %zmm1, (%rdx) - ; AVX512DQ-NEXT: vmovdqa64 %zmm0, 64(%rcx) --; AVX512DQ-NEXT: vmovdqa64 %zmm1, (%rcx) -+; AVX512DQ-NEXT: vmovdqa64 %zmm2, (%rcx) - ; AVX512DQ-NEXT: vzeroupper - ; AVX512DQ-NEXT: retq - ; - ; AVX512DQ-FCP-LABEL: load_i16_stride3_vf64: - ; AVX512DQ-FCP: # %bb.0: - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm0 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] --; AVX512DQ-FCP-NEXT: vmovdqa64 224(%rdi), %ymm18 --; AVX512DQ-FCP-NEXT: vmovdqa64 192(%rdi), %ymm20 -+; AVX512DQ-FCP-NEXT: vmovdqa64 224(%rdi), %ymm20 -+; AVX512DQ-FCP-NEXT: vmovdqa64 192(%rdi), %ymm21 - ; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm1 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm18, %ymm20, %ymm1 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm1[0],ymm2[1],ymm1[2,3],ymm2[4],ymm1[5,6],ymm2[7],ymm1[8],ymm2[9],ymm1[10,11],ymm2[12],ymm1[13,14],ymm2[15] --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm7 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] --; AVX512DQ-FCP-NEXT: vpshufb %ymm7, %ymm2, %ymm5 --; AVX512DQ-FCP-NEXT: vmovdqa 272(%rdi), %xmm1 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm20, %ymm21, %ymm1 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0],ymm3[1],ymm1[2,3],ymm3[4],ymm1[5,6],ymm3[7],ymm1[8],ymm3[9],ymm1[10,11],ymm3[12],ymm1[13,14],ymm3[15] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm3 = [0,1,6,7,12,13,2,3,4,5,14,15,8,9,10,11,16,17,22,23,28,29,18,19,20,21,30,31,24,25,26,27] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm3, %ymm1, %ymm5 -+; AVX512DQ-FCP-NEXT: vmovdqa 272(%rdi), %xmm8 - ; AVX512DQ-FCP-NEXT: vmovdqa 256(%rdi), %xmm2 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm1[2],xmm2[3,4],xmm1[5],xmm2[6,7] --; AVX512DQ-FCP-NEXT: vmovdqa %xmm2, %xmm3 --; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm1, %xmm19 --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm13 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] --; AVX512DQ-FCP-NEXT: vpshufb %xmm13, %xmm6, %xmm6 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm2[0,1],xmm8[2],xmm2[3,4],xmm8[5],xmm2[6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa %xmm2, %xmm14 -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm9 = [4,5,14,15,0,1,2,3,8,9,14,15,4,5,10,11] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm9, %xmm6, %xmm6 - ; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm6, %ymm0, %ymm6 - ; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm6 = ymm5[0,1,2],ymm6[3,4,5,6,7],ymm5[8,9,10],ymm6[11,12,13,14,15] - ; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm5 = xmm5[0,1,2,3,6,5,4,7] - ; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm5 = ymm5[0,1,2,3],ymm6[4,5,6,7] --; AVX512DQ-FCP-NEXT: vmovdqa64 320(%rdi), %ymm21 --; AVX512DQ-FCP-NEXT: vmovdqa64 352(%rdi), %ymm22 --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm8 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm8 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm9 = ymm8[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm8 = ymm8[0],ymm9[1],ymm8[2,3],ymm9[4],ymm8[5,6],ymm9[7],ymm8[8],ymm9[9],ymm8[10,11],ymm9[12],ymm8[13,14],ymm9[15] --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm10 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] --; AVX512DQ-FCP-NEXT: vpshufb %ymm10, %ymm8, %ymm11 -+; AVX512DQ-FCP-NEXT: vmovdqa64 320(%rdi), %ymm22 -+; AVX512DQ-FCP-NEXT: vmovdqa64 352(%rdi), %ymm23 -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm6 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm6 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0],ymm7[1],ymm6[2,3],ymm7[4],ymm6[5,6],ymm7[7],ymm6[8],ymm7[9],ymm6[10,11],ymm7[12],ymm6[13,14],ymm7[15] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm11 = [0,1,6,7,12,13,2,3,8,9,14,15,4,5,10,11,16,17,22,23,28,29,18,19,24,25,30,31,20,21,26,27] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm6, %ymm12 - ; AVX512DQ-FCP-NEXT: vmovdqa 304(%rdi), %xmm1 - ; AVX512DQ-FCP-NEXT: vmovdqa 288(%rdi), %xmm2 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm13 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] - ; AVX512DQ-FCP-NEXT: vmovdqa %xmm2, %xmm4 --; AVX512DQ-FCP-NEXT: vmovdqa %xmm1, %xmm8 --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm14 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] --; AVX512DQ-FCP-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm11 = ymm12[0,1,2],ymm11[3,4,5,6,7] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm11, %zmm5, %zmm16 --; AVX512DQ-FCP-NEXT: vmovdqa64 128(%rdi), %ymm23 --; AVX512DQ-FCP-NEXT: vmovdqa 160(%rdi), %ymm11 -+; AVX512DQ-FCP-NEXT: vmovdqa %xmm1, %xmm6 -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm15 = [0,1,6,7,12,13,2,3,8,9,14,15,12,13,14,15] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm15, %xmm13, %xmm13 -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm12 = ymm13[0,1,2],ymm12[3,4,5,6,7] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm12, %zmm5, %zmm16 -+; AVX512DQ-FCP-NEXT: vmovdqa64 128(%rdi), %ymm24 -+; AVX512DQ-FCP-NEXT: vmovdqa 160(%rdi), %ymm13 - ; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm5 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm23, %ymm11, %ymm5 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm24, %ymm13, %ymm5 - ; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm12 = ymm5[2,3,0,1] - ; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6],ymm12[7],ymm5[8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14],ymm12[15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm10, %ymm5, %ymm10 --; AVX512DQ-FCP-NEXT: vmovdqa 112(%rdi), %xmm15 --; AVX512DQ-FCP-NEXT: vmovdqa 96(%rdi), %xmm5 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm12 = xmm5[0],xmm15[1],xmm5[2,3],xmm15[4],xmm5[5,6],xmm15[7] --; AVX512DQ-FCP-NEXT: vpshufb %xmm14, %xmm12, %xmm12 --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm6 = ymm12[0,1,2],ymm10[3,4,5,6,7] --; AVX512DQ-FCP-NEXT: vmovdqa64 (%rdi), %ymm24 --; AVX512DQ-FCP-NEXT: vmovdqa 32(%rdi), %ymm12 -+; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm5, %ymm5 -+; AVX512DQ-FCP-NEXT: vmovdqa 112(%rdi), %xmm11 -+; AVX512DQ-FCP-NEXT: vmovdqa 96(%rdi), %xmm12 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm10 = xmm12[0],xmm11[1],xmm12[2,3],xmm11[4],xmm12[5,6],xmm11[7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm15, %xmm10, %xmm10 -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm10[0,1,2],ymm5[3,4,5,6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa64 (%rdi), %ymm17 -+; AVX512DQ-FCP-NEXT: vmovdqa 32(%rdi), %ymm5 - ; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm10 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm12, %ymm24, %ymm10 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm1 = ymm10[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm10[0],ymm1[1],ymm10[2,3],ymm1[4],ymm10[5,6],ymm1[7],ymm10[8],ymm1[9],ymm10[10,11],ymm1[12],ymm10[13,14],ymm1[15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm7, %ymm1, %ymm7 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm5, %ymm17, %ymm10 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm15 = ymm10[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0],ymm15[1],ymm10[2,3],ymm15[4],ymm10[5,6],ymm15[7],ymm10[8],ymm15[9],ymm10[10,11],ymm15[12],ymm10[13,14],ymm15[15] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm3, %ymm10, %ymm2 - ; AVX512DQ-FCP-NEXT: vmovdqa 80(%rdi), %xmm10 --; AVX512DQ-FCP-NEXT: vmovdqa 64(%rdi), %xmm1 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm2 = xmm1[0,1],xmm10[2],xmm1[3,4],xmm10[5],xmm1[6,7] --; AVX512DQ-FCP-NEXT: vpshufb %xmm13, %xmm2, %xmm2 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm7[0,1,2],ymm2[3,4,5,6,7],ymm7[8,9,10],ymm2[11,12,13,14,15] --; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm7 = xmm7[0,1,2,3,6,5,4,7] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm6, %zmm2, %zmm17 --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm22, %ymm21, %ymm2 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm2[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm6[2],ymm2[3,4],ymm6[5],ymm2[6,7,8,9],ymm6[10],ymm2[11,12],ymm6[13],ymm2[14,15] --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm9 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] --; AVX512DQ-FCP-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm7 = xmm4[0,1],xmm8[2],xmm4[3,4],xmm8[5],xmm4[6,7] --; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm8, %xmm25 -+; AVX512DQ-FCP-NEXT: vmovdqa 64(%rdi), %xmm15 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm15[0,1],xmm10[2],xmm15[3,4],xmm10[5],xmm15[6,7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm9, %xmm3, %xmm3 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm2[0,1,2],ymm3[3,4,5,6,7],ymm2[8,9,10],ymm3[11,12,13,14,15] -+; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,6,5,4,7] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm3[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm18 -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm23, %ymm22, %ymm1 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm1[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm2[2],ymm1[3,4],ymm2[5],ymm1[6,7,8,9],ymm2[10],ymm1[11,12],ymm2[13],ymm1[14,15] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm2 = [2,3,8,9,14,15,4,5,10,11,0,1,6,7,12,13,18,19,24,25,30,31,20,21,26,27,16,17,22,23,28,29] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm2, %ymm1, %ymm1 -+; AVX512DQ-FCP-NEXT: vmovdqa64 %ymm2, %ymm28 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm4[0,1],xmm6[2],xmm4[3,4],xmm6[5],xmm4[6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm6, %xmm25 - ; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm4, %xmm26 - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm6 = [2,3,8,9,14,15,4,5,10,11,10,11,10,11,10,11] --; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm7, %xmm7 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm7 = xmm7[0,1,2,3,4],xmm2[5,6,7] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm7 = ymm7[0,1,2,3],ymm2[4,5,6,7] --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] --; AVX512DQ-FCP-NEXT: vmovdqa %ymm13, %ymm2 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm2 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm2[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535] -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm9, %ymm1 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm1 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm1[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm4[2],ymm1[3,4],ymm4[5],ymm1[6,7,8,9],ymm4[10],ymm1[11,12],ymm4[13],ymm1[14,15] - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm4 = [2,3,8,9,14,15,4,5,12,13,10,11,0,1,6,7,18,19,24,25,30,31,20,21,28,29,26,27,16,17,22,23] --; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm2, %ymm2 --; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm19, %xmm8 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm3[2],xmm8[3,4],xmm3[5],xmm8[6,7] --; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm3, %xmm27 --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm3 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] --; AVX512DQ-FCP-NEXT: vpshufb %xmm3, %xmm14, %xmm14 -+; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm1, %ymm1 -+; AVX512DQ-FCP-NEXT: vmovdqa %xmm14, %xmm7 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm14 = xmm8[0,1],xmm14[2],xmm8[3,4],xmm14[5],xmm8[6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm8, %xmm27 -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm2 = [4,5,4,5,4,5,4,5,10,11,0,1,6,7,12,13] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm2, %xmm14, %xmm14 - ; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm14, %ymm0, %ymm14 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm14 = ymm2[0,1,2],ymm14[3,4,5,6,7],ymm2[8,9,10],ymm14[11,12,13,14,15] --; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,5,6,7,4] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm2[0,1,2,3],ymm14[4,5,6,7] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm7, %zmm2, %zmm19 --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm2 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm11, %ymm23, %ymm2 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm2[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm7[2],ymm2[3,4],ymm7[5],ymm2[6,7,8,9],ymm7[10],ymm2[11,12],ymm7[13],ymm2[14,15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm9, %ymm2, %ymm2 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm7 = xmm5[0,1],xmm15[2],xmm5[3,4],xmm15[5],xmm5[6,7] --; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm7, %xmm6 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm6[0,1,2,3,4],xmm2[5,6,7] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm6[0,1,2,3],ymm2[4,5,6,7] --; AVX512DQ-FCP-NEXT: vmovdqa %ymm13, %ymm6 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm24, %ymm12, %ymm6 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm6[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm7[2],ymm6[3,4],ymm7[5],ymm6[6,7,8,9],ymm7[10],ymm6[11,12],ymm7[13],ymm6[14,15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm6, %ymm4 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm6 = xmm10[0,1],xmm1[2],xmm10[3,4],xmm1[5],xmm10[6,7] --; AVX512DQ-FCP-NEXT: vpshufb %xmm3, %xmm6, %xmm3 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm3, %ymm0, %ymm3 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm4[0,1,2],ymm3[3,4,5,6,7],ymm4[8,9,10],ymm3[11,12,13,14,15] --; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm4 = xmm4[0,1,2,3,5,6,7,4] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 --; AVX512DQ-FCP-NEXT: vpternlogq $226, %ymm23, %ymm13, %ymm11 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm11[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm11[1,2],ymm3[3],ymm11[4,5],ymm3[6],ymm11[7],ymm3[8],ymm11[9,10],ymm3[11],ymm11[12,13],ymm3[14],ymm11[15] --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm11 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] --; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm15[0,1],xmm5[2],xmm15[3,4],xmm5[5],xmm15[6,7] --; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] --; AVX512DQ-FCP-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512DQ-FCP-NEXT: vpternlogq $226, %ymm24, %ymm0, %ymm12 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm12[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0],ymm12[1,2],ymm4[3],ymm12[4,5],ymm4[6],ymm12[7],ymm4[8],ymm12[9,10],ymm4[11],ymm12[12,13],ymm4[14],ymm12[15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm4, %ymm4 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm10[1],xmm1[2,3],xmm10[4],xmm1[5,6],xmm10[7] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm14 = ymm1[0,1,2],ymm14[3,4,5,6,7],ymm1[8,9,10],ymm14[11,12,13,14,15] -+; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,5,6,7,4] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm14[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm19 -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, %ymm1 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm13, %ymm24, %ymm1 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm1[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] -+; AVX512DQ-FCP-NEXT: vmovdqa64 %ymm28, %ymm3 -+; AVX512DQ-FCP-NEXT: vpshufb %ymm3, %ymm1, %ymm1 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm12[0,1],xmm11[2],xmm12[3,4],xmm11[5],xmm12[6,7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm3, %xmm3 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm3 = xmm3[0,1,2,3,4],xmm1[5,6,7] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm3[0,1,2,3],ymm1[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm9, %ymm3 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm17, %ymm5, %ymm3 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm3[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm10[0,1],xmm15[2],xmm10[3,4],xmm15[5],xmm10[6,7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm2, %xmm4, %xmm2 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm2, %ymm0, %ymm2 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm2[3,4,5,6,7],ymm3[8,9,10],ymm2[11,12,13,14,15] -+; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,5,6,7,4] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm3[0,1,2,3],ymm2[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 -+; AVX512DQ-FCP-NEXT: vpternlogq $226, %ymm24, %ymm9, %ymm13 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm13[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0],ymm13[1,2],ymm2[3],ymm13[4,5],ymm2[6],ymm13[7],ymm2[8],ymm13[9,10],ymm2[11],ymm13[12,13],ymm2[14],ymm13[15] -+; AVX512DQ-FCP-NEXT: vpternlogq $226, %ymm17, %ymm0, %ymm5 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm5[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm5[1,2],ymm3[3],ymm5[4,5],ymm3[6],ymm5[7],ymm3[8],ymm5[9,10],ymm3[11],ymm5[12,13],ymm3[14],ymm5[15] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} ymm4 = [4,5,10,11,0,1,6,7,12,13,2,3,8,9,14,15,20,21,26,27,16,17,22,23,28,29,18,19,24,25,30,31] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm5 = xmm15[0],xmm10[1],xmm15[2,3],xmm10[4],xmm15[5,6],xmm10[7] - ; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm6 = [0,1,2,3,0,1,6,7,12,13,2,3,8,9,14,15] --; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm1, %xmm1 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm1 --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm1 = ymm4[0,1,2,3,4],ymm1[5,6,7] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm1 --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm21, %ymm22, %ymm13 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm13[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm13[1,2],ymm3[3],ymm13[4,5],ymm3[6],ymm13[7],ymm3[8],ymm13[9,10],ymm3[11],ymm13[12,13],ymm3[14],ymm13[15] --; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm25, %xmm4 --; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm26, %xmm7 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm7[2],xmm4[3,4],xmm7[5],xmm4[6,7] --; AVX512DQ-FCP-NEXT: vpshufb %xmm5, %xmm4, %xmm4 --; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm3, %ymm3 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] --; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] --; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm20, %ymm18, %ymm0 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm0[2,3,0,1] --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm4[0],ymm0[1,2],ymm4[3],ymm0[4,5],ymm4[6],ymm0[7],ymm4[8],ymm0[9,10],ymm4[11],ymm0[12,13],ymm4[14],ymm0[15] --; AVX512DQ-FCP-NEXT: vpshufb %ymm11, %ymm0, %ymm0 -+; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm5, %xmm5 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm5, %ymm0, %ymm5 -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3,4],ymm5[5,6,7] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm2, %ymm2 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm5 = xmm11[0,1],xmm12[2],xmm11[3,4],xmm12[5],xmm11[6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa {{.*#+}} xmm8 = [4,5,10,11,0,1,6,7,12,13,14,15,0,1,2,3] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm8, %xmm5, %xmm5 -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm5, %zmm3, %zmm5 -+; AVX512DQ-FCP-NEXT: vextracti32x4 $2, %zmm5, %xmm5 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm5 = xmm5[0,1,2,3,4],xmm2[5,6,7] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm2 = ymm5[0,1,2,3],ymm2[4,5,6,7] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm22, %ymm23, %ymm9 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm9[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0],ymm9[1,2],ymm3[3],ymm9[4,5],ymm3[6],ymm9[7],ymm3[8],ymm9[9,10],ymm3[11],ymm9[12,13],ymm3[14],ymm9[15] -+; AVX512DQ-FCP-NEXT: vpternlogq $202, %ymm21, %ymm20, %ymm0 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm5 = ymm0[2,3,0,1] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} ymm0 = ymm5[0],ymm0[1,2],ymm5[3],ymm0[4,5],ymm5[6],ymm0[7],ymm5[8],ymm0[9,10],ymm5[11],ymm0[12,13],ymm5[14],ymm0[15] -+; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512DQ-FCP-NEXT: vpshufb %ymm4, %ymm0, %ymm0 - ; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm27, %xmm4 --; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0],xmm8[1],xmm4[2,3],xmm8[4],xmm4[5,6],xmm8[7] -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm7[0],xmm4[1],xmm7[2,3],xmm4[4],xmm7[5,6],xmm4[7] - ; AVX512DQ-FCP-NEXT: vpshufb %xmm6, %xmm4, %xmm4 - ; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm4, %ymm0, %ymm4 - ; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3,4],ymm4[5,6,7] -+; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm25, %xmm4 -+; AVX512DQ-FCP-NEXT: vmovdqa64 %xmm26, %xmm5 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1],xmm5[2],xmm4[3,4],xmm5[5],xmm4[6,7] -+; AVX512DQ-FCP-NEXT: vpshufb %xmm8, %xmm4, %xmm4 -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm4 -+; AVX512DQ-FCP-NEXT: vextracti32x4 $2, %zmm4, %xmm4 -+; AVX512DQ-FCP-NEXT: vpblendw {{.*#+}} xmm4 = xmm4[0,1,2,3,4],xmm3[5,6,7] -+; AVX512DQ-FCP-NEXT: vpblendd {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7] - ; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 --; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm17, (%rsi) -+; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm18, (%rsi) - ; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm16, 64(%rsi) - ; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm19, 64(%rdx) --; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm2, (%rdx) -+; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm1, (%rdx) - ; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm0, 64(%rcx) --; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm1, (%rcx) -+; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm2, (%rcx) - ; AVX512DQ-FCP-NEXT: vzeroupper - ; AVX512DQ-FCP-NEXT: retq - ; -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll -@@ -1246,28 +1246,29 @@ - ; AVX512BW-FCP-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 - ; AVX512BW-FCP-NEXT: vpmovsxbq {{.*#+}} ymm1 = [0,2,4,0] - ; AVX512BW-FCP-NEXT: vpermi2q %ymm3, %ymm0, %ymm1 --; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm0 = ymm2[0,8],zero,zero,zero,zero,zero,ymm2[1,9],zero,zero,zero,zero,zero,ymm2[2,10,18,26],zero,zero,zero,zero,zero,ymm2[19,27],zero,zero,zero,zero,zero,ymm2[20,28] --; AVX512BW-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm2[2,3,0,1] --; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,ymm3[0,8],zero,zero,zero,zero,zero,ymm3[1,9],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,ymm3[19,27],zero,zero,zero,zero,zero,ymm3[20,28],zero,zero -+; AVX512BW-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [1,3,5,7,1,3,5,7] -+; AVX512BW-FCP-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512BW-FCP-NEXT: vpermd %ymm2, %ymm0, %ymm0 -+; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm0[1,5,9,13],zero,zero,zero,ymm0[2,6,10,14],zero,zero,zero,ymm0[19,23,27,31],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero -+; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} ymm3 = [1,3,5,0,5,1,3,0] -+; AVX512BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm3 -+; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm3[0,4,8],zero,zero,zero,zero,ymm3[1,5,9],zero,zero,zero,zero,ymm3[2,6,18],zero,zero,zero,zero,ymm3[23,27,19],zero,zero,zero,zero,zero,zero,zero,zero - ; AVX512BW-FCP-NEXT: vpor %ymm0, %ymm3, %ymm0 -+; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm2[0,8],zero,zero,zero,zero,zero,ymm2[1,9],zero,zero,zero,zero,zero,ymm2[2,10,18,26],zero,zero,zero,zero,zero,ymm2[19,27],zero,zero,zero,zero,zero,ymm2[20,28] -+; AVX512BW-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,0,1] -+; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,ymm2[0,8],zero,zero,zero,zero,zero,ymm2[1,9],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,ymm2[19,27],zero,zero,zero,zero,zero,ymm2[20,28],zero,zero -+; AVX512BW-FCP-NEXT: vpor %ymm3, %ymm2, %ymm2 - ; AVX512BW-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm3 = [0,2,4,6,0,2,4,6] - ; AVX512BW-FCP-NEXT: # ymm3 = mem[0,1,0,1] --; AVX512BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm3 -+; AVX512BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm1 - ; AVX512BW-FCP-NEXT: movl $236730480, %ecx # imm = 0xE1C3870 - ; AVX512BW-FCP-NEXT: kmovd %ecx, %k1 --; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm0 {%k1} = ymm3[u,u,u,u,0,4,8,u,u,u,u,1,5,9,u,u,u,u,18,22,26,u,u,u,u,19,23,27,u,u,u,u] --; AVX512BW-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm3 = [1,3,5,7,1,3,5,7] --; AVX512BW-FCP-NEXT: # ymm3 = mem[0,1,0,1] --; AVX512BW-FCP-NEXT: vpermd %ymm2, %ymm3, %ymm2 --; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,zero,ymm2[1,5,9,13],zero,zero,zero,ymm2[2,6,10,14],zero,zero,zero,ymm2[19,23,27,31],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero --; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} ymm3 = [1,3,5,0,5,1,3,0] --; AVX512BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm1 --; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,4,8],zero,zero,zero,zero,ymm1[1,5,9],zero,zero,zero,zero,ymm1[2,6,18],zero,zero,zero,zero,ymm1[23,27,19],zero,zero,zero,zero,zero,zero,zero,zero --; AVX512BW-FCP-NEXT: vpor %ymm2, %ymm1, %ymm1 --; AVX512BW-FCP-NEXT: vextracti128 $1, %ymm1, %xmm2 --; AVX512BW-FCP-NEXT: vmovq %xmm2, 48(%rax) --; AVX512BW-FCP-NEXT: vmovdqa %xmm1, 32(%rax) --; AVX512BW-FCP-NEXT: vmovdqa %ymm0, (%rax) -+; AVX512BW-FCP-NEXT: vpshufb {{.*#+}} ymm2 {%k1} = ymm1[u,u,u,u,0,4,8,u,u,u,u,1,5,9,u,u,u,u,18,22,26,u,u,u,u,19,23,27,u,u,u,u] -+; AVX512BW-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512BW-FCP-NEXT: vmovdqa %ymm2, (%rax) -+; AVX512BW-FCP-NEXT: vextracti32x4 $2, %zmm0, 32(%rax) -+; AVX512BW-FCP-NEXT: vextracti32x4 $3, %zmm0, %xmm0 -+; AVX512BW-FCP-NEXT: vmovq %xmm0, 48(%rax) - ; AVX512BW-FCP-NEXT: vzeroupper - ; AVX512BW-FCP-NEXT: retq - ; -@@ -1325,28 +1326,29 @@ - ; AVX512DQ-BW-FCP-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0 - ; AVX512DQ-BW-FCP-NEXT: vpmovsxbq {{.*#+}} ymm1 = [0,2,4,0] - ; AVX512DQ-BW-FCP-NEXT: vpermi2q %ymm3, %ymm0, %ymm1 --; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm0 = ymm2[0,8],zero,zero,zero,zero,zero,ymm2[1,9],zero,zero,zero,zero,zero,ymm2[2,10,18,26],zero,zero,zero,zero,zero,ymm2[19,27],zero,zero,zero,zero,zero,ymm2[20,28] --; AVX512DQ-BW-FCP-NEXT: vpermq {{.*#+}} ymm3 = ymm2[2,3,0,1] --; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,ymm3[0,8],zero,zero,zero,zero,zero,ymm3[1,9],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,ymm3[19,27],zero,zero,zero,zero,zero,ymm3[20,28],zero,zero -+; AVX512DQ-BW-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [1,3,5,7,1,3,5,7] -+; AVX512DQ-BW-FCP-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512DQ-BW-FCP-NEXT: vpermd %ymm2, %ymm0, %ymm0 -+; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm0[1,5,9,13],zero,zero,zero,ymm0[2,6,10,14],zero,zero,zero,ymm0[19,23,27,31],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero -+; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} ymm3 = [1,3,5,0,5,1,3,0] -+; AVX512DQ-BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm3 -+; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm3[0,4,8],zero,zero,zero,zero,ymm3[1,5,9],zero,zero,zero,zero,ymm3[2,6,18],zero,zero,zero,zero,ymm3[23,27,19],zero,zero,zero,zero,zero,zero,zero,zero - ; AVX512DQ-BW-FCP-NEXT: vpor %ymm0, %ymm3, %ymm0 -+; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm3 = ymm2[0,8],zero,zero,zero,zero,zero,ymm2[1,9],zero,zero,zero,zero,zero,ymm2[2,10,18,26],zero,zero,zero,zero,zero,ymm2[19,27],zero,zero,zero,zero,zero,ymm2[20,28] -+; AVX512DQ-BW-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,0,1] -+; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,ymm2[0,8],zero,zero,zero,zero,zero,ymm2[1,9],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,ymm2[19,27],zero,zero,zero,zero,zero,ymm2[20,28],zero,zero -+; AVX512DQ-BW-FCP-NEXT: vpor %ymm3, %ymm2, %ymm2 - ; AVX512DQ-BW-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm3 = [0,2,4,6,0,2,4,6] - ; AVX512DQ-BW-FCP-NEXT: # ymm3 = mem[0,1,0,1] --; AVX512DQ-BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm3 -+; AVX512DQ-BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm1 - ; AVX512DQ-BW-FCP-NEXT: movl $236730480, %ecx # imm = 0xE1C3870 - ; AVX512DQ-BW-FCP-NEXT: kmovd %ecx, %k1 --; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm0 {%k1} = ymm3[u,u,u,u,0,4,8,u,u,u,u,1,5,9,u,u,u,u,18,22,26,u,u,u,u,19,23,27,u,u,u,u] --; AVX512DQ-BW-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm3 = [1,3,5,7,1,3,5,7] --; AVX512DQ-BW-FCP-NEXT: # ymm3 = mem[0,1,0,1] --; AVX512DQ-BW-FCP-NEXT: vpermd %ymm2, %ymm3, %ymm2 --; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,zero,ymm2[1,5,9,13],zero,zero,zero,ymm2[2,6,10,14],zero,zero,zero,ymm2[19,23,27,31],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero --; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} ymm3 = [1,3,5,0,5,1,3,0] --; AVX512DQ-BW-FCP-NEXT: vpermd %ymm1, %ymm3, %ymm1 --; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,4,8],zero,zero,zero,zero,ymm1[1,5,9],zero,zero,zero,zero,ymm1[2,6,18],zero,zero,zero,zero,ymm1[23,27,19],zero,zero,zero,zero,zero,zero,zero,zero --; AVX512DQ-BW-FCP-NEXT: vpor %ymm2, %ymm1, %ymm1 --; AVX512DQ-BW-FCP-NEXT: vextracti128 $1, %ymm1, %xmm2 --; AVX512DQ-BW-FCP-NEXT: vmovq %xmm2, 48(%rax) --; AVX512DQ-BW-FCP-NEXT: vmovdqa %xmm1, 32(%rax) --; AVX512DQ-BW-FCP-NEXT: vmovdqa %ymm0, (%rax) -+; AVX512DQ-BW-FCP-NEXT: vpshufb {{.*#+}} ymm2 {%k1} = ymm1[u,u,u,u,0,4,8,u,u,u,u,1,5,9,u,u,u,u,18,22,26,u,u,u,u,19,23,27,u,u,u,u] -+; AVX512DQ-BW-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512DQ-BW-FCP-NEXT: vmovdqa %ymm2, (%rax) -+; AVX512DQ-BW-FCP-NEXT: vextracti32x4 $2, %zmm0, 32(%rax) -+; AVX512DQ-BW-FCP-NEXT: vextracti32x4 $3, %zmm0, %xmm0 -+; AVX512DQ-BW-FCP-NEXT: vmovq %xmm0, 48(%rax) - ; AVX512DQ-BW-FCP-NEXT: vzeroupper - ; AVX512DQ-BW-FCP-NEXT: retq - %in.vec0 = load <8 x i8>, ptr %in.vecptr0, align 64 -@@ -2051,76 +2053,77 @@ - ; AVX512: # %bb.0: - ; AVX512-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512-NEXT: movq {{[0-9]+}}(%rsp), %r10 --; AVX512-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512-NEXT: vmovdqa (%rsi), %xmm1 --; AVX512-NEXT: vmovdqa (%rdx), %xmm5 --; AVX512-NEXT: vmovdqa (%rcx), %xmm6 --; AVX512-NEXT: vmovdqa (%r8), %xmm3 --; AVX512-NEXT: vmovdqa (%r9), %xmm4 --; AVX512-NEXT: vmovdqa (%r10), %xmm2 --; AVX512-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm8 --; AVX512-NEXT: vinserti128 $1, %xmm6, %ymm5, %ymm9 --; AVX512-NEXT: vinserti128 $1, %xmm4, %ymm3, %ymm7 --; AVX512-NEXT: vpshufb {{.*#+}} ymm10 = ymm9[u,u,u,u,u,5],zero,ymm9[u,u,u,u,u,6],zero,ymm9[u,u,u,u,u],zero,ymm9[23,u,u,u,u,u],zero,ymm9[24,u,u,u,u] --; AVX512-NEXT: vpermq {{.*#+}} ymm11 = ymm9[2,3,0,1] --; AVX512-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[u,u,u,u,u],zero,ymm11[5,u,u,u,u,u],zero,ymm11[6,u,u,u,u,u,23],zero,ymm11[u,u,u,u,u,24],zero,ymm11[u,u,u,u] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm12 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255] --; AVX512-NEXT: vpternlogq $50, %ymm10, %ymm12, %ymm11 --; AVX512-NEXT: vpermq {{.*#+}} ymm10 = ymm9[0,2,0,2] --; AVX512-NEXT: vpshufb {{.*#+}} ymm10 = zero,zero,ymm10[0,8,u,u,u],zero,zero,ymm10[1,9,u,u,u],zero,zero,ymm10[18,26,u,u,u],zero,zero,ymm10[19,27,u,u,u],zero,zero,ymm10[20,28] --; AVX512-NEXT: vinserti64x4 $1, %ymm11, %zmm10, %zmm10 --; AVX512-NEXT: vpshufb {{.*#+}} ymm11 = ymm8[u,u,u,5],zero,ymm8[u,u,u,u,u,6],zero,ymm8[u,u,u,u,u],zero,ymm8[23,u,u,u,u,u],zero,ymm8[24,u,u,u,u,u],zero --; AVX512-NEXT: vpermq {{.*#+}} ymm13 = ymm8[2,3,0,1] --; AVX512-NEXT: vpshufb {{.*#+}} ymm13 = ymm13[u,u,u],zero,ymm13[5,u,u,u,u,u],zero,ymm13[6,u,u,u,u,u,23],zero,ymm13[u,u,u,u,u,24],zero,ymm13[u,u,u,u,u,25] --; AVX512-NEXT: vpternlogq $200, %ymm11, %ymm12, %ymm13 --; AVX512-NEXT: vpermq {{.*#+}} ymm11 = ymm8[0,2,0,2] --; AVX512-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[0,8],zero,zero,ymm11[u,u,u,1,9],zero,zero,ymm11[u,u,u,2,10],zero,zero,ymm11[u,u,u,19,27],zero,zero,ymm11[u,u,u,20,28],zero,zero --; AVX512-NEXT: vinserti64x4 $1, %ymm13, %zmm11, %zmm11 --; AVX512-NEXT: vporq %zmm10, %zmm11, %zmm10 --; AVX512-NEXT: vpshufb {{.*#+}} ymm11 = ymm7[4],zero,ymm7[u,u,u,u,u,5],zero,ymm7[u,u,u,u,u,6],zero,ymm7[u,u,u,u,u],zero,ymm7[23,u,u,u,u,u],zero,ymm7[24,u,u] --; AVX512-NEXT: vpermq {{.*#+}} ymm12 = ymm7[2,3,0,1] --; AVX512-NEXT: vpshufb {{.*#+}} ymm12 = zero,ymm12[4,u,u,u,u,u],zero,ymm12[5,u,u,u,u,u],zero,ymm12[6,u,u,u,u,u,23],zero,ymm12[u,u,u,u,u,24],zero,ymm12[u,u] --; AVX512-NEXT: vmovdqa {{.*#+}} ymm13 = [255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255] --; AVX512-NEXT: vpternlogq $200, %ymm11, %ymm13, %ymm12 --; AVX512-NEXT: vpermq {{.*#+}} ymm11 = ymm7[0,2,0,2] --; AVX512-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[u,u,u,u,0,8],zero,ymm11[u,u,u,u,1,9],zero,ymm11[u,u,u,u,18,26],zero,ymm11[u,u,u,u,19,27],zero,ymm11[u,u,u,u] --; AVX512-NEXT: vinserti64x4 $1, %ymm12, %zmm11, %zmm11 --; AVX512-NEXT: vpshufb {{.*#+}} xmm12 = xmm2[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512-NEXT: vpermq {{.*#+}} ymm12 = ymm12[0,0,1,0] --; AVX512-NEXT: vpandn %ymm12, %ymm13, %ymm12 --; AVX512-NEXT: vpshuflw {{.*#+}} xmm13 = xmm2[1,1,0,0,4,5,6,7] --; AVX512-NEXT: vpshufd {{.*#+}} xmm13 = xmm13[0,1,2,0] --; AVX512-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,1,0] --; AVX512-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm13, %ymm13 --; AVX512-NEXT: vinserti64x4 $1, %ymm12, %zmm13, %zmm12 --; AVX512-NEXT: vporq %zmm12, %zmm11, %zmm11 --; AVX512-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm11 --; AVX512-NEXT: vpermq {{.*#+}} ymm8 = ymm8[3,1,1,3] --; AVX512-NEXT: vpshufb {{.*#+}} ymm8 = ymm8[1],zero,zero,ymm8[u,u,u,10,2],zero,zero,ymm8[u,u,u,11,3],zero,zero,ymm8[u,u,u,20,28],zero,zero,ymm8[u,u,u,21,29],zero,zero,ymm8[u] --; AVX512-NEXT: vpermq {{.*#+}} ymm9 = ymm9[1,3,3,1] --; AVX512-NEXT: vpshufb {{.*#+}} ymm9 = zero,ymm9[1,9,u,u,u],zero,zero,ymm9[2,10,u,u,u],zero,zero,ymm9[3,19,u,u,u],zero,zero,ymm9[28,20,u,u,u],zero,zero,ymm9[29,21,u] --; AVX512-NEXT: vpor %ymm8, %ymm9, %ymm8 --; AVX512-NEXT: vpshufhw {{.*#+}} xmm9 = xmm2[0,1,2,3,4,5,5,6] --; AVX512-NEXT: vpshufd {{.*#+}} xmm9 = xmm9[2,2,3,3] --; AVX512-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,1,0,1] --; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm7[1,3,1,3] -+; AVX512-NEXT: vmovdqa (%rdi), %xmm4 -+; AVX512-NEXT: vmovdqa (%rsi), %xmm5 -+; AVX512-NEXT: vmovdqa (%rdx), %xmm6 -+; AVX512-NEXT: vmovdqa (%rcx), %xmm7 -+; AVX512-NEXT: vmovdqa (%r8), %xmm0 -+; AVX512-NEXT: vmovdqa (%r10), %xmm1 -+; AVX512-NEXT: vinserti128 $1, %xmm7, %ymm6, %ymm3 -+; AVX512-NEXT: vinserti128 $1, %xmm5, %ymm4, %ymm2 -+; AVX512-NEXT: vinserti128 $1, (%r9), %ymm0, %ymm0 -+; AVX512-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm0 -+; AVX512-NEXT: vpunpckhbw {{.*#+}} xmm6 = xmm6[8],xmm7[8],xmm6[9],xmm7[9],xmm6[10],xmm7[10],xmm6[11],xmm7[11],xmm6[12],xmm7[12],xmm6[13],xmm7[13],xmm6[14],xmm7[14],xmm6[15],xmm7[15] -+; AVX512-NEXT: vpshufb {{.*#+}} xmm6 = xmm6[u,u],zero,zero,xmm6[12,13,u,u,u],zero,zero,xmm6[14,15,u,u,u] -+; AVX512-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm4[8],xmm5[8],xmm4[9],xmm5[9],xmm4[10],xmm5[10],xmm4[11],xmm5[11],xmm4[12],xmm5[12],xmm4[13],xmm5[13],xmm4[14],xmm5[14],xmm4[15],xmm5[15] -+; AVX512-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[u,u,12,13],zero,zero,xmm4[u,u,u,14,15],zero,zero,xmm4[u,u,u] -+; AVX512-NEXT: vpor %xmm6, %xmm4, %xmm4 -+; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm5 -+; AVX512-NEXT: vpunpckhbw {{.*#+}} xmm5 = xmm5[8],xmm0[8],xmm5[9],xmm0[9],xmm5[10],xmm0[10],xmm5[11],xmm0[11],xmm5[12],xmm0[12],xmm5[13],xmm0[13],xmm5[14],xmm0[14],xmm5[15],xmm0[15] -+; AVX512-NEXT: vpshufb {{.*#+}} xmm5 = xmm5[10],zero,xmm5[u,u,u,u,13,12],zero,xmm5[u,u,u,u,15,14],zero -+; AVX512-NEXT: vpshufb {{.*#+}} xmm6 = zero,xmm1[13,u,u,u,u],zero,zero,xmm1[14,u,u,u,u],zero,zero,xmm1[15] -+; AVX512-NEXT: vpor %xmm6, %xmm5, %xmm5 -+; AVX512-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm4, %xmm5 -+; AVX512-NEXT: vpermq {{.*#+}} ymm4 = ymm2[3,1,1,3] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm4 = ymm4[1],zero,zero,ymm4[u,u,u,10,2],zero,zero,ymm4[u,u,u,11,3],zero,zero,ymm4[u,u,u,20,28],zero,zero,ymm4[u,u,u,21,29],zero,zero,ymm4[u] -+; AVX512-NEXT: vpermq {{.*#+}} ymm6 = ymm3[1,3,3,1] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm6 = zero,ymm6[1,9,u,u,u],zero,zero,ymm6[2,10,u,u,u],zero,zero,ymm6[3,19,u,u,u],zero,zero,ymm6[28,20,u,u,u],zero,zero,ymm6[29,21,u] -+; AVX512-NEXT: vpor %ymm4, %ymm6, %ymm4 -+; AVX512-NEXT: vpshufhw {{.*#+}} xmm6 = xmm1[0,1,2,3,4,5,5,6] -+; AVX512-NEXT: vpshufd {{.*#+}} xmm6 = xmm6[2,2,3,3] -+; AVX512-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,1,0,1] -+; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm0[1,3,1,3] - ; AVX512-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,1,9],zero,ymm7[u,u,u,u,2,10],zero,ymm7[u,u,u,u,19,27],zero,ymm7[u,u,u,u,20,28],zero,ymm7[u,u,u,u,21] --; AVX512-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm9, %ymm7 --; AVX512-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm8, %ymm7 --; AVX512-NEXT: vpunpckhbw {{.*#+}} xmm5 = xmm5[8],xmm6[8],xmm5[9],xmm6[9],xmm5[10],xmm6[10],xmm5[11],xmm6[11],xmm5[12],xmm6[12],xmm5[13],xmm6[13],xmm5[14],xmm6[14],xmm5[15],xmm6[15] --; AVX512-NEXT: vpshufb {{.*#+}} xmm5 = xmm5[u,u],zero,zero,xmm5[12,13,u,u,u],zero,zero,xmm5[14,15,u,u,u] --; AVX512-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm1[8],xmm0[9],xmm1[9],xmm0[10],xmm1[10],xmm0[11],xmm1[11],xmm0[12],xmm1[12],xmm0[13],xmm1[13],xmm0[14],xmm1[14],xmm0[15],xmm1[15] --; AVX512-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,12,13],zero,zero,xmm0[u,u,u,14,15],zero,zero,xmm0[u,u,u] --; AVX512-NEXT: vpor %xmm5, %xmm0, %xmm0 --; AVX512-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] --; AVX512-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[10],zero,xmm1[u,u,u,u,13,12],zero,xmm1[u,u,u,u,15,14],zero --; AVX512-NEXT: vpshufb {{.*#+}} xmm2 = zero,xmm2[13,u,u,u,u],zero,zero,xmm2[14,u,u,u,u],zero,zero,xmm2[15] --; AVX512-NEXT: vpor %xmm2, %xmm1, %xmm1 --; AVX512-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 --; AVX512-NEXT: vinserti32x4 $2, %xmm1, %zmm7, %zmm0 --; AVX512-NEXT: vmovdqa %xmm1, 96(%rax) --; AVX512-NEXT: vmovdqa %ymm0, 64(%rax) --; AVX512-NEXT: vmovdqa64 %zmm11, (%rax) -+; AVX512-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm6, %ymm7 -+; AVX512-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm7 -+; AVX512-NEXT: vinserti32x4 $2, %xmm5, %zmm7, %zmm4 -+; AVX512-NEXT: vpshufb {{.*#+}} ymm6 = ymm3[u,u,u,u,u,5],zero,ymm3[u,u,u,u,u,6],zero,ymm3[u,u,u,u,u],zero,ymm3[23,u,u,u,u,u],zero,ymm3[24,u,u,u,u] -+; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm3[2,3,0,1] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,u,u],zero,ymm7[5,u,u,u,u,u],zero,ymm7[6,u,u,u,u,u,23],zero,ymm7[u,u,u,u,u,24],zero,ymm7[u,u,u,u] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm8 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255] -+; AVX512-NEXT: vpternlogq $50, %ymm6, %ymm8, %ymm7 -+; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,2,0,2] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,ymm3[0,8,u,u,u],zero,zero,ymm3[1,9,u,u,u],zero,zero,ymm3[18,26,u,u,u],zero,zero,ymm3[19,27,u,u,u],zero,zero,ymm3[20,28] -+; AVX512-NEXT: vinserti64x4 $1, %ymm7, %zmm3, %zmm3 -+; AVX512-NEXT: vpshufb {{.*#+}} ymm6 = ymm2[u,u,u,5],zero,ymm2[u,u,u,u,u,6],zero,ymm2[u,u,u,u,u],zero,ymm2[23,u,u,u,u,u],zero,ymm2[24,u,u,u,u,u],zero -+; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm2[2,3,0,1] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u],zero,ymm7[5,u,u,u,u,u],zero,ymm7[6,u,u,u,u,u,23],zero,ymm7[u,u,u,u,u,24],zero,ymm7[u,u,u,u,u,25] -+; AVX512-NEXT: vpternlogq $200, %ymm6, %ymm8, %ymm7 -+; AVX512-NEXT: vpermq {{.*#+}} ymm2 = ymm2[0,2,0,2] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[0,8],zero,zero,ymm2[u,u,u,1,9],zero,zero,ymm2[u,u,u,2,10],zero,zero,ymm2[u,u,u,19,27],zero,zero,ymm2[u,u,u,20,28],zero,zero -+; AVX512-NEXT: vinserti64x4 $1, %ymm7, %zmm2, %zmm2 -+; AVX512-NEXT: vporq %zmm3, %zmm2, %zmm2 -+; AVX512-NEXT: vpshufb {{.*#+}} xmm3 = xmm1[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,0,1,0] -+; AVX512-NEXT: vmovdqa {{.*#+}} ymm6 = [255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255] -+; AVX512-NEXT: vpandn %ymm3, %ymm6, %ymm3 -+; AVX512-NEXT: vpshuflw {{.*#+}} xmm1 = xmm1[1,1,0,0,4,5,6,7] -+; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,1,2,0] -+; AVX512-NEXT: vpermq {{.*#+}} ymm1 = ymm1[0,0,1,0] -+; AVX512-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 -+; AVX512-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm1 -+; AVX512-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[4],zero,ymm0[u,u,u,u,u,5],zero,ymm0[u,u,u,u,u,6],zero,ymm0[u,u,u,u,u],zero,ymm0[23,u,u,u,u,u],zero,ymm0[24,u,u] -+; AVX512-NEXT: vpermq {{.*#+}} ymm7 = ymm0[2,3,0,1] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm7 = zero,ymm7[4,u,u,u,u,u],zero,ymm7[5,u,u,u,u,u],zero,ymm7[6,u,u,u,u,u,23],zero,ymm7[u,u,u,u,u,24],zero,ymm7[u,u] -+; AVX512-NEXT: vpternlogq $200, %ymm3, %ymm6, %ymm7 -+; AVX512-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,2,0,2] -+; AVX512-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,0,8],zero,ymm0[u,u,u,u,1,9],zero,ymm0[u,u,u,u,18,26],zero,ymm0[u,u,u,u,19,27],zero,ymm0[u,u,u,u] -+; AVX512-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm0 -+; AVX512-NEXT: vporq %zmm1, %zmm0, %zmm0 -+; AVX512-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 -+; AVX512-NEXT: vmovdqa %xmm5, 96(%rax) -+; AVX512-NEXT: vmovdqa64 %zmm0, (%rax) -+; AVX512-NEXT: vmovdqa %ymm4, 64(%rax) - ; AVX512-NEXT: vzeroupper - ; AVX512-NEXT: retq - ; -@@ -2128,69 +2131,70 @@ - ; AVX512-FCP: # %bb.0: - ; AVX512-FCP-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512-FCP-NEXT: movq {{[0-9]+}}(%rsp), %r10 --; AVX512-FCP-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512-FCP-NEXT: vmovdqa (%rsi), %xmm1 --; AVX512-FCP-NEXT: vmovdqa (%rdx), %xmm5 --; AVX512-FCP-NEXT: vmovdqa (%rcx), %xmm6 --; AVX512-FCP-NEXT: vmovdqa (%r8), %xmm3 --; AVX512-FCP-NEXT: vmovdqa (%r9), %xmm4 --; AVX512-FCP-NEXT: vmovdqa (%r10), %xmm2 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm7 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm6, %ymm5, %ymm8 --; AVX512-FCP-NEXT: vinserti128 $1, %xmm4, %ymm3, %ymm9 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm10 = ymm8[0,2,0,2] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm10 = zero,zero,ymm10[0,8,u,u,u],zero,zero,ymm10[1,9,u,u,u],zero,zero,ymm10[18,26,u,u,u],zero,zero,ymm10[19,27,u,u,u],zero,zero,ymm10[20,28] --; AVX512-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [1,5,2,6,1,5,2,6] --; AVX512-FCP-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512-FCP-NEXT: vpermd %ymm8, %ymm11, %ymm12 --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u],zero,zero,ymm12[1,5,u,u,u],zero,zero,ymm12[2,6,u,u,u],zero,zero,ymm12[19,23,u,u,u],zero,zero,ymm12[24,28,u,u,u],zero --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm12, %zmm10, %zmm10 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm12 = ymm7[0,2,0,2] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[0,8],zero,zero,ymm12[u,u,u,1,9],zero,zero,ymm12[u,u,u,2,10],zero,zero,ymm12[u,u,u,19,27],zero,zero,ymm12[u,u,u,20,28],zero,zero --; AVX512-FCP-NEXT: vpermd %ymm7, %ymm11, %ymm13 --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm13 = ymm13[u,u,u,1,5],zero,zero,ymm13[u,u,u,2,6],zero,zero,ymm13[u,u,u,19,23],zero,zero,ymm13[u,u,u,24,28],zero,zero,ymm13[u,u,u,25] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm13, %zmm12, %zmm12 --; AVX512-FCP-NEXT: vporq %zmm10, %zmm12, %zmm10 --; AVX512-FCP-NEXT: vpshuflw {{.*#+}} xmm12 = xmm2[1,1,0,0,4,5,6,7] --; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} ymm13 = [0,1,0,1,0,0,0,0] --; AVX512-FCP-NEXT: vpermd %ymm12, %ymm13, %ymm12 --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm13 = xmm2[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,1,0] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm13, %zmm12, %zmm12 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm13 = ymm9[0,2,0,2] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm13 = ymm13[u,u,u,u,0,8],zero,ymm13[u,u,u,u,1,9],zero,ymm13[u,u,u,u,18,26],zero,ymm13[u,u,u,u,19,27],zero,ymm13[u,u,u,u] --; AVX512-FCP-NEXT: vpermd %ymm9, %ymm11, %ymm11 --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[0,4],zero,ymm11[u,u,u,u,1,5],zero,ymm11[u,u,u,u,2,6],zero,ymm11[u,u,u,u,19,23],zero,ymm11[u,u,u,u,24,28],zero,ymm11[u] --; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm11, %zmm13, %zmm11 --; AVX512-FCP-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm11 --; AVX512-FCP-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm11 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm7[3,1,1,3] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[1],zero,zero,ymm7[u,u,u,10,2],zero,zero,ymm7[u,u,u,11,3],zero,zero,ymm7[u,u,u,20,28],zero,zero,ymm7[u,u,u,21,29],zero,zero,ymm7[u] --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm8 = ymm8[1,3,3,1] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm8 = zero,ymm8[1,9,u,u,u],zero,zero,ymm8[2,10,u,u,u],zero,zero,ymm8[3,19,u,u,u],zero,zero,ymm8[28,20,u,u,u],zero,zero,ymm8[29,21,u] --; AVX512-FCP-NEXT: vpor %ymm7, %ymm8, %ymm7 --; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm8 = xmm2[0,1,2,3,4,5,5,6] --; AVX512-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm10 = [2,2,3,3,2,2,3,3] --; AVX512-FCP-NEXT: # ymm10 = mem[0,1,0,1] --; AVX512-FCP-NEXT: vpermd %ymm8, %ymm10, %ymm8 --; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm9 = ymm9[1,3,1,3] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm9 = ymm9[u,u,u,1,9],zero,ymm9[u,u,u,u,2,10],zero,ymm9[u,u,u,u,19,27],zero,ymm9[u,u,u,u,20,28],zero,ymm9[u,u,u,u,21] --; AVX512-FCP-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm8, %ymm9 --; AVX512-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm7, %ymm9 --; AVX512-FCP-NEXT: vpunpckhbw {{.*#+}} xmm5 = xmm5[8],xmm6[8],xmm5[9],xmm6[9],xmm5[10],xmm6[10],xmm5[11],xmm6[11],xmm5[12],xmm6[12],xmm5[13],xmm6[13],xmm5[14],xmm6[14],xmm5[15],xmm6[15] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm5 = xmm5[u,u],zero,zero,xmm5[12,13,u,u,u],zero,zero,xmm5[14,15,u,u,u] --; AVX512-FCP-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm1[8],xmm0[9],xmm1[9],xmm0[10],xmm1[10],xmm0[11],xmm1[11],xmm0[12],xmm1[12],xmm0[13],xmm1[13],xmm0[14],xmm1[14],xmm0[15],xmm1[15] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,12,13],zero,zero,xmm0[u,u,u,14,15],zero,zero,xmm0[u,u,u] --; AVX512-FCP-NEXT: vpor %xmm5, %xmm0, %xmm0 --; AVX512-FCP-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[10],zero,xmm1[u,u,u,u,13,12],zero,xmm1[u,u,u,u,15,14],zero --; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm2 = zero,xmm2[13,u,u,u,u],zero,zero,xmm2[14,u,u,u,u],zero,zero,xmm2[15] --; AVX512-FCP-NEXT: vpor %xmm2, %xmm1, %xmm1 --; AVX512-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 --; AVX512-FCP-NEXT: vinserti32x4 $2, %xmm1, %zmm9, %zmm0 --; AVX512-FCP-NEXT: vmovdqa %xmm1, 96(%rax) --; AVX512-FCP-NEXT: vmovdqa64 %zmm11, (%rax) --; AVX512-FCP-NEXT: vmovdqa %ymm0, 64(%rax) -+; AVX512-FCP-NEXT: vmovdqa (%rdi), %xmm2 -+; AVX512-FCP-NEXT: vmovdqa (%rsi), %xmm3 -+; AVX512-FCP-NEXT: vmovdqa (%rdx), %xmm4 -+; AVX512-FCP-NEXT: vmovdqa (%rcx), %xmm5 -+; AVX512-FCP-NEXT: vmovdqa (%r8), %xmm1 -+; AVX512-FCP-NEXT: vmovdqa (%r10), %xmm0 -+; AVX512-FCP-NEXT: vinserti128 $1, %xmm5, %ymm4, %ymm6 -+; AVX512-FCP-NEXT: vinserti128 $1, %xmm3, %ymm2, %ymm7 -+; AVX512-FCP-NEXT: vinserti128 $1, (%r9), %ymm1, %ymm1 -+; AVX512-FCP-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm1 -+; AVX512-FCP-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm4[8],xmm5[8],xmm4[9],xmm5[9],xmm4[10],xmm5[10],xmm4[11],xmm5[11],xmm4[12],xmm5[12],xmm4[13],xmm5[13],xmm4[14],xmm5[14],xmm4[15],xmm5[15] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[u,u],zero,zero,xmm4[12,13,u,u,u],zero,zero,xmm4[14,15,u,u,u] -+; AVX512-FCP-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm2[8],xmm3[8],xmm2[9],xmm3[9],xmm2[10],xmm3[10],xmm2[11],xmm3[11],xmm2[12],xmm3[12],xmm2[13],xmm3[13],xmm2[14],xmm3[14],xmm2[15],xmm3[15] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[u,u,12,13],zero,zero,xmm2[u,u,u,14,15],zero,zero,xmm2[u,u,u] -+; AVX512-FCP-NEXT: vpor %xmm4, %xmm2, %xmm2 -+; AVX512-FCP-NEXT: vextracti128 $1, %ymm1, %xmm3 -+; AVX512-FCP-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm3[8],xmm1[8],xmm3[9],xmm1[9],xmm3[10],xmm1[10],xmm3[11],xmm1[11],xmm3[12],xmm1[12],xmm3[13],xmm1[13],xmm3[14],xmm1[14],xmm3[15],xmm1[15] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm3 = xmm3[10],zero,xmm3[u,u,u,u,13,12],zero,xmm3[u,u,u,u,15,14],zero -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm4 = zero,xmm0[13,u,u,u,u],zero,zero,xmm0[14,u,u,u,u],zero,zero,xmm0[15] -+; AVX512-FCP-NEXT: vpor %xmm4, %xmm3, %xmm3 -+; AVX512-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2, %xmm3 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm7[3,1,1,3] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[1],zero,zero,ymm2[u,u,u,10,2],zero,zero,ymm2[u,u,u,11,3],zero,zero,ymm2[u,u,u,20,28],zero,zero,ymm2[u,u,u,21,29],zero,zero,ymm2[u] -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm6[1,3,3,1] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm4 = zero,ymm4[1,9,u,u,u],zero,zero,ymm4[2,10,u,u,u],zero,zero,ymm4[3,19,u,u,u],zero,zero,ymm4[28,20,u,u,u],zero,zero,ymm4[29,21,u] -+; AVX512-FCP-NEXT: vpor %ymm2, %ymm4, %ymm2 -+; AVX512-FCP-NEXT: vpshufhw {{.*#+}} xmm4 = xmm0[0,1,2,3,4,5,5,6] -+; AVX512-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [2,2,3,3,2,2,3,3] -+; AVX512-FCP-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512-FCP-NEXT: vpermd %ymm4, %ymm5, %ymm4 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm5 = ymm1[1,3,1,3] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm5 = ymm5[u,u,u,1,9],zero,ymm5[u,u,u,u,2,10],zero,ymm5[u,u,u,u,19,27],zero,ymm5[u,u,u,u,20,28],zero,ymm5[u,u,u,u,21] -+; AVX512-FCP-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm5 -+; AVX512-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm2, %ymm5 -+; AVX512-FCP-NEXT: vinserti32x4 $2, %xmm3, %zmm5, %zmm2 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm6[0,2,0,2] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm4 = zero,zero,ymm4[0,8,u,u,u],zero,zero,ymm4[1,9,u,u,u],zero,zero,ymm4[18,26,u,u,u],zero,zero,ymm4[19,27,u,u,u],zero,zero,ymm4[20,28] -+; AVX512-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [1,5,2,6,1,5,2,6] -+; AVX512-FCP-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512-FCP-NEXT: vpermd %ymm6, %ymm5, %ymm6 -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u],zero,zero,ymm6[1,5,u,u,u],zero,zero,ymm6[2,6,u,u,u],zero,zero,ymm6[19,23,u,u,u],zero,zero,ymm6[24,28,u,u,u],zero -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm6, %zmm4, %zmm4 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm7[0,2,0,2] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[0,8],zero,zero,ymm6[u,u,u,1,9],zero,zero,ymm6[u,u,u,2,10],zero,zero,ymm6[u,u,u,19,27],zero,zero,ymm6[u,u,u,20,28],zero,zero -+; AVX512-FCP-NEXT: vpermd %ymm7, %ymm5, %ymm7 -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,1,5],zero,zero,ymm7[u,u,u,2,6],zero,zero,ymm7[u,u,u,19,23],zero,zero,ymm7[u,u,u,24,28],zero,zero,ymm7[u,u,u,25] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm7, %zmm6, %zmm6 -+; AVX512-FCP-NEXT: vporq %zmm4, %zmm6, %zmm4 -+; AVX512-FCP-NEXT: vpshuflw {{.*#+}} xmm6 = xmm0[1,1,0,0,4,5,6,7] -+; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} ymm7 = [0,1,0,1,0,0,0,0] -+; AVX512-FCP-NEXT: vpermd %ymm6, %ymm7, %ymm6 -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,0,1,0] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm6, %zmm0 -+; AVX512-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm1[0,2,0,2] -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u,u,0,8],zero,ymm6[u,u,u,u,1,9],zero,ymm6[u,u,u,u,18,26],zero,ymm6[u,u,u,u,19,27],zero,ymm6[u,u,u,u] -+; AVX512-FCP-NEXT: vpermd %ymm1, %ymm5, %ymm1 -+; AVX512-FCP-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,4],zero,ymm1[u,u,u,u,1,5],zero,ymm1[u,u,u,u,2,6],zero,ymm1[u,u,u,u,19,23],zero,ymm1[u,u,u,u,24,28],zero,ymm1[u] -+; AVX512-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm6, %zmm1 -+; AVX512-FCP-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm1 -+; AVX512-FCP-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm1 -+; AVX512-FCP-NEXT: vmovdqa %xmm3, 96(%rax) -+; AVX512-FCP-NEXT: vmovdqa64 %zmm1, (%rax) -+; AVX512-FCP-NEXT: vmovdqa %ymm2, 64(%rax) - ; AVX512-FCP-NEXT: vzeroupper - ; AVX512-FCP-NEXT: retq - ; -@@ -2198,76 +2202,77 @@ - ; AVX512DQ: # %bb.0: - ; AVX512DQ-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512DQ-NEXT: movq {{[0-9]+}}(%rsp), %r10 --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512DQ-NEXT: vmovdqa (%rsi), %xmm1 --; AVX512DQ-NEXT: vmovdqa (%rdx), %xmm5 --; AVX512DQ-NEXT: vmovdqa (%rcx), %xmm6 --; AVX512DQ-NEXT: vmovdqa (%r8), %xmm3 --; AVX512DQ-NEXT: vmovdqa (%r9), %xmm4 --; AVX512DQ-NEXT: vmovdqa (%r10), %xmm2 --; AVX512DQ-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm8 --; AVX512DQ-NEXT: vinserti128 $1, %xmm6, %ymm5, %ymm9 --; AVX512DQ-NEXT: vinserti128 $1, %xmm4, %ymm3, %ymm7 --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm10 = ymm9[u,u,u,u,u,5],zero,ymm9[u,u,u,u,u,6],zero,ymm9[u,u,u,u,u],zero,ymm9[23,u,u,u,u,u],zero,ymm9[24,u,u,u,u] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm11 = ymm9[2,3,0,1] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[u,u,u,u,u],zero,ymm11[5,u,u,u,u,u],zero,ymm11[6,u,u,u,u,u,23],zero,ymm11[u,u,u,u,u,24],zero,ymm11[u,u,u,u] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm12 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255] --; AVX512DQ-NEXT: vpternlogq $50, %ymm10, %ymm12, %ymm11 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm10 = ymm9[0,2,0,2] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm10 = zero,zero,ymm10[0,8,u,u,u],zero,zero,ymm10[1,9,u,u,u],zero,zero,ymm10[18,26,u,u,u],zero,zero,ymm10[19,27,u,u,u],zero,zero,ymm10[20,28] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm11, %zmm10, %zmm10 --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm11 = ymm8[u,u,u,5],zero,ymm8[u,u,u,u,u,6],zero,ymm8[u,u,u,u,u],zero,ymm8[23,u,u,u,u,u],zero,ymm8[24,u,u,u,u,u],zero --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm13 = ymm8[2,3,0,1] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm13 = ymm13[u,u,u],zero,ymm13[5,u,u,u,u,u],zero,ymm13[6,u,u,u,u,u,23],zero,ymm13[u,u,u,u,u,24],zero,ymm13[u,u,u,u,u,25] --; AVX512DQ-NEXT: vpternlogq $200, %ymm11, %ymm12, %ymm13 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm11 = ymm8[0,2,0,2] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[0,8],zero,zero,ymm11[u,u,u,1,9],zero,zero,ymm11[u,u,u,2,10],zero,zero,ymm11[u,u,u,19,27],zero,zero,ymm11[u,u,u,20,28],zero,zero --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm13, %zmm11, %zmm11 --; AVX512DQ-NEXT: vporq %zmm10, %zmm11, %zmm10 --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm11 = ymm7[4],zero,ymm7[u,u,u,u,u,5],zero,ymm7[u,u,u,u,u,6],zero,ymm7[u,u,u,u,u],zero,ymm7[23,u,u,u,u,u],zero,ymm7[24,u,u] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm12 = ymm7[2,3,0,1] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm12 = zero,ymm12[4,u,u,u,u,u],zero,ymm12[5,u,u,u,u,u],zero,ymm12[6,u,u,u,u,u,23],zero,ymm12[u,u,u,u,u,24],zero,ymm12[u,u] --; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm13 = [255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255] --; AVX512DQ-NEXT: vpternlogq $200, %ymm11, %ymm13, %ymm12 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm11 = ymm7[0,2,0,2] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[u,u,u,u,0,8],zero,ymm11[u,u,u,u,1,9],zero,ymm11[u,u,u,u,18,26],zero,ymm11[u,u,u,u,19,27],zero,ymm11[u,u,u,u] --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm12, %zmm11, %zmm11 --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm12 = xmm2[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm12 = ymm12[0,0,1,0] --; AVX512DQ-NEXT: vpandn %ymm12, %ymm13, %ymm12 --; AVX512DQ-NEXT: vpshuflw {{.*#+}} xmm13 = xmm2[1,1,0,0,4,5,6,7] --; AVX512DQ-NEXT: vpshufd {{.*#+}} xmm13 = xmm13[0,1,2,0] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,1,0] --; AVX512DQ-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm13, %ymm13 --; AVX512DQ-NEXT: vinserti64x4 $1, %ymm12, %zmm13, %zmm12 --; AVX512DQ-NEXT: vporq %zmm12, %zmm11, %zmm11 --; AVX512DQ-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm11 --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm8 = ymm8[3,1,1,3] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm8 = ymm8[1],zero,zero,ymm8[u,u,u,10,2],zero,zero,ymm8[u,u,u,11,3],zero,zero,ymm8[u,u,u,20,28],zero,zero,ymm8[u,u,u,21,29],zero,zero,ymm8[u] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm9 = ymm9[1,3,3,1] --; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm9 = zero,ymm9[1,9,u,u,u],zero,zero,ymm9[2,10,u,u,u],zero,zero,ymm9[3,19,u,u,u],zero,zero,ymm9[28,20,u,u,u],zero,zero,ymm9[29,21,u] --; AVX512DQ-NEXT: vpor %ymm8, %ymm9, %ymm8 --; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm9 = xmm2[0,1,2,3,4,5,5,6] --; AVX512DQ-NEXT: vpshufd {{.*#+}} xmm9 = xmm9[2,2,3,3] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,1,0,1] --; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm7[1,3,1,3] -+; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm4 -+; AVX512DQ-NEXT: vmovdqa (%rsi), %xmm5 -+; AVX512DQ-NEXT: vmovdqa (%rdx), %xmm6 -+; AVX512DQ-NEXT: vmovdqa (%rcx), %xmm7 -+; AVX512DQ-NEXT: vmovdqa (%r8), %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%r10), %xmm1 -+; AVX512DQ-NEXT: vinserti128 $1, %xmm7, %ymm6, %ymm3 -+; AVX512DQ-NEXT: vinserti128 $1, %xmm5, %ymm4, %ymm2 -+; AVX512DQ-NEXT: vinserti128 $1, (%r9), %ymm0, %ymm0 -+; AVX512DQ-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm0 -+; AVX512DQ-NEXT: vpunpckhbw {{.*#+}} xmm6 = xmm6[8],xmm7[8],xmm6[9],xmm7[9],xmm6[10],xmm7[10],xmm6[11],xmm7[11],xmm6[12],xmm7[12],xmm6[13],xmm7[13],xmm6[14],xmm7[14],xmm6[15],xmm7[15] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm6 = xmm6[u,u],zero,zero,xmm6[12,13,u,u,u],zero,zero,xmm6[14,15,u,u,u] -+; AVX512DQ-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm4[8],xmm5[8],xmm4[9],xmm5[9],xmm4[10],xmm5[10],xmm4[11],xmm5[11],xmm4[12],xmm5[12],xmm4[13],xmm5[13],xmm4[14],xmm5[14],xmm4[15],xmm5[15] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[u,u,12,13],zero,zero,xmm4[u,u,u,14,15],zero,zero,xmm4[u,u,u] -+; AVX512DQ-NEXT: vpor %xmm6, %xmm4, %xmm4 -+; AVX512DQ-NEXT: vextracti128 $1, %ymm0, %xmm5 -+; AVX512DQ-NEXT: vpunpckhbw {{.*#+}} xmm5 = xmm5[8],xmm0[8],xmm5[9],xmm0[9],xmm5[10],xmm0[10],xmm5[11],xmm0[11],xmm5[12],xmm0[12],xmm5[13],xmm0[13],xmm5[14],xmm0[14],xmm5[15],xmm0[15] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm5 = xmm5[10],zero,xmm5[u,u,u,u,13,12],zero,xmm5[u,u,u,u,15,14],zero -+; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm6 = zero,xmm1[13,u,u,u,u],zero,zero,xmm1[14,u,u,u,u],zero,zero,xmm1[15] -+; AVX512DQ-NEXT: vpor %xmm6, %xmm5, %xmm5 -+; AVX512DQ-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm4, %xmm5 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm4 = ymm2[3,1,1,3] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm4 = ymm4[1],zero,zero,ymm4[u,u,u,10,2],zero,zero,ymm4[u,u,u,11,3],zero,zero,ymm4[u,u,u,20,28],zero,zero,ymm4[u,u,u,21,29],zero,zero,ymm4[u] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm6 = ymm3[1,3,3,1] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm6 = zero,ymm6[1,9,u,u,u],zero,zero,ymm6[2,10,u,u,u],zero,zero,ymm6[3,19,u,u,u],zero,zero,ymm6[28,20,u,u,u],zero,zero,ymm6[29,21,u] -+; AVX512DQ-NEXT: vpor %ymm4, %ymm6, %ymm4 -+; AVX512DQ-NEXT: vpshufhw {{.*#+}} xmm6 = xmm1[0,1,2,3,4,5,5,6] -+; AVX512DQ-NEXT: vpshufd {{.*#+}} xmm6 = xmm6[2,2,3,3] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,1,0,1] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm0[1,3,1,3] - ; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,1,9],zero,ymm7[u,u,u,u,2,10],zero,ymm7[u,u,u,u,19,27],zero,ymm7[u,u,u,u,20,28],zero,ymm7[u,u,u,u,21] --; AVX512DQ-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm9, %ymm7 --; AVX512DQ-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm8, %ymm7 --; AVX512DQ-NEXT: vpunpckhbw {{.*#+}} xmm5 = xmm5[8],xmm6[8],xmm5[9],xmm6[9],xmm5[10],xmm6[10],xmm5[11],xmm6[11],xmm5[12],xmm6[12],xmm5[13],xmm6[13],xmm5[14],xmm6[14],xmm5[15],xmm6[15] --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm5 = xmm5[u,u],zero,zero,xmm5[12,13,u,u,u],zero,zero,xmm5[14,15,u,u,u] --; AVX512DQ-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm1[8],xmm0[9],xmm1[9],xmm0[10],xmm1[10],xmm0[11],xmm1[11],xmm0[12],xmm1[12],xmm0[13],xmm1[13],xmm0[14],xmm1[14],xmm0[15],xmm1[15] --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,12,13],zero,zero,xmm0[u,u,u,14,15],zero,zero,xmm0[u,u,u] --; AVX512DQ-NEXT: vpor %xmm5, %xmm0, %xmm0 --; AVX512DQ-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[10],zero,xmm1[u,u,u,u,13,12],zero,xmm1[u,u,u,u,15,14],zero --; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm2 = zero,xmm2[13,u,u,u,u],zero,zero,xmm2[14,u,u,u,u],zero,zero,xmm2[15] --; AVX512DQ-NEXT: vpor %xmm2, %xmm1, %xmm1 --; AVX512DQ-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 --; AVX512DQ-NEXT: vinserti32x4 $2, %xmm1, %zmm7, %zmm0 --; AVX512DQ-NEXT: vmovdqa %xmm1, 96(%rax) --; AVX512DQ-NEXT: vmovdqa %ymm0, 64(%rax) --; AVX512DQ-NEXT: vmovdqa64 %zmm11, (%rax) -+; AVX512DQ-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm6, %ymm7 -+; AVX512DQ-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm7 -+; AVX512DQ-NEXT: vinserti32x4 $2, %xmm5, %zmm7, %zmm4 -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm6 = ymm3[u,u,u,u,u,5],zero,ymm3[u,u,u,u,u,6],zero,ymm3[u,u,u,u,u],zero,ymm3[23,u,u,u,u,u],zero,ymm3[24,u,u,u,u] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm3[2,3,0,1] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,u,u],zero,ymm7[5,u,u,u,u,u],zero,ymm7[6,u,u,u,u,u,23],zero,ymm7[u,u,u,u,u,24],zero,ymm7[u,u,u,u] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm8 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255] -+; AVX512DQ-NEXT: vpternlogq $50, %ymm6, %ymm8, %ymm7 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,2,0,2] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,ymm3[0,8,u,u,u],zero,zero,ymm3[1,9,u,u,u],zero,zero,ymm3[18,26,u,u,u],zero,zero,ymm3[19,27,u,u,u],zero,zero,ymm3[20,28] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm7, %zmm3, %zmm3 -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm6 = ymm2[u,u,u,5],zero,ymm2[u,u,u,u,u,6],zero,ymm2[u,u,u,u,u],zero,ymm2[23,u,u,u,u,u],zero,ymm2[24,u,u,u,u,u],zero -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm2[2,3,0,1] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u],zero,ymm7[5,u,u,u,u,u],zero,ymm7[6,u,u,u,u,u,23],zero,ymm7[u,u,u,u,u,24],zero,ymm7[u,u,u,u,u,25] -+; AVX512DQ-NEXT: vpternlogq $200, %ymm6, %ymm8, %ymm7 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm2 = ymm2[0,2,0,2] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[0,8],zero,zero,ymm2[u,u,u,1,9],zero,zero,ymm2[u,u,u,2,10],zero,zero,ymm2[u,u,u,19,27],zero,zero,ymm2[u,u,u,20,28],zero,zero -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm7, %zmm2, %zmm2 -+; AVX512DQ-NEXT: vporq %zmm3, %zmm2, %zmm2 -+; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm3 = xmm1[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,0,1,0] -+; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm6 = [255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255] -+; AVX512DQ-NEXT: vpandn %ymm3, %ymm6, %ymm3 -+; AVX512DQ-NEXT: vpshuflw {{.*#+}} xmm1 = xmm1[1,1,0,0,4,5,6,7] -+; AVX512DQ-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,1,2,0] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm1 = ymm1[0,0,1,0] -+; AVX512DQ-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm3, %zmm1, %zmm1 -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[4],zero,ymm0[u,u,u,u,u,5],zero,ymm0[u,u,u,u,u,6],zero,ymm0[u,u,u,u,u],zero,ymm0[23,u,u,u,u,u],zero,ymm0[24,u,u] -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm7 = ymm0[2,3,0,1] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm7 = zero,ymm7[4,u,u,u,u,u],zero,ymm7[5,u,u,u,u,u],zero,ymm7[6,u,u,u,u,u,23],zero,ymm7[u,u,u,u,u,24],zero,ymm7[u,u] -+; AVX512DQ-NEXT: vpternlogq $200, %ymm3, %ymm6, %ymm7 -+; AVX512DQ-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,2,0,2] -+; AVX512DQ-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,0,8],zero,ymm0[u,u,u,u,1,9],zero,ymm0[u,u,u,u,18,26],zero,ymm0[u,u,u,u,19,27],zero,ymm0[u,u,u,u] -+; AVX512DQ-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm0 -+; AVX512DQ-NEXT: vporq %zmm1, %zmm0, %zmm0 -+; AVX512DQ-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 -+; AVX512DQ-NEXT: vmovdqa %xmm5, 96(%rax) -+; AVX512DQ-NEXT: vmovdqa64 %zmm0, (%rax) -+; AVX512DQ-NEXT: vmovdqa %ymm4, 64(%rax) - ; AVX512DQ-NEXT: vzeroupper - ; AVX512DQ-NEXT: retq - ; -@@ -2275,69 +2280,70 @@ - ; AVX512DQ-FCP: # %bb.0: - ; AVX512DQ-FCP-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512DQ-FCP-NEXT: movq {{[0-9]+}}(%rsp), %r10 --; AVX512DQ-FCP-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512DQ-FCP-NEXT: vmovdqa (%rsi), %xmm1 --; AVX512DQ-FCP-NEXT: vmovdqa (%rdx), %xmm5 --; AVX512DQ-FCP-NEXT: vmovdqa (%rcx), %xmm6 --; AVX512DQ-FCP-NEXT: vmovdqa (%r8), %xmm3 --; AVX512DQ-FCP-NEXT: vmovdqa (%r9), %xmm4 --; AVX512DQ-FCP-NEXT: vmovdqa (%r10), %xmm2 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm7 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm6, %ymm5, %ymm8 --; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm4, %ymm3, %ymm9 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm10 = ymm8[0,2,0,2] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm10 = zero,zero,ymm10[0,8,u,u,u],zero,zero,ymm10[1,9,u,u,u],zero,zero,ymm10[18,26,u,u,u],zero,zero,ymm10[19,27,u,u,u],zero,zero,ymm10[20,28] --; AVX512DQ-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [1,5,2,6,1,5,2,6] --; AVX512DQ-FCP-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512DQ-FCP-NEXT: vpermd %ymm8, %ymm11, %ymm12 --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u],zero,zero,ymm12[1,5,u,u,u],zero,zero,ymm12[2,6,u,u,u],zero,zero,ymm12[19,23,u,u,u],zero,zero,ymm12[24,28,u,u,u],zero --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm12, %zmm10, %zmm10 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm12 = ymm7[0,2,0,2] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[0,8],zero,zero,ymm12[u,u,u,1,9],zero,zero,ymm12[u,u,u,2,10],zero,zero,ymm12[u,u,u,19,27],zero,zero,ymm12[u,u,u,20,28],zero,zero --; AVX512DQ-FCP-NEXT: vpermd %ymm7, %ymm11, %ymm13 --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm13 = ymm13[u,u,u,1,5],zero,zero,ymm13[u,u,u,2,6],zero,zero,ymm13[u,u,u,19,23],zero,zero,ymm13[u,u,u,24,28],zero,zero,ymm13[u,u,u,25] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm13, %zmm12, %zmm12 --; AVX512DQ-FCP-NEXT: vporq %zmm10, %zmm12, %zmm10 --; AVX512DQ-FCP-NEXT: vpshuflw {{.*#+}} xmm12 = xmm2[1,1,0,0,4,5,6,7] --; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} ymm13 = [0,1,0,1,0,0,0,0] --; AVX512DQ-FCP-NEXT: vpermd %ymm12, %ymm13, %ymm12 --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm13 = xmm2[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,1,0] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm13, %zmm12, %zmm12 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm13 = ymm9[0,2,0,2] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm13 = ymm13[u,u,u,u,0,8],zero,ymm13[u,u,u,u,1,9],zero,ymm13[u,u,u,u,18,26],zero,ymm13[u,u,u,u,19,27],zero,ymm13[u,u,u,u] --; AVX512DQ-FCP-NEXT: vpermd %ymm9, %ymm11, %ymm11 --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm11 = ymm11[0,4],zero,ymm11[u,u,u,u,1,5],zero,ymm11[u,u,u,u,2,6],zero,ymm11[u,u,u,u,19,23],zero,ymm11[u,u,u,u,24,28],zero,ymm11[u] --; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm11, %zmm13, %zmm11 --; AVX512DQ-FCP-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm11 --; AVX512DQ-FCP-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm11 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm7 = ymm7[3,1,1,3] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[1],zero,zero,ymm7[u,u,u,10,2],zero,zero,ymm7[u,u,u,11,3],zero,zero,ymm7[u,u,u,20,28],zero,zero,ymm7[u,u,u,21,29],zero,zero,ymm7[u] --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm8 = ymm8[1,3,3,1] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm8 = zero,ymm8[1,9,u,u,u],zero,zero,ymm8[2,10,u,u,u],zero,zero,ymm8[3,19,u,u,u],zero,zero,ymm8[28,20,u,u,u],zero,zero,ymm8[29,21,u] --; AVX512DQ-FCP-NEXT: vpor %ymm7, %ymm8, %ymm7 --; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm8 = xmm2[0,1,2,3,4,5,5,6] --; AVX512DQ-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm10 = [2,2,3,3,2,2,3,3] --; AVX512DQ-FCP-NEXT: # ymm10 = mem[0,1,0,1] --; AVX512DQ-FCP-NEXT: vpermd %ymm8, %ymm10, %ymm8 --; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm9 = ymm9[1,3,1,3] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm9 = ymm9[u,u,u,1,9],zero,ymm9[u,u,u,u,2,10],zero,ymm9[u,u,u,u,19,27],zero,ymm9[u,u,u,u,20,28],zero,ymm9[u,u,u,u,21] --; AVX512DQ-FCP-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm8, %ymm9 --; AVX512DQ-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm7, %ymm9 --; AVX512DQ-FCP-NEXT: vpunpckhbw {{.*#+}} xmm5 = xmm5[8],xmm6[8],xmm5[9],xmm6[9],xmm5[10],xmm6[10],xmm5[11],xmm6[11],xmm5[12],xmm6[12],xmm5[13],xmm6[13],xmm5[14],xmm6[14],xmm5[15],xmm6[15] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm5 = xmm5[u,u],zero,zero,xmm5[12,13,u,u,u],zero,zero,xmm5[14,15,u,u,u] --; AVX512DQ-FCP-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm1[8],xmm0[9],xmm1[9],xmm0[10],xmm1[10],xmm0[11],xmm1[11],xmm0[12],xmm1[12],xmm0[13],xmm1[13],xmm0[14],xmm1[14],xmm0[15],xmm1[15] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,12,13],zero,zero,xmm0[u,u,u,14,15],zero,zero,xmm0[u,u,u] --; AVX512DQ-FCP-NEXT: vpor %xmm5, %xmm0, %xmm0 --; AVX512DQ-FCP-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[10],zero,xmm1[u,u,u,u,13,12],zero,xmm1[u,u,u,u,15,14],zero --; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm2 = zero,xmm2[13,u,u,u,u],zero,zero,xmm2[14,u,u,u,u],zero,zero,xmm2[15] --; AVX512DQ-FCP-NEXT: vpor %xmm2, %xmm1, %xmm1 --; AVX512DQ-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 --; AVX512DQ-FCP-NEXT: vinserti32x4 $2, %xmm1, %zmm9, %zmm0 --; AVX512DQ-FCP-NEXT: vmovdqa %xmm1, 96(%rax) --; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm11, (%rax) --; AVX512DQ-FCP-NEXT: vmovdqa %ymm0, 64(%rax) -+; AVX512DQ-FCP-NEXT: vmovdqa (%rdi), %xmm2 -+; AVX512DQ-FCP-NEXT: vmovdqa (%rsi), %xmm3 -+; AVX512DQ-FCP-NEXT: vmovdqa (%rdx), %xmm4 -+; AVX512DQ-FCP-NEXT: vmovdqa (%rcx), %xmm5 -+; AVX512DQ-FCP-NEXT: vmovdqa (%r8), %xmm1 -+; AVX512DQ-FCP-NEXT: vmovdqa (%r10), %xmm0 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm5, %ymm4, %ymm6 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, %xmm3, %ymm2, %ymm7 -+; AVX512DQ-FCP-NEXT: vinserti128 $1, (%r9), %ymm1, %ymm1 -+; AVX512DQ-FCP-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm1 -+; AVX512DQ-FCP-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm4[8],xmm5[8],xmm4[9],xmm5[9],xmm4[10],xmm5[10],xmm4[11],xmm5[11],xmm4[12],xmm5[12],xmm4[13],xmm5[13],xmm4[14],xmm5[14],xmm4[15],xmm5[15] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[u,u],zero,zero,xmm4[12,13,u,u,u],zero,zero,xmm4[14,15,u,u,u] -+; AVX512DQ-FCP-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm2[8],xmm3[8],xmm2[9],xmm3[9],xmm2[10],xmm3[10],xmm2[11],xmm3[11],xmm2[12],xmm3[12],xmm2[13],xmm3[13],xmm2[14],xmm3[14],xmm2[15],xmm3[15] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[u,u,12,13],zero,zero,xmm2[u,u,u,14,15],zero,zero,xmm2[u,u,u] -+; AVX512DQ-FCP-NEXT: vpor %xmm4, %xmm2, %xmm2 -+; AVX512DQ-FCP-NEXT: vextracti128 $1, %ymm1, %xmm3 -+; AVX512DQ-FCP-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm3[8],xmm1[8],xmm3[9],xmm1[9],xmm3[10],xmm1[10],xmm3[11],xmm1[11],xmm3[12],xmm1[12],xmm3[13],xmm1[13],xmm3[14],xmm1[14],xmm3[15],xmm1[15] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm3 = xmm3[10],zero,xmm3[u,u,u,u,13,12],zero,xmm3[u,u,u,u,15,14],zero -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm4 = zero,xmm0[13,u,u,u,u],zero,zero,xmm0[14,u,u,u,u],zero,zero,xmm0[15] -+; AVX512DQ-FCP-NEXT: vpor %xmm4, %xmm3, %xmm3 -+; AVX512DQ-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2, %xmm3 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm2 = ymm7[3,1,1,3] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[1],zero,zero,ymm2[u,u,u,10,2],zero,zero,ymm2[u,u,u,11,3],zero,zero,ymm2[u,u,u,20,28],zero,zero,ymm2[u,u,u,21,29],zero,zero,ymm2[u] -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm6[1,3,3,1] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm4 = zero,ymm4[1,9,u,u,u],zero,zero,ymm4[2,10,u,u,u],zero,zero,ymm4[3,19,u,u,u],zero,zero,ymm4[28,20,u,u,u],zero,zero,ymm4[29,21,u] -+; AVX512DQ-FCP-NEXT: vpor %ymm2, %ymm4, %ymm2 -+; AVX512DQ-FCP-NEXT: vpshufhw {{.*#+}} xmm4 = xmm0[0,1,2,3,4,5,5,6] -+; AVX512DQ-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [2,2,3,3,2,2,3,3] -+; AVX512DQ-FCP-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512DQ-FCP-NEXT: vpermd %ymm4, %ymm5, %ymm4 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm5 = ymm1[1,3,1,3] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm5 = ymm5[u,u,u,1,9],zero,ymm5[u,u,u,u,2,10],zero,ymm5[u,u,u,u,19,27],zero,ymm5[u,u,u,u,20,28],zero,ymm5[u,u,u,u,21] -+; AVX512DQ-FCP-NEXT: vpternlogq $244, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm5 -+; AVX512DQ-FCP-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm2, %ymm5 -+; AVX512DQ-FCP-NEXT: vinserti32x4 $2, %xmm3, %zmm5, %zmm2 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm4 = ymm6[0,2,0,2] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm4 = zero,zero,ymm4[0,8,u,u,u],zero,zero,ymm4[1,9,u,u,u],zero,zero,ymm4[18,26,u,u,u],zero,zero,ymm4[19,27,u,u,u],zero,zero,ymm4[20,28] -+; AVX512DQ-FCP-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [1,5,2,6,1,5,2,6] -+; AVX512DQ-FCP-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512DQ-FCP-NEXT: vpermd %ymm6, %ymm5, %ymm6 -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u],zero,zero,ymm6[1,5,u,u,u],zero,zero,ymm6[2,6,u,u,u],zero,zero,ymm6[19,23,u,u,u],zero,zero,ymm6[24,28,u,u,u],zero -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm6, %zmm4, %zmm4 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm7[0,2,0,2] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[0,8],zero,zero,ymm6[u,u,u,1,9],zero,zero,ymm6[u,u,u,2,10],zero,zero,ymm6[u,u,u,19,27],zero,zero,ymm6[u,u,u,20,28],zero,zero -+; AVX512DQ-FCP-NEXT: vpermd %ymm7, %ymm5, %ymm7 -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,1,5],zero,zero,ymm7[u,u,u,2,6],zero,zero,ymm7[u,u,u,19,23],zero,zero,ymm7[u,u,u,24,28],zero,zero,ymm7[u,u,u,25] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm7, %zmm6, %zmm6 -+; AVX512DQ-FCP-NEXT: vporq %zmm4, %zmm6, %zmm4 -+; AVX512DQ-FCP-NEXT: vpshuflw {{.*#+}} xmm6 = xmm0[1,1,0,0,4,5,6,7] -+; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} ymm7 = [0,1,0,1,0,0,0,0] -+; AVX512DQ-FCP-NEXT: vpermd %ymm6, %ymm7, %ymm6 -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,0,1,0] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm0, %zmm6, %zmm0 -+; AVX512DQ-FCP-NEXT: vpermq {{.*#+}} ymm6 = ymm1[0,2,0,2] -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u,u,0,8],zero,ymm6[u,u,u,u,1,9],zero,ymm6[u,u,u,u,18,26],zero,ymm6[u,u,u,u,19,27],zero,ymm6[u,u,u,u] -+; AVX512DQ-FCP-NEXT: vpermd %ymm1, %ymm5, %ymm1 -+; AVX512DQ-FCP-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,4],zero,ymm1[u,u,u,u,1,5],zero,ymm1[u,u,u,u,2,6],zero,ymm1[u,u,u,u,19,23],zero,ymm1[u,u,u,u,24,28],zero,ymm1[u] -+; AVX512DQ-FCP-NEXT: vinserti64x4 $1, %ymm1, %zmm6, %zmm1 -+; AVX512DQ-FCP-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm1 -+; AVX512DQ-FCP-NEXT: vpternlogd $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm1 -+; AVX512DQ-FCP-NEXT: vmovdqa %xmm3, 96(%rax) -+; AVX512DQ-FCP-NEXT: vmovdqa64 %zmm1, (%rax) -+; AVX512DQ-FCP-NEXT: vmovdqa %ymm2, 64(%rax) - ; AVX512DQ-FCP-NEXT: vzeroupper - ; AVX512DQ-FCP-NEXT: retq - ; -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast.ll b/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast.ll ---- a/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast.ll -+++ b/llvm/test/CodeGen/X86/zero_extend_vector_inreg_of_broadcast.ll -@@ -314,8 +314,8 @@ - ; - ; AVX512F-LABEL: vec64_i16_widen_to_i32_factor2_broadcast_to_v2i32_factor2: - ; AVX512F: # %bb.0: --; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -+; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1,10,11,0,1,14,15,u,u,u,u,u,u,u,u] - ; AVX512F-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa %ymm0, (%rcx) -@@ -324,8 +324,8 @@ - ; - ; AVX512DQ-LABEL: vec64_i16_widen_to_i32_factor2_broadcast_to_v2i32_factor2: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 --; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1,10,11,0,1,14,15,u,u,u,u,u,u,u,u] - ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa %ymm0, (%rcx) -@@ -981,7 +981,7 @@ - ; AVX512F-NEXT: vpmovsxbd {{.*#+}} xmm0 = [0,5,0,7] - ; AVX512F-NEXT: vmovdqa (%rdi), %ymm1 - ; AVX512F-NEXT: vpaddb (%rsi), %ymm1, %ymm1 --; AVX512F-NEXT: vpermd %ymm1, %ymm0, %ymm0 -+; AVX512F-NEXT: vpermd %zmm1, %zmm0, %zmm0 - ; AVX512F-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa %ymm0, (%rcx) - ; AVX512F-NEXT: vzeroupper -@@ -992,7 +992,7 @@ - ; AVX512DQ-NEXT: vpmovsxbd {{.*#+}} xmm0 = [0,5,0,7] - ; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm1 - ; AVX512DQ-NEXT: vpaddb (%rsi), %ymm1, %ymm1 --; AVX512DQ-NEXT: vpermd %ymm1, %ymm0, %ymm0 -+; AVX512DQ-NEXT: vpermd %zmm1, %zmm0, %zmm0 - ; AVX512DQ-NEXT: vpaddb (%rdx), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa %ymm0, (%rcx) - ; AVX512DQ-NEXT: vzeroupper -@@ -4026,10 +4026,10 @@ - ; - ; AVX512F-FAST-LABEL: vec384_i16_widen_to_i64_factor4_broadcast_to_v6i64_factor6: - ; AVX512F-FAST: # %bb.0: --; AVX512F-FAST-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512F-FAST-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-FAST-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-FAST-NEXT: vmovdqa 48(%rdi), %xmm1 - ; AVX512F-FAST-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 --; AVX512F-FAST-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512F-FAST-NEXT: vpbroadcastq %xmm0, %ymm2 - ; AVX512F-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm2[0],ymm1[1,2,3],ymm2[4],ymm1[5,6,7],ymm2[8],ymm1[9,10,11],ymm2[12],ymm1[13,14,15] - ; AVX512F-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 -@@ -4062,10 +4062,10 @@ - ; - ; AVX512DQ-FAST-LABEL: vec384_i16_widen_to_i64_factor4_broadcast_to_v6i64_factor6: - ; AVX512DQ-FAST: # %bb.0: --; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-FAST-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-FAST-NEXT: vmovdqa 48(%rdi), %xmm1 - ; AVX512DQ-FAST-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 --; AVX512DQ-FAST-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512DQ-FAST-NEXT: vpbroadcastq %xmm0, %ymm2 - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm2[0],ymm1[1,2,3],ymm2[4],ymm1[5,6,7],ymm2[8],ymm1[9,10,11],ymm2[12],ymm1[13,14,15] - ; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 -@@ -4541,9 +4541,9 @@ - ; - ; AVX512F-LABEL: vec384_i16_widen_to_i192_factor12_broadcast_to_v2i192_factor2: - ; AVX512F: # %bb.0: --; AVX512F-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512F-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512F-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512F-NEXT: vmovdqa 48(%rdi), %xmm1 --; AVX512F-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512F-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 - ; AVX512F-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5,6,7] - ; AVX512F-NEXT: vpbroadcastw %xmm0, %ymm0 -@@ -4559,9 +4559,9 @@ - ; - ; AVX512DQ-LABEL: vec384_i16_widen_to_i192_factor12_broadcast_to_v2i192_factor2: - ; AVX512DQ: # %bb.0: --; AVX512DQ-NEXT: vmovdqa (%rdi), %xmm0 -+; AVX512DQ-NEXT: vmovdqa (%rdi), %ymm0 -+; AVX512DQ-NEXT: vpaddb (%rsi), %ymm0, %ymm0 - ; AVX512DQ-NEXT: vmovdqa 48(%rdi), %xmm1 --; AVX512DQ-NEXT: vpaddb (%rsi), %xmm0, %xmm0 - ; AVX512DQ-NEXT: vpaddb 48(%rsi), %xmm1, %xmm1 - ; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm1 = xmm0[0],xmm1[1,2,3,4,5,6,7] - ; AVX512DQ-NEXT: vpbroadcastw %xmm0, %ymm0 diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 5fee96230ce7cb..5ef93e8ac6da3e 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "37b7207651b44743909a427b5509bed5e6c21b59" - LLVM_SHA256 = "dfa0cbf5107203a8a80b4d8dc65b66d68c59414eb9743fce90249da16d719780" + LLVM_COMMIT = "694c444b5bbb56dcba8978d283fe5385237c309a" + LLVM_SHA256 = "b5c953f633562b6e81101f9c50f246b741cd43158cec852d426ec55adb63f4b3" tf_http_archive( name = name, diff --git a/third_party/png.BUILD b/third_party/png.BUILD index e79442e35bc58d..ea1995e025fe87 100644 --- a/third_party/png.BUILD +++ b/third_party/png.BUILD @@ -61,7 +61,7 @@ genrule( name = "snappy_stubs_public_h", srcs = ["scripts/pnglibconf.h.prebuilt"], outs = ["pnglibconf.h"], - cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x12d0/' $< >$@", + cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x1310/' $< >$@", ) config_setting( diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index f9229c0d13765e..c8897665644baa 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -174,170 +174,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp ---- stablehlo/stablehlo/dialect/AssemblyFormat.cpp -+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp -@@ -15,6 +15,7 @@ - - #include "stablehlo/dialect/AssemblyFormat.h" - -+#include - #include - #include - #include -@@ -130,6 +131,42 @@ - for (Type& t : opTypes) typePtrs.push_back(&t); - - return detail::parseSameOperandsAndResultTypeImpl(parser, typePtrs, result); -+} -+ -+void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value) { -+ assert(op->getNumResults() == 1); -+ // If not all types are the same, use generic form. -+ if (value.getType() != op->getResultTypes().front()) { -+ p.printGenericOp(op, /*printOpName=*/false); -+ return; -+ } -+ -+ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); -+ p << ' '; -+ p.printStrippedAttrOrType(value); -+} -+ -+ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) { -+ // Parse the generic form. -+ if (succeeded(parser.parseOptionalLParen())) { -+ if (parser.parseRParen()) return failure(); -+ if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -+ if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || -+ parser.parseArrow()) -+ return failure(); -+ Type resultTy; -+ if (parser.parseType(resultTy)) return failure(); -+ result.addTypes(resultTy); -+ return success(); -+ } -+ -+ ElementsAttr valueAttr; -+ if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -+ if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", -+ result.attributes)) -+ return failure(); -+ result.addTypes(valueAttr.getType()); -+ return success(); - } - - void printTupleOpType(OpAsmPrinter& p, Operation*, TypeRange operands, -diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h ---- stablehlo/stablehlo/dialect/AssemblyFormat.h -+++ stablehlo/stablehlo/dialect/AssemblyFormat.h -@@ -101,6 +101,16 @@ - SmallVectorImpl& operands, - SmallVectorImpl& opTypes, Type& result); - -+// Print a `constant` op. -+// -+// op ::= attr-dict $value -+// -+// When the `value` and `output` have different type, it just uses the default -+// operator assembly format as a fallback. -+void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value); -+ -+ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result); -+ - // TuplesOp - only print result type. Operand type is trivially inferrable. - // - // Inferring operand types from tuple type: -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp ---- stablehlo/stablehlo/dialect/StablehloOps.cpp -+++ stablehlo/stablehlo/dialect/StablehloOps.cpp -@@ -256,6 +256,16 @@ - // ConstantOp - //===----------------------------------------------------------------------===// - -+void ConstantOp::getAsmResultNames( -+ function_ref setNameFn) { -+ mlir::TensorType type = getType(); -+ if (type.getElementType().isa()) { -+ setNameFn(getResult(), "c"); -+ } else { -+ setNameFn(getResult(), "cst"); -+ } -+} -+ - OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { - assert(adaptor.getOperands().empty() && "constant has no operands"); - -@@ -311,44 +321,11 @@ - } - - ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { -- // Parse the generic form. -- if (succeeded(parser.parseOptionalLParen())) { -- if (parser.parseRParen()) return failure(); -- if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -- if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || -- parser.parseArrow()) -- return failure(); -- Type resultTy; -- if (parser.parseType(resultTy)) return failure(); -- result.addTypes(resultTy); -- return success(); -- } -- -- ElementsAttr valueAttr; -- if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -- if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", -- result.attributes)) -- return failure(); -- result.addTypes(valueAttr.getType()); -- return success(); --} -- --/// Print a `constant` op. --/// --/// op ::= attr-dict $value --/// --/// When the `value` and `output` have different type, it just uses the default --/// operator assembly format as a fallback. -+ return hlo::parseConstantOp(parser, result); -+} -+ - void ConstantOp::print(::mlir::OpAsmPrinter& p) { -- // If not all types are the same, use generic form. -- if (getValue().getType() != getType()) { -- p.printGenericOp(getOperation(), /*printOpName=*/false); -- return; -- } -- -- p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); -- p << ' '; -- p.printStrippedAttrOrType(getValueAttr()); -+ hlo::printConstantOp(p, getOperation(), getValue()); - } - - //===----------------------------------------------------------------------===// -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td ---- stablehlo/stablehlo/dialect/StablehloOps.td -+++ stablehlo/stablehlo/dialect/StablehloOps.td -@@ -38,6 +38,7 @@ - - let useDefaultAttributePrinterParser = 0; - let useDefaultTypePrinterParser = 0; -+ let usePropertiesForAttributes = 0; - } - - class StableHLO_Op traits = []> : -@@ -65,7 +66,8 @@ - //===----------------------------------------------------------------------===// - - def StableHLO_ConstantOp : StableHLO_Op<"constant", -- [ConstantLike, Pure, DeclareOpInterfaceMethods]> { -+ [ConstantLike, Pure, DeclareOpInterfaceMethods, -+ DeclareOpInterfaceMethods]> { - let summary = "Constant operation"; - let description = [{ - Produces an `output` tensor from a constant `value`. diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -2800,504 +2636,4 @@ diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/referen result.set(resultIndex, dotProduct.get(*dotProductIt)); } } -diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ---- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -+++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -@@ -1317,8 +1317,8 @@ - // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_3]] - // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_4]] - // CHECK: %[[TMP_35:.*]] = stablehlo.constant dense<1.000000e+00> -- // CHECK-DAG: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] -- // CHECK-DAG: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] -+ // CHECK: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] -+ // CHECK: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] - // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_36]], %[[TMP_37]] - // CHECK: %[[TMP_39:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] - // CHECK: %[[TMP_40:.*]] = stablehlo.divide %[[TMP_3]], %[[TMP_39]] -@@ -1465,7 +1465,7 @@ - - // ----- - --// CHECK-LABEL: @polygamma_f32 -+// CHECK: @polygamma_f32 - // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> -@@ -1592,8 +1592,8 @@ - // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] - // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] - // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> -- // CHECK-DAG: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -- // CHECK-DAG: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] -+ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -+ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] - // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] -@@ -1602,7 +1602,7 @@ - // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] - // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] -- // CHECK: %[[TMP_134:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.39544646E-19> - // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] - // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] -@@ -1611,7 +1611,7 @@ - // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] - // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] -- // CHECK: %[[TMP_143:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.50900303E-18> - // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] - // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] -@@ -1620,7 +1620,7 @@ - // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] - // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] -- // CHECK: %[[TMP_152:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.17486866E-16> - // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] - // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] - // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] -@@ -1629,7 +1629,7 @@ - // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] - // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] -- // CHECK: %[[TMP_161:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.58606213E-15> - // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] - // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] - // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] -@@ -1638,7 +1638,7 @@ - // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] - // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] -- // CHECK: %[[TMP_170:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896803E-13> - // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] - // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] - // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] -@@ -1647,7 +1647,7 @@ - // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] - // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] -- // CHECK: %[[TMP_179:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.33825364E-11> - // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] - // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] - // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] -@@ -1656,7 +1656,7 @@ - // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] - // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] -- // CHECK: %[[TMP_188:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.28419031E-10> - // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] - // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] - // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] -@@ -1665,7 +1665,7 @@ - // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] - // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] -- // CHECK: %[[TMP_197:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767563E-8> - // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] - // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] - // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] -@@ -1674,7 +1674,7 @@ - // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] - // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] -- // CHECK: %[[TMP_206:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.26719599E-7> - // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] - // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] - // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] -@@ -1683,7 +1683,7 @@ - // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] - // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] -- // CHECK: %[[TMP_215:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.30687835E-5> - // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] - // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] - // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] -@@ -1692,13 +1692,13 @@ - // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] - // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] -- // CHECK: %[[TMP_224:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.00138888892> - // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] - // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] - // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] - // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] -- // CHECK: %[[TMP_230:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.0833333358> - // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] - // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] - // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] -@@ -1707,11 +1707,11 @@ - // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] - // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] - // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] -- // CHECK: %[[TMP_239:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<1.401300e-45> - // CHECK: %[[TMP_240:.*]] = stablehlo.multiply %[[TMP_238]], %[[TMP_239]] - // CHECK: %[[TMP_241:.*]] = stablehlo.compare LT, %[[TMP_237]], %[[TMP_240]], NOTYPE - // CHECK: %[[TMP_242:.*]] = stablehlo.select %[[TMP_241]], %[[TMP_120]], %[[TMP_236]] -- // CHECK: %[[TMP_243:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_243:.*]] = stablehlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_244:.*]] = stablehlo.compare LT, %[[TMP_5]], %[[TMP_123]], NOTYPE - // CHECK: %[[TMP_245:.*]] = stablehlo.select %[[TMP_244]], %[[TMP_243]], %[[TMP_242]] - // CHECK: %[[TMP_246:.*]] = stablehlo.compare LE, %[[ARG1]], %[[TMP_90]], NOTYPE -@@ -1719,7 +1719,7 @@ - // CHECK: %[[TMP_248:.*]] = stablehlo.compare NE, %[[TMP_5]], %[[TMP_247]], NOTYPE - // CHECK: %[[TMP_249:.*]] = stablehlo.and %[[TMP_246]], %[[TMP_248]] - // CHECK: %[[TMP_250:.*]] = stablehlo.select %[[TMP_249]], %[[TMP_243]], %[[TMP_245]] -- // CHECK: %[[TMP_251:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_251:.*]] = stablehlo.constant dense<0x7F800000> - // CHECK: %[[TMP_252:.*]] = stablehlo.floor %[[ARG1]] - // CHECK: %[[TMP_253:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[TMP_252]], NOTYPE - // CHECK: %[[TMP_254:.*]] = stablehlo.and %[[TMP_246]], %[[TMP_253]] -@@ -1744,8 +1744,8 @@ - // CHECK: %[[TMP_273:.*]] = stablehlo.subtract %[[ARG1]], %[[TMP_272]] - // CHECK: %[[TMP_274:.*]] = stablehlo.select %[[TMP_270]], %[[TMP_271]], %[[TMP_273]] - // CHECK-DAG: %[[TMP_275:.*]] = stablehlo.constant dense<0.000000e+00> -- // CHECK-DAG: %[[TMP_276:.*]] = stablehlo.constant -- // CHECK-DAG: %[[TMP_277:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_276:.*]] = stablehlo.constant dense<1.000000e+00> -+ // CHECK-DAG: %[[TMP_277:.*]] = stablehlo.constant dense<676.520386> - // CHECK-DAG: %[[TMP_278:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_279:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_278]] - // CHECK: %[[TMP_280:.*]] = stablehlo.multiply %[[TMP_279]], %[[TMP_279]] -@@ -1753,7 +1753,7 @@ - // CHECK: %[[TMP_282:.*]] = stablehlo.subtract %[[TMP_275]], %[[TMP_281]] - // CHECK: %[[TMP_283:.*]] = stablehlo.divide %[[TMP_277]], %[[TMP_279]] - // CHECK: %[[TMP_284:.*]] = stablehlo.add %[[TMP_276]], %[[TMP_283]] -- // CHECK-DAG: %[[TMP_285:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_285:.*]] = stablehlo.constant dense<-1259.13916> - // CHECK-DAG: %[[TMP_286:.*]] = stablehlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_287:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_286]] - // CHECK: %[[TMP_288:.*]] = stablehlo.multiply %[[TMP_287]], %[[TMP_287]] -@@ -1761,7 +1761,7 @@ - // CHECK: %[[TMP_290:.*]] = stablehlo.subtract %[[TMP_282]], %[[TMP_289]] - // CHECK: %[[TMP_291:.*]] = stablehlo.divide %[[TMP_285]], %[[TMP_287]] - // CHECK: %[[TMP_292:.*]] = stablehlo.add %[[TMP_284]], %[[TMP_291]] -- // CHECK-DAG: %[[TMP_293:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_293:.*]] = stablehlo.constant dense<771.323425> - // CHECK-DAG: %[[TMP_294:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_295:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_294]] - // CHECK: %[[TMP_296:.*]] = stablehlo.multiply %[[TMP_295]], %[[TMP_295]] -@@ -1769,15 +1769,15 @@ - // CHECK: %[[TMP_298:.*]] = stablehlo.subtract %[[TMP_290]], %[[TMP_297]] - // CHECK: %[[TMP_299:.*]] = stablehlo.divide %[[TMP_293]], %[[TMP_295]] - // CHECK: %[[TMP_300:.*]] = stablehlo.add %[[TMP_292]], %[[TMP_299]] -- // CHECK-DAG: %[[TMP_301:.*]] = stablehlo.constant -- // CHECK-DAG: %[[TMP_302:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_301:.*]] = stablehlo.constant dense<-176.615036> -+ // CHECK-DAG: %[[TMP_302:.*]] = stablehlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_303:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_302]] - // CHECK: %[[TMP_304:.*]] = stablehlo.multiply %[[TMP_303]], %[[TMP_303]] - // CHECK: %[[TMP_305:.*]] = stablehlo.divide %[[TMP_301]], %[[TMP_304]] - // CHECK: %[[TMP_306:.*]] = stablehlo.subtract %[[TMP_298]], %[[TMP_305]] - // CHECK: %[[TMP_307:.*]] = stablehlo.divide %[[TMP_301]], %[[TMP_303]] - // CHECK: %[[TMP_308:.*]] = stablehlo.add %[[TMP_300]], %[[TMP_307]] -- // CHECK-DAG: %[[TMP_309:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_309:.*]] = stablehlo.constant dense<12.5073433> - // CHECK-DAG: %[[TMP_310:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_311:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_310]] - // CHECK: %[[TMP_312:.*]] = stablehlo.multiply %[[TMP_311]], %[[TMP_311]] -@@ -1785,7 +1785,7 @@ - // CHECK: %[[TMP_314:.*]] = stablehlo.subtract %[[TMP_306]], %[[TMP_313]] - // CHECK: %[[TMP_315:.*]] = stablehlo.divide %[[TMP_309]], %[[TMP_311]] - // CHECK: %[[TMP_316:.*]] = stablehlo.add %[[TMP_308]], %[[TMP_315]] -- // CHECK-DAG: %[[TMP_317:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_317:.*]] = stablehlo.constant dense<-0.138571098> - // CHECK-DAG: %[[TMP_318:.*]] = stablehlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_319:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_318]] - // CHECK: %[[TMP_320:.*]] = stablehlo.multiply %[[TMP_319]], %[[TMP_319]] -@@ -1793,7 +1793,7 @@ - // CHECK: %[[TMP_322:.*]] = stablehlo.subtract %[[TMP_314]], %[[TMP_321]] - // CHECK: %[[TMP_323:.*]] = stablehlo.divide %[[TMP_317]], %[[TMP_319]] - // CHECK: %[[TMP_324:.*]] = stablehlo.add %[[TMP_316]], %[[TMP_323]] -- // CHECK-DAG: %[[TMP_325:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_325:.*]] = stablehlo.constant dense<9.98436917E-6> - // CHECK-DAG: %[[TMP_326:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_327:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_326]] - // CHECK: %[[TMP_328:.*]] = stablehlo.multiply %[[TMP_327]], %[[TMP_327]] -@@ -1801,7 +1801,7 @@ - // CHECK: %[[TMP_330:.*]] = stablehlo.subtract %[[TMP_322]], %[[TMP_329]] - // CHECK: %[[TMP_331:.*]] = stablehlo.divide %[[TMP_325]], %[[TMP_327]] - // CHECK: %[[TMP_332:.*]] = stablehlo.add %[[TMP_324]], %[[TMP_331]] -- // CHECK-DAG: %[[TMP_333:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_333:.*]] = stablehlo.constant dense<1.50563267E-7> - // CHECK-DAG: %[[TMP_334:.*]] = stablehlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_335:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_334]] - // CHECK: %[[TMP_336:.*]] = stablehlo.multiply %[[TMP_335]], %[[TMP_335]] -@@ -1811,7 +1811,7 @@ - // CHECK: %[[TMP_340:.*]] = stablehlo.add %[[TMP_332]], %[[TMP_339]] - // CHECK: %[[TMP_341:.*]] = stablehlo.constant dense<7.500000e+00> - // CHECK: %[[TMP_342:.*]] = stablehlo.add %[[TMP_341]], %[[TMP_274]] -- // CHECK: %[[TMP_343:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_343:.*]] = stablehlo.constant dense<2.01490307> - // CHECK: %[[TMP_344:.*]] = stablehlo.divide %[[TMP_274]], %[[TMP_341]] - // CHECK: %[[TMP_345:.*]] = stablehlo.log_plus_one %[[TMP_344]] - // CHECK: %[[TMP_346:.*]] = stablehlo.add %[[TMP_343]], %[[TMP_345]] -@@ -1825,7 +1825,7 @@ - // CHECK: %[[TMP_354:.*]] = stablehlo.floor %[[TMP_353]] - // CHECK: %[[TMP_355:.*]] = stablehlo.abs %[[TMP_354]] - // CHECK: %[[TMP_356:.*]] = stablehlo.add %[[ARG1]], %[[TMP_355]] -- // CHECK: %[[TMP_357:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_357:.*]] = stablehlo.constant dense<3.14159274> - // CHECK: %[[TMP_358:.*]] = stablehlo.multiply %[[TMP_357]], %[[TMP_356]] - // CHECK: %[[TMP_359:.*]] = stablehlo.cosine %[[TMP_358]] - // CHECK: %[[TMP_360:.*]] = stablehlo.sine %[[TMP_358]] -@@ -1837,14 +1837,14 @@ - // CHECK: %[[TMP_366:.*]] = stablehlo.floor %[[ARG1]] - // CHECK: %[[TMP_367:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[TMP_366]], NOTYPE - // CHECK: %[[TMP_368:.*]] = stablehlo.and %[[TMP_365]], %[[TMP_367]] -- // CHECK: %[[TMP_369:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_369:.*]] = stablehlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_370:.*]] = stablehlo.select %[[TMP_368]], %[[TMP_369]], %[[TMP_364]] - // CHECK: %[[TMP_371:.*]] = stablehlo.select %[[TMP_268]], %[[TMP_370]], %[[TMP_266]] - // CHECK: %[[TMP_372:.*]] = stablehlo.floor %[[ARG0]] - // CHECK: %[[TMP_373:.*]] = stablehlo.compare NE, %[[ARG0]], %[[TMP_372]], NOTYPE - // CHECK: %[[TMP_374:.*]] = stablehlo.compare LT, %[[ARG0]], %[[TMP_267]], NOTYPE - // CHECK: %[[TMP_375:.*]] = stablehlo.or %[[TMP_373]], %[[TMP_374]] -- // CHECK: %[[TMP_376:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_376:.*]] = stablehlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_377:.*]] = stablehlo.select %[[TMP_375]], %[[TMP_376]], %[[TMP_371]] - %1 = chlo.polygamma %lhs, %rhs : tensor, tensor -> tensor - func.return %1 : tensor -@@ -1852,7 +1852,7 @@ - - // ----- - --// CHECK-LABEL: @polygamma_f64 -+// CHECK: @polygamma_f64 - // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> -@@ -1979,8 +1979,8 @@ - // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] - // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] - // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> -- // CHECK-DAG: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -- // CHECK-DAG: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] -+ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -+ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] - // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] -@@ -2492,11 +2492,11 @@ - // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) - func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - // CHECK: %[[IOTA:.*]] = stablehlo.iota dim = 1 : tensor<16x16xi32> -- // CHECK-NEXT: %[[SORT:.*]]:2 = "stablehlo.sort"(%[[ARG]], %[[IOTA]]) <{dimension = 1 : i64, is_stable = true}> ({ -+ // CHECK-NEXT: %[[SORT:.*]]:2 = "stablehlo.sort"(%[[ARG]], %[[IOTA]]) ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): - // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: stablehlo.return %[[CMP]] -- // CHECK-NEXT: }) : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) -+ // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - // CHECK-NEXT: %[[VAL:.*]] = stablehlo.slice %[[SORT]]#0 [0:16, 0:8] : (tensor<16x16xf32>) -> tensor<16x8xf32> - // CHECK-NEXT: %[[IDX:.*]] = stablehlo.slice %[[SORT]]#1 [0:16, 0:8] : (tensor<16x16xi32>) -> tensor<16x8xi32> - // CHECK-NEXT: return %[[VAL]], %[[IDX]] -@@ -2521,11 +2521,11 @@ - // CHECK-NEXT: [[K_I32x1:%.*]] = stablehlo.reshape [[K_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[RESULT_SHAPE:%.*]] = stablehlo.concatenate [[DIM_0_I32x1]], [[DIM_1_I32x1]], [[K_I32x1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[IOTA:%.*]] = stablehlo.dynamic_iota [[IOTA_SHAPE]], dim = 2 : (tensor<3xi32>) -> tensor -- // CHECK-NEXT: [[SORT:%.*]]:2 = "stablehlo.sort"([[ARG]], [[IOTA]]) <{dimension = 2 : i64, is_stable = true}> ({ -+ // CHECK-NEXT: [[SORT:%.*]]:2 = "stablehlo.sort"([[ARG]], [[IOTA]]) ({ - // CHECK-NEXT: ^bb0([[ARG_1:%.*]]: tensor, [[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor, [[ARG_4:%.*]]: tensor): - // CHECK-NEXT: [[CMP:%.*]] = stablehlo.compare GT, [[ARG_1]], [[ARG_2]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: stablehlo.return [[CMP]] : tensor -- // CHECK-NEXT: }) : (tensor, tensor) -> (tensor, tensor) -+ // CHECK-NEXT: }) {dimension = 2 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) - // CHECK-NEXT: [[STARTS:%.*]] = stablehlo.constant dense<0> : tensor<3xi64> - // CHECK-NEXT: [[LIMITS:%.*]] = stablehlo.convert [[RESULT_SHAPE]] : (tensor<3xi32>) -> tensor<3xi64> - // CHECK-NEXT: [[STRIDES:%.*]] = stablehlo.constant dense<1> : tensor<3xi64> -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo.mlir -@@ -2014,7 +2014,7 @@ - - // CHECK-LABEL: func @rng_normal - func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf32> { -- %cst = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> -+ %cst = "stablehlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64> - %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> - func.return %0 : tensor<2x3x5xf32> - } -@@ -2030,7 +2030,7 @@ - // ----- - - func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { -- %cst = stablehlo.constant dense<7> : tensor<1xi64> -+ %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error@+2 {{failed to infer returned types}} - // expected-error @+1 {{inferred type(s) 'tensor<7xf32>' are incompatible with return type(s) of operation 'tensor<12xf32>'}} - %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> -@@ -2067,7 +2067,7 @@ - // ----- - - func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { -- %cst = stablehlo.constant dense<7> : tensor<1xi64> -+ %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} - %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> - func.return -@@ -2691,7 +2691,7 @@ - // CHECK-LABEL: func @constants - func.func @constants() -> () { - // CHECK: stablehlo.constant dense<0> : tensor -- %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> (tensor) -+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> (tensor) - - // CHECK: stablehlo.constant {extra_attr = 3 : i32} dense<0> : tensor - %1 = "stablehlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) -@@ -2703,7 +2703,7 @@ - func.func @constant_invalid() -> () { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{'stablehlo.constant' op inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor<3xi32>'}} -- %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> (tensor<3xi32>) -+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> (tensor<3xi32>) - func.return - } - -@@ -2711,7 +2711,7 @@ - - func.func @constant_invalid() -> () { - // expected-error@+1 {{op result #0 must be statically shaped tensor}} -- %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor -+ %0 = "stablehlo.constant"() {value = dense<1> : tensor} : () -> tensor - func.return - } - -@@ -2719,7 +2719,7 @@ - - func.func @constant_invalid() -> () { - // expected-error@+1 {{elements literal type must have static shape}} -- %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor -+ %0 = "stablehlo.constant"() {value = dense<1> : tensor} : () -> tensor - func.return - } - -@@ -4872,7 +4872,7 @@ - %3 = stablehlo.uniform_quantize %2 : (tensor<2xf32>) -> tensor<2x!quant.uniform> - %4 = stablehlo.uniform_quantize %1 : (tensor<2xf32>) -> tensor<2x!quant.uniform> - func.return %0, %4, %3 : tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform> -- // CHECK: stablehlo.constant() <{value = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2x!quant.uniform> -+ // CHECK: stablehlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform> - // CHECK-NEXT: stablehlo.constant dense<[1.000000e+01, 1.200000e+01]> : tensor<2xf32> - // CHECK-NEXT: stablehlo.constant dense<[3.000000e+00, 1.000000e+02]> : tensor<2xf32> - } -diff --ruN a/stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir ---- stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir -+++ stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir -@@ -14,11 +14,11 @@ - // CHECK-NEXT: %[[INPUT_SIZE_PRODUCT:.*]] = stablehlo.multiply %[[TMP1]], %[[INPUT_SIZE1]] : tensor - // CHECK-NEXT: %[[COMPUTED_SIZE:.*]] = stablehlo.divide %[[ARG0_I32]], %[[INPUT_SIZE_PRODUCT]] : tensor - // CHECK-NEXT: %[[M1:.*]] = stablehlo.constant dense<-1> : tensor -- // CHECK-NEXT: %[[INPUT_SIZE0_EQ_M1:.*]] = stablehlo.compare EQ, %3, %[[M1]], NOTYPE : (tensor, tensor) -> tensor -- // CHECK-NEXT: %[[RESULT_SIZE0:.*]] = stablehlo.select %[[INPUT_SIZE0_EQ_M1]], %[[COMPUTED_SIZE]], %3 : tensor, tensor -+ // CHECK-NEXT: %[[INPUT_SIZE0_EQ_M1:.*]] = stablehlo.compare EQ, %[[INPUT_SIZE0]], %[[M1]], NOTYPE : (tensor, tensor) -> tensor -+ // CHECK-NEXT: %[[RESULT_SIZE0:.*]] = stablehlo.select %[[INPUT_SIZE0_EQ_M1]], %[[COMPUTED_SIZE]], %[[INPUT_SIZE0]] : tensor, tensor - // CHECK-NEXT: %[[RESULT_SIZE0x1:.*]] = stablehlo.reshape %[[RESULT_SIZE0]] : (tensor) -> tensor<1xi32> -- // CHECK-NEXT: %[[INPUT_SIZE1_EQ_M1:.*]] = stablehlo.compare EQ, %6, %[[M1]], NOTYPE : (tensor, tensor) -> tensor -- // CHECK-NEXT: %[[RESULT_SIZE1:.*]] = stablehlo.select %[[INPUT_SIZE1_EQ_M1]], %[[COMPUTED_SIZE]], %6 : tensor, tensor -+ // CHECK-NEXT: %[[INPUT_SIZE1_EQ_M1:.*]] = stablehlo.compare EQ, %[[INPUT_SIZE1]], %[[M1]], NOTYPE : (tensor, tensor) -> tensor -+ // CHECK-NEXT: %[[RESULT_SIZE1:.*]] = stablehlo.select %[[INPUT_SIZE1_EQ_M1]], %[[COMPUTED_SIZE]], %[[INPUT_SIZE1]] : tensor, tensor - // CHECK-NEXT: %[[RESULT_SIZE1x1:.*]] = stablehlo.reshape %[[RESULT_SIZE1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.concatenate %[[RESULT_SIZE0x1]], %[[RESULT_SIZE1x1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir ---- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -+++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -@@ -137,7 +137,7 @@ - - // CHECK-LABEL: func @custom_call_inapplicable_missing_indices_of_shape_operands - func.func @custom_call_inapplicable_missing_indices_of_shape_operands(%arg0: tensor<4xf32>) -> tensor<1x2xf32> { -- // CHECK: stablehlo.custom_call @foo(%arg0, %0) -+ // CHECK: stablehlo.custom_call @foo(%arg0, %c) - %0 = stablehlo.constant dense<[1, 2]> : tensor<2xi64> - %1 = stablehlo.custom_call @foo(%arg0, %0) : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x2xf32> - return %1 : tensor<1x2xf32> -@@ -147,7 +147,7 @@ - - // CHECK-LABEL: func @custom_call_inapplicable_dynamic_result_type - func.func @custom_call_inapplicable_dynamic_result_type(%arg0: tensor<4xf32>) -> tensor<1x?xf32> { -- // CHECK: stablehlo.custom_call @foo(%arg0, %0) -+ // CHECK: stablehlo.custom_call @foo(%arg0, %c) - %0 = stablehlo.constant dense<[1, 2]> : tensor<2xi64> - %1 = stablehlo.custom_call @foo(%arg0, %0) { - indices_of_shape_operands = dense<[1]> : tensor<1xi64> -@@ -272,7 +272,7 @@ - // CHECK-LABEL: @dynamic_gather_success_static_result_type - func.func @dynamic_gather_success_static_result_type(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { - // CHECK-NOT: stablehlo.dynamic_gather -- // CHECK: "stablehlo.gather"(%arg0, %arg1) <{ -+ // CHECK: "stablehlo.gather"(%arg0, %arg1) { - // CHECK-SAME: dimension_numbers = #stablehlo.gather< - // CHECK-SAME: offset_dims = [2], - // CHECK-SAME: collapsed_slice_dims = [0, 1], -@@ -280,7 +280,7 @@ - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: >, - // CHECK-SAME: slice_sizes = array -- // CHECK-SAME: }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x8xi32> -+ // CHECK-SAME: } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x8xi32> - %0 = stablehlo.constant dense<[1, 1, 8]> : tensor<3xi32> - %1 = "stablehlo.dynamic_gather"(%arg0, %arg1, %0) { - dimension_numbers = #stablehlo.gather< -@@ -298,7 +298,7 @@ - // CHECK-LABEL: @dynamic_gather_success_dynamic_result_type - func.func @dynamic_gather_success_dynamic_result_type(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x?xi32> { - // CHECK-NOT: stablehlo.dynamic_gather -- // CHECK: "stablehlo.gather"(%arg0, %arg1) <{ -+ // CHECK: "stablehlo.gather"(%arg0, %arg1) { - // CHECK-SAME: dimension_numbers = #stablehlo.gather< - // CHECK-SAME: offset_dims = [2], - // CHECK-SAME: collapsed_slice_dims = [0, 1], -@@ -306,16 +306,16 @@ - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: >, - // CHECK-SAME: slice_sizes = array -- // CHECK-SAME: }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x?xi32> -+ // CHECK-SAME: } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x?xi32> - %0 = stablehlo.constant dense<[1, 1, 8]> : tensor<3xi32> -- %1 = "stablehlo.dynamic_gather"(%arg0, %arg1, %0) <{ -+ %1 = "stablehlo.dynamic_gather"(%arg0, %arg1, %0) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - > -- }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x?xi32> -+ } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x?xi32> - return %1 : tensor<1x5x?xi32> - } - -@@ -324,14 +324,14 @@ - // CHECK-LABEL: @dynamic_gather_inapplicable_dynamic_slice_sizes - func.func @dynamic_gather_inapplicable_dynamic_slice_sizes(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xi32> { - // CHECK: stablehlo.dynamic_gather -- %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) <{ -+ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - > -- }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xi32> -+ } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xi32> - return %0 : tensor<1x5x8xi32> - } - diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index abaa26639918b6..86613c0dcd88cf 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "714d9acacd96760f4d8a9fe3898bee4f204cd419" - STABLEHLO_SHA256 = "4912edb1bef3362862e4942a624de7430eafbb0a39d93c0120a2ffbfcbedf11a" + STABLEHLO_COMMIT = "e81411ef562e11337283ef24bb3c40b2f3a6ebfa" + STABLEHLO_SHA256 = "167f15fbdfc3dc54601b6e37d53bce7323123f701893deb3935c8629a763766a" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 158ea9dc0be0b3..c707d71c20d603 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "16ceace318b63a569138261331c66d7809d079de" - TFRT_SHA256 = "f7572ad7b08aa59023e25b9907710ec10fff495837bed57088c9da16d46c2174" + TFRT_COMMIT = "9fdfdeada1eb04e11e0db68461b5bd1dcdb02062" + TFRT_SHA256 = "1e310a961a1248efd767fadfda4c205d3e05ca15205bb8d299ca4037f42f9b18" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index d7ae76f096431a..1a47809a8f9753 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -251,7 +251,7 @@ build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang" +build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" @@ -602,15 +602,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium -# Ensure release_base is set on linux -build:release_linux_base --config=release_base - -# Target the AVX instruction set -build:release_linux_base --config=avx_linux - # Enable support for all targets build:release_base --config=cpu_cross +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Disable clang extension that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within @@ -633,8 +630,8 @@ build:release_linux_base --action_env PYTHON_BIN_PATH="/usr/bin/python3" build:release_linux_base --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" build:release_linux_base --python_path="/usr/bin/python3" # Set Clang as compiler. Use the actual path to clang installed in container. -build:release_cpu_linux_base --repo_env=CC="/usr/lib/llvm-17/bin/clang" -build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-17/bin/clang" +build:release_cpu_linux_base --repo_env=CC="/usr/lib/llvm-18/bin/clang" +build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/clang" # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS @@ -645,6 +642,8 @@ test:release_linux_base --test_summary=short # Use the Clang toolchain to compile build:release_cpu_linux --config=release_linux_base build:release_cpu_linux --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" +# Target the AVX instruction set +build:release_cpu_linux --config=avx_linux build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. diff --git a/third_party/xla/.github/workflows/benchmark.yml b/third_party/xla/.github/workflows/benchmark.yml deleted file mode 100644 index 9d4a8d0881b47e..00000000000000 --- a/third_party/xla/.github/workflows/benchmark.yml +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -name: A/B Diff Performance Benchmarking - -on: - pull_request: - types: [labeled, synchronize] - -jobs: - run-benchmark-at-pr: - # TODO(b/278787029): Rework triggering to prevent new labels from overriding job status. - if: contains(github.event.pull_request.labels.*.name, 'A/B diff benchmarking') - runs-on: - - self-hosted # must come first - - environment=testing - - cpu - - os-family=Linux - defaults: - run: - shell: bash - timeout-minutes: 60 - steps: - - name: "Checking out PR repository" - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 - - name: "Build docker" # TODO(b/277242108): build once and reference docker image by digest. - run: | - docker build --file build_tools/docker/dockerfiles/benchmarking.Dockerfile \ - --tag base \ - build_tools/docker/context - - name: "Benchmark at PR" - run: | - docker run --mount="type=bind,src="${PWD}",target=/work" --workdir="/work" \ - base:latest \ - build_tools/github_actions/build_xla.sh - - name: "Checking out base repository" - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 - with: - ref: "${{ github.event.pull_request.base.sha }}" - - name: "Benchmark at Base" - run: | - docker run --mount="type=bind,src="${PWD}",target=/work" --workdir="/work" \ - base:latest \ - build_tools/github_actions/build_xla.sh diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index f9229c0d13765e..c8897665644baa 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -174,170 +174,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp ---- stablehlo/stablehlo/dialect/AssemblyFormat.cpp -+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp -@@ -15,6 +15,7 @@ - - #include "stablehlo/dialect/AssemblyFormat.h" - -+#include - #include - #include - #include -@@ -130,6 +131,42 @@ - for (Type& t : opTypes) typePtrs.push_back(&t); - - return detail::parseSameOperandsAndResultTypeImpl(parser, typePtrs, result); -+} -+ -+void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value) { -+ assert(op->getNumResults() == 1); -+ // If not all types are the same, use generic form. -+ if (value.getType() != op->getResultTypes().front()) { -+ p.printGenericOp(op, /*printOpName=*/false); -+ return; -+ } -+ -+ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); -+ p << ' '; -+ p.printStrippedAttrOrType(value); -+} -+ -+ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) { -+ // Parse the generic form. -+ if (succeeded(parser.parseOptionalLParen())) { -+ if (parser.parseRParen()) return failure(); -+ if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -+ if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || -+ parser.parseArrow()) -+ return failure(); -+ Type resultTy; -+ if (parser.parseType(resultTy)) return failure(); -+ result.addTypes(resultTy); -+ return success(); -+ } -+ -+ ElementsAttr valueAttr; -+ if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -+ if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", -+ result.attributes)) -+ return failure(); -+ result.addTypes(valueAttr.getType()); -+ return success(); - } - - void printTupleOpType(OpAsmPrinter& p, Operation*, TypeRange operands, -diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h ---- stablehlo/stablehlo/dialect/AssemblyFormat.h -+++ stablehlo/stablehlo/dialect/AssemblyFormat.h -@@ -101,6 +101,16 @@ - SmallVectorImpl& operands, - SmallVectorImpl& opTypes, Type& result); - -+// Print a `constant` op. -+// -+// op ::= attr-dict $value -+// -+// When the `value` and `output` have different type, it just uses the default -+// operator assembly format as a fallback. -+void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value); -+ -+ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result); -+ - // TuplesOp - only print result type. Operand type is trivially inferrable. - // - // Inferring operand types from tuple type: -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp ---- stablehlo/stablehlo/dialect/StablehloOps.cpp -+++ stablehlo/stablehlo/dialect/StablehloOps.cpp -@@ -256,6 +256,16 @@ - // ConstantOp - //===----------------------------------------------------------------------===// - -+void ConstantOp::getAsmResultNames( -+ function_ref setNameFn) { -+ mlir::TensorType type = getType(); -+ if (type.getElementType().isa()) { -+ setNameFn(getResult(), "c"); -+ } else { -+ setNameFn(getResult(), "cst"); -+ } -+} -+ - OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { - assert(adaptor.getOperands().empty() && "constant has no operands"); - -@@ -311,44 +321,11 @@ - } - - ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { -- // Parse the generic form. -- if (succeeded(parser.parseOptionalLParen())) { -- if (parser.parseRParen()) return failure(); -- if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -- if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || -- parser.parseArrow()) -- return failure(); -- Type resultTy; -- if (parser.parseType(resultTy)) return failure(); -- result.addTypes(resultTy); -- return success(); -- } -- -- ElementsAttr valueAttr; -- if (parser.parseOptionalAttrDict(result.attributes)) return failure(); -- if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value", -- result.attributes)) -- return failure(); -- result.addTypes(valueAttr.getType()); -- return success(); --} -- --/// Print a `constant` op. --/// --/// op ::= attr-dict $value --/// --/// When the `value` and `output` have different type, it just uses the default --/// operator assembly format as a fallback. -+ return hlo::parseConstantOp(parser, result); -+} -+ - void ConstantOp::print(::mlir::OpAsmPrinter& p) { -- // If not all types are the same, use generic form. -- if (getValue().getType() != getType()) { -- p.printGenericOp(getOperation(), /*printOpName=*/false); -- return; -- } -- -- p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); -- p << ' '; -- p.printStrippedAttrOrType(getValueAttr()); -+ hlo::printConstantOp(p, getOperation(), getValue()); - } - - //===----------------------------------------------------------------------===// -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td ---- stablehlo/stablehlo/dialect/StablehloOps.td -+++ stablehlo/stablehlo/dialect/StablehloOps.td -@@ -38,6 +38,7 @@ - - let useDefaultAttributePrinterParser = 0; - let useDefaultTypePrinterParser = 0; -+ let usePropertiesForAttributes = 0; - } - - class StableHLO_Op traits = []> : -@@ -65,7 +66,8 @@ - //===----------------------------------------------------------------------===// - - def StableHLO_ConstantOp : StableHLO_Op<"constant", -- [ConstantLike, Pure, DeclareOpInterfaceMethods]> { -+ [ConstantLike, Pure, DeclareOpInterfaceMethods, -+ DeclareOpInterfaceMethods]> { - let summary = "Constant operation"; - let description = [{ - Produces an `output` tensor from a constant `value`. diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -2800,504 +2636,4 @@ diff --ruN a/stablehlo/stablehlo/reference/Ops.cpp b/stablehlo/stablehlo/referen result.set(resultIndex, dotProduct.get(*dotProductIt)); } } -diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ---- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -+++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -@@ -1317,8 +1317,8 @@ - // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_3]] - // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_4]] - // CHECK: %[[TMP_35:.*]] = stablehlo.constant dense<1.000000e+00> -- // CHECK-DAG: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] -- // CHECK-DAG: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] -+ // CHECK: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] -+ // CHECK: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] - // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_36]], %[[TMP_37]] - // CHECK: %[[TMP_39:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] - // CHECK: %[[TMP_40:.*]] = stablehlo.divide %[[TMP_3]], %[[TMP_39]] -@@ -1465,7 +1465,7 @@ - - // ----- - --// CHECK-LABEL: @polygamma_f32 -+// CHECK: @polygamma_f32 - // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> -@@ -1592,8 +1592,8 @@ - // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] - // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] - // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> -- // CHECK-DAG: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -- // CHECK-DAG: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] -+ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -+ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] - // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] -@@ -1602,7 +1602,7 @@ - // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] - // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] -- // CHECK: %[[TMP_134:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.39544646E-19> - // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] - // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] -@@ -1611,7 +1611,7 @@ - // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] - // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] -- // CHECK: %[[TMP_143:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.50900303E-18> - // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] - // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] -@@ -1620,7 +1620,7 @@ - // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] - // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] -- // CHECK: %[[TMP_152:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.17486866E-16> - // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] - // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] - // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] -@@ -1629,7 +1629,7 @@ - // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] - // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] -- // CHECK: %[[TMP_161:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.58606213E-15> - // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] - // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] - // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] -@@ -1638,7 +1638,7 @@ - // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] - // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] -- // CHECK: %[[TMP_170:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896803E-13> - // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] - // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] - // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] -@@ -1647,7 +1647,7 @@ - // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] - // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] -- // CHECK: %[[TMP_179:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.33825364E-11> - // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] - // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] - // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] -@@ -1656,7 +1656,7 @@ - // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] - // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] -- // CHECK: %[[TMP_188:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.28419031E-10> - // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] - // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] - // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] -@@ -1665,7 +1665,7 @@ - // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] - // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] -- // CHECK: %[[TMP_197:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767563E-8> - // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] - // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] - // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] -@@ -1674,7 +1674,7 @@ - // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] - // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] -- // CHECK: %[[TMP_206:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.26719599E-7> - // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] - // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] - // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] -@@ -1683,7 +1683,7 @@ - // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] - // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] -- // CHECK: %[[TMP_215:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.30687835E-5> - // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] - // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] - // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] -@@ -1692,13 +1692,13 @@ - // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] - // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] -- // CHECK: %[[TMP_224:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.00138888892> - // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] - // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] - // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] - // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] -- // CHECK: %[[TMP_230:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.0833333358> - // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] - // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] - // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] -@@ -1707,11 +1707,11 @@ - // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] - // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] - // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] -- // CHECK: %[[TMP_239:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<1.401300e-45> - // CHECK: %[[TMP_240:.*]] = stablehlo.multiply %[[TMP_238]], %[[TMP_239]] - // CHECK: %[[TMP_241:.*]] = stablehlo.compare LT, %[[TMP_237]], %[[TMP_240]], NOTYPE - // CHECK: %[[TMP_242:.*]] = stablehlo.select %[[TMP_241]], %[[TMP_120]], %[[TMP_236]] -- // CHECK: %[[TMP_243:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_243:.*]] = stablehlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_244:.*]] = stablehlo.compare LT, %[[TMP_5]], %[[TMP_123]], NOTYPE - // CHECK: %[[TMP_245:.*]] = stablehlo.select %[[TMP_244]], %[[TMP_243]], %[[TMP_242]] - // CHECK: %[[TMP_246:.*]] = stablehlo.compare LE, %[[ARG1]], %[[TMP_90]], NOTYPE -@@ -1719,7 +1719,7 @@ - // CHECK: %[[TMP_248:.*]] = stablehlo.compare NE, %[[TMP_5]], %[[TMP_247]], NOTYPE - // CHECK: %[[TMP_249:.*]] = stablehlo.and %[[TMP_246]], %[[TMP_248]] - // CHECK: %[[TMP_250:.*]] = stablehlo.select %[[TMP_249]], %[[TMP_243]], %[[TMP_245]] -- // CHECK: %[[TMP_251:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_251:.*]] = stablehlo.constant dense<0x7F800000> - // CHECK: %[[TMP_252:.*]] = stablehlo.floor %[[ARG1]] - // CHECK: %[[TMP_253:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[TMP_252]], NOTYPE - // CHECK: %[[TMP_254:.*]] = stablehlo.and %[[TMP_246]], %[[TMP_253]] -@@ -1744,8 +1744,8 @@ - // CHECK: %[[TMP_273:.*]] = stablehlo.subtract %[[ARG1]], %[[TMP_272]] - // CHECK: %[[TMP_274:.*]] = stablehlo.select %[[TMP_270]], %[[TMP_271]], %[[TMP_273]] - // CHECK-DAG: %[[TMP_275:.*]] = stablehlo.constant dense<0.000000e+00> -- // CHECK-DAG: %[[TMP_276:.*]] = stablehlo.constant -- // CHECK-DAG: %[[TMP_277:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_276:.*]] = stablehlo.constant dense<1.000000e+00> -+ // CHECK-DAG: %[[TMP_277:.*]] = stablehlo.constant dense<676.520386> - // CHECK-DAG: %[[TMP_278:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_279:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_278]] - // CHECK: %[[TMP_280:.*]] = stablehlo.multiply %[[TMP_279]], %[[TMP_279]] -@@ -1753,7 +1753,7 @@ - // CHECK: %[[TMP_282:.*]] = stablehlo.subtract %[[TMP_275]], %[[TMP_281]] - // CHECK: %[[TMP_283:.*]] = stablehlo.divide %[[TMP_277]], %[[TMP_279]] - // CHECK: %[[TMP_284:.*]] = stablehlo.add %[[TMP_276]], %[[TMP_283]] -- // CHECK-DAG: %[[TMP_285:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_285:.*]] = stablehlo.constant dense<-1259.13916> - // CHECK-DAG: %[[TMP_286:.*]] = stablehlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_287:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_286]] - // CHECK: %[[TMP_288:.*]] = stablehlo.multiply %[[TMP_287]], %[[TMP_287]] -@@ -1761,7 +1761,7 @@ - // CHECK: %[[TMP_290:.*]] = stablehlo.subtract %[[TMP_282]], %[[TMP_289]] - // CHECK: %[[TMP_291:.*]] = stablehlo.divide %[[TMP_285]], %[[TMP_287]] - // CHECK: %[[TMP_292:.*]] = stablehlo.add %[[TMP_284]], %[[TMP_291]] -- // CHECK-DAG: %[[TMP_293:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_293:.*]] = stablehlo.constant dense<771.323425> - // CHECK-DAG: %[[TMP_294:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_295:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_294]] - // CHECK: %[[TMP_296:.*]] = stablehlo.multiply %[[TMP_295]], %[[TMP_295]] -@@ -1769,15 +1769,15 @@ - // CHECK: %[[TMP_298:.*]] = stablehlo.subtract %[[TMP_290]], %[[TMP_297]] - // CHECK: %[[TMP_299:.*]] = stablehlo.divide %[[TMP_293]], %[[TMP_295]] - // CHECK: %[[TMP_300:.*]] = stablehlo.add %[[TMP_292]], %[[TMP_299]] -- // CHECK-DAG: %[[TMP_301:.*]] = stablehlo.constant -- // CHECK-DAG: %[[TMP_302:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_301:.*]] = stablehlo.constant dense<-176.615036> -+ // CHECK-DAG: %[[TMP_302:.*]] = stablehlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_303:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_302]] - // CHECK: %[[TMP_304:.*]] = stablehlo.multiply %[[TMP_303]], %[[TMP_303]] - // CHECK: %[[TMP_305:.*]] = stablehlo.divide %[[TMP_301]], %[[TMP_304]] - // CHECK: %[[TMP_306:.*]] = stablehlo.subtract %[[TMP_298]], %[[TMP_305]] - // CHECK: %[[TMP_307:.*]] = stablehlo.divide %[[TMP_301]], %[[TMP_303]] - // CHECK: %[[TMP_308:.*]] = stablehlo.add %[[TMP_300]], %[[TMP_307]] -- // CHECK-DAG: %[[TMP_309:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_309:.*]] = stablehlo.constant dense<12.5073433> - // CHECK-DAG: %[[TMP_310:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_311:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_310]] - // CHECK: %[[TMP_312:.*]] = stablehlo.multiply %[[TMP_311]], %[[TMP_311]] -@@ -1785,7 +1785,7 @@ - // CHECK: %[[TMP_314:.*]] = stablehlo.subtract %[[TMP_306]], %[[TMP_313]] - // CHECK: %[[TMP_315:.*]] = stablehlo.divide %[[TMP_309]], %[[TMP_311]] - // CHECK: %[[TMP_316:.*]] = stablehlo.add %[[TMP_308]], %[[TMP_315]] -- // CHECK-DAG: %[[TMP_317:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_317:.*]] = stablehlo.constant dense<-0.138571098> - // CHECK-DAG: %[[TMP_318:.*]] = stablehlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_319:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_318]] - // CHECK: %[[TMP_320:.*]] = stablehlo.multiply %[[TMP_319]], %[[TMP_319]] -@@ -1793,7 +1793,7 @@ - // CHECK: %[[TMP_322:.*]] = stablehlo.subtract %[[TMP_314]], %[[TMP_321]] - // CHECK: %[[TMP_323:.*]] = stablehlo.divide %[[TMP_317]], %[[TMP_319]] - // CHECK: %[[TMP_324:.*]] = stablehlo.add %[[TMP_316]], %[[TMP_323]] -- // CHECK-DAG: %[[TMP_325:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_325:.*]] = stablehlo.constant dense<9.98436917E-6> - // CHECK-DAG: %[[TMP_326:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_327:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_326]] - // CHECK: %[[TMP_328:.*]] = stablehlo.multiply %[[TMP_327]], %[[TMP_327]] -@@ -1801,7 +1801,7 @@ - // CHECK: %[[TMP_330:.*]] = stablehlo.subtract %[[TMP_322]], %[[TMP_329]] - // CHECK: %[[TMP_331:.*]] = stablehlo.divide %[[TMP_325]], %[[TMP_327]] - // CHECK: %[[TMP_332:.*]] = stablehlo.add %[[TMP_324]], %[[TMP_331]] -- // CHECK-DAG: %[[TMP_333:.*]] = stablehlo.constant -+ // CHECK-DAG: %[[TMP_333:.*]] = stablehlo.constant dense<1.50563267E-7> - // CHECK-DAG: %[[TMP_334:.*]] = stablehlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_335:.*]] = stablehlo.add %[[TMP_274]], %[[TMP_334]] - // CHECK: %[[TMP_336:.*]] = stablehlo.multiply %[[TMP_335]], %[[TMP_335]] -@@ -1811,7 +1811,7 @@ - // CHECK: %[[TMP_340:.*]] = stablehlo.add %[[TMP_332]], %[[TMP_339]] - // CHECK: %[[TMP_341:.*]] = stablehlo.constant dense<7.500000e+00> - // CHECK: %[[TMP_342:.*]] = stablehlo.add %[[TMP_341]], %[[TMP_274]] -- // CHECK: %[[TMP_343:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_343:.*]] = stablehlo.constant dense<2.01490307> - // CHECK: %[[TMP_344:.*]] = stablehlo.divide %[[TMP_274]], %[[TMP_341]] - // CHECK: %[[TMP_345:.*]] = stablehlo.log_plus_one %[[TMP_344]] - // CHECK: %[[TMP_346:.*]] = stablehlo.add %[[TMP_343]], %[[TMP_345]] -@@ -1825,7 +1825,7 @@ - // CHECK: %[[TMP_354:.*]] = stablehlo.floor %[[TMP_353]] - // CHECK: %[[TMP_355:.*]] = stablehlo.abs %[[TMP_354]] - // CHECK: %[[TMP_356:.*]] = stablehlo.add %[[ARG1]], %[[TMP_355]] -- // CHECK: %[[TMP_357:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_357:.*]] = stablehlo.constant dense<3.14159274> - // CHECK: %[[TMP_358:.*]] = stablehlo.multiply %[[TMP_357]], %[[TMP_356]] - // CHECK: %[[TMP_359:.*]] = stablehlo.cosine %[[TMP_358]] - // CHECK: %[[TMP_360:.*]] = stablehlo.sine %[[TMP_358]] -@@ -1837,14 +1837,14 @@ - // CHECK: %[[TMP_366:.*]] = stablehlo.floor %[[ARG1]] - // CHECK: %[[TMP_367:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[TMP_366]], NOTYPE - // CHECK: %[[TMP_368:.*]] = stablehlo.and %[[TMP_365]], %[[TMP_367]] -- // CHECK: %[[TMP_369:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_369:.*]] = stablehlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_370:.*]] = stablehlo.select %[[TMP_368]], %[[TMP_369]], %[[TMP_364]] - // CHECK: %[[TMP_371:.*]] = stablehlo.select %[[TMP_268]], %[[TMP_370]], %[[TMP_266]] - // CHECK: %[[TMP_372:.*]] = stablehlo.floor %[[ARG0]] - // CHECK: %[[TMP_373:.*]] = stablehlo.compare NE, %[[ARG0]], %[[TMP_372]], NOTYPE - // CHECK: %[[TMP_374:.*]] = stablehlo.compare LT, %[[ARG0]], %[[TMP_267]], NOTYPE - // CHECK: %[[TMP_375:.*]] = stablehlo.or %[[TMP_373]], %[[TMP_374]] -- // CHECK: %[[TMP_376:.*]] = stablehlo.constant -+ // CHECK: %[[TMP_376:.*]] = stablehlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_377:.*]] = stablehlo.select %[[TMP_375]], %[[TMP_376]], %[[TMP_371]] - %1 = chlo.polygamma %lhs, %rhs : tensor, tensor -> tensor - func.return %1 : tensor -@@ -1852,7 +1852,7 @@ - - // ----- - --// CHECK-LABEL: @polygamma_f64 -+// CHECK: @polygamma_f64 - // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> -@@ -1979,8 +1979,8 @@ - // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] - // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] - // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> -- // CHECK-DAG: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -- // CHECK-DAG: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] -+ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] -+ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] - // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] -@@ -2492,11 +2492,11 @@ - // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) - func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - // CHECK: %[[IOTA:.*]] = stablehlo.iota dim = 1 : tensor<16x16xi32> -- // CHECK-NEXT: %[[SORT:.*]]:2 = "stablehlo.sort"(%[[ARG]], %[[IOTA]]) <{dimension = 1 : i64, is_stable = true}> ({ -+ // CHECK-NEXT: %[[SORT:.*]]:2 = "stablehlo.sort"(%[[ARG]], %[[IOTA]]) ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): - // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: stablehlo.return %[[CMP]] -- // CHECK-NEXT: }) : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) -+ // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - // CHECK-NEXT: %[[VAL:.*]] = stablehlo.slice %[[SORT]]#0 [0:16, 0:8] : (tensor<16x16xf32>) -> tensor<16x8xf32> - // CHECK-NEXT: %[[IDX:.*]] = stablehlo.slice %[[SORT]]#1 [0:16, 0:8] : (tensor<16x16xi32>) -> tensor<16x8xi32> - // CHECK-NEXT: return %[[VAL]], %[[IDX]] -@@ -2521,11 +2521,11 @@ - // CHECK-NEXT: [[K_I32x1:%.*]] = stablehlo.reshape [[K_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[RESULT_SHAPE:%.*]] = stablehlo.concatenate [[DIM_0_I32x1]], [[DIM_1_I32x1]], [[K_I32x1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[IOTA:%.*]] = stablehlo.dynamic_iota [[IOTA_SHAPE]], dim = 2 : (tensor<3xi32>) -> tensor -- // CHECK-NEXT: [[SORT:%.*]]:2 = "stablehlo.sort"([[ARG]], [[IOTA]]) <{dimension = 2 : i64, is_stable = true}> ({ -+ // CHECK-NEXT: [[SORT:%.*]]:2 = "stablehlo.sort"([[ARG]], [[IOTA]]) ({ - // CHECK-NEXT: ^bb0([[ARG_1:%.*]]: tensor, [[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor, [[ARG_4:%.*]]: tensor): - // CHECK-NEXT: [[CMP:%.*]] = stablehlo.compare GT, [[ARG_1]], [[ARG_2]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: stablehlo.return [[CMP]] : tensor -- // CHECK-NEXT: }) : (tensor, tensor) -> (tensor, tensor) -+ // CHECK-NEXT: }) {dimension = 2 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) - // CHECK-NEXT: [[STARTS:%.*]] = stablehlo.constant dense<0> : tensor<3xi64> - // CHECK-NEXT: [[LIMITS:%.*]] = stablehlo.convert [[RESULT_SHAPE]] : (tensor<3xi32>) -> tensor<3xi64> - // CHECK-NEXT: [[STRIDES:%.*]] = stablehlo.constant dense<1> : tensor<3xi64> -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo.mlir -@@ -2014,7 +2014,7 @@ - - // CHECK-LABEL: func @rng_normal - func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf32> { -- %cst = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> -+ %cst = "stablehlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64> - %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> - func.return %0 : tensor<2x3x5xf32> - } -@@ -2030,7 +2030,7 @@ - // ----- - - func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { -- %cst = stablehlo.constant dense<7> : tensor<1xi64> -+ %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error@+2 {{failed to infer returned types}} - // expected-error @+1 {{inferred type(s) 'tensor<7xf32>' are incompatible with return type(s) of operation 'tensor<12xf32>'}} - %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> -@@ -2067,7 +2067,7 @@ - // ----- - - func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { -- %cst = stablehlo.constant dense<7> : tensor<1xi64> -+ %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} - %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> - func.return -@@ -2691,7 +2691,7 @@ - // CHECK-LABEL: func @constants - func.func @constants() -> () { - // CHECK: stablehlo.constant dense<0> : tensor -- %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> (tensor) -+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> (tensor) - - // CHECK: stablehlo.constant {extra_attr = 3 : i32} dense<0> : tensor - %1 = "stablehlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) -@@ -2703,7 +2703,7 @@ - func.func @constant_invalid() -> () { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{'stablehlo.constant' op inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor<3xi32>'}} -- %0 = "stablehlo.constant"() <{value = dense<0> : tensor}> : () -> (tensor<3xi32>) -+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> (tensor<3xi32>) - func.return - } - -@@ -2711,7 +2711,7 @@ - - func.func @constant_invalid() -> () { - // expected-error@+1 {{op result #0 must be statically shaped tensor}} -- %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor -+ %0 = "stablehlo.constant"() {value = dense<1> : tensor} : () -> tensor - func.return - } - -@@ -2719,7 +2719,7 @@ - - func.func @constant_invalid() -> () { - // expected-error@+1 {{elements literal type must have static shape}} -- %0 = "stablehlo.constant"() <{value = dense<1> : tensor}> : () -> tensor -+ %0 = "stablehlo.constant"() {value = dense<1> : tensor} : () -> tensor - func.return - } - -@@ -4872,7 +4872,7 @@ - %3 = stablehlo.uniform_quantize %2 : (tensor<2xf32>) -> tensor<2x!quant.uniform> - %4 = stablehlo.uniform_quantize %1 : (tensor<2xf32>) -> tensor<2x!quant.uniform> - func.return %0, %4, %3 : tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform> -- // CHECK: stablehlo.constant() <{value = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2x!quant.uniform> -+ // CHECK: stablehlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform> - // CHECK-NEXT: stablehlo.constant dense<[1.000000e+01, 1.200000e+01]> : tensor<2xf32> - // CHECK-NEXT: stablehlo.constant dense<[3.000000e+00, 1.000000e+02]> : tensor<2xf32> - } -diff --ruN a/stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir ---- stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir -+++ stablehlo/stablehlo/tests/shape_legalize_to_stablehlo.mlir -@@ -14,11 +14,11 @@ - // CHECK-NEXT: %[[INPUT_SIZE_PRODUCT:.*]] = stablehlo.multiply %[[TMP1]], %[[INPUT_SIZE1]] : tensor - // CHECK-NEXT: %[[COMPUTED_SIZE:.*]] = stablehlo.divide %[[ARG0_I32]], %[[INPUT_SIZE_PRODUCT]] : tensor - // CHECK-NEXT: %[[M1:.*]] = stablehlo.constant dense<-1> : tensor -- // CHECK-NEXT: %[[INPUT_SIZE0_EQ_M1:.*]] = stablehlo.compare EQ, %3, %[[M1]], NOTYPE : (tensor, tensor) -> tensor -- // CHECK-NEXT: %[[RESULT_SIZE0:.*]] = stablehlo.select %[[INPUT_SIZE0_EQ_M1]], %[[COMPUTED_SIZE]], %3 : tensor, tensor -+ // CHECK-NEXT: %[[INPUT_SIZE0_EQ_M1:.*]] = stablehlo.compare EQ, %[[INPUT_SIZE0]], %[[M1]], NOTYPE : (tensor, tensor) -> tensor -+ // CHECK-NEXT: %[[RESULT_SIZE0:.*]] = stablehlo.select %[[INPUT_SIZE0_EQ_M1]], %[[COMPUTED_SIZE]], %[[INPUT_SIZE0]] : tensor, tensor - // CHECK-NEXT: %[[RESULT_SIZE0x1:.*]] = stablehlo.reshape %[[RESULT_SIZE0]] : (tensor) -> tensor<1xi32> -- // CHECK-NEXT: %[[INPUT_SIZE1_EQ_M1:.*]] = stablehlo.compare EQ, %6, %[[M1]], NOTYPE : (tensor, tensor) -> tensor -- // CHECK-NEXT: %[[RESULT_SIZE1:.*]] = stablehlo.select %[[INPUT_SIZE1_EQ_M1]], %[[COMPUTED_SIZE]], %6 : tensor, tensor -+ // CHECK-NEXT: %[[INPUT_SIZE1_EQ_M1:.*]] = stablehlo.compare EQ, %[[INPUT_SIZE1]], %[[M1]], NOTYPE : (tensor, tensor) -> tensor -+ // CHECK-NEXT: %[[RESULT_SIZE1:.*]] = stablehlo.select %[[INPUT_SIZE1_EQ_M1]], %[[COMPUTED_SIZE]], %[[INPUT_SIZE1]] : tensor, tensor - // CHECK-NEXT: %[[RESULT_SIZE1x1:.*]] = stablehlo.reshape %[[RESULT_SIZE1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.concatenate %[[RESULT_SIZE0x1]], %[[RESULT_SIZE1x1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir ---- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -+++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -@@ -137,7 +137,7 @@ - - // CHECK-LABEL: func @custom_call_inapplicable_missing_indices_of_shape_operands - func.func @custom_call_inapplicable_missing_indices_of_shape_operands(%arg0: tensor<4xf32>) -> tensor<1x2xf32> { -- // CHECK: stablehlo.custom_call @foo(%arg0, %0) -+ // CHECK: stablehlo.custom_call @foo(%arg0, %c) - %0 = stablehlo.constant dense<[1, 2]> : tensor<2xi64> - %1 = stablehlo.custom_call @foo(%arg0, %0) : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x2xf32> - return %1 : tensor<1x2xf32> -@@ -147,7 +147,7 @@ - - // CHECK-LABEL: func @custom_call_inapplicable_dynamic_result_type - func.func @custom_call_inapplicable_dynamic_result_type(%arg0: tensor<4xf32>) -> tensor<1x?xf32> { -- // CHECK: stablehlo.custom_call @foo(%arg0, %0) -+ // CHECK: stablehlo.custom_call @foo(%arg0, %c) - %0 = stablehlo.constant dense<[1, 2]> : tensor<2xi64> - %1 = stablehlo.custom_call @foo(%arg0, %0) { - indices_of_shape_operands = dense<[1]> : tensor<1xi64> -@@ -272,7 +272,7 @@ - // CHECK-LABEL: @dynamic_gather_success_static_result_type - func.func @dynamic_gather_success_static_result_type(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { - // CHECK-NOT: stablehlo.dynamic_gather -- // CHECK: "stablehlo.gather"(%arg0, %arg1) <{ -+ // CHECK: "stablehlo.gather"(%arg0, %arg1) { - // CHECK-SAME: dimension_numbers = #stablehlo.gather< - // CHECK-SAME: offset_dims = [2], - // CHECK-SAME: collapsed_slice_dims = [0, 1], -@@ -280,7 +280,7 @@ - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: >, - // CHECK-SAME: slice_sizes = array -- // CHECK-SAME: }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x8xi32> -+ // CHECK-SAME: } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x8xi32> - %0 = stablehlo.constant dense<[1, 1, 8]> : tensor<3xi32> - %1 = "stablehlo.dynamic_gather"(%arg0, %arg1, %0) { - dimension_numbers = #stablehlo.gather< -@@ -298,7 +298,7 @@ - // CHECK-LABEL: @dynamic_gather_success_dynamic_result_type - func.func @dynamic_gather_success_dynamic_result_type(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x?xi32> { - // CHECK-NOT: stablehlo.dynamic_gather -- // CHECK: "stablehlo.gather"(%arg0, %arg1) <{ -+ // CHECK: "stablehlo.gather"(%arg0, %arg1) { - // CHECK-SAME: dimension_numbers = #stablehlo.gather< - // CHECK-SAME: offset_dims = [2], - // CHECK-SAME: collapsed_slice_dims = [0, 1], -@@ -306,16 +306,16 @@ - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: >, - // CHECK-SAME: slice_sizes = array -- // CHECK-SAME: }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x?xi32> -+ // CHECK-SAME: } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>) -> tensor<1x5x?xi32> - %0 = stablehlo.constant dense<[1, 1, 8]> : tensor<3xi32> -- %1 = "stablehlo.dynamic_gather"(%arg0, %arg1, %0) <{ -+ %1 = "stablehlo.dynamic_gather"(%arg0, %arg1, %0) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - > -- }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x?xi32> -+ } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x?xi32> - return %1 : tensor<1x5x?xi32> - } - -@@ -324,14 +324,14 @@ - // CHECK-LABEL: @dynamic_gather_inapplicable_dynamic_slice_sizes - func.func @dynamic_gather_inapplicable_dynamic_slice_sizes(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xi32> { - // CHECK: stablehlo.dynamic_gather -- %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) <{ -+ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - > -- }> : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xi32> -+ } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xi32> - return %0 : tensor<1x5x8xi32> - } - diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index abaa26639918b6..86613c0dcd88cf 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "714d9acacd96760f4d8a9fe3898bee4f204cd419" - STABLEHLO_SHA256 = "4912edb1bef3362862e4942a624de7430eafbb0a39d93c0120a2ffbfcbedf11a" + STABLEHLO_COMMIT = "e81411ef562e11337283ef24bb3c40b2f3a6ebfa" + STABLEHLO_SHA256 = "167f15fbdfc3dc54601b6e37d53bce7323123f701893deb3935c8629a763766a" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index d7ae76f096431a..1a47809a8f9753 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -251,7 +251,7 @@ build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang" +build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" @@ -602,15 +602,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium -# Ensure release_base is set on linux -build:release_linux_base --config=release_base - -# Target the AVX instruction set -build:release_linux_base --config=avx_linux - # Enable support for all targets build:release_base --config=cpu_cross +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Disable clang extension that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within @@ -633,8 +630,8 @@ build:release_linux_base --action_env PYTHON_BIN_PATH="/usr/bin/python3" build:release_linux_base --action_env PYTHON_LIB_PATH="/usr/lib/tf_python" build:release_linux_base --python_path="/usr/bin/python3" # Set Clang as compiler. Use the actual path to clang installed in container. -build:release_cpu_linux_base --repo_env=CC="/usr/lib/llvm-17/bin/clang" -build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-17/bin/clang" +build:release_cpu_linux_base --repo_env=CC="/usr/lib/llvm-18/bin/clang" +build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/clang" # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS @@ -645,6 +642,8 @@ test:release_linux_base --test_summary=short # Use the Clang toolchain to compile build:release_cpu_linux --config=release_linux_base build:release_cpu_linux --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" +# Target the AVX instruction set +build:release_cpu_linux --config=avx_linux build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 158ea9dc0be0b3..c707d71c20d603 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "16ceace318b63a569138261331c66d7809d079de" - TFRT_SHA256 = "f7572ad7b08aa59023e25b9907710ec10fff495837bed57088c9da16d46c2174" + TFRT_COMMIT = "9fdfdeada1eb04e11e0db68461b5bd1dcdb02062" + TFRT_SHA256 = "1e310a961a1248efd767fadfda4c205d3e05ca15205bb8d299ca4037f42f9b18" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/third_party/xla/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl index 4455aea60109fa..6a7a4f116843d9 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl @@ -74,22 +74,23 @@ def aarch64_compiler_configure(): "ml2014_clang_aarch64-python3.9": "docker://localhost/tensorflow-build-aarch64:latest-python3.9", "ml2014_clang_aarch64-python3.10": "docker://localhost/tensorflow-build-aarch64:latest-python3.10", "ml2014_clang_aarch64-python3.11": "docker://localhost/tensorflow-build-aarch64:latest-python3.11", + "ml2014_clang_aarch64-python3.12": "docker://localhost/tensorflow-build-aarch64:latest-python3.12", }, env = { "ABI_LIBC_VERSION": "glibc_2.17", "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", + "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", "BAZEL_HOST_SYSTEM": "aarch64-unknown-linux-gnu", "BAZEL_TARGET_CPU": "generic", "BAZEL_TARGET_LIBC": "glibc_2.17", "BAZEL_TARGET_SYSTEM": "aarch64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", + "CC": "/usr/lib/llvm-18/bin/clang", "CC_TOOLCHAIN_NAME": "linux_llvm_aarch64", "CLEAR_CACHE": "1", "CUDNN_INSTALL_PATH": "", - "CLANG_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", + "CLANG_COMPILER_PATH": "/usr/lib/llvm-18/bin/clang", + "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", + "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", "PYTHON_BIN_PATH": "/usr/local/bin/python3", "TENSORRT_INSTALL_PATH": "", "TF_CUDA_CLANG": "0", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD index 7cf6d8c3747b27..a69c431c08f5a4 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD @@ -62,15 +62,15 @@ cc_toolchain_config( cpu = "k8", cxx_builtin_include_directories = [ "/dt9/", - "/usr/lib/llvm-17/include/", - "/usr/lib/llvm-17/lib/clang/17/include", + "/usr/lib/llvm-18/include/", + "/usr/lib/llvm-18/lib/clang/18/include", ], dbg_compile_flags = ["-g"], host_system_name = "linux", link_flags = [ "--target=x86_64-unknown-linux-gnu", "-fuse-ld=lld", - "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "--ld-path=/usr/lib/llvm-18/bin/ld.lld", "-Wl,--undefined-version", ], link_libs = [ @@ -90,14 +90,14 @@ cc_toolchain_config( target_libc = "", target_system_name = "x86_64-unknown-linux-gnu", tool_paths = { - "gcc": "/usr/lib/llvm-17/bin/clang", - "ld": "/usr/lib/llvm-17/bin/ld.lld", - "ar": "/usr/lib/llvm-17/bin/llvm-ar", - "cpp": "/usr/lib/llvm-17/bin/clang++", - "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", - "nm": "/usr/lib/llvm-17/bin/llvm-nm", - "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", - "strip": "/usr/lib/llvm-17/bin/llvm-strip", + "gcc": "/usr/lib/llvm-18/bin/clang", + "ld": "/usr/lib/llvm-18/bin/ld.lld", + "ar": "/usr/lib/llvm-18/bin/llvm-ar", + "cpp": "/usr/lib/llvm-18/bin/clang++", + "llvm-cov": "/usr/lib/llvm-18/bin/llvm-cov", + "nm": "/usr/lib/llvm-18/bin/llvm-nm", + "objdump": "/usr/lib/llvm-18/bin/llvm-objdump", + "strip": "/usr/lib/llvm-18/bin/llvm-strip", }, toolchain_identifier = "linux_x86_toolchain", unfiltered_compile_flags = [ @@ -148,15 +148,15 @@ cc_toolchain_config( cpu = "aarch64", cxx_builtin_include_directories = [ "/dt10/", - "/usr/lib/llvm-17/include/", - "/usr/lib/llvm-17/lib/clang/17/include", + "/usr/lib/llvm-18/include/", + "/usr/lib/llvm-18/lib/clang/18/include", ], dbg_compile_flags = ["-g"], host_system_name = "linux", link_flags = [ "--target=aarch64-unknown-linux-gnu", "-fuse-ld=lld", - "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "--ld-path=/usr/lib/llvm-18/bin/ld.lld", "-Wl,--undefined-version", ], link_libs = [ @@ -176,14 +176,14 @@ cc_toolchain_config( target_libc = "", target_system_name = "aarch64-unknown-linux-gnu", tool_paths = { - "gcc": "/usr/lib/llvm-17/bin/clang", - "ld": "/usr/lib/llvm-17/bin/ld.lld", - "ar": "/usr/lib/llvm-17/bin/llvm-ar", - "cpp": "/usr/lib/llvm-17/bin/clang++", - "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", - "nm": "/usr/lib/llvm-17/bin/llvm-nm", - "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", - "strip": "/usr/lib/llvm-17/bin/llvm-strip", + "gcc": "/usr/lib/llvm-18/bin/clang", + "ld": "/usr/lib/llvm-18/bin/ld.lld", + "ar": "/usr/lib/llvm-18/bin/llvm-ar", + "cpp": "/usr/lib/llvm-18/bin/clang++", + "llvm-cov": "/usr/lib/llvm-18/bin/llvm-cov", + "nm": "/usr/lib/llvm-18/bin/llvm-nm", + "objdump": "/usr/lib/llvm-18/bin/llvm-objdump", + "strip": "/usr/lib/llvm-18/bin/llvm-strip", }, toolchain_identifier = "linux_aarch64_toolchain", unfiltered_compile_flags = [ @@ -238,8 +238,8 @@ cc_toolchain_config( cpu = "darwin", cxx_builtin_include_directories = [ "%sysroot%/usr/include", - "/usr/lib/llvm-17/include/", - "/usr/lib/llvm-17/lib/clang/17/include", + "/usr/lib/llvm-18/include/", + "/usr/lib/llvm-18/lib/clang/18/include", "%sysroot%/System/Library/Frameworks/Security.framework/Headers", "%sysroot%/System/Library/Frameworks/CoreFoundation.framework/Headers", "%sysroot%/System/Library/Frameworks/SystemConfiguration.framework/Headers", @@ -250,7 +250,7 @@ cc_toolchain_config( "--target=x86_64-apple-darwin", "-lSystem", "-fuse-ld=lld", - "--ld-path=/usr/lib/llvm-17/bin/ld64.lld", + "--ld-path=/usr/lib/llvm-18/bin/ld64.lld", "-headerpad_max_install_names", "-Wl,-undefined,dynamic_lookup", # Target Catalina as the minimum supported OS @@ -274,13 +274,13 @@ cc_toolchain_config( target_system_name = "x86_64-apple-macosx10.15", tool_paths = { "gcc": "cc_wrapper.sh", - "ld": "/usr/lib/llvm-17/bin/ld64.lld", - "ar": "/usr/lib/llvm-17/bin/llvm-libtool-darwin", - "cpp": "/usr/lib/llvm-17/bin/clang++", - "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", - "nm": "/usr/lib/llvm-17/bin/llvm-nm", - "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", - "strip": "/usr/lib/llvm-17/bin/llvm-strip", + "ld": "/usr/lib/llvm-18/bin/ld64.lld", + "ar": "/usr/lib/llvm-18/bin/llvm-libtool-darwin", + "cpp": "/usr/lib/llvm-18/bin/clang++", + "llvm-cov": "/usr/lib/llvm-18/bin/llvm-cov", + "nm": "/usr/lib/llvm-18/bin/llvm-nm", + "objdump": "/usr/lib/llvm-18/bin/llvm-objdump", + "strip": "/usr/lib/llvm-18/bin/llvm-strip", }, toolchain_identifier = "macos_x86_toolchain", unfiltered_compile_flags = [ diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD index e60a32aced24e7..7b60e5bacf2dea 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD @@ -9,7 +9,7 @@ platform( "@platforms//cpu:x86_64", ], exec_properties = { - "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:06040763c500bd2ebaaa4585d4729c88d2c8ccec94baa7fbe9bbe3dc2827d79d", "OSFamily": "Linux", }, ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 756c7bbe8d786c..d8f137804ae52c 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -728,11 +728,11 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.17": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:8ca6b205b54f18d26a053cfe606145b8b11cc99cf83fc970a936ce327913c3c3", - "sigbuild-r2.17-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:5cfd081a337548165a800546f2365a38245e38e7a97052b1a21830bf66b2356d", - "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:933c9f4bf65c92780863e00bd2132c6cfd41dbd624736c1af0dd2a5a056db6b8", + "sigbuild-r2.17": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:d0f27a4c7b97dbe9d530703dca3449afd464758e56b3ac4e1609c701223a0572", + "sigbuild-r2.17-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:64e68a1d65ac265a2a59c8c2f6eb1f2148a323048a679a08e53239d467fa1478", + "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. @@ -767,29 +767,29 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.17-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:8ca6b205b54f18d26a053cfe606145b8b11cc99cf83fc970a936ce327913c3c3", - "sigbuild-r2.17-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:5cfd081a337548165a800546f2365a38245e38e7a97052b1a21830bf66b2356d", - "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:933c9f4bf65c92780863e00bd2132c6cfd41dbd624736c1af0dd2a5a056db6b8", + "sigbuild-r2.17-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:d0f27a4c7b97dbe9d530703dca3449afd464758e56b3ac4e1609c701223a0572", + "sigbuild-r2.17-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:64e68a1d65ac265a2a59c8c2f6eb1f2148a323048a679a08e53239d467fa1478", + "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. env = { "ABI_LIBC_VERSION": "glibc_2.19", "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", + "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", "BAZEL_TARGET_CPU": "k8", "BAZEL_TARGET_LIBC": "glibc_2.19", "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", + "CC": "/usr/lib/llvm-18/bin/clang", "CC_TOOLCHAIN_NAME": "linux_gnu_x86", "CLEAR_CACHE": "1", "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", - "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", + "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-18/bin/clang", + "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", + "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", diff --git a/third_party/xla/third_party/tsl/tsl/platform/base64.cc b/third_party/xla/third_party/tsl/tsl/platform/base64.cc index 6421b5ec920010..7cf8f2d606887f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/base64.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/base64.cc @@ -63,7 +63,7 @@ inline uint32 Convert(char x) { return static_cast(z); } -Status DecodeThreeChars(const char* codes, char* result) { +absl::Status DecodeThreeChars(const char* codes, char* result) { const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) | (Convert(codes[2]) << 6) | (Convert(codes[3])); // Convert() return value has upper 25 bits set if input is invalid. @@ -74,19 +74,19 @@ Status DecodeThreeChars(const char* codes, char* result) { result[0] = static_cast(packed >> 16); result[1] = static_cast(packed >> 8); result[2] = static_cast(packed); - return OkStatus(); + return absl::OkStatus(); } } // namespace template -Status Base64Decode(StringPiece data, T* decoded) { +absl::Status Base64Decode(StringPiece data, T* decoded) { if (decoded == nullptr) { return errors::Internal("'decoded' cannot be nullptr."); } if (data.empty()) { decoded->clear(); - return OkStatus(); + return absl::OkStatus(); } // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be @@ -138,16 +138,16 @@ Status Base64Decode(StringPiece data, T* decoded) { current += remain - 1; decoded->assign(buffer.get(), current - buffer.get()); - return OkStatus(); + return absl::OkStatus(); } template -Status Base64Encode(StringPiece source, T* encoded) { +absl::Status Base64Encode(StringPiece source, T* encoded) { return Base64Encode(source, false, encoded); } template -Status Base64Encode(StringPiece source, bool with_padding, T* encoded) { +absl::Status Base64Encode(StringPiece source, bool with_padding, T* encoded) { const char* const base64_chars = kBase64UrlSafeChars; if (encoded == nullptr) { return errors::Internal("'encoded' cannot be nullptr."); @@ -196,7 +196,7 @@ Status Base64Encode(StringPiece source, bool with_padding, T* encoded) { } encoded->assign(buffer.get(), current - buffer.get()); - return OkStatus(); + return absl::OkStatus(); } template Status Base64Decode(StringPiece data, diff --git a/third_party/xla/third_party/tsl/tsl/platform/base64.h b/third_party/xla/third_party/tsl/tsl/platform/base64.h index 888a3ebb35545a..fa2ad0ad40d618 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/base64.h +++ b/third_party/xla/third_party/tsl/tsl/platform/base64.h @@ -27,16 +27,16 @@ namespace tsl { /// /// See https://en.wikipedia.org/wiki/Base64 template -Status Base64Encode(StringPiece source, bool with_padding, T* encoded); +absl::Status Base64Encode(StringPiece source, bool with_padding, T* encoded); template -Status Base64Encode(StringPiece source, - T* encoded); // with_padding=false. +absl::Status Base64Encode(StringPiece source, + T* encoded); // with_padding=false. /// \brief Converts data from web-safe base64 encoding. /// /// See https://en.wikipedia.org/wiki/Base64 template -Status Base64Decode(StringPiece data, T* decoded); +absl::Status Base64Decode(StringPiece data, T* decoded); // Explicit instantiations defined in base64.cc. extern template Status Base64Decode(StringPiece data, diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto index cc16b2141e0e7a..695db935f6a0b4 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto +++ b/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto @@ -192,3 +192,12 @@ enum FusedMHAKind { BMM1_OUTPUT_INPUT_TYPE = 1; BMM1_OUTPUT_FLOAT = 2; } + +// FusedMHAMaskKind kind +enum FMHAMaskKind { + NO_MASK = 0; + PADDING = 1; + CAUSAL = 2; + PADDING_CAUSAL = 3; + ALIBI = 4; +} diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index d5004e732eece7..823417850b1cf5 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -6,7 +6,6 @@ load("@bazel_skylib//lib:versions.bzl", "versions") # Import external repository rules. load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@tf_runtime//:dependencies.bzl", "tfrt_dependencies") load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # Import third party repository rules. See go/tfbr-thirdparty. @@ -386,10 +385,10 @@ def _tf_repositories(): tf_http_archive( name = "zlib", build_file = "//third_party:zlib.BUILD", - sha256 = "b3a24de97a8fdbc835b9833169501030b8977031bcb54b3b3ac13740f846ab30", - strip_prefix = "zlib-1.2.13", + sha256 = "9a93b2b7dfdac77ceba5a558a580e74667dd6fede4585b91eefb60f03b72df23", + strip_prefix = "zlib-1.3.1", system_build_file = "//third_party/systemlibs:zlib.BUILD", - urls = tf_mirror_urls("https://zlib.net/fossils/zlib-1.2.13.tar.gz"), + urls = tf_mirror_urls("https://zlib.net/fossils/zlib-1.3.1.tar.gz"), ) tf_http_archive( @@ -623,8 +622,6 @@ def workspace(): # written according to common practice to query native.existing_rule()). _tf_repositories() - tfrt_dependencies() - # Alias so it can be loaded without assigning to a different symbol to prevent # shadowing previous loads and trigger a buildifier warning. tsl_workspace2 = workspace diff --git a/third_party/xla/third_party/tsl/workspace3.bzl b/third_party/xla/third_party/tsl/workspace3.bzl index 9510b09374206c..a1293f59a48885 100644 --- a/third_party/xla/third_party/tsl/workspace3.bzl +++ b/third_party/xla/third_party/tsl/workspace3.bzl @@ -2,7 +2,6 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("//third_party/llvm:workspace.bzl", llvm = "repo") -load("//third_party/tf_runtime:workspace.bzl", tf_runtime = "repo") def workspace(): http_archive( @@ -15,8 +14,6 @@ def workspace(): ], ) - tf_runtime() - # https://github.com/bazelbuild/bazel-skylib/releases http_archive( name = "bazel_skylib", diff --git a/third_party/xla/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/third_party/xla/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl index 4455aea60109fa..6a7a4f116843d9 100644 --- a/third_party/xla/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl +++ b/third_party/xla/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl @@ -74,22 +74,23 @@ def aarch64_compiler_configure(): "ml2014_clang_aarch64-python3.9": "docker://localhost/tensorflow-build-aarch64:latest-python3.9", "ml2014_clang_aarch64-python3.10": "docker://localhost/tensorflow-build-aarch64:latest-python3.10", "ml2014_clang_aarch64-python3.11": "docker://localhost/tensorflow-build-aarch64:latest-python3.11", + "ml2014_clang_aarch64-python3.12": "docker://localhost/tensorflow-build-aarch64:latest-python3.12", }, env = { "ABI_LIBC_VERSION": "glibc_2.17", "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", + "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", "BAZEL_HOST_SYSTEM": "aarch64-unknown-linux-gnu", "BAZEL_TARGET_CPU": "generic", "BAZEL_TARGET_LIBC": "glibc_2.17", "BAZEL_TARGET_SYSTEM": "aarch64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", + "CC": "/usr/lib/llvm-18/bin/clang", "CC_TOOLCHAIN_NAME": "linux_llvm_aarch64", "CLEAR_CACHE": "1", "CUDNN_INSTALL_PATH": "", - "CLANG_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", + "CLANG_COMPILER_PATH": "/usr/lib/llvm-18/bin/clang", + "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", + "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", "PYTHON_BIN_PATH": "/usr/local/bin/python3", "TENSORRT_INSTALL_PATH": "", "TF_CUDA_CLANG": "0", diff --git a/third_party/xla/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD index a064bbf1402511..be6e8968d53a64 100644 --- a/third_party/xla/tools/toolchains/cross_compile/cc/BUILD +++ b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD @@ -62,15 +62,15 @@ cc_toolchain_config( cpu = "k8", cxx_builtin_include_directories = [ "/dt9/", - "/usr/lib/llvm-17/include/", - "/usr/lib/llvm-17/lib/clang/17/include", + "/usr/lib/llvm-18/include/", + "/usr/lib/llvm-18/lib/clang/18/include", ], dbg_compile_flags = ["-g"], host_system_name = "linux", link_flags = [ "--target=x86_64-unknown-linux-gnu", "-fuse-ld=lld", - "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "--ld-path=/usr/lib/llvm-18/bin/ld.lld", "-Wl,--undefined-version", ], link_libs = [ @@ -90,14 +90,14 @@ cc_toolchain_config( target_libc = "", target_system_name = "x86_64-unknown-linux-gnu", tool_paths = { - "gcc": "/usr/lib/llvm-17/bin/clang", - "ld": "/usr/lib/llvm-17/bin/ld.lld", - "ar": "/usr/lib/llvm-17/bin/llvm-ar", - "cpp": "/usr/lib/llvm-17/bin/clang++", - "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", - "nm": "/usr/lib/llvm-17/bin/llvm-nm", - "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", - "strip": "/usr/lib/llvm-17/bin/llvm-strip", + "gcc": "/usr/lib/llvm-18/bin/clang", + "ld": "/usr/lib/llvm-18/bin/ld.lld", + "ar": "/usr/lib/llvm-18/bin/llvm-ar", + "cpp": "/usr/lib/llvm-18/bin/clang++", + "llvm-cov": "/usr/lib/llvm-18/bin/llvm-cov", + "nm": "/usr/lib/llvm-18/bin/llvm-nm", + "objdump": "/usr/lib/llvm-18/bin/llvm-objdump", + "strip": "/usr/lib/llvm-18/bin/llvm-strip", }, toolchain_identifier = "linux_x86_toolchain", unfiltered_compile_flags = [ @@ -148,15 +148,15 @@ cc_toolchain_config( cpu = "aarch64", cxx_builtin_include_directories = [ "/dt10/", - "/usr/lib/llvm-17/include/", - "/usr/lib/llvm-17/lib/clang/17/include", + "/usr/lib/llvm-18/include/", + "/usr/lib/llvm-18/lib/clang/18/include", ], dbg_compile_flags = ["-g"], host_system_name = "linux", link_flags = [ "--target=aarch64-unknown-linux-gnu", "-fuse-ld=lld", - "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "--ld-path=/usr/lib/llvm-18/bin/ld.lld", "-Wl,--undefined-version", ], link_libs = [ @@ -176,14 +176,14 @@ cc_toolchain_config( target_libc = "", target_system_name = "aarch64-unknown-linux-gnu", tool_paths = { - "gcc": "/usr/lib/llvm-17/bin/clang", - "ld": "/usr/lib/llvm-17/bin/ld.lld", - "ar": "/usr/lib/llvm-17/bin/llvm-ar", - "cpp": "/usr/lib/llvm-17/bin/clang++", - "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", - "nm": "/usr/lib/llvm-17/bin/llvm-nm", - "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", - "strip": "/usr/lib/llvm-17/bin/llvm-strip", + "gcc": "/usr/lib/llvm-18/bin/clang", + "ld": "/usr/lib/llvm-18/bin/ld.lld", + "ar": "/usr/lib/llvm-18/bin/llvm-ar", + "cpp": "/usr/lib/llvm-18/bin/clang++", + "llvm-cov": "/usr/lib/llvm-18/bin/llvm-cov", + "nm": "/usr/lib/llvm-18/bin/llvm-nm", + "objdump": "/usr/lib/llvm-18/bin/llvm-objdump", + "strip": "/usr/lib/llvm-18/bin/llvm-strip", }, toolchain_identifier = "linux_aarch64_toolchain", unfiltered_compile_flags = [ @@ -238,8 +238,8 @@ cc_toolchain_config( cpu = "darwin", cxx_builtin_include_directories = [ "%sysroot%/usr/include", - "/usr/lib/llvm-17/include/", - "/usr/lib/llvm-17/lib/clang/17/include", + "/usr/lib/llvm-18/include/", + "/usr/lib/llvm-18/lib/clang/18/include", "%sysroot%/System/Library/Frameworks/Security.framework/Headers", "%sysroot%/System/Library/Frameworks/CoreFoundation.framework/Headers", "%sysroot%/System/Library/Frameworks/SystemConfiguration.framework/Headers", @@ -250,7 +250,7 @@ cc_toolchain_config( "--target=x86_64-apple-darwin", "-lSystem", "-fuse-ld=lld", - "--ld-path=/usr/lib/llvm-17/bin/ld64.lld", + "--ld-path=/usr/lib/llvm-18/bin/ld64.lld", "-headerpad_max_install_names", "-Wl,-undefined,dynamic_lookup", # Target Catalina as the minimum supported OS @@ -274,13 +274,13 @@ cc_toolchain_config( target_system_name = "x86_64-apple-macosx10.15", tool_paths = { "gcc": "cc_wrapper.sh", - "ld": "/usr/lib/llvm-17/bin/ld64.lld", - "ar": "/usr/lib/llvm-17/bin/llvm-libtool-darwin", - "cpp": "/usr/lib/llvm-17/bin/clang++", - "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", - "nm": "/usr/lib/llvm-17/bin/llvm-nm", - "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", - "strip": "/usr/lib/llvm-17/bin/llvm-strip", + "ld": "/usr/lib/llvm-18/bin/ld64.lld", + "ar": "/usr/lib/llvm-18/bin/llvm-libtool-darwin", + "cpp": "/usr/lib/llvm-18/bin/clang++", + "llvm-cov": "/usr/lib/llvm-18/bin/llvm-cov", + "nm": "/usr/lib/llvm-18/bin/llvm-nm", + "objdump": "/usr/lib/llvm-18/bin/llvm-objdump", + "strip": "/usr/lib/llvm-18/bin/llvm-strip", }, toolchain_identifier = "macos_x86_toolchain", unfiltered_compile_flags = [ diff --git a/third_party/xla/tools/toolchains/cross_compile/config/BUILD b/third_party/xla/tools/toolchains/cross_compile/config/BUILD index 386b8858fa8b3d..efc929c9b24a2d 100644 --- a/third_party/xla/tools/toolchains/cross_compile/config/BUILD +++ b/third_party/xla/tools/toolchains/cross_compile/config/BUILD @@ -9,7 +9,7 @@ platform( "@platforms//cpu:x86_64", ], exec_properties = { - "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:06040763c500bd2ebaaa4585d4729c88d2c8ccec94baa7fbe9bbe3dc2827d79d", "OSFamily": "Linux", }, ) diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 756c7bbe8d786c..d8f137804ae52c 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -728,11 +728,11 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.17": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:8ca6b205b54f18d26a053cfe606145b8b11cc99cf83fc970a936ce327913c3c3", - "sigbuild-r2.17-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:5cfd081a337548165a800546f2365a38245e38e7a97052b1a21830bf66b2356d", - "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:933c9f4bf65c92780863e00bd2132c6cfd41dbd624736c1af0dd2a5a056db6b8", + "sigbuild-r2.17": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:d0f27a4c7b97dbe9d530703dca3449afd464758e56b3ac4e1609c701223a0572", + "sigbuild-r2.17-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:64e68a1d65ac265a2a59c8c2f6eb1f2148a323048a679a08e53239d467fa1478", + "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. @@ -767,29 +767,29 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.17-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:8ca6b205b54f18d26a053cfe606145b8b11cc99cf83fc970a936ce327913c3c3", - "sigbuild-r2.17-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:5cfd081a337548165a800546f2365a38245e38e7a97052b1a21830bf66b2356d", - "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545", - "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:933c9f4bf65c92780863e00bd2132c6cfd41dbd624736c1af0dd2a5a056db6b8", + "sigbuild-r2.17-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:d0f27a4c7b97dbe9d530703dca3449afd464758e56b3ac4e1609c701223a0572", + "sigbuild-r2.17-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:64e68a1d65ac265a2a59c8c2f6eb1f2148a323048a679a08e53239d467fa1478", + "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", + "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. env = { "ABI_LIBC_VERSION": "glibc_2.19", "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", + "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", "BAZEL_TARGET_CPU": "k8", "BAZEL_TARGET_LIBC": "glibc_2.19", "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", + "CC": "/usr/lib/llvm-18/bin/clang", "CC_TOOLCHAIN_NAME": "linux_gnu_x86", "CLEAR_CACHE": "1", "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", - "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", + "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-18/bin/clang", + "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", + "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index eaca59116decf2..fbf38aea099f98 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -128,25 +128,27 @@ double ComputeMemoryReshardingCost(const Shape& shape, shape, device_mesh.num_elements(), dst_sharding)); if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { - Shape inter_shape = ComputeIntermediateShape(src_sharding, dst_sharding, - shape, device_mesh); - - std::optional src_inter_sharding = - hlo_sharding_util::ReshapeSharding(shape, inter_shape, src_sharding); - std::optional dst_inter_sharding = - hlo_sharding_util::ReshapeSharding(shape, inter_shape, dst_sharding); - if (!src_inter_sharding.has_value() || !dst_inter_sharding.has_value()) { - src_inter_sharding = HloSharding::Replicate(); - dst_inter_sharding = HloSharding::Replicate(); - } - - result = std::max( - result, - static_cast(std::max( - GetShardedInstructionSize(inter_shape, device_mesh.num_elements(), - src_inter_sharding), - GetShardedInstructionSize(inter_shape, device_mesh.num_elements(), - dst_inter_sharding)))); + absl::StatusOr inter_shape = ComputeIntermediateShape( + src_sharding, dst_sharding, shape, device_mesh); + if (inter_shape.ok()) { + std::optional src_inter_sharding = + hlo_sharding_util::ReshapeSharding(shape, *inter_shape, src_sharding); + std::optional dst_inter_sharding = + hlo_sharding_util::ReshapeSharding(shape, *inter_shape, dst_sharding); + if (!src_inter_sharding.has_value() || !dst_inter_sharding.has_value()) { + src_inter_sharding = HloSharding::Replicate(); + dst_inter_sharding = HloSharding::Replicate(); + } + + result = std::max( + result, + static_cast(std::max( + GetShardedInstructionSize( + *inter_shape, device_mesh.num_elements(), src_inter_sharding), + GetShardedInstructionSize(*inter_shape, + device_mesh.num_elements(), + dst_inter_sharding)))); + } } return result - src_sharded_bytes; } @@ -3656,7 +3658,6 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( } } VLOG(10) << hlo_live_range->ToString(); - VLOG(10) << spmd::PrintLivenessSet(liveness_set); XLA_VLOG_LINES(10, spmd::PrintLivenessSet(liveness_set)); const HloInstructionSequence& sequence = hlo_live_range->flattened_instruction_sequence(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index ef6f99f7e4e727..d07dcd0208d1b3 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1272,10 +1272,9 @@ std::vector GetTensorDimToMeshDim( } } -Shape ComputeIntermediateShape(const HloSharding& src_sharding, - const HloSharding& dst_sharding, - const Shape& shape, - const Array& device_mesh) { +absl::StatusOr ComputeIntermediateShape( + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const Shape& shape, const Array& device_mesh) { int64_t src_n_dim = NumTileDimensions(src_sharding); const HloSharding* sharding_1d; @@ -1293,8 +1292,10 @@ Shape ComputeIntermediateShape(const HloSharding& src_sharding, if (sharding_1d->tile_assignment().dim(i) == 1) { inter_shape_dims.push_back(shape.dimensions(i)); } else { - CHECK(shape.dimensions(i) % device_mesh.dim(0) == 0) - << "Only support even partition"; + // TODO(b/333750146): Support this case instead of bailing here + if (shape.dimensions(i) % device_mesh.dim(0) != 0) { + return absl::InternalError("Indivisible tensor dims"); + } inter_shape_dims.push_back(device_mesh.dim(0)); inter_shape_dims.push_back(shape.dimensions(i) / device_mesh.dim(0)); } @@ -1318,29 +1319,33 @@ HloInstruction* ReshardTensor(HloInstruction* tensor, HloInstruction* replace_with = nullptr; if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { - Shape inter_shape = ComputeIntermediateShape(src_sharding, dst_sharding, - shape, device_mesh); - - std::optional src_inter_sharding = - hlo_sharding_util::ReshapeSharding(shape, inter_shape, src_sharding); - std::optional dst_inter_sharding = - hlo_sharding_util::ReshapeSharding(shape, inter_shape, dst_sharding); - if (!src_inter_sharding.has_value() || !dst_inter_sharding.has_value()) { - src_inter_sharding = HloSharding::Replicate(); - dst_inter_sharding = HloSharding::Replicate(); - LOG(WARNING) << "Invalid mixed mesh shape resharding."; - } + absl::StatusOr inter_shape = ComputeIntermediateShape( + src_sharding, dst_sharding, shape, device_mesh); + if (inter_shape.ok()) { + std::optional src_inter_sharding = + hlo_sharding_util::ReshapeSharding(shape, *inter_shape, src_sharding); + std::optional dst_inter_sharding = + hlo_sharding_util::ReshapeSharding(shape, *inter_shape, dst_sharding); + if (!src_inter_sharding.has_value() || !dst_inter_sharding.has_value()) { + src_inter_sharding = HloSharding::Replicate(); + dst_inter_sharding = HloSharding::Replicate(); + LOG(WARNING) << "Invalid mixed mesh shape resharding."; + } - HloInstruction* src_inter = computation->AddInstruction( - HloInstruction::CreateReshape(inter_shape, tensor)); - src_inter->set_sharding(*src_inter_sharding); + HloInstruction* src_inter = computation->AddInstruction( + HloInstruction::CreateReshape(*inter_shape, tensor)); + src_inter->set_sharding(*src_inter_sharding); - HloInstruction* dst_inter = computation->AddInstruction( - HloInstruction::CreateReshape(inter_shape, src_inter)); - dst_inter->set_sharding(*dst_inter_sharding); + HloInstruction* dst_inter = computation->AddInstruction( + HloInstruction::CreateReshape(*inter_shape, src_inter)); + dst_inter->set_sharding(*dst_inter_sharding); - replace_with = computation->AddInstruction( - HloInstruction::CreateReshape(shape, dst_inter)); + replace_with = computation->AddInstruction( + HloInstruction::CreateReshape(shape, dst_inter)); + } else { + replace_with = computation->AddInstruction( + HloInstruction::CreateReshape(shape, tensor)); + } } else { replace_with = computation->AddInstruction( HloInstruction::CreateReshape(shape, tensor)); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 5e6548f9b20599..1ca5c644d58041 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_UTIL_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_UTIL_H_ @@ -455,10 +454,9 @@ int64_t NumTileDimensions(const HloSharding& spec); // When fixing mixed mesh resharding (see below), compute the correct // intermediate shape in order to insert copies. -Shape ComputeIntermediateShape(const HloSharding& src_sharding, - const HloSharding& dst_sharding, - const Shape& shape, - const Array& device_mesh); +absl::StatusOr ComputeIntermediateShape( + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const Shape& shape, const Array& device_mesh); // Forcibly set the sharding of the operand of inst. // Also fix the resharding between 1d and 2d logical mesh. diff --git a/third_party/xla/xla/mlir/memref/BUILD b/third_party/xla/xla/mlir/memref/BUILD index aa2c29d34e2df1..7f628aefb16c13 100644 --- a/third_party/xla/xla/mlir/memref/BUILD +++ b/third_party/xla/xla/mlir/memref/BUILD @@ -4,7 +4,6 @@ package_group( "//xla/mlir/...", # copybara:uncomment_begin(google-only) # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. - # "@tf_runtime//...", # "//third_party/py/enzyme_ad/...", # copybara:uncomment_end(google-only) ], diff --git a/third_party/xla/xla/mlir/runtime/BUILD b/third_party/xla/xla/mlir/runtime/BUILD index 54432dcf8c60b5..45147f52981c2b 100644 --- a/third_party/xla/xla/mlir/runtime/BUILD +++ b/third_party/xla/xla/mlir/runtime/BUILD @@ -6,7 +6,6 @@ package_group( # "//third_party/mlir_edge/tpgen/...", # # TODO(ezhulenev): Clean up dependencies that are leftovers from Autofusion project. # "@tf_runtime//...", - # "//third_party/tf_runtime_google/...", # copybara:uncomment_end(google-only) "//tensorflow/compiler/mlir/tfrt/...", "//xla/mlir/...", diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h index c821916add71af..77603203e4a64c 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_profiler_extension.h @@ -26,7 +26,7 @@ limitations under the License. extern "C" { #endif -#define PJRT_API_PROFILER_EXTENSION_VERSION 0 +#define PJRT_API_PROFILER_EXTENSION_VERSION 1 typedef struct PJRT_Profiler_Extension { size_t struct_size; @@ -37,7 +37,7 @@ typedef struct PJRT_Profiler_Extension { // valid only when used as an args extension int64_t traceme_context_id; } PJRT_Profiler_Extension; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Profiler_Extension, profiler_api); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Profiler_Extension, traceme_context_id); #ifdef __cplusplus } diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index ef1de960a6cbcf..5e974db0483aa6 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -91,6 +91,7 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt/ir", "@com_google_absl//absl/algorithm:container", @@ -134,6 +135,7 @@ xla_cc_test( srcs = ["future_test.cc"], deps = [ ":ifrt", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/python/ifrt/array.h b/third_party/xla/xla/python/ifrt/array.h index b63d9f4c90096f..9047945e02e447 100644 --- a/third_party/xla/xla/python/ifrt/array.h +++ b/third_party/xla/xla/python/ifrt/array.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_ARRAY_H_ #define XLA_PYTHON_IFRT_ARRAY_H_ +#include #include #include #include @@ -116,7 +117,7 @@ class Array : public llvm::RTTIExtends { // an API that lets users query the alignment requirement of the specific // implementation. ABSL_MUST_USE_RESULT - virtual Future CopyToHostBuffer( + virtual Future<> CopyToHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) = 0; diff --git a/third_party/xla/xla/python/ifrt/device.proto b/third_party/xla/xla/python/ifrt/device.proto index e5ce1f8e301a1e..53e384522d56ec 100644 --- a/third_party/xla/xla/python/ifrt/device.proto +++ b/third_party/xla/xla/python/ifrt/device.proto @@ -17,7 +17,7 @@ syntax = "proto3"; package xla.ifrt; -// Wire format for `DeviceList`. +// Proto equivalent of C++ `DeviceList`. message DeviceListProto { // Serialization and deserialization are expected to ensure that device ids // are stable across proto construction and consumption. diff --git a/third_party/xla/xla/python/ifrt/dtype.h b/third_party/xla/xla/python/ifrt/dtype.h index 517d35207705d1..ef461e6053de31 100644 --- a/third_party/xla/xla/python/ifrt/dtype.h +++ b/third_party/xla/xla/python/ifrt/dtype.h @@ -36,6 +36,7 @@ namespace ifrt { // * Add kString. class DType { public: + // LINT.IfChange enum Kind { // Invalid data type. kInvalid = 0, @@ -89,6 +90,7 @@ class DType { // collision. kString = 99, }; + // LINT.ThenChange(dtype.proto:DTypeProtoKind) explicit DType(Kind kind) : kind_(kind) {} DType(const DType&) = default; diff --git a/third_party/xla/xla/python/ifrt/dtype.proto b/third_party/xla/xla/python/ifrt/dtype.proto index b23ffa75c2a8bb..640a3ae5a6efad 100644 --- a/third_party/xla/xla/python/ifrt/dtype.proto +++ b/third_party/xla/xla/python/ifrt/dtype.proto @@ -17,8 +17,9 @@ syntax = "proto3"; package xla.ifrt; -// Data type kinds. Mirrors `xla::ifrt::DType`. +// Proto equivalent of C++ `DType`. message DTypeProto { + // LINT.IfChange(DTypeProtoKind) enum Kind { KIND_UNSPECIFIED = 0; @@ -69,5 +70,6 @@ message DTypeProto { // collision. KIND_STRING = 99; } + // LINT.ThenChange() Kind kind = 1; } diff --git a/third_party/xla/xla/python/ifrt/executable.h b/third_party/xla/xla/python/ifrt/executable.h index 0d7e0b8d43d356..0000e6ac9df9bd 100644 --- a/third_party/xla/xla/python/ifrt/executable.h +++ b/third_party/xla/xla/python/ifrt/executable.h @@ -176,7 +176,7 @@ class LoadedExecutable // Result from an execution. struct ExecuteResult { // Resulting status of the execution. - Future status; + Future<> status; // Output arrays. std::vector> outputs; }; diff --git a/third_party/xla/xla/python/ifrt/future.cc b/third_party/xla/xla/python/ifrt/future.cc index c434daa1a42a9c..3533b85af3e99e 100644 --- a/third_party/xla/xla/python/ifrt/future.cc +++ b/third_party/xla/xla/python/ifrt/future.cc @@ -21,6 +21,8 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_future.h" #include "xla/status.h" namespace xla { @@ -61,41 +63,8 @@ Future JoinFutures(absl::Span> futures) { return future; } -// TODO(b/333538339): Use Future<> with implicit Status in IFRT APIs. For now -// this is a workaround to convert between different error semantics.S -Future JoinFutures(absl::Span> futures) { - if (futures.empty()) { - return Future(OkStatus()); - } else if (futures.size() == 1) { - return futures.front().ToStatusFuture(); - } - // State shared by `PjRtFuture` onready callbacks. - struct CombinedStatus { - explicit CombinedStatus(int initial_count) - : count(initial_count), promise(Future::CreatePromise()) {} - std::atomic count; - absl::Mutex mu; - Status status ABSL_GUARDED_BY(&mu); - Promise promise; - }; - auto combined_status = std::make_shared(futures.size()); - Future future(combined_status->promise); - for (auto& fut : futures) { - fut.OnReady([combined_status](Status s) { - if (!s.ok()) { - absl::MutexLock lock(&combined_status->mu); - combined_status->status.Update(std::move(s)); - } - const int pre_dec_count = - combined_status->count.fetch_add(-1, std::memory_order_acq_rel); - CHECK_GE(pre_dec_count, 1); - if (pre_dec_count == 1) { - absl::MutexLock lock(&combined_status->mu); - combined_status->promise.Set(std::move(combined_status->status)); - } - }); - } - return future; +Future<> JoinFutures(absl::Span> futures) { + return ::xla::JoinFutures(futures); } } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/future.h b/third_party/xla/xla/python/ifrt/future.h index 7b3d96b975ab59..0c87042f971286 100644 --- a/third_party/xla/xla/python/ifrt/future.h +++ b/third_party/xla/xla/python/ifrt/future.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_FUTURE_H_ #define XLA_PYTHON_IFRT_FUTURE_H_ -#include "xla/pjrt/pjrt_client.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_future.h" #include "xla/status.h" namespace xla { @@ -36,20 +37,17 @@ namespace ifrt { // (1) no reference counting of `Future`s sharing the same `Promise` and (2) // safe mutable access to the value when the `Future` becomes ready, including // moving the value out of the `Future`/`Promise`. -template +template using Future = ::xla::PjRtFuture; -template +template using Promise = typename ::xla::PjRtFuture::Promise; -// TODO(b/333538339): Add JoinFutures API to PjRtFuture. - // Returns a `Future` that aggregates the return status of all `Future`s. Future JoinFutures(absl::Span> futures); -// TODO(b/333538339): Use Future<> with implicit Status in IFRT APIs. For now -// this is a workaround to convert between different error semantics.S -Future JoinFutures(absl::Span> futures); +// Returns a `Future` that aggregates the return status of all `Future`s. +Future<> JoinFutures(absl::Span> futures); } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/future_test.cc b/third_party/xla/xla/python/ifrt/future_test.cc index 4eb4c737a03fb9..5a967318af84dd 100644 --- a/third_party/xla/xla/python/ifrt/future_test.cc +++ b/third_party/xla/xla/python/ifrt/future_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" @@ -55,7 +56,7 @@ TEST(FutureTest, JoinOneFailingFuture) { Future future = JoinFutures(absl::MakeSpan(futures)); ASSERT_FALSE(future.IsReady()); - promise.Set(InvalidArgument("Some error")); + promise.Set(absl::InvalidArgumentError("Some error")); EXPECT_THAT(future.Await(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Some error"))); } @@ -95,7 +96,7 @@ TEST(FutureTest, JoinAllFailingFutures) { ASSERT_FALSE(future.IsReady()); for (Promise& promise : promises) { - promise.Set(InvalidArgument("Some error")); + promise.Set(absl::InvalidArgumentError("Some error")); } EXPECT_THAT(future.Await(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Some error"))); @@ -120,7 +121,7 @@ TEST_P(JoinAllOkFuturesExceptForOneTest, JoinAllOkFuturesExceptForOne) { ASSERT_FALSE(future.IsReady()); for (int i = 0; i < kNumFutures; ++i) { if (i == failing_future_idx) { - promises[i].Set(InvalidArgument("Some error")); + promises[i].Set(absl::InvalidArgumentError("Some error")); } else { promises[i].Set(OkStatus()); } diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index 71a3352a7a5498..a8dbc2337b3a65 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -55,8 +55,8 @@ class MockArray : public llvm::RTTIExtends { // LINT.IfChange MOCK_METHOD(Client*, client, (), (const, final)); - MOCK_METHOD(Future, GetReadyFuture, (), (const, final)); - MOCK_METHOD(Future, Delete, (), (final)); + MOCK_METHOD(Future<>, GetReadyFuture, (), (const, final)); + MOCK_METHOD(Future<>, Delete, (), (final)); MOCK_METHOD(bool, IsDeleted, (), (const, final)); MOCK_METHOD(std::string, DebugString, (), (const, final)); @@ -72,7 +72,7 @@ class MockArray : public llvm::RTTIExtends { (final)); MOCK_METHOD(absl::StatusOr>, FullyReplicatedShard, (ArrayCopySemantics semantics), (final)); - MOCK_METHOD(Future, CopyToHostBuffer, + MOCK_METHOD(Future<>, CopyToHostBuffer, (void* data, std::optional> byte_strides, ArrayCopySemantics semantics), diff --git a/third_party/xla/xla/python/ifrt/shape.proto b/third_party/xla/xla/python/ifrt/shape.proto index 354383dbd43a2b..cd5d26e13826f9 100644 --- a/third_party/xla/xla/python/ifrt/shape.proto +++ b/third_party/xla/xla/python/ifrt/shape.proto @@ -17,18 +17,19 @@ syntax = "proto3"; package xla.ifrt; -// Wire format for `Shape`. Currently support static shapes with all dimension -// sizes greater than or equal to 0. +// Proto equivalent of C++ `Shape`. Currently support static shapes with all +// dimension sizes greater than or equal to 0. message ShapeProto { repeated int64 dims = 1; } -// Wire format for `BoundedDynamicShapeTag`. +// Proto equivalent of C++ `BoundedDynamicShapeTag`. message BoundedDynamicShapeTagProto { repeated bool is_dynamic_dims = 1; } -// Wire format for `DynamicShape`. Currently only support bounded dynamic shape. +// Proto equivalent of C++ `DynamicShape`. Currently only support bounded +// dynamic shape. message DynamicShapeProto { ShapeProto shape = 1; oneof tag { diff --git a/third_party/xla/xla/python/ifrt/sharding.proto b/third_party/xla/xla/python/ifrt/sharding.proto index 5a6c1484922d09..a6033bb348e021 100644 --- a/third_party/xla/xla/python/ifrt/sharding.proto +++ b/third_party/xla/xla/python/ifrt/sharding.proto @@ -19,7 +19,7 @@ package xla.ifrt; import "xla/python/ifrt/serdes.proto"; -// Wire format for `Sharding`. A suitable serializer and deserializer +// Proto equivalent of C++ `Sharding`. A suitable serializer and deserializer // implementation must be registered. message ShardingProto { xla.ifrt.Serialized serialized_sharding = 1; diff --git a/third_party/xla/xla/python/ifrt/sharding_serdes.proto b/third_party/xla/xla/python/ifrt/sharding_serdes.proto index dcd8fc8fe6b65b..c95a766c592b6a 100644 --- a/third_party/xla/xla/python/ifrt/sharding_serdes.proto +++ b/third_party/xla/xla/python/ifrt/sharding_serdes.proto @@ -20,7 +20,7 @@ package xla.ifrt; import "xla/python/ifrt/device.proto"; import "xla/python/ifrt/shape.proto"; -// Wire format for `SingleDeviceSharding`. +// Proto equivalent of C++ `SingleDeviceSharding`. message SingleDeviceShardingProto { // Serialization and deserialization are expected to ensure that device ids // are stable across proto construction and consumption. @@ -28,13 +28,13 @@ message SingleDeviceShardingProto { optional string memory_kind = 2; } -// Wire format for `OpaqueSharding`. +// Proto equivalent of C++ `OpaqueSharding`. message OpaqueShardingProto { DeviceListProto devices = 1; optional string memory_kind = 2; } -// Wire format for `ConcreteSharding`. +// Proto equivalent of C++ `ConcreteSharding`. message ConcreteShardingProto { DeviceListProto devices = 1; optional string memory_kind = 4; @@ -46,7 +46,7 @@ message ConcreteShardingProto { repeated DynamicShapeProto shard_dynamic_shapes = 6; } -// Wire format for `ConcreteEvenSharding`. +// Proto equivalent of C++ `ConcreteEvenSharding`. message ConcreteEvenShardingProto { DeviceListProto devices = 1; optional string memory_kind = 4; diff --git a/third_party/xla/xla/python/ifrt/value.h b/third_party/xla/xla/python/ifrt/value.h index 11bcf519ce6dc2..5293e4da84ceaa 100644 --- a/third_party/xla/xla/python/ifrt/value.h +++ b/third_party/xla/xla/python/ifrt/value.h @@ -44,7 +44,7 @@ class Value : public tsl::ReferenceCounted, // Returns a future that becomes ready when the buffer is computed or has an // error. - virtual Future GetReadyFuture() const = 0; + virtual Future<> GetReadyFuture() const = 0; // Deletes the value from the devices. The operation may be asynchronous. The // returned future will have the result of the deletion on the devices, and @@ -52,7 +52,7 @@ class Value : public tsl::ReferenceCounted, // Implementations that do not track the completion of the deletion operation // may make the future immediately ready with an OK status. // TODO(phawkins): decide if we want Delete() to be idempotent. - virtual Future Delete() = 0; + virtual Future<> Delete() = 0; // Returns whether the value has been enqueued for deletion from the devices. virtual bool IsDeleted() const = 0; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc index 88b51add3ecf30..34edd2cbc7e347 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -112,33 +112,40 @@ void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { }); } -Future Array::GetReadyFuture() const { +Future<> Array::GetReadyFuture() const { auto req = std::make_unique(); req->set_array_handle(handle_.handle); - auto promise = Future::CreatePromise(); + auto promise = Future<>::CreatePromise(); rpc_helper_->CheckArrayReady(std::move(req)) .OnReady( [promise](absl::StatusOr> - resp) mutable -> void { promise.Set(resp.status()); }); - return Future(std::move(promise)); + resp) mutable -> void { + if (resp.status().ok()) { + promise.Set(); + } else { + promise.SetError(resp.status()); + } + }); + return Future<>(std::move(promise)); } -Future Array::Delete() { +Future<> Array::Delete() { auto req = std::make_unique(); req->set_array_handle(handle_.handle); absl::StatusOr> response = rpc_helper_->DeleteArray(std::move(req)).Await(); if (!response.ok()) { - return Future(response.status()); + return Future<>(response.status()); } // TODO(b/266635130): So that the caller is not blocked until the server // replies with the deletion's response, from within // `Future(status_handle_promise).OnReady()`, schedule `CheckFuture()` on a // separate thread. - return rpc_helper_->CheckFuture((*response)->deletion_future_handle()); + return Future<>::FromStatusFuture( + rpc_helper_->CheckFuture((*response)->deletion_future_handle())); } bool Array::IsDeleted() const { @@ -263,13 +270,13 @@ absl::StatusOr> Array::Reshard( client_, rpc_helper_, dtype_, shape_, std::move(new_sharding), handle)); } -Future Array::CopyToHostBuffer( +Future<> Array::CopyToHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) { const auto mem_region = ArrayMemRegion::FromZerothElementPointer( /*zeroth_element=*/data, dtype_, shape_, byte_strides); if (!mem_region.ok()) { - return Future(mem_region.status()); + return Future<>(mem_region.status()); } auto req = std::make_unique(); @@ -281,14 +288,14 @@ Future Array::CopyToHostBuffer( rpc_helper_->host_buffer_store()->NextHandle(); req->set_host_buffer_handle(host_buffer_handle); - auto promise = Future::CreatePromise(); + auto promise = Future<>::CreatePromise(); auto on_ready = [host_buffer_store = rpc_helper_->host_buffer_store(), promise, host_buffer_handle, mem_region = mem_region->mem_region()]( absl::StatusOr> resp) mutable { if (!resp.ok()) { - promise.Set(resp.status()); + promise.SetError(resp.status()); return; } @@ -307,7 +314,7 @@ Future Array::CopyToHostBuffer( }; if (!data.ok()) { - promise.Set(data.status()); + promise.SetError(data.status()); return; } if (data->size() != mem_region.size()) { @@ -316,7 +323,7 @@ Future Array::CopyToHostBuffer( "response from proxy: ", mem_region.size(), " vs ", data->size())); LOG(ERROR) << status; - promise.Set(status); + promise.SetError(status); return; } #if defined(PLATFORM_GOOGLE) @@ -325,11 +332,11 @@ Future Array::CopyToHostBuffer( std::memcpy(const_cast(mem_region.data()), data->Flatten().data(), data->size()); #endif - promise.Set(absl::OkStatus()); + promise.Set(); }); }; rpc_helper_->CopyToHostBuffer(std::move(req)).OnReady(std::move(on_ready)); - return Future(std::move(promise)); + return Future<>(std::move(promise)); } xla::ifrt::Client* Array::client() const { return client_; } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.h b/third_party/xla/xla/python/ifrt_proxy/client/array.h index 3b5e8d9d5e1149..9e96d14900a047 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.h @@ -91,8 +91,8 @@ class Array final : public llvm::RTTIExtends { ArrayHandle handle() const { return handle_; } xla::ifrt::Client* client() const override; - Future GetReadyFuture() const override; - Future Delete() override; + Future<> GetReadyFuture() const override; + Future<> Delete() override; bool IsDeleted() const override; std::string DebugString() const override; @@ -114,7 +114,7 @@ class Array final : public llvm::RTTIExtends { xla::ifrt::ArrayCopySemantics semantics) override; ABSL_MUST_USE_RESULT - Future CopyToHostBuffer( + Future<> CopyToHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) override; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc index fbcdde60365804..96b46569819e35 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc @@ -437,7 +437,8 @@ LoadedExecutable::Execute(absl::Span> args, // Populate the execution status future. `CheckFuture` deletes the server-side // futures after its completion. - result.status = rpc_helper_->CheckFuture(response->status_handle()); + result.status = Future<>::FromStatusFuture( + rpc_helper_->CheckFuture(response->status_handle())); // Create output arrays. The cleanup logic ensures that all handles are // properly cleaned up on early return. diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc index 3d4cd8fd800d51..9a0121fdcb2826 100644 --- a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -179,7 +179,7 @@ TEST_F(MockArrayTest, ReadyFuturePropagatesError) { TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { - return Future(absl::InternalError("testing")); + return Future<>(absl::InternalError("testing")); }); EXPECT_THAT(arr.proxy_client_array->GetReadyFuture().Await(), @@ -198,12 +198,12 @@ TEST_F(MockArrayTest, DeletionFutureWaitsUntilDeleted) { // returns being blocked on `wait_ready`. That version of the testcase does // not currently work since both the client and the server synchronously // block until the MockArray's Delete() returns. - auto promise = Future::CreatePromise(); + auto promise = Future<>::CreatePromise(); threads.Schedule([&, promise]() mutable { wait_ready.WaitForNotification(); promise.Set(arr.backend_array->delegated()->Delete().Await()); }); - return Future(promise); + return Future<>(promise); }); EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); @@ -222,7 +222,7 @@ TEST_F(MockArrayTest, DeletionPropagatesError) { TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { - return Future(absl::InternalError("testing")); + return Future<>(absl::InternalError("testing")); }); EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); @@ -258,7 +258,7 @@ TEST_F(MockArrayTest, CopyToHostFuturePropagatesError) { absl::Notification wait_ready; EXPECT_CALL(*arr.backend_array, CopyToHostBuffer).WillOnce([&] { - return Future(absl::InternalError("testing")); + return Future<>(absl::InternalError("testing")); }); char data[1000]; diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index b57b922ce18eac..5edb8f542301a9 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -473,7 +473,7 @@ Future IfrtBackend::HandleCopyToHostBufferRequest( } // TODO(b/282757875): Consider other ArrayCopySemantics. - Future copy_status = + Future<> copy_status = (*array)->CopyToHostBuffer(mem_region->zeroth_element(), byte_strides, ArrayCopySemantics::kAlwaysCopy); @@ -631,7 +631,8 @@ BackendInterface::Response IfrtBackend::HandleDeleteArrayRequest( uint64_t future_handle = handle_generator_.New(); { absl::MutexLock lock(&futures_mutex_); - futures_.insert({future_handle, std::move(deletion_future)}); + futures_.insert( + {future_handle, std::move(deletion_future).ToStatusFuture()}); } auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); @@ -915,8 +916,8 @@ BackendInterface::Response IfrtBackend::HandleLoadedExecutableExecuteRequest( { absl::MutexLock lock(&futures_mutex_); execute_response->set_status_handle(handle_generator_.New()); - futures_.insert( - {execute_response->status_handle(), std::move(result.status)}); + futures_.insert({execute_response->status_handle(), + std::move(result.status).ToStatusFuture()}); } // Register output arrays. At this point, we should never early return because diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 08838e05e7e349..2ee6a22ce34630 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -608,7 +608,7 @@ TEST_F(IfrtBackendHandlerTest, CopyToHostSuccess) { const std::optional> expected_byte_strides = absl::Span(expected_byte_strides_vec); EXPECT_CALL(*array, CopyToHostBuffer(_, expected_byte_strides, _)) - .WillOnce(Return(Future(absl::OkStatus()))); + .WillOnce(Return(Future<>(absl::OkStatus()))); TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); // Given the above shape, dtype, and compact byte_strides, the size of the @@ -768,9 +768,8 @@ TEST_F(IfrtBackendHandlerTest, CheckArrayReadyRequestRelaysTheResultFromBackend) { auto mock_array = tsl::MakeRef(); EXPECT_CALL(*mock_array, GetReadyFuture()) - .WillOnce(Return(Future(absl::OkStatus()))) - .WillOnce( - Return(Future(absl::UnknownError("injected error")))); + .WillOnce(Return(Future<>(absl::OkStatus()))) + .WillOnce(Return(Future<>(absl::UnknownError("injected error")))); TF_ASSERT_OK_AND_ASSIGN(auto array_handle, MakeTestArray(std::move(mock_array))); @@ -807,7 +806,7 @@ TEST_F(IfrtBackendHandlerTest, DeleteArraySuccess) { tsl::RCReference mock_array = tsl::MakeRef(); EXPECT_CALL(*mock_array, Delete()) - .WillOnce(Return(Future(absl::OkStatus()))); + .WillOnce(Return(Future<>(absl::OkStatus()))); TF_ASSERT_OK_AND_ASSIGN(auto array_handle, MakeTestArray(std::move(mock_array))); @@ -1086,8 +1085,7 @@ TEST_F(IfrtBackendHandlerTest, LoadedExecutableExecute) { std::optional devices) -> absl::StatusOr { return LoadedExecutable::ExecuteResult{ - .status = - Future(absl::InternalError("injected error")), + .status = Future<>(absl::InternalError("injected error")), .outputs = outputs, }; })); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 68bc459684762b..2d93e4c9d65412 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -345,19 +345,19 @@ PjRtArray::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { return result; } -Future PjRtArray::CopyToHostBuffer( +Future<> PjRtArray::CopyToHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) { DCHECK(this); if (sharding_->devices().size() != 1) { - return Future( + return Future<>( InvalidArgument("Only single-shard is implemented, but got %d", sharding_->devices().size())); } auto dtype = ToPrimitiveType(dtype_); if (!dtype.ok()) { - return Future(std::move(dtype).status()); + return Future<>(std::move(dtype).status()); } PjRtBuffer* pjrt_buffer = pjrt_buffers_.front().get(); @@ -374,7 +374,7 @@ Future PjRtArray::CopyToHostBuffer( // TODO(b/314805296): Use the new dynamic shape here. logical_dims = pjrt_buffer->logical_dimensions(); if (!logical_dims.ok()) { - return Future(std::move(logical_dims).status()); + return Future<>(std::move(logical_dims).status()); } dims = *logical_dims; } @@ -384,7 +384,7 @@ Future PjRtArray::CopyToHostBuffer( auto xla_shape = MakeShapeWithTrivialByteStrides(*dtype, dims, *byte_strides); if (!xla_shape.ok()) { - return Future(std::move(xla_shape).status()); + return Future<>(std::move(xla_shape).status()); } literal = std::make_unique( static_cast(data), *xla_shape); @@ -394,8 +394,8 @@ Future PjRtArray::CopyToHostBuffer( static_cast(data), xla_shape); } auto* literal_ptr = literal.get(); - auto promise = Future::CreatePromise(); - Future future(promise); + auto promise = Future<>::CreatePromise(); + Future<> future(promise); // TODO(hyeontaek): Handle semantics == kDonateInput. pjrt_buffer->ToLiteral(literal_ptr) .OnReady([literal = std::move(literal), @@ -521,26 +521,26 @@ absl::StatusOr> PjRtArray::Reshard( shape_); } -Future PjRtArray::GetReadyFuture() const { +Future<> PjRtArray::GetReadyFuture() const { DCHECK(this); if (pjrt_buffers_.size() == 1) { - return pjrt_buffers_.front()->GetReadyFuture().ToStatusFuture(); + return pjrt_buffers_.front()->GetReadyFuture(); } - std::vector> futures; + std::vector> futures; futures.reserve(pjrt_buffers_.size()); for (auto& buf : pjrt_buffers_) { - futures.push_back(buf->GetReadyFuture().ToStatusFuture()); + futures.push_back(buf->GetReadyFuture()); } return JoinFutures(absl::MakeSpan(futures)); } -Future PjRtArray::Delete() { +Future<> PjRtArray::Delete() { DCHECK(this); for (auto& buffer : pjrt_buffers_) { buffer->Delete(); } // TODO(hyeontaek): Return a correct future. - return Future(OkStatus()); + return Future<>(OkStatus()); } bool PjRtArray::IsDeleted() const { diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h index 5542ae9e54e11d..614187481e85f8 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h @@ -152,7 +152,7 @@ class PjRtArray final DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; ABSL_MUST_USE_RESULT - Future CopyToHostBuffer( + Future<> CopyToHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) override; @@ -160,12 +160,12 @@ class PjRtArray final std::shared_ptr new_sharding, ArrayCopySemantics semantics) override; - Future GetReadyFuture() const override; + Future<> GetReadyFuture() const override; std::shared_ptr GetPjRtBuffer(ArrayCopySemantics semantics, int index) const; - Future Delete() override; + Future<> Delete() override; bool IsDeleted() const override; std::string DebugString() const override; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index cde1b5d33984ca..b0230c538a5c29 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -575,9 +575,9 @@ PjRtLoadedExecutable::Execute(absl::Span> args, pjrt_outputs.push_back(std::move(single_device_pjrt_results)); if (returned_future_supported) { - result.status = std::move(returned_pjrt_future)->ToStatusFuture(); + result.status = *std::move(returned_pjrt_future); } else { - result.status = Future(OkStatus()); + result.status = Future<>(OkStatus()); } } else { std::optional>> returned_pjrt_futures; @@ -592,7 +592,7 @@ PjRtLoadedExecutable::Execute(absl::Span> args, if (returned_future_supported) { result.status = JoinFutures(absl::MakeSpan(*returned_pjrt_futures)); } else { - result.status = Future(OkStatus()); + result.status = Future<>(OkStatus()); } } diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.cc index c18b2ed6c1d1ed..61578abb195c5d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" #include "tsl/concurrency/ref_count.h" namespace xla { @@ -34,8 +35,8 @@ namespace ifrt { return tsl::MakeRef(client, values); } -Future PjRtTuple::GetReadyFuture() const { - std::vector> futures; +Future<> PjRtTuple::GetReadyFuture() const { + std::vector> futures; futures.reserve(values_.size()); for (const auto& value : values_) { futures.push_back(value->GetReadyFuture()); @@ -43,14 +44,14 @@ Future PjRtTuple::GetReadyFuture() const { return JoinFutures(absl::MakeSpan(futures)); } -Future PjRtTuple::Delete() { +Future<> PjRtTuple::Delete() { { absl::MutexLock lock(&mu_); if (!is_deleted_.HasBeenNotified()) { is_deleted_.Notify(); } } - std::vector> futures; + std::vector> futures; futures.reserve(values_.size()); for (const auto& value : values_) { futures.push_back(value->Delete()); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.h index 3d4bab2c5d71be..98bc66aff699b6 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_tuple.h @@ -46,9 +46,9 @@ class PjRtTuple final : public llvm::RTTIExtends { return client_; } - Future GetReadyFuture() const override; + Future<> GetReadyFuture() const override; - Future Delete() override; + Future<> Delete() override; bool IsDeleted() const override; diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 4fe9cd52863241..9d8782ed75301a 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -1470,8 +1470,10 @@ Status PyHostValue::CopyToHostAsync(std::optional& dynamic_shape_holder, // better about an efficient layout for the host buffer. It will be useful // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is // desirable for the runtime to choose the layout. - ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, - ifrt::ArrayCopySemantics::kReuseInput); + ready_ = ifrt_array + ->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput) + .ToStatusFuture(); // Make sure the destination of the copy remains alive until the copy is done. value_.inc_ref(); ready_.OnReady([array{value_.ptr()}](Status status) { diff --git a/third_party/xla/xla/python/py_executable.cc b/third_party/xla/xla/python/py_executable.cc index 41a1a9d71e4d64..32f2025cb61984 100644 --- a/third_party/xla/xla/python/py_executable.cc +++ b/third_party/xla/xla/python/py_executable.cc @@ -232,11 +232,10 @@ absl::StatusOr ExecuteShardedOnLocalDevicesInternal( // attach_status_to_results is only supposed to be true when the computation // has tokens. if (attach_status_to_results) { - result_status = PjRtFuture<>::FromStatusFuture(result.status); + result_status = result.status; } if (returned_futures.has_value()) { - returned_futures->resize(num_computations, PjRtFuture<>::FromStatusFuture( - std::move(result.status))); + returned_futures->resize(num_computations, std::move(result.status)); } } diff --git a/third_party/xla/xla/python/util.cc b/third_party/xla/xla/python/util.cc index 9db18e31534dbf..d03de01d88f341 100644 --- a/third_party/xla/xla/python/util.cc +++ b/third_party/xla/xla/python/util.cc @@ -30,11 +30,11 @@ limitations under the License. namespace xla { Status AwaitBuffersReady(absl::Span ifrt_arrays) { - ifrt::Future future; + ifrt::Future<> future; if (ifrt_arrays.size() == 1) { future = ifrt_arrays[0]->GetReadyFuture(); } else { - std::vector> futures; + std::vector> futures; futures.reserve(ifrt_arrays.size()); for (ifrt::Array* const ifrt_array : ifrt_arrays) { futures.push_back(ifrt_array->GetReadyFuture()); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 29ac1c3e1f4aa7..7e98e210186449 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2943,6 +2943,7 @@ cc_library( ":ir_emission_utils", ":gpu_fused_mha_runner", ":cublas_cudnn", + ":stream_executor_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -5823,6 +5824,7 @@ cc_library( hdrs = ["stream_attribute_annotator.h"], deps = [ ":backend_configs_cc", + ":gpu_fusible", "//xla:comparison_util", "//xla:status", "//xla:statusor", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 61335f07f0497a..7f225008337be8 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -228,6 +228,16 @@ message CudnnfMHABackendConfig { // Is causal mask bool is_causal_mask = 21; + + // mask type + enum MaskType { + NO_MASK = 0; + PADDING = 1; + CAUSAL = 2; + PADDING_CAUSAL = 3; + ALIBI = 4; + } + MaskType mask_type = 22; } // Generic backend config for XLA:GPU diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.h b/third_party/xla/xla/service/gpu/cublas_cudnn.h index c79e76b9a72b3c..08f41affd2e911 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.h +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.h @@ -79,6 +79,14 @@ enum class CudnnfMHAKind { kBackwardScaleBiasSoftmaxDropout, }; +enum class CudnnfMHAMaskKind { + kNoMask, + kPadding, + kCausal, + kPaddingCausal, + kAlibi, +}; + absl::StatusOr GetCudnnConvKind( const HloCustomCallInstruction* instr); diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 4e4007847ee436..636788795049b4 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -1451,7 +1451,8 @@ absl::StatusOr FuseFwdMultiHeadedAttentionBlock( fmha_config.set_is_flash_attention(is_flash_attention); // set is_causal_mask here // choose to generate causal mask inside cuDNN attention or not - fmha_config.set_is_causal_mask(is_causal_mask); + fmha_config.set_mask_type(is_causal_mask ? CudnnfMHABackendConfig::CAUSAL + : CudnnfMHABackendConfig::NO_MASK); // Output Order: {O, scratch, Fwd act*} const Shape& output_shape = bmm_2->shape(); @@ -1595,7 +1596,8 @@ absl::StatusOr FuseBwdMultiHeadedAttentionBlock( fwd_fmha_call->backend_config()); CudnnfMHABackendConfig fwd_config = gpu_config.cudnn_fmha_backend_config(); bool is_flash_attention = fwd_config.is_flash_attention(); - bool is_causal_mask = fwd_config.is_causal_mask(); + bool is_causal_mask = + fwd_config.mask_type() == CudnnfMHABackendConfig::CAUSAL; CudnnfMHABackendConfig bwd_fmha_config; // Q tensor TF_ASSIGN_OR_RETURN( @@ -1705,7 +1707,9 @@ absl::StatusOr FuseBwdMultiHeadedAttentionBlock( // Set is flash attention bwd_fmha_config.set_is_flash_attention(is_flash_attention); - bwd_fmha_config.set_is_causal_mask(is_causal_mask); + bwd_fmha_config.set_mask_type(is_causal_mask + ? CudnnfMHABackendConfig::CAUSAL + : CudnnfMHABackendConfig::NO_MASK); *bwd_fmha_config.mutable_intermediate_tensor_shape() = fwd_config.intermediate_tensor_shape(); diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index 345ea0206eea4b..9741e3250f0831 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -1877,7 +1877,6 @@ ENTRY main.129 { EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 1); } - TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1ScaleBiasSoftmaxDropoutBmm2DbiasShouldHaveUserShape) { if (skip_reason_) GTEST_SKIP() << *skip_reason_; @@ -3046,7 +3045,7 @@ ENTRY main.92 { EXPECT_EQ(bwd_fmha->operands().size(), 6); EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); EXPECT_EQ(config.is_flash_attention(), true); - EXPECT_EQ(config.is_causal_mask(), true); + EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL); } TEST_F(CudnnFusedMhaRewriterTestHloTest, @@ -3154,7 +3153,7 @@ ENTRY main.92 { EXPECT_EQ(fmha->operands().size(), 7); EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); EXPECT_EQ(config.is_flash_attention(), true); - EXPECT_EQ(config.is_causal_mask(), false); + EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK); } TEST_F(CudnnFusedMhaRewriterTestHloTest, @@ -3257,10 +3256,9 @@ ENTRY main.92 { EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); EXPECT_FLOAT_EQ(config.fmha_scale(), 2); EXPECT_EQ(config.is_flash_attention(), true); - EXPECT_EQ(config.is_causal_mask(), false); + EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK); } - // GPT3 pattern TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionBF16TrainingGPT3_5B) { if (skip_reason_) GTEST_SKIP() << *skip_reason_; @@ -3845,7 +3843,7 @@ main { fwd_instruction->backend_config()); const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.is_flash_attention(), true); - EXPECT_EQ(config.is_causal_mask(), true); + EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL); } TEST_F(CudnnFusedMhaRewriterTestHloTest, diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc index 7c355dd4283e6e..b9e03c2f2ffc24 100644 --- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/stream_executor/cuda/cuda_dnn.h" #include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" #include "xla/util.h" @@ -102,11 +103,12 @@ absl::StatusOr HloCustomCallToCuDnnGraph( Shape q_shape = custom_call->operand(0)->shape(); Shape k_shape = custom_call->operand(1)->shape(); Shape v_shape = custom_call->operand(2)->shape(); - + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); GpufMHADescriptor descriptor = {kind, config, config.is_flash_attention(), - config.is_causal_mask(), + cudnn_mask_type, q_shape, k_shape, v_shape, @@ -119,6 +121,9 @@ absl::StatusOr HloCustomCallToCuDnnGraph( TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, GpufMHAConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); TF_ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionOperationGraph( @@ -127,7 +132,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( fmha_config.mask, fmha_config.activation, static_cast(*fmha_config.fmha_scale), fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.dropout_rate, fmha_config.is_causal_mask)); + fmha_config.dropout_rate, dnn_mask_type)); return std::move(graph); } else { TF_ASSIGN_OR_RETURN( @@ -184,12 +189,13 @@ absl::StatusOr HloCustomCallToCuDnnGraph( std::optional d_s_shape; std::optional d_bias_shape; TF_RET_CHECK(output_index == custom_call->shape().tuple_shapes().size()); - + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); GpufMHABackwardDescriptor descriptor = { kind, config, is_flash_attention, - config.is_causal_mask(), + cudnn_mask_type, bmm1_grad_gemm1_rhs_shape, bmm1_grad_gemm2_rhs_shape, bmm2_grad_gemm1_lhs_shape, @@ -210,6 +216,9 @@ absl::StatusOr HloCustomCallToCuDnnGraph( TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, GpufMHABackwardConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); TF_ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( @@ -221,7 +230,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( fmha_config.seed, *fmha_config.fmha_scale, fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, fmha_config.mask != std::nullopt, fmha_config.bias != std::nullopt, - fmha_config.is_causal_mask)); + dnn_mask_type)); return std::move(graph); } } diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index 3ea1441aaf3b30..abab64dfb8064c 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -158,6 +158,13 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( auto updated_value = ProvideParameter(dus_subgraph, dus_instr, kDUSUpdateIndex, input_indices, call_targets, entry_function, b)[0]; + // Handle bitcasts under the DUS. + if (dus_instr->shape() != fusion.shape()) { + update_indices = ApplyAffineMap( + GetBitcastMap(dus_instr->shape(), fusion.shape(), b.getContext()) + .GetAffineMap(), + update_indices, {}, b); + } auto insert = b.create(updated_value, output_tensors[0], update_indices); diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index 7bcc2e43175626..b9cd940c102295 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -179,6 +179,29 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } +TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, BitcastDus) { + auto kHloString = R"( + HloModule module + + fused_computation { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1) + ROOT bitcast = f32[600] bitcast(updated) + } + ENTRY entry { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] constant(2) + i1 = s32[] constant(3) + ROOT fusion = f32[600] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index f08275b1646844..997ee83be9a228 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -159,6 +159,7 @@ cc_library( ":type_util", "//xla:shape_util", "//xla:status_macros", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", @@ -189,11 +190,11 @@ cc_library( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 3c1907f8d17661..90b1ea30903a44 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -1055,14 +1055,8 @@ bool IsHloOpSupported(const HloInstruction* instr, se::CudaComputeCapability compute_capability) { auto is_unsupported_type = [](const HloInstruction* instr) { auto e = instr->shape().element_type(); - // TODO(akuegel): Fix remaining issues with complex. // TODO(jreiffers): Support fp8. - // TODO(jreiffers): Support int4. - return (primitive_util::IsIntegralType(e) && - primitive_util::BitWidth(e) > 1 && - primitive_util::BitWidth(e) < 8) || - primitive_util::IsComplexType(e) || - (primitive_util::IsFloatingPointType(e) && + return (primitive_util::IsFloatingPointType(e) && primitive_util::BitWidth(e) < 16); }; if (is_unsupported_type(instr) || diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc index 11467a37effa50..cee198a44dc726 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/strings/str_cat.h" @@ -38,11 +39,13 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -117,11 +120,8 @@ struct RewriteFunctionSignatures : mlir::OpRewritePattern { } }; -mlir::LLVM::GEPOp CreateGep(mlir::Operation* op, - mlir::TypedValue tensor, - ValueRange indices, - mlir::PatternRewriter& rewriter) { - auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); +Value GetLinearIndex(mlir::TypedValue tensor, + ValueRange indices, mlir::PatternRewriter& rewriter) { auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape()); if (auto encoding = tensor.getType().getEncoding()) { *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( @@ -134,40 +134,81 @@ mlir::LLVM::GEPOp CreateGep(mlir::Operation* op, mlir::getAffineDimExpr(dim, rewriter.getContext()) * stride; } - rewriter.setInsertionPoint(op); Value index = rewriter.create( tensor.getLoc(), linearize_map, indices); - auto index_ty = - ShapeUtil::ElementsIn(byte_shape) < std::numeric_limits::max() - ? rewriter.getI32Type() - : rewriter.getI64Type(); - index = rewriter.create(tensor.getLoc(), index_ty, - index); + auto index_ty = rewriter.getIntegerType( + mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp()) + .getTypeSizeInBits(index.getType())); + return rewriter.create(tensor.getLoc(), index_ty, + index); +} +std::tuple GetI4IndexAndNibble(Value linear_index, + mlir::ImplicitLocOpBuilder& b) { + Value one = b.create(1, linear_index.getType()); + Value is_low_nibble = b.create( + mlir::arith::CmpIPredicate::eq, one, + b.create(linear_index, one)); + Value i8_index = b.create(linear_index, one); + return {i8_index, is_low_nibble}; +} + +mlir::LLVM::GEPOp CreateGep(mlir::TypedValue tensor, + Value linear_index, mlir::PatternRewriter& rewriter, + mlir::Type element_type = nullptr) { + if (!element_type) { + element_type = tensor.getType().getElementType(); + } + auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); auto tensor_ptr = rewriter .create( tensor.getLoc(), ptr, tensor) .getResult(0); mlir::LLVMTypeConverter converter(rewriter.getContext()); - auto llvm_element_type = - converter.convertType(tensor.getType().getElementType()); + auto llvm_element_type = converter.convertType(element_type); auto gep = rewriter.create( - tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, index); + tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index); gep.setInbounds(true); return gep; } +mlir::LLVM::GEPOp CreateGep(mlir::TypedValue tensor, + ValueRange indices, + mlir::PatternRewriter& rewriter) { + return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter); +} + struct RewriteTensorExtract : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite( mlir::tensor::ExtractOp op, mlir::PatternRewriter& rewriter) const override { - auto gep = CreateGep(op, op.getTensor(), op.getIndices(), rewriter); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto linear_index = + GetLinearIndex(op.getTensor(), op.getIndices(), rewriter); + mlir::Type element_type = op.getTensor().getType().getElementType(); + Value is_low_nibble = nullptr; + if (element_type == rewriter.getI4Type()) { + element_type = rewriter.getI8Type(); + std::tie(linear_index, is_low_nibble) = + GetI4IndexAndNibble(linear_index, b); + } + + auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type); auto load = rewriter .create(gep.getLoc(), gep.getElemType(), gep) .getResult(); + + if (is_low_nibble) { + auto high_value = b.create( + load, b.create(4, load.getType())); + load = b.create( + op.getType(), + b.create(is_low_nibble, load, high_value)); + } + rewriter.replaceOpWithNewOp( op, op.getType(), load); return success(); @@ -199,10 +240,39 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { } } - auto gep = - CreateGep(op, dest.cast>(), - op.getIndices(), rewriter); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto tensor_dest = dest.cast>(); + auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter); + auto element_type = tensor_dest.getType().getElementType(); + Value is_low_nibble = nullptr; + + if (element_type == rewriter.getI4Type()) { + element_type = rewriter.getI8Type(); + std::tie(linear_index, is_low_nibble) = + GetI4IndexAndNibble(linear_index, b); + } + + auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type); auto scalar_value = op.getScalar(); + + if (is_low_nibble) { + Value current_value = + b.create(gep.getElemType(), gep); + auto ty = current_value.getType(); + scalar_value = b.create(ty, scalar_value); + Value low_updated = b.create( + b.create( + current_value, b.create(0xf0, ty)), + scalar_value); + Value high_updated = b.create( + b.create( + current_value, b.create(0x0f, ty)), + b.create( + scalar_value, b.create(4, ty))); + scalar_value = b.create(is_low_nibble, low_updated, + high_updated); + } + mlir::LLVMTypeConverter converter(getContext()); auto llvm_type = converter.convertType(scalar_value.getType()); scalar_value = rewriter @@ -251,7 +321,14 @@ struct RewriteAllocateShared : mlir::OpRewritePattern { auto module = op->getParentOfType(); auto shaped_ty = op.getResult().getType().cast(); constexpr int kGPUSharedMemoryAddrSpace = 3; - auto array_ty = mlir::LLVM::LLVMArrayType::get(shaped_ty.getElementType(), + mlir::Type element_type = shaped_ty.getElementType(); + if (auto complex_ty = mlir::dyn_cast(element_type)) { + element_type = mlir::LLVM::LLVMStructType::getLiteral( + getContext(), + {complex_ty.getElementType(), complex_ty.getElementType()}); + } + + auto array_ty = mlir::LLVM::LLVMArrayType::get(element_type, shaped_ty.getNumElements()); std::string name; @@ -335,7 +412,7 @@ struct RewriteAtomicRMW : mlir::OpRewritePattern { mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size); // Calculate load address for the input. - Value addr = CreateGep(op, input, op.getIndices(), rewriter); + Value addr = CreateGep(input, op.getIndices(), rewriter); Value shift, mask; if (small_type) { // Update input pointer by discarding the last two bits - i.e. align to diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc index 625c9f8794b484..2c2d458cab6f67 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc @@ -52,7 +52,10 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase { void runOnOperation() override { // Populate type conversions. - mlir::LLVMTypeConverter type_converter(getOperation().getContext()); + mlir::LowerToLLVMOptions llvm_opts(&getContext(), + mlir::DataLayout(getOperation())); + mlir::LLVMTypeConverter type_converter(getOperation().getContext(), + llvm_opts); mlir::LLVMConversionTarget target(*getOperation().getContext()); // Populate patterns. diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc index aab493ec184e09..29e9887caa752f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project @@ -25,6 +26,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -103,41 +105,52 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { mlir::ValueRange values = op.getOperands(); for (int distance = max_distance; distance > 0; distance /= 2) { namespace ml = mlir::LLVM; - auto shuffle = [&](mlir::Value v) { + auto shuffle_32 = [&](mlir::Value v) { return b .create(v, distance, WarpSize(), mlir::gpu::ShuffleMode::DOWN) .getShuffleResult(); }; - llvm::SmallVector args = values; - for (auto value : values) { - // Shuffle within the warps. + auto shuffle_int_or_float = [&](mlir::Value value) { auto ty = value.getType(); int bit_width = ty.getIntOrFloatBitWidth(); - if (bit_width == 32) { - value = shuffle(value); - } else { - int n_shuffles = CeilOfRatio(bit_width, 32); - auto int_ty = b.getIntegerType(bit_width); - auto padded_int_ty = b.getIntegerType(n_shuffles * 32); - value = b.create(int_ty, value); - value = b.create(padded_int_ty, value); - auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles); - value = b.create(vector_type, value); - mlir::Value result_vec = b.create(vector_type); - for (int i = 0; i < n_shuffles; ++i) { - auto idx = b.create(i, 32); - result_vec = b.create( - result_vec, shuffle(b.create(value, idx)), - idx); - } - value = b.create(padded_int_ty, result_vec); - value = b.create(int_ty, value); - value = b.create(ty, value); + return shuffle_32(value); } - args.push_back(value); + int n_shuffles = CeilOfRatio(bit_width, 32); + auto int_ty = b.getIntegerType(bit_width); + auto padded_int_ty = b.getIntegerType(n_shuffles * 32); + value = b.create(int_ty, value); + value = b.create(padded_int_ty, value); + auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles); + value = b.create(vector_type, value); + mlir::Value result_vec = b.create(vector_type); + for (int i = 0; i < n_shuffles; ++i) { + auto idx = b.create(i, 32); + result_vec = b.create( + result_vec, + shuffle_32(b.create(value, idx)), idx); + } + value = b.create(padded_int_ty, result_vec); + value = b.create(int_ty, value); + value = b.create(ty, value); + return value; + }; + + auto shuffle = [&](mlir::Value value) -> mlir::Value { + if (value.getType().isa()) { + return b.create( + value.getType(), + shuffle_int_or_float(b.create(value)), + shuffle_int_or_float(b.create(value))); + } + return shuffle_int_or_float(value); + }; + + llvm::SmallVector args = values; + for (auto value : values) { + args.push_back(shuffle(value)); } values = b.create(op.getReducerAttr().getAttr(), op.getResultTypes(), args) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 9bd5f8a1d8fdfe..fbd479552aa9be 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project +#include "mlir/Dialect/DLTI/DLTI.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project @@ -92,6 +93,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -141,6 +143,21 @@ void AddRanges(llvm::Function* func, const LaunchDimensions& launch_dims, } } +bool Needs64Bits(const Shape& shape) { + return shape.IsArray() ? !IsInt32(ShapeUtil::ElementsIn(shape)) + : absl::c_any_of(shape.tuple_shapes(), Needs64Bits); +} + +bool Needs64BitIndices(const HloComputation* computation) { + for (auto* instr : computation->instructions()) { + if (Needs64Bits(instr->shape()) || + absl::c_any_of(instr->called_computations(), Needs64BitIndices)) { + return true; + } + } + return false; +} + } // namespace Value MlirFusionEmitterBase::EmitBlockId(mlir::ImplicitLocOpBuilder& builder, @@ -302,12 +319,12 @@ MlirFusionEmitterBase::CreateMLIRModule( mlir::MLIRContext& context, const HloFusionInstruction& fusion, const std::string& entry_function_name, const BufferAssignment* buffer_assignment) const { - context.loadDialect(); + context.loadDialect(); mlir::DialectRegistry registry; mlir::func::registerInlinerExtension(registry); mlir::registerBuiltinDialectTranslation(registry); @@ -451,6 +468,15 @@ absl::Status MlirFusionEmitterBase::EmitMlir( *epilogue, subgraph_to_mlir_fn[&*epilogue], call_targets)); } + int index_bitwidth = + Needs64BitIndices(fusion.fused_instructions_computation()) ? 64 : 32; + mlir::OpBuilder b(module->getContext()); + auto index_layout = mlir::DataLayoutEntryAttr::get( + b.getIndexType(), b.getI32IntegerAttr(index_bitwidth)); + module->setAttr( + mlir::DLTIDialect::kDataLayoutAttrName, + mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout})); + return EmitEntryFunction(computations, call_targets, entry_function, fusion); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index f3f699a8b854e5..759d625d02b4ba 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -170,11 +170,9 @@ TEST_F(MlirFusionEmitterTest, CreateLLVMModule) { TF_ASSERT_OK_AND_ASSIGN(auto filecheck_result, RunFileCheck(out, R"( // CHECK: define void @fusion(ptr noalias %[[IN:.*]], ptr noalias %[[OUT:.*]]) // CHECK: %[[TID:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() - // CHECK: %[[EXT:.*]] = sext i32 %[[TID]] to i64 - // CHECK: %[[TRUNC:.*]] = trunc i64 %[[EXT]] to i32 - // CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TRUNC]] + // CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TID]] // CHECK: %[[VAL:.*]] = load float, ptr %[[IN_PTR]], align 4 - // CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TRUNC]] + // CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TID]] // CHECK: store float %[[VAL]], ptr %[[OUT_PTR]], align 4 // CHECK: ret void )")); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD index c4a8230a2966b4..f6e4c28d00b47a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD @@ -16,6 +16,7 @@ xla_cc_binary( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:GPUDialect", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir index b508d2cb84dd32..1687d88a74a30c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -1,6 +1,6 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors | FileCheck %s -module { +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry>} { func.func private @add(%arg0: f32, %arg1: f32) -> f32 { %sum = arith.addf %arg0, %arg1 : f32 func.return %sum : f32 @@ -72,7 +72,7 @@ module { // CHECK: @layout(%[[ARG0:.*]]: !llvm.ptr, // CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]](%[[X]], %[[Y]]) -// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i32 +// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64 // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] // CHECK: llvm.load %[[PTR]] @@ -110,7 +110,7 @@ module { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i32 +// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i64 // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]] // CHECK: llvm.store {{.*}}, %[[PTR]] // CHECK: %[[INBOUNDS:.*]] = arith.cmpi @@ -309,3 +309,34 @@ module { // CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} // CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]] // CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]] + +// ----- + +module { + func.func @shared_complex() -> tensor<10xcomplex> { + %shared = xla_gpu.allocate_shared : tensor<10xcomplex> + return %shared : tensor<10xcomplex> + } +} + +// CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>> +// CHECK: @shared_complex + +// ----- + +module { + func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> { + %v = tensor.extract %arg[%i] : tensor<10xi4> + %r = tensor.insert %v into %arg[%j] : tensor<10xi4> + return %r : tensor<10xi4> + } +} + +// CHECK: @i4_load_store +// CHECK: llvm.getelementptr +// CHECK-SAME: -> !llvm.ptr, i8 +// CHECK: llvm.load +// CHECK: llvm.getelementptr +// CHECK-SAME: -> !llvm.ptr, i8 +// CHECK: llvm.load +// CHECK: llvm.store diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir index 9b1d0b20fe894c..79923cde8b3d04 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir @@ -46,6 +46,22 @@ module { // ----- +module { + func.func @reducer(%a: complex, %b: complex) -> complex { + return %a : complex + } + + func.func @shuffler(%a: complex) -> complex { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex + return %ret : complex + } +} + +// CHECK: @shuffler +// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]] + +// ----- + module { func.func @predicated_insert( %v: i32, %tensor: tensor<2xi32>, %index: index, diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc index 3f77b3336b2b7f..37629b4faae153 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/DLTI/DLTI.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project @@ -32,12 +33,13 @@ limitations under the License. int main(int argc, char **argv) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); mlir::func::registerAllExtensions(registry); mlir::registerCanonicalizerPass(); mlir::registerCSEPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 06b81c967b4172..28ac6c70bdbd21 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -283,7 +283,6 @@ absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( auto output_indexing = *ComputeThreadIdToOutputIndexing(0, mlir_context); auto shmem_output_indexing = GetSharedMemoryReadIndexingMap(output_indexing, permutation_[2]); - std::cerr << "output indexing: " << output_indexing.ToString() << "\n"; auto result_tensors = EmitThreadLoopNest( builder, output_tensor_args, output_indexing, [&](ValueRange output_tensors, ValueRange dim_values, diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc index 3148ea95df00f5..3dc4763bae91dd 100644 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc +++ b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc @@ -80,7 +80,9 @@ absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, if (params.config->seed) { seed = *params.config->seed; } - + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(params.config->mask_type)); se::dnn::FusedMHAOp::Config config{kind, scale, params.config->lhs_bmm1, @@ -94,7 +96,7 @@ absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, dropout_rate, seed, params.config->is_flash_attention, - params.config->is_causal_mask}; + mask_type}; TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); return (*runner)(stream, options.profile_result, scratch_memory, @@ -208,6 +210,10 @@ absl::Status RunFusedMHABackward( if (params.config->seed) { seed = *params.config->seed; } + + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(params.config->mask_type)); se::dnn::FusedMHABackwardOp::Config config{kind, scale, params.config->bmm1_grad_gemm1_rhs, @@ -226,7 +232,7 @@ absl::Status RunFusedMHABackward( dropout_rate, seed, params.config->is_flash_attention, - params.config->is_causal_mask}; + mask_type}; TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); // TODO: pass in real softmax_sum, dQ_accum, fwd_output @@ -420,7 +426,7 @@ absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, } config.kind = desc.kind; config.is_flash_attention = desc.is_flash_attention; - config.is_causal_mask = desc.is_causal_mask; + config.mask_type = desc.mask_type; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); config.fmha_scale.emplace(backend_config.fmha_scale()); @@ -563,7 +569,7 @@ absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, config.kind = desc.kind; config.is_flash_attention = desc.is_flash_attention; - config.is_causal_mask = desc.is_causal_mask; + config.mask_type = desc.mask_type; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); config.fmha_scale.emplace(backend_config.fmha_scale()); diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h index 5a4274e950988e..5466e61ff627a6 100644 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h @@ -38,6 +38,24 @@ limitations under the License. namespace xla { namespace gpu { +inline absl::StatusOr AsCudnnFmhaMaskKind( + xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) { + switch (mask_type) { + case xla::gpu::CudnnfMHABackendConfig::NO_MASK: + return xla::gpu::CudnnfMHAMaskKind::kNoMask; + case xla::gpu::CudnnfMHABackendConfig::PADDING: + return xla::gpu::CudnnfMHAMaskKind::kPadding; + case xla::gpu::CudnnfMHABackendConfig::CAUSAL: + return xla::gpu::CudnnfMHAMaskKind::kCausal; + case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL: + return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal; + case xla::gpu::CudnnfMHABackendConfig::ALIBI: + return xla::gpu::CudnnfMHAMaskKind::kAlibi; + default: + return xla::Internal("Unknown fmha mask kind."); + } +} + // This is an interim structure to hold the parameters to construct a // GpufMHAConfig. // Struct to describe properties of a FMHA without being tied to specific @@ -47,7 +65,7 @@ struct GpufMHADescriptor { CudnnfMHAKind kind; CudnnfMHABackendConfig backend_config; bool is_flash_attention; - bool is_causal_mask; + CudnnfMHAMaskKind mask_type; Shape lhs_bmm1_shape; Shape rhs_bmm1_shape; Shape rhs_bmm2_shape; @@ -65,7 +83,7 @@ struct GpufMHABackwardDescriptor { CudnnfMHAKind kind; CudnnfMHABackendConfig backend_config; bool is_flash_attention; - bool is_causal_mask; + CudnnfMHAMaskKind mask_type; Shape bmm1_grad_gemm1_rhs_shape; Shape bmm1_grad_gemm2_rhs_shape; Shape bmm2_grad_gemm1_lhs_shape; @@ -99,7 +117,7 @@ struct GpufMHAConfig { se::dnn::AlgorithmDesc algorithm; bool is_flash_attention; - bool is_causal_mask; + CudnnfMHAMaskKind mask_type; // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] // mask -> [batch_size, 1, q_seq_len, kv_seq_len] se::dnn::MatmulTensorDescriptor lhs_bmm1; @@ -128,7 +146,7 @@ struct GpufMHABackwardConfig { se::dnn::AlgorithmDesc algorithm; bool is_flash_attention; - bool is_causal_mask; + CudnnfMHAMaskKind mask_type; // mask -> [batch_size, 1, q_seq_len, kv_seq_len] // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index a8fe6df7a8b09f..7ff92e7207ad62 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -69,6 +70,10 @@ namespace { namespace m = ::xla::match; class TritonTest : public GpuCodegenTest { + const auto& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } + public: se::CudaComputeCapability GetCudaComputeCapability() { return backend() @@ -76,6 +81,27 @@ class TritonTest : public GpuCodegenTest { ->GetDeviceDescription() .cuda_compute_capability(); } + + const se::GpuComputeCapability& GpuComputeComp() { + return device_desc().gpu_compute_capability(); + } + + bool SkipBF16Tests() { + if (std::holds_alternative(GpuComputeComp())) { + auto rcc = device_desc().rocm_compute_capability(); + return !rcc.has_bf16_dtype_support(); + } + return false; + } + + se::GpuComputeCapability CudaAmpereOrRocm() { + if (std::holds_alternative(GpuComputeComp())) { + return se::GpuComputeCapability{device_desc().rocm_compute_capability()}; + } else { + return se::GpuComputeCapability{ + se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}}; + } + } }; class TritonGemmTest : public TritonTest { @@ -839,6 +865,9 @@ TEST_F(TritonFilecheckTest, NestedReducerFusionGetsCodegenedCorrectly) { se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() << "Doesn't pass on pre-Ampere GPUs."; } + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule softmax @@ -1268,6 +1297,9 @@ CHECK: mma } TEST_F(TritonGemmTest, FailIfTooMuchShmem) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } const std::string kHloText = R"( HloModule module, is_scheduled=true @@ -1300,9 +1332,7 @@ ENTRY entry { TritonGemmConfig config(16, 32, 512, 1, 4, 8); EXPECT_THAT( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + "test_fn", triton_dot_computation, CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context), tsl::testing::StatusIs( tsl::error::RESOURCE_EXHAUSTED, @@ -1315,9 +1345,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN( const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + "test_fn", triton_dot_computation, CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context)); // Use optin shared memory which is > shared_memory_per_block. EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); @@ -1348,7 +1376,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); // Not doing a comparison here, because the input matrices are quite big. @@ -1374,7 +1402,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1399,7 +1427,7 @@ ENTRY e { ; CHECK-NEXT: ROOT ; CHECK-SAME: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": ; CHECK-NOT: pad ; CHECK-NOT: slice )"); @@ -1426,7 +1454,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/0, /*arel=*/0})); @@ -1453,7 +1481,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1530,7 +1558,7 @@ ENTRY e { ; CHECK: transpose ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1555,7 +1583,7 @@ ENTRY e { ; CHECK: transpose ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1580,7 +1608,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); @@ -1605,7 +1633,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1631,7 +1659,7 @@ ENTRY e { ; CHECK-NEXT: parameter ; CHECK-NEXT: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-2})); @@ -1656,7 +1684,7 @@ ENTRY e { ; CHECK: f32[5,3,4]{2,1,0} bitcast ; CHECK: fusion ; CHECK-SAME: kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); @@ -1734,6 +1762,9 @@ ENTRY e { } TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipU8) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } const std::string hlo_text = R"( HloModule t @@ -1752,6 +1783,9 @@ ENTRY e { } TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipF32F32) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; + } const std::string hlo_text = R"( HloModule t @@ -1803,9 +1837,7 @@ ENTRY entry { TritonGemmConfig config(512, 512, 32, 1, 1, 2); EXPECT_THAT( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + "test_fn", triton_dot_computation, CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context), tsl::testing::StatusIs( tsl::error::RESOURCE_EXHAUSTED, @@ -1817,9 +1849,7 @@ ENTRY entry { config.block_k = 32; TF_CHECK_OK( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, - /*minor=*/0}, + "test_fn", triton_dot_computation, CudaAmpereOrRocm(), dev_info, config, &llvm_module, &EmitMatMul, mlir_context) .status()); } @@ -1877,11 +1907,14 @@ ENTRY e { // multiple times and assign block sizes on success. R"( ; CHECK: f16[77,99,111]{2,1,0} transpose -; CHECK: block_m +; CHECK-PTX: block_m )"); } TEST_F(TritonGemmTest, SingleElementTileIsHandled) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } MatchOptimizedHlo(R"( t { p0 = f32[2,7,3]{2,1,0} parameter(0) @@ -1935,13 +1968,16 @@ ENTRY e { MatchOptimizedHlo(hlo_text, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } TEST_F(TritonGemmTestAny, DoAddConstantToScalarAndBroadcastThat) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not using autotuner on ROCM yet."; + } const std::string hlo_text = R"( HloModule t @@ -1978,7 +2014,7 @@ ENTRY e { ; CHECK: ENTRY ; CHECK: %[[p0:.*]] = pred[5,5]{1,0} parameter(0) ; CHECK: fusion(%[[p0]], %[[p0]]), kind=kCustom -; CHECK-SAME: "block_m": +; CHECK-PTX-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); @@ -1986,6 +2022,9 @@ ENTRY e { TEST_F(TritonGemmTestAny, DoNotFuseConcatenationOfSplitNonContractingDimension) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string hlo_text = R"( HloModule m @@ -2208,6 +2247,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( ENTRY e { c = s32[] constant(1) @@ -2253,6 +2295,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, AlwaysFuseScalarConstantAtBroadcastInput) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( ENTRY e { p0 = bf16[2,3,3]{2,1,0} parameter(0) @@ -2306,6 +2351,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, FuseConcatenation) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( e { p0 = s8[153,1536] parameter(0) @@ -2348,7 +2396,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2371,7 +2419,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2394,7 +2442,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2417,7 +2465,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -2441,7 +2489,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2466,7 +2514,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2491,7 +2539,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2516,7 +2564,7 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( ; CHECK: fusion( ; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK-PTX-SAME: block_m )"); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, @@ -2684,6 +2732,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -2713,6 +2764,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -2746,6 +2800,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -2770,6 +2827,9 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( ENTRY e { p0t = (s8[5,18,20,150]) parameter(0) @@ -3045,6 +3105,9 @@ ENTRY e { } TEST_F(CompareTest, BF16TransposedLHS) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const char* hlo_text_ref = R"( HloModule r @@ -3133,9 +3196,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN( const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, - GetCudaComputeCapability(), dev_info, triton_gemm_config, - &llvm_module, &EmitMatMul, mlir_context)); + "test_fn", triton_dot_computation, GpuComputeComp(), + dev_info, triton_gemm_config, &llvm_module, &EmitMatMul, + mlir_context)); // The config is chosen so that the used memory size is slightly above the // 48 kB boundary of standard / optin shared memory so that any GPU that // has the optin one should be able to execute the test. @@ -3250,6 +3313,9 @@ ENTRY e { } TEST_F(CompareTest, S8BF16) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const char* hlo_text_ref = R"( HloModule r @@ -3297,6 +3363,9 @@ ENTRY e { } TEST_F(CompareTest, SplitK) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string hlo_text_ref = R"( HloModule t, is_scheduled=true @@ -3370,6 +3439,9 @@ ENTRY e { } TEST_F(CompareTest, SplitKBatch) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloTextRef = R"( HloModule m, is_scheduled=true @@ -3432,6 +3504,9 @@ ENTRY e { } TEST_F(CompareTest, SplitKNontrivialBitcast) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloTextRef = R"( HloModule module, is_scheduled=true @@ -4006,6 +4081,9 @@ ENTRY e { } TEST_F(CompareTest, PredToBF16ConversionWorks) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloTextTest = R"( HloModule m, is_scheduled=true @@ -4126,6 +4204,9 @@ class TritonGemmContractionDims : public TritonGemmTest { }; TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4148,6 +4229,9 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4170,6 +4254,9 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4193,6 +4280,9 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } const std::string kHloText = R"( HloModule m @@ -4232,6 +4322,13 @@ class Triton6xBF16GemmTest : public TritonFilecheckTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); return debug_options; } + + protected: + void SetUp() override { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } + } }; // In these tests, we depend on debug option flags for selecting the 6XBF16 @@ -4577,6 +4674,13 @@ class Triton3xBF16GemmTestWithFlag : public TritonFilecheckTest { debug_options.set_xla_gpu_enable_bf16_3way_gemm(true); return debug_options; } + + protected: + void SetUp() override { + if (SkipBF16Tests()) { + GTEST_SKIP() << "BF16 not supported."; + } + } }; TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index e02f4c4166b617..adcf09a0546003 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1008,11 +1008,12 @@ absl::Status IrEmitterUnnested::EmitFusedMHAThunk( if (has_activation) { output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {2})); } - + TF_ASSIGN_OR_RETURN(const auto mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); GpufMHADescriptor descriptor = {kind, config, config.is_flash_attention(), - config.is_causal_mask(), + mask_type, lhs_bmm1->shape(), rhs_bmm1->shape(), rhs_bmm2->shape(), @@ -1157,12 +1158,13 @@ absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); } TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); - + TF_ASSIGN_OR_RETURN(const auto mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); GpufMHABackwardDescriptor descriptor = { kind, config, is_flash_attention, - config.is_causal_mask(), + mask_type, bmm1_grad_gemm1_rhs_shape, bmm1_grad_gemm2_rhs_shape, bmm2_grad_gemm1_lhs_shape, diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 5c59a7e9acc923..633974e4e1c293 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -49,6 +49,31 @@ limitations under the License. namespace xla { namespace gpu { +// Execution stream id allows to specify what Gpu stream Thunk should be using +// for launching device work (kernels, library calls, etc.). By default all +// thunks use stream #0, which is the default compute stream of an XLA +// executable. +// +// Stream synchronizations are explicit and represented as WaitForStreams thunk +// in a ThunkSequence. When ThunkSequence converted to CommandBuffer, execution +// streams mapped to concurrent execution scopes and barriers between them. +// +// IMPORTANT: Async execution semantics and execution stream id +// +// For async thunks (i.e. thunks corresponding to `all-reduce-start` and +// `all-reduce-done`) execution stream id means NOT a stream where the async +// operation must execute, but a stream that async operation must be +// synchronized with: +// +// - Start operation must wait for the completion of all launched work on the +// execution stream id (usually by adding a stream wait) and after that +// launch async work on implementation defined extra stream (can be borrowed +// from a pool) +// +// - Corresponding Done operation must synchronize execution stream id with +// an implementation defined stream that is running async work, again +// usually by adding a stream wait. +// TSL_LIB_GTL_DEFINE_INT_TYPE(ExecutionStreamId, uint64_t); // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc index 0bfa2cef837e2d..0b8d00984df7a3 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/statusor.h" #include "xla/util.h" @@ -102,6 +103,39 @@ absl::StatusOr AnnotateStreamAttributesForCopyStart( return true; } +absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( + HloInstruction* instruction, int64_t channel_id, + GpuBackendConfig& instr_gpu_config) { + auto* computation = instruction->parent(); + auto* module = computation->parent(); + auto* fusion_instruction = + computation->AddInstruction(HloInstruction::CreateFusion( + instruction->shape(), ChooseFusionKind(*instruction, *instruction), + instruction)); + const absl::string_view wrapped_opcode = + HloOpcodeString(instruction->opcode()); + module->SetAndUniquifyInstrName(fusion_instruction, + absl::StrCat("wrapped_", wrapped_opcode)); + module->SetAndUniquifyComputationName( + fusion_instruction->fused_instructions_computation(), + absl::StrCat("wrapped_", wrapped_opcode, "_computation")); + if (module->has_schedule()) { + module->schedule().replace_instruction(computation, instruction, + fusion_instruction); + } + TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction)); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + + instr_gpu_config.set_operation_queue_id(channel_id); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config)); + VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction " + << instruction->ToString(); + VLOG(3) << " Fusion wrapper: " << fusion_instruction->ToString(); + return true; +} + absl::StatusOr AnnotateStreamAttributesForUsers( HloInstruction* instr, GpuBackendConfig& instr_gpu_config) { bool changed = false; @@ -140,7 +174,8 @@ absl::StatusOr StreamAttributeAnnotator::Run( 5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString()); bool changed = false; int64_t channel_id = hlo_query::NextChannelId(*module); - for (const HloComputation* comp : module->computations(execution_threads)) { + for (const HloComputation* comp : + module->MakeComputationPostOrder(execution_threads)) { for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { auto instr_gpu_config = instr->backend_config(); if (!instr_gpu_config.ok()) { @@ -160,6 +195,14 @@ absl::StatusOr StreamAttributeAnnotator::Run( instr, channel_id, instr_gpu_config.value())); changed |= comp_result; continue; + } else if (comp->IsAsyncComputation() && + (instr->opcode() == HloOpcode::kDynamicSlice || + instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { + TF_ASSIGN_OR_RETURN(bool comp_result, + WrapIntoFusionAndAnnotateStreamAttributes( + instr, channel_id, instr_gpu_config.value())); + changed |= comp_result; + continue; } TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc index 2861f9a82a7ef1..17d9b2f1e212d7 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc @@ -208,5 +208,80 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { EXPECT_EQ(gpu_config.operation_queue_id(), 1); } } + +TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsyncDynamicUpdateSlice + + ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] { + param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0) + param_1 = f32[1,128,128]{2,1,0} parameter(1) + izero = s32[] constant(0) + dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[]) + dynamic-update-slice-start(param_0, param_1, izero, izero, izero) + ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)} + dynamic-update-slice-done(dynamic-update-slice-start.2) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + StreamAttributeAnnotator().Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the dynamic-update-slice instruction is wrapped in a fusion + // and the fusion is annotated with the correct operation_queue_id. + const HloInstruction* dus = + FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice); + const HloComputation* computation = dus->parent(); + EXPECT_TRUE(computation->IsFusionComputation()); + const HloInstruction* fusion = computation->FusionInstruction(); + EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(fusion->parent()->IsAsyncComputation()); + + EXPECT_TRUE(fusion->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + fusion->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); +} + +TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsyncDynamicSlice + + ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] { + param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0) + izero = s32[] constant(0) + dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[]) + dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128} + ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0} + dynamic-slice-done(dynamic-slice-start.2) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + StreamAttributeAnnotator().Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the dynamic-slice instruction is wrapped in a fusion + // and the fusion is annotated with the correct operation_queue_id. + const HloInstruction* ds = + FindInstruction(module.get(), HloOpcode::kDynamicSlice); + const HloComputation* computation = ds->parent(); + EXPECT_TRUE(computation->IsFusionComputation()); + const HloInstruction* fusion = computation->FusionInstruction(); + EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(fusion->parent()->IsAsyncComputation()); + + EXPECT_TRUE(fusion->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + fusion->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); +} } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 711e20fc3c1d56..16d813cfe64db8 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -530,6 +530,24 @@ absl::StatusOr GetDNNNormKindFromCudnnNormKind( } } +absl::StatusOr GetDNNFmhaMaskKindFromCudnnFmhaMaskKind( + CudnnfMHAMaskKind kind) { + switch (kind) { + case CudnnfMHAMaskKind::kNoMask: + return se::dnn::NO_MASK; + case CudnnfMHAMaskKind::kPadding: + return se::dnn::PADDING; + case CudnnfMHAMaskKind::kCausal: + return se::dnn::CAUSAL; + case CudnnfMHAMaskKind::kPaddingCausal: + return se::dnn::PADDING_CAUSAL; + case CudnnfMHAMaskKind::kAlibi: + return se::dnn::ALIBI; + default: + return Internal("Unexpected fmha mask kind"); + } +} + absl::StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( CudnnfMHAKind kind) { switch (kind) { diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index b52062ee314f55..fb901abaa08f60 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -127,6 +127,9 @@ absl::StatusOr GetDNNConvKindFromCudnnConvKind( absl::StatusOr GetDNNNormKindFromCudnnNormKind( CudnnNormKind kind); +absl::StatusOr GetDNNFmhaMaskKindFromCudnnFmhaMaskKind( + CudnnfMHAMaskKind kind); + absl::StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( CudnnfMHAKind kind); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 9c65e73117a051..fc0eddb2c500f6 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -6345,9 +6345,11 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional mask_descriptor, const std::optional stats_descriptor, const float scale, const bool use_dropout, - const std::optional dropout_rate, const bool is_causal_mask) { + const std::optional dropout_rate, + const dnn::FMHAMaskKind mask_type) { using cudnn_frontend::graph::Tensor_attributes; +#if CUDNN_VERSION >= 8904 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_lhs(q): " << q_descriptor.ToString() << "\n bmm1_rhs(k): " << k_descriptor.ToString() @@ -6400,10 +6402,12 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_uid(CudnnfMHAUid::V_ID)); // Setting sdpa, and is_inference + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || + mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; cudnn_frontend::graph::SDPA_attributes sdpa_options; sdpa_options.set_name("flash_attention") .set_is_inference(stats_descriptor == std::nullopt) - .set_causal_mask(is_causal_mask) + .set_causal_mask(is_causal) .set_attn_scale(scale); // Setting bias @@ -6473,6 +6477,10 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( VLOG(4) << "\b flash attention operation graph: " << graph; } return cudnnGraph; +#else + return absl::UnimplementedError( + "Cudnn flash attention only supported with Cudnn >= 8.9.4"); +#endif } absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( @@ -6485,8 +6493,9 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const dnn::TensorDescriptor& dv_desc, const std::optional bias_descriptor, std::optional dropout_rate, std::optional seed, - double scale, bool use_dropout = false, bool use_mask = false, - bool use_bias = false, bool use_causal_mask = false) { + double scale, bool use_dropout, bool use_mask, bool use_bias, + dnn::FMHAMaskKind mask_type) { +#if CUDNN_VERSION >= 8904 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() << "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString() @@ -6566,11 +6575,12 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_stride(p_reduction_strides) .set_uid(CudnnfMHAUid::P_ID) .set_data_type(cudnn_frontend::DataType_t::FLOAT)); - + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || + mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; auto sdpa_backward_options = cudnn_frontend::graph::SDPA_backward_attributes() .set_name("flash_attention_backward") - .set_causal_mask(use_causal_mask) + .set_causal_mask(is_causal) .set_attn_scale(scale) .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); @@ -6585,7 +6595,6 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_uid(CudnnfMHAUid::BIAS_ID)); sdpa_backward_options.set_bias(bias_tensor); } - // Setting seed and offset if (use_dropout) { DCHECK(dropout_rate != std::nullopt); @@ -6643,6 +6652,10 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( } return cudnnGraph; +#else + return absl::UnimplementedError( + "Cudnn flash attention only supported with Cudnn >= 8.9.4"); +#endif } absl::Status CudnnSupport::DoPrepareForConvolution( @@ -8463,7 +8476,7 @@ CudnnSupport::FusedMHARunnerFromDesc( std::optional mask_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { + bool is_flash_attention, dnn::FMHAMaskKind mask_type) { #if CUDNN_VERSION >= 8800 auto cudnn = cudnn_->GetHandle(parent_, stream); bool use_dropout = dropout_rate && *dropout_rate > 0.0; @@ -8479,7 +8492,7 @@ CudnnSupport::FusedMHARunnerFromDesc( /*o_descriptor=*/output_descriptor, bias_descriptor, mask_descriptor, /*stats_descriptor=*/activation_descriptor, /*scale=*/static_cast(scale), use_dropout, dropout_rate, - is_causal_mask)); + mask_type)); std::vector intermediate_bmm2_lhs_dims = intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); @@ -8581,7 +8594,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( std::optional fwd_output_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { + bool is_flash_attention, dnn::FMHAMaskKind mask_type) { #if CUDNN_VERSION >= 8800 auto cudnn = cudnn_->GetHandle(parent_, stream); @@ -8598,7 +8611,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale, use_dropout, /*use_mask*/ mask_descriptor != std::nullopt, - /*use_bias*/ bias_descriptor != std::nullopt, is_causal_mask)); + /*use_bias*/ bias_descriptor != std::nullopt, mask_type)); std::vector p_dims = bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false); @@ -8611,9 +8624,9 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt, std::nullopt, std::nullopt, CudnnfMHAUid::O_ID}; - if (bias_descriptor) { - uids.push_back(CudnnfMHAUid::BIAS_ID); - } + uids.emplace_back(bias_descriptor.has_value() + ? std::optional(CudnnfMHAUid::BIAS_ID) + : std::nullopt); TF_ASSIGN_OR_RETURN( auto runner, CudnnGraphRunner::Create( parent_, cudnn_.get(), graph, dropout_rng_seed, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index fe7bbd717c294d..d2ab9ad7d61ab3 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -348,7 +348,7 @@ class CudnnSupport : public dnn::DnnSupport { std::optional mask_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) override; + bool is_flash_attention, dnn::FMHAMaskKind mask_type) override; absl::StatusOr> FusedMHABackwardRunnerFromDesc( @@ -368,7 +368,7 @@ class CudnnSupport : public dnn::DnnSupport { std::optional fwd_output_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask); + bool is_flash_attention, dnn::FMHAMaskKind mask_type); bool GetRnnAlgorithms( std::vector* out_algorithms) override; @@ -730,7 +730,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional mask_descriptor, const std::optional stats_descriptor, const float scale, const bool use_dropout, - const std::optional dropout_rate, const bool is_causal_mask); + const std::optional dropout_rate, + const dnn::FMHAMaskKind mask_type); absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, @@ -743,7 +744,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const std::optional bias_descriptor, std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_mask, bool use_bias, - bool use_causal_mask); + const dnn::FMHAMaskKind mask_type); } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index f812dca34fc9c0..7cfc21b78a06ff 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -262,7 +262,7 @@ DnnSupport::FusedMHARunnerFromDesc( std::optional mask_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { + bool is_flash_attention, dnn::FMHAMaskKind mask_type) { return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented."); } @@ -284,7 +284,7 @@ DnnSupport::FusedMHABackwardRunnerFromDesc( std::optional fwd_output_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { + bool is_flash_attention, dnn::FMHAMaskKind mask_type) { return absl::UnimplementedError( "FusedMHABackwardRunnerFromDesc not implemented."); } diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index 696c7cd2bfd0ac..c312ccd802a131 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -1748,7 +1748,7 @@ class DnnSupport { std::optional mask_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask); + bool is_flash_attention, dnn::FMHAMaskKind mask_type); virtual absl::StatusOr> FusedMHABackwardRunnerFromDesc( @@ -1767,7 +1767,7 @@ class DnnSupport { std::optional fwd_output_descriptor, std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask); + bool is_flash_attention, dnn::FMHAMaskKind mask_type); virtual bool GetMIOpenConvolveAlgorithms( ConvolutionKind kind, DataType element_type, Stream* stream, diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index c54a9bfdac005f..8f99d6f6b23f54 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -296,7 +296,7 @@ struct FusedMHAOp { std::optional dropout_rate; std::optional seed; bool is_flash_attention; - bool is_causal_mask; + FMHAMaskKind mask_type; }; static absl::StatusOr>> @@ -309,7 +309,7 @@ struct FusedMHAOp { config.intermediate_bmm2_lhs_descriptor, config.output_descriptor, config.activation_descriptor, config.mask_descriptor, config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.is_flash_attention, config.is_causal_mask); + config.is_flash_attention, config.mask_type); } }; @@ -335,7 +335,7 @@ struct FusedMHABackwardOp { std::optional dropout_rate; std::optional seed; bool is_flash_attention; - bool is_causal_mask; + FMHAMaskKind mask_type; }; static absl::StatusOr< @@ -353,7 +353,7 @@ struct FusedMHABackwardOp { config.mask_descriptor, config.d_bias_descriptor, config.fwd_output_descriptor, config.bias_descriptor, config.scale, config.dropout_rate, config.seed, config.is_flash_attention, - config.is_causal_mask); + config.mask_type); } }; diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index 1cbfbf1a363304..06e2380b70a542 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -34,6 +34,7 @@ cc_library( "//xla:debug_options_flags", "//xla:statusor", "//xla:types", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:compiler", "//xla/service:executable", @@ -41,9 +42,13 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor:platform", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:statusor", ], ) @@ -73,6 +78,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor/platform", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:ir_headers", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index ee8f93aa7cb0eb..095cd94dc60ac7 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -19,16 +19,13 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" -#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/buffer_value.h" #include "xla/service/compiler.h" #include "xla/service/dump.h" #include "xla/service/executable.h" -#include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_compiler.h" @@ -36,11 +33,10 @@ limitations under the License. #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/platform_util.h" -#include "xla/statusor.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tools/hlo_opt/opt_lib.h" -#include "xla/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/opt_lib.cc index 5435c0a27259c4..ebc0654808c454 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc +++ b/third_party/xla/xla/tools/hlo_opt/opt_lib.cc @@ -22,20 +22,23 @@ limitations under the License. #include #include +#include "absl/base/const_init.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/platform_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" +#include "xla/xla.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.h b/third_party/xla/xla/tools/hlo_opt/opt_lib.h index bed286a5815b1c..97097facf570b1 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_lib.h +++ b/third_party/xla/xla/tools/hlo_opt/opt_lib.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" diff --git a/third_party/xla/xla/tools/hlo_opt/opt_main.cc b/third_party/xla/xla/tools/hlo_opt/opt_main.cc index a767a20c5b4087..fca335e60647d9 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_main.cc +++ b/third_party/xla/xla/tools/hlo_opt/opt_main.cc @@ -34,20 +34,15 @@ limitations under the License. #include "absl/strings/str_split.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_runner.h" -#include "xla/service/platform_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/tools/hlo_module_loader.h" #include "xla/tools/hlo_opt/opt_lib.h" -#include "xla/tools/run_hlo_module.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 86d8c858328fdc..81df83ad7ba4a6 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -640,9 +640,22 @@ static void ExtractShardingsFromFunction( (*ret_shardings)[i] = xla::ConvertSharding(sharding.getValue()); } -// Creates a tuple sharding with the given shardings if at least one is present. -// +void AppendTupleShardingElements(xla::OpSharding* result, + const xla::OpSharding& tuple_sharding) { + if (tuple_sharding.type() == xla::OpSharding::TUPLE) { + for (const xla::OpSharding& element : tuple_sharding.tuple_shardings()) { + AppendTupleShardingElements(result, element); + } + } else { + *result->add_tuple_shardings() = tuple_sharding; + } +} + +// Creates a tuple sharding with `tuple_shardings` if at least one is present. // Adds replicated shardings for any missing tuple shardings. +// +// The tuple xla::Shape can be nested, while xla::OpSharding stores a flattened +// list of shardings for the leaves of a tuple shape. std::optional CreateTupleSharding( llvm::ArrayRef> tuple_shardings) { if (tuple_shardings.empty() || @@ -653,7 +666,7 @@ std::optional CreateTupleSharding( sharding.set_type(xla::OpSharding::TUPLE); for (const std::optional& tuple_sharding : tuple_shardings) { if (tuple_sharding) { - *sharding.add_tuple_shardings() = *tuple_sharding; + AppendTupleShardingElements(&sharding, *tuple_sharding); } else { xla::OpSharding fallback_sharding; fallback_sharding.set_type(xla::OpSharding::REPLICATED); @@ -882,13 +895,17 @@ bool SimplyReturnedOp(mlir::Operation* op) { void BuildGetTupleElementsForTupleResults(mlir::Operation* op, xla::XlaOp tuple, OpLoweringContext ctx) { - const std::optional& tuple_sharding = - ctx.builder->sharding(); - if (tuple_sharding.has_value()) { - assert(op->getNumResults() == tuple_sharding->tuple_shardings_size()); + const std::optional& sharding = ctx.builder->sharding(); + if (sharding.has_value()) { + bool is_tuple_sharding = sharding->type() == xla::OpSharding::TUPLE; + assert(!is_tuple_sharding || + op->getNumResults() == sharding->tuple_shardings_size()); for (auto [index, result] : llvm::enumerate(op->getResults())) { + // If `sharding` is not a tuple sharding, then every `get-tuple-element` + // gets the same sharding. xla::XlaScopedShardingAssignment scoped_sharding( - ctx.builder, tuple_sharding->tuple_shardings(index)); + ctx.builder, + is_tuple_sharding ? sharding->tuple_shardings(index) : sharding); (*ctx.values)[result] = xla::GetTupleElement(tuple, index); } } else { @@ -3210,16 +3227,23 @@ LogicalResult ConvertToHloModule::Lower( xla::XlaScopedShardingAssignment scoped_sharding(builder, ret_tuple_sharding); *return_value = xla::Tuple(builder, returns); - } else if (num_return_values == 1) { + } else { xla::XlaOp operand; if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst))) return failure(); if (ret_tuple_sharding) { - auto tuple = Tuple(builder, {operand}); - builder->SetSharding(*ret_shardings[0]); - *return_value = GetTupleElement(tuple, 0); - builder->ClearSharding(); + xla::XlaOp tuple; + { + xla::XlaScopedShardingAssignment scoped_sharding(builder, + ret_tuple_sharding); + tuple = Tuple(builder, {operand}); + } + { + xla::XlaScopedShardingAssignment scoped_sharding(builder, + *ret_shardings[0]); + *return_value = GetTupleElement(tuple, 0); + } } else { *return_value = operand; } diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir index 509f654a0f77bc..96df99713d6d3d 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir @@ -14,20 +14,31 @@ func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4x // ----- // CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> f32[5,8,128] -func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) { +func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "{devices=[1,2,1]0,1}"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "{devices=[1,2,1]0,1}"}) { // CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1} // CHECK-NEXT: %custom-call.2 = f32[5,8,128] custom-call(f32[5,8,128] %Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(f32[5,8,128] %custom-call.2) - // CHECK-NEXT: ROOT %get-tuple-element.4 = f32[5,8,128] get-tuple-element((f32[5,8,128]) %tuple.3), index=0 - // CHECK-SAME: sharding={devices=[1,2,1]0,1} - %0 = "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", - mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01" - } : (tensor<5x8x128xf32>) -> tensor<5x8x128xf32> + // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(f32[5,8,128] %custom-call.2), sharding={{\{}}{devices=[1,2,1]0,1}} + // CHECK-NEXT: ROOT %get-tuple-element.4 = f32[5,8,128] get-tuple-element((f32[5,8,128]) %tuple.3), index=0, sharding={devices=[1,2,1]0,1} + %0 = "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = "{devices=[1,2,1]0,1}"} : (tensor<5x8x128xf32>) -> tensor<5x8x128xf32> func.return %0 : tensor<5x8x128xf32> } // ----- +// CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> (f32[5,8,128]) +func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "{devices=[1,2,1]0,1}"}) -> (tuple> {mhlo.sharding = "{{devices=[1,2,1]0,1}}"}) { + // CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: %custom-call.2 = f32[5,8,128] custom-call(f32[5,8,128] %Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(f32[5,8,128] %custom-call.2) + // CHECK-NEXT: %tuple.4 = ((f32[5,8,128])) tuple((f32[5,8,128]) %tuple.3), sharding={{\{}}{devices=[1,2,1]0,1}} + // CHECK-NEXT: ROOT %get-tuple-element.5 = (f32[5,8,128]) get-tuple-element(((f32[5,8,128])) %tuple.4), index=0, sharding={{\{}}{devices=[1,2,1]0,1}} + %0 = "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = "{devices=[1,2,1]0,1}"} : (tensor<5x8x128xf32>) -> tensor<5x8x128xf32> + %1 = "mhlo.tuple"(%0) : (tensor<5x8x128xf32>) -> tuple> + func.return %1 : tuple> +} + +// ----- + // CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[4,4]) -> (f32[4,4], f32[4,4]) func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\03\02\01\02\22\04\00\01\02\03B\01\00"}, tensor<4x4xf32>) { // CHECK-NEXT: %Arg_0.1 = f32[4,4] parameter(0) @@ -82,6 +93,21 @@ func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64> {mhlo.sharding = "{dev // ----- +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) +func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) { + // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) + // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {replicated}} + // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} + // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) + // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={replicated} + // CHECK-NEXT: ROOT %tuple.6 = (u64[2], u32[512,4]) tuple(u64[2] %add.5, u32[512,4] %get-tuple-element.4) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{replicated}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) + %0 = mhlo.add %output_state, %output_state : tensor<2xui64> + return %0, %output : tensor<2xui64>, tensor<512x4xui32> +} + +// ----- + // CHECK-LABEL: HloModule main // CHECK: %region_0.2 (Arg_.3: s32[]) -> s32[] { diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 0e9e3453187f48..796b7d63cdf88a 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -60,7 +60,7 @@ def clean_dep(target): # Label() call appears, e.g. @local_tsl or tsl. # TODO(ddunleavy): update this during and after go/moving-tsl-into-xla-lsc label = Label(target) - not_yet_moved = ["concurrency", "distributed_runtime", "framework", "lib", "platform", "profiler", "protobuf"] + not_yet_moved = ["concurrency", "framework", "lib", "platform", "profiler", "protobuf"] if any([label.package.startswith("tsl/" + dirname) for dirname in not_yet_moved]): return "@local_tsl//" + label.package + ":" + label.name